1use std::cmp::Ordering;
25
26use arrow::datatypes::{
27 DataType, MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
28 MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION,
29 MIN_DECIMAL64_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION, TimeUnit,
30};
31use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
32use datafusion_common::ScalarValue;
33
34pub fn try_cast_literal_to_type(
36 lit_value: &ScalarValue,
37 target_type: &DataType,
38) -> Option<ScalarValue> {
39 let lit_data_type = lit_value.data_type();
40 if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
41 return None;
42 }
43 if lit_value.is_null() {
44 return ScalarValue::try_from(target_type).ok();
46 }
47 try_cast_numeric_literal(lit_value, target_type)
48 .or_else(|| try_cast_string_literal(lit_value, target_type))
49 .or_else(|| try_cast_dictionary(lit_value, target_type))
50 .or_else(|| try_cast_binary(lit_value, target_type))
51}
52
53pub fn is_supported_type(data_type: &DataType) -> bool {
55 is_supported_numeric_type(data_type)
56 || is_supported_string_type(data_type)
57 || is_supported_dictionary_type(data_type)
58 || is_supported_binary_type(data_type)
59}
60
61fn is_date_type(data_type: &DataType) -> bool {
62 matches!(data_type, DataType::Date32 | DataType::Date64)
63}
64
65fn is_lossy_temporal_cast(from_type: &DataType, to_type: &DataType) -> bool {
78 (is_date_type(from_type) && to_type.is_temporal())
79 || (is_date_type(to_type) && from_type.is_temporal())
80}
81
82fn is_supported_numeric_type(data_type: &DataType) -> bool {
84 matches!(
85 data_type,
86 DataType::UInt8
87 | DataType::UInt16
88 | DataType::UInt32
89 | DataType::UInt64
90 | DataType::Int8
91 | DataType::Int16
92 | DataType::Int32
93 | DataType::Int64
94 | DataType::Date32
95 | DataType::Date64
96 | DataType::Decimal32(_, _)
97 | DataType::Decimal64(_, _)
98 | DataType::Decimal128(_, _)
99 | DataType::Timestamp(_, _)
100 )
101}
102
103fn is_supported_string_type(data_type: &DataType) -> bool {
105 matches!(
106 data_type,
107 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
108 )
109}
110
111fn is_supported_dictionary_type(data_type: &DataType) -> bool {
113 matches!(data_type,
114 DataType::Dictionary(_, inner) if is_supported_type(inner))
115}
116
117fn is_supported_binary_type(data_type: &DataType) -> bool {
118 matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
119}
120
121fn try_cast_numeric_literal(
123 lit_value: &ScalarValue,
124 target_type: &DataType,
125) -> Option<ScalarValue> {
126 let lit_data_type = lit_value.data_type();
127 if !is_supported_numeric_type(&lit_data_type)
128 || !is_supported_numeric_type(target_type)
129 {
130 return None;
131 }
132
133 if is_lossy_temporal_cast(&lit_data_type, target_type) {
134 return None;
135 }
136
137 let mul = match target_type {
138 DataType::UInt8
139 | DataType::UInt16
140 | DataType::UInt32
141 | DataType::UInt64
142 | DataType::Int8
143 | DataType::Int16
144 | DataType::Int32
145 | DataType::Int64
146 | DataType::Date32
147 | DataType::Date64 => 1_i128,
148 DataType::Timestamp(_, _) => 1_i128,
149 DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
150 DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
151 DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
152 _ => return None,
153 };
154 let (target_min, target_max) = match target_type {
155 DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
156 DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
157 DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
158 DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
159 DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
160 DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
161 DataType::Int32 | DataType::Date32 => (i32::MIN as i128, i32::MAX as i128),
162 DataType::Int64 | DataType::Date64 => (i64::MIN as i128, i64::MAX as i128),
163 DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
164 DataType::Decimal32(precision, _) => (
165 MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
169 MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
170 ),
171 DataType::Decimal64(precision, _) => (
172 MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
176 MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
177 ),
178 DataType::Decimal128(precision, _) => (
179 MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
183 MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
184 ),
185 _ => return None,
186 };
187 let lit_value_target_type = match lit_value {
188 ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
189 ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
190 ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
191 ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
192 ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
193 ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
194 ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
195 ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
196 ScalarValue::Date32(Some(v)) => (*v as i128).checked_mul(mul),
197 ScalarValue::Date64(Some(v)) => (*v as i128).checked_mul(mul),
198 ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
199 ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
200 ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
201 ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
202 ScalarValue::Decimal32(Some(v), _, scale) => {
203 let v = *v as i128;
204 let lit_scale_mul = 10_i128.pow(*scale as u32);
205 if mul >= lit_scale_mul {
206 v.checked_mul(mul / lit_scale_mul)
211 } else if v % (lit_scale_mul / mul) == 0 {
212 Some(v / (lit_scale_mul / mul))
217 } else {
218 None
220 }
221 }
222 ScalarValue::Decimal64(Some(v), _, scale) => {
223 let v = *v as i128;
224 let lit_scale_mul = 10_i128.pow(*scale as u32);
225 if mul >= lit_scale_mul {
226 v.checked_mul(mul / lit_scale_mul)
231 } else if v % (lit_scale_mul / mul) == 0 {
232 Some(v / (lit_scale_mul / mul))
237 } else {
238 None
240 }
241 }
242 ScalarValue::Decimal128(Some(v), _, scale) => {
243 let lit_scale_mul = 10_i128.pow(*scale as u32);
244 if mul >= lit_scale_mul {
245 (*v).checked_mul(mul / lit_scale_mul)
250 } else if (*v) % (lit_scale_mul / mul) == 0 {
251 Some(*v / (lit_scale_mul / mul))
256 } else {
257 None
259 }
260 }
261 _ => None,
262 };
263
264 match lit_value_target_type {
265 None => None,
266 Some(value) => {
267 if value >= target_min && value <= target_max {
268 let result_scalar = match target_type {
271 DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
272 DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
273 DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
274 DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
275 DataType::Date32 => ScalarValue::Date32(Some(value as i32)),
276 DataType::Date64 => ScalarValue::Date64(Some(value as i64)),
277 DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
278 DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
279 DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
280 DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
281 DataType::Timestamp(TimeUnit::Second, tz) => {
282 let value = cast_between_timestamp(
283 &lit_data_type,
284 &DataType::Timestamp(TimeUnit::Second, tz.clone()),
285 value,
286 );
287 ScalarValue::TimestampSecond(value, tz.clone())
288 }
289 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
290 let value = cast_between_timestamp(
291 &lit_data_type,
292 &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
293 value,
294 );
295 ScalarValue::TimestampMillisecond(value, tz.clone())
296 }
297 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
298 let value = cast_between_timestamp(
299 &lit_data_type,
300 &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
301 value,
302 );
303 ScalarValue::TimestampMicrosecond(value, tz.clone())
304 }
305 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
306 let value = cast_between_timestamp(
307 &lit_data_type,
308 &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
309 value,
310 );
311 ScalarValue::TimestampNanosecond(value, tz.clone())
312 }
313 DataType::Decimal32(p, s) => {
314 ScalarValue::Decimal32(Some(value as i32), *p, *s)
315 }
316 DataType::Decimal64(p, s) => {
317 ScalarValue::Decimal64(Some(value as i64), *p, *s)
318 }
319 DataType::Decimal128(p, s) => {
320 ScalarValue::Decimal128(Some(value), *p, *s)
321 }
322 _ => {
323 return None;
324 }
325 };
326 Some(result_scalar)
327 } else {
328 None
329 }
330 }
331 }
332}
333
334fn try_cast_string_literal(
335 lit_value: &ScalarValue,
336 target_type: &DataType,
337) -> Option<ScalarValue> {
338 let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
339 let scalar_value = match target_type {
340 DataType::Utf8 => ScalarValue::Utf8(string_value),
341 DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
342 DataType::Utf8View => ScalarValue::Utf8View(string_value),
343 _ => return None,
344 };
345 Some(scalar_value)
346}
347
348fn try_cast_dictionary(
350 lit_value: &ScalarValue,
351 target_type: &DataType,
352) -> Option<ScalarValue> {
353 let lit_value_type = lit_value.data_type();
354 let result_scalar = match (lit_value, target_type) {
355 (ScalarValue::Dictionary(_, inner_value), _)
357 if inner_value.data_type() == *target_type =>
358 {
359 (**inner_value).clone()
360 }
361 (_, DataType::Dictionary(index_type, inner_type))
363 if **inner_type == lit_value_type =>
364 {
365 ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
366 }
367 _ => {
368 return None;
369 }
370 };
371 Some(result_scalar)
372}
373
374fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
376 let value = value as i64;
377 let from_scale = match from {
378 DataType::Timestamp(TimeUnit::Second, _) => 1,
379 DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
380 DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
381 DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
382 _ => return Some(value),
383 };
384
385 let to_scale = match to {
386 DataType::Timestamp(TimeUnit::Second, _) => 1,
387 DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
388 DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
389 DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
390 _ => return Some(value),
391 };
392
393 match from_scale.cmp(&to_scale) {
394 Ordering::Less => value.checked_mul(to_scale / from_scale),
395 Ordering::Greater => Some(value / (from_scale / to_scale)),
396 Ordering::Equal => Some(value),
397 }
398}
399
400fn try_cast_binary(
401 lit_value: &ScalarValue,
402 target_type: &DataType,
403) -> Option<ScalarValue> {
404 match (lit_value, target_type) {
405 (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
406 if v.len() == *n as usize =>
407 {
408 Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
409 }
410 _ => None,
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use arrow::compute::{CastOptions, cast_with_options};
418 use arrow::datatypes::{Field, Fields};
419 use std::sync::Arc;
420
421 #[derive(Debug, Clone)]
422 enum ExpectedCast {
423 Value(ScalarValue),
425 NoValue,
427 }
428
429 fn expect_cast(
433 literal: ScalarValue,
434 target_type: DataType,
435 expected_result: ExpectedCast,
436 ) {
437 let actual_value = try_cast_literal_to_type(&literal, &target_type);
438
439 println!("expect_cast: ");
440 println!(" {literal:?} --> {target_type}");
441 println!(" expected_result: {expected_result:?}");
442 println!(" actual_result: {actual_value:?}");
443
444 match expected_result {
445 ExpectedCast::Value(expected_value) => {
446 let actual_value =
447 actual_value.expect("Expected cast value but got None");
448
449 assert_eq!(actual_value, expected_value);
450
451 let literal_array = literal
455 .to_array_of_size(1)
456 .expect("Failed to convert to array of size");
457 let expected_array = expected_value
458 .to_array_of_size(1)
459 .expect("Failed to convert to array of size");
460 let cast_array = cast_with_options(
461 &literal_array,
462 &target_type,
463 &CastOptions::default(),
464 )
465 .expect("Expected to be cast array with arrow cast kernel");
466
467 assert_eq!(
468 &expected_array, &cast_array,
469 "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
470 );
471
472 if let (
475 DataType::Timestamp(left_unit, left_tz),
476 DataType::Timestamp(right_unit, right_tz),
477 ) = (actual_value.data_type(), expected_value.data_type())
478 {
479 assert_eq!(left_unit, right_unit);
480 assert_eq!(left_tz, right_tz);
481 }
482 }
483 ExpectedCast::NoValue => {
484 assert!(
485 actual_value.is_none(),
486 "Expected no cast value, but got {actual_value:?}"
487 );
488 }
489 }
490 }
491
492 #[test]
493 fn test_try_cast_to_type_nulls() {
494 let scalars = vec![
496 ScalarValue::Int8(None),
497 ScalarValue::Int16(None),
498 ScalarValue::Int32(None),
499 ScalarValue::Int64(None),
500 ScalarValue::UInt8(None),
501 ScalarValue::UInt16(None),
502 ScalarValue::UInt32(None),
503 ScalarValue::UInt64(None),
504 ScalarValue::Decimal128(None, 3, 0),
505 ScalarValue::Decimal128(None, 8, 2),
506 ScalarValue::Utf8(None),
507 ScalarValue::LargeUtf8(None),
508 ];
509
510 for s1 in &scalars {
511 for s2 in &scalars {
512 let expected_value = ExpectedCast::Value(s2.clone());
513
514 expect_cast(s1.clone(), s2.data_type(), expected_value);
515 }
516 }
517 }
518
519 #[test]
520 fn test_try_cast_to_type_int_in_range() {
521 let scalars = vec![
523 ScalarValue::Int8(Some(123)),
524 ScalarValue::Int16(Some(123)),
525 ScalarValue::Int32(Some(123)),
526 ScalarValue::Int64(Some(123)),
527 ScalarValue::UInt8(Some(123)),
528 ScalarValue::UInt16(Some(123)),
529 ScalarValue::UInt32(Some(123)),
530 ScalarValue::UInt64(Some(123)),
531 ScalarValue::Decimal128(Some(123), 3, 0),
532 ScalarValue::Decimal128(Some(12300), 8, 2),
533 ];
534
535 for s1 in &scalars {
536 for s2 in &scalars {
537 let expected_value = ExpectedCast::Value(s2.clone());
538
539 expect_cast(s1.clone(), s2.data_type(), expected_value);
540 }
541 }
542
543 let max_i32 = ScalarValue::Int32(Some(i32::MAX));
544 expect_cast(
545 max_i32,
546 DataType::UInt64,
547 ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
548 );
549
550 let min_i32 = ScalarValue::Int32(Some(i32::MIN));
551 expect_cast(
552 min_i32,
553 DataType::Int64,
554 ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
555 );
556
557 let max_i64 = ScalarValue::Int64(Some(i64::MAX));
558 expect_cast(
559 max_i64,
560 DataType::UInt64,
561 ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
562 );
563 }
564
565 #[test]
566 fn test_try_cast_to_type_int_out_of_range() {
567 let min_i32 = ScalarValue::Int32(Some(i32::MIN));
568 let min_i64 = ScalarValue::Int64(Some(i64::MIN));
569 let max_i64 = ScalarValue::Int64(Some(i64::MAX));
570 let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
571
572 expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
573
574 expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
575
576 expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
577
578 expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
579
580 expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
581
582 expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
583
584 expect_cast(
586 ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
587 DataType::Int64,
588 ExpectedCast::NoValue,
589 );
590
591 expect_cast(
592 ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
593 DataType::Int64,
594 ExpectedCast::NoValue,
595 );
596 }
597
598 #[test]
599 fn test_try_decimal_cast_in_range() {
600 expect_cast(
601 ScalarValue::Decimal128(Some(12300), 5, 2),
602 DataType::Decimal128(3, 0),
603 ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
604 );
605
606 expect_cast(
607 ScalarValue::Decimal128(Some(12300), 5, 2),
608 DataType::Decimal128(8, 0),
609 ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
610 );
611
612 expect_cast(
613 ScalarValue::Decimal128(Some(12300), 5, 2),
614 DataType::Decimal128(8, 5),
615 ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
616 );
617 }
618
619 #[test]
620 fn test_try_decimal_cast_out_of_range() {
621 expect_cast(
623 ScalarValue::Decimal128(Some(12345), 5, 2),
624 DataType::Decimal128(3, 0),
625 ExpectedCast::NoValue,
626 );
627
628 expect_cast(
630 ScalarValue::Decimal128(Some(12300), 5, 2),
631 DataType::Decimal128(2, 0),
632 ExpectedCast::NoValue,
633 );
634 }
635
636 #[test]
637 fn test_try_cast_to_type_timestamps() {
638 for time_unit in [
639 TimeUnit::Second,
640 TimeUnit::Millisecond,
641 TimeUnit::Microsecond,
642 TimeUnit::Nanosecond,
643 ] {
644 let utc = Some("+00:00".into());
645 let (lit_tz_none, lit_tz_utc) = match time_unit {
647 TimeUnit::Second => (
648 ScalarValue::TimestampSecond(Some(12345), None),
649 ScalarValue::TimestampSecond(Some(12345), utc),
650 ),
651
652 TimeUnit::Millisecond => (
653 ScalarValue::TimestampMillisecond(Some(12345), None),
654 ScalarValue::TimestampMillisecond(Some(12345), utc),
655 ),
656
657 TimeUnit::Microsecond => (
658 ScalarValue::TimestampMicrosecond(Some(12345), None),
659 ScalarValue::TimestampMicrosecond(Some(12345), utc),
660 ),
661
662 TimeUnit::Nanosecond => (
663 ScalarValue::TimestampNanosecond(Some(12345), None),
664 ScalarValue::TimestampNanosecond(Some(12345), utc),
665 ),
666 };
667
668 assert_eq!(lit_tz_none, lit_tz_utc);
671
672 let dt_tz_none = lit_tz_none.data_type();
674
675 let dt_tz_utc = lit_tz_utc.data_type();
677
678 expect_cast(
680 lit_tz_none.clone(),
681 dt_tz_none.clone(),
682 ExpectedCast::Value(lit_tz_none.clone()),
683 );
684
685 expect_cast(
687 lit_tz_none.clone(),
688 dt_tz_utc.clone(),
689 ExpectedCast::Value(lit_tz_utc.clone()),
690 );
691
692 expect_cast(
694 lit_tz_utc.clone(),
695 dt_tz_none.clone(),
696 ExpectedCast::Value(lit_tz_none.clone()),
697 );
698
699 expect_cast(
701 lit_tz_utc.clone(),
702 dt_tz_utc.clone(),
703 ExpectedCast::Value(lit_tz_utc.clone()),
704 );
705
706 expect_cast(
708 lit_tz_utc.clone(),
709 DataType::Int64,
710 ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
711 );
712
713 expect_cast(
715 ScalarValue::Int64(Some(12345)),
716 dt_tz_none.clone(),
717 ExpectedCast::Value(lit_tz_none.clone()),
718 );
719
720 expect_cast(
722 ScalarValue::Int64(Some(12345)),
723 dt_tz_utc.clone(),
724 ExpectedCast::Value(lit_tz_utc.clone()),
725 );
726
727 expect_cast(
729 lit_tz_utc.clone(),
730 DataType::LargeUtf8,
731 ExpectedCast::NoValue,
732 );
733 }
734 }
735
736 #[test]
737 fn test_try_cast_to_type_date_timestamp_lossy_not_allowed() {
738 expect_cast(
739 ScalarValue::Date32(Some(1)),
740 DataType::Timestamp(TimeUnit::Second, None),
741 ExpectedCast::NoValue,
742 );
743
744 expect_cast(
745 ScalarValue::Date64(Some(86_400_000)),
746 DataType::Timestamp(TimeUnit::Millisecond, None),
747 ExpectedCast::NoValue,
748 );
749
750 expect_cast(
751 ScalarValue::TimestampSecond(Some(86_400), None),
752 DataType::Date32,
753 ExpectedCast::NoValue,
754 );
755
756 expect_cast(
757 ScalarValue::TimestampMillisecond(Some(86_400_000), None),
758 DataType::Date64,
759 ExpectedCast::NoValue,
760 );
761 }
762
763 #[test]
764 fn test_try_cast_to_type_unsupported() {
765 expect_cast(
767 ScalarValue::Int64(Some(12345)),
768 DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
769 ExpectedCast::NoValue,
770 );
771 }
772
773 #[test]
774 fn test_try_cast_literal_to_timestamp() {
775 let new_scalar = try_cast_literal_to_type(
777 &ScalarValue::TimestampNanosecond(Some(123456), None),
778 &DataType::Timestamp(TimeUnit::Nanosecond, None),
779 )
780 .unwrap();
781
782 assert_eq!(
783 new_scalar,
784 ScalarValue::TimestampNanosecond(Some(123456), None)
785 );
786
787 let new_scalar = try_cast_literal_to_type(
789 &ScalarValue::TimestampNanosecond(Some(123456), None),
790 &DataType::Timestamp(TimeUnit::Microsecond, None),
791 )
792 .unwrap();
793
794 assert_eq!(
795 new_scalar,
796 ScalarValue::TimestampMicrosecond(Some(123), None)
797 );
798
799 let new_scalar = try_cast_literal_to_type(
801 &ScalarValue::TimestampNanosecond(Some(123456), None),
802 &DataType::Timestamp(TimeUnit::Millisecond, None),
803 )
804 .unwrap();
805
806 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
807
808 let new_scalar = try_cast_literal_to_type(
810 &ScalarValue::TimestampNanosecond(Some(123456), None),
811 &DataType::Timestamp(TimeUnit::Second, None),
812 )
813 .unwrap();
814
815 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
816
817 let new_scalar = try_cast_literal_to_type(
819 &ScalarValue::TimestampMicrosecond(Some(123), None),
820 &DataType::Timestamp(TimeUnit::Nanosecond, None),
821 )
822 .unwrap();
823
824 assert_eq!(
825 new_scalar,
826 ScalarValue::TimestampNanosecond(Some(123000), None)
827 );
828
829 let new_scalar = try_cast_literal_to_type(
831 &ScalarValue::TimestampMicrosecond(Some(123), None),
832 &DataType::Timestamp(TimeUnit::Millisecond, None),
833 )
834 .unwrap();
835
836 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
837
838 let new_scalar = try_cast_literal_to_type(
840 &ScalarValue::TimestampMicrosecond(Some(123456789), None),
841 &DataType::Timestamp(TimeUnit::Second, None),
842 )
843 .unwrap();
844 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
845
846 let new_scalar = try_cast_literal_to_type(
848 &ScalarValue::TimestampMillisecond(Some(123), None),
849 &DataType::Timestamp(TimeUnit::Nanosecond, None),
850 )
851 .unwrap();
852 assert_eq!(
853 new_scalar,
854 ScalarValue::TimestampNanosecond(Some(123000000), None)
855 );
856
857 let new_scalar = try_cast_literal_to_type(
859 &ScalarValue::TimestampMillisecond(Some(123), None),
860 &DataType::Timestamp(TimeUnit::Microsecond, None),
861 )
862 .unwrap();
863 assert_eq!(
864 new_scalar,
865 ScalarValue::TimestampMicrosecond(Some(123000), None)
866 );
867 let new_scalar = try_cast_literal_to_type(
869 &ScalarValue::TimestampMillisecond(Some(123456789), None),
870 &DataType::Timestamp(TimeUnit::Second, None),
871 )
872 .unwrap();
873 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
874
875 let new_scalar = try_cast_literal_to_type(
877 &ScalarValue::TimestampSecond(Some(123), None),
878 &DataType::Timestamp(TimeUnit::Nanosecond, None),
879 )
880 .unwrap();
881 assert_eq!(
882 new_scalar,
883 ScalarValue::TimestampNanosecond(Some(123000000000), None)
884 );
885
886 let new_scalar = try_cast_literal_to_type(
888 &ScalarValue::TimestampSecond(Some(123), None),
889 &DataType::Timestamp(TimeUnit::Microsecond, None),
890 )
891 .unwrap();
892 assert_eq!(
893 new_scalar,
894 ScalarValue::TimestampMicrosecond(Some(123000000), None)
895 );
896
897 let new_scalar = try_cast_literal_to_type(
899 &ScalarValue::TimestampSecond(Some(123), None),
900 &DataType::Timestamp(TimeUnit::Millisecond, None),
901 )
902 .unwrap();
903 assert_eq!(
904 new_scalar,
905 ScalarValue::TimestampMillisecond(Some(123000), None)
906 );
907
908 let new_scalar = try_cast_literal_to_type(
910 &ScalarValue::TimestampSecond(Some(i64::MAX), None),
911 &DataType::Timestamp(TimeUnit::Millisecond, None),
912 )
913 .unwrap();
914 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
915 }
916
917 #[test]
918 fn test_try_cast_to_string_type() {
919 let scalars = vec![
920 ScalarValue::from("string"),
921 ScalarValue::LargeUtf8(Some("string".to_owned())),
922 ];
923
924 for s1 in &scalars {
925 for s2 in &scalars {
926 let expected_value = ExpectedCast::Value(s2.clone());
927
928 expect_cast(s1.clone(), s2.data_type(), expected_value);
929 }
930 }
931 }
932
933 #[test]
934 fn test_try_cast_to_dictionary_type() {
935 fn dictionary_type(t: DataType) -> DataType {
936 DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
937 }
938 fn dictionary_value(value: ScalarValue) -> ScalarValue {
939 ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
940 }
941 let scalars = vec![
942 ScalarValue::from("string"),
943 ScalarValue::LargeUtf8(Some("string".to_owned())),
944 ];
945 for s in &scalars {
946 expect_cast(
947 s.clone(),
948 dictionary_type(s.data_type()),
949 ExpectedCast::Value(dictionary_value(s.clone())),
950 );
951 expect_cast(
952 dictionary_value(s.clone()),
953 s.data_type(),
954 ExpectedCast::Value(s.clone()),
955 )
956 }
957 }
958
959 #[test]
960 fn test_try_cast_to_fixed_size_binary() {
961 expect_cast(
962 ScalarValue::Binary(Some(vec![1, 2, 3])),
963 DataType::FixedSizeBinary(3),
964 ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
965 )
966 }
967
968 #[test]
969 fn test_numeric_boundary_values() {
970 expect_cast(
972 ScalarValue::Int8(Some(i8::MAX)),
973 DataType::UInt8,
974 ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))),
975 );
976
977 expect_cast(
978 ScalarValue::Int8(Some(i8::MIN)),
979 DataType::UInt8,
980 ExpectedCast::NoValue,
981 );
982
983 expect_cast(
984 ScalarValue::UInt8(Some(u8::MAX)),
985 DataType::Int8,
986 ExpectedCast::NoValue,
987 );
988
989 expect_cast(
991 ScalarValue::Int32(Some(i32::MAX)),
992 DataType::Int64,
993 ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))),
994 );
995
996 expect_cast(
997 ScalarValue::Int64(Some(i64::MIN)),
998 DataType::UInt64,
999 ExpectedCast::NoValue,
1000 );
1001
1002 expect_cast(
1004 ScalarValue::UInt32(Some(u32::MAX)),
1005 DataType::Int32,
1006 ExpectedCast::NoValue,
1007 );
1008
1009 expect_cast(
1010 ScalarValue::UInt64(Some(u64::MAX)),
1011 DataType::Int64,
1012 ExpectedCast::NoValue,
1013 );
1014 }
1015
1016 #[test]
1017 fn test_decimal_precision_limits() {
1018 use arrow::datatypes::{
1019 MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION,
1020 };
1021
1022 expect_cast(
1024 ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
1025 DataType::Decimal128(5, 0),
1026 ExpectedCast::Value(ScalarValue::Decimal128(
1027 Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]),
1028 5,
1029 0,
1030 )),
1031 );
1032
1033 expect_cast(
1035 ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
1036 DataType::Decimal128(5, 0),
1037 ExpectedCast::Value(ScalarValue::Decimal128(
1038 Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]),
1039 5,
1040 0,
1041 )),
1042 );
1043
1044 expect_cast(
1046 ScalarValue::Decimal128(Some(123), 3, 0),
1047 DataType::Decimal128(5, 2),
1048 ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)),
1049 );
1050
1051 expect_cast(
1053 ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0),
1054 DataType::Decimal128(3, 0),
1055 ExpectedCast::NoValue,
1056 );
1057
1058 expect_cast(
1060 ScalarValue::Decimal128(Some(12345), 5, 3), DataType::Int32,
1062 ExpectedCast::NoValue, );
1064
1065 expect_cast(
1067 ScalarValue::Decimal128(Some(12345), 5, 2), DataType::Decimal128(3, 0), ExpectedCast::NoValue,
1070 );
1071 }
1072
1073 #[test]
1074 fn test_timestamp_overflow_scenarios() {
1075 let max_seconds = i64::MAX / 1_000_000_000; expect_cast(
1080 ScalarValue::TimestampSecond(Some(max_seconds), None),
1081 DataType::Timestamp(TimeUnit::Nanosecond, None),
1082 ExpectedCast::Value(ScalarValue::TimestampNanosecond(
1083 Some(max_seconds * 1_000_000_000),
1084 None,
1085 )),
1086 );
1087
1088 expect_cast(
1090 ScalarValue::TimestampNanosecond(Some(i64::MAX), None),
1091 DataType::Timestamp(TimeUnit::Second, None),
1092 ExpectedCast::Value(ScalarValue::TimestampSecond(
1093 Some(i64::MAX / 1_000_000_000),
1094 None,
1095 )),
1096 );
1097
1098 expect_cast(
1100 ScalarValue::TimestampNanosecond(Some(1), None),
1101 DataType::Timestamp(TimeUnit::Second, None),
1102 ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)),
1103 );
1104
1105 expect_cast(
1106 ScalarValue::TimestampMicrosecond(Some(999), None),
1107 DataType::Timestamp(TimeUnit::Millisecond, None),
1108 ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)),
1109 );
1110 }
1111
1112 #[test]
1113 fn test_string_view() {
1114 expect_cast(
1116 ScalarValue::Utf8View(Some("test".to_string())),
1117 DataType::Utf8,
1118 ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))),
1119 );
1120
1121 expect_cast(
1122 ScalarValue::Utf8View(Some("test".to_string())),
1123 DataType::LargeUtf8,
1124 ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))),
1125 );
1126
1127 expect_cast(
1129 ScalarValue::Utf8(Some("hello".to_string())),
1130 DataType::Utf8View,
1131 ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))),
1132 );
1133
1134 expect_cast(
1135 ScalarValue::LargeUtf8(Some("world".to_string())),
1136 DataType::Utf8View,
1137 ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))),
1138 );
1139
1140 expect_cast(
1142 ScalarValue::Utf8(Some("".to_string())),
1143 DataType::Utf8View,
1144 ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))),
1145 );
1146
1147 let large_string = "x".repeat(1000);
1149 expect_cast(
1150 ScalarValue::LargeUtf8(Some(large_string.clone())),
1151 DataType::Utf8View,
1152 ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))),
1153 );
1154 }
1155
1156 #[test]
1157 fn test_binary_size_edge_cases() {
1158 expect_cast(
1160 ScalarValue::Binary(Some(vec![1, 2])),
1161 DataType::FixedSizeBinary(3),
1162 ExpectedCast::NoValue,
1163 );
1164
1165 expect_cast(
1167 ScalarValue::Binary(Some(vec![1, 2, 3, 4])),
1168 DataType::FixedSizeBinary(3),
1169 ExpectedCast::NoValue,
1170 );
1171
1172 expect_cast(
1174 ScalarValue::Binary(Some(vec![])),
1175 DataType::FixedSizeBinary(0),
1176 ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))),
1177 );
1178
1179 expect_cast(
1181 ScalarValue::Binary(Some(vec![1, 2, 3])),
1182 DataType::FixedSizeBinary(3),
1183 ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
1184 );
1185
1186 expect_cast(
1188 ScalarValue::Binary(Some(vec![42])),
1189 DataType::FixedSizeBinary(1),
1190 ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))),
1191 );
1192 }
1193
1194 #[test]
1195 fn test_dictionary_index_types() {
1196 let string_value = ScalarValue::Utf8(Some("test".to_string()));
1198
1199 let dict_int8 =
1201 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
1202 expect_cast(
1203 string_value.clone(),
1204 dict_int8,
1205 ExpectedCast::Value(ScalarValue::Dictionary(
1206 Box::new(DataType::Int8),
1207 Box::new(string_value.clone()),
1208 )),
1209 );
1210
1211 let dict_int16 =
1213 DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
1214 expect_cast(
1215 string_value.clone(),
1216 dict_int16,
1217 ExpectedCast::Value(ScalarValue::Dictionary(
1218 Box::new(DataType::Int16),
1219 Box::new(string_value.clone()),
1220 )),
1221 );
1222
1223 let dict_int64 =
1225 DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8));
1226 expect_cast(
1227 string_value.clone(),
1228 dict_int64,
1229 ExpectedCast::Value(ScalarValue::Dictionary(
1230 Box::new(DataType::Int64),
1231 Box::new(string_value.clone()),
1232 )),
1233 );
1234
1235 let dict_value = ScalarValue::Dictionary(
1237 Box::new(DataType::Int32),
1238 Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1239 );
1240 expect_cast(
1241 dict_value,
1242 DataType::LargeUtf8,
1243 ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1244 );
1245 }
1246
1247 #[test]
1248 fn test_type_support_functions() {
1249 assert!(is_supported_numeric_type(&DataType::Int8));
1251 assert!(is_supported_numeric_type(&DataType::UInt64));
1252 assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2)));
1253 assert!(is_supported_numeric_type(&DataType::Timestamp(
1254 TimeUnit::Nanosecond,
1255 None
1256 )));
1257 assert!(!is_supported_numeric_type(&DataType::Float32));
1258 assert!(!is_supported_numeric_type(&DataType::Float64));
1259
1260 assert!(is_supported_string_type(&DataType::Utf8));
1262 assert!(is_supported_string_type(&DataType::LargeUtf8));
1263 assert!(is_supported_string_type(&DataType::Utf8View));
1264 assert!(!is_supported_string_type(&DataType::Binary));
1265
1266 assert!(is_supported_binary_type(&DataType::Binary));
1268 assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10)));
1269 assert!(!is_supported_binary_type(&DataType::Utf8));
1270
1271 assert!(is_supported_dictionary_type(&DataType::Dictionary(
1273 Box::new(DataType::Int32),
1274 Box::new(DataType::Utf8)
1275 )));
1276 assert!(is_supported_dictionary_type(&DataType::Dictionary(
1277 Box::new(DataType::Int32),
1278 Box::new(DataType::Int64)
1279 )));
1280 assert!(!is_supported_dictionary_type(&DataType::Dictionary(
1281 Box::new(DataType::Int32),
1282 Box::new(DataType::List(Arc::new(Field::new(
1283 "item",
1284 DataType::Int32,
1285 true
1286 ))))
1287 )));
1288
1289 assert!(is_supported_type(&DataType::Int32));
1291 assert!(is_supported_type(&DataType::Utf8));
1292 assert!(is_supported_type(&DataType::Binary));
1293 assert!(is_supported_type(&DataType::Dictionary(
1294 Box::new(DataType::Int32),
1295 Box::new(DataType::Utf8)
1296 )));
1297 assert!(!is_supported_type(&DataType::List(Arc::new(Field::new(
1298 "item",
1299 DataType::Int32,
1300 true
1301 )))));
1302 assert!(!is_supported_type(&DataType::Struct(Fields::empty())));
1303 }
1304
1305 #[test]
1306 fn test_error_conditions() {
1307 expect_cast(
1309 ScalarValue::Float32(Some(1.5)),
1310 DataType::Int32,
1311 ExpectedCast::NoValue,
1312 );
1313
1314 expect_cast(
1316 ScalarValue::Int32(Some(123)),
1317 DataType::Float64,
1318 ExpectedCast::NoValue,
1319 );
1320
1321 expect_cast(
1323 ScalarValue::Float64(Some(1.5)),
1324 DataType::Float32,
1325 ExpectedCast::NoValue,
1326 );
1327
1328 let list_type =
1330 DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1331 expect_cast(
1332 ScalarValue::Int32(Some(123)),
1333 list_type,
1334 ExpectedCast::NoValue,
1335 );
1336
1337 let bad_dict = DataType::Dictionary(
1339 Box::new(DataType::Int32),
1340 Box::new(DataType::List(Arc::new(Field::new(
1341 "item",
1342 DataType::Int32,
1343 true,
1344 )))),
1345 );
1346 expect_cast(
1347 ScalarValue::Int32(Some(123)),
1348 bad_dict,
1349 ExpectedCast::NoValue,
1350 );
1351 }
1352}