1use crate::timezone;
19use crate::utils::array_with_timezone;
20use crate::{EvalMode, SparkError, SparkResult};
21use arrow::{
22 array::{
23 cast::AsArray,
24 types::{Date32Type, Int16Type, Int32Type, Int8Type},
25 Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array,
26 GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
27 PrimitiveArray,
28 },
29 compute::{cast_with_options, take, unary, CastOptions},
30 datatypes::{
31 ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type,
32 TimestampMicrosecondType,
33 },
34 error::ArrowError,
35 record_batch::RecordBatch,
36 util::display::FormatOptions,
37};
38use arrow_array::builder::StringBuilder;
39use arrow_array::{DictionaryArray, StringArray, StructArray};
40use arrow_schema::{DataType, Schema};
41use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
42use datafusion_common::{
43 cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue,
44};
45use datafusion_expr::ColumnarValue;
46use datafusion_physical_expr::PhysicalExpr;
47use num::{
48 cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num,
49 ToPrimitive,
50};
51use regex::Regex;
52use std::collections::HashMap;
53use std::str::FromStr;
54use std::{
55 any::Any,
56 fmt::{Debug, Display, Formatter},
57 hash::Hash,
58 num::Wrapping,
59 sync::Arc,
60};
61
62static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
63
64const MICROS_PER_SECOND: i64 = 1000000;
65
66static CAST_OPTIONS: CastOptions = CastOptions {
67 safe: true,
68 format_options: FormatOptions::new()
69 .with_timestamp_tz_format(TIMESTAMP_FORMAT)
70 .with_timestamp_format(TIMESTAMP_FORMAT),
71};
72
73struct TimeStampInfo {
74 year: i32,
75 month: u32,
76 day: u32,
77 hour: u32,
78 minute: u32,
79 second: u32,
80 microsecond: u32,
81}
82
83impl Default for TimeStampInfo {
84 fn default() -> Self {
85 TimeStampInfo {
86 year: 1,
87 month: 1,
88 day: 1,
89 hour: 0,
90 minute: 0,
91 second: 0,
92 microsecond: 0,
93 }
94 }
95}
96
97impl TimeStampInfo {
98 pub fn with_year(&mut self, year: i32) -> &mut Self {
99 self.year = year;
100 self
101 }
102
103 pub fn with_month(&mut self, month: u32) -> &mut Self {
104 self.month = month;
105 self
106 }
107
108 pub fn with_day(&mut self, day: u32) -> &mut Self {
109 self.day = day;
110 self
111 }
112
113 pub fn with_hour(&mut self, hour: u32) -> &mut Self {
114 self.hour = hour;
115 self
116 }
117
118 pub fn with_minute(&mut self, minute: u32) -> &mut Self {
119 self.minute = minute;
120 self
121 }
122
123 pub fn with_second(&mut self, second: u32) -> &mut Self {
124 self.second = second;
125 self
126 }
127
128 pub fn with_microsecond(&mut self, microsecond: u32) -> &mut Self {
129 self.microsecond = microsecond;
130 self
131 }
132}
133
134#[derive(Debug, Eq)]
135pub struct Cast {
136 pub child: Arc<dyn PhysicalExpr>,
137 pub data_type: DataType,
138 pub cast_options: SparkCastOptions,
139}
140
141impl PartialEq for Cast {
142 fn eq(&self, other: &Self) -> bool {
143 self.child.eq(&other.child)
144 && self.data_type.eq(&other.data_type)
145 && self.cast_options.eq(&other.cast_options)
146 }
147}
148
149impl Hash for Cast {
150 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
151 self.child.hash(state);
152 self.data_type.hash(state);
153 self.cast_options.hash(state);
154 }
155}
156
157pub fn cast_supported(
159 from_type: &DataType,
160 to_type: &DataType,
161 options: &SparkCastOptions,
162) -> bool {
163 use DataType::*;
164
165 let from_type = if let Dictionary(_, dt) = from_type {
166 dt
167 } else {
168 from_type
169 };
170
171 let to_type = if let Dictionary(_, dt) = to_type {
172 dt
173 } else {
174 to_type
175 };
176
177 if from_type == to_type {
178 return true;
179 }
180
181 match (from_type, to_type) {
182 (Boolean, _) => can_cast_from_boolean(to_type, options),
183 (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
184 if options.allow_cast_unsigned_ints =>
185 {
186 true
187 }
188 (Int8, _) => can_cast_from_byte(to_type, options),
189 (Int16, _) => can_cast_from_short(to_type, options),
190 (Int32, _) => can_cast_from_int(to_type, options),
191 (Int64, _) => can_cast_from_long(to_type, options),
192 (Float32, _) => can_cast_from_float(to_type, options),
193 (Float64, _) => can_cast_from_double(to_type, options),
194 (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
195 (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options),
196 (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options),
197 (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
198 (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
199 (Struct(from_fields), Struct(to_fields)) => from_fields
200 .iter()
201 .zip(to_fields.iter())
202 .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)),
203 _ => false,
204 }
205}
206
207fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool {
208 use DataType::*;
209 match to_type {
210 Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
211 Float32 | Float64 => {
212 options.allow_incompat
216 }
217 Decimal128(_, _) => {
218 options.allow_incompat
223 }
224 Date32 | Date64 => {
225 options.allow_incompat
228 }
229 Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => {
230 false
232 }
233 Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => {
234 options.allow_incompat
236 }
237 Timestamp(_, _) => {
238 options.allow_incompat
241 }
242 _ => false,
243 }
244}
245
246fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> bool {
247 use DataType::*;
248 match from_type {
249 Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => true,
250 Float32 | Float64 => {
251 true
255 }
256 Decimal128(_, _) => {
257 true
261 }
262 Binary => {
263 options.allow_incompat
266 }
267 Struct(fields) => fields
268 .iter()
269 .all(|f| can_cast_to_string(f.data_type(), options)),
270 _ => false,
271 }
272}
273
274fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) -> bool {
275 use DataType::*;
276 match to_type {
277 Timestamp(_, _) | Date32 | Date64 | Utf8 => {
278 options.allow_incompat
280 }
281 _ => {
282 false
284 }
285 }
286}
287
288fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> bool {
289 use DataType::*;
290 match to_type {
291 Boolean | Int8 | Int16 => {
292 false
295 }
296 Int64 => {
297 true
299 }
300 Date32 | Date64 | Utf8 | Decimal128(_, _) => true,
301 _ => {
302 false
304 }
305 }
306}
307
308fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
309 use DataType::*;
310 matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
311}
312
313fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
314 use DataType::*;
315 matches!(
316 to_type,
317 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
318 )
319}
320
321fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
322 use DataType::*;
323 matches!(
324 to_type,
325 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
326 )
327}
328
329fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
330 use DataType::*;
331 match to_type {
332 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true,
333 Decimal128(_, _) => {
334 options.allow_incompat
336 }
337 _ => false,
338 }
339}
340
341fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
342 use DataType::*;
343 match to_type {
344 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
345 Decimal128(_, _) => {
346 options.allow_incompat
348 }
349 _ => false,
350 }
351}
352
353fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool {
354 use DataType::*;
355 matches!(
356 to_type,
357 Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _)
358 )
359}
360
361fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool {
362 use DataType::*;
363 matches!(
364 to_type,
365 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _)
366 )
367}
368
369fn can_cast_from_decimal(
370 p1: &u8,
371 _s1: &i8,
372 to_type: &DataType,
373 options: &SparkCastOptions,
374) -> bool {
375 use DataType::*;
376 match to_type {
377 Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
378 Decimal128(p2, _) => {
379 if p2 < p1 {
380 options.allow_incompat
383 } else {
384 true
385 }
386 }
387 _ => false,
388 }
389}
390
391macro_rules! cast_utf8_to_int {
392 ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
393 let len = $array.len();
394 let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
395 for i in 0..len {
396 if $array.is_null(i) {
397 cast_array.append_null()
398 } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
399 cast_array.append_value(cast_value);
400 } else {
401 cast_array.append_null()
402 }
403 }
404 let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
405 result
406 }};
407}
408macro_rules! cast_utf8_to_timestamp {
409 ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
410 let len = $array.len();
411 let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
412 for i in 0..len {
413 if $array.is_null(i) {
414 cast_array.append_null()
415 } else if let Ok(Some(cast_value)) =
416 $cast_method($array.value(i).trim(), $eval_mode, $tz)
417 {
418 cast_array.append_value(cast_value);
419 } else {
420 cast_array.append_null()
421 }
422 }
423 let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
424 result
425 }};
426}
427
428macro_rules! cast_float_to_string {
429 ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{
430
431 fn cast<OffsetSize>(
432 from: &dyn Array,
433 _eval_mode: EvalMode,
434 ) -> SparkResult<ArrayRef>
435 where
436 OffsetSize: OffsetSizeTrait, {
437 let array = from.as_any().downcast_ref::<$output_type>().unwrap();
438
439 const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
448 const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;
449
450 let output_array = array
451 .iter()
452 .map(|value| match value {
453 Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())),
454 Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())),
455 Some(value)
456 if (value.abs() < UPPER_SCIENTIFIC_BOUND
457 && value.abs() >= LOWER_SCIENTIFIC_BOUND)
458 || value.abs() == 0.0 =>
459 {
460 let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" };
461
462 Ok(Some(format!("{value}{trailing_zero}")))
463 }
464 Some(value)
465 if value.abs() >= UPPER_SCIENTIFIC_BOUND
466 || value.abs() < LOWER_SCIENTIFIC_BOUND =>
467 {
468 let formatted = format!("{value:E}");
469
470 if formatted.contains(".") {
471 Ok(Some(formatted))
472 } else {
473 let prepare_number: Vec<&str> = formatted.split("E").collect();
476
477 let coefficient = prepare_number[0];
478
479 let exponent = prepare_number[1];
480
481 Ok(Some(format!("{coefficient}.0E{exponent}")))
482 }
483 }
484 Some(value) => Ok(Some(value.to_string())),
485 _ => Ok(None),
486 })
487 .collect::<Result<GenericStringArray<OffsetSize>, SparkError>>()?;
488
489 Ok(Arc::new(output_array))
490 }
491
492 cast::<$offset_type>($from, $eval_mode)
493 }};
494}
495
496macro_rules! cast_int_to_int_macro {
497 (
498 $array: expr,
499 $eval_mode:expr,
500 $from_arrow_primitive_type: ty,
501 $to_arrow_primitive_type: ty,
502 $from_data_type: expr,
503 $to_native_type: ty,
504 $spark_from_data_type_name: expr,
505 $spark_to_data_type_name: expr
506 ) => {{
507 let cast_array = $array
508 .as_any()
509 .downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
510 .unwrap();
511 let spark_int_literal_suffix = match $from_data_type {
512 &DataType::Int64 => "L",
513 &DataType::Int16 => "S",
514 &DataType::Int8 => "T",
515 _ => "",
516 };
517
518 let output_array = match $eval_mode {
519 EvalMode::Legacy => cast_array
520 .iter()
521 .map(|value| match value {
522 Some(value) => {
523 Ok::<Option<$to_native_type>, SparkError>(Some(value as $to_native_type))
524 }
525 _ => Ok(None),
526 })
527 .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
528 _ => cast_array
529 .iter()
530 .map(|value| match value {
531 Some(value) => {
532 let res = <$to_native_type>::try_from(value);
533 if res.is_err() {
534 Err(cast_overflow(
535 &(value.to_string() + spark_int_literal_suffix),
536 $spark_from_data_type_name,
537 $spark_to_data_type_name,
538 ))
539 } else {
540 Ok::<Option<$to_native_type>, SparkError>(Some(res.unwrap()))
541 }
542 }
543 _ => Ok(None),
544 })
545 .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
546 }?;
547 let result: SparkResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
548 result
549 }};
550}
551
552macro_rules! cast_float_to_int16_down {
556 (
557 $array:expr,
558 $eval_mode:expr,
559 $src_array_type:ty,
560 $dest_array_type:ty,
561 $rust_src_type:ty,
562 $rust_dest_type:ty,
563 $src_type_str:expr,
564 $dest_type_str:expr,
565 $format_str:expr
566 ) => {{
567 let cast_array = $array
568 .as_any()
569 .downcast_ref::<$src_array_type>()
570 .expect(concat!("Expected a ", stringify!($src_array_type)));
571
572 let output_array = match $eval_mode {
573 EvalMode::Ansi => cast_array
574 .iter()
575 .map(|value| match value {
576 Some(value) => {
577 let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX;
578 if is_overflow {
579 return Err(cast_overflow(
580 &format!($format_str, value).replace("e", "E"),
581 $src_type_str,
582 $dest_type_str,
583 ));
584 }
585 let i32_value = value as i32;
586 <$rust_dest_type>::try_from(i32_value)
587 .map_err(|_| {
588 cast_overflow(
589 &format!($format_str, value).replace("e", "E"),
590 $src_type_str,
591 $dest_type_str,
592 )
593 })
594 .map(Some)
595 }
596 None => Ok(None),
597 })
598 .collect::<Result<$dest_array_type, _>>()?,
599 _ => cast_array
600 .iter()
601 .map(|value| match value {
602 Some(value) => {
603 let i32_value = value as i32;
604 Ok::<Option<$rust_dest_type>, SparkError>(Some(
605 i32_value as $rust_dest_type,
606 ))
607 }
608 None => Ok(None),
609 })
610 .collect::<Result<$dest_array_type, _>>()?,
611 };
612 Ok(Arc::new(output_array) as ArrayRef)
613 }};
614}
615
616macro_rules! cast_float_to_int32_up {
617 (
618 $array:expr,
619 $eval_mode:expr,
620 $src_array_type:ty,
621 $dest_array_type:ty,
622 $rust_src_type:ty,
623 $rust_dest_type:ty,
624 $src_type_str:expr,
625 $dest_type_str:expr,
626 $max_dest_val:expr,
627 $format_str:expr
628 ) => {{
629 let cast_array = $array
630 .as_any()
631 .downcast_ref::<$src_array_type>()
632 .expect(concat!("Expected a ", stringify!($src_array_type)));
633
634 let output_array = match $eval_mode {
635 EvalMode::Ansi => cast_array
636 .iter()
637 .map(|value| match value {
638 Some(value) => {
639 let is_overflow =
640 value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val;
641 if is_overflow {
642 return Err(cast_overflow(
643 &format!($format_str, value).replace("e", "E"),
644 $src_type_str,
645 $dest_type_str,
646 ));
647 }
648 Ok(Some(value as $rust_dest_type))
649 }
650 None => Ok(None),
651 })
652 .collect::<Result<$dest_array_type, _>>()?,
653 _ => cast_array
654 .iter()
655 .map(|value| match value {
656 Some(value) => {
657 Ok::<Option<$rust_dest_type>, SparkError>(Some(value as $rust_dest_type))
658 }
659 None => Ok(None),
660 })
661 .collect::<Result<$dest_array_type, _>>()?,
662 };
663 Ok(Arc::new(output_array) as ArrayRef)
664 }};
665}
666
667macro_rules! cast_decimal_to_int16_down {
671 (
672 $array:expr,
673 $eval_mode:expr,
674 $dest_array_type:ty,
675 $rust_dest_type:ty,
676 $dest_type_str:expr,
677 $precision:expr,
678 $scale:expr
679 ) => {{
680 let cast_array = $array
681 .as_any()
682 .downcast_ref::<Decimal128Array>()
683 .expect(concat!("Expected a Decimal128ArrayType"));
684
685 let output_array = match $eval_mode {
686 EvalMode::Ansi => cast_array
687 .iter()
688 .map(|value| match value {
689 Some(value) => {
690 let divisor = 10_i128.pow($scale as u32);
691 let (truncated, decimal) = (value / divisor, (value % divisor).abs());
692 let is_overflow = truncated.abs() > i32::MAX.into();
693 if is_overflow {
694 return Err(cast_overflow(
695 &format!("{}.{}BD", truncated, decimal),
696 &format!("DECIMAL({},{})", $precision, $scale),
697 $dest_type_str,
698 ));
699 }
700 let i32_value = truncated as i32;
701 <$rust_dest_type>::try_from(i32_value)
702 .map_err(|_| {
703 cast_overflow(
704 &format!("{}.{}BD", truncated, decimal),
705 &format!("DECIMAL({},{})", $precision, $scale),
706 $dest_type_str,
707 )
708 })
709 .map(Some)
710 }
711 None => Ok(None),
712 })
713 .collect::<Result<$dest_array_type, _>>()?,
714 _ => cast_array
715 .iter()
716 .map(|value| match value {
717 Some(value) => {
718 let divisor = 10_i128.pow($scale as u32);
719 let i32_value = (value / divisor) as i32;
720 Ok::<Option<$rust_dest_type>, SparkError>(Some(
721 i32_value as $rust_dest_type,
722 ))
723 }
724 None => Ok(None),
725 })
726 .collect::<Result<$dest_array_type, _>>()?,
727 };
728 Ok(Arc::new(output_array) as ArrayRef)
729 }};
730}
731
732macro_rules! cast_decimal_to_int32_up {
733 (
734 $array:expr,
735 $eval_mode:expr,
736 $dest_array_type:ty,
737 $rust_dest_type:ty,
738 $dest_type_str:expr,
739 $max_dest_val:expr,
740 $precision:expr,
741 $scale:expr
742 ) => {{
743 let cast_array = $array
744 .as_any()
745 .downcast_ref::<Decimal128Array>()
746 .expect(concat!("Expected a Decimal128ArrayType"));
747
748 let output_array = match $eval_mode {
749 EvalMode::Ansi => cast_array
750 .iter()
751 .map(|value| match value {
752 Some(value) => {
753 let divisor = 10_i128.pow($scale as u32);
754 let (truncated, decimal) = (value / divisor, (value % divisor).abs());
755 let is_overflow = truncated.abs() > $max_dest_val.into();
756 if is_overflow {
757 return Err(cast_overflow(
758 &format!("{}.{}BD", truncated, decimal),
759 &format!("DECIMAL({},{})", $precision, $scale),
760 $dest_type_str,
761 ));
762 }
763 Ok(Some(truncated as $rust_dest_type))
764 }
765 None => Ok(None),
766 })
767 .collect::<Result<$dest_array_type, _>>()?,
768 _ => cast_array
769 .iter()
770 .map(|value| match value {
771 Some(value) => {
772 let divisor = 10_i128.pow($scale as u32);
773 let truncated = value / divisor;
774 Ok::<Option<$rust_dest_type>, SparkError>(Some(
775 truncated as $rust_dest_type,
776 ))
777 }
778 None => Ok(None),
779 })
780 .collect::<Result<$dest_array_type, _>>()?,
781 };
782 Ok(Arc::new(output_array) as ArrayRef)
783 }};
784}
785
786impl Cast {
787 pub fn new(
788 child: Arc<dyn PhysicalExpr>,
789 data_type: DataType,
790 cast_options: SparkCastOptions,
791 ) -> Self {
792 Self {
793 child,
794 data_type,
795 cast_options,
796 }
797 }
798}
799
800#[derive(Debug, Clone, Hash, PartialEq, Eq)]
802pub struct SparkCastOptions {
803 pub eval_mode: EvalMode,
805 pub timezone: String,
809 pub allow_incompat: bool,
811 pub allow_cast_unsigned_ints: bool,
813 pub is_adapting_schema: bool,
816}
817
818impl SparkCastOptions {
819 pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self {
820 Self {
821 eval_mode,
822 timezone: timezone.to_string(),
823 allow_incompat,
824 allow_cast_unsigned_ints: false,
825 is_adapting_schema: false,
826 }
827 }
828
829 pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self {
830 Self {
831 eval_mode,
832 timezone: "".to_string(),
833 allow_incompat,
834 allow_cast_unsigned_ints: false,
835 is_adapting_schema: false,
836 }
837 }
838}
839
840pub fn spark_cast(
844 arg: ColumnarValue,
845 data_type: &DataType,
846 cast_options: &SparkCastOptions,
847) -> DataFusionResult<ColumnarValue> {
848 match arg {
849 ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array(
850 array,
851 data_type,
852 cast_options,
853 )?)),
854 ColumnarValue::Scalar(scalar) => {
855 let array = scalar.to_array()?;
859 let scalar =
860 ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?;
861 Ok(ColumnarValue::Scalar(scalar))
862 }
863 }
864}
865
866fn cast_array(
867 array: ArrayRef,
868 to_type: &DataType,
869 cast_options: &SparkCastOptions,
870) -> DataFusionResult<ArrayRef> {
871 use DataType::*;
872 let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
873 let from_type = array.data_type().clone();
874
875 let native_cast_options: CastOptions = CastOptions {
876 safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), format_options: FormatOptions::new()
878 .with_timestamp_tz_format(TIMESTAMP_FORMAT)
879 .with_timestamp_format(TIMESTAMP_FORMAT),
880 };
881
882 let array = match &from_type {
883 Dictionary(key_type, value_type)
884 if key_type.as_ref() == &Int32
885 && (value_type.as_ref() == &Utf8 || value_type.as_ref() == &LargeUtf8) =>
886 {
887 let dict_array = array
888 .as_any()
889 .downcast_ref::<DictionaryArray<Int32Type>>()
890 .expect("Expected a dictionary array");
891
892 let casted_dictionary = DictionaryArray::<Int32Type>::new(
893 dict_array.keys().clone(),
894 cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?,
895 );
896
897 let casted_result = match to_type {
898 Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
899 _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?,
900 };
901 return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
902 }
903 _ => array,
904 };
905 let from_type = array.data_type();
906 let eval_mode = cast_options.eval_mode;
907
908 let cast_result = match (from_type, to_type) {
909 (Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
910 (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
911 (Utf8, Timestamp(_, _)) => {
912 cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
913 }
914 (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
915 (Int64, Int32)
916 | (Int64, Int16)
917 | (Int64, Int8)
918 | (Int32, Int16)
919 | (Int32, Int8)
920 | (Int16, Int8)
921 if eval_mode != EvalMode::Try =>
922 {
923 spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
924 }
925 (Utf8, Int8 | Int16 | Int32 | Int64) => {
926 cast_string_to_int::<i32>(to_type, &array, eval_mode)
927 }
928 (LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
929 cast_string_to_int::<i64>(to_type, &array, eval_mode)
930 }
931 (Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
932 (Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array, eval_mode),
933 (Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
934 (Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array, eval_mode),
935 (Float32, Decimal128(precision, scale)) => {
936 cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
937 }
938 (Float64, Decimal128(precision, scale)) => {
939 cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
940 }
941 (Float32, Int8)
942 | (Float32, Int16)
943 | (Float32, Int32)
944 | (Float32, Int64)
945 | (Float64, Int8)
946 | (Float64, Int16)
947 | (Float64, Int32)
948 | (Float64, Int64)
949 | (Decimal128(_, _), Int8)
950 | (Decimal128(_, _), Int16)
951 | (Decimal128(_, _), Int32)
952 | (Decimal128(_, _), Int64)
953 if eval_mode != EvalMode::Try =>
954 {
955 spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
956 }
957 (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
958 (Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
959 array.as_struct(),
960 from_type,
961 to_type,
962 cast_options,
963 )?),
964 (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
965 if cast_options.allow_cast_unsigned_ints =>
966 {
967 Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
968 }
969 _ if cast_options.is_adapting_schema
970 || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
971 {
972 Ok(cast_with_options(&array, to_type, &native_cast_options)?)
974 }
975 _ => {
976 Err(SparkError::Internal(format!(
980 "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}"
981 )))
982 }
983 };
984 Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
985}
986
987fn is_datafusion_spark_compatible(
990 from_type: &DataType,
991 to_type: &DataType,
992 allow_incompat: bool,
993) -> bool {
994 if from_type == to_type {
995 return true;
996 }
997 match from_type {
998 DataType::Null => {
999 matches!(to_type, DataType::List(_))
1000 }
1001 DataType::Boolean => matches!(
1002 to_type,
1003 DataType::Int8
1004 | DataType::Int16
1005 | DataType::Int32
1006 | DataType::Int64
1007 | DataType::Float32
1008 | DataType::Float64
1009 | DataType::Utf8
1010 ),
1011 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1012 matches!(
1016 to_type,
1017 DataType::Boolean
1018 | DataType::Int8
1019 | DataType::Int16
1020 | DataType::Int32
1021 | DataType::Int64
1022 | DataType::Float32
1023 | DataType::Float64
1024 | DataType::Decimal128(_, _)
1025 | DataType::Utf8
1026 )
1027 }
1028 DataType::Float32 | DataType::Float64 => matches!(
1029 to_type,
1030 DataType::Boolean
1031 | DataType::Int8
1032 | DataType::Int16
1033 | DataType::Int32
1034 | DataType::Int64
1035 | DataType::Float32
1036 | DataType::Float64
1037 ),
1038 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1039 to_type,
1040 DataType::Int8
1041 | DataType::Int16
1042 | DataType::Int32
1043 | DataType::Int64
1044 | DataType::Float32
1045 | DataType::Float64
1046 | DataType::Decimal128(_, _)
1047 | DataType::Decimal256(_, _)
1048 | DataType::Utf8 ),
1050 DataType::Utf8 if allow_incompat => matches!(
1051 to_type,
1052 DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
1053 ),
1054 DataType::Utf8 => matches!(to_type, DataType::Binary),
1055 DataType::Date32 => matches!(to_type, DataType::Utf8),
1056 DataType::Timestamp(_, _) => {
1057 matches!(
1058 to_type,
1059 DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1060 )
1061 }
1062 DataType::Binary => {
1063 matches!(to_type, DataType::Utf8)
1066 }
1067 _ => false,
1068 }
1069}
1070
1071fn cast_struct_to_struct(
1074 array: &StructArray,
1075 from_type: &DataType,
1076 to_type: &DataType,
1077 cast_options: &SparkCastOptions,
1078) -> DataFusionResult<ArrayRef> {
1079 match (from_type, to_type) {
1080 (DataType::Struct(from_fields), DataType::Struct(to_fields)) => {
1081 let mut field_name_to_index_map = HashMap::new();
1083 for (i, field) in from_fields.iter().enumerate() {
1084 field_name_to_index_map.insert(field.name(), i);
1085 }
1086 assert_eq!(field_name_to_index_map.len(), from_fields.len());
1087 let mut cast_fields: Vec<ArrayRef> = Vec::with_capacity(to_fields.len());
1088 for i in 0..to_fields.len() {
1089 let from_index = field_name_to_index_map[to_fields[i].name()];
1090 let cast_field = cast_array(
1091 Arc::clone(array.column(from_index)),
1092 to_fields[i].data_type(),
1093 cast_options,
1094 )?;
1095 cast_fields.push(cast_field);
1096 }
1097 Ok(Arc::new(StructArray::new(
1098 to_fields.clone(),
1099 cast_fields,
1100 array.nulls().cloned(),
1101 )))
1102 }
1103 _ => unreachable!(),
1104 }
1105}
1106
1107fn casts_struct_to_string(
1108 array: &StructArray,
1109 spark_cast_options: &SparkCastOptions,
1110) -> DataFusionResult<ArrayRef> {
1111 let string_arrays: Vec<ArrayRef> = array
1113 .columns()
1114 .iter()
1115 .map(|arr| {
1116 spark_cast(
1117 ColumnarValue::Array(Arc::clone(arr)),
1118 &DataType::Utf8,
1119 spark_cast_options,
1120 )
1121 .and_then(|cv| cv.into_array(arr.len()))
1122 })
1123 .collect::<DataFusionResult<Vec<_>>>()?;
1124 let string_arrays: Vec<&StringArray> =
1125 string_arrays.iter().map(|arr| arr.as_string()).collect();
1126 let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
1128 let mut str = String::with_capacity(array.len() * 16);
1129 for row_index in 0..array.len() {
1130 if array.is_null(row_index) {
1131 builder.append_null();
1132 } else {
1133 str.clear();
1134 let mut any_fields_written = false;
1135 str.push('{');
1136 for field in &string_arrays {
1137 if any_fields_written {
1138 str.push_str(", ");
1139 }
1140 if field.is_null(row_index) {
1141 str.push_str("null");
1142 } else {
1143 str.push_str(field.value(row_index));
1144 }
1145 any_fields_written = true;
1146 }
1147 str.push('}');
1148 builder.append_value(&str);
1149 }
1150 }
1151 Ok(Arc::new(builder.finish()))
1152}
1153
1154fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
1155 to_type: &DataType,
1156 array: &ArrayRef,
1157 eval_mode: EvalMode,
1158) -> SparkResult<ArrayRef> {
1159 let string_array = array
1160 .as_any()
1161 .downcast_ref::<GenericStringArray<OffsetSize>>()
1162 .expect("cast_string_to_int expected a string array");
1163
1164 let cast_array: ArrayRef = match to_type {
1165 DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?,
1166 DataType::Int16 => {
1167 cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)?
1168 }
1169 DataType::Int32 => {
1170 cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)?
1171 }
1172 DataType::Int64 => {
1173 cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)?
1174 }
1175 dt => unreachable!(
1176 "{}",
1177 format!("invalid integer type {dt} in cast from string")
1178 ),
1179 };
1180 Ok(cast_array)
1181}
1182
1183fn cast_string_to_date(
1184 array: &ArrayRef,
1185 to_type: &DataType,
1186 eval_mode: EvalMode,
1187) -> SparkResult<ArrayRef> {
1188 let string_array = array
1189 .as_any()
1190 .downcast_ref::<GenericStringArray<i32>>()
1191 .expect("Expected a string array");
1192
1193 if to_type != &DataType::Date32 {
1194 unreachable!("Invalid data type {:?} in cast from string", to_type);
1195 }
1196
1197 let len = string_array.len();
1198 let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
1199
1200 for i in 0..len {
1201 let value = if string_array.is_null(i) {
1202 None
1203 } else {
1204 match date_parser(string_array.value(i), eval_mode) {
1205 Ok(Some(cast_value)) => Some(cast_value),
1206 Ok(None) => None,
1207 Err(e) => return Err(e),
1208 }
1209 };
1210
1211 match value {
1212 Some(cast_value) => cast_array.append_value(cast_value),
1213 None => cast_array.append_null(),
1214 }
1215 }
1216
1217 Ok(Arc::new(cast_array.finish()) as ArrayRef)
1218}
1219
1220fn cast_string_to_timestamp(
1221 array: &ArrayRef,
1222 to_type: &DataType,
1223 eval_mode: EvalMode,
1224 timezone_str: &str,
1225) -> SparkResult<ArrayRef> {
1226 let string_array = array
1227 .as_any()
1228 .downcast_ref::<GenericStringArray<i32>>()
1229 .expect("Expected a string array");
1230
1231 let tz = &timezone::Tz::from_str(timezone_str).unwrap();
1232
1233 let cast_array: ArrayRef = match to_type {
1234 DataType::Timestamp(_, _) => {
1235 cast_utf8_to_timestamp!(
1236 string_array,
1237 eval_mode,
1238 TimestampMicrosecondType,
1239 timestamp_parser,
1240 tz
1241 )
1242 }
1243 _ => unreachable!("Invalid data type {:?} in cast from string", to_type),
1244 };
1245 Ok(cast_array)
1246}
1247
1248fn cast_float64_to_decimal128(
1249 array: &dyn Array,
1250 precision: u8,
1251 scale: i8,
1252 eval_mode: EvalMode,
1253) -> SparkResult<ArrayRef> {
1254 cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode)
1255}
1256
1257fn cast_float32_to_decimal128(
1258 array: &dyn Array,
1259 precision: u8,
1260 scale: i8,
1261 eval_mode: EvalMode,
1262) -> SparkResult<ArrayRef> {
1263 cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode)
1264}
1265
1266fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
1267 array: &dyn Array,
1268 precision: u8,
1269 scale: i8,
1270 eval_mode: EvalMode,
1271) -> SparkResult<ArrayRef>
1272where
1273 <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
1274{
1275 let input = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1276 let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());
1277
1278 let mul = 10_f64.powi(scale as i32);
1279
1280 for i in 0..input.len() {
1281 if input.is_null(i) {
1282 cast_array.append_null();
1283 } else {
1284 let input_value = input.value(i).as_();
1285 let value = (input_value * mul).round().to_i128();
1286
1287 match value {
1288 Some(v) => {
1289 if Decimal128Type::validate_decimal_precision(v, precision).is_err() {
1290 if eval_mode == EvalMode::Ansi {
1291 return Err(SparkError::NumericValueOutOfRange {
1292 value: input_value.to_string(),
1293 precision,
1294 scale,
1295 });
1296 } else {
1297 cast_array.append_null();
1298 }
1299 }
1300 cast_array.append_value(v);
1301 }
1302 None => {
1303 if eval_mode == EvalMode::Ansi {
1304 return Err(SparkError::NumericValueOutOfRange {
1305 value: input_value.to_string(),
1306 precision,
1307 scale,
1308 });
1309 } else {
1310 cast_array.append_null();
1311 }
1312 }
1313 }
1314 }
1315 }
1316
1317 let res = Arc::new(
1318 cast_array
1319 .with_precision_and_scale(precision, scale)?
1320 .finish(),
1321 ) as ArrayRef;
1322 Ok(res)
1323}
1324
1325fn spark_cast_float64_to_utf8<OffsetSize>(
1326 from: &dyn Array,
1327 _eval_mode: EvalMode,
1328) -> SparkResult<ArrayRef>
1329where
1330 OffsetSize: OffsetSizeTrait,
1331{
1332 cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
1333}
1334
1335fn spark_cast_float32_to_utf8<OffsetSize>(
1336 from: &dyn Array,
1337 _eval_mode: EvalMode,
1338) -> SparkResult<ArrayRef>
1339where
1340 OffsetSize: OffsetSizeTrait,
1341{
1342 cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
1343}
1344
1345fn spark_cast_int_to_int(
1346 array: &dyn Array,
1347 eval_mode: EvalMode,
1348 from_type: &DataType,
1349 to_type: &DataType,
1350) -> SparkResult<ArrayRef> {
1351 match (from_type, to_type) {
1352 (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
1353 array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
1354 ),
1355 (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
1356 array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
1357 ),
1358 (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
1359 array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
1360 ),
1361 (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
1362 array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
1363 ),
1364 (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
1365 array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
1366 ),
1367 (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
1368 array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
1369 ),
1370 _ => unreachable!(
1371 "{}",
1372 format!("invalid integer type {to_type} in cast from {from_type}")
1373 ),
1374 }
1375}
1376
1377fn spark_cast_utf8_to_boolean<OffsetSize>(
1378 from: &dyn Array,
1379 eval_mode: EvalMode,
1380) -> SparkResult<ArrayRef>
1381where
1382 OffsetSize: OffsetSizeTrait,
1383{
1384 let array = from
1385 .as_any()
1386 .downcast_ref::<GenericStringArray<OffsetSize>>()
1387 .unwrap();
1388
1389 let output_array = array
1390 .iter()
1391 .map(|value| match value {
1392 Some(value) => match value.to_ascii_lowercase().trim() {
1393 "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
1394 "f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
1395 _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue {
1396 value: value.to_string(),
1397 from_type: "STRING".to_string(),
1398 to_type: "BOOLEAN".to_string(),
1399 }),
1400 _ => Ok(None),
1401 },
1402 _ => Ok(None),
1403 })
1404 .collect::<Result<BooleanArray, _>>()?;
1405
1406 Ok(Arc::new(output_array))
1407}
1408
1409fn spark_cast_nonintegral_numeric_to_integral(
1410 array: &dyn Array,
1411 eval_mode: EvalMode,
1412 from_type: &DataType,
1413 to_type: &DataType,
1414) -> SparkResult<ArrayRef> {
1415 match (from_type, to_type) {
1416 (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
1417 array,
1418 eval_mode,
1419 Float32Array,
1420 Int8Array,
1421 f32,
1422 i8,
1423 "FLOAT",
1424 "TINYINT",
1425 "{:e}"
1426 ),
1427 (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!(
1428 array,
1429 eval_mode,
1430 Float32Array,
1431 Int16Array,
1432 f32,
1433 i16,
1434 "FLOAT",
1435 "SMALLINT",
1436 "{:e}"
1437 ),
1438 (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!(
1439 array,
1440 eval_mode,
1441 Float32Array,
1442 Int32Array,
1443 f32,
1444 i32,
1445 "FLOAT",
1446 "INT",
1447 i32::MAX,
1448 "{:e}"
1449 ),
1450 (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
1451 array,
1452 eval_mode,
1453 Float32Array,
1454 Int64Array,
1455 f32,
1456 i64,
1457 "FLOAT",
1458 "BIGINT",
1459 i64::MAX,
1460 "{:e}"
1461 ),
1462 (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
1463 array,
1464 eval_mode,
1465 Float64Array,
1466 Int8Array,
1467 f64,
1468 i8,
1469 "DOUBLE",
1470 "TINYINT",
1471 "{:e}D"
1472 ),
1473 (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!(
1474 array,
1475 eval_mode,
1476 Float64Array,
1477 Int16Array,
1478 f64,
1479 i16,
1480 "DOUBLE",
1481 "SMALLINT",
1482 "{:e}D"
1483 ),
1484 (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!(
1485 array,
1486 eval_mode,
1487 Float64Array,
1488 Int32Array,
1489 f64,
1490 i32,
1491 "DOUBLE",
1492 "INT",
1493 i32::MAX,
1494 "{:e}D"
1495 ),
1496 (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
1497 array,
1498 eval_mode,
1499 Float64Array,
1500 Int64Array,
1501 f64,
1502 i64,
1503 "DOUBLE",
1504 "BIGINT",
1505 i64::MAX,
1506 "{:e}D"
1507 ),
1508 (DataType::Decimal128(precision, scale), DataType::Int8) => {
1509 cast_decimal_to_int16_down!(
1510 array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1511 )
1512 }
1513 (DataType::Decimal128(precision, scale), DataType::Int16) => {
1514 cast_decimal_to_int16_down!(
1515 array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1516 )
1517 }
1518 (DataType::Decimal128(precision, scale), DataType::Int32) => {
1519 cast_decimal_to_int32_up!(
1520 array,
1521 eval_mode,
1522 Int32Array,
1523 i32,
1524 "INT",
1525 i32::MAX,
1526 *precision,
1527 *scale
1528 )
1529 }
1530 (DataType::Decimal128(precision, scale), DataType::Int64) => {
1531 cast_decimal_to_int32_up!(
1532 array,
1533 eval_mode,
1534 Int64Array,
1535 i64,
1536 "BIGINT",
1537 i64::MAX,
1538 *precision,
1539 *scale
1540 )
1541 }
1542 _ => unreachable!(
1543 "{}",
1544 format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}")
1545 ),
1546 }
1547}
1548
1549fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> {
1551 Ok(cast_string_to_int_with_range_check(
1552 str,
1553 eval_mode,
1554 "TINYINT",
1555 i8::MIN as i32,
1556 i8::MAX as i32,
1557 )?
1558 .map(|v| v as i8))
1559}
1560
1561fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>> {
1563 Ok(cast_string_to_int_with_range_check(
1564 str,
1565 eval_mode,
1566 "SMALLINT",
1567 i16::MIN as i32,
1568 i16::MAX as i32,
1569 )?
1570 .map(|v| v as i16))
1571}
1572
1573fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1575 do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
1576}
1577
1578fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> {
1580 do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
1581}
1582
1583fn cast_string_to_int_with_range_check(
1584 str: &str,
1585 eval_mode: EvalMode,
1586 type_name: &str,
1587 min: i32,
1588 max: i32,
1589) -> SparkResult<Option<i32>> {
1590 match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
1591 None => Ok(None),
1592 Some(v) if v >= min && v <= max => Ok(Some(v)),
1593 _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1594 _ => Ok(None),
1595 }
1596}
1597
1598fn do_cast_string_to_int<
1602 T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
1603>(
1604 str: &str,
1605 eval_mode: EvalMode,
1606 type_name: &str,
1607 min_value: T,
1608) -> SparkResult<Option<T>> {
1609 let trimmed_str = str.trim();
1610 if trimmed_str.is_empty() {
1611 return none_or_err(eval_mode, type_name, str);
1612 }
1613 let len = trimmed_str.len();
1614 let mut result: T = T::zero();
1615 let mut negative = false;
1616 let radix = T::from(10);
1617 let stop_value = min_value / radix;
1618 let mut parse_sign_and_digits = true;
1619
1620 for (i, ch) in trimmed_str.char_indices() {
1621 if parse_sign_and_digits {
1622 if i == 0 {
1623 negative = ch == '-';
1624 let positive = ch == '+';
1625 if negative || positive {
1626 if i + 1 == len {
1627 return none_or_err(eval_mode, type_name, str);
1629 }
1630 continue;
1632 }
1633 }
1634
1635 if ch == '.' {
1636 if eval_mode == EvalMode::Legacy {
1637 parse_sign_and_digits = false;
1639 continue;
1640 } else {
1641 return none_or_err(eval_mode, type_name, str);
1642 }
1643 }
1644
1645 let digit = if ch.is_ascii_digit() {
1646 (ch as u32) - ('0' as u32)
1647 } else {
1648 return none_or_err(eval_mode, type_name, str);
1649 };
1650
1651 if result < stop_value {
1656 return none_or_err(eval_mode, type_name, str);
1657 }
1658
1659 let v = result * radix;
1663 let digit = (digit as i32).into();
1664 match v.checked_sub(&digit) {
1665 Some(x) if x <= T::zero() => result = x,
1666 _ => {
1667 return none_or_err(eval_mode, type_name, str);
1668 }
1669 }
1670 } else {
1671 if !ch.is_ascii_digit() {
1673 return none_or_err(eval_mode, type_name, str);
1674 }
1675 }
1676 }
1677
1678 if !negative {
1679 if let Some(neg) = result.checked_neg() {
1680 if neg < T::zero() {
1681 return none_or_err(eval_mode, type_name, str);
1682 }
1683 result = neg;
1684 } else {
1685 return none_or_err(eval_mode, type_name, str);
1686 }
1687 }
1688
1689 Ok(Some(result))
1690}
1691
1692#[inline]
1694fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {
1695 match eval_mode {
1696 EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1697 _ => Ok(None),
1698 }
1699}
1700
1701#[inline]
1702fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
1703 SparkError::CastInvalidValue {
1704 value: value.to_string(),
1705 from_type: from_type.to_string(),
1706 to_type: to_type.to_string(),
1707 }
1708}
1709
1710#[inline]
1711fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
1712 SparkError::CastOverFlow {
1713 value: value.to_string(),
1714 from_type: from_type.to_string(),
1715 to_type: to_type.to_string(),
1716 }
1717}
1718
1719impl Display for Cast {
1720 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1721 write!(
1722 f,
1723 "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]",
1724 self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode
1725 )
1726 }
1727}
1728
1729impl PhysicalExpr for Cast {
1730 fn as_any(&self) -> &dyn Any {
1731 self
1732 }
1733
1734 fn data_type(&self, _: &Schema) -> DataFusionResult<DataType> {
1735 Ok(self.data_type.clone())
1736 }
1737
1738 fn nullable(&self, _: &Schema) -> DataFusionResult<bool> {
1739 Ok(true)
1740 }
1741
1742 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
1743 let arg = self.child.evaluate(batch)?;
1744 spark_cast(arg, &self.data_type, &self.cast_options)
1745 }
1746
1747 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1748 vec![&self.child]
1749 }
1750
1751 fn with_new_children(
1752 self: Arc<Self>,
1753 children: Vec<Arc<dyn PhysicalExpr>>,
1754 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
1755 match children.len() {
1756 1 => Ok(Arc::new(Cast::new(
1757 Arc::clone(&children[0]),
1758 self.data_type.clone(),
1759 self.cast_options.clone(),
1760 ))),
1761 _ => internal_err!("Cast should have exactly one child"),
1762 }
1763 }
1764}
1765
1766fn timestamp_parser<T: TimeZone>(
1767 value: &str,
1768 eval_mode: EvalMode,
1769 tz: &T,
1770) -> SparkResult<Option<i64>> {
1771 let value = value.trim();
1772 if value.is_empty() {
1773 return Ok(None);
1774 }
1775 let patterns = &[
1777 (
1778 Regex::new(r"^\d{4,5}$").unwrap(),
1779 parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
1780 ),
1781 (
1782 Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1783 parse_str_to_month_timestamp,
1784 ),
1785 (
1786 Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1787 parse_str_to_day_timestamp,
1788 ),
1789 (
1790 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1791 parse_str_to_hour_timestamp,
1792 ),
1793 (
1794 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1795 parse_str_to_minute_timestamp,
1796 ),
1797 (
1798 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1799 parse_str_to_second_timestamp,
1800 ),
1801 (
1802 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1803 parse_str_to_microsecond_timestamp,
1804 ),
1805 (
1806 Regex::new(r"^T\d{1,2}$").unwrap(),
1807 parse_str_to_time_only_timestamp,
1808 ),
1809 ];
1810
1811 let mut timestamp = None;
1812
1813 for (pattern, parse_func) in patterns {
1815 if pattern.is_match(value) {
1816 timestamp = parse_func(value, tz)?;
1817 break;
1818 }
1819 }
1820
1821 if timestamp.is_none() {
1822 return if eval_mode == EvalMode::Ansi {
1823 Err(SparkError::CastInvalidValue {
1824 value: value.to_string(),
1825 from_type: "STRING".to_string(),
1826 to_type: "TIMESTAMP".to_string(),
1827 })
1828 } else {
1829 Ok(None)
1830 };
1831 }
1832
1833 match timestamp {
1834 Some(ts) => Ok(Some(ts)),
1835 None => Err(SparkError::Internal(
1836 "Failed to parse timestamp".to_string(),
1837 )),
1838 }
1839}
1840
1841fn parse_timestamp_to_micros<T: TimeZone>(
1842 timestamp_info: &TimeStampInfo,
1843 tz: &T,
1844) -> SparkResult<Option<i64>> {
1845 let datetime = tz.with_ymd_and_hms(
1846 timestamp_info.year,
1847 timestamp_info.month,
1848 timestamp_info.day,
1849 timestamp_info.hour,
1850 timestamp_info.minute,
1851 timestamp_info.second,
1852 );
1853
1854 let tz_datetime = match datetime.single() {
1856 Some(dt) => dt
1857 .with_timezone(tz)
1858 .with_nanosecond(timestamp_info.microsecond * 1000),
1859 None => {
1860 return Err(SparkError::Internal(
1861 "Failed to parse timestamp".to_string(),
1862 ));
1863 }
1864 };
1865
1866 let result = match tz_datetime {
1867 Some(dt) => dt.timestamp_micros(),
1868 None => {
1869 return Err(SparkError::Internal(
1870 "Failed to parse timestamp".to_string(),
1871 ));
1872 }
1873 };
1874
1875 Ok(Some(result))
1876}
1877
1878fn get_timestamp_values<T: TimeZone>(
1879 value: &str,
1880 timestamp_type: &str,
1881 tz: &T,
1882) -> SparkResult<Option<i64>> {
1883 let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
1884 let year = values[0].parse::<i32>().unwrap_or_default();
1885 let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
1886 let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
1887 let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
1888 let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
1889 let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
1890 let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));
1891
1892 let mut timestamp_info = TimeStampInfo::default();
1893
1894 let timestamp_info = match timestamp_type {
1895 "year" => timestamp_info.with_year(year),
1896 "month" => timestamp_info.with_year(year).with_month(month),
1897 "day" => timestamp_info
1898 .with_year(year)
1899 .with_month(month)
1900 .with_day(day),
1901 "hour" => timestamp_info
1902 .with_year(year)
1903 .with_month(month)
1904 .with_day(day)
1905 .with_hour(hour),
1906 "minute" => timestamp_info
1907 .with_year(year)
1908 .with_month(month)
1909 .with_day(day)
1910 .with_hour(hour)
1911 .with_minute(minute),
1912 "second" => timestamp_info
1913 .with_year(year)
1914 .with_month(month)
1915 .with_day(day)
1916 .with_hour(hour)
1917 .with_minute(minute)
1918 .with_second(second),
1919 "microsecond" => timestamp_info
1920 .with_year(year)
1921 .with_month(month)
1922 .with_day(day)
1923 .with_hour(hour)
1924 .with_minute(minute)
1925 .with_second(second)
1926 .with_microsecond(microsecond),
1927 _ => {
1928 return Err(SparkError::CastInvalidValue {
1929 value: value.to_string(),
1930 from_type: "STRING".to_string(),
1931 to_type: "TIMESTAMP".to_string(),
1932 })
1933 }
1934 };
1935
1936 parse_timestamp_to_micros(timestamp_info, tz)
1937}
1938
1939fn parse_str_to_year_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1940 get_timestamp_values(value, "year", tz)
1941}
1942
1943fn parse_str_to_month_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1944 get_timestamp_values(value, "month", tz)
1945}
1946
1947fn parse_str_to_day_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1948 get_timestamp_values(value, "day", tz)
1949}
1950
1951fn parse_str_to_hour_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1952 get_timestamp_values(value, "hour", tz)
1953}
1954
1955fn parse_str_to_minute_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1956 get_timestamp_values(value, "minute", tz)
1957}
1958
1959fn parse_str_to_second_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1960 get_timestamp_values(value, "second", tz)
1961}
1962
1963fn parse_str_to_microsecond_timestamp<T: TimeZone>(
1964 value: &str,
1965 tz: &T,
1966) -> SparkResult<Option<i64>> {
1967 get_timestamp_values(value, "microsecond", tz)
1968}
1969
1970fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1971 let values: Vec<&str> = value.split('T').collect();
1972 let time_values: Vec<u32> = values[1]
1973 .split(':')
1974 .map(|v| v.parse::<u32>().unwrap_or(0))
1975 .collect();
1976
1977 let datetime = tz.from_utc_datetime(&chrono::Utc::now().naive_utc());
1978 let timestamp = datetime
1979 .with_timezone(tz)
1980 .with_hour(time_values.first().copied().unwrap_or_default())
1981 .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
1982 .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
1983 .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000))
1984 .map(|dt| dt.timestamp_micros())
1985 .unwrap_or_default();
1986
1987 Ok(Some(timestamp))
1988}
1989
1990fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1992 fn get_trimmed_start(bytes: &[u8]) -> usize {
1994 let mut start = 0;
1995 while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) {
1996 start += 1;
1997 }
1998 start
1999 }
2000
2001 fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize {
2002 let mut end = bytes.len() - 1;
2003 while end > start && is_whitespace_or_iso_control(bytes[end]) {
2004 end -= 1;
2005 }
2006 end + 1
2007 }
2008
2009 fn is_whitespace_or_iso_control(byte: u8) -> bool {
2010 byte.is_ascii_whitespace() || byte.is_ascii_control()
2011 }
2012
2013 fn is_valid_digits(segment: i32, digits: usize) -> bool {
2014 let max_digits_year = 7;
2016 (segment == 0 && digits >= 4 && digits <= max_digits_year)
2019 || (segment != 0 && digits > 0 && digits <= 2)
2020 }
2021
2022 fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
2023 if eval_mode == EvalMode::Ansi {
2024 Err(SparkError::CastInvalidValue {
2025 value: date_str.to_string(),
2026 from_type: "STRING".to_string(),
2027 to_type: "DATE".to_string(),
2028 })
2029 } else {
2030 Ok(None)
2031 }
2032 }
2033 if date_str.is_empty() {
2036 return return_result(date_str, eval_mode);
2037 }
2038
2039 let mut date_segments = [1, 1, 1];
2041 let mut sign = 1;
2042 let mut current_segment = 0;
2043 let mut current_segment_value = Wrapping(0);
2044 let mut current_segment_digits = 0;
2045 let bytes = date_str.as_bytes();
2046
2047 let mut j = get_trimmed_start(bytes);
2048 let str_end_trimmed = get_trimmed_end(j, bytes);
2049
2050 if j == str_end_trimmed {
2051 return return_result(date_str, eval_mode);
2052 }
2053
2054 if bytes[j] == b'-' || bytes[j] == b'+' {
2056 sign = if bytes[j] == b'-' { -1 } else { 1 };
2057 j += 1;
2058 }
2059
2060 while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) {
2063 let b = bytes[j];
2064 if current_segment < 2 && b == b'-' {
2065 if !is_valid_digits(current_segment, current_segment_digits) {
2067 return return_result(date_str, eval_mode);
2068 }
2069 date_segments[current_segment as usize] = current_segment_value.0;
2071 current_segment_value = Wrapping(0);
2072 current_segment_digits = 0;
2073 current_segment += 1;
2074 } else if !b.is_ascii_digit() {
2075 return return_result(date_str, eval_mode);
2076 } else {
2077 let parsed_value = Wrapping((b - b'0') as i32);
2079 current_segment_value = current_segment_value * Wrapping(10) + parsed_value;
2080 current_segment_digits += 1;
2081 }
2082 j += 1;
2083 }
2084
2085 if !is_valid_digits(current_segment, current_segment_digits) {
2087 return return_result(date_str, eval_mode);
2088 }
2089
2090 if current_segment < 2 && j < str_end_trimmed {
2091 return return_result(date_str, eval_mode);
2093 }
2094
2095 date_segments[current_segment as usize] = current_segment_value.0;
2096
2097 match NaiveDate::from_ymd_opt(
2098 sign * date_segments[0],
2099 date_segments[1] as u32,
2100 date_segments[2] as u32,
2101 ) {
2102 Some(date) => {
2103 let duration_since_epoch = date
2104 .signed_duration_since(NaiveDateTime::UNIX_EPOCH.date())
2105 .num_days();
2106 Ok(Some(duration_since_epoch.to_i32().unwrap()))
2107 }
2108 None => Ok(None),
2109 }
2110}
2111
2112fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef {
2118 match (from_type, to_type) {
2119 (DataType::Timestamp(_, _), DataType::Int64) => {
2120 unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2122 }
2123 (DataType::Dictionary(_, value_type), DataType::Int64)
2124 if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2125 {
2126 unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2128 }
2129 (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array),
2130 (DataType::Dictionary(_, value_type), DataType::Utf8)
2131 if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2132 {
2133 remove_trailing_zeroes(array)
2134 }
2135 _ => array,
2136 }
2137}
2138
2139fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
2141where
2142 T: ArrowPrimitiveType,
2143 F: Fn(T::Native) -> T::Native,
2144{
2145 if let Some(d) = array.as_any_dictionary_opt() {
2146 let new_values = unary_dyn::<F, T>(d.values(), op)?;
2147 return Ok(Arc::new(d.with_values(Arc::new(new_values))));
2148 }
2149
2150 match array.as_primitive_opt::<T>() {
2151 Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
2152 Ok(Arc::new(unary::<T, F, T>(
2153 array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
2154 op,
2155 )))
2156 }
2157 _ => Err(ArrowError::NotYetImplemented(format!(
2158 "Cannot perform unary operation of type {} on array of type {}",
2159 T::DATA_TYPE,
2160 array.data_type()
2161 ))),
2162 }
2163}
2164
2165fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
2174 let string_array = as_generic_string_array::<i32>(&array).unwrap();
2175 let result = string_array
2176 .iter()
2177 .map(|s| s.map(trim_end))
2178 .collect::<GenericStringArray<i32>>();
2179 Arc::new(result) as ArrayRef
2180}
2181
2182fn trim_end(s: &str) -> &str {
2183 if s.rfind('.').is_some() {
2184 s.trim_end_matches('0')
2185 } else {
2186 s
2187 }
2188}
2189
2190#[cfg(test)]
2191mod tests {
2192 use arrow::datatypes::TimestampMicrosecondType;
2193 use arrow_array::StringArray;
2194 use arrow_schema::{Field, Fields, TimeUnit};
2195 use std::str::FromStr;
2196
2197 use super::*;
2198
2199 #[test]
2200 #[cfg_attr(miri, ignore)] fn timestamp_parser_test() {
2202 let tz = &timezone::Tz::from_str("UTC").unwrap();
2203 assert_eq!(
2205 timestamp_parser("2020", EvalMode::Legacy, tz).unwrap(),
2206 Some(1577836800000000) );
2208 assert_eq!(
2209 timestamp_parser("2020-01", EvalMode::Legacy, tz).unwrap(),
2210 Some(1577836800000000)
2211 );
2212 assert_eq!(
2213 timestamp_parser("2020-01-01", EvalMode::Legacy, tz).unwrap(),
2214 Some(1577836800000000)
2215 );
2216 assert_eq!(
2217 timestamp_parser("2020-01-01T12", EvalMode::Legacy, tz).unwrap(),
2218 Some(1577880000000000)
2219 );
2220 assert_eq!(
2221 timestamp_parser("2020-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2222 Some(1577882040000000)
2223 );
2224 assert_eq!(
2225 timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2226 Some(1577882096000000)
2227 );
2228 assert_eq!(
2229 timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2230 Some(1577882096123456)
2231 );
2232 assert_eq!(
2233 timestamp_parser("0100", EvalMode::Legacy, tz).unwrap(),
2234 Some(-59011459200000000)
2235 );
2236 assert_eq!(
2237 timestamp_parser("0100-01", EvalMode::Legacy, tz).unwrap(),
2238 Some(-59011459200000000)
2239 );
2240 assert_eq!(
2241 timestamp_parser("0100-01-01", EvalMode::Legacy, tz).unwrap(),
2242 Some(-59011459200000000)
2243 );
2244 assert_eq!(
2245 timestamp_parser("0100-01-01T12", EvalMode::Legacy, tz).unwrap(),
2246 Some(-59011416000000000)
2247 );
2248 assert_eq!(
2249 timestamp_parser("0100-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2250 Some(-59011413960000000)
2251 );
2252 assert_eq!(
2253 timestamp_parser("0100-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2254 Some(-59011413904000000)
2255 );
2256 assert_eq!(
2257 timestamp_parser("0100-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2258 Some(-59011413903876544)
2259 );
2260 assert_eq!(
2261 timestamp_parser("10000", EvalMode::Legacy, tz).unwrap(),
2262 Some(253402300800000000)
2263 );
2264 assert_eq!(
2265 timestamp_parser("10000-01", EvalMode::Legacy, tz).unwrap(),
2266 Some(253402300800000000)
2267 );
2268 assert_eq!(
2269 timestamp_parser("10000-01-01", EvalMode::Legacy, tz).unwrap(),
2270 Some(253402300800000000)
2271 );
2272 assert_eq!(
2273 timestamp_parser("10000-01-01T12", EvalMode::Legacy, tz).unwrap(),
2274 Some(253402344000000000)
2275 );
2276 assert_eq!(
2277 timestamp_parser("10000-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2278 Some(253402346040000000)
2279 );
2280 assert_eq!(
2281 timestamp_parser("10000-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2282 Some(253402346096000000)
2283 );
2284 assert_eq!(
2285 timestamp_parser("10000-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2286 Some(253402346096123456)
2287 );
2288 }
2293
2294 #[test]
2295 #[cfg_attr(miri, ignore)] fn test_cast_string_to_timestamp() {
2297 let array: ArrayRef = Arc::new(StringArray::from(vec![
2298 Some("2020-01-01T12:34:56.123456"),
2299 Some("T2"),
2300 Some("0100-01-01T12:34:56.123456"),
2301 Some("10000-01-01T12:34:56.123456"),
2302 ]));
2303 let tz = &timezone::Tz::from_str("UTC").unwrap();
2304
2305 let string_array = array
2306 .as_any()
2307 .downcast_ref::<GenericStringArray<i32>>()
2308 .expect("Expected a string array");
2309
2310 let eval_mode = EvalMode::Legacy;
2311 let result = cast_utf8_to_timestamp!(
2312 &string_array,
2313 eval_mode,
2314 TimestampMicrosecondType,
2315 timestamp_parser,
2316 tz
2317 );
2318
2319 assert_eq!(
2320 result.data_type(),
2321 &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
2322 );
2323 assert_eq!(result.len(), 4);
2324 }
2325
2326 #[test]
2327 fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
2328 let keys = Int32Array::from(vec![0, 1]);
2330 let values: ArrayRef = Arc::new(StringArray::from(vec![
2331 Some("2020-01-01T12:34:56.123456"),
2332 Some("T2"),
2333 ]));
2334 let dict_array = Arc::new(DictionaryArray::new(keys, values));
2335
2336 let timezone = "UTC".to_string();
2337 let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false);
2339 let result = cast_array(
2340 dict_array,
2341 &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
2342 &cast_options,
2343 )?;
2344 assert_eq!(
2345 *result.data_type(),
2346 DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
2347 );
2348 assert_eq!(result.len(), 2);
2349
2350 Ok(())
2351 }
2352
2353 #[test]
2354 fn date_parser_test() {
2355 for date in &[
2356 "2020",
2357 "2020-01",
2358 "2020-01-01",
2359 "02020-01-01",
2360 "002020-01-01",
2361 "0002020-01-01",
2362 "2020-1-1",
2363 "2020-01-01 ",
2364 "2020-01-01T",
2365 ] {
2366 for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] {
2367 assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262));
2368 }
2369 }
2370
2371 for date in &[
2373 "abc",
2374 "",
2375 "not_a_date",
2376 "3/",
2377 "3/12",
2378 "3/12/2020",
2379 "3/12/2002 T",
2380 "202",
2381 "2020-010-01",
2382 "2020-10-010",
2383 "2020-10-010T",
2384 "--262143-12-31",
2385 "--262143-12-31 ",
2386 ] {
2387 for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2388 assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2389 }
2390 assert!(date_parser(date, EvalMode::Ansi).is_err());
2391 }
2392
2393 for date in &["-3638-5"] {
2394 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2395 assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(-2048160));
2396 }
2397 }
2398
2399 for date in &[
2402 "-262144-1-1",
2403 "262143-01-1",
2404 "262143-1-1",
2405 "262143-01-1 ",
2406 "262143-01-01T ",
2407 "262143-1-01T 1234",
2408 "-0973250",
2409 ] {
2410 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2411 assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2412 }
2413 }
2414 }
2415
2416 #[test]
2417 fn test_cast_string_to_date() {
2418 let array: ArrayRef = Arc::new(StringArray::from(vec![
2419 Some("2020"),
2420 Some("2020-01"),
2421 Some("2020-01-01"),
2422 Some("2020-01-01T"),
2423 ]));
2424
2425 let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap();
2426
2427 let date32_array = result
2428 .as_any()
2429 .downcast_ref::<arrow::array::Date32Array>()
2430 .unwrap();
2431 assert_eq!(date32_array.len(), 4);
2432 date32_array
2433 .iter()
2434 .for_each(|v| assert_eq!(v.unwrap(), 18262));
2435 }
2436
2437 #[test]
2438 fn test_cast_string_array_with_valid_dates() {
2439 let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2440 Some("-262143-12-31"),
2441 Some("\n -262143-12-31 "),
2442 Some("-262143-12-31T \t\n"),
2443 Some("\n\t-262143-12-31T\r"),
2444 Some("-262143-12-31T 123123123"),
2445 Some("\r\n-262143-12-31T \r123123123"),
2446 Some("\n -262143-12-31T \n\t"),
2447 ]));
2448
2449 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2450 let result =
2451 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2452 .unwrap();
2453
2454 let date32_array = result
2455 .as_any()
2456 .downcast_ref::<arrow::array::Date32Array>()
2457 .unwrap();
2458 assert_eq!(result.len(), 7);
2459 date32_array
2460 .iter()
2461 .for_each(|v| assert_eq!(v.unwrap(), -96464928));
2462 }
2463 }
2464
2465 #[test]
2466 fn test_cast_string_array_with_invalid_dates() {
2467 let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2468 Some("2020"),
2469 Some("2020-01"),
2470 Some("2020-01-01"),
2471 Some("2020-010-01T"),
2473 Some("202"),
2474 Some(" 202 "),
2475 Some("\n 2020-\r8 "),
2476 Some("2020-01-01T"),
2477 Some("-4607172990231812908"),
2479 ]));
2480
2481 for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2482 let result =
2483 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2484 .unwrap();
2485
2486 let date32_array = result
2487 .as_any()
2488 .downcast_ref::<arrow::array::Date32Array>()
2489 .unwrap();
2490 assert_eq!(
2491 date32_array.iter().collect::<Vec<_>>(),
2492 vec![
2493 Some(18262),
2494 Some(18262),
2495 Some(18262),
2496 None,
2497 None,
2498 None,
2499 None,
2500 Some(18262),
2501 None
2502 ]
2503 );
2504 }
2505
2506 let result =
2507 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi);
2508 match result {
2509 Err(e) => assert!(
2510 e.to_string().contains(
2511 "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed")
2512 ),
2513 _ => panic!("Expected error"),
2514 }
2515 }
2516
2517 #[test]
2518 fn test_cast_string_as_i8() {
2519 assert_eq!(
2521 cast_string_to_i8("127", EvalMode::Legacy).unwrap(),
2522 Some(127_i8)
2523 );
2524 assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None);
2525 assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err());
2526 assert_eq!(
2528 cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(),
2529 Some(0_i8)
2530 );
2531 assert_eq!(
2532 cast_string_to_i8(".", EvalMode::Legacy).unwrap(),
2533 Some(0_i8)
2534 );
2535 assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None);
2537 assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None);
2538 assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err());
2540 assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err());
2541 }
2542
2543 #[test]
2544 fn test_cast_unsupported_timestamp_to_date() {
2545 let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2547 let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
2548 let result = cast_array(
2549 Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2550 &DataType::Date32,
2551 &cast_options,
2552 );
2553 assert!(result.is_err())
2554 }
2555
2556 #[test]
2557 fn test_cast_invalid_timezone() {
2558 let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2559 let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false);
2560 let result = cast_array(
2561 Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2562 &DataType::Date32,
2563 &cast_options,
2564 );
2565 assert!(result.is_err())
2566 }
2567
2568 #[test]
2569 fn test_cast_struct_to_utf8() {
2570 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2571 Some(1),
2572 Some(2),
2573 None,
2574 Some(4),
2575 Some(5),
2576 ]));
2577 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2578 let c: ArrayRef = Arc::new(StructArray::from(vec![
2579 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2580 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2581 ]));
2582 let string_array = cast_array(
2583 c,
2584 &DataType::Utf8,
2585 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2586 )
2587 .unwrap();
2588 let string_array = string_array.as_string::<i32>();
2589 assert_eq!(5, string_array.len());
2590 assert_eq!(r#"{1, a}"#, string_array.value(0));
2591 assert_eq!(r#"{2, b}"#, string_array.value(1));
2592 assert_eq!(r#"{null, c}"#, string_array.value(2));
2593 assert_eq!(r#"{4, d}"#, string_array.value(3));
2594 assert_eq!(r#"{5, e}"#, string_array.value(4));
2595 }
2596
2597 #[test]
2598 fn test_cast_struct_to_struct() {
2599 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2600 Some(1),
2601 Some(2),
2602 None,
2603 Some(4),
2604 Some(5),
2605 ]));
2606 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2607 let c: ArrayRef = Arc::new(StructArray::from(vec![
2608 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2609 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2610 ]));
2611 let fields = Fields::from(vec![
2613 Field::new("a", DataType::Utf8, true),
2614 Field::new("b", DataType::Utf8, true),
2615 ]);
2616 let cast_array = spark_cast(
2617 ColumnarValue::Array(c),
2618 &DataType::Struct(fields),
2619 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2620 )
2621 .unwrap();
2622 if let ColumnarValue::Array(cast_array) = cast_array {
2623 assert_eq!(5, cast_array.len());
2624 let a = cast_array.as_struct().column(0).as_string::<i32>();
2625 assert_eq!("1", a.value(0));
2626 } else {
2627 unreachable!()
2628 }
2629 }
2630
2631 #[test]
2632 fn test_cast_struct_to_struct_drop_column() {
2633 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2634 Some(1),
2635 Some(2),
2636 None,
2637 Some(4),
2638 Some(5),
2639 ]));
2640 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2641 let c: ArrayRef = Arc::new(StructArray::from(vec![
2642 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2643 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2644 ]));
2645 let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]);
2647 let cast_array = spark_cast(
2648 ColumnarValue::Array(c),
2649 &DataType::Struct(fields),
2650 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2651 )
2652 .unwrap();
2653 if let ColumnarValue::Array(cast_array) = cast_array {
2654 assert_eq!(5, cast_array.len());
2655 let struct_array = cast_array.as_struct();
2656 assert_eq!(1, struct_array.columns().len());
2657 let a = struct_array.column(0).as_string::<i32>();
2658 assert_eq!("1", a.value(0));
2659 } else {
2660 unreachable!()
2661 }
2662 }
2663}