datafusion_comet_spark_expr/conversion_funcs/
cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::timezone;
19use crate::utils::array_with_timezone;
20use crate::{EvalMode, SparkError, SparkResult};
21use arrow::{
22    array::{
23        cast::AsArray,
24        types::{Date32Type, Int16Type, Int32Type, Int8Type},
25        Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array,
26        GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
27        PrimitiveArray,
28    },
29    compute::{cast_with_options, take, unary, CastOptions},
30    datatypes::{
31        ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type,
32        TimestampMicrosecondType,
33    },
34    error::ArrowError,
35    record_batch::RecordBatch,
36    util::display::FormatOptions,
37};
38use arrow_array::builder::StringBuilder;
39use arrow_array::{DictionaryArray, StringArray, StructArray};
40use arrow_schema::{DataType, Schema};
41use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
42use datafusion_common::{
43    cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue,
44};
45use datafusion_expr::ColumnarValue;
46use datafusion_physical_expr::PhysicalExpr;
47use num::{
48    cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num,
49    ToPrimitive,
50};
51use regex::Regex;
52use std::collections::HashMap;
53use std::str::FromStr;
54use std::{
55    any::Any,
56    fmt::{Debug, Display, Formatter},
57    hash::Hash,
58    num::Wrapping,
59    sync::Arc,
60};
61
62static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
63
64const MICROS_PER_SECOND: i64 = 1000000;
65
66static CAST_OPTIONS: CastOptions = CastOptions {
67    safe: true,
68    format_options: FormatOptions::new()
69        .with_timestamp_tz_format(TIMESTAMP_FORMAT)
70        .with_timestamp_format(TIMESTAMP_FORMAT),
71};
72
73struct TimeStampInfo {
74    year: i32,
75    month: u32,
76    day: u32,
77    hour: u32,
78    minute: u32,
79    second: u32,
80    microsecond: u32,
81}
82
83impl Default for TimeStampInfo {
84    fn default() -> Self {
85        TimeStampInfo {
86            year: 1,
87            month: 1,
88            day: 1,
89            hour: 0,
90            minute: 0,
91            second: 0,
92            microsecond: 0,
93        }
94    }
95}
96
97impl TimeStampInfo {
98    pub fn with_year(&mut self, year: i32) -> &mut Self {
99        self.year = year;
100        self
101    }
102
103    pub fn with_month(&mut self, month: u32) -> &mut Self {
104        self.month = month;
105        self
106    }
107
108    pub fn with_day(&mut self, day: u32) -> &mut Self {
109        self.day = day;
110        self
111    }
112
113    pub fn with_hour(&mut self, hour: u32) -> &mut Self {
114        self.hour = hour;
115        self
116    }
117
118    pub fn with_minute(&mut self, minute: u32) -> &mut Self {
119        self.minute = minute;
120        self
121    }
122
123    pub fn with_second(&mut self, second: u32) -> &mut Self {
124        self.second = second;
125        self
126    }
127
128    pub fn with_microsecond(&mut self, microsecond: u32) -> &mut Self {
129        self.microsecond = microsecond;
130        self
131    }
132}
133
134#[derive(Debug, Eq)]
135pub struct Cast {
136    pub child: Arc<dyn PhysicalExpr>,
137    pub data_type: DataType,
138    pub cast_options: SparkCastOptions,
139}
140
141impl PartialEq for Cast {
142    fn eq(&self, other: &Self) -> bool {
143        self.child.eq(&other.child)
144            && self.data_type.eq(&other.data_type)
145            && self.cast_options.eq(&other.cast_options)
146    }
147}
148
149impl Hash for Cast {
150    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
151        self.child.hash(state);
152        self.data_type.hash(state);
153        self.cast_options.hash(state);
154    }
155}
156
157/// Determine if Comet supports a cast, taking options such as EvalMode and Timezone into account.
158pub fn cast_supported(
159    from_type: &DataType,
160    to_type: &DataType,
161    options: &SparkCastOptions,
162) -> bool {
163    use DataType::*;
164
165    let from_type = if let Dictionary(_, dt) = from_type {
166        dt
167    } else {
168        from_type
169    };
170
171    let to_type = if let Dictionary(_, dt) = to_type {
172        dt
173    } else {
174        to_type
175    };
176
177    if from_type == to_type {
178        return true;
179    }
180
181    match (from_type, to_type) {
182        (Boolean, _) => can_cast_from_boolean(to_type, options),
183        (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
184            if options.allow_cast_unsigned_ints =>
185        {
186            true
187        }
188        (Int8, _) => can_cast_from_byte(to_type, options),
189        (Int16, _) => can_cast_from_short(to_type, options),
190        (Int32, _) => can_cast_from_int(to_type, options),
191        (Int64, _) => can_cast_from_long(to_type, options),
192        (Float32, _) => can_cast_from_float(to_type, options),
193        (Float64, _) => can_cast_from_double(to_type, options),
194        (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
195        (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options),
196        (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options),
197        (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
198        (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
199        (Struct(from_fields), Struct(to_fields)) => from_fields
200            .iter()
201            .zip(to_fields.iter())
202            .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)),
203        _ => false,
204    }
205}
206
207fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool {
208    use DataType::*;
209    match to_type {
210        Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
211        Float32 | Float64 => {
212            // https://github.com/apache/datafusion-comet/issues/326
213            // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
214            // Does not support ANSI mode.
215            options.allow_incompat
216        }
217        Decimal128(_, _) => {
218            // https://github.com/apache/datafusion-comet/issues/325
219            // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
220            // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits
221
222            options.allow_incompat
223        }
224        Date32 | Date64 => {
225            // https://github.com/apache/datafusion-comet/issues/327
226            // Only supports years between 262143 BC and 262142 AD
227            options.allow_incompat
228        }
229        Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => {
230            // ANSI mode not supported
231            false
232        }
233        Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => {
234            // Cast will use UTC instead of $timeZoneId
235            options.allow_incompat
236        }
237        Timestamp(_, _) => {
238            // https://github.com/apache/datafusion-comet/issues/328
239            // Not all valid formats are supported
240            options.allow_incompat
241        }
242        _ => false,
243    }
244}
245
246fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> bool {
247    use DataType::*;
248    match from_type {
249        Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => true,
250        Float32 | Float64 => {
251            // There can be differences in precision.
252            // For example, the input \"1.4E-45\" will produce 1.0E-45 " +
253            // instead of 1.4E-45"))
254            true
255        }
256        Decimal128(_, _) => {
257            // https://github.com/apache/datafusion-comet/issues/1068
258            // There can be formatting differences in some case due to Spark using
259            // scientific notation where Comet does not
260            true
261        }
262        Binary => {
263            // https://github.com/apache/datafusion-comet/issues/377
264            // Only works for binary data representing valid UTF-8 strings
265            options.allow_incompat
266        }
267        Struct(fields) => fields
268            .iter()
269            .all(|f| can_cast_to_string(f.data_type(), options)),
270        _ => false,
271    }
272}
273
274fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) -> bool {
275    use DataType::*;
276    match to_type {
277        Timestamp(_, _) | Date32 | Date64 | Utf8 => {
278            // incompatible
279            options.allow_incompat
280        }
281        _ => {
282            // unsupported
283            false
284        }
285    }
286}
287
288fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> bool {
289    use DataType::*;
290    match to_type {
291        Boolean | Int8 | Int16 => {
292            // https://github.com/apache/datafusion-comet/issues/352
293            // this seems like an edge case that isn't important for us to support
294            false
295        }
296        Int64 => {
297            // https://github.com/apache/datafusion-comet/issues/352
298            true
299        }
300        Date32 | Date64 | Utf8 | Decimal128(_, _) => true,
301        _ => {
302            // unsupported
303            false
304        }
305    }
306}
307
308fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
309    use DataType::*;
310    matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
311}
312
313fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
314    use DataType::*;
315    matches!(
316        to_type,
317        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
318    )
319}
320
321fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
322    use DataType::*;
323    matches!(
324        to_type,
325        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
326    )
327}
328
329fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
330    use DataType::*;
331    match to_type {
332        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true,
333        Decimal128(_, _) => {
334            // incompatible: no overflow check
335            options.allow_incompat
336        }
337        _ => false,
338    }
339}
340
341fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
342    use DataType::*;
343    match to_type {
344        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
345        Decimal128(_, _) => {
346            // incompatible: no overflow check
347            options.allow_incompat
348        }
349        _ => false,
350    }
351}
352
353fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool {
354    use DataType::*;
355    matches!(
356        to_type,
357        Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _)
358    )
359}
360
361fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool {
362    use DataType::*;
363    matches!(
364        to_type,
365        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _)
366    )
367}
368
369fn can_cast_from_decimal(
370    p1: &u8,
371    _s1: &i8,
372    to_type: &DataType,
373    options: &SparkCastOptions,
374) -> bool {
375    use DataType::*;
376    match to_type {
377        Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
378        Decimal128(p2, _) => {
379            if p2 < p1 {
380                // https://github.com/apache/datafusion/issues/13492
381                // Incompatible(Some("Casting to smaller precision is not supported"))
382                options.allow_incompat
383            } else {
384                true
385            }
386        }
387        _ => false,
388    }
389}
390
391macro_rules! cast_utf8_to_int {
392    ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
393        let len = $array.len();
394        let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
395        for i in 0..len {
396            if $array.is_null(i) {
397                cast_array.append_null()
398            } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
399                cast_array.append_value(cast_value);
400            } else {
401                cast_array.append_null()
402            }
403        }
404        let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
405        result
406    }};
407}
408macro_rules! cast_utf8_to_timestamp {
409    ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
410        let len = $array.len();
411        let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
412        for i in 0..len {
413            if $array.is_null(i) {
414                cast_array.append_null()
415            } else if let Ok(Some(cast_value)) =
416                $cast_method($array.value(i).trim(), $eval_mode, $tz)
417            {
418                cast_array.append_value(cast_value);
419            } else {
420                cast_array.append_null()
421            }
422        }
423        let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
424        result
425    }};
426}
427
428macro_rules! cast_float_to_string {
429    ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{
430
431        fn cast<OffsetSize>(
432            from: &dyn Array,
433            _eval_mode: EvalMode,
434        ) -> SparkResult<ArrayRef>
435        where
436            OffsetSize: OffsetSizeTrait, {
437                let array = from.as_any().downcast_ref::<$output_type>().unwrap();
438
439                // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the
440                // result is expressed without scientific notation with at least one digit on either side of
441                // the decimal point. Otherwise, Spark uses a mantissa followed by E and an
442                // exponent. The mantissa has an optional leading minus sign followed by one digit to the
443                // left of the decimal point, and the minimal number of digits greater than zero to the
444                // right. The exponent has and optional leading minus sign.
445                // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html
446
447                const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
448                const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;
449
450                let output_array = array
451                    .iter()
452                    .map(|value| match value {
453                        Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())),
454                        Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())),
455                        Some(value)
456                            if (value.abs() < UPPER_SCIENTIFIC_BOUND
457                                && value.abs() >= LOWER_SCIENTIFIC_BOUND)
458                                || value.abs() == 0.0 =>
459                        {
460                            let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" };
461
462                            Ok(Some(format!("{value}{trailing_zero}")))
463                        }
464                        Some(value)
465                            if value.abs() >= UPPER_SCIENTIFIC_BOUND
466                                || value.abs() < LOWER_SCIENTIFIC_BOUND =>
467                        {
468                            let formatted = format!("{value:E}");
469
470                            if formatted.contains(".") {
471                                Ok(Some(formatted))
472                            } else {
473                                // `formatted` is already in scientific notation and can be split up by E
474                                // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0
475                                let prepare_number: Vec<&str> = formatted.split("E").collect();
476
477                                let coefficient = prepare_number[0];
478
479                                let exponent = prepare_number[1];
480
481                                Ok(Some(format!("{coefficient}.0E{exponent}")))
482                            }
483                        }
484                        Some(value) => Ok(Some(value.to_string())),
485                        _ => Ok(None),
486                    })
487                    .collect::<Result<GenericStringArray<OffsetSize>, SparkError>>()?;
488
489                Ok(Arc::new(output_array))
490            }
491
492        cast::<$offset_type>($from, $eval_mode)
493    }};
494}
495
496macro_rules! cast_int_to_int_macro {
497    (
498        $array: expr,
499        $eval_mode:expr,
500        $from_arrow_primitive_type: ty,
501        $to_arrow_primitive_type: ty,
502        $from_data_type: expr,
503        $to_native_type: ty,
504        $spark_from_data_type_name: expr,
505        $spark_to_data_type_name: expr
506    ) => {{
507        let cast_array = $array
508            .as_any()
509            .downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
510            .unwrap();
511        let spark_int_literal_suffix = match $from_data_type {
512            &DataType::Int64 => "L",
513            &DataType::Int16 => "S",
514            &DataType::Int8 => "T",
515            _ => "",
516        };
517
518        let output_array = match $eval_mode {
519            EvalMode::Legacy => cast_array
520                .iter()
521                .map(|value| match value {
522                    Some(value) => {
523                        Ok::<Option<$to_native_type>, SparkError>(Some(value as $to_native_type))
524                    }
525                    _ => Ok(None),
526                })
527                .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
528            _ => cast_array
529                .iter()
530                .map(|value| match value {
531                    Some(value) => {
532                        let res = <$to_native_type>::try_from(value);
533                        if res.is_err() {
534                            Err(cast_overflow(
535                                &(value.to_string() + spark_int_literal_suffix),
536                                $spark_from_data_type_name,
537                                $spark_to_data_type_name,
538                            ))
539                        } else {
540                            Ok::<Option<$to_native_type>, SparkError>(Some(res.unwrap()))
541                        }
542                    }
543                    _ => Ok(None),
544                })
545                .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
546        }?;
547        let result: SparkResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
548        result
549    }};
550}
551
552// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short.
553// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast,
554// this can cause unexpected Short/Byte cast results. Replicate this behavior.
555macro_rules! cast_float_to_int16_down {
556    (
557        $array:expr,
558        $eval_mode:expr,
559        $src_array_type:ty,
560        $dest_array_type:ty,
561        $rust_src_type:ty,
562        $rust_dest_type:ty,
563        $src_type_str:expr,
564        $dest_type_str:expr,
565        $format_str:expr
566    ) => {{
567        let cast_array = $array
568            .as_any()
569            .downcast_ref::<$src_array_type>()
570            .expect(concat!("Expected a ", stringify!($src_array_type)));
571
572        let output_array = match $eval_mode {
573            EvalMode::Ansi => cast_array
574                .iter()
575                .map(|value| match value {
576                    Some(value) => {
577                        let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX;
578                        if is_overflow {
579                            return Err(cast_overflow(
580                                &format!($format_str, value).replace("e", "E"),
581                                $src_type_str,
582                                $dest_type_str,
583                            ));
584                        }
585                        let i32_value = value as i32;
586                        <$rust_dest_type>::try_from(i32_value)
587                            .map_err(|_| {
588                                cast_overflow(
589                                    &format!($format_str, value).replace("e", "E"),
590                                    $src_type_str,
591                                    $dest_type_str,
592                                )
593                            })
594                            .map(Some)
595                    }
596                    None => Ok(None),
597                })
598                .collect::<Result<$dest_array_type, _>>()?,
599            _ => cast_array
600                .iter()
601                .map(|value| match value {
602                    Some(value) => {
603                        let i32_value = value as i32;
604                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
605                            i32_value as $rust_dest_type,
606                        ))
607                    }
608                    None => Ok(None),
609                })
610                .collect::<Result<$dest_array_type, _>>()?,
611        };
612        Ok(Arc::new(output_array) as ArrayRef)
613    }};
614}
615
616macro_rules! cast_float_to_int32_up {
617    (
618        $array:expr,
619        $eval_mode:expr,
620        $src_array_type:ty,
621        $dest_array_type:ty,
622        $rust_src_type:ty,
623        $rust_dest_type:ty,
624        $src_type_str:expr,
625        $dest_type_str:expr,
626        $max_dest_val:expr,
627        $format_str:expr
628    ) => {{
629        let cast_array = $array
630            .as_any()
631            .downcast_ref::<$src_array_type>()
632            .expect(concat!("Expected a ", stringify!($src_array_type)));
633
634        let output_array = match $eval_mode {
635            EvalMode::Ansi => cast_array
636                .iter()
637                .map(|value| match value {
638                    Some(value) => {
639                        let is_overflow =
640                            value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val;
641                        if is_overflow {
642                            return Err(cast_overflow(
643                                &format!($format_str, value).replace("e", "E"),
644                                $src_type_str,
645                                $dest_type_str,
646                            ));
647                        }
648                        Ok(Some(value as $rust_dest_type))
649                    }
650                    None => Ok(None),
651                })
652                .collect::<Result<$dest_array_type, _>>()?,
653            _ => cast_array
654                .iter()
655                .map(|value| match value {
656                    Some(value) => {
657                        Ok::<Option<$rust_dest_type>, SparkError>(Some(value as $rust_dest_type))
658                    }
659                    None => Ok(None),
660                })
661                .collect::<Result<$dest_array_type, _>>()?,
662        };
663        Ok(Arc::new(output_array) as ArrayRef)
664    }};
665}
666
667// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short.
668// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast,
669// this can cause unexpected Short/Byte cast results. Replicate this behavior.
670macro_rules! cast_decimal_to_int16_down {
671    (
672        $array:expr,
673        $eval_mode:expr,
674        $dest_array_type:ty,
675        $rust_dest_type:ty,
676        $dest_type_str:expr,
677        $precision:expr,
678        $scale:expr
679    ) => {{
680        let cast_array = $array
681            .as_any()
682            .downcast_ref::<Decimal128Array>()
683            .expect(concat!("Expected a Decimal128ArrayType"));
684
685        let output_array = match $eval_mode {
686            EvalMode::Ansi => cast_array
687                .iter()
688                .map(|value| match value {
689                    Some(value) => {
690                        let divisor = 10_i128.pow($scale as u32);
691                        let (truncated, decimal) = (value / divisor, (value % divisor).abs());
692                        let is_overflow = truncated.abs() > i32::MAX.into();
693                        if is_overflow {
694                            return Err(cast_overflow(
695                                &format!("{}.{}BD", truncated, decimal),
696                                &format!("DECIMAL({},{})", $precision, $scale),
697                                $dest_type_str,
698                            ));
699                        }
700                        let i32_value = truncated as i32;
701                        <$rust_dest_type>::try_from(i32_value)
702                            .map_err(|_| {
703                                cast_overflow(
704                                    &format!("{}.{}BD", truncated, decimal),
705                                    &format!("DECIMAL({},{})", $precision, $scale),
706                                    $dest_type_str,
707                                )
708                            })
709                            .map(Some)
710                    }
711                    None => Ok(None),
712                })
713                .collect::<Result<$dest_array_type, _>>()?,
714            _ => cast_array
715                .iter()
716                .map(|value| match value {
717                    Some(value) => {
718                        let divisor = 10_i128.pow($scale as u32);
719                        let i32_value = (value / divisor) as i32;
720                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
721                            i32_value as $rust_dest_type,
722                        ))
723                    }
724                    None => Ok(None),
725                })
726                .collect::<Result<$dest_array_type, _>>()?,
727        };
728        Ok(Arc::new(output_array) as ArrayRef)
729    }};
730}
731
732macro_rules! cast_decimal_to_int32_up {
733    (
734        $array:expr,
735        $eval_mode:expr,
736        $dest_array_type:ty,
737        $rust_dest_type:ty,
738        $dest_type_str:expr,
739        $max_dest_val:expr,
740        $precision:expr,
741        $scale:expr
742    ) => {{
743        let cast_array = $array
744            .as_any()
745            .downcast_ref::<Decimal128Array>()
746            .expect(concat!("Expected a Decimal128ArrayType"));
747
748        let output_array = match $eval_mode {
749            EvalMode::Ansi => cast_array
750                .iter()
751                .map(|value| match value {
752                    Some(value) => {
753                        let divisor = 10_i128.pow($scale as u32);
754                        let (truncated, decimal) = (value / divisor, (value % divisor).abs());
755                        let is_overflow = truncated.abs() > $max_dest_val.into();
756                        if is_overflow {
757                            return Err(cast_overflow(
758                                &format!("{}.{}BD", truncated, decimal),
759                                &format!("DECIMAL({},{})", $precision, $scale),
760                                $dest_type_str,
761                            ));
762                        }
763                        Ok(Some(truncated as $rust_dest_type))
764                    }
765                    None => Ok(None),
766                })
767                .collect::<Result<$dest_array_type, _>>()?,
768            _ => cast_array
769                .iter()
770                .map(|value| match value {
771                    Some(value) => {
772                        let divisor = 10_i128.pow($scale as u32);
773                        let truncated = value / divisor;
774                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
775                            truncated as $rust_dest_type,
776                        ))
777                    }
778                    None => Ok(None),
779                })
780                .collect::<Result<$dest_array_type, _>>()?,
781        };
782        Ok(Arc::new(output_array) as ArrayRef)
783    }};
784}
785
786impl Cast {
787    pub fn new(
788        child: Arc<dyn PhysicalExpr>,
789        data_type: DataType,
790        cast_options: SparkCastOptions,
791    ) -> Self {
792        Self {
793            child,
794            data_type,
795            cast_options,
796        }
797    }
798}
799
800/// Spark cast options
801#[derive(Debug, Clone, Hash, PartialEq, Eq)]
802pub struct SparkCastOptions {
803    /// Spark evaluation mode
804    pub eval_mode: EvalMode,
805    /// When cast from/to timezone related types, we need timezone, which will be resolved with
806    /// session local timezone by an analyzer in Spark.
807    // TODO we should change timezone to Tz to avoid repeated parsing
808    pub timezone: String,
809    /// Allow casts that are supported but not guaranteed to be 100% compatible
810    pub allow_incompat: bool,
811    /// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter)
812    pub allow_cast_unsigned_ints: bool,
813    /// We also use the cast logic for adapting Parquet schemas, so this flag is used
814    /// for that use case
815    pub is_adapting_schema: bool,
816}
817
818impl SparkCastOptions {
819    pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self {
820        Self {
821            eval_mode,
822            timezone: timezone.to_string(),
823            allow_incompat,
824            allow_cast_unsigned_ints: false,
825            is_adapting_schema: false,
826        }
827    }
828
829    pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self {
830        Self {
831            eval_mode,
832            timezone: "".to_string(),
833            allow_incompat,
834            allow_cast_unsigned_ints: false,
835            is_adapting_schema: false,
836        }
837    }
838}
839
840/// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known
841/// to be compatible, and returns an error when a not supported and not DF-compatible cast
842/// is requested.
843pub fn spark_cast(
844    arg: ColumnarValue,
845    data_type: &DataType,
846    cast_options: &SparkCastOptions,
847) -> DataFusionResult<ColumnarValue> {
848    match arg {
849        ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array(
850            array,
851            data_type,
852            cast_options,
853        )?)),
854        ColumnarValue::Scalar(scalar) => {
855            // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for
856            // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it
857            // here.
858            let array = scalar.to_array()?;
859            let scalar =
860                ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?;
861            Ok(ColumnarValue::Scalar(scalar))
862        }
863    }
864}
865
866fn cast_array(
867    array: ArrayRef,
868    to_type: &DataType,
869    cast_options: &SparkCastOptions,
870) -> DataFusionResult<ArrayRef> {
871    use DataType::*;
872    let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
873    let from_type = array.data_type().clone();
874
875    let native_cast_options: CastOptions = CastOptions {
876        safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), // take safe mode from cast_options passed
877        format_options: FormatOptions::new()
878            .with_timestamp_tz_format(TIMESTAMP_FORMAT)
879            .with_timestamp_format(TIMESTAMP_FORMAT),
880    };
881
882    let array = match &from_type {
883        Dictionary(key_type, value_type)
884            if key_type.as_ref() == &Int32
885                && (value_type.as_ref() == &Utf8 || value_type.as_ref() == &LargeUtf8) =>
886        {
887            let dict_array = array
888                .as_any()
889                .downcast_ref::<DictionaryArray<Int32Type>>()
890                .expect("Expected a dictionary array");
891
892            let casted_dictionary = DictionaryArray::<Int32Type>::new(
893                dict_array.keys().clone(),
894                cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?,
895            );
896
897            let casted_result = match to_type {
898                Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
899                _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?,
900            };
901            return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
902        }
903        _ => array,
904    };
905    let from_type = array.data_type();
906    let eval_mode = cast_options.eval_mode;
907
908    let cast_result = match (from_type, to_type) {
909        (Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
910        (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
911        (Utf8, Timestamp(_, _)) => {
912            cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
913        }
914        (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
915        (Int64, Int32)
916        | (Int64, Int16)
917        | (Int64, Int8)
918        | (Int32, Int16)
919        | (Int32, Int8)
920        | (Int16, Int8)
921            if eval_mode != EvalMode::Try =>
922        {
923            spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
924        }
925        (Utf8, Int8 | Int16 | Int32 | Int64) => {
926            cast_string_to_int::<i32>(to_type, &array, eval_mode)
927        }
928        (LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
929            cast_string_to_int::<i64>(to_type, &array, eval_mode)
930        }
931        (Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
932        (Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array, eval_mode),
933        (Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
934        (Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array, eval_mode),
935        (Float32, Decimal128(precision, scale)) => {
936            cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
937        }
938        (Float64, Decimal128(precision, scale)) => {
939            cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
940        }
941        (Float32, Int8)
942        | (Float32, Int16)
943        | (Float32, Int32)
944        | (Float32, Int64)
945        | (Float64, Int8)
946        | (Float64, Int16)
947        | (Float64, Int32)
948        | (Float64, Int64)
949        | (Decimal128(_, _), Int8)
950        | (Decimal128(_, _), Int16)
951        | (Decimal128(_, _), Int32)
952        | (Decimal128(_, _), Int64)
953            if eval_mode != EvalMode::Try =>
954        {
955            spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
956        }
957        (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
958        (Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
959            array.as_struct(),
960            from_type,
961            to_type,
962            cast_options,
963        )?),
964        (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
965            if cast_options.allow_cast_unsigned_ints =>
966        {
967            Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
968        }
969        _ if cast_options.is_adapting_schema
970            || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
971        {
972            // use DataFusion cast only when we know that it is compatible with Spark
973            Ok(cast_with_options(&array, to_type, &native_cast_options)?)
974        }
975        _ => {
976            // we should never reach this code because the Scala code should be checking
977            // for supported cast operations and falling back to Spark for anything that
978            // is not yet supported
979            Err(SparkError::Internal(format!(
980                "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}"
981            )))
982        }
983    };
984    Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
985}
986
987/// Determines if DataFusion supports the given cast in a way that is
988/// compatible with Spark
989fn is_datafusion_spark_compatible(
990    from_type: &DataType,
991    to_type: &DataType,
992    allow_incompat: bool,
993) -> bool {
994    if from_type == to_type {
995        return true;
996    }
997    match from_type {
998        DataType::Null => {
999            matches!(to_type, DataType::List(_))
1000        }
1001        DataType::Boolean => matches!(
1002            to_type,
1003            DataType::Int8
1004                | DataType::Int16
1005                | DataType::Int32
1006                | DataType::Int64
1007                | DataType::Float32
1008                | DataType::Float64
1009                | DataType::Utf8
1010        ),
1011        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1012            // note that the cast from Int32/Int64 -> Decimal128 here is actually
1013            // not compatible with Spark (no overflow checks) but we have tests that
1014            // rely on this cast working so we have to leave it here for now
1015            matches!(
1016                to_type,
1017                DataType::Boolean
1018                    | DataType::Int8
1019                    | DataType::Int16
1020                    | DataType::Int32
1021                    | DataType::Int64
1022                    | DataType::Float32
1023                    | DataType::Float64
1024                    | DataType::Decimal128(_, _)
1025                    | DataType::Utf8
1026            )
1027        }
1028        DataType::Float32 | DataType::Float64 => matches!(
1029            to_type,
1030            DataType::Boolean
1031                | DataType::Int8
1032                | DataType::Int16
1033                | DataType::Int32
1034                | DataType::Int64
1035                | DataType::Float32
1036                | DataType::Float64
1037        ),
1038        DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1039            to_type,
1040            DataType::Int8
1041                | DataType::Int16
1042                | DataType::Int32
1043                | DataType::Int64
1044                | DataType::Float32
1045                | DataType::Float64
1046                | DataType::Decimal128(_, _)
1047                | DataType::Decimal256(_, _)
1048                | DataType::Utf8 // note that there can be formatting differences
1049        ),
1050        DataType::Utf8 if allow_incompat => matches!(
1051            to_type,
1052            DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
1053        ),
1054        DataType::Utf8 => matches!(to_type, DataType::Binary),
1055        DataType::Date32 => matches!(to_type, DataType::Utf8),
1056        DataType::Timestamp(_, _) => {
1057            matches!(
1058                to_type,
1059                DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1060            )
1061        }
1062        DataType::Binary => {
1063            // note that this is not completely Spark compatible because
1064            // DataFusion only supports binary data containing valid UTF-8 strings
1065            matches!(to_type, DataType::Utf8)
1066        }
1067        _ => false,
1068    }
1069}
1070
1071/// Cast between struct types based on logic in
1072/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`.
1073fn cast_struct_to_struct(
1074    array: &StructArray,
1075    from_type: &DataType,
1076    to_type: &DataType,
1077    cast_options: &SparkCastOptions,
1078) -> DataFusionResult<ArrayRef> {
1079    match (from_type, to_type) {
1080        (DataType::Struct(from_fields), DataType::Struct(to_fields)) => {
1081            // TODO some of this logic may be specific to converting Parquet to Spark
1082            let mut field_name_to_index_map = HashMap::new();
1083            for (i, field) in from_fields.iter().enumerate() {
1084                field_name_to_index_map.insert(field.name(), i);
1085            }
1086            assert_eq!(field_name_to_index_map.len(), from_fields.len());
1087            let mut cast_fields: Vec<ArrayRef> = Vec::with_capacity(to_fields.len());
1088            for i in 0..to_fields.len() {
1089                let from_index = field_name_to_index_map[to_fields[i].name()];
1090                let cast_field = cast_array(
1091                    Arc::clone(array.column(from_index)),
1092                    to_fields[i].data_type(),
1093                    cast_options,
1094                )?;
1095                cast_fields.push(cast_field);
1096            }
1097            Ok(Arc::new(StructArray::new(
1098                to_fields.clone(),
1099                cast_fields,
1100                array.nulls().cloned(),
1101            )))
1102        }
1103        _ => unreachable!(),
1104    }
1105}
1106
1107fn casts_struct_to_string(
1108    array: &StructArray,
1109    spark_cast_options: &SparkCastOptions,
1110) -> DataFusionResult<ArrayRef> {
1111    // cast each field to a string
1112    let string_arrays: Vec<ArrayRef> = array
1113        .columns()
1114        .iter()
1115        .map(|arr| {
1116            spark_cast(
1117                ColumnarValue::Array(Arc::clone(arr)),
1118                &DataType::Utf8,
1119                spark_cast_options,
1120            )
1121            .and_then(|cv| cv.into_array(arr.len()))
1122        })
1123        .collect::<DataFusionResult<Vec<_>>>()?;
1124    let string_arrays: Vec<&StringArray> =
1125        string_arrays.iter().map(|arr| arr.as_string()).collect();
1126    // build the struct string containing entries in the format `"field_name":field_value`
1127    let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
1128    let mut str = String::with_capacity(array.len() * 16);
1129    for row_index in 0..array.len() {
1130        if array.is_null(row_index) {
1131            builder.append_null();
1132        } else {
1133            str.clear();
1134            let mut any_fields_written = false;
1135            str.push('{');
1136            for field in &string_arrays {
1137                if any_fields_written {
1138                    str.push_str(", ");
1139                }
1140                if field.is_null(row_index) {
1141                    str.push_str("null");
1142                } else {
1143                    str.push_str(field.value(row_index));
1144                }
1145                any_fields_written = true;
1146            }
1147            str.push('}');
1148            builder.append_value(&str);
1149        }
1150    }
1151    Ok(Arc::new(builder.finish()))
1152}
1153
1154fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
1155    to_type: &DataType,
1156    array: &ArrayRef,
1157    eval_mode: EvalMode,
1158) -> SparkResult<ArrayRef> {
1159    let string_array = array
1160        .as_any()
1161        .downcast_ref::<GenericStringArray<OffsetSize>>()
1162        .expect("cast_string_to_int expected a string array");
1163
1164    let cast_array: ArrayRef = match to_type {
1165        DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?,
1166        DataType::Int16 => {
1167            cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)?
1168        }
1169        DataType::Int32 => {
1170            cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)?
1171        }
1172        DataType::Int64 => {
1173            cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)?
1174        }
1175        dt => unreachable!(
1176            "{}",
1177            format!("invalid integer type {dt} in cast from string")
1178        ),
1179    };
1180    Ok(cast_array)
1181}
1182
1183fn cast_string_to_date(
1184    array: &ArrayRef,
1185    to_type: &DataType,
1186    eval_mode: EvalMode,
1187) -> SparkResult<ArrayRef> {
1188    let string_array = array
1189        .as_any()
1190        .downcast_ref::<GenericStringArray<i32>>()
1191        .expect("Expected a string array");
1192
1193    if to_type != &DataType::Date32 {
1194        unreachable!("Invalid data type {:?} in cast from string", to_type);
1195    }
1196
1197    let len = string_array.len();
1198    let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
1199
1200    for i in 0..len {
1201        let value = if string_array.is_null(i) {
1202            None
1203        } else {
1204            match date_parser(string_array.value(i), eval_mode) {
1205                Ok(Some(cast_value)) => Some(cast_value),
1206                Ok(None) => None,
1207                Err(e) => return Err(e),
1208            }
1209        };
1210
1211        match value {
1212            Some(cast_value) => cast_array.append_value(cast_value),
1213            None => cast_array.append_null(),
1214        }
1215    }
1216
1217    Ok(Arc::new(cast_array.finish()) as ArrayRef)
1218}
1219
1220fn cast_string_to_timestamp(
1221    array: &ArrayRef,
1222    to_type: &DataType,
1223    eval_mode: EvalMode,
1224    timezone_str: &str,
1225) -> SparkResult<ArrayRef> {
1226    let string_array = array
1227        .as_any()
1228        .downcast_ref::<GenericStringArray<i32>>()
1229        .expect("Expected a string array");
1230
1231    let tz = &timezone::Tz::from_str(timezone_str).unwrap();
1232
1233    let cast_array: ArrayRef = match to_type {
1234        DataType::Timestamp(_, _) => {
1235            cast_utf8_to_timestamp!(
1236                string_array,
1237                eval_mode,
1238                TimestampMicrosecondType,
1239                timestamp_parser,
1240                tz
1241            )
1242        }
1243        _ => unreachable!("Invalid data type {:?} in cast from string", to_type),
1244    };
1245    Ok(cast_array)
1246}
1247
1248fn cast_float64_to_decimal128(
1249    array: &dyn Array,
1250    precision: u8,
1251    scale: i8,
1252    eval_mode: EvalMode,
1253) -> SparkResult<ArrayRef> {
1254    cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode)
1255}
1256
1257fn cast_float32_to_decimal128(
1258    array: &dyn Array,
1259    precision: u8,
1260    scale: i8,
1261    eval_mode: EvalMode,
1262) -> SparkResult<ArrayRef> {
1263    cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode)
1264}
1265
1266fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
1267    array: &dyn Array,
1268    precision: u8,
1269    scale: i8,
1270    eval_mode: EvalMode,
1271) -> SparkResult<ArrayRef>
1272where
1273    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
1274{
1275    let input = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1276    let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());
1277
1278    let mul = 10_f64.powi(scale as i32);
1279
1280    for i in 0..input.len() {
1281        if input.is_null(i) {
1282            cast_array.append_null();
1283        } else {
1284            let input_value = input.value(i).as_();
1285            let value = (input_value * mul).round().to_i128();
1286
1287            match value {
1288                Some(v) => {
1289                    if Decimal128Type::validate_decimal_precision(v, precision).is_err() {
1290                        if eval_mode == EvalMode::Ansi {
1291                            return Err(SparkError::NumericValueOutOfRange {
1292                                value: input_value.to_string(),
1293                                precision,
1294                                scale,
1295                            });
1296                        } else {
1297                            cast_array.append_null();
1298                        }
1299                    }
1300                    cast_array.append_value(v);
1301                }
1302                None => {
1303                    if eval_mode == EvalMode::Ansi {
1304                        return Err(SparkError::NumericValueOutOfRange {
1305                            value: input_value.to_string(),
1306                            precision,
1307                            scale,
1308                        });
1309                    } else {
1310                        cast_array.append_null();
1311                    }
1312                }
1313            }
1314        }
1315    }
1316
1317    let res = Arc::new(
1318        cast_array
1319            .with_precision_and_scale(precision, scale)?
1320            .finish(),
1321    ) as ArrayRef;
1322    Ok(res)
1323}
1324
1325fn spark_cast_float64_to_utf8<OffsetSize>(
1326    from: &dyn Array,
1327    _eval_mode: EvalMode,
1328) -> SparkResult<ArrayRef>
1329where
1330    OffsetSize: OffsetSizeTrait,
1331{
1332    cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
1333}
1334
1335fn spark_cast_float32_to_utf8<OffsetSize>(
1336    from: &dyn Array,
1337    _eval_mode: EvalMode,
1338) -> SparkResult<ArrayRef>
1339where
1340    OffsetSize: OffsetSizeTrait,
1341{
1342    cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
1343}
1344
1345fn spark_cast_int_to_int(
1346    array: &dyn Array,
1347    eval_mode: EvalMode,
1348    from_type: &DataType,
1349    to_type: &DataType,
1350) -> SparkResult<ArrayRef> {
1351    match (from_type, to_type) {
1352        (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
1353            array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
1354        ),
1355        (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
1356            array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
1357        ),
1358        (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
1359            array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
1360        ),
1361        (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
1362            array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
1363        ),
1364        (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
1365            array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
1366        ),
1367        (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
1368            array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
1369        ),
1370        _ => unreachable!(
1371            "{}",
1372            format!("invalid integer type {to_type} in cast from {from_type}")
1373        ),
1374    }
1375}
1376
1377fn spark_cast_utf8_to_boolean<OffsetSize>(
1378    from: &dyn Array,
1379    eval_mode: EvalMode,
1380) -> SparkResult<ArrayRef>
1381where
1382    OffsetSize: OffsetSizeTrait,
1383{
1384    let array = from
1385        .as_any()
1386        .downcast_ref::<GenericStringArray<OffsetSize>>()
1387        .unwrap();
1388
1389    let output_array = array
1390        .iter()
1391        .map(|value| match value {
1392            Some(value) => match value.to_ascii_lowercase().trim() {
1393                "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
1394                "f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
1395                _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue {
1396                    value: value.to_string(),
1397                    from_type: "STRING".to_string(),
1398                    to_type: "BOOLEAN".to_string(),
1399                }),
1400                _ => Ok(None),
1401            },
1402            _ => Ok(None),
1403        })
1404        .collect::<Result<BooleanArray, _>>()?;
1405
1406    Ok(Arc::new(output_array))
1407}
1408
1409fn spark_cast_nonintegral_numeric_to_integral(
1410    array: &dyn Array,
1411    eval_mode: EvalMode,
1412    from_type: &DataType,
1413    to_type: &DataType,
1414) -> SparkResult<ArrayRef> {
1415    match (from_type, to_type) {
1416        (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
1417            array,
1418            eval_mode,
1419            Float32Array,
1420            Int8Array,
1421            f32,
1422            i8,
1423            "FLOAT",
1424            "TINYINT",
1425            "{:e}"
1426        ),
1427        (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!(
1428            array,
1429            eval_mode,
1430            Float32Array,
1431            Int16Array,
1432            f32,
1433            i16,
1434            "FLOAT",
1435            "SMALLINT",
1436            "{:e}"
1437        ),
1438        (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!(
1439            array,
1440            eval_mode,
1441            Float32Array,
1442            Int32Array,
1443            f32,
1444            i32,
1445            "FLOAT",
1446            "INT",
1447            i32::MAX,
1448            "{:e}"
1449        ),
1450        (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
1451            array,
1452            eval_mode,
1453            Float32Array,
1454            Int64Array,
1455            f32,
1456            i64,
1457            "FLOAT",
1458            "BIGINT",
1459            i64::MAX,
1460            "{:e}"
1461        ),
1462        (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
1463            array,
1464            eval_mode,
1465            Float64Array,
1466            Int8Array,
1467            f64,
1468            i8,
1469            "DOUBLE",
1470            "TINYINT",
1471            "{:e}D"
1472        ),
1473        (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!(
1474            array,
1475            eval_mode,
1476            Float64Array,
1477            Int16Array,
1478            f64,
1479            i16,
1480            "DOUBLE",
1481            "SMALLINT",
1482            "{:e}D"
1483        ),
1484        (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!(
1485            array,
1486            eval_mode,
1487            Float64Array,
1488            Int32Array,
1489            f64,
1490            i32,
1491            "DOUBLE",
1492            "INT",
1493            i32::MAX,
1494            "{:e}D"
1495        ),
1496        (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
1497            array,
1498            eval_mode,
1499            Float64Array,
1500            Int64Array,
1501            f64,
1502            i64,
1503            "DOUBLE",
1504            "BIGINT",
1505            i64::MAX,
1506            "{:e}D"
1507        ),
1508        (DataType::Decimal128(precision, scale), DataType::Int8) => {
1509            cast_decimal_to_int16_down!(
1510                array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1511            )
1512        }
1513        (DataType::Decimal128(precision, scale), DataType::Int16) => {
1514            cast_decimal_to_int16_down!(
1515                array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1516            )
1517        }
1518        (DataType::Decimal128(precision, scale), DataType::Int32) => {
1519            cast_decimal_to_int32_up!(
1520                array,
1521                eval_mode,
1522                Int32Array,
1523                i32,
1524                "INT",
1525                i32::MAX,
1526                *precision,
1527                *scale
1528            )
1529        }
1530        (DataType::Decimal128(precision, scale), DataType::Int64) => {
1531            cast_decimal_to_int32_up!(
1532                array,
1533                eval_mode,
1534                Int64Array,
1535                i64,
1536                "BIGINT",
1537                i64::MAX,
1538                *precision,
1539                *scale
1540            )
1541        }
1542        _ => unreachable!(
1543            "{}",
1544            format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}")
1545        ),
1546    }
1547}
1548
1549/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
1550fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> {
1551    Ok(cast_string_to_int_with_range_check(
1552        str,
1553        eval_mode,
1554        "TINYINT",
1555        i8::MIN as i32,
1556        i8::MAX as i32,
1557    )?
1558    .map(|v| v as i8))
1559}
1560
1561/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
1562fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>> {
1563    Ok(cast_string_to_int_with_range_check(
1564        str,
1565        eval_mode,
1566        "SMALLINT",
1567        i16::MIN as i32,
1568        i16::MAX as i32,
1569    )?
1570    .map(|v| v as i16))
1571}
1572
1573/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper)
1574fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1575    do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
1576}
1577
1578/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper)
1579fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> {
1580    do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
1581}
1582
1583fn cast_string_to_int_with_range_check(
1584    str: &str,
1585    eval_mode: EvalMode,
1586    type_name: &str,
1587    min: i32,
1588    max: i32,
1589) -> SparkResult<Option<i32>> {
1590    match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
1591        None => Ok(None),
1592        Some(v) if v >= min && v <= max => Ok(Some(v)),
1593        _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1594        _ => Ok(None),
1595    }
1596}
1597
1598/// Equivalent to
1599/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal)
1600/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal)
1601fn do_cast_string_to_int<
1602    T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
1603>(
1604    str: &str,
1605    eval_mode: EvalMode,
1606    type_name: &str,
1607    min_value: T,
1608) -> SparkResult<Option<T>> {
1609    let trimmed_str = str.trim();
1610    if trimmed_str.is_empty() {
1611        return none_or_err(eval_mode, type_name, str);
1612    }
1613    let len = trimmed_str.len();
1614    let mut result: T = T::zero();
1615    let mut negative = false;
1616    let radix = T::from(10);
1617    let stop_value = min_value / radix;
1618    let mut parse_sign_and_digits = true;
1619
1620    for (i, ch) in trimmed_str.char_indices() {
1621        if parse_sign_and_digits {
1622            if i == 0 {
1623                negative = ch == '-';
1624                let positive = ch == '+';
1625                if negative || positive {
1626                    if i + 1 == len {
1627                        // input string is just "+" or "-"
1628                        return none_or_err(eval_mode, type_name, str);
1629                    }
1630                    // consume this char
1631                    continue;
1632                }
1633            }
1634
1635            if ch == '.' {
1636                if eval_mode == EvalMode::Legacy {
1637                    // truncate decimal in legacy mode
1638                    parse_sign_and_digits = false;
1639                    continue;
1640                } else {
1641                    return none_or_err(eval_mode, type_name, str);
1642                }
1643            }
1644
1645            let digit = if ch.is_ascii_digit() {
1646                (ch as u32) - ('0' as u32)
1647            } else {
1648                return none_or_err(eval_mode, type_name, str);
1649            };
1650
1651            // We are going to process the new digit and accumulate the result. However, before
1652            // doing this, if the result is already smaller than the
1653            // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be
1654            // smaller than minValue, and we can stop
1655            if result < stop_value {
1656                return none_or_err(eval_mode, type_name, str);
1657            }
1658
1659            // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE /
1660            // radix), we can just use `result > 0` to check overflow. If result
1661            // overflows, we should stop
1662            let v = result * radix;
1663            let digit = (digit as i32).into();
1664            match v.checked_sub(&digit) {
1665                Some(x) if x <= T::zero() => result = x,
1666                _ => {
1667                    return none_or_err(eval_mode, type_name, str);
1668                }
1669            }
1670        } else {
1671            // make sure fractional digits are valid digits but ignore them
1672            if !ch.is_ascii_digit() {
1673                return none_or_err(eval_mode, type_name, str);
1674            }
1675        }
1676    }
1677
1678    if !negative {
1679        if let Some(neg) = result.checked_neg() {
1680            if neg < T::zero() {
1681                return none_or_err(eval_mode, type_name, str);
1682            }
1683            result = neg;
1684        } else {
1685            return none_or_err(eval_mode, type_name, str);
1686        }
1687    }
1688
1689    Ok(Some(result))
1690}
1691
1692/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode
1693#[inline]
1694fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {
1695    match eval_mode {
1696        EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1697        _ => Ok(None),
1698    }
1699}
1700
1701#[inline]
1702fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
1703    SparkError::CastInvalidValue {
1704        value: value.to_string(),
1705        from_type: from_type.to_string(),
1706        to_type: to_type.to_string(),
1707    }
1708}
1709
1710#[inline]
1711fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
1712    SparkError::CastOverFlow {
1713        value: value.to_string(),
1714        from_type: from_type.to_string(),
1715        to_type: to_type.to_string(),
1716    }
1717}
1718
1719impl Display for Cast {
1720    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1721        write!(
1722            f,
1723            "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]",
1724            self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode
1725        )
1726    }
1727}
1728
1729impl PhysicalExpr for Cast {
1730    fn as_any(&self) -> &dyn Any {
1731        self
1732    }
1733
1734    fn data_type(&self, _: &Schema) -> DataFusionResult<DataType> {
1735        Ok(self.data_type.clone())
1736    }
1737
1738    fn nullable(&self, _: &Schema) -> DataFusionResult<bool> {
1739        Ok(true)
1740    }
1741
1742    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
1743        let arg = self.child.evaluate(batch)?;
1744        spark_cast(arg, &self.data_type, &self.cast_options)
1745    }
1746
1747    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1748        vec![&self.child]
1749    }
1750
1751    fn with_new_children(
1752        self: Arc<Self>,
1753        children: Vec<Arc<dyn PhysicalExpr>>,
1754    ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
1755        match children.len() {
1756            1 => Ok(Arc::new(Cast::new(
1757                Arc::clone(&children[0]),
1758                self.data_type.clone(),
1759                self.cast_options.clone(),
1760            ))),
1761            _ => internal_err!("Cast should have exactly one child"),
1762        }
1763    }
1764}
1765
1766fn timestamp_parser<T: TimeZone>(
1767    value: &str,
1768    eval_mode: EvalMode,
1769    tz: &T,
1770) -> SparkResult<Option<i64>> {
1771    let value = value.trim();
1772    if value.is_empty() {
1773        return Ok(None);
1774    }
1775    // Define regex patterns and corresponding parsing functions
1776    let patterns = &[
1777        (
1778            Regex::new(r"^\d{4,5}$").unwrap(),
1779            parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
1780        ),
1781        (
1782            Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1783            parse_str_to_month_timestamp,
1784        ),
1785        (
1786            Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1787            parse_str_to_day_timestamp,
1788        ),
1789        (
1790            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1791            parse_str_to_hour_timestamp,
1792        ),
1793        (
1794            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1795            parse_str_to_minute_timestamp,
1796        ),
1797        (
1798            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1799            parse_str_to_second_timestamp,
1800        ),
1801        (
1802            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1803            parse_str_to_microsecond_timestamp,
1804        ),
1805        (
1806            Regex::new(r"^T\d{1,2}$").unwrap(),
1807            parse_str_to_time_only_timestamp,
1808        ),
1809    ];
1810
1811    let mut timestamp = None;
1812
1813    // Iterate through patterns and try matching
1814    for (pattern, parse_func) in patterns {
1815        if pattern.is_match(value) {
1816            timestamp = parse_func(value, tz)?;
1817            break;
1818        }
1819    }
1820
1821    if timestamp.is_none() {
1822        return if eval_mode == EvalMode::Ansi {
1823            Err(SparkError::CastInvalidValue {
1824                value: value.to_string(),
1825                from_type: "STRING".to_string(),
1826                to_type: "TIMESTAMP".to_string(),
1827            })
1828        } else {
1829            Ok(None)
1830        };
1831    }
1832
1833    match timestamp {
1834        Some(ts) => Ok(Some(ts)),
1835        None => Err(SparkError::Internal(
1836            "Failed to parse timestamp".to_string(),
1837        )),
1838    }
1839}
1840
1841fn parse_timestamp_to_micros<T: TimeZone>(
1842    timestamp_info: &TimeStampInfo,
1843    tz: &T,
1844) -> SparkResult<Option<i64>> {
1845    let datetime = tz.with_ymd_and_hms(
1846        timestamp_info.year,
1847        timestamp_info.month,
1848        timestamp_info.day,
1849        timestamp_info.hour,
1850        timestamp_info.minute,
1851        timestamp_info.second,
1852    );
1853
1854    // Check if datetime is not None
1855    let tz_datetime = match datetime.single() {
1856        Some(dt) => dt
1857            .with_timezone(tz)
1858            .with_nanosecond(timestamp_info.microsecond * 1000),
1859        None => {
1860            return Err(SparkError::Internal(
1861                "Failed to parse timestamp".to_string(),
1862            ));
1863        }
1864    };
1865
1866    let result = match tz_datetime {
1867        Some(dt) => dt.timestamp_micros(),
1868        None => {
1869            return Err(SparkError::Internal(
1870                "Failed to parse timestamp".to_string(),
1871            ));
1872        }
1873    };
1874
1875    Ok(Some(result))
1876}
1877
1878fn get_timestamp_values<T: TimeZone>(
1879    value: &str,
1880    timestamp_type: &str,
1881    tz: &T,
1882) -> SparkResult<Option<i64>> {
1883    let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
1884    let year = values[0].parse::<i32>().unwrap_or_default();
1885    let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
1886    let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
1887    let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
1888    let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
1889    let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
1890    let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));
1891
1892    let mut timestamp_info = TimeStampInfo::default();
1893
1894    let timestamp_info = match timestamp_type {
1895        "year" => timestamp_info.with_year(year),
1896        "month" => timestamp_info.with_year(year).with_month(month),
1897        "day" => timestamp_info
1898            .with_year(year)
1899            .with_month(month)
1900            .with_day(day),
1901        "hour" => timestamp_info
1902            .with_year(year)
1903            .with_month(month)
1904            .with_day(day)
1905            .with_hour(hour),
1906        "minute" => timestamp_info
1907            .with_year(year)
1908            .with_month(month)
1909            .with_day(day)
1910            .with_hour(hour)
1911            .with_minute(minute),
1912        "second" => timestamp_info
1913            .with_year(year)
1914            .with_month(month)
1915            .with_day(day)
1916            .with_hour(hour)
1917            .with_minute(minute)
1918            .with_second(second),
1919        "microsecond" => timestamp_info
1920            .with_year(year)
1921            .with_month(month)
1922            .with_day(day)
1923            .with_hour(hour)
1924            .with_minute(minute)
1925            .with_second(second)
1926            .with_microsecond(microsecond),
1927        _ => {
1928            return Err(SparkError::CastInvalidValue {
1929                value: value.to_string(),
1930                from_type: "STRING".to_string(),
1931                to_type: "TIMESTAMP".to_string(),
1932            })
1933        }
1934    };
1935
1936    parse_timestamp_to_micros(timestamp_info, tz)
1937}
1938
1939fn parse_str_to_year_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1940    get_timestamp_values(value, "year", tz)
1941}
1942
1943fn parse_str_to_month_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1944    get_timestamp_values(value, "month", tz)
1945}
1946
1947fn parse_str_to_day_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1948    get_timestamp_values(value, "day", tz)
1949}
1950
1951fn parse_str_to_hour_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1952    get_timestamp_values(value, "hour", tz)
1953}
1954
1955fn parse_str_to_minute_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1956    get_timestamp_values(value, "minute", tz)
1957}
1958
1959fn parse_str_to_second_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1960    get_timestamp_values(value, "second", tz)
1961}
1962
1963fn parse_str_to_microsecond_timestamp<T: TimeZone>(
1964    value: &str,
1965    tz: &T,
1966) -> SparkResult<Option<i64>> {
1967    get_timestamp_values(value, "microsecond", tz)
1968}
1969
1970fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1971    let values: Vec<&str> = value.split('T').collect();
1972    let time_values: Vec<u32> = values[1]
1973        .split(':')
1974        .map(|v| v.parse::<u32>().unwrap_or(0))
1975        .collect();
1976
1977    let datetime = tz.from_utc_datetime(&chrono::Utc::now().naive_utc());
1978    let timestamp = datetime
1979        .with_timezone(tz)
1980        .with_hour(time_values.first().copied().unwrap_or_default())
1981        .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
1982        .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
1983        .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000))
1984        .map(|dt| dt.timestamp_micros())
1985        .unwrap_or_default();
1986
1987    Ok(Some(timestamp))
1988}
1989
1990//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate.
1991fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1992    // local functions
1993    fn get_trimmed_start(bytes: &[u8]) -> usize {
1994        let mut start = 0;
1995        while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) {
1996            start += 1;
1997        }
1998        start
1999    }
2000
2001    fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize {
2002        let mut end = bytes.len() - 1;
2003        while end > start && is_whitespace_or_iso_control(bytes[end]) {
2004            end -= 1;
2005        }
2006        end + 1
2007    }
2008
2009    fn is_whitespace_or_iso_control(byte: u8) -> bool {
2010        byte.is_ascii_whitespace() || byte.is_ascii_control()
2011    }
2012
2013    fn is_valid_digits(segment: i32, digits: usize) -> bool {
2014        // An integer is able to represent a date within [+-]5 million years.
2015        let max_digits_year = 7;
2016        //year (segment 0) can be between 4 to 7 digits,
2017        //month and day (segment 1 and 2) can be between 1 to 2 digits
2018        (segment == 0 && digits >= 4 && digits <= max_digits_year)
2019            || (segment != 0 && digits > 0 && digits <= 2)
2020    }
2021
2022    fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
2023        if eval_mode == EvalMode::Ansi {
2024            Err(SparkError::CastInvalidValue {
2025                value: date_str.to_string(),
2026                from_type: "STRING".to_string(),
2027                to_type: "DATE".to_string(),
2028            })
2029        } else {
2030            Ok(None)
2031        }
2032    }
2033    // end local functions
2034
2035    if date_str.is_empty() {
2036        return return_result(date_str, eval_mode);
2037    }
2038
2039    //values of date segments year, month and day defaulting to 1
2040    let mut date_segments = [1, 1, 1];
2041    let mut sign = 1;
2042    let mut current_segment = 0;
2043    let mut current_segment_value = Wrapping(0);
2044    let mut current_segment_digits = 0;
2045    let bytes = date_str.as_bytes();
2046
2047    let mut j = get_trimmed_start(bytes);
2048    let str_end_trimmed = get_trimmed_end(j, bytes);
2049
2050    if j == str_end_trimmed {
2051        return return_result(date_str, eval_mode);
2052    }
2053
2054    //assign a sign to the date
2055    if bytes[j] == b'-' || bytes[j] == b'+' {
2056        sign = if bytes[j] == b'-' { -1 } else { 1 };
2057        j += 1;
2058    }
2059
2060    //loop to the end of string until we have processed 3 segments,
2061    //exit loop on encountering any space ' ' or 'T' after the 3rd segment
2062    while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) {
2063        let b = bytes[j];
2064        if current_segment < 2 && b == b'-' {
2065            //check for validity of year and month segments if current byte is separator
2066            if !is_valid_digits(current_segment, current_segment_digits) {
2067                return return_result(date_str, eval_mode);
2068            }
2069            //if valid update corresponding segment with the current segment value.
2070            date_segments[current_segment as usize] = current_segment_value.0;
2071            current_segment_value = Wrapping(0);
2072            current_segment_digits = 0;
2073            current_segment += 1;
2074        } else if !b.is_ascii_digit() {
2075            return return_result(date_str, eval_mode);
2076        } else {
2077            //increment value of current segment by the next digit
2078            let parsed_value = Wrapping((b - b'0') as i32);
2079            current_segment_value = current_segment_value * Wrapping(10) + parsed_value;
2080            current_segment_digits += 1;
2081        }
2082        j += 1;
2083    }
2084
2085    //check for validity of last segment
2086    if !is_valid_digits(current_segment, current_segment_digits) {
2087        return return_result(date_str, eval_mode);
2088    }
2089
2090    if current_segment < 2 && j < str_end_trimmed {
2091        // For the `yyyy` and `yyyy-[m]m` formats, entire input must be consumed.
2092        return return_result(date_str, eval_mode);
2093    }
2094
2095    date_segments[current_segment as usize] = current_segment_value.0;
2096
2097    match NaiveDate::from_ymd_opt(
2098        sign * date_segments[0],
2099        date_segments[1] as u32,
2100        date_segments[2] as u32,
2101    ) {
2102        Some(date) => {
2103            let duration_since_epoch = date
2104                .signed_duration_since(NaiveDateTime::UNIX_EPOCH.date())
2105                .num_days();
2106            Ok(Some(duration_since_epoch.to_i32().unwrap()))
2107        }
2108        None => Ok(None),
2109    }
2110}
2111
2112/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
2113/// This function runs as a post process of the DataFusion cast(). By the time it arrives here,
2114/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify
2115/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in
2116/// expressions/cast.rs, so it can be still Dictionary.
2117fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef {
2118    match (from_type, to_type) {
2119        (DataType::Timestamp(_, _), DataType::Int64) => {
2120            // See Spark's `Cast` expression
2121            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2122        }
2123        (DataType::Dictionary(_, value_type), DataType::Int64)
2124            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2125        {
2126            // See Spark's `Cast` expression
2127            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2128        }
2129        (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array),
2130        (DataType::Dictionary(_, value_type), DataType::Utf8)
2131            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2132        {
2133            remove_trailing_zeroes(array)
2134        }
2135        _ => array,
2136    }
2137}
2138
2139/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
2140fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
2141where
2142    T: ArrowPrimitiveType,
2143    F: Fn(T::Native) -> T::Native,
2144{
2145    if let Some(d) = array.as_any_dictionary_opt() {
2146        let new_values = unary_dyn::<F, T>(d.values(), op)?;
2147        return Ok(Arc::new(d.with_values(Arc::new(new_values))));
2148    }
2149
2150    match array.as_primitive_opt::<T>() {
2151        Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
2152            Ok(Arc::new(unary::<T, F, T>(
2153                array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
2154                op,
2155            )))
2156        }
2157        _ => Err(ArrowError::NotYetImplemented(format!(
2158            "Cannot perform unary operation of type {} on array of type {}",
2159            T::DATA_TYPE,
2160            array.data_type()
2161        ))),
2162    }
2163}
2164
2165/// Remove any trailing zeroes in the string if they occur after in the fractional seconds,
2166/// to match Spark behavior
2167/// example:
2168/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
2169/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
2170/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
2171/// "1970-01-01 05:30:00"     => "1970-01-01 05:30:00"
2172/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
2173fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
2174    let string_array = as_generic_string_array::<i32>(&array).unwrap();
2175    let result = string_array
2176        .iter()
2177        .map(|s| s.map(trim_end))
2178        .collect::<GenericStringArray<i32>>();
2179    Arc::new(result) as ArrayRef
2180}
2181
2182fn trim_end(s: &str) -> &str {
2183    if s.rfind('.').is_some() {
2184        s.trim_end_matches('0')
2185    } else {
2186        s
2187    }
2188}
2189
2190#[cfg(test)]
2191mod tests {
2192    use arrow::datatypes::TimestampMicrosecondType;
2193    use arrow_array::StringArray;
2194    use arrow_schema::{Field, Fields, TimeUnit};
2195    use std::str::FromStr;
2196
2197    use super::*;
2198
2199    #[test]
2200    #[cfg_attr(miri, ignore)] // test takes too long with miri
2201    fn timestamp_parser_test() {
2202        let tz = &timezone::Tz::from_str("UTC").unwrap();
2203        // write for all formats
2204        assert_eq!(
2205            timestamp_parser("2020", EvalMode::Legacy, tz).unwrap(),
2206            Some(1577836800000000) // this is in milliseconds
2207        );
2208        assert_eq!(
2209            timestamp_parser("2020-01", EvalMode::Legacy, tz).unwrap(),
2210            Some(1577836800000000)
2211        );
2212        assert_eq!(
2213            timestamp_parser("2020-01-01", EvalMode::Legacy, tz).unwrap(),
2214            Some(1577836800000000)
2215        );
2216        assert_eq!(
2217            timestamp_parser("2020-01-01T12", EvalMode::Legacy, tz).unwrap(),
2218            Some(1577880000000000)
2219        );
2220        assert_eq!(
2221            timestamp_parser("2020-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2222            Some(1577882040000000)
2223        );
2224        assert_eq!(
2225            timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2226            Some(1577882096000000)
2227        );
2228        assert_eq!(
2229            timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2230            Some(1577882096123456)
2231        );
2232        assert_eq!(
2233            timestamp_parser("0100", EvalMode::Legacy, tz).unwrap(),
2234            Some(-59011459200000000)
2235        );
2236        assert_eq!(
2237            timestamp_parser("0100-01", EvalMode::Legacy, tz).unwrap(),
2238            Some(-59011459200000000)
2239        );
2240        assert_eq!(
2241            timestamp_parser("0100-01-01", EvalMode::Legacy, tz).unwrap(),
2242            Some(-59011459200000000)
2243        );
2244        assert_eq!(
2245            timestamp_parser("0100-01-01T12", EvalMode::Legacy, tz).unwrap(),
2246            Some(-59011416000000000)
2247        );
2248        assert_eq!(
2249            timestamp_parser("0100-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2250            Some(-59011413960000000)
2251        );
2252        assert_eq!(
2253            timestamp_parser("0100-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2254            Some(-59011413904000000)
2255        );
2256        assert_eq!(
2257            timestamp_parser("0100-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2258            Some(-59011413903876544)
2259        );
2260        assert_eq!(
2261            timestamp_parser("10000", EvalMode::Legacy, tz).unwrap(),
2262            Some(253402300800000000)
2263        );
2264        assert_eq!(
2265            timestamp_parser("10000-01", EvalMode::Legacy, tz).unwrap(),
2266            Some(253402300800000000)
2267        );
2268        assert_eq!(
2269            timestamp_parser("10000-01-01", EvalMode::Legacy, tz).unwrap(),
2270            Some(253402300800000000)
2271        );
2272        assert_eq!(
2273            timestamp_parser("10000-01-01T12", EvalMode::Legacy, tz).unwrap(),
2274            Some(253402344000000000)
2275        );
2276        assert_eq!(
2277            timestamp_parser("10000-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2278            Some(253402346040000000)
2279        );
2280        assert_eq!(
2281            timestamp_parser("10000-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2282            Some(253402346096000000)
2283        );
2284        assert_eq!(
2285            timestamp_parser("10000-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2286            Some(253402346096123456)
2287        );
2288        // assert_eq!(
2289        //     timestamp_parser("T2",  EvalMode::Legacy).unwrap(),
2290        //     Some(1714356000000000) // this value needs to change everyday.
2291        // );
2292    }
2293
2294    #[test]
2295    #[cfg_attr(miri, ignore)] // test takes too long with miri
2296    fn test_cast_string_to_timestamp() {
2297        let array: ArrayRef = Arc::new(StringArray::from(vec![
2298            Some("2020-01-01T12:34:56.123456"),
2299            Some("T2"),
2300            Some("0100-01-01T12:34:56.123456"),
2301            Some("10000-01-01T12:34:56.123456"),
2302        ]));
2303        let tz = &timezone::Tz::from_str("UTC").unwrap();
2304
2305        let string_array = array
2306            .as_any()
2307            .downcast_ref::<GenericStringArray<i32>>()
2308            .expect("Expected a string array");
2309
2310        let eval_mode = EvalMode::Legacy;
2311        let result = cast_utf8_to_timestamp!(
2312            &string_array,
2313            eval_mode,
2314            TimestampMicrosecondType,
2315            timestamp_parser,
2316            tz
2317        );
2318
2319        assert_eq!(
2320            result.data_type(),
2321            &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
2322        );
2323        assert_eq!(result.len(), 4);
2324    }
2325
2326    #[test]
2327    fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
2328        // prepare input data
2329        let keys = Int32Array::from(vec![0, 1]);
2330        let values: ArrayRef = Arc::new(StringArray::from(vec![
2331            Some("2020-01-01T12:34:56.123456"),
2332            Some("T2"),
2333        ]));
2334        let dict_array = Arc::new(DictionaryArray::new(keys, values));
2335
2336        let timezone = "UTC".to_string();
2337        // test casting string dictionary array to timestamp array
2338        let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false);
2339        let result = cast_array(
2340            dict_array,
2341            &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
2342            &cast_options,
2343        )?;
2344        assert_eq!(
2345            *result.data_type(),
2346            DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
2347        );
2348        assert_eq!(result.len(), 2);
2349
2350        Ok(())
2351    }
2352
2353    #[test]
2354    fn date_parser_test() {
2355        for date in &[
2356            "2020",
2357            "2020-01",
2358            "2020-01-01",
2359            "02020-01-01",
2360            "002020-01-01",
2361            "0002020-01-01",
2362            "2020-1-1",
2363            "2020-01-01 ",
2364            "2020-01-01T",
2365        ] {
2366            for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] {
2367                assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262));
2368            }
2369        }
2370
2371        //dates in invalid formats
2372        for date in &[
2373            "abc",
2374            "",
2375            "not_a_date",
2376            "3/",
2377            "3/12",
2378            "3/12/2020",
2379            "3/12/2002 T",
2380            "202",
2381            "2020-010-01",
2382            "2020-10-010",
2383            "2020-10-010T",
2384            "--262143-12-31",
2385            "--262143-12-31 ",
2386        ] {
2387            for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2388                assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2389            }
2390            assert!(date_parser(date, EvalMode::Ansi).is_err());
2391        }
2392
2393        for date in &["-3638-5"] {
2394            for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2395                assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(-2048160));
2396            }
2397        }
2398
2399        //Naive Date only supports years 262142 AD to 262143 BC
2400        //returns None for dates out of range supported by Naive Date.
2401        for date in &[
2402            "-262144-1-1",
2403            "262143-01-1",
2404            "262143-1-1",
2405            "262143-01-1 ",
2406            "262143-01-01T ",
2407            "262143-1-01T 1234",
2408            "-0973250",
2409        ] {
2410            for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2411                assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2412            }
2413        }
2414    }
2415
2416    #[test]
2417    fn test_cast_string_to_date() {
2418        let array: ArrayRef = Arc::new(StringArray::from(vec![
2419            Some("2020"),
2420            Some("2020-01"),
2421            Some("2020-01-01"),
2422            Some("2020-01-01T"),
2423        ]));
2424
2425        let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap();
2426
2427        let date32_array = result
2428            .as_any()
2429            .downcast_ref::<arrow::array::Date32Array>()
2430            .unwrap();
2431        assert_eq!(date32_array.len(), 4);
2432        date32_array
2433            .iter()
2434            .for_each(|v| assert_eq!(v.unwrap(), 18262));
2435    }
2436
2437    #[test]
2438    fn test_cast_string_array_with_valid_dates() {
2439        let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2440            Some("-262143-12-31"),
2441            Some("\n -262143-12-31 "),
2442            Some("-262143-12-31T \t\n"),
2443            Some("\n\t-262143-12-31T\r"),
2444            Some("-262143-12-31T 123123123"),
2445            Some("\r\n-262143-12-31T \r123123123"),
2446            Some("\n -262143-12-31T \n\t"),
2447        ]));
2448
2449        for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2450            let result =
2451                cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2452                    .unwrap();
2453
2454            let date32_array = result
2455                .as_any()
2456                .downcast_ref::<arrow::array::Date32Array>()
2457                .unwrap();
2458            assert_eq!(result.len(), 7);
2459            date32_array
2460                .iter()
2461                .for_each(|v| assert_eq!(v.unwrap(), -96464928));
2462        }
2463    }
2464
2465    #[test]
2466    fn test_cast_string_array_with_invalid_dates() {
2467        let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2468            Some("2020"),
2469            Some("2020-01"),
2470            Some("2020-01-01"),
2471            //4 invalid dates
2472            Some("2020-010-01T"),
2473            Some("202"),
2474            Some(" 202 "),
2475            Some("\n 2020-\r8 "),
2476            Some("2020-01-01T"),
2477            // Overflows i32
2478            Some("-4607172990231812908"),
2479        ]));
2480
2481        for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2482            let result =
2483                cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2484                    .unwrap();
2485
2486            let date32_array = result
2487                .as_any()
2488                .downcast_ref::<arrow::array::Date32Array>()
2489                .unwrap();
2490            assert_eq!(
2491                date32_array.iter().collect::<Vec<_>>(),
2492                vec![
2493                    Some(18262),
2494                    Some(18262),
2495                    Some(18262),
2496                    None,
2497                    None,
2498                    None,
2499                    None,
2500                    Some(18262),
2501                    None
2502                ]
2503            );
2504        }
2505
2506        let result =
2507            cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi);
2508        match result {
2509            Err(e) => assert!(
2510                e.to_string().contains(
2511                    "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed")
2512            ),
2513            _ => panic!("Expected error"),
2514        }
2515    }
2516
2517    #[test]
2518    fn test_cast_string_as_i8() {
2519        // basic
2520        assert_eq!(
2521            cast_string_to_i8("127", EvalMode::Legacy).unwrap(),
2522            Some(127_i8)
2523        );
2524        assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None);
2525        assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err());
2526        // decimals
2527        assert_eq!(
2528            cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(),
2529            Some(0_i8)
2530        );
2531        assert_eq!(
2532            cast_string_to_i8(".", EvalMode::Legacy).unwrap(),
2533            Some(0_i8)
2534        );
2535        // TRY should always return null for decimals
2536        assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None);
2537        assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None);
2538        // ANSI mode should throw error on decimal
2539        assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err());
2540        assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err());
2541    }
2542
2543    #[test]
2544    fn test_cast_unsupported_timestamp_to_date() {
2545        // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported
2546        let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2547        let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
2548        let result = cast_array(
2549            Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2550            &DataType::Date32,
2551            &cast_options,
2552        );
2553        assert!(result.is_err())
2554    }
2555
2556    #[test]
2557    fn test_cast_invalid_timezone() {
2558        let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2559        let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false);
2560        let result = cast_array(
2561            Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2562            &DataType::Date32,
2563            &cast_options,
2564        );
2565        assert!(result.is_err())
2566    }
2567
2568    #[test]
2569    fn test_cast_struct_to_utf8() {
2570        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2571            Some(1),
2572            Some(2),
2573            None,
2574            Some(4),
2575            Some(5),
2576        ]));
2577        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2578        let c: ArrayRef = Arc::new(StructArray::from(vec![
2579            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2580            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2581        ]));
2582        let string_array = cast_array(
2583            c,
2584            &DataType::Utf8,
2585            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2586        )
2587        .unwrap();
2588        let string_array = string_array.as_string::<i32>();
2589        assert_eq!(5, string_array.len());
2590        assert_eq!(r#"{1, a}"#, string_array.value(0));
2591        assert_eq!(r#"{2, b}"#, string_array.value(1));
2592        assert_eq!(r#"{null, c}"#, string_array.value(2));
2593        assert_eq!(r#"{4, d}"#, string_array.value(3));
2594        assert_eq!(r#"{5, e}"#, string_array.value(4));
2595    }
2596
2597    #[test]
2598    fn test_cast_struct_to_struct() {
2599        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2600            Some(1),
2601            Some(2),
2602            None,
2603            Some(4),
2604            Some(5),
2605        ]));
2606        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2607        let c: ArrayRef = Arc::new(StructArray::from(vec![
2608            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2609            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2610        ]));
2611        // change type of "a" from Int32 to Utf8
2612        let fields = Fields::from(vec![
2613            Field::new("a", DataType::Utf8, true),
2614            Field::new("b", DataType::Utf8, true),
2615        ]);
2616        let cast_array = spark_cast(
2617            ColumnarValue::Array(c),
2618            &DataType::Struct(fields),
2619            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2620        )
2621        .unwrap();
2622        if let ColumnarValue::Array(cast_array) = cast_array {
2623            assert_eq!(5, cast_array.len());
2624            let a = cast_array.as_struct().column(0).as_string::<i32>();
2625            assert_eq!("1", a.value(0));
2626        } else {
2627            unreachable!()
2628        }
2629    }
2630
2631    #[test]
2632    fn test_cast_struct_to_struct_drop_column() {
2633        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2634            Some(1),
2635            Some(2),
2636            None,
2637            Some(4),
2638            Some(5),
2639        ]));
2640        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2641        let c: ArrayRef = Arc::new(StructArray::from(vec![
2642            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2643            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2644        ]));
2645        // change type of "a" from Int32 to Utf8 and drop "b"
2646        let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]);
2647        let cast_array = spark_cast(
2648            ColumnarValue::Array(c),
2649            &DataType::Struct(fields),
2650            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2651        )
2652        .unwrap();
2653        if let ColumnarValue::Array(cast_array) = cast_array {
2654            assert_eq!(5, cast_array.len());
2655            let struct_array = cast_array.as_struct();
2656            assert_eq!(1, struct_array.columns().len());
2657            let a = struct_array.column(0).as_string::<i32>();
2658            assert_eq!("1", a.value(0));
2659        } else {
2660            unreachable!()
2661        }
2662    }
2663}