use std::collections::HashMap;
use std::ops::Range;
use tracing::debug;
use crate::expressions::ArrayData;
use crate::log_replay::HasSelectionVector;
use crate::schema::{ColumnName, DataType, SchemaRef};
use crate::{AsAny, DeltaResult, Error};
pub struct FilteredEngineData {
data: Box<dyn EngineData>,
selection_vector: Vec<bool>,
}
impl FilteredEngineData {
pub fn try_new(data: Box<dyn EngineData>, selection_vector: Vec<bool>) -> DeltaResult<Self> {
if selection_vector.len() > data.len() {
return Err(Error::InvalidSelectionVector(format!(
"Selection vector is larger than data length: {} > {}",
selection_vector.len(),
data.len()
)));
}
Ok(Self {
data,
selection_vector,
})
}
pub fn data(&self) -> &dyn EngineData {
&*self.data
}
pub fn selection_vector(&self) -> &[bool] {
&self.selection_vector
}
pub fn into_parts(self) -> (Box<dyn EngineData>, Vec<bool>) {
(self.data, self.selection_vector)
}
pub fn with_all_rows_selected(data: Box<dyn EngineData>) -> Self {
Self {
data,
selection_vector: vec![],
}
}
pub fn apply_selection_vector(self) -> DeltaResult<Box<dyn EngineData>> {
self.data
.apply_selection_vector(self.selection_vector.clone())
}
}
impl HasSelectionVector for FilteredEngineData {
fn has_selected_rows(&self) -> bool {
if self.selection_vector.len() < self.data.len() {
return true;
}
self.selection_vector.contains(&true)
}
}
impl From<Box<dyn EngineData>> for FilteredEngineData {
fn from(data: Box<dyn EngineData>) -> Self {
Self::with_all_rows_selected(data)
}
}
pub trait StringArrayAccessor {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn value(&self, index: usize) -> &str;
fn is_valid(&self, index: usize) -> bool;
}
pub struct ListItem<'a> {
values: &'a dyn StringArrayAccessor,
offsets: Range<usize>,
}
impl<'a> ListItem<'a> {
pub fn new(values: &'a dyn StringArrayAccessor, offsets: Range<usize>) -> ListItem<'a> {
ListItem { values, offsets }
}
pub fn len(&self) -> usize {
self.offsets.len()
}
pub fn is_empty(&self) -> bool {
self.offsets.is_empty()
}
pub fn get(&self, list_index: usize) -> String {
self.values
.value(self.offsets.start + list_index)
.to_string()
}
pub fn materialize(&self) -> Vec<String> {
self.offsets
.clone()
.map(|i| self.values.value(i).to_string())
.collect()
}
}
pub struct MapItem<'a> {
keys: &'a dyn StringArrayAccessor,
values: &'a dyn StringArrayAccessor,
offsets: Range<usize>,
}
impl<'a> MapItem<'a> {
pub fn new(
keys: &'a dyn StringArrayAccessor,
values: &'a dyn StringArrayAccessor,
offsets: Range<usize>,
) -> MapItem<'a> {
MapItem {
keys,
values,
offsets,
}
}
pub fn get(&self, key: &str) -> Option<&'a str> {
let idx = self
.offsets
.clone()
.rev()
.find(|&idx| self.keys.value(idx) == key)?;
self.values.is_valid(idx).then(|| self.values.value(idx))
}
pub fn materialize(&self) -> HashMap<String, String> {
let mut ret = HashMap::with_capacity(self.offsets.len());
for idx in self.offsets.clone() {
if self.values.is_valid(idx) {
ret.insert(
self.keys.value(idx).to_string(),
self.values.value(idx).to_string(),
);
}
}
ret
}
}
macro_rules! impl_default_get {
( $(($name: ident, $typ: ty)), * ) => {
$(
fn $name(&'a self, _row_index: usize, field_name: &str) -> DeltaResult<Option<$typ>> {
debug!("Asked for type {} on {field_name}, but using default error impl.", stringify!($typ));
Err(Error::UnexpectedColumnType(format!("{field_name} is not of type {}", stringify!($typ))).with_backtrace())
}
)*
};
}
pub trait GetData<'a> {
impl_default_get!(
(get_bool, bool),
(get_int, i32),
(get_long, i64),
(get_float, f32),
(get_double, f64),
(get_date, i32),
(get_timestamp, i64),
(get_decimal, i128),
(get_str, &'a str),
(get_binary, &'a [u8]),
(get_list, ListItem<'a>),
(get_map, MapItem<'a>)
);
}
macro_rules! impl_null_get {
( $(($name: ident, $typ: ty)), * ) => {
$(
fn $name(&'a self, _row_index: usize, _field_name: &str) -> DeltaResult<Option<$typ>> {
Ok(None)
}
)*
};
}
impl<'a> GetData<'a> for () {
impl_null_get!(
(get_bool, bool),
(get_int, i32),
(get_long, i64),
(get_float, f32),
(get_double, f64),
(get_date, i32),
(get_timestamp, i64),
(get_decimal, i128),
(get_str, &'a str),
(get_binary, &'a [u8]),
(get_list, ListItem<'a>),
(get_map, MapItem<'a>)
);
}
pub trait TypedGetData<'a, T> {
fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<T>>;
fn get(&'a self, row_index: usize, field_name: &str) -> DeltaResult<T> {
let val = self.get_opt(row_index, field_name)?;
val.ok_or_else(|| {
Error::MissingData(format!("Data missing for field {field_name}")).with_backtrace()
})
}
}
macro_rules! impl_typed_get_data {
( $(($name: ident, $typ: ty)), * ) => {
$(
impl<'a> TypedGetData<'a, $typ> for dyn GetData<'a> +'_ {
fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<$typ>> {
self.$name(row_index, field_name)
}
}
)*
};
}
impl_typed_get_data!(
(get_bool, bool),
(get_int, i32),
(get_long, i64),
(get_float, f32),
(get_double, f64),
(get_decimal, i128),
(get_str, &'a str),
(get_binary, &'a [u8]),
(get_list, ListItem<'a>),
(get_map, MapItem<'a>)
);
impl<'a> TypedGetData<'a, String> for dyn GetData<'a> + '_ {
fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<String>> {
self.get_str(row_index, field_name)
.map(|s| s.map(|s| s.to_string()))
}
}
impl<'a> TypedGetData<'a, Vec<String>> for dyn GetData<'a> + '_ {
fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<Vec<String>>> {
let list_opt: Option<ListItem<'_>> = self.get_opt(row_index, field_name)?;
Ok(list_opt.map(|list| list.materialize()))
}
}
impl<'a> TypedGetData<'a, HashMap<String, String>> for dyn GetData<'a> + '_ {
fn get_opt(
&'a self,
row_index: usize,
field_name: &str,
) -> DeltaResult<Option<HashMap<String, String>>> {
let map_opt: Option<MapItem<'_>> = self.get_opt(row_index, field_name)?;
Ok(map_opt.map(|map| map.materialize()))
}
}
pub struct RowIndexIterator<'sv> {
sv_pos: usize,
selection_vector: &'sv [bool],
row_count: usize,
}
impl<'sv> RowIndexIterator<'sv> {
pub(crate) fn new(row_count: usize, selection_vector: &'sv [bool]) -> Self {
Self {
sv_pos: 0,
selection_vector,
row_count,
}
}
pub fn num_rows(&self) -> usize {
self.row_count
}
}
impl<'sv> Iterator for RowIndexIterator<'sv> {
type Item = usize;
fn next(&mut self) -> Option<usize> {
while self.sv_pos < self.row_count {
let pos = self.sv_pos;
self.sv_pos += 1;
if pos >= self.selection_vector.len() || self.selection_vector[pos] {
return Some(pos);
}
}
None
}
}
pub trait FilteredRowVisitor {
fn selected_column_names_and_types(&self) -> (&'static [ColumnName], &'static [DataType]);
fn visit_filtered<'a>(
&mut self,
getters: &[&'a dyn GetData<'a>],
rows: RowIndexIterator<'_>,
) -> DeltaResult<()>;
fn visit_rows_of(&mut self, data: &FilteredEngineData) -> DeltaResult<()>
where
Self: Sized,
{
let column_names = self.selected_column_names_and_types().0;
let mut bridge = FilteredVisitorBridge {
visitor: self,
selection_vector: data.selection_vector(),
};
data.data().visit_rows(column_names, &mut bridge)
}
}
struct FilteredVisitorBridge<'bridge, V: FilteredRowVisitor> {
visitor: &'bridge mut V,
selection_vector: &'bridge [bool],
}
impl<V: FilteredRowVisitor> RowVisitor for FilteredVisitorBridge<'_, V> {
fn selected_column_names_and_types(&self) -> (&'static [ColumnName], &'static [DataType]) {
self.visitor.selected_column_names_and_types()
}
fn visit<'a>(&mut self, row_count: usize, getters: &[&'a dyn GetData<'a>]) -> DeltaResult<()> {
let rows = RowIndexIterator::new(row_count, self.selection_vector);
self.visitor.visit_filtered(getters, rows)
}
}
pub trait RowVisitor {
fn selected_column_names_and_types(&self) -> (&'static [ColumnName], &'static [DataType]);
fn visit<'a>(&mut self, row_count: usize, getters: &[&'a dyn GetData<'a>]) -> DeltaResult<()>;
fn visit_rows_of(&mut self, data: &dyn EngineData) -> DeltaResult<()>
where
Self: Sized,
{
data.visit_rows(self.selected_column_names_and_types().0, self)
}
}
pub trait EngineData: AsAny {
fn visit_rows(
&self,
column_names: &[ColumnName],
visitor: &mut dyn RowVisitor,
) -> DeltaResult<()>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn append_columns(
&self,
schema: SchemaRef,
columns: Vec<ArrayData>,
) -> DeltaResult<Box<dyn EngineData>>;
fn apply_selection_vector(
self: Box<Self>,
selection_vector: Vec<bool>,
) -> DeltaResult<Box<dyn EngineData>>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrow::array::{RecordBatch, StringArray};
use crate::arrow::datatypes::{
DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema,
};
use crate::engine::arrow_data::ArrowEngineData;
use rstest::rstest;
use std::sync::Arc;
fn get_engine_data(rows: usize) -> Box<dyn EngineData> {
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"value",
ArrowDataType::Utf8,
true,
)]));
let data: Vec<String> = (0..rows).map(|i| format!("row{i}")).collect();
Box::new(ArrowEngineData::new(
RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(data))]).unwrap(),
))
}
#[test]
fn test_with_all_rows_selected_empty_data() {
let data = get_engine_data(0);
let filtered_data = FilteredEngineData::with_all_rows_selected(data);
assert_eq!(filtered_data.selection_vector().len(), 0);
assert!(filtered_data.selection_vector().is_empty());
assert_eq!(filtered_data.data().len(), 0);
}
#[test]
fn test_with_all_rows_selected_single_row() {
let data = get_engine_data(1);
let filtered_data = FilteredEngineData::with_all_rows_selected(data);
assert!(filtered_data.selection_vector().is_empty());
assert_eq!(filtered_data.data().len(), 1);
assert!(filtered_data.has_selected_rows());
}
#[test]
fn test_with_all_rows_selected_multiple_rows() {
let data = get_engine_data(4);
let filtered_data = FilteredEngineData::with_all_rows_selected(data);
assert!(filtered_data.selection_vector().is_empty());
assert_eq!(filtered_data.data().len(), 4);
assert!(filtered_data.has_selected_rows());
}
#[test]
fn test_has_selected_rows_empty_data() {
let data = get_engine_data(0);
let filtered_data = FilteredEngineData::try_new(data, vec![]).unwrap();
assert!(!filtered_data.has_selected_rows());
}
#[test]
fn test_has_selected_rows_selection_vector_shorter_than_data() {
let data = get_engine_data(3);
let filtered_data = FilteredEngineData::try_new(data, vec![false, false]).unwrap();
assert!(filtered_data.has_selected_rows());
}
#[test]
fn test_has_selected_rows_selection_vector_same_length_all_false() {
let data = get_engine_data(2);
let filtered_data = FilteredEngineData::try_new(data, vec![false, false]).unwrap();
assert!(!filtered_data.has_selected_rows());
}
#[test]
fn test_has_selected_rows_selection_vector_same_length_some_true() {
let data = get_engine_data(3);
let filtered_data = FilteredEngineData::try_new(data, vec![true, false, true]).unwrap();
assert!(filtered_data.has_selected_rows());
}
#[test]
fn test_try_new_selection_vector_larger_than_data() {
let data = get_engine_data(2);
let result = FilteredEngineData::try_new(data, vec![true, false, true]);
assert!(result.is_err());
if let Err(e) = result {
assert!(e
.to_string()
.contains("Selection vector is larger than data length"));
assert!(e.to_string().contains("3 > 2"));
}
}
#[test]
fn test_get_binary_some_value() {
use crate::arrow::array::BinaryArray;
let binary_data: Vec<Option<&[u8]>> = vec![Some(b"hello"), Some(b"world"), None];
let binary_array = BinaryArray::from(binary_data);
let getter: &dyn GetData<'_> = &binary_array;
let result: Option<&[u8]> = getter.get_opt(0, "binary_field").unwrap();
assert_eq!(result, Some(b"hello".as_ref()));
let result: Option<&[u8]> = getter.get_opt(1, "binary_field").unwrap();
assert_eq!(result, Some(b"world".as_ref()));
let result: Option<&[u8]> = getter.get_opt(2, "binary_field").unwrap();
assert_eq!(result, None);
}
#[test]
fn test_get_binary_required() {
use crate::arrow::array::BinaryArray;
let binary_data: Vec<Option<&[u8]>> = vec![Some(b"hello")];
let binary_array = BinaryArray::from(binary_data);
let getter: &dyn GetData<'_> = &binary_array;
let result: &[u8] = getter.get(0, "binary_field").unwrap();
assert_eq!(result, b"hello");
}
#[test]
fn test_get_binary_required_missing() {
use crate::arrow::array::BinaryArray;
let binary_data: Vec<Option<&[u8]>> = vec![None];
let binary_array = BinaryArray::from(binary_data);
let getter: &dyn GetData<'_> = &binary_array;
let result: DeltaResult<&[u8]> = getter.get(0, "binary_field");
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Data missing for field"));
}
}
#[test]
fn test_get_binary_empty_bytes() {
use crate::arrow::array::BinaryArray;
let binary_data: Vec<Option<&[u8]>> = vec![Some(b"")];
let binary_array = BinaryArray::from(binary_data);
let getter: &dyn GetData<'_> = &binary_array;
let result: Option<&[u8]> = getter.get_opt(0, "binary_field").unwrap();
assert_eq!(result, Some([].as_ref()));
assert_eq!(result.unwrap().len(), 0);
}
#[test]
fn test_from_engine_data() {
let data = get_engine_data(3);
let data_len = data.len();
let filtered_data: FilteredEngineData = data.into();
assert!(filtered_data.selection_vector().is_empty());
assert_eq!(filtered_data.data().len(), data_len);
assert_eq!(filtered_data.data().len(), 3);
assert!(filtered_data.has_selected_rows());
}
#[test]
fn filtered_apply_seclection_vector_full() {
let data = get_engine_data(4);
let filtered = FilteredEngineData::try_new(data, vec![true, false, true, false]).unwrap();
let data = filtered.apply_selection_vector().unwrap();
assert_eq!(data.len(), 2);
}
#[test]
fn filtered_apply_seclection_vector_partial() {
let data = get_engine_data(4);
let filtered = FilteredEngineData::try_new(data, vec![true, false]).unwrap();
let data = filtered.apply_selection_vector().unwrap();
assert_eq!(data.len(), 3);
}
fn collect_indices(row_count: usize, selection: &[bool]) -> Vec<usize> {
RowIndexIterator::new(row_count, selection).collect()
}
#[rstest]
#[case(0, &[], vec![])]
#[case(3, &[], vec![0, 1, 2])]
#[case(3, &[true, true, true], vec![0, 1, 2])]
#[case(3, &[false, false, false], vec![])]
#[case(5, &[true, false, false, true, true], vec![0, 3, 4])]
#[case(4, &[false, false, true, true], vec![2, 3])]
#[case(3, &[true, false, false], vec![0])]
#[case(4, &[false, true], vec![1, 2, 3])]
#[case(4, &[true, false], vec![0, 2, 3])]
#[case(4, &[false, true, false, true], vec![1, 3])]
fn row_index_iter(
#[case] row_count: usize,
#[case] selection: &[bool],
#[case] expected: Vec<usize>,
) {
assert_eq!(collect_indices(row_count, selection), expected);
}
}