datafusion_physical_expr/expressions/
cast_column.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical expression for struct-aware casting of columns.
19
20use crate::physical_expr::PhysicalExpr;
21use arrow::{
22    compute::CastOptions,
23    datatypes::{DataType, FieldRef, Schema},
24    record_batch::RecordBatch,
25};
26use datafusion_common::{
27    format::DEFAULT_CAST_OPTIONS, nested_struct::cast_column, Result, ScalarValue,
28};
29use datafusion_expr_common::columnar_value::ColumnarValue;
30use std::{
31    any::Any,
32    fmt::{self, Display},
33    hash::Hash,
34    sync::Arc,
35};
36/// A physical expression that applies [`cast_column`] to its input.
37///
38/// [`CastColumnExpr`] extends the regular [`CastExpr`](super::CastExpr) by
39/// retaining schema metadata for both the input and output fields. This allows
40/// the evaluator to perform struct-aware casts that honour nested field
41/// ordering, preserve nullability, and fill missing fields with null values.
42///
43/// This expression is intended for schema rewriting scenarios where the
44/// planner already resolved the input column but needs to adapt its physical
45/// representation to a new [`arrow::datatypes::Field`]. It mirrors the behaviour of the
46/// [`datafusion_common::nested_struct::cast_column`] helper while integrating
47/// with the `PhysicalExpr` trait so it can participate in the execution plan
48/// like any other column expression.
49#[derive(Debug, Clone, Eq)]
50pub struct CastColumnExpr {
51    /// The physical expression producing the value to cast.
52    expr: Arc<dyn PhysicalExpr>,
53    /// The logical field of the input column.
54    input_field: FieldRef,
55    /// The field metadata describing the desired output column.
56    target_field: FieldRef,
57    /// Options forwarded to [`cast_column`].
58    cast_options: CastOptions<'static>,
59}
60
61// Manually derive `PartialEq`/`Hash` as `Arc<dyn PhysicalExpr>` does not
62// implement these traits by default for the trait object.
63impl PartialEq for CastColumnExpr {
64    fn eq(&self, other: &Self) -> bool {
65        self.expr.eq(&other.expr)
66            && self.input_field.eq(&other.input_field)
67            && self.target_field.eq(&other.target_field)
68            && self.cast_options.eq(&other.cast_options)
69    }
70}
71
72impl Hash for CastColumnExpr {
73    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
74        self.expr.hash(state);
75        self.input_field.hash(state);
76        self.target_field.hash(state);
77        self.cast_options.hash(state);
78    }
79}
80
81impl CastColumnExpr {
82    /// Create a new [`CastColumnExpr`].
83    pub fn new(
84        expr: Arc<dyn PhysicalExpr>,
85        input_field: FieldRef,
86        target_field: FieldRef,
87        cast_options: Option<CastOptions<'static>>,
88    ) -> Self {
89        Self {
90            expr,
91            input_field,
92            target_field,
93            cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
94        }
95    }
96
97    /// The expression that produces the value to be cast.
98    pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
99        &self.expr
100    }
101
102    /// Field metadata describing the resolved input column.
103    pub fn input_field(&self) -> &FieldRef {
104        &self.input_field
105    }
106
107    /// Field metadata describing the output column after casting.
108    pub fn target_field(&self) -> &FieldRef {
109        &self.target_field
110    }
111}
112
113impl Display for CastColumnExpr {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        write!(
116            f,
117            "CAST_COLUMN({} AS {:?})",
118            self.expr,
119            self.target_field.data_type()
120        )
121    }
122}
123
124impl PhysicalExpr for CastColumnExpr {
125    fn as_any(&self) -> &dyn Any {
126        self
127    }
128
129    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
130        Ok(self.target_field.data_type().clone())
131    }
132
133    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
134        Ok(self.target_field.is_nullable())
135    }
136
137    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
138        let value = self.expr.evaluate(batch)?;
139        match value {
140            ColumnarValue::Array(array) => {
141                let casted =
142                    cast_column(&array, self.target_field.as_ref(), &self.cast_options)?;
143                Ok(ColumnarValue::Array(casted))
144            }
145            ColumnarValue::Scalar(scalar) => {
146                let as_array = scalar.to_array_of_size(1)?;
147                let casted = cast_column(
148                    &as_array,
149                    self.target_field.as_ref(),
150                    &self.cast_options,
151                )?;
152                let result = ScalarValue::try_from_array(casted.as_ref(), 0)?;
153                Ok(ColumnarValue::Scalar(result))
154            }
155        }
156    }
157
158    fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
159        Ok(Arc::clone(&self.target_field))
160    }
161
162    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
163        vec![&self.expr]
164    }
165
166    fn with_new_children(
167        self: Arc<Self>,
168        mut children: Vec<Arc<dyn PhysicalExpr>>,
169    ) -> Result<Arc<dyn PhysicalExpr>> {
170        assert_eq!(children.len(), 1);
171        let child = children.pop().expect("CastColumnExpr child");
172        Ok(Arc::new(Self::new(
173            child,
174            Arc::clone(&self.input_field),
175            Arc::clone(&self.target_field),
176            Some(self.cast_options.clone()),
177        )))
178    }
179
180    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        Display::fmt(self, f)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    use crate::expressions::{Column, Literal};
190    use arrow::{
191        array::{Array, ArrayRef, BooleanArray, Int32Array, StringArray, StructArray},
192        datatypes::{DataType, Field, Fields, SchemaRef},
193    };
194    use datafusion_common::{
195        cast::{as_int64_array, as_string_array, as_struct_array, as_uint8_array},
196        Result as DFResult, ScalarValue,
197    };
198
199    fn make_schema(field: &Field) -> SchemaRef {
200        Arc::new(Schema::new(vec![field.clone()]))
201    }
202
203    fn make_struct_array(fields: Fields, arrays: Vec<ArrayRef>) -> StructArray {
204        StructArray::new(fields, arrays, None)
205    }
206
207    #[test]
208    fn cast_primitive_array() -> DFResult<()> {
209        let input_field = Field::new("a", DataType::Int32, true);
210        let target_field = Field::new("a", DataType::Int64, true);
211        let schema = make_schema(&input_field);
212
213        let values = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
214        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![values])?;
215
216        let column = Arc::new(Column::new_with_schema("a", schema.as_ref())?);
217        let expr = CastColumnExpr::new(
218            column,
219            Arc::new(input_field.clone()),
220            Arc::new(target_field.clone()),
221            None,
222        );
223
224        let result = expr.evaluate(&batch)?;
225        let ColumnarValue::Array(array) = result else {
226            panic!("expected array");
227        };
228        let casted = as_int64_array(array.as_ref())?;
229        assert_eq!(casted.value(0), 1);
230        assert!(casted.is_null(1));
231        assert_eq!(casted.value(2), 3);
232        Ok(())
233    }
234
235    #[test]
236    fn cast_struct_array_missing_child() -> DFResult<()> {
237        let source_a = Field::new("a", DataType::Int32, true);
238        let source_b = Field::new("b", DataType::Utf8, true);
239        let input_field = Field::new(
240            "s",
241            DataType::Struct(
242                vec![Arc::new(source_a.clone()), Arc::new(source_b.clone())].into(),
243            ),
244            true,
245        );
246        let target_a = Field::new("a", DataType::Int64, true);
247        let target_c = Field::new("c", DataType::Utf8, true);
248        let target_field = Field::new(
249            "s",
250            DataType::Struct(
251                vec![Arc::new(target_a.clone()), Arc::new(target_c.clone())].into(),
252            ),
253            true,
254        );
255
256        let schema = make_schema(&input_field);
257        let struct_array = make_struct_array(
258            vec![Arc::new(source_a.clone()), Arc::new(source_b.clone())].into(),
259            vec![
260                Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef,
261                Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")]))
262                    as ArrayRef,
263            ],
264        );
265        let batch = RecordBatch::try_new(
266            Arc::clone(&schema),
267            vec![Arc::new(struct_array) as Arc<_>],
268        )?;
269
270        let column = Arc::new(Column::new_with_schema("s", schema.as_ref())?);
271        let expr = CastColumnExpr::new(
272            column,
273            Arc::new(input_field.clone()),
274            Arc::new(target_field.clone()),
275            None,
276        );
277
278        let result = expr.evaluate(&batch)?;
279        let ColumnarValue::Array(array) = result else {
280            panic!("expected array");
281        };
282        let struct_array = as_struct_array(array.as_ref())?;
283        let cast_a = as_int64_array(struct_array.column_by_name("a").unwrap().as_ref())?;
284        assert_eq!(cast_a.value(0), 1);
285        assert!(cast_a.is_null(1));
286
287        let cast_c = as_string_array(struct_array.column_by_name("c").unwrap().as_ref())?;
288        assert!(cast_c.is_null(0));
289        assert!(cast_c.is_null(1));
290        Ok(())
291    }
292
293    #[test]
294    fn cast_nested_struct_array() -> DFResult<()> {
295        let inner_source = Field::new(
296            "inner",
297            DataType::Struct(
298                vec![Arc::new(Field::new("x", DataType::Int32, true))].into(),
299            ),
300            true,
301        );
302        let outer_field = Field::new(
303            "root",
304            DataType::Struct(vec![Arc::new(inner_source.clone())].into()),
305            true,
306        );
307
308        let inner_target = Field::new(
309            "inner",
310            DataType::Struct(
311                vec![
312                    Arc::new(Field::new("x", DataType::Int64, true)),
313                    Arc::new(Field::new("y", DataType::Boolean, true)),
314                ]
315                .into(),
316            ),
317            true,
318        );
319        let target_field = Field::new(
320            "root",
321            DataType::Struct(vec![Arc::new(inner_target.clone())].into()),
322            true,
323        );
324
325        let schema = make_schema(&outer_field);
326
327        let inner_struct = make_struct_array(
328            vec![Arc::new(Field::new("x", DataType::Int32, true))].into(),
329            vec![Arc::new(Int32Array::from(vec![Some(7), None])) as ArrayRef],
330        );
331        let outer_struct = make_struct_array(
332            vec![Arc::new(inner_source.clone())].into(),
333            vec![Arc::new(inner_struct) as ArrayRef],
334        );
335        let batch = RecordBatch::try_new(
336            Arc::clone(&schema),
337            vec![Arc::new(outer_struct) as ArrayRef],
338        )?;
339
340        let column = Arc::new(Column::new_with_schema("root", schema.as_ref())?);
341        let expr = CastColumnExpr::new(
342            column,
343            Arc::new(outer_field.clone()),
344            Arc::new(target_field.clone()),
345            None,
346        );
347
348        let result = expr.evaluate(&batch)?;
349        let ColumnarValue::Array(array) = result else {
350            panic!("expected array");
351        };
352        let struct_array = as_struct_array(array.as_ref())?;
353        let inner =
354            as_struct_array(struct_array.column_by_name("inner").unwrap().as_ref())?;
355        let x = as_int64_array(inner.column_by_name("x").unwrap().as_ref())?;
356        assert_eq!(x.value(0), 7);
357        assert!(x.is_null(1));
358        let y = inner.column_by_name("y").unwrap();
359        let y = y
360            .as_any()
361            .downcast_ref::<BooleanArray>()
362            .expect("boolean array");
363        assert!(y.is_null(0));
364        assert!(y.is_null(1));
365        Ok(())
366    }
367
368    #[test]
369    fn cast_struct_scalar() -> DFResult<()> {
370        let source_field = Field::new("a", DataType::Int32, true);
371        let input_field = Field::new(
372            "s",
373            DataType::Struct(vec![Arc::new(source_field.clone())].into()),
374            true,
375        );
376        let target_field = Field::new(
377            "s",
378            DataType::Struct(
379                vec![Arc::new(Field::new("a", DataType::UInt8, true))].into(),
380            ),
381            true,
382        );
383
384        let schema = make_schema(&input_field);
385        let scalar_struct = StructArray::new(
386            vec![Arc::new(source_field.clone())].into(),
387            vec![Arc::new(Int32Array::from(vec![Some(9)])) as ArrayRef],
388            None,
389        );
390        let literal =
391            Arc::new(Literal::new(ScalarValue::Struct(Arc::new(scalar_struct))));
392        let expr = CastColumnExpr::new(
393            literal,
394            Arc::new(input_field.clone()),
395            Arc::new(target_field.clone()),
396            None,
397        );
398
399        let batch = RecordBatch::new_empty(Arc::clone(&schema));
400        let result = expr.evaluate(&batch)?;
401        let ColumnarValue::Scalar(ScalarValue::Struct(array)) = result else {
402            panic!("expected struct scalar");
403        };
404        let casted = array.column_by_name("a").unwrap();
405        let casted = as_uint8_array(casted.as_ref())?;
406        assert_eq!(casted.value(0), 9);
407        Ok(())
408    }
409}