datafusion_comet_spark_expr/conversion_funcs/
cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::timezone;
19use crate::utils::array_with_timezone;
20use crate::{EvalMode, SparkError, SparkResult};
21use arrow::array::builder::StringBuilder;
22use arrow::array::{DictionaryArray, StringArray, StructArray};
23use arrow::datatypes::{DataType, Schema};
24use arrow::{
25    array::{
26        cast::AsArray,
27        types::{Date32Type, Int16Type, Int32Type, Int8Type},
28        Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array,
29        GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
30        PrimitiveArray,
31    },
32    compute::{cast_with_options, take, unary, CastOptions},
33    datatypes::{
34        is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type,
35        Float64Type, Int64Type, TimestampMicrosecondType,
36    },
37    error::ArrowError,
38    record_batch::RecordBatch,
39    util::display::FormatOptions,
40};
41use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
42use datafusion::common::{
43    cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue,
44};
45use datafusion::physical_expr::PhysicalExpr;
46use datafusion::physical_plan::ColumnarValue;
47use num::{
48    cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num,
49    ToPrimitive,
50};
51use regex::Regex;
52use std::str::FromStr;
53use std::{
54    any::Any,
55    fmt::{Debug, Display, Formatter},
56    hash::Hash,
57    num::Wrapping,
58    sync::Arc,
59};
60
61static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
62
63const MICROS_PER_SECOND: i64 = 1000000;
64
65static CAST_OPTIONS: CastOptions = CastOptions {
66    safe: true,
67    format_options: FormatOptions::new()
68        .with_timestamp_tz_format(TIMESTAMP_FORMAT)
69        .with_timestamp_format(TIMESTAMP_FORMAT),
70};
71
72struct TimeStampInfo {
73    year: i32,
74    month: u32,
75    day: u32,
76    hour: u32,
77    minute: u32,
78    second: u32,
79    microsecond: u32,
80}
81
82impl Default for TimeStampInfo {
83    fn default() -> Self {
84        TimeStampInfo {
85            year: 1,
86            month: 1,
87            day: 1,
88            hour: 0,
89            minute: 0,
90            second: 0,
91            microsecond: 0,
92        }
93    }
94}
95
96impl TimeStampInfo {
97    pub fn with_year(&mut self, year: i32) -> &mut Self {
98        self.year = year;
99        self
100    }
101
102    pub fn with_month(&mut self, month: u32) -> &mut Self {
103        self.month = month;
104        self
105    }
106
107    pub fn with_day(&mut self, day: u32) -> &mut Self {
108        self.day = day;
109        self
110    }
111
112    pub fn with_hour(&mut self, hour: u32) -> &mut Self {
113        self.hour = hour;
114        self
115    }
116
117    pub fn with_minute(&mut self, minute: u32) -> &mut Self {
118        self.minute = minute;
119        self
120    }
121
122    pub fn with_second(&mut self, second: u32) -> &mut Self {
123        self.second = second;
124        self
125    }
126
127    pub fn with_microsecond(&mut self, microsecond: u32) -> &mut Self {
128        self.microsecond = microsecond;
129        self
130    }
131}
132
133#[derive(Debug, Eq)]
134pub struct Cast {
135    pub child: Arc<dyn PhysicalExpr>,
136    pub data_type: DataType,
137    pub cast_options: SparkCastOptions,
138}
139
140impl PartialEq for Cast {
141    fn eq(&self, other: &Self) -> bool {
142        self.child.eq(&other.child)
143            && self.data_type.eq(&other.data_type)
144            && self.cast_options.eq(&other.cast_options)
145    }
146}
147
148impl Hash for Cast {
149    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
150        self.child.hash(state);
151        self.data_type.hash(state);
152        self.cast_options.hash(state);
153    }
154}
155
156/// Determine if Comet supports a cast, taking options such as EvalMode and Timezone into account.
157pub fn cast_supported(
158    from_type: &DataType,
159    to_type: &DataType,
160    options: &SparkCastOptions,
161) -> bool {
162    use DataType::*;
163
164    let from_type = if let Dictionary(_, dt) = from_type {
165        dt
166    } else {
167        from_type
168    };
169
170    let to_type = if let Dictionary(_, dt) = to_type {
171        dt
172    } else {
173        to_type
174    };
175
176    if from_type == to_type {
177        return true;
178    }
179
180    match (from_type, to_type) {
181        (Boolean, _) => can_cast_from_boolean(to_type, options),
182        (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
183            if options.allow_cast_unsigned_ints =>
184        {
185            true
186        }
187        (Int8, _) => can_cast_from_byte(to_type, options),
188        (Int16, _) => can_cast_from_short(to_type, options),
189        (Int32, _) => can_cast_from_int(to_type, options),
190        (Int64, _) => can_cast_from_long(to_type, options),
191        (Float32, _) => can_cast_from_float(to_type, options),
192        (Float64, _) => can_cast_from_double(to_type, options),
193        (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
194        (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options),
195        (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options),
196        (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
197        (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
198        (Struct(from_fields), Struct(to_fields)) => from_fields
199            .iter()
200            .zip(to_fields.iter())
201            .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)),
202        _ => false,
203    }
204}
205
206fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool {
207    use DataType::*;
208    match to_type {
209        Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
210        Float32 | Float64 => {
211            // https://github.com/apache/datafusion-comet/issues/326
212            // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
213            // Does not support ANSI mode.
214            options.allow_incompat
215        }
216        Decimal128(_, _) => {
217            // https://github.com/apache/datafusion-comet/issues/325
218            // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
219            // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits
220
221            options.allow_incompat
222        }
223        Date32 | Date64 => {
224            // https://github.com/apache/datafusion-comet/issues/327
225            // Only supports years between 262143 BC and 262142 AD
226            options.allow_incompat
227        }
228        Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => {
229            // ANSI mode not supported
230            false
231        }
232        Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => {
233            // Cast will use UTC instead of $timeZoneId
234            options.allow_incompat
235        }
236        Timestamp(_, _) => {
237            // https://github.com/apache/datafusion-comet/issues/328
238            // Not all valid formats are supported
239            options.allow_incompat
240        }
241        _ => false,
242    }
243}
244
245fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> bool {
246    use DataType::*;
247    match from_type {
248        Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => true,
249        Float32 | Float64 => {
250            // There can be differences in precision.
251            // For example, the input \"1.4E-45\" will produce 1.0E-45 " +
252            // instead of 1.4E-45"))
253            true
254        }
255        Decimal128(_, _) => {
256            // https://github.com/apache/datafusion-comet/issues/1068
257            // There can be formatting differences in some case due to Spark using
258            // scientific notation where Comet does not
259            true
260        }
261        Binary => {
262            // https://github.com/apache/datafusion-comet/issues/377
263            // Only works for binary data representing valid UTF-8 strings
264            options.allow_incompat
265        }
266        Struct(fields) => fields
267            .iter()
268            .all(|f| can_cast_to_string(f.data_type(), options)),
269        _ => false,
270    }
271}
272
273fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) -> bool {
274    use DataType::*;
275    match to_type {
276        Timestamp(_, _) | Date32 | Date64 | Utf8 => {
277            // incompatible
278            options.allow_incompat
279        }
280        _ => {
281            // unsupported
282            false
283        }
284    }
285}
286
287fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> bool {
288    use DataType::*;
289    match to_type {
290        Boolean | Int8 | Int16 => {
291            // https://github.com/apache/datafusion-comet/issues/352
292            // this seems like an edge case that isn't important for us to support
293            false
294        }
295        Int64 => {
296            // https://github.com/apache/datafusion-comet/issues/352
297            true
298        }
299        Date32 | Date64 | Utf8 | Decimal128(_, _) => true,
300        _ => {
301            // unsupported
302            false
303        }
304    }
305}
306
307fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
308    use DataType::*;
309    matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
310}
311
312fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
313    use DataType::*;
314    matches!(
315        to_type,
316        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
317    )
318}
319
320fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
321    use DataType::*;
322    matches!(
323        to_type,
324        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
325    )
326}
327
328fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
329    use DataType::*;
330    match to_type {
331        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true,
332        Decimal128(_, _) => {
333            // incompatible: no overflow check
334            options.allow_incompat
335        }
336        _ => false,
337    }
338}
339
340fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
341    use DataType::*;
342    match to_type {
343        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
344        Decimal128(_, _) => {
345            // incompatible: no overflow check
346            options.allow_incompat
347        }
348        _ => false,
349    }
350}
351
352fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool {
353    use DataType::*;
354    matches!(
355        to_type,
356        Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _)
357    )
358}
359
360fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool {
361    use DataType::*;
362    matches!(
363        to_type,
364        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _)
365    )
366}
367
368fn can_cast_from_decimal(
369    p1: &u8,
370    _s1: &i8,
371    to_type: &DataType,
372    options: &SparkCastOptions,
373) -> bool {
374    use DataType::*;
375    match to_type {
376        Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
377        Decimal128(p2, _) => {
378            if p2 < p1 {
379                // https://github.com/apache/datafusion/issues/13492
380                // Incompatible(Some("Casting to smaller precision is not supported"))
381                options.allow_incompat
382            } else {
383                true
384            }
385        }
386        _ => false,
387    }
388}
389
390macro_rules! cast_utf8_to_int {
391    ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
392        let len = $array.len();
393        let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
394        for i in 0..len {
395            if $array.is_null(i) {
396                cast_array.append_null()
397            } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
398                cast_array.append_value(cast_value);
399            } else {
400                cast_array.append_null()
401            }
402        }
403        let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
404        result
405    }};
406}
407macro_rules! cast_utf8_to_timestamp {
408    ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
409        let len = $array.len();
410        let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
411        for i in 0..len {
412            if $array.is_null(i) {
413                cast_array.append_null()
414            } else if let Ok(Some(cast_value)) =
415                $cast_method($array.value(i).trim(), $eval_mode, $tz)
416            {
417                cast_array.append_value(cast_value);
418            } else {
419                cast_array.append_null()
420            }
421        }
422        let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
423        result
424    }};
425}
426
427macro_rules! cast_float_to_string {
428    ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{
429
430        fn cast<OffsetSize>(
431            from: &dyn Array,
432            _eval_mode: EvalMode,
433        ) -> SparkResult<ArrayRef>
434        where
435            OffsetSize: OffsetSizeTrait, {
436                let array = from.as_any().downcast_ref::<$output_type>().unwrap();
437
438                // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the
439                // result is expressed without scientific notation with at least one digit on either side of
440                // the decimal point. Otherwise, Spark uses a mantissa followed by E and an
441                // exponent. The mantissa has an optional leading minus sign followed by one digit to the
442                // left of the decimal point, and the minimal number of digits greater than zero to the
443                // right. The exponent has and optional leading minus sign.
444                // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html
445
446                const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
447                const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;
448
449                let output_array = array
450                    .iter()
451                    .map(|value| match value {
452                        Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())),
453                        Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())),
454                        Some(value)
455                            if (value.abs() < UPPER_SCIENTIFIC_BOUND
456                                && value.abs() >= LOWER_SCIENTIFIC_BOUND)
457                                || value.abs() == 0.0 =>
458                        {
459                            let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" };
460
461                            Ok(Some(format!("{value}{trailing_zero}")))
462                        }
463                        Some(value)
464                            if value.abs() >= UPPER_SCIENTIFIC_BOUND
465                                || value.abs() < LOWER_SCIENTIFIC_BOUND =>
466                        {
467                            let formatted = format!("{value:E}");
468
469                            if formatted.contains(".") {
470                                Ok(Some(formatted))
471                            } else {
472                                // `formatted` is already in scientific notation and can be split up by E
473                                // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0
474                                let prepare_number: Vec<&str> = formatted.split("E").collect();
475
476                                let coefficient = prepare_number[0];
477
478                                let exponent = prepare_number[1];
479
480                                Ok(Some(format!("{coefficient}.0E{exponent}")))
481                            }
482                        }
483                        Some(value) => Ok(Some(value.to_string())),
484                        _ => Ok(None),
485                    })
486                    .collect::<Result<GenericStringArray<OffsetSize>, SparkError>>()?;
487
488                Ok(Arc::new(output_array))
489            }
490
491        cast::<$offset_type>($from, $eval_mode)
492    }};
493}
494
495macro_rules! cast_int_to_int_macro {
496    (
497        $array: expr,
498        $eval_mode:expr,
499        $from_arrow_primitive_type: ty,
500        $to_arrow_primitive_type: ty,
501        $from_data_type: expr,
502        $to_native_type: ty,
503        $spark_from_data_type_name: expr,
504        $spark_to_data_type_name: expr
505    ) => {{
506        let cast_array = $array
507            .as_any()
508            .downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
509            .unwrap();
510        let spark_int_literal_suffix = match $from_data_type {
511            &DataType::Int64 => "L",
512            &DataType::Int16 => "S",
513            &DataType::Int8 => "T",
514            _ => "",
515        };
516
517        let output_array = match $eval_mode {
518            EvalMode::Legacy => cast_array
519                .iter()
520                .map(|value| match value {
521                    Some(value) => {
522                        Ok::<Option<$to_native_type>, SparkError>(Some(value as $to_native_type))
523                    }
524                    _ => Ok(None),
525                })
526                .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
527            _ => cast_array
528                .iter()
529                .map(|value| match value {
530                    Some(value) => {
531                        let res = <$to_native_type>::try_from(value);
532                        if res.is_err() {
533                            Err(cast_overflow(
534                                &(value.to_string() + spark_int_literal_suffix),
535                                $spark_from_data_type_name,
536                                $spark_to_data_type_name,
537                            ))
538                        } else {
539                            Ok::<Option<$to_native_type>, SparkError>(Some(res.unwrap()))
540                        }
541                    }
542                    _ => Ok(None),
543                })
544                .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
545        }?;
546        let result: SparkResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
547        result
548    }};
549}
550
551// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short.
552// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast,
553// this can cause unexpected Short/Byte cast results. Replicate this behavior.
554macro_rules! cast_float_to_int16_down {
555    (
556        $array:expr,
557        $eval_mode:expr,
558        $src_array_type:ty,
559        $dest_array_type:ty,
560        $rust_src_type:ty,
561        $rust_dest_type:ty,
562        $src_type_str:expr,
563        $dest_type_str:expr,
564        $format_str:expr
565    ) => {{
566        let cast_array = $array
567            .as_any()
568            .downcast_ref::<$src_array_type>()
569            .expect(concat!("Expected a ", stringify!($src_array_type)));
570
571        let output_array = match $eval_mode {
572            EvalMode::Ansi => cast_array
573                .iter()
574                .map(|value| match value {
575                    Some(value) => {
576                        let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX;
577                        if is_overflow {
578                            return Err(cast_overflow(
579                                &format!($format_str, value).replace("e", "E"),
580                                $src_type_str,
581                                $dest_type_str,
582                            ));
583                        }
584                        let i32_value = value as i32;
585                        <$rust_dest_type>::try_from(i32_value)
586                            .map_err(|_| {
587                                cast_overflow(
588                                    &format!($format_str, value).replace("e", "E"),
589                                    $src_type_str,
590                                    $dest_type_str,
591                                )
592                            })
593                            .map(Some)
594                    }
595                    None => Ok(None),
596                })
597                .collect::<Result<$dest_array_type, _>>()?,
598            _ => cast_array
599                .iter()
600                .map(|value| match value {
601                    Some(value) => {
602                        let i32_value = value as i32;
603                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
604                            i32_value as $rust_dest_type,
605                        ))
606                    }
607                    None => Ok(None),
608                })
609                .collect::<Result<$dest_array_type, _>>()?,
610        };
611        Ok(Arc::new(output_array) as ArrayRef)
612    }};
613}
614
615macro_rules! cast_float_to_int32_up {
616    (
617        $array:expr,
618        $eval_mode:expr,
619        $src_array_type:ty,
620        $dest_array_type:ty,
621        $rust_src_type:ty,
622        $rust_dest_type:ty,
623        $src_type_str:expr,
624        $dest_type_str:expr,
625        $max_dest_val:expr,
626        $format_str:expr
627    ) => {{
628        let cast_array = $array
629            .as_any()
630            .downcast_ref::<$src_array_type>()
631            .expect(concat!("Expected a ", stringify!($src_array_type)));
632
633        let output_array = match $eval_mode {
634            EvalMode::Ansi => cast_array
635                .iter()
636                .map(|value| match value {
637                    Some(value) => {
638                        let is_overflow =
639                            value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val;
640                        if is_overflow {
641                            return Err(cast_overflow(
642                                &format!($format_str, value).replace("e", "E"),
643                                $src_type_str,
644                                $dest_type_str,
645                            ));
646                        }
647                        Ok(Some(value as $rust_dest_type))
648                    }
649                    None => Ok(None),
650                })
651                .collect::<Result<$dest_array_type, _>>()?,
652            _ => cast_array
653                .iter()
654                .map(|value| match value {
655                    Some(value) => {
656                        Ok::<Option<$rust_dest_type>, SparkError>(Some(value as $rust_dest_type))
657                    }
658                    None => Ok(None),
659                })
660                .collect::<Result<$dest_array_type, _>>()?,
661        };
662        Ok(Arc::new(output_array) as ArrayRef)
663    }};
664}
665
666// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short.
667// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast,
668// this can cause unexpected Short/Byte cast results. Replicate this behavior.
669macro_rules! cast_decimal_to_int16_down {
670    (
671        $array:expr,
672        $eval_mode:expr,
673        $dest_array_type:ty,
674        $rust_dest_type:ty,
675        $dest_type_str:expr,
676        $precision:expr,
677        $scale:expr
678    ) => {{
679        let cast_array = $array
680            .as_any()
681            .downcast_ref::<Decimal128Array>()
682            .expect(concat!("Expected a Decimal128ArrayType"));
683
684        let output_array = match $eval_mode {
685            EvalMode::Ansi => cast_array
686                .iter()
687                .map(|value| match value {
688                    Some(value) => {
689                        let divisor = 10_i128.pow($scale as u32);
690                        let (truncated, decimal) = (value / divisor, (value % divisor).abs());
691                        let is_overflow = truncated.abs() > i32::MAX.into();
692                        if is_overflow {
693                            return Err(cast_overflow(
694                                &format!("{}.{}BD", truncated, decimal),
695                                &format!("DECIMAL({},{})", $precision, $scale),
696                                $dest_type_str,
697                            ));
698                        }
699                        let i32_value = truncated as i32;
700                        <$rust_dest_type>::try_from(i32_value)
701                            .map_err(|_| {
702                                cast_overflow(
703                                    &format!("{}.{}BD", truncated, decimal),
704                                    &format!("DECIMAL({},{})", $precision, $scale),
705                                    $dest_type_str,
706                                )
707                            })
708                            .map(Some)
709                    }
710                    None => Ok(None),
711                })
712                .collect::<Result<$dest_array_type, _>>()?,
713            _ => cast_array
714                .iter()
715                .map(|value| match value {
716                    Some(value) => {
717                        let divisor = 10_i128.pow($scale as u32);
718                        let i32_value = (value / divisor) as i32;
719                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
720                            i32_value as $rust_dest_type,
721                        ))
722                    }
723                    None => Ok(None),
724                })
725                .collect::<Result<$dest_array_type, _>>()?,
726        };
727        Ok(Arc::new(output_array) as ArrayRef)
728    }};
729}
730
731macro_rules! cast_decimal_to_int32_up {
732    (
733        $array:expr,
734        $eval_mode:expr,
735        $dest_array_type:ty,
736        $rust_dest_type:ty,
737        $dest_type_str:expr,
738        $max_dest_val:expr,
739        $precision:expr,
740        $scale:expr
741    ) => {{
742        let cast_array = $array
743            .as_any()
744            .downcast_ref::<Decimal128Array>()
745            .expect(concat!("Expected a Decimal128ArrayType"));
746
747        let output_array = match $eval_mode {
748            EvalMode::Ansi => cast_array
749                .iter()
750                .map(|value| match value {
751                    Some(value) => {
752                        let divisor = 10_i128.pow($scale as u32);
753                        let (truncated, decimal) = (value / divisor, (value % divisor).abs());
754                        let is_overflow = truncated.abs() > $max_dest_val.into();
755                        if is_overflow {
756                            return Err(cast_overflow(
757                                &format!("{}.{}BD", truncated, decimal),
758                                &format!("DECIMAL({},{})", $precision, $scale),
759                                $dest_type_str,
760                            ));
761                        }
762                        Ok(Some(truncated as $rust_dest_type))
763                    }
764                    None => Ok(None),
765                })
766                .collect::<Result<$dest_array_type, _>>()?,
767            _ => cast_array
768                .iter()
769                .map(|value| match value {
770                    Some(value) => {
771                        let divisor = 10_i128.pow($scale as u32);
772                        let truncated = value / divisor;
773                        Ok::<Option<$rust_dest_type>, SparkError>(Some(
774                            truncated as $rust_dest_type,
775                        ))
776                    }
777                    None => Ok(None),
778                })
779                .collect::<Result<$dest_array_type, _>>()?,
780        };
781        Ok(Arc::new(output_array) as ArrayRef)
782    }};
783}
784
785impl Cast {
786    pub fn new(
787        child: Arc<dyn PhysicalExpr>,
788        data_type: DataType,
789        cast_options: SparkCastOptions,
790    ) -> Self {
791        Self {
792            child,
793            data_type,
794            cast_options,
795        }
796    }
797}
798
799/// Spark cast options
800#[derive(Debug, Clone, Hash, PartialEq, Eq)]
801pub struct SparkCastOptions {
802    /// Spark evaluation mode
803    pub eval_mode: EvalMode,
804    /// When cast from/to timezone related types, we need timezone, which will be resolved with
805    /// session local timezone by an analyzer in Spark.
806    // TODO we should change timezone to Tz to avoid repeated parsing
807    pub timezone: String,
808    /// Allow casts that are supported but not guaranteed to be 100% compatible
809    pub allow_incompat: bool,
810    /// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter)
811    pub allow_cast_unsigned_ints: bool,
812    /// We also use the cast logic for adapting Parquet schemas, so this flag is used
813    /// for that use case
814    pub is_adapting_schema: bool,
815    /// String to use to represent null values
816    pub null_string: String,
817}
818
819impl SparkCastOptions {
820    pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self {
821        Self {
822            eval_mode,
823            timezone: timezone.to_string(),
824            allow_incompat,
825            allow_cast_unsigned_ints: false,
826            is_adapting_schema: false,
827            null_string: "null".to_string(),
828        }
829    }
830
831    pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self {
832        Self {
833            eval_mode,
834            timezone: "".to_string(),
835            allow_incompat,
836            allow_cast_unsigned_ints: false,
837            is_adapting_schema: false,
838            null_string: "null".to_string(),
839        }
840    }
841}
842
843/// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known
844/// to be compatible, and returns an error when a not supported and not DF-compatible cast
845/// is requested.
846pub fn spark_cast(
847    arg: ColumnarValue,
848    data_type: &DataType,
849    cast_options: &SparkCastOptions,
850) -> DataFusionResult<ColumnarValue> {
851    match arg {
852        ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array(
853            array,
854            data_type,
855            cast_options,
856        )?)),
857        ColumnarValue::Scalar(scalar) => {
858            // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for
859            // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it
860            // here.
861            let array = scalar.to_array()?;
862            let scalar =
863                ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?;
864            Ok(ColumnarValue::Scalar(scalar))
865        }
866    }
867}
868
869fn cast_array(
870    array: ArrayRef,
871    to_type: &DataType,
872    cast_options: &SparkCastOptions,
873) -> DataFusionResult<ArrayRef> {
874    use DataType::*;
875    let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
876    let from_type = array.data_type().clone();
877
878    let native_cast_options: CastOptions = CastOptions {
879        safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), // take safe mode from cast_options passed
880        format_options: FormatOptions::new()
881            .with_timestamp_tz_format(TIMESTAMP_FORMAT)
882            .with_timestamp_format(TIMESTAMP_FORMAT),
883    };
884
885    let array = match &from_type {
886        Dictionary(key_type, value_type)
887            if key_type.as_ref() == &Int32
888                && (value_type.as_ref() == &Utf8
889                    || value_type.as_ref() == &LargeUtf8
890                    || value_type.as_ref() == &Binary
891                    || value_type.as_ref() == &LargeBinary) =>
892        {
893            let dict_array = array
894                .as_any()
895                .downcast_ref::<DictionaryArray<Int32Type>>()
896                .expect("Expected a dictionary array");
897
898            let casted_dictionary = DictionaryArray::<Int32Type>::new(
899                dict_array.keys().clone(),
900                cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?,
901            );
902
903            let casted_result = match to_type {
904                Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
905                _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?,
906            };
907            return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
908        }
909        _ => array,
910    };
911    let from_type = array.data_type();
912    let eval_mode = cast_options.eval_mode;
913
914    let cast_result = match (from_type, to_type) {
915        (Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
916        (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
917        (Utf8, Timestamp(_, _)) => {
918            cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
919        }
920        (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
921        (Int64, Int32)
922        | (Int64, Int16)
923        | (Int64, Int8)
924        | (Int32, Int16)
925        | (Int32, Int8)
926        | (Int16, Int8)
927            if eval_mode != EvalMode::Try =>
928        {
929            spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
930        }
931        (Utf8, Int8 | Int16 | Int32 | Int64) => {
932            cast_string_to_int::<i32>(to_type, &array, eval_mode)
933        }
934        (LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
935            cast_string_to_int::<i64>(to_type, &array, eval_mode)
936        }
937        (Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
938        (Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array, eval_mode),
939        (Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
940        (Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array, eval_mode),
941        (Float32, Decimal128(precision, scale)) => {
942            cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
943        }
944        (Float64, Decimal128(precision, scale)) => {
945            cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
946        }
947        (Float32, Int8)
948        | (Float32, Int16)
949        | (Float32, Int32)
950        | (Float32, Int64)
951        | (Float64, Int8)
952        | (Float64, Int16)
953        | (Float64, Int32)
954        | (Float64, Int64)
955        | (Decimal128(_, _), Int8)
956        | (Decimal128(_, _), Int16)
957        | (Decimal128(_, _), Int32)
958        | (Decimal128(_, _), Int64)
959            if eval_mode != EvalMode::Try =>
960        {
961            spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
962        }
963        (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
964        (Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
965            array.as_struct(),
966            from_type,
967            to_type,
968            cast_options,
969        )?),
970        (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
971            if cast_options.allow_cast_unsigned_ints =>
972        {
973            Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
974        }
975        _ if cast_options.is_adapting_schema
976            || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
977        {
978            // use DataFusion cast only when we know that it is compatible with Spark
979            Ok(cast_with_options(&array, to_type, &native_cast_options)?)
980        }
981        _ => {
982            // we should never reach this code because the Scala code should be checking
983            // for supported cast operations and falling back to Spark for anything that
984            // is not yet supported
985            Err(SparkError::Internal(format!(
986                "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}"
987            )))
988        }
989    };
990    Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
991}
992
993/// Determines if DataFusion supports the given cast in a way that is
994/// compatible with Spark
995fn is_datafusion_spark_compatible(
996    from_type: &DataType,
997    to_type: &DataType,
998    allow_incompat: bool,
999) -> bool {
1000    if from_type == to_type {
1001        return true;
1002    }
1003    match from_type {
1004        DataType::Null => {
1005            matches!(to_type, DataType::List(_))
1006        }
1007        DataType::Boolean => matches!(
1008            to_type,
1009            DataType::Int8
1010                | DataType::Int16
1011                | DataType::Int32
1012                | DataType::Int64
1013                | DataType::Float32
1014                | DataType::Float64
1015                | DataType::Utf8
1016        ),
1017        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1018            // note that the cast from Int32/Int64 -> Decimal128 here is actually
1019            // not compatible with Spark (no overflow checks) but we have tests that
1020            // rely on this cast working so we have to leave it here for now
1021            matches!(
1022                to_type,
1023                DataType::Boolean
1024                    | DataType::Int8
1025                    | DataType::Int16
1026                    | DataType::Int32
1027                    | DataType::Int64
1028                    | DataType::Float32
1029                    | DataType::Float64
1030                    | DataType::Decimal128(_, _)
1031                    | DataType::Utf8
1032            )
1033        }
1034        DataType::Float32 | DataType::Float64 => matches!(
1035            to_type,
1036            DataType::Boolean
1037                | DataType::Int8
1038                | DataType::Int16
1039                | DataType::Int32
1040                | DataType::Int64
1041                | DataType::Float32
1042                | DataType::Float64
1043        ),
1044        DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1045            to_type,
1046            DataType::Int8
1047                | DataType::Int16
1048                | DataType::Int32
1049                | DataType::Int64
1050                | DataType::Float32
1051                | DataType::Float64
1052                | DataType::Decimal128(_, _)
1053                | DataType::Decimal256(_, _)
1054                | DataType::Utf8 // note that there can be formatting differences
1055        ),
1056        DataType::Utf8 if allow_incompat => matches!(
1057            to_type,
1058            DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
1059        ),
1060        DataType::Utf8 => matches!(to_type, DataType::Binary),
1061        DataType::Date32 => matches!(to_type, DataType::Utf8),
1062        DataType::Timestamp(_, _) => {
1063            matches!(
1064                to_type,
1065                DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1066            )
1067        }
1068        DataType::Binary => {
1069            // note that this is not completely Spark compatible because
1070            // DataFusion only supports binary data containing valid UTF-8 strings
1071            matches!(to_type, DataType::Utf8)
1072        }
1073        _ => false,
1074    }
1075}
1076
1077/// Cast between struct types based on logic in
1078/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`.
1079fn cast_struct_to_struct(
1080    array: &StructArray,
1081    from_type: &DataType,
1082    to_type: &DataType,
1083    cast_options: &SparkCastOptions,
1084) -> DataFusionResult<ArrayRef> {
1085    match (from_type, to_type) {
1086        (DataType::Struct(from_fields), DataType::Struct(to_fields)) => {
1087            let cast_fields: Vec<ArrayRef> = from_fields
1088                .iter()
1089                .enumerate()
1090                .zip(to_fields.iter())
1091                .map(|((idx, _from), to)| {
1092                    let from_field = Arc::clone(array.column(idx));
1093                    let array_length = from_field.len();
1094                    let cast_result = spark_cast(
1095                        ColumnarValue::from(from_field),
1096                        to.data_type(),
1097                        cast_options,
1098                    )
1099                    .unwrap();
1100                    cast_result.to_array(array_length).unwrap()
1101                })
1102                .collect();
1103
1104            Ok(Arc::new(StructArray::new(
1105                to_fields.clone(),
1106                cast_fields,
1107                array.nulls().cloned(),
1108            )))
1109        }
1110        _ => unreachable!(),
1111    }
1112}
1113
1114fn casts_struct_to_string(
1115    array: &StructArray,
1116    spark_cast_options: &SparkCastOptions,
1117) -> DataFusionResult<ArrayRef> {
1118    // cast each field to a string
1119    let string_arrays: Vec<ArrayRef> = array
1120        .columns()
1121        .iter()
1122        .map(|arr| {
1123            spark_cast(
1124                ColumnarValue::Array(Arc::clone(arr)),
1125                &DataType::Utf8,
1126                spark_cast_options,
1127            )
1128            .and_then(|cv| cv.into_array(arr.len()))
1129        })
1130        .collect::<DataFusionResult<Vec<_>>>()?;
1131    let string_arrays: Vec<&StringArray> =
1132        string_arrays.iter().map(|arr| arr.as_string()).collect();
1133    // build the struct string containing entries in the format `"field_name":field_value`
1134    let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
1135    let mut str = String::with_capacity(array.len() * 16);
1136    for row_index in 0..array.len() {
1137        if array.is_null(row_index) {
1138            builder.append_null();
1139        } else {
1140            str.clear();
1141            let mut any_fields_written = false;
1142            str.push('{');
1143            for field in &string_arrays {
1144                if any_fields_written {
1145                    str.push_str(", ");
1146                }
1147                if field.is_null(row_index) {
1148                    str.push_str(&spark_cast_options.null_string);
1149                } else {
1150                    str.push_str(field.value(row_index));
1151                }
1152                any_fields_written = true;
1153            }
1154            str.push('}');
1155            builder.append_value(&str);
1156        }
1157    }
1158    Ok(Arc::new(builder.finish()))
1159}
1160
1161fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
1162    to_type: &DataType,
1163    array: &ArrayRef,
1164    eval_mode: EvalMode,
1165) -> SparkResult<ArrayRef> {
1166    let string_array = array
1167        .as_any()
1168        .downcast_ref::<GenericStringArray<OffsetSize>>()
1169        .expect("cast_string_to_int expected a string array");
1170
1171    let cast_array: ArrayRef = match to_type {
1172        DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?,
1173        DataType::Int16 => {
1174            cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)?
1175        }
1176        DataType::Int32 => {
1177            cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)?
1178        }
1179        DataType::Int64 => {
1180            cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)?
1181        }
1182        dt => unreachable!(
1183            "{}",
1184            format!("invalid integer type {dt} in cast from string")
1185        ),
1186    };
1187    Ok(cast_array)
1188}
1189
1190fn cast_string_to_date(
1191    array: &ArrayRef,
1192    to_type: &DataType,
1193    eval_mode: EvalMode,
1194) -> SparkResult<ArrayRef> {
1195    let string_array = array
1196        .as_any()
1197        .downcast_ref::<GenericStringArray<i32>>()
1198        .expect("Expected a string array");
1199
1200    if to_type != &DataType::Date32 {
1201        unreachable!("Invalid data type {:?} in cast from string", to_type);
1202    }
1203
1204    let len = string_array.len();
1205    let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
1206
1207    for i in 0..len {
1208        let value = if string_array.is_null(i) {
1209            None
1210        } else {
1211            match date_parser(string_array.value(i), eval_mode) {
1212                Ok(Some(cast_value)) => Some(cast_value),
1213                Ok(None) => None,
1214                Err(e) => return Err(e),
1215            }
1216        };
1217
1218        match value {
1219            Some(cast_value) => cast_array.append_value(cast_value),
1220            None => cast_array.append_null(),
1221        }
1222    }
1223
1224    Ok(Arc::new(cast_array.finish()) as ArrayRef)
1225}
1226
1227fn cast_string_to_timestamp(
1228    array: &ArrayRef,
1229    to_type: &DataType,
1230    eval_mode: EvalMode,
1231    timezone_str: &str,
1232) -> SparkResult<ArrayRef> {
1233    let string_array = array
1234        .as_any()
1235        .downcast_ref::<GenericStringArray<i32>>()
1236        .expect("Expected a string array");
1237
1238    let tz = &timezone::Tz::from_str(timezone_str).unwrap();
1239
1240    let cast_array: ArrayRef = match to_type {
1241        DataType::Timestamp(_, _) => {
1242            cast_utf8_to_timestamp!(
1243                string_array,
1244                eval_mode,
1245                TimestampMicrosecondType,
1246                timestamp_parser,
1247                tz
1248            )
1249        }
1250        _ => unreachable!("Invalid data type {:?} in cast from string", to_type),
1251    };
1252    Ok(cast_array)
1253}
1254
1255fn cast_float64_to_decimal128(
1256    array: &dyn Array,
1257    precision: u8,
1258    scale: i8,
1259    eval_mode: EvalMode,
1260) -> SparkResult<ArrayRef> {
1261    cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode)
1262}
1263
1264fn cast_float32_to_decimal128(
1265    array: &dyn Array,
1266    precision: u8,
1267    scale: i8,
1268    eval_mode: EvalMode,
1269) -> SparkResult<ArrayRef> {
1270    cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode)
1271}
1272
1273fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
1274    array: &dyn Array,
1275    precision: u8,
1276    scale: i8,
1277    eval_mode: EvalMode,
1278) -> SparkResult<ArrayRef>
1279where
1280    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
1281{
1282    let input = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1283    let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());
1284
1285    let mul = 10_f64.powi(scale as i32);
1286
1287    for i in 0..input.len() {
1288        if input.is_null(i) {
1289            cast_array.append_null();
1290            continue;
1291        }
1292
1293        let input_value = input.value(i).as_();
1294        if let Some(v) = (input_value * mul).round().to_i128() {
1295            if is_validate_decimal_precision(v, precision) {
1296                cast_array.append_value(v);
1297                continue;
1298            }
1299        };
1300
1301        if eval_mode == EvalMode::Ansi {
1302            return Err(SparkError::NumericValueOutOfRange {
1303                value: input_value.to_string(),
1304                precision,
1305                scale,
1306            });
1307        }
1308        cast_array.append_null();
1309    }
1310
1311    let res = Arc::new(
1312        cast_array
1313            .with_precision_and_scale(precision, scale)?
1314            .finish(),
1315    ) as ArrayRef;
1316    Ok(res)
1317}
1318
1319fn spark_cast_float64_to_utf8<OffsetSize>(
1320    from: &dyn Array,
1321    _eval_mode: EvalMode,
1322) -> SparkResult<ArrayRef>
1323where
1324    OffsetSize: OffsetSizeTrait,
1325{
1326    cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
1327}
1328
1329fn spark_cast_float32_to_utf8<OffsetSize>(
1330    from: &dyn Array,
1331    _eval_mode: EvalMode,
1332) -> SparkResult<ArrayRef>
1333where
1334    OffsetSize: OffsetSizeTrait,
1335{
1336    cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
1337}
1338
1339fn spark_cast_int_to_int(
1340    array: &dyn Array,
1341    eval_mode: EvalMode,
1342    from_type: &DataType,
1343    to_type: &DataType,
1344) -> SparkResult<ArrayRef> {
1345    match (from_type, to_type) {
1346        (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
1347            array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
1348        ),
1349        (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
1350            array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
1351        ),
1352        (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
1353            array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
1354        ),
1355        (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
1356            array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
1357        ),
1358        (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
1359            array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
1360        ),
1361        (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
1362            array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
1363        ),
1364        _ => unreachable!(
1365            "{}",
1366            format!("invalid integer type {to_type} in cast from {from_type}")
1367        ),
1368    }
1369}
1370
1371fn spark_cast_utf8_to_boolean<OffsetSize>(
1372    from: &dyn Array,
1373    eval_mode: EvalMode,
1374) -> SparkResult<ArrayRef>
1375where
1376    OffsetSize: OffsetSizeTrait,
1377{
1378    let array = from
1379        .as_any()
1380        .downcast_ref::<GenericStringArray<OffsetSize>>()
1381        .unwrap();
1382
1383    let output_array = array
1384        .iter()
1385        .map(|value| match value {
1386            Some(value) => match value.to_ascii_lowercase().trim() {
1387                "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
1388                "f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
1389                _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue {
1390                    value: value.to_string(),
1391                    from_type: "STRING".to_string(),
1392                    to_type: "BOOLEAN".to_string(),
1393                }),
1394                _ => Ok(None),
1395            },
1396            _ => Ok(None),
1397        })
1398        .collect::<Result<BooleanArray, _>>()?;
1399
1400    Ok(Arc::new(output_array))
1401}
1402
1403fn spark_cast_nonintegral_numeric_to_integral(
1404    array: &dyn Array,
1405    eval_mode: EvalMode,
1406    from_type: &DataType,
1407    to_type: &DataType,
1408) -> SparkResult<ArrayRef> {
1409    match (from_type, to_type) {
1410        (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
1411            array,
1412            eval_mode,
1413            Float32Array,
1414            Int8Array,
1415            f32,
1416            i8,
1417            "FLOAT",
1418            "TINYINT",
1419            "{:e}"
1420        ),
1421        (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!(
1422            array,
1423            eval_mode,
1424            Float32Array,
1425            Int16Array,
1426            f32,
1427            i16,
1428            "FLOAT",
1429            "SMALLINT",
1430            "{:e}"
1431        ),
1432        (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!(
1433            array,
1434            eval_mode,
1435            Float32Array,
1436            Int32Array,
1437            f32,
1438            i32,
1439            "FLOAT",
1440            "INT",
1441            i32::MAX,
1442            "{:e}"
1443        ),
1444        (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
1445            array,
1446            eval_mode,
1447            Float32Array,
1448            Int64Array,
1449            f32,
1450            i64,
1451            "FLOAT",
1452            "BIGINT",
1453            i64::MAX,
1454            "{:e}"
1455        ),
1456        (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
1457            array,
1458            eval_mode,
1459            Float64Array,
1460            Int8Array,
1461            f64,
1462            i8,
1463            "DOUBLE",
1464            "TINYINT",
1465            "{:e}D"
1466        ),
1467        (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!(
1468            array,
1469            eval_mode,
1470            Float64Array,
1471            Int16Array,
1472            f64,
1473            i16,
1474            "DOUBLE",
1475            "SMALLINT",
1476            "{:e}D"
1477        ),
1478        (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!(
1479            array,
1480            eval_mode,
1481            Float64Array,
1482            Int32Array,
1483            f64,
1484            i32,
1485            "DOUBLE",
1486            "INT",
1487            i32::MAX,
1488            "{:e}D"
1489        ),
1490        (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
1491            array,
1492            eval_mode,
1493            Float64Array,
1494            Int64Array,
1495            f64,
1496            i64,
1497            "DOUBLE",
1498            "BIGINT",
1499            i64::MAX,
1500            "{:e}D"
1501        ),
1502        (DataType::Decimal128(precision, scale), DataType::Int8) => {
1503            cast_decimal_to_int16_down!(
1504                array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1505            )
1506        }
1507        (DataType::Decimal128(precision, scale), DataType::Int16) => {
1508            cast_decimal_to_int16_down!(
1509                array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1510            )
1511        }
1512        (DataType::Decimal128(precision, scale), DataType::Int32) => {
1513            cast_decimal_to_int32_up!(
1514                array,
1515                eval_mode,
1516                Int32Array,
1517                i32,
1518                "INT",
1519                i32::MAX,
1520                *precision,
1521                *scale
1522            )
1523        }
1524        (DataType::Decimal128(precision, scale), DataType::Int64) => {
1525            cast_decimal_to_int32_up!(
1526                array,
1527                eval_mode,
1528                Int64Array,
1529                i64,
1530                "BIGINT",
1531                i64::MAX,
1532                *precision,
1533                *scale
1534            )
1535        }
1536        _ => unreachable!(
1537            "{}",
1538            format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}")
1539        ),
1540    }
1541}
1542
1543/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
1544fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> {
1545    Ok(cast_string_to_int_with_range_check(
1546        str,
1547        eval_mode,
1548        "TINYINT",
1549        i8::MIN as i32,
1550        i8::MAX as i32,
1551    )?
1552    .map(|v| v as i8))
1553}
1554
1555/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
1556fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>> {
1557    Ok(cast_string_to_int_with_range_check(
1558        str,
1559        eval_mode,
1560        "SMALLINT",
1561        i16::MIN as i32,
1562        i16::MAX as i32,
1563    )?
1564    .map(|v| v as i16))
1565}
1566
1567/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper)
1568fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1569    do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
1570}
1571
1572/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper)
1573fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> {
1574    do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
1575}
1576
1577fn cast_string_to_int_with_range_check(
1578    str: &str,
1579    eval_mode: EvalMode,
1580    type_name: &str,
1581    min: i32,
1582    max: i32,
1583) -> SparkResult<Option<i32>> {
1584    match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
1585        None => Ok(None),
1586        Some(v) if v >= min && v <= max => Ok(Some(v)),
1587        _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1588        _ => Ok(None),
1589    }
1590}
1591
1592/// Equivalent to
1593/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal)
1594/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal)
1595fn do_cast_string_to_int<
1596    T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
1597>(
1598    str: &str,
1599    eval_mode: EvalMode,
1600    type_name: &str,
1601    min_value: T,
1602) -> SparkResult<Option<T>> {
1603    let trimmed_str = str.trim();
1604    if trimmed_str.is_empty() {
1605        return none_or_err(eval_mode, type_name, str);
1606    }
1607    let len = trimmed_str.len();
1608    let mut result: T = T::zero();
1609    let mut negative = false;
1610    let radix = T::from(10);
1611    let stop_value = min_value / radix;
1612    let mut parse_sign_and_digits = true;
1613
1614    for (i, ch) in trimmed_str.char_indices() {
1615        if parse_sign_and_digits {
1616            if i == 0 {
1617                negative = ch == '-';
1618                let positive = ch == '+';
1619                if negative || positive {
1620                    if i + 1 == len {
1621                        // input string is just "+" or "-"
1622                        return none_or_err(eval_mode, type_name, str);
1623                    }
1624                    // consume this char
1625                    continue;
1626                }
1627            }
1628
1629            if ch == '.' {
1630                if eval_mode == EvalMode::Legacy {
1631                    // truncate decimal in legacy mode
1632                    parse_sign_and_digits = false;
1633                    continue;
1634                } else {
1635                    return none_or_err(eval_mode, type_name, str);
1636                }
1637            }
1638
1639            let digit = if ch.is_ascii_digit() {
1640                (ch as u32) - ('0' as u32)
1641            } else {
1642                return none_or_err(eval_mode, type_name, str);
1643            };
1644
1645            // We are going to process the new digit and accumulate the result. However, before
1646            // doing this, if the result is already smaller than the
1647            // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be
1648            // smaller than minValue, and we can stop
1649            if result < stop_value {
1650                return none_or_err(eval_mode, type_name, str);
1651            }
1652
1653            // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE /
1654            // radix), we can just use `result > 0` to check overflow. If result
1655            // overflows, we should stop
1656            let v = result * radix;
1657            let digit = (digit as i32).into();
1658            match v.checked_sub(&digit) {
1659                Some(x) if x <= T::zero() => result = x,
1660                _ => {
1661                    return none_or_err(eval_mode, type_name, str);
1662                }
1663            }
1664        } else {
1665            // make sure fractional digits are valid digits but ignore them
1666            if !ch.is_ascii_digit() {
1667                return none_or_err(eval_mode, type_name, str);
1668            }
1669        }
1670    }
1671
1672    if !negative {
1673        if let Some(neg) = result.checked_neg() {
1674            if neg < T::zero() {
1675                return none_or_err(eval_mode, type_name, str);
1676            }
1677            result = neg;
1678        } else {
1679            return none_or_err(eval_mode, type_name, str);
1680        }
1681    }
1682
1683    Ok(Some(result))
1684}
1685
1686/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode
1687#[inline]
1688fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {
1689    match eval_mode {
1690        EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
1691        _ => Ok(None),
1692    }
1693}
1694
1695#[inline]
1696fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
1697    SparkError::CastInvalidValue {
1698        value: value.to_string(),
1699        from_type: from_type.to_string(),
1700        to_type: to_type.to_string(),
1701    }
1702}
1703
1704#[inline]
1705fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
1706    SparkError::CastOverFlow {
1707        value: value.to_string(),
1708        from_type: from_type.to_string(),
1709        to_type: to_type.to_string(),
1710    }
1711}
1712
1713impl Display for Cast {
1714    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1715        write!(
1716            f,
1717            "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]",
1718            self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode
1719        )
1720    }
1721}
1722
1723impl PhysicalExpr for Cast {
1724    fn as_any(&self) -> &dyn Any {
1725        self
1726    }
1727
1728    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
1729        unimplemented!()
1730    }
1731
1732    fn data_type(&self, _: &Schema) -> DataFusionResult<DataType> {
1733        Ok(self.data_type.clone())
1734    }
1735
1736    fn nullable(&self, _: &Schema) -> DataFusionResult<bool> {
1737        Ok(true)
1738    }
1739
1740    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
1741        let arg = self.child.evaluate(batch)?;
1742        spark_cast(arg, &self.data_type, &self.cast_options)
1743    }
1744
1745    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1746        vec![&self.child]
1747    }
1748
1749    fn with_new_children(
1750        self: Arc<Self>,
1751        children: Vec<Arc<dyn PhysicalExpr>>,
1752    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
1753        match children.len() {
1754            1 => Ok(Arc::new(Cast::new(
1755                Arc::clone(&children[0]),
1756                self.data_type.clone(),
1757                self.cast_options.clone(),
1758            ))),
1759            _ => internal_err!("Cast should have exactly one child"),
1760        }
1761    }
1762}
1763
1764fn timestamp_parser<T: TimeZone>(
1765    value: &str,
1766    eval_mode: EvalMode,
1767    tz: &T,
1768) -> SparkResult<Option<i64>> {
1769    let value = value.trim();
1770    if value.is_empty() {
1771        return Ok(None);
1772    }
1773    // Define regex patterns and corresponding parsing functions
1774    let patterns = &[
1775        (
1776            Regex::new(r"^\d{4,5}$").unwrap(),
1777            parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
1778        ),
1779        (
1780            Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1781            parse_str_to_month_timestamp,
1782        ),
1783        (
1784            Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1785            parse_str_to_day_timestamp,
1786        ),
1787        (
1788            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1789            parse_str_to_hour_timestamp,
1790        ),
1791        (
1792            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1793            parse_str_to_minute_timestamp,
1794        ),
1795        (
1796            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1797            parse_str_to_second_timestamp,
1798        ),
1799        (
1800            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1801            parse_str_to_microsecond_timestamp,
1802        ),
1803        (
1804            Regex::new(r"^T\d{1,2}$").unwrap(),
1805            parse_str_to_time_only_timestamp,
1806        ),
1807    ];
1808
1809    let mut timestamp = None;
1810
1811    // Iterate through patterns and try matching
1812    for (pattern, parse_func) in patterns {
1813        if pattern.is_match(value) {
1814            timestamp = parse_func(value, tz)?;
1815            break;
1816        }
1817    }
1818
1819    if timestamp.is_none() {
1820        return if eval_mode == EvalMode::Ansi {
1821            Err(SparkError::CastInvalidValue {
1822                value: value.to_string(),
1823                from_type: "STRING".to_string(),
1824                to_type: "TIMESTAMP".to_string(),
1825            })
1826        } else {
1827            Ok(None)
1828        };
1829    }
1830
1831    match timestamp {
1832        Some(ts) => Ok(Some(ts)),
1833        None => Err(SparkError::Internal(
1834            "Failed to parse timestamp".to_string(),
1835        )),
1836    }
1837}
1838
1839fn parse_timestamp_to_micros<T: TimeZone>(
1840    timestamp_info: &TimeStampInfo,
1841    tz: &T,
1842) -> SparkResult<Option<i64>> {
1843    let datetime = tz.with_ymd_and_hms(
1844        timestamp_info.year,
1845        timestamp_info.month,
1846        timestamp_info.day,
1847        timestamp_info.hour,
1848        timestamp_info.minute,
1849        timestamp_info.second,
1850    );
1851
1852    // Check if datetime is not None
1853    let tz_datetime = match datetime.single() {
1854        Some(dt) => dt
1855            .with_timezone(tz)
1856            .with_nanosecond(timestamp_info.microsecond * 1000),
1857        None => {
1858            return Err(SparkError::Internal(
1859                "Failed to parse timestamp".to_string(),
1860            ));
1861        }
1862    };
1863
1864    let result = match tz_datetime {
1865        Some(dt) => dt.timestamp_micros(),
1866        None => {
1867            return Err(SparkError::Internal(
1868                "Failed to parse timestamp".to_string(),
1869            ));
1870        }
1871    };
1872
1873    Ok(Some(result))
1874}
1875
1876fn get_timestamp_values<T: TimeZone>(
1877    value: &str,
1878    timestamp_type: &str,
1879    tz: &T,
1880) -> SparkResult<Option<i64>> {
1881    let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
1882    let year = values[0].parse::<i32>().unwrap_or_default();
1883    let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
1884    let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
1885    let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
1886    let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
1887    let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
1888    let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));
1889
1890    let mut timestamp_info = TimeStampInfo::default();
1891
1892    let timestamp_info = match timestamp_type {
1893        "year" => timestamp_info.with_year(year),
1894        "month" => timestamp_info.with_year(year).with_month(month),
1895        "day" => timestamp_info
1896            .with_year(year)
1897            .with_month(month)
1898            .with_day(day),
1899        "hour" => timestamp_info
1900            .with_year(year)
1901            .with_month(month)
1902            .with_day(day)
1903            .with_hour(hour),
1904        "minute" => timestamp_info
1905            .with_year(year)
1906            .with_month(month)
1907            .with_day(day)
1908            .with_hour(hour)
1909            .with_minute(minute),
1910        "second" => timestamp_info
1911            .with_year(year)
1912            .with_month(month)
1913            .with_day(day)
1914            .with_hour(hour)
1915            .with_minute(minute)
1916            .with_second(second),
1917        "microsecond" => timestamp_info
1918            .with_year(year)
1919            .with_month(month)
1920            .with_day(day)
1921            .with_hour(hour)
1922            .with_minute(minute)
1923            .with_second(second)
1924            .with_microsecond(microsecond),
1925        _ => {
1926            return Err(SparkError::CastInvalidValue {
1927                value: value.to_string(),
1928                from_type: "STRING".to_string(),
1929                to_type: "TIMESTAMP".to_string(),
1930            })
1931        }
1932    };
1933
1934    parse_timestamp_to_micros(timestamp_info, tz)
1935}
1936
1937fn parse_str_to_year_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1938    get_timestamp_values(value, "year", tz)
1939}
1940
1941fn parse_str_to_month_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1942    get_timestamp_values(value, "month", tz)
1943}
1944
1945fn parse_str_to_day_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1946    get_timestamp_values(value, "day", tz)
1947}
1948
1949fn parse_str_to_hour_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1950    get_timestamp_values(value, "hour", tz)
1951}
1952
1953fn parse_str_to_minute_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1954    get_timestamp_values(value, "minute", tz)
1955}
1956
1957fn parse_str_to_second_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1958    get_timestamp_values(value, "second", tz)
1959}
1960
1961fn parse_str_to_microsecond_timestamp<T: TimeZone>(
1962    value: &str,
1963    tz: &T,
1964) -> SparkResult<Option<i64>> {
1965    get_timestamp_values(value, "microsecond", tz)
1966}
1967
1968fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> SparkResult<Option<i64>> {
1969    let values: Vec<&str> = value.split('T').collect();
1970    let time_values: Vec<u32> = values[1]
1971        .split(':')
1972        .map(|v| v.parse::<u32>().unwrap_or(0))
1973        .collect();
1974
1975    let datetime = tz.from_utc_datetime(&chrono::Utc::now().naive_utc());
1976    let timestamp = datetime
1977        .with_timezone(tz)
1978        .with_hour(time_values.first().copied().unwrap_or_default())
1979        .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
1980        .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
1981        .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000))
1982        .map(|dt| dt.timestamp_micros())
1983        .unwrap_or_default();
1984
1985    Ok(Some(timestamp))
1986}
1987
1988//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate.
1989fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1990    // local functions
1991    fn get_trimmed_start(bytes: &[u8]) -> usize {
1992        let mut start = 0;
1993        while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) {
1994            start += 1;
1995        }
1996        start
1997    }
1998
1999    fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize {
2000        let mut end = bytes.len() - 1;
2001        while end > start && is_whitespace_or_iso_control(bytes[end]) {
2002            end -= 1;
2003        }
2004        end + 1
2005    }
2006
2007    fn is_whitespace_or_iso_control(byte: u8) -> bool {
2008        byte.is_ascii_whitespace() || byte.is_ascii_control()
2009    }
2010
2011    fn is_valid_digits(segment: i32, digits: usize) -> bool {
2012        // An integer is able to represent a date within [+-]5 million years.
2013        let max_digits_year = 7;
2014        //year (segment 0) can be between 4 to 7 digits,
2015        //month and day (segment 1 and 2) can be between 1 to 2 digits
2016        (segment == 0 && digits >= 4 && digits <= max_digits_year)
2017            || (segment != 0 && digits > 0 && digits <= 2)
2018    }
2019
2020    fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
2021        if eval_mode == EvalMode::Ansi {
2022            Err(SparkError::CastInvalidValue {
2023                value: date_str.to_string(),
2024                from_type: "STRING".to_string(),
2025                to_type: "DATE".to_string(),
2026            })
2027        } else {
2028            Ok(None)
2029        }
2030    }
2031    // end local functions
2032
2033    if date_str.is_empty() {
2034        return return_result(date_str, eval_mode);
2035    }
2036
2037    //values of date segments year, month and day defaulting to 1
2038    let mut date_segments = [1, 1, 1];
2039    let mut sign = 1;
2040    let mut current_segment = 0;
2041    let mut current_segment_value = Wrapping(0);
2042    let mut current_segment_digits = 0;
2043    let bytes = date_str.as_bytes();
2044
2045    let mut j = get_trimmed_start(bytes);
2046    let str_end_trimmed = get_trimmed_end(j, bytes);
2047
2048    if j == str_end_trimmed {
2049        return return_result(date_str, eval_mode);
2050    }
2051
2052    //assign a sign to the date
2053    if bytes[j] == b'-' || bytes[j] == b'+' {
2054        sign = if bytes[j] == b'-' { -1 } else { 1 };
2055        j += 1;
2056    }
2057
2058    //loop to the end of string until we have processed 3 segments,
2059    //exit loop on encountering any space ' ' or 'T' after the 3rd segment
2060    while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) {
2061        let b = bytes[j];
2062        if current_segment < 2 && b == b'-' {
2063            //check for validity of year and month segments if current byte is separator
2064            if !is_valid_digits(current_segment, current_segment_digits) {
2065                return return_result(date_str, eval_mode);
2066            }
2067            //if valid update corresponding segment with the current segment value.
2068            date_segments[current_segment as usize] = current_segment_value.0;
2069            current_segment_value = Wrapping(0);
2070            current_segment_digits = 0;
2071            current_segment += 1;
2072        } else if !b.is_ascii_digit() {
2073            return return_result(date_str, eval_mode);
2074        } else {
2075            //increment value of current segment by the next digit
2076            let parsed_value = Wrapping((b - b'0') as i32);
2077            current_segment_value = current_segment_value * Wrapping(10) + parsed_value;
2078            current_segment_digits += 1;
2079        }
2080        j += 1;
2081    }
2082
2083    //check for validity of last segment
2084    if !is_valid_digits(current_segment, current_segment_digits) {
2085        return return_result(date_str, eval_mode);
2086    }
2087
2088    if current_segment < 2 && j < str_end_trimmed {
2089        // For the `yyyy` and `yyyy-[m]m` formats, entire input must be consumed.
2090        return return_result(date_str, eval_mode);
2091    }
2092
2093    date_segments[current_segment as usize] = current_segment_value.0;
2094
2095    match NaiveDate::from_ymd_opt(
2096        sign * date_segments[0],
2097        date_segments[1] as u32,
2098        date_segments[2] as u32,
2099    ) {
2100        Some(date) => {
2101            let duration_since_epoch = date
2102                .signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
2103                .num_days();
2104            Ok(Some(duration_since_epoch.to_i32().unwrap()))
2105        }
2106        None => Ok(None),
2107    }
2108}
2109
2110/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
2111/// This function runs as a post process of the DataFusion cast(). By the time it arrives here,
2112/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify
2113/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in
2114/// expressions/cast.rs, so it can be still Dictionary.
2115fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef {
2116    match (from_type, to_type) {
2117        (DataType::Timestamp(_, _), DataType::Int64) => {
2118            // See Spark's `Cast` expression
2119            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2120        }
2121        (DataType::Dictionary(_, value_type), DataType::Int64)
2122            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2123        {
2124            // See Spark's `Cast` expression
2125            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap()
2126        }
2127        (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array),
2128        (DataType::Dictionary(_, value_type), DataType::Utf8)
2129            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
2130        {
2131            remove_trailing_zeroes(array)
2132        }
2133        _ => array,
2134    }
2135}
2136
2137/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
2138fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
2139where
2140    T: ArrowPrimitiveType,
2141    F: Fn(T::Native) -> T::Native,
2142{
2143    if let Some(d) = array.as_any_dictionary_opt() {
2144        let new_values = unary_dyn::<F, T>(d.values(), op)?;
2145        return Ok(Arc::new(d.with_values(Arc::new(new_values))));
2146    }
2147
2148    match array.as_primitive_opt::<T>() {
2149        Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
2150            Ok(Arc::new(unary::<T, F, T>(
2151                array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
2152                op,
2153            )))
2154        }
2155        _ => Err(ArrowError::NotYetImplemented(format!(
2156            "Cannot perform unary operation of type {} on array of type {}",
2157            T::DATA_TYPE,
2158            array.data_type()
2159        ))),
2160    }
2161}
2162
2163/// Remove any trailing zeroes in the string if they occur after in the fractional seconds,
2164/// to match Spark behavior
2165/// example:
2166/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
2167/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
2168/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
2169/// "1970-01-01 05:30:00"     => "1970-01-01 05:30:00"
2170/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
2171fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
2172    let string_array = as_generic_string_array::<i32>(&array).unwrap();
2173    let result = string_array
2174        .iter()
2175        .map(|s| s.map(trim_end))
2176        .collect::<GenericStringArray<i32>>();
2177    Arc::new(result) as ArrayRef
2178}
2179
2180fn trim_end(s: &str) -> &str {
2181    if s.rfind('.').is_some() {
2182        s.trim_end_matches('0')
2183    } else {
2184        s
2185    }
2186}
2187
2188#[cfg(test)]
2189mod tests {
2190    use arrow::array::StringArray;
2191    use arrow::datatypes::TimestampMicrosecondType;
2192    use arrow::datatypes::{Field, Fields, TimeUnit};
2193    use core::f64;
2194    use std::str::FromStr;
2195
2196    use super::*;
2197
2198    #[test]
2199    #[cfg_attr(miri, ignore)] // test takes too long with miri
2200    fn timestamp_parser_test() {
2201        let tz = &timezone::Tz::from_str("UTC").unwrap();
2202        // write for all formats
2203        assert_eq!(
2204            timestamp_parser("2020", EvalMode::Legacy, tz).unwrap(),
2205            Some(1577836800000000) // this is in milliseconds
2206        );
2207        assert_eq!(
2208            timestamp_parser("2020-01", EvalMode::Legacy, tz).unwrap(),
2209            Some(1577836800000000)
2210        );
2211        assert_eq!(
2212            timestamp_parser("2020-01-01", EvalMode::Legacy, tz).unwrap(),
2213            Some(1577836800000000)
2214        );
2215        assert_eq!(
2216            timestamp_parser("2020-01-01T12", EvalMode::Legacy, tz).unwrap(),
2217            Some(1577880000000000)
2218        );
2219        assert_eq!(
2220            timestamp_parser("2020-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2221            Some(1577882040000000)
2222        );
2223        assert_eq!(
2224            timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2225            Some(1577882096000000)
2226        );
2227        assert_eq!(
2228            timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2229            Some(1577882096123456)
2230        );
2231        assert_eq!(
2232            timestamp_parser("0100", EvalMode::Legacy, tz).unwrap(),
2233            Some(-59011459200000000)
2234        );
2235        assert_eq!(
2236            timestamp_parser("0100-01", EvalMode::Legacy, tz).unwrap(),
2237            Some(-59011459200000000)
2238        );
2239        assert_eq!(
2240            timestamp_parser("0100-01-01", EvalMode::Legacy, tz).unwrap(),
2241            Some(-59011459200000000)
2242        );
2243        assert_eq!(
2244            timestamp_parser("0100-01-01T12", EvalMode::Legacy, tz).unwrap(),
2245            Some(-59011416000000000)
2246        );
2247        assert_eq!(
2248            timestamp_parser("0100-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2249            Some(-59011413960000000)
2250        );
2251        assert_eq!(
2252            timestamp_parser("0100-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2253            Some(-59011413904000000)
2254        );
2255        assert_eq!(
2256            timestamp_parser("0100-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2257            Some(-59011413903876544)
2258        );
2259        assert_eq!(
2260            timestamp_parser("10000", EvalMode::Legacy, tz).unwrap(),
2261            Some(253402300800000000)
2262        );
2263        assert_eq!(
2264            timestamp_parser("10000-01", EvalMode::Legacy, tz).unwrap(),
2265            Some(253402300800000000)
2266        );
2267        assert_eq!(
2268            timestamp_parser("10000-01-01", EvalMode::Legacy, tz).unwrap(),
2269            Some(253402300800000000)
2270        );
2271        assert_eq!(
2272            timestamp_parser("10000-01-01T12", EvalMode::Legacy, tz).unwrap(),
2273            Some(253402344000000000)
2274        );
2275        assert_eq!(
2276            timestamp_parser("10000-01-01T12:34", EvalMode::Legacy, tz).unwrap(),
2277            Some(253402346040000000)
2278        );
2279        assert_eq!(
2280            timestamp_parser("10000-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(),
2281            Some(253402346096000000)
2282        );
2283        assert_eq!(
2284            timestamp_parser("10000-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(),
2285            Some(253402346096123456)
2286        );
2287        // assert_eq!(
2288        //     timestamp_parser("T2",  EvalMode::Legacy).unwrap(),
2289        //     Some(1714356000000000) // this value needs to change everyday.
2290        // );
2291    }
2292
2293    #[test]
2294    #[cfg_attr(miri, ignore)] // test takes too long with miri
2295    fn test_cast_string_to_timestamp() {
2296        let array: ArrayRef = Arc::new(StringArray::from(vec![
2297            Some("2020-01-01T12:34:56.123456"),
2298            Some("T2"),
2299            Some("0100-01-01T12:34:56.123456"),
2300            Some("10000-01-01T12:34:56.123456"),
2301        ]));
2302        let tz = &timezone::Tz::from_str("UTC").unwrap();
2303
2304        let string_array = array
2305            .as_any()
2306            .downcast_ref::<GenericStringArray<i32>>()
2307            .expect("Expected a string array");
2308
2309        let eval_mode = EvalMode::Legacy;
2310        let result = cast_utf8_to_timestamp!(
2311            &string_array,
2312            eval_mode,
2313            TimestampMicrosecondType,
2314            timestamp_parser,
2315            tz
2316        );
2317
2318        assert_eq!(
2319            result.data_type(),
2320            &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
2321        );
2322        assert_eq!(result.len(), 4);
2323    }
2324
2325    #[test]
2326    fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
2327        // prepare input data
2328        let keys = Int32Array::from(vec![0, 1]);
2329        let values: ArrayRef = Arc::new(StringArray::from(vec![
2330            Some("2020-01-01T12:34:56.123456"),
2331            Some("T2"),
2332        ]));
2333        let dict_array = Arc::new(DictionaryArray::new(keys, values));
2334
2335        let timezone = "UTC".to_string();
2336        // test casting string dictionary array to timestamp array
2337        let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false);
2338        let result = cast_array(
2339            dict_array,
2340            &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
2341            &cast_options,
2342        )?;
2343        assert_eq!(
2344            *result.data_type(),
2345            DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
2346        );
2347        assert_eq!(result.len(), 2);
2348
2349        Ok(())
2350    }
2351
2352    #[test]
2353    fn date_parser_test() {
2354        for date in &[
2355            "2020",
2356            "2020-01",
2357            "2020-01-01",
2358            "02020-01-01",
2359            "002020-01-01",
2360            "0002020-01-01",
2361            "2020-1-1",
2362            "2020-01-01 ",
2363            "2020-01-01T",
2364        ] {
2365            for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] {
2366                assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262));
2367            }
2368        }
2369
2370        //dates in invalid formats
2371        for date in &[
2372            "abc",
2373            "",
2374            "not_a_date",
2375            "3/",
2376            "3/12",
2377            "3/12/2020",
2378            "3/12/2002 T",
2379            "202",
2380            "2020-010-01",
2381            "2020-10-010",
2382            "2020-10-010T",
2383            "--262143-12-31",
2384            "--262143-12-31 ",
2385        ] {
2386            for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2387                assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2388            }
2389            assert!(date_parser(date, EvalMode::Ansi).is_err());
2390        }
2391
2392        for date in &["-3638-5"] {
2393            for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2394                assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(-2048160));
2395            }
2396        }
2397
2398        //Naive Date only supports years 262142 AD to 262143 BC
2399        //returns None for dates out of range supported by Naive Date.
2400        for date in &[
2401            "-262144-1-1",
2402            "262143-01-1",
2403            "262143-1-1",
2404            "262143-01-1 ",
2405            "262143-01-01T ",
2406            "262143-1-01T 1234",
2407            "-0973250",
2408        ] {
2409            for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2410                assert_eq!(date_parser(date, *eval_mode).unwrap(), None);
2411            }
2412        }
2413    }
2414
2415    #[test]
2416    fn test_cast_string_to_date() {
2417        let array: ArrayRef = Arc::new(StringArray::from(vec![
2418            Some("2020"),
2419            Some("2020-01"),
2420            Some("2020-01-01"),
2421            Some("2020-01-01T"),
2422        ]));
2423
2424        let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap();
2425
2426        let date32_array = result
2427            .as_any()
2428            .downcast_ref::<arrow::array::Date32Array>()
2429            .unwrap();
2430        assert_eq!(date32_array.len(), 4);
2431        date32_array
2432            .iter()
2433            .for_each(|v| assert_eq!(v.unwrap(), 18262));
2434    }
2435
2436    #[test]
2437    fn test_cast_string_array_with_valid_dates() {
2438        let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2439            Some("-262143-12-31"),
2440            Some("\n -262143-12-31 "),
2441            Some("-262143-12-31T \t\n"),
2442            Some("\n\t-262143-12-31T\r"),
2443            Some("-262143-12-31T 123123123"),
2444            Some("\r\n-262143-12-31T \r123123123"),
2445            Some("\n -262143-12-31T \n\t"),
2446        ]));
2447
2448        for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
2449            let result =
2450                cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2451                    .unwrap();
2452
2453            let date32_array = result
2454                .as_any()
2455                .downcast_ref::<arrow::array::Date32Array>()
2456                .unwrap();
2457            assert_eq!(result.len(), 7);
2458            date32_array
2459                .iter()
2460                .for_each(|v| assert_eq!(v.unwrap(), -96464928));
2461        }
2462    }
2463
2464    #[test]
2465    fn test_cast_string_array_with_invalid_dates() {
2466        let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![
2467            Some("2020"),
2468            Some("2020-01"),
2469            Some("2020-01-01"),
2470            //4 invalid dates
2471            Some("2020-010-01T"),
2472            Some("202"),
2473            Some(" 202 "),
2474            Some("\n 2020-\r8 "),
2475            Some("2020-01-01T"),
2476            // Overflows i32
2477            Some("-4607172990231812908"),
2478        ]));
2479
2480        for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
2481            let result =
2482                cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode)
2483                    .unwrap();
2484
2485            let date32_array = result
2486                .as_any()
2487                .downcast_ref::<arrow::array::Date32Array>()
2488                .unwrap();
2489            assert_eq!(
2490                date32_array.iter().collect::<Vec<_>>(),
2491                vec![
2492                    Some(18262),
2493                    Some(18262),
2494                    Some(18262),
2495                    None,
2496                    None,
2497                    None,
2498                    None,
2499                    Some(18262),
2500                    None
2501                ]
2502            );
2503        }
2504
2505        let result =
2506            cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi);
2507        match result {
2508            Err(e) => assert!(
2509                e.to_string().contains(
2510                    "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed")
2511            ),
2512            _ => panic!("Expected error"),
2513        }
2514    }
2515
2516    #[test]
2517    fn test_cast_string_as_i8() {
2518        // basic
2519        assert_eq!(
2520            cast_string_to_i8("127", EvalMode::Legacy).unwrap(),
2521            Some(127_i8)
2522        );
2523        assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None);
2524        assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err());
2525        // decimals
2526        assert_eq!(
2527            cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(),
2528            Some(0_i8)
2529        );
2530        assert_eq!(
2531            cast_string_to_i8(".", EvalMode::Legacy).unwrap(),
2532            Some(0_i8)
2533        );
2534        // TRY should always return null for decimals
2535        assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None);
2536        assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None);
2537        // ANSI mode should throw error on decimal
2538        assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err());
2539        assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err());
2540    }
2541
2542    #[test]
2543    fn test_cast_unsupported_timestamp_to_date() {
2544        // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported
2545        let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2546        let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
2547        let result = cast_array(
2548            Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2549            &DataType::Date32,
2550            &cast_options,
2551        );
2552        assert!(result.is_err())
2553    }
2554
2555    #[test]
2556    fn test_cast_invalid_timezone() {
2557        let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
2558        let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false);
2559        let result = cast_array(
2560            Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
2561            &DataType::Date32,
2562            &cast_options,
2563        );
2564        assert!(result.is_err())
2565    }
2566
2567    #[test]
2568    fn test_cast_struct_to_utf8() {
2569        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2570            Some(1),
2571            Some(2),
2572            None,
2573            Some(4),
2574            Some(5),
2575        ]));
2576        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2577        let c: ArrayRef = Arc::new(StructArray::from(vec![
2578            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2579            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2580        ]));
2581        let string_array = cast_array(
2582            c,
2583            &DataType::Utf8,
2584            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2585        )
2586        .unwrap();
2587        let string_array = string_array.as_string::<i32>();
2588        assert_eq!(5, string_array.len());
2589        assert_eq!(r#"{1, a}"#, string_array.value(0));
2590        assert_eq!(r#"{2, b}"#, string_array.value(1));
2591        assert_eq!(r#"{null, c}"#, string_array.value(2));
2592        assert_eq!(r#"{4, d}"#, string_array.value(3));
2593        assert_eq!(r#"{5, e}"#, string_array.value(4));
2594    }
2595
2596    #[test]
2597    fn test_cast_struct_to_struct() {
2598        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2599            Some(1),
2600            Some(2),
2601            None,
2602            Some(4),
2603            Some(5),
2604        ]));
2605        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2606        let c: ArrayRef = Arc::new(StructArray::from(vec![
2607            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2608            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2609        ]));
2610        // change type of "a" from Int32 to Utf8
2611        let fields = Fields::from(vec![
2612            Field::new("a", DataType::Utf8, true),
2613            Field::new("b", DataType::Utf8, true),
2614        ]);
2615        let cast_array = spark_cast(
2616            ColumnarValue::Array(c),
2617            &DataType::Struct(fields),
2618            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2619        )
2620        .unwrap();
2621        if let ColumnarValue::Array(cast_array) = cast_array {
2622            assert_eq!(5, cast_array.len());
2623            let a = cast_array.as_struct().column(0).as_string::<i32>();
2624            assert_eq!("1", a.value(0));
2625        } else {
2626            unreachable!()
2627        }
2628    }
2629
2630    #[test]
2631    fn test_cast_struct_to_struct_drop_column() {
2632        let a: ArrayRef = Arc::new(Int32Array::from(vec![
2633            Some(1),
2634            Some(2),
2635            None,
2636            Some(4),
2637            Some(5),
2638        ]));
2639        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2640        let c: ArrayRef = Arc::new(StructArray::from(vec![
2641            (Arc::new(Field::new("a", DataType::Int32, true)), a),
2642            (Arc::new(Field::new("b", DataType::Utf8, true)), b),
2643        ]));
2644        // change type of "a" from Int32 to Utf8 and drop "b"
2645        let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]);
2646        let cast_array = spark_cast(
2647            ColumnarValue::Array(c),
2648            &DataType::Struct(fields),
2649            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2650        )
2651        .unwrap();
2652        if let ColumnarValue::Array(cast_array) = cast_array {
2653            assert_eq!(5, cast_array.len());
2654            let struct_array = cast_array.as_struct();
2655            assert_eq!(1, struct_array.columns().len());
2656            let a = struct_array.column(0).as_string::<i32>();
2657            assert_eq!("1", a.value(0));
2658        } else {
2659            unreachable!()
2660        }
2661    }
2662
2663    #[test]
2664    // Currently the cast function depending on `f64::powi`, which has unspecified precision according to the doc
2665    // https://doc.rust-lang.org/std/primitive.f64.html#unspecified-precision.
2666    // Miri deliberately apply random floating-point errors to these operations to expose bugs
2667    // https://github.com/rust-lang/miri/issues/4395.
2668    // The random errors may interfere with test cases at rounding edge, so we ignore it on miri for now.
2669    // Once https://github.com/apache/datafusion-comet/issues/1371 is fixed, this should no longer be an issue.
2670    #[cfg_attr(miri, ignore)]
2671    fn test_cast_float_to_decimal() {
2672        let a: ArrayRef = Arc::new(Float64Array::from(vec![
2673            Some(42.),
2674            Some(0.5153125),
2675            Some(-42.4242415),
2676            Some(42e-314),
2677            Some(0.),
2678            Some(-4242.424242),
2679            Some(f64::INFINITY),
2680            Some(f64::NEG_INFINITY),
2681            Some(f64::NAN),
2682            None,
2683        ]));
2684        let b =
2685            cast_floating_point_to_decimal128::<Float64Type>(&a, 8, 6, EvalMode::Legacy).unwrap();
2686        assert_eq!(b.len(), a.len());
2687        let casted = b.as_primitive::<Decimal128Type>();
2688        assert_eq!(casted.value(0), 42000000);
2689        // https://github.com/apache/datafusion-comet/issues/1371
2690        // assert_eq!(casted.value(1), 515313);
2691        assert_eq!(casted.value(2), -42424242);
2692        assert_eq!(casted.value(3), 0);
2693        assert_eq!(casted.value(4), 0);
2694        assert!(casted.is_null(5));
2695        assert!(casted.is_null(6));
2696        assert!(casted.is_null(7));
2697        assert!(casted.is_null(8));
2698        assert!(casted.is_null(9));
2699    }
2700}