1use std::cmp::Ordering;
25
26use arrow::datatypes::{
27 DataType, MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
28 MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION,
29 MIN_DECIMAL64_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION, TimeUnit,
30};
31use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
32use datafusion_common::ScalarValue;
33
34pub fn try_cast_literal_to_type(
36 lit_value: &ScalarValue,
37 target_type: &DataType,
38) -> Option<ScalarValue> {
39 let lit_data_type = lit_value.data_type();
40 if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
41 return None;
42 }
43 if lit_value.is_null() {
44 return ScalarValue::try_from(target_type).ok();
46 }
47 try_cast_numeric_literal(lit_value, target_type)
48 .or_else(|| try_cast_string_literal(lit_value, target_type))
49 .or_else(|| try_cast_dictionary(lit_value, target_type))
50 .or_else(|| try_cast_binary(lit_value, target_type))
51}
52
53pub fn is_supported_type(data_type: &DataType) -> bool {
55 is_supported_numeric_type(data_type)
56 || is_supported_string_type(data_type)
57 || is_supported_dictionary_type(data_type)
58 || is_supported_binary_type(data_type)
59}
60
61fn is_supported_numeric_type(data_type: &DataType) -> bool {
63 matches!(
64 data_type,
65 DataType::UInt8
66 | DataType::UInt16
67 | DataType::UInt32
68 | DataType::UInt64
69 | DataType::Int8
70 | DataType::Int16
71 | DataType::Int32
72 | DataType::Int64
73 | DataType::Decimal32(_, _)
74 | DataType::Decimal64(_, _)
75 | DataType::Decimal128(_, _)
76 | DataType::Timestamp(_, _)
77 )
78}
79
80fn is_supported_string_type(data_type: &DataType) -> bool {
82 matches!(
83 data_type,
84 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
85 )
86}
87
88fn is_supported_dictionary_type(data_type: &DataType) -> bool {
90 matches!(data_type,
91 DataType::Dictionary(_, inner) if is_supported_type(inner))
92}
93
94fn is_supported_binary_type(data_type: &DataType) -> bool {
95 matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
96}
97
98fn try_cast_numeric_literal(
100 lit_value: &ScalarValue,
101 target_type: &DataType,
102) -> Option<ScalarValue> {
103 let lit_data_type = lit_value.data_type();
104 if !is_supported_numeric_type(&lit_data_type)
105 || !is_supported_numeric_type(target_type)
106 {
107 return None;
108 }
109
110 let mul = match target_type {
111 DataType::UInt8
112 | DataType::UInt16
113 | DataType::UInt32
114 | DataType::UInt64
115 | DataType::Int8
116 | DataType::Int16
117 | DataType::Int32
118 | DataType::Int64 => 1_i128,
119 DataType::Timestamp(_, _) => 1_i128,
120 DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
121 DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
122 DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
123 _ => return None,
124 };
125 let (target_min, target_max) = match target_type {
126 DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
127 DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
128 DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
129 DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
130 DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
131 DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
132 DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
133 DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
134 DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
135 DataType::Decimal32(precision, _) => (
136 MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
140 MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
141 ),
142 DataType::Decimal64(precision, _) => (
143 MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
147 MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
148 ),
149 DataType::Decimal128(precision, _) => (
150 MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
154 MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
155 ),
156 _ => return None,
157 };
158 let lit_value_target_type = match lit_value {
159 ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
160 ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
161 ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
162 ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
163 ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
164 ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
165 ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
166 ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
167 ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
168 ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
169 ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
170 ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
171 ScalarValue::Decimal32(Some(v), _, scale) => {
172 let v = *v as i128;
173 let lit_scale_mul = 10_i128.pow(*scale as u32);
174 if mul >= lit_scale_mul {
175 v.checked_mul(mul / lit_scale_mul)
180 } else if v % (lit_scale_mul / mul) == 0 {
181 Some(v / (lit_scale_mul / mul))
186 } else {
187 None
189 }
190 }
191 ScalarValue::Decimal64(Some(v), _, scale) => {
192 let v = *v as i128;
193 let lit_scale_mul = 10_i128.pow(*scale as u32);
194 if mul >= lit_scale_mul {
195 v.checked_mul(mul / lit_scale_mul)
200 } else if v % (lit_scale_mul / mul) == 0 {
201 Some(v / (lit_scale_mul / mul))
206 } else {
207 None
209 }
210 }
211 ScalarValue::Decimal128(Some(v), _, scale) => {
212 let lit_scale_mul = 10_i128.pow(*scale as u32);
213 if mul >= lit_scale_mul {
214 (*v).checked_mul(mul / lit_scale_mul)
219 } else if (*v) % (lit_scale_mul / mul) == 0 {
220 Some(*v / (lit_scale_mul / mul))
225 } else {
226 None
228 }
229 }
230 _ => None,
231 };
232
233 match lit_value_target_type {
234 None => None,
235 Some(value) => {
236 if value >= target_min && value <= target_max {
237 let result_scalar = match target_type {
240 DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
241 DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
242 DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
243 DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
244 DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
245 DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
246 DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
247 DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
248 DataType::Timestamp(TimeUnit::Second, tz) => {
249 let value = cast_between_timestamp(
250 &lit_data_type,
251 &DataType::Timestamp(TimeUnit::Second, tz.clone()),
252 value,
253 );
254 ScalarValue::TimestampSecond(value, tz.clone())
255 }
256 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
257 let value = cast_between_timestamp(
258 &lit_data_type,
259 &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
260 value,
261 );
262 ScalarValue::TimestampMillisecond(value, tz.clone())
263 }
264 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
265 let value = cast_between_timestamp(
266 &lit_data_type,
267 &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
268 value,
269 );
270 ScalarValue::TimestampMicrosecond(value, tz.clone())
271 }
272 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
273 let value = cast_between_timestamp(
274 &lit_data_type,
275 &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
276 value,
277 );
278 ScalarValue::TimestampNanosecond(value, tz.clone())
279 }
280 DataType::Decimal32(p, s) => {
281 ScalarValue::Decimal32(Some(value as i32), *p, *s)
282 }
283 DataType::Decimal64(p, s) => {
284 ScalarValue::Decimal64(Some(value as i64), *p, *s)
285 }
286 DataType::Decimal128(p, s) => {
287 ScalarValue::Decimal128(Some(value), *p, *s)
288 }
289 _ => {
290 return None;
291 }
292 };
293 Some(result_scalar)
294 } else {
295 None
296 }
297 }
298 }
299}
300
301fn try_cast_string_literal(
302 lit_value: &ScalarValue,
303 target_type: &DataType,
304) -> Option<ScalarValue> {
305 let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
306 let scalar_value = match target_type {
307 DataType::Utf8 => ScalarValue::Utf8(string_value),
308 DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
309 DataType::Utf8View => ScalarValue::Utf8View(string_value),
310 _ => return None,
311 };
312 Some(scalar_value)
313}
314
315fn try_cast_dictionary(
317 lit_value: &ScalarValue,
318 target_type: &DataType,
319) -> Option<ScalarValue> {
320 let lit_value_type = lit_value.data_type();
321 let result_scalar = match (lit_value, target_type) {
322 (ScalarValue::Dictionary(_, inner_value), _)
324 if inner_value.data_type() == *target_type =>
325 {
326 (**inner_value).clone()
327 }
328 (_, DataType::Dictionary(index_type, inner_type))
330 if **inner_type == lit_value_type =>
331 {
332 ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
333 }
334 _ => {
335 return None;
336 }
337 };
338 Some(result_scalar)
339}
340
341fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
343 let value = value as i64;
344 let from_scale = match from {
345 DataType::Timestamp(TimeUnit::Second, _) => 1,
346 DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
347 DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
348 DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
349 _ => return Some(value),
350 };
351
352 let to_scale = match to {
353 DataType::Timestamp(TimeUnit::Second, _) => 1,
354 DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
355 DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
356 DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
357 _ => return Some(value),
358 };
359
360 match from_scale.cmp(&to_scale) {
361 Ordering::Less => value.checked_mul(to_scale / from_scale),
362 Ordering::Greater => Some(value / (from_scale / to_scale)),
363 Ordering::Equal => Some(value),
364 }
365}
366
367fn try_cast_binary(
368 lit_value: &ScalarValue,
369 target_type: &DataType,
370) -> Option<ScalarValue> {
371 match (lit_value, target_type) {
372 (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
373 if v.len() == *n as usize =>
374 {
375 Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
376 }
377 _ => None,
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use arrow::compute::{CastOptions, cast_with_options};
385 use arrow::datatypes::{Field, Fields, TimeUnit};
386 use std::sync::Arc;
387
388 #[derive(Debug, Clone)]
389 enum ExpectedCast {
390 Value(ScalarValue),
392 NoValue,
394 }
395
396 fn expect_cast(
400 literal: ScalarValue,
401 target_type: DataType,
402 expected_result: ExpectedCast,
403 ) {
404 let actual_value = try_cast_literal_to_type(&literal, &target_type);
405
406 println!("expect_cast: ");
407 println!(" {literal:?} --> {target_type}");
408 println!(" expected_result: {expected_result:?}");
409 println!(" actual_result: {actual_value:?}");
410
411 match expected_result {
412 ExpectedCast::Value(expected_value) => {
413 let actual_value =
414 actual_value.expect("Expected cast value but got None");
415
416 assert_eq!(actual_value, expected_value);
417
418 let literal_array = literal
422 .to_array_of_size(1)
423 .expect("Failed to convert to array of size");
424 let expected_array = expected_value
425 .to_array_of_size(1)
426 .expect("Failed to convert to array of size");
427 let cast_array = cast_with_options(
428 &literal_array,
429 &target_type,
430 &CastOptions::default(),
431 )
432 .expect("Expected to be cast array with arrow cast kernel");
433
434 assert_eq!(
435 &expected_array, &cast_array,
436 "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
437 );
438
439 if let (
442 DataType::Timestamp(left_unit, left_tz),
443 DataType::Timestamp(right_unit, right_tz),
444 ) = (actual_value.data_type(), expected_value.data_type())
445 {
446 assert_eq!(left_unit, right_unit);
447 assert_eq!(left_tz, right_tz);
448 }
449 }
450 ExpectedCast::NoValue => {
451 assert!(
452 actual_value.is_none(),
453 "Expected no cast value, but got {actual_value:?}"
454 );
455 }
456 }
457 }
458
459 #[test]
460 fn test_try_cast_to_type_nulls() {
461 let scalars = vec![
463 ScalarValue::Int8(None),
464 ScalarValue::Int16(None),
465 ScalarValue::Int32(None),
466 ScalarValue::Int64(None),
467 ScalarValue::UInt8(None),
468 ScalarValue::UInt16(None),
469 ScalarValue::UInt32(None),
470 ScalarValue::UInt64(None),
471 ScalarValue::Decimal128(None, 3, 0),
472 ScalarValue::Decimal128(None, 8, 2),
473 ScalarValue::Utf8(None),
474 ScalarValue::LargeUtf8(None),
475 ];
476
477 for s1 in &scalars {
478 for s2 in &scalars {
479 let expected_value = ExpectedCast::Value(s2.clone());
480
481 expect_cast(s1.clone(), s2.data_type(), expected_value);
482 }
483 }
484 }
485
486 #[test]
487 fn test_try_cast_to_type_int_in_range() {
488 let scalars = vec![
490 ScalarValue::Int8(Some(123)),
491 ScalarValue::Int16(Some(123)),
492 ScalarValue::Int32(Some(123)),
493 ScalarValue::Int64(Some(123)),
494 ScalarValue::UInt8(Some(123)),
495 ScalarValue::UInt16(Some(123)),
496 ScalarValue::UInt32(Some(123)),
497 ScalarValue::UInt64(Some(123)),
498 ScalarValue::Decimal128(Some(123), 3, 0),
499 ScalarValue::Decimal128(Some(12300), 8, 2),
500 ];
501
502 for s1 in &scalars {
503 for s2 in &scalars {
504 let expected_value = ExpectedCast::Value(s2.clone());
505
506 expect_cast(s1.clone(), s2.data_type(), expected_value);
507 }
508 }
509
510 let max_i32 = ScalarValue::Int32(Some(i32::MAX));
511 expect_cast(
512 max_i32,
513 DataType::UInt64,
514 ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
515 );
516
517 let min_i32 = ScalarValue::Int32(Some(i32::MIN));
518 expect_cast(
519 min_i32,
520 DataType::Int64,
521 ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
522 );
523
524 let max_i64 = ScalarValue::Int64(Some(i64::MAX));
525 expect_cast(
526 max_i64,
527 DataType::UInt64,
528 ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
529 );
530 }
531
532 #[test]
533 fn test_try_cast_to_type_int_out_of_range() {
534 let min_i32 = ScalarValue::Int32(Some(i32::MIN));
535 let min_i64 = ScalarValue::Int64(Some(i64::MIN));
536 let max_i64 = ScalarValue::Int64(Some(i64::MAX));
537 let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
538
539 expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
540
541 expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
542
543 expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
544
545 expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
546
547 expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
548
549 expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
550
551 expect_cast(
553 ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
554 DataType::Int64,
555 ExpectedCast::NoValue,
556 );
557
558 expect_cast(
559 ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
560 DataType::Int64,
561 ExpectedCast::NoValue,
562 );
563 }
564
565 #[test]
566 fn test_try_decimal_cast_in_range() {
567 expect_cast(
568 ScalarValue::Decimal128(Some(12300), 5, 2),
569 DataType::Decimal128(3, 0),
570 ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
571 );
572
573 expect_cast(
574 ScalarValue::Decimal128(Some(12300), 5, 2),
575 DataType::Decimal128(8, 0),
576 ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
577 );
578
579 expect_cast(
580 ScalarValue::Decimal128(Some(12300), 5, 2),
581 DataType::Decimal128(8, 5),
582 ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
583 );
584 }
585
586 #[test]
587 fn test_try_decimal_cast_out_of_range() {
588 expect_cast(
590 ScalarValue::Decimal128(Some(12345), 5, 2),
591 DataType::Decimal128(3, 0),
592 ExpectedCast::NoValue,
593 );
594
595 expect_cast(
597 ScalarValue::Decimal128(Some(12300), 5, 2),
598 DataType::Decimal128(2, 0),
599 ExpectedCast::NoValue,
600 );
601 }
602
603 #[test]
604 fn test_try_cast_to_type_timestamps() {
605 for time_unit in [
606 TimeUnit::Second,
607 TimeUnit::Millisecond,
608 TimeUnit::Microsecond,
609 TimeUnit::Nanosecond,
610 ] {
611 let utc = Some("+00:00".into());
612 let (lit_tz_none, lit_tz_utc) = match time_unit {
614 TimeUnit::Second => (
615 ScalarValue::TimestampSecond(Some(12345), None),
616 ScalarValue::TimestampSecond(Some(12345), utc),
617 ),
618
619 TimeUnit::Millisecond => (
620 ScalarValue::TimestampMillisecond(Some(12345), None),
621 ScalarValue::TimestampMillisecond(Some(12345), utc),
622 ),
623
624 TimeUnit::Microsecond => (
625 ScalarValue::TimestampMicrosecond(Some(12345), None),
626 ScalarValue::TimestampMicrosecond(Some(12345), utc),
627 ),
628
629 TimeUnit::Nanosecond => (
630 ScalarValue::TimestampNanosecond(Some(12345), None),
631 ScalarValue::TimestampNanosecond(Some(12345), utc),
632 ),
633 };
634
635 assert_eq!(lit_tz_none, lit_tz_utc);
638
639 let dt_tz_none = lit_tz_none.data_type();
641
642 let dt_tz_utc = lit_tz_utc.data_type();
644
645 expect_cast(
647 lit_tz_none.clone(),
648 dt_tz_none.clone(),
649 ExpectedCast::Value(lit_tz_none.clone()),
650 );
651
652 expect_cast(
654 lit_tz_none.clone(),
655 dt_tz_utc.clone(),
656 ExpectedCast::Value(lit_tz_utc.clone()),
657 );
658
659 expect_cast(
661 lit_tz_utc.clone(),
662 dt_tz_none.clone(),
663 ExpectedCast::Value(lit_tz_none.clone()),
664 );
665
666 expect_cast(
668 lit_tz_utc.clone(),
669 dt_tz_utc.clone(),
670 ExpectedCast::Value(lit_tz_utc.clone()),
671 );
672
673 expect_cast(
675 lit_tz_utc.clone(),
676 DataType::Int64,
677 ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
678 );
679
680 expect_cast(
682 ScalarValue::Int64(Some(12345)),
683 dt_tz_none.clone(),
684 ExpectedCast::Value(lit_tz_none.clone()),
685 );
686
687 expect_cast(
689 ScalarValue::Int64(Some(12345)),
690 dt_tz_utc.clone(),
691 ExpectedCast::Value(lit_tz_utc.clone()),
692 );
693
694 expect_cast(
696 lit_tz_utc.clone(),
697 DataType::LargeUtf8,
698 ExpectedCast::NoValue,
699 );
700 }
701 }
702
703 #[test]
704 fn test_try_cast_to_type_unsupported() {
705 expect_cast(
707 ScalarValue::Int64(Some(12345)),
708 DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
709 ExpectedCast::NoValue,
710 );
711 }
712
713 #[test]
714 fn test_try_cast_literal_to_timestamp() {
715 let new_scalar = try_cast_literal_to_type(
717 &ScalarValue::TimestampNanosecond(Some(123456), None),
718 &DataType::Timestamp(TimeUnit::Nanosecond, None),
719 )
720 .unwrap();
721
722 assert_eq!(
723 new_scalar,
724 ScalarValue::TimestampNanosecond(Some(123456), None)
725 );
726
727 let new_scalar = try_cast_literal_to_type(
729 &ScalarValue::TimestampNanosecond(Some(123456), None),
730 &DataType::Timestamp(TimeUnit::Microsecond, None),
731 )
732 .unwrap();
733
734 assert_eq!(
735 new_scalar,
736 ScalarValue::TimestampMicrosecond(Some(123), None)
737 );
738
739 let new_scalar = try_cast_literal_to_type(
741 &ScalarValue::TimestampNanosecond(Some(123456), None),
742 &DataType::Timestamp(TimeUnit::Millisecond, None),
743 )
744 .unwrap();
745
746 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
747
748 let new_scalar = try_cast_literal_to_type(
750 &ScalarValue::TimestampNanosecond(Some(123456), None),
751 &DataType::Timestamp(TimeUnit::Second, None),
752 )
753 .unwrap();
754
755 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
756
757 let new_scalar = try_cast_literal_to_type(
759 &ScalarValue::TimestampMicrosecond(Some(123), None),
760 &DataType::Timestamp(TimeUnit::Nanosecond, None),
761 )
762 .unwrap();
763
764 assert_eq!(
765 new_scalar,
766 ScalarValue::TimestampNanosecond(Some(123000), None)
767 );
768
769 let new_scalar = try_cast_literal_to_type(
771 &ScalarValue::TimestampMicrosecond(Some(123), None),
772 &DataType::Timestamp(TimeUnit::Millisecond, None),
773 )
774 .unwrap();
775
776 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
777
778 let new_scalar = try_cast_literal_to_type(
780 &ScalarValue::TimestampMicrosecond(Some(123456789), None),
781 &DataType::Timestamp(TimeUnit::Second, None),
782 )
783 .unwrap();
784 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
785
786 let new_scalar = try_cast_literal_to_type(
788 &ScalarValue::TimestampMillisecond(Some(123), None),
789 &DataType::Timestamp(TimeUnit::Nanosecond, None),
790 )
791 .unwrap();
792 assert_eq!(
793 new_scalar,
794 ScalarValue::TimestampNanosecond(Some(123000000), None)
795 );
796
797 let new_scalar = try_cast_literal_to_type(
799 &ScalarValue::TimestampMillisecond(Some(123), None),
800 &DataType::Timestamp(TimeUnit::Microsecond, None),
801 )
802 .unwrap();
803 assert_eq!(
804 new_scalar,
805 ScalarValue::TimestampMicrosecond(Some(123000), None)
806 );
807 let new_scalar = try_cast_literal_to_type(
809 &ScalarValue::TimestampMillisecond(Some(123456789), None),
810 &DataType::Timestamp(TimeUnit::Second, None),
811 )
812 .unwrap();
813 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
814
815 let new_scalar = try_cast_literal_to_type(
817 &ScalarValue::TimestampSecond(Some(123), None),
818 &DataType::Timestamp(TimeUnit::Nanosecond, None),
819 )
820 .unwrap();
821 assert_eq!(
822 new_scalar,
823 ScalarValue::TimestampNanosecond(Some(123000000000), None)
824 );
825
826 let new_scalar = try_cast_literal_to_type(
828 &ScalarValue::TimestampSecond(Some(123), None),
829 &DataType::Timestamp(TimeUnit::Microsecond, None),
830 )
831 .unwrap();
832 assert_eq!(
833 new_scalar,
834 ScalarValue::TimestampMicrosecond(Some(123000000), None)
835 );
836
837 let new_scalar = try_cast_literal_to_type(
839 &ScalarValue::TimestampSecond(Some(123), None),
840 &DataType::Timestamp(TimeUnit::Millisecond, None),
841 )
842 .unwrap();
843 assert_eq!(
844 new_scalar,
845 ScalarValue::TimestampMillisecond(Some(123000), None)
846 );
847
848 let new_scalar = try_cast_literal_to_type(
850 &ScalarValue::TimestampSecond(Some(i64::MAX), None),
851 &DataType::Timestamp(TimeUnit::Millisecond, None),
852 )
853 .unwrap();
854 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
855 }
856
857 #[test]
858 fn test_try_cast_to_string_type() {
859 let scalars = vec![
860 ScalarValue::from("string"),
861 ScalarValue::LargeUtf8(Some("string".to_owned())),
862 ];
863
864 for s1 in &scalars {
865 for s2 in &scalars {
866 let expected_value = ExpectedCast::Value(s2.clone());
867
868 expect_cast(s1.clone(), s2.data_type(), expected_value);
869 }
870 }
871 }
872
873 #[test]
874 fn test_try_cast_to_dictionary_type() {
875 fn dictionary_type(t: DataType) -> DataType {
876 DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
877 }
878 fn dictionary_value(value: ScalarValue) -> ScalarValue {
879 ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
880 }
881 let scalars = vec![
882 ScalarValue::from("string"),
883 ScalarValue::LargeUtf8(Some("string".to_owned())),
884 ];
885 for s in &scalars {
886 expect_cast(
887 s.clone(),
888 dictionary_type(s.data_type()),
889 ExpectedCast::Value(dictionary_value(s.clone())),
890 );
891 expect_cast(
892 dictionary_value(s.clone()),
893 s.data_type(),
894 ExpectedCast::Value(s.clone()),
895 )
896 }
897 }
898
899 #[test]
900 fn test_try_cast_to_fixed_size_binary() {
901 expect_cast(
902 ScalarValue::Binary(Some(vec![1, 2, 3])),
903 DataType::FixedSizeBinary(3),
904 ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
905 )
906 }
907
908 #[test]
909 fn test_numeric_boundary_values() {
910 expect_cast(
912 ScalarValue::Int8(Some(i8::MAX)),
913 DataType::UInt8,
914 ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))),
915 );
916
917 expect_cast(
918 ScalarValue::Int8(Some(i8::MIN)),
919 DataType::UInt8,
920 ExpectedCast::NoValue,
921 );
922
923 expect_cast(
924 ScalarValue::UInt8(Some(u8::MAX)),
925 DataType::Int8,
926 ExpectedCast::NoValue,
927 );
928
929 expect_cast(
931 ScalarValue::Int32(Some(i32::MAX)),
932 DataType::Int64,
933 ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))),
934 );
935
936 expect_cast(
937 ScalarValue::Int64(Some(i64::MIN)),
938 DataType::UInt64,
939 ExpectedCast::NoValue,
940 );
941
942 expect_cast(
944 ScalarValue::UInt32(Some(u32::MAX)),
945 DataType::Int32,
946 ExpectedCast::NoValue,
947 );
948
949 expect_cast(
950 ScalarValue::UInt64(Some(u64::MAX)),
951 DataType::Int64,
952 ExpectedCast::NoValue,
953 );
954 }
955
956 #[test]
957 fn test_decimal_precision_limits() {
958 use arrow::datatypes::{
959 MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION,
960 };
961
962 expect_cast(
964 ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
965 DataType::Decimal128(5, 0),
966 ExpectedCast::Value(ScalarValue::Decimal128(
967 Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]),
968 5,
969 0,
970 )),
971 );
972
973 expect_cast(
975 ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
976 DataType::Decimal128(5, 0),
977 ExpectedCast::Value(ScalarValue::Decimal128(
978 Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]),
979 5,
980 0,
981 )),
982 );
983
984 expect_cast(
986 ScalarValue::Decimal128(Some(123), 3, 0),
987 DataType::Decimal128(5, 2),
988 ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)),
989 );
990
991 expect_cast(
993 ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0),
994 DataType::Decimal128(3, 0),
995 ExpectedCast::NoValue,
996 );
997
998 expect_cast(
1000 ScalarValue::Decimal128(Some(12345), 5, 3), DataType::Int32,
1002 ExpectedCast::NoValue, );
1004
1005 expect_cast(
1007 ScalarValue::Decimal128(Some(12345), 5, 2), DataType::Decimal128(3, 0), ExpectedCast::NoValue,
1010 );
1011 }
1012
1013 #[test]
1014 fn test_timestamp_overflow_scenarios() {
1015 let max_seconds = i64::MAX / 1_000_000_000; expect_cast(
1020 ScalarValue::TimestampSecond(Some(max_seconds), None),
1021 DataType::Timestamp(TimeUnit::Nanosecond, None),
1022 ExpectedCast::Value(ScalarValue::TimestampNanosecond(
1023 Some(max_seconds * 1_000_000_000),
1024 None,
1025 )),
1026 );
1027
1028 expect_cast(
1030 ScalarValue::TimestampNanosecond(Some(i64::MAX), None),
1031 DataType::Timestamp(TimeUnit::Second, None),
1032 ExpectedCast::Value(ScalarValue::TimestampSecond(
1033 Some(i64::MAX / 1_000_000_000),
1034 None,
1035 )),
1036 );
1037
1038 expect_cast(
1040 ScalarValue::TimestampNanosecond(Some(1), None),
1041 DataType::Timestamp(TimeUnit::Second, None),
1042 ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)),
1043 );
1044
1045 expect_cast(
1046 ScalarValue::TimestampMicrosecond(Some(999), None),
1047 DataType::Timestamp(TimeUnit::Millisecond, None),
1048 ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)),
1049 );
1050 }
1051
1052 #[test]
1053 fn test_string_view() {
1054 expect_cast(
1056 ScalarValue::Utf8View(Some("test".to_string())),
1057 DataType::Utf8,
1058 ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))),
1059 );
1060
1061 expect_cast(
1062 ScalarValue::Utf8View(Some("test".to_string())),
1063 DataType::LargeUtf8,
1064 ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))),
1065 );
1066
1067 expect_cast(
1069 ScalarValue::Utf8(Some("hello".to_string())),
1070 DataType::Utf8View,
1071 ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))),
1072 );
1073
1074 expect_cast(
1075 ScalarValue::LargeUtf8(Some("world".to_string())),
1076 DataType::Utf8View,
1077 ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))),
1078 );
1079
1080 expect_cast(
1082 ScalarValue::Utf8(Some("".to_string())),
1083 DataType::Utf8View,
1084 ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))),
1085 );
1086
1087 let large_string = "x".repeat(1000);
1089 expect_cast(
1090 ScalarValue::LargeUtf8(Some(large_string.clone())),
1091 DataType::Utf8View,
1092 ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))),
1093 );
1094 }
1095
1096 #[test]
1097 fn test_binary_size_edge_cases() {
1098 expect_cast(
1100 ScalarValue::Binary(Some(vec![1, 2])),
1101 DataType::FixedSizeBinary(3),
1102 ExpectedCast::NoValue,
1103 );
1104
1105 expect_cast(
1107 ScalarValue::Binary(Some(vec![1, 2, 3, 4])),
1108 DataType::FixedSizeBinary(3),
1109 ExpectedCast::NoValue,
1110 );
1111
1112 expect_cast(
1114 ScalarValue::Binary(Some(vec![])),
1115 DataType::FixedSizeBinary(0),
1116 ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))),
1117 );
1118
1119 expect_cast(
1121 ScalarValue::Binary(Some(vec![1, 2, 3])),
1122 DataType::FixedSizeBinary(3),
1123 ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
1124 );
1125
1126 expect_cast(
1128 ScalarValue::Binary(Some(vec![42])),
1129 DataType::FixedSizeBinary(1),
1130 ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))),
1131 );
1132 }
1133
1134 #[test]
1135 fn test_dictionary_index_types() {
1136 let string_value = ScalarValue::Utf8(Some("test".to_string()));
1138
1139 let dict_int8 =
1141 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
1142 expect_cast(
1143 string_value.clone(),
1144 dict_int8,
1145 ExpectedCast::Value(ScalarValue::Dictionary(
1146 Box::new(DataType::Int8),
1147 Box::new(string_value.clone()),
1148 )),
1149 );
1150
1151 let dict_int16 =
1153 DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
1154 expect_cast(
1155 string_value.clone(),
1156 dict_int16,
1157 ExpectedCast::Value(ScalarValue::Dictionary(
1158 Box::new(DataType::Int16),
1159 Box::new(string_value.clone()),
1160 )),
1161 );
1162
1163 let dict_int64 =
1165 DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8));
1166 expect_cast(
1167 string_value.clone(),
1168 dict_int64,
1169 ExpectedCast::Value(ScalarValue::Dictionary(
1170 Box::new(DataType::Int64),
1171 Box::new(string_value.clone()),
1172 )),
1173 );
1174
1175 let dict_value = ScalarValue::Dictionary(
1177 Box::new(DataType::Int32),
1178 Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1179 );
1180 expect_cast(
1181 dict_value,
1182 DataType::LargeUtf8,
1183 ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1184 );
1185 }
1186
1187 #[test]
1188 fn test_type_support_functions() {
1189 assert!(is_supported_numeric_type(&DataType::Int8));
1191 assert!(is_supported_numeric_type(&DataType::UInt64));
1192 assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2)));
1193 assert!(is_supported_numeric_type(&DataType::Timestamp(
1194 TimeUnit::Nanosecond,
1195 None
1196 )));
1197 assert!(!is_supported_numeric_type(&DataType::Float32));
1198 assert!(!is_supported_numeric_type(&DataType::Float64));
1199
1200 assert!(is_supported_string_type(&DataType::Utf8));
1202 assert!(is_supported_string_type(&DataType::LargeUtf8));
1203 assert!(is_supported_string_type(&DataType::Utf8View));
1204 assert!(!is_supported_string_type(&DataType::Binary));
1205
1206 assert!(is_supported_binary_type(&DataType::Binary));
1208 assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10)));
1209 assert!(!is_supported_binary_type(&DataType::Utf8));
1210
1211 assert!(is_supported_dictionary_type(&DataType::Dictionary(
1213 Box::new(DataType::Int32),
1214 Box::new(DataType::Utf8)
1215 )));
1216 assert!(is_supported_dictionary_type(&DataType::Dictionary(
1217 Box::new(DataType::Int32),
1218 Box::new(DataType::Int64)
1219 )));
1220 assert!(!is_supported_dictionary_type(&DataType::Dictionary(
1221 Box::new(DataType::Int32),
1222 Box::new(DataType::List(Arc::new(Field::new(
1223 "item",
1224 DataType::Int32,
1225 true
1226 ))))
1227 )));
1228
1229 assert!(is_supported_type(&DataType::Int32));
1231 assert!(is_supported_type(&DataType::Utf8));
1232 assert!(is_supported_type(&DataType::Binary));
1233 assert!(is_supported_type(&DataType::Dictionary(
1234 Box::new(DataType::Int32),
1235 Box::new(DataType::Utf8)
1236 )));
1237 assert!(!is_supported_type(&DataType::List(Arc::new(Field::new(
1238 "item",
1239 DataType::Int32,
1240 true
1241 )))));
1242 assert!(!is_supported_type(&DataType::Struct(Fields::empty())));
1243 }
1244
1245 #[test]
1246 fn test_error_conditions() {
1247 expect_cast(
1249 ScalarValue::Float32(Some(1.5)),
1250 DataType::Int32,
1251 ExpectedCast::NoValue,
1252 );
1253
1254 expect_cast(
1256 ScalarValue::Int32(Some(123)),
1257 DataType::Float64,
1258 ExpectedCast::NoValue,
1259 );
1260
1261 expect_cast(
1263 ScalarValue::Float64(Some(1.5)),
1264 DataType::Float32,
1265 ExpectedCast::NoValue,
1266 );
1267
1268 let list_type =
1270 DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1271 expect_cast(
1272 ScalarValue::Int32(Some(123)),
1273 list_type,
1274 ExpectedCast::NoValue,
1275 );
1276
1277 let bad_dict = DataType::Dictionary(
1279 Box::new(DataType::Int32),
1280 Box::new(DataType::List(Arc::new(Field::new(
1281 "item",
1282 DataType::Int32,
1283 true,
1284 )))),
1285 );
1286 expect_cast(
1287 ScalarValue::Int32(Some(123)),
1288 bad_dict,
1289 ExpectedCast::NoValue,
1290 );
1291 }
1292}