datafusion_physical_expr/expressions/
cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt;
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::physical_expr::PhysicalExpr;
24
25use arrow::compute::{can_cast_types, CastOptions};
26use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema};
27use arrow::record_batch::RecordBatch;
28use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
29use datafusion_common::{not_impl_err, Result};
30use datafusion_expr_common::columnar_value::ColumnarValue;
31use datafusion_expr_common::interval_arithmetic::Interval;
32use datafusion_expr_common::sort_properties::ExprProperties;
33
34const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
35    safe: false,
36    format_options: DEFAULT_FORMAT_OPTIONS,
37};
38
39const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
40    safe: true,
41    format_options: DEFAULT_FORMAT_OPTIONS,
42};
43
44/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
45#[derive(Debug, Clone, Eq)]
46pub struct CastExpr {
47    /// The expression to cast
48    pub expr: Arc<dyn PhysicalExpr>,
49    /// The data type to cast to
50    cast_type: DataType,
51    /// Cast options
52    cast_options: CastOptions<'static>,
53}
54
55// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
56impl PartialEq for CastExpr {
57    fn eq(&self, other: &Self) -> bool {
58        self.expr.eq(&other.expr)
59            && self.cast_type.eq(&other.cast_type)
60            && self.cast_options.eq(&other.cast_options)
61    }
62}
63
64impl Hash for CastExpr {
65    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
66        self.expr.hash(state);
67        self.cast_type.hash(state);
68        self.cast_options.hash(state);
69    }
70}
71
72impl CastExpr {
73    /// Create a new CastExpr
74    pub fn new(
75        expr: Arc<dyn PhysicalExpr>,
76        cast_type: DataType,
77        cast_options: Option<CastOptions<'static>>,
78    ) -> Self {
79        Self {
80            expr,
81            cast_type,
82            cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
83        }
84    }
85
86    /// The expression to cast
87    pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
88        &self.expr
89    }
90
91    /// The data type to cast to
92    pub fn cast_type(&self) -> &DataType {
93        &self.cast_type
94    }
95
96    /// The cast options
97    pub fn cast_options(&self) -> &CastOptions<'static> {
98        &self.cast_options
99    }
100    pub fn is_bigger_cast(&self, src: DataType) -> bool {
101        if src == self.cast_type {
102            return true;
103        }
104        matches!(
105            (src, &self.cast_type),
106            (Int8, Int16 | Int32 | Int64)
107                | (Int16, Int32 | Int64)
108                | (Int32, Int64)
109                | (UInt8, UInt16 | UInt32 | UInt64)
110                | (UInt16, UInt32 | UInt64)
111                | (UInt32, UInt64)
112                | (
113                    Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32,
114                    Float32 | Float64
115                )
116                | (Int64 | UInt64, Float64)
117                | (Utf8, LargeUtf8)
118        )
119    }
120}
121
122impl fmt::Display for CastExpr {
123    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124        write!(f, "CAST({} AS {:?})", self.expr, self.cast_type)
125    }
126}
127
128impl PhysicalExpr for CastExpr {
129    /// Return a reference to Any that can be used for downcasting
130    fn as_any(&self) -> &dyn Any {
131        self
132    }
133
134    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
135        Ok(self.cast_type.clone())
136    }
137
138    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
139        self.expr.nullable(input_schema)
140    }
141
142    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
143        let value = self.expr.evaluate(batch)?;
144        value.cast_to(&self.cast_type, Some(&self.cast_options))
145    }
146
147    fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
148        Ok(self
149            .expr
150            .return_field(input_schema)?
151            .as_ref()
152            .clone()
153            .with_data_type(self.cast_type.clone())
154            .into())
155    }
156
157    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
158        vec![&self.expr]
159    }
160
161    fn with_new_children(
162        self: Arc<Self>,
163        children: Vec<Arc<dyn PhysicalExpr>>,
164    ) -> Result<Arc<dyn PhysicalExpr>> {
165        Ok(Arc::new(CastExpr::new(
166            Arc::clone(&children[0]),
167            self.cast_type.clone(),
168            Some(self.cast_options.clone()),
169        )))
170    }
171
172    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
173        // Cast current node's interval to the right type:
174        children[0].cast_to(&self.cast_type, &self.cast_options)
175    }
176
177    fn propagate_constraints(
178        &self,
179        interval: &Interval,
180        children: &[&Interval],
181    ) -> Result<Option<Vec<Interval>>> {
182        let child_interval = children[0];
183        // Get child's datatype:
184        let cast_type = child_interval.data_type();
185        Ok(Some(vec![
186            interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?
187        ]))
188    }
189
190    /// A [`CastExpr`] preserves the ordering of its child if the cast is done
191    /// under the same datatype family.
192    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
193        let source_datatype = children[0].range.data_type();
194        let target_type = &self.cast_type;
195
196        let unbounded = Interval::make_unbounded(target_type)?;
197        if (source_datatype.is_numeric() || source_datatype == Boolean)
198            && target_type.is_numeric()
199            || source_datatype.is_temporal() && target_type.is_temporal()
200            || source_datatype.eq(target_type)
201        {
202            Ok(children[0].clone().with_range(unbounded))
203        } else {
204            Ok(ExprProperties::new_unknown().with_range(unbounded))
205        }
206    }
207
208    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        write!(f, "CAST(")?;
210        self.expr.fmt_sql(f)?;
211        write!(f, " AS {:?}", self.cast_type)?;
212
213        write!(f, ")")
214    }
215}
216
217/// Return a PhysicalExpression representing `expr` casted to
218/// `cast_type`, if any casting is needed.
219///
220/// Note that such casts may lose type information
221pub fn cast_with_options(
222    expr: Arc<dyn PhysicalExpr>,
223    input_schema: &Schema,
224    cast_type: DataType,
225    cast_options: Option<CastOptions<'static>>,
226) -> Result<Arc<dyn PhysicalExpr>> {
227    let expr_type = expr.data_type(input_schema)?;
228    if expr_type == cast_type {
229        Ok(Arc::clone(&expr))
230    } else if can_cast_types(&expr_type, &cast_type) {
231        Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
232    } else {
233        not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}")
234    }
235}
236
237/// Return a PhysicalExpression representing `expr` casted to
238/// `cast_type`, if any casting is needed.
239///
240/// Note that such casts may lose type information
241pub fn cast(
242    expr: Arc<dyn PhysicalExpr>,
243    input_schema: &Schema,
244    cast_type: DataType,
245) -> Result<Arc<dyn PhysicalExpr>> {
246    cast_with_options(expr, input_schema, cast_type, None)
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    use crate::expressions::column::col;
254
255    use arrow::{
256        array::{
257            Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
258            Int64Array, Int8Array, StringArray, Time64NanosecondArray,
259            TimestampNanosecondArray, UInt32Array,
260        },
261        datatypes::*,
262    };
263    use datafusion_common::assert_contains;
264    use datafusion_physical_expr_common::physical_expr::fmt_sql;
265
266    // runs an end-to-end test of physical type cast
267    // 1. construct a record batch with a column "a" of type A
268    // 2. construct a physical expression of CAST(a AS B)
269    // 3. evaluate the expression
270    // 4. verify that the resulting expression is of type B
271    // 5. verify that the resulting values are downcastable and correct
272    macro_rules! generic_decimal_to_other_test_cast {
273        ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{
274            let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
275            let batch = RecordBatch::try_new(
276                Arc::new(schema.clone()),
277                vec![Arc::new($DECIMAL_ARRAY)],
278            )?;
279            // verify that we can construct the expression
280            let expression =
281                cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
282
283            // verify that its display is correct
284            assert_eq!(
285                format!("CAST(a@0 AS {:?})", $TYPE),
286                format!("{}", expression)
287            );
288
289            // verify that the expression's type is correct
290            assert_eq!(expression.data_type(&schema)?, $TYPE);
291
292            // compute
293            let result = expression
294                .evaluate(&batch)?
295                .into_array(batch.num_rows())
296                .expect("Failed to convert to array");
297
298            // verify that the array's data_type is correct
299            assert_eq!(*result.data_type(), $TYPE);
300
301            // verify that the data itself is downcastable
302            let result = result
303                .as_any()
304                .downcast_ref::<$TYPEARRAY>()
305                .expect("failed to downcast");
306
307            // verify that the result itself is correct
308            for (i, x) in $VEC.iter().enumerate() {
309                match x {
310                    Some(x) => assert_eq!(result.value(i), *x),
311                    None => assert!(!result.is_valid(i)),
312                }
313            }
314        }};
315    }
316
317    // runs an end-to-end test of physical type cast
318    // 1. construct a record batch with a column "a" of type A
319    // 2. construct a physical expression of CAST(a AS B)
320    // 3. evaluate the expression
321    // 4. verify that the resulting expression is of type B
322    // 5. verify that the resulting values are downcastable and correct
323    macro_rules! generic_test_cast {
324        ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
325            let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
326            let a_vec_len = $A_VEC.len();
327            let a = $A_ARRAY::from($A_VEC);
328            let batch =
329                RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
330
331            // verify that we can construct the expression
332            let expression =
333                cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
334
335            // verify that its display is correct
336            assert_eq!(
337                format!("CAST(a@0 AS {:?})", $TYPE),
338                format!("{}", expression)
339            );
340
341            // verify that the expression's type is correct
342            assert_eq!(expression.data_type(&schema)?, $TYPE);
343
344            // compute
345            let result = expression
346                .evaluate(&batch)?
347                .into_array(batch.num_rows())
348                .expect("Failed to convert to array");
349
350            // verify that the array's data_type is correct
351            assert_eq!(*result.data_type(), $TYPE);
352
353            // verify that the len is correct
354            assert_eq!(result.len(), a_vec_len);
355
356            // verify that the data itself is downcastable
357            let result = result
358                .as_any()
359                .downcast_ref::<$TYPEARRAY>()
360                .expect("failed to downcast");
361
362            // verify that the result itself is correct
363            for (i, x) in $VEC.iter().enumerate() {
364                match x {
365                    Some(x) => assert_eq!(result.value(i), *x),
366                    None => assert!(!result.is_valid(i)),
367                }
368            }
369        }};
370    }
371
372    #[test]
373    fn test_cast_decimal_to_decimal() -> Result<()> {
374        let array = vec![
375            Some(1234),
376            Some(2222),
377            Some(3),
378            Some(4000),
379            Some(5000),
380            None,
381        ];
382
383        let decimal_array = array
384            .clone()
385            .into_iter()
386            .collect::<Decimal128Array>()
387            .with_precision_and_scale(10, 3)?;
388
389        generic_decimal_to_other_test_cast!(
390            decimal_array,
391            Decimal128(10, 3),
392            Decimal128Array,
393            Decimal128(20, 6),
394            [
395                Some(1_234_000),
396                Some(2_222_000),
397                Some(3_000),
398                Some(4_000_000),
399                Some(5_000_000),
400                None
401            ],
402            None
403        );
404
405        let decimal_array = array
406            .into_iter()
407            .collect::<Decimal128Array>()
408            .with_precision_and_scale(10, 3)?;
409
410        generic_decimal_to_other_test_cast!(
411            decimal_array,
412            Decimal128(10, 3),
413            Decimal128Array,
414            Decimal128(10, 2),
415            [Some(123), Some(222), Some(0), Some(400), Some(500), None],
416            None
417        );
418
419        Ok(())
420    }
421
422    #[test]
423    fn test_cast_decimal_to_decimal_overflow() -> Result<()> {
424        let array = vec![Some(123456789)];
425
426        let decimal_array = array
427            .clone()
428            .into_iter()
429            .collect::<Decimal128Array>()
430            .with_precision_and_scale(10, 3)?;
431
432        let schema = Schema::new(vec![Field::new("a", Decimal128(10, 3), false)]);
433        let batch = RecordBatch::try_new(
434            Arc::new(schema.clone()),
435            vec![Arc::new(decimal_array)],
436        )?;
437        let expression =
438            cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?;
439        let e = expression.evaluate(&batch).unwrap_err(); // panics on OK
440        assert_contains!(
441            e.to_string(),
442            "Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999"
443        );
444
445        let expression_safe = cast_with_options(
446            col("a", &schema)?,
447            &schema,
448            Decimal128(6, 2),
449            Some(DEFAULT_SAFE_CAST_OPTIONS),
450        )?;
451        let result_safe = expression_safe
452            .evaluate(&batch)?
453            .into_array(batch.num_rows())
454            .expect("failed to convert to array");
455
456        assert!(result_safe.is_null(0));
457
458        Ok(())
459    }
460
461    #[test]
462    fn test_cast_decimal_to_numeric() -> Result<()> {
463        let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
464        // decimal to i8
465        let decimal_array = array
466            .clone()
467            .into_iter()
468            .collect::<Decimal128Array>()
469            .with_precision_and_scale(10, 0)?;
470        generic_decimal_to_other_test_cast!(
471            decimal_array,
472            Decimal128(10, 0),
473            Int8Array,
474            Int8,
475            [
476                Some(1_i8),
477                Some(2_i8),
478                Some(3_i8),
479                Some(4_i8),
480                Some(5_i8),
481                None
482            ],
483            None
484        );
485
486        // decimal to i16
487        let decimal_array = array
488            .clone()
489            .into_iter()
490            .collect::<Decimal128Array>()
491            .with_precision_and_scale(10, 0)?;
492        generic_decimal_to_other_test_cast!(
493            decimal_array,
494            Decimal128(10, 0),
495            Int16Array,
496            Int16,
497            [
498                Some(1_i16),
499                Some(2_i16),
500                Some(3_i16),
501                Some(4_i16),
502                Some(5_i16),
503                None
504            ],
505            None
506        );
507
508        // decimal to i32
509        let decimal_array = array
510            .clone()
511            .into_iter()
512            .collect::<Decimal128Array>()
513            .with_precision_and_scale(10, 0)?;
514        generic_decimal_to_other_test_cast!(
515            decimal_array,
516            Decimal128(10, 0),
517            Int32Array,
518            Int32,
519            [
520                Some(1_i32),
521                Some(2_i32),
522                Some(3_i32),
523                Some(4_i32),
524                Some(5_i32),
525                None
526            ],
527            None
528        );
529
530        // decimal to i64
531        let decimal_array = array
532            .into_iter()
533            .collect::<Decimal128Array>()
534            .with_precision_and_scale(10, 0)?;
535        generic_decimal_to_other_test_cast!(
536            decimal_array,
537            Decimal128(10, 0),
538            Int64Array,
539            Int64,
540            [
541                Some(1_i64),
542                Some(2_i64),
543                Some(3_i64),
544                Some(4_i64),
545                Some(5_i64),
546                None
547            ],
548            None
549        );
550
551        // decimal to float32
552        let array = vec![
553            Some(1234),
554            Some(2222),
555            Some(3),
556            Some(4000),
557            Some(5000),
558            None,
559        ];
560        let decimal_array = array
561            .clone()
562            .into_iter()
563            .collect::<Decimal128Array>()
564            .with_precision_and_scale(10, 3)?;
565        generic_decimal_to_other_test_cast!(
566            decimal_array,
567            Decimal128(10, 3),
568            Float32Array,
569            Float32,
570            [
571                Some(1.234_f32),
572                Some(2.222_f32),
573                Some(0.003_f32),
574                Some(4.0_f32),
575                Some(5.0_f32),
576                None
577            ],
578            None
579        );
580
581        // decimal to float64
582        let decimal_array = array
583            .into_iter()
584            .collect::<Decimal128Array>()
585            .with_precision_and_scale(20, 6)?;
586        generic_decimal_to_other_test_cast!(
587            decimal_array,
588            Decimal128(20, 6),
589            Float64Array,
590            Float64,
591            [
592                Some(0.001234_f64),
593                Some(0.002222_f64),
594                Some(0.000003_f64),
595                Some(0.004_f64),
596                Some(0.005_f64),
597                None
598            ],
599            None
600        );
601        Ok(())
602    }
603
604    #[test]
605    fn test_cast_numeric_to_decimal() -> Result<()> {
606        // int8
607        generic_test_cast!(
608            Int8Array,
609            Int8,
610            vec![1, 2, 3, 4, 5],
611            Decimal128Array,
612            Decimal128(3, 0),
613            [Some(1), Some(2), Some(3), Some(4), Some(5)],
614            None
615        );
616
617        // int16
618        generic_test_cast!(
619            Int16Array,
620            Int16,
621            vec![1, 2, 3, 4, 5],
622            Decimal128Array,
623            Decimal128(5, 0),
624            [Some(1), Some(2), Some(3), Some(4), Some(5)],
625            None
626        );
627
628        // int32
629        generic_test_cast!(
630            Int32Array,
631            Int32,
632            vec![1, 2, 3, 4, 5],
633            Decimal128Array,
634            Decimal128(10, 0),
635            [Some(1), Some(2), Some(3), Some(4), Some(5)],
636            None
637        );
638
639        // int64
640        generic_test_cast!(
641            Int64Array,
642            Int64,
643            vec![1, 2, 3, 4, 5],
644            Decimal128Array,
645            Decimal128(20, 0),
646            [Some(1), Some(2), Some(3), Some(4), Some(5)],
647            None
648        );
649
650        // int64 to different scale
651        generic_test_cast!(
652            Int64Array,
653            Int64,
654            vec![1, 2, 3, 4, 5],
655            Decimal128Array,
656            Decimal128(20, 2),
657            [Some(100), Some(200), Some(300), Some(400), Some(500)],
658            None
659        );
660
661        // float32
662        generic_test_cast!(
663            Float32Array,
664            Float32,
665            vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
666            Decimal128Array,
667            Decimal128(10, 2),
668            [Some(150), Some(250), Some(300), Some(112), Some(550)],
669            None
670        );
671
672        // float64
673        generic_test_cast!(
674            Float64Array,
675            Float64,
676            vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
677            Decimal128Array,
678            Decimal128(20, 4),
679            [
680                Some(15000),
681                Some(25000),
682                Some(30000),
683                Some(11235),
684                Some(55000)
685            ],
686            None
687        );
688        Ok(())
689    }
690
691    #[test]
692    fn test_cast_i32_u32() -> Result<()> {
693        generic_test_cast!(
694            Int32Array,
695            Int32,
696            vec![1, 2, 3, 4, 5],
697            UInt32Array,
698            UInt32,
699            [
700                Some(1_u32),
701                Some(2_u32),
702                Some(3_u32),
703                Some(4_u32),
704                Some(5_u32)
705            ],
706            None
707        );
708        Ok(())
709    }
710
711    #[test]
712    fn test_cast_i32_utf8() -> Result<()> {
713        generic_test_cast!(
714            Int32Array,
715            Int32,
716            vec![1, 2, 3, 4, 5],
717            StringArray,
718            Utf8,
719            [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
720            None
721        );
722        Ok(())
723    }
724
725    #[test]
726    fn test_cast_i64_t64() -> Result<()> {
727        let original = vec![1, 2, 3, 4, 5];
728        let expected: Vec<Option<i64>> = original
729            .iter()
730            .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
731            .collect();
732        generic_test_cast!(
733            Int64Array,
734            Int64,
735            original,
736            TimestampNanosecondArray,
737            Timestamp(TimeUnit::Nanosecond, None),
738            expected,
739            None
740        );
741        Ok(())
742    }
743
744    #[test]
745    fn invalid_cast() {
746        // Ensure a useful error happens at plan time if invalid casts are used
747        let schema = Schema::new(vec![Field::new("a", Int32, false)]);
748
749        let result = cast(
750            col("a", &schema).unwrap(),
751            &schema,
752            Interval(IntervalUnit::MonthDayNano),
753        );
754        result.expect_err("expected Invalid CAST");
755    }
756
757    #[test]
758    fn invalid_cast_with_options_error() -> Result<()> {
759        // Ensure a useful error happens at plan time if invalid casts are used
760        let schema = Schema::new(vec![Field::new("a", Utf8, false)]);
761        let a = StringArray::from(vec!["9.1"]);
762        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
763        let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?;
764        let result = expression.evaluate(&batch);
765
766        match result {
767            Ok(_) => panic!("expected error"),
768            Err(e) => {
769                assert!(e
770                    .to_string()
771                    .contains("Cannot cast string '9.1' to value of Int32 type"))
772            }
773        }
774        Ok(())
775    }
776
777    #[test]
778    #[ignore] // TODO: https://github.com/apache/datafusion/issues/5396
779    fn test_cast_decimal() -> Result<()> {
780        let schema = Schema::new(vec![Field::new("a", Int64, false)]);
781        let a = Int64Array::from(vec![100]);
782        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
783        let expression =
784            cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?;
785        expression.evaluate(&batch)?;
786        Ok(())
787    }
788
789    #[test]
790    fn test_fmt_sql() -> Result<()> {
791        let schema = Schema::new(vec![Field::new("a", Int32, true)]);
792
793        // Test numeric casting
794        let expr = cast(col("a", &schema)?, &schema, Int64)?;
795        let display_string = expr.to_string();
796        assert_eq!(display_string, "CAST(a@0 AS Int64)");
797        let sql_string = fmt_sql(expr.as_ref()).to_string();
798        assert_eq!(sql_string, "CAST(a AS Int64)");
799
800        // Test string casting
801        let schema = Schema::new(vec![Field::new("b", Utf8, true)]);
802        let expr = cast(col("b", &schema)?, &schema, Int32)?;
803        let display_string = expr.to_string();
804        assert_eq!(display_string, "CAST(b@0 AS Int32)");
805        let sql_string = fmt_sql(expr.as_ref()).to_string();
806        assert_eq!(sql_string, "CAST(b AS Int32)");
807
808        Ok(())
809    }
810}