1use std::borrow::Borrow;
23use std::collections::HashMap;
24use std::hash::Hash;
25use std::sync::Arc;
26
27use arrow::array::RecordBatch;
28use arrow::compute::can_cast_types;
29use arrow::datatypes::{DataType, Schema, SchemaRef};
30use datafusion_common::{
31 Result, ScalarValue, exec_err,
32 nested_struct::validate_struct_compatibility,
33 tree_node::{Transformed, TransformedResult, TreeNode},
34};
35use datafusion_functions::core::getfield::GetFieldFunc;
36use datafusion_physical_expr::PhysicalExprSimplifier;
37use datafusion_physical_expr::expressions::CastColumnExpr;
38use datafusion_physical_expr::projection::{ProjectionExprs, Projector};
39use datafusion_physical_expr::{
40 ScalarFunctionExpr,
41 expressions::{self, Column},
42};
43use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
44use itertools::Itertools;
45
46pub fn replace_columns_with_literals<K, V>(
65 expr: Arc<dyn PhysicalExpr>,
66 replacements: &HashMap<K, V>,
67) -> Result<Arc<dyn PhysicalExpr>>
68where
69 K: Borrow<str> + Eq + Hash,
70 V: Borrow<ScalarValue>,
71{
72 expr.transform_down(|expr| {
73 if let Some(column) = expr.as_any().downcast_ref::<Column>()
74 && let Some(replacement_value) = replacements.get(column.name())
75 {
76 return Ok(Transformed::yes(expressions::lit(
77 replacement_value.borrow().clone(),
78 )));
79 }
80 Ok(Transformed::no(expr))
81 })
82 .data()
83}
84
85pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
153 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
170}
171
172pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
176 fn create(
178 &self,
179 logical_file_schema: SchemaRef,
180 physical_file_schema: SchemaRef,
181 ) -> Arc<dyn PhysicalExprAdapter>;
182}
183
184#[derive(Debug, Clone)]
185pub struct DefaultPhysicalExprAdapterFactory;
186
187impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
188 fn create(
189 &self,
190 logical_file_schema: SchemaRef,
191 physical_file_schema: SchemaRef,
192 ) -> Arc<dyn PhysicalExprAdapter> {
193 Arc::new(DefaultPhysicalExprAdapter {
194 logical_file_schema,
195 physical_file_schema,
196 })
197 }
198}
199
200#[derive(Debug, Clone)]
241pub struct DefaultPhysicalExprAdapter {
242 logical_file_schema: SchemaRef,
243 physical_file_schema: SchemaRef,
244}
245
246impl DefaultPhysicalExprAdapter {
247 pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
252 Self {
253 logical_file_schema,
254 physical_file_schema,
255 }
256 }
257}
258
259impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
260 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
261 let rewriter = DefaultPhysicalExprAdapterRewriter {
262 logical_file_schema: &self.logical_file_schema,
263 physical_file_schema: &self.physical_file_schema,
264 };
265 expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
266 .data()
267 }
268}
269
270struct DefaultPhysicalExprAdapterRewriter<'a> {
271 logical_file_schema: &'a Schema,
272 physical_file_schema: &'a Schema,
273}
274
275impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
276 fn rewrite_expr(
277 &self,
278 expr: Arc<dyn PhysicalExpr>,
279 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
280 if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? {
281 return Ok(Transformed::yes(transformed));
282 }
283
284 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
285 return self.rewrite_column(Arc::clone(&expr), column);
286 }
287
288 Ok(Transformed::no(expr))
289 }
290
291 fn try_rewrite_struct_field_access(
295 &self,
296 expr: &Arc<dyn PhysicalExpr>,
297 ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
298 let get_field_expr =
299 match ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(expr.as_ref()) {
300 Some(expr) => expr,
301 None => return Ok(None),
302 };
303
304 let source_expr = match get_field_expr.args().first() {
305 Some(expr) => expr,
306 None => return Ok(None),
307 };
308
309 let field_name_expr = match get_field_expr.args().get(1) {
310 Some(expr) => expr,
311 None => return Ok(None),
312 };
313
314 let lit = match field_name_expr
315 .as_any()
316 .downcast_ref::<expressions::Literal>()
317 {
318 Some(lit) => lit,
319 None => return Ok(None),
320 };
321
322 let field_name = match lit.value().try_as_str().flatten() {
323 Some(name) => name,
324 None => return Ok(None),
325 };
326
327 let column = match source_expr.as_any().downcast_ref::<Column>() {
328 Some(column) => column,
329 None => return Ok(None),
330 };
331
332 let physical_field =
333 match self.physical_file_schema.field_with_name(column.name()) {
334 Ok(field) => field,
335 Err(_) => return Ok(None),
336 };
337
338 let physical_struct_fields = match physical_field.data_type() {
339 DataType::Struct(fields) => fields,
340 _ => return Ok(None),
341 };
342
343 if physical_struct_fields
344 .iter()
345 .any(|f| f.name() == field_name)
346 {
347 return Ok(None);
348 }
349
350 let logical_field = match self.logical_file_schema.field_with_name(column.name())
351 {
352 Ok(field) => field,
353 Err(_) => return Ok(None),
354 };
355
356 let logical_struct_fields = match logical_field.data_type() {
357 DataType::Struct(fields) => fields,
358 _ => return Ok(None),
359 };
360
361 let logical_struct_field = match logical_struct_fields
362 .iter()
363 .find(|f| f.name() == field_name)
364 {
365 Some(field) => field,
366 None => return Ok(None),
367 };
368
369 let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
370 Ok(Some(expressions::lit(null_value)))
371 }
372
373 fn rewrite_column(
374 &self,
375 expr: Arc<dyn PhysicalExpr>,
376 column: &Column,
377 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
378 let logical_field = match self.logical_file_schema.field_with_name(column.name())
380 {
381 Ok(field) => field,
382 Err(e) => {
383 if let Ok(physical_field) =
387 self.physical_file_schema.field_with_name(column.name())
388 {
389 physical_field
393 } else {
394 return Err(e.into());
398 }
399 }
400 };
401
402 let physical_column_index = match self
404 .physical_file_schema
405 .index_of(column.name())
406 {
407 Ok(index) => index,
408 Err(_) => {
409 if !logical_field.is_nullable() {
410 return exec_err!(
411 "Non-nullable column '{}' is missing from the physical schema",
412 column.name()
413 );
414 }
415 let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?;
418 return Ok(Transformed::yes(expressions::lit(null_value)));
419 }
420 };
421 let physical_field = self.physical_file_schema.field(physical_column_index);
422
423 let column = match (
424 column.index() == physical_column_index,
425 logical_field.data_type() == physical_field.data_type(),
426 ) {
427 (true, true) => return Ok(Transformed::no(expr)),
429 (true, _) => column.clone(),
431 (false, _) => {
432 Column::new_with_schema(logical_field.name(), self.physical_file_schema)?
433 }
434 };
435
436 if logical_field.data_type() == physical_field.data_type() {
437 return Ok(Transformed::yes(Arc::new(column)));
439 }
440
441 match (physical_field.data_type(), logical_field.data_type()) {
452 (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => {
453 validate_struct_compatibility(physical_fields, logical_fields)?;
454 }
455 _ => {
456 let is_compatible =
457 can_cast_types(physical_field.data_type(), logical_field.data_type());
458 if !is_compatible {
459 return exec_err!(
460 "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
461 column.name(),
462 physical_field.data_type(),
463 logical_field.data_type()
464 );
465 }
466 }
467 }
468
469 let cast_expr = Arc::new(CastColumnExpr::new(
470 Arc::new(column),
471 Arc::new(physical_field.clone()),
472 Arc::new(logical_field.clone()),
473 None,
474 ));
475
476 Ok(Transformed::yes(cast_expr))
477 }
478}
479
480#[derive(Debug)]
536pub struct BatchAdapterFactory {
537 target_schema: SchemaRef,
538 expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
539}
540
541impl BatchAdapterFactory {
542 pub fn new(target_schema: SchemaRef) -> Self {
544 let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory);
545 Self {
546 target_schema,
547 expr_adapter_factory,
548 }
549 }
550
551 pub fn with_adapter_factory(
558 self,
559 factory: Arc<dyn PhysicalExprAdapterFactory>,
560 ) -> Self {
561 Self {
562 expr_adapter_factory: factory,
563 ..self
564 }
565 }
566
567 pub fn make_adapter(&self, source_schema: SchemaRef) -> Result<BatchAdapter> {
572 let expr_adapter = self
573 .expr_adapter_factory
574 .create(Arc::clone(&self.target_schema), Arc::clone(&source_schema));
575
576 let simplifier = PhysicalExprSimplifier::new(&self.target_schema);
577
578 let projection = ProjectionExprs::from_indices(
579 &(0..self.target_schema.fields().len()).collect_vec(),
580 &self.target_schema,
581 );
582
583 let adapted = projection
584 .try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?;
585 let projector = adapted.make_projector(&source_schema)?;
586
587 Ok(BatchAdapter { projector })
588 }
589}
590
591#[derive(Debug)]
601pub struct BatchAdapter {
602 projector: Projector,
603}
604
605impl BatchAdapter {
606 pub fn adapt_batch(&self, batch: &RecordBatch) -> Result<RecordBatch> {
611 self.projector.project_batch(batch)
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use arrow::array::{
619 BooleanArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions,
620 StringArray, StringViewArray, StructArray,
621 };
622 use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
623 use datafusion_common::{Result, ScalarValue, assert_contains, record_batch};
624 use datafusion_expr::Operator;
625 use datafusion_physical_expr::expressions::{Column, Literal, col, lit};
626 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
627 use itertools::Itertools;
628 use std::sync::Arc;
629
630 fn create_test_schema() -> (Schema, Schema) {
631 let physical_schema = Schema::new(vec![
632 Field::new("a", DataType::Int32, false),
633 Field::new("b", DataType::Utf8, true),
634 ]);
635
636 let logical_schema = Schema::new(vec![
637 Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, true),
639 Field::new("c", DataType::Float64, true), ]);
641
642 (physical_schema, logical_schema)
643 }
644
645 #[test]
646 fn test_rewrite_column_with_type_cast() {
647 let (physical_schema, logical_schema) = create_test_schema();
648
649 let factory = DefaultPhysicalExprAdapterFactory;
650 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
651 let column_expr = Arc::new(Column::new("a", 0));
652
653 let result = adapter.rewrite(column_expr).unwrap();
654
655 assert!(result.as_any().downcast_ref::<CastColumnExpr>().is_some());
657 }
658
659 #[test]
660 fn test_rewrite_multi_column_expr_with_type_cast() {
661 let (physical_schema, logical_schema) = create_test_schema();
662 let factory = DefaultPhysicalExprAdapterFactory;
663 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
664
665 let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
667 let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
668 let expr = expressions::BinaryExpr::new(
669 Arc::clone(&column_a),
670 Operator::Plus,
671 Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
672 );
673 let expr = expressions::BinaryExpr::new(
674 Arc::new(expr),
675 Operator::Or,
676 Arc::new(expressions::BinaryExpr::new(
677 Arc::clone(&column_c),
678 Operator::Gt,
679 Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
680 )),
681 );
682
683 let result = adapter.rewrite(Arc::new(expr)).unwrap();
684 println!("Rewritten expression: {result}");
685
686 let expected = expressions::BinaryExpr::new(
687 Arc::new(CastColumnExpr::new(
688 Arc::new(Column::new("a", 0)),
689 Arc::new(Field::new("a", DataType::Int32, false)),
690 Arc::new(Field::new("a", DataType::Int64, false)),
691 None,
692 )),
693 Operator::Plus,
694 Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
695 );
696 let expected = Arc::new(expressions::BinaryExpr::new(
697 Arc::new(expected),
698 Operator::Or,
699 Arc::new(expressions::BinaryExpr::new(
700 lit(ScalarValue::Float64(None)), Operator::Gt,
702 Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
703 )),
704 )) as Arc<dyn PhysicalExpr>;
705
706 assert_eq!(
707 result.to_string(),
708 expected.to_string(),
709 "The rewritten expression did not match the expected output"
710 );
711 }
712
713 #[test]
714 fn test_rewrite_struct_column_incompatible() {
715 let physical_schema = Schema::new(vec![Field::new(
716 "data",
717 DataType::Struct(vec![Field::new("field1", DataType::Binary, true)].into()),
718 true,
719 )]);
720
721 let logical_schema = Schema::new(vec![Field::new(
722 "data",
723 DataType::Struct(vec![Field::new("field1", DataType::Int32, true)].into()),
724 true,
725 )]);
726
727 let factory = DefaultPhysicalExprAdapterFactory;
728 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
729 let column_expr = Arc::new(Column::new("data", 0));
730
731 let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
732 assert_contains!(
734 error_msg,
735 "Cannot cast struct field 'field1' from type Binary to type Int32"
736 );
737 }
738
739 #[test]
740 fn test_rewrite_struct_compatible_cast() {
741 let physical_schema = Schema::new(vec![Field::new(
742 "data",
743 DataType::Struct(
744 vec![
745 Field::new("id", DataType::Int32, false),
746 Field::new("name", DataType::Utf8, true),
747 ]
748 .into(),
749 ),
750 false,
751 )]);
752
753 let logical_schema = Schema::new(vec![Field::new(
754 "data",
755 DataType::Struct(
756 vec![
757 Field::new("id", DataType::Int64, false),
758 Field::new("name", DataType::Utf8View, true),
759 ]
760 .into(),
761 ),
762 false,
763 )]);
764
765 let factory = DefaultPhysicalExprAdapterFactory;
766 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
767 let column_expr = Arc::new(Column::new("data", 0));
768
769 let result = adapter.rewrite(column_expr).unwrap();
770
771 let expected = Arc::new(CastColumnExpr::new(
772 Arc::new(Column::new("data", 0)),
773 Arc::new(Field::new(
774 "data",
775 DataType::Struct(
776 vec![
777 Field::new("id", DataType::Int32, false),
778 Field::new("name", DataType::Utf8, true),
779 ]
780 .into(),
781 ),
782 false,
783 )),
784 Arc::new(Field::new(
785 "data",
786 DataType::Struct(
787 vec![
788 Field::new("id", DataType::Int64, false),
789 Field::new("name", DataType::Utf8View, true),
790 ]
791 .into(),
792 ),
793 false,
794 )),
795 None,
796 )) as Arc<dyn PhysicalExpr>;
797
798 assert_eq!(result.to_string(), expected.to_string());
799 }
800
801 #[test]
802 fn test_rewrite_missing_column() -> Result<()> {
803 let (physical_schema, logical_schema) = create_test_schema();
804
805 let factory = DefaultPhysicalExprAdapterFactory;
806 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
807 let column_expr = Arc::new(Column::new("c", 2));
808
809 let result = adapter.rewrite(column_expr)?;
810
811 if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
813 assert_eq!(*literal.value(), ScalarValue::Float64(None));
814 } else {
815 panic!("Expected literal expression");
816 }
817
818 Ok(())
819 }
820
821 #[test]
822 fn test_rewrite_missing_column_non_nullable_error() {
823 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
824 let logical_schema = Schema::new(vec![
825 Field::new("a", DataType::Int64, false),
826 Field::new("b", DataType::Utf8, false), ]);
828
829 let factory = DefaultPhysicalExprAdapterFactory;
830 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
831 let column_expr = Arc::new(Column::new("b", 1));
832
833 let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
834 assert_contains!(error_msg, "Non-nullable column 'b' is missing");
835 }
836
837 #[test]
838 fn test_rewrite_missing_column_nullable() {
839 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
840 let logical_schema = Schema::new(vec![
841 Field::new("a", DataType::Int64, false),
842 Field::new("b", DataType::Utf8, true), ]);
844
845 let factory = DefaultPhysicalExprAdapterFactory;
846 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
847 let column_expr = Arc::new(Column::new("b", 1));
848
849 let result = adapter.rewrite(column_expr).unwrap();
850
851 let expected =
852 Arc::new(Literal::new(ScalarValue::Utf8(None))) as Arc<dyn PhysicalExpr>;
853
854 assert_eq!(result.to_string(), expected.to_string());
855 }
856
857 #[test]
858 fn test_replace_columns_with_literals() -> Result<()> {
859 let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
860 let replacements = HashMap::from([("partition_col", &partition_value)]);
861
862 let column_expr =
863 Arc::new(Column::new("partition_col", 0)) as Arc<dyn PhysicalExpr>;
864 let result = replace_columns_with_literals(column_expr, &replacements)?;
865
866 let literal = result
868 .as_any()
869 .downcast_ref::<expressions::Literal>()
870 .expect("Expected literal expression");
871 assert_eq!(*literal.value(), partition_value);
872
873 Ok(())
874 }
875
876 #[test]
877 fn test_replace_columns_with_literals_no_match() -> Result<()> {
878 let value = ScalarValue::Utf8(Some("test_value".to_string()));
879 let replacements = HashMap::from([("other_col", &value)]);
880
881 let column_expr =
882 Arc::new(Column::new("partition_col", 0)) as Arc<dyn PhysicalExpr>;
883 let result = replace_columns_with_literals(column_expr, &replacements)?;
884
885 assert!(result.as_any().downcast_ref::<Column>().is_some());
886 Ok(())
887 }
888
889 #[test]
890 fn test_replace_columns_with_literals_nested_expr() -> Result<()> {
891 let value_a = ScalarValue::Int64(Some(10));
892 let value_b = ScalarValue::Int64(Some(20));
893 let replacements = HashMap::from([("a", &value_a), ("b", &value_b)]);
894
895 let expr = Arc::new(expressions::BinaryExpr::new(
896 Arc::new(Column::new("a", 0)),
897 Operator::Plus,
898 Arc::new(Column::new("b", 1)),
899 )) as Arc<dyn PhysicalExpr>;
900
901 let result = replace_columns_with_literals(expr, &replacements)?;
902 assert_eq!(result.to_string(), "10 + 20");
903
904 Ok(())
905 }
906
907 #[test]
908 fn test_rewrite_no_change_needed() -> Result<()> {
909 let (physical_schema, logical_schema) = create_test_schema();
910
911 let factory = DefaultPhysicalExprAdapterFactory;
912 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
913 let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
914
915 let result = adapter.rewrite(Arc::clone(&column_expr))?;
916
917 assert!(std::ptr::eq(
920 column_expr.as_ref() as *const dyn PhysicalExpr,
921 result.as_ref() as *const dyn PhysicalExpr
922 ));
923
924 Ok(())
925 }
926
927 #[test]
928 fn test_non_nullable_missing_column_error() {
929 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
930 let logical_schema = Schema::new(vec![
931 Field::new("a", DataType::Int32, false),
932 Field::new("b", DataType::Utf8, false), ]);
934
935 let factory = DefaultPhysicalExprAdapterFactory;
936 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
937 let column_expr = Arc::new(Column::new("b", 1));
938
939 let result = adapter.rewrite(column_expr);
940 assert!(result.is_err());
941 assert_contains!(
942 result.unwrap_err().to_string(),
943 "Non-nullable column 'b' is missing from the physical schema"
944 );
945 }
946
947 fn batch_project(
949 expr: Vec<Arc<dyn PhysicalExpr>>,
950 batch: &RecordBatch,
951 schema: SchemaRef,
952 ) -> Result<RecordBatch> {
953 let arrays = expr
954 .iter()
955 .map(|expr| {
956 expr.evaluate(batch)
957 .and_then(|v| v.into_array(batch.num_rows()))
958 })
959 .collect::<Result<Vec<_>>>()?;
960
961 if arrays.is_empty() {
962 let options =
963 RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
964 RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
965 .map_err(Into::into)
966 } else {
967 RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
968 }
969 }
970
971 #[test]
974 fn test_adapt_batches() {
975 let physical_batch = record_batch!(
976 ("a", Int32, vec![Some(1), None, Some(3)]),
977 ("extra", Utf8, vec![Some("x"), Some("y"), None])
978 )
979 .unwrap();
980
981 let physical_schema = physical_batch.schema();
982
983 let logical_schema = Arc::new(Schema::new(vec![
984 Field::new("a", DataType::Int64, true), Field::new("b", DataType::Utf8, true), ]));
987
988 let projection = vec![
989 col("b", &logical_schema).unwrap(),
990 col("a", &logical_schema).unwrap(),
991 ];
992
993 let factory = DefaultPhysicalExprAdapterFactory;
994 let adapter =
995 factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
996
997 let adapted_projection = projection
998 .into_iter()
999 .map(|expr| adapter.rewrite(expr).unwrap())
1000 .collect_vec();
1001
1002 let adapted_schema = Arc::new(Schema::new(
1003 adapted_projection
1004 .iter()
1005 .map(|expr| expr.return_field(&physical_schema).unwrap())
1006 .collect_vec(),
1007 ));
1008
1009 let res = batch_project(
1010 adapted_projection,
1011 &physical_batch,
1012 Arc::clone(&adapted_schema),
1013 )
1014 .unwrap();
1015
1016 assert_eq!(res.num_columns(), 2);
1017 assert_eq!(res.column(0).data_type(), &DataType::Utf8);
1018 assert_eq!(res.column(1).data_type(), &DataType::Int64);
1019 assert_eq!(
1020 res.column(0)
1021 .as_any()
1022 .downcast_ref::<arrow::array::StringArray>()
1023 .unwrap()
1024 .iter()
1025 .collect_vec(),
1026 vec![None, None, None]
1027 );
1028 assert_eq!(
1029 res.column(1)
1030 .as_any()
1031 .downcast_ref::<arrow::array::Int64Array>()
1032 .unwrap()
1033 .iter()
1034 .collect_vec(),
1035 vec![Some(1), None, Some(3)]
1036 );
1037 }
1038
1039 #[test]
1043 fn test_adapt_struct_batches() {
1044 let physical_struct_fields: Fields = vec![
1046 Field::new("id", DataType::Int32, false),
1047 Field::new("name", DataType::Utf8, true),
1048 ]
1049 .into();
1050
1051 let struct_array = StructArray::new(
1052 physical_struct_fields.clone(),
1053 vec![
1054 Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
1055 Arc::new(StringArray::from(vec![
1056 Some("alice"),
1057 None,
1058 Some("charlie"),
1059 ])) as _,
1060 ],
1061 None,
1062 );
1063
1064 let physical_schema = Arc::new(Schema::new(vec![Field::new(
1065 "data",
1066 DataType::Struct(physical_struct_fields),
1067 false,
1068 )]));
1069
1070 let physical_batch = RecordBatch::try_new(
1071 Arc::clone(&physical_schema),
1072 vec![Arc::new(struct_array)],
1073 )
1074 .unwrap();
1075
1076 let logical_struct_fields: Fields = vec![
1081 Field::new("id", DataType::Int64, false),
1082 Field::new("name", DataType::Utf8View, true),
1083 Field::new("extra", DataType::Boolean, true), ]
1085 .into();
1086
1087 let logical_schema = Arc::new(Schema::new(vec![Field::new(
1088 "data",
1089 DataType::Struct(logical_struct_fields),
1090 false,
1091 )]));
1092
1093 let projection = vec![col("data", &logical_schema).unwrap()];
1094
1095 let factory = DefaultPhysicalExprAdapterFactory;
1096 let adapter =
1097 factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
1098
1099 let adapted_projection = projection
1100 .into_iter()
1101 .map(|expr| adapter.rewrite(expr).unwrap())
1102 .collect_vec();
1103
1104 let adapted_schema = Arc::new(Schema::new(
1105 adapted_projection
1106 .iter()
1107 .map(|expr| expr.return_field(&physical_schema).unwrap())
1108 .collect_vec(),
1109 ));
1110
1111 let res = batch_project(
1112 adapted_projection,
1113 &physical_batch,
1114 Arc::clone(&adapted_schema),
1115 )
1116 .unwrap();
1117
1118 assert_eq!(res.num_columns(), 1);
1119
1120 let result_struct = res
1121 .column(0)
1122 .as_any()
1123 .downcast_ref::<StructArray>()
1124 .unwrap();
1125
1126 let id_col = result_struct.column_by_name("id").unwrap();
1128 assert_eq!(id_col.data_type(), &DataType::Int64);
1129 let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
1130 assert_eq!(
1131 id_values.iter().collect_vec(),
1132 vec![Some(1), Some(2), Some(3)]
1133 );
1134
1135 let name_col = result_struct.column_by_name("name").unwrap();
1137 assert_eq!(name_col.data_type(), &DataType::Utf8View);
1138 let name_values = name_col.as_any().downcast_ref::<StringViewArray>().unwrap();
1139 assert_eq!(
1140 name_values.iter().collect_vec(),
1141 vec![Some("alice"), None, Some("charlie")]
1142 );
1143
1144 let extra_col = result_struct.column_by_name("extra").unwrap();
1146 assert_eq!(extra_col.data_type(), &DataType::Boolean);
1147 let extra_values = extra_col.as_any().downcast_ref::<BooleanArray>().unwrap();
1148 assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]);
1149 }
1150
1151 #[test]
1152 fn test_try_rewrite_struct_field_access() {
1153 let physical_schema = Schema::new(vec![Field::new(
1155 "struct_col",
1156 DataType::Struct(
1157 vec![Field::new("existing_field", DataType::Int32, true)].into(),
1158 ),
1159 true,
1160 )]);
1161
1162 let logical_schema = Schema::new(vec![Field::new(
1163 "struct_col",
1164 DataType::Struct(
1165 vec![
1166 Field::new("existing_field", DataType::Int32, true),
1167 Field::new("missing_field", DataType::Utf8, true),
1168 ]
1169 .into(),
1170 ),
1171 true,
1172 )]);
1173
1174 let rewriter = DefaultPhysicalExprAdapterRewriter {
1175 logical_file_schema: &logical_schema,
1176 physical_file_schema: &physical_schema,
1177 };
1178
1179 let column = Arc::new(Column::new("struct_col", 0)) as Arc<dyn PhysicalExpr>;
1181 let result = rewriter.try_rewrite_struct_field_access(&column).unwrap();
1182 assert!(result.is_none());
1183
1184 }
1188
1189 #[test]
1194 fn test_batch_adapter_factory_basic() {
1195 let target_schema = Arc::new(Schema::new(vec![
1197 Field::new("a", DataType::Int64, false),
1198 Field::new("b", DataType::Utf8, true),
1199 ]));
1200
1201 let source_schema = Arc::new(Schema::new(vec![
1203 Field::new("b", DataType::Utf8, true),
1204 Field::new("a", DataType::Int32, false), ]));
1206
1207 let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1208 let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
1209
1210 let source_batch = RecordBatch::try_new(
1212 Arc::clone(&source_schema),
1213 vec![
1214 Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
1215 Arc::new(Int32Array::from(vec![1, 2, 3])),
1216 ],
1217 )
1218 .unwrap();
1219
1220 let adapted = adapter.adapt_batch(&source_batch).unwrap();
1221
1222 assert_eq!(adapted.num_columns(), 2);
1224 assert_eq!(adapted.schema().field(0).name(), "a");
1225 assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64);
1226 assert_eq!(adapted.schema().field(1).name(), "b");
1227 assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
1228
1229 let col_a = adapted
1231 .column(0)
1232 .as_any()
1233 .downcast_ref::<Int64Array>()
1234 .unwrap();
1235 assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]);
1236
1237 let col_b = adapted
1238 .column(1)
1239 .as_any()
1240 .downcast_ref::<StringArray>()
1241 .unwrap();
1242 assert_eq!(
1243 col_b.iter().collect_vec(),
1244 vec![Some("hello"), None, Some("world")]
1245 );
1246 }
1247
1248 #[test]
1249 fn test_batch_adapter_factory_missing_column() {
1250 let target_schema = Arc::new(Schema::new(vec![
1252 Field::new("a", DataType::Int32, false),
1253 Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Float64, true), ]));
1256
1257 let source_schema = Arc::new(Schema::new(vec![
1258 Field::new("a", DataType::Int32, false),
1259 Field::new("b", DataType::Utf8, true),
1260 ]));
1261
1262 let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1263 let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
1264
1265 let source_batch = RecordBatch::try_new(
1266 Arc::clone(&source_schema),
1267 vec![
1268 Arc::new(Int32Array::from(vec![1, 2])),
1269 Arc::new(StringArray::from(vec!["x", "y"])),
1270 ],
1271 )
1272 .unwrap();
1273
1274 let adapted = adapter.adapt_batch(&source_batch).unwrap();
1275
1276 assert_eq!(adapted.num_columns(), 3);
1277
1278 let col_c = adapted.column(2);
1280 assert_eq!(col_c.data_type(), &DataType::Float64);
1281 assert_eq!(col_c.null_count(), 2); }
1283
1284 #[test]
1285 fn test_batch_adapter_factory_with_struct() {
1286 let target_struct_fields: Fields = vec![
1288 Field::new("id", DataType::Int64, false),
1289 Field::new("name", DataType::Utf8, true),
1290 ]
1291 .into();
1292 let target_schema = Arc::new(Schema::new(vec![Field::new(
1293 "data",
1294 DataType::Struct(target_struct_fields),
1295 false,
1296 )]));
1297
1298 let source_struct_fields: Fields = vec![
1300 Field::new("id", DataType::Int32, false),
1301 Field::new("name", DataType::Utf8, true),
1302 ]
1303 .into();
1304 let source_schema = Arc::new(Schema::new(vec![Field::new(
1305 "data",
1306 DataType::Struct(source_struct_fields.clone()),
1307 false,
1308 )]));
1309
1310 let struct_array = StructArray::new(
1311 source_struct_fields,
1312 vec![
1313 Arc::new(Int32Array::from(vec![10, 20])) as _,
1314 Arc::new(StringArray::from(vec!["a", "b"])) as _,
1315 ],
1316 None,
1317 );
1318
1319 let source_batch = RecordBatch::try_new(
1320 Arc::clone(&source_schema),
1321 vec![Arc::new(struct_array)],
1322 )
1323 .unwrap();
1324
1325 let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1326 let adapter = factory.make_adapter(source_schema).unwrap();
1327 let adapted = adapter.adapt_batch(&source_batch).unwrap();
1328
1329 let result_struct = adapted
1330 .column(0)
1331 .as_any()
1332 .downcast_ref::<StructArray>()
1333 .unwrap();
1334
1335 let id_col = result_struct.column_by_name("id").unwrap();
1337 assert_eq!(id_col.data_type(), &DataType::Int64);
1338 let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
1339 assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]);
1340 }
1341
1342 #[test]
1343 fn test_batch_adapter_factory_identity() {
1344 let schema = Arc::new(Schema::new(vec![
1346 Field::new("a", DataType::Int32, false),
1347 Field::new("b", DataType::Utf8, true),
1348 ]));
1349
1350 let factory = BatchAdapterFactory::new(Arc::clone(&schema));
1351 let adapter = factory.make_adapter(Arc::clone(&schema)).unwrap();
1352
1353 let batch = RecordBatch::try_new(
1354 Arc::clone(&schema),
1355 vec![
1356 Arc::new(Int32Array::from(vec![1, 2, 3])),
1357 Arc::new(StringArray::from(vec!["a", "b", "c"])),
1358 ],
1359 )
1360 .unwrap();
1361
1362 let adapted = adapter.adapt_batch(&batch).unwrap();
1363
1364 assert_eq!(adapted.num_columns(), 2);
1365 assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32);
1366 assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
1367 }
1368
1369 #[test]
1370 fn test_batch_adapter_factory_reuse() {
1371 let target_schema = Arc::new(Schema::new(vec![
1373 Field::new("x", DataType::Int64, false),
1374 Field::new("y", DataType::Utf8, true),
1375 ]));
1376
1377 let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1378
1379 let source1 = Arc::new(Schema::new(vec![
1381 Field::new("x", DataType::Int32, false),
1382 Field::new("y", DataType::Utf8, true),
1383 ]));
1384 let adapter1 = factory.make_adapter(source1).unwrap();
1385
1386 let source2 = Arc::new(Schema::new(vec![
1388 Field::new("y", DataType::Utf8, true),
1389 Field::new("x", DataType::Int64, false),
1390 ]));
1391 let adapter2 = factory.make_adapter(source2).unwrap();
1392
1393 assert!(format!("{:?}", adapter1).contains("BatchAdapter"));
1395 assert!(format!("{:?}", adapter2).contains("BatchAdapter"));
1396 }
1397}