datafusion_expr_common/
casts.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities for casting scalar literals to different data types
19//!
20//! This module contains functions for casting ScalarValue literals
21//! to different data types, originally extracted from the optimizer's
22//! unwrap_cast module to be shared between logical and physical layers.
23
24use std::cmp::Ordering;
25
26use arrow::datatypes::{
27    DataType, MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
28    MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION,
29    MIN_DECIMAL64_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION, TimeUnit,
30};
31use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
32use datafusion_common::ScalarValue;
33
34/// Convert a literal value from one data type to another
35pub fn try_cast_literal_to_type(
36    lit_value: &ScalarValue,
37    target_type: &DataType,
38) -> Option<ScalarValue> {
39    let lit_data_type = lit_value.data_type();
40    if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
41        return None;
42    }
43    if lit_value.is_null() {
44        // null value can be cast to any type of null value
45        return ScalarValue::try_from(target_type).ok();
46    }
47    try_cast_numeric_literal(lit_value, target_type)
48        .or_else(|| try_cast_string_literal(lit_value, target_type))
49        .or_else(|| try_cast_dictionary(lit_value, target_type))
50        .or_else(|| try_cast_binary(lit_value, target_type))
51}
52
53/// Returns true if unwrap_cast_in_comparison supports this data type
54pub fn is_supported_type(data_type: &DataType) -> bool {
55    is_supported_numeric_type(data_type)
56        || is_supported_string_type(data_type)
57        || is_supported_dictionary_type(data_type)
58        || is_supported_binary_type(data_type)
59}
60
61/// Returns true if unwrap_cast_in_comparison support this numeric type
62fn is_supported_numeric_type(data_type: &DataType) -> bool {
63    matches!(
64        data_type,
65        DataType::UInt8
66            | DataType::UInt16
67            | DataType::UInt32
68            | DataType::UInt64
69            | DataType::Int8
70            | DataType::Int16
71            | DataType::Int32
72            | DataType::Int64
73            | DataType::Decimal32(_, _)
74            | DataType::Decimal64(_, _)
75            | DataType::Decimal128(_, _)
76            | DataType::Timestamp(_, _)
77    )
78}
79
80/// Returns true if unwrap_cast_in_comparison supports casting this value as a string
81fn is_supported_string_type(data_type: &DataType) -> bool {
82    matches!(
83        data_type,
84        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
85    )
86}
87
88/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary
89fn is_supported_dictionary_type(data_type: &DataType) -> bool {
90    matches!(data_type,
91                    DataType::Dictionary(_, inner) if is_supported_type(inner))
92}
93
94fn is_supported_binary_type(data_type: &DataType) -> bool {
95    matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
96}
97
98/// Convert a numeric value from one numeric data type to another
99fn try_cast_numeric_literal(
100    lit_value: &ScalarValue,
101    target_type: &DataType,
102) -> Option<ScalarValue> {
103    let lit_data_type = lit_value.data_type();
104    if !is_supported_numeric_type(&lit_data_type)
105        || !is_supported_numeric_type(target_type)
106    {
107        return None;
108    }
109
110    let mul = match target_type {
111        DataType::UInt8
112        | DataType::UInt16
113        | DataType::UInt32
114        | DataType::UInt64
115        | DataType::Int8
116        | DataType::Int16
117        | DataType::Int32
118        | DataType::Int64 => 1_i128,
119        DataType::Timestamp(_, _) => 1_i128,
120        DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
121        DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
122        DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
123        _ => return None,
124    };
125    let (target_min, target_max) = match target_type {
126        DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
127        DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
128        DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
129        DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
130        DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
131        DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
132        DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
133        DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
134        DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
135        DataType::Decimal32(precision, _) => (
136            // Different precision for decimal32 can store different range of value.
137            // For example, the precision is 3, the max of value is `999` and the min
138            // value is `-999`
139            MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
140            MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
141        ),
142        DataType::Decimal64(precision, _) => (
143            // Different precision for decimal64 can store different range of value.
144            // For example, the precision is 3, the max of value is `999` and the min
145            // value is `-999`
146            MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
147            MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
148        ),
149        DataType::Decimal128(precision, _) => (
150            // Different precision for decimal128 can store different range of value.
151            // For example, the precision is 3, the max of value is `999` and the min
152            // value is `-999`
153            MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
154            MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
155        ),
156        _ => return None,
157    };
158    let lit_value_target_type = match lit_value {
159        ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
160        ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
161        ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
162        ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
163        ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
164        ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
165        ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
166        ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
167        ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
168        ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
169        ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
170        ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
171        ScalarValue::Decimal32(Some(v), _, scale) => {
172            let v = *v as i128;
173            let lit_scale_mul = 10_i128.pow(*scale as u32);
174            if mul >= lit_scale_mul {
175                // Example:
176                // lit is decimal(123,3,2)
177                // target type is decimal(5,3)
178                // the lit can be converted to the decimal(1230,5,3)
179                v.checked_mul(mul / lit_scale_mul)
180            } else if v % (lit_scale_mul / mul) == 0 {
181                // Example:
182                // lit is decimal(123000,10,3)
183                // target type is int32: the lit can be converted to INT32(123)
184                // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
185                Some(v / (lit_scale_mul / mul))
186            } else {
187                // can't convert the lit decimal to the target data type
188                None
189            }
190        }
191        ScalarValue::Decimal64(Some(v), _, scale) => {
192            let v = *v as i128;
193            let lit_scale_mul = 10_i128.pow(*scale as u32);
194            if mul >= lit_scale_mul {
195                // Example:
196                // lit is decimal(123,3,2)
197                // target type is decimal(5,3)
198                // the lit can be converted to the decimal(1230,5,3)
199                v.checked_mul(mul / lit_scale_mul)
200            } else if v % (lit_scale_mul / mul) == 0 {
201                // Example:
202                // lit is decimal(123000,10,3)
203                // target type is int32: the lit can be converted to INT32(123)
204                // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
205                Some(v / (lit_scale_mul / mul))
206            } else {
207                // can't convert the lit decimal to the target data type
208                None
209            }
210        }
211        ScalarValue::Decimal128(Some(v), _, scale) => {
212            let lit_scale_mul = 10_i128.pow(*scale as u32);
213            if mul >= lit_scale_mul {
214                // Example:
215                // lit is decimal(123,3,2)
216                // target type is decimal(5,3)
217                // the lit can be converted to the decimal(1230,5,3)
218                (*v).checked_mul(mul / lit_scale_mul)
219            } else if (*v) % (lit_scale_mul / mul) == 0 {
220                // Example:
221                // lit is decimal(123000,10,3)
222                // target type is int32: the lit can be converted to INT32(123)
223                // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
224                Some(*v / (lit_scale_mul / mul))
225            } else {
226                // can't convert the lit decimal to the target data type
227                None
228            }
229        }
230        _ => None,
231    };
232
233    match lit_value_target_type {
234        None => None,
235        Some(value) => {
236            if value >= target_min && value <= target_max {
237                // the value casted from lit to the target type is in the range of target type.
238                // return the target type of scalar value
239                let result_scalar = match target_type {
240                    DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
241                    DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
242                    DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
243                    DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
244                    DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
245                    DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
246                    DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
247                    DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
248                    DataType::Timestamp(TimeUnit::Second, tz) => {
249                        let value = cast_between_timestamp(
250                            &lit_data_type,
251                            &DataType::Timestamp(TimeUnit::Second, tz.clone()),
252                            value,
253                        );
254                        ScalarValue::TimestampSecond(value, tz.clone())
255                    }
256                    DataType::Timestamp(TimeUnit::Millisecond, tz) => {
257                        let value = cast_between_timestamp(
258                            &lit_data_type,
259                            &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
260                            value,
261                        );
262                        ScalarValue::TimestampMillisecond(value, tz.clone())
263                    }
264                    DataType::Timestamp(TimeUnit::Microsecond, tz) => {
265                        let value = cast_between_timestamp(
266                            &lit_data_type,
267                            &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
268                            value,
269                        );
270                        ScalarValue::TimestampMicrosecond(value, tz.clone())
271                    }
272                    DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
273                        let value = cast_between_timestamp(
274                            &lit_data_type,
275                            &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
276                            value,
277                        );
278                        ScalarValue::TimestampNanosecond(value, tz.clone())
279                    }
280                    DataType::Decimal32(p, s) => {
281                        ScalarValue::Decimal32(Some(value as i32), *p, *s)
282                    }
283                    DataType::Decimal64(p, s) => {
284                        ScalarValue::Decimal64(Some(value as i64), *p, *s)
285                    }
286                    DataType::Decimal128(p, s) => {
287                        ScalarValue::Decimal128(Some(value), *p, *s)
288                    }
289                    _ => {
290                        return None;
291                    }
292                };
293                Some(result_scalar)
294            } else {
295                None
296            }
297        }
298    }
299}
300
301fn try_cast_string_literal(
302    lit_value: &ScalarValue,
303    target_type: &DataType,
304) -> Option<ScalarValue> {
305    let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
306    let scalar_value = match target_type {
307        DataType::Utf8 => ScalarValue::Utf8(string_value),
308        DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
309        DataType::Utf8View => ScalarValue::Utf8View(string_value),
310        _ => return None,
311    };
312    Some(scalar_value)
313}
314
315/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary
316fn try_cast_dictionary(
317    lit_value: &ScalarValue,
318    target_type: &DataType,
319) -> Option<ScalarValue> {
320    let lit_value_type = lit_value.data_type();
321    let result_scalar = match (lit_value, target_type) {
322        // Unwrap dictionary when inner type matches target type
323        (ScalarValue::Dictionary(_, inner_value), _)
324            if inner_value.data_type() == *target_type =>
325        {
326            (**inner_value).clone()
327        }
328        // Wrap type when target type is dictionary
329        (_, DataType::Dictionary(index_type, inner_type))
330            if **inner_type == lit_value_type =>
331        {
332            ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
333        }
334        _ => {
335            return None;
336        }
337    };
338    Some(result_scalar)
339}
340
341/// Cast a timestamp value from one unit to another
342fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
343    let value = value as i64;
344    let from_scale = match from {
345        DataType::Timestamp(TimeUnit::Second, _) => 1,
346        DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
347        DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
348        DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
349        _ => return Some(value),
350    };
351
352    let to_scale = match to {
353        DataType::Timestamp(TimeUnit::Second, _) => 1,
354        DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
355        DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
356        DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
357        _ => return Some(value),
358    };
359
360    match from_scale.cmp(&to_scale) {
361        Ordering::Less => value.checked_mul(to_scale / from_scale),
362        Ordering::Greater => Some(value / (from_scale / to_scale)),
363        Ordering::Equal => Some(value),
364    }
365}
366
367fn try_cast_binary(
368    lit_value: &ScalarValue,
369    target_type: &DataType,
370) -> Option<ScalarValue> {
371    match (lit_value, target_type) {
372        (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
373            if v.len() == *n as usize =>
374        {
375            Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
376        }
377        _ => None,
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use arrow::compute::{CastOptions, cast_with_options};
385    use arrow::datatypes::{Field, Fields, TimeUnit};
386    use std::sync::Arc;
387
388    #[derive(Debug, Clone)]
389    enum ExpectedCast {
390        /// test successfully cast value and it is as specified
391        Value(ScalarValue),
392        /// test returned OK, but could not cast the value
393        NoValue,
394    }
395
396    /// Runs try_cast_literal_to_type with the specified inputs and
397    /// ensure it computes the expected output, and ensures the
398    /// casting is consistent with the Arrow kernels
399    fn expect_cast(
400        literal: ScalarValue,
401        target_type: DataType,
402        expected_result: ExpectedCast,
403    ) {
404        let actual_value = try_cast_literal_to_type(&literal, &target_type);
405
406        println!("expect_cast: ");
407        println!("  {literal:?} --> {target_type}");
408        println!("  expected_result: {expected_result:?}");
409        println!("  actual_result:   {actual_value:?}");
410
411        match expected_result {
412            ExpectedCast::Value(expected_value) => {
413                let actual_value =
414                    actual_value.expect("Expected cast value but got None");
415
416                assert_eq!(actual_value, expected_value);
417
418                // Verify that calling the arrow
419                // cast kernel yields the same results
420                // input array
421                let literal_array = literal
422                    .to_array_of_size(1)
423                    .expect("Failed to convert to array of size");
424                let expected_array = expected_value
425                    .to_array_of_size(1)
426                    .expect("Failed to convert to array of size");
427                let cast_array = cast_with_options(
428                    &literal_array,
429                    &target_type,
430                    &CastOptions::default(),
431                )
432                .expect("Expected to be cast array with arrow cast kernel");
433
434                assert_eq!(
435                    &expected_array, &cast_array,
436                    "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
437                );
438
439                // Verify that for timestamp types the timezones are the same
440                // (ScalarValue::cmp doesn't account for timezones);
441                if let (
442                    DataType::Timestamp(left_unit, left_tz),
443                    DataType::Timestamp(right_unit, right_tz),
444                ) = (actual_value.data_type(), expected_value.data_type())
445                {
446                    assert_eq!(left_unit, right_unit);
447                    assert_eq!(left_tz, right_tz);
448                }
449            }
450            ExpectedCast::NoValue => {
451                assert!(
452                    actual_value.is_none(),
453                    "Expected no cast value, but got {actual_value:?}"
454                );
455            }
456        }
457    }
458
459    #[test]
460    fn test_try_cast_to_type_nulls() {
461        // test that nulls can be cast to/from all integer types
462        let scalars = vec![
463            ScalarValue::Int8(None),
464            ScalarValue::Int16(None),
465            ScalarValue::Int32(None),
466            ScalarValue::Int64(None),
467            ScalarValue::UInt8(None),
468            ScalarValue::UInt16(None),
469            ScalarValue::UInt32(None),
470            ScalarValue::UInt64(None),
471            ScalarValue::Decimal128(None, 3, 0),
472            ScalarValue::Decimal128(None, 8, 2),
473            ScalarValue::Utf8(None),
474            ScalarValue::LargeUtf8(None),
475        ];
476
477        for s1 in &scalars {
478            for s2 in &scalars {
479                let expected_value = ExpectedCast::Value(s2.clone());
480
481                expect_cast(s1.clone(), s2.data_type(), expected_value);
482            }
483        }
484    }
485
486    #[test]
487    fn test_try_cast_to_type_int_in_range() {
488        // test values that can be cast to/from all integer types
489        let scalars = vec![
490            ScalarValue::Int8(Some(123)),
491            ScalarValue::Int16(Some(123)),
492            ScalarValue::Int32(Some(123)),
493            ScalarValue::Int64(Some(123)),
494            ScalarValue::UInt8(Some(123)),
495            ScalarValue::UInt16(Some(123)),
496            ScalarValue::UInt32(Some(123)),
497            ScalarValue::UInt64(Some(123)),
498            ScalarValue::Decimal128(Some(123), 3, 0),
499            ScalarValue::Decimal128(Some(12300), 8, 2),
500        ];
501
502        for s1 in &scalars {
503            for s2 in &scalars {
504                let expected_value = ExpectedCast::Value(s2.clone());
505
506                expect_cast(s1.clone(), s2.data_type(), expected_value);
507            }
508        }
509
510        let max_i32 = ScalarValue::Int32(Some(i32::MAX));
511        expect_cast(
512            max_i32,
513            DataType::UInt64,
514            ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
515        );
516
517        let min_i32 = ScalarValue::Int32(Some(i32::MIN));
518        expect_cast(
519            min_i32,
520            DataType::Int64,
521            ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
522        );
523
524        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
525        expect_cast(
526            max_i64,
527            DataType::UInt64,
528            ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
529        );
530    }
531
532    #[test]
533    fn test_try_cast_to_type_int_out_of_range() {
534        let min_i32 = ScalarValue::Int32(Some(i32::MIN));
535        let min_i64 = ScalarValue::Int64(Some(i64::MIN));
536        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
537        let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
538
539        expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
540
541        expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
542
543        expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
544
545        expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
546
547        expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
548
549        expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
550
551        // decimal out of range
552        expect_cast(
553            ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
554            DataType::Int64,
555            ExpectedCast::NoValue,
556        );
557
558        expect_cast(
559            ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
560            DataType::Int64,
561            ExpectedCast::NoValue,
562        );
563    }
564
565    #[test]
566    fn test_try_decimal_cast_in_range() {
567        expect_cast(
568            ScalarValue::Decimal128(Some(12300), 5, 2),
569            DataType::Decimal128(3, 0),
570            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
571        );
572
573        expect_cast(
574            ScalarValue::Decimal128(Some(12300), 5, 2),
575            DataType::Decimal128(8, 0),
576            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
577        );
578
579        expect_cast(
580            ScalarValue::Decimal128(Some(12300), 5, 2),
581            DataType::Decimal128(8, 5),
582            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
583        );
584    }
585
586    #[test]
587    fn test_try_decimal_cast_out_of_range() {
588        // decimal would lose precision
589        expect_cast(
590            ScalarValue::Decimal128(Some(12345), 5, 2),
591            DataType::Decimal128(3, 0),
592            ExpectedCast::NoValue,
593        );
594
595        // decimal would lose precision
596        expect_cast(
597            ScalarValue::Decimal128(Some(12300), 5, 2),
598            DataType::Decimal128(2, 0),
599            ExpectedCast::NoValue,
600        );
601    }
602
603    #[test]
604    fn test_try_cast_to_type_timestamps() {
605        for time_unit in [
606            TimeUnit::Second,
607            TimeUnit::Millisecond,
608            TimeUnit::Microsecond,
609            TimeUnit::Nanosecond,
610        ] {
611            let utc = Some("+00:00".into());
612            // No timezone, utc timezone
613            let (lit_tz_none, lit_tz_utc) = match time_unit {
614                TimeUnit::Second => (
615                    ScalarValue::TimestampSecond(Some(12345), None),
616                    ScalarValue::TimestampSecond(Some(12345), utc),
617                ),
618
619                TimeUnit::Millisecond => (
620                    ScalarValue::TimestampMillisecond(Some(12345), None),
621                    ScalarValue::TimestampMillisecond(Some(12345), utc),
622                ),
623
624                TimeUnit::Microsecond => (
625                    ScalarValue::TimestampMicrosecond(Some(12345), None),
626                    ScalarValue::TimestampMicrosecond(Some(12345), utc),
627                ),
628
629                TimeUnit::Nanosecond => (
630                    ScalarValue::TimestampNanosecond(Some(12345), None),
631                    ScalarValue::TimestampNanosecond(Some(12345), utc),
632                ),
633            };
634
635            // DataFusion ignores timezones for comparisons of ScalarValue
636            // so double check it here
637            assert_eq!(lit_tz_none, lit_tz_utc);
638
639            // e.g. DataType::Timestamp(_, None)
640            let dt_tz_none = lit_tz_none.data_type();
641
642            // e.g. DataType::Timestamp(_, Some(utc))
643            let dt_tz_utc = lit_tz_utc.data_type();
644
645            // None <--> None
646            expect_cast(
647                lit_tz_none.clone(),
648                dt_tz_none.clone(),
649                ExpectedCast::Value(lit_tz_none.clone()),
650            );
651
652            // None <--> Utc
653            expect_cast(
654                lit_tz_none.clone(),
655                dt_tz_utc.clone(),
656                ExpectedCast::Value(lit_tz_utc.clone()),
657            );
658
659            // Utc <--> None
660            expect_cast(
661                lit_tz_utc.clone(),
662                dt_tz_none.clone(),
663                ExpectedCast::Value(lit_tz_none.clone()),
664            );
665
666            // Utc <--> Utc
667            expect_cast(
668                lit_tz_utc.clone(),
669                dt_tz_utc.clone(),
670                ExpectedCast::Value(lit_tz_utc.clone()),
671            );
672
673            // timestamp to int64
674            expect_cast(
675                lit_tz_utc.clone(),
676                DataType::Int64,
677                ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
678            );
679
680            // int64 to timestamp
681            expect_cast(
682                ScalarValue::Int64(Some(12345)),
683                dt_tz_none.clone(),
684                ExpectedCast::Value(lit_tz_none.clone()),
685            );
686
687            // int64 to timestamp
688            expect_cast(
689                ScalarValue::Int64(Some(12345)),
690                dt_tz_utc.clone(),
691                ExpectedCast::Value(lit_tz_utc.clone()),
692            );
693
694            // timestamp to string (not supported yet)
695            expect_cast(
696                lit_tz_utc.clone(),
697                DataType::LargeUtf8,
698                ExpectedCast::NoValue,
699            );
700        }
701    }
702
703    #[test]
704    fn test_try_cast_to_type_unsupported() {
705        // int64 to list
706        expect_cast(
707            ScalarValue::Int64(Some(12345)),
708            DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
709            ExpectedCast::NoValue,
710        );
711    }
712
713    #[test]
714    fn test_try_cast_literal_to_timestamp() {
715        // same timestamp
716        let new_scalar = try_cast_literal_to_type(
717            &ScalarValue::TimestampNanosecond(Some(123456), None),
718            &DataType::Timestamp(TimeUnit::Nanosecond, None),
719        )
720        .unwrap();
721
722        assert_eq!(
723            new_scalar,
724            ScalarValue::TimestampNanosecond(Some(123456), None)
725        );
726
727        // TimestampNanosecond to TimestampMicrosecond
728        let new_scalar = try_cast_literal_to_type(
729            &ScalarValue::TimestampNanosecond(Some(123456), None),
730            &DataType::Timestamp(TimeUnit::Microsecond, None),
731        )
732        .unwrap();
733
734        assert_eq!(
735            new_scalar,
736            ScalarValue::TimestampMicrosecond(Some(123), None)
737        );
738
739        // TimestampNanosecond to TimestampMillisecond
740        let new_scalar = try_cast_literal_to_type(
741            &ScalarValue::TimestampNanosecond(Some(123456), None),
742            &DataType::Timestamp(TimeUnit::Millisecond, None),
743        )
744        .unwrap();
745
746        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
747
748        // TimestampNanosecond to TimestampSecond
749        let new_scalar = try_cast_literal_to_type(
750            &ScalarValue::TimestampNanosecond(Some(123456), None),
751            &DataType::Timestamp(TimeUnit::Second, None),
752        )
753        .unwrap();
754
755        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
756
757        // TimestampMicrosecond to TimestampNanosecond
758        let new_scalar = try_cast_literal_to_type(
759            &ScalarValue::TimestampMicrosecond(Some(123), None),
760            &DataType::Timestamp(TimeUnit::Nanosecond, None),
761        )
762        .unwrap();
763
764        assert_eq!(
765            new_scalar,
766            ScalarValue::TimestampNanosecond(Some(123000), None)
767        );
768
769        // TimestampMicrosecond to TimestampMillisecond
770        let new_scalar = try_cast_literal_to_type(
771            &ScalarValue::TimestampMicrosecond(Some(123), None),
772            &DataType::Timestamp(TimeUnit::Millisecond, None),
773        )
774        .unwrap();
775
776        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
777
778        // TimestampMicrosecond to TimestampSecond
779        let new_scalar = try_cast_literal_to_type(
780            &ScalarValue::TimestampMicrosecond(Some(123456789), None),
781            &DataType::Timestamp(TimeUnit::Second, None),
782        )
783        .unwrap();
784        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
785
786        // TimestampMillisecond to TimestampNanosecond
787        let new_scalar = try_cast_literal_to_type(
788            &ScalarValue::TimestampMillisecond(Some(123), None),
789            &DataType::Timestamp(TimeUnit::Nanosecond, None),
790        )
791        .unwrap();
792        assert_eq!(
793            new_scalar,
794            ScalarValue::TimestampNanosecond(Some(123000000), None)
795        );
796
797        // TimestampMillisecond to TimestampMicrosecond
798        let new_scalar = try_cast_literal_to_type(
799            &ScalarValue::TimestampMillisecond(Some(123), None),
800            &DataType::Timestamp(TimeUnit::Microsecond, None),
801        )
802        .unwrap();
803        assert_eq!(
804            new_scalar,
805            ScalarValue::TimestampMicrosecond(Some(123000), None)
806        );
807        // TimestampMillisecond to TimestampSecond
808        let new_scalar = try_cast_literal_to_type(
809            &ScalarValue::TimestampMillisecond(Some(123456789), None),
810            &DataType::Timestamp(TimeUnit::Second, None),
811        )
812        .unwrap();
813        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
814
815        // TimestampSecond to TimestampNanosecond
816        let new_scalar = try_cast_literal_to_type(
817            &ScalarValue::TimestampSecond(Some(123), None),
818            &DataType::Timestamp(TimeUnit::Nanosecond, None),
819        )
820        .unwrap();
821        assert_eq!(
822            new_scalar,
823            ScalarValue::TimestampNanosecond(Some(123000000000), None)
824        );
825
826        // TimestampSecond to TimestampMicrosecond
827        let new_scalar = try_cast_literal_to_type(
828            &ScalarValue::TimestampSecond(Some(123), None),
829            &DataType::Timestamp(TimeUnit::Microsecond, None),
830        )
831        .unwrap();
832        assert_eq!(
833            new_scalar,
834            ScalarValue::TimestampMicrosecond(Some(123000000), None)
835        );
836
837        // TimestampSecond to TimestampMillisecond
838        let new_scalar = try_cast_literal_to_type(
839            &ScalarValue::TimestampSecond(Some(123), None),
840            &DataType::Timestamp(TimeUnit::Millisecond, None),
841        )
842        .unwrap();
843        assert_eq!(
844            new_scalar,
845            ScalarValue::TimestampMillisecond(Some(123000), None)
846        );
847
848        // overflow
849        let new_scalar = try_cast_literal_to_type(
850            &ScalarValue::TimestampSecond(Some(i64::MAX), None),
851            &DataType::Timestamp(TimeUnit::Millisecond, None),
852        )
853        .unwrap();
854        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
855    }
856
857    #[test]
858    fn test_try_cast_to_string_type() {
859        let scalars = vec![
860            ScalarValue::from("string"),
861            ScalarValue::LargeUtf8(Some("string".to_owned())),
862        ];
863
864        for s1 in &scalars {
865            for s2 in &scalars {
866                let expected_value = ExpectedCast::Value(s2.clone());
867
868                expect_cast(s1.clone(), s2.data_type(), expected_value);
869            }
870        }
871    }
872
873    #[test]
874    fn test_try_cast_to_dictionary_type() {
875        fn dictionary_type(t: DataType) -> DataType {
876            DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
877        }
878        fn dictionary_value(value: ScalarValue) -> ScalarValue {
879            ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
880        }
881        let scalars = vec![
882            ScalarValue::from("string"),
883            ScalarValue::LargeUtf8(Some("string".to_owned())),
884        ];
885        for s in &scalars {
886            expect_cast(
887                s.clone(),
888                dictionary_type(s.data_type()),
889                ExpectedCast::Value(dictionary_value(s.clone())),
890            );
891            expect_cast(
892                dictionary_value(s.clone()),
893                s.data_type(),
894                ExpectedCast::Value(s.clone()),
895            )
896        }
897    }
898
899    #[test]
900    fn test_try_cast_to_fixed_size_binary() {
901        expect_cast(
902            ScalarValue::Binary(Some(vec![1, 2, 3])),
903            DataType::FixedSizeBinary(3),
904            ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
905        )
906    }
907
908    #[test]
909    fn test_numeric_boundary_values() {
910        // Test exact boundary values for signed integers
911        expect_cast(
912            ScalarValue::Int8(Some(i8::MAX)),
913            DataType::UInt8,
914            ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))),
915        );
916
917        expect_cast(
918            ScalarValue::Int8(Some(i8::MIN)),
919            DataType::UInt8,
920            ExpectedCast::NoValue,
921        );
922
923        expect_cast(
924            ScalarValue::UInt8(Some(u8::MAX)),
925            DataType::Int8,
926            ExpectedCast::NoValue,
927        );
928
929        // Test cross-type boundary scenarios
930        expect_cast(
931            ScalarValue::Int32(Some(i32::MAX)),
932            DataType::Int64,
933            ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))),
934        );
935
936        expect_cast(
937            ScalarValue::Int64(Some(i64::MIN)),
938            DataType::UInt64,
939            ExpectedCast::NoValue,
940        );
941
942        // Test unsigned to signed edge cases
943        expect_cast(
944            ScalarValue::UInt32(Some(u32::MAX)),
945            DataType::Int32,
946            ExpectedCast::NoValue,
947        );
948
949        expect_cast(
950            ScalarValue::UInt64(Some(u64::MAX)),
951            DataType::Int64,
952            ExpectedCast::NoValue,
953        );
954    }
955
956    #[test]
957    fn test_decimal_precision_limits() {
958        use arrow::datatypes::{
959            MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION,
960        };
961
962        // Test maximum precision values
963        expect_cast(
964            ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
965            DataType::Decimal128(5, 0),
966            ExpectedCast::Value(ScalarValue::Decimal128(
967                Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]),
968                5,
969                0,
970            )),
971        );
972
973        // Test minimum precision values
974        expect_cast(
975            ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
976            DataType::Decimal128(5, 0),
977            ExpectedCast::Value(ScalarValue::Decimal128(
978                Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]),
979                5,
980                0,
981            )),
982        );
983
984        // Test scale increase
985        expect_cast(
986            ScalarValue::Decimal128(Some(123), 3, 0),
987            DataType::Decimal128(5, 2),
988            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)),
989        );
990
991        // Test precision overflow (value too large for target precision)
992        expect_cast(
993            ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0),
994            DataType::Decimal128(3, 0),
995            ExpectedCast::NoValue,
996        );
997
998        // Test non-divisible decimal conversion (should fail)
999        expect_cast(
1000            ScalarValue::Decimal128(Some(12345), 5, 3), // 12.345
1001            DataType::Int32,
1002            ExpectedCast::NoValue, // Can't convert 12.345 to integer without loss
1003        );
1004
1005        // Test edge case: scale reduction with precision loss
1006        expect_cast(
1007            ScalarValue::Decimal128(Some(12345), 5, 2), // 123.45
1008            DataType::Decimal128(3, 0),                 // Can only hold up to 999
1009            ExpectedCast::NoValue,
1010        );
1011    }
1012
1013    #[test]
1014    fn test_timestamp_overflow_scenarios() {
1015        // Test overflow in timestamp conversions
1016        let max_seconds = i64::MAX / 1_000_000_000; // Avoid overflow when converting to nanos
1017
1018        // This should work - within safe range
1019        expect_cast(
1020            ScalarValue::TimestampSecond(Some(max_seconds), None),
1021            DataType::Timestamp(TimeUnit::Nanosecond, None),
1022            ExpectedCast::Value(ScalarValue::TimestampNanosecond(
1023                Some(max_seconds * 1_000_000_000),
1024                None,
1025            )),
1026        );
1027
1028        // Test very large nanosecond value conversion to smaller units
1029        expect_cast(
1030            ScalarValue::TimestampNanosecond(Some(i64::MAX), None),
1031            DataType::Timestamp(TimeUnit::Second, None),
1032            ExpectedCast::Value(ScalarValue::TimestampSecond(
1033                Some(i64::MAX / 1_000_000_000),
1034                None,
1035            )),
1036        );
1037
1038        // Test precision loss in downscaling
1039        expect_cast(
1040            ScalarValue::TimestampNanosecond(Some(1), None),
1041            DataType::Timestamp(TimeUnit::Second, None),
1042            ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)),
1043        );
1044
1045        expect_cast(
1046            ScalarValue::TimestampMicrosecond(Some(999), None),
1047            DataType::Timestamp(TimeUnit::Millisecond, None),
1048            ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)),
1049        );
1050    }
1051
1052    #[test]
1053    fn test_string_view() {
1054        // Test Utf8View to other string types
1055        expect_cast(
1056            ScalarValue::Utf8View(Some("test".to_string())),
1057            DataType::Utf8,
1058            ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))),
1059        );
1060
1061        expect_cast(
1062            ScalarValue::Utf8View(Some("test".to_string())),
1063            DataType::LargeUtf8,
1064            ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))),
1065        );
1066
1067        // Test other string types to Utf8View
1068        expect_cast(
1069            ScalarValue::Utf8(Some("hello".to_string())),
1070            DataType::Utf8View,
1071            ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))),
1072        );
1073
1074        expect_cast(
1075            ScalarValue::LargeUtf8(Some("world".to_string())),
1076            DataType::Utf8View,
1077            ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))),
1078        );
1079
1080        // Test empty string
1081        expect_cast(
1082            ScalarValue::Utf8(Some("".to_string())),
1083            DataType::Utf8View,
1084            ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))),
1085        );
1086
1087        // Test large string
1088        let large_string = "x".repeat(1000);
1089        expect_cast(
1090            ScalarValue::LargeUtf8(Some(large_string.clone())),
1091            DataType::Utf8View,
1092            ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))),
1093        );
1094    }
1095
1096    #[test]
1097    fn test_binary_size_edge_cases() {
1098        // Test size mismatch - too small
1099        expect_cast(
1100            ScalarValue::Binary(Some(vec![1, 2])),
1101            DataType::FixedSizeBinary(3),
1102            ExpectedCast::NoValue,
1103        );
1104
1105        // Test size mismatch - too large
1106        expect_cast(
1107            ScalarValue::Binary(Some(vec![1, 2, 3, 4])),
1108            DataType::FixedSizeBinary(3),
1109            ExpectedCast::NoValue,
1110        );
1111
1112        // Test empty binary
1113        expect_cast(
1114            ScalarValue::Binary(Some(vec![])),
1115            DataType::FixedSizeBinary(0),
1116            ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))),
1117        );
1118
1119        // Test exact size match
1120        expect_cast(
1121            ScalarValue::Binary(Some(vec![1, 2, 3])),
1122            DataType::FixedSizeBinary(3),
1123            ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
1124        );
1125
1126        // Test single byte
1127        expect_cast(
1128            ScalarValue::Binary(Some(vec![42])),
1129            DataType::FixedSizeBinary(1),
1130            ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))),
1131        );
1132    }
1133
1134    #[test]
1135    fn test_dictionary_index_types() {
1136        // Test different dictionary index types
1137        let string_value = ScalarValue::Utf8(Some("test".to_string()));
1138
1139        // Int8 index dictionary
1140        let dict_int8 =
1141            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
1142        expect_cast(
1143            string_value.clone(),
1144            dict_int8,
1145            ExpectedCast::Value(ScalarValue::Dictionary(
1146                Box::new(DataType::Int8),
1147                Box::new(string_value.clone()),
1148            )),
1149        );
1150
1151        // Int16 index dictionary
1152        let dict_int16 =
1153            DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
1154        expect_cast(
1155            string_value.clone(),
1156            dict_int16,
1157            ExpectedCast::Value(ScalarValue::Dictionary(
1158                Box::new(DataType::Int16),
1159                Box::new(string_value.clone()),
1160            )),
1161        );
1162
1163        // Int64 index dictionary
1164        let dict_int64 =
1165            DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8));
1166        expect_cast(
1167            string_value.clone(),
1168            dict_int64,
1169            ExpectedCast::Value(ScalarValue::Dictionary(
1170                Box::new(DataType::Int64),
1171                Box::new(string_value.clone()),
1172            )),
1173        );
1174
1175        // Test dictionary unwrapping
1176        let dict_value = ScalarValue::Dictionary(
1177            Box::new(DataType::Int32),
1178            Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1179        );
1180        expect_cast(
1181            dict_value,
1182            DataType::LargeUtf8,
1183            ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1184        );
1185    }
1186
1187    #[test]
1188    fn test_type_support_functions() {
1189        // Test numeric type support
1190        assert!(is_supported_numeric_type(&DataType::Int8));
1191        assert!(is_supported_numeric_type(&DataType::UInt64));
1192        assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2)));
1193        assert!(is_supported_numeric_type(&DataType::Timestamp(
1194            TimeUnit::Nanosecond,
1195            None
1196        )));
1197        assert!(!is_supported_numeric_type(&DataType::Float32));
1198        assert!(!is_supported_numeric_type(&DataType::Float64));
1199
1200        // Test string type support
1201        assert!(is_supported_string_type(&DataType::Utf8));
1202        assert!(is_supported_string_type(&DataType::LargeUtf8));
1203        assert!(is_supported_string_type(&DataType::Utf8View));
1204        assert!(!is_supported_string_type(&DataType::Binary));
1205
1206        // Test binary type support
1207        assert!(is_supported_binary_type(&DataType::Binary));
1208        assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10)));
1209        assert!(!is_supported_binary_type(&DataType::Utf8));
1210
1211        // Test dictionary type support with nested types
1212        assert!(is_supported_dictionary_type(&DataType::Dictionary(
1213            Box::new(DataType::Int32),
1214            Box::new(DataType::Utf8)
1215        )));
1216        assert!(is_supported_dictionary_type(&DataType::Dictionary(
1217            Box::new(DataType::Int32),
1218            Box::new(DataType::Int64)
1219        )));
1220        assert!(!is_supported_dictionary_type(&DataType::Dictionary(
1221            Box::new(DataType::Int32),
1222            Box::new(DataType::List(Arc::new(Field::new(
1223                "item",
1224                DataType::Int32,
1225                true
1226            ))))
1227        )));
1228
1229        // Test overall type support
1230        assert!(is_supported_type(&DataType::Int32));
1231        assert!(is_supported_type(&DataType::Utf8));
1232        assert!(is_supported_type(&DataType::Binary));
1233        assert!(is_supported_type(&DataType::Dictionary(
1234            Box::new(DataType::Int32),
1235            Box::new(DataType::Utf8)
1236        )));
1237        assert!(!is_supported_type(&DataType::List(Arc::new(Field::new(
1238            "item",
1239            DataType::Int32,
1240            true
1241        )))));
1242        assert!(!is_supported_type(&DataType::Struct(Fields::empty())));
1243    }
1244
1245    #[test]
1246    fn test_error_conditions() {
1247        // Test unsupported source type
1248        expect_cast(
1249            ScalarValue::Float32(Some(1.5)),
1250            DataType::Int32,
1251            ExpectedCast::NoValue,
1252        );
1253
1254        // Test unsupported target type
1255        expect_cast(
1256            ScalarValue::Int32(Some(123)),
1257            DataType::Float64,
1258            ExpectedCast::NoValue,
1259        );
1260
1261        // Test both types unsupported
1262        expect_cast(
1263            ScalarValue::Float64(Some(1.5)),
1264            DataType::Float32,
1265            ExpectedCast::NoValue,
1266        );
1267
1268        // Test complex unsupported types
1269        let list_type =
1270            DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1271        expect_cast(
1272            ScalarValue::Int32(Some(123)),
1273            list_type,
1274            ExpectedCast::NoValue,
1275        );
1276
1277        // Test dictionary with unsupported inner type
1278        let bad_dict = DataType::Dictionary(
1279            Box::new(DataType::Int32),
1280            Box::new(DataType::List(Arc::new(Field::new(
1281                "item",
1282                DataType::Int32,
1283                true,
1284            )))),
1285        );
1286        expect_cast(
1287            ScalarValue::Int32(Some(123)),
1288            bad_dict,
1289            ExpectedCast::NoValue,
1290        );
1291    }
1292}