1use std::cmp::Ordering;
25
26use arrow::datatypes::{
27 DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
28 MIN_DECIMAL128_FOR_EACH_PRECISION,
29};
30use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
31use datafusion_common::ScalarValue;
32
33pub fn try_cast_literal_to_type(
35 lit_value: &ScalarValue,
36 target_type: &DataType,
37) -> Option<ScalarValue> {
38 let lit_data_type = lit_value.data_type();
39 if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
40 return None;
41 }
42 if lit_value.is_null() {
43 return ScalarValue::try_from(target_type).ok();
45 }
46 try_cast_numeric_literal(lit_value, target_type)
47 .or_else(|| try_cast_string_literal(lit_value, target_type))
48 .or_else(|| try_cast_dictionary(lit_value, target_type))
49 .or_else(|| try_cast_binary(lit_value, target_type))
50}
51
52pub fn is_supported_type(data_type: &DataType) -> bool {
54 is_supported_numeric_type(data_type)
55 || is_supported_string_type(data_type)
56 || is_supported_dictionary_type(data_type)
57 || is_supported_binary_type(data_type)
58}
59
60fn is_supported_numeric_type(data_type: &DataType) -> bool {
62 matches!(
63 data_type,
64 DataType::UInt8
65 | DataType::UInt16
66 | DataType::UInt32
67 | DataType::UInt64
68 | DataType::Int8
69 | DataType::Int16
70 | DataType::Int32
71 | DataType::Int64
72 | DataType::Decimal128(_, _)
73 | DataType::Timestamp(_, _)
74 )
75}
76
77fn is_supported_string_type(data_type: &DataType) -> bool {
79 matches!(
80 data_type,
81 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
82 )
83}
84
85fn is_supported_dictionary_type(data_type: &DataType) -> bool {
87 matches!(data_type,
88 DataType::Dictionary(_, inner) if is_supported_type(inner))
89}
90
91fn is_supported_binary_type(data_type: &DataType) -> bool {
92 matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
93}
94
95fn try_cast_numeric_literal(
97 lit_value: &ScalarValue,
98 target_type: &DataType,
99) -> Option<ScalarValue> {
100 let lit_data_type = lit_value.data_type();
101 if !is_supported_numeric_type(&lit_data_type)
102 || !is_supported_numeric_type(target_type)
103 {
104 return None;
105 }
106
107 let mul = match target_type {
108 DataType::UInt8
109 | DataType::UInt16
110 | DataType::UInt32
111 | DataType::UInt64
112 | DataType::Int8
113 | DataType::Int16
114 | DataType::Int32
115 | DataType::Int64 => 1_i128,
116 DataType::Timestamp(_, _) => 1_i128,
117 DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
118 _ => return None,
119 };
120 let (target_min, target_max) = match target_type {
121 DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
122 DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
123 DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
124 DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
125 DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
126 DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
127 DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
128 DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
129 DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
130 DataType::Decimal128(precision, _) => (
131 MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
135 MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
136 ),
137 _ => return None,
138 };
139 let lit_value_target_type = match lit_value {
140 ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
141 ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
142 ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
143 ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
144 ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
145 ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
146 ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
147 ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
148 ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
149 ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
150 ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
151 ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
152 ScalarValue::Decimal128(Some(v), _, scale) => {
153 let lit_scale_mul = 10_i128.pow(*scale as u32);
154 if mul >= lit_scale_mul {
155 (*v).checked_mul(mul / lit_scale_mul)
160 } else if (*v) % (lit_scale_mul / mul) == 0 {
161 Some(*v / (lit_scale_mul / mul))
166 } else {
167 None
169 }
170 }
171 _ => None,
172 };
173
174 match lit_value_target_type {
175 None => None,
176 Some(value) => {
177 if value >= target_min && value <= target_max {
178 let result_scalar = match target_type {
181 DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
182 DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
183 DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
184 DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
185 DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
186 DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
187 DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
188 DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
189 DataType::Timestamp(TimeUnit::Second, tz) => {
190 let value = cast_between_timestamp(
191 &lit_data_type,
192 &DataType::Timestamp(TimeUnit::Second, tz.clone()),
193 value,
194 );
195 ScalarValue::TimestampSecond(value, tz.clone())
196 }
197 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
198 let value = cast_between_timestamp(
199 &lit_data_type,
200 &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
201 value,
202 );
203 ScalarValue::TimestampMillisecond(value, tz.clone())
204 }
205 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
206 let value = cast_between_timestamp(
207 &lit_data_type,
208 &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
209 value,
210 );
211 ScalarValue::TimestampMicrosecond(value, tz.clone())
212 }
213 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
214 let value = cast_between_timestamp(
215 &lit_data_type,
216 &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
217 value,
218 );
219 ScalarValue::TimestampNanosecond(value, tz.clone())
220 }
221 DataType::Decimal128(p, s) => {
222 ScalarValue::Decimal128(Some(value), *p, *s)
223 }
224 _ => {
225 return None;
226 }
227 };
228 Some(result_scalar)
229 } else {
230 None
231 }
232 }
233 }
234}
235
236fn try_cast_string_literal(
237 lit_value: &ScalarValue,
238 target_type: &DataType,
239) -> Option<ScalarValue> {
240 let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
241 let scalar_value = match target_type {
242 DataType::Utf8 => ScalarValue::Utf8(string_value),
243 DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
244 DataType::Utf8View => ScalarValue::Utf8View(string_value),
245 _ => return None,
246 };
247 Some(scalar_value)
248}
249
250fn try_cast_dictionary(
252 lit_value: &ScalarValue,
253 target_type: &DataType,
254) -> Option<ScalarValue> {
255 let lit_value_type = lit_value.data_type();
256 let result_scalar = match (lit_value, target_type) {
257 (ScalarValue::Dictionary(_, inner_value), _)
259 if inner_value.data_type() == *target_type =>
260 {
261 (**inner_value).clone()
262 }
263 (_, DataType::Dictionary(index_type, inner_type))
265 if **inner_type == lit_value_type =>
266 {
267 ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
268 }
269 _ => {
270 return None;
271 }
272 };
273 Some(result_scalar)
274}
275
276fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
278 let value = value as i64;
279 let from_scale = match from {
280 DataType::Timestamp(TimeUnit::Second, _) => 1,
281 DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
282 DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
283 DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
284 _ => return Some(value),
285 };
286
287 let to_scale = match to {
288 DataType::Timestamp(TimeUnit::Second, _) => 1,
289 DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
290 DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
291 DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
292 _ => return Some(value),
293 };
294
295 match from_scale.cmp(&to_scale) {
296 Ordering::Less => value.checked_mul(to_scale / from_scale),
297 Ordering::Greater => Some(value / (from_scale / to_scale)),
298 Ordering::Equal => Some(value),
299 }
300}
301
302fn try_cast_binary(
303 lit_value: &ScalarValue,
304 target_type: &DataType,
305) -> Option<ScalarValue> {
306 match (lit_value, target_type) {
307 (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
308 if v.len() == *n as usize =>
309 {
310 Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
311 }
312 _ => None,
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use arrow::compute::{cast_with_options, CastOptions};
320 use arrow::datatypes::{Field, Fields, TimeUnit};
321 use std::sync::Arc;
322
323 #[derive(Debug, Clone)]
324 enum ExpectedCast {
325 Value(ScalarValue),
327 NoValue,
329 }
330
331 fn expect_cast(
335 literal: ScalarValue,
336 target_type: DataType,
337 expected_result: ExpectedCast,
338 ) {
339 let actual_value = try_cast_literal_to_type(&literal, &target_type);
340
341 println!("expect_cast: ");
342 println!(" {literal:?} --> {target_type:?}");
343 println!(" expected_result: {expected_result:?}");
344 println!(" actual_result: {actual_value:?}");
345
346 match expected_result {
347 ExpectedCast::Value(expected_value) => {
348 let actual_value =
349 actual_value.expect("Expected cast value but got None");
350
351 assert_eq!(actual_value, expected_value);
352
353 let literal_array = literal
357 .to_array_of_size(1)
358 .expect("Failed to convert to array of size");
359 let expected_array = expected_value
360 .to_array_of_size(1)
361 .expect("Failed to convert to array of size");
362 let cast_array = cast_with_options(
363 &literal_array,
364 &target_type,
365 &CastOptions::default(),
366 )
367 .expect("Expected to be cast array with arrow cast kernel");
368
369 assert_eq!(
370 &expected_array, &cast_array,
371 "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
372 );
373
374 if let (
377 DataType::Timestamp(left_unit, left_tz),
378 DataType::Timestamp(right_unit, right_tz),
379 ) = (actual_value.data_type(), expected_value.data_type())
380 {
381 assert_eq!(left_unit, right_unit);
382 assert_eq!(left_tz, right_tz);
383 }
384 }
385 ExpectedCast::NoValue => {
386 assert!(
387 actual_value.is_none(),
388 "Expected no cast value, but got {actual_value:?}"
389 );
390 }
391 }
392 }
393
394 #[test]
395 fn test_try_cast_to_type_nulls() {
396 let scalars = vec![
398 ScalarValue::Int8(None),
399 ScalarValue::Int16(None),
400 ScalarValue::Int32(None),
401 ScalarValue::Int64(None),
402 ScalarValue::UInt8(None),
403 ScalarValue::UInt16(None),
404 ScalarValue::UInt32(None),
405 ScalarValue::UInt64(None),
406 ScalarValue::Decimal128(None, 3, 0),
407 ScalarValue::Decimal128(None, 8, 2),
408 ScalarValue::Utf8(None),
409 ScalarValue::LargeUtf8(None),
410 ];
411
412 for s1 in &scalars {
413 for s2 in &scalars {
414 let expected_value = ExpectedCast::Value(s2.clone());
415
416 expect_cast(s1.clone(), s2.data_type(), expected_value);
417 }
418 }
419 }
420
421 #[test]
422 fn test_try_cast_to_type_int_in_range() {
423 let scalars = vec![
425 ScalarValue::Int8(Some(123)),
426 ScalarValue::Int16(Some(123)),
427 ScalarValue::Int32(Some(123)),
428 ScalarValue::Int64(Some(123)),
429 ScalarValue::UInt8(Some(123)),
430 ScalarValue::UInt16(Some(123)),
431 ScalarValue::UInt32(Some(123)),
432 ScalarValue::UInt64(Some(123)),
433 ScalarValue::Decimal128(Some(123), 3, 0),
434 ScalarValue::Decimal128(Some(12300), 8, 2),
435 ];
436
437 for s1 in &scalars {
438 for s2 in &scalars {
439 let expected_value = ExpectedCast::Value(s2.clone());
440
441 expect_cast(s1.clone(), s2.data_type(), expected_value);
442 }
443 }
444
445 let max_i32 = ScalarValue::Int32(Some(i32::MAX));
446 expect_cast(
447 max_i32,
448 DataType::UInt64,
449 ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
450 );
451
452 let min_i32 = ScalarValue::Int32(Some(i32::MIN));
453 expect_cast(
454 min_i32,
455 DataType::Int64,
456 ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
457 );
458
459 let max_i64 = ScalarValue::Int64(Some(i64::MAX));
460 expect_cast(
461 max_i64,
462 DataType::UInt64,
463 ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
464 );
465 }
466
467 #[test]
468 fn test_try_cast_to_type_int_out_of_range() {
469 let min_i32 = ScalarValue::Int32(Some(i32::MIN));
470 let min_i64 = ScalarValue::Int64(Some(i64::MIN));
471 let max_i64 = ScalarValue::Int64(Some(i64::MAX));
472 let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
473
474 expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
475
476 expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
477
478 expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
479
480 expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
481
482 expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
483
484 expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
485
486 expect_cast(
488 ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
489 DataType::Int64,
490 ExpectedCast::NoValue,
491 );
492
493 expect_cast(
494 ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
495 DataType::Int64,
496 ExpectedCast::NoValue,
497 );
498 }
499
500 #[test]
501 fn test_try_decimal_cast_in_range() {
502 expect_cast(
503 ScalarValue::Decimal128(Some(12300), 5, 2),
504 DataType::Decimal128(3, 0),
505 ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
506 );
507
508 expect_cast(
509 ScalarValue::Decimal128(Some(12300), 5, 2),
510 DataType::Decimal128(8, 0),
511 ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
512 );
513
514 expect_cast(
515 ScalarValue::Decimal128(Some(12300), 5, 2),
516 DataType::Decimal128(8, 5),
517 ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
518 );
519 }
520
521 #[test]
522 fn test_try_decimal_cast_out_of_range() {
523 expect_cast(
525 ScalarValue::Decimal128(Some(12345), 5, 2),
526 DataType::Decimal128(3, 0),
527 ExpectedCast::NoValue,
528 );
529
530 expect_cast(
532 ScalarValue::Decimal128(Some(12300), 5, 2),
533 DataType::Decimal128(2, 0),
534 ExpectedCast::NoValue,
535 );
536 }
537
538 #[test]
539 fn test_try_cast_to_type_timestamps() {
540 for time_unit in [
541 TimeUnit::Second,
542 TimeUnit::Millisecond,
543 TimeUnit::Microsecond,
544 TimeUnit::Nanosecond,
545 ] {
546 let utc = Some("+00:00".into());
547 let (lit_tz_none, lit_tz_utc) = match time_unit {
549 TimeUnit::Second => (
550 ScalarValue::TimestampSecond(Some(12345), None),
551 ScalarValue::TimestampSecond(Some(12345), utc),
552 ),
553
554 TimeUnit::Millisecond => (
555 ScalarValue::TimestampMillisecond(Some(12345), None),
556 ScalarValue::TimestampMillisecond(Some(12345), utc),
557 ),
558
559 TimeUnit::Microsecond => (
560 ScalarValue::TimestampMicrosecond(Some(12345), None),
561 ScalarValue::TimestampMicrosecond(Some(12345), utc),
562 ),
563
564 TimeUnit::Nanosecond => (
565 ScalarValue::TimestampNanosecond(Some(12345), None),
566 ScalarValue::TimestampNanosecond(Some(12345), utc),
567 ),
568 };
569
570 assert_eq!(lit_tz_none, lit_tz_utc);
573
574 let dt_tz_none = lit_tz_none.data_type();
576
577 let dt_tz_utc = lit_tz_utc.data_type();
579
580 expect_cast(
582 lit_tz_none.clone(),
583 dt_tz_none.clone(),
584 ExpectedCast::Value(lit_tz_none.clone()),
585 );
586
587 expect_cast(
589 lit_tz_none.clone(),
590 dt_tz_utc.clone(),
591 ExpectedCast::Value(lit_tz_utc.clone()),
592 );
593
594 expect_cast(
596 lit_tz_utc.clone(),
597 dt_tz_none.clone(),
598 ExpectedCast::Value(lit_tz_none.clone()),
599 );
600
601 expect_cast(
603 lit_tz_utc.clone(),
604 dt_tz_utc.clone(),
605 ExpectedCast::Value(lit_tz_utc.clone()),
606 );
607
608 expect_cast(
610 lit_tz_utc.clone(),
611 DataType::Int64,
612 ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
613 );
614
615 expect_cast(
617 ScalarValue::Int64(Some(12345)),
618 dt_tz_none.clone(),
619 ExpectedCast::Value(lit_tz_none.clone()),
620 );
621
622 expect_cast(
624 ScalarValue::Int64(Some(12345)),
625 dt_tz_utc.clone(),
626 ExpectedCast::Value(lit_tz_utc.clone()),
627 );
628
629 expect_cast(
631 lit_tz_utc.clone(),
632 DataType::LargeUtf8,
633 ExpectedCast::NoValue,
634 );
635 }
636 }
637
638 #[test]
639 fn test_try_cast_to_type_unsupported() {
640 expect_cast(
642 ScalarValue::Int64(Some(12345)),
643 DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
644 ExpectedCast::NoValue,
645 );
646 }
647
648 #[test]
649 fn test_try_cast_literal_to_timestamp() {
650 let new_scalar = try_cast_literal_to_type(
652 &ScalarValue::TimestampNanosecond(Some(123456), None),
653 &DataType::Timestamp(TimeUnit::Nanosecond, None),
654 )
655 .unwrap();
656
657 assert_eq!(
658 new_scalar,
659 ScalarValue::TimestampNanosecond(Some(123456), None)
660 );
661
662 let new_scalar = try_cast_literal_to_type(
664 &ScalarValue::TimestampNanosecond(Some(123456), None),
665 &DataType::Timestamp(TimeUnit::Microsecond, None),
666 )
667 .unwrap();
668
669 assert_eq!(
670 new_scalar,
671 ScalarValue::TimestampMicrosecond(Some(123), None)
672 );
673
674 let new_scalar = try_cast_literal_to_type(
676 &ScalarValue::TimestampNanosecond(Some(123456), None),
677 &DataType::Timestamp(TimeUnit::Millisecond, None),
678 )
679 .unwrap();
680
681 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
682
683 let new_scalar = try_cast_literal_to_type(
685 &ScalarValue::TimestampNanosecond(Some(123456), None),
686 &DataType::Timestamp(TimeUnit::Second, None),
687 )
688 .unwrap();
689
690 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
691
692 let new_scalar = try_cast_literal_to_type(
694 &ScalarValue::TimestampMicrosecond(Some(123), None),
695 &DataType::Timestamp(TimeUnit::Nanosecond, None),
696 )
697 .unwrap();
698
699 assert_eq!(
700 new_scalar,
701 ScalarValue::TimestampNanosecond(Some(123000), None)
702 );
703
704 let new_scalar = try_cast_literal_to_type(
706 &ScalarValue::TimestampMicrosecond(Some(123), None),
707 &DataType::Timestamp(TimeUnit::Millisecond, None),
708 )
709 .unwrap();
710
711 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
712
713 let new_scalar = try_cast_literal_to_type(
715 &ScalarValue::TimestampMicrosecond(Some(123456789), None),
716 &DataType::Timestamp(TimeUnit::Second, None),
717 )
718 .unwrap();
719 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
720
721 let new_scalar = try_cast_literal_to_type(
723 &ScalarValue::TimestampMillisecond(Some(123), None),
724 &DataType::Timestamp(TimeUnit::Nanosecond, None),
725 )
726 .unwrap();
727 assert_eq!(
728 new_scalar,
729 ScalarValue::TimestampNanosecond(Some(123000000), None)
730 );
731
732 let new_scalar = try_cast_literal_to_type(
734 &ScalarValue::TimestampMillisecond(Some(123), None),
735 &DataType::Timestamp(TimeUnit::Microsecond, None),
736 )
737 .unwrap();
738 assert_eq!(
739 new_scalar,
740 ScalarValue::TimestampMicrosecond(Some(123000), None)
741 );
742 let new_scalar = try_cast_literal_to_type(
744 &ScalarValue::TimestampMillisecond(Some(123456789), None),
745 &DataType::Timestamp(TimeUnit::Second, None),
746 )
747 .unwrap();
748 assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
749
750 let new_scalar = try_cast_literal_to_type(
752 &ScalarValue::TimestampSecond(Some(123), None),
753 &DataType::Timestamp(TimeUnit::Nanosecond, None),
754 )
755 .unwrap();
756 assert_eq!(
757 new_scalar,
758 ScalarValue::TimestampNanosecond(Some(123000000000), None)
759 );
760
761 let new_scalar = try_cast_literal_to_type(
763 &ScalarValue::TimestampSecond(Some(123), None),
764 &DataType::Timestamp(TimeUnit::Microsecond, None),
765 )
766 .unwrap();
767 assert_eq!(
768 new_scalar,
769 ScalarValue::TimestampMicrosecond(Some(123000000), None)
770 );
771
772 let new_scalar = try_cast_literal_to_type(
774 &ScalarValue::TimestampSecond(Some(123), None),
775 &DataType::Timestamp(TimeUnit::Millisecond, None),
776 )
777 .unwrap();
778 assert_eq!(
779 new_scalar,
780 ScalarValue::TimestampMillisecond(Some(123000), None)
781 );
782
783 let new_scalar = try_cast_literal_to_type(
785 &ScalarValue::TimestampSecond(Some(i64::MAX), None),
786 &DataType::Timestamp(TimeUnit::Millisecond, None),
787 )
788 .unwrap();
789 assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
790 }
791
792 #[test]
793 fn test_try_cast_to_string_type() {
794 let scalars = vec![
795 ScalarValue::from("string"),
796 ScalarValue::LargeUtf8(Some("string".to_owned())),
797 ];
798
799 for s1 in &scalars {
800 for s2 in &scalars {
801 let expected_value = ExpectedCast::Value(s2.clone());
802
803 expect_cast(s1.clone(), s2.data_type(), expected_value);
804 }
805 }
806 }
807
808 #[test]
809 fn test_try_cast_to_dictionary_type() {
810 fn dictionary_type(t: DataType) -> DataType {
811 DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
812 }
813 fn dictionary_value(value: ScalarValue) -> ScalarValue {
814 ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
815 }
816 let scalars = vec![
817 ScalarValue::from("string"),
818 ScalarValue::LargeUtf8(Some("string".to_owned())),
819 ];
820 for s in &scalars {
821 expect_cast(
822 s.clone(),
823 dictionary_type(s.data_type()),
824 ExpectedCast::Value(dictionary_value(s.clone())),
825 );
826 expect_cast(
827 dictionary_value(s.clone()),
828 s.data_type(),
829 ExpectedCast::Value(s.clone()),
830 )
831 }
832 }
833
834 #[test]
835 fn test_try_cast_to_fixed_size_binary() {
836 expect_cast(
837 ScalarValue::Binary(Some(vec![1, 2, 3])),
838 DataType::FixedSizeBinary(3),
839 ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
840 )
841 }
842
843 #[test]
844 fn test_numeric_boundary_values() {
845 expect_cast(
847 ScalarValue::Int8(Some(i8::MAX)),
848 DataType::UInt8,
849 ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))),
850 );
851
852 expect_cast(
853 ScalarValue::Int8(Some(i8::MIN)),
854 DataType::UInt8,
855 ExpectedCast::NoValue,
856 );
857
858 expect_cast(
859 ScalarValue::UInt8(Some(u8::MAX)),
860 DataType::Int8,
861 ExpectedCast::NoValue,
862 );
863
864 expect_cast(
866 ScalarValue::Int32(Some(i32::MAX)),
867 DataType::Int64,
868 ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))),
869 );
870
871 expect_cast(
872 ScalarValue::Int64(Some(i64::MIN)),
873 DataType::UInt64,
874 ExpectedCast::NoValue,
875 );
876
877 expect_cast(
879 ScalarValue::UInt32(Some(u32::MAX)),
880 DataType::Int32,
881 ExpectedCast::NoValue,
882 );
883
884 expect_cast(
885 ScalarValue::UInt64(Some(u64::MAX)),
886 DataType::Int64,
887 ExpectedCast::NoValue,
888 );
889 }
890
891 #[test]
892 fn test_decimal_precision_limits() {
893 use arrow::datatypes::{
894 MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION,
895 };
896
897 expect_cast(
899 ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
900 DataType::Decimal128(5, 0),
901 ExpectedCast::Value(ScalarValue::Decimal128(
902 Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]),
903 5,
904 0,
905 )),
906 );
907
908 expect_cast(
910 ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
911 DataType::Decimal128(5, 0),
912 ExpectedCast::Value(ScalarValue::Decimal128(
913 Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]),
914 5,
915 0,
916 )),
917 );
918
919 expect_cast(
921 ScalarValue::Decimal128(Some(123), 3, 0),
922 DataType::Decimal128(5, 2),
923 ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)),
924 );
925
926 expect_cast(
928 ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0),
929 DataType::Decimal128(3, 0),
930 ExpectedCast::NoValue,
931 );
932
933 expect_cast(
935 ScalarValue::Decimal128(Some(12345), 5, 3), DataType::Int32,
937 ExpectedCast::NoValue, );
939
940 expect_cast(
942 ScalarValue::Decimal128(Some(12345), 5, 2), DataType::Decimal128(3, 0), ExpectedCast::NoValue,
945 );
946 }
947
948 #[test]
949 fn test_timestamp_overflow_scenarios() {
950 let max_seconds = i64::MAX / 1_000_000_000; expect_cast(
955 ScalarValue::TimestampSecond(Some(max_seconds), None),
956 DataType::Timestamp(TimeUnit::Nanosecond, None),
957 ExpectedCast::Value(ScalarValue::TimestampNanosecond(
958 Some(max_seconds * 1_000_000_000),
959 None,
960 )),
961 );
962
963 expect_cast(
965 ScalarValue::TimestampNanosecond(Some(i64::MAX), None),
966 DataType::Timestamp(TimeUnit::Second, None),
967 ExpectedCast::Value(ScalarValue::TimestampSecond(
968 Some(i64::MAX / 1_000_000_000),
969 None,
970 )),
971 );
972
973 expect_cast(
975 ScalarValue::TimestampNanosecond(Some(1), None),
976 DataType::Timestamp(TimeUnit::Second, None),
977 ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)),
978 );
979
980 expect_cast(
981 ScalarValue::TimestampMicrosecond(Some(999), None),
982 DataType::Timestamp(TimeUnit::Millisecond, None),
983 ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)),
984 );
985 }
986
987 #[test]
988 fn test_string_view() {
989 expect_cast(
991 ScalarValue::Utf8View(Some("test".to_string())),
992 DataType::Utf8,
993 ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))),
994 );
995
996 expect_cast(
997 ScalarValue::Utf8View(Some("test".to_string())),
998 DataType::LargeUtf8,
999 ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))),
1000 );
1001
1002 expect_cast(
1004 ScalarValue::Utf8(Some("hello".to_string())),
1005 DataType::Utf8View,
1006 ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))),
1007 );
1008
1009 expect_cast(
1010 ScalarValue::LargeUtf8(Some("world".to_string())),
1011 DataType::Utf8View,
1012 ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))),
1013 );
1014
1015 expect_cast(
1017 ScalarValue::Utf8(Some("".to_string())),
1018 DataType::Utf8View,
1019 ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))),
1020 );
1021
1022 let large_string = "x".repeat(1000);
1024 expect_cast(
1025 ScalarValue::LargeUtf8(Some(large_string.clone())),
1026 DataType::Utf8View,
1027 ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))),
1028 );
1029 }
1030
1031 #[test]
1032 fn test_binary_size_edge_cases() {
1033 expect_cast(
1035 ScalarValue::Binary(Some(vec![1, 2])),
1036 DataType::FixedSizeBinary(3),
1037 ExpectedCast::NoValue,
1038 );
1039
1040 expect_cast(
1042 ScalarValue::Binary(Some(vec![1, 2, 3, 4])),
1043 DataType::FixedSizeBinary(3),
1044 ExpectedCast::NoValue,
1045 );
1046
1047 expect_cast(
1049 ScalarValue::Binary(Some(vec![])),
1050 DataType::FixedSizeBinary(0),
1051 ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))),
1052 );
1053
1054 expect_cast(
1056 ScalarValue::Binary(Some(vec![1, 2, 3])),
1057 DataType::FixedSizeBinary(3),
1058 ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
1059 );
1060
1061 expect_cast(
1063 ScalarValue::Binary(Some(vec![42])),
1064 DataType::FixedSizeBinary(1),
1065 ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))),
1066 );
1067 }
1068
1069 #[test]
1070 fn test_dictionary_index_types() {
1071 let string_value = ScalarValue::Utf8(Some("test".to_string()));
1073
1074 let dict_int8 =
1076 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
1077 expect_cast(
1078 string_value.clone(),
1079 dict_int8,
1080 ExpectedCast::Value(ScalarValue::Dictionary(
1081 Box::new(DataType::Int8),
1082 Box::new(string_value.clone()),
1083 )),
1084 );
1085
1086 let dict_int16 =
1088 DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
1089 expect_cast(
1090 string_value.clone(),
1091 dict_int16,
1092 ExpectedCast::Value(ScalarValue::Dictionary(
1093 Box::new(DataType::Int16),
1094 Box::new(string_value.clone()),
1095 )),
1096 );
1097
1098 let dict_int64 =
1100 DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8));
1101 expect_cast(
1102 string_value.clone(),
1103 dict_int64,
1104 ExpectedCast::Value(ScalarValue::Dictionary(
1105 Box::new(DataType::Int64),
1106 Box::new(string_value.clone()),
1107 )),
1108 );
1109
1110 let dict_value = ScalarValue::Dictionary(
1112 Box::new(DataType::Int32),
1113 Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1114 );
1115 expect_cast(
1116 dict_value,
1117 DataType::LargeUtf8,
1118 ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1119 );
1120 }
1121
1122 #[test]
1123 fn test_type_support_functions() {
1124 assert!(is_supported_numeric_type(&DataType::Int8));
1126 assert!(is_supported_numeric_type(&DataType::UInt64));
1127 assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2)));
1128 assert!(is_supported_numeric_type(&DataType::Timestamp(
1129 TimeUnit::Nanosecond,
1130 None
1131 )));
1132 assert!(!is_supported_numeric_type(&DataType::Float32));
1133 assert!(!is_supported_numeric_type(&DataType::Float64));
1134
1135 assert!(is_supported_string_type(&DataType::Utf8));
1137 assert!(is_supported_string_type(&DataType::LargeUtf8));
1138 assert!(is_supported_string_type(&DataType::Utf8View));
1139 assert!(!is_supported_string_type(&DataType::Binary));
1140
1141 assert!(is_supported_binary_type(&DataType::Binary));
1143 assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10)));
1144 assert!(!is_supported_binary_type(&DataType::Utf8));
1145
1146 assert!(is_supported_dictionary_type(&DataType::Dictionary(
1148 Box::new(DataType::Int32),
1149 Box::new(DataType::Utf8)
1150 )));
1151 assert!(is_supported_dictionary_type(&DataType::Dictionary(
1152 Box::new(DataType::Int32),
1153 Box::new(DataType::Int64)
1154 )));
1155 assert!(!is_supported_dictionary_type(&DataType::Dictionary(
1156 Box::new(DataType::Int32),
1157 Box::new(DataType::List(Arc::new(Field::new(
1158 "item",
1159 DataType::Int32,
1160 true
1161 ))))
1162 )));
1163
1164 assert!(is_supported_type(&DataType::Int32));
1166 assert!(is_supported_type(&DataType::Utf8));
1167 assert!(is_supported_type(&DataType::Binary));
1168 assert!(is_supported_type(&DataType::Dictionary(
1169 Box::new(DataType::Int32),
1170 Box::new(DataType::Utf8)
1171 )));
1172 assert!(!is_supported_type(&DataType::List(Arc::new(Field::new(
1173 "item",
1174 DataType::Int32,
1175 true
1176 )))));
1177 assert!(!is_supported_type(&DataType::Struct(Fields::empty())));
1178 }
1179
1180 #[test]
1181 fn test_error_conditions() {
1182 expect_cast(
1184 ScalarValue::Float32(Some(1.5)),
1185 DataType::Int32,
1186 ExpectedCast::NoValue,
1187 );
1188
1189 expect_cast(
1191 ScalarValue::Int32(Some(123)),
1192 DataType::Float64,
1193 ExpectedCast::NoValue,
1194 );
1195
1196 expect_cast(
1198 ScalarValue::Float64(Some(1.5)),
1199 DataType::Float32,
1200 ExpectedCast::NoValue,
1201 );
1202
1203 let list_type =
1205 DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1206 expect_cast(
1207 ScalarValue::Int32(Some(123)),
1208 list_type,
1209 ExpectedCast::NoValue,
1210 );
1211
1212 let bad_dict = DataType::Dictionary(
1214 Box::new(DataType::Int32),
1215 Box::new(DataType::List(Arc::new(Field::new(
1216 "item",
1217 DataType::Int32,
1218 true,
1219 )))),
1220 );
1221 expect_cast(
1222 ScalarValue::Int32(Some(123)),
1223 bad_dict,
1224 ExpectedCast::NoValue,
1225 );
1226 }
1227}