1use crate::timezone;
19use crate::utils::array_with_timezone;
20use crate::{EvalMode, SparkError, SparkResult};
21use arrow::array::builder::StringBuilder;
22use arrow::array::{DictionaryArray, StringArray, StructArray};
23use arrow::compute::can_cast_types;
24use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
25use arrow::{
26 array::{
27 cast::AsArray,
28 types::{Date32Type, Int16Type, Int32Type, Int8Type},
29 Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array,
30 GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
31 PrimitiveArray,
32 },
33 compute::{cast_with_options, take, unary, CastOptions},
34 datatypes::{
35 is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type,
36 Float64Type, Int64Type, TimestampMicrosecondType,
37 },
38 error::ArrowError,
39 record_batch::RecordBatch,
40 util::display::FormatOptions,
41};
42use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
43use datafusion::common::{
44 cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult,
45 ScalarValue,
46};
47use datafusion::physical_expr::PhysicalExpr;
48use datafusion::physical_plan::ColumnarValue;
49use num::{
50 cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num,
51 ToPrimitive,
52};
53use regex::Regex;
54use std::str::FromStr;
55use std::{
56 any::Any,
57 fmt::{Debug, Display, Formatter},
58 hash::Hash,
59 num::Wrapping,
60 sync::Arc,
61};
62
63static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
64
65const MICROS_PER_SECOND: i64 = 1000000;
66
67static CAST_OPTIONS: CastOptions = CastOptions {
68 safe: true,
69 format_options: FormatOptions::new()
70 .with_timestamp_tz_format(TIMESTAMP_FORMAT)
71 .with_timestamp_format(TIMESTAMP_FORMAT),
72};
73
74struct TimeStampInfo {
75 year: i32,
76 month: u32,
77 day: u32,
78 hour: u32,
79 minute: u32,
80 second: u32,
81 microsecond: u32,
82}
83
84impl Default for TimeStampInfo {
85 fn default() -> Self {
86 TimeStampInfo {
87 year: 1,
88 month: 1,
89 day: 1,
90 hour: 0,
91 minute: 0,
92 second: 0,
93 microsecond: 0,
94 }
95 }
96}
97
98impl TimeStampInfo {
99 pub fn with_year(&mut self, year: i32) -> &mut Self {
100 self.year = year;
101 self
102 }
103
104 pub fn with_month(&mut self, month: u32) -> &mut Self {
105 self.month = month;
106 self
107 }
108
109 pub fn with_day(&mut self, day: u32) -> &mut Self {
110 self.day = day;
111 self
112 }
113
114 pub fn with_hour(&mut self, hour: u32) -> &mut Self {
115 self.hour = hour;
116 self
117 }
118
119 pub fn with_minute(&mut self, minute: u32) -> &mut Self {
120 self.minute = minute;
121 self
122 }
123
124 pub fn with_second(&mut self, second: u32) -> &mut Self {
125 self.second = second;
126 self
127 }
128
129 pub fn with_microsecond(&mut self, microsecond: u32) -> &mut Self {
130 self.microsecond = microsecond;
131 self
132 }
133}
134
135#[derive(Debug, Eq)]
136pub struct Cast {
137 pub child: Arc<dyn PhysicalExpr>,
138 pub data_type: DataType,
139 pub cast_options: SparkCastOptions,
140}
141
142impl PartialEq for Cast {
143 fn eq(&self, other: &Self) -> bool {
144 self.child.eq(&other.child)
145 && self.data_type.eq(&other.data_type)
146 && self.cast_options.eq(&other.cast_options)
147 }
148}
149
150impl Hash for Cast {
151 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
152 self.child.hash(state);
153 self.data_type.hash(state);
154 self.cast_options.hash(state);
155 }
156}
157
158pub fn cast_supported(
160 from_type: &DataType,
161 to_type: &DataType,
162 options: &SparkCastOptions,
163) -> bool {
164 use DataType::*;
165
166 let from_type = if let Dictionary(_, dt) = from_type {
167 dt
168 } else {
169 from_type
170 };
171
172 let to_type = if let Dictionary(_, dt) = to_type {
173 dt
174 } else {
175 to_type
176 };
177
178 if from_type == to_type {
179 return true;
180 }
181
182 match (from_type, to_type) {
183 (Boolean, _) => can_cast_from_boolean(to_type, options),
184 (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
185 if options.allow_cast_unsigned_ints =>
186 {
187 true
188 }
189 (Int8, _) => can_cast_from_byte(to_type, options),
190 (Int16, _) => can_cast_from_short(to_type, options),
191 (Int32, _) => can_cast_from_int(to_type, options),
192 (Int64, _) => can_cast_from_long(to_type, options),
193 (Float32, _) => can_cast_from_float(to_type, options),
194 (Float64, _) => can_cast_from_double(to_type, options),
195 (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
196 (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options),
197 (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options),
198 (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
199 (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
200 (Struct(from_fields), Struct(to_fields)) => from_fields
201 .iter()
202 .zip(to_fields.iter())
203 .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)),
204 _ => false,
205 }
206}
207
208fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool {
209 use DataType::*;
210 match to_type {
211 Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
212 Float32 | Float64 => {
213 options.allow_incompat
217 }
218 Decimal128(_, _) => {
219 options.allow_incompat
224 }
225 Date32 | Date64 => {
226 options.allow_incompat
229 }
230 Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => {
231 false
233 }
234 Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => {
235 options.allow_incompat
237 }
238 Timestamp(_, _) => {
239 options.allow_incompat
242 }
243 _ => false,
244 }
245}
246
247fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> bool {
248 use DataType::*;
249 match from_type {
250 Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => true,
251 Float32 | Float64 => {
252 true
256 }
257 Decimal128(_, _) => {
258 true
262 }
263 Binary => {
264 options.allow_incompat
267 }
268 Struct(fields) => fields
269 .iter()
270 .all(|f| can_cast_to_string(f.data_type(), options)),
271 _ => false,
272 }
273}
274
275fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) -> bool {
276 use DataType::*;
277 match to_type {
278 Timestamp(_, _) | Date32 | Date64 | Utf8 => {
279 options.allow_incompat
281 }
282 _ => {
283 false
285 }
286 }
287}
288
289fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> bool {
290 use DataType::*;
291 match to_type {
292 Boolean | Int8 | Int16 => {
293 false
296 }
297 Int64 => {
298 true
300 }
301 Date32 | Date64 | Utf8 | Decimal128(_, _) => true,
302 _ => {
303 false
305 }
306 }
307}
308
309fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
310 use DataType::*;
311 matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
312}
313
314fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
315 use DataType::*;
316 matches!(
317 to_type,
318 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
319 )
320}
321
322fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
323 use DataType::*;
324 matches!(
325 to_type,
326 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
327 )
328}
329
330fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
331 use DataType::*;
332 match to_type {
333 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true,
334 Decimal128(_, _) => {
335 options.allow_incompat
337 }
338 _ => false,
339 }
340}
341
342fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
343 use DataType::*;
344 match to_type {
345 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
346 Decimal128(_, _) => {
347 options.allow_incompat
349 }
350 _ => false,
351 }
352}
353
354fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool {
355 use DataType::*;
356 matches!(
357 to_type,
358 Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _)
359 )
360}
361
362fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool {
363 use DataType::*;
364 matches!(
365 to_type,
366 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _)
367 )
368}
369
370fn can_cast_from_decimal(
371 p1: &u8,
372 _s1: &i8,
373 to_type: &DataType,
374 options: &SparkCastOptions,
375) -> bool {
376 use DataType::*;
377 match to_type {
378 Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
379 Decimal128(p2, _) => {
380 if p2 < p1 {
381 options.allow_incompat
384 } else {
385 true
386 }
387 }
388 _ => false,
389 }
390}
391
392macro_rules! cast_utf8_to_int {
393 ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
394 let len = $array.len();
395 let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
396 for i in 0..len {
397 if $array.is_null(i) {
398 cast_array.append_null()
399 } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
400 cast_array.append_value(cast_value);
401 } else {
402 cast_array.append_null()
403 }
404 }
405 let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
406 result
407 }};
408}
409macro_rules! cast_utf8_to_timestamp {
410 ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
411 let len = $array.len();
412 let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
413 for i in 0..len {
414 if $array.is_null(i) {
415 cast_array.append_null()
416 } else if let Ok(Some(cast_value)) =
417 $cast_method($array.value(i).trim(), $eval_mode, $tz)
418 {
419 cast_array.append_value(cast_value);
420 } else {
421 cast_array.append_null()
422 }
423 }
424 let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
425 result
426 }};
427}
428
429macro_rules! cast_float_to_string {
430 ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{
431
432 fn cast<OffsetSize>(
433 from: &dyn Array,
434 _eval_mode: EvalMode,
435 ) -> SparkResult<ArrayRef>
436 where
437 OffsetSize: OffsetSizeTrait, {
438 let array = from.as_any().downcast_ref::<$output_type>().unwrap();
439
440 const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
449 const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;
450
451 let output_array = array
452 .iter()
453 .map(|value| match value {
454 Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())),
455 Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())),
456 Some(value)
457 if (value.abs() < UPPER_SCIENTIFIC_BOUND
458 && value.abs() >= LOWER_SCIENTIFIC_BOUND)
459 || value.abs() == 0.0 =>
460 {
461 let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" };
462
463 Ok(Some(format!("{value}{trailing_zero}")))
464 }
465 Some(value)
466 if value.abs() >= UPPER_SCIENTIFIC_BOUND
467 || value.abs() < LOWER_SCIENTIFIC_BOUND =>
468 {
469 let formatted = format!("{value:E}");
470
471 if formatted.contains(".") {
472 Ok(Some(formatted))
473 } else {
474 let prepare_number: Vec<&str> = formatted.split("E").collect();
477
478 let coefficient = prepare_number[0];
479
480 let exponent = prepare_number[1];
481
482 Ok(Some(format!("{coefficient}.0E{exponent}")))
483 }
484 }
485 Some(value) => Ok(Some(value.to_string())),
486 _ => Ok(None),
487 })
488 .collect::<Result<GenericStringArray<OffsetSize>, SparkError>>()?;
489
490 Ok(Arc::new(output_array))
491 }
492
493 cast::<$offset_type>($from, $eval_mode)
494 }};
495}
496
497macro_rules! cast_int_to_int_macro {
498 (
499 $array: expr,
500 $eval_mode:expr,
501 $from_arrow_primitive_type: ty,
502 $to_arrow_primitive_type: ty,
503 $from_data_type: expr,
504 $to_native_type: ty,
505 $spark_from_data_type_name: expr,
506 $spark_to_data_type_name: expr
507 ) => {{
508 let cast_array = $array
509 .as_any()
510 .downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
511 .unwrap();
512 let spark_int_literal_suffix = match $from_data_type {
513 &DataType::Int64 => "L",
514 &DataType::Int16 => "S",
515 &DataType::Int8 => "T",
516 _ => "",
517 };
518
519 let output_array = match $eval_mode {
520 EvalMode::Legacy => cast_array
521 .iter()
522 .map(|value| match value {
523 Some(value) => {
524 Ok::<Option<$to_native_type>, SparkError>(Some(value as $to_native_type))
525 }
526 _ => Ok(None),
527 })
528 .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
529 _ => cast_array
530 .iter()
531 .map(|value| match value {
532 Some(value) => {
533 let res = <$to_native_type>::try_from(value);
534 if res.is_err() {
535 Err(cast_overflow(
536 &(value.to_string() + spark_int_literal_suffix),
537 $spark_from_data_type_name,
538 $spark_to_data_type_name,
539 ))
540 } else {
541 Ok::<Option<$to_native_type>, SparkError>(Some(res.unwrap()))
542 }
543 }
544 _ => Ok(None),
545 })
546 .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
547 }?;
548 let result: SparkResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
549 result
550 }};
551}
552
553macro_rules! cast_float_to_int16_down {
557 (
558 $array:expr,
559 $eval_mode:expr,
560 $src_array_type:ty,
561 $dest_array_type:ty,
562 $rust_src_type:ty,
563 $rust_dest_type:ty,
564 $src_type_str:expr,
565 $dest_type_str:expr,
566 $format_str:expr
567 ) => {{
568 let cast_array = $array
569 .as_any()
570 .downcast_ref::<$src_array_type>()
571 .expect(concat!("Expected a ", stringify!($src_array_type)));
572
573 let output_array = match $eval_mode {
574 EvalMode::Ansi => cast_array
575 .iter()
576 .map(|value| match value {
577 Some(value) => {
578 let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX;
579 if is_overflow {
580 return Err(cast_overflow(
581 &format!($format_str, value).replace("e", "E"),
582 $src_type_str,
583 $dest_type_str,
584 ));
585 }
586 let i32_value = value as i32;
587 <$rust_dest_type>::try_from(i32_value)
588 .map_err(|_| {
589 cast_overflow(
590 &format!($format_str, value).replace("e", "E"),
591 $src_type_str,
592 $dest_type_str,
593 )
594 })
595 .map(Some)
596 }
597 None => Ok(None),
598 })
599 .collect::<Result<$dest_array_type, _>>()?,
600 _ => cast_array
601 .iter()
602 .map(|value| match value {
603 Some(value) => {
604 let i32_value = value as i32;
605 Ok::<Option<$rust_dest_type>, SparkError>(Some(
606 i32_value as $rust_dest_type,
607 ))
608 }
609 None => Ok(None),
610 })
611 .collect::<Result<$dest_array_type, _>>()?,
612 };
613 Ok(Arc::new(output_array) as ArrayRef)
614 }};
615}
616
617macro_rules! cast_float_to_int32_up {
618 (
619 $array:expr,
620 $eval_mode:expr,
621 $src_array_type:ty,
622 $dest_array_type:ty,
623 $rust_src_type:ty,
624 $rust_dest_type:ty,
625 $src_type_str:expr,
626 $dest_type_str:expr,
627 $max_dest_val:expr,
628 $format_str:expr
629 ) => {{
630 let cast_array = $array
631 .as_any()
632 .downcast_ref::<$src_array_type>()
633 .expect(concat!("Expected a ", stringify!($src_array_type)));
634
635 let output_array = match $eval_mode {
636 EvalMode::Ansi => cast_array
637 .iter()
638 .map(|value| match value {
639 Some(value) => {
640 let is_overflow =
641 value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val;
642 if is_overflow {
643 return Err(cast_overflow(
644 &format!($format_str, value).replace("e", "E"),
645 $src_type_str,
646 $dest_type_str,
647 ));
648 }
649 Ok(Some(value as $rust_dest_type))
650 }
651 None => Ok(None),
652 })
653 .collect::<Result<$dest_array_type, _>>()?,
654 _ => cast_array
655 .iter()
656 .map(|value| match value {
657 Some(value) => {
658 Ok::<Option<$rust_dest_type>, SparkError>(Some(value as $rust_dest_type))
659 }
660 None => Ok(None),
661 })
662 .collect::<Result<$dest_array_type, _>>()?,
663 };
664 Ok(Arc::new(output_array) as ArrayRef)
665 }};
666}
667
668macro_rules! cast_decimal_to_int16_down {
672 (
673 $array:expr,
674 $eval_mode:expr,
675 $dest_array_type:ty,
676 $rust_dest_type:ty,
677 $dest_type_str:expr,
678 $precision:expr,
679 $scale:expr
680 ) => {{
681 let cast_array = $array
682 .as_any()
683 .downcast_ref::<Decimal128Array>()
684 .expect("Expected a Decimal128ArrayType");
685
686 let output_array = match $eval_mode {
687 EvalMode::Ansi => cast_array
688 .iter()
689 .map(|value| match value {
690 Some(value) => {
691 let divisor = 10_i128.pow($scale as u32);
692 let (truncated, decimal) = (value / divisor, (value % divisor).abs());
693 let is_overflow = truncated.abs() > i32::MAX.into();
694 if is_overflow {
695 return Err(cast_overflow(
696 &format!("{}.{}BD", truncated, decimal),
697 &format!("DECIMAL({},{})", $precision, $scale),
698 $dest_type_str,
699 ));
700 }
701 let i32_value = truncated as i32;
702 <$rust_dest_type>::try_from(i32_value)
703 .map_err(|_| {
704 cast_overflow(
705 &format!("{}.{}BD", truncated, decimal),
706 &format!("DECIMAL({},{})", $precision, $scale),
707 $dest_type_str,
708 )
709 })
710 .map(Some)
711 }
712 None => Ok(None),
713 })
714 .collect::<Result<$dest_array_type, _>>()?,
715 _ => cast_array
716 .iter()
717 .map(|value| match value {
718 Some(value) => {
719 let divisor = 10_i128.pow($scale as u32);
720 let i32_value = (value / divisor) as i32;
721 Ok::<Option<$rust_dest_type>, SparkError>(Some(
722 i32_value as $rust_dest_type,
723 ))
724 }
725 None => Ok(None),
726 })
727 .collect::<Result<$dest_array_type, _>>()?,
728 };
729 Ok(Arc::new(output_array) as ArrayRef)
730 }};
731}
732
733macro_rules! cast_decimal_to_int32_up {
734 (
735 $array:expr,
736 $eval_mode:expr,
737 $dest_array_type:ty,
738 $rust_dest_type:ty,
739 $dest_type_str:expr,
740 $max_dest_val:expr,
741 $precision:expr,
742 $scale:expr
743 ) => {{
744 let cast_array = $array
745 .as_any()
746 .downcast_ref::<Decimal128Array>()
747 .expect("Expected a Decimal128ArrayType");
748
749 let output_array = match $eval_mode {
750 EvalMode::Ansi => cast_array
751 .iter()
752 .map(|value| match value {
753 Some(value) => {
754 let divisor = 10_i128.pow($scale as u32);
755 let (truncated, decimal) = (value / divisor, (value % divisor).abs());
756 let is_overflow = truncated.abs() > $max_dest_val.into();
757 if is_overflow {
758 return Err(cast_overflow(
759 &format!("{}.{}BD", truncated, decimal),
760 &format!("DECIMAL({},{})", $precision, $scale),
761 $dest_type_str,
762 ));
763 }
764 Ok(Some(truncated as $rust_dest_type))
765 }
766 None => Ok(None),
767 })
768 .collect::<Result<$dest_array_type, _>>()?,
769 _ => cast_array
770 .iter()
771 .map(|value| match value {
772 Some(value) => {
773 let divisor = 10_i128.pow($scale as u32);
774 let truncated = value / divisor;
775 Ok::<Option<$rust_dest_type>, SparkError>(Some(
776 truncated as $rust_dest_type,
777 ))
778 }
779 None => Ok(None),
780 })
781 .collect::<Result<$dest_array_type, _>>()?,
782 };
783 Ok(Arc::new(output_array) as ArrayRef)
784 }};
785}
786
787impl Cast {
788 pub fn new(
789 child: Arc<dyn PhysicalExpr>,
790 data_type: DataType,
791 cast_options: SparkCastOptions,
792 ) -> Self {
793 Self {
794 child,
795 data_type,
796 cast_options,
797 }
798 }
799}
800
801#[derive(Debug, Clone, Hash, PartialEq, Eq)]
803pub struct SparkCastOptions {
804 pub eval_mode: EvalMode,
806 pub timezone: String,
810 pub allow_incompat: bool,
812 pub allow_cast_unsigned_ints: bool,
814 pub is_adapting_schema: bool,
817 pub null_string: String,
819}
820
821impl SparkCastOptions {
822 pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self {
823 Self {
824 eval_mode,
825 timezone: timezone.to_string(),
826 allow_incompat,
827 allow_cast_unsigned_ints: false,
828 is_adapting_schema: false,
829 null_string: "null".to_string(),
830 }
831 }
832
833 pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self {
834 Self {
835 eval_mode,
836 timezone: "".to_string(),
837 allow_incompat,
838 allow_cast_unsigned_ints: false,
839 is_adapting_schema: false,
840 null_string: "null".to_string(),
841 }
842 }
843}
844
845pub fn spark_cast(
849 arg: ColumnarValue,
850 data_type: &DataType,
851 cast_options: &SparkCastOptions,
852) -> DataFusionResult<ColumnarValue> {
853 match arg {
854 ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array(
855 array,
856 data_type,
857 cast_options,
858 )?)),
859 ColumnarValue::Scalar(scalar) => {
860 let array = scalar.to_array()?;
864 let scalar =
865 ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?;
866 Ok(ColumnarValue::Scalar(scalar))
867 }
868 }
869}
870
871fn dict_from_values<K: ArrowDictionaryKeyType>(
873 values_array: ArrayRef,
874) -> datafusion::common::Result<ArrayRef> {
875 let key_array: PrimitiveArray<K> = (0..values_array.len())
878 .map(|index| {
879 if values_array.is_valid(index) {
880 let native_index = K::Native::from_usize(index).ok_or_else(|| {
881 DataFusionError::Internal(format!(
882 "Can not create index of type {} from value {}",
883 K::DATA_TYPE,
884 index
885 ))
886 })?;
887 Ok(Some(native_index))
888 } else {
889 Ok(None)
890 }
891 })
892 .collect::<datafusion::common::Result<Vec<_>>>()?
893 .into_iter()
894 .collect();
895
896 let dict_array = DictionaryArray::<K>::try_new(key_array, values_array)?;
902 Ok(Arc::new(dict_array))
903}
904
905fn cast_array(
906 array: ArrayRef,
907 to_type: &DataType,
908 cast_options: &SparkCastOptions,
909) -> DataFusionResult<ArrayRef> {
910 use DataType::*;
911 let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
912 let from_type = array.data_type().clone();
913
914 let native_cast_options: CastOptions = CastOptions {
915 safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), format_options: FormatOptions::new()
917 .with_timestamp_tz_format(TIMESTAMP_FORMAT)
918 .with_timestamp_format(TIMESTAMP_FORMAT),
919 };
920
921 let array = match &from_type {
922 Dictionary(key_type, value_type)
923 if key_type.as_ref() == &Int32
924 && (value_type.as_ref() == &Utf8
925 || value_type.as_ref() == &LargeUtf8
926 || value_type.as_ref() == &Binary
927 || value_type.as_ref() == &LargeBinary) =>
928 {
929 let dict_array = array
930 .as_any()
931 .downcast_ref::<DictionaryArray<Int32Type>>()
932 .expect("Expected a dictionary array");
933
934 let casted_result = match to_type {
935 Dictionary(_, to_value_type) => {
936 let casted_dictionary = DictionaryArray::<Int32Type>::new(
937 dict_array.keys().clone(),
938 cast_array(Arc::clone(dict_array.values()), to_value_type, cast_options)?,
939 );
940 Arc::new(casted_dictionary.clone())
941 }
942 _ => {
943 let casted_dictionary = DictionaryArray::<Int32Type>::new(
944 dict_array.keys().clone(),
945 cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?,
946 );
947 take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?
948 }
949 };
950 return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
951 }
952 _ => {
953 if let Dictionary(_, _) = to_type {
954 let dict_array = dict_from_values::<Int32Type>(array)?;
955 let casted_result = cast_array(dict_array, to_type, cast_options)?;
956 return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
957 } else {
958 array
959 }
960 }
961 };
962 let from_type = array.data_type();
963 let eval_mode = cast_options.eval_mode;
964
965 let cast_result = match (from_type, to_type) {
966 (Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
967 (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
968 (Utf8, Timestamp(_, _)) => {
969 cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
970 }
971 (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
972 (Int64, Int32)
973 | (Int64, Int16)
974 | (Int64, Int8)
975 | (Int32, Int16)
976 | (Int32, Int8)
977 | (Int16, Int8)
978 if eval_mode != EvalMode::Try =>
979 {
980 spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
981 }
982 (Utf8, Int8 | Int16 | Int32 | Int64) => {
983 cast_string_to_int::<i32>(to_type, &array, eval_mode)
984 }
985 (LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
986 cast_string_to_int::<i64>(to_type, &array, eval_mode)
987 }
988 (Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
989 (Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array, eval_mode),
990 (Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
991 (Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array, eval_mode),
992 (Float32, Decimal128(precision, scale)) => {
993 cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
994 }
995 (Float64, Decimal128(precision, scale)) => {
996 cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
997 }
998 (Float32, Int8)
999 | (Float32, Int16)
1000 | (Float32, Int32)
1001 | (Float32, Int64)
1002 | (Float64, Int8)
1003 | (Float64, Int16)
1004 | (Float64, Int32)
1005 | (Float64, Int64)
1006 | (Decimal128(_, _), Int8)
1007 | (Decimal128(_, _), Int16)
1008 | (Decimal128(_, _), Int32)
1009 | (Decimal128(_, _), Int64)
1010 if eval_mode != EvalMode::Try =>
1011 {
1012 spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
1013 }
1014 (Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?),
1015 (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
1016 (Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
1017 array.as_struct(),
1018 from_type,
1019 to_type,
1020 cast_options,
1021 )?),
1022 (List(_), List(_)) if can_cast_types(from_type, to_type) => {
1023 Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
1024 }
1025 (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
1026 if cast_options.allow_cast_unsigned_ints =>
1027 {
1028 Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
1029 }
1030 _ if cast_options.is_adapting_schema
1031 || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
1032 {
1033 Ok(cast_with_options(&array, to_type, &native_cast_options)?)
1035 }
1036 _ => {
1037 Err(SparkError::Internal(format!(
1041 "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}"
1042 )))
1043 }
1044 };
1045 Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
1046}
1047
1048fn is_datafusion_spark_compatible(
1051 from_type: &DataType,
1052 to_type: &DataType,
1053 allow_incompat: bool,
1054) -> bool {
1055 if from_type == to_type {
1056 return true;
1057 }
1058 match from_type {
1059 DataType::Null => {
1060 matches!(to_type, DataType::List(_))
1061 }
1062 DataType::Boolean => matches!(
1063 to_type,
1064 DataType::Int8
1065 | DataType::Int16
1066 | DataType::Int32
1067 | DataType::Int64
1068 | DataType::Float32
1069 | DataType::Float64
1070 | DataType::Utf8
1071 ),
1072 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1073 matches!(
1077 to_type,
1078 DataType::Boolean
1079 | DataType::Int8
1080 | DataType::Int16
1081 | DataType::Int32
1082 | DataType::Int64
1083 | DataType::Float32
1084 | DataType::Float64
1085 | DataType::Decimal128(_, _)
1086 | DataType::Utf8
1087 )
1088 }
1089 DataType::Float32 | DataType::Float64 => matches!(
1090 to_type,
1091 DataType::Boolean
1092 | DataType::Int8
1093 | DataType::Int16
1094 | DataType::Int32
1095 | DataType::Int64
1096 | DataType::Float32
1097 | DataType::Float64
1098 ),
1099 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1100 to_type,
1101 DataType::Int8
1102 | DataType::Int16
1103 | DataType::Int32
1104 | DataType::Int64
1105 | DataType::Float32
1106 | DataType::Float64
1107 | DataType::Decimal128(_, _)
1108 | DataType::Decimal256(_, _)
1109 | DataType::Utf8 ),
1111 DataType::Utf8 if allow_incompat => matches!(
1112 to_type,
1113 DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
1114 ),
1115 DataType::Utf8 => matches!(to_type, DataType::Binary),
1116 DataType::Date32 => matches!(to_type, DataType::Utf8),
1117 DataType::Timestamp(_, _) => {
1118 matches!(
1119 to_type,
1120 DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1121 )
1122 }
1123 DataType::Binary => {
1124 matches!(to_type, DataType::Utf8)
1127 }
1128 _ => false,
1129 }
1130}
1131
1132fn cast_struct_to_struct(
1135 array: &StructArray,
1136 from_type: &DataType,
1137 to_type: &DataType,
1138 cast_options: &SparkCastOptions,
1139) -> DataFusionResult<ArrayRef> {
1140 match (from_type, to_type) {
1141 (DataType::Struct(from_fields), DataType::Struct(to_fields)) => {
1142 let cast_fields: Vec<ArrayRef> = from_fields
1143 .iter()
1144 .enumerate()
1145 .zip(to_fields.iter())
1146 .map(|((idx, _from), to)| {
1147 let from_field = Arc::clone(array.column(idx));
1148 let array_length = from_field.len();
1149 let cast_result = spark_cast(
1150 ColumnarValue::from(from_field),
1151 to.data_type(),
1152 cast_options,
1153 )
1154 .unwrap();
1155 cast_result.to_array(array_length).unwrap()
1156 })
1157 .collect();
1158
1159 Ok(Arc::new(StructArray::new(
1160 to_fields.clone(),
1161 cast_fields,
1162 array.nulls().cloned(),
1163 )))
1164 }
1165 _ => unreachable!(),
1166 }
1167}
1168
1169fn casts_struct_to_string(
1170 array: &StructArray,
1171 spark_cast_options: &SparkCastOptions,
1172) -> DataFusionResult<ArrayRef> {
1173 let string_arrays: Vec<ArrayRef> = array
1175 .columns()
1176 .iter()
1177 .map(|arr| {
1178 spark_cast(
1179 ColumnarValue::Array(Arc::clone(arr)),
1180 &DataType::Utf8,
1181 spark_cast_options,
1182 )
1183 .and_then(|cv| cv.into_array(arr.len()))
1184 })
1185 .collect::<DataFusionResult<Vec<_>>>()?;
1186 let string_arrays: Vec<&StringArray> =
1187 string_arrays.iter().map(|arr| arr.as_string()).collect();
1188 let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
1190 let mut str = String::with_capacity(array.len() * 16);
1191 for row_index in 0..array.len() {
1192 if array.is_null(row_index) {
1193 builder.append_null();
1194 } else {
1195 str.clear();
1196 let mut any_fields_written = false;
1197 str.push('{');
1198 for field in &string_arrays {
1199 if any_fields_written {
1200 str.push_str(", ");
1201 }
1202 if field.is_null(row_index) {
1203 str.push_str(&spark_cast_options.null_string);
1204 } else {
1205 str.push_str(field.value(row_index));
1206 }
1207 any_fields_written = true;
1208 }
1209 str.push('}');
1210 builder.append_value(&str);
1211 }
1212 }
1213 Ok(Arc::new(builder.finish()))
1214}
1215
1216fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
1217 to_type: &DataType,
1218 array: &ArrayRef,
1219 eval_mode: EvalMode,
1220) -> SparkResult<ArrayRef> {
1221 let string_array = array
1222 .as_any()
1223 .downcast_ref::<GenericStringArray<OffsetSize>>()
1224 .expect("cast_string_to_int expected a string array");
1225
1226 let cast_array: ArrayRef = match to_type {
1227 DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?,
1228 DataType::Int16 => {
1229 cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)?
1230 }
1231 DataType::Int32 => {
1232 cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)?
1233 }
1234 DataType::Int64 => {
1235 cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)?
1236 }
1237 dt => unreachable!(
1238 "{}",
1239 format!("invalid integer type {dt} in cast from string")
1240 ),
1241 };
1242 Ok(cast_array)
1243}
1244
1245fn cast_string_to_date(
1246 array: &ArrayRef,
1247 to_type: &DataType,
1248 eval_mode: EvalMode,
1249) -> SparkResult<ArrayRef> {
1250 let string_array = array
1251 .as_any()
1252 .downcast_ref::<GenericStringArray<i32>>()
1253 .expect("Expected a string array");
1254
1255 if to_type != &DataType::Date32 {
1256 unreachable!("Invalid data type {:?} in cast from string", to_type);
1257 }
1258
1259 let len = string_array.len();
1260 let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
1261
1262 for i in 0..len {
1263 let value = if string_array.is_null(i) {
1264 None
1265 } else {
1266 match date_parser(string_array.value(i), eval_mode) {
1267 Ok(Some(cast_value)) => Some(cast_value),
1268 Ok(None) => None,
1269 Err(e) => return Err(e),
1270 }
1271 };
1272
1273 match value {
1274 Some(cast_value) => cast_array.append_value(cast_value),
1275 None => cast_array.append_null(),
1276 }
1277 }
1278
1279 Ok(Arc::new(cast_array.finish()) as ArrayRef)
1280}
1281
1282fn cast_string_to_timestamp(
1283 array: &ArrayRef,
1284 to_type: &DataType,
1285 eval_mode: EvalMode,
1286 timezone_str: &str,
1287) -> SparkResult<ArrayRef> {
1288 let string_array = array
1289 .as_any()
1290 .downcast_ref::<GenericStringArray<i32>>()
1291 .expect("Expected a string array");
1292
1293 let tz = &timezone::Tz::from_str(timezone_str).unwrap();
1294
1295 let cast_array: ArrayRef = match to_type {
1296 DataType::Timestamp(_, _) => {
1297 cast_utf8_to_timestamp!(
1298 string_array,
1299 eval_mode,
1300 TimestampMicrosecondType,
1301 timestamp_parser,
1302 tz
1303 )
1304 }
1305 _ => unreachable!("Invalid data type {:?} in cast from string", to_type),
1306 };
1307 Ok(cast_array)
1308}
1309
1310fn cast_float64_to_decimal128(
1311 array: &dyn Array,
1312 precision: u8,
1313 scale: i8,
1314 eval_mode: EvalMode,
1315) -> SparkResult<ArrayRef> {
1316 cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode)
1317}
1318
1319fn cast_float32_to_decimal128(
1320 array: &dyn Array,
1321 precision: u8,
1322 scale: i8,
1323 eval_mode: EvalMode,
1324) -> SparkResult<ArrayRef> {
1325 cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode)
1326}
1327
1328fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
1329 array: &dyn Array,
1330 precision: u8,
1331 scale: i8,
1332 eval_mode: EvalMode,
1333) -> SparkResult<ArrayRef>
1334where
1335 <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
1336{
1337 let input = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1338 let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());
1339
1340 let mul = 10_f64.powi(scale as i32);
1341
1342 for i in 0..input.len() {
1343 if input.is_null(i) {
1344 cast_array.append_null();
1345 continue;
1346 }
1347
1348 let input_value = input.value(i).as_();
1349 if let Some(v) = (input_value * mul).round().to_i128() {
1350 if is_validate_decimal_precision(v, precision) {
1351 cast_array.append_value(v);
1352 continue;
1353 }
1354 };
1355
1356 if eval_mode == EvalMode::Ansi {
1357 return Err(SparkError::NumericValueOutOfRange {
1358 value: input_value.to_string(),
1359 precision,
1360 scale,
1361 });
1362 }
1363 cast_array.append_null();
1364 }
1365
1366 let res = Arc::new(
1367 cast_array
1368 .with_precision_and_scale(precision, scale)?
1369 .finish(),
1370 ) as ArrayRef;
1371 Ok(res)
1372}
1373
1374fn spark_cast_float64_to_utf8<OffsetSize>(
1375 from: &dyn Array,
1376 _eval_mode: EvalMode,
1377) -> SparkResult<ArrayRef>
1378where
1379 OffsetSize: OffsetSizeTrait,
1380{
1381 cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
1382}
1383
1384fn spark_cast_float32_to_utf8<OffsetSize>(
1385 from: &dyn Array,
1386 _eval_mode: EvalMode,
1387) -> SparkResult<ArrayRef>
1388where
1389 OffsetSize: OffsetSizeTrait,
1390{
1391 cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
1392}
1393
1394fn spark_cast_int_to_int(
1395 array: &dyn Array,
1396 eval_mode: EvalMode,
1397 from_type: &DataType,
1398 to_type: &DataType,
1399) -> SparkResult<ArrayRef> {
1400 match (from_type, to_type) {
1401 (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
1402 array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
1403 ),
1404 (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
1405 array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
1406 ),
1407 (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
1408 array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
1409 ),
1410 (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
1411 array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
1412 ),
1413 (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
1414 array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
1415 ),
1416 (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
1417 array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
1418 ),
1419 _ => unreachable!(
1420 "{}",
1421 format!("invalid integer type {to_type} in cast from {from_type}")
1422 ),
1423 }
1424}
1425
1426fn spark_cast_utf8_to_boolean<OffsetSize>(
1427 from: &dyn Array,
1428 eval_mode: EvalMode,
1429) -> SparkResult<ArrayRef>
1430where
1431 OffsetSize: OffsetSizeTrait,
1432{
1433 let array = from
1434 .as_any()
1435 .downcast_ref::<GenericStringArray<OffsetSize>>()
1436 .unwrap();
1437
1438 let output_array = array
1439 .iter()
1440 .map(|value| match value {
1441 Some(value) => match value.to_ascii_lowercase().trim() {
1442 "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
1443 "f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
1444 _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue {
1445 value: value.to_string(),
1446 from_type: "STRING".to_string(),
1447 to_type: "BOOLEAN".to_string(),
1448 }),
1449 _ => Ok(None),
1450 },
1451 _ => Ok(None),
1452 })
1453 .collect::<Result<BooleanArray, _>>()?;
1454
1455 Ok(Arc::new(output_array))
1456}
1457
1458fn spark_cast_nonintegral_numeric_to_integral(
1459 array: &dyn Array,
1460 eval_mode: EvalMode,
1461 from_type: &DataType,
1462 to_type: &DataType,
1463) -> SparkResult<ArrayRef> {
1464 match (from_type, to_type) {
1465 (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
1466 array,
1467 eval_mode,
1468 Float32Array,
1469 Int8Array,
1470 f32,
1471 i8,
1472 "FLOAT",
1473 "TINYINT",
1474 "{:e}"
1475 ),
1476 (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!(
1477 array,
1478 eval_mode,
1479 Float32Array,
1480 Int16Array,
1481 f32,
1482 i16,
1483 "FLOAT",
1484 "SMALLINT",
1485 "{:e}"
1486 ),
1487 (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!(
1488 array,
1489 eval_mode,
1490 Float32Array,
1491 Int32Array,
1492 f32,
1493 i32,
1494 "FLOAT",
1495 "INT",
1496 i32::MAX,
1497 "{:e}"
1498 ),
1499 (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
1500 array,
1501 eval_mode,
1502 Float32Array,
1503 Int64Array,
1504 f32,
1505 i64,
1506 "FLOAT",
1507 "BIGINT",
1508 i64::MAX,
1509 "{:e}"
1510 ),
1511 (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
1512 array,
1513 eval_mode,
1514 Float64Array,
1515 Int8Array,
1516 f64,
1517 i8,
1518 "DOUBLE",
1519 "TINYINT",
1520 "{:e}D"
1521 ),
1522 (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!(
1523 array,
1524 eval_mode,
1525 Float64Array,
1526 Int16Array,
1527 f64,
1528 i16,
1529 "DOUBLE",
1530 "SMALLINT",
1531 "{:e}D"
1532 ),
1533 (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!(
1534 array,
1535 eval_mode,
1536 Float64Array,
1537 Int32Array,
1538 f64,
1539 i32,
1540 "DOUBLE",
1541 "INT",
1542 i32::MAX,
1543 "{:e}D"
1544 ),
1545 (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
1546 array,
1547 eval_mode,
1548 Float64Array,
1549 Int64Array,
1550 f64,
1551 i64,
1552 "DOUBLE",
1553 "BIGINT",
1554 i64::MAX,
1555 "{:e}D"
1556 ),
1557 (DataType::Decimal128(precision, scale), DataType::Int8) => {
1558 cast_decimal_to_int16_down!(
1559 array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1560 )
1561 }
1562 (DataType::Decimal128(precision, scale), DataType::Int16) => {
1563 cast_decimal_to_int16_down!(
1564 array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1565 )
1566 }
1567 (DataType::Decimal128(precision, scale), DataType::Int32) => {
1568 cast_decimal_to_int32_up!(
1569 array,
1570 eval_mode,
1571 Int32Array,
1572 i32,
1573 "INT",
1574 i32::MAX,
1575 *precision,
1576 *scale
1577 )
1578 }
1579 (DataType::Decimal128(precision, scale), DataType::Int64) => {
1580 cast_decimal_to_int32_up!(
1581 array,
1582 eval_mode,
1583 Int64Array,
1584 i64,
1585 "BIGINT",
1586 i64::MAX,
1587 *precision,
1588 *scale
1589 )
1590 }
1591 _ => unreachable!(
1592 "{}",
1593 format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}")
1594 ),
1595 }
1596}
1597
1598fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> {
1600 Ok(cast_string_to_int_with_range_check(
1601 str,
1602 eval_mode,
1603 "TINYINT",
1604 i8::MIN as i32,
1605 i8::MAX as i32,
1606 )?
1607 .map(|v| v as i8))
1608}
1609
1610fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>> {
1612 Ok(cast_string_to_int_with_range_check(
1613 str,
1614 eval_mode,
1615 "SMALLINT",
1616 i16::MIN as i32,
1617 i16::MAX as i32,
1618 )?
1619 .map(|v| v as i16))
1620}
1621
1622fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1624 do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
1625}
1626
1627fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> {
1629 do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
1630}
1631
1632fn cast_string_to_int_with_range_check(
1633 str: &str,
1634 eval_mode: EvalMode,
1635 type_name: &str,
1636 min: i32,
1637 max: i32,
1638) -> SparkResult<Option<i32>> {
1639 match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
1640 None => Ok(None),
1641 Some(v) if v >= min && v <= max => Ok(Some(v)),
1642 _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1643 _ => Ok(None),
1644 }
1645}
1646
1647fn do_cast_string_to_int<
1651 T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
1652>(
1653 str: &str,
1654 eval_mode: EvalMode,
1655 type_name: &str,
1656 min_value: T,
1657) -> SparkResult<Option<T>> {
1658 let trimmed_str = str.trim();
1659 if trimmed_str.is_empty() {
1660 return none_or_err(eval_mode, type_name, str);
1661 }
1662 let len = trimmed_str.len();
1663 let mut result: T = T::zero();
1664 let mut negative = false;
1665 let radix = T::from(10);
1666 let stop_value = min_value / radix;
1667 let mut parse_sign_and_digits = true;
1668
1669 for (i, ch) in trimmed_str.char_indices() {
1670 if parse_sign_and_digits {
1671 if i == 0 {
1672 negative = ch == '-';
1673 let positive = ch == '+';
1674 if negative || positive {
1675 if i + 1 == len {
1676 return none_or_err(eval_mode, type_name, str);
1678 }
1679 continue;
1681 }
1682 }
1683
1684 if ch == '.' {
1685 if eval_mode == EvalMode::Legacy {
1686 parse_sign_and_digits = false;
1688 continue;
1689 } else {
1690 return none_or_err(eval_mode, type_name, str);
1691 }
1692 }
1693
1694 let digit = if ch.is_ascii_digit() {
1695 (ch as u32) - ('0' as u32)
1696 } else {
1697 return none_or_err(eval_mode, type_name, str);
1698 };
1699
1700 if result < stop_value {
1705 return none_or_err(eval_mode, type_name, str);
1706 }
1707
1708 let v = result * radix;
1712 let digit = (digit as i32).into();
1713 match v.checked_sub(&digit) {
1714 Some(x) if x <= T::zero() => result = x,
1715 _ => {
1716 return none_or_err(eval_mode, type_name, str);
1717 }
1718 }
1719 } else {
1720 if !ch.is_ascii_digit() {
1722 return none_or_err(eval_mode, type_name, str);
1723 }
1724 }
1725 }
1726
1727 if !negative {
1728 if let Some(neg) = result.checked_neg() {
1729 if neg < T::zero() {
1730 return none_or_err(eval_mode, type_name, str);
1731 }
1732 result = neg;
1733 } else {
1734 return none_or_err(eval_mode, type_name, str);
1735 }
1736 }
1737
1738 Ok(Some(result))
1739}
1740
1741#[inline]
1743fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {
1744 match eval_mode {
1745 EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1746 _ => Ok(None),
1747 }
1748}
1749
1750#[inline]
1751fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
1752 SparkError::CastInvalidValue {
1753 value: value.to_string(),
1754 from_type: from_type.to_string(),
1755 to_type: to_type.to_string(),
1756 }
1757}
1758
1759#[inline]
1760fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
1761 SparkError::CastOverFlow {
1762 value: value.to_string(),
1763 from_type: from_type.to_string(),
1764 to_type: to_type.to_string(),
1765 }
1766}
1767
1768impl Display for Cast {
1769 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1770 write!(
1771 f,
1772 "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]",
1773 self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode
1774 )
1775 }
1776}
1777
1778impl PhysicalExpr for Cast {
1779 fn as_any(&self) -> &dyn Any {
1780 self
1781 }
1782
1783 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
1784 unimplemented!()
1785 }
1786
1787 fn data_type(&self, _: &Schema) -> DataFusionResult<DataType> {
1788 Ok(self.data_type.clone())
1789 }
1790
1791 fn nullable(&self, _: &Schema) -> DataFusionResult<bool> {
1792 Ok(true)
1793 }
1794
1795 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
1796 let arg = self.child.evaluate(batch)?;
1797 spark_cast(arg, &self.data_type, &self.cast_options)
1798 }
1799
1800 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1801 vec![&self.child]
1802 }
1803
1804 fn with_new_children(
1805 self: Arc<Self>,
1806 children: Vec<Arc<dyn PhysicalExpr>>,
1807 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
1808 match children.len() {
1809 1 => Ok(Arc::new(Cast::new(
1810 Arc::clone(&children[0]),
1811 self.data_type.clone(),
1812 self.cast_options.clone(),
1813 ))),
1814 _ => internal_err!("Cast should have exactly one child"),
1815 }
1816 }
1817}
1818
1819fn timestamp_parser<T: TimeZone>(
1820 value: &str,
1821 eval_mode: EvalMode,
1822 tz: &T,
1823) -> SparkResult<Option<i64>> {
1824 let value = value.trim();
1825 if value.is_empty() {
1826 return Ok(None);
1827 }
1828 let patterns = &[
1830 (
1831 Regex::new(r"^\d{4,5}$").unwrap(),
1832 parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
1833 ),
1834 (
1835 Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1836 parse_str_to_month_timestamp,
1837 ),
1838 (
1839 Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1840 parse_str_to_day_timestamp,
1841 ),
1842 (
1843 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1844 parse_str_to_hour_timestamp,
1845 ),
1846 (
1847 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1848 parse_str_to_minute_timestamp,
1849 ),
1850 (
1851 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1852 parse_str_to_second_timestamp,
1853 ),
1854 (
1855 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1856 parse_str_to_microsecond_timestamp,
1857 ),
1858 (
1859 Regex::new(r"^T\d{1,2}$").unwrap(),
1860 parse_str_to_time_only_timestamp,
1861 ),
1862 ];
1863
1864 let mut timestamp = None;
1865
1866 for (pattern, parse_func) in patterns {
1868 if pattern.is_match(value) {
1869 timestamp = parse_func(value, tz)?;
1870 break;
1871 }
1872 }
1873
1874 if timestamp.is_none() {
1875 return if eval_mode == EvalMode::Ansi {
1876 Err(SparkError::CastInvalidValue {
1877 value: value.to_string(),
1878 from_type: "STRING".to_string(),
1879 to_type: "TIMESTAMP".to_string(),
1880 })
1881 } else {
1882 Ok(None)
1883 };
1884 }
1885
1886 match timestamp {
1887 Some(ts) => Ok(Some(ts)),
1888 None => Err(SparkError::Internal(
1889 "Failed to parse timestamp".to_string(),
1890 )),
1891 }
1892}
1893
1894fn parse_timestamp_to_micros<T: TimeZone>(
1895 timestamp_info: &TimeStampInfo,
1896 tz: &T,
1897) -> SparkResult<Option<i64>> {
1898 let datetime = tz.with_ymd_and_hms(
1899 timestamp_info.year,
1900 timestamp_info.month,
1901 timestamp_info.day,
1902 timestamp_info.hour,
1903 timestamp_info.minute,
1904 timestamp_info.second,
1905 );
1906
1907 let tz_datetime = match datetime.single() {
1909 Some(dt) => dt
1910 .with_timezone(tz)
1911 .with_nanosecond(timestamp_info.microsecond * 1000),
1912 None => {
1913 return Err(SparkError::Internal(
1914 "Failed to parse timestamp".to_string(),
1915 ));
1916 }
1917 };
1918
1919 let result = match tz_datetime {
1920 Some(dt) => dt.timestamp_micros(),
1921 None => {
1922 return Err(SparkError::Internal(
1923 "Failed to parse timestamp".to_string(),
1924 ));
1925 }
1926 };
1927
1928 Ok(Some(result))
1929}
1930
1931fn get_timestamp_values<T: TimeZone>(
1932 value: &str,
1933 timestamp_type: &str,
1934 tz: &T,
1935) -> SparkResult<Option<i64>> {
1936 let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
1937 let year = values[0].parse::<i32>().unwrap_or_default();
1938 let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
1939 let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
1940 let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
1941 let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
1942 let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
1943 let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));
1944
1945 let mut timestamp_info = TimeStampInfo::default();
1946
1947 let timestamp_info = match timestamp_type {
1948 "year" => timestamp_info.with_year(year),
1949 "month" => timestamp_info.with_year(year).with_month(month),
1950 "day" => timestamp_info
1951 .with_year(year)
1952 .with_month(month)
1953 .with_day(day),
1954 "hour" => timestamp_info
1955 .with_year(year)
1956 .with_month(month)
1957 .with_day(day)
1958 .with_hour(hour),
1959 "minute" => timestamp_info
1960 .with_year(year)
1961 .with_month(month)
1962 .with_day(day)
1963 .with_hour(hour)
1964 .with_minute(minute),
1965 "second" => timestamp_info
1966 .with_year(year)
1967 .with_month(month)
1968 .with_day(day)
1969 .with_hour(hour)
1970 .with_minute(minute)
1971 .with_second(second),
1972 "microsecond" => timestamp_info
1973 .with_year(year)
1974 .with_month(month)
1975 .with_day(day)
1976 .with_hour(hour)
1977 .with_minute(minute)
1978 .with_second(second)
1979 .with_microsecond(microsecond),
1980 _ => {
1981 return Err(SparkError::CastInvalidValue {
1982 value: value.to_string(),
1983 from_type: "STRING".to_string(),
1984 to_type: "TIMESTAMP".to_string(),
1985 })
1986 }
1987 };
1988
1989 parse_timestamp_to_micros(timestamp_info, tz)
1990}
1991
1992fn parse_str_to_year_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1993 get_timestamp_values(value, "year", tz)
1994}
1995
1996fn parse_str_to_month_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1997 get_timestamp_values(value, "month", tz)
1998}
1999
2000fn parse_str_to_day_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2001 get_timestamp_values(value, "day", tz)
2002}
2003
2004fn parse_str_to_hour_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2005 get_timestamp_values(value, "hour", tz)
2006}
2007
2008fn parse_str_to_minute_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2009 get_timestamp_values(value, "minute", tz)
2010}
2011
2012fn parse_str_to_second_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2013 get_timestamp_values(value, "second", tz)
2014}
2015
2016fn parse_str_to_microsecond_timestamp<T: TimeZone>(
2017 value: &str,
2018 tz: &T,
2019) -> SparkResult<Option<i64>> {
2020 get_timestamp_values(value, "microsecond", tz)
2021}
2022
2023fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2024 let values: Vec<&str> = value.split('T').collect();
2025 let time_values: Vec<u32> = values[1]
2026 .split(':')
2027 .map(|v| v.parse::<u32>().unwrap_or(0))
2028 .collect();
2029
2030 let datetime = tz.from_utc_datetime(&chrono::Utc::now().naive_utc());
2031 let timestamp = datetime
2032 .with_timezone(tz)
2033 .with_hour(time_values.first().copied().unwrap_or_default())
2034 .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
2035 .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
2036 .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000))
2037 .map(|dt| dt.timestamp_micros())
2038 .unwrap_or_default();
2039
2040 Ok(Some(timestamp))
2041}
2042
2043fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
2045 fn get_trimmed_start(bytes: &[u8]) -> usize {
2047 let mut start = 0;
2048 while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) {
2049 start += 1;
2050 }
2051 start
2052 }
2053
2054 fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize {
2055 let mut end = bytes.len() - 1;
2056 while end > start && is_whitespace_or_iso_control(bytes[end]) {
2057 end -= 1;
2058 }
2059 end + 1
2060 }
2061
2062 fn is_whitespace_or_iso_control(byte: u8) -> bool {
2063 byte.is_ascii_whitespace() || byte.is_ascii_control()
2064 }
2065
2066 fn is_valid_digits(segment: i32, digits: usize) -> bool {
2067 let max_digits_year = 7;
2069 (segment == 0 && digits >= 4 && digits <= max_digits_year)
2072 || (segment != 0 && digits > 0 && digits <= 2)
2073 }
2074
2075 fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
2076 if eval_mode == EvalMode::Ansi {
2077 Err(SparkError::CastInvalidValue {
2078 value: date_str.to_string(),
2079 from_type: "STRING".to_string(),
2080 to_type: "DATE".to_string(),
2081 })
2082 } else {
2083 Ok(None)
2084 }
2085 }
2086 if date_str.is_empty() {
2089 return return_result(date_str, eval_mode);
2090 }
2091
2092 let mut date_segments = [1, 1, 1];
2094 let mut sign = 1;
2095 let mut current_segment = 0;
2096 let mut current_segment_value = Wrapping(0);
2097 let mut current_segment_digits = 0;
2098 let bytes = date_str.as_bytes();
2099
2100 let mut j = get_trimmed_start(bytes);
2101 let str_end_trimmed = get_trimmed_end(j, bytes);
2102
2103 if j == str_end_trimmed {
2104 return return_result(date_str, eval_mode);
2105 }
2106
2107 if bytes[j] == b'-' || bytes[j] == b'+' {
2109 sign = if bytes[j] == b'-' { -1 } else { 1 };
2110 j += 1;
2111 }
2112
2113 while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) {
2116 let b = bytes[j];
2117 if current_segment < 2 && b == b'-' {
2118 if !is_valid_digits(current_segment, current_segment_digits) {
2120 return return_result(date_str, eval_mode);
2121 }
2122 date_segments[current_segment as usize] = current_segment_value.0;
2124 current_segment_value = Wrapping(0);
2125 current_segment_digits = 0;
2126 current_segment += 1;
2127 } else if !b.is_ascii_digit() {
2128 return return_result(date_str, eval_mode);
2129 } else {
2130 let parsed_value = Wrapping((b - b'0') as i32);
2132 current_segment_value = current_segment_value * Wrapping(10) + parsed_value;
2133 current_segment_digits += 1;
2134 }
2135 j += 1;
2136 }
2137
2138 if !is_valid_digits(current_segment, current_segment_digits) {
2140 return return_result(date_str, eval_mode);
2141 }
2142
2143 if current_segment < 2 && j < str_end_trimmed {
2144 return return_result(date_str, eval_mode);
2146 }
2147
2148 date_segments[current_segment as usize] = current_segment_value.0;
2149
2150 match NaiveDate::from_ymd_opt(
2151 sign * date_segments[0],
2152 date_segments[1] as u32,
2153 date_segments[2] as u32,
2154 ) {
2155 Some(date) => {
2156 let duration_since_epoch = date
2157 .signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
2158 .num_days();
2159 Ok(Some(duration_since_epoch.to_i32().unwrap()))
2160 }
2161 None => Ok(None),
2162 }
2163}
2164
2165fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef {
2171 match (from_type, to_type) {
2172 (DataType::Timestamp(_, _), DataType::Int64) => {
2173 unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2175 }
2176 (DataType::Dictionary(_, value_type), DataType::Int64)
2177 if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2178 {
2179 unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2181 }
2182 (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array),
2183 (DataType::Dictionary(_, value_type), DataType::Utf8)
2184 if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2185 {
2186 remove_trailing_zeroes(array)
2187 }
2188 _ => array,
2189 }
2190}
2191
2192fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
2194where
2195 T: ArrowPrimitiveType,
2196 F: Fn(T::Native) -> T::Native,
2197{
2198 if let Some(d) = array.as_any_dictionary_opt() {
2199 let new_values = unary_dyn::<F, T>(d.values(), op)?;
2200 return Ok(Arc::new(d.with_values(Arc::new(new_values))));
2201 }
2202
2203 match array.as_primitive_opt::<T>() {
2204 Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
2205 Ok(Arc::new(unary::<T, F, T>(
2206 array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
2207 op,
2208 )))
2209 }
2210 _ => Err(ArrowError::NotYetImplemented(format!(
2211 "Cannot perform unary operation of type {} on array of type {}",
2212 T::DATA_TYPE,
2213 array.data_type()
2214 ))),
2215 }
2216}
2217
2218fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
2227 let string_array = as_generic_string_array::<i32>(&array).unwrap();
2228 let result = string_array
2229 .iter()
2230 .map(|s| s.map(trim_end))
2231 .collect::<GenericStringArray<i32>>();
2232 Arc::new(result) as ArrayRef
2233}
2234
2235fn trim_end(s: &str) -> &str {
2236 if s.rfind('.').is_some() {
2237 s.trim_end_matches('0')
2238 } else {
2239 s
2240 }
2241}
2242
2243#[cfg(test)]
2244mod tests {
2245 use arrow::array::StringArray;
2246 use arrow::datatypes::TimestampMicrosecondType;
2247 use arrow::datatypes::{Field, Fields, TimeUnit};
2248 use core::f64;
2249 use std::str::FromStr;
2250
2251 use super::*;
2252
2253 #[test]
2254 #[cfg_attr(miri, ignore)] fn timestamp_parser_test() {
2256 let tz = &timezone::Tz::from_str("UTC").unwrap();
2257 assert_eq!(
2259 timestamp_parser("2020", EvalMode::Legacy, tz).unwrap(),
2260 Some(1577836800000000) );
2262 assert_eq!(
2263 timestamp_parser("2020-01", EvalMode::Legacy, tz).unwrap(),
2264 Some(1577836800000000)
2265 );
2266 assert_eq!(
2267 timestamp_parser("2020-01-01", EvalMode::Legacy, tz).unwrap(),
2268 Some(1577836800000000)
2269 );
2270 assert_eq!(
2271 timestamp_parser("2020-01-01T12", EvalMode::Legacy, tz).unwrap(),
2272 Some(1577880000000000)
2273 );
2274 assert_eq!(
2275 timestamp_parser("2020-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2276 Some(1577882040000000)
2277 );
2278 assert_eq!(
2279 timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2280 Some(1577882096000000)
2281 );
2282 assert_eq!(
2283 timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2284 Some(1577882096123456)
2285 );
2286 assert_eq!(
2287 timestamp_parser("0100", EvalMode::Legacy, tz).unwrap(),
2288 Some(-59011459200000000)
2289 );
2290 assert_eq!(
2291 timestamp_parser("0100-01", EvalMode::Legacy, tz).unwrap(),
2292 Some(-59011459200000000)
2293 );
2294 assert_eq!(
2295 timestamp_parser("0100-01-01", EvalMode::Legacy, tz).unwrap(),
2296 Some(-59011459200000000)
2297 );
2298 assert_eq!(
2299 timestamp_parser("0100-01-01T12", EvalMode::Legacy, tz).unwrap(),
2300 Some(-59011416000000000)
2301 );
2302 assert_eq!(
2303 timestamp_parser("0100-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2304 Some(-59011413960000000)
2305 );
2306 assert_eq!(
2307 timestamp_parser("0100-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2308 Some(-59011413904000000)
2309 );
2310 assert_eq!(
2311 timestamp_parser("0100-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2312 Some(-59011413903876544)
2313 );
2314 assert_eq!(
2315 timestamp_parser("10000", EvalMode::Legacy, tz).unwrap(),
2316 Some(253402300800000000)
2317 );
2318 assert_eq!(
2319 timestamp_parser("10000-01", EvalMode::Legacy, tz).unwrap(),
2320 Some(253402300800000000)
2321 );
2322 assert_eq!(
2323 timestamp_parser("10000-01-01", EvalMode::Legacy, tz).unwrap(),
2324 Some(253402300800000000)
2325 );
2326 assert_eq!(
2327 timestamp_parser("10000-01-01T12", EvalMode::Legacy, tz).unwrap(),
2328 Some(253402344000000000)
2329 );
2330 assert_eq!(
2331 timestamp_parser("10000-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2332 Some(253402346040000000)
2333 );
2334 assert_eq!(
2335 timestamp_parser("10000-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2336 Some(253402346096000000)
2337 );
2338 assert_eq!(
2339 timestamp_parser("10000-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2340 Some(253402346096123456)
2341 );
2342 }
2347
2348 #[test]
2349 #[cfg_attr(miri, ignore)] fn test_cast_string_to_timestamp() {
2351 let array: ArrayRef = Arc::new(StringArray::from(vec![
2352 Some("2020-01-01T12:34:56.123456"),
2353 Some("T2"),
2354 Some("0100-01-01T12:34:56.123456"),
2355 Some("10000-01-01T12:34:56.123456"),
2356 ]));
2357 let tz = &timezone::Tz::from_str("UTC").unwrap();
2358
2359 let string_array = array
2360 .as_any()
2361 .downcast_ref::<GenericStringArray<i32>>()
2362 .expect("Expected a string array");
2363
2364 let eval_mode = EvalMode::Legacy;
2365 let result = cast_utf8_to_timestamp!(
2366 &string_array,
2367 eval_mode,
2368 TimestampMicrosecondType,
2369 timestamp_parser,
2370 tz
2371 );
2372
2373 assert_eq!(
2374 result.data_type(),
2375 &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
2376 );
2377 assert_eq!(result.len(), 4);
2378 }
2379
2380 #[test]
2381 fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
2382 let keys = Int32Array::from(vec![0, 1]);
2384 let values: ArrayRef = Arc::new(StringArray::from(vec![
2385 Some("2020-01-01T12:34:56.123456"),
2386 Some("T2"),
2387 ]));
2388 let dict_array = Arc::new(DictionaryArray::new(keys, values));
2389
2390 let timezone = "UTC".to_string();
2391 let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false);
2393 let result = cast_array(
2394 dict_array,
2395 &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
2396 &cast_options,
2397 )?;
2398 assert_eq!(
2399 *result.data_type(),
2400 DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
2401 );
2402 assert_eq!(result.len(), 2);
2403
2404 Ok(())
2405 }
2406
2407 #[test]
2408 fn date_parser_test() {
2409 for date in &[
2410 "2020",
2411 "2020-01",
2412 "2020-01-01",
2413 "02020-01-01",
2414 "002020-01-01",
2415 "0002020-01-01",
2416 "2020-1-1",
2417 "2020-01-01 ",
2418 "2020-01-01T",
2419 ] {
2420 for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] {
2421 assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262));
2422 }
2423 }
2424
2425 for date in &[
2427 "abc",
2428 "",
2429 "not_a_date",
2430 "3/",
2431 "3/12",
2432 "3/12/2020",
2433 "3/12/2002 T",
2434 "202",
2435 "2020-010-01",
2436 "2020-10-010",
2437 "2020-10-010T",
2438 "--262143-12-31",
2439 "--262143-12-31 ",
2440 ] {
2441 for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2442 assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2443 }
2444 assert!(date_parser(date, EvalMode::Ansi).is_err());
2445 }
2446
2447 for date in &["-3638-5"] {
2448 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2449 assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(-2048160));
2450 }
2451 }
2452
2453 for date in &[
2456 "-262144-1-1",
2457 "262143-01-1",
2458 "262143-1-1",
2459 "262143-01-1 ",
2460 "262143-01-01T ",
2461 "262143-1-01T 1234",
2462 "-0973250",
2463 ] {
2464 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2465 assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2466 }
2467 }
2468 }
2469
2470 #[test]
2471 fn test_cast_string_to_date() {
2472 let array: ArrayRef = Arc::new(StringArray::from(vec![
2473 Some("2020"),
2474 Some("2020-01"),
2475 Some("2020-01-01"),
2476 Some("2020-01-01T"),
2477 ]));
2478
2479 let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap();
2480
2481 let date32_array = result
2482 .as_any()
2483 .downcast_ref::<arrow::array::Date32Array>()
2484 .unwrap();
2485 assert_eq!(date32_array.len(), 4);
2486 date32_array
2487 .iter()
2488 .for_each(|v| assert_eq!(v.unwrap(), 18262));
2489 }
2490
2491 #[test]
2492 fn test_cast_string_array_with_valid_dates() {
2493 let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2494 Some("-262143-12-31"),
2495 Some("\n -262143-12-31 "),
2496 Some("-262143-12-31T \t\n"),
2497 Some("\n\t-262143-12-31T\r"),
2498 Some("-262143-12-31T 123123123"),
2499 Some("\r\n-262143-12-31T \r123123123"),
2500 Some("\n -262143-12-31T \n\t"),
2501 ]));
2502
2503 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2504 let result =
2505 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2506 .unwrap();
2507
2508 let date32_array = result
2509 .as_any()
2510 .downcast_ref::<arrow::array::Date32Array>()
2511 .unwrap();
2512 assert_eq!(result.len(), 7);
2513 date32_array
2514 .iter()
2515 .for_each(|v| assert_eq!(v.unwrap(), -96464928));
2516 }
2517 }
2518
2519 #[test]
2520 fn test_cast_string_array_with_invalid_dates() {
2521 let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2522 Some("2020"),
2523 Some("2020-01"),
2524 Some("2020-01-01"),
2525 Some("2020-010-01T"),
2527 Some("202"),
2528 Some(" 202 "),
2529 Some("\n 2020-\r8 "),
2530 Some("2020-01-01T"),
2531 Some("-4607172990231812908"),
2533 ]));
2534
2535 for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2536 let result =
2537 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2538 .unwrap();
2539
2540 let date32_array = result
2541 .as_any()
2542 .downcast_ref::<arrow::array::Date32Array>()
2543 .unwrap();
2544 assert_eq!(
2545 date32_array.iter().collect::<Vec<_>>(),
2546 vec![
2547 Some(18262),
2548 Some(18262),
2549 Some(18262),
2550 None,
2551 None,
2552 None,
2553 None,
2554 Some(18262),
2555 None
2556 ]
2557 );
2558 }
2559
2560 let result =
2561 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi);
2562 match result {
2563 Err(e) => assert!(
2564 e.to_string().contains(
2565 "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed")
2566 ),
2567 _ => panic!("Expected error"),
2568 }
2569 }
2570
2571 #[test]
2572 fn test_cast_string_as_i8() {
2573 assert_eq!(
2575 cast_string_to_i8("127", EvalMode::Legacy).unwrap(),
2576 Some(127_i8)
2577 );
2578 assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None);
2579 assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err());
2580 assert_eq!(
2582 cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(),
2583 Some(0_i8)
2584 );
2585 assert_eq!(
2586 cast_string_to_i8(".", EvalMode::Legacy).unwrap(),
2587 Some(0_i8)
2588 );
2589 assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None);
2591 assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None);
2592 assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err());
2594 assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err());
2595 }
2596
2597 #[test]
2598 fn test_cast_unsupported_timestamp_to_date() {
2599 let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2601 let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
2602 let result = cast_array(
2603 Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2604 &DataType::Date32,
2605 &cast_options,
2606 );
2607 assert!(result.is_err())
2608 }
2609
2610 #[test]
2611 fn test_cast_invalid_timezone() {
2612 let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2613 let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false);
2614 let result = cast_array(
2615 Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2616 &DataType::Date32,
2617 &cast_options,
2618 );
2619 assert!(result.is_err())
2620 }
2621
2622 #[test]
2623 fn test_cast_struct_to_utf8() {
2624 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2625 Some(1),
2626 Some(2),
2627 None,
2628 Some(4),
2629 Some(5),
2630 ]));
2631 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2632 let c: ArrayRef = Arc::new(StructArray::from(vec![
2633 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2634 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2635 ]));
2636 let string_array = cast_array(
2637 c,
2638 &DataType::Utf8,
2639 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2640 )
2641 .unwrap();
2642 let string_array = string_array.as_string::<i32>();
2643 assert_eq!(5, string_array.len());
2644 assert_eq!(r#"{1, a}"#, string_array.value(0));
2645 assert_eq!(r#"{2, b}"#, string_array.value(1));
2646 assert_eq!(r#"{null, c}"#, string_array.value(2));
2647 assert_eq!(r#"{4, d}"#, string_array.value(3));
2648 assert_eq!(r#"{5, e}"#, string_array.value(4));
2649 }
2650
2651 #[test]
2652 fn test_cast_struct_to_struct() {
2653 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2654 Some(1),
2655 Some(2),
2656 None,
2657 Some(4),
2658 Some(5),
2659 ]));
2660 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2661 let c: ArrayRef = Arc::new(StructArray::from(vec![
2662 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2663 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2664 ]));
2665 let fields = Fields::from(vec![
2667 Field::new("a", DataType::Utf8, true),
2668 Field::new("b", DataType::Utf8, true),
2669 ]);
2670 let cast_array = spark_cast(
2671 ColumnarValue::Array(c),
2672 &DataType::Struct(fields),
2673 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2674 )
2675 .unwrap();
2676 if let ColumnarValue::Array(cast_array) = cast_array {
2677 assert_eq!(5, cast_array.len());
2678 let a = cast_array.as_struct().column(0).as_string::<i32>();
2679 assert_eq!("1", a.value(0));
2680 } else {
2681 unreachable!()
2682 }
2683 }
2684
2685 #[test]
2686 fn test_cast_struct_to_struct_drop_column() {
2687 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2688 Some(1),
2689 Some(2),
2690 None,
2691 Some(4),
2692 Some(5),
2693 ]));
2694 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2695 let c: ArrayRef = Arc::new(StructArray::from(vec![
2696 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2697 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2698 ]));
2699 let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]);
2701 let cast_array = spark_cast(
2702 ColumnarValue::Array(c),
2703 &DataType::Struct(fields),
2704 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2705 )
2706 .unwrap();
2707 if let ColumnarValue::Array(cast_array) = cast_array {
2708 assert_eq!(5, cast_array.len());
2709 let struct_array = cast_array.as_struct();
2710 assert_eq!(1, struct_array.columns().len());
2711 let a = struct_array.column(0).as_string::<i32>();
2712 assert_eq!("1", a.value(0));
2713 } else {
2714 unreachable!()
2715 }
2716 }
2717
2718 #[test]
2719 #[cfg_attr(miri, ignore)]
2726 fn test_cast_float_to_decimal() {
2727 let a: ArrayRef = Arc::new(Float64Array::from(vec![
2728 Some(42.),
2729 Some(0.5153125),
2730 Some(-42.4242415),
2731 Some(42e-314),
2732 Some(0.),
2733 Some(-4242.424242),
2734 Some(f64::INFINITY),
2735 Some(f64::NEG_INFINITY),
2736 Some(f64::NAN),
2737 None,
2738 ]));
2739 let b =
2740 cast_floating_point_to_decimal128::<Float64Type>(&a, 8, 6, EvalMode::Legacy).unwrap();
2741 assert_eq!(b.len(), a.len());
2742 let casted = b.as_primitive::<Decimal128Type>();
2743 assert_eq!(casted.value(0), 42000000);
2744 assert_eq!(casted.value(2), -42424242);
2747 assert_eq!(casted.value(3), 0);
2748 assert_eq!(casted.value(4), 0);
2749 assert!(casted.is_null(5));
2750 assert!(casted.is_null(6));
2751 assert!(casted.is_null(7));
2752 assert!(casted.is_null(8));
2753 assert!(casted.is_null(9));
2754 }
2755}