datafusion_comet_spark_expr/conversion_funcs/
cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::timezone;
19use crate::utils::array_with_timezone;
20use crate::{EvalMode, SparkError, SparkResult};
21use arrow::array::builder::StringBuilder;
22use arrow::array::{DictionaryArray, StringArray, StructArray};
23use arrow::compute::can_cast_types;
24use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
25use arrow::{
26    array::{
27        cast::AsArray,
28        types::{Date32Type, Int16Type, Int32Type, Int8Type},
29        Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array,
30        GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
31        PrimitiveArray,
32    },
33    compute::{cast_with_options, take, unary, CastOptions},
34    datatypes::{
35        is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type,
36        Float64Type, Int64Type, TimestampMicrosecondType,
37    },
38    error::ArrowError,
39    record_batch::RecordBatch,
40    util::display::FormatOptions,
41};
42use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
43use datafusion::common::{
44    cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult,
45    ScalarValue,
46};
47use datafusion::physical_expr::PhysicalExpr;
48use datafusion::physical_plan::ColumnarValue;
49use num::{
50    cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num,
51    ToPrimitive,
52};
53use regex::Regex;
54use std::str::FromStr;
55use std::{
56    any::Any,
57    fmt::{Debug, Display, Formatter},
58    hash::Hash,
59    num::Wrapping,
60    sync::Arc,
61};
62
63static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
64
65const MICROS_PER_SECOND: i64 = 1000000;
66
67static CAST_OPTIONS: CastOptions = CastOptions {
68    safe: true,
69    format_options: FormatOptions::new()
70        .with_timestamp_tz_format(TIMESTAMP_FORMAT)
71        .with_timestamp_format(TIMESTAMP_FORMAT),
72};
73
74struct TimeStampInfo {
75    year: i32,
76    month: u32,
77    day: u32,
78    hour: u32,
79    minute: u32,
80    second: u32,
81    microsecond: u32,
82}
83
84impl Default for TimeStampInfo {
85    fn default() -> Self {
86        TimeStampInfo {
87            year: 1,
88            month: 1,
89            day: 1,
90            hour: 0,
91            minute: 0,
92            second: 0,
93            microsecond: 0,
94        }
95    }
96}
97
98impl TimeStampInfo {
99    pub fn with_year(&mut self, year: i32) -> &mut Self {
100        self.year = year;
101        self
102    }
103
104    pub fn with_month(&mut self, month: u32) -> &mut Self {
105        self.month = month;
106        self
107    }
108
109    pub fn with_day(&mut self, day: u32) -> &mut Self {
110        self.day = day;
111        self
112    }
113
114    pub fn with_hour(&mut self, hour: u32) -> &mut Self {
115        self.hour = hour;
116        self
117    }
118
119    pub fn with_minute(&mut self, minute: u32) -> &mut Self {
120        self.minute = minute;
121        self
122    }
123
124    pub fn with_second(&mut self, second: u32) -> &mut Self {
125        self.second = second;
126        self
127    }
128
129    pub fn with_microsecond(&mut self, microsecond: u32) -> &mut Self {
130        self.microsecond = microsecond;
131        self
132    }
133}
134
135#[derive(Debug, Eq)]
136pub struct Cast {
137    pub child: Arc<dyn PhysicalExpr>,
138    pub data_type: DataType,
139    pub cast_options: SparkCastOptions,
140}
141
142impl PartialEq for Cast {
143    fn eq(&self, other: &Self) -> bool {
144        self.child.eq(&other.child)
145            && self.data_type.eq(&other.data_type)
146            && self.cast_options.eq(&other.cast_options)
147    }
148}
149
150impl Hash for Cast {
151    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
152        self.child.hash(state);
153        self.data_type.hash(state);
154        self.cast_options.hash(state);
155    }
156}
157
158/// Determine if Comet supports a cast, taking options such as EvalMode and Timezone into account.
159pub fn cast_supported(
160    from_type: &DataType,
161    to_type: &DataType,
162    options: &SparkCastOptions,
163) -> bool {
164    use DataType::*;
165
166    let from_type = if let Dictionary(_, dt) = from_type {
167        dt
168    } else {
169        from_type
170    };
171
172    let to_type = if let Dictionary(_, dt) = to_type {
173        dt
174    } else {
175        to_type
176    };
177
178    if from_type == to_type {
179        return true;
180    }
181
182    match (from_type, to_type) {
183        (Boolean, _) => can_cast_from_boolean(to_type, options),
184        (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
185            if options.allow_cast_unsigned_ints =>
186        {
187            true
188        }
189        (Int8, _) => can_cast_from_byte(to_type, options),
190        (Int16, _) => can_cast_from_short(to_type, options),
191        (Int32, _) => can_cast_from_int(to_type, options),
192        (Int64, _) => can_cast_from_long(to_type, options),
193        (Float32, _) => can_cast_from_float(to_type, options),
194        (Float64, _) => can_cast_from_double(to_type, options),
195        (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
196        (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options),
197        (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options),
198        (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
199        (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
200        (Struct(from_fields), Struct(to_fields)) => from_fields
201            .iter()
202            .zip(to_fields.iter())
203            .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)),
204        _ => false,
205    }
206}
207
208fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool {
209    use DataType::*;
210    match to_type {
211        Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
212        Float32 | Float64 => {
213            // https://github.com/apache/datafusion-comet/issues/326
214            // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
215            // Does not support ANSI mode.
216            options.allow_incompat
217        }
218        Decimal128(_, _) => {
219            // https://github.com/apache/datafusion-comet/issues/325
220            // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
221            // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits
222
223            options.allow_incompat
224        }
225        Date32 | Date64 => {
226            // https://github.com/apache/datafusion-comet/issues/327
227            // Only supports years between 262143 BC and 262142 AD
228            options.allow_incompat
229        }
230        Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => {
231            // ANSI mode not supported
232            false
233        }
234        Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => {
235            // Cast will use UTC instead of $timeZoneId
236            options.allow_incompat
237        }
238        Timestamp(_, _) => {
239            // https://github.com/apache/datafusion-comet/issues/328
240            // Not all valid formats are supported
241            options.allow_incompat
242        }
243        _ => false,
244    }
245}
246
247fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> bool {
248    use DataType::*;
249    match from_type {
250        Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => true,
251        Float32 | Float64 => {
252            // There can be differences in precision.
253            // For example, the input \"1.4E-45\" will produce 1.0E-45 " +
254            // instead of 1.4E-45"))
255            true
256        }
257        Decimal128(_, _) => {
258            // https://github.com/apache/datafusion-comet/issues/1068
259            // There can be formatting differences in some case due to Spark using
260            // scientific notation where Comet does not
261            true
262        }
263        Binary => {
264            // https://github.com/apache/datafusion-comet/issues/377
265            // Only works for binary data representing valid UTF-8 strings
266            options.allow_incompat
267        }
268        Struct(fields) => fields
269            .iter()
270            .all(|f| can_cast_to_string(f.data_type(), options)),
271        _ => false,
272    }
273}
274
275fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) -> bool {
276    use DataType::*;
277    match to_type {
278        Timestamp(_, _) | Date32 | Date64 | Utf8 => {
279            // incompatible
280            options.allow_incompat
281        }
282        _ => {
283            // unsupported
284            false
285        }
286    }
287}
288
289fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> bool {
290    use DataType::*;
291    match to_type {
292        Boolean | Int8 | Int16 => {
293            // https://github.com/apache/datafusion-comet/issues/352
294            // this seems like an edge case that isn't important for us to support
295            false
296        }
297        Int64 => {
298            // https://github.com/apache/datafusion-comet/issues/352
299            true
300        }
301        Date32 | Date64 | Utf8 | Decimal128(_, _) => true,
302        _ => {
303            // unsupported
304            false
305        }
306    }
307}
308
309fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
310    use DataType::*;
311    matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
312}
313
314fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
315    use DataType::*;
316    matches!(
317        to_type,
318        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
319    )
320}
321
322fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
323    use DataType::*;
324    matches!(
325        to_type,
326        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
327    )
328}
329
330fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
331    use DataType::*;
332    match to_type {
333        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true,
334        Decimal128(_, _) => {
335            // incompatible: no overflow check
336            options.allow_incompat
337        }
338        _ => false,
339    }
340}
341
342fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
343    use DataType::*;
344    match to_type {
345        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
346        Decimal128(_, _) => {
347            // incompatible: no overflow check
348            options.allow_incompat
349        }
350        _ => false,
351    }
352}
353
354fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool {
355    use DataType::*;
356    matches!(
357        to_type,
358        Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _)
359    )
360}
361
362fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool {
363    use DataType::*;
364    matches!(
365        to_type,
366        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _)
367    )
368}
369
370fn can_cast_from_decimal(
371    p1: &u8,
372    _s1: &i8,
373    to_type: &DataType,
374    options: &SparkCastOptions,
375) -> bool {
376    use DataType::*;
377    match to_type {
378        Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
379        Decimal128(p2, _) => {
380            if p2 < p1 {
381                // https://github.com/apache/datafusion/issues/13492
382                // Incompatible(Some("Casting to smaller precision is not supported"))
383                options.allow_incompat
384            } else {
385                true
386            }
387        }
388        _ => false,
389    }
390}
391
392macro_rules! cast_utf8_to_int {
393    ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
394        let len = $array.len();
395        let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
396        for i in 0..len {
397            if $array.is_null(i) {
398                cast_array.append_null()
399            } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
400                cast_array.append_value(cast_value);
401            } else {
402                cast_array.append_null()
403            }
404        }
405        let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
406        result
407    }};
408}
409macro_rules! cast_utf8_to_timestamp {
410    ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
411        let len = $array.len();
412        let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
413        for i in 0..len {
414            if $array.is_null(i) {
415                cast_array.append_null()
416            } else if let Ok(Some(cast_value)) =
417                $cast_method($array.value(i).trim(), $eval_mode, $tz)
418            {
419                cast_array.append_value(cast_value);
420            } else {
421                cast_array.append_null()
422            }
423        }
424        let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
425        result
426    }};
427}
428
429macro_rules! cast_float_to_string {
430    ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{
431
432        fn cast<OffsetSize>(
433            from: &dyn Array,
434            _eval_mode: EvalMode,
435        ) -> SparkResult<ArrayRef>
436        where
437            OffsetSize: OffsetSizeTrait, {
438                let array = from.as_any().downcast_ref::<$output_type>().unwrap();
439
440                // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the
441                // result is expressed without scientific notation with at least one digit on either side of
442                // the decimal point. Otherwise, Spark uses a mantissa followed by E and an
443                // exponent. The mantissa has an optional leading minus sign followed by one digit to the
444                // left of the decimal point, and the minimal number of digits greater than zero to the
445                // right. The exponent has and optional leading minus sign.
446                // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html
447
448                const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
449                const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;
450
451                let output_array = array
452                    .iter()
453                    .map(|value| match value {
454                        Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())),
455                        Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())),
456                        Some(value)
457                            if (value.abs() < UPPER_SCIENTIFIC_BOUND
458                                && value.abs() >= LOWER_SCIENTIFIC_BOUND)
459                                || value.abs() == 0.0 =>
460                        {
461                            let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" };
462
463                            Ok(Some(format!("{value}{trailing_zero}")))
464                        }
465                        Some(value)
466                            if value.abs() >= UPPER_SCIENTIFIC_BOUND
467                                || value.abs() < LOWER_SCIENTIFIC_BOUND =>
468                        {
469                            let formatted = format!("{value:E}");
470
471                            if formatted.contains(".") {
472                                Ok(Some(formatted))
473                            } else {
474                                // `formatted` is already in scientific notation and can be split up by E
475                                // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0
476                                let prepare_number: Vec<&str> = formatted.split("E").collect();
477
478                                let coefficient = prepare_number[0];
479
480                                let exponent = prepare_number[1];
481
482                                Ok(Some(format!("{coefficient}.0E{exponent}")))
483                            }
484                        }
485                        Some(value) => Ok(Some(value.to_string())),
486                        _ => Ok(None),
487                    })
488                    .collect::<Result<GenericStringArray<OffsetSize>, SparkError>>()?;
489
490                Ok(Arc::new(output_array))
491            }
492
493        cast::<$offset_type>($from, $eval_mode)
494    }};
495}
496
497macro_rules! cast_int_to_int_macro {
498    (
499        $array: expr,
500        $eval_mode:expr,
501        $from_arrow_primitive_type: ty,
502        $to_arrow_primitive_type: ty,
503        $from_data_type: expr,
504        $to_native_type: ty,
505        $spark_from_data_type_name: expr,
506        $spark_to_data_type_name: expr
507    ) => {{
508        let cast_array = $array
509            .as_any()
510            .downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
511            .unwrap();
512        let spark_int_literal_suffix = match $from_data_type {
513            &DataType::Int64 => "L",
514            &DataType::Int16 => "S",
515            &DataType::Int8 => "T",
516            _ => "",
517        };
518
519        let output_array = match $eval_mode {
520            EvalMode::Legacy => cast_array
521                .iter()
522                .map(|value| match value {
523                    Some(value) => {
524                        Ok::<Option<$to_native_type>, SparkError>(Some(value as $to_native_type))
525                    }
526                    _ => Ok(None),
527                })
528                .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
529            _ => cast_array
530                .iter()
531                .map(|value| match value {
532                    Some(value) => {
533                        let res = <$to_native_type>::try_from(value);
534                        if res.is_err() {
535                            Err(cast_overflow(
536                                &(value.to_string() + spark_int_literal_suffix),
537                                $spark_from_data_type_name,
538                                $spark_to_data_type_name,
539                            ))
540                        } else {
541                            Ok::<Option<$to_native_type>, SparkError>(Some(res.unwrap()))
542                        }
543                    }
544                    _ => Ok(None),
545                })
546                .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
547        }?;
548        let result: SparkResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
549        result
550    }};
551}
552
553// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short.
554// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast,
555// this can cause unexpected Short/Byte cast results. Replicate this behavior.
556macro_rules! cast_float_to_int16_down {
557    (
558        $array:expr,
559        $eval_mode:expr,
560        $src_array_type:ty,
561        $dest_array_type:ty,
562        $rust_src_type:ty,
563        $rust_dest_type:ty,
564        $src_type_str:expr,
565        $dest_type_str:expr,
566        $format_str:expr
567    ) => {{
568        let cast_array = $array
569            .as_any()
570            .downcast_ref::<$src_array_type>()
571            .expect(concat!("Expected a ", stringify!($src_array_type)));
572
573        let output_array = match $eval_mode {
574            EvalMode::Ansi => cast_array
575                .iter()
576                .map(|value| match value {
577                    Some(value) => {
578                        let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX;
579                        if is_overflow {
580                            return Err(cast_overflow(
581                                &format!($format_str, value).replace("e", "E"),
582                                $src_type_str,
583                                $dest_type_str,
584                            ));
585                        }
586                        let i32_value = value as i32;
587                        <$rust_dest_type>::try_from(i32_value)
588                            .map_err(|_| {
589                                cast_overflow(
590                                    &format!($format_str, value).replace("e", "E"),
591                                    $src_type_str,
592                                    $dest_type_str,
593                                )
594                            })
595                            .map(Some)
596                    }
597                    None => Ok(None),
598                })
599                .collect::<Result<$dest_array_type, _>>()?,
600            _ => cast_array
601                .iter()
602                .map(|value| match value {
603                    Some(value) => {
604                        let i32_value = value as i32;
605                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
606                            i32_value as $rust_dest_type,
607                        ))
608                    }
609                    None => Ok(None),
610                })
611                .collect::<Result<$dest_array_type, _>>()?,
612        };
613        Ok(Arc::new(output_array) as ArrayRef)
614    }};
615}
616
617macro_rules! cast_float_to_int32_up {
618    (
619        $array:expr,
620        $eval_mode:expr,
621        $src_array_type:ty,
622        $dest_array_type:ty,
623        $rust_src_type:ty,
624        $rust_dest_type:ty,
625        $src_type_str:expr,
626        $dest_type_str:expr,
627        $max_dest_val:expr,
628        $format_str:expr
629    ) => {{
630        let cast_array = $array
631            .as_any()
632            .downcast_ref::<$src_array_type>()
633            .expect(concat!("Expected a ", stringify!($src_array_type)));
634
635        let output_array = match $eval_mode {
636            EvalMode::Ansi => cast_array
637                .iter()
638                .map(|value| match value {
639                    Some(value) => {
640                        let is_overflow =
641                            value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val;
642                        if is_overflow {
643                            return Err(cast_overflow(
644                                &format!($format_str, value).replace("e", "E"),
645                                $src_type_str,
646                                $dest_type_str,
647                            ));
648                        }
649                        Ok(Some(value as $rust_dest_type))
650                    }
651                    None => Ok(None),
652                })
653                .collect::<Result<$dest_array_type, _>>()?,
654            _ => cast_array
655                .iter()
656                .map(|value| match value {
657                    Some(value) => {
658                        Ok::<Option<$rust_dest_type>, SparkError>(Some(value as $rust_dest_type))
659                    }
660                    None => Ok(None),
661                })
662                .collect::<Result<$dest_array_type, _>>()?,
663        };
664        Ok(Arc::new(output_array) as ArrayRef)
665    }};
666}
667
668// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short.
669// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast,
670// this can cause unexpected Short/Byte cast results. Replicate this behavior.
671macro_rules! cast_decimal_to_int16_down {
672    (
673        $array:expr,
674        $eval_mode:expr,
675        $dest_array_type:ty,
676        $rust_dest_type:ty,
677        $dest_type_str:expr,
678        $precision:expr,
679        $scale:expr
680    ) => {{
681        let cast_array = $array
682            .as_any()
683            .downcast_ref::<Decimal128Array>()
684            .expect("Expected a Decimal128ArrayType");
685
686        let output_array = match $eval_mode {
687            EvalMode::Ansi => cast_array
688                .iter()
689                .map(|value| match value {
690                    Some(value) => {
691                        let divisor = 10_i128.pow($scale as u32);
692                        let (truncated, decimal) = (value / divisor, (value % divisor).abs());
693                        let is_overflow = truncated.abs() > i32::MAX.into();
694                        if is_overflow {
695                            return Err(cast_overflow(
696                                &format!("{}.{}BD", truncated, decimal),
697                                &format!("DECIMAL({},{})", $precision, $scale),
698                                $dest_type_str,
699                            ));
700                        }
701                        let i32_value = truncated as i32;
702                        <$rust_dest_type>::try_from(i32_value)
703                            .map_err(|_| {
704                                cast_overflow(
705                                    &format!("{}.{}BD", truncated, decimal),
706                                    &format!("DECIMAL({},{})", $precision, $scale),
707                                    $dest_type_str,
708                                )
709                            })
710                            .map(Some)
711                    }
712                    None => Ok(None),
713                })
714                .collect::<Result<$dest_array_type, _>>()?,
715            _ => cast_array
716                .iter()
717                .map(|value| match value {
718                    Some(value) => {
719                        let divisor = 10_i128.pow($scale as u32);
720                        let i32_value = (value / divisor) as i32;
721                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
722                            i32_value as $rust_dest_type,
723                        ))
724                    }
725                    None => Ok(None),
726                })
727                .collect::<Result<$dest_array_type, _>>()?,
728        };
729        Ok(Arc::new(output_array) as ArrayRef)
730    }};
731}
732
733macro_rules! cast_decimal_to_int32_up {
734    (
735        $array:expr,
736        $eval_mode:expr,
737        $dest_array_type:ty,
738        $rust_dest_type:ty,
739        $dest_type_str:expr,
740        $max_dest_val:expr,
741        $precision:expr,
742        $scale:expr
743    ) => {{
744        let cast_array = $array
745            .as_any()
746            .downcast_ref::<Decimal128Array>()
747            .expect("Expected a Decimal128ArrayType");
748
749        let output_array = match $eval_mode {
750            EvalMode::Ansi => cast_array
751                .iter()
752                .map(|value| match value {
753                    Some(value) => {
754                        let divisor = 10_i128.pow($scale as u32);
755                        let (truncated, decimal) = (value / divisor, (value % divisor).abs());
756                        let is_overflow = truncated.abs() > $max_dest_val.into();
757                        if is_overflow {
758                            return Err(cast_overflow(
759                                &format!("{}.{}BD", truncated, decimal),
760                                &format!("DECIMAL({},{})", $precision, $scale),
761                                $dest_type_str,
762                            ));
763                        }
764                        Ok(Some(truncated as $rust_dest_type))
765                    }
766                    None => Ok(None),
767                })
768                .collect::<Result<$dest_array_type, _>>()?,
769            _ => cast_array
770                .iter()
771                .map(|value| match value {
772                    Some(value) => {
773                        let divisor = 10_i128.pow($scale as u32);
774                        let truncated = value / divisor;
775                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
776                            truncated as $rust_dest_type,
777                        ))
778                    }
779                    None => Ok(None),
780                })
781                .collect::<Result<$dest_array_type, _>>()?,
782        };
783        Ok(Arc::new(output_array) as ArrayRef)
784    }};
785}
786
787impl Cast {
788    pub fn new(
789        child: Arc<dyn PhysicalExpr>,
790        data_type: DataType,
791        cast_options: SparkCastOptions,
792    ) -> Self {
793        Self {
794            child,
795            data_type,
796            cast_options,
797        }
798    }
799}
800
801/// Spark cast options
802#[derive(Debug, Clone, Hash, PartialEq, Eq)]
803pub struct SparkCastOptions {
804    /// Spark evaluation mode
805    pub eval_mode: EvalMode,
806    /// When cast from/to timezone related types, we need timezone, which will be resolved with
807    /// session local timezone by an analyzer in Spark.
808    // TODO we should change timezone to Tz to avoid repeated parsing
809    pub timezone: String,
810    /// Allow casts that are supported but not guaranteed to be 100% compatible
811    pub allow_incompat: bool,
812    /// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter)
813    pub allow_cast_unsigned_ints: bool,
814    /// We also use the cast logic for adapting Parquet schemas, so this flag is used
815    /// for that use case
816    pub is_adapting_schema: bool,
817    /// String to use to represent null values
818    pub null_string: String,
819}
820
821impl SparkCastOptions {
822    pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self {
823        Self {
824            eval_mode,
825            timezone: timezone.to_string(),
826            allow_incompat,
827            allow_cast_unsigned_ints: false,
828            is_adapting_schema: false,
829            null_string: "null".to_string(),
830        }
831    }
832
833    pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self {
834        Self {
835            eval_mode,
836            timezone: "".to_string(),
837            allow_incompat,
838            allow_cast_unsigned_ints: false,
839            is_adapting_schema: false,
840            null_string: "null".to_string(),
841        }
842    }
843}
844
845/// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known
846/// to be compatible, and returns an error when a not supported and not DF-compatible cast
847/// is requested.
848pub fn spark_cast(
849    arg: ColumnarValue,
850    data_type: &DataType,
851    cast_options: &SparkCastOptions,
852) -> DataFusionResult<ColumnarValue> {
853    match arg {
854        ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array(
855            array,
856            data_type,
857            cast_options,
858        )?)),
859        ColumnarValue::Scalar(scalar) => {
860            // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for
861            // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it
862            // here.
863            let array = scalar.to_array()?;
864            let scalar =
865                ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?;
866            Ok(ColumnarValue::Scalar(scalar))
867        }
868    }
869}
870
871// copied from datafusion common scalar/mod.rs
872fn dict_from_values<K: ArrowDictionaryKeyType>(
873    values_array: ArrayRef,
874) -> datafusion::common::Result<ArrayRef> {
875    // Create a key array with `size` elements of 0..array_len for all
876    // non-null value elements
877    let key_array: PrimitiveArray<K> = (0..values_array.len())
878        .map(|index| {
879            if values_array.is_valid(index) {
880                let native_index = K::Native::from_usize(index).ok_or_else(|| {
881                    DataFusionError::Internal(format!(
882                        "Can not create index of type {} from value {}",
883                        K::DATA_TYPE,
884                        index
885                    ))
886                })?;
887                Ok(Some(native_index))
888            } else {
889                Ok(None)
890            }
891        })
892        .collect::<datafusion::common::Result<Vec<_>>>()?
893        .into_iter()
894        .collect();
895
896    // create a new DictionaryArray
897    //
898    // Note: this path could be made faster by using the ArrayData
899    // APIs and skipping validation, if it every comes up in
900    // performance traces.
901    let dict_array = DictionaryArray::<K>::try_new(key_array, values_array)?;
902    Ok(Arc::new(dict_array))
903}
904
905fn cast_array(
906    array: ArrayRef,
907    to_type: &DataType,
908    cast_options: &SparkCastOptions,
909) -> DataFusionResult<ArrayRef> {
910    use DataType::*;
911    let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
912    let from_type = array.data_type().clone();
913
914    let native_cast_options: CastOptions = CastOptions {
915        safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), // take safe mode from cast_options passed
916        format_options: FormatOptions::new()
917            .with_timestamp_tz_format(TIMESTAMP_FORMAT)
918            .with_timestamp_format(TIMESTAMP_FORMAT),
919    };
920
921    let array = match &from_type {
922        Dictionary(key_type, value_type)
923            if key_type.as_ref() == &Int32
924                && (value_type.as_ref() == &Utf8
925                    || value_type.as_ref() == &LargeUtf8
926                    || value_type.as_ref() == &Binary
927                    || value_type.as_ref() == &LargeBinary) =>
928        {
929            let dict_array = array
930                .as_any()
931                .downcast_ref::<DictionaryArray<Int32Type>>()
932                .expect("Expected a dictionary array");
933
934            let casted_result = match to_type {
935                Dictionary(_, to_value_type) => {
936                    let casted_dictionary = DictionaryArray::<Int32Type>::new(
937                        dict_array.keys().clone(),
938                        cast_array(Arc::clone(dict_array.values()), to_value_type, cast_options)?,
939                    );
940                    Arc::new(casted_dictionary.clone())
941                }
942                _ => {
943                    let casted_dictionary = DictionaryArray::<Int32Type>::new(
944                        dict_array.keys().clone(),
945                        cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?,
946                    );
947                    take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?
948                }
949            };
950            return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
951        }
952        _ => {
953            if let Dictionary(_, _) = to_type {
954                let dict_array = dict_from_values::<Int32Type>(array)?;
955                let casted_result = cast_array(dict_array, to_type, cast_options)?;
956                return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
957            } else {
958                array
959            }
960        }
961    };
962    let from_type = array.data_type();
963    let eval_mode = cast_options.eval_mode;
964
965    let cast_result = match (from_type, to_type) {
966        (Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
967        (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
968        (Utf8, Timestamp(_, _)) => {
969            cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
970        }
971        (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
972        (Int64, Int32)
973        | (Int64, Int16)
974        | (Int64, Int8)
975        | (Int32, Int16)
976        | (Int32, Int8)
977        | (Int16, Int8)
978            if eval_mode != EvalMode::Try =>
979        {
980            spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
981        }
982        (Utf8, Int8 | Int16 | Int32 | Int64) => {
983            cast_string_to_int::<i32>(to_type, &array, eval_mode)
984        }
985        (LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
986            cast_string_to_int::<i64>(to_type, &array, eval_mode)
987        }
988        (Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
989        (Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array, eval_mode),
990        (Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
991        (Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array, eval_mode),
992        (Float32, Decimal128(precision, scale)) => {
993            cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
994        }
995        (Float64, Decimal128(precision, scale)) => {
996            cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
997        }
998        (Float32, Int8)
999        | (Float32, Int16)
1000        | (Float32, Int32)
1001        | (Float32, Int64)
1002        | (Float64, Int8)
1003        | (Float64, Int16)
1004        | (Float64, Int32)
1005        | (Float64, Int64)
1006        | (Decimal128(_, _), Int8)
1007        | (Decimal128(_, _), Int16)
1008        | (Decimal128(_, _), Int32)
1009        | (Decimal128(_, _), Int64)
1010            if eval_mode != EvalMode::Try =>
1011        {
1012            spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
1013        }
1014        (Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?),
1015        (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
1016        (Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
1017            array.as_struct(),
1018            from_type,
1019            to_type,
1020            cast_options,
1021        )?),
1022        (List(_), List(_)) if can_cast_types(from_type, to_type) => {
1023            Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
1024        }
1025        (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
1026            if cast_options.allow_cast_unsigned_ints =>
1027        {
1028            Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
1029        }
1030        _ if cast_options.is_adapting_schema
1031            || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
1032        {
1033            // use DataFusion cast only when we know that it is compatible with Spark
1034            Ok(cast_with_options(&array, to_type, &native_cast_options)?)
1035        }
1036        _ => {
1037            // we should never reach this code because the Scala code should be checking
1038            // for supported cast operations and falling back to Spark for anything that
1039            // is not yet supported
1040            Err(SparkError::Internal(format!(
1041                "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}"
1042            )))
1043        }
1044    };
1045    Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
1046}
1047
1048/// Determines if DataFusion supports the given cast in a way that is
1049/// compatible with Spark
1050fn is_datafusion_spark_compatible(
1051    from_type: &DataType,
1052    to_type: &DataType,
1053    allow_incompat: bool,
1054) -> bool {
1055    if from_type == to_type {
1056        return true;
1057    }
1058    match from_type {
1059        DataType::Null => {
1060            matches!(to_type, DataType::List(_))
1061        }
1062        DataType::Boolean => matches!(
1063            to_type,
1064            DataType::Int8
1065                | DataType::Int16
1066                | DataType::Int32
1067                | DataType::Int64
1068                | DataType::Float32
1069                | DataType::Float64
1070                | DataType::Utf8
1071        ),
1072        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1073            // note that the cast from Int32/Int64 -> Decimal128 here is actually
1074            // not compatible with Spark (no overflow checks) but we have tests that
1075            // rely on this cast working, so we have to leave it here for now
1076            matches!(
1077                to_type,
1078                DataType::Boolean
1079                    | DataType::Int8
1080                    | DataType::Int16
1081                    | DataType::Int32
1082                    | DataType::Int64
1083                    | DataType::Float32
1084                    | DataType::Float64
1085                    | DataType::Decimal128(_, _)
1086                    | DataType::Utf8
1087            )
1088        }
1089        DataType::Float32 | DataType::Float64 => matches!(
1090            to_type,
1091            DataType::Boolean
1092                | DataType::Int8
1093                | DataType::Int16
1094                | DataType::Int32
1095                | DataType::Int64
1096                | DataType::Float32
1097                | DataType::Float64
1098        ),
1099        DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1100            to_type,
1101            DataType::Int8
1102                | DataType::Int16
1103                | DataType::Int32
1104                | DataType::Int64
1105                | DataType::Float32
1106                | DataType::Float64
1107                | DataType::Decimal128(_, _)
1108                | DataType::Decimal256(_, _)
1109                | DataType::Utf8 // note that there can be formatting differences
1110        ),
1111        DataType::Utf8 if allow_incompat => matches!(
1112            to_type,
1113            DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
1114        ),
1115        DataType::Utf8 => matches!(to_type, DataType::Binary),
1116        DataType::Date32 => matches!(to_type, DataType::Utf8),
1117        DataType::Timestamp(_, _) => {
1118            matches!(
1119                to_type,
1120                DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1121            )
1122        }
1123        DataType::Binary => {
1124            // note that this is not completely Spark compatible because
1125            // DataFusion only supports binary data containing valid UTF-8 strings
1126            matches!(to_type, DataType::Utf8)
1127        }
1128        _ => false,
1129    }
1130}
1131
1132/// Cast between struct types based on logic in
1133/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`.
1134fn cast_struct_to_struct(
1135    array: &StructArray,
1136    from_type: &DataType,
1137    to_type: &DataType,
1138    cast_options: &SparkCastOptions,
1139) -> DataFusionResult<ArrayRef> {
1140    match (from_type, to_type) {
1141        (DataType::Struct(from_fields), DataType::Struct(to_fields)) => {
1142            let cast_fields: Vec<ArrayRef> = from_fields
1143                .iter()
1144                .enumerate()
1145                .zip(to_fields.iter())
1146                .map(|((idx, _from), to)| {
1147                    let from_field = Arc::clone(array.column(idx));
1148                    let array_length = from_field.len();
1149                    let cast_result = spark_cast(
1150                        ColumnarValue::from(from_field),
1151                        to.data_type(),
1152                        cast_options,
1153                    )
1154                    .unwrap();
1155                    cast_result.to_array(array_length).unwrap()
1156                })
1157                .collect();
1158
1159            Ok(Arc::new(StructArray::new(
1160                to_fields.clone(),
1161                cast_fields,
1162                array.nulls().cloned(),
1163            )))
1164        }
1165        _ => unreachable!(),
1166    }
1167}
1168
1169fn casts_struct_to_string(
1170    array: &StructArray,
1171    spark_cast_options: &SparkCastOptions,
1172) -> DataFusionResult<ArrayRef> {
1173    // cast each field to a string
1174    let string_arrays: Vec<ArrayRef> = array
1175        .columns()
1176        .iter()
1177        .map(|arr| {
1178            spark_cast(
1179                ColumnarValue::Array(Arc::clone(arr)),
1180                &DataType::Utf8,
1181                spark_cast_options,
1182            )
1183            .and_then(|cv| cv.into_array(arr.len()))
1184        })
1185        .collect::<DataFusionResult<Vec<_>>>()?;
1186    let string_arrays: Vec<&StringArray> =
1187        string_arrays.iter().map(|arr| arr.as_string()).collect();
1188    // build the struct string containing entries in the format `"field_name":field_value`
1189    let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
1190    let mut str = String::with_capacity(array.len() * 16);
1191    for row_index in 0..array.len() {
1192        if array.is_null(row_index) {
1193            builder.append_null();
1194        } else {
1195            str.clear();
1196            let mut any_fields_written = false;
1197            str.push('{');
1198            for field in &string_arrays {
1199                if any_fields_written {
1200                    str.push_str(", ");
1201                }
1202                if field.is_null(row_index) {
1203                    str.push_str(&spark_cast_options.null_string);
1204                } else {
1205                    str.push_str(field.value(row_index));
1206                }
1207                any_fields_written = true;
1208            }
1209            str.push('}');
1210            builder.append_value(&str);
1211        }
1212    }
1213    Ok(Arc::new(builder.finish()))
1214}
1215
1216fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
1217    to_type: &DataType,
1218    array: &ArrayRef,
1219    eval_mode: EvalMode,
1220) -> SparkResult<ArrayRef> {
1221    let string_array = array
1222        .as_any()
1223        .downcast_ref::<GenericStringArray<OffsetSize>>()
1224        .expect("cast_string_to_int expected a string array");
1225
1226    let cast_array: ArrayRef = match to_type {
1227        DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?,
1228        DataType::Int16 => {
1229            cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)?
1230        }
1231        DataType::Int32 => {
1232            cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)?
1233        }
1234        DataType::Int64 => {
1235            cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)?
1236        }
1237        dt => unreachable!(
1238            "{}",
1239            format!("invalid integer type {dt} in cast from string")
1240        ),
1241    };
1242    Ok(cast_array)
1243}
1244
1245fn cast_string_to_date(
1246    array: &ArrayRef,
1247    to_type: &DataType,
1248    eval_mode: EvalMode,
1249) -> SparkResult<ArrayRef> {
1250    let string_array = array
1251        .as_any()
1252        .downcast_ref::<GenericStringArray<i32>>()
1253        .expect("Expected a string array");
1254
1255    if to_type != &DataType::Date32 {
1256        unreachable!("Invalid data type {:?} in cast from string", to_type);
1257    }
1258
1259    let len = string_array.len();
1260    let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
1261
1262    for i in 0..len {
1263        let value = if string_array.is_null(i) {
1264            None
1265        } else {
1266            match date_parser(string_array.value(i), eval_mode) {
1267                Ok(Some(cast_value)) => Some(cast_value),
1268                Ok(None) => None,
1269                Err(e) => return Err(e),
1270            }
1271        };
1272
1273        match value {
1274            Some(cast_value) => cast_array.append_value(cast_value),
1275            None => cast_array.append_null(),
1276        }
1277    }
1278
1279    Ok(Arc::new(cast_array.finish()) as ArrayRef)
1280}
1281
1282fn cast_string_to_timestamp(
1283    array: &ArrayRef,
1284    to_type: &DataType,
1285    eval_mode: EvalMode,
1286    timezone_str: &str,
1287) -> SparkResult<ArrayRef> {
1288    let string_array = array
1289        .as_any()
1290        .downcast_ref::<GenericStringArray<i32>>()
1291        .expect("Expected a string array");
1292
1293    let tz = &timezone::Tz::from_str(timezone_str).unwrap();
1294
1295    let cast_array: ArrayRef = match to_type {
1296        DataType::Timestamp(_, _) => {
1297            cast_utf8_to_timestamp!(
1298                string_array,
1299                eval_mode,
1300                TimestampMicrosecondType,
1301                timestamp_parser,
1302                tz
1303            )
1304        }
1305        _ => unreachable!("Invalid data type {:?} in cast from string", to_type),
1306    };
1307    Ok(cast_array)
1308}
1309
1310fn cast_float64_to_decimal128(
1311    array: &dyn Array,
1312    precision: u8,
1313    scale: i8,
1314    eval_mode: EvalMode,
1315) -> SparkResult<ArrayRef> {
1316    cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode)
1317}
1318
1319fn cast_float32_to_decimal128(
1320    array: &dyn Array,
1321    precision: u8,
1322    scale: i8,
1323    eval_mode: EvalMode,
1324) -> SparkResult<ArrayRef> {
1325    cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode)
1326}
1327
1328fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
1329    array: &dyn Array,
1330    precision: u8,
1331    scale: i8,
1332    eval_mode: EvalMode,
1333) -> SparkResult<ArrayRef>
1334where
1335    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
1336{
1337    let input = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1338    let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());
1339
1340    let mul = 10_f64.powi(scale as i32);
1341
1342    for i in 0..input.len() {
1343        if input.is_null(i) {
1344            cast_array.append_null();
1345            continue;
1346        }
1347
1348        let input_value = input.value(i).as_();
1349        if let Some(v) = (input_value * mul).round().to_i128() {
1350            if is_validate_decimal_precision(v, precision) {
1351                cast_array.append_value(v);
1352                continue;
1353            }
1354        };
1355
1356        if eval_mode == EvalMode::Ansi {
1357            return Err(SparkError::NumericValueOutOfRange {
1358                value: input_value.to_string(),
1359                precision,
1360                scale,
1361            });
1362        }
1363        cast_array.append_null();
1364    }
1365
1366    let res = Arc::new(
1367        cast_array
1368            .with_precision_and_scale(precision, scale)?
1369            .finish(),
1370    ) as ArrayRef;
1371    Ok(res)
1372}
1373
1374fn spark_cast_float64_to_utf8<OffsetSize>(
1375    from: &dyn Array,
1376    _eval_mode: EvalMode,
1377) -> SparkResult<ArrayRef>
1378where
1379    OffsetSize: OffsetSizeTrait,
1380{
1381    cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
1382}
1383
1384fn spark_cast_float32_to_utf8<OffsetSize>(
1385    from: &dyn Array,
1386    _eval_mode: EvalMode,
1387) -> SparkResult<ArrayRef>
1388where
1389    OffsetSize: OffsetSizeTrait,
1390{
1391    cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
1392}
1393
1394fn spark_cast_int_to_int(
1395    array: &dyn Array,
1396    eval_mode: EvalMode,
1397    from_type: &DataType,
1398    to_type: &DataType,
1399) -> SparkResult<ArrayRef> {
1400    match (from_type, to_type) {
1401        (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
1402            array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
1403        ),
1404        (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
1405            array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
1406        ),
1407        (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
1408            array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
1409        ),
1410        (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
1411            array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
1412        ),
1413        (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
1414            array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
1415        ),
1416        (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
1417            array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
1418        ),
1419        _ => unreachable!(
1420            "{}",
1421            format!("invalid integer type {to_type} in cast from {from_type}")
1422        ),
1423    }
1424}
1425
1426fn spark_cast_utf8_to_boolean<OffsetSize>(
1427    from: &dyn Array,
1428    eval_mode: EvalMode,
1429) -> SparkResult<ArrayRef>
1430where
1431    OffsetSize: OffsetSizeTrait,
1432{
1433    let array = from
1434        .as_any()
1435        .downcast_ref::<GenericStringArray<OffsetSize>>()
1436        .unwrap();
1437
1438    let output_array = array
1439        .iter()
1440        .map(|value| match value {
1441            Some(value) => match value.to_ascii_lowercase().trim() {
1442                "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
1443                "f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
1444                _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue {
1445                    value: value.to_string(),
1446                    from_type: "STRING".to_string(),
1447                    to_type: "BOOLEAN".to_string(),
1448                }),
1449                _ => Ok(None),
1450            },
1451            _ => Ok(None),
1452        })
1453        .collect::<Result<BooleanArray, _>>()?;
1454
1455    Ok(Arc::new(output_array))
1456}
1457
1458fn spark_cast_nonintegral_numeric_to_integral(
1459    array: &dyn Array,
1460    eval_mode: EvalMode,
1461    from_type: &DataType,
1462    to_type: &DataType,
1463) -> SparkResult<ArrayRef> {
1464    match (from_type, to_type) {
1465        (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
1466            array,
1467            eval_mode,
1468            Float32Array,
1469            Int8Array,
1470            f32,
1471            i8,
1472            "FLOAT",
1473            "TINYINT",
1474            "{:e}"
1475        ),
1476        (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!(
1477            array,
1478            eval_mode,
1479            Float32Array,
1480            Int16Array,
1481            f32,
1482            i16,
1483            "FLOAT",
1484            "SMALLINT",
1485            "{:e}"
1486        ),
1487        (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!(
1488            array,
1489            eval_mode,
1490            Float32Array,
1491            Int32Array,
1492            f32,
1493            i32,
1494            "FLOAT",
1495            "INT",
1496            i32::MAX,
1497            "{:e}"
1498        ),
1499        (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
1500            array,
1501            eval_mode,
1502            Float32Array,
1503            Int64Array,
1504            f32,
1505            i64,
1506            "FLOAT",
1507            "BIGINT",
1508            i64::MAX,
1509            "{:e}"
1510        ),
1511        (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
1512            array,
1513            eval_mode,
1514            Float64Array,
1515            Int8Array,
1516            f64,
1517            i8,
1518            "DOUBLE",
1519            "TINYINT",
1520            "{:e}D"
1521        ),
1522        (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!(
1523            array,
1524            eval_mode,
1525            Float64Array,
1526            Int16Array,
1527            f64,
1528            i16,
1529            "DOUBLE",
1530            "SMALLINT",
1531            "{:e}D"
1532        ),
1533        (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!(
1534            array,
1535            eval_mode,
1536            Float64Array,
1537            Int32Array,
1538            f64,
1539            i32,
1540            "DOUBLE",
1541            "INT",
1542            i32::MAX,
1543            "{:e}D"
1544        ),
1545        (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
1546            array,
1547            eval_mode,
1548            Float64Array,
1549            Int64Array,
1550            f64,
1551            i64,
1552            "DOUBLE",
1553            "BIGINT",
1554            i64::MAX,
1555            "{:e}D"
1556        ),
1557        (DataType::Decimal128(precision, scale), DataType::Int8) => {
1558            cast_decimal_to_int16_down!(
1559                array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1560            )
1561        }
1562        (DataType::Decimal128(precision, scale), DataType::Int16) => {
1563            cast_decimal_to_int16_down!(
1564                array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1565            )
1566        }
1567        (DataType::Decimal128(precision, scale), DataType::Int32) => {
1568            cast_decimal_to_int32_up!(
1569                array,
1570                eval_mode,
1571                Int32Array,
1572                i32,
1573                "INT",
1574                i32::MAX,
1575                *precision,
1576                *scale
1577            )
1578        }
1579        (DataType::Decimal128(precision, scale), DataType::Int64) => {
1580            cast_decimal_to_int32_up!(
1581                array,
1582                eval_mode,
1583                Int64Array,
1584                i64,
1585                "BIGINT",
1586                i64::MAX,
1587                *precision,
1588                *scale
1589            )
1590        }
1591        _ => unreachable!(
1592            "{}",
1593            format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}")
1594        ),
1595    }
1596}
1597
1598/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
1599fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> {
1600    Ok(cast_string_to_int_with_range_check(
1601        str,
1602        eval_mode,
1603        "TINYINT",
1604        i8::MIN as i32,
1605        i8::MAX as i32,
1606    )?
1607    .map(|v| v as i8))
1608}
1609
1610/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
1611fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>> {
1612    Ok(cast_string_to_int_with_range_check(
1613        str,
1614        eval_mode,
1615        "SMALLINT",
1616        i16::MIN as i32,
1617        i16::MAX as i32,
1618    )?
1619    .map(|v| v as i16))
1620}
1621
1622/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper)
1623fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1624    do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
1625}
1626
1627/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper)
1628fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> {
1629    do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
1630}
1631
1632fn cast_string_to_int_with_range_check(
1633    str: &str,
1634    eval_mode: EvalMode,
1635    type_name: &str,
1636    min: i32,
1637    max: i32,
1638) -> SparkResult<Option<i32>> {
1639    match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
1640        None => Ok(None),
1641        Some(v) if v >= min && v <= max => Ok(Some(v)),
1642        _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1643        _ => Ok(None),
1644    }
1645}
1646
1647/// Equivalent to
1648/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal)
1649/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal)
1650fn do_cast_string_to_int<
1651    T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
1652>(
1653    str: &str,
1654    eval_mode: EvalMode,
1655    type_name: &str,
1656    min_value: T,
1657) -> SparkResult<Option<T>> {
1658    let trimmed_str = str.trim();
1659    if trimmed_str.is_empty() {
1660        return none_or_err(eval_mode, type_name, str);
1661    }
1662    let len = trimmed_str.len();
1663    let mut result: T = T::zero();
1664    let mut negative = false;
1665    let radix = T::from(10);
1666    let stop_value = min_value / radix;
1667    let mut parse_sign_and_digits = true;
1668
1669    for (i, ch) in trimmed_str.char_indices() {
1670        if parse_sign_and_digits {
1671            if i == 0 {
1672                negative = ch == '-';
1673                let positive = ch == '+';
1674                if negative || positive {
1675                    if i + 1 == len {
1676                        // input string is just "+" or "-"
1677                        return none_or_err(eval_mode, type_name, str);
1678                    }
1679                    // consume this char
1680                    continue;
1681                }
1682            }
1683
1684            if ch == '.' {
1685                if eval_mode == EvalMode::Legacy {
1686                    // truncate decimal in legacy mode
1687                    parse_sign_and_digits = false;
1688                    continue;
1689                } else {
1690                    return none_or_err(eval_mode, type_name, str);
1691                }
1692            }
1693
1694            let digit = if ch.is_ascii_digit() {
1695                (ch as u32) - ('0' as u32)
1696            } else {
1697                return none_or_err(eval_mode, type_name, str);
1698            };
1699
1700            // We are going to process the new digit and accumulate the result. However, before
1701            // doing this, if the result is already smaller than the
1702            // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be
1703            // smaller than minValue, and we can stop
1704            if result < stop_value {
1705                return none_or_err(eval_mode, type_name, str);
1706            }
1707
1708            // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE /
1709            // radix), we can just use `result > 0` to check overflow. If result
1710            // overflows, we should stop
1711            let v = result * radix;
1712            let digit = (digit as i32).into();
1713            match v.checked_sub(&digit) {
1714                Some(x) if x <= T::zero() => result = x,
1715                _ => {
1716                    return none_or_err(eval_mode, type_name, str);
1717                }
1718            }
1719        } else {
1720            // make sure fractional digits are valid digits but ignore them
1721            if !ch.is_ascii_digit() {
1722                return none_or_err(eval_mode, type_name, str);
1723            }
1724        }
1725    }
1726
1727    if !negative {
1728        if let Some(neg) = result.checked_neg() {
1729            if neg < T::zero() {
1730                return none_or_err(eval_mode, type_name, str);
1731            }
1732            result = neg;
1733        } else {
1734            return none_or_err(eval_mode, type_name, str);
1735        }
1736    }
1737
1738    Ok(Some(result))
1739}
1740
1741/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode
1742#[inline]
1743fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {
1744    match eval_mode {
1745        EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1746        _ => Ok(None),
1747    }
1748}
1749
1750#[inline]
1751fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
1752    SparkError::CastInvalidValue {
1753        value: value.to_string(),
1754        from_type: from_type.to_string(),
1755        to_type: to_type.to_string(),
1756    }
1757}
1758
1759#[inline]
1760fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
1761    SparkError::CastOverFlow {
1762        value: value.to_string(),
1763        from_type: from_type.to_string(),
1764        to_type: to_type.to_string(),
1765    }
1766}
1767
1768impl Display for Cast {
1769    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1770        write!(
1771            f,
1772            "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]",
1773            self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode
1774        )
1775    }
1776}
1777
1778impl PhysicalExpr for Cast {
1779    fn as_any(&self) -> &dyn Any {
1780        self
1781    }
1782
1783    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
1784        unimplemented!()
1785    }
1786
1787    fn data_type(&self, _: &Schema) -> DataFusionResult<DataType> {
1788        Ok(self.data_type.clone())
1789    }
1790
1791    fn nullable(&self, _: &Schema) -> DataFusionResult<bool> {
1792        Ok(true)
1793    }
1794
1795    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
1796        let arg = self.child.evaluate(batch)?;
1797        spark_cast(arg, &self.data_type, &self.cast_options)
1798    }
1799
1800    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1801        vec![&self.child]
1802    }
1803
1804    fn with_new_children(
1805        self: Arc<Self>,
1806        children: Vec<Arc<dyn PhysicalExpr>>,
1807    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
1808        match children.len() {
1809            1 => Ok(Arc::new(Cast::new(
1810                Arc::clone(&children[0]),
1811                self.data_type.clone(),
1812                self.cast_options.clone(),
1813            ))),
1814            _ => internal_err!("Cast should have exactly one child"),
1815        }
1816    }
1817}
1818
1819fn timestamp_parser<T: TimeZone>(
1820    value: &str,
1821    eval_mode: EvalMode,
1822    tz: &T,
1823) -> SparkResult<Option<i64>> {
1824    let value = value.trim();
1825    if value.is_empty() {
1826        return Ok(None);
1827    }
1828    // Define regex patterns and corresponding parsing functions
1829    let patterns = &[
1830        (
1831            Regex::new(r"^\d{4,5}$").unwrap(),
1832            parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
1833        ),
1834        (
1835            Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1836            parse_str_to_month_timestamp,
1837        ),
1838        (
1839            Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1840            parse_str_to_day_timestamp,
1841        ),
1842        (
1843            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1844            parse_str_to_hour_timestamp,
1845        ),
1846        (
1847            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1848            parse_str_to_minute_timestamp,
1849        ),
1850        (
1851            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1852            parse_str_to_second_timestamp,
1853        ),
1854        (
1855            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1856            parse_str_to_microsecond_timestamp,
1857        ),
1858        (
1859            Regex::new(r"^T\d{1,2}$").unwrap(),
1860            parse_str_to_time_only_timestamp,
1861        ),
1862    ];
1863
1864    let mut timestamp = None;
1865
1866    // Iterate through patterns and try matching
1867    for (pattern, parse_func) in patterns {
1868        if pattern.is_match(value) {
1869            timestamp = parse_func(value, tz)?;
1870            break;
1871        }
1872    }
1873
1874    if timestamp.is_none() {
1875        return if eval_mode == EvalMode::Ansi {
1876            Err(SparkError::CastInvalidValue {
1877                value: value.to_string(),
1878                from_type: "STRING".to_string(),
1879                to_type: "TIMESTAMP".to_string(),
1880            })
1881        } else {
1882            Ok(None)
1883        };
1884    }
1885
1886    match timestamp {
1887        Some(ts) => Ok(Some(ts)),
1888        None => Err(SparkError::Internal(
1889            "Failed to parse timestamp".to_string(),
1890        )),
1891    }
1892}
1893
1894fn parse_timestamp_to_micros<T: TimeZone>(
1895    timestamp_info: &TimeStampInfo,
1896    tz: &T,
1897) -> SparkResult<Option<i64>> {
1898    let datetime = tz.with_ymd_and_hms(
1899        timestamp_info.year,
1900        timestamp_info.month,
1901        timestamp_info.day,
1902        timestamp_info.hour,
1903        timestamp_info.minute,
1904        timestamp_info.second,
1905    );
1906
1907    // Check if datetime is not None
1908    let tz_datetime = match datetime.single() {
1909        Some(dt) => dt
1910            .with_timezone(tz)
1911            .with_nanosecond(timestamp_info.microsecond * 1000),
1912        None => {
1913            return Err(SparkError::Internal(
1914                "Failed to parse timestamp".to_string(),
1915            ));
1916        }
1917    };
1918
1919    let result = match tz_datetime {
1920        Some(dt) => dt.timestamp_micros(),
1921        None => {
1922            return Err(SparkError::Internal(
1923                "Failed to parse timestamp".to_string(),
1924            ));
1925        }
1926    };
1927
1928    Ok(Some(result))
1929}
1930
1931fn get_timestamp_values<T: TimeZone>(
1932    value: &str,
1933    timestamp_type: &str,
1934    tz: &T,
1935) -> SparkResult<Option<i64>> {
1936    let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
1937    let year = values[0].parse::<i32>().unwrap_or_default();
1938    let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
1939    let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
1940    let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
1941    let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
1942    let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
1943    let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));
1944
1945    let mut timestamp_info = TimeStampInfo::default();
1946
1947    let timestamp_info = match timestamp_type {
1948        "year" => timestamp_info.with_year(year),
1949        "month" => timestamp_info.with_year(year).with_month(month),
1950        "day" => timestamp_info
1951            .with_year(year)
1952            .with_month(month)
1953            .with_day(day),
1954        "hour" => timestamp_info
1955            .with_year(year)
1956            .with_month(month)
1957            .with_day(day)
1958            .with_hour(hour),
1959        "minute" => timestamp_info
1960            .with_year(year)
1961            .with_month(month)
1962            .with_day(day)
1963            .with_hour(hour)
1964            .with_minute(minute),
1965        "second" => timestamp_info
1966            .with_year(year)
1967            .with_month(month)
1968            .with_day(day)
1969            .with_hour(hour)
1970            .with_minute(minute)
1971            .with_second(second),
1972        "microsecond" => timestamp_info
1973            .with_year(year)
1974            .with_month(month)
1975            .with_day(day)
1976            .with_hour(hour)
1977            .with_minute(minute)
1978            .with_second(second)
1979            .with_microsecond(microsecond),
1980        _ => {
1981            return Err(SparkError::CastInvalidValue {
1982                value: value.to_string(),
1983                from_type: "STRING".to_string(),
1984                to_type: "TIMESTAMP".to_string(),
1985            })
1986        }
1987    };
1988
1989    parse_timestamp_to_micros(timestamp_info, tz)
1990}
1991
1992fn parse_str_to_year_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1993    get_timestamp_values(value, "year", tz)
1994}
1995
1996fn parse_str_to_month_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1997    get_timestamp_values(value, "month", tz)
1998}
1999
2000fn parse_str_to_day_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2001    get_timestamp_values(value, "day", tz)
2002}
2003
2004fn parse_str_to_hour_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2005    get_timestamp_values(value, "hour", tz)
2006}
2007
2008fn parse_str_to_minute_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2009    get_timestamp_values(value, "minute", tz)
2010}
2011
2012fn parse_str_to_second_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2013    get_timestamp_values(value, "second", tz)
2014}
2015
2016fn parse_str_to_microsecond_timestamp<T: TimeZone>(
2017    value: &str,
2018    tz: &T,
2019) -> SparkResult<Option<i64>> {
2020    get_timestamp_values(value, "microsecond", tz)
2021}
2022
2023fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
2024    let values: Vec<&str> = value.split('T').collect();
2025    let time_values: Vec<u32> = values[1]
2026        .split(':')
2027        .map(|v| v.parse::<u32>().unwrap_or(0))
2028        .collect();
2029
2030    let datetime = tz.from_utc_datetime(&chrono::Utc::now().naive_utc());
2031    let timestamp = datetime
2032        .with_timezone(tz)
2033        .with_hour(time_values.first().copied().unwrap_or_default())
2034        .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
2035        .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
2036        .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000))
2037        .map(|dt| dt.timestamp_micros())
2038        .unwrap_or_default();
2039
2040    Ok(Some(timestamp))
2041}
2042
2043//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate.
2044fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
2045    // local functions
2046    fn get_trimmed_start(bytes: &[u8]) -> usize {
2047        let mut start = 0;
2048        while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) {
2049            start += 1;
2050        }
2051        start
2052    }
2053
2054    fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize {
2055        let mut end = bytes.len() - 1;
2056        while end > start && is_whitespace_or_iso_control(bytes[end]) {
2057            end -= 1;
2058        }
2059        end + 1
2060    }
2061
2062    fn is_whitespace_or_iso_control(byte: u8) -> bool {
2063        byte.is_ascii_whitespace() || byte.is_ascii_control()
2064    }
2065
2066    fn is_valid_digits(segment: i32, digits: usize) -> bool {
2067        // An integer is able to represent a date within [+-]5 million years.
2068        let max_digits_year = 7;
2069        //year (segment 0) can be between 4 to 7 digits,
2070        //month and day (segment 1 and 2) can be between 1 to 2 digits
2071        (segment == 0 && digits >= 4 && digits <= max_digits_year)
2072            || (segment != 0 && digits > 0 && digits <= 2)
2073    }
2074
2075    fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
2076        if eval_mode == EvalMode::Ansi {
2077            Err(SparkError::CastInvalidValue {
2078                value: date_str.to_string(),
2079                from_type: "STRING".to_string(),
2080                to_type: "DATE".to_string(),
2081            })
2082        } else {
2083            Ok(None)
2084        }
2085    }
2086    // end local functions
2087
2088    if date_str.is_empty() {
2089        return return_result(date_str, eval_mode);
2090    }
2091
2092    //values of date segments year, month and day defaulting to 1
2093    let mut date_segments = [1, 1, 1];
2094    let mut sign = 1;
2095    let mut current_segment = 0;
2096    let mut current_segment_value = Wrapping(0);
2097    let mut current_segment_digits = 0;
2098    let bytes = date_str.as_bytes();
2099
2100    let mut j = get_trimmed_start(bytes);
2101    let str_end_trimmed = get_trimmed_end(j, bytes);
2102
2103    if j == str_end_trimmed {
2104        return return_result(date_str, eval_mode);
2105    }
2106
2107    //assign a sign to the date
2108    if bytes[j] == b'-' || bytes[j] == b'+' {
2109        sign = if bytes[j] == b'-' { -1 } else { 1 };
2110        j += 1;
2111    }
2112
2113    //loop to the end of string until we have processed 3 segments,
2114    //exit loop on encountering any space ' ' or 'T' after the 3rd segment
2115    while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) {
2116        let b = bytes[j];
2117        if current_segment < 2 && b == b'-' {
2118            //check for validity of year and month segments if current byte is separator
2119            if !is_valid_digits(current_segment, current_segment_digits) {
2120                return return_result(date_str, eval_mode);
2121            }
2122            //if valid update corresponding segment with the current segment value.
2123            date_segments[current_segment as usize] = current_segment_value.0;
2124            current_segment_value = Wrapping(0);
2125            current_segment_digits = 0;
2126            current_segment += 1;
2127        } else if !b.is_ascii_digit() {
2128            return return_result(date_str, eval_mode);
2129        } else {
2130            //increment value of current segment by the next digit
2131            let parsed_value = Wrapping((b - b'0') as i32);
2132            current_segment_value = current_segment_value * Wrapping(10) + parsed_value;
2133            current_segment_digits += 1;
2134        }
2135        j += 1;
2136    }
2137
2138    //check for validity of last segment
2139    if !is_valid_digits(current_segment, current_segment_digits) {
2140        return return_result(date_str, eval_mode);
2141    }
2142
2143    if current_segment < 2 && j < str_end_trimmed {
2144        // For the `yyyy` and `yyyy-[m]m` formats, entire input must be consumed.
2145        return return_result(date_str, eval_mode);
2146    }
2147
2148    date_segments[current_segment as usize] = current_segment_value.0;
2149
2150    match NaiveDate::from_ymd_opt(
2151        sign * date_segments[0],
2152        date_segments[1] as u32,
2153        date_segments[2] as u32,
2154    ) {
2155        Some(date) => {
2156            let duration_since_epoch = date
2157                .signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
2158                .num_days();
2159            Ok(Some(duration_since_epoch.to_i32().unwrap()))
2160        }
2161        None => Ok(None),
2162    }
2163}
2164
2165/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
2166/// This function runs as a post process of the DataFusion cast(). By the time it arrives here,
2167/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify
2168/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in
2169/// expressions/cast.rs, so it can be still Dictionary.
2170fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef {
2171    match (from_type, to_type) {
2172        (DataType::Timestamp(_, _), DataType::Int64) => {
2173            // See Spark's `Cast` expression
2174            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2175        }
2176        (DataType::Dictionary(_, value_type), DataType::Int64)
2177            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2178        {
2179            // See Spark's `Cast` expression
2180            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2181        }
2182        (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array),
2183        (DataType::Dictionary(_, value_type), DataType::Utf8)
2184            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2185        {
2186            remove_trailing_zeroes(array)
2187        }
2188        _ => array,
2189    }
2190}
2191
2192/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
2193fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
2194where
2195    T: ArrowPrimitiveType,
2196    F: Fn(T::Native) -> T::Native,
2197{
2198    if let Some(d) = array.as_any_dictionary_opt() {
2199        let new_values = unary_dyn::<F, T>(d.values(), op)?;
2200        return Ok(Arc::new(d.with_values(Arc::new(new_values))));
2201    }
2202
2203    match array.as_primitive_opt::<T>() {
2204        Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
2205            Ok(Arc::new(unary::<T, F, T>(
2206                array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
2207                op,
2208            )))
2209        }
2210        _ => Err(ArrowError::NotYetImplemented(format!(
2211            "Cannot perform unary operation of type {} on array of type {}",
2212            T::DATA_TYPE,
2213            array.data_type()
2214        ))),
2215    }
2216}
2217
2218/// Remove any trailing zeroes in the string if they occur after in the fractional seconds,
2219/// to match Spark behavior
2220/// example:
2221/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
2222/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
2223/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
2224/// "1970-01-01 05:30:00"     => "1970-01-01 05:30:00"
2225/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
2226fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
2227    let string_array = as_generic_string_array::<i32>(&array).unwrap();
2228    let result = string_array
2229        .iter()
2230        .map(|s| s.map(trim_end))
2231        .collect::<GenericStringArray<i32>>();
2232    Arc::new(result) as ArrayRef
2233}
2234
2235fn trim_end(s: &str) -> &str {
2236    if s.rfind('.').is_some() {
2237        s.trim_end_matches('0')
2238    } else {
2239        s
2240    }
2241}
2242
2243#[cfg(test)]
2244mod tests {
2245    use arrow::array::StringArray;
2246    use arrow::datatypes::TimestampMicrosecondType;
2247    use arrow::datatypes::{Field, Fields, TimeUnit};
2248    use core::f64;
2249    use std::str::FromStr;
2250
2251    use super::*;
2252
2253    #[test]
2254    #[cfg_attr(miri, ignore)] // test takes too long with miri
2255    fn timestamp_parser_test() {
2256        let tz = &timezone::Tz::from_str("UTC").unwrap();
2257        // write for all formats
2258        assert_eq!(
2259            timestamp_parser("2020", EvalMode::Legacy, tz).unwrap(),
2260            Some(1577836800000000) // this is in milliseconds
2261        );
2262        assert_eq!(
2263            timestamp_parser("2020-01", EvalMode::Legacy, tz).unwrap(),
2264            Some(1577836800000000)
2265        );
2266        assert_eq!(
2267            timestamp_parser("2020-01-01", EvalMode::Legacy, tz).unwrap(),
2268            Some(1577836800000000)
2269        );
2270        assert_eq!(
2271            timestamp_parser("2020-01-01T12", EvalMode::Legacy, tz).unwrap(),
2272            Some(1577880000000000)
2273        );
2274        assert_eq!(
2275            timestamp_parser("2020-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2276            Some(1577882040000000)
2277        );
2278        assert_eq!(
2279            timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2280            Some(1577882096000000)
2281        );
2282        assert_eq!(
2283            timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2284            Some(1577882096123456)
2285        );
2286        assert_eq!(
2287            timestamp_parser("0100", EvalMode::Legacy, tz).unwrap(),
2288            Some(-59011459200000000)
2289        );
2290        assert_eq!(
2291            timestamp_parser("0100-01", EvalMode::Legacy, tz).unwrap(),
2292            Some(-59011459200000000)
2293        );
2294        assert_eq!(
2295            timestamp_parser("0100-01-01", EvalMode::Legacy, tz).unwrap(),
2296            Some(-59011459200000000)
2297        );
2298        assert_eq!(
2299            timestamp_parser("0100-01-01T12", EvalMode::Legacy, tz).unwrap(),
2300            Some(-59011416000000000)
2301        );
2302        assert_eq!(
2303            timestamp_parser("0100-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2304            Some(-59011413960000000)
2305        );
2306        assert_eq!(
2307            timestamp_parser("0100-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2308            Some(-59011413904000000)
2309        );
2310        assert_eq!(
2311            timestamp_parser("0100-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2312            Some(-59011413903876544)
2313        );
2314        assert_eq!(
2315            timestamp_parser("10000", EvalMode::Legacy, tz).unwrap(),
2316            Some(253402300800000000)
2317        );
2318        assert_eq!(
2319            timestamp_parser("10000-01", EvalMode::Legacy, tz).unwrap(),
2320            Some(253402300800000000)
2321        );
2322        assert_eq!(
2323            timestamp_parser("10000-01-01", EvalMode::Legacy, tz).unwrap(),
2324            Some(253402300800000000)
2325        );
2326        assert_eq!(
2327            timestamp_parser("10000-01-01T12", EvalMode::Legacy, tz).unwrap(),
2328            Some(253402344000000000)
2329        );
2330        assert_eq!(
2331            timestamp_parser("10000-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2332            Some(253402346040000000)
2333        );
2334        assert_eq!(
2335            timestamp_parser("10000-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2336            Some(253402346096000000)
2337        );
2338        assert_eq!(
2339            timestamp_parser("10000-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2340            Some(253402346096123456)
2341        );
2342        // assert_eq!(
2343        //     timestamp_parser("T2",  EvalMode::Legacy).unwrap(),
2344        //     Some(1714356000000000) // this value needs to change everyday.
2345        // );
2346    }
2347
2348    #[test]
2349    #[cfg_attr(miri, ignore)] // test takes too long with miri
2350    fn test_cast_string_to_timestamp() {
2351        let array: ArrayRef = Arc::new(StringArray::from(vec![
2352            Some("2020-01-01T12:34:56.123456"),
2353            Some("T2"),
2354            Some("0100-01-01T12:34:56.123456"),
2355            Some("10000-01-01T12:34:56.123456"),
2356        ]));
2357        let tz = &timezone::Tz::from_str("UTC").unwrap();
2358
2359        let string_array = array
2360            .as_any()
2361            .downcast_ref::<GenericStringArray<i32>>()
2362            .expect("Expected a string array");
2363
2364        let eval_mode = EvalMode::Legacy;
2365        let result = cast_utf8_to_timestamp!(
2366            &string_array,
2367            eval_mode,
2368            TimestampMicrosecondType,
2369            timestamp_parser,
2370            tz
2371        );
2372
2373        assert_eq!(
2374            result.data_type(),
2375            &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
2376        );
2377        assert_eq!(result.len(), 4);
2378    }
2379
2380    #[test]
2381    fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
2382        // prepare input data
2383        let keys = Int32Array::from(vec![0, 1]);
2384        let values: ArrayRef = Arc::new(StringArray::from(vec![
2385            Some("2020-01-01T12:34:56.123456"),
2386            Some("T2"),
2387        ]));
2388        let dict_array = Arc::new(DictionaryArray::new(keys, values));
2389
2390        let timezone = "UTC".to_string();
2391        // test casting string dictionary array to timestamp array
2392        let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false);
2393        let result = cast_array(
2394            dict_array,
2395            &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
2396            &cast_options,
2397        )?;
2398        assert_eq!(
2399            *result.data_type(),
2400            DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
2401        );
2402        assert_eq!(result.len(), 2);
2403
2404        Ok(())
2405    }
2406
2407    #[test]
2408    fn date_parser_test() {
2409        for date in &[
2410            "2020",
2411            "2020-01",
2412            "2020-01-01",
2413            "02020-01-01",
2414            "002020-01-01",
2415            "0002020-01-01",
2416            "2020-1-1",
2417            "2020-01-01 ",
2418            "2020-01-01T",
2419        ] {
2420            for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] {
2421                assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262));
2422            }
2423        }
2424
2425        //dates in invalid formats
2426        for date in &[
2427            "abc",
2428            "",
2429            "not_a_date",
2430            "3/",
2431            "3/12",
2432            "3/12/2020",
2433            "3/12/2002 T",
2434            "202",
2435            "2020-010-01",
2436            "2020-10-010",
2437            "2020-10-010T",
2438            "--262143-12-31",
2439            "--262143-12-31 ",
2440        ] {
2441            for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2442                assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2443            }
2444            assert!(date_parser(date, EvalMode::Ansi).is_err());
2445        }
2446
2447        for date in &["-3638-5"] {
2448            for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2449                assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(-2048160));
2450            }
2451        }
2452
2453        //Naive Date only supports years 262142 AD to 262143 BC
2454        //returns None for dates out of range supported by Naive Date.
2455        for date in &[
2456            "-262144-1-1",
2457            "262143-01-1",
2458            "262143-1-1",
2459            "262143-01-1 ",
2460            "262143-01-01T ",
2461            "262143-1-01T 1234",
2462            "-0973250",
2463        ] {
2464            for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2465                assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2466            }
2467        }
2468    }
2469
2470    #[test]
2471    fn test_cast_string_to_date() {
2472        let array: ArrayRef = Arc::new(StringArray::from(vec![
2473            Some("2020"),
2474            Some("2020-01"),
2475            Some("2020-01-01"),
2476            Some("2020-01-01T"),
2477        ]));
2478
2479        let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap();
2480
2481        let date32_array = result
2482            .as_any()
2483            .downcast_ref::<arrow::array::Date32Array>()
2484            .unwrap();
2485        assert_eq!(date32_array.len(), 4);
2486        date32_array
2487            .iter()
2488            .for_each(|v| assert_eq!(v.unwrap(), 18262));
2489    }
2490
2491    #[test]
2492    fn test_cast_string_array_with_valid_dates() {
2493        let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2494            Some("-262143-12-31"),
2495            Some("\n -262143-12-31 "),
2496            Some("-262143-12-31T \t\n"),
2497            Some("\n\t-262143-12-31T\r"),
2498            Some("-262143-12-31T 123123123"),
2499            Some("\r\n-262143-12-31T \r123123123"),
2500            Some("\n -262143-12-31T \n\t"),
2501        ]));
2502
2503        for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2504            let result =
2505                cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2506                    .unwrap();
2507
2508            let date32_array = result
2509                .as_any()
2510                .downcast_ref::<arrow::array::Date32Array>()
2511                .unwrap();
2512            assert_eq!(result.len(), 7);
2513            date32_array
2514                .iter()
2515                .for_each(|v| assert_eq!(v.unwrap(), -96464928));
2516        }
2517    }
2518
2519    #[test]
2520    fn test_cast_string_array_with_invalid_dates() {
2521        let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2522            Some("2020"),
2523            Some("2020-01"),
2524            Some("2020-01-01"),
2525            //4 invalid dates
2526            Some("2020-010-01T"),
2527            Some("202"),
2528            Some(" 202 "),
2529            Some("\n 2020-\r8 "),
2530            Some("2020-01-01T"),
2531            // Overflows i32
2532            Some("-4607172990231812908"),
2533        ]));
2534
2535        for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2536            let result =
2537                cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2538                    .unwrap();
2539
2540            let date32_array = result
2541                .as_any()
2542                .downcast_ref::<arrow::array::Date32Array>()
2543                .unwrap();
2544            assert_eq!(
2545                date32_array.iter().collect::<Vec<_>>(),
2546                vec![
2547                    Some(18262),
2548                    Some(18262),
2549                    Some(18262),
2550                    None,
2551                    None,
2552                    None,
2553                    None,
2554                    Some(18262),
2555                    None
2556                ]
2557            );
2558        }
2559
2560        let result =
2561            cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi);
2562        match result {
2563            Err(e) => assert!(
2564                e.to_string().contains(
2565                    "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed")
2566            ),
2567            _ => panic!("Expected error"),
2568        }
2569    }
2570
2571    #[test]
2572    fn test_cast_string_as_i8() {
2573        // basic
2574        assert_eq!(
2575            cast_string_to_i8("127", EvalMode::Legacy).unwrap(),
2576            Some(127_i8)
2577        );
2578        assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None);
2579        assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err());
2580        // decimals
2581        assert_eq!(
2582            cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(),
2583            Some(0_i8)
2584        );
2585        assert_eq!(
2586            cast_string_to_i8(".", EvalMode::Legacy).unwrap(),
2587            Some(0_i8)
2588        );
2589        // TRY should always return null for decimals
2590        assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None);
2591        assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None);
2592        // ANSI mode should throw error on decimal
2593        assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err());
2594        assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err());
2595    }
2596
2597    #[test]
2598    fn test_cast_unsupported_timestamp_to_date() {
2599        // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported
2600        let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2601        let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
2602        let result = cast_array(
2603            Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2604            &DataType::Date32,
2605            &cast_options,
2606        );
2607        assert!(result.is_err())
2608    }
2609
2610    #[test]
2611    fn test_cast_invalid_timezone() {
2612        let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2613        let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false);
2614        let result = cast_array(
2615            Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2616            &DataType::Date32,
2617            &cast_options,
2618        );
2619        assert!(result.is_err())
2620    }
2621
2622    #[test]
2623    fn test_cast_struct_to_utf8() {
2624        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2625            Some(1),
2626            Some(2),
2627            None,
2628            Some(4),
2629            Some(5),
2630        ]));
2631        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2632        let c: ArrayRef = Arc::new(StructArray::from(vec![
2633            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2634            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2635        ]));
2636        let string_array = cast_array(
2637            c,
2638            &DataType::Utf8,
2639            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2640        )
2641        .unwrap();
2642        let string_array = string_array.as_string::<i32>();
2643        assert_eq!(5, string_array.len());
2644        assert_eq!(r#"{1, a}"#, string_array.value(0));
2645        assert_eq!(r#"{2, b}"#, string_array.value(1));
2646        assert_eq!(r#"{null, c}"#, string_array.value(2));
2647        assert_eq!(r#"{4, d}"#, string_array.value(3));
2648        assert_eq!(r#"{5, e}"#, string_array.value(4));
2649    }
2650
2651    #[test]
2652    fn test_cast_struct_to_struct() {
2653        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2654            Some(1),
2655            Some(2),
2656            None,
2657            Some(4),
2658            Some(5),
2659        ]));
2660        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2661        let c: ArrayRef = Arc::new(StructArray::from(vec![
2662            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2663            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2664        ]));
2665        // change type of "a" from Int32 to Utf8
2666        let fields = Fields::from(vec![
2667            Field::new("a", DataType::Utf8, true),
2668            Field::new("b", DataType::Utf8, true),
2669        ]);
2670        let cast_array = spark_cast(
2671            ColumnarValue::Array(c),
2672            &DataType::Struct(fields),
2673            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2674        )
2675        .unwrap();
2676        if let ColumnarValue::Array(cast_array) = cast_array {
2677            assert_eq!(5, cast_array.len());
2678            let a = cast_array.as_struct().column(0).as_string::<i32>();
2679            assert_eq!("1", a.value(0));
2680        } else {
2681            unreachable!()
2682        }
2683    }
2684
2685    #[test]
2686    fn test_cast_struct_to_struct_drop_column() {
2687        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2688            Some(1),
2689            Some(2),
2690            None,
2691            Some(4),
2692            Some(5),
2693        ]));
2694        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2695        let c: ArrayRef = Arc::new(StructArray::from(vec![
2696            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2697            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2698        ]));
2699        // change type of "a" from Int32 to Utf8 and drop "b"
2700        let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]);
2701        let cast_array = spark_cast(
2702            ColumnarValue::Array(c),
2703            &DataType::Struct(fields),
2704            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2705        )
2706        .unwrap();
2707        if let ColumnarValue::Array(cast_array) = cast_array {
2708            assert_eq!(5, cast_array.len());
2709            let struct_array = cast_array.as_struct();
2710            assert_eq!(1, struct_array.columns().len());
2711            let a = struct_array.column(0).as_string::<i32>();
2712            assert_eq!("1", a.value(0));
2713        } else {
2714            unreachable!()
2715        }
2716    }
2717
2718    #[test]
2719    // Currently the cast function depending on `f64::powi`, which has unspecified precision according to the doc
2720    // https://doc.rust-lang.org/std/primitive.f64.html#unspecified-precision.
2721    // Miri deliberately apply random floating-point errors to these operations to expose bugs
2722    // https://github.com/rust-lang/miri/issues/4395.
2723    // The random errors may interfere with test cases at rounding edge, so we ignore it on miri for now.
2724    // Once https://github.com/apache/datafusion-comet/issues/1371 is fixed, this should no longer be an issue.
2725    #[cfg_attr(miri, ignore)]
2726    fn test_cast_float_to_decimal() {
2727        let a: ArrayRef = Arc::new(Float64Array::from(vec![
2728            Some(42.),
2729            Some(0.5153125),
2730            Some(-42.4242415),
2731            Some(42e-314),
2732            Some(0.),
2733            Some(-4242.424242),
2734            Some(f64::INFINITY),
2735            Some(f64::NEG_INFINITY),
2736            Some(f64::NAN),
2737            None,
2738        ]));
2739        let b =
2740            cast_floating_point_to_decimal128::<Float64Type>(&a, 8, 6, EvalMode::Legacy).unwrap();
2741        assert_eq!(b.len(), a.len());
2742        let casted = b.as_primitive::<Decimal128Type>();
2743        assert_eq!(casted.value(0), 42000000);
2744        // https://github.com/apache/datafusion-comet/issues/1371
2745        // assert_eq!(casted.value(1), 515313);
2746        assert_eq!(casted.value(2), -42424242);
2747        assert_eq!(casted.value(3), 0);
2748        assert_eq!(casted.value(4), 0);
2749        assert!(casted.is_null(5));
2750        assert!(casted.is_null(6));
2751        assert!(casted.is_null(7));
2752        assert!(casted.is_null(8));
2753        assert!(casted.is_null(9));
2754    }
2755}