Skip to main content

datafusion_physical_expr/expressions/
cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt;
19use std::hash::Hash;
20use std::sync::Arc;
21
22use crate::physical_expr::PhysicalExpr;
23
24use arrow::compute::{CastOptions, can_cast_types};
25use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema};
26use arrow::record_batch::RecordBatch;
27use datafusion_common::datatype::DataTypeExt;
28use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
29use datafusion_common::nested_struct::{
30    requires_nested_struct_cast, validate_data_type_compatibility,
31};
32use datafusion_common::{Result, not_impl_err};
33use datafusion_expr_common::columnar_value::ColumnarValue;
34use datafusion_expr_common::interval_arithmetic::Interval;
35use datafusion_expr_common::sort_properties::ExprProperties;
36
37const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
38    safe: false,
39    format_options: DEFAULT_FORMAT_OPTIONS,
40};
41
42const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
43    safe: true,
44    format_options: DEFAULT_FORMAT_OPTIONS,
45};
46
47/// Check if name-based struct casting is allowed by validating field compatibility.
48///
49/// This function applies the same validation rules as execution time to ensure
50/// planning-time validation matches runtime validation, enabling fail-fast behavior
51/// instead of deferring errors to execution. Handles structs at any nesting level
52/// (e.g., `List<Struct>`, `Dictionary<_, Struct>`).
53fn can_cast_named_struct_types(source: &DataType, target: &DataType) -> bool {
54    validate_data_type_compatibility("", source, target).is_ok()
55}
56
57/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
58#[derive(Debug, Clone, Eq)]
59pub struct CastExpr {
60    /// The expression to cast
61    pub expr: Arc<dyn PhysicalExpr>,
62    /// Field metadata describing the desired output after casting
63    target_field: FieldRef,
64    /// Cast options
65    cast_options: CastOptions<'static>,
66}
67
68// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
69impl PartialEq for CastExpr {
70    fn eq(&self, other: &Self) -> bool {
71        self.expr.eq(&other.expr)
72            && self.target_field.eq(&other.target_field)
73            && self.cast_options.eq(&other.cast_options)
74    }
75}
76
77impl Hash for CastExpr {
78    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
79        self.expr.hash(state);
80        self.target_field.hash(state);
81        self.cast_options.hash(state);
82    }
83}
84
85impl CastExpr {
86    /// Create a new `CastExpr` using only a `DataType`.
87    ///
88    /// This constructor is provided for compatibility with existing call sites
89    /// that only know the target type.  It synthesizes a ``Field`` with the
90    /// given type (**nullable by default**) and no name metadata.  Callers that
91    /// already have a `FieldRef` (for example, coming from schema inference or a
92    /// resolved column) should prefer [`CastExpr::new_with_target_field`], which
93    /// preserves the field's name, nullability, and other metadata.  In other
94    /// words:
95    ///
96    /// * use `new()` when only a `DataType` is available and you want the legacy
97    ///   semantics of a type-only cast
98    /// * use `new_with_target_field()` when you need explicit field
99    ///   metadata/name/nullability preserved
100    pub fn new(
101        expr: Arc<dyn PhysicalExpr>,
102        cast_type: DataType,
103        cast_options: Option<CastOptions<'static>>,
104    ) -> Self {
105        Self::new_with_target_field(
106            expr,
107            cast_type.into_nullable_field_ref(),
108            cast_options,
109        )
110    }
111
112    /// Create a new `CastExpr` with an explicit target `FieldRef`.
113    ///
114    /// The provided `target_field` is used verbatim for the expression's
115    /// return schema, so the field's name, nullability, and other metadata are
116    /// preserved.  This is the preferred constructor when the caller already
117    /// has field information (for example, during logical-to-physical planning).
118    ///
119    /// See [`CastExpr::new`] for the compatibility constructor that only accepts
120    /// a `DataType`.
121    pub fn new_with_target_field(
122        expr: Arc<dyn PhysicalExpr>,
123        target_field: FieldRef,
124        cast_options: Option<CastOptions<'static>>,
125    ) -> Self {
126        Self {
127            expr,
128            target_field,
129            cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
130        }
131    }
132
133    /// The expression to cast
134    pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
135        &self.expr
136    }
137
138    /// The data type to cast to
139    pub fn cast_type(&self) -> &DataType {
140        self.target_field.data_type()
141    }
142
143    /// Field metadata describing the output column after casting.
144    pub fn target_field(&self) -> &FieldRef {
145        &self.target_field
146    }
147
148    /// The cast options
149    pub fn cast_options(&self) -> &CastOptions<'static> {
150        &self.cast_options
151    }
152
153    fn resolved_target_field(&self, input_schema: &Schema) -> Result<FieldRef> {
154        if is_default_target_field(&self.target_field) {
155            self.expr.return_field(input_schema).map(|field| {
156                Arc::new(
157                    field
158                        .as_ref()
159                        .clone()
160                        .with_data_type(self.cast_type().clone()),
161                )
162            })
163        } else {
164            Ok(Arc::clone(&self.target_field))
165        }
166    }
167
168    /// Check if casting from the specified source type to the target type is a
169    /// widening cast (e.g. from `Int8` to `Int16`).
170    pub fn check_bigger_cast(cast_type: &DataType, src: &DataType) -> bool {
171        if cast_type.eq(src) {
172            return true;
173        }
174        matches!(
175            (src, cast_type),
176            (Int8, Int16 | Int32 | Int64)
177                | (Int16, Int32 | Int64)
178                | (Int32, Int64)
179                | (UInt8, UInt16 | UInt32 | UInt64)
180                | (UInt16, UInt32 | UInt64)
181                | (UInt32, UInt64)
182                | (
183                    Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32,
184                    Float32 | Float64
185                )
186                | (Int64 | UInt64, Float64)
187                | (Utf8, LargeUtf8)
188        )
189    }
190
191    /// Check if the cast is a widening cast (e.g. from `Int8` to `Int16`).
192    pub fn is_bigger_cast(&self, src: &DataType) -> bool {
193        Self::check_bigger_cast(self.cast_type(), src)
194    }
195}
196
197fn is_default_target_field(target_field: &FieldRef) -> bool {
198    target_field.name().is_empty()
199        && target_field.is_nullable()
200        && target_field.metadata().is_empty()
201}
202
203pub(crate) fn is_order_preserving_cast_family(
204    source_type: &DataType,
205    target_type: &DataType,
206) -> bool {
207    (source_type.is_numeric() || *source_type == Boolean) && target_type.is_numeric()
208        || source_type.is_temporal() && target_type.is_temporal()
209        || source_type.eq(target_type)
210}
211
212pub(crate) fn cast_expr_properties(
213    child: &ExprProperties,
214    target_type: &DataType,
215) -> Result<ExprProperties> {
216    let unbounded = Interval::make_unbounded(target_type)?;
217    if is_order_preserving_cast_family(&child.range.data_type(), target_type) {
218        Ok(child.clone().with_range(unbounded))
219    } else {
220        Ok(ExprProperties::new_unknown().with_range(unbounded))
221    }
222}
223
224impl fmt::Display for CastExpr {
225    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226        write!(f, "CAST({} AS {})", self.expr, self.cast_type())
227    }
228}
229
230impl PhysicalExpr for CastExpr {
231    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
232        Ok(self.cast_type().clone())
233    }
234
235    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
236        // A cast is nullable if **either** the child is nullable or the
237        // target field allows nulls.  This conservative rule prevents
238        // optimizers from assuming a non-null result when a null input could
239        // still propagate.  `return_field()` continues to expose the exact
240        // target metadata separately.
241        let child_nullable = self.expr.nullable(input_schema)?;
242        let target_nullable = self.resolved_target_field(input_schema)?.is_nullable();
243        Ok(child_nullable || target_nullable)
244    }
245
246    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
247        let value = self.expr.evaluate(batch)?;
248        value.cast_to(self.cast_type(), Some(&self.cast_options))
249    }
250
251    fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
252        self.resolved_target_field(input_schema)
253    }
254
255    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
256        vec![&self.expr]
257    }
258
259    fn with_new_children(
260        self: Arc<Self>,
261        children: Vec<Arc<dyn PhysicalExpr>>,
262    ) -> Result<Arc<dyn PhysicalExpr>> {
263        Ok(Arc::new(CastExpr::new_with_target_field(
264            Arc::clone(&children[0]),
265            Arc::clone(&self.target_field),
266            Some(self.cast_options.clone()),
267        )))
268    }
269
270    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
271        // Cast current node's interval to the right type:
272        children[0].cast_to(self.cast_type(), &self.cast_options)
273    }
274
275    fn propagate_constraints(
276        &self,
277        interval: &Interval,
278        children: &[&Interval],
279    ) -> Result<Option<Vec<Interval>>> {
280        let child_interval = children[0];
281        // Get child's datatype:
282        let cast_type = child_interval.data_type();
283        Ok(Some(vec![
284            interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?,
285        ]))
286    }
287
288    /// A [`CastExpr`] preserves the ordering of its child if the cast is done
289    /// under the same datatype family.
290    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
291        cast_expr_properties(&children[0], self.cast_type())
292    }
293
294    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295        write!(f, "CAST(")?;
296        self.expr.fmt_sql(f)?;
297        write!(f, " AS {:?}", self.cast_type())?;
298
299        write!(f, ")")
300    }
301}
302
303/// Return a PhysicalExpression representing `expr` casted to
304/// `cast_type`, if any casting is needed.
305///
306/// Note that such casts may lose type information
307pub fn cast_with_options(
308    expr: Arc<dyn PhysicalExpr>,
309    input_schema: &Schema,
310    cast_type: DataType,
311    cast_options: Option<CastOptions<'static>>,
312) -> Result<Arc<dyn PhysicalExpr>> {
313    cast_with_target_field(
314        expr,
315        input_schema,
316        cast_type.into_nullable_field_ref(),
317        cast_options,
318    )
319}
320
321/// Return a PhysicalExpression representing `expr` casted to `target_field`,
322/// preserving any explicit field semantics such as name, nullability, and
323/// metadata.
324///
325/// If the input expression already has the same data type, this helper still
326/// preserves an explicit `target_field` by constructing a field-aware
327/// [`CastExpr`]. Only the default synthesized field created by the legacy
328/// type-only API is elided back to the original child expression.
329pub fn cast_with_target_field(
330    expr: Arc<dyn PhysicalExpr>,
331    input_schema: &Schema,
332    target_field: FieldRef,
333    cast_options: Option<CastOptions<'static>>,
334) -> Result<Arc<dyn PhysicalExpr>> {
335    let expr_type = expr.data_type(input_schema)?;
336    let cast_type = target_field.data_type();
337    if expr_type == *cast_type && is_default_target_field(&target_field) {
338        return Ok(Arc::clone(&expr));
339    }
340
341    let can_build_cast = if requires_nested_struct_cast(&expr_type, cast_type) {
342        // Allow casts involving structs (including nested inside Lists, Dictionaries,
343        // etc.) that pass name-based compatibility validation. This validation is
344        // applied at planning time (now) to fail fast, rather than deferring errors
345        // to execution time. The name-based casting logic will be executed at runtime
346        // via ColumnarValue::cast_to.
347        can_cast_named_struct_types(&expr_type, cast_type)
348    } else {
349        can_cast_types(&expr_type, cast_type)
350    };
351
352    if !can_build_cast {
353        return not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}");
354    }
355
356    Ok(Arc::new(CastExpr::new_with_target_field(
357        expr,
358        target_field,
359        cast_options,
360    )))
361}
362
363/// Return a PhysicalExpression representing `expr` casted to
364/// `cast_type`, if any casting is needed.
365///
366/// Note that such casts may lose type information
367pub fn cast(
368    expr: Arc<dyn PhysicalExpr>,
369    input_schema: &Schema,
370    cast_type: DataType,
371) -> Result<Arc<dyn PhysicalExpr>> {
372    cast_with_options(expr, input_schema, cast_type, None)
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    use crate::expressions::column::col;
380
381    use arrow::{
382        array::{
383            Array, ArrayRef, Decimal128Array, Float32Array, Float64Array, Int8Array,
384            Int16Array, Int32Array, Int64Array, StringArray, StructArray,
385            Time64NanosecondArray, TimestampNanosecondArray, UInt32Array,
386        },
387        datatypes::*,
388    };
389    use datafusion_common::ScalarValue;
390    use datafusion_common::cast::{
391        as_boolean_array, as_int64_array, as_string_array, as_struct_array,
392        as_uint8_array,
393    };
394    use datafusion_physical_expr_common::physical_expr::fmt_sql;
395    use insta::assert_snapshot;
396    use std::collections::HashMap;
397
398    fn make_struct_array(fields: Fields, arrays: Vec<ArrayRef>) -> StructArray {
399        StructArray::new(fields, arrays, None)
400    }
401
402    fn cast_struct_array(
403        column: &str,
404        input_field: Field,
405        target_field: Field,
406        input_array: StructArray,
407    ) -> Result<StructArray> {
408        let schema = Arc::new(Schema::new(vec![input_field]));
409        let batch = RecordBatch::try_new(
410            Arc::clone(&schema),
411            vec![Arc::new(input_array) as ArrayRef],
412        )?;
413        let expr = CastExpr::new_with_target_field(
414            col(column, schema.as_ref())?,
415            Arc::new(target_field),
416            None,
417        );
418
419        let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
420        Ok(as_struct_array(result.as_ref())?.clone())
421    }
422
423    // runs an end-to-end test of physical type cast
424    // 1. construct a record batch with a column "a" of type A
425    // 2. construct a physical expression of CAST(a AS B)
426    // 3. evaluate the expression
427    // 4. verify that the resulting expression is of type B
428    // 5. verify that the resulting values are downcastable and correct
429    macro_rules! generic_decimal_to_other_test_cast {
430        ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{
431            let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
432            let batch = RecordBatch::try_new(
433                Arc::new(schema.clone()),
434                vec![Arc::new($DECIMAL_ARRAY)],
435            )?;
436            // verify that we can construct the expression
437            let expression =
438                cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
439
440            // verify that its display is correct
441            assert_eq!(format!("CAST(a@0 AS {})", $TYPE), format!("{}", expression));
442
443            // verify that the expression's type is correct
444            assert_eq!(expression.data_type(&schema)?, $TYPE);
445
446            // compute
447            let result = expression
448                .evaluate(&batch)?
449                .into_array(batch.num_rows())
450                .expect("Failed to convert to array");
451
452            // verify that the array's data_type is correct
453            assert_eq!(*result.data_type(), $TYPE);
454
455            // verify that the data itself is downcastable
456            let result = result
457                .as_any()
458                .downcast_ref::<$TYPEARRAY>()
459                .expect("failed to downcast");
460
461            // verify that the result itself is correct
462            for (i, x) in $VEC.iter().enumerate() {
463                match x {
464                    Some(x) => assert_eq!(result.value(i), *x),
465                    None => assert!(result.is_null(i)),
466                }
467            }
468        }};
469    }
470
471    // runs an end-to-end test of physical type cast
472    // 1. construct a record batch with a column "a" of type A
473    // 2. construct a physical expression of CAST(a AS B)
474    // 3. evaluate the expression
475    // 4. verify that the resulting expression is of type B
476    // 5. verify that the resulting values are downcastable and correct
477    macro_rules! generic_test_cast {
478        ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
479            let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
480            let a_vec_len = $A_VEC.len();
481            let a = $A_ARRAY::from($A_VEC);
482            let batch =
483                RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
484
485            // verify that we can construct the expression
486            let expression =
487                cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
488
489            // verify that its display is correct
490            assert_eq!(format!("CAST(a@0 AS {})", $TYPE), format!("{}", expression));
491
492            // verify that the expression's type is correct
493            assert_eq!(expression.data_type(&schema)?, $TYPE);
494
495            // compute
496            let result = expression
497                .evaluate(&batch)?
498                .into_array(batch.num_rows())
499                .expect("Failed to convert to array");
500
501            // verify that the array's data_type is correct
502            assert_eq!(*result.data_type(), $TYPE);
503
504            // verify that the len is correct
505            assert_eq!(result.len(), a_vec_len);
506
507            // verify that the data itself is downcastable
508            let result = result
509                .as_any()
510                .downcast_ref::<$TYPEARRAY>()
511                .expect("failed to downcast");
512
513            // verify that the result itself is correct
514            for (i, x) in $VEC.iter().enumerate() {
515                match x {
516                    Some(x) => assert_eq!(result.value(i), *x),
517                    None => assert!(result.is_null(i)),
518                }
519            }
520        }};
521    }
522
523    #[test]
524    fn test_cast_decimal_to_decimal() -> Result<()> {
525        let array = vec![
526            Some(1234),
527            Some(2222),
528            Some(3),
529            Some(4000),
530            Some(5000),
531            None,
532        ];
533
534        let decimal_array = array
535            .clone()
536            .into_iter()
537            .collect::<Decimal128Array>()
538            .with_precision_and_scale(10, 3)?;
539
540        generic_decimal_to_other_test_cast!(
541            decimal_array,
542            Decimal128(10, 3),
543            Decimal128Array,
544            Decimal128(20, 6),
545            [
546                Some(1_234_000),
547                Some(2_222_000),
548                Some(3_000),
549                Some(4_000_000),
550                Some(5_000_000),
551                None
552            ],
553            None
554        );
555
556        let decimal_array = array
557            .into_iter()
558            .collect::<Decimal128Array>()
559            .with_precision_and_scale(10, 3)?;
560
561        generic_decimal_to_other_test_cast!(
562            decimal_array,
563            Decimal128(10, 3),
564            Decimal128Array,
565            Decimal128(10, 2),
566            [Some(123), Some(222), Some(0), Some(400), Some(500), None],
567            None
568        );
569
570        Ok(())
571    }
572
573    #[test]
574    fn test_cast_decimal_to_decimal_overflow() -> Result<()> {
575        let array = vec![Some(123456789)];
576
577        let decimal_array = array
578            .clone()
579            .into_iter()
580            .collect::<Decimal128Array>()
581            .with_precision_and_scale(10, 3)?;
582
583        let schema = Schema::new(vec![Field::new("a", Decimal128(10, 3), false)]);
584        let batch = RecordBatch::try_new(
585            Arc::new(schema.clone()),
586            vec![Arc::new(decimal_array)],
587        )?;
588        let expression =
589            cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?;
590        let e = expression.evaluate(&batch).unwrap_err().strip_backtrace(); // panics on OK
591        assert_snapshot!(e, @"Arrow error: Invalid argument error: 123456.79 is too large to store in a Decimal128 of precision 6. Max is 9999.99");
592        // safe cast should return null
593        let expression_safe = cast_with_options(
594            col("a", &schema)?,
595            &schema,
596            Decimal128(6, 2),
597            Some(DEFAULT_SAFE_CAST_OPTIONS),
598        )?;
599        let result_safe = expression_safe
600            .evaluate(&batch)?
601            .into_array(batch.num_rows())
602            .expect("failed to convert to array");
603
604        assert!(result_safe.is_null(0));
605
606        Ok(())
607    }
608
609    #[test]
610    fn test_cast_decimal_to_numeric() -> Result<()> {
611        let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
612        // decimal to i8
613        let decimal_array = array
614            .clone()
615            .into_iter()
616            .collect::<Decimal128Array>()
617            .with_precision_and_scale(10, 0)?;
618        generic_decimal_to_other_test_cast!(
619            decimal_array,
620            Decimal128(10, 0),
621            Int8Array,
622            Int8,
623            [
624                Some(1_i8),
625                Some(2_i8),
626                Some(3_i8),
627                Some(4_i8),
628                Some(5_i8),
629                None
630            ],
631            None
632        );
633
634        // decimal to i16
635        let decimal_array = array
636            .clone()
637            .into_iter()
638            .collect::<Decimal128Array>()
639            .with_precision_and_scale(10, 0)?;
640        generic_decimal_to_other_test_cast!(
641            decimal_array,
642            Decimal128(10, 0),
643            Int16Array,
644            Int16,
645            [
646                Some(1_i16),
647                Some(2_i16),
648                Some(3_i16),
649                Some(4_i16),
650                Some(5_i16),
651                None
652            ],
653            None
654        );
655
656        // decimal to i32
657        let decimal_array = array
658            .clone()
659            .into_iter()
660            .collect::<Decimal128Array>()
661            .with_precision_and_scale(10, 0)?;
662        generic_decimal_to_other_test_cast!(
663            decimal_array,
664            Decimal128(10, 0),
665            Int32Array,
666            Int32,
667            [
668                Some(1_i32),
669                Some(2_i32),
670                Some(3_i32),
671                Some(4_i32),
672                Some(5_i32),
673                None
674            ],
675            None
676        );
677
678        // decimal to i64
679        let decimal_array = array
680            .into_iter()
681            .collect::<Decimal128Array>()
682            .with_precision_and_scale(10, 0)?;
683        generic_decimal_to_other_test_cast!(
684            decimal_array,
685            Decimal128(10, 0),
686            Int64Array,
687            Int64,
688            [
689                Some(1_i64),
690                Some(2_i64),
691                Some(3_i64),
692                Some(4_i64),
693                Some(5_i64),
694                None
695            ],
696            None
697        );
698
699        // decimal to float32
700        let array = vec![
701            Some(1234),
702            Some(2222),
703            Some(3),
704            Some(4000),
705            Some(5000),
706            None,
707        ];
708        let decimal_array = array
709            .clone()
710            .into_iter()
711            .collect::<Decimal128Array>()
712            .with_precision_and_scale(10, 3)?;
713        generic_decimal_to_other_test_cast!(
714            decimal_array,
715            Decimal128(10, 3),
716            Float32Array,
717            Float32,
718            [
719                Some(1.234_f32),
720                Some(2.222_f32),
721                Some(0.003_f32),
722                Some(4.0_f32),
723                Some(5.0_f32),
724                None
725            ],
726            None
727        );
728
729        // decimal to float64
730        let decimal_array = array
731            .into_iter()
732            .collect::<Decimal128Array>()
733            .with_precision_and_scale(20, 6)?;
734        generic_decimal_to_other_test_cast!(
735            decimal_array,
736            Decimal128(20, 6),
737            Float64Array,
738            Float64,
739            [
740                Some(0.001234_f64),
741                Some(0.002222_f64),
742                Some(0.000003_f64),
743                Some(0.004_f64),
744                Some(0.005_f64),
745                None
746            ],
747            None
748        );
749        Ok(())
750    }
751
752    #[test]
753    fn test_cast_numeric_to_decimal() -> Result<()> {
754        // int8
755        generic_test_cast!(
756            Int8Array,
757            Int8,
758            vec![1, 2, 3, 4, 5],
759            Decimal128Array,
760            Decimal128(3, 0),
761            [Some(1), Some(2), Some(3), Some(4), Some(5)],
762            None
763        );
764
765        // int16
766        generic_test_cast!(
767            Int16Array,
768            Int16,
769            vec![1, 2, 3, 4, 5],
770            Decimal128Array,
771            Decimal128(5, 0),
772            [Some(1), Some(2), Some(3), Some(4), Some(5)],
773            None
774        );
775
776        // int32
777        generic_test_cast!(
778            Int32Array,
779            Int32,
780            vec![1, 2, 3, 4, 5],
781            Decimal128Array,
782            Decimal128(10, 0),
783            [Some(1), Some(2), Some(3), Some(4), Some(5)],
784            None
785        );
786
787        // int64
788        generic_test_cast!(
789            Int64Array,
790            Int64,
791            vec![1, 2, 3, 4, 5],
792            Decimal128Array,
793            Decimal128(20, 0),
794            [Some(1), Some(2), Some(3), Some(4), Some(5)],
795            None
796        );
797
798        // int64 to different scale
799        generic_test_cast!(
800            Int64Array,
801            Int64,
802            vec![1, 2, 3, 4, 5],
803            Decimal128Array,
804            Decimal128(20, 2),
805            [Some(100), Some(200), Some(300), Some(400), Some(500)],
806            None
807        );
808
809        // float32
810        generic_test_cast!(
811            Float32Array,
812            Float32,
813            vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
814            Decimal128Array,
815            Decimal128(10, 2),
816            [Some(150), Some(250), Some(300), Some(112), Some(550)],
817            None
818        );
819
820        // float64
821        generic_test_cast!(
822            Float64Array,
823            Float64,
824            vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
825            Decimal128Array,
826            Decimal128(20, 4),
827            [
828                Some(15000),
829                Some(25000),
830                Some(30000),
831                Some(11235),
832                Some(55000)
833            ],
834            None
835        );
836        Ok(())
837    }
838
839    #[test]
840    fn test_cast_i32_u32() -> Result<()> {
841        generic_test_cast!(
842            Int32Array,
843            Int32,
844            vec![1, 2, 3, 4, 5],
845            UInt32Array,
846            UInt32,
847            [
848                Some(1_u32),
849                Some(2_u32),
850                Some(3_u32),
851                Some(4_u32),
852                Some(5_u32)
853            ],
854            None
855        );
856        Ok(())
857    }
858
859    #[test]
860    fn test_cast_i32_utf8() -> Result<()> {
861        generic_test_cast!(
862            Int32Array,
863            Int32,
864            vec![1, 2, 3, 4, 5],
865            StringArray,
866            Utf8,
867            [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
868            None
869        );
870        Ok(())
871    }
872
873    #[test]
874    fn test_cast_i64_t64() -> Result<()> {
875        let original = vec![1, 2, 3, 4, 5];
876        let expected: Vec<Option<i64>> = original
877            .iter()
878            .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
879            .collect();
880        generic_test_cast!(
881            Int64Array,
882            Int64,
883            original,
884            TimestampNanosecondArray,
885            Timestamp(TimeUnit::Nanosecond, None),
886            expected,
887            None
888        );
889        Ok(())
890    }
891
892    // Tests for timestamp timezone casting have been moved to timestamps.slt
893    // See the "Casting between timestamp with and without timezone" section
894
895    #[test]
896    fn invalid_cast() {
897        // Ensure a useful error happens at plan time if invalid casts are used
898        let schema = Schema::new(vec![Field::new("a", Int32, false)]);
899
900        let result = cast(
901            col("a", &schema).unwrap(),
902            &schema,
903            Interval(IntervalUnit::MonthDayNano),
904        );
905        result.expect_err("expected Invalid CAST");
906    }
907
908    #[test]
909    fn invalid_cast_with_options_error() -> Result<()> {
910        // Ensure a useful error happens at plan time if invalid casts are used
911        let schema = Schema::new(vec![Field::new("a", Utf8, false)]);
912        let a = StringArray::from(vec!["9.1"]);
913        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
914        let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?;
915        let result = expression.evaluate(&batch);
916
917        match result {
918            Ok(_) => panic!("expected error"),
919            Err(e) => {
920                assert!(
921                    e.to_string()
922                        .contains("Cannot cast string '9.1' to value of Int32 type")
923                )
924            }
925        }
926        Ok(())
927    }
928
929    #[test]
930    fn field_aware_cast_preserves_target_field_semantics() -> Result<()> {
931        let metadata = HashMap::from([("target_meta".to_string(), "1".to_string())]);
932
933        for (child_nullable, target_nullable) in [(true, false), (false, true)] {
934            let schema = Schema::new(vec![Field::new("a", Int32, child_nullable)]);
935            let expr = CastExpr::new_with_target_field(
936                col("a", &schema)?,
937                Arc::new(
938                    Field::new("cast_target", Int64, target_nullable)
939                        .with_metadata(metadata.clone()),
940                ),
941                None,
942            );
943
944            let field = expr.return_field(&schema)?;
945            assert_eq!(field.name(), "cast_target");
946            assert_eq!(field.data_type(), &Int64);
947            assert_eq!(field.is_nullable(), target_nullable);
948            assert_eq!(
949                field.metadata().get("target_meta").map(String::as_str),
950                Some("1")
951            );
952            assert_eq!(expr.nullable(&schema)?, child_nullable || target_nullable);
953        }
954
955        Ok(())
956    }
957
958    #[test]
959    fn type_only_cast_preserves_legacy_field_name_and_nullability() -> Result<()> {
960        let schema = Schema::new(vec![Field::new("a", Int32, false)]);
961        let expr = CastExpr::new(col("a", &schema)?, Int64, None);
962
963        let field = expr.return_field(&schema)?;
964
965        assert_eq!(field.name(), "a");
966        assert_eq!(field.data_type(), &Int64);
967        assert!(!field.is_nullable());
968        assert!(!expr.nullable(&schema)?);
969
970        Ok(())
971    }
972
973    #[test]
974    fn struct_cast_validation_uses_nested_target_fields() -> Result<()> {
975        let source_type = Struct(Fields::from(vec![
976            Arc::new(Field::new("x", Int32, true)),
977            Arc::new(Field::new("y", Utf8, true)),
978        ]));
979        let schema = Schema::new(vec![Field::new("a", source_type.clone(), true)]);
980
981        let valid_target = Struct(Fields::from(vec![
982            Arc::new(Field::new("y", Utf8, true)),
983            Arc::new(Field::new("x", Int64, true)),
984        ]));
985        cast_with_options(col("a", &schema)?, &schema, valid_target, None)?;
986
987        let invalid_target = Struct(Fields::from(vec![
988            Arc::new(Field::new("y", Utf8, true)),
989            Arc::new(Field::new("missing", Int64, false)),
990        ]));
991        let err = cast_with_options(col("a", &schema)?, &schema, invalid_target, None)
992            .expect_err("missing required struct field should fail");
993
994        assert!(err.to_string().contains("Unsupported CAST"));
995
996        Ok(())
997    }
998
999    #[test]
1000    fn field_aware_cast_struct_array_missing_child() -> Result<()> {
1001        let source_a = Field::new("a", Int32, true);
1002        let source_b = Field::new("b", Utf8, true);
1003        let target_field = Field::new(
1004            "s",
1005            Struct(
1006                vec![
1007                    Arc::new(Field::new("a", Int64, true)),
1008                    Arc::new(Field::new("c", Utf8, true)),
1009                ]
1010                .into(),
1011            ),
1012            true,
1013        );
1014
1015        let struct_array = cast_struct_array(
1016            "s",
1017            Field::new(
1018                "s",
1019                Struct(
1020                    vec![Arc::new(source_a.clone()), Arc::new(source_b.clone())].into(),
1021                ),
1022                true,
1023            ),
1024            target_field,
1025            make_struct_array(
1026                vec![Arc::new(source_a), Arc::new(source_b)].into(),
1027                vec![
1028                    Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef,
1029                    Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")]))
1030                        as ArrayRef,
1031                ],
1032            ),
1033        )?;
1034        let cast_a = as_int64_array(struct_array.column_by_name("a").unwrap().as_ref())?;
1035        assert_eq!(cast_a.value(0), 1);
1036        assert!(cast_a.is_null(1));
1037
1038        let cast_c = as_string_array(struct_array.column_by_name("c").unwrap().as_ref())?;
1039        assert!(cast_c.is_null(0));
1040        assert!(cast_c.is_null(1));
1041        Ok(())
1042    }
1043
1044    #[test]
1045    fn field_aware_cast_nested_struct_array() -> Result<()> {
1046        let inner_source = Field::new(
1047            "inner",
1048            Struct(vec![Arc::new(Field::new("x", Int32, true))].into()),
1049            true,
1050        );
1051        let inner_target = Field::new(
1052            "inner",
1053            Struct(
1054                vec![
1055                    Arc::new(Field::new("x", Int64, true)),
1056                    Arc::new(Field::new("y", Boolean, true)),
1057                ]
1058                .into(),
1059            ),
1060            true,
1061        );
1062        let target_field =
1063            Field::new("root", Struct(vec![Arc::new(inner_target)].into()), true);
1064
1065        let inner_struct = make_struct_array(
1066            vec![Arc::new(Field::new("x", Int32, true))].into(),
1067            vec![Arc::new(Int32Array::from(vec![Some(7), None])) as ArrayRef],
1068        );
1069        let outer_struct = make_struct_array(
1070            vec![Arc::new(inner_source.clone())].into(),
1071            vec![Arc::new(inner_struct) as ArrayRef],
1072        );
1073        let struct_array = cast_struct_array(
1074            "root",
1075            Field::new("root", Struct(vec![Arc::new(inner_source)].into()), true),
1076            target_field,
1077            outer_struct,
1078        )?;
1079        let inner =
1080            as_struct_array(struct_array.column_by_name("inner").unwrap().as_ref())?;
1081        let x = as_int64_array(inner.column_by_name("x").unwrap().as_ref())?;
1082        assert_eq!(x.value(0), 7);
1083        assert!(x.is_null(1));
1084        let y = as_boolean_array(inner.column_by_name("y").unwrap().as_ref())?;
1085        assert!(y.is_null(0));
1086        assert!(y.is_null(1));
1087        Ok(())
1088    }
1089
1090    #[test]
1091    fn field_aware_cast_struct_scalar() -> Result<()> {
1092        let source_field = Field::new("a", Int32, true);
1093        let target_field = Field::new(
1094            "s",
1095            Struct(vec![Arc::new(Field::new("a", UInt8, true))].into()),
1096            true,
1097        );
1098
1099        let schema = Arc::new(Schema::new(vec![Field::new(
1100            "s",
1101            Struct(vec![Arc::new(source_field.clone())].into()),
1102            true,
1103        )]));
1104        let scalar_struct = make_struct_array(
1105            vec![Arc::new(source_field)].into(),
1106            vec![Arc::new(Int32Array::from(vec![Some(9)])) as ArrayRef],
1107        );
1108        let literal = Arc::new(crate::expressions::Literal::new(ScalarValue::Struct(
1109            Arc::new(scalar_struct),
1110        )));
1111        let expr = CastExpr::new_with_target_field(literal, Arc::new(target_field), None);
1112
1113        let batch = RecordBatch::new_empty(schema);
1114        let result = expr.evaluate(&batch)?;
1115        let ColumnarValue::Scalar(ScalarValue::Struct(array)) = result else {
1116            panic!("expected struct scalar");
1117        };
1118        let casted = as_uint8_array(array.column_by_name("a").unwrap().as_ref())?;
1119        assert_eq!(casted.value(0), 9);
1120        Ok(())
1121    }
1122
1123    #[test]
1124    #[ignore] // TODO: https://github.com/apache/datafusion/issues/5396
1125    fn test_cast_decimal() -> Result<()> {
1126        let schema = Schema::new(vec![Field::new("a", Int64, false)]);
1127        let a = Int64Array::from(vec![100]);
1128        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1129        let expression =
1130            cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?;
1131        expression.evaluate(&batch)?;
1132        Ok(())
1133    }
1134
1135    #[test]
1136    fn test_fmt_sql() -> Result<()> {
1137        let schema = Schema::new(vec![Field::new("a", Int32, true)]);
1138
1139        // Test numeric casting
1140        let expr = cast(col("a", &schema)?, &schema, Int64)?;
1141        let display_string = expr.to_string();
1142        assert_eq!(display_string, "CAST(a@0 AS Int64)");
1143        let sql_string = fmt_sql(expr.as_ref()).to_string();
1144        assert_eq!(sql_string, "CAST(a AS Int64)");
1145
1146        // Test string casting
1147        let schema = Schema::new(vec![Field::new("b", Utf8, true)]);
1148        let expr = cast(col("b", &schema)?, &schema, Int32)?;
1149        let display_string = expr.to_string();
1150        assert_eq!(display_string, "CAST(b@0 AS Int32)");
1151        let sql_string = fmt_sql(expr.as_ref()).to_string();
1152        assert_eq!(sql_string, "CAST(b AS Int32)");
1153
1154        Ok(())
1155    }
1156}