1use std::any::Any;
19use std::fmt;
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::physical_expr::PhysicalExpr;
24
25use arrow::compute::{can_cast_types, CastOptions};
26use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema};
27use arrow::record_batch::RecordBatch;
28use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
29use datafusion_common::{not_impl_err, Result};
30use datafusion_expr_common::columnar_value::ColumnarValue;
31use datafusion_expr_common::interval_arithmetic::Interval;
32use datafusion_expr_common::sort_properties::ExprProperties;
33
34const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
35 safe: false,
36 format_options: DEFAULT_FORMAT_OPTIONS,
37};
38
39const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
40 safe: true,
41 format_options: DEFAULT_FORMAT_OPTIONS,
42};
43
44#[derive(Debug, Clone, Eq)]
46pub struct CastExpr {
47 pub expr: Arc<dyn PhysicalExpr>,
49 cast_type: DataType,
51 cast_options: CastOptions<'static>,
53}
54
55impl PartialEq for CastExpr {
57 fn eq(&self, other: &Self) -> bool {
58 self.expr.eq(&other.expr)
59 && self.cast_type.eq(&other.cast_type)
60 && self.cast_options.eq(&other.cast_options)
61 }
62}
63
64impl Hash for CastExpr {
65 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
66 self.expr.hash(state);
67 self.cast_type.hash(state);
68 self.cast_options.hash(state);
69 }
70}
71
72impl CastExpr {
73 pub fn new(
75 expr: Arc<dyn PhysicalExpr>,
76 cast_type: DataType,
77 cast_options: Option<CastOptions<'static>>,
78 ) -> Self {
79 Self {
80 expr,
81 cast_type,
82 cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
83 }
84 }
85
86 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
88 &self.expr
89 }
90
91 pub fn cast_type(&self) -> &DataType {
93 &self.cast_type
94 }
95
96 pub fn cast_options(&self) -> &CastOptions<'static> {
98 &self.cast_options
99 }
100 pub fn is_bigger_cast(&self, src: DataType) -> bool {
101 if src == self.cast_type {
102 return true;
103 }
104 matches!(
105 (src, &self.cast_type),
106 (Int8, Int16 | Int32 | Int64)
107 | (Int16, Int32 | Int64)
108 | (Int32, Int64)
109 | (UInt8, UInt16 | UInt32 | UInt64)
110 | (UInt16, UInt32 | UInt64)
111 | (UInt32, UInt64)
112 | (
113 Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32,
114 Float32 | Float64
115 )
116 | (Int64 | UInt64, Float64)
117 | (Utf8, LargeUtf8)
118 )
119 }
120}
121
122impl fmt::Display for CastExpr {
123 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124 write!(f, "CAST({} AS {:?})", self.expr, self.cast_type)
125 }
126}
127
128impl PhysicalExpr for CastExpr {
129 fn as_any(&self) -> &dyn Any {
131 self
132 }
133
134 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
135 Ok(self.cast_type.clone())
136 }
137
138 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
139 self.expr.nullable(input_schema)
140 }
141
142 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
143 let value = self.expr.evaluate(batch)?;
144 value.cast_to(&self.cast_type, Some(&self.cast_options))
145 }
146
147 fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
148 Ok(self
149 .expr
150 .return_field(input_schema)?
151 .as_ref()
152 .clone()
153 .with_data_type(self.cast_type.clone())
154 .into())
155 }
156
157 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
158 vec![&self.expr]
159 }
160
161 fn with_new_children(
162 self: Arc<Self>,
163 children: Vec<Arc<dyn PhysicalExpr>>,
164 ) -> Result<Arc<dyn PhysicalExpr>> {
165 Ok(Arc::new(CastExpr::new(
166 Arc::clone(&children[0]),
167 self.cast_type.clone(),
168 Some(self.cast_options.clone()),
169 )))
170 }
171
172 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
173 children[0].cast_to(&self.cast_type, &self.cast_options)
175 }
176
177 fn propagate_constraints(
178 &self,
179 interval: &Interval,
180 children: &[&Interval],
181 ) -> Result<Option<Vec<Interval>>> {
182 let child_interval = children[0];
183 let cast_type = child_interval.data_type();
185 Ok(Some(vec![
186 interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?
187 ]))
188 }
189
190 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
193 let source_datatype = children[0].range.data_type();
194 let target_type = &self.cast_type;
195
196 let unbounded = Interval::make_unbounded(target_type)?;
197 if (source_datatype.is_numeric() || source_datatype == Boolean)
198 && target_type.is_numeric()
199 || source_datatype.is_temporal() && target_type.is_temporal()
200 || source_datatype.eq(target_type)
201 {
202 Ok(children[0].clone().with_range(unbounded))
203 } else {
204 Ok(ExprProperties::new_unknown().with_range(unbounded))
205 }
206 }
207
208 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 write!(f, "CAST(")?;
210 self.expr.fmt_sql(f)?;
211 write!(f, " AS {:?}", self.cast_type)?;
212
213 write!(f, ")")
214 }
215}
216
217pub fn cast_with_options(
222 expr: Arc<dyn PhysicalExpr>,
223 input_schema: &Schema,
224 cast_type: DataType,
225 cast_options: Option<CastOptions<'static>>,
226) -> Result<Arc<dyn PhysicalExpr>> {
227 let expr_type = expr.data_type(input_schema)?;
228 if expr_type == cast_type {
229 Ok(Arc::clone(&expr))
230 } else if can_cast_types(&expr_type, &cast_type) {
231 Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
232 } else {
233 not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}")
234 }
235}
236
237pub fn cast(
242 expr: Arc<dyn PhysicalExpr>,
243 input_schema: &Schema,
244 cast_type: DataType,
245) -> Result<Arc<dyn PhysicalExpr>> {
246 cast_with_options(expr, input_schema, cast_type, None)
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 use crate::expressions::column::col;
254
255 use arrow::{
256 array::{
257 Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
258 Int64Array, Int8Array, StringArray, Time64NanosecondArray,
259 TimestampNanosecondArray, UInt32Array,
260 },
261 datatypes::*,
262 };
263 use datafusion_common::assert_contains;
264 use datafusion_physical_expr_common::physical_expr::fmt_sql;
265
266 macro_rules! generic_decimal_to_other_test_cast {
273 ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{
274 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
275 let batch = RecordBatch::try_new(
276 Arc::new(schema.clone()),
277 vec![Arc::new($DECIMAL_ARRAY)],
278 )?;
279 let expression =
281 cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
282
283 assert_eq!(
285 format!("CAST(a@0 AS {:?})", $TYPE),
286 format!("{}", expression)
287 );
288
289 assert_eq!(expression.data_type(&schema)?, $TYPE);
291
292 let result = expression
294 .evaluate(&batch)?
295 .into_array(batch.num_rows())
296 .expect("Failed to convert to array");
297
298 assert_eq!(*result.data_type(), $TYPE);
300
301 let result = result
303 .as_any()
304 .downcast_ref::<$TYPEARRAY>()
305 .expect("failed to downcast");
306
307 for (i, x) in $VEC.iter().enumerate() {
309 match x {
310 Some(x) => assert_eq!(result.value(i), *x),
311 None => assert!(!result.is_valid(i)),
312 }
313 }
314 }};
315 }
316
317 macro_rules! generic_test_cast {
324 ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
325 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
326 let a_vec_len = $A_VEC.len();
327 let a = $A_ARRAY::from($A_VEC);
328 let batch =
329 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
330
331 let expression =
333 cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
334
335 assert_eq!(
337 format!("CAST(a@0 AS {:?})", $TYPE),
338 format!("{}", expression)
339 );
340
341 assert_eq!(expression.data_type(&schema)?, $TYPE);
343
344 let result = expression
346 .evaluate(&batch)?
347 .into_array(batch.num_rows())
348 .expect("Failed to convert to array");
349
350 assert_eq!(*result.data_type(), $TYPE);
352
353 assert_eq!(result.len(), a_vec_len);
355
356 let result = result
358 .as_any()
359 .downcast_ref::<$TYPEARRAY>()
360 .expect("failed to downcast");
361
362 for (i, x) in $VEC.iter().enumerate() {
364 match x {
365 Some(x) => assert_eq!(result.value(i), *x),
366 None => assert!(!result.is_valid(i)),
367 }
368 }
369 }};
370 }
371
372 #[test]
373 fn test_cast_decimal_to_decimal() -> Result<()> {
374 let array = vec![
375 Some(1234),
376 Some(2222),
377 Some(3),
378 Some(4000),
379 Some(5000),
380 None,
381 ];
382
383 let decimal_array = array
384 .clone()
385 .into_iter()
386 .collect::<Decimal128Array>()
387 .with_precision_and_scale(10, 3)?;
388
389 generic_decimal_to_other_test_cast!(
390 decimal_array,
391 Decimal128(10, 3),
392 Decimal128Array,
393 Decimal128(20, 6),
394 [
395 Some(1_234_000),
396 Some(2_222_000),
397 Some(3_000),
398 Some(4_000_000),
399 Some(5_000_000),
400 None
401 ],
402 None
403 );
404
405 let decimal_array = array
406 .into_iter()
407 .collect::<Decimal128Array>()
408 .with_precision_and_scale(10, 3)?;
409
410 generic_decimal_to_other_test_cast!(
411 decimal_array,
412 Decimal128(10, 3),
413 Decimal128Array,
414 Decimal128(10, 2),
415 [Some(123), Some(222), Some(0), Some(400), Some(500), None],
416 None
417 );
418
419 Ok(())
420 }
421
422 #[test]
423 fn test_cast_decimal_to_decimal_overflow() -> Result<()> {
424 let array = vec![Some(123456789)];
425
426 let decimal_array = array
427 .clone()
428 .into_iter()
429 .collect::<Decimal128Array>()
430 .with_precision_and_scale(10, 3)?;
431
432 let schema = Schema::new(vec![Field::new("a", Decimal128(10, 3), false)]);
433 let batch = RecordBatch::try_new(
434 Arc::new(schema.clone()),
435 vec![Arc::new(decimal_array)],
436 )?;
437 let expression =
438 cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?;
439 let e = expression.evaluate(&batch).unwrap_err(); assert_contains!(
441 e.to_string(),
442 "Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999"
443 );
444
445 let expression_safe = cast_with_options(
446 col("a", &schema)?,
447 &schema,
448 Decimal128(6, 2),
449 Some(DEFAULT_SAFE_CAST_OPTIONS),
450 )?;
451 let result_safe = expression_safe
452 .evaluate(&batch)?
453 .into_array(batch.num_rows())
454 .expect("failed to convert to array");
455
456 assert!(result_safe.is_null(0));
457
458 Ok(())
459 }
460
461 #[test]
462 fn test_cast_decimal_to_numeric() -> Result<()> {
463 let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
464 let decimal_array = array
466 .clone()
467 .into_iter()
468 .collect::<Decimal128Array>()
469 .with_precision_and_scale(10, 0)?;
470 generic_decimal_to_other_test_cast!(
471 decimal_array,
472 Decimal128(10, 0),
473 Int8Array,
474 Int8,
475 [
476 Some(1_i8),
477 Some(2_i8),
478 Some(3_i8),
479 Some(4_i8),
480 Some(5_i8),
481 None
482 ],
483 None
484 );
485
486 let decimal_array = array
488 .clone()
489 .into_iter()
490 .collect::<Decimal128Array>()
491 .with_precision_and_scale(10, 0)?;
492 generic_decimal_to_other_test_cast!(
493 decimal_array,
494 Decimal128(10, 0),
495 Int16Array,
496 Int16,
497 [
498 Some(1_i16),
499 Some(2_i16),
500 Some(3_i16),
501 Some(4_i16),
502 Some(5_i16),
503 None
504 ],
505 None
506 );
507
508 let decimal_array = array
510 .clone()
511 .into_iter()
512 .collect::<Decimal128Array>()
513 .with_precision_and_scale(10, 0)?;
514 generic_decimal_to_other_test_cast!(
515 decimal_array,
516 Decimal128(10, 0),
517 Int32Array,
518 Int32,
519 [
520 Some(1_i32),
521 Some(2_i32),
522 Some(3_i32),
523 Some(4_i32),
524 Some(5_i32),
525 None
526 ],
527 None
528 );
529
530 let decimal_array = array
532 .into_iter()
533 .collect::<Decimal128Array>()
534 .with_precision_and_scale(10, 0)?;
535 generic_decimal_to_other_test_cast!(
536 decimal_array,
537 Decimal128(10, 0),
538 Int64Array,
539 Int64,
540 [
541 Some(1_i64),
542 Some(2_i64),
543 Some(3_i64),
544 Some(4_i64),
545 Some(5_i64),
546 None
547 ],
548 None
549 );
550
551 let array = vec![
553 Some(1234),
554 Some(2222),
555 Some(3),
556 Some(4000),
557 Some(5000),
558 None,
559 ];
560 let decimal_array = array
561 .clone()
562 .into_iter()
563 .collect::<Decimal128Array>()
564 .with_precision_and_scale(10, 3)?;
565 generic_decimal_to_other_test_cast!(
566 decimal_array,
567 Decimal128(10, 3),
568 Float32Array,
569 Float32,
570 [
571 Some(1.234_f32),
572 Some(2.222_f32),
573 Some(0.003_f32),
574 Some(4.0_f32),
575 Some(5.0_f32),
576 None
577 ],
578 None
579 );
580
581 let decimal_array = array
583 .into_iter()
584 .collect::<Decimal128Array>()
585 .with_precision_and_scale(20, 6)?;
586 generic_decimal_to_other_test_cast!(
587 decimal_array,
588 Decimal128(20, 6),
589 Float64Array,
590 Float64,
591 [
592 Some(0.001234_f64),
593 Some(0.002222_f64),
594 Some(0.000003_f64),
595 Some(0.004_f64),
596 Some(0.005_f64),
597 None
598 ],
599 None
600 );
601 Ok(())
602 }
603
604 #[test]
605 fn test_cast_numeric_to_decimal() -> Result<()> {
606 generic_test_cast!(
608 Int8Array,
609 Int8,
610 vec![1, 2, 3, 4, 5],
611 Decimal128Array,
612 Decimal128(3, 0),
613 [Some(1), Some(2), Some(3), Some(4), Some(5)],
614 None
615 );
616
617 generic_test_cast!(
619 Int16Array,
620 Int16,
621 vec![1, 2, 3, 4, 5],
622 Decimal128Array,
623 Decimal128(5, 0),
624 [Some(1), Some(2), Some(3), Some(4), Some(5)],
625 None
626 );
627
628 generic_test_cast!(
630 Int32Array,
631 Int32,
632 vec![1, 2, 3, 4, 5],
633 Decimal128Array,
634 Decimal128(10, 0),
635 [Some(1), Some(2), Some(3), Some(4), Some(5)],
636 None
637 );
638
639 generic_test_cast!(
641 Int64Array,
642 Int64,
643 vec![1, 2, 3, 4, 5],
644 Decimal128Array,
645 Decimal128(20, 0),
646 [Some(1), Some(2), Some(3), Some(4), Some(5)],
647 None
648 );
649
650 generic_test_cast!(
652 Int64Array,
653 Int64,
654 vec![1, 2, 3, 4, 5],
655 Decimal128Array,
656 Decimal128(20, 2),
657 [Some(100), Some(200), Some(300), Some(400), Some(500)],
658 None
659 );
660
661 generic_test_cast!(
663 Float32Array,
664 Float32,
665 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
666 Decimal128Array,
667 Decimal128(10, 2),
668 [Some(150), Some(250), Some(300), Some(112), Some(550)],
669 None
670 );
671
672 generic_test_cast!(
674 Float64Array,
675 Float64,
676 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
677 Decimal128Array,
678 Decimal128(20, 4),
679 [
680 Some(15000),
681 Some(25000),
682 Some(30000),
683 Some(11235),
684 Some(55000)
685 ],
686 None
687 );
688 Ok(())
689 }
690
691 #[test]
692 fn test_cast_i32_u32() -> Result<()> {
693 generic_test_cast!(
694 Int32Array,
695 Int32,
696 vec![1, 2, 3, 4, 5],
697 UInt32Array,
698 UInt32,
699 [
700 Some(1_u32),
701 Some(2_u32),
702 Some(3_u32),
703 Some(4_u32),
704 Some(5_u32)
705 ],
706 None
707 );
708 Ok(())
709 }
710
711 #[test]
712 fn test_cast_i32_utf8() -> Result<()> {
713 generic_test_cast!(
714 Int32Array,
715 Int32,
716 vec![1, 2, 3, 4, 5],
717 StringArray,
718 Utf8,
719 [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
720 None
721 );
722 Ok(())
723 }
724
725 #[test]
726 fn test_cast_i64_t64() -> Result<()> {
727 let original = vec![1, 2, 3, 4, 5];
728 let expected: Vec<Option<i64>> = original
729 .iter()
730 .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
731 .collect();
732 generic_test_cast!(
733 Int64Array,
734 Int64,
735 original,
736 TimestampNanosecondArray,
737 Timestamp(TimeUnit::Nanosecond, None),
738 expected,
739 None
740 );
741 Ok(())
742 }
743
744 #[test]
745 fn invalid_cast() {
746 let schema = Schema::new(vec![Field::new("a", Int32, false)]);
748
749 let result = cast(
750 col("a", &schema).unwrap(),
751 &schema,
752 Interval(IntervalUnit::MonthDayNano),
753 );
754 result.expect_err("expected Invalid CAST");
755 }
756
757 #[test]
758 fn invalid_cast_with_options_error() -> Result<()> {
759 let schema = Schema::new(vec![Field::new("a", Utf8, false)]);
761 let a = StringArray::from(vec!["9.1"]);
762 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
763 let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?;
764 let result = expression.evaluate(&batch);
765
766 match result {
767 Ok(_) => panic!("expected error"),
768 Err(e) => {
769 assert!(e
770 .to_string()
771 .contains("Cannot cast string '9.1' to value of Int32 type"))
772 }
773 }
774 Ok(())
775 }
776
777 #[test]
778 #[ignore] fn test_cast_decimal() -> Result<()> {
780 let schema = Schema::new(vec![Field::new("a", Int64, false)]);
781 let a = Int64Array::from(vec![100]);
782 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
783 let expression =
784 cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?;
785 expression.evaluate(&batch)?;
786 Ok(())
787 }
788
789 #[test]
790 fn test_fmt_sql() -> Result<()> {
791 let schema = Schema::new(vec![Field::new("a", Int32, true)]);
792
793 let expr = cast(col("a", &schema)?, &schema, Int64)?;
795 let display_string = expr.to_string();
796 assert_eq!(display_string, "CAST(a@0 AS Int64)");
797 let sql_string = fmt_sql(expr.as_ref()).to_string();
798 assert_eq!(sql_string, "CAST(a AS Int64)");
799
800 let schema = Schema::new(vec![Field::new("b", Utf8, true)]);
802 let expr = cast(col("b", &schema)?, &schema, Int32)?;
803 let display_string = expr.to_string();
804 assert_eq!(display_string, "CAST(b@0 AS Int32)");
805 let sql_string = fmt_sql(expr.as_ref()).to_string();
806 assert_eq!(sql_string, "CAST(b AS Int32)");
807
808 Ok(())
809 }
810}