1use std::collections::HashMap;
4use std::ops::Range;
5
6use tracing::debug;
7
8use crate::actions::visitors::SelectionVectorVisitor;
9use crate::expressions::ArrayData;
10use crate::log_replay::HasSelectionVector;
11use crate::schema::{ColumnName, DataType, SchemaRef};
12use crate::utils::require;
13use crate::{AsAny, DeltaResult, Error};
14
15pub struct FilteredEngineData {
24 data: Box<dyn EngineData>,
26 selection_vector: Vec<bool>,
30}
31
32impl FilteredEngineData {
33 pub fn try_new(data: Box<dyn EngineData>, selection_vector: Vec<bool>) -> DeltaResult<Self> {
34 if selection_vector.len() > data.len() {
35 return Err(Error::InvalidSelectionVector(format!(
36 "Selection vector is larger than data length: {} > {}",
37 selection_vector.len(),
38 data.len()
39 )));
40 }
41 Ok(Self {
42 data,
43 selection_vector,
44 })
45 }
46
47 pub fn data(&self) -> &dyn EngineData {
49 &*self.data
50 }
51
52 pub fn selection_vector(&self) -> &[bool] {
54 &self.selection_vector
55 }
56
57 pub fn into_parts(self) -> (Box<dyn EngineData>, Vec<bool>) {
59 (self.data, self.selection_vector)
60 }
61
62 pub fn with_all_rows_selected(data: Box<dyn EngineData>) -> Self {
67 Self {
68 data,
69 selection_vector: vec![],
70 }
71 }
72
73 pub fn apply_selection_vector(self) -> DeltaResult<Box<dyn EngineData>> {
76 self.data
77 .apply_selection_vector(self.selection_vector.clone())
78 }
79}
80
81impl HasSelectionVector for FilteredEngineData {
82 fn has_selected_rows(&self) -> bool {
84 if self.selection_vector.len() < self.data.len() {
86 return true;
87 }
88
89 self.selection_vector.contains(&true)
90 }
91}
92
93impl From<Box<dyn EngineData>> for FilteredEngineData {
94 fn from(data: Box<dyn EngineData>) -> Self {
106 Self::with_all_rows_selected(data)
107 }
108}
109
110pub trait StringArrayAccessor {
115 fn len(&self) -> usize;
117 fn is_empty(&self) -> bool {
119 self.len() == 0
120 }
121 fn value(&self, index: usize) -> &str;
123 fn is_valid(&self, index: usize) -> bool;
125}
126
127pub struct ListItem<'a> {
131 values: &'a dyn StringArrayAccessor,
132 offsets: Range<usize>,
133}
134
135impl<'a> ListItem<'a> {
136 pub fn new(values: &'a dyn StringArrayAccessor, offsets: Range<usize>) -> ListItem<'a> {
137 ListItem { values, offsets }
138 }
139
140 pub fn len(&self) -> usize {
141 self.offsets.len()
142 }
143
144 pub fn is_empty(&self) -> bool {
145 self.offsets.is_empty()
146 }
147
148 pub fn get(&self, list_index: usize) -> String {
149 self.values
150 .value(self.offsets.start + list_index)
151 .to_string()
152 }
153
154 pub fn materialize(&self) -> Vec<String> {
155 self.offsets
156 .clone()
157 .map(|i| self.values.value(i).to_string())
158 .collect()
159 }
160}
161
162pub struct MapItem<'a> {
172 keys: &'a dyn StringArrayAccessor,
173 values: &'a dyn StringArrayAccessor,
174 offsets: Range<usize>,
175}
176
177impl<'a> MapItem<'a> {
178 pub fn new(
179 keys: &'a dyn StringArrayAccessor,
180 values: &'a dyn StringArrayAccessor,
181 offsets: Range<usize>,
182 ) -> MapItem<'a> {
183 MapItem {
184 keys,
185 values,
186 offsets,
187 }
188 }
189
190 pub fn get(&self, key: &str) -> Option<&'a str> {
191 let idx = self
192 .offsets
193 .clone()
194 .rev()
195 .find(|&idx| self.keys.value(idx) == key)?;
196 self.values.is_valid(idx).then(|| self.values.value(idx))
197 }
198
199 pub fn materialize(&self) -> HashMap<String, String> {
200 let mut ret = HashMap::with_capacity(self.offsets.len());
201 for idx in self.offsets.clone() {
202 if self.values.is_valid(idx) {
203 ret.insert(
204 self.keys.value(idx).to_string(),
205 self.values.value(idx).to_string(),
206 );
207 }
208 }
209 ret
210 }
211}
212
213macro_rules! impl_default_get {
214 ( $(($name: ident, $typ: ty)), * ) => {
215 $(
216 fn $name(&'a self, _row_index: usize, field_name: &str) -> DeltaResult<Option<$typ>> {
217 debug!("Asked for type {} on {field_name}, but using default error impl.", stringify!($typ));
218 Err(Error::UnexpectedColumnType(format!("{field_name} is not of type {}", stringify!($typ))).with_backtrace())
219 }
220 )*
221 };
222}
223
224pub trait GetData<'a> {
230 impl_default_get!(
231 (get_bool, bool),
232 (get_byte, i8),
233 (get_short, i16),
234 (get_int, i32),
235 (get_long, i64),
236 (get_float, f32),
237 (get_double, f64),
238 (get_date, i32),
239 (get_timestamp, i64),
240 (get_decimal, i128),
241 (get_str, &'a str),
242 (get_binary, &'a [u8]),
243 (get_list, ListItem<'a>),
244 (get_map, MapItem<'a>)
245 );
246}
247
248macro_rules! impl_null_get {
249 ( $(($name: ident, $typ: ty)), * ) => {
250 $(
251 fn $name(&'a self, _row_index: usize, _field_name: &str) -> DeltaResult<Option<$typ>> {
252 Ok(None)
253 }
254 )*
255 };
256}
257
258impl<'a> GetData<'a> for () {
259 impl_null_get!(
260 (get_bool, bool),
261 (get_byte, i8),
262 (get_short, i16),
263 (get_int, i32),
264 (get_long, i64),
265 (get_float, f32),
266 (get_double, f64),
267 (get_date, i32),
268 (get_timestamp, i64),
269 (get_decimal, i128),
270 (get_str, &'a str),
271 (get_binary, &'a [u8]),
272 (get_list, ListItem<'a>),
273 (get_map, MapItem<'a>)
274 );
275}
276
277pub trait TypedGetData<'a, T> {
280 fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<T>>;
281 fn get(&'a self, row_index: usize, field_name: &str) -> DeltaResult<T> {
282 let val = self.get_opt(row_index, field_name)?;
283 val.ok_or_else(|| {
284 Error::MissingData(format!("Data missing for field {field_name}")).with_backtrace()
285 })
286 }
287}
288
289macro_rules! impl_typed_get_data {
290 ( $(($name: ident, $typ: ty)), * ) => {
291 $(
292 impl<'a> TypedGetData<'a, $typ> for dyn GetData<'a> +'_ {
293 fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<$typ>> {
294 self.$name(row_index, field_name)
295 }
296 }
297 )*
298 };
299}
300
301impl_typed_get_data!(
305 (get_bool, bool),
306 (get_byte, i8),
307 (get_short, i16),
308 (get_int, i32),
309 (get_long, i64),
310 (get_float, f32),
311 (get_double, f64),
312 (get_decimal, i128),
313 (get_str, &'a str),
314 (get_binary, &'a [u8]),
315 (get_list, ListItem<'a>),
316 (get_map, MapItem<'a>)
317);
318
319impl<'a> TypedGetData<'a, String> for dyn GetData<'a> + '_ {
320 fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<String>> {
321 self.get_str(row_index, field_name)
322 .map(|s| s.map(|s| s.to_string()))
323 }
324}
325
326impl<'a> TypedGetData<'a, Vec<String>> for dyn GetData<'a> + '_ {
329 fn get_opt(&'a self, row_index: usize, field_name: &str) -> DeltaResult<Option<Vec<String>>> {
330 let list_opt: Option<ListItem<'_>> = self.get_opt(row_index, field_name)?;
331 Ok(list_opt.map(|list| list.materialize()))
332 }
333}
334
335impl<'a> TypedGetData<'a, HashMap<String, String>> for dyn GetData<'a> + '_ {
338 fn get_opt(
339 &'a self,
340 row_index: usize,
341 field_name: &str,
342 ) -> DeltaResult<Option<HashMap<String, String>>> {
343 let map_opt: Option<MapItem<'_>> = self.get_opt(row_index, field_name)?;
344 Ok(map_opt.map(|map| map.materialize()))
345 }
346}
347
348pub struct RowIndexIterator<'sv> {
355 sv_pos: usize,
356 selection_vector: &'sv [bool],
357 row_count: usize,
358}
359
360impl<'sv> RowIndexIterator<'sv> {
361 pub(crate) fn new(row_count: usize, selection_vector: &'sv [bool]) -> Self {
362 Self {
363 sv_pos: 0,
364 selection_vector,
365 row_count,
366 }
367 }
368
369 pub fn num_rows(&self) -> usize {
371 self.row_count
372 }
373}
374
375impl<'sv> Iterator for RowIndexIterator<'sv> {
376 type Item = usize;
377
378 fn next(&mut self) -> Option<usize> {
379 while self.sv_pos < self.row_count {
380 let pos = self.sv_pos;
381 self.sv_pos += 1;
382 if pos >= self.selection_vector.len() || self.selection_vector[pos] {
383 return Some(pos);
384 }
385 }
386 None
387 }
388}
389
390pub trait FilteredRowVisitor {
400 fn selected_column_names_and_types(&self) -> (&'static [ColumnName], &'static [DataType]);
401
402 fn visit_filtered<'a>(
407 &mut self,
408 getters: &[&'a dyn GetData<'a>],
409 rows: RowIndexIterator<'_>,
410 ) -> DeltaResult<()>;
411
412 fn visit_rows_of(&mut self, data: &FilteredEngineData) -> DeltaResult<()>
417 where
418 Self: Sized,
419 {
420 let column_names = self.selected_column_names_and_types().0;
422 let mut bridge = FilteredVisitorBridge {
423 visitor: self,
424 selection_vector: data.selection_vector(),
425 };
426 data.data().visit_rows(column_names, &mut bridge)
427 }
428}
429
430struct FilteredVisitorBridge<'bridge, V: FilteredRowVisitor> {
432 visitor: &'bridge mut V,
433 selection_vector: &'bridge [bool],
434}
435
436impl<V: FilteredRowVisitor> RowVisitor for FilteredVisitorBridge<'_, V> {
437 fn selected_column_names_and_types(&self) -> (&'static [ColumnName], &'static [DataType]) {
438 self.visitor.selected_column_names_and_types()
439 }
440
441 fn visit<'a>(&mut self, row_count: usize, getters: &[&'a dyn GetData<'a>]) -> DeltaResult<()> {
442 let rows = RowIndexIterator::new(row_count, self.selection_vector);
443 self.visitor.visit_filtered(getters, rows)
444 }
445}
446
447pub trait RowVisitor {
451 fn selected_column_names_and_types(&self) -> (&'static [ColumnName], &'static [DataType]);
457
458 fn visit<'a>(&mut self, row_count: usize, getters: &[&'a dyn GetData<'a>]) -> DeltaResult<()>;
465
466 fn visit_rows_of(&mut self, data: &dyn EngineData) -> DeltaResult<()>
470 where
471 Self: Sized,
472 {
473 data.visit_rows(self.selected_column_names_and_types().0, self)
474 }
475}
476
477pub trait EngineData: AsAny {
515 fn visit_rows(
519 &self,
520 column_names: &[ColumnName],
521 visitor: &mut dyn RowVisitor,
522 ) -> DeltaResult<()>;
523
524 fn len(&self) -> usize;
526
527 fn is_empty(&self) -> bool {
529 self.len() == 0
530 }
531
532 fn append_columns(
554 &self,
555 schema: SchemaRef,
556 columns: Vec<ArrayData>,
557 ) -> DeltaResult<Box<dyn EngineData>>;
558
559 fn apply_selection_vector(
563 self: Box<Self>,
564 selection_vector: Vec<bool>,
565 ) -> DeltaResult<Box<dyn EngineData>>;
566
567 fn has_field(&self, name: &ColumnName) -> bool;
572}
573
574pub(crate) fn filter_by_predicate(
577 filter: &dyn crate::PredicateEvaluator,
578 batch: Box<dyn EngineData>,
579) -> DeltaResult<Box<dyn EngineData>> {
580 let predicate_result = filter.evaluate(batch.as_ref())?;
581 let mut visitor = SelectionVectorVisitor::default();
582 visitor.visit_rows_of(predicate_result.as_ref())?;
583 require!(
584 visitor.selection_vector.len() == batch.len(),
585 Error::internal_error(format!(
586 "predicate output length {} != batch length {}",
587 visitor.selection_vector.len(),
588 batch.len()
589 ))
590 );
591 batch.apply_selection_vector(visitor.selection_vector)
592}
593
594#[cfg(test)]
595mod tests {
596 use std::sync::Arc;
597
598 use rstest::rstest;
599
600 use super::*;
601 use crate::arrow::array::{RecordBatch, StringArray};
602 use crate::arrow::datatypes::{
603 DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema,
604 };
605 use crate::engine::arrow_data::ArrowEngineData;
606
607 fn get_engine_data(rows: usize) -> Box<dyn EngineData> {
608 let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
609 "value",
610 ArrowDataType::Utf8,
611 true,
612 )]));
613 let data: Vec<String> = (0..rows).map(|i| format!("row{i}")).collect();
614 Box::new(ArrowEngineData::new(
615 RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(data))]).unwrap(),
616 ))
617 }
618
619 #[test]
620 fn test_with_all_rows_selected_empty_data() {
621 let data = get_engine_data(0);
623 let filtered_data = FilteredEngineData::with_all_rows_selected(data);
624
625 assert_eq!(filtered_data.selection_vector().len(), 0);
626 assert!(filtered_data.selection_vector().is_empty());
627 assert_eq!(filtered_data.data().len(), 0);
628 }
629
630 #[test]
631 fn test_with_all_rows_selected_single_row() {
632 let data = get_engine_data(1);
634 let filtered_data = FilteredEngineData::with_all_rows_selected(data);
635
636 assert!(filtered_data.selection_vector().is_empty());
638 assert_eq!(filtered_data.data().len(), 1);
639 assert!(filtered_data.has_selected_rows());
640 }
641
642 #[test]
643 fn test_with_all_rows_selected_multiple_rows() {
644 let data = get_engine_data(4);
646 let filtered_data = FilteredEngineData::with_all_rows_selected(data);
647
648 assert!(filtered_data.selection_vector().is_empty());
650 assert_eq!(filtered_data.data().len(), 4);
651 assert!(filtered_data.has_selected_rows());
652 }
653
654 #[test]
655 fn test_has_selected_rows_empty_data() {
656 let data = get_engine_data(0);
658 let filtered_data = FilteredEngineData::try_new(data, vec![]).unwrap();
659
660 assert!(!filtered_data.has_selected_rows());
662 }
663
664 #[test]
665 fn test_has_selected_rows_selection_vector_shorter_than_data() {
666 let data = get_engine_data(3);
668 let filtered_data = FilteredEngineData::try_new(data, vec![false, false]).unwrap();
670
671 assert!(filtered_data.has_selected_rows());
673 }
674
675 #[test]
676 fn test_has_selected_rows_selection_vector_same_length_all_false() {
677 let data = get_engine_data(2);
678 let filtered_data = FilteredEngineData::try_new(data, vec![false, false]).unwrap();
679
680 assert!(!filtered_data.has_selected_rows());
682 }
683
684 #[test]
685 fn test_has_selected_rows_selection_vector_same_length_some_true() {
686 let data = get_engine_data(3);
687 let filtered_data = FilteredEngineData::try_new(data, vec![true, false, true]).unwrap();
688
689 assert!(filtered_data.has_selected_rows());
691 }
692
693 #[test]
694 fn test_try_new_selection_vector_larger_than_data() {
695 let data = get_engine_data(2);
697 let result = FilteredEngineData::try_new(data, vec![true, false, true]);
699
700 assert!(result.is_err());
702 if let Err(e) = result {
703 assert!(e
704 .to_string()
705 .contains("Selection vector is larger than data length"));
706 assert!(e.to_string().contains("3 > 2"));
707 }
708 }
709
710 #[test]
711 fn test_get_binary_some_value() {
712 use crate::arrow::array::BinaryArray;
713
714 let binary_data: Vec<Option<&[u8]>> = vec![Some(b"hello"), Some(b"world"), None];
716 let binary_array = BinaryArray::from(binary_data);
717
718 let getter: &dyn GetData<'_> = &binary_array;
720
721 let result: Option<&[u8]> = getter.get_opt(0, "binary_field").unwrap();
723 assert_eq!(result, Some(b"hello".as_ref()));
724
725 let result: Option<&[u8]> = getter.get_opt(1, "binary_field").unwrap();
727 assert_eq!(result, Some(b"world".as_ref()));
728
729 let result: Option<&[u8]> = getter.get_opt(2, "binary_field").unwrap();
731 assert_eq!(result, None);
732 }
733
734 #[test]
735 fn test_get_binary_required() {
736 use crate::arrow::array::BinaryArray;
737
738 let binary_data: Vec<Option<&[u8]>> = vec![Some(b"hello")];
739 let binary_array = BinaryArray::from(binary_data);
740
741 let getter: &dyn GetData<'_> = &binary_array;
743
744 let result: &[u8] = getter.get(0, "binary_field").unwrap();
746 assert_eq!(result, b"hello");
747 }
748
749 #[test]
750 fn test_get_binary_required_missing() {
751 use crate::arrow::array::BinaryArray;
752
753 let binary_data: Vec<Option<&[u8]>> = vec![None];
754 let binary_array = BinaryArray::from(binary_data);
755
756 let getter: &dyn GetData<'_> = &binary_array;
758
759 let result: DeltaResult<&[u8]> = getter.get(0, "binary_field");
761 assert!(result.is_err());
762 if let Err(e) = result {
763 assert!(e.to_string().contains("Data missing for field"));
764 }
765 }
766
767 #[test]
768 fn test_get_binary_empty_bytes() {
769 use crate::arrow::array::BinaryArray;
770
771 let binary_data: Vec<Option<&[u8]>> = vec![Some(b"")];
772 let binary_array = BinaryArray::from(binary_data);
773
774 let getter: &dyn GetData<'_> = &binary_array;
776
777 let result: Option<&[u8]> = getter.get_opt(0, "binary_field").unwrap();
779 assert_eq!(result, Some([].as_ref()));
780 assert_eq!(result.unwrap().len(), 0);
781 }
782
783 #[test]
784 fn test_from_engine_data() {
785 let data = get_engine_data(3);
786 let data_len = data.len(); let filtered_data: FilteredEngineData = data.into();
790
791 assert!(filtered_data.selection_vector().is_empty());
793 assert_eq!(filtered_data.data().len(), data_len);
794 assert_eq!(filtered_data.data().len(), 3);
795 assert!(filtered_data.has_selected_rows());
796 }
797
798 #[test]
799 fn filtered_apply_seclection_vector_full() {
800 let data = get_engine_data(4);
801 let filtered = FilteredEngineData::try_new(data, vec![true, false, true, false]).unwrap();
802 let data = filtered.apply_selection_vector().unwrap();
803 assert_eq!(data.len(), 2);
804 }
805
806 #[test]
807 fn filtered_apply_seclection_vector_partial() {
808 let data = get_engine_data(4);
809 let filtered = FilteredEngineData::try_new(data, vec![true, false]).unwrap();
810 let data = filtered.apply_selection_vector().unwrap();
811 assert_eq!(data.len(), 3);
812 }
813
814 fn collect_indices(row_count: usize, selection: &[bool]) -> Vec<usize> {
815 RowIndexIterator::new(row_count, selection).collect()
816 }
817
818 #[rstest]
819 #[case(0, &[], vec![])]
820 #[case(3, &[], vec![0, 1, 2])]
821 #[case(3, &[true, true, true], vec![0, 1, 2])]
822 #[case(3, &[false, false, false], vec![])]
823 #[case(5, &[true, false, false, true, true], vec![0, 3, 4])]
824 #[case(4, &[false, false, true, true], vec![2, 3])]
825 #[case(3, &[true, false, false], vec![0])]
826 #[case(4, &[false, true], vec![1, 2, 3])]
828 #[case(4, &[true, false], vec![0, 2, 3])]
829 #[case(4, &[false, true, false, true], vec![1, 3])]
830 fn row_index_iter(
831 #[case] row_count: usize,
832 #[case] selection: &[bool],
833 #[case] expected: Vec<usize>,
834 ) {
835 assert_eq!(collect_indices(row_count, selection), expected);
836 }
837}