datafusion_physical_expr/
schema_rewriter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical expression schema rewriting utilities
19
20use std::sync::Arc;
21
22use arrow::compute::can_cast_types;
23use arrow::datatypes::{FieldRef, Schema, SchemaRef};
24use datafusion_common::{
25    exec_err,
26    tree_node::{Transformed, TransformedResult, TreeNode},
27    Result, ScalarValue,
28};
29use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
30
31use crate::expressions::{self, CastExpr, Column};
32
33/// Trait for adapting physical expressions to match a target schema.
34///
35/// This is used in file scans to rewrite expressions so that they can be evaluated
36/// against the physical schema of the file being scanned. It allows for handling
37/// differences between logical and physical schemas, such as type mismatches or missing columns.
38///
39/// You can create a custom implemention of this trait to handle specific rewriting logic.
40/// For example, to fill in missing columns with default values instead of nulls:
41///
42/// ```rust
43/// use datafusion_physical_expr::schema_rewriter::{PhysicalExprAdapter, PhysicalExprAdapterFactory};
44/// use arrow::datatypes::{Schema, Field, DataType, FieldRef, SchemaRef};
45/// use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
46/// use datafusion_common::{Result, ScalarValue, tree_node::{Transformed, TransformedResult, TreeNode}};
47/// use datafusion_physical_expr::expressions::{self, Column};
48/// use std::sync::Arc;
49///
50/// #[derive(Debug)]
51/// pub struct CustomPhysicalExprAdapter {
52///     logical_file_schema: SchemaRef,
53///     physical_file_schema: SchemaRef,
54/// }
55///
56/// impl PhysicalExprAdapter for CustomPhysicalExprAdapter {
57///     fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
58///         expr.transform(|expr| {
59///             if let Some(column) = expr.as_any().downcast_ref::<Column>() {
60///                 // Check if the column exists in the physical schema
61///                 if self.physical_file_schema.index_of(column.name()).is_err() {
62///                     // If the column is missing, fill it with a default value instead of null
63///                     // The default value could be stored in the table schema's column metadata for example.
64///                     let default_value = ScalarValue::Int32(Some(0));
65///                     return Ok(Transformed::yes(expressions::lit(default_value)));
66///                 }
67///             }
68///             // If the column exists, return it as is
69///             Ok(Transformed::no(expr))
70///         }).data()
71///     }
72///
73///     fn with_partition_values(
74///         &self,
75///         partition_values: Vec<(FieldRef, ScalarValue)>,
76///     ) -> Arc<dyn PhysicalExprAdapter> {
77///         // For simplicity, this example ignores partition values
78///         Arc::new(CustomPhysicalExprAdapter {
79///             logical_file_schema: self.logical_file_schema.clone(),
80///             physical_file_schema: self.physical_file_schema.clone(),
81///         })
82///     }
83/// }
84///
85/// #[derive(Debug)]
86/// pub struct CustomPhysicalExprAdapterFactory;
87///
88/// impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory {
89///     fn create(
90///         &self,
91///         logical_file_schema: SchemaRef,
92///         physical_file_schema: SchemaRef,
93///     ) -> Arc<dyn PhysicalExprAdapter> {
94///         Arc::new(CustomPhysicalExprAdapter {
95///             logical_file_schema,
96///             physical_file_schema,
97///         })
98///     }
99/// }
100/// ```
101pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
102    /// Rewrite a physical expression to match the target schema.
103    ///
104    /// This method should return a transformed expression that matches the target schema.
105    ///
106    /// Arguments:
107    /// - `expr`: The physical expression to rewrite.
108    /// - `logical_file_schema`: The logical schema of the table being queried, excluding any partition columns.
109    /// - `physical_file_schema`: The physical schema of the file being scanned.
110    /// - `partition_values`: Optional partition values to use for rewriting partition column references.
111    ///   These are handled as if they were columns appended onto the logical file schema.
112    ///
113    /// Returns:
114    /// - `Arc<dyn PhysicalExpr>`: The rewritten physical expression that can be evaluated against the physical schema.
115    fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
116
117    fn with_partition_values(
118        &self,
119        partition_values: Vec<(FieldRef, ScalarValue)>,
120    ) -> Arc<dyn PhysicalExprAdapter>;
121}
122
123pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
124    /// Create a new instance of the physical expression adapter.
125    fn create(
126        &self,
127        logical_file_schema: SchemaRef,
128        physical_file_schema: SchemaRef,
129    ) -> Arc<dyn PhysicalExprAdapter>;
130}
131
132#[derive(Debug, Clone)]
133pub struct DefaultPhysicalExprAdapterFactory;
134
135impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
136    fn create(
137        &self,
138        logical_file_schema: SchemaRef,
139        physical_file_schema: SchemaRef,
140    ) -> Arc<dyn PhysicalExprAdapter> {
141        Arc::new(DefaultPhysicalExprAdapter {
142            logical_file_schema,
143            physical_file_schema,
144            partition_values: Vec::new(),
145        })
146    }
147}
148
149/// Default implementation for rewriting physical expressions to match different schemas.
150///
151/// # Example
152///
153/// ```rust
154/// use datafusion_physical_expr::schema_rewriter::{DefaultPhysicalExprAdapterFactory, PhysicalExprAdapterFactory};
155/// use arrow::datatypes::Schema;
156/// use std::sync::Arc;
157///
158/// # fn example(
159/// #     predicate: std::sync::Arc<dyn datafusion_physical_expr_common::physical_expr::PhysicalExpr>,
160/// #     physical_file_schema: &Schema,
161/// #     logical_file_schema: &Schema,
162/// # ) -> datafusion_common::Result<()> {
163/// let factory = DefaultPhysicalExprAdapterFactory;
164/// let adapter = factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone()));
165/// let adapted_predicate = adapter.rewrite(predicate)?;
166/// # Ok(())
167/// # }
168/// ```
169#[derive(Debug, Clone)]
170pub struct DefaultPhysicalExprAdapter {
171    logical_file_schema: SchemaRef,
172    physical_file_schema: SchemaRef,
173    partition_values: Vec<(FieldRef, ScalarValue)>,
174}
175
176impl DefaultPhysicalExprAdapter {
177    /// Create a new instance of the default physical expression adapter.
178    ///
179    /// This adapter rewrites expressions to match the physical schema of the file being scanned,
180    /// handling type mismatches and missing columns by filling them with default values.
181    pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
182        Self {
183            logical_file_schema,
184            physical_file_schema,
185            partition_values: Vec::new(),
186        }
187    }
188}
189
190impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
191    fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
192        let rewriter = DefaultPhysicalExprAdapterRewriter {
193            logical_file_schema: &self.logical_file_schema,
194            physical_file_schema: &self.physical_file_schema,
195            partition_fields: &self.partition_values,
196        };
197        expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
198            .data()
199    }
200
201    fn with_partition_values(
202        &self,
203        partition_values: Vec<(FieldRef, ScalarValue)>,
204    ) -> Arc<dyn PhysicalExprAdapter> {
205        Arc::new(DefaultPhysicalExprAdapter {
206            partition_values,
207            ..self.clone()
208        })
209    }
210}
211
212struct DefaultPhysicalExprAdapterRewriter<'a> {
213    logical_file_schema: &'a Schema,
214    physical_file_schema: &'a Schema,
215    partition_fields: &'a [(FieldRef, ScalarValue)],
216}
217
218impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
219    fn rewrite_expr(
220        &self,
221        expr: Arc<dyn PhysicalExpr>,
222    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
223        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
224            return self.rewrite_column(Arc::clone(&expr), column);
225        }
226
227        Ok(Transformed::no(expr))
228    }
229
230    fn rewrite_column(
231        &self,
232        expr: Arc<dyn PhysicalExpr>,
233        column: &Column,
234    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
235        // Get the logical field for this column if it exists in the logical schema
236        let logical_field = match self.logical_file_schema.field_with_name(column.name())
237        {
238            Ok(field) => field,
239            Err(e) => {
240                // If the column is a partition field, we can use the partition value
241                if let Some(partition_value) = self.get_partition_value(column.name()) {
242                    return Ok(Transformed::yes(expressions::lit(partition_value)));
243                }
244                // This can be hit if a custom rewrite injected a reference to a column that doesn't exist in the logical schema.
245                // For example, a pre-computed column that is kept only in the physical schema.
246                // If the column exists in the physical schema, we can still use it.
247                if let Ok(physical_field) =
248                    self.physical_file_schema.field_with_name(column.name())
249                {
250                    // If the column exists in the physical schema, we can use it in place of the logical column.
251                    // This is nice to users because if they do a rewrite that results in something like `phyiscal_int32_col = 123u64`
252                    // we'll at least handle the casts for them.
253                    physical_field
254                } else {
255                    // A completely unknown column that doesn't exist in either schema!
256                    // This should probably never be hit unless something upstream broke, but nontheless it's better
257                    // for us to return a handleable error than to panic / do something unexpected.
258                    return Err(e.into());
259                }
260            }
261        };
262
263        // Check if the column exists in the physical schema
264        let physical_column_index =
265            match self.physical_file_schema.index_of(column.name()) {
266                Ok(index) => index,
267                Err(_) => {
268                    if !logical_field.is_nullable() {
269                        return exec_err!(
270                        "Non-nullable column '{}' is missing from the physical schema",
271                        column.name()
272                    );
273                    }
274                    // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do.
275                    // TODO: do we need to sync this with what the `SchemaAdapter` actually does?
276                    // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else!
277                    // See https://github.com/apache/datafusion/issues/16527
278                    let null_value =
279                        ScalarValue::Null.cast_to(logical_field.data_type())?;
280                    return Ok(Transformed::yes(expressions::lit(null_value)));
281                }
282            };
283        let physical_field = self.physical_file_schema.field(physical_column_index);
284
285        let column = match (
286            column.index() == physical_column_index,
287            logical_field.data_type() == physical_field.data_type(),
288        ) {
289            // If the column index matches and the data types match, we can use the column as is
290            (true, true) => return Ok(Transformed::no(expr)),
291            // If the indexes or data types do not match, we need to create a new column expression
292            (true, _) => column.clone(),
293            (false, _) => {
294                Column::new_with_schema(logical_field.name(), self.physical_file_schema)?
295            }
296        };
297
298        if logical_field.data_type() == physical_field.data_type() {
299            // If the data types match, we can use the column as is
300            return Ok(Transformed::yes(Arc::new(column)));
301        }
302
303        // We need to cast the column to the logical data type
304        // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
305        // since that's much cheaper to evalaute.
306        // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928
307        if !can_cast_types(physical_field.data_type(), logical_field.data_type()) {
308            return exec_err!(
309                "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
310                column.name(),
311                physical_field.data_type(),
312                logical_field.data_type()
313            );
314        }
315
316        let cast_expr = Arc::new(CastExpr::new(
317            Arc::new(column),
318            logical_field.data_type().clone(),
319            None,
320        ));
321
322        Ok(Transformed::yes(cast_expr))
323    }
324
325    fn get_partition_value(&self, column_name: &str) -> Option<ScalarValue> {
326        self.partition_fields
327            .iter()
328            .find(|(field, _)| field.name() == column_name)
329            .map(|(_, value)| value.clone())
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use crate::expressions::{col, lit};
336
337    use super::*;
338    use arrow::{
339        array::{RecordBatch, RecordBatchOptions},
340        datatypes::{DataType, Field, Schema, SchemaRef},
341    };
342    use datafusion_common::{record_batch, ScalarValue};
343    use datafusion_expr::Operator;
344    use itertools::Itertools;
345    use std::sync::Arc;
346
347    fn create_test_schema() -> (Schema, Schema) {
348        let physical_schema = Schema::new(vec![
349            Field::new("a", DataType::Int32, false),
350            Field::new("b", DataType::Utf8, true),
351        ]);
352
353        let logical_schema = Schema::new(vec![
354            Field::new("a", DataType::Int64, false), // Different type
355            Field::new("b", DataType::Utf8, true),
356            Field::new("c", DataType::Float64, true), // Missing from physical
357        ]);
358
359        (physical_schema, logical_schema)
360    }
361
362    #[test]
363    fn test_rewrite_column_with_type_cast() {
364        let (physical_schema, logical_schema) = create_test_schema();
365
366        let factory = DefaultPhysicalExprAdapterFactory;
367        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
368        let column_expr = Arc::new(Column::new("a", 0));
369
370        let result = adapter.rewrite(column_expr).unwrap();
371
372        // Should be wrapped in a cast expression
373        assert!(result.as_any().downcast_ref::<CastExpr>().is_some());
374    }
375
376    #[test]
377    fn test_rewrite_mulit_column_expr_with_type_cast() {
378        let (physical_schema, logical_schema) = create_test_schema();
379        let factory = DefaultPhysicalExprAdapterFactory;
380        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
381
382        // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter
383        let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
384        let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
385        let expr = expressions::BinaryExpr::new(
386            Arc::clone(&column_a),
387            Operator::Plus,
388            Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
389        );
390        let expr = expressions::BinaryExpr::new(
391            Arc::new(expr),
392            Operator::Or,
393            Arc::new(expressions::BinaryExpr::new(
394                Arc::clone(&column_c),
395                Operator::Gt,
396                Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
397            )),
398        );
399
400        let result = adapter.rewrite(Arc::new(expr)).unwrap();
401        println!("Rewritten expression: {result}");
402
403        let expected = expressions::BinaryExpr::new(
404            Arc::new(CastExpr::new(
405                Arc::new(Column::new("a", 0)),
406                DataType::Int64,
407                None,
408            )),
409            Operator::Plus,
410            Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
411        );
412        let expected = Arc::new(expressions::BinaryExpr::new(
413            Arc::new(expected),
414            Operator::Or,
415            Arc::new(expressions::BinaryExpr::new(
416                lit(ScalarValue::Null),
417                Operator::Gt,
418                Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
419            )),
420        )) as Arc<dyn PhysicalExpr>;
421
422        assert_eq!(
423            result.to_string(),
424            expected.to_string(),
425            "The rewritten expression did not match the expected output"
426        );
427    }
428
429    #[test]
430    fn test_rewrite_missing_column() -> Result<()> {
431        let (physical_schema, logical_schema) = create_test_schema();
432
433        let factory = DefaultPhysicalExprAdapterFactory;
434        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
435        let column_expr = Arc::new(Column::new("c", 2));
436
437        let result = adapter.rewrite(column_expr)?;
438
439        // Should be replaced with a literal null
440        if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
441            assert_eq!(*literal.value(), ScalarValue::Float64(None));
442        } else {
443            panic!("Expected literal expression");
444        }
445
446        Ok(())
447    }
448
449    #[test]
450    fn test_rewrite_partition_column() -> Result<()> {
451        let (physical_schema, logical_schema) = create_test_schema();
452
453        let partition_field =
454            Arc::new(Field::new("partition_col", DataType::Utf8, false));
455        let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
456        let partition_values = vec![(partition_field, partition_value)];
457
458        let factory = DefaultPhysicalExprAdapterFactory;
459        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
460        let adapter = adapter.with_partition_values(partition_values);
461
462        let column_expr = Arc::new(Column::new("partition_col", 0));
463        let result = adapter.rewrite(column_expr)?;
464
465        // Should be replaced with the partition value
466        if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
467            assert_eq!(
468                *literal.value(),
469                ScalarValue::Utf8(Some("test_value".to_string()))
470            );
471        } else {
472            panic!("Expected literal expression");
473        }
474
475        Ok(())
476    }
477
478    #[test]
479    fn test_rewrite_no_change_needed() -> Result<()> {
480        let (physical_schema, logical_schema) = create_test_schema();
481
482        let factory = DefaultPhysicalExprAdapterFactory;
483        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
484        let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
485
486        let result = adapter.rewrite(Arc::clone(&column_expr))?;
487
488        // Should be the same expression (no transformation needed)
489        // We compare the underlying pointer through the trait object
490        assert!(std::ptr::eq(
491            column_expr.as_ref() as *const dyn PhysicalExpr,
492            result.as_ref() as *const dyn PhysicalExpr
493        ));
494
495        Ok(())
496    }
497
498    #[test]
499    fn test_non_nullable_missing_column_error() {
500        let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
501        let logical_schema = Schema::new(vec![
502            Field::new("a", DataType::Int32, false),
503            Field::new("b", DataType::Utf8, false), // Non-nullable missing column
504        ]);
505
506        let factory = DefaultPhysicalExprAdapterFactory;
507        let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
508        let column_expr = Arc::new(Column::new("b", 1));
509
510        let result = adapter.rewrite(column_expr);
511        assert!(result.is_err());
512        assert!(result
513            .unwrap_err()
514            .to_string()
515            .contains("Non-nullable column 'b' is missing"));
516    }
517
518    /// Roughly stolen from ProjectionExec
519    fn batch_project(
520        expr: Vec<Arc<dyn PhysicalExpr>>,
521        batch: &RecordBatch,
522        schema: SchemaRef,
523    ) -> Result<RecordBatch> {
524        let arrays = expr
525            .iter()
526            .map(|expr| {
527                expr.evaluate(batch)
528                    .and_then(|v| v.into_array(batch.num_rows()))
529            })
530            .collect::<Result<Vec<_>>>()?;
531
532        if arrays.is_empty() {
533            let options =
534                RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
535            RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
536                .map_err(Into::into)
537        } else {
538            RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
539        }
540    }
541
542    /// Example showing how we can use the `DefaultPhysicalExprAdapter` to adapt RecordBatches during a scan
543    /// to apply projections, type conversions and handling of missing columns all at once.
544    #[test]
545    fn test_adapt_batches() {
546        let physical_batch = record_batch!(
547            ("a", Int32, vec![Some(1), None, Some(3)]),
548            ("extra", Utf8, vec![Some("x"), Some("y"), None])
549        )
550        .unwrap();
551
552        let physical_schema = physical_batch.schema();
553
554        let logical_schema = Arc::new(Schema::new(vec![
555            Field::new("a", DataType::Int64, true), // Different type
556            Field::new("b", DataType::Utf8, true),  // Missing from physical
557        ]));
558
559        let projection = vec![
560            col("b", &logical_schema).unwrap(),
561            col("a", &logical_schema).unwrap(),
562        ];
563
564        let factory = DefaultPhysicalExprAdapterFactory;
565        let adapter =
566            factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
567
568        let adapted_projection = projection
569            .into_iter()
570            .map(|expr| adapter.rewrite(expr).unwrap())
571            .collect_vec();
572
573        let adapted_schema = Arc::new(Schema::new(
574            adapted_projection
575                .iter()
576                .map(|expr| expr.return_field(&physical_schema).unwrap())
577                .collect_vec(),
578        ));
579
580        let res = batch_project(
581            adapted_projection,
582            &physical_batch,
583            Arc::clone(&adapted_schema),
584        )
585        .unwrap();
586
587        assert_eq!(res.num_columns(), 2);
588        assert_eq!(res.column(0).data_type(), &DataType::Utf8);
589        assert_eq!(res.column(1).data_type(), &DataType::Int64);
590        assert_eq!(
591            res.column(0)
592                .as_any()
593                .downcast_ref::<arrow::array::StringArray>()
594                .unwrap()
595                .iter()
596                .collect_vec(),
597            vec![None, None, None]
598        );
599        assert_eq!(
600            res.column(1)
601                .as_any()
602                .downcast_ref::<arrow::array::Int64Array>()
603                .unwrap()
604                .iter()
605                .collect_vec(),
606            vec![Some(1), None, Some(3)]
607        );
608    }
609}