Skip to main content

datafusion_physical_expr_adapter/
schema_rewriter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical expression schema rewriting utilities: [`PhysicalExprAdapter`],
19//! [`PhysicalExprAdapterFactory`], default implementations,
20//! and [`replace_columns_with_literals`].
21
22use std::borrow::Borrow;
23use std::collections::HashMap;
24use std::hash::Hash;
25use std::sync::Arc;
26
27use arrow::array::RecordBatch;
28use arrow::compute::can_cast_types;
29use arrow::datatypes::{DataType, Field, SchemaRef};
30use datafusion_common::{
31    Result, ScalarValue, exec_err,
32    nested_struct::validate_struct_compatibility,
33    tree_node::{Transformed, TransformedResult, TreeNode},
34};
35use datafusion_functions::core::getfield::GetFieldFunc;
36use datafusion_physical_expr::PhysicalExprSimplifier;
37use datafusion_physical_expr::expressions::CastColumnExpr;
38use datafusion_physical_expr::projection::{ProjectionExprs, Projector};
39use datafusion_physical_expr::{
40    ScalarFunctionExpr,
41    expressions::{self, Column},
42};
43use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
44use itertools::Itertools;
45
46/// Replace column references in the given physical expression with literal values.
47///
48/// Some use cases for this include:
49/// - Partition column pruning: When scanning partitioned data, partition column references
50///   can be replaced with their literal values for the specific partition being scanned.
51/// - Constant folding: In some cases, columns that can be proven to be constant
52///   from statistical analysis may be replaced with their literal values to optimize expression evaluation.
53/// - Filling in non-null default values: in a custom [`PhysicalExprAdapter`] implementation,
54///   column references can be replaced with default literal values instead of nulls.
55///
56/// # Arguments
57/// - `expr`: The physical expression in which to replace column references.
58/// - `replacements`: A mapping from column names to their corresponding literal `ScalarValue`s.
59///   Accepts various HashMap types including `HashMap<&str, &ScalarValue>`,
60///   `HashMap<String, ScalarValue>`, `HashMap<String, &ScalarValue>`, etc.
61///
62/// # Returns
63/// - `Result<Arc<dyn PhysicalExpr>>`: The rewritten physical expression with columns replaced by literals.
64pub fn replace_columns_with_literals<K, V>(
65    expr: Arc<dyn PhysicalExpr>,
66    replacements: &HashMap<K, V>,
67) -> Result<Arc<dyn PhysicalExpr>>
68where
69    K: Borrow<str> + Eq + Hash,
70    V: Borrow<ScalarValue>,
71{
72    expr.transform_down(|expr| {
73        if let Some(column) = expr.as_any().downcast_ref::<Column>()
74            && let Some(replacement_value) = replacements.get(column.name())
75        {
76            return Ok(Transformed::yes(expressions::lit(
77                replacement_value.borrow().clone(),
78            )));
79        }
80        Ok(Transformed::no(expr))
81    })
82    .data()
83}
84
85/// Trait for adapting [`PhysicalExpr`] expressions to match a target schema.
86///
87/// This is used in file scans to rewrite expressions so that they can be
88/// evaluated against the physical schema of the file being scanned. It allows
89/// for handling differences between logical and physical schemas, such as type
90/// mismatches or missing columns common in [Schema evolution] scenarios.
91///
92/// [Schema evolution]: https://www.dremio.com/wiki/schema-evolution/
93///
94/// ## Default Implementations
95///
96/// The default implementation [`DefaultPhysicalExprAdapter`]  handles common
97/// cases.
98///
99/// ## Custom Implementations
100///
101/// You can create a custom implementation of this trait to handle specific rewriting logic.
102/// For example, to fill in missing columns with default values instead of nulls:
103///
104/// ```rust
105/// use datafusion_physical_expr_adapter::{PhysicalExprAdapter, PhysicalExprAdapterFactory};
106/// use arrow::datatypes::{Schema, Field, DataType, FieldRef, SchemaRef};
107/// use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
108/// use datafusion_common::{Result, ScalarValue, tree_node::{Transformed, TransformedResult, TreeNode}};
109/// use datafusion_physical_expr::expressions::{self, Column};
110/// use std::sync::Arc;
111///
112/// #[derive(Debug)]
113/// pub struct CustomPhysicalExprAdapter {
114///     logical_file_schema: SchemaRef,
115///     physical_file_schema: SchemaRef,
116/// }
117///
118/// impl PhysicalExprAdapter for CustomPhysicalExprAdapter {
119///     fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
120///         expr.transform(|expr| {
121///             if let Some(column) = expr.as_any().downcast_ref::<Column>() {
122///                 // Check if the column exists in the physical schema
123///                 if self.physical_file_schema.index_of(column.name()).is_err() {
124///                     // If the column is missing, fill it with a default value instead of null
125///                     // The default value could be stored in the table schema's column metadata for example.
126///                     let default_value = ScalarValue::Int32(Some(0));
127///                     return Ok(Transformed::yes(expressions::lit(default_value)));
128///                 }
129///             }
130///             // If the column exists, return it as is
131///             Ok(Transformed::no(expr))
132///         }).data()
133///     }
134/// }
135///
136/// #[derive(Debug)]
137/// pub struct CustomPhysicalExprAdapterFactory;
138///
139/// impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory {
140///     fn create(
141///         &self,
142///         logical_file_schema: SchemaRef,
143///         physical_file_schema: SchemaRef,
144///     ) -> Result<Arc<dyn PhysicalExprAdapter>> {
145///         Ok(Arc::new(CustomPhysicalExprAdapter {
146///             logical_file_schema,
147///             physical_file_schema,
148///         }))
149///     }
150/// }
151/// ```
152pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
153    /// Rewrite a physical expression to match the target schema.
154    ///
155    /// This method should return a transformed expression that matches the target schema.
156    ///
157    /// Arguments:
158    /// - `expr`: The physical expression to rewrite.
159    /// - `logical_file_schema`: The logical schema of the table being queried, excluding any partition columns.
160    /// - `physical_file_schema`: The physical schema of the file being scanned.
161    /// - `partition_values`: Optional partition values to use for rewriting partition column references.
162    ///   These are handled as if they were columns appended onto the logical file schema.
163    ///
164    /// Returns:
165    /// - `Arc<dyn PhysicalExpr>`: The rewritten physical expression that can be evaluated against the physical schema.
166    ///
167    /// See Also:
168    /// - [`replace_columns_with_literals`]: for replacing partition column references with their literal values.
169    fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
170}
171
172/// Creates instances of [`PhysicalExprAdapter`] for given logical and physical schemas.
173///
174/// See [`DefaultPhysicalExprAdapterFactory`] for the default implementation.
175pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
176    /// Create a new instance of the physical expression adapter.
177    fn create(
178        &self,
179        logical_file_schema: SchemaRef,
180        physical_file_schema: SchemaRef,
181    ) -> Result<Arc<dyn PhysicalExprAdapter>>;
182}
183
184#[derive(Debug, Clone)]
185pub struct DefaultPhysicalExprAdapterFactory;
186
187impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
188    fn create(
189        &self,
190        logical_file_schema: SchemaRef,
191        physical_file_schema: SchemaRef,
192    ) -> Result<Arc<dyn PhysicalExprAdapter>> {
193        Ok(Arc::new(DefaultPhysicalExprAdapter {
194            logical_file_schema,
195            physical_file_schema,
196        }))
197    }
198}
199
200/// Default implementation of [`PhysicalExprAdapter`] for rewriting physical
201/// expressions to match different schemas.
202///
203/// ## Overview
204///
205///  [`DefaultPhysicalExprAdapter`] rewrites physical expressions to match
206///  different schemas, including:
207///
208/// - **Type casting**: When logical and physical schemas have different types, expressions are
209///   automatically wrapped with cast operations. For example, `lit(ScalarValue::Int32(123)) = int64_column`
210///   gets rewritten to `lit(ScalarValue::Int32(123)) = cast(int64_column, 'Int32')`.
211///   Note that this does not attempt to simplify such expressions - that is done by shared simplifiers.
212///
213/// - **Missing columns**: When a column exists in the logical schema but not in the physical schema,
214///   references to it are replaced with null literals.
215///
216/// - **Struct field access**: Expressions like `struct_column.field_that_is_missing_in_schema` are
217///   rewritten to `null` when the field doesn't exist in the physical schema.
218///
219/// - **Default column values**: Partition column references can be replaced with their literal values
220///   when scanning specific partitions. See [`replace_columns_with_literals`] for more details.
221///
222/// # Example
223///
224/// ```rust
225/// # use datafusion_physical_expr_adapter::{DefaultPhysicalExprAdapterFactory, PhysicalExprAdapterFactory};
226/// # use arrow::datatypes::Schema;
227/// # use std::sync::Arc;
228/// #
229/// # fn example(
230/// #     predicate: std::sync::Arc<dyn datafusion_physical_expr_common::physical_expr::PhysicalExpr>,
231/// #     physical_file_schema: &Schema,
232/// #     logical_file_schema: &Schema,
233/// # ) -> datafusion_common::Result<()> {
234/// let factory = DefaultPhysicalExprAdapterFactory;
235/// let adapter =
236///     factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone()))?;
237/// let adapted_predicate = adapter.rewrite(predicate)?;
238/// # Ok(())
239/// # }
240/// ```
241#[derive(Debug, Clone)]
242pub struct DefaultPhysicalExprAdapter {
243    logical_file_schema: SchemaRef,
244    physical_file_schema: SchemaRef,
245}
246
247impl DefaultPhysicalExprAdapter {
248    /// Create a new instance of the default physical expression adapter.
249    ///
250    /// This adapter rewrites expressions to match the physical schema of the file being scanned,
251    /// handling type mismatches and missing columns by filling them with default values.
252    pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
253        Self {
254            logical_file_schema,
255            physical_file_schema,
256        }
257    }
258}
259
260impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
261    fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
262        let rewriter = DefaultPhysicalExprAdapterRewriter {
263            logical_file_schema: Arc::clone(&self.logical_file_schema),
264            physical_file_schema: Arc::clone(&self.physical_file_schema),
265        };
266        expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
267            .data()
268    }
269}
270
271struct DefaultPhysicalExprAdapterRewriter {
272    logical_file_schema: SchemaRef,
273    physical_file_schema: SchemaRef,
274}
275
276impl DefaultPhysicalExprAdapterRewriter {
277    fn rewrite_expr(
278        &self,
279        expr: Arc<dyn PhysicalExpr>,
280    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
281        if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? {
282            return Ok(Transformed::yes(transformed));
283        }
284
285        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
286            return self.rewrite_column(Arc::clone(&expr), column);
287        }
288
289        Ok(Transformed::no(expr))
290    }
291
292    /// Attempt to rewrite struct field access expressions to return null if the field does not exist in the physical schema.
293    /// Note that this does *not* handle nested struct fields, only top-level struct field access.
294    /// See <https://github.com/apache/datafusion/issues/17114> for more details.
295    fn try_rewrite_struct_field_access(
296        &self,
297        expr: &Arc<dyn PhysicalExpr>,
298    ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
299        let get_field_expr =
300            match ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(expr.as_ref()) {
301                Some(expr) => expr,
302                None => return Ok(None),
303            };
304
305        let source_expr = match get_field_expr.args().first() {
306            Some(expr) => expr,
307            None => return Ok(None),
308        };
309
310        let field_name_expr = match get_field_expr.args().get(1) {
311            Some(expr) => expr,
312            None => return Ok(None),
313        };
314
315        let lit = match field_name_expr
316            .as_any()
317            .downcast_ref::<expressions::Literal>()
318        {
319            Some(lit) => lit,
320            None => return Ok(None),
321        };
322
323        let field_name = match lit.value().try_as_str().flatten() {
324            Some(name) => name,
325            None => return Ok(None),
326        };
327
328        let column = match source_expr.as_any().downcast_ref::<Column>() {
329            Some(column) => column,
330            None => return Ok(None),
331        };
332
333        let physical_field =
334            match self.physical_file_schema.field_with_name(column.name()) {
335                Ok(field) => field,
336                Err(_) => return Ok(None),
337            };
338
339        let physical_struct_fields = match physical_field.data_type() {
340            DataType::Struct(fields) => fields,
341            _ => return Ok(None),
342        };
343
344        if physical_struct_fields
345            .iter()
346            .any(|f| f.name() == field_name)
347        {
348            return Ok(None);
349        }
350
351        let logical_field = match self.logical_file_schema.field_with_name(column.name())
352        {
353            Ok(field) => field,
354            Err(_) => return Ok(None),
355        };
356
357        let logical_struct_fields = match logical_field.data_type() {
358            DataType::Struct(fields) => fields,
359            _ => return Ok(None),
360        };
361
362        let logical_struct_field = match logical_struct_fields
363            .iter()
364            .find(|f| f.name() == field_name)
365        {
366            Some(field) => field,
367            None => return Ok(None),
368        };
369
370        let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
371        Ok(Some(expressions::lit(null_value)))
372    }
373
374    fn rewrite_column(
375        &self,
376        expr: Arc<dyn PhysicalExpr>,
377        column: &Column,
378    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
379        // Get the logical field for this column if it exists in the logical schema
380        let logical_field = match self.logical_file_schema.field_with_name(column.name())
381        {
382            Ok(field) => field,
383            Err(e) => {
384                // This can be hit if a custom rewrite injected a reference to a column that doesn't exist in the logical schema.
385                // For example, a pre-computed column that is kept only in the physical schema.
386                // If the column exists in the physical schema, we can still use it.
387                if let Ok(physical_field) =
388                    self.physical_file_schema.field_with_name(column.name())
389                {
390                    // If the column exists in the physical schema, we can use it in place of the logical column.
391                    // This is nice to users because if they do a rewrite that results in something like `physical_int32_col = 123u64`
392                    // we'll at least handle the casts for them.
393                    physical_field
394                } else {
395                    // A completely unknown column that doesn't exist in either schema!
396                    // This should probably never be hit unless something upstream broke, but nonetheless it's better
397                    // for us to return a handleable error than to panic / do something unexpected.
398                    return Err(e.into());
399                }
400            }
401        };
402
403        // Check if the column exists in the physical schema
404        let physical_column_index = match self
405            .physical_file_schema
406            .index_of(column.name())
407        {
408            Ok(index) => index,
409            Err(_) => {
410                if !logical_field.is_nullable() {
411                    return exec_err!(
412                        "Non-nullable column '{}' is missing from the physical schema",
413                        column.name()
414                    );
415                }
416                // If the column is missing from the physical schema fill it in with nulls.
417                // For a different behavior, provide a custom `PhysicalExprAdapter` implementation.
418                let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?;
419                return Ok(Transformed::yes(expressions::lit(null_value)));
420            }
421        };
422        let physical_field = self.physical_file_schema.field(physical_column_index);
423
424        if column.index() == physical_column_index
425            && logical_field.data_type() == physical_field.data_type()
426        {
427            return Ok(Transformed::no(expr));
428        }
429
430        let column = self.resolve_column(column, physical_column_index)?;
431
432        if logical_field.data_type() == physical_field.data_type() {
433            // If the data types match, we can use the column as is
434            return Ok(Transformed::yes(Arc::new(column)));
435        }
436
437        // We need to cast the column to the logical data type
438        // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
439        // since that's much cheaper to evalaute.
440        // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928
441        self.create_cast_column_expr(column, logical_field)
442    }
443
444    /// Resolves a column expression, handling index and type mismatches.
445    ///
446    /// Returns the appropriate Column expression when the column's index or data type
447    /// don't match the physical schema. Assumes that the early-exit case (both index
448    /// and type match) has already been checked by the caller.
449    fn resolve_column(
450        &self,
451        column: &Column,
452        physical_column_index: usize,
453    ) -> Result<Column> {
454        if column.index() == physical_column_index {
455            Ok(column.clone())
456        } else {
457            Column::new_with_schema(column.name(), self.physical_file_schema.as_ref())
458        }
459    }
460
461    /// Validates type compatibility and creates a CastColumnExpr if needed.
462    ///
463    /// Checks whether the physical field can be cast to the logical field type,
464    /// handling both struct and scalar types. Returns a CastColumnExpr with the
465    /// appropriate configuration.
466    fn create_cast_column_expr(
467        &self,
468        column: Column,
469        logical_field: &Field,
470    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
471        // Look up the column index in the physical schema by name to ensure correctness.
472        let physical_column_index = self.physical_file_schema.index_of(column.name())?;
473        let actual_physical_field =
474            self.physical_file_schema.field(physical_column_index);
475
476        // For struct types, use validate_struct_compatibility which handles:
477        // - Missing fields in source (filled with nulls)
478        // - Extra fields in source (ignored)
479        // - Recursive validation of nested structs
480        // For non-struct types, use Arrow's can_cast_types
481        match (actual_physical_field.data_type(), logical_field.data_type()) {
482            (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => {
483                validate_struct_compatibility(
484                    physical_fields.as_ref(),
485                    logical_fields.as_ref(),
486                )?;
487            }
488            _ => {
489                let is_compatible = can_cast_types(
490                    actual_physical_field.data_type(),
491                    logical_field.data_type(),
492                );
493                if !is_compatible {
494                    return exec_err!(
495                        "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
496                        column.name(),
497                        actual_physical_field.data_type(),
498                        logical_field.data_type()
499                    );
500                }
501            }
502        }
503
504        let cast_expr = Arc::new(CastColumnExpr::new(
505            Arc::new(column),
506            Arc::new(actual_physical_field.clone()),
507            Arc::new(logical_field.clone()),
508            None,
509        ));
510
511        Ok(Transformed::yes(cast_expr))
512    }
513}
514
515/// Factory for creating [`BatchAdapter`] instances to adapt record batches
516/// to a target schema.
517///
518/// This binds a target schema and allows creating adapters for different source schemas.
519/// It handles:
520/// - **Column reordering**: Columns are reordered to match the target schema
521/// - **Type casting**: Automatic type conversion (e.g., Int32 to Int64)
522/// - **Missing columns**: Nullable columns missing from source are filled with nulls
523/// - **Struct field adaptation**: Nested struct fields are recursively adapted
524///
525/// ## Examples
526///
527/// ```rust
528/// use arrow::array::{Int32Array, Int64Array, StringArray, RecordBatch};
529/// use arrow::datatypes::{DataType, Field, Schema};
530/// use datafusion_physical_expr_adapter::BatchAdapterFactory;
531/// use std::sync::Arc;
532///
533/// // Target schema has different column order and types
534/// let target_schema = Arc::new(Schema::new(vec![
535///     Field::new("name", DataType::Utf8, true),
536///     Field::new("id", DataType::Int64, false),    // Int64 in target
537///     Field::new("score", DataType::Float64, true), // Missing from source
538/// ]));
539///
540/// // Source schema has different column order and Int32 for id
541/// let source_schema = Arc::new(Schema::new(vec![
542///     Field::new("id", DataType::Int32, false),    // Int32 in source
543///     Field::new("name", DataType::Utf8, true),
544///     // Note: 'score' column is missing from source
545/// ]));
546///
547/// // Create factory with target schema
548/// let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
549///
550/// // Create adapter for this specific source schema
551/// let adapter = factory.make_adapter(&source_schema).unwrap();
552///
553/// // Create a source batch
554/// let source_batch = RecordBatch::try_new(
555///     source_schema,
556///     vec![
557///         Arc::new(Int32Array::from(vec![1, 2, 3])),
558///         Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
559///     ],
560/// ).unwrap();
561///
562/// // Adapt the batch to match target schema
563/// let adapted = adapter.adapt_batch(&source_batch).unwrap();
564///
565/// assert_eq!(adapted.num_columns(), 3);
566/// assert_eq!(adapted.column(0).data_type(), &DataType::Utf8);   // name
567/// assert_eq!(adapted.column(1).data_type(), &DataType::Int64);  // id (cast from Int32)
568/// assert_eq!(adapted.column(2).data_type(), &DataType::Float64); // score (filled with nulls)
569/// ```
570#[derive(Debug)]
571pub struct BatchAdapterFactory {
572    target_schema: SchemaRef,
573    expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
574}
575
576impl BatchAdapterFactory {
577    /// Create a new [`BatchAdapterFactory`] with the given target schema.
578    pub fn new(target_schema: SchemaRef) -> Self {
579        let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory);
580        Self {
581            target_schema,
582            expr_adapter_factory,
583        }
584    }
585
586    /// Set a custom [`PhysicalExprAdapterFactory`] to use when adapting expressions.
587    ///
588    /// Use this to customize behavior when adapting batches, e.g. to fill in missing values
589    /// with defaults instead of nulls.
590    ///
591    /// See [`PhysicalExprAdapter`] for more details.
592    pub fn with_adapter_factory(
593        self,
594        factory: Arc<dyn PhysicalExprAdapterFactory>,
595    ) -> Self {
596        Self {
597            expr_adapter_factory: factory,
598            ..self
599        }
600    }
601
602    /// Create a new [`BatchAdapter`] for the given source schema.
603    ///
604    /// Batches fed into this [`BatchAdapter`] *must* conform to the source schema,
605    /// no validation is performed at runtime to minimize overheads.
606    pub fn make_adapter(&self, source_schema: &SchemaRef) -> Result<BatchAdapter> {
607        let expr_adapter = self
608            .expr_adapter_factory
609            .create(Arc::clone(&self.target_schema), Arc::clone(source_schema))?;
610
611        let simplifier = PhysicalExprSimplifier::new(&self.target_schema);
612
613        let projection = ProjectionExprs::from_indices(
614            &(0..self.target_schema.fields().len()).collect_vec(),
615            &self.target_schema,
616        );
617
618        let adapted = projection
619            .try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?;
620        let projector = adapted.make_projector(source_schema)?;
621
622        Ok(BatchAdapter { projector })
623    }
624}
625
626/// Adapter for transforming record batches to match a target schema.
627///
628/// Create instances via [`BatchAdapterFactory`].
629///
630/// ## Performance
631///
632/// The adapter pre-computes the projection expressions during creation,
633/// so the [`adapt_batch`](BatchAdapter::adapt_batch) call is efficient and suitable
634/// for use in hot paths like streaming file scans.
635#[derive(Debug)]
636pub struct BatchAdapter {
637    projector: Projector,
638}
639
640impl BatchAdapter {
641    /// Adapt the given record batch to match the target schema.
642    ///
643    /// The input batch *must* conform to the source schema used when
644    /// creating this adapter.
645    pub fn adapt_batch(&self, batch: &RecordBatch) -> Result<RecordBatch> {
646        self.projector.project_batch(batch)
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use arrow::array::{
654        BooleanArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions,
655        StringArray, StringViewArray, StructArray,
656    };
657    use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
658    use datafusion_common::{Result, ScalarValue, assert_contains, record_batch};
659    use datafusion_expr::Operator;
660    use datafusion_physical_expr::expressions::{Column, Literal, col, lit};
661    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
662    use itertools::Itertools;
663    use std::sync::Arc;
664
665    fn create_test_schema() -> (Schema, Schema) {
666        let physical_schema = Schema::new(vec![
667            Field::new("a", DataType::Int32, false),
668            Field::new("b", DataType::Utf8, true),
669        ]);
670
671        let logical_schema = Schema::new(vec![
672            Field::new("a", DataType::Int64, false), // Different type
673            Field::new("b", DataType::Utf8, true),
674            Field::new("c", DataType::Float64, true), // Missing from physical
675        ]);
676
677        (physical_schema, logical_schema)
678    }
679
680    #[test]
681    fn test_rewrite_column_with_type_cast() {
682        let (physical_schema, logical_schema) = create_test_schema();
683
684        let factory = DefaultPhysicalExprAdapterFactory;
685        let adapter = factory
686            .create(Arc::new(logical_schema), Arc::new(physical_schema))
687            .unwrap();
688        let column_expr = Arc::new(Column::new("a", 0));
689
690        let result = adapter.rewrite(column_expr).unwrap();
691
692        // Should be wrapped in a cast expression
693        assert!(result.as_any().downcast_ref::<CastColumnExpr>().is_some());
694    }
695
696    #[test]
697    fn test_rewrite_multi_column_expr_with_type_cast() {
698        let (physical_schema, logical_schema) = create_test_schema();
699        let factory = DefaultPhysicalExprAdapterFactory;
700        let adapter = factory
701            .create(Arc::new(logical_schema), Arc::new(physical_schema))
702            .unwrap();
703
704        // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter
705        let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
706        let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
707        let expr = expressions::BinaryExpr::new(
708            Arc::clone(&column_a),
709            Operator::Plus,
710            Arc::new(Literal::new(ScalarValue::Int64(Some(5)))),
711        );
712        let expr = expressions::BinaryExpr::new(
713            Arc::new(expr),
714            Operator::Or,
715            Arc::new(expressions::BinaryExpr::new(
716                Arc::clone(&column_c),
717                Operator::Gt,
718                Arc::new(Literal::new(ScalarValue::Float64(Some(0.0)))),
719            )),
720        );
721
722        let result = adapter.rewrite(Arc::new(expr)).unwrap();
723        println!("Rewritten expression: {result}");
724
725        let expected = expressions::BinaryExpr::new(
726            Arc::new(CastColumnExpr::new(
727                Arc::new(Column::new("a", 0)),
728                Arc::new(Field::new("a", DataType::Int32, false)),
729                Arc::new(Field::new("a", DataType::Int64, false)),
730                None,
731            )),
732            Operator::Plus,
733            Arc::new(Literal::new(ScalarValue::Int64(Some(5)))),
734        );
735        let expected = Arc::new(expressions::BinaryExpr::new(
736            Arc::new(expected),
737            Operator::Or,
738            Arc::new(expressions::BinaryExpr::new(
739                lit(ScalarValue::Float64(None)), // c is missing, so it becomes null
740                Operator::Gt,
741                Arc::new(Literal::new(ScalarValue::Float64(Some(0.0)))),
742            )),
743        )) as Arc<dyn PhysicalExpr>;
744
745        assert_eq!(
746            result.to_string(),
747            expected.to_string(),
748            "The rewritten expression did not match the expected output"
749        );
750    }
751
752    #[test]
753    fn test_rewrite_struct_column_incompatible() {
754        let physical_schema = Schema::new(vec![Field::new(
755            "data",
756            DataType::Struct(vec![Field::new("field1", DataType::Binary, true)].into()),
757            true,
758        )]);
759
760        let logical_schema = Schema::new(vec![Field::new(
761            "data",
762            DataType::Struct(vec![Field::new("field1", DataType::Int32, true)].into()),
763            true,
764        )]);
765
766        let factory = DefaultPhysicalExprAdapterFactory;
767        let adapter = factory
768            .create(Arc::new(logical_schema), Arc::new(physical_schema))
769            .unwrap();
770        let column_expr = Arc::new(Column::new("data", 0));
771
772        let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
773        // validate_struct_compatibility provides more specific error about which field can't be cast
774        assert_contains!(
775            error_msg,
776            "Cannot cast struct field 'field1' from type Binary to type Int32"
777        );
778    }
779
780    #[test]
781    fn test_rewrite_struct_compatible_cast() {
782        let physical_schema = Schema::new(vec![Field::new(
783            "data",
784            DataType::Struct(
785                vec![
786                    Field::new("id", DataType::Int32, false),
787                    Field::new("name", DataType::Utf8, true),
788                ]
789                .into(),
790            ),
791            false,
792        )]);
793
794        let logical_schema = Schema::new(vec![Field::new(
795            "data",
796            DataType::Struct(
797                vec![
798                    Field::new("id", DataType::Int64, false),
799                    Field::new("name", DataType::Utf8View, true),
800                ]
801                .into(),
802            ),
803            false,
804        )]);
805
806        let factory = DefaultPhysicalExprAdapterFactory;
807        let adapter = factory
808            .create(Arc::new(logical_schema), Arc::new(physical_schema))
809            .unwrap();
810        let column_expr = Arc::new(Column::new("data", 0));
811
812        let result = adapter.rewrite(column_expr).unwrap();
813
814        let physical_struct_fields: Fields = vec![
815            Field::new("id", DataType::Int32, false),
816            Field::new("name", DataType::Utf8, true),
817        ]
818        .into();
819        let physical_field = Arc::new(Field::new(
820            "data",
821            DataType::Struct(physical_struct_fields),
822            false,
823        ));
824
825        let logical_struct_fields: Fields = vec![
826            Field::new("id", DataType::Int64, false),
827            Field::new("name", DataType::Utf8View, true),
828        ]
829        .into();
830        let logical_field = Arc::new(Field::new(
831            "data",
832            DataType::Struct(logical_struct_fields),
833            false,
834        ));
835
836        let expected = Arc::new(CastColumnExpr::new(
837            Arc::new(Column::new("data", 0)),
838            physical_field,
839            logical_field,
840            None,
841        )) as Arc<dyn PhysicalExpr>;
842
843        assert_eq!(result.to_string(), expected.to_string());
844    }
845
846    #[test]
847    fn test_rewrite_missing_column() -> Result<()> {
848        let (physical_schema, logical_schema) = create_test_schema();
849
850        let factory = DefaultPhysicalExprAdapterFactory;
851        let adapter = factory
852            .create(Arc::new(logical_schema), Arc::new(physical_schema))
853            .unwrap();
854        let column_expr = Arc::new(Column::new("c", 2));
855
856        let result = adapter.rewrite(column_expr)?;
857
858        // Should be replaced with a literal null
859        if let Some(literal) = result.as_any().downcast_ref::<Literal>() {
860            assert_eq!(*literal.value(), ScalarValue::Float64(None));
861        } else {
862            panic!("Expected literal expression");
863        }
864
865        Ok(())
866    }
867
868    #[test]
869    fn test_rewrite_missing_column_non_nullable_error() {
870        let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
871        let logical_schema = Schema::new(vec![
872            Field::new("a", DataType::Int64, false),
873            Field::new("b", DataType::Utf8, false), // Missing and non-nullable
874        ]);
875
876        let factory = DefaultPhysicalExprAdapterFactory;
877        let adapter = factory
878            .create(Arc::new(logical_schema), Arc::new(physical_schema))
879            .unwrap();
880        let column_expr = Arc::new(Column::new("b", 1));
881
882        let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
883        assert_contains!(error_msg, "Non-nullable column 'b' is missing");
884    }
885
886    #[test]
887    fn test_rewrite_missing_column_nullable() {
888        let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
889        let logical_schema = Schema::new(vec![
890            Field::new("a", DataType::Int64, false),
891            Field::new("b", DataType::Utf8, true), // Missing but nullable
892        ]);
893
894        let factory = DefaultPhysicalExprAdapterFactory;
895        let adapter = factory
896            .create(Arc::new(logical_schema), Arc::new(physical_schema))
897            .unwrap();
898        let column_expr = Arc::new(Column::new("b", 1));
899
900        let result = adapter.rewrite(column_expr).unwrap();
901
902        let expected =
903            Arc::new(Literal::new(ScalarValue::Utf8(None))) as Arc<dyn PhysicalExpr>;
904
905        assert_eq!(result.to_string(), expected.to_string());
906    }
907
908    #[test]
909    fn test_replace_columns_with_literals() -> Result<()> {
910        let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
911        let replacements = HashMap::from([("partition_col", &partition_value)]);
912
913        let column_expr =
914            Arc::new(Column::new("partition_col", 0)) as Arc<dyn PhysicalExpr>;
915        let result = replace_columns_with_literals(column_expr, &replacements)?;
916
917        // Should be replaced with the partition value
918        let literal = result
919            .as_any()
920            .downcast_ref::<Literal>()
921            .expect("Expected literal expression");
922        assert_eq!(*literal.value(), partition_value);
923
924        Ok(())
925    }
926
927    #[test]
928    fn test_replace_columns_with_literals_no_match() -> Result<()> {
929        let value = ScalarValue::Utf8(Some("test_value".to_string()));
930        let replacements = HashMap::from([("other_col", &value)]);
931
932        let column_expr =
933            Arc::new(Column::new("partition_col", 0)) as Arc<dyn PhysicalExpr>;
934        let result = replace_columns_with_literals(column_expr, &replacements)?;
935
936        assert!(result.as_any().downcast_ref::<Column>().is_some());
937        Ok(())
938    }
939
940    #[test]
941    fn test_replace_columns_with_literals_nested_expr() -> Result<()> {
942        let value_a = ScalarValue::Int64(Some(10));
943        let value_b = ScalarValue::Int64(Some(20));
944        let replacements = HashMap::from([("a", &value_a), ("b", &value_b)]);
945
946        let expr = Arc::new(expressions::BinaryExpr::new(
947            Arc::new(Column::new("a", 0)),
948            Operator::Plus,
949            Arc::new(Column::new("b", 1)),
950        )) as Arc<dyn PhysicalExpr>;
951
952        let result = replace_columns_with_literals(expr, &replacements)?;
953        assert_eq!(result.to_string(), "10 + 20");
954
955        Ok(())
956    }
957
958    #[test]
959    fn test_rewrite_no_change_needed() -> Result<()> {
960        let (physical_schema, logical_schema) = create_test_schema();
961
962        let factory = DefaultPhysicalExprAdapterFactory;
963        let adapter = factory
964            .create(Arc::new(logical_schema), Arc::new(physical_schema))
965            .unwrap();
966        let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
967
968        let result = adapter.rewrite(Arc::clone(&column_expr))?;
969
970        // Should be the same expression (no transformation needed)
971        // We compare the underlying pointer through the trait object
972        assert!(std::ptr::eq(
973            column_expr.as_ref() as *const dyn PhysicalExpr,
974            result.as_ref() as *const dyn PhysicalExpr
975        ));
976
977        Ok(())
978    }
979
980    #[test]
981    fn test_non_nullable_missing_column_error() {
982        let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
983        let logical_schema = Schema::new(vec![
984            Field::new("a", DataType::Int32, false),
985            Field::new("b", DataType::Utf8, false), // Non-nullable missing column
986        ]);
987
988        let factory = DefaultPhysicalExprAdapterFactory;
989        let adapter = factory
990            .create(Arc::new(logical_schema), Arc::new(physical_schema))
991            .unwrap();
992        let column_expr = Arc::new(Column::new("b", 1));
993
994        let result = adapter.rewrite(column_expr);
995        assert!(result.is_err());
996        assert_contains!(
997            result.unwrap_err().to_string(),
998            "Non-nullable column 'b' is missing from the physical schema"
999        );
1000    }
1001
1002    /// Helper function to project expressions onto a RecordBatch
1003    fn batch_project(
1004        expr: Vec<Arc<dyn PhysicalExpr>>,
1005        batch: &RecordBatch,
1006        schema: SchemaRef,
1007    ) -> Result<RecordBatch> {
1008        let arrays = expr
1009            .iter()
1010            .map(|expr| {
1011                expr.evaluate(batch)
1012                    .and_then(|v| v.into_array(batch.num_rows()))
1013            })
1014            .collect::<Result<Vec<_>>>()?;
1015
1016        if arrays.is_empty() {
1017            let options =
1018                RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
1019            RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
1020                .map_err(Into::into)
1021        } else {
1022            RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
1023        }
1024    }
1025
1026    /// Example showing how we can use the `DefaultPhysicalExprAdapter` to adapt RecordBatches during a scan
1027    /// to apply projections, type conversions and handling of missing columns all at once.
1028    #[test]
1029    fn test_adapt_batches() {
1030        let physical_batch = record_batch!(
1031            ("a", Int32, vec![Some(1), None, Some(3)]),
1032            ("extra", Utf8, vec![Some("x"), Some("y"), None])
1033        )
1034        .unwrap();
1035
1036        let physical_schema = physical_batch.schema();
1037
1038        let logical_schema = Arc::new(Schema::new(vec![
1039            Field::new("a", DataType::Int64, true), // Different type
1040            Field::new("b", DataType::Utf8, true),  // Missing from physical
1041        ]));
1042
1043        let projection = vec![
1044            col("b", &logical_schema).unwrap(),
1045            col("a", &logical_schema).unwrap(),
1046        ];
1047
1048        let factory = DefaultPhysicalExprAdapterFactory;
1049        let adapter = factory
1050            .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema))
1051            .unwrap();
1052
1053        let adapted_projection = projection
1054            .into_iter()
1055            .map(|expr| adapter.rewrite(expr).unwrap())
1056            .collect_vec();
1057
1058        let adapted_schema = Arc::new(Schema::new(
1059            adapted_projection
1060                .iter()
1061                .map(|expr| expr.return_field(&physical_schema).unwrap())
1062                .collect_vec(),
1063        ));
1064
1065        let res = batch_project(
1066            adapted_projection,
1067            &physical_batch,
1068            Arc::clone(&adapted_schema),
1069        )
1070        .unwrap();
1071
1072        assert_eq!(res.num_columns(), 2);
1073        assert_eq!(res.column(0).data_type(), &DataType::Utf8);
1074        assert_eq!(res.column(1).data_type(), &DataType::Int64);
1075        assert_eq!(
1076            res.column(0)
1077                .as_any()
1078                .downcast_ref::<StringArray>()
1079                .unwrap()
1080                .iter()
1081                .collect_vec(),
1082            vec![None, None, None]
1083        );
1084        assert_eq!(
1085            res.column(1)
1086                .as_any()
1087                .downcast_ref::<Int64Array>()
1088                .unwrap()
1089                .iter()
1090                .collect_vec(),
1091            vec![Some(1), None, Some(3)]
1092        );
1093    }
1094
1095    /// Test that struct columns are properly adapted including:
1096    /// - Type casting of subfields (Int32 -> Int64, Utf8 -> Utf8View)
1097    /// - Missing fields in logical schema are filled with nulls
1098    #[test]
1099    fn test_adapt_struct_batches() {
1100        // Physical struct: {id: Int32, name: Utf8}
1101        let physical_struct_fields: Fields = vec![
1102            Field::new("id", DataType::Int32, false),
1103            Field::new("name", DataType::Utf8, true),
1104        ]
1105        .into();
1106
1107        let struct_array = StructArray::new(
1108            physical_struct_fields.clone(),
1109            vec![
1110                Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
1111                Arc::new(StringArray::from(vec![
1112                    Some("alice"),
1113                    None,
1114                    Some("charlie"),
1115                ])) as _,
1116            ],
1117            None,
1118        );
1119
1120        let physical_schema = Arc::new(Schema::new(vec![Field::new(
1121            "data",
1122            DataType::Struct(physical_struct_fields),
1123            false,
1124        )]));
1125
1126        let physical_batch = RecordBatch::try_new(
1127            Arc::clone(&physical_schema),
1128            vec![Arc::new(struct_array)],
1129        )
1130        .unwrap();
1131
1132        // Logical struct: {id: Int64, name: Utf8View, extra: Boolean}
1133        // - id: cast from Int32 to Int64
1134        // - name: cast from Utf8 to Utf8View
1135        // - extra: missing from physical, should be filled with nulls
1136        let logical_struct_fields: Fields = vec![
1137            Field::new("id", DataType::Int64, false),
1138            Field::new("name", DataType::Utf8View, true),
1139            Field::new("extra", DataType::Boolean, true), // New field, not in physical
1140        ]
1141        .into();
1142
1143        let logical_schema = Arc::new(Schema::new(vec![Field::new(
1144            "data",
1145            DataType::Struct(logical_struct_fields),
1146            false,
1147        )]));
1148
1149        let projection = vec![col("data", &logical_schema).unwrap()];
1150
1151        let factory = DefaultPhysicalExprAdapterFactory;
1152        let adapter = factory
1153            .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema))
1154            .unwrap();
1155
1156        let adapted_projection = projection
1157            .into_iter()
1158            .map(|expr| adapter.rewrite(expr).unwrap())
1159            .collect_vec();
1160
1161        let adapted_schema = Arc::new(Schema::new(
1162            adapted_projection
1163                .iter()
1164                .map(|expr| expr.return_field(&physical_schema).unwrap())
1165                .collect_vec(),
1166        ));
1167
1168        let res = batch_project(
1169            adapted_projection,
1170            &physical_batch,
1171            Arc::clone(&adapted_schema),
1172        )
1173        .unwrap();
1174
1175        assert_eq!(res.num_columns(), 1);
1176
1177        let result_struct = res
1178            .column(0)
1179            .as_any()
1180            .downcast_ref::<StructArray>()
1181            .unwrap();
1182
1183        // Verify id field is cast to Int64
1184        let id_col = result_struct.column_by_name("id").unwrap();
1185        assert_eq!(id_col.data_type(), &DataType::Int64);
1186        let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
1187        assert_eq!(
1188            id_values.iter().collect_vec(),
1189            vec![Some(1), Some(2), Some(3)]
1190        );
1191
1192        // Verify name field is cast to Utf8View
1193        let name_col = result_struct.column_by_name("name").unwrap();
1194        assert_eq!(name_col.data_type(), &DataType::Utf8View);
1195        let name_values = name_col.as_any().downcast_ref::<StringViewArray>().unwrap();
1196        assert_eq!(
1197            name_values.iter().collect_vec(),
1198            vec![Some("alice"), None, Some("charlie")]
1199        );
1200
1201        // Verify extra field (missing from physical) is filled with nulls
1202        let extra_col = result_struct.column_by_name("extra").unwrap();
1203        assert_eq!(extra_col.data_type(), &DataType::Boolean);
1204        let extra_values = extra_col.as_any().downcast_ref::<BooleanArray>().unwrap();
1205        assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]);
1206    }
1207
1208    #[test]
1209    fn test_try_rewrite_struct_field_access() {
1210        // Test the core logic of try_rewrite_struct_field_access
1211        let physical_schema = Schema::new(vec![Field::new(
1212            "struct_col",
1213            DataType::Struct(
1214                vec![Field::new("existing_field", DataType::Int32, true)].into(),
1215            ),
1216            true,
1217        )]);
1218
1219        let logical_schema = Schema::new(vec![Field::new(
1220            "struct_col",
1221            DataType::Struct(
1222                vec![
1223                    Field::new("existing_field", DataType::Int32, true),
1224                    Field::new("missing_field", DataType::Utf8, true),
1225                ]
1226                .into(),
1227            ),
1228            true,
1229        )]);
1230
1231        let rewriter = DefaultPhysicalExprAdapterRewriter {
1232            logical_file_schema: Arc::new(logical_schema),
1233            physical_file_schema: Arc::new(physical_schema),
1234        };
1235
1236        // Test that when a field exists in physical schema, it returns None
1237        let column = Arc::new(Column::new("struct_col", 0)) as Arc<dyn PhysicalExpr>;
1238        let result = rewriter.try_rewrite_struct_field_access(&column).unwrap();
1239        assert!(result.is_none());
1240
1241        // The actual test for the get_field expression would require creating a proper ScalarFunctionExpr
1242        // with ScalarUDF, which is complex to set up in a unit test. The integration tests in
1243        // datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality.
1244    }
1245
1246    // ============================================================================
1247    // BatchAdapterFactory and BatchAdapter tests
1248    // ============================================================================
1249
1250    #[test]
1251    fn test_batch_adapter_factory_basic() {
1252        // Target schema
1253        let target_schema = Arc::new(Schema::new(vec![
1254            Field::new("a", DataType::Int64, false),
1255            Field::new("b", DataType::Utf8, true),
1256        ]));
1257
1258        // Source schema with different column order and type
1259        let source_schema = Arc::new(Schema::new(vec![
1260            Field::new("b", DataType::Utf8, true),
1261            Field::new("a", DataType::Int32, false), // Int32 -> Int64
1262        ]));
1263
1264        let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1265        let adapter = factory.make_adapter(&source_schema).unwrap();
1266
1267        // Create source batch
1268        let source_batch = RecordBatch::try_new(
1269            Arc::clone(&source_schema),
1270            vec![
1271                Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
1272                Arc::new(Int32Array::from(vec![1, 2, 3])),
1273            ],
1274        )
1275        .unwrap();
1276
1277        let adapted = adapter.adapt_batch(&source_batch).unwrap();
1278
1279        // Verify schema matches target
1280        assert_eq!(adapted.num_columns(), 2);
1281        assert_eq!(adapted.schema().field(0).name(), "a");
1282        assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64);
1283        assert_eq!(adapted.schema().field(1).name(), "b");
1284        assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
1285
1286        // Verify data
1287        let col_a = adapted
1288            .column(0)
1289            .as_any()
1290            .downcast_ref::<Int64Array>()
1291            .unwrap();
1292        assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]);
1293
1294        let col_b = adapted
1295            .column(1)
1296            .as_any()
1297            .downcast_ref::<StringArray>()
1298            .unwrap();
1299        assert_eq!(
1300            col_b.iter().collect_vec(),
1301            vec![Some("hello"), None, Some("world")]
1302        );
1303    }
1304
1305    #[test]
1306    fn test_batch_adapter_factory_missing_column() {
1307        // Target schema with a column missing from source
1308        let target_schema = Arc::new(Schema::new(vec![
1309            Field::new("a", DataType::Int32, false),
1310            Field::new("b", DataType::Utf8, true), // exists in source
1311            Field::new("c", DataType::Float64, true), // missing from source
1312        ]));
1313
1314        let source_schema = Arc::new(Schema::new(vec![
1315            Field::new("a", DataType::Int32, false),
1316            Field::new("b", DataType::Utf8, true),
1317        ]));
1318
1319        let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1320        let adapter = factory.make_adapter(&source_schema).unwrap();
1321
1322        let source_batch = RecordBatch::try_new(
1323            Arc::clone(&source_schema),
1324            vec![
1325                Arc::new(Int32Array::from(vec![1, 2])),
1326                Arc::new(StringArray::from(vec!["x", "y"])),
1327            ],
1328        )
1329        .unwrap();
1330
1331        let adapted = adapter.adapt_batch(&source_batch).unwrap();
1332
1333        assert_eq!(adapted.num_columns(), 3);
1334
1335        // Missing column should be filled with nulls
1336        let col_c = adapted.column(2);
1337        assert_eq!(col_c.data_type(), &DataType::Float64);
1338        assert_eq!(col_c.null_count(), 2); // All nulls
1339    }
1340
1341    #[test]
1342    fn test_batch_adapter_factory_with_struct() {
1343        // Target has struct with Int64 id
1344        let target_struct_fields: Fields = vec![
1345            Field::new("id", DataType::Int64, false),
1346            Field::new("name", DataType::Utf8, true),
1347        ]
1348        .into();
1349        let target_schema = Arc::new(Schema::new(vec![Field::new(
1350            "data",
1351            DataType::Struct(target_struct_fields),
1352            false,
1353        )]));
1354
1355        // Source has struct with Int32 id
1356        let source_struct_fields: Fields = vec![
1357            Field::new("id", DataType::Int32, false),
1358            Field::new("name", DataType::Utf8, true),
1359        ]
1360        .into();
1361        let source_schema = Arc::new(Schema::new(vec![Field::new(
1362            "data",
1363            DataType::Struct(source_struct_fields.clone()),
1364            false,
1365        )]));
1366
1367        let struct_array = StructArray::new(
1368            source_struct_fields,
1369            vec![
1370                Arc::new(Int32Array::from(vec![10, 20])) as _,
1371                Arc::new(StringArray::from(vec!["a", "b"])) as _,
1372            ],
1373            None,
1374        );
1375
1376        let source_batch = RecordBatch::try_new(
1377            Arc::clone(&source_schema),
1378            vec![Arc::new(struct_array)],
1379        )
1380        .unwrap();
1381
1382        let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1383        let adapter = factory.make_adapter(&source_schema).unwrap();
1384        let adapted = adapter.adapt_batch(&source_batch).unwrap();
1385
1386        let result_struct = adapted
1387            .column(0)
1388            .as_any()
1389            .downcast_ref::<StructArray>()
1390            .unwrap();
1391
1392        // Verify id was cast to Int64
1393        let id_col = result_struct.column_by_name("id").unwrap();
1394        assert_eq!(id_col.data_type(), &DataType::Int64);
1395        let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
1396        assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]);
1397    }
1398
1399    #[test]
1400    fn test_batch_adapter_factory_identity() {
1401        // When source and target schemas are identical, should pass through efficiently
1402        let schema = Arc::new(Schema::new(vec![
1403            Field::new("a", DataType::Int32, false),
1404            Field::new("b", DataType::Utf8, true),
1405        ]));
1406
1407        let factory = BatchAdapterFactory::new(Arc::clone(&schema));
1408        let adapter = factory.make_adapter(&schema).unwrap();
1409
1410        let batch = RecordBatch::try_new(
1411            Arc::clone(&schema),
1412            vec![
1413                Arc::new(Int32Array::from(vec![1, 2, 3])),
1414                Arc::new(StringArray::from(vec!["a", "b", "c"])),
1415            ],
1416        )
1417        .unwrap();
1418
1419        let adapted = adapter.adapt_batch(&batch).unwrap();
1420
1421        assert_eq!(adapted.num_columns(), 2);
1422        assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32);
1423        assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
1424    }
1425
1426    #[test]
1427    fn test_batch_adapter_factory_reuse() {
1428        // Factory can create multiple adapters for different source schemas
1429        let target_schema = Arc::new(Schema::new(vec![
1430            Field::new("x", DataType::Int64, false),
1431            Field::new("y", DataType::Utf8, true),
1432        ]));
1433
1434        let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1435
1436        // First source schema
1437        let source1 = Arc::new(Schema::new(vec![
1438            Field::new("x", DataType::Int32, false),
1439            Field::new("y", DataType::Utf8, true),
1440        ]));
1441        let adapter1 = factory.make_adapter(&source1).unwrap();
1442
1443        // Second source schema (different order)
1444        let source2 = Arc::new(Schema::new(vec![
1445            Field::new("y", DataType::Utf8, true),
1446            Field::new("x", DataType::Int64, false),
1447        ]));
1448        let adapter2 = factory.make_adapter(&source2).unwrap();
1449
1450        // Both should work correctly
1451        assert!(format!("{adapter1:?}").contains("BatchAdapter"));
1452        assert!(format!("{adapter2:?}").contains("BatchAdapter"));
1453    }
1454
1455    #[test]
1456    fn test_rewrite_column_index_and_type_mismatch() {
1457        let physical_schema = Schema::new(vec![
1458            Field::new("b", DataType::Utf8, true),
1459            Field::new("a", DataType::Int32, false), // Index 1
1460        ]);
1461
1462        let logical_schema = Schema::new(vec![
1463            Field::new("a", DataType::Int64, false), // Index 0, Different Type
1464            Field::new("b", DataType::Utf8, true),
1465        ]);
1466
1467        let factory = DefaultPhysicalExprAdapterFactory;
1468        let adapter = factory
1469            .create(Arc::new(logical_schema), Arc::new(physical_schema))
1470            .unwrap();
1471
1472        // Logical column "a" is at index 0
1473        let column_expr = Arc::new(Column::new("a", 0));
1474
1475        let result = adapter.rewrite(column_expr).unwrap();
1476
1477        // Should be a CastColumnExpr
1478        let cast_expr = result
1479            .as_any()
1480            .downcast_ref::<CastColumnExpr>()
1481            .expect("Expected CastColumnExpr");
1482
1483        // Verify the inner column points to the correct physical index (1)
1484        let inner_col = cast_expr
1485            .expr()
1486            .as_any()
1487            .downcast_ref::<Column>()
1488            .expect("Expected inner Column");
1489        assert_eq!(inner_col.name(), "a");
1490        assert_eq!(inner_col.index(), 1); // Physical index is 1
1491
1492        // Verify cast types
1493        assert_eq!(
1494            cast_expr.data_type(&Schema::empty()).unwrap(),
1495            DataType::Int64
1496        );
1497    }
1498
1499    #[test]
1500    fn test_create_cast_column_expr_uses_name_lookup_not_column_index() {
1501        // Physical schema has column `a` at index 1; index 0 is an incompatible type.
1502        let physical_schema = Arc::new(Schema::new(vec![
1503            Field::new("b", DataType::Binary, true),
1504            Field::new("a", DataType::Int32, false),
1505        ]));
1506
1507        let logical_schema = Arc::new(Schema::new(vec![
1508            Field::new("a", DataType::Int64, false),
1509            Field::new("b", DataType::Binary, true),
1510        ]));
1511
1512        let rewriter = DefaultPhysicalExprAdapterRewriter {
1513            logical_file_schema: Arc::clone(&logical_schema),
1514            physical_file_schema: Arc::clone(&physical_schema),
1515        };
1516
1517        // Deliberately provide the wrong index for column `a`.
1518        // Regression: this must still resolve against physical field `a` by name.
1519        let transformed = rewriter
1520            .create_cast_column_expr(
1521                Column::new("a", 0),
1522                logical_schema.field_with_name("a").unwrap(),
1523            )
1524            .unwrap();
1525
1526        let cast_expr = transformed
1527            .data
1528            .as_any()
1529            .downcast_ref::<CastColumnExpr>()
1530            .expect("Expected CastColumnExpr");
1531
1532        assert_eq!(cast_expr.input_field().name(), "a");
1533        assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32);
1534        assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64);
1535    }
1536}