datafusion_physical_expr_adapter/
schema_rewriter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical expression schema rewriting utilities
19
20use std::sync::Arc;
21
22use arrow::compute::can_cast_types;
23use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef};
24use datafusion_common::{
25    exec_err,
26    tree_node::{Transformed, TransformedResult, TreeNode},
27    Result, ScalarValue,
28};
29use datafusion_functions::core::getfield::GetFieldFunc;
30use datafusion_physical_expr::{
31    expressions::{self, CastExpr, Column},
32    ScalarFunctionExpr,
33};
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35
36/// Trait for adapting physical expressions to match a target schema.
37///
38/// This is used in file scans to rewrite expressions so that they can be evaluated
39/// against the physical schema of the file being scanned. It allows for handling
40/// differences between logical and physical schemas, such as type mismatches or missing columns.
41///
42/// ## Overview
43///
44/// The `PhysicalExprAdapter` allows rewriting physical expressions to match different schemas, including:
45///
46/// - **Type casting**: When logical and physical schemas have different types, expressions are
47///   automatically wrapped with cast operations. For example, `lit(ScalarValue::Int32(123)) = int64_column`
48///   gets rewritten to `lit(ScalarValue::Int32(123)) = cast(int64_column, 'Int32')`.
49///   Note that this does not attempt to simplify such expressions - that is done by shared simplifiers.
50///
51/// - **Missing columns**: When a column exists in the logical schema but not in the physical schema,
52///   references to it are replaced with null literals.
53///
54/// - **Struct field access**: Expressions like `struct_column.field_that_is_missing_in_schema` are
55///   rewritten to `null` when the field doesn't exist in the physical schema.
56///
57/// - **Partition columns**: Partition column references can be replaced with their literal values
58///   when scanning specific partitions.
59///
60/// ## Custom Implementations
61///
62/// You can create a custom implementation of this trait to handle specific rewriting logic.
63/// For example, to fill in missing columns with default values instead of nulls:
64///
65/// ```rust
66/// use datafusion_physical_expr_adapter::{PhysicalExprAdapter, PhysicalExprAdapterFactory};
67/// use arrow::datatypes::{Schema, Field, DataType, FieldRef, SchemaRef};
68/// use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
69/// use datafusion_common::{Result, ScalarValue, tree_node::{Transformed, TransformedResult, TreeNode}};
70/// use datafusion_physical_expr::expressions::{self, Column};
71/// use std::sync::Arc;
72///
73/// #[derive(Debug)]
74/// pub struct CustomPhysicalExprAdapter {
75///     logical_file_schema: SchemaRef,
76///     physical_file_schema: SchemaRef,
77/// }
78///
79/// impl PhysicalExprAdapter for CustomPhysicalExprAdapter {
80///     fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
81///         expr.transform(|expr| {
82///             if let Some(column) = expr.as_any().downcast_ref::<Column>() {
83///                 // Check if the column exists in the physical schema
84///                 if self.physical_file_schema.index_of(column.name()).is_err() {
85///                     // If the column is missing, fill it with a default value instead of null
86///                     // The default value could be stored in the table schema's column metadata for example.
87///                     let default_value = ScalarValue::Int32(Some(0));
88///                     return Ok(Transformed::yes(expressions::lit(default_value)));
89///                 }
90///             }
91///             // If the column exists, return it as is
92///             Ok(Transformed::no(expr))
93///         }).data()
94///     }
95///
96///     fn with_partition_values(
97///         &self,
98///         partition_values: Vec<(FieldRef, ScalarValue)>,
99///     ) -> Arc<dyn PhysicalExprAdapter> {
100///         // For simplicity, this example ignores partition values
101///         Arc::new(CustomPhysicalExprAdapter {
102///             logical_file_schema: self.logical_file_schema.clone(),
103///             physical_file_schema: self.physical_file_schema.clone(),
104///         })
105///     }
106/// }
107///
108/// #[derive(Debug)]
109/// pub struct CustomPhysicalExprAdapterFactory;
110///
111/// impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory {
112///     fn create(
113///         &self,
114///         logical_file_schema: SchemaRef,
115///         physical_file_schema: SchemaRef,
116///     ) -> Arc<dyn PhysicalExprAdapter> {
117///         Arc::new(CustomPhysicalExprAdapter {
118///             logical_file_schema,
119///             physical_file_schema,
120///         })
121///     }
122/// }
123/// ```
124pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
125    /// Rewrite a physical expression to match the target schema.
126    ///
127    /// This method should return a transformed expression that matches the target schema.
128    ///
129    /// Arguments:
130    /// - `expr`: The physical expression to rewrite.
131    /// - `logical_file_schema`: The logical schema of the table being queried, excluding any partition columns.
132    /// - `physical_file_schema`: The physical schema of the file being scanned.
133    /// - `partition_values`: Optional partition values to use for rewriting partition column references.
134    ///   These are handled as if they were columns appended onto the logical file schema.
135    ///
136    /// Returns:
137    /// - `Arc<dyn PhysicalExpr>`: The rewritten physical expression that can be evaluated against the physical schema.
138    fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
139
140    fn with_partition_values(
141        &self,
142        partition_values: Vec<(FieldRef, ScalarValue)>,
143    ) -> Arc<dyn PhysicalExprAdapter>;
144}
145
146pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
147    /// Create a new instance of the physical expression adapter.
148    fn create(
149        &self,
150        logical_file_schema: SchemaRef,
151        physical_file_schema: SchemaRef,
152    ) -> Arc<dyn PhysicalExprAdapter>;
153}
154
155#[derive(Debug, Clone)]
156pub struct DefaultPhysicalExprAdapterFactory;
157
158impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
159    fn create(
160        &self,
161        logical_file_schema: SchemaRef,
162        physical_file_schema: SchemaRef,
163    ) -> Arc<dyn PhysicalExprAdapter> {
164        Arc::new(DefaultPhysicalExprAdapter {
165            logical_file_schema,
166            physical_file_schema,
167            partition_values: Vec::new(),
168        })
169    }
170}
171
172/// Default implementation for rewriting physical expressions to match different schemas.
173///
174/// # Example
175///
176/// ```rust
177/// use datafusion_physical_expr_adapter::{DefaultPhysicalExprAdapterFactory, PhysicalExprAdapterFactory};
178/// use arrow::datatypes::Schema;
179/// use std::sync::Arc;
180///
181/// # fn example(
182/// #     predicate: std::sync::Arc<dyn datafusion_physical_expr_common::physical_expr::PhysicalExpr>,
183/// #     physical_file_schema: &Schema,
184/// #     logical_file_schema: &Schema,
185/// # ) -> datafusion_common::Result<()> {
186/// let factory = DefaultPhysicalExprAdapterFactory;
187/// let adapter = factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone()));
188/// let adapted_predicate = adapter.rewrite(predicate)?;
189/// # Ok(())
190/// # }
191/// ```
192#[derive(Debug, Clone)]
193pub struct DefaultPhysicalExprAdapter {
194    logical_file_schema: SchemaRef,
195    physical_file_schema: SchemaRef,
196    partition_values: Vec<(FieldRef, ScalarValue)>,
197}
198
199impl DefaultPhysicalExprAdapter {
200    /// Create a new instance of the default physical expression adapter.
201    ///
202    /// This adapter rewrites expressions to match the physical schema of the file being scanned,
203    /// handling type mismatches and missing columns by filling them with default values.
204    pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
205        Self {
206            logical_file_schema,
207            physical_file_schema,
208            partition_values: Vec::new(),
209        }
210    }
211}
212
213impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
214    fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
215        let rewriter = DefaultPhysicalExprAdapterRewriter {
216            logical_file_schema: &self.logical_file_schema,
217            physical_file_schema: &self.physical_file_schema,
218            partition_fields: &self.partition_values,
219        };
220        expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
221            .data()
222    }
223
224    fn with_partition_values(
225        &self,
226        partition_values: Vec<(FieldRef, ScalarValue)>,
227    ) -> Arc<dyn PhysicalExprAdapter> {
228        Arc::new(DefaultPhysicalExprAdapter {
229            partition_values,
230            ..self.clone()
231        })
232    }
233}
234
235struct DefaultPhysicalExprAdapterRewriter<'a> {
236    logical_file_schema: &'a Schema,
237    physical_file_schema: &'a Schema,
238    partition_fields: &'a [(FieldRef, ScalarValue)],
239}
240
241impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
242    fn rewrite_expr(
243        &self,
244        expr: Arc<dyn PhysicalExpr>,
245    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
246        if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? {
247            return Ok(Transformed::yes(transformed));
248        }
249
250        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
251            return self.rewrite_column(Arc::clone(&expr), column);
252        }
253
254        Ok(Transformed::no(expr))
255    }
256
257    /// Attempt to rewrite struct field access expressions to return null if the field does not exist in the physical schema.
258    /// Note that this does *not* handle nested struct fields, only top-level struct field access.
259    /// See <https://github.com/apache/datafusion/issues/17114> for more details.
260    fn try_rewrite_struct_field_access(
261        &self,
262        expr: &Arc<dyn PhysicalExpr>,
263    ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
264        let get_field_expr =
265            match ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(expr.as_ref()) {
266                Some(expr) => expr,
267                None => return Ok(None),
268            };
269
270        let source_expr = match get_field_expr.args().first() {
271            Some(expr) => expr,
272            None => return Ok(None),
273        };
274
275        let field_name_expr = match get_field_expr.args().get(1) {
276            Some(expr) => expr,
277            None => return Ok(None),
278        };
279
280        let lit = match field_name_expr
281            .as_any()
282            .downcast_ref::<expressions::Literal>()
283        {
284            Some(lit) => lit,
285            None => return Ok(None),
286        };
287
288        let field_name = match lit.value().try_as_str().flatten() {
289            Some(name) => name,
290            None => return Ok(None),
291        };
292
293        let column = match source_expr.as_any().downcast_ref::<Column>() {
294            Some(column) => column,
295            None => return Ok(None),
296        };
297
298        let physical_field =
299            match self.physical_file_schema.field_with_name(column.name()) {
300                Ok(field) => field,
301                Err(_) => return Ok(None),
302            };
303
304        let physical_struct_fields = match physical_field.data_type() {
305            DataType::Struct(fields) => fields,
306            _ => return Ok(None),
307        };
308
309        if physical_struct_fields
310            .iter()
311            .any(|f| f.name() == field_name)
312        {
313            return Ok(None);
314        }
315
316        let logical_field = match self.logical_file_schema.field_with_name(column.name())
317        {
318            Ok(field) => field,
319            Err(_) => return Ok(None),
320        };
321
322        let logical_struct_fields = match logical_field.data_type() {
323            DataType::Struct(fields) => fields,
324            _ => return Ok(None),
325        };
326
327        let logical_struct_field = match logical_struct_fields
328            .iter()
329            .find(|f| f.name() == field_name)
330        {
331            Some(field) => field,
332            None => return Ok(None),
333        };
334
335        let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
336        Ok(Some(expressions::lit(null_value)))
337    }
338
339    fn rewrite_column(
340        &self,
341        expr: Arc<dyn PhysicalExpr>,
342        column: &Column,
343    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
344        // Get the logical field for this column if it exists in the logical schema
345        let logical_field = match self.logical_file_schema.field_with_name(column.name())
346        {
347            Ok(field) => field,
348            Err(e) => {
349                // If the column is a partition field, we can use the partition value
350                if let Some(partition_value) = self.get_partition_value(column.name()) {
351                    return Ok(Transformed::yes(expressions::lit(partition_value)));
352                }
353                // This can be hit if a custom rewrite injected a reference to a column that doesn't exist in the logical schema.
354                // For example, a pre-computed column that is kept only in the physical schema.
355                // If the column exists in the physical schema, we can still use it.
356                if let Ok(physical_field) =
357                    self.physical_file_schema.field_with_name(column.name())
358                {
359                    // If the column exists in the physical schema, we can use it in place of the logical column.
360                    // This is nice to users because if they do a rewrite that results in something like `physical_int32_col = 123u64`
361                    // we'll at least handle the casts for them.
362                    physical_field
363                } else {
364                    // A completely unknown column that doesn't exist in either schema!
365                    // This should probably never be hit unless something upstream broke, but nonetheless it's better
366                    // for us to return a handleable error than to panic / do something unexpected.
367                    return Err(e.into());
368                }
369            }
370        };
371
372        // Check if the column exists in the physical schema
373        let physical_column_index =
374            match self.physical_file_schema.index_of(column.name()) {
375                Ok(index) => index,
376                Err(_) => {
377                    if !logical_field.is_nullable() {
378                        return exec_err!(
379                        "Non-nullable column '{}' is missing from the physical schema",
380                        column.name()
381                    );
382                    }
383                    // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do.
384                    // TODO: do we need to sync this with what the `SchemaAdapter` actually does?
385                    // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else!
386                    // See https://github.com/apache/datafusion/issues/16527
387                    let null_value =
388                        ScalarValue::Null.cast_to(logical_field.data_type())?;
389                    return Ok(Transformed::yes(expressions::lit(null_value)));
390                }
391            };
392        let physical_field = self.physical_file_schema.field(physical_column_index);
393
394        let column = match (
395            column.index() == physical_column_index,
396            logical_field.data_type() == physical_field.data_type(),
397        ) {
398            // If the column index matches and the data types match, we can use the column as is
399            (true, true) => return Ok(Transformed::no(expr)),
400            // If the indexes or data types do not match, we need to create a new column expression
401            (true, _) => column.clone(),
402            (false, _) => {
403                Column::new_with_schema(logical_field.name(), self.physical_file_schema)?
404            }
405        };
406
407        if logical_field.data_type() == physical_field.data_type() {
408            // If the data types match, we can use the column as is
409            return Ok(Transformed::yes(Arc::new(column)));
410        }
411
412        // We need to cast the column to the logical data type
413        // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
414        // since that's much cheaper to evalaute.
415        // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928
416        let is_compatible =
417            can_cast_types(physical_field.data_type(), logical_field.data_type());
418        if !is_compatible {
419            return exec_err!(
420                "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
421                column.name(),
422                physical_field.data_type(),
423                logical_field.data_type()
424            );
425        }
426
427        let cast_expr = Arc::new(CastExpr::new(
428            Arc::new(column),
429            logical_field.data_type().clone(),
430            None,
431        ));
432
433        Ok(Transformed::yes(cast_expr))
434    }
435
436    fn get_partition_value(&self, column_name: &str) -> Option<ScalarValue> {
437        self.partition_fields
438            .iter()
439            .find(|(field, _)| field.name() == column_name)
440            .map(|(_, value)| value.clone())
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use arrow::array::{RecordBatch, RecordBatchOptions};
448    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
449    use datafusion_common::{assert_contains, record_batch, Result, ScalarValue};
450    use datafusion_expr::Operator;
451    use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal};
452    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
453    use itertools::Itertools;
454    use std::sync::Arc;
455
456    fn create_test_schema() -> (Schema, Schema) {
457        let physical_schema = Schema::new(vec![
458            Field::new("a", DataType::Int32, false),
459            Field::new("b", DataType::Utf8, true),
460        ]);
461
462        let logical_schema = Schema::new(vec![
463            Field::new("a", DataType::Int64, false), // Different type
464            Field::new("b", DataType::Utf8, true),
465            Field::new("c", DataType::Float64, true), // Missing from physical
466        ]);
467
468        (physical_schema, logical_schema)
469    }
470
471    #[test]
472    fn test_rewrite_column_with_type_cast() {
473        let (physical_schema, logical_schema) = create_test_schema();
474
475        let factory = DefaultPhysicalExprAdapterFactory;
476        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
477        let column_expr = Arc::new(Column::new("a", 0));
478
479        let result = adapter.rewrite(column_expr).unwrap();
480
481        // Should be wrapped in a cast expression
482        assert!(result.as_any().downcast_ref::<CastExpr>().is_some());
483    }
484
485    #[test]
486    fn test_rewrite_multi_column_expr_with_type_cast() {
487        let (physical_schema, logical_schema) = create_test_schema();
488        let factory = DefaultPhysicalExprAdapterFactory;
489        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
490
491        // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter
492        let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
493        let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
494        let expr = expressions::BinaryExpr::new(
495            Arc::clone(&column_a),
496            Operator::Plus,
497            Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
498        );
499        let expr = expressions::BinaryExpr::new(
500            Arc::new(expr),
501            Operator::Or,
502            Arc::new(expressions::BinaryExpr::new(
503                Arc::clone(&column_c),
504                Operator::Gt,
505                Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
506            )),
507        );
508
509        let result = adapter.rewrite(Arc::new(expr)).unwrap();
510        println!("Rewritten expression: {result}");
511
512        let expected = expressions::BinaryExpr::new(
513            Arc::new(CastExpr::new(
514                Arc::new(Column::new("a", 0)),
515                DataType::Int64,
516                None,
517            )),
518            Operator::Plus,
519            Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
520        );
521        let expected = Arc::new(expressions::BinaryExpr::new(
522            Arc::new(expected),
523            Operator::Or,
524            Arc::new(expressions::BinaryExpr::new(
525                lit(ScalarValue::Float64(None)), // c is missing, so it becomes null
526                Operator::Gt,
527                Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
528            )),
529        )) as Arc<dyn PhysicalExpr>;
530
531        assert_eq!(
532            result.to_string(),
533            expected.to_string(),
534            "The rewritten expression did not match the expected output"
535        );
536    }
537
538    #[test]
539    fn test_rewrite_struct_column_incompatible() {
540        let physical_schema = Schema::new(vec![Field::new(
541            "data",
542            DataType::Struct(vec![Field::new("field1", DataType::Binary, true)].into()),
543            true,
544        )]);
545
546        let logical_schema = Schema::new(vec![Field::new(
547            "data",
548            DataType::Struct(vec![Field::new("field1", DataType::Int32, true)].into()),
549            true,
550        )]);
551
552        let factory = DefaultPhysicalExprAdapterFactory;
553        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
554        let column_expr = Arc::new(Column::new("data", 0));
555
556        let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
557        assert_contains!(error_msg, "Cannot cast column 'data'");
558    }
559
560    #[test]
561    fn test_rewrite_struct_compatible_cast() {
562        let physical_schema = Schema::new(vec![Field::new(
563            "data",
564            DataType::Struct(
565                vec![
566                    Field::new("id", DataType::Int32, false),
567                    Field::new("name", DataType::Utf8, true),
568                ]
569                .into(),
570            ),
571            false,
572        )]);
573
574        let logical_schema = Schema::new(vec![Field::new(
575            "data",
576            DataType::Struct(
577                vec![
578                    Field::new("id", DataType::Int64, false),
579                    Field::new("name", DataType::Utf8View, true),
580                ]
581                .into(),
582            ),
583            false,
584        )]);
585
586        let factory = DefaultPhysicalExprAdapterFactory;
587        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
588        let column_expr = Arc::new(Column::new("data", 0));
589
590        let result = adapter.rewrite(column_expr).unwrap();
591
592        let expected = Arc::new(CastExpr::new(
593            Arc::new(Column::new("data", 0)),
594            DataType::Struct(
595                vec![
596                    Field::new("id", DataType::Int64, false),
597                    Field::new("name", DataType::Utf8View, true),
598                ]
599                .into(),
600            ),
601            None,
602        )) as Arc<dyn PhysicalExpr>;
603
604        assert_eq!(result.to_string(), expected.to_string());
605    }
606
607    #[test]
608    fn test_rewrite_missing_column() -> Result<()> {
609        let (physical_schema, logical_schema) = create_test_schema();
610
611        let factory = DefaultPhysicalExprAdapterFactory;
612        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
613        let column_expr = Arc::new(Column::new("c", 2));
614
615        let result = adapter.rewrite(column_expr)?;
616
617        // Should be replaced with a literal null
618        if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
619            assert_eq!(*literal.value(), ScalarValue::Float64(None));
620        } else {
621            panic!("Expected literal expression");
622        }
623
624        Ok(())
625    }
626
627    #[test]
628    fn test_rewrite_missing_column_non_nullable_error() {
629        let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
630        let logical_schema = Schema::new(vec![
631            Field::new("a", DataType::Int64, false),
632            Field::new("b", DataType::Utf8, false), // Missing and non-nullable
633        ]);
634
635        let factory = DefaultPhysicalExprAdapterFactory;
636        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
637        let column_expr = Arc::new(Column::new("b", 1));
638
639        let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
640        assert_contains!(error_msg, "Non-nullable column 'b' is missing");
641    }
642
643    #[test]
644    fn test_rewrite_missing_column_nullable() {
645        let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
646        let logical_schema = Schema::new(vec![
647            Field::new("a", DataType::Int64, false),
648            Field::new("b", DataType::Utf8, true), // Missing but nullable
649        ]);
650
651        let factory = DefaultPhysicalExprAdapterFactory;
652        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
653        let column_expr = Arc::new(Column::new("b", 1));
654
655        let result = adapter.rewrite(column_expr).unwrap();
656
657        let expected =
658            Arc::new(Literal::new(ScalarValue::Utf8(None))) as Arc<dyn PhysicalExpr>;
659
660        assert_eq!(result.to_string(), expected.to_string());
661    }
662
663    #[test]
664    fn test_rewrite_partition_column() -> Result<()> {
665        let (physical_schema, logical_schema) = create_test_schema();
666
667        let partition_field =
668            Arc::new(Field::new("partition_col", DataType::Utf8, false));
669        let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
670        let partition_values = vec![(partition_field, partition_value)];
671
672        let factory = DefaultPhysicalExprAdapterFactory;
673        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
674        let adapter = adapter.with_partition_values(partition_values);
675
676        let column_expr = Arc::new(Column::new("partition_col", 0));
677        let result = adapter.rewrite(column_expr)?;
678
679        // Should be replaced with the partition value
680        if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
681            assert_eq!(
682                *literal.value(),
683                ScalarValue::Utf8(Some("test_value".to_string()))
684            );
685        } else {
686            panic!("Expected literal expression");
687        }
688
689        Ok(())
690    }
691
692    #[test]
693    fn test_rewrite_no_change_needed() -> Result<()> {
694        let (physical_schema, logical_schema) = create_test_schema();
695
696        let factory = DefaultPhysicalExprAdapterFactory;
697        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
698        let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
699
700        let result = adapter.rewrite(Arc::clone(&column_expr))?;
701
702        // Should be the same expression (no transformation needed)
703        // We compare the underlying pointer through the trait object
704        assert!(std::ptr::eq(
705            column_expr.as_ref() as *const dyn PhysicalExpr,
706            result.as_ref() as *const dyn PhysicalExpr
707        ));
708
709        Ok(())
710    }
711
712    #[test]
713    fn test_non_nullable_missing_column_error() {
714        let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
715        let logical_schema = Schema::new(vec![
716            Field::new("a", DataType::Int32, false),
717            Field::new("b", DataType::Utf8, false), // Non-nullable missing column
718        ]);
719
720        let factory = DefaultPhysicalExprAdapterFactory;
721        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
722        let column_expr = Arc::new(Column::new("b", 1));
723
724        let result = adapter.rewrite(column_expr);
725        assert!(result.is_err());
726        assert_contains!(
727            result.unwrap_err().to_string(),
728            "Non-nullable column 'b' is missing from the physical schema"
729        );
730    }
731
732    /// Helper function to project expressions onto a RecordBatch
733    fn batch_project(
734        expr: Vec<Arc<dyn PhysicalExpr>>,
735        batch: &RecordBatch,
736        schema: SchemaRef,
737    ) -> Result<RecordBatch> {
738        let arrays = expr
739            .iter()
740            .map(|expr| {
741                expr.evaluate(batch)
742                    .and_then(|v| v.into_array(batch.num_rows()))
743            })
744            .collect::<Result<Vec<_>>>()?;
745
746        if arrays.is_empty() {
747            let options =
748                RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
749            RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
750                .map_err(Into::into)
751        } else {
752            RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
753        }
754    }
755
756    /// Example showing how we can use the `DefaultPhysicalExprAdapter` to adapt RecordBatches during a scan
757    /// to apply projections, type conversions and handling of missing columns all at once.
758    #[test]
759    fn test_adapt_batches() {
760        let physical_batch = record_batch!(
761            ("a", Int32, vec![Some(1), None, Some(3)]),
762            ("extra", Utf8, vec![Some("x"), Some("y"), None])
763        )
764        .unwrap();
765
766        let physical_schema = physical_batch.schema();
767
768        let logical_schema = Arc::new(Schema::new(vec![
769            Field::new("a", DataType::Int64, true), // Different type
770            Field::new("b", DataType::Utf8, true),  // Missing from physical
771        ]));
772
773        let projection = vec![
774            col("b", &logical_schema).unwrap(),
775            col("a", &logical_schema).unwrap(),
776        ];
777
778        let factory = DefaultPhysicalExprAdapterFactory;
779        let adapter =
780            factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
781
782        let adapted_projection = projection
783            .into_iter()
784            .map(|expr| adapter.rewrite(expr).unwrap())
785            .collect_vec();
786
787        let adapted_schema = Arc::new(Schema::new(
788            adapted_projection
789                .iter()
790                .map(|expr| expr.return_field(&physical_schema).unwrap())
791                .collect_vec(),
792        ));
793
794        let res = batch_project(
795            adapted_projection,
796            &physical_batch,
797            Arc::clone(&adapted_schema),
798        )
799        .unwrap();
800
801        assert_eq!(res.num_columns(), 2);
802        assert_eq!(res.column(0).data_type(), &DataType::Utf8);
803        assert_eq!(res.column(1).data_type(), &DataType::Int64);
804        assert_eq!(
805            res.column(0)
806                .as_any()
807                .downcast_ref::<arrow::array::StringArray>()
808                .unwrap()
809                .iter()
810                .collect_vec(),
811            vec![None, None, None]
812        );
813        assert_eq!(
814            res.column(1)
815                .as_any()
816                .downcast_ref::<arrow::array::Int64Array>()
817                .unwrap()
818                .iter()
819                .collect_vec(),
820            vec![Some(1), None, Some(3)]
821        );
822    }
823
824    #[test]
825    fn test_try_rewrite_struct_field_access() {
826        // Test the core logic of try_rewrite_struct_field_access
827        let physical_schema = Schema::new(vec![Field::new(
828            "struct_col",
829            DataType::Struct(
830                vec![Field::new("existing_field", DataType::Int32, true)].into(),
831            ),
832            true,
833        )]);
834
835        let logical_schema = Schema::new(vec![Field::new(
836            "struct_col",
837            DataType::Struct(
838                vec![
839                    Field::new("existing_field", DataType::Int32, true),
840                    Field::new("missing_field", DataType::Utf8, true),
841                ]
842                .into(),
843            ),
844            true,
845        )]);
846
847        let rewriter = DefaultPhysicalExprAdapterRewriter {
848            logical_file_schema: &logical_schema,
849            physical_file_schema: &physical_schema,
850            partition_fields: &[],
851        };
852
853        // Test that when a field exists in physical schema, it returns None
854        let column = Arc::new(Column::new("struct_col", 0)) as Arc<dyn PhysicalExpr>;
855        let result = rewriter.try_rewrite_struct_field_access(&column).unwrap();
856        assert!(result.is_none());
857
858        // The actual test for the get_field expression would require creating a proper ScalarFunctionExpr
859        // with ScalarUDF, which is complex to set up in a unit test. The integration tests in
860        // datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality.
861    }
862}