1use std::any::Any;
19use std::fmt;
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::physical_expr::PhysicalExpr;
24
25use arrow::compute::{can_cast_types, CastOptions};
26use arrow::datatypes::{DataType, DataType::*, Schema};
27use arrow::record_batch::RecordBatch;
28use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
29use datafusion_common::{not_impl_err, Result};
30use datafusion_expr_common::columnar_value::ColumnarValue;
31use datafusion_expr_common::interval_arithmetic::Interval;
32use datafusion_expr_common::sort_properties::ExprProperties;
33
34const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
35 safe: false,
36 format_options: DEFAULT_FORMAT_OPTIONS,
37};
38
39const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
40 safe: true,
41 format_options: DEFAULT_FORMAT_OPTIONS,
42};
43
44#[derive(Debug, Clone, Eq)]
46pub struct CastExpr {
47 pub expr: Arc<dyn PhysicalExpr>,
49 cast_type: DataType,
51 cast_options: CastOptions<'static>,
53}
54
55impl PartialEq for CastExpr {
57 fn eq(&self, other: &Self) -> bool {
58 self.expr.eq(&other.expr)
59 && self.cast_type.eq(&other.cast_type)
60 && self.cast_options.eq(&other.cast_options)
61 }
62}
63
64impl Hash for CastExpr {
65 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
66 self.expr.hash(state);
67 self.cast_type.hash(state);
68 self.cast_options.hash(state);
69 }
70}
71
72impl CastExpr {
73 pub fn new(
75 expr: Arc<dyn PhysicalExpr>,
76 cast_type: DataType,
77 cast_options: Option<CastOptions<'static>>,
78 ) -> Self {
79 Self {
80 expr,
81 cast_type,
82 cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
83 }
84 }
85
86 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
88 &self.expr
89 }
90
91 pub fn cast_type(&self) -> &DataType {
93 &self.cast_type
94 }
95
96 pub fn cast_options(&self) -> &CastOptions<'static> {
98 &self.cast_options
99 }
100 pub fn is_bigger_cast(&self, src: DataType) -> bool {
101 if src == self.cast_type {
102 return true;
103 }
104 matches!(
105 (src, &self.cast_type),
106 (Int8, Int16 | Int32 | Int64)
107 | (Int16, Int32 | Int64)
108 | (Int32, Int64)
109 | (UInt8, UInt16 | UInt32 | UInt64)
110 | (UInt16, UInt32 | UInt64)
111 | (UInt32, UInt64)
112 | (
113 Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32,
114 Float32 | Float64
115 )
116 | (Int64 | UInt64, Float64)
117 | (Utf8, LargeUtf8)
118 )
119 }
120}
121
122impl fmt::Display for CastExpr {
123 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124 write!(f, "CAST({} AS {:?})", self.expr, self.cast_type)
125 }
126}
127
128impl PhysicalExpr for CastExpr {
129 fn as_any(&self) -> &dyn Any {
131 self
132 }
133
134 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
135 Ok(self.cast_type.clone())
136 }
137
138 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
139 self.expr.nullable(input_schema)
140 }
141
142 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
143 let value = self.expr.evaluate(batch)?;
144 value.cast_to(&self.cast_type, Some(&self.cast_options))
145 }
146
147 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
148 vec![&self.expr]
149 }
150
151 fn with_new_children(
152 self: Arc<Self>,
153 children: Vec<Arc<dyn PhysicalExpr>>,
154 ) -> Result<Arc<dyn PhysicalExpr>> {
155 Ok(Arc::new(CastExpr::new(
156 Arc::clone(&children[0]),
157 self.cast_type.clone(),
158 Some(self.cast_options.clone()),
159 )))
160 }
161
162 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
163 children[0].cast_to(&self.cast_type, &self.cast_options)
165 }
166
167 fn propagate_constraints(
168 &self,
169 interval: &Interval,
170 children: &[&Interval],
171 ) -> Result<Option<Vec<Interval>>> {
172 let child_interval = children[0];
173 let cast_type = child_interval.data_type();
175 Ok(Some(vec![
176 interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?
177 ]))
178 }
179
180 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
183 let source_datatype = children[0].range.data_type();
184 let target_type = &self.cast_type;
185
186 let unbounded = Interval::make_unbounded(target_type)?;
187 if (source_datatype.is_numeric() || source_datatype == Boolean)
188 && target_type.is_numeric()
189 || source_datatype.is_temporal() && target_type.is_temporal()
190 || source_datatype.eq(target_type)
191 {
192 Ok(children[0].clone().with_range(unbounded))
193 } else {
194 Ok(ExprProperties::new_unknown().with_range(unbounded))
195 }
196 }
197
198 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 write!(f, "CAST(")?;
200 self.expr.fmt_sql(f)?;
201 write!(f, " AS {:?}", self.cast_type)?;
202
203 write!(f, ")")
204 }
205}
206
207pub fn cast_with_options(
212 expr: Arc<dyn PhysicalExpr>,
213 input_schema: &Schema,
214 cast_type: DataType,
215 cast_options: Option<CastOptions<'static>>,
216) -> Result<Arc<dyn PhysicalExpr>> {
217 let expr_type = expr.data_type(input_schema)?;
218 if expr_type == cast_type {
219 Ok(Arc::clone(&expr))
220 } else if can_cast_types(&expr_type, &cast_type) {
221 Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
222 } else {
223 not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}")
224 }
225}
226
227pub fn cast(
232 expr: Arc<dyn PhysicalExpr>,
233 input_schema: &Schema,
234 cast_type: DataType,
235) -> Result<Arc<dyn PhysicalExpr>> {
236 cast_with_options(expr, input_schema, cast_type, None)
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 use crate::expressions::column::col;
244
245 use arrow::{
246 array::{
247 Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
248 Int64Array, Int8Array, StringArray, Time64NanosecondArray,
249 TimestampNanosecondArray, UInt32Array,
250 },
251 datatypes::*,
252 };
253 use datafusion_common::assert_contains;
254 use datafusion_physical_expr_common::physical_expr::fmt_sql;
255
256 macro_rules! generic_decimal_to_other_test_cast {
263 ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{
264 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
265 let batch = RecordBatch::try_new(
266 Arc::new(schema.clone()),
267 vec![Arc::new($DECIMAL_ARRAY)],
268 )?;
269 let expression =
271 cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
272
273 assert_eq!(
275 format!("CAST(a@0 AS {:?})", $TYPE),
276 format!("{}", expression)
277 );
278
279 assert_eq!(expression.data_type(&schema)?, $TYPE);
281
282 let result = expression
284 .evaluate(&batch)?
285 .into_array(batch.num_rows())
286 .expect("Failed to convert to array");
287
288 assert_eq!(*result.data_type(), $TYPE);
290
291 let result = result
293 .as_any()
294 .downcast_ref::<$TYPEARRAY>()
295 .expect("failed to downcast");
296
297 for (i, x) in $VEC.iter().enumerate() {
299 match x {
300 Some(x) => assert_eq!(result.value(i), *x),
301 None => assert!(!result.is_valid(i)),
302 }
303 }
304 }};
305 }
306
307 macro_rules! generic_test_cast {
314 ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
315 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
316 let a_vec_len = $A_VEC.len();
317 let a = $A_ARRAY::from($A_VEC);
318 let batch =
319 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
320
321 let expression =
323 cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
324
325 assert_eq!(
327 format!("CAST(a@0 AS {:?})", $TYPE),
328 format!("{}", expression)
329 );
330
331 assert_eq!(expression.data_type(&schema)?, $TYPE);
333
334 let result = expression
336 .evaluate(&batch)?
337 .into_array(batch.num_rows())
338 .expect("Failed to convert to array");
339
340 assert_eq!(*result.data_type(), $TYPE);
342
343 assert_eq!(result.len(), a_vec_len);
345
346 let result = result
348 .as_any()
349 .downcast_ref::<$TYPEARRAY>()
350 .expect("failed to downcast");
351
352 for (i, x) in $VEC.iter().enumerate() {
354 match x {
355 Some(x) => assert_eq!(result.value(i), *x),
356 None => assert!(!result.is_valid(i)),
357 }
358 }
359 }};
360 }
361
362 #[test]
363 fn test_cast_decimal_to_decimal() -> Result<()> {
364 let array = vec![
365 Some(1234),
366 Some(2222),
367 Some(3),
368 Some(4000),
369 Some(5000),
370 None,
371 ];
372
373 let decimal_array = array
374 .clone()
375 .into_iter()
376 .collect::<Decimal128Array>()
377 .with_precision_and_scale(10, 3)?;
378
379 generic_decimal_to_other_test_cast!(
380 decimal_array,
381 Decimal128(10, 3),
382 Decimal128Array,
383 Decimal128(20, 6),
384 [
385 Some(1_234_000),
386 Some(2_222_000),
387 Some(3_000),
388 Some(4_000_000),
389 Some(5_000_000),
390 None
391 ],
392 None
393 );
394
395 let decimal_array = array
396 .into_iter()
397 .collect::<Decimal128Array>()
398 .with_precision_and_scale(10, 3)?;
399
400 generic_decimal_to_other_test_cast!(
401 decimal_array,
402 Decimal128(10, 3),
403 Decimal128Array,
404 Decimal128(10, 2),
405 [Some(123), Some(222), Some(0), Some(400), Some(500), None],
406 None
407 );
408
409 Ok(())
410 }
411
412 #[test]
413 fn test_cast_decimal_to_decimal_overflow() -> Result<()> {
414 let array = vec![Some(123456789)];
415
416 let decimal_array = array
417 .clone()
418 .into_iter()
419 .collect::<Decimal128Array>()
420 .with_precision_and_scale(10, 3)?;
421
422 let schema = Schema::new(vec![Field::new("a", Decimal128(10, 3), false)]);
423 let batch = RecordBatch::try_new(
424 Arc::new(schema.clone()),
425 vec![Arc::new(decimal_array)],
426 )?;
427 let expression =
428 cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?;
429 let e = expression.evaluate(&batch).unwrap_err(); assert_contains!(
431 e.to_string(),
432 "Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999"
433 );
434
435 let expression_safe = cast_with_options(
436 col("a", &schema)?,
437 &schema,
438 Decimal128(6, 2),
439 Some(DEFAULT_SAFE_CAST_OPTIONS),
440 )?;
441 let result_safe = expression_safe
442 .evaluate(&batch)?
443 .into_array(batch.num_rows())
444 .expect("failed to convert to array");
445
446 assert!(result_safe.is_null(0));
447
448 Ok(())
449 }
450
451 #[test]
452 fn test_cast_decimal_to_numeric() -> Result<()> {
453 let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
454 let decimal_array = array
456 .clone()
457 .into_iter()
458 .collect::<Decimal128Array>()
459 .with_precision_and_scale(10, 0)?;
460 generic_decimal_to_other_test_cast!(
461 decimal_array,
462 Decimal128(10, 0),
463 Int8Array,
464 Int8,
465 [
466 Some(1_i8),
467 Some(2_i8),
468 Some(3_i8),
469 Some(4_i8),
470 Some(5_i8),
471 None
472 ],
473 None
474 );
475
476 let decimal_array = array
478 .clone()
479 .into_iter()
480 .collect::<Decimal128Array>()
481 .with_precision_and_scale(10, 0)?;
482 generic_decimal_to_other_test_cast!(
483 decimal_array,
484 Decimal128(10, 0),
485 Int16Array,
486 Int16,
487 [
488 Some(1_i16),
489 Some(2_i16),
490 Some(3_i16),
491 Some(4_i16),
492 Some(5_i16),
493 None
494 ],
495 None
496 );
497
498 let decimal_array = array
500 .clone()
501 .into_iter()
502 .collect::<Decimal128Array>()
503 .with_precision_and_scale(10, 0)?;
504 generic_decimal_to_other_test_cast!(
505 decimal_array,
506 Decimal128(10, 0),
507 Int32Array,
508 Int32,
509 [
510 Some(1_i32),
511 Some(2_i32),
512 Some(3_i32),
513 Some(4_i32),
514 Some(5_i32),
515 None
516 ],
517 None
518 );
519
520 let decimal_array = array
522 .into_iter()
523 .collect::<Decimal128Array>()
524 .with_precision_and_scale(10, 0)?;
525 generic_decimal_to_other_test_cast!(
526 decimal_array,
527 Decimal128(10, 0),
528 Int64Array,
529 Int64,
530 [
531 Some(1_i64),
532 Some(2_i64),
533 Some(3_i64),
534 Some(4_i64),
535 Some(5_i64),
536 None
537 ],
538 None
539 );
540
541 let array = vec![
543 Some(1234),
544 Some(2222),
545 Some(3),
546 Some(4000),
547 Some(5000),
548 None,
549 ];
550 let decimal_array = array
551 .clone()
552 .into_iter()
553 .collect::<Decimal128Array>()
554 .with_precision_and_scale(10, 3)?;
555 generic_decimal_to_other_test_cast!(
556 decimal_array,
557 Decimal128(10, 3),
558 Float32Array,
559 Float32,
560 [
561 Some(1.234_f32),
562 Some(2.222_f32),
563 Some(0.003_f32),
564 Some(4.0_f32),
565 Some(5.0_f32),
566 None
567 ],
568 None
569 );
570
571 let decimal_array = array
573 .into_iter()
574 .collect::<Decimal128Array>()
575 .with_precision_and_scale(20, 6)?;
576 generic_decimal_to_other_test_cast!(
577 decimal_array,
578 Decimal128(20, 6),
579 Float64Array,
580 Float64,
581 [
582 Some(0.001234_f64),
583 Some(0.002222_f64),
584 Some(0.000003_f64),
585 Some(0.004_f64),
586 Some(0.005_f64),
587 None
588 ],
589 None
590 );
591 Ok(())
592 }
593
594 #[test]
595 fn test_cast_numeric_to_decimal() -> Result<()> {
596 generic_test_cast!(
598 Int8Array,
599 Int8,
600 vec![1, 2, 3, 4, 5],
601 Decimal128Array,
602 Decimal128(3, 0),
603 [Some(1), Some(2), Some(3), Some(4), Some(5)],
604 None
605 );
606
607 generic_test_cast!(
609 Int16Array,
610 Int16,
611 vec![1, 2, 3, 4, 5],
612 Decimal128Array,
613 Decimal128(5, 0),
614 [Some(1), Some(2), Some(3), Some(4), Some(5)],
615 None
616 );
617
618 generic_test_cast!(
620 Int32Array,
621 Int32,
622 vec![1, 2, 3, 4, 5],
623 Decimal128Array,
624 Decimal128(10, 0),
625 [Some(1), Some(2), Some(3), Some(4), Some(5)],
626 None
627 );
628
629 generic_test_cast!(
631 Int64Array,
632 Int64,
633 vec![1, 2, 3, 4, 5],
634 Decimal128Array,
635 Decimal128(20, 0),
636 [Some(1), Some(2), Some(3), Some(4), Some(5)],
637 None
638 );
639
640 generic_test_cast!(
642 Int64Array,
643 Int64,
644 vec![1, 2, 3, 4, 5],
645 Decimal128Array,
646 Decimal128(20, 2),
647 [Some(100), Some(200), Some(300), Some(400), Some(500)],
648 None
649 );
650
651 generic_test_cast!(
653 Float32Array,
654 Float32,
655 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
656 Decimal128Array,
657 Decimal128(10, 2),
658 [Some(150), Some(250), Some(300), Some(112), Some(550)],
659 None
660 );
661
662 generic_test_cast!(
664 Float64Array,
665 Float64,
666 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
667 Decimal128Array,
668 Decimal128(20, 4),
669 [
670 Some(15000),
671 Some(25000),
672 Some(30000),
673 Some(11235),
674 Some(55000)
675 ],
676 None
677 );
678 Ok(())
679 }
680
681 #[test]
682 fn test_cast_i32_u32() -> Result<()> {
683 generic_test_cast!(
684 Int32Array,
685 Int32,
686 vec![1, 2, 3, 4, 5],
687 UInt32Array,
688 UInt32,
689 [
690 Some(1_u32),
691 Some(2_u32),
692 Some(3_u32),
693 Some(4_u32),
694 Some(5_u32)
695 ],
696 None
697 );
698 Ok(())
699 }
700
701 #[test]
702 fn test_cast_i32_utf8() -> Result<()> {
703 generic_test_cast!(
704 Int32Array,
705 Int32,
706 vec![1, 2, 3, 4, 5],
707 StringArray,
708 Utf8,
709 [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
710 None
711 );
712 Ok(())
713 }
714
715 #[test]
716 fn test_cast_i64_t64() -> Result<()> {
717 let original = vec![1, 2, 3, 4, 5];
718 let expected: Vec<Option<i64>> = original
719 .iter()
720 .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
721 .collect();
722 generic_test_cast!(
723 Int64Array,
724 Int64,
725 original,
726 TimestampNanosecondArray,
727 Timestamp(TimeUnit::Nanosecond, None),
728 expected,
729 None
730 );
731 Ok(())
732 }
733
734 #[test]
735 fn invalid_cast() {
736 let schema = Schema::new(vec![Field::new("a", Int32, false)]);
738
739 let result = cast(
740 col("a", &schema).unwrap(),
741 &schema,
742 Interval(IntervalUnit::MonthDayNano),
743 );
744 result.expect_err("expected Invalid CAST");
745 }
746
747 #[test]
748 fn invalid_cast_with_options_error() -> Result<()> {
749 let schema = Schema::new(vec![Field::new("a", Utf8, false)]);
751 let a = StringArray::from(vec!["9.1"]);
752 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
753 let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?;
754 let result = expression.evaluate(&batch);
755
756 match result {
757 Ok(_) => panic!("expected error"),
758 Err(e) => {
759 assert!(e
760 .to_string()
761 .contains("Cannot cast string '9.1' to value of Int32 type"))
762 }
763 }
764 Ok(())
765 }
766
767 #[test]
768 #[ignore] fn test_cast_decimal() -> Result<()> {
770 let schema = Schema::new(vec![Field::new("a", Int64, false)]);
771 let a = Int64Array::from(vec![100]);
772 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
773 let expression =
774 cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?;
775 expression.evaluate(&batch)?;
776 Ok(())
777 }
778
779 #[test]
780 fn test_fmt_sql() -> Result<()> {
781 let schema = Schema::new(vec![Field::new("a", Int32, true)]);
782
783 let expr = cast(col("a", &schema)?, &schema, Int64)?;
785 let display_string = expr.to_string();
786 assert_eq!(display_string, "CAST(a@0 AS Int64)");
787 let sql_string = fmt_sql(expr.as_ref()).to_string();
788 assert_eq!(sql_string, "CAST(a AS Int64)");
789
790 let schema = Schema::new(vec![Field::new("b", Utf8, true)]);
792 let expr = cast(col("b", &schema)?, &schema, Int32)?;
793 let display_string = expr.to_string();
794 assert_eq!(display_string, "CAST(b@0 AS Int32)");
795 let sql_string = fmt_sql(expr.as_ref()).to_string();
796 assert_eq!(sql_string, "CAST(b AS Int32)");
797
798 Ok(())
799 }
800}