1use std::any::Any;
19use std::fmt;
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::PhysicalExpr;
24use arrow::compute;
25use arrow::compute::{cast_with_options, CastOptions};
26use arrow::datatypes::{DataType, Schema};
27use arrow::record_batch::RecordBatch;
28use compute::can_cast_types;
29use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
30use datafusion_common::{not_impl_err, Result, ScalarValue};
31use datafusion_expr::ColumnarValue;
32
33#[derive(Debug, Eq)]
35pub struct TryCastExpr {
36 expr: Arc<dyn PhysicalExpr>,
38 cast_type: DataType,
40}
41
42impl PartialEq for TryCastExpr {
44 fn eq(&self, other: &Self) -> bool {
45 self.expr.eq(&other.expr) && self.cast_type == other.cast_type
46 }
47}
48
49impl Hash for TryCastExpr {
50 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
51 self.expr.hash(state);
52 self.cast_type.hash(state);
53 }
54}
55
56impl TryCastExpr {
57 pub fn new(expr: Arc<dyn PhysicalExpr>, cast_type: DataType) -> Self {
59 Self { expr, cast_type }
60 }
61
62 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
64 &self.expr
65 }
66
67 pub fn cast_type(&self) -> &DataType {
69 &self.cast_type
70 }
71}
72
73impl fmt::Display for TryCastExpr {
74 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
75 write!(f, "TRY_CAST({} AS {:?})", self.expr, self.cast_type)
76 }
77}
78
79impl PhysicalExpr for TryCastExpr {
80 fn as_any(&self) -> &dyn Any {
82 self
83 }
84
85 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
86 Ok(self.cast_type.clone())
87 }
88
89 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
90 Ok(true)
91 }
92
93 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
94 let value = self.expr.evaluate(batch)?;
95 let options = CastOptions {
96 safe: true,
97 format_options: DEFAULT_FORMAT_OPTIONS,
98 };
99 match value {
100 ColumnarValue::Array(array) => {
101 let cast = cast_with_options(&array, &self.cast_type, &options)?;
102 Ok(ColumnarValue::Array(cast))
103 }
104 ColumnarValue::Scalar(scalar) => {
105 let array = scalar.to_array()?;
106 let cast_array = cast_with_options(&array, &self.cast_type, &options)?;
107 let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
108 Ok(ColumnarValue::Scalar(cast_scalar))
109 }
110 }
111 }
112
113 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
114 vec![&self.expr]
115 }
116
117 fn with_new_children(
118 self: Arc<Self>,
119 children: Vec<Arc<dyn PhysicalExpr>>,
120 ) -> Result<Arc<dyn PhysicalExpr>> {
121 Ok(Arc::new(TryCastExpr::new(
122 Arc::clone(&children[0]),
123 self.cast_type.clone(),
124 )))
125 }
126
127 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 write!(f, "TRY_CAST(")?;
129 self.expr.fmt_sql(f)?;
130 write!(f, " AS {:?})", self.cast_type)
131 }
132}
133
134pub fn try_cast(
139 expr: Arc<dyn PhysicalExpr>,
140 input_schema: &Schema,
141 cast_type: DataType,
142) -> Result<Arc<dyn PhysicalExpr>> {
143 let expr_type = expr.data_type(input_schema)?;
144 if expr_type == cast_type {
145 Ok(Arc::clone(&expr))
146 } else if can_cast_types(&expr_type, &cast_type) {
147 Ok(Arc::new(TryCastExpr::new(expr, cast_type)))
148 } else {
149 not_impl_err!("Unsupported TRY_CAST from {expr_type:?} to {cast_type:?}")
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::expressions::col;
157 use arrow::array::{
158 Decimal128Array, Decimal128Builder, StringArray, Time64NanosecondArray,
159 };
160 use arrow::{
161 array::{
162 Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
163 Int8Array, TimestampNanosecondArray, UInt32Array,
164 },
165 datatypes::*,
166 };
167 use datafusion_physical_expr_common::physical_expr::fmt_sql;
168
169 macro_rules! generic_decimal_to_other_test_cast {
176 ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{
177 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
178 let batch = RecordBatch::try_new(
179 Arc::new(schema.clone()),
180 vec![Arc::new($DECIMAL_ARRAY)],
181 )?;
182 let expression = try_cast(col("a", &schema)?, &schema, $TYPE)?;
184
185 assert_eq!(
187 format!("TRY_CAST(a@0 AS {:?})", $TYPE),
188 format!("{}", expression)
189 );
190
191 assert_eq!(expression.data_type(&schema)?, $TYPE);
193
194 let result = expression
196 .evaluate(&batch)?
197 .into_array(batch.num_rows())
198 .expect("Failed to convert to array");
199
200 assert_eq!(*result.data_type(), $TYPE);
202
203 let result = result
205 .as_any()
206 .downcast_ref::<$TYPEARRAY>()
207 .expect("failed to downcast");
208
209 for (i, x) in $VEC.iter().enumerate() {
211 match x {
212 Some(x) => assert_eq!(result.value(i), *x),
213 None => assert!(!result.is_valid(i)),
214 }
215 }
216 }};
217 }
218
219 macro_rules! generic_test_cast {
226 ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{
227 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
228 let a_vec_len = $A_VEC.len();
229 let a = $A_ARRAY::from($A_VEC);
230 let batch =
231 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
232
233 let expression = try_cast(col("a", &schema)?, &schema, $TYPE)?;
235
236 assert_eq!(
238 format!("TRY_CAST(a@0 AS {:?})", $TYPE),
239 format!("{}", expression)
240 );
241
242 assert_eq!(expression.data_type(&schema)?, $TYPE);
244
245 let result = expression
247 .evaluate(&batch)?
248 .into_array(batch.num_rows())
249 .expect("Failed to convert to array");
250
251 assert_eq!(*result.data_type(), $TYPE);
253
254 assert_eq!(result.len(), a_vec_len);
256
257 let result = result
259 .as_any()
260 .downcast_ref::<$TYPEARRAY>()
261 .expect("failed to downcast");
262
263 for (i, x) in $VEC.iter().enumerate() {
265 match x {
266 Some(x) => assert_eq!(result.value(i), *x),
267 None => assert!(!result.is_valid(i)),
268 }
269 }
270 }};
271 }
272
273 #[test]
274 fn test_try_cast_decimal_to_decimal() -> Result<()> {
275 let array: Vec<i128> = vec![1234, 2222, 3, 4000, 5000];
277 let decimal_array = create_decimal_array(&array, 10, 3);
278 generic_decimal_to_other_test_cast!(
279 decimal_array,
280 DataType::Decimal128(10, 3),
281 Decimal128Array,
282 DataType::Decimal128(20, 6),
283 [
284 Some(1_234_000),
285 Some(2_222_000),
286 Some(3_000),
287 Some(4_000_000),
288 Some(5_000_000),
289 None
290 ]
291 );
292
293 let decimal_array = create_decimal_array(&array, 10, 3);
294 generic_decimal_to_other_test_cast!(
295 decimal_array,
296 DataType::Decimal128(10, 3),
297 Decimal128Array,
298 DataType::Decimal128(10, 2),
299 [Some(123), Some(222), Some(0), Some(400), Some(500), None]
300 );
301
302 Ok(())
303 }
304
305 #[test]
306 fn test_try_cast_decimal_to_numeric() -> Result<()> {
307 let array: Vec<i128> = vec![1, 2, 3, 4, 5];
310 let decimal_array = create_decimal_array(&array, 10, 0);
311 generic_decimal_to_other_test_cast!(
313 decimal_array,
314 DataType::Decimal128(10, 0),
315 Int8Array,
316 DataType::Int8,
317 [
318 Some(1_i8),
319 Some(2_i8),
320 Some(3_i8),
321 Some(4_i8),
322 Some(5_i8),
323 None
324 ]
325 );
326
327 let decimal_array = create_decimal_array(&array, 10, 0);
329 generic_decimal_to_other_test_cast!(
330 decimal_array,
331 DataType::Decimal128(10, 0),
332 Int16Array,
333 DataType::Int16,
334 [
335 Some(1_i16),
336 Some(2_i16),
337 Some(3_i16),
338 Some(4_i16),
339 Some(5_i16),
340 None
341 ]
342 );
343
344 let decimal_array = create_decimal_array(&array, 10, 0);
346 generic_decimal_to_other_test_cast!(
347 decimal_array,
348 DataType::Decimal128(10, 0),
349 Int32Array,
350 DataType::Int32,
351 [
352 Some(1_i32),
353 Some(2_i32),
354 Some(3_i32),
355 Some(4_i32),
356 Some(5_i32),
357 None
358 ]
359 );
360
361 let decimal_array = create_decimal_array(&array, 10, 0);
363 generic_decimal_to_other_test_cast!(
364 decimal_array,
365 DataType::Decimal128(10, 0),
366 Int64Array,
367 DataType::Int64,
368 [
369 Some(1_i64),
370 Some(2_i64),
371 Some(3_i64),
372 Some(4_i64),
373 Some(5_i64),
374 None
375 ]
376 );
377
378 let array: Vec<i128> = vec![1234, 2222, 3, 4000, 5000];
380 let decimal_array = create_decimal_array(&array, 10, 3);
381 generic_decimal_to_other_test_cast!(
382 decimal_array,
383 DataType::Decimal128(10, 3),
384 Float32Array,
385 DataType::Float32,
386 [
387 Some(1.234_f32),
388 Some(2.222_f32),
389 Some(0.003_f32),
390 Some(4.0_f32),
391 Some(5.0_f32),
392 None
393 ]
394 );
395 let decimal_array = create_decimal_array(&array, 20, 6);
397 generic_decimal_to_other_test_cast!(
398 decimal_array,
399 DataType::Decimal128(20, 6),
400 Float64Array,
401 DataType::Float64,
402 [
403 Some(0.001234_f64),
404 Some(0.002222_f64),
405 Some(0.000003_f64),
406 Some(0.004_f64),
407 Some(0.005_f64),
408 None
409 ]
410 );
411
412 Ok(())
413 }
414
415 #[test]
416 fn test_try_cast_numeric_to_decimal() -> Result<()> {
417 generic_test_cast!(
419 Int8Array,
420 DataType::Int8,
421 vec![1, 2, 3, 4, 5],
422 Decimal128Array,
423 DataType::Decimal128(3, 0),
424 [Some(1), Some(2), Some(3), Some(4), Some(5)]
425 );
426
427 generic_test_cast!(
429 Int16Array,
430 DataType::Int16,
431 vec![1, 2, 3, 4, 5],
432 Decimal128Array,
433 DataType::Decimal128(5, 0),
434 [Some(1), Some(2), Some(3), Some(4), Some(5)]
435 );
436
437 generic_test_cast!(
439 Int32Array,
440 DataType::Int32,
441 vec![1, 2, 3, 4, 5],
442 Decimal128Array,
443 DataType::Decimal128(10, 0),
444 [Some(1), Some(2), Some(3), Some(4), Some(5)]
445 );
446
447 generic_test_cast!(
449 Int64Array,
450 DataType::Int64,
451 vec![1, 2, 3, 4, 5],
452 Decimal128Array,
453 DataType::Decimal128(20, 0),
454 [Some(1), Some(2), Some(3), Some(4), Some(5)]
455 );
456
457 generic_test_cast!(
459 Int64Array,
460 DataType::Int64,
461 vec![1, 2, 3, 4, 5],
462 Decimal128Array,
463 DataType::Decimal128(20, 2),
464 [Some(100), Some(200), Some(300), Some(400), Some(500)]
465 );
466
467 generic_test_cast!(
469 Float32Array,
470 DataType::Float32,
471 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
472 Decimal128Array,
473 DataType::Decimal128(10, 2),
474 [Some(150), Some(250), Some(300), Some(112), Some(550)]
475 );
476
477 generic_test_cast!(
479 Float64Array,
480 DataType::Float64,
481 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
482 Decimal128Array,
483 DataType::Decimal128(20, 4),
484 [
485 Some(15000),
486 Some(25000),
487 Some(30000),
488 Some(11235),
489 Some(55000)
490 ]
491 );
492 Ok(())
493 }
494
495 #[test]
496 fn test_cast_i32_u32() -> Result<()> {
497 generic_test_cast!(
498 Int32Array,
499 DataType::Int32,
500 vec![1, 2, 3, 4, 5],
501 UInt32Array,
502 DataType::UInt32,
503 [
504 Some(1_u32),
505 Some(2_u32),
506 Some(3_u32),
507 Some(4_u32),
508 Some(5_u32)
509 ]
510 );
511 Ok(())
512 }
513
514 #[test]
515 fn test_cast_i32_utf8() -> Result<()> {
516 generic_test_cast!(
517 Int32Array,
518 DataType::Int32,
519 vec![1, 2, 3, 4, 5],
520 StringArray,
521 DataType::Utf8,
522 [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]
523 );
524 Ok(())
525 }
526
527 #[test]
528 fn test_try_cast_utf8_i32() -> Result<()> {
529 generic_test_cast!(
530 StringArray,
531 DataType::Utf8,
532 vec!["a", "2", "3", "b", "5"],
533 Int32Array,
534 DataType::Int32,
535 [None, Some(2), Some(3), None, Some(5)]
536 );
537 Ok(())
538 }
539
540 #[test]
541 fn test_cast_i64_t64() -> Result<()> {
542 let original = vec![1, 2, 3, 4, 5];
543 let expected: Vec<Option<i64>> = original
544 .iter()
545 .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
546 .collect();
547 generic_test_cast!(
548 Int64Array,
549 DataType::Int64,
550 original,
551 TimestampNanosecondArray,
552 DataType::Timestamp(TimeUnit::Nanosecond, None),
553 expected
554 );
555 Ok(())
556 }
557
558 #[test]
559 fn invalid_cast() {
560 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
562
563 let result = try_cast(
564 col("a", &schema).unwrap(),
565 &schema,
566 DataType::Interval(IntervalUnit::MonthDayNano),
567 );
568 result.expect_err("expected Invalid TRY_CAST");
569 }
570
571 fn create_decimal_array(array: &[i128], precision: u8, scale: i8) -> Decimal128Array {
573 let mut decimal_builder = Decimal128Builder::with_capacity(array.len());
574 for value in array {
575 decimal_builder.append_value(*value);
576 }
577 decimal_builder.append_null();
578 decimal_builder
579 .finish()
580 .with_precision_and_scale(precision, scale)
581 .unwrap()
582 }
583
584 #[test]
585 fn test_fmt_sql() -> Result<()> {
586 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
587
588 let expr = try_cast(col("a", &schema)?, &schema, DataType::Int64)?;
590 let display_string = expr.to_string();
591 assert_eq!(display_string, "TRY_CAST(a@0 AS Int64)");
592 let sql_string = fmt_sql(expr.as_ref()).to_string();
593 assert_eq!(sql_string, "TRY_CAST(a AS Int64)");
594
595 let schema = Schema::new(vec![Field::new("b", DataType::Utf8, true)]);
597 let expr = try_cast(col("b", &schema)?, &schema, DataType::Int32)?;
598 let display_string = expr.to_string();
599 assert_eq!(display_string, "TRY_CAST(b@0 AS Int32)");
600 let sql_string = fmt_sql(expr.as_ref()).to_string();
601 assert_eq!(sql_string, "TRY_CAST(b AS Int32)");
602
603 Ok(())
604 }
605}