1use std::any::Any;
19use std::fmt;
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::physical_expr::PhysicalExpr;
24
25use arrow::compute::{can_cast_types, CastOptions};
26use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema};
27use arrow::record_batch::RecordBatch;
28use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
29use datafusion_common::{not_impl_err, Result};
30use datafusion_expr_common::columnar_value::ColumnarValue;
31use datafusion_expr_common::interval_arithmetic::Interval;
32use datafusion_expr_common::sort_properties::ExprProperties;
33
34const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
35 safe: false,
36 format_options: DEFAULT_FORMAT_OPTIONS,
37};
38
39const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
40 safe: true,
41 format_options: DEFAULT_FORMAT_OPTIONS,
42};
43
44#[derive(Debug, Clone, Eq)]
46pub struct CastExpr {
47 pub expr: Arc<dyn PhysicalExpr>,
49 cast_type: DataType,
51 cast_options: CastOptions<'static>,
53}
54
55impl PartialEq for CastExpr {
57 fn eq(&self, other: &Self) -> bool {
58 self.expr.eq(&other.expr)
59 && self.cast_type.eq(&other.cast_type)
60 && self.cast_options.eq(&other.cast_options)
61 }
62}
63
64impl Hash for CastExpr {
65 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
66 self.expr.hash(state);
67 self.cast_type.hash(state);
68 self.cast_options.hash(state);
69 }
70}
71
72impl CastExpr {
73 pub fn new(
75 expr: Arc<dyn PhysicalExpr>,
76 cast_type: DataType,
77 cast_options: Option<CastOptions<'static>>,
78 ) -> Self {
79 Self {
80 expr,
81 cast_type,
82 cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
83 }
84 }
85
86 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
88 &self.expr
89 }
90
91 pub fn cast_type(&self) -> &DataType {
93 &self.cast_type
94 }
95
96 pub fn cast_options(&self) -> &CastOptions<'static> {
98 &self.cast_options
99 }
100
101 pub fn is_bigger_cast(&self, src: &DataType) -> bool {
103 if self.cast_type.eq(src) {
104 return true;
105 }
106 matches!(
107 (src, &self.cast_type),
108 (Int8, Int16 | Int32 | Int64)
109 | (Int16, Int32 | Int64)
110 | (Int32, Int64)
111 | (UInt8, UInt16 | UInt32 | UInt64)
112 | (UInt16, UInt32 | UInt64)
113 | (UInt32, UInt64)
114 | (
115 Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32,
116 Float32 | Float64
117 )
118 | (Int64 | UInt64, Float64)
119 | (Utf8, LargeUtf8)
120 )
121 }
122}
123
124impl fmt::Display for CastExpr {
125 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126 write!(f, "CAST({} AS {:?})", self.expr, self.cast_type)
127 }
128}
129
130impl PhysicalExpr for CastExpr {
131 fn as_any(&self) -> &dyn Any {
133 self
134 }
135
136 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
137 Ok(self.cast_type.clone())
138 }
139
140 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
141 self.expr.nullable(input_schema)
142 }
143
144 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
145 let value = self.expr.evaluate(batch)?;
146 value.cast_to(&self.cast_type, Some(&self.cast_options))
147 }
148
149 fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
150 Ok(self
151 .expr
152 .return_field(input_schema)?
153 .as_ref()
154 .clone()
155 .with_data_type(self.cast_type.clone())
156 .into())
157 }
158
159 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
160 vec![&self.expr]
161 }
162
163 fn with_new_children(
164 self: Arc<Self>,
165 children: Vec<Arc<dyn PhysicalExpr>>,
166 ) -> Result<Arc<dyn PhysicalExpr>> {
167 Ok(Arc::new(CastExpr::new(
168 Arc::clone(&children[0]),
169 self.cast_type.clone(),
170 Some(self.cast_options.clone()),
171 )))
172 }
173
174 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
175 children[0].cast_to(&self.cast_type, &self.cast_options)
177 }
178
179 fn propagate_constraints(
180 &self,
181 interval: &Interval,
182 children: &[&Interval],
183 ) -> Result<Option<Vec<Interval>>> {
184 let child_interval = children[0];
185 let cast_type = child_interval.data_type();
187 Ok(Some(vec![
188 interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?
189 ]))
190 }
191
192 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
195 let source_datatype = children[0].range.data_type();
196 let target_type = &self.cast_type;
197
198 let unbounded = Interval::make_unbounded(target_type)?;
199 if (source_datatype.is_numeric() || source_datatype == Boolean)
200 && target_type.is_numeric()
201 || source_datatype.is_temporal() && target_type.is_temporal()
202 || source_datatype.eq(target_type)
203 {
204 Ok(children[0].clone().with_range(unbounded))
205 } else {
206 Ok(ExprProperties::new_unknown().with_range(unbounded))
207 }
208 }
209
210 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 write!(f, "CAST(")?;
212 self.expr.fmt_sql(f)?;
213 write!(f, " AS {:?}", self.cast_type)?;
214
215 write!(f, ")")
216 }
217}
218
219pub fn cast_with_options(
224 expr: Arc<dyn PhysicalExpr>,
225 input_schema: &Schema,
226 cast_type: DataType,
227 cast_options: Option<CastOptions<'static>>,
228) -> Result<Arc<dyn PhysicalExpr>> {
229 let expr_type = expr.data_type(input_schema)?;
230 if expr_type == cast_type {
231 Ok(Arc::clone(&expr))
232 } else if can_cast_types(&expr_type, &cast_type) {
233 Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
234 } else {
235 not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}")
236 }
237}
238
239pub fn cast(
244 expr: Arc<dyn PhysicalExpr>,
245 input_schema: &Schema,
246 cast_type: DataType,
247) -> Result<Arc<dyn PhysicalExpr>> {
248 cast_with_options(expr, input_schema, cast_type, None)
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 use crate::expressions::column::col;
256
257 use arrow::{
258 array::{
259 Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
260 Int64Array, Int8Array, StringArray, Time64NanosecondArray,
261 TimestampNanosecondArray, UInt32Array,
262 },
263 datatypes::*,
264 };
265 use datafusion_common::assert_contains;
266 use datafusion_physical_expr_common::physical_expr::fmt_sql;
267
268 macro_rules! generic_decimal_to_other_test_cast {
275 ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{
276 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
277 let batch = RecordBatch::try_new(
278 Arc::new(schema.clone()),
279 vec![Arc::new($DECIMAL_ARRAY)],
280 )?;
281 let expression =
283 cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
284
285 assert_eq!(
287 format!("CAST(a@0 AS {:?})", $TYPE),
288 format!("{}", expression)
289 );
290
291 assert_eq!(expression.data_type(&schema)?, $TYPE);
293
294 let result = expression
296 .evaluate(&batch)?
297 .into_array(batch.num_rows())
298 .expect("Failed to convert to array");
299
300 assert_eq!(*result.data_type(), $TYPE);
302
303 let result = result
305 .as_any()
306 .downcast_ref::<$TYPEARRAY>()
307 .expect("failed to downcast");
308
309 for (i, x) in $VEC.iter().enumerate() {
311 match x {
312 Some(x) => assert_eq!(result.value(i), *x),
313 None => assert!(!result.is_valid(i)),
314 }
315 }
316 }};
317 }
318
319 macro_rules! generic_test_cast {
326 ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
327 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
328 let a_vec_len = $A_VEC.len();
329 let a = $A_ARRAY::from($A_VEC);
330 let batch =
331 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
332
333 let expression =
335 cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
336
337 assert_eq!(
339 format!("CAST(a@0 AS {:?})", $TYPE),
340 format!("{}", expression)
341 );
342
343 assert_eq!(expression.data_type(&schema)?, $TYPE);
345
346 let result = expression
348 .evaluate(&batch)?
349 .into_array(batch.num_rows())
350 .expect("Failed to convert to array");
351
352 assert_eq!(*result.data_type(), $TYPE);
354
355 assert_eq!(result.len(), a_vec_len);
357
358 let result = result
360 .as_any()
361 .downcast_ref::<$TYPEARRAY>()
362 .expect("failed to downcast");
363
364 for (i, x) in $VEC.iter().enumerate() {
366 match x {
367 Some(x) => assert_eq!(result.value(i), *x),
368 None => assert!(!result.is_valid(i)),
369 }
370 }
371 }};
372 }
373
374 #[test]
375 fn test_cast_decimal_to_decimal() -> Result<()> {
376 let array = vec![
377 Some(1234),
378 Some(2222),
379 Some(3),
380 Some(4000),
381 Some(5000),
382 None,
383 ];
384
385 let decimal_array = array
386 .clone()
387 .into_iter()
388 .collect::<Decimal128Array>()
389 .with_precision_and_scale(10, 3)?;
390
391 generic_decimal_to_other_test_cast!(
392 decimal_array,
393 Decimal128(10, 3),
394 Decimal128Array,
395 Decimal128(20, 6),
396 [
397 Some(1_234_000),
398 Some(2_222_000),
399 Some(3_000),
400 Some(4_000_000),
401 Some(5_000_000),
402 None
403 ],
404 None
405 );
406
407 let decimal_array = array
408 .into_iter()
409 .collect::<Decimal128Array>()
410 .with_precision_and_scale(10, 3)?;
411
412 generic_decimal_to_other_test_cast!(
413 decimal_array,
414 Decimal128(10, 3),
415 Decimal128Array,
416 Decimal128(10, 2),
417 [Some(123), Some(222), Some(0), Some(400), Some(500), None],
418 None
419 );
420
421 Ok(())
422 }
423
424 #[test]
425 fn test_cast_decimal_to_decimal_overflow() -> Result<()> {
426 let array = vec![Some(123456789)];
427
428 let decimal_array = array
429 .clone()
430 .into_iter()
431 .collect::<Decimal128Array>()
432 .with_precision_and_scale(10, 3)?;
433
434 let schema = Schema::new(vec![Field::new("a", Decimal128(10, 3), false)]);
435 let batch = RecordBatch::try_new(
436 Arc::new(schema.clone()),
437 vec![Arc::new(decimal_array)],
438 )?;
439 let expression =
440 cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?;
441 let e = expression.evaluate(&batch).unwrap_err(); assert_contains!(
443 e.to_string(),
444 "Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999"
445 );
446
447 let expression_safe = cast_with_options(
448 col("a", &schema)?,
449 &schema,
450 Decimal128(6, 2),
451 Some(DEFAULT_SAFE_CAST_OPTIONS),
452 )?;
453 let result_safe = expression_safe
454 .evaluate(&batch)?
455 .into_array(batch.num_rows())
456 .expect("failed to convert to array");
457
458 assert!(result_safe.is_null(0));
459
460 Ok(())
461 }
462
463 #[test]
464 fn test_cast_decimal_to_numeric() -> Result<()> {
465 let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
466 let decimal_array = array
468 .clone()
469 .into_iter()
470 .collect::<Decimal128Array>()
471 .with_precision_and_scale(10, 0)?;
472 generic_decimal_to_other_test_cast!(
473 decimal_array,
474 Decimal128(10, 0),
475 Int8Array,
476 Int8,
477 [
478 Some(1_i8),
479 Some(2_i8),
480 Some(3_i8),
481 Some(4_i8),
482 Some(5_i8),
483 None
484 ],
485 None
486 );
487
488 let decimal_array = array
490 .clone()
491 .into_iter()
492 .collect::<Decimal128Array>()
493 .with_precision_and_scale(10, 0)?;
494 generic_decimal_to_other_test_cast!(
495 decimal_array,
496 Decimal128(10, 0),
497 Int16Array,
498 Int16,
499 [
500 Some(1_i16),
501 Some(2_i16),
502 Some(3_i16),
503 Some(4_i16),
504 Some(5_i16),
505 None
506 ],
507 None
508 );
509
510 let decimal_array = array
512 .clone()
513 .into_iter()
514 .collect::<Decimal128Array>()
515 .with_precision_and_scale(10, 0)?;
516 generic_decimal_to_other_test_cast!(
517 decimal_array,
518 Decimal128(10, 0),
519 Int32Array,
520 Int32,
521 [
522 Some(1_i32),
523 Some(2_i32),
524 Some(3_i32),
525 Some(4_i32),
526 Some(5_i32),
527 None
528 ],
529 None
530 );
531
532 let decimal_array = array
534 .into_iter()
535 .collect::<Decimal128Array>()
536 .with_precision_and_scale(10, 0)?;
537 generic_decimal_to_other_test_cast!(
538 decimal_array,
539 Decimal128(10, 0),
540 Int64Array,
541 Int64,
542 [
543 Some(1_i64),
544 Some(2_i64),
545 Some(3_i64),
546 Some(4_i64),
547 Some(5_i64),
548 None
549 ],
550 None
551 );
552
553 let array = vec![
555 Some(1234),
556 Some(2222),
557 Some(3),
558 Some(4000),
559 Some(5000),
560 None,
561 ];
562 let decimal_array = array
563 .clone()
564 .into_iter()
565 .collect::<Decimal128Array>()
566 .with_precision_and_scale(10, 3)?;
567 generic_decimal_to_other_test_cast!(
568 decimal_array,
569 Decimal128(10, 3),
570 Float32Array,
571 Float32,
572 [
573 Some(1.234_f32),
574 Some(2.222_f32),
575 Some(0.003_f32),
576 Some(4.0_f32),
577 Some(5.0_f32),
578 None
579 ],
580 None
581 );
582
583 let decimal_array = array
585 .into_iter()
586 .collect::<Decimal128Array>()
587 .with_precision_and_scale(20, 6)?;
588 generic_decimal_to_other_test_cast!(
589 decimal_array,
590 Decimal128(20, 6),
591 Float64Array,
592 Float64,
593 [
594 Some(0.001234_f64),
595 Some(0.002222_f64),
596 Some(0.000003_f64),
597 Some(0.004_f64),
598 Some(0.005_f64),
599 None
600 ],
601 None
602 );
603 Ok(())
604 }
605
606 #[test]
607 fn test_cast_numeric_to_decimal() -> Result<()> {
608 generic_test_cast!(
610 Int8Array,
611 Int8,
612 vec![1, 2, 3, 4, 5],
613 Decimal128Array,
614 Decimal128(3, 0),
615 [Some(1), Some(2), Some(3), Some(4), Some(5)],
616 None
617 );
618
619 generic_test_cast!(
621 Int16Array,
622 Int16,
623 vec![1, 2, 3, 4, 5],
624 Decimal128Array,
625 Decimal128(5, 0),
626 [Some(1), Some(2), Some(3), Some(4), Some(5)],
627 None
628 );
629
630 generic_test_cast!(
632 Int32Array,
633 Int32,
634 vec![1, 2, 3, 4, 5],
635 Decimal128Array,
636 Decimal128(10, 0),
637 [Some(1), Some(2), Some(3), Some(4), Some(5)],
638 None
639 );
640
641 generic_test_cast!(
643 Int64Array,
644 Int64,
645 vec![1, 2, 3, 4, 5],
646 Decimal128Array,
647 Decimal128(20, 0),
648 [Some(1), Some(2), Some(3), Some(4), Some(5)],
649 None
650 );
651
652 generic_test_cast!(
654 Int64Array,
655 Int64,
656 vec![1, 2, 3, 4, 5],
657 Decimal128Array,
658 Decimal128(20, 2),
659 [Some(100), Some(200), Some(300), Some(400), Some(500)],
660 None
661 );
662
663 generic_test_cast!(
665 Float32Array,
666 Float32,
667 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
668 Decimal128Array,
669 Decimal128(10, 2),
670 [Some(150), Some(250), Some(300), Some(112), Some(550)],
671 None
672 );
673
674 generic_test_cast!(
676 Float64Array,
677 Float64,
678 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
679 Decimal128Array,
680 Decimal128(20, 4),
681 [
682 Some(15000),
683 Some(25000),
684 Some(30000),
685 Some(11235),
686 Some(55000)
687 ],
688 None
689 );
690 Ok(())
691 }
692
693 #[test]
694 fn test_cast_i32_u32() -> Result<()> {
695 generic_test_cast!(
696 Int32Array,
697 Int32,
698 vec![1, 2, 3, 4, 5],
699 UInt32Array,
700 UInt32,
701 [
702 Some(1_u32),
703 Some(2_u32),
704 Some(3_u32),
705 Some(4_u32),
706 Some(5_u32)
707 ],
708 None
709 );
710 Ok(())
711 }
712
713 #[test]
714 fn test_cast_i32_utf8() -> Result<()> {
715 generic_test_cast!(
716 Int32Array,
717 Int32,
718 vec![1, 2, 3, 4, 5],
719 StringArray,
720 Utf8,
721 [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
722 None
723 );
724 Ok(())
725 }
726
727 #[test]
728 fn test_cast_i64_t64() -> Result<()> {
729 let original = vec![1, 2, 3, 4, 5];
730 let expected: Vec<Option<i64>> = original
731 .iter()
732 .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
733 .collect();
734 generic_test_cast!(
735 Int64Array,
736 Int64,
737 original,
738 TimestampNanosecondArray,
739 Timestamp(TimeUnit::Nanosecond, None),
740 expected,
741 None
742 );
743 Ok(())
744 }
745
746 #[test]
747 fn invalid_cast() {
748 let schema = Schema::new(vec![Field::new("a", Int32, false)]);
750
751 let result = cast(
752 col("a", &schema).unwrap(),
753 &schema,
754 Interval(IntervalUnit::MonthDayNano),
755 );
756 result.expect_err("expected Invalid CAST");
757 }
758
759 #[test]
760 fn invalid_cast_with_options_error() -> Result<()> {
761 let schema = Schema::new(vec![Field::new("a", Utf8, false)]);
763 let a = StringArray::from(vec!["9.1"]);
764 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
765 let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?;
766 let result = expression.evaluate(&batch);
767
768 match result {
769 Ok(_) => panic!("expected error"),
770 Err(e) => {
771 assert!(e
772 .to_string()
773 .contains("Cannot cast string '9.1' to value of Int32 type"))
774 }
775 }
776 Ok(())
777 }
778
779 #[test]
780 #[ignore] fn test_cast_decimal() -> Result<()> {
782 let schema = Schema::new(vec![Field::new("a", Int64, false)]);
783 let a = Int64Array::from(vec![100]);
784 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
785 let expression =
786 cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?;
787 expression.evaluate(&batch)?;
788 Ok(())
789 }
790
791 #[test]
792 fn test_fmt_sql() -> Result<()> {
793 let schema = Schema::new(vec![Field::new("a", Int32, true)]);
794
795 let expr = cast(col("a", &schema)?, &schema, Int64)?;
797 let display_string = expr.to_string();
798 assert_eq!(display_string, "CAST(a@0 AS Int64)");
799 let sql_string = fmt_sql(expr.as_ref()).to_string();
800 assert_eq!(sql_string, "CAST(a AS Int64)");
801
802 let schema = Schema::new(vec![Field::new("b", Utf8, true)]);
804 let expr = cast(col("b", &schema)?, &schema, Int32)?;
805 let display_string = expr.to_string();
806 assert_eq!(display_string, "CAST(b@0 AS Int32)");
807 let sql_string = fmt_sql(expr.as_ref()).to_string();
808 assert_eq!(sql_string, "CAST(b AS Int32)");
809
810 Ok(())
811 }
812}