1use std::borrow::Borrow;
23use std::collections::HashMap;
24use std::hash::Hash;
25use std::sync::Arc;
26
27use arrow::array::RecordBatch;
28use arrow::compute::can_cast_types;
29use arrow::datatypes::{DataType, Field, SchemaRef};
30use datafusion_common::{
31 Result, ScalarValue, exec_err,
32 nested_struct::validate_struct_compatibility,
33 tree_node::{Transformed, TransformedResult, TreeNode},
34};
35use datafusion_functions::core::getfield::GetFieldFunc;
36use datafusion_physical_expr::PhysicalExprSimplifier;
37use datafusion_physical_expr::expressions::CastColumnExpr;
38use datafusion_physical_expr::projection::{ProjectionExprs, Projector};
39use datafusion_physical_expr::{
40 ScalarFunctionExpr,
41 expressions::{self, Column},
42};
43use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
44use itertools::Itertools;
45
46pub fn replace_columns_with_literals<K, V>(
65 expr: Arc<dyn PhysicalExpr>,
66 replacements: &HashMap<K, V>,
67) -> Result<Arc<dyn PhysicalExpr>>
68where
69 K: Borrow<str> + Eq + Hash,
70 V: Borrow<ScalarValue>,
71{
72 expr.transform_down(|expr| {
73 if let Some(column) = expr.as_any().downcast_ref::<Column>()
74 && let Some(replacement_value) = replacements.get(column.name())
75 {
76 return Ok(Transformed::yes(expressions::lit(
77 replacement_value.borrow().clone(),
78 )));
79 }
80 Ok(Transformed::no(expr))
81 })
82 .data()
83}
84
85pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
153 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
170}
171
172pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
176 fn create(
178 &self,
179 logical_file_schema: SchemaRef,
180 physical_file_schema: SchemaRef,
181 ) -> Result<Arc<dyn PhysicalExprAdapter>>;
182}
183
184#[derive(Debug, Clone)]
185pub struct DefaultPhysicalExprAdapterFactory;
186
187impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
188 fn create(
189 &self,
190 logical_file_schema: SchemaRef,
191 physical_file_schema: SchemaRef,
192 ) -> Result<Arc<dyn PhysicalExprAdapter>> {
193 Ok(Arc::new(DefaultPhysicalExprAdapter {
194 logical_file_schema,
195 physical_file_schema,
196 }))
197 }
198}
199
200#[derive(Debug, Clone)]
242pub struct DefaultPhysicalExprAdapter {
243 logical_file_schema: SchemaRef,
244 physical_file_schema: SchemaRef,
245}
246
247impl DefaultPhysicalExprAdapter {
248 pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
253 Self {
254 logical_file_schema,
255 physical_file_schema,
256 }
257 }
258}
259
260impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
261 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
262 let rewriter = DefaultPhysicalExprAdapterRewriter {
263 logical_file_schema: Arc::clone(&self.logical_file_schema),
264 physical_file_schema: Arc::clone(&self.physical_file_schema),
265 };
266 expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
267 .data()
268 }
269}
270
271struct DefaultPhysicalExprAdapterRewriter {
272 logical_file_schema: SchemaRef,
273 physical_file_schema: SchemaRef,
274}
275
276impl DefaultPhysicalExprAdapterRewriter {
277 fn rewrite_expr(
278 &self,
279 expr: Arc<dyn PhysicalExpr>,
280 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
281 if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? {
282 return Ok(Transformed::yes(transformed));
283 }
284
285 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
286 return self.rewrite_column(Arc::clone(&expr), column);
287 }
288
289 Ok(Transformed::no(expr))
290 }
291
292 fn try_rewrite_struct_field_access(
296 &self,
297 expr: &Arc<dyn PhysicalExpr>,
298 ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
299 let get_field_expr =
300 match ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(expr.as_ref()) {
301 Some(expr) => expr,
302 None => return Ok(None),
303 };
304
305 let source_expr = match get_field_expr.args().first() {
306 Some(expr) => expr,
307 None => return Ok(None),
308 };
309
310 let field_name_expr = match get_field_expr.args().get(1) {
311 Some(expr) => expr,
312 None => return Ok(None),
313 };
314
315 let lit = match field_name_expr
316 .as_any()
317 .downcast_ref::<expressions::Literal>()
318 {
319 Some(lit) => lit,
320 None => return Ok(None),
321 };
322
323 let field_name = match lit.value().try_as_str().flatten() {
324 Some(name) => name,
325 None => return Ok(None),
326 };
327
328 let column = match source_expr.as_any().downcast_ref::<Column>() {
329 Some(column) => column,
330 None => return Ok(None),
331 };
332
333 let physical_field =
334 match self.physical_file_schema.field_with_name(column.name()) {
335 Ok(field) => field,
336 Err(_) => return Ok(None),
337 };
338
339 let physical_struct_fields = match physical_field.data_type() {
340 DataType::Struct(fields) => fields,
341 _ => return Ok(None),
342 };
343
344 if physical_struct_fields
345 .iter()
346 .any(|f| f.name() == field_name)
347 {
348 return Ok(None);
349 }
350
351 let logical_field = match self.logical_file_schema.field_with_name(column.name())
352 {
353 Ok(field) => field,
354 Err(_) => return Ok(None),
355 };
356
357 let logical_struct_fields = match logical_field.data_type() {
358 DataType::Struct(fields) => fields,
359 _ => return Ok(None),
360 };
361
362 let logical_struct_field = match logical_struct_fields
363 .iter()
364 .find(|f| f.name() == field_name)
365 {
366 Some(field) => field,
367 None => return Ok(None),
368 };
369
370 let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
371 Ok(Some(expressions::lit(null_value)))
372 }
373
374 fn rewrite_column(
375 &self,
376 expr: Arc<dyn PhysicalExpr>,
377 column: &Column,
378 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
379 let logical_field = match self.logical_file_schema.field_with_name(column.name())
381 {
382 Ok(field) => field,
383 Err(e) => {
384 if let Ok(physical_field) =
388 self.physical_file_schema.field_with_name(column.name())
389 {
390 physical_field
394 } else {
395 return Err(e.into());
399 }
400 }
401 };
402
403 let physical_column_index = match self
405 .physical_file_schema
406 .index_of(column.name())
407 {
408 Ok(index) => index,
409 Err(_) => {
410 if !logical_field.is_nullable() {
411 return exec_err!(
412 "Non-nullable column '{}' is missing from the physical schema",
413 column.name()
414 );
415 }
416 let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?;
419 return Ok(Transformed::yes(expressions::lit(null_value)));
420 }
421 };
422 let physical_field = self.physical_file_schema.field(physical_column_index);
423
424 if column.index() == physical_column_index
425 && logical_field.data_type() == physical_field.data_type()
426 {
427 return Ok(Transformed::no(expr));
428 }
429
430 let column = self.resolve_column(column, physical_column_index)?;
431
432 if logical_field.data_type() == physical_field.data_type() {
433 return Ok(Transformed::yes(Arc::new(column)));
435 }
436
437 self.create_cast_column_expr(column, logical_field)
442 }
443
444 fn resolve_column(
450 &self,
451 column: &Column,
452 physical_column_index: usize,
453 ) -> Result<Column> {
454 if column.index() == physical_column_index {
455 Ok(column.clone())
456 } else {
457 Column::new_with_schema(column.name(), self.physical_file_schema.as_ref())
458 }
459 }
460
461 fn create_cast_column_expr(
467 &self,
468 column: Column,
469 logical_field: &Field,
470 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
471 let physical_column_index = self.physical_file_schema.index_of(column.name())?;
473 let actual_physical_field =
474 self.physical_file_schema.field(physical_column_index);
475
476 match (actual_physical_field.data_type(), logical_field.data_type()) {
482 (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => {
483 validate_struct_compatibility(
484 physical_fields.as_ref(),
485 logical_fields.as_ref(),
486 )?;
487 }
488 _ => {
489 let is_compatible = can_cast_types(
490 actual_physical_field.data_type(),
491 logical_field.data_type(),
492 );
493 if !is_compatible {
494 return exec_err!(
495 "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
496 column.name(),
497 actual_physical_field.data_type(),
498 logical_field.data_type()
499 );
500 }
501 }
502 }
503
504 let cast_expr = Arc::new(CastColumnExpr::new(
505 Arc::new(column),
506 Arc::new(actual_physical_field.clone()),
507 Arc::new(logical_field.clone()),
508 None,
509 ));
510
511 Ok(Transformed::yes(cast_expr))
512 }
513}
514
515#[derive(Debug)]
571pub struct BatchAdapterFactory {
572 target_schema: SchemaRef,
573 expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
574}
575
576impl BatchAdapterFactory {
577 pub fn new(target_schema: SchemaRef) -> Self {
579 let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory);
580 Self {
581 target_schema,
582 expr_adapter_factory,
583 }
584 }
585
586 pub fn with_adapter_factory(
593 self,
594 factory: Arc<dyn PhysicalExprAdapterFactory>,
595 ) -> Self {
596 Self {
597 expr_adapter_factory: factory,
598 ..self
599 }
600 }
601
602 pub fn make_adapter(&self, source_schema: &SchemaRef) -> Result<BatchAdapter> {
607 let expr_adapter = self
608 .expr_adapter_factory
609 .create(Arc::clone(&self.target_schema), Arc::clone(source_schema))?;
610
611 let simplifier = PhysicalExprSimplifier::new(&self.target_schema);
612
613 let projection = ProjectionExprs::from_indices(
614 &(0..self.target_schema.fields().len()).collect_vec(),
615 &self.target_schema,
616 );
617
618 let adapted = projection
619 .try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?;
620 let projector = adapted.make_projector(source_schema)?;
621
622 Ok(BatchAdapter { projector })
623 }
624}
625
626#[derive(Debug)]
636pub struct BatchAdapter {
637 projector: Projector,
638}
639
640impl BatchAdapter {
641 pub fn adapt_batch(&self, batch: &RecordBatch) -> Result<RecordBatch> {
646 self.projector.project_batch(batch)
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use arrow::array::{
654 BooleanArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions,
655 StringArray, StringViewArray, StructArray,
656 };
657 use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
658 use datafusion_common::{Result, ScalarValue, assert_contains, record_batch};
659 use datafusion_expr::Operator;
660 use datafusion_physical_expr::expressions::{Column, Literal, col, lit};
661 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
662 use itertools::Itertools;
663 use std::sync::Arc;
664
665 fn create_test_schema() -> (Schema, Schema) {
666 let physical_schema = Schema::new(vec![
667 Field::new("a", DataType::Int32, false),
668 Field::new("b", DataType::Utf8, true),
669 ]);
670
671 let logical_schema = Schema::new(vec![
672 Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, true),
674 Field::new("c", DataType::Float64, true), ]);
676
677 (physical_schema, logical_schema)
678 }
679
680 #[test]
681 fn test_rewrite_column_with_type_cast() {
682 let (physical_schema, logical_schema) = create_test_schema();
683
684 let factory = DefaultPhysicalExprAdapterFactory;
685 let adapter = factory
686 .create(Arc::new(logical_schema), Arc::new(physical_schema))
687 .unwrap();
688 let column_expr = Arc::new(Column::new("a", 0));
689
690 let result = adapter.rewrite(column_expr).unwrap();
691
692 assert!(result.as_any().downcast_ref::<CastColumnExpr>().is_some());
694 }
695
696 #[test]
697 fn test_rewrite_multi_column_expr_with_type_cast() {
698 let (physical_schema, logical_schema) = create_test_schema();
699 let factory = DefaultPhysicalExprAdapterFactory;
700 let adapter = factory
701 .create(Arc::new(logical_schema), Arc::new(physical_schema))
702 .unwrap();
703
704 let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
706 let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
707 let expr = expressions::BinaryExpr::new(
708 Arc::clone(&column_a),
709 Operator::Plus,
710 Arc::new(Literal::new(ScalarValue::Int64(Some(5)))),
711 );
712 let expr = expressions::BinaryExpr::new(
713 Arc::new(expr),
714 Operator::Or,
715 Arc::new(expressions::BinaryExpr::new(
716 Arc::clone(&column_c),
717 Operator::Gt,
718 Arc::new(Literal::new(ScalarValue::Float64(Some(0.0)))),
719 )),
720 );
721
722 let result = adapter.rewrite(Arc::new(expr)).unwrap();
723 println!("Rewritten expression: {result}");
724
725 let expected = expressions::BinaryExpr::new(
726 Arc::new(CastColumnExpr::new(
727 Arc::new(Column::new("a", 0)),
728 Arc::new(Field::new("a", DataType::Int32, false)),
729 Arc::new(Field::new("a", DataType::Int64, false)),
730 None,
731 )),
732 Operator::Plus,
733 Arc::new(Literal::new(ScalarValue::Int64(Some(5)))),
734 );
735 let expected = Arc::new(expressions::BinaryExpr::new(
736 Arc::new(expected),
737 Operator::Or,
738 Arc::new(expressions::BinaryExpr::new(
739 lit(ScalarValue::Float64(None)), Operator::Gt,
741 Arc::new(Literal::new(ScalarValue::Float64(Some(0.0)))),
742 )),
743 )) as Arc<dyn PhysicalExpr>;
744
745 assert_eq!(
746 result.to_string(),
747 expected.to_string(),
748 "The rewritten expression did not match the expected output"
749 );
750 }
751
752 #[test]
753 fn test_rewrite_struct_column_incompatible() {
754 let physical_schema = Schema::new(vec![Field::new(
755 "data",
756 DataType::Struct(vec![Field::new("field1", DataType::Binary, true)].into()),
757 true,
758 )]);
759
760 let logical_schema = Schema::new(vec![Field::new(
761 "data",
762 DataType::Struct(vec![Field::new("field1", DataType::Int32, true)].into()),
763 true,
764 )]);
765
766 let factory = DefaultPhysicalExprAdapterFactory;
767 let adapter = factory
768 .create(Arc::new(logical_schema), Arc::new(physical_schema))
769 .unwrap();
770 let column_expr = Arc::new(Column::new("data", 0));
771
772 let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
773 assert_contains!(
775 error_msg,
776 "Cannot cast struct field 'field1' from type Binary to type Int32"
777 );
778 }
779
780 #[test]
781 fn test_rewrite_struct_compatible_cast() {
782 let physical_schema = Schema::new(vec![Field::new(
783 "data",
784 DataType::Struct(
785 vec![
786 Field::new("id", DataType::Int32, false),
787 Field::new("name", DataType::Utf8, true),
788 ]
789 .into(),
790 ),
791 false,
792 )]);
793
794 let logical_schema = Schema::new(vec![Field::new(
795 "data",
796 DataType::Struct(
797 vec![
798 Field::new("id", DataType::Int64, false),
799 Field::new("name", DataType::Utf8View, true),
800 ]
801 .into(),
802 ),
803 false,
804 )]);
805
806 let factory = DefaultPhysicalExprAdapterFactory;
807 let adapter = factory
808 .create(Arc::new(logical_schema), Arc::new(physical_schema))
809 .unwrap();
810 let column_expr = Arc::new(Column::new("data", 0));
811
812 let result = adapter.rewrite(column_expr).unwrap();
813
814 let physical_struct_fields: Fields = vec![
815 Field::new("id", DataType::Int32, false),
816 Field::new("name", DataType::Utf8, true),
817 ]
818 .into();
819 let physical_field = Arc::new(Field::new(
820 "data",
821 DataType::Struct(physical_struct_fields),
822 false,
823 ));
824
825 let logical_struct_fields: Fields = vec![
826 Field::new("id", DataType::Int64, false),
827 Field::new("name", DataType::Utf8View, true),
828 ]
829 .into();
830 let logical_field = Arc::new(Field::new(
831 "data",
832 DataType::Struct(logical_struct_fields),
833 false,
834 ));
835
836 let expected = Arc::new(CastColumnExpr::new(
837 Arc::new(Column::new("data", 0)),
838 physical_field,
839 logical_field,
840 None,
841 )) as Arc<dyn PhysicalExpr>;
842
843 assert_eq!(result.to_string(), expected.to_string());
844 }
845
846 #[test]
847 fn test_rewrite_missing_column() -> Result<()> {
848 let (physical_schema, logical_schema) = create_test_schema();
849
850 let factory = DefaultPhysicalExprAdapterFactory;
851 let adapter = factory
852 .create(Arc::new(logical_schema), Arc::new(physical_schema))
853 .unwrap();
854 let column_expr = Arc::new(Column::new("c", 2));
855
856 let result = adapter.rewrite(column_expr)?;
857
858 if let Some(literal) = result.as_any().downcast_ref::<Literal>() {
860 assert_eq!(*literal.value(), ScalarValue::Float64(None));
861 } else {
862 panic!("Expected literal expression");
863 }
864
865 Ok(())
866 }
867
868 #[test]
869 fn test_rewrite_missing_column_non_nullable_error() {
870 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
871 let logical_schema = Schema::new(vec![
872 Field::new("a", DataType::Int64, false),
873 Field::new("b", DataType::Utf8, false), ]);
875
876 let factory = DefaultPhysicalExprAdapterFactory;
877 let adapter = factory
878 .create(Arc::new(logical_schema), Arc::new(physical_schema))
879 .unwrap();
880 let column_expr = Arc::new(Column::new("b", 1));
881
882 let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
883 assert_contains!(error_msg, "Non-nullable column 'b' is missing");
884 }
885
886 #[test]
887 fn test_rewrite_missing_column_nullable() {
888 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
889 let logical_schema = Schema::new(vec![
890 Field::new("a", DataType::Int64, false),
891 Field::new("b", DataType::Utf8, true), ]);
893
894 let factory = DefaultPhysicalExprAdapterFactory;
895 let adapter = factory
896 .create(Arc::new(logical_schema), Arc::new(physical_schema))
897 .unwrap();
898 let column_expr = Arc::new(Column::new("b", 1));
899
900 let result = adapter.rewrite(column_expr).unwrap();
901
902 let expected =
903 Arc::new(Literal::new(ScalarValue::Utf8(None))) as Arc<dyn PhysicalExpr>;
904
905 assert_eq!(result.to_string(), expected.to_string());
906 }
907
908 #[test]
909 fn test_replace_columns_with_literals() -> Result<()> {
910 let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
911 let replacements = HashMap::from([("partition_col", &partition_value)]);
912
913 let column_expr =
914 Arc::new(Column::new("partition_col", 0)) as Arc<dyn PhysicalExpr>;
915 let result = replace_columns_with_literals(column_expr, &replacements)?;
916
917 let literal = result
919 .as_any()
920 .downcast_ref::<Literal>()
921 .expect("Expected literal expression");
922 assert_eq!(*literal.value(), partition_value);
923
924 Ok(())
925 }
926
927 #[test]
928 fn test_replace_columns_with_literals_no_match() -> Result<()> {
929 let value = ScalarValue::Utf8(Some("test_value".to_string()));
930 let replacements = HashMap::from([("other_col", &value)]);
931
932 let column_expr =
933 Arc::new(Column::new("partition_col", 0)) as Arc<dyn PhysicalExpr>;
934 let result = replace_columns_with_literals(column_expr, &replacements)?;
935
936 assert!(result.as_any().downcast_ref::<Column>().is_some());
937 Ok(())
938 }
939
940 #[test]
941 fn test_replace_columns_with_literals_nested_expr() -> Result<()> {
942 let value_a = ScalarValue::Int64(Some(10));
943 let value_b = ScalarValue::Int64(Some(20));
944 let replacements = HashMap::from([("a", &value_a), ("b", &value_b)]);
945
946 let expr = Arc::new(expressions::BinaryExpr::new(
947 Arc::new(Column::new("a", 0)),
948 Operator::Plus,
949 Arc::new(Column::new("b", 1)),
950 )) as Arc<dyn PhysicalExpr>;
951
952 let result = replace_columns_with_literals(expr, &replacements)?;
953 assert_eq!(result.to_string(), "10 + 20");
954
955 Ok(())
956 }
957
958 #[test]
959 fn test_rewrite_no_change_needed() -> Result<()> {
960 let (physical_schema, logical_schema) = create_test_schema();
961
962 let factory = DefaultPhysicalExprAdapterFactory;
963 let adapter = factory
964 .create(Arc::new(logical_schema), Arc::new(physical_schema))
965 .unwrap();
966 let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
967
968 let result = adapter.rewrite(Arc::clone(&column_expr))?;
969
970 assert!(std::ptr::eq(
973 column_expr.as_ref() as *const dyn PhysicalExpr,
974 result.as_ref() as *const dyn PhysicalExpr
975 ));
976
977 Ok(())
978 }
979
980 #[test]
981 fn test_non_nullable_missing_column_error() {
982 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
983 let logical_schema = Schema::new(vec![
984 Field::new("a", DataType::Int32, false),
985 Field::new("b", DataType::Utf8, false), ]);
987
988 let factory = DefaultPhysicalExprAdapterFactory;
989 let adapter = factory
990 .create(Arc::new(logical_schema), Arc::new(physical_schema))
991 .unwrap();
992 let column_expr = Arc::new(Column::new("b", 1));
993
994 let result = adapter.rewrite(column_expr);
995 assert!(result.is_err());
996 assert_contains!(
997 result.unwrap_err().to_string(),
998 "Non-nullable column 'b' is missing from the physical schema"
999 );
1000 }
1001
1002 fn batch_project(
1004 expr: Vec<Arc<dyn PhysicalExpr>>,
1005 batch: &RecordBatch,
1006 schema: SchemaRef,
1007 ) -> Result<RecordBatch> {
1008 let arrays = expr
1009 .iter()
1010 .map(|expr| {
1011 expr.evaluate(batch)
1012 .and_then(|v| v.into_array(batch.num_rows()))
1013 })
1014 .collect::<Result<Vec<_>>>()?;
1015
1016 if arrays.is_empty() {
1017 let options =
1018 RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
1019 RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
1020 .map_err(Into::into)
1021 } else {
1022 RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
1023 }
1024 }
1025
1026 #[test]
1029 fn test_adapt_batches() {
1030 let physical_batch = record_batch!(
1031 ("a", Int32, vec![Some(1), None, Some(3)]),
1032 ("extra", Utf8, vec![Some("x"), Some("y"), None])
1033 )
1034 .unwrap();
1035
1036 let physical_schema = physical_batch.schema();
1037
1038 let logical_schema = Arc::new(Schema::new(vec![
1039 Field::new("a", DataType::Int64, true), Field::new("b", DataType::Utf8, true), ]));
1042
1043 let projection = vec![
1044 col("b", &logical_schema).unwrap(),
1045 col("a", &logical_schema).unwrap(),
1046 ];
1047
1048 let factory = DefaultPhysicalExprAdapterFactory;
1049 let adapter = factory
1050 .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema))
1051 .unwrap();
1052
1053 let adapted_projection = projection
1054 .into_iter()
1055 .map(|expr| adapter.rewrite(expr).unwrap())
1056 .collect_vec();
1057
1058 let adapted_schema = Arc::new(Schema::new(
1059 adapted_projection
1060 .iter()
1061 .map(|expr| expr.return_field(&physical_schema).unwrap())
1062 .collect_vec(),
1063 ));
1064
1065 let res = batch_project(
1066 adapted_projection,
1067 &physical_batch,
1068 Arc::clone(&adapted_schema),
1069 )
1070 .unwrap();
1071
1072 assert_eq!(res.num_columns(), 2);
1073 assert_eq!(res.column(0).data_type(), &DataType::Utf8);
1074 assert_eq!(res.column(1).data_type(), &DataType::Int64);
1075 assert_eq!(
1076 res.column(0)
1077 .as_any()
1078 .downcast_ref::<StringArray>()
1079 .unwrap()
1080 .iter()
1081 .collect_vec(),
1082 vec![None, None, None]
1083 );
1084 assert_eq!(
1085 res.column(1)
1086 .as_any()
1087 .downcast_ref::<Int64Array>()
1088 .unwrap()
1089 .iter()
1090 .collect_vec(),
1091 vec![Some(1), None, Some(3)]
1092 );
1093 }
1094
1095 #[test]
1099 fn test_adapt_struct_batches() {
1100 let physical_struct_fields: Fields = vec![
1102 Field::new("id", DataType::Int32, false),
1103 Field::new("name", DataType::Utf8, true),
1104 ]
1105 .into();
1106
1107 let struct_array = StructArray::new(
1108 physical_struct_fields.clone(),
1109 vec![
1110 Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
1111 Arc::new(StringArray::from(vec![
1112 Some("alice"),
1113 None,
1114 Some("charlie"),
1115 ])) as _,
1116 ],
1117 None,
1118 );
1119
1120 let physical_schema = Arc::new(Schema::new(vec![Field::new(
1121 "data",
1122 DataType::Struct(physical_struct_fields),
1123 false,
1124 )]));
1125
1126 let physical_batch = RecordBatch::try_new(
1127 Arc::clone(&physical_schema),
1128 vec![Arc::new(struct_array)],
1129 )
1130 .unwrap();
1131
1132 let logical_struct_fields: Fields = vec![
1137 Field::new("id", DataType::Int64, false),
1138 Field::new("name", DataType::Utf8View, true),
1139 Field::new("extra", DataType::Boolean, true), ]
1141 .into();
1142
1143 let logical_schema = Arc::new(Schema::new(vec![Field::new(
1144 "data",
1145 DataType::Struct(logical_struct_fields),
1146 false,
1147 )]));
1148
1149 let projection = vec![col("data", &logical_schema).unwrap()];
1150
1151 let factory = DefaultPhysicalExprAdapterFactory;
1152 let adapter = factory
1153 .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema))
1154 .unwrap();
1155
1156 let adapted_projection = projection
1157 .into_iter()
1158 .map(|expr| adapter.rewrite(expr).unwrap())
1159 .collect_vec();
1160
1161 let adapted_schema = Arc::new(Schema::new(
1162 adapted_projection
1163 .iter()
1164 .map(|expr| expr.return_field(&physical_schema).unwrap())
1165 .collect_vec(),
1166 ));
1167
1168 let res = batch_project(
1169 adapted_projection,
1170 &physical_batch,
1171 Arc::clone(&adapted_schema),
1172 )
1173 .unwrap();
1174
1175 assert_eq!(res.num_columns(), 1);
1176
1177 let result_struct = res
1178 .column(0)
1179 .as_any()
1180 .downcast_ref::<StructArray>()
1181 .unwrap();
1182
1183 let id_col = result_struct.column_by_name("id").unwrap();
1185 assert_eq!(id_col.data_type(), &DataType::Int64);
1186 let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
1187 assert_eq!(
1188 id_values.iter().collect_vec(),
1189 vec![Some(1), Some(2), Some(3)]
1190 );
1191
1192 let name_col = result_struct.column_by_name("name").unwrap();
1194 assert_eq!(name_col.data_type(), &DataType::Utf8View);
1195 let name_values = name_col.as_any().downcast_ref::<StringViewArray>().unwrap();
1196 assert_eq!(
1197 name_values.iter().collect_vec(),
1198 vec![Some("alice"), None, Some("charlie")]
1199 );
1200
1201 let extra_col = result_struct.column_by_name("extra").unwrap();
1203 assert_eq!(extra_col.data_type(), &DataType::Boolean);
1204 let extra_values = extra_col.as_any().downcast_ref::<BooleanArray>().unwrap();
1205 assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]);
1206 }
1207
1208 #[test]
1209 fn test_try_rewrite_struct_field_access() {
1210 let physical_schema = Schema::new(vec![Field::new(
1212 "struct_col",
1213 DataType::Struct(
1214 vec![Field::new("existing_field", DataType::Int32, true)].into(),
1215 ),
1216 true,
1217 )]);
1218
1219 let logical_schema = Schema::new(vec![Field::new(
1220 "struct_col",
1221 DataType::Struct(
1222 vec![
1223 Field::new("existing_field", DataType::Int32, true),
1224 Field::new("missing_field", DataType::Utf8, true),
1225 ]
1226 .into(),
1227 ),
1228 true,
1229 )]);
1230
1231 let rewriter = DefaultPhysicalExprAdapterRewriter {
1232 logical_file_schema: Arc::new(logical_schema),
1233 physical_file_schema: Arc::new(physical_schema),
1234 };
1235
1236 let column = Arc::new(Column::new("struct_col", 0)) as Arc<dyn PhysicalExpr>;
1238 let result = rewriter.try_rewrite_struct_field_access(&column).unwrap();
1239 assert!(result.is_none());
1240
1241 }
1245
1246 #[test]
1251 fn test_batch_adapter_factory_basic() {
1252 let target_schema = Arc::new(Schema::new(vec![
1254 Field::new("a", DataType::Int64, false),
1255 Field::new("b", DataType::Utf8, true),
1256 ]));
1257
1258 let source_schema = Arc::new(Schema::new(vec![
1260 Field::new("b", DataType::Utf8, true),
1261 Field::new("a", DataType::Int32, false), ]));
1263
1264 let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1265 let adapter = factory.make_adapter(&source_schema).unwrap();
1266
1267 let source_batch = RecordBatch::try_new(
1269 Arc::clone(&source_schema),
1270 vec![
1271 Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
1272 Arc::new(Int32Array::from(vec![1, 2, 3])),
1273 ],
1274 )
1275 .unwrap();
1276
1277 let adapted = adapter.adapt_batch(&source_batch).unwrap();
1278
1279 assert_eq!(adapted.num_columns(), 2);
1281 assert_eq!(adapted.schema().field(0).name(), "a");
1282 assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64);
1283 assert_eq!(adapted.schema().field(1).name(), "b");
1284 assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
1285
1286 let col_a = adapted
1288 .column(0)
1289 .as_any()
1290 .downcast_ref::<Int64Array>()
1291 .unwrap();
1292 assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]);
1293
1294 let col_b = adapted
1295 .column(1)
1296 .as_any()
1297 .downcast_ref::<StringArray>()
1298 .unwrap();
1299 assert_eq!(
1300 col_b.iter().collect_vec(),
1301 vec![Some("hello"), None, Some("world")]
1302 );
1303 }
1304
1305 #[test]
1306 fn test_batch_adapter_factory_missing_column() {
1307 let target_schema = Arc::new(Schema::new(vec![
1309 Field::new("a", DataType::Int32, false),
1310 Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Float64, true), ]));
1313
1314 let source_schema = Arc::new(Schema::new(vec![
1315 Field::new("a", DataType::Int32, false),
1316 Field::new("b", DataType::Utf8, true),
1317 ]));
1318
1319 let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1320 let adapter = factory.make_adapter(&source_schema).unwrap();
1321
1322 let source_batch = RecordBatch::try_new(
1323 Arc::clone(&source_schema),
1324 vec![
1325 Arc::new(Int32Array::from(vec![1, 2])),
1326 Arc::new(StringArray::from(vec!["x", "y"])),
1327 ],
1328 )
1329 .unwrap();
1330
1331 let adapted = adapter.adapt_batch(&source_batch).unwrap();
1332
1333 assert_eq!(adapted.num_columns(), 3);
1334
1335 let col_c = adapted.column(2);
1337 assert_eq!(col_c.data_type(), &DataType::Float64);
1338 assert_eq!(col_c.null_count(), 2); }
1340
1341 #[test]
1342 fn test_batch_adapter_factory_with_struct() {
1343 let target_struct_fields: Fields = vec![
1345 Field::new("id", DataType::Int64, false),
1346 Field::new("name", DataType::Utf8, true),
1347 ]
1348 .into();
1349 let target_schema = Arc::new(Schema::new(vec![Field::new(
1350 "data",
1351 DataType::Struct(target_struct_fields),
1352 false,
1353 )]));
1354
1355 let source_struct_fields: Fields = vec![
1357 Field::new("id", DataType::Int32, false),
1358 Field::new("name", DataType::Utf8, true),
1359 ]
1360 .into();
1361 let source_schema = Arc::new(Schema::new(vec![Field::new(
1362 "data",
1363 DataType::Struct(source_struct_fields.clone()),
1364 false,
1365 )]));
1366
1367 let struct_array = StructArray::new(
1368 source_struct_fields,
1369 vec![
1370 Arc::new(Int32Array::from(vec![10, 20])) as _,
1371 Arc::new(StringArray::from(vec!["a", "b"])) as _,
1372 ],
1373 None,
1374 );
1375
1376 let source_batch = RecordBatch::try_new(
1377 Arc::clone(&source_schema),
1378 vec![Arc::new(struct_array)],
1379 )
1380 .unwrap();
1381
1382 let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1383 let adapter = factory.make_adapter(&source_schema).unwrap();
1384 let adapted = adapter.adapt_batch(&source_batch).unwrap();
1385
1386 let result_struct = adapted
1387 .column(0)
1388 .as_any()
1389 .downcast_ref::<StructArray>()
1390 .unwrap();
1391
1392 let id_col = result_struct.column_by_name("id").unwrap();
1394 assert_eq!(id_col.data_type(), &DataType::Int64);
1395 let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
1396 assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]);
1397 }
1398
1399 #[test]
1400 fn test_batch_adapter_factory_identity() {
1401 let schema = Arc::new(Schema::new(vec![
1403 Field::new("a", DataType::Int32, false),
1404 Field::new("b", DataType::Utf8, true),
1405 ]));
1406
1407 let factory = BatchAdapterFactory::new(Arc::clone(&schema));
1408 let adapter = factory.make_adapter(&schema).unwrap();
1409
1410 let batch = RecordBatch::try_new(
1411 Arc::clone(&schema),
1412 vec![
1413 Arc::new(Int32Array::from(vec![1, 2, 3])),
1414 Arc::new(StringArray::from(vec!["a", "b", "c"])),
1415 ],
1416 )
1417 .unwrap();
1418
1419 let adapted = adapter.adapt_batch(&batch).unwrap();
1420
1421 assert_eq!(adapted.num_columns(), 2);
1422 assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32);
1423 assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
1424 }
1425
1426 #[test]
1427 fn test_batch_adapter_factory_reuse() {
1428 let target_schema = Arc::new(Schema::new(vec![
1430 Field::new("x", DataType::Int64, false),
1431 Field::new("y", DataType::Utf8, true),
1432 ]));
1433
1434 let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1435
1436 let source1 = Arc::new(Schema::new(vec![
1438 Field::new("x", DataType::Int32, false),
1439 Field::new("y", DataType::Utf8, true),
1440 ]));
1441 let adapter1 = factory.make_adapter(&source1).unwrap();
1442
1443 let source2 = Arc::new(Schema::new(vec![
1445 Field::new("y", DataType::Utf8, true),
1446 Field::new("x", DataType::Int64, false),
1447 ]));
1448 let adapter2 = factory.make_adapter(&source2).unwrap();
1449
1450 assert!(format!("{adapter1:?}").contains("BatchAdapter"));
1452 assert!(format!("{adapter2:?}").contains("BatchAdapter"));
1453 }
1454
1455 #[test]
1456 fn test_rewrite_column_index_and_type_mismatch() {
1457 let physical_schema = Schema::new(vec![
1458 Field::new("b", DataType::Utf8, true),
1459 Field::new("a", DataType::Int32, false), ]);
1461
1462 let logical_schema = Schema::new(vec![
1463 Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, true),
1465 ]);
1466
1467 let factory = DefaultPhysicalExprAdapterFactory;
1468 let adapter = factory
1469 .create(Arc::new(logical_schema), Arc::new(physical_schema))
1470 .unwrap();
1471
1472 let column_expr = Arc::new(Column::new("a", 0));
1474
1475 let result = adapter.rewrite(column_expr).unwrap();
1476
1477 let cast_expr = result
1479 .as_any()
1480 .downcast_ref::<CastColumnExpr>()
1481 .expect("Expected CastColumnExpr");
1482
1483 let inner_col = cast_expr
1485 .expr()
1486 .as_any()
1487 .downcast_ref::<Column>()
1488 .expect("Expected inner Column");
1489 assert_eq!(inner_col.name(), "a");
1490 assert_eq!(inner_col.index(), 1); assert_eq!(
1494 cast_expr.data_type(&Schema::empty()).unwrap(),
1495 DataType::Int64
1496 );
1497 }
1498
1499 #[test]
1500 fn test_create_cast_column_expr_uses_name_lookup_not_column_index() {
1501 let physical_schema = Arc::new(Schema::new(vec![
1503 Field::new("b", DataType::Binary, true),
1504 Field::new("a", DataType::Int32, false),
1505 ]));
1506
1507 let logical_schema = Arc::new(Schema::new(vec![
1508 Field::new("a", DataType::Int64, false),
1509 Field::new("b", DataType::Binary, true),
1510 ]));
1511
1512 let rewriter = DefaultPhysicalExprAdapterRewriter {
1513 logical_file_schema: Arc::clone(&logical_schema),
1514 physical_file_schema: Arc::clone(&physical_schema),
1515 };
1516
1517 let transformed = rewriter
1520 .create_cast_column_expr(
1521 Column::new("a", 0),
1522 logical_schema.field_with_name("a").unwrap(),
1523 )
1524 .unwrap();
1525
1526 let cast_expr = transformed
1527 .data
1528 .as_any()
1529 .downcast_ref::<CastColumnExpr>()
1530 .expect("Expected CastColumnExpr");
1531
1532 assert_eq!(cast_expr.input_field().name(), "a");
1533 assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32);
1534 assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64);
1535 }
1536}