1use std::sync::Arc;
21
22use arrow::compute::can_cast_types;
23use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef};
24use datafusion_common::{
25 exec_err,
26 tree_node::{Transformed, TransformedResult, TreeNode},
27 Result, ScalarValue,
28};
29use datafusion_functions::core::getfield::GetFieldFunc;
30use datafusion_physical_expr::{
31 expressions::{self, CastExpr, Column},
32 ScalarFunctionExpr,
33};
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35
36pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
125 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
139
140 fn with_partition_values(
141 &self,
142 partition_values: Vec<(FieldRef, ScalarValue)>,
143 ) -> Arc<dyn PhysicalExprAdapter>;
144}
145
146pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
147 fn create(
149 &self,
150 logical_file_schema: SchemaRef,
151 physical_file_schema: SchemaRef,
152 ) -> Arc<dyn PhysicalExprAdapter>;
153}
154
155#[derive(Debug, Clone)]
156pub struct DefaultPhysicalExprAdapterFactory;
157
158impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
159 fn create(
160 &self,
161 logical_file_schema: SchemaRef,
162 physical_file_schema: SchemaRef,
163 ) -> Arc<dyn PhysicalExprAdapter> {
164 Arc::new(DefaultPhysicalExprAdapter {
165 logical_file_schema,
166 physical_file_schema,
167 partition_values: Vec::new(),
168 })
169 }
170}
171
172#[derive(Debug, Clone)]
193pub struct DefaultPhysicalExprAdapter {
194 logical_file_schema: SchemaRef,
195 physical_file_schema: SchemaRef,
196 partition_values: Vec<(FieldRef, ScalarValue)>,
197}
198
199impl DefaultPhysicalExprAdapter {
200 pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
205 Self {
206 logical_file_schema,
207 physical_file_schema,
208 partition_values: Vec::new(),
209 }
210 }
211}
212
213impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
214 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
215 let rewriter = DefaultPhysicalExprAdapterRewriter {
216 logical_file_schema: &self.logical_file_schema,
217 physical_file_schema: &self.physical_file_schema,
218 partition_fields: &self.partition_values,
219 };
220 expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
221 .data()
222 }
223
224 fn with_partition_values(
225 &self,
226 partition_values: Vec<(FieldRef, ScalarValue)>,
227 ) -> Arc<dyn PhysicalExprAdapter> {
228 Arc::new(DefaultPhysicalExprAdapter {
229 partition_values,
230 ..self.clone()
231 })
232 }
233}
234
235struct DefaultPhysicalExprAdapterRewriter<'a> {
236 logical_file_schema: &'a Schema,
237 physical_file_schema: &'a Schema,
238 partition_fields: &'a [(FieldRef, ScalarValue)],
239}
240
241impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
242 fn rewrite_expr(
243 &self,
244 expr: Arc<dyn PhysicalExpr>,
245 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
246 if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? {
247 return Ok(Transformed::yes(transformed));
248 }
249
250 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
251 return self.rewrite_column(Arc::clone(&expr), column);
252 }
253
254 Ok(Transformed::no(expr))
255 }
256
257 fn try_rewrite_struct_field_access(
261 &self,
262 expr: &Arc<dyn PhysicalExpr>,
263 ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
264 let get_field_expr =
265 match ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(expr.as_ref()) {
266 Some(expr) => expr,
267 None => return Ok(None),
268 };
269
270 let source_expr = match get_field_expr.args().first() {
271 Some(expr) => expr,
272 None => return Ok(None),
273 };
274
275 let field_name_expr = match get_field_expr.args().get(1) {
276 Some(expr) => expr,
277 None => return Ok(None),
278 };
279
280 let lit = match field_name_expr
281 .as_any()
282 .downcast_ref::<expressions::Literal>()
283 {
284 Some(lit) => lit,
285 None => return Ok(None),
286 };
287
288 let field_name = match lit.value().try_as_str().flatten() {
289 Some(name) => name,
290 None => return Ok(None),
291 };
292
293 let column = match source_expr.as_any().downcast_ref::<Column>() {
294 Some(column) => column,
295 None => return Ok(None),
296 };
297
298 let physical_field =
299 match self.physical_file_schema.field_with_name(column.name()) {
300 Ok(field) => field,
301 Err(_) => return Ok(None),
302 };
303
304 let physical_struct_fields = match physical_field.data_type() {
305 DataType::Struct(fields) => fields,
306 _ => return Ok(None),
307 };
308
309 if physical_struct_fields
310 .iter()
311 .any(|f| f.name() == field_name)
312 {
313 return Ok(None);
314 }
315
316 let logical_field = match self.logical_file_schema.field_with_name(column.name())
317 {
318 Ok(field) => field,
319 Err(_) => return Ok(None),
320 };
321
322 let logical_struct_fields = match logical_field.data_type() {
323 DataType::Struct(fields) => fields,
324 _ => return Ok(None),
325 };
326
327 let logical_struct_field = match logical_struct_fields
328 .iter()
329 .find(|f| f.name() == field_name)
330 {
331 Some(field) => field,
332 None => return Ok(None),
333 };
334
335 let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
336 Ok(Some(expressions::lit(null_value)))
337 }
338
339 fn rewrite_column(
340 &self,
341 expr: Arc<dyn PhysicalExpr>,
342 column: &Column,
343 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
344 let logical_field = match self.logical_file_schema.field_with_name(column.name())
346 {
347 Ok(field) => field,
348 Err(e) => {
349 if let Some(partition_value) = self.get_partition_value(column.name()) {
351 return Ok(Transformed::yes(expressions::lit(partition_value)));
352 }
353 if let Ok(physical_field) =
357 self.physical_file_schema.field_with_name(column.name())
358 {
359 physical_field
363 } else {
364 return Err(e.into());
368 }
369 }
370 };
371
372 let physical_column_index =
374 match self.physical_file_schema.index_of(column.name()) {
375 Ok(index) => index,
376 Err(_) => {
377 if !logical_field.is_nullable() {
378 return exec_err!(
379 "Non-nullable column '{}' is missing from the physical schema",
380 column.name()
381 );
382 }
383 let null_value =
388 ScalarValue::Null.cast_to(logical_field.data_type())?;
389 return Ok(Transformed::yes(expressions::lit(null_value)));
390 }
391 };
392 let physical_field = self.physical_file_schema.field(physical_column_index);
393
394 let column = match (
395 column.index() == physical_column_index,
396 logical_field.data_type() == physical_field.data_type(),
397 ) {
398 (true, true) => return Ok(Transformed::no(expr)),
400 (true, _) => column.clone(),
402 (false, _) => {
403 Column::new_with_schema(logical_field.name(), self.physical_file_schema)?
404 }
405 };
406
407 if logical_field.data_type() == physical_field.data_type() {
408 return Ok(Transformed::yes(Arc::new(column)));
410 }
411
412 let is_compatible =
417 can_cast_types(physical_field.data_type(), logical_field.data_type());
418 if !is_compatible {
419 return exec_err!(
420 "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
421 column.name(),
422 physical_field.data_type(),
423 logical_field.data_type()
424 );
425 }
426
427 let cast_expr = Arc::new(CastExpr::new(
428 Arc::new(column),
429 logical_field.data_type().clone(),
430 None,
431 ));
432
433 Ok(Transformed::yes(cast_expr))
434 }
435
436 fn get_partition_value(&self, column_name: &str) -> Option<ScalarValue> {
437 self.partition_fields
438 .iter()
439 .find(|(field, _)| field.name() == column_name)
440 .map(|(_, value)| value.clone())
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use arrow::array::{RecordBatch, RecordBatchOptions};
448 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
449 use datafusion_common::{assert_contains, record_batch, Result, ScalarValue};
450 use datafusion_expr::Operator;
451 use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal};
452 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
453 use itertools::Itertools;
454 use std::sync::Arc;
455
456 fn create_test_schema() -> (Schema, Schema) {
457 let physical_schema = Schema::new(vec![
458 Field::new("a", DataType::Int32, false),
459 Field::new("b", DataType::Utf8, true),
460 ]);
461
462 let logical_schema = Schema::new(vec![
463 Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, true),
465 Field::new("c", DataType::Float64, true), ]);
467
468 (physical_schema, logical_schema)
469 }
470
471 #[test]
472 fn test_rewrite_column_with_type_cast() {
473 let (physical_schema, logical_schema) = create_test_schema();
474
475 let factory = DefaultPhysicalExprAdapterFactory;
476 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
477 let column_expr = Arc::new(Column::new("a", 0));
478
479 let result = adapter.rewrite(column_expr).unwrap();
480
481 assert!(result.as_any().downcast_ref::<CastExpr>().is_some());
483 }
484
485 #[test]
486 fn test_rewrite_multi_column_expr_with_type_cast() {
487 let (physical_schema, logical_schema) = create_test_schema();
488 let factory = DefaultPhysicalExprAdapterFactory;
489 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
490
491 let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
493 let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
494 let expr = expressions::BinaryExpr::new(
495 Arc::clone(&column_a),
496 Operator::Plus,
497 Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
498 );
499 let expr = expressions::BinaryExpr::new(
500 Arc::new(expr),
501 Operator::Or,
502 Arc::new(expressions::BinaryExpr::new(
503 Arc::clone(&column_c),
504 Operator::Gt,
505 Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
506 )),
507 );
508
509 let result = adapter.rewrite(Arc::new(expr)).unwrap();
510 println!("Rewritten expression: {result}");
511
512 let expected = expressions::BinaryExpr::new(
513 Arc::new(CastExpr::new(
514 Arc::new(Column::new("a", 0)),
515 DataType::Int64,
516 None,
517 )),
518 Operator::Plus,
519 Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
520 );
521 let expected = Arc::new(expressions::BinaryExpr::new(
522 Arc::new(expected),
523 Operator::Or,
524 Arc::new(expressions::BinaryExpr::new(
525 lit(ScalarValue::Float64(None)), Operator::Gt,
527 Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
528 )),
529 )) as Arc<dyn PhysicalExpr>;
530
531 assert_eq!(
532 result.to_string(),
533 expected.to_string(),
534 "The rewritten expression did not match the expected output"
535 );
536 }
537
538 #[test]
539 fn test_rewrite_struct_column_incompatible() {
540 let physical_schema = Schema::new(vec![Field::new(
541 "data",
542 DataType::Struct(vec![Field::new("field1", DataType::Binary, true)].into()),
543 true,
544 )]);
545
546 let logical_schema = Schema::new(vec![Field::new(
547 "data",
548 DataType::Struct(vec![Field::new("field1", DataType::Int32, true)].into()),
549 true,
550 )]);
551
552 let factory = DefaultPhysicalExprAdapterFactory;
553 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
554 let column_expr = Arc::new(Column::new("data", 0));
555
556 let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
557 assert_contains!(error_msg, "Cannot cast column 'data'");
558 }
559
560 #[test]
561 fn test_rewrite_struct_compatible_cast() {
562 let physical_schema = Schema::new(vec![Field::new(
563 "data",
564 DataType::Struct(
565 vec![
566 Field::new("id", DataType::Int32, false),
567 Field::new("name", DataType::Utf8, true),
568 ]
569 .into(),
570 ),
571 false,
572 )]);
573
574 let logical_schema = Schema::new(vec![Field::new(
575 "data",
576 DataType::Struct(
577 vec![
578 Field::new("id", DataType::Int64, false),
579 Field::new("name", DataType::Utf8View, true),
580 ]
581 .into(),
582 ),
583 false,
584 )]);
585
586 let factory = DefaultPhysicalExprAdapterFactory;
587 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
588 let column_expr = Arc::new(Column::new("data", 0));
589
590 let result = adapter.rewrite(column_expr).unwrap();
591
592 let expected = Arc::new(CastExpr::new(
593 Arc::new(Column::new("data", 0)),
594 DataType::Struct(
595 vec![
596 Field::new("id", DataType::Int64, false),
597 Field::new("name", DataType::Utf8View, true),
598 ]
599 .into(),
600 ),
601 None,
602 )) as Arc<dyn PhysicalExpr>;
603
604 assert_eq!(result.to_string(), expected.to_string());
605 }
606
607 #[test]
608 fn test_rewrite_missing_column() -> Result<()> {
609 let (physical_schema, logical_schema) = create_test_schema();
610
611 let factory = DefaultPhysicalExprAdapterFactory;
612 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
613 let column_expr = Arc::new(Column::new("c", 2));
614
615 let result = adapter.rewrite(column_expr)?;
616
617 if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
619 assert_eq!(*literal.value(), ScalarValue::Float64(None));
620 } else {
621 panic!("Expected literal expression");
622 }
623
624 Ok(())
625 }
626
627 #[test]
628 fn test_rewrite_missing_column_non_nullable_error() {
629 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
630 let logical_schema = Schema::new(vec![
631 Field::new("a", DataType::Int64, false),
632 Field::new("b", DataType::Utf8, false), ]);
634
635 let factory = DefaultPhysicalExprAdapterFactory;
636 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
637 let column_expr = Arc::new(Column::new("b", 1));
638
639 let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
640 assert_contains!(error_msg, "Non-nullable column 'b' is missing");
641 }
642
643 #[test]
644 fn test_rewrite_missing_column_nullable() {
645 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
646 let logical_schema = Schema::new(vec![
647 Field::new("a", DataType::Int64, false),
648 Field::new("b", DataType::Utf8, true), ]);
650
651 let factory = DefaultPhysicalExprAdapterFactory;
652 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
653 let column_expr = Arc::new(Column::new("b", 1));
654
655 let result = adapter.rewrite(column_expr).unwrap();
656
657 let expected =
658 Arc::new(Literal::new(ScalarValue::Utf8(None))) as Arc<dyn PhysicalExpr>;
659
660 assert_eq!(result.to_string(), expected.to_string());
661 }
662
663 #[test]
664 fn test_rewrite_partition_column() -> Result<()> {
665 let (physical_schema, logical_schema) = create_test_schema();
666
667 let partition_field =
668 Arc::new(Field::new("partition_col", DataType::Utf8, false));
669 let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
670 let partition_values = vec![(partition_field, partition_value)];
671
672 let factory = DefaultPhysicalExprAdapterFactory;
673 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
674 let adapter = adapter.with_partition_values(partition_values);
675
676 let column_expr = Arc::new(Column::new("partition_col", 0));
677 let result = adapter.rewrite(column_expr)?;
678
679 if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
681 assert_eq!(
682 *literal.value(),
683 ScalarValue::Utf8(Some("test_value".to_string()))
684 );
685 } else {
686 panic!("Expected literal expression");
687 }
688
689 Ok(())
690 }
691
692 #[test]
693 fn test_rewrite_no_change_needed() -> Result<()> {
694 let (physical_schema, logical_schema) = create_test_schema();
695
696 let factory = DefaultPhysicalExprAdapterFactory;
697 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
698 let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
699
700 let result = adapter.rewrite(Arc::clone(&column_expr))?;
701
702 assert!(std::ptr::eq(
705 column_expr.as_ref() as *const dyn PhysicalExpr,
706 result.as_ref() as *const dyn PhysicalExpr
707 ));
708
709 Ok(())
710 }
711
712 #[test]
713 fn test_non_nullable_missing_column_error() {
714 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
715 let logical_schema = Schema::new(vec![
716 Field::new("a", DataType::Int32, false),
717 Field::new("b", DataType::Utf8, false), ]);
719
720 let factory = DefaultPhysicalExprAdapterFactory;
721 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
722 let column_expr = Arc::new(Column::new("b", 1));
723
724 let result = adapter.rewrite(column_expr);
725 assert!(result.is_err());
726 assert_contains!(
727 result.unwrap_err().to_string(),
728 "Non-nullable column 'b' is missing from the physical schema"
729 );
730 }
731
732 fn batch_project(
734 expr: Vec<Arc<dyn PhysicalExpr>>,
735 batch: &RecordBatch,
736 schema: SchemaRef,
737 ) -> Result<RecordBatch> {
738 let arrays = expr
739 .iter()
740 .map(|expr| {
741 expr.evaluate(batch)
742 .and_then(|v| v.into_array(batch.num_rows()))
743 })
744 .collect::<Result<Vec<_>>>()?;
745
746 if arrays.is_empty() {
747 let options =
748 RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
749 RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
750 .map_err(Into::into)
751 } else {
752 RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
753 }
754 }
755
756 #[test]
759 fn test_adapt_batches() {
760 let physical_batch = record_batch!(
761 ("a", Int32, vec![Some(1), None, Some(3)]),
762 ("extra", Utf8, vec![Some("x"), Some("y"), None])
763 )
764 .unwrap();
765
766 let physical_schema = physical_batch.schema();
767
768 let logical_schema = Arc::new(Schema::new(vec![
769 Field::new("a", DataType::Int64, true), Field::new("b", DataType::Utf8, true), ]));
772
773 let projection = vec![
774 col("b", &logical_schema).unwrap(),
775 col("a", &logical_schema).unwrap(),
776 ];
777
778 let factory = DefaultPhysicalExprAdapterFactory;
779 let adapter =
780 factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
781
782 let adapted_projection = projection
783 .into_iter()
784 .map(|expr| adapter.rewrite(expr).unwrap())
785 .collect_vec();
786
787 let adapted_schema = Arc::new(Schema::new(
788 adapted_projection
789 .iter()
790 .map(|expr| expr.return_field(&physical_schema).unwrap())
791 .collect_vec(),
792 ));
793
794 let res = batch_project(
795 adapted_projection,
796 &physical_batch,
797 Arc::clone(&adapted_schema),
798 )
799 .unwrap();
800
801 assert_eq!(res.num_columns(), 2);
802 assert_eq!(res.column(0).data_type(), &DataType::Utf8);
803 assert_eq!(res.column(1).data_type(), &DataType::Int64);
804 assert_eq!(
805 res.column(0)
806 .as_any()
807 .downcast_ref::<arrow::array::StringArray>()
808 .unwrap()
809 .iter()
810 .collect_vec(),
811 vec![None, None, None]
812 );
813 assert_eq!(
814 res.column(1)
815 .as_any()
816 .downcast_ref::<arrow::array::Int64Array>()
817 .unwrap()
818 .iter()
819 .collect_vec(),
820 vec![Some(1), None, Some(3)]
821 );
822 }
823
824 #[test]
825 fn test_try_rewrite_struct_field_access() {
826 let physical_schema = Schema::new(vec![Field::new(
828 "struct_col",
829 DataType::Struct(
830 vec![Field::new("existing_field", DataType::Int32, true)].into(),
831 ),
832 true,
833 )]);
834
835 let logical_schema = Schema::new(vec![Field::new(
836 "struct_col",
837 DataType::Struct(
838 vec![
839 Field::new("existing_field", DataType::Int32, true),
840 Field::new("missing_field", DataType::Utf8, true),
841 ]
842 .into(),
843 ),
844 true,
845 )]);
846
847 let rewriter = DefaultPhysicalExprAdapterRewriter {
848 logical_file_schema: &logical_schema,
849 physical_file_schema: &physical_schema,
850 partition_fields: &[],
851 };
852
853 let column = Arc::new(Column::new("struct_col", 0)) as Arc<dyn PhysicalExpr>;
855 let result = rewriter.try_rewrite_struct_field_access(&column).unwrap();
856 assert!(result.is_none());
857
858 }
862}