1use std::cmp::Ordering;
25
26use arrow::datatypes::{
27 DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
28 MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
29 MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION,
30 MIN_DECIMAL64_FOR_EACH_PRECISION,
31};
32use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
33use datafusion_common::ScalarValue;
34
35pub fn try_cast_literal_to_type(
37 lit_value: &ScalarValue,
38 target_type: &DataType,
39) -> Option<ScalarValue> {
40 let lit_data_type = lit_value.data_type();
41 if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
42 return None;
43 }
44 if lit_value.is_null() {
45 return ScalarValue::try_from(target_type).ok();
47 }
48 try_cast_numeric_literal(lit_value, target_type)
49 .or_else(|| try_cast_string_literal(lit_value, target_type))
50 .or_else(|| try_cast_dictionary(lit_value, target_type))
51 .or_else(|| try_cast_binary(lit_value, target_type))
52}
53
54pub fn is_supported_type(data_type: &DataType) -> bool {
56 is_supported_numeric_type(data_type)
57 || is_supported_string_type(data_type)
58 || is_supported_dictionary_type(data_type)
59 || is_supported_binary_type(data_type)
60}
61
62fn is_supported_numeric_type(data_type: &DataType) -> bool {
64 matches!(
65 data_type,
66 DataType::UInt8
67 | DataType::UInt16
68 | DataType::UInt32
69 | DataType::UInt64
70 | DataType::Int8
71 | DataType::Int16
72 | DataType::Int32
73 | DataType::Int64
74 | DataType::Decimal32(_, _)
75 | DataType::Decimal64(_, _)
76 | DataType::Decimal128(_, _)
77 | DataType::Timestamp(_, _)
78 )
79}
80
81fn is_supported_string_type(data_type: &DataType) -> bool {
83 matches!(
84 data_type,
85 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
86 )
87}
88
89fn is_supported_dictionary_type(data_type: &DataType) -> bool {
91 matches!(data_type,
92 DataType::Dictionary(_, inner) if is_supported_type(inner))
93}
94
95fn is_supported_binary_type(data_type: &DataType) -> bool {
96 matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
97}
98
99fn try_cast_numeric_literal(
101 lit_value: &ScalarValue,
102 target_type: &DataType,
103) -> Option<ScalarValue> {
104 let lit_data_type = lit_value.data_type();
105 if !is_supported_numeric_type(&lit_data_type)
106 || !is_supported_numeric_type(target_type)
107 {
108 return None;
109 }
110
111 let mul = match target_type {
112 DataType::UInt8
113 | DataType::UInt16
114 | DataType::UInt32
115 | DataType::UInt64
116 | DataType::Int8
117 | DataType::Int16
118 | DataType::Int32
119 | DataType::Int64 => 1_i128,
120 DataType::Timestamp(_, _) => 1_i128,
121 DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
122 DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
123 DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
124 _ => return None,
125 };
126 let (target_min, target_max) = match target_type {
127 DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
128 DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
129 DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
130 DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
131 DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
132 DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
133 DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
134 DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
135 DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
136 DataType::Decimal32(precision, _) => (
137 MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
141 MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
142 ),
143 DataType::Decimal64(precision, _) => (
144 MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
148 MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
149 ),
150 DataType::Decimal128(precision, _) => (
151 MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
155 MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
156 ),
157 _ => return None,
158 };
159 let lit_value_target_type = match lit_value {
160 ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
161 ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
162 ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
163 ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
164 ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
165 ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
166 ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
167 ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
168 ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
169 ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
170 ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
171 ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
172 ScalarValue::Decimal32(Some(v), _, scale) => {
173 let v = *v as i128;
174 let lit_scale_mul = 10_i128.pow(*scale as u32);
175 if mul >= lit_scale_mul {
176 v.checked_mul(mul / lit_scale_mul)
181 } else if v % (lit_scale_mul / mul) == 0 {
182 Some(v / (lit_scale_mul / mul))
187 } else {
188 None
190 }
191 }
192 ScalarValue::Decimal64(Some(v), _, scale) => {
193 let v = *v as i128;
194 let lit_scale_mul = 10_i128.pow(*scale as u32);
195 if mul >= lit_scale_mul {
196 v.checked_mul(mul / lit_scale_mul)
201 } else if v % (lit_scale_mul / mul) == 0 {
202 Some(v / (lit_scale_mul / mul))
207 } else {
208 None
210 }
211 }
212 ScalarValue::Decimal128(Some(v), _, scale) => {
213 let lit_scale_mul = 10_i128.pow(*scale as u32);
214 if mul >= lit_scale_mul {
215 (*v).checked_mul(mul / lit_scale_mul)
220 } else if (*v) % (lit_scale_mul / mul) == 0 {
221 Some(*v / (lit_scale_mul / mul))
226 } else {
227 None
229 }
230 }
231 _ => None,
232 };
233
234 match lit_value_target_type {
235 None => None,
236 Some(value) => {
237 if value >= target_min && value <= target_max {
238 let result_scalar = match target_type {
241 DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
242 DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
243 DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
244 DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
245 DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
246 DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
247 DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
248 DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
249 DataType::Timestamp(TimeUnit::Second, tz) => {
250 let value = cast_between_timestamp(
251 &lit_data_type,
252 &DataType::Timestamp(TimeUnit::Second, tz.clone()),
253 value,
254 );
255 ScalarValue::TimestampSecond(value, tz.clone())
256 }
257 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
258 let value = cast_between_timestamp(
259 &lit_data_type,
260 &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
261 value,
262 );
263 ScalarValue::TimestampMillisecond(value, tz.clone())
264 }
265 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
266 let value = cast_between_timestamp(
267 &lit_data_type,
268 &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
269 value,
270 );
271 ScalarValue::TimestampMicrosecond(value, tz.clone())
272 }
273 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
274 let value = cast_between_timestamp(
275 &lit_data_type,
276 &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
277 value,
278 );
279 ScalarValue::TimestampNanosecond(value, tz.clone())
280 }
281 DataType::Decimal32(p, s) => {
282 ScalarValue::Decimal32(Some(value as i32), *p, *s)
283 }
284 DataType::Decimal64(p, s) => {
285 ScalarValue::Decimal64(Some(value as i64), *p, *s)
286 }
287 DataType::Decimal128(p, s) => {
288 ScalarValue::Decimal128(Some(value), *p, *s)
289 }
290 _ => {
291 return None;
292 }
293 };
294 Some(result_scalar)
295 } else {
296 None
297 }
298 }
299 }
300}
301
302fn try_cast_string_literal(
303 lit_value: &ScalarValue,
304 target_type: &DataType,
305) -> Option<ScalarValue> {
306 let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
307 let scalar_value = match target_type {
308 DataType::Utf8 => ScalarValue::Utf8(string_value),
309 DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
310 DataType::Utf8View => ScalarValue::Utf8View(string_value),
311 _ => return None,
312 };
313 Some(scalar_value)
314}
315
316fn try_cast_dictionary(
318 lit_value: &ScalarValue,
319 target_type: &DataType,
320) -> Option<ScalarValue> {
321 let lit_value_type = lit_value.data_type();
322 let result_scalar = match (lit_value, target_type) {
323 (ScalarValue::Dictionary(_, inner_value), _)
325 if inner_value.data_type() == *target_type =>
326 {
327 (**inner_value).clone()
328 }
329 (_, DataType::Dictionary(index_type, inner_type))
331 if **inner_type == lit_value_type =>
332 {
333 ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
334 }
335 _ => {
336 return None;
337 }
338 };
339 Some(result_scalar)
340}
341
342fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
344 let value = value as i64;
345 let from_scale = match from {
346 DataType::Timestamp(TimeUnit::Second, _) => 1,
347 DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
348 DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
349 DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
350 _ => return Some(value),
351 };
352
353 let to_scale = match to {
354 DataType::Timestamp(TimeUnit::Second, _) => 1,
355 DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
356 DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
357 DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
358 _ => return Some(value),
359 };
360
361 match from_scale.cmp(&to_scale) {
362 Ordering::Less => value.checked_mul(to_scale / from_scale),
363 Ordering::Greater => Some(value / (from_scale / to_scale)),
364 Ordering::Equal => Some(value),
365 }
366}
367
368fn try_cast_binary(
369 lit_value: &ScalarValue,
370 target_type: &DataType,
371) -> Option<ScalarValue> {
372 match (lit_value, target_type) {
373 (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
374 if v.len() == *n as usize =>
375 {
376 Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
377 }
378 _ => None,
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use arrow::compute::{cast_with_options, CastOptions};
386 use arrow::datatypes::{Field, Fields, TimeUnit};
387 use std::sync::Arc;
388
389 #[derive(Debug, Clone)]
390 enum ExpectedCast {
391 Value(ScalarValue),
393 NoValue,
395 }
396
397 fn expect_cast(
401 literal: ScalarValue,
402 target_type: DataType,
403 expected_result: ExpectedCast,
404 ) {
405 let actual_value = try_cast_literal_to_type(&literal, &target_type);
406
407 println!("expect_cast: ");
408 println!(" {literal:?} --> {target_type}");
409 println!(" expected_result: {expected_result:?}");
410 println!(" actual_result: {actual_value:?}");
411
412 match expected_result {
413 ExpectedCast::Value(expected_value) => {
414 let actual_value =
415 actual_value.expect("Expected cast value but got None");
416
417 assert_eq!(actual_value, expected_value);
418
419 let literal_array = literal
423 .to_array_of_size(1)
424 .expect("Failed to convert to array of size");
425 let expected_array = expected_value
426 .to_array_of_size(1)
427 .expect("Failed to convert to array of size");
428 let cast_array = cast_with_options(
429 &literal_array,
430 &target_type,
431 &CastOptions::default(),
432 )
433 .expect("Expected to be cast array with arrow cast kernel");
434
435 assert_eq!(
436 &expected_array, &cast_array,
437 "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
438 );
439
440 if let (
443 DataType::Timestamp(left_unit, left_tz),
444 DataType::Timestamp(right_unit, right_tz),
445 ) = (actual_value.data_type(), expected_value.data_type())
446 {
447 assert_eq!(left_unit, right_unit);
448 assert_eq!(left_tz, right_tz);
449 }
450 }
451 ExpectedCast::NoValue => {
452 assert!(
453 actual_value.is_none(),
454 "Expected no cast value, but got {actual_value:?}"
455 );
456 }
457 }
458 }
459
460 #[test]
461 fn test_try_cast_to_type_nulls() {
462 let scalars = vec![
464 ScalarValue::Int8(None),
465 ScalarValue::Int16(None),
466 ScalarValue::Int32(None),
467 ScalarValue::Int64(None),
468 ScalarValue::UInt8(None),
469 ScalarValue::UInt16(None),
470 ScalarValue::UInt32(None),
471 ScalarValue::UInt64(None),
472 ScalarValue::Decimal128(None, 3, 0),
473 ScalarValue::Decimal128(None, 8, 2),
474 ScalarValue::Utf8(None),
475 ScalarValue::LargeUtf8(None),
476 ];
477
478 for s1 in &scalars {
479 for s2 in &scalars {
480 let expected_value = ExpectedCast::Value(s2.clone());
481
482 expect_cast(s1.clone(), s2.data_type(), expected_value);
483 }
484 }
485 }
486
487 #[test]
488 fn test_try_cast_to_type_int_in_range() {
489 let scalars = vec![
491 ScalarValue::Int8(Some(123)),
492 ScalarValue::Int16(Some(123)),
493 ScalarValue::Int32(Some(123)),
494 ScalarValue::Int64(Some(123)),
495 ScalarValue::UInt8(Some(123)),
496 ScalarValue::UInt16(Some(123)),
497 ScalarValue::UInt32(Some(123)),
498 ScalarValue::UInt64(Some(123)),
499 ScalarValue::Decimal128(Some(123), 3, 0),
500 ScalarValue::Decimal128(Some(12300), 8, 2),
501 ];
502
503 for s1 in &scalars {
504 for s2 in &scalars {
505 let expected_value = ExpectedCast::Value(s2.clone());
506
507 expect_cast(s1.clone(), s2.data_type(), expected_value);
508 }
509 }
510
511 let max_i32 = ScalarValue::Int32(Some(i32::MAX));
512 expect_cast(
513 max_i32,
514 DataType::UInt64,
515 ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
516 );
517
518 let min_i32 = ScalarValue::Int32(Some(i32::MIN));
519 expect_cast(
520 min_i32,
521 DataType::Int64,
522 ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
523 );
524
525 let max_i64 = ScalarValue::Int64(Some(i64::MAX));
526 expect_cast(
527 max_i64,
528 DataType::UInt64,
529 ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
530 );
531 }
532
533 #[test]
534 fn test_try_cast_to_type_int_out_of_range() {
535 let min_i32 = ScalarValue::Int32(Some(i32::MIN));
536 let min_i64 = ScalarValue::Int64(Some(i64::MIN));
537 let max_i64 = ScalarValue::Int64(Some(i64::MAX));
538 let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
539
540 expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
541
542 expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
543
544 expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
545
546 expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
547
548 expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
549
550 expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
551
552 expect_cast(
554 ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
555 DataType::Int64,
556 ExpectedCast::NoValue,
557 );
558
559 expect_cast(
560 ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
561 DataType::Int64,
562 ExpectedCast::NoValue,
563 );
564 }
565
566 #[test]
567 fn test_try_decimal_cast_in_range() {
568 expect_cast(
569 ScalarValue::Decimal128(Some(12300), 5, 2),
570 DataType::Decimal128(3, 0),
571 ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
572 );
573
574 expect_cast(
575 ScalarValue::Decimal128(Some(12300), 5, 2),
576 DataType::Decimal128(8, 0),
577 ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
578 );
579
580 expect_cast(
581 ScalarValue::Decimal128(Some(12300), 5, 2),
582 DataType::Decimal128(8, 5),
583 ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
584 );
585 }
586
587 #[test]
588 fn test_try_decimal_cast_out_of_range() {
589 expect_cast(
591 ScalarValue::Decimal128(Some(12345), 5, 2),
592 DataType::Decimal128(3, 0),
593 ExpectedCast::NoValue,
594 );
595
596 expect_cast(
598 ScalarValue::Decimal128(Some(12300), 5, 2),
599 DataType::Decimal128(2, 0),
600 ExpectedCast::NoValue,
601 );
602 }
603
604 #[test]
605 fn test_try_cast_to_type_timestamps() {
606 for time_unit in [
607 TimeUnit::Second,
608 TimeUnit::Millisecond,
609 TimeUnit::Microsecond,
610 TimeUnit::Nanosecond,
611 ] {
612 let utc = Some("+00:00".into());
613 let (lit_tz_none, lit_tz_utc) = match time_unit {
615 TimeUnit::Second => (
616 ScalarValue::TimestampSecond(Some(12345), None),
617 ScalarValue::TimestampSecond(Some(12345), utc),
618 ),
619
620 TimeUnit::Millisecond => (
621 ScalarValue::TimestampMillisecond(Some(12345), None),
622 ScalarValue::TimestampMillisecond(Some(12345), utc),
623 ),
624
625 TimeUnit::Microsecond => (
626 ScalarValue::TimestampMicrosecond(Some(12345), None),
627 ScalarValue::TimestampMicrosecond(Some(12345), utc),
628 ),
629
630 TimeUnit::Nanosecond => (
631 ScalarValue::TimestampNanosecond(Some(12345), None),
632 ScalarValue::TimestampNanosecond(Some(12345), utc),
633 ),
634 };
635
636 assert_eq!(lit_tz_none, lit_tz_utc);
639
640 let dt_tz_none = lit_tz_none.data_type();
642
643 let dt_tz_utc = lit_tz_utc.data_type();
645
646 expect_cast(
648 lit_tz_none.clone(),
649 dt_tz_none.clone(),
650 ExpectedCast::Value(lit_tz_none.clone()),
651 );
652
653 expect_cast(
655 lit_tz_none.clone(),
656 dt_tz_utc.clone(),
657 ExpectedCast::Value(lit_tz_utc.clone()),
658 );
659
660 expect_cast(
662 lit_tz_utc.clone(),
663 dt_tz_none.clone(),
664 ExpectedCast::Value(lit_tz_none.clone()),
665 );
666
667 expect_cast(
669 lit_tz_utc.clone(),
670 dt_tz_utc.clone(),
671 ExpectedCast::Value(lit_tz_utc.clone()),
672 );
673
674 expect_cast(
676 lit_tz_utc.clone(),
677 DataType::Int64,
678 ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
679 );
680
681 expect_cast(
683 ScalarValue::Int64(Some(12345)),
684 dt_tz_none.clone(),
685 ExpectedCast::Value(lit_tz_none.clone()),
686 );
687
688 expect_cast(
690 ScalarValue::Int64(Some(12345)),
691 dt_tz_utc.clone(),
692 ExpectedCast::Value(lit_tz_utc.clone()),
693 );
694
695 expect_cast(
697 lit_tz_utc.clone(),
698 DataType::LargeUtf8,
699 ExpectedCast::NoValue,
700 );
701 }
702 }
703
704 #[test]
705 fn test_try_cast_to_type_unsupported() {
706 expect_cast(
708 ScalarValue::Int64(Some(12345)),
709 DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
710 ExpectedCast::NoValue,
711 );
712 }
713
714 #[test]
715 fn test_try_cast_literal_to_timestamp() {
716 let new_scalar = try_cast_literal_to_type(
718 &ScalarValue::TimestampNanosecond(Some(123456), None),
719 &DataType::Timestamp(TimeUnit::Nanosecond, None),
720 )
721 .unwrap();
722
723 assert_eq!(
724 new_scalar,
725 ScalarValue::TimestampNanosecond(Some(123456), None)
726 );
727
728 let new_scalar = try_cast_literal_to_type(
730 &ScalarValue::TimestampNanosecond(Some(123456), None),
731 &DataType::Timestamp(TimeUnit::Microsecond, None),
732 )
733 .unwrap();
734
735 assert_eq!(
736 new_scalar,
737 ScalarValue::TimestampMicrosecond(Some(123), None)
738 );
739
740 let new_scalar = try_cast_literal_to_type(
742 &ScalarValue::TimestampNanosecond(Some(123456), None),
743 &DataType::Timestamp(TimeUnit::Millisecond, None),
744 )
745 .unwrap();
746
747 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
748
749 let new_scalar = try_cast_literal_to_type(
751 &ScalarValue::TimestampNanosecond(Some(123456), None),
752 &DataType::Timestamp(TimeUnit::Second, None),
753 )
754 .unwrap();
755
756 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
757
758 let new_scalar = try_cast_literal_to_type(
760 &ScalarValue::TimestampMicrosecond(Some(123), None),
761 &DataType::Timestamp(TimeUnit::Nanosecond, None),
762 )
763 .unwrap();
764
765 assert_eq!(
766 new_scalar,
767 ScalarValue::TimestampNanosecond(Some(123000), None)
768 );
769
770 let new_scalar = try_cast_literal_to_type(
772 &ScalarValue::TimestampMicrosecond(Some(123), None),
773 &DataType::Timestamp(TimeUnit::Millisecond, None),
774 )
775 .unwrap();
776
777 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
778
779 let new_scalar = try_cast_literal_to_type(
781 &ScalarValue::TimestampMicrosecond(Some(123456789), None),
782 &DataType::Timestamp(TimeUnit::Second, None),
783 )
784 .unwrap();
785 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
786
787 let new_scalar = try_cast_literal_to_type(
789 &ScalarValue::TimestampMillisecond(Some(123), None),
790 &DataType::Timestamp(TimeUnit::Nanosecond, None),
791 )
792 .unwrap();
793 assert_eq!(
794 new_scalar,
795 ScalarValue::TimestampNanosecond(Some(123000000), None)
796 );
797
798 let new_scalar = try_cast_literal_to_type(
800 &ScalarValue::TimestampMillisecond(Some(123), None),
801 &DataType::Timestamp(TimeUnit::Microsecond, None),
802 )
803 .unwrap();
804 assert_eq!(
805 new_scalar,
806 ScalarValue::TimestampMicrosecond(Some(123000), None)
807 );
808 let new_scalar = try_cast_literal_to_type(
810 &ScalarValue::TimestampMillisecond(Some(123456789), None),
811 &DataType::Timestamp(TimeUnit::Second, None),
812 )
813 .unwrap();
814 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
815
816 let new_scalar = try_cast_literal_to_type(
818 &ScalarValue::TimestampSecond(Some(123), None),
819 &DataType::Timestamp(TimeUnit::Nanosecond, None),
820 )
821 .unwrap();
822 assert_eq!(
823 new_scalar,
824 ScalarValue::TimestampNanosecond(Some(123000000000), None)
825 );
826
827 let new_scalar = try_cast_literal_to_type(
829 &ScalarValue::TimestampSecond(Some(123), None),
830 &DataType::Timestamp(TimeUnit::Microsecond, None),
831 )
832 .unwrap();
833 assert_eq!(
834 new_scalar,
835 ScalarValue::TimestampMicrosecond(Some(123000000), None)
836 );
837
838 let new_scalar = try_cast_literal_to_type(
840 &ScalarValue::TimestampSecond(Some(123), None),
841 &DataType::Timestamp(TimeUnit::Millisecond, None),
842 )
843 .unwrap();
844 assert_eq!(
845 new_scalar,
846 ScalarValue::TimestampMillisecond(Some(123000), None)
847 );
848
849 let new_scalar = try_cast_literal_to_type(
851 &ScalarValue::TimestampSecond(Some(i64::MAX), None),
852 &DataType::Timestamp(TimeUnit::Millisecond, None),
853 )
854 .unwrap();
855 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
856 }
857
858 #[test]
859 fn test_try_cast_to_string_type() {
860 let scalars = vec![
861 ScalarValue::from("string"),
862 ScalarValue::LargeUtf8(Some("string".to_owned())),
863 ];
864
865 for s1 in &scalars {
866 for s2 in &scalars {
867 let expected_value = ExpectedCast::Value(s2.clone());
868
869 expect_cast(s1.clone(), s2.data_type(), expected_value);
870 }
871 }
872 }
873
874 #[test]
875 fn test_try_cast_to_dictionary_type() {
876 fn dictionary_type(t: DataType) -> DataType {
877 DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
878 }
879 fn dictionary_value(value: ScalarValue) -> ScalarValue {
880 ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
881 }
882 let scalars = vec![
883 ScalarValue::from("string"),
884 ScalarValue::LargeUtf8(Some("string".to_owned())),
885 ];
886 for s in &scalars {
887 expect_cast(
888 s.clone(),
889 dictionary_type(s.data_type()),
890 ExpectedCast::Value(dictionary_value(s.clone())),
891 );
892 expect_cast(
893 dictionary_value(s.clone()),
894 s.data_type(),
895 ExpectedCast::Value(s.clone()),
896 )
897 }
898 }
899
900 #[test]
901 fn test_try_cast_to_fixed_size_binary() {
902 expect_cast(
903 ScalarValue::Binary(Some(vec![1, 2, 3])),
904 DataType::FixedSizeBinary(3),
905 ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
906 )
907 }
908
909 #[test]
910 fn test_numeric_boundary_values() {
911 expect_cast(
913 ScalarValue::Int8(Some(i8::MAX)),
914 DataType::UInt8,
915 ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))),
916 );
917
918 expect_cast(
919 ScalarValue::Int8(Some(i8::MIN)),
920 DataType::UInt8,
921 ExpectedCast::NoValue,
922 );
923
924 expect_cast(
925 ScalarValue::UInt8(Some(u8::MAX)),
926 DataType::Int8,
927 ExpectedCast::NoValue,
928 );
929
930 expect_cast(
932 ScalarValue::Int32(Some(i32::MAX)),
933 DataType::Int64,
934 ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))),
935 );
936
937 expect_cast(
938 ScalarValue::Int64(Some(i64::MIN)),
939 DataType::UInt64,
940 ExpectedCast::NoValue,
941 );
942
943 expect_cast(
945 ScalarValue::UInt32(Some(u32::MAX)),
946 DataType::Int32,
947 ExpectedCast::NoValue,
948 );
949
950 expect_cast(
951 ScalarValue::UInt64(Some(u64::MAX)),
952 DataType::Int64,
953 ExpectedCast::NoValue,
954 );
955 }
956
957 #[test]
958 fn test_decimal_precision_limits() {
959 use arrow::datatypes::{
960 MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION,
961 };
962
963 expect_cast(
965 ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
966 DataType::Decimal128(5, 0),
967 ExpectedCast::Value(ScalarValue::Decimal128(
968 Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]),
969 5,
970 0,
971 )),
972 );
973
974 expect_cast(
976 ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
977 DataType::Decimal128(5, 0),
978 ExpectedCast::Value(ScalarValue::Decimal128(
979 Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]),
980 5,
981 0,
982 )),
983 );
984
985 expect_cast(
987 ScalarValue::Decimal128(Some(123), 3, 0),
988 DataType::Decimal128(5, 2),
989 ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)),
990 );
991
992 expect_cast(
994 ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0),
995 DataType::Decimal128(3, 0),
996 ExpectedCast::NoValue,
997 );
998
999 expect_cast(
1001 ScalarValue::Decimal128(Some(12345), 5, 3), DataType::Int32,
1003 ExpectedCast::NoValue, );
1005
1006 expect_cast(
1008 ScalarValue::Decimal128(Some(12345), 5, 2), DataType::Decimal128(3, 0), ExpectedCast::NoValue,
1011 );
1012 }
1013
1014 #[test]
1015 fn test_timestamp_overflow_scenarios() {
1016 let max_seconds = i64::MAX / 1_000_000_000; expect_cast(
1021 ScalarValue::TimestampSecond(Some(max_seconds), None),
1022 DataType::Timestamp(TimeUnit::Nanosecond, None),
1023 ExpectedCast::Value(ScalarValue::TimestampNanosecond(
1024 Some(max_seconds * 1_000_000_000),
1025 None,
1026 )),
1027 );
1028
1029 expect_cast(
1031 ScalarValue::TimestampNanosecond(Some(i64::MAX), None),
1032 DataType::Timestamp(TimeUnit::Second, None),
1033 ExpectedCast::Value(ScalarValue::TimestampSecond(
1034 Some(i64::MAX / 1_000_000_000),
1035 None,
1036 )),
1037 );
1038
1039 expect_cast(
1041 ScalarValue::TimestampNanosecond(Some(1), None),
1042 DataType::Timestamp(TimeUnit::Second, None),
1043 ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)),
1044 );
1045
1046 expect_cast(
1047 ScalarValue::TimestampMicrosecond(Some(999), None),
1048 DataType::Timestamp(TimeUnit::Millisecond, None),
1049 ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)),
1050 );
1051 }
1052
1053 #[test]
1054 fn test_string_view() {
1055 expect_cast(
1057 ScalarValue::Utf8View(Some("test".to_string())),
1058 DataType::Utf8,
1059 ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))),
1060 );
1061
1062 expect_cast(
1063 ScalarValue::Utf8View(Some("test".to_string())),
1064 DataType::LargeUtf8,
1065 ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))),
1066 );
1067
1068 expect_cast(
1070 ScalarValue::Utf8(Some("hello".to_string())),
1071 DataType::Utf8View,
1072 ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))),
1073 );
1074
1075 expect_cast(
1076 ScalarValue::LargeUtf8(Some("world".to_string())),
1077 DataType::Utf8View,
1078 ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))),
1079 );
1080
1081 expect_cast(
1083 ScalarValue::Utf8(Some("".to_string())),
1084 DataType::Utf8View,
1085 ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))),
1086 );
1087
1088 let large_string = "x".repeat(1000);
1090 expect_cast(
1091 ScalarValue::LargeUtf8(Some(large_string.clone())),
1092 DataType::Utf8View,
1093 ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))),
1094 );
1095 }
1096
1097 #[test]
1098 fn test_binary_size_edge_cases() {
1099 expect_cast(
1101 ScalarValue::Binary(Some(vec![1, 2])),
1102 DataType::FixedSizeBinary(3),
1103 ExpectedCast::NoValue,
1104 );
1105
1106 expect_cast(
1108 ScalarValue::Binary(Some(vec![1, 2, 3, 4])),
1109 DataType::FixedSizeBinary(3),
1110 ExpectedCast::NoValue,
1111 );
1112
1113 expect_cast(
1115 ScalarValue::Binary(Some(vec![])),
1116 DataType::FixedSizeBinary(0),
1117 ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))),
1118 );
1119
1120 expect_cast(
1122 ScalarValue::Binary(Some(vec![1, 2, 3])),
1123 DataType::FixedSizeBinary(3),
1124 ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
1125 );
1126
1127 expect_cast(
1129 ScalarValue::Binary(Some(vec![42])),
1130 DataType::FixedSizeBinary(1),
1131 ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))),
1132 );
1133 }
1134
1135 #[test]
1136 fn test_dictionary_index_types() {
1137 let string_value = ScalarValue::Utf8(Some("test".to_string()));
1139
1140 let dict_int8 =
1142 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
1143 expect_cast(
1144 string_value.clone(),
1145 dict_int8,
1146 ExpectedCast::Value(ScalarValue::Dictionary(
1147 Box::new(DataType::Int8),
1148 Box::new(string_value.clone()),
1149 )),
1150 );
1151
1152 let dict_int16 =
1154 DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
1155 expect_cast(
1156 string_value.clone(),
1157 dict_int16,
1158 ExpectedCast::Value(ScalarValue::Dictionary(
1159 Box::new(DataType::Int16),
1160 Box::new(string_value.clone()),
1161 )),
1162 );
1163
1164 let dict_int64 =
1166 DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8));
1167 expect_cast(
1168 string_value.clone(),
1169 dict_int64,
1170 ExpectedCast::Value(ScalarValue::Dictionary(
1171 Box::new(DataType::Int64),
1172 Box::new(string_value.clone()),
1173 )),
1174 );
1175
1176 let dict_value = ScalarValue::Dictionary(
1178 Box::new(DataType::Int32),
1179 Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1180 );
1181 expect_cast(
1182 dict_value,
1183 DataType::LargeUtf8,
1184 ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1185 );
1186 }
1187
1188 #[test]
1189 fn test_type_support_functions() {
1190 assert!(is_supported_numeric_type(&DataType::Int8));
1192 assert!(is_supported_numeric_type(&DataType::UInt64));
1193 assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2)));
1194 assert!(is_supported_numeric_type(&DataType::Timestamp(
1195 TimeUnit::Nanosecond,
1196 None
1197 )));
1198 assert!(!is_supported_numeric_type(&DataType::Float32));
1199 assert!(!is_supported_numeric_type(&DataType::Float64));
1200
1201 assert!(is_supported_string_type(&DataType::Utf8));
1203 assert!(is_supported_string_type(&DataType::LargeUtf8));
1204 assert!(is_supported_string_type(&DataType::Utf8View));
1205 assert!(!is_supported_string_type(&DataType::Binary));
1206
1207 assert!(is_supported_binary_type(&DataType::Binary));
1209 assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10)));
1210 assert!(!is_supported_binary_type(&DataType::Utf8));
1211
1212 assert!(is_supported_dictionary_type(&DataType::Dictionary(
1214 Box::new(DataType::Int32),
1215 Box::new(DataType::Utf8)
1216 )));
1217 assert!(is_supported_dictionary_type(&DataType::Dictionary(
1218 Box::new(DataType::Int32),
1219 Box::new(DataType::Int64)
1220 )));
1221 assert!(!is_supported_dictionary_type(&DataType::Dictionary(
1222 Box::new(DataType::Int32),
1223 Box::new(DataType::List(Arc::new(Field::new(
1224 "item",
1225 DataType::Int32,
1226 true
1227 ))))
1228 )));
1229
1230 assert!(is_supported_type(&DataType::Int32));
1232 assert!(is_supported_type(&DataType::Utf8));
1233 assert!(is_supported_type(&DataType::Binary));
1234 assert!(is_supported_type(&DataType::Dictionary(
1235 Box::new(DataType::Int32),
1236 Box::new(DataType::Utf8)
1237 )));
1238 assert!(!is_supported_type(&DataType::List(Arc::new(Field::new(
1239 "item",
1240 DataType::Int32,
1241 true
1242 )))));
1243 assert!(!is_supported_type(&DataType::Struct(Fields::empty())));
1244 }
1245
1246 #[test]
1247 fn test_error_conditions() {
1248 expect_cast(
1250 ScalarValue::Float32(Some(1.5)),
1251 DataType::Int32,
1252 ExpectedCast::NoValue,
1253 );
1254
1255 expect_cast(
1257 ScalarValue::Int32(Some(123)),
1258 DataType::Float64,
1259 ExpectedCast::NoValue,
1260 );
1261
1262 expect_cast(
1264 ScalarValue::Float64(Some(1.5)),
1265 DataType::Float32,
1266 ExpectedCast::NoValue,
1267 );
1268
1269 let list_type =
1271 DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1272 expect_cast(
1273 ScalarValue::Int32(Some(123)),
1274 list_type,
1275 ExpectedCast::NoValue,
1276 );
1277
1278 let bad_dict = DataType::Dictionary(
1280 Box::new(DataType::Int32),
1281 Box::new(DataType::List(Arc::new(Field::new(
1282 "item",
1283 DataType::Int32,
1284 true,
1285 )))),
1286 );
1287 expect_cast(
1288 ScalarValue::Int32(Some(123)),
1289 bad_dict,
1290 ExpectedCast::NoValue,
1291 );
1292 }
1293}