1use crate::timezone;
19use crate::utils::array_with_timezone;
20use crate::{EvalMode, SparkError, SparkResult};
21use arrow::array::builder::StringBuilder;
22use arrow::array::{DictionaryArray, StringArray, StructArray};
23use arrow::datatypes::{DataType, Schema};
24use arrow::{
25 array::{
26 cast::AsArray,
27 types::{Date32Type, Int16Type, Int32Type, Int8Type},
28 Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array,
29 GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
30 PrimitiveArray,
31 },
32 compute::{cast_with_options, take, unary, CastOptions},
33 datatypes::{
34 is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type,
35 Float64Type, Int64Type, TimestampMicrosecondType,
36 },
37 error::ArrowError,
38 record_batch::RecordBatch,
39 util::display::FormatOptions,
40};
41use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
42use datafusion::common::{
43 cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue,
44};
45use datafusion::physical_expr::PhysicalExpr;
46use datafusion::physical_plan::ColumnarValue;
47use num::{
48 cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num,
49 ToPrimitive,
50};
51use regex::Regex;
52use std::str::FromStr;
53use std::{
54 any::Any,
55 fmt::{Debug, Display, Formatter},
56 hash::Hash,
57 num::Wrapping,
58 sync::Arc,
59};
60
61static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
62
63const MICROS_PER_SECOND: i64 = 1000000;
64
65static CAST_OPTIONS: CastOptions = CastOptions {
66 safe: true,
67 format_options: FormatOptions::new()
68 .with_timestamp_tz_format(TIMESTAMP_FORMAT)
69 .with_timestamp_format(TIMESTAMP_FORMAT),
70};
71
72struct TimeStampInfo {
73 year: i32,
74 month: u32,
75 day: u32,
76 hour: u32,
77 minute: u32,
78 second: u32,
79 microsecond: u32,
80}
81
82impl Default for TimeStampInfo {
83 fn default() -> Self {
84 TimeStampInfo {
85 year: 1,
86 month: 1,
87 day: 1,
88 hour: 0,
89 minute: 0,
90 second: 0,
91 microsecond: 0,
92 }
93 }
94}
95
96impl TimeStampInfo {
97 pub fn with_year(&mut self, year: i32) -> &mut Self {
98 self.year = year;
99 self
100 }
101
102 pub fn with_month(&mut self, month: u32) -> &mut Self {
103 self.month = month;
104 self
105 }
106
107 pub fn with_day(&mut self, day: u32) -> &mut Self {
108 self.day = day;
109 self
110 }
111
112 pub fn with_hour(&mut self, hour: u32) -> &mut Self {
113 self.hour = hour;
114 self
115 }
116
117 pub fn with_minute(&mut self, minute: u32) -> &mut Self {
118 self.minute = minute;
119 self
120 }
121
122 pub fn with_second(&mut self, second: u32) -> &mut Self {
123 self.second = second;
124 self
125 }
126
127 pub fn with_microsecond(&mut self, microsecond: u32) -> &mut Self {
128 self.microsecond = microsecond;
129 self
130 }
131}
132
133#[derive(Debug, Eq)]
134pub struct Cast {
135 pub child: Arc<dyn PhysicalExpr>,
136 pub data_type: DataType,
137 pub cast_options: SparkCastOptions,
138}
139
140impl PartialEq for Cast {
141 fn eq(&self, other: &Self) -> bool {
142 self.child.eq(&other.child)
143 && self.data_type.eq(&other.data_type)
144 && self.cast_options.eq(&other.cast_options)
145 }
146}
147
148impl Hash for Cast {
149 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
150 self.child.hash(state);
151 self.data_type.hash(state);
152 self.cast_options.hash(state);
153 }
154}
155
156pub fn cast_supported(
158 from_type: &DataType,
159 to_type: &DataType,
160 options: &SparkCastOptions,
161) -> bool {
162 use DataType::*;
163
164 let from_type = if let Dictionary(_, dt) = from_type {
165 dt
166 } else {
167 from_type
168 };
169
170 let to_type = if let Dictionary(_, dt) = to_type {
171 dt
172 } else {
173 to_type
174 };
175
176 if from_type == to_type {
177 return true;
178 }
179
180 match (from_type, to_type) {
181 (Boolean, _) => can_cast_from_boolean(to_type, options),
182 (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
183 if options.allow_cast_unsigned_ints =>
184 {
185 true
186 }
187 (Int8, _) => can_cast_from_byte(to_type, options),
188 (Int16, _) => can_cast_from_short(to_type, options),
189 (Int32, _) => can_cast_from_int(to_type, options),
190 (Int64, _) => can_cast_from_long(to_type, options),
191 (Float32, _) => can_cast_from_float(to_type, options),
192 (Float64, _) => can_cast_from_double(to_type, options),
193 (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
194 (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options),
195 (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options),
196 (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
197 (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
198 (Struct(from_fields), Struct(to_fields)) => from_fields
199 .iter()
200 .zip(to_fields.iter())
201 .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)),
202 _ => false,
203 }
204}
205
206fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool {
207 use DataType::*;
208 match to_type {
209 Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
210 Float32 | Float64 => {
211 options.allow_incompat
215 }
216 Decimal128(_, _) => {
217 options.allow_incompat
222 }
223 Date32 | Date64 => {
224 options.allow_incompat
227 }
228 Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => {
229 false
231 }
232 Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => {
233 options.allow_incompat
235 }
236 Timestamp(_, _) => {
237 options.allow_incompat
240 }
241 _ => false,
242 }
243}
244
245fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> bool {
246 use DataType::*;
247 match from_type {
248 Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => true,
249 Float32 | Float64 => {
250 true
254 }
255 Decimal128(_, _) => {
256 true
260 }
261 Binary => {
262 options.allow_incompat
265 }
266 Struct(fields) => fields
267 .iter()
268 .all(|f| can_cast_to_string(f.data_type(), options)),
269 _ => false,
270 }
271}
272
273fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) -> bool {
274 use DataType::*;
275 match to_type {
276 Timestamp(_, _) | Date32 | Date64 | Utf8 => {
277 options.allow_incompat
279 }
280 _ => {
281 false
283 }
284 }
285}
286
287fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> bool {
288 use DataType::*;
289 match to_type {
290 Boolean | Int8 | Int16 => {
291 false
294 }
295 Int64 => {
296 true
298 }
299 Date32 | Date64 | Utf8 | Decimal128(_, _) => true,
300 _ => {
301 false
303 }
304 }
305}
306
307fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
308 use DataType::*;
309 matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
310}
311
312fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
313 use DataType::*;
314 matches!(
315 to_type,
316 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
317 )
318}
319
320fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
321 use DataType::*;
322 matches!(
323 to_type,
324 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
325 )
326}
327
328fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
329 use DataType::*;
330 match to_type {
331 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true,
332 Decimal128(_, _) => {
333 options.allow_incompat
335 }
336 _ => false,
337 }
338}
339
340fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
341 use DataType::*;
342 match to_type {
343 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
344 Decimal128(_, _) => {
345 options.allow_incompat
347 }
348 _ => false,
349 }
350}
351
352fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool {
353 use DataType::*;
354 matches!(
355 to_type,
356 Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _)
357 )
358}
359
360fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool {
361 use DataType::*;
362 matches!(
363 to_type,
364 Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _)
365 )
366}
367
368fn can_cast_from_decimal(
369 p1: &u8,
370 _s1: &i8,
371 to_type: &DataType,
372 options: &SparkCastOptions,
373) -> bool {
374 use DataType::*;
375 match to_type {
376 Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
377 Decimal128(p2, _) => {
378 if p2 < p1 {
379 options.allow_incompat
382 } else {
383 true
384 }
385 }
386 _ => false,
387 }
388}
389
390macro_rules! cast_utf8_to_int {
391 ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
392 let len = $array.len();
393 let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
394 for i in 0..len {
395 if $array.is_null(i) {
396 cast_array.append_null()
397 } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
398 cast_array.append_value(cast_value);
399 } else {
400 cast_array.append_null()
401 }
402 }
403 let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
404 result
405 }};
406}
407macro_rules! cast_utf8_to_timestamp {
408 ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
409 let len = $array.len();
410 let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
411 for i in 0..len {
412 if $array.is_null(i) {
413 cast_array.append_null()
414 } else if let Ok(Some(cast_value)) =
415 $cast_method($array.value(i).trim(), $eval_mode, $tz)
416 {
417 cast_array.append_value(cast_value);
418 } else {
419 cast_array.append_null()
420 }
421 }
422 let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
423 result
424 }};
425}
426
427macro_rules! cast_float_to_string {
428 ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{
429
430 fn cast<OffsetSize>(
431 from: &dyn Array,
432 _eval_mode: EvalMode,
433 ) -> SparkResult<ArrayRef>
434 where
435 OffsetSize: OffsetSizeTrait, {
436 let array = from.as_any().downcast_ref::<$output_type>().unwrap();
437
438 const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
447 const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;
448
449 let output_array = array
450 .iter()
451 .map(|value| match value {
452 Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())),
453 Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())),
454 Some(value)
455 if (value.abs() < UPPER_SCIENTIFIC_BOUND
456 && value.abs() >= LOWER_SCIENTIFIC_BOUND)
457 || value.abs() == 0.0 =>
458 {
459 let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" };
460
461 Ok(Some(format!("{value}{trailing_zero}")))
462 }
463 Some(value)
464 if value.abs() >= UPPER_SCIENTIFIC_BOUND
465 || value.abs() < LOWER_SCIENTIFIC_BOUND =>
466 {
467 let formatted = format!("{value:E}");
468
469 if formatted.contains(".") {
470 Ok(Some(formatted))
471 } else {
472 let prepare_number: Vec<&str> = formatted.split("E").collect();
475
476 let coefficient = prepare_number[0];
477
478 let exponent = prepare_number[1];
479
480 Ok(Some(format!("{coefficient}.0E{exponent}")))
481 }
482 }
483 Some(value) => Ok(Some(value.to_string())),
484 _ => Ok(None),
485 })
486 .collect::<Result<GenericStringArray<OffsetSize>, SparkError>>()?;
487
488 Ok(Arc::new(output_array))
489 }
490
491 cast::<$offset_type>($from, $eval_mode)
492 }};
493}
494
495macro_rules! cast_int_to_int_macro {
496 (
497 $array: expr,
498 $eval_mode:expr,
499 $from_arrow_primitive_type: ty,
500 $to_arrow_primitive_type: ty,
501 $from_data_type: expr,
502 $to_native_type: ty,
503 $spark_from_data_type_name: expr,
504 $spark_to_data_type_name: expr
505 ) => {{
506 let cast_array = $array
507 .as_any()
508 .downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
509 .unwrap();
510 let spark_int_literal_suffix = match $from_data_type {
511 &DataType::Int64 => "L",
512 &DataType::Int16 => "S",
513 &DataType::Int8 => "T",
514 _ => "",
515 };
516
517 let output_array = match $eval_mode {
518 EvalMode::Legacy => cast_array
519 .iter()
520 .map(|value| match value {
521 Some(value) => {
522 Ok::<Option<$to_native_type>, SparkError>(Some(value as $to_native_type))
523 }
524 _ => Ok(None),
525 })
526 .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
527 _ => cast_array
528 .iter()
529 .map(|value| match value {
530 Some(value) => {
531 let res = <$to_native_type>::try_from(value);
532 if res.is_err() {
533 Err(cast_overflow(
534 &(value.to_string() + spark_int_literal_suffix),
535 $spark_from_data_type_name,
536 $spark_to_data_type_name,
537 ))
538 } else {
539 Ok::<Option<$to_native_type>, SparkError>(Some(res.unwrap()))
540 }
541 }
542 _ => Ok(None),
543 })
544 .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
545 }?;
546 let result: SparkResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
547 result
548 }};
549}
550
551macro_rules! cast_float_to_int16_down {
555 (
556 $array:expr,
557 $eval_mode:expr,
558 $src_array_type:ty,
559 $dest_array_type:ty,
560 $rust_src_type:ty,
561 $rust_dest_type:ty,
562 $src_type_str:expr,
563 $dest_type_str:expr,
564 $format_str:expr
565 ) => {{
566 let cast_array = $array
567 .as_any()
568 .downcast_ref::<$src_array_type>()
569 .expect(concat!("Expected a ", stringify!($src_array_type)));
570
571 let output_array = match $eval_mode {
572 EvalMode::Ansi => cast_array
573 .iter()
574 .map(|value| match value {
575 Some(value) => {
576 let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX;
577 if is_overflow {
578 return Err(cast_overflow(
579 &format!($format_str, value).replace("e", "E"),
580 $src_type_str,
581 $dest_type_str,
582 ));
583 }
584 let i32_value = value as i32;
585 <$rust_dest_type>::try_from(i32_value)
586 .map_err(|_| {
587 cast_overflow(
588 &format!($format_str, value).replace("e", "E"),
589 $src_type_str,
590 $dest_type_str,
591 )
592 })
593 .map(Some)
594 }
595 None => Ok(None),
596 })
597 .collect::<Result<$dest_array_type, _>>()?,
598 _ => cast_array
599 .iter()
600 .map(|value| match value {
601 Some(value) => {
602 let i32_value = value as i32;
603 Ok::<Option<$rust_dest_type>, SparkError>(Some(
604 i32_value as $rust_dest_type,
605 ))
606 }
607 None => Ok(None),
608 })
609 .collect::<Result<$dest_array_type, _>>()?,
610 };
611 Ok(Arc::new(output_array) as ArrayRef)
612 }};
613}
614
615macro_rules! cast_float_to_int32_up {
616 (
617 $array:expr,
618 $eval_mode:expr,
619 $src_array_type:ty,
620 $dest_array_type:ty,
621 $rust_src_type:ty,
622 $rust_dest_type:ty,
623 $src_type_str:expr,
624 $dest_type_str:expr,
625 $max_dest_val:expr,
626 $format_str:expr
627 ) => {{
628 let cast_array = $array
629 .as_any()
630 .downcast_ref::<$src_array_type>()
631 .expect(concat!("Expected a ", stringify!($src_array_type)));
632
633 let output_array = match $eval_mode {
634 EvalMode::Ansi => cast_array
635 .iter()
636 .map(|value| match value {
637 Some(value) => {
638 let is_overflow =
639 value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val;
640 if is_overflow {
641 return Err(cast_overflow(
642 &format!($format_str, value).replace("e", "E"),
643 $src_type_str,
644 $dest_type_str,
645 ));
646 }
647 Ok(Some(value as $rust_dest_type))
648 }
649 None => Ok(None),
650 })
651 .collect::<Result<$dest_array_type, _>>()?,
652 _ => cast_array
653 .iter()
654 .map(|value| match value {
655 Some(value) => {
656 Ok::<Option<$rust_dest_type>, SparkError>(Some(value as $rust_dest_type))
657 }
658 None => Ok(None),
659 })
660 .collect::<Result<$dest_array_type, _>>()?,
661 };
662 Ok(Arc::new(output_array) as ArrayRef)
663 }};
664}
665
666macro_rules! cast_decimal_to_int16_down {
670 (
671 $array:expr,
672 $eval_mode:expr,
673 $dest_array_type:ty,
674 $rust_dest_type:ty,
675 $dest_type_str:expr,
676 $precision:expr,
677 $scale:expr
678 ) => {{
679 let cast_array = $array
680 .as_any()
681 .downcast_ref::<Decimal128Array>()
682 .expect(concat!("Expected a Decimal128ArrayType"));
683
684 let output_array = match $eval_mode {
685 EvalMode::Ansi => cast_array
686 .iter()
687 .map(|value| match value {
688 Some(value) => {
689 let divisor = 10_i128.pow($scale as u32);
690 let (truncated, decimal) = (value / divisor, (value % divisor).abs());
691 let is_overflow = truncated.abs() > i32::MAX.into();
692 if is_overflow {
693 return Err(cast_overflow(
694 &format!("{}.{}BD", truncated, decimal),
695 &format!("DECIMAL({},{})", $precision, $scale),
696 $dest_type_str,
697 ));
698 }
699 let i32_value = truncated as i32;
700 <$rust_dest_type>::try_from(i32_value)
701 .map_err(|_| {
702 cast_overflow(
703 &format!("{}.{}BD", truncated, decimal),
704 &format!("DECIMAL({},{})", $precision, $scale),
705 $dest_type_str,
706 )
707 })
708 .map(Some)
709 }
710 None => Ok(None),
711 })
712 .collect::<Result<$dest_array_type, _>>()?,
713 _ => cast_array
714 .iter()
715 .map(|value| match value {
716 Some(value) => {
717 let divisor = 10_i128.pow($scale as u32);
718 let i32_value = (value / divisor) as i32;
719 Ok::<Option<$rust_dest_type>, SparkError>(Some(
720 i32_value as $rust_dest_type,
721 ))
722 }
723 None => Ok(None),
724 })
725 .collect::<Result<$dest_array_type, _>>()?,
726 };
727 Ok(Arc::new(output_array) as ArrayRef)
728 }};
729}
730
731macro_rules! cast_decimal_to_int32_up {
732 (
733 $array:expr,
734 $eval_mode:expr,
735 $dest_array_type:ty,
736 $rust_dest_type:ty,
737 $dest_type_str:expr,
738 $max_dest_val:expr,
739 $precision:expr,
740 $scale:expr
741 ) => {{
742 let cast_array = $array
743 .as_any()
744 .downcast_ref::<Decimal128Array>()
745 .expect(concat!("Expected a Decimal128ArrayType"));
746
747 let output_array = match $eval_mode {
748 EvalMode::Ansi => cast_array
749 .iter()
750 .map(|value| match value {
751 Some(value) => {
752 let divisor = 10_i128.pow($scale as u32);
753 let (truncated, decimal) = (value / divisor, (value % divisor).abs());
754 let is_overflow = truncated.abs() > $max_dest_val.into();
755 if is_overflow {
756 return Err(cast_overflow(
757 &format!("{}.{}BD", truncated, decimal),
758 &format!("DECIMAL({},{})", $precision, $scale),
759 $dest_type_str,
760 ));
761 }
762 Ok(Some(truncated as $rust_dest_type))
763 }
764 None => Ok(None),
765 })
766 .collect::<Result<$dest_array_type, _>>()?,
767 _ => cast_array
768 .iter()
769 .map(|value| match value {
770 Some(value) => {
771 let divisor = 10_i128.pow($scale as u32);
772 let truncated = value / divisor;
773 Ok::<Option<$rust_dest_type>, SparkError>(Some(
774 truncated as $rust_dest_type,
775 ))
776 }
777 None => Ok(None),
778 })
779 .collect::<Result<$dest_array_type, _>>()?,
780 };
781 Ok(Arc::new(output_array) as ArrayRef)
782 }};
783}
784
785impl Cast {
786 pub fn new(
787 child: Arc<dyn PhysicalExpr>,
788 data_type: DataType,
789 cast_options: SparkCastOptions,
790 ) -> Self {
791 Self {
792 child,
793 data_type,
794 cast_options,
795 }
796 }
797}
798
799#[derive(Debug, Clone, Hash, PartialEq, Eq)]
801pub struct SparkCastOptions {
802 pub eval_mode: EvalMode,
804 pub timezone: String,
808 pub allow_incompat: bool,
810 pub allow_cast_unsigned_ints: bool,
812 pub is_adapting_schema: bool,
815 pub null_string: String,
817}
818
819impl SparkCastOptions {
820 pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self {
821 Self {
822 eval_mode,
823 timezone: timezone.to_string(),
824 allow_incompat,
825 allow_cast_unsigned_ints: false,
826 is_adapting_schema: false,
827 null_string: "null".to_string(),
828 }
829 }
830
831 pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self {
832 Self {
833 eval_mode,
834 timezone: "".to_string(),
835 allow_incompat,
836 allow_cast_unsigned_ints: false,
837 is_adapting_schema: false,
838 null_string: "null".to_string(),
839 }
840 }
841}
842
843pub fn spark_cast(
847 arg: ColumnarValue,
848 data_type: &DataType,
849 cast_options: &SparkCastOptions,
850) -> DataFusionResult<ColumnarValue> {
851 match arg {
852 ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array(
853 array,
854 data_type,
855 cast_options,
856 )?)),
857 ColumnarValue::Scalar(scalar) => {
858 let array = scalar.to_array()?;
862 let scalar =
863 ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?;
864 Ok(ColumnarValue::Scalar(scalar))
865 }
866 }
867}
868
869fn cast_array(
870 array: ArrayRef,
871 to_type: &DataType,
872 cast_options: &SparkCastOptions,
873) -> DataFusionResult<ArrayRef> {
874 use DataType::*;
875 let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
876 let from_type = array.data_type().clone();
877
878 let native_cast_options: CastOptions = CastOptions {
879 safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), format_options: FormatOptions::new()
881 .with_timestamp_tz_format(TIMESTAMP_FORMAT)
882 .with_timestamp_format(TIMESTAMP_FORMAT),
883 };
884
885 let array = match &from_type {
886 Dictionary(key_type, value_type)
887 if key_type.as_ref() == &Int32
888 && (value_type.as_ref() == &Utf8
889 || value_type.as_ref() == &LargeUtf8
890 || value_type.as_ref() == &Binary
891 || value_type.as_ref() == &LargeBinary) =>
892 {
893 let dict_array = array
894 .as_any()
895 .downcast_ref::<DictionaryArray<Int32Type>>()
896 .expect("Expected a dictionary array");
897
898 let casted_dictionary = DictionaryArray::<Int32Type>::new(
899 dict_array.keys().clone(),
900 cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?,
901 );
902
903 let casted_result = match to_type {
904 Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
905 _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?,
906 };
907 return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
908 }
909 _ => array,
910 };
911 let from_type = array.data_type();
912 let eval_mode = cast_options.eval_mode;
913
914 let cast_result = match (from_type, to_type) {
915 (Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
916 (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
917 (Utf8, Timestamp(_, _)) => {
918 cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
919 }
920 (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
921 (Int64, Int32)
922 | (Int64, Int16)
923 | (Int64, Int8)
924 | (Int32, Int16)
925 | (Int32, Int8)
926 | (Int16, Int8)
927 if eval_mode != EvalMode::Try =>
928 {
929 spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
930 }
931 (Utf8, Int8 | Int16 | Int32 | Int64) => {
932 cast_string_to_int::<i32>(to_type, &array, eval_mode)
933 }
934 (LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
935 cast_string_to_int::<i64>(to_type, &array, eval_mode)
936 }
937 (Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
938 (Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array, eval_mode),
939 (Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
940 (Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array, eval_mode),
941 (Float32, Decimal128(precision, scale)) => {
942 cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
943 }
944 (Float64, Decimal128(precision, scale)) => {
945 cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
946 }
947 (Float32, Int8)
948 | (Float32, Int16)
949 | (Float32, Int32)
950 | (Float32, Int64)
951 | (Float64, Int8)
952 | (Float64, Int16)
953 | (Float64, Int32)
954 | (Float64, Int64)
955 | (Decimal128(_, _), Int8)
956 | (Decimal128(_, _), Int16)
957 | (Decimal128(_, _), Int32)
958 | (Decimal128(_, _), Int64)
959 if eval_mode != EvalMode::Try =>
960 {
961 spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
962 }
963 (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
964 (Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
965 array.as_struct(),
966 from_type,
967 to_type,
968 cast_options,
969 )?),
970 (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
971 if cast_options.allow_cast_unsigned_ints =>
972 {
973 Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
974 }
975 _ if cast_options.is_adapting_schema
976 || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
977 {
978 Ok(cast_with_options(&array, to_type, &native_cast_options)?)
980 }
981 _ => {
982 Err(SparkError::Internal(format!(
986 "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}"
987 )))
988 }
989 };
990 Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
991}
992
993fn is_datafusion_spark_compatible(
996 from_type: &DataType,
997 to_type: &DataType,
998 allow_incompat: bool,
999) -> bool {
1000 if from_type == to_type {
1001 return true;
1002 }
1003 match from_type {
1004 DataType::Null => {
1005 matches!(to_type, DataType::List(_))
1006 }
1007 DataType::Boolean => matches!(
1008 to_type,
1009 DataType::Int8
1010 | DataType::Int16
1011 | DataType::Int32
1012 | DataType::Int64
1013 | DataType::Float32
1014 | DataType::Float64
1015 | DataType::Utf8
1016 ),
1017 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1018 matches!(
1022 to_type,
1023 DataType::Boolean
1024 | DataType::Int8
1025 | DataType::Int16
1026 | DataType::Int32
1027 | DataType::Int64
1028 | DataType::Float32
1029 | DataType::Float64
1030 | DataType::Decimal128(_, _)
1031 | DataType::Utf8
1032 )
1033 }
1034 DataType::Float32 | DataType::Float64 => matches!(
1035 to_type,
1036 DataType::Boolean
1037 | DataType::Int8
1038 | DataType::Int16
1039 | DataType::Int32
1040 | DataType::Int64
1041 | DataType::Float32
1042 | DataType::Float64
1043 ),
1044 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1045 to_type,
1046 DataType::Int8
1047 | DataType::Int16
1048 | DataType::Int32
1049 | DataType::Int64
1050 | DataType::Float32
1051 | DataType::Float64
1052 | DataType::Decimal128(_, _)
1053 | DataType::Decimal256(_, _)
1054 | DataType::Utf8 ),
1056 DataType::Utf8 if allow_incompat => matches!(
1057 to_type,
1058 DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
1059 ),
1060 DataType::Utf8 => matches!(to_type, DataType::Binary),
1061 DataType::Date32 => matches!(to_type, DataType::Utf8),
1062 DataType::Timestamp(_, _) => {
1063 matches!(
1064 to_type,
1065 DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1066 )
1067 }
1068 DataType::Binary => {
1069 matches!(to_type, DataType::Utf8)
1072 }
1073 _ => false,
1074 }
1075}
1076
1077fn cast_struct_to_struct(
1080 array: &StructArray,
1081 from_type: &DataType,
1082 to_type: &DataType,
1083 cast_options: &SparkCastOptions,
1084) -> DataFusionResult<ArrayRef> {
1085 match (from_type, to_type) {
1086 (DataType::Struct(from_fields), DataType::Struct(to_fields)) => {
1087 let cast_fields: Vec<ArrayRef> = from_fields
1088 .iter()
1089 .enumerate()
1090 .zip(to_fields.iter())
1091 .map(|((idx, _from), to)| {
1092 let from_field = Arc::clone(array.column(idx));
1093 let array_length = from_field.len();
1094 let cast_result = spark_cast(
1095 ColumnarValue::from(from_field),
1096 to.data_type(),
1097 cast_options,
1098 )
1099 .unwrap();
1100 cast_result.to_array(array_length).unwrap()
1101 })
1102 .collect();
1103
1104 Ok(Arc::new(StructArray::new(
1105 to_fields.clone(),
1106 cast_fields,
1107 array.nulls().cloned(),
1108 )))
1109 }
1110 _ => unreachable!(),
1111 }
1112}
1113
1114fn casts_struct_to_string(
1115 array: &StructArray,
1116 spark_cast_options: &SparkCastOptions,
1117) -> DataFusionResult<ArrayRef> {
1118 let string_arrays: Vec<ArrayRef> = array
1120 .columns()
1121 .iter()
1122 .map(|arr| {
1123 spark_cast(
1124 ColumnarValue::Array(Arc::clone(arr)),
1125 &DataType::Utf8,
1126 spark_cast_options,
1127 )
1128 .and_then(|cv| cv.into_array(arr.len()))
1129 })
1130 .collect::<DataFusionResult<Vec<_>>>()?;
1131 let string_arrays: Vec<&StringArray> =
1132 string_arrays.iter().map(|arr| arr.as_string()).collect();
1133 let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
1135 let mut str = String::with_capacity(array.len() * 16);
1136 for row_index in 0..array.len() {
1137 if array.is_null(row_index) {
1138 builder.append_null();
1139 } else {
1140 str.clear();
1141 let mut any_fields_written = false;
1142 str.push('{');
1143 for field in &string_arrays {
1144 if any_fields_written {
1145 str.push_str(", ");
1146 }
1147 if field.is_null(row_index) {
1148 str.push_str(&spark_cast_options.null_string);
1149 } else {
1150 str.push_str(field.value(row_index));
1151 }
1152 any_fields_written = true;
1153 }
1154 str.push('}');
1155 builder.append_value(&str);
1156 }
1157 }
1158 Ok(Arc::new(builder.finish()))
1159}
1160
1161fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
1162 to_type: &DataType,
1163 array: &ArrayRef,
1164 eval_mode: EvalMode,
1165) -> SparkResult<ArrayRef> {
1166 let string_array = array
1167 .as_any()
1168 .downcast_ref::<GenericStringArray<OffsetSize>>()
1169 .expect("cast_string_to_int expected a string array");
1170
1171 let cast_array: ArrayRef = match to_type {
1172 DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?,
1173 DataType::Int16 => {
1174 cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)?
1175 }
1176 DataType::Int32 => {
1177 cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)?
1178 }
1179 DataType::Int64 => {
1180 cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)?
1181 }
1182 dt => unreachable!(
1183 "{}",
1184 format!("invalid integer type {dt} in cast from string")
1185 ),
1186 };
1187 Ok(cast_array)
1188}
1189
1190fn cast_string_to_date(
1191 array: &ArrayRef,
1192 to_type: &DataType,
1193 eval_mode: EvalMode,
1194) -> SparkResult<ArrayRef> {
1195 let string_array = array
1196 .as_any()
1197 .downcast_ref::<GenericStringArray<i32>>()
1198 .expect("Expected a string array");
1199
1200 if to_type != &DataType::Date32 {
1201 unreachable!("Invalid data type {:?} in cast from string", to_type);
1202 }
1203
1204 let len = string_array.len();
1205 let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
1206
1207 for i in 0..len {
1208 let value = if string_array.is_null(i) {
1209 None
1210 } else {
1211 match date_parser(string_array.value(i), eval_mode) {
1212 Ok(Some(cast_value)) => Some(cast_value),
1213 Ok(None) => None,
1214 Err(e) => return Err(e),
1215 }
1216 };
1217
1218 match value {
1219 Some(cast_value) => cast_array.append_value(cast_value),
1220 None => cast_array.append_null(),
1221 }
1222 }
1223
1224 Ok(Arc::new(cast_array.finish()) as ArrayRef)
1225}
1226
1227fn cast_string_to_timestamp(
1228 array: &ArrayRef,
1229 to_type: &DataType,
1230 eval_mode: EvalMode,
1231 timezone_str: &str,
1232) -> SparkResult<ArrayRef> {
1233 let string_array = array
1234 .as_any()
1235 .downcast_ref::<GenericStringArray<i32>>()
1236 .expect("Expected a string array");
1237
1238 let tz = &timezone::Tz::from_str(timezone_str).unwrap();
1239
1240 let cast_array: ArrayRef = match to_type {
1241 DataType::Timestamp(_, _) => {
1242 cast_utf8_to_timestamp!(
1243 string_array,
1244 eval_mode,
1245 TimestampMicrosecondType,
1246 timestamp_parser,
1247 tz
1248 )
1249 }
1250 _ => unreachable!("Invalid data type {:?} in cast from string", to_type),
1251 };
1252 Ok(cast_array)
1253}
1254
1255fn cast_float64_to_decimal128(
1256 array: &dyn Array,
1257 precision: u8,
1258 scale: i8,
1259 eval_mode: EvalMode,
1260) -> SparkResult<ArrayRef> {
1261 cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode)
1262}
1263
1264fn cast_float32_to_decimal128(
1265 array: &dyn Array,
1266 precision: u8,
1267 scale: i8,
1268 eval_mode: EvalMode,
1269) -> SparkResult<ArrayRef> {
1270 cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode)
1271}
1272
1273fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
1274 array: &dyn Array,
1275 precision: u8,
1276 scale: i8,
1277 eval_mode: EvalMode,
1278) -> SparkResult<ArrayRef>
1279where
1280 <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
1281{
1282 let input = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1283 let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());
1284
1285 let mul = 10_f64.powi(scale as i32);
1286
1287 for i in 0..input.len() {
1288 if input.is_null(i) {
1289 cast_array.append_null();
1290 continue;
1291 }
1292
1293 let input_value = input.value(i).as_();
1294 if let Some(v) = (input_value * mul).round().to_i128() {
1295 if is_validate_decimal_precision(v, precision) {
1296 cast_array.append_value(v);
1297 continue;
1298 }
1299 };
1300
1301 if eval_mode == EvalMode::Ansi {
1302 return Err(SparkError::NumericValueOutOfRange {
1303 value: input_value.to_string(),
1304 precision,
1305 scale,
1306 });
1307 }
1308 cast_array.append_null();
1309 }
1310
1311 let res = Arc::new(
1312 cast_array
1313 .with_precision_and_scale(precision, scale)?
1314 .finish(),
1315 ) as ArrayRef;
1316 Ok(res)
1317}
1318
1319fn spark_cast_float64_to_utf8<OffsetSize>(
1320 from: &dyn Array,
1321 _eval_mode: EvalMode,
1322) -> SparkResult<ArrayRef>
1323where
1324 OffsetSize: OffsetSizeTrait,
1325{
1326 cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
1327}
1328
1329fn spark_cast_float32_to_utf8<OffsetSize>(
1330 from: &dyn Array,
1331 _eval_mode: EvalMode,
1332) -> SparkResult<ArrayRef>
1333where
1334 OffsetSize: OffsetSizeTrait,
1335{
1336 cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
1337}
1338
1339fn spark_cast_int_to_int(
1340 array: &dyn Array,
1341 eval_mode: EvalMode,
1342 from_type: &DataType,
1343 to_type: &DataType,
1344) -> SparkResult<ArrayRef> {
1345 match (from_type, to_type) {
1346 (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
1347 array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
1348 ),
1349 (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
1350 array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
1351 ),
1352 (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
1353 array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
1354 ),
1355 (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
1356 array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
1357 ),
1358 (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
1359 array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
1360 ),
1361 (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
1362 array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
1363 ),
1364 _ => unreachable!(
1365 "{}",
1366 format!("invalid integer type {to_type} in cast from {from_type}")
1367 ),
1368 }
1369}
1370
1371fn spark_cast_utf8_to_boolean<OffsetSize>(
1372 from: &dyn Array,
1373 eval_mode: EvalMode,
1374) -> SparkResult<ArrayRef>
1375where
1376 OffsetSize: OffsetSizeTrait,
1377{
1378 let array = from
1379 .as_any()
1380 .downcast_ref::<GenericStringArray<OffsetSize>>()
1381 .unwrap();
1382
1383 let output_array = array
1384 .iter()
1385 .map(|value| match value {
1386 Some(value) => match value.to_ascii_lowercase().trim() {
1387 "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
1388 "f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
1389 _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue {
1390 value: value.to_string(),
1391 from_type: "STRING".to_string(),
1392 to_type: "BOOLEAN".to_string(),
1393 }),
1394 _ => Ok(None),
1395 },
1396 _ => Ok(None),
1397 })
1398 .collect::<Result<BooleanArray, _>>()?;
1399
1400 Ok(Arc::new(output_array))
1401}
1402
1403fn spark_cast_nonintegral_numeric_to_integral(
1404 array: &dyn Array,
1405 eval_mode: EvalMode,
1406 from_type: &DataType,
1407 to_type: &DataType,
1408) -> SparkResult<ArrayRef> {
1409 match (from_type, to_type) {
1410 (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
1411 array,
1412 eval_mode,
1413 Float32Array,
1414 Int8Array,
1415 f32,
1416 i8,
1417 "FLOAT",
1418 "TINYINT",
1419 "{:e}"
1420 ),
1421 (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!(
1422 array,
1423 eval_mode,
1424 Float32Array,
1425 Int16Array,
1426 f32,
1427 i16,
1428 "FLOAT",
1429 "SMALLINT",
1430 "{:e}"
1431 ),
1432 (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!(
1433 array,
1434 eval_mode,
1435 Float32Array,
1436 Int32Array,
1437 f32,
1438 i32,
1439 "FLOAT",
1440 "INT",
1441 i32::MAX,
1442 "{:e}"
1443 ),
1444 (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
1445 array,
1446 eval_mode,
1447 Float32Array,
1448 Int64Array,
1449 f32,
1450 i64,
1451 "FLOAT",
1452 "BIGINT",
1453 i64::MAX,
1454 "{:e}"
1455 ),
1456 (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
1457 array,
1458 eval_mode,
1459 Float64Array,
1460 Int8Array,
1461 f64,
1462 i8,
1463 "DOUBLE",
1464 "TINYINT",
1465 "{:e}D"
1466 ),
1467 (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!(
1468 array,
1469 eval_mode,
1470 Float64Array,
1471 Int16Array,
1472 f64,
1473 i16,
1474 "DOUBLE",
1475 "SMALLINT",
1476 "{:e}D"
1477 ),
1478 (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!(
1479 array,
1480 eval_mode,
1481 Float64Array,
1482 Int32Array,
1483 f64,
1484 i32,
1485 "DOUBLE",
1486 "INT",
1487 i32::MAX,
1488 "{:e}D"
1489 ),
1490 (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
1491 array,
1492 eval_mode,
1493 Float64Array,
1494 Int64Array,
1495 f64,
1496 i64,
1497 "DOUBLE",
1498 "BIGINT",
1499 i64::MAX,
1500 "{:e}D"
1501 ),
1502 (DataType::Decimal128(precision, scale), DataType::Int8) => {
1503 cast_decimal_to_int16_down!(
1504 array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1505 )
1506 }
1507 (DataType::Decimal128(precision, scale), DataType::Int16) => {
1508 cast_decimal_to_int16_down!(
1509 array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1510 )
1511 }
1512 (DataType::Decimal128(precision, scale), DataType::Int32) => {
1513 cast_decimal_to_int32_up!(
1514 array,
1515 eval_mode,
1516 Int32Array,
1517 i32,
1518 "INT",
1519 i32::MAX,
1520 *precision,
1521 *scale
1522 )
1523 }
1524 (DataType::Decimal128(precision, scale), DataType::Int64) => {
1525 cast_decimal_to_int32_up!(
1526 array,
1527 eval_mode,
1528 Int64Array,
1529 i64,
1530 "BIGINT",
1531 i64::MAX,
1532 *precision,
1533 *scale
1534 )
1535 }
1536 _ => unreachable!(
1537 "{}",
1538 format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}")
1539 ),
1540 }
1541}
1542
1543fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> {
1545 Ok(cast_string_to_int_with_range_check(
1546 str,
1547 eval_mode,
1548 "TINYINT",
1549 i8::MIN as i32,
1550 i8::MAX as i32,
1551 )?
1552 .map(|v| v as i8))
1553}
1554
1555fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>> {
1557 Ok(cast_string_to_int_with_range_check(
1558 str,
1559 eval_mode,
1560 "SMALLINT",
1561 i16::MIN as i32,
1562 i16::MAX as i32,
1563 )?
1564 .map(|v| v as i16))
1565}
1566
1567fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1569 do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
1570}
1571
1572fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> {
1574 do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
1575}
1576
1577fn cast_string_to_int_with_range_check(
1578 str: &str,
1579 eval_mode: EvalMode,
1580 type_name: &str,
1581 min: i32,
1582 max: i32,
1583) -> SparkResult<Option<i32>> {
1584 match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
1585 None => Ok(None),
1586 Some(v) if v >= min && v <= max => Ok(Some(v)),
1587 _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1588 _ => Ok(None),
1589 }
1590}
1591
1592fn do_cast_string_to_int<
1596 T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
1597>(
1598 str: &str,
1599 eval_mode: EvalMode,
1600 type_name: &str,
1601 min_value: T,
1602) -> SparkResult<Option<T>> {
1603 let trimmed_str = str.trim();
1604 if trimmed_str.is_empty() {
1605 return none_or_err(eval_mode, type_name, str);
1606 }
1607 let len = trimmed_str.len();
1608 let mut result: T = T::zero();
1609 let mut negative = false;
1610 let radix = T::from(10);
1611 let stop_value = min_value / radix;
1612 let mut parse_sign_and_digits = true;
1613
1614 for (i, ch) in trimmed_str.char_indices() {
1615 if parse_sign_and_digits {
1616 if i == 0 {
1617 negative = ch == '-';
1618 let positive = ch == '+';
1619 if negative || positive {
1620 if i + 1 == len {
1621 return none_or_err(eval_mode, type_name, str);
1623 }
1624 continue;
1626 }
1627 }
1628
1629 if ch == '.' {
1630 if eval_mode == EvalMode::Legacy {
1631 parse_sign_and_digits = false;
1633 continue;
1634 } else {
1635 return none_or_err(eval_mode, type_name, str);
1636 }
1637 }
1638
1639 let digit = if ch.is_ascii_digit() {
1640 (ch as u32) - ('0' as u32)
1641 } else {
1642 return none_or_err(eval_mode, type_name, str);
1643 };
1644
1645 if result < stop_value {
1650 return none_or_err(eval_mode, type_name, str);
1651 }
1652
1653 let v = result * radix;
1657 let digit = (digit as i32).into();
1658 match v.checked_sub(&digit) {
1659 Some(x) if x <= T::zero() => result = x,
1660 _ => {
1661 return none_or_err(eval_mode, type_name, str);
1662 }
1663 }
1664 } else {
1665 if !ch.is_ascii_digit() {
1667 return none_or_err(eval_mode, type_name, str);
1668 }
1669 }
1670 }
1671
1672 if !negative {
1673 if let Some(neg) = result.checked_neg() {
1674 if neg < T::zero() {
1675 return none_or_err(eval_mode, type_name, str);
1676 }
1677 result = neg;
1678 } else {
1679 return none_or_err(eval_mode, type_name, str);
1680 }
1681 }
1682
1683 Ok(Some(result))
1684}
1685
1686#[inline]
1688fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {
1689 match eval_mode {
1690 EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1691 _ => Ok(None),
1692 }
1693}
1694
1695#[inline]
1696fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
1697 SparkError::CastInvalidValue {
1698 value: value.to_string(),
1699 from_type: from_type.to_string(),
1700 to_type: to_type.to_string(),
1701 }
1702}
1703
1704#[inline]
1705fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
1706 SparkError::CastOverFlow {
1707 value: value.to_string(),
1708 from_type: from_type.to_string(),
1709 to_type: to_type.to_string(),
1710 }
1711}
1712
1713impl Display for Cast {
1714 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1715 write!(
1716 f,
1717 "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]",
1718 self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode
1719 )
1720 }
1721}
1722
1723impl PhysicalExpr for Cast {
1724 fn as_any(&self) -> &dyn Any {
1725 self
1726 }
1727
1728 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
1729 unimplemented!()
1730 }
1731
1732 fn data_type(&self, _: &Schema) -> DataFusionResult<DataType> {
1733 Ok(self.data_type.clone())
1734 }
1735
1736 fn nullable(&self, _: &Schema) -> DataFusionResult<bool> {
1737 Ok(true)
1738 }
1739
1740 fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
1741 let arg = self.child.evaluate(batch)?;
1742 spark_cast(arg, &self.data_type, &self.cast_options)
1743 }
1744
1745 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1746 vec![&self.child]
1747 }
1748
1749 fn with_new_children(
1750 self: Arc<Self>,
1751 children: Vec<Arc<dyn PhysicalExpr>>,
1752 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
1753 match children.len() {
1754 1 => Ok(Arc::new(Cast::new(
1755 Arc::clone(&children[0]),
1756 self.data_type.clone(),
1757 self.cast_options.clone(),
1758 ))),
1759 _ => internal_err!("Cast should have exactly one child"),
1760 }
1761 }
1762}
1763
1764fn timestamp_parser<T: TimeZone>(
1765 value: &str,
1766 eval_mode: EvalMode,
1767 tz: &T,
1768) -> SparkResult<Option<i64>> {
1769 let value = value.trim();
1770 if value.is_empty() {
1771 return Ok(None);
1772 }
1773 let patterns = &[
1775 (
1776 Regex::new(r"^\d{4,5}$").unwrap(),
1777 parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
1778 ),
1779 (
1780 Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1781 parse_str_to_month_timestamp,
1782 ),
1783 (
1784 Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1785 parse_str_to_day_timestamp,
1786 ),
1787 (
1788 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1789 parse_str_to_hour_timestamp,
1790 ),
1791 (
1792 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1793 parse_str_to_minute_timestamp,
1794 ),
1795 (
1796 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1797 parse_str_to_second_timestamp,
1798 ),
1799 (
1800 Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1801 parse_str_to_microsecond_timestamp,
1802 ),
1803 (
1804 Regex::new(r"^T\d{1,2}$").unwrap(),
1805 parse_str_to_time_only_timestamp,
1806 ),
1807 ];
1808
1809 let mut timestamp = None;
1810
1811 for (pattern, parse_func) in patterns {
1813 if pattern.is_match(value) {
1814 timestamp = parse_func(value, tz)?;
1815 break;
1816 }
1817 }
1818
1819 if timestamp.is_none() {
1820 return if eval_mode == EvalMode::Ansi {
1821 Err(SparkError::CastInvalidValue {
1822 value: value.to_string(),
1823 from_type: "STRING".to_string(),
1824 to_type: "TIMESTAMP".to_string(),
1825 })
1826 } else {
1827 Ok(None)
1828 };
1829 }
1830
1831 match timestamp {
1832 Some(ts) => Ok(Some(ts)),
1833 None => Err(SparkError::Internal(
1834 "Failed to parse timestamp".to_string(),
1835 )),
1836 }
1837}
1838
1839fn parse_timestamp_to_micros<T: TimeZone>(
1840 timestamp_info: &TimeStampInfo,
1841 tz: &T,
1842) -> SparkResult<Option<i64>> {
1843 let datetime = tz.with_ymd_and_hms(
1844 timestamp_info.year,
1845 timestamp_info.month,
1846 timestamp_info.day,
1847 timestamp_info.hour,
1848 timestamp_info.minute,
1849 timestamp_info.second,
1850 );
1851
1852 let tz_datetime = match datetime.single() {
1854 Some(dt) => dt
1855 .with_timezone(tz)
1856 .with_nanosecond(timestamp_info.microsecond * 1000),
1857 None => {
1858 return Err(SparkError::Internal(
1859 "Failed to parse timestamp".to_string(),
1860 ));
1861 }
1862 };
1863
1864 let result = match tz_datetime {
1865 Some(dt) => dt.timestamp_micros(),
1866 None => {
1867 return Err(SparkError::Internal(
1868 "Failed to parse timestamp".to_string(),
1869 ));
1870 }
1871 };
1872
1873 Ok(Some(result))
1874}
1875
1876fn get_timestamp_values<T: TimeZone>(
1877 value: &str,
1878 timestamp_type: &str,
1879 tz: &T,
1880) -> SparkResult<Option<i64>> {
1881 let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
1882 let year = values[0].parse::<i32>().unwrap_or_default();
1883 let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
1884 let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
1885 let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
1886 let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
1887 let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
1888 let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));
1889
1890 let mut timestamp_info = TimeStampInfo::default();
1891
1892 let timestamp_info = match timestamp_type {
1893 "year" => timestamp_info.with_year(year),
1894 "month" => timestamp_info.with_year(year).with_month(month),
1895 "day" => timestamp_info
1896 .with_year(year)
1897 .with_month(month)
1898 .with_day(day),
1899 "hour" => timestamp_info
1900 .with_year(year)
1901 .with_month(month)
1902 .with_day(day)
1903 .with_hour(hour),
1904 "minute" => timestamp_info
1905 .with_year(year)
1906 .with_month(month)
1907 .with_day(day)
1908 .with_hour(hour)
1909 .with_minute(minute),
1910 "second" => timestamp_info
1911 .with_year(year)
1912 .with_month(month)
1913 .with_day(day)
1914 .with_hour(hour)
1915 .with_minute(minute)
1916 .with_second(second),
1917 "microsecond" => timestamp_info
1918 .with_year(year)
1919 .with_month(month)
1920 .with_day(day)
1921 .with_hour(hour)
1922 .with_minute(minute)
1923 .with_second(second)
1924 .with_microsecond(microsecond),
1925 _ => {
1926 return Err(SparkError::CastInvalidValue {
1927 value: value.to_string(),
1928 from_type: "STRING".to_string(),
1929 to_type: "TIMESTAMP".to_string(),
1930 })
1931 }
1932 };
1933
1934 parse_timestamp_to_micros(timestamp_info, tz)
1935}
1936
1937fn parse_str_to_year_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1938 get_timestamp_values(value, "year", tz)
1939}
1940
1941fn parse_str_to_month_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1942 get_timestamp_values(value, "month", tz)
1943}
1944
1945fn parse_str_to_day_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1946 get_timestamp_values(value, "day", tz)
1947}
1948
1949fn parse_str_to_hour_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1950 get_timestamp_values(value, "hour", tz)
1951}
1952
1953fn parse_str_to_minute_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1954 get_timestamp_values(value, "minute", tz)
1955}
1956
1957fn parse_str_to_second_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1958 get_timestamp_values(value, "second", tz)
1959}
1960
1961fn parse_str_to_microsecond_timestamp<T: TimeZone>(
1962 value: &str,
1963 tz: &T,
1964) -> SparkResult<Option<i64>> {
1965 get_timestamp_values(value, "microsecond", tz)
1966}
1967
1968fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1969 let values: Vec<&str> = value.split('T').collect();
1970 let time_values: Vec<u32> = values[1]
1971 .split(':')
1972 .map(|v| v.parse::<u32>().unwrap_or(0))
1973 .collect();
1974
1975 let datetime = tz.from_utc_datetime(&chrono::Utc::now().naive_utc());
1976 let timestamp = datetime
1977 .with_timezone(tz)
1978 .with_hour(time_values.first().copied().unwrap_or_default())
1979 .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
1980 .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
1981 .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000))
1982 .map(|dt| dt.timestamp_micros())
1983 .unwrap_or_default();
1984
1985 Ok(Some(timestamp))
1986}
1987
1988fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1990 fn get_trimmed_start(bytes: &[u8]) -> usize {
1992 let mut start = 0;
1993 while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) {
1994 start += 1;
1995 }
1996 start
1997 }
1998
1999 fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize {
2000 let mut end = bytes.len() - 1;
2001 while end > start && is_whitespace_or_iso_control(bytes[end]) {
2002 end -= 1;
2003 }
2004 end + 1
2005 }
2006
2007 fn is_whitespace_or_iso_control(byte: u8) -> bool {
2008 byte.is_ascii_whitespace() || byte.is_ascii_control()
2009 }
2010
2011 fn is_valid_digits(segment: i32, digits: usize) -> bool {
2012 let max_digits_year = 7;
2014 (segment == 0 && digits >= 4 && digits <= max_digits_year)
2017 || (segment != 0 && digits > 0 && digits <= 2)
2018 }
2019
2020 fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
2021 if eval_mode == EvalMode::Ansi {
2022 Err(SparkError::CastInvalidValue {
2023 value: date_str.to_string(),
2024 from_type: "STRING".to_string(),
2025 to_type: "DATE".to_string(),
2026 })
2027 } else {
2028 Ok(None)
2029 }
2030 }
2031 if date_str.is_empty() {
2034 return return_result(date_str, eval_mode);
2035 }
2036
2037 let mut date_segments = [1, 1, 1];
2039 let mut sign = 1;
2040 let mut current_segment = 0;
2041 let mut current_segment_value = Wrapping(0);
2042 let mut current_segment_digits = 0;
2043 let bytes = date_str.as_bytes();
2044
2045 let mut j = get_trimmed_start(bytes);
2046 let str_end_trimmed = get_trimmed_end(j, bytes);
2047
2048 if j == str_end_trimmed {
2049 return return_result(date_str, eval_mode);
2050 }
2051
2052 if bytes[j] == b'-' || bytes[j] == b'+' {
2054 sign = if bytes[j] == b'-' { -1 } else { 1 };
2055 j += 1;
2056 }
2057
2058 while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) {
2061 let b = bytes[j];
2062 if current_segment < 2 && b == b'-' {
2063 if !is_valid_digits(current_segment, current_segment_digits) {
2065 return return_result(date_str, eval_mode);
2066 }
2067 date_segments[current_segment as usize] = current_segment_value.0;
2069 current_segment_value = Wrapping(0);
2070 current_segment_digits = 0;
2071 current_segment += 1;
2072 } else if !b.is_ascii_digit() {
2073 return return_result(date_str, eval_mode);
2074 } else {
2075 let parsed_value = Wrapping((b - b'0') as i32);
2077 current_segment_value = current_segment_value * Wrapping(10) + parsed_value;
2078 current_segment_digits += 1;
2079 }
2080 j += 1;
2081 }
2082
2083 if !is_valid_digits(current_segment, current_segment_digits) {
2085 return return_result(date_str, eval_mode);
2086 }
2087
2088 if current_segment < 2 && j < str_end_trimmed {
2089 return return_result(date_str, eval_mode);
2091 }
2092
2093 date_segments[current_segment as usize] = current_segment_value.0;
2094
2095 match NaiveDate::from_ymd_opt(
2096 sign * date_segments[0],
2097 date_segments[1] as u32,
2098 date_segments[2] as u32,
2099 ) {
2100 Some(date) => {
2101 let duration_since_epoch = date
2102 .signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
2103 .num_days();
2104 Ok(Some(duration_since_epoch.to_i32().unwrap()))
2105 }
2106 None => Ok(None),
2107 }
2108}
2109
2110fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef {
2116 match (from_type, to_type) {
2117 (DataType::Timestamp(_, _), DataType::Int64) => {
2118 unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2120 }
2121 (DataType::Dictionary(_, value_type), DataType::Int64)
2122 if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2123 {
2124 unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2126 }
2127 (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array),
2128 (DataType::Dictionary(_, value_type), DataType::Utf8)
2129 if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2130 {
2131 remove_trailing_zeroes(array)
2132 }
2133 _ => array,
2134 }
2135}
2136
2137fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
2139where
2140 T: ArrowPrimitiveType,
2141 F: Fn(T::Native) -> T::Native,
2142{
2143 if let Some(d) = array.as_any_dictionary_opt() {
2144 let new_values = unary_dyn::<F, T>(d.values(), op)?;
2145 return Ok(Arc::new(d.with_values(Arc::new(new_values))));
2146 }
2147
2148 match array.as_primitive_opt::<T>() {
2149 Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
2150 Ok(Arc::new(unary::<T, F, T>(
2151 array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
2152 op,
2153 )))
2154 }
2155 _ => Err(ArrowError::NotYetImplemented(format!(
2156 "Cannot perform unary operation of type {} on array of type {}",
2157 T::DATA_TYPE,
2158 array.data_type()
2159 ))),
2160 }
2161}
2162
2163fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
2172 let string_array = as_generic_string_array::<i32>(&array).unwrap();
2173 let result = string_array
2174 .iter()
2175 .map(|s| s.map(trim_end))
2176 .collect::<GenericStringArray<i32>>();
2177 Arc::new(result) as ArrayRef
2178}
2179
2180fn trim_end(s: &str) -> &str {
2181 if s.rfind('.').is_some() {
2182 s.trim_end_matches('0')
2183 } else {
2184 s
2185 }
2186}
2187
2188#[cfg(test)]
2189mod tests {
2190 use arrow::array::StringArray;
2191 use arrow::datatypes::TimestampMicrosecondType;
2192 use arrow::datatypes::{Field, Fields, TimeUnit};
2193 use core::f64;
2194 use std::str::FromStr;
2195
2196 use super::*;
2197
2198 #[test]
2199 #[cfg_attr(miri, ignore)] fn timestamp_parser_test() {
2201 let tz = &timezone::Tz::from_str("UTC").unwrap();
2202 assert_eq!(
2204 timestamp_parser("2020", EvalMode::Legacy, tz).unwrap(),
2205 Some(1577836800000000) );
2207 assert_eq!(
2208 timestamp_parser("2020-01", EvalMode::Legacy, tz).unwrap(),
2209 Some(1577836800000000)
2210 );
2211 assert_eq!(
2212 timestamp_parser("2020-01-01", EvalMode::Legacy, tz).unwrap(),
2213 Some(1577836800000000)
2214 );
2215 assert_eq!(
2216 timestamp_parser("2020-01-01T12", EvalMode::Legacy, tz).unwrap(),
2217 Some(1577880000000000)
2218 );
2219 assert_eq!(
2220 timestamp_parser("2020-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2221 Some(1577882040000000)
2222 );
2223 assert_eq!(
2224 timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2225 Some(1577882096000000)
2226 );
2227 assert_eq!(
2228 timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2229 Some(1577882096123456)
2230 );
2231 assert_eq!(
2232 timestamp_parser("0100", EvalMode::Legacy, tz).unwrap(),
2233 Some(-59011459200000000)
2234 );
2235 assert_eq!(
2236 timestamp_parser("0100-01", EvalMode::Legacy, tz).unwrap(),
2237 Some(-59011459200000000)
2238 );
2239 assert_eq!(
2240 timestamp_parser("0100-01-01", EvalMode::Legacy, tz).unwrap(),
2241 Some(-59011459200000000)
2242 );
2243 assert_eq!(
2244 timestamp_parser("0100-01-01T12", EvalMode::Legacy, tz).unwrap(),
2245 Some(-59011416000000000)
2246 );
2247 assert_eq!(
2248 timestamp_parser("0100-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2249 Some(-59011413960000000)
2250 );
2251 assert_eq!(
2252 timestamp_parser("0100-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2253 Some(-59011413904000000)
2254 );
2255 assert_eq!(
2256 timestamp_parser("0100-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2257 Some(-59011413903876544)
2258 );
2259 assert_eq!(
2260 timestamp_parser("10000", EvalMode::Legacy, tz).unwrap(),
2261 Some(253402300800000000)
2262 );
2263 assert_eq!(
2264 timestamp_parser("10000-01", EvalMode::Legacy, tz).unwrap(),
2265 Some(253402300800000000)
2266 );
2267 assert_eq!(
2268 timestamp_parser("10000-01-01", EvalMode::Legacy, tz).unwrap(),
2269 Some(253402300800000000)
2270 );
2271 assert_eq!(
2272 timestamp_parser("10000-01-01T12", EvalMode::Legacy, tz).unwrap(),
2273 Some(253402344000000000)
2274 );
2275 assert_eq!(
2276 timestamp_parser("10000-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2277 Some(253402346040000000)
2278 );
2279 assert_eq!(
2280 timestamp_parser("10000-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2281 Some(253402346096000000)
2282 );
2283 assert_eq!(
2284 timestamp_parser("10000-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2285 Some(253402346096123456)
2286 );
2287 }
2292
2293 #[test]
2294 #[cfg_attr(miri, ignore)] fn test_cast_string_to_timestamp() {
2296 let array: ArrayRef = Arc::new(StringArray::from(vec![
2297 Some("2020-01-01T12:34:56.123456"),
2298 Some("T2"),
2299 Some("0100-01-01T12:34:56.123456"),
2300 Some("10000-01-01T12:34:56.123456"),
2301 ]));
2302 let tz = &timezone::Tz::from_str("UTC").unwrap();
2303
2304 let string_array = array
2305 .as_any()
2306 .downcast_ref::<GenericStringArray<i32>>()
2307 .expect("Expected a string array");
2308
2309 let eval_mode = EvalMode::Legacy;
2310 let result = cast_utf8_to_timestamp!(
2311 &string_array,
2312 eval_mode,
2313 TimestampMicrosecondType,
2314 timestamp_parser,
2315 tz
2316 );
2317
2318 assert_eq!(
2319 result.data_type(),
2320 &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
2321 );
2322 assert_eq!(result.len(), 4);
2323 }
2324
2325 #[test]
2326 fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
2327 let keys = Int32Array::from(vec![0, 1]);
2329 let values: ArrayRef = Arc::new(StringArray::from(vec![
2330 Some("2020-01-01T12:34:56.123456"),
2331 Some("T2"),
2332 ]));
2333 let dict_array = Arc::new(DictionaryArray::new(keys, values));
2334
2335 let timezone = "UTC".to_string();
2336 let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false);
2338 let result = cast_array(
2339 dict_array,
2340 &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
2341 &cast_options,
2342 )?;
2343 assert_eq!(
2344 *result.data_type(),
2345 DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
2346 );
2347 assert_eq!(result.len(), 2);
2348
2349 Ok(())
2350 }
2351
2352 #[test]
2353 fn date_parser_test() {
2354 for date in &[
2355 "2020",
2356 "2020-01",
2357 "2020-01-01",
2358 "02020-01-01",
2359 "002020-01-01",
2360 "0002020-01-01",
2361 "2020-1-1",
2362 "2020-01-01 ",
2363 "2020-01-01T",
2364 ] {
2365 for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] {
2366 assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262));
2367 }
2368 }
2369
2370 for date in &[
2372 "abc",
2373 "",
2374 "not_a_date",
2375 "3/",
2376 "3/12",
2377 "3/12/2020",
2378 "3/12/2002 T",
2379 "202",
2380 "2020-010-01",
2381 "2020-10-010",
2382 "2020-10-010T",
2383 "--262143-12-31",
2384 "--262143-12-31 ",
2385 ] {
2386 for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2387 assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2388 }
2389 assert!(date_parser(date, EvalMode::Ansi).is_err());
2390 }
2391
2392 for date in &["-3638-5"] {
2393 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2394 assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(-2048160));
2395 }
2396 }
2397
2398 for date in &[
2401 "-262144-1-1",
2402 "262143-01-1",
2403 "262143-1-1",
2404 "262143-01-1 ",
2405 "262143-01-01T ",
2406 "262143-1-01T 1234",
2407 "-0973250",
2408 ] {
2409 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2410 assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2411 }
2412 }
2413 }
2414
2415 #[test]
2416 fn test_cast_string_to_date() {
2417 let array: ArrayRef = Arc::new(StringArray::from(vec![
2418 Some("2020"),
2419 Some("2020-01"),
2420 Some("2020-01-01"),
2421 Some("2020-01-01T"),
2422 ]));
2423
2424 let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap();
2425
2426 let date32_array = result
2427 .as_any()
2428 .downcast_ref::<arrow::array::Date32Array>()
2429 .unwrap();
2430 assert_eq!(date32_array.len(), 4);
2431 date32_array
2432 .iter()
2433 .for_each(|v| assert_eq!(v.unwrap(), 18262));
2434 }
2435
2436 #[test]
2437 fn test_cast_string_array_with_valid_dates() {
2438 let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2439 Some("-262143-12-31"),
2440 Some("\n -262143-12-31 "),
2441 Some("-262143-12-31T \t\n"),
2442 Some("\n\t-262143-12-31T\r"),
2443 Some("-262143-12-31T 123123123"),
2444 Some("\r\n-262143-12-31T \r123123123"),
2445 Some("\n -262143-12-31T \n\t"),
2446 ]));
2447
2448 for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2449 let result =
2450 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2451 .unwrap();
2452
2453 let date32_array = result
2454 .as_any()
2455 .downcast_ref::<arrow::array::Date32Array>()
2456 .unwrap();
2457 assert_eq!(result.len(), 7);
2458 date32_array
2459 .iter()
2460 .for_each(|v| assert_eq!(v.unwrap(), -96464928));
2461 }
2462 }
2463
2464 #[test]
2465 fn test_cast_string_array_with_invalid_dates() {
2466 let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2467 Some("2020"),
2468 Some("2020-01"),
2469 Some("2020-01-01"),
2470 Some("2020-010-01T"),
2472 Some("202"),
2473 Some(" 202 "),
2474 Some("\n 2020-\r8 "),
2475 Some("2020-01-01T"),
2476 Some("-4607172990231812908"),
2478 ]));
2479
2480 for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2481 let result =
2482 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2483 .unwrap();
2484
2485 let date32_array = result
2486 .as_any()
2487 .downcast_ref::<arrow::array::Date32Array>()
2488 .unwrap();
2489 assert_eq!(
2490 date32_array.iter().collect::<Vec<_>>(),
2491 vec![
2492 Some(18262),
2493 Some(18262),
2494 Some(18262),
2495 None,
2496 None,
2497 None,
2498 None,
2499 Some(18262),
2500 None
2501 ]
2502 );
2503 }
2504
2505 let result =
2506 cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi);
2507 match result {
2508 Err(e) => assert!(
2509 e.to_string().contains(
2510 "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed")
2511 ),
2512 _ => panic!("Expected error"),
2513 }
2514 }
2515
2516 #[test]
2517 fn test_cast_string_as_i8() {
2518 assert_eq!(
2520 cast_string_to_i8("127", EvalMode::Legacy).unwrap(),
2521 Some(127_i8)
2522 );
2523 assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None);
2524 assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err());
2525 assert_eq!(
2527 cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(),
2528 Some(0_i8)
2529 );
2530 assert_eq!(
2531 cast_string_to_i8(".", EvalMode::Legacy).unwrap(),
2532 Some(0_i8)
2533 );
2534 assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None);
2536 assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None);
2537 assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err());
2539 assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err());
2540 }
2541
2542 #[test]
2543 fn test_cast_unsupported_timestamp_to_date() {
2544 let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2546 let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
2547 let result = cast_array(
2548 Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2549 &DataType::Date32,
2550 &cast_options,
2551 );
2552 assert!(result.is_err())
2553 }
2554
2555 #[test]
2556 fn test_cast_invalid_timezone() {
2557 let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2558 let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false);
2559 let result = cast_array(
2560 Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2561 &DataType::Date32,
2562 &cast_options,
2563 );
2564 assert!(result.is_err())
2565 }
2566
2567 #[test]
2568 fn test_cast_struct_to_utf8() {
2569 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2570 Some(1),
2571 Some(2),
2572 None,
2573 Some(4),
2574 Some(5),
2575 ]));
2576 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2577 let c: ArrayRef = Arc::new(StructArray::from(vec![
2578 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2579 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2580 ]));
2581 let string_array = cast_array(
2582 c,
2583 &DataType::Utf8,
2584 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2585 )
2586 .unwrap();
2587 let string_array = string_array.as_string::<i32>();
2588 assert_eq!(5, string_array.len());
2589 assert_eq!(r#"{1, a}"#, string_array.value(0));
2590 assert_eq!(r#"{2, b}"#, string_array.value(1));
2591 assert_eq!(r#"{null, c}"#, string_array.value(2));
2592 assert_eq!(r#"{4, d}"#, string_array.value(3));
2593 assert_eq!(r#"{5, e}"#, string_array.value(4));
2594 }
2595
2596 #[test]
2597 fn test_cast_struct_to_struct() {
2598 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2599 Some(1),
2600 Some(2),
2601 None,
2602 Some(4),
2603 Some(5),
2604 ]));
2605 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2606 let c: ArrayRef = Arc::new(StructArray::from(vec![
2607 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2608 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2609 ]));
2610 let fields = Fields::from(vec![
2612 Field::new("a", DataType::Utf8, true),
2613 Field::new("b", DataType::Utf8, true),
2614 ]);
2615 let cast_array = spark_cast(
2616 ColumnarValue::Array(c),
2617 &DataType::Struct(fields),
2618 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2619 )
2620 .unwrap();
2621 if let ColumnarValue::Array(cast_array) = cast_array {
2622 assert_eq!(5, cast_array.len());
2623 let a = cast_array.as_struct().column(0).as_string::<i32>();
2624 assert_eq!("1", a.value(0));
2625 } else {
2626 unreachable!()
2627 }
2628 }
2629
2630 #[test]
2631 fn test_cast_struct_to_struct_drop_column() {
2632 let a: ArrayRef = Arc::new(Int32Array::from(vec![
2633 Some(1),
2634 Some(2),
2635 None,
2636 Some(4),
2637 Some(5),
2638 ]));
2639 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2640 let c: ArrayRef = Arc::new(StructArray::from(vec![
2641 (Arc::new(Field::new("a", DataType::Int32, true)), a),
2642 (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2643 ]));
2644 let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]);
2646 let cast_array = spark_cast(
2647 ColumnarValue::Array(c),
2648 &DataType::Struct(fields),
2649 &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2650 )
2651 .unwrap();
2652 if let ColumnarValue::Array(cast_array) = cast_array {
2653 assert_eq!(5, cast_array.len());
2654 let struct_array = cast_array.as_struct();
2655 assert_eq!(1, struct_array.columns().len());
2656 let a = struct_array.column(0).as_string::<i32>();
2657 assert_eq!("1", a.value(0));
2658 } else {
2659 unreachable!()
2660 }
2661 }
2662
2663 #[test]
2664 #[cfg_attr(miri, ignore)]
2671 fn test_cast_float_to_decimal() {
2672 let a: ArrayRef = Arc::new(Float64Array::from(vec![
2673 Some(42.),
2674 Some(0.5153125),
2675 Some(-42.4242415),
2676 Some(42e-314),
2677 Some(0.),
2678 Some(-4242.424242),
2679 Some(f64::INFINITY),
2680 Some(f64::NEG_INFINITY),
2681 Some(f64::NAN),
2682 None,
2683 ]));
2684 let b =
2685 cast_floating_point_to_decimal128::<Float64Type>(&a, 8, 6, EvalMode::Legacy).unwrap();
2686 assert_eq!(b.len(), a.len());
2687 let casted = b.as_primitive::<Decimal128Type>();
2688 assert_eq!(casted.value(0), 42000000);
2689 assert_eq!(casted.value(2), -42424242);
2692 assert_eq!(casted.value(3), 0);
2693 assert_eq!(casted.value(4), 0);
2694 assert!(casted.is_null(5));
2695 assert!(casted.is_null(6));
2696 assert!(casted.is_null(7));
2697 assert!(casted.is_null(8));
2698 assert!(casted.is_null(9));
2699 }
2700}