datafusion_expr_common/
casts.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities for casting scalar literals to different data types
19//!
20//! This module contains functions for casting ScalarValue literals
21//! to different data types, originally extracted from the optimizer's
22//! unwrap_cast module to be shared between logical and physical layers.
23
24use std::cmp::Ordering;
25
26use arrow::datatypes::{
27    DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
28    MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
29    MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION,
30    MIN_DECIMAL64_FOR_EACH_PRECISION,
31};
32use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
33use datafusion_common::ScalarValue;
34
35/// Convert a literal value from one data type to another
36pub fn try_cast_literal_to_type(
37    lit_value: &ScalarValue,
38    target_type: &DataType,
39) -> Option<ScalarValue> {
40    let lit_data_type = lit_value.data_type();
41    if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
42        return None;
43    }
44    if lit_value.is_null() {
45        // null value can be cast to any type of null value
46        return ScalarValue::try_from(target_type).ok();
47    }
48    try_cast_numeric_literal(lit_value, target_type)
49        .or_else(|| try_cast_string_literal(lit_value, target_type))
50        .or_else(|| try_cast_dictionary(lit_value, target_type))
51        .or_else(|| try_cast_binary(lit_value, target_type))
52}
53
54/// Returns true if unwrap_cast_in_comparison supports this data type
55pub fn is_supported_type(data_type: &DataType) -> bool {
56    is_supported_numeric_type(data_type)
57        || is_supported_string_type(data_type)
58        || is_supported_dictionary_type(data_type)
59        || is_supported_binary_type(data_type)
60}
61
62/// Returns true if unwrap_cast_in_comparison support this numeric type
63fn is_supported_numeric_type(data_type: &DataType) -> bool {
64    matches!(
65        data_type,
66        DataType::UInt8
67            | DataType::UInt16
68            | DataType::UInt32
69            | DataType::UInt64
70            | DataType::Int8
71            | DataType::Int16
72            | DataType::Int32
73            | DataType::Int64
74            | DataType::Decimal32(_, _)
75            | DataType::Decimal64(_, _)
76            | DataType::Decimal128(_, _)
77            | DataType::Timestamp(_, _)
78    )
79}
80
81/// Returns true if unwrap_cast_in_comparison supports casting this value as a string
82fn is_supported_string_type(data_type: &DataType) -> bool {
83    matches!(
84        data_type,
85        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
86    )
87}
88
89/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary
90fn is_supported_dictionary_type(data_type: &DataType) -> bool {
91    matches!(data_type,
92                    DataType::Dictionary(_, inner) if is_supported_type(inner))
93}
94
95fn is_supported_binary_type(data_type: &DataType) -> bool {
96    matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
97}
98
99/// Convert a numeric value from one numeric data type to another
100fn try_cast_numeric_literal(
101    lit_value: &ScalarValue,
102    target_type: &DataType,
103) -> Option<ScalarValue> {
104    let lit_data_type = lit_value.data_type();
105    if !is_supported_numeric_type(&lit_data_type)
106        || !is_supported_numeric_type(target_type)
107    {
108        return None;
109    }
110
111    let mul = match target_type {
112        DataType::UInt8
113        | DataType::UInt16
114        | DataType::UInt32
115        | DataType::UInt64
116        | DataType::Int8
117        | DataType::Int16
118        | DataType::Int32
119        | DataType::Int64 => 1_i128,
120        DataType::Timestamp(_, _) => 1_i128,
121        DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
122        DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
123        DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
124        _ => return None,
125    };
126    let (target_min, target_max) = match target_type {
127        DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
128        DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
129        DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
130        DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
131        DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
132        DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
133        DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
134        DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
135        DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
136        DataType::Decimal32(precision, _) => (
137            // Different precision for decimal32 can store different range of value.
138            // For example, the precision is 3, the max of value is `999` and the min
139            // value is `-999`
140            MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
141            MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
142        ),
143        DataType::Decimal64(precision, _) => (
144            // Different precision for decimal64 can store different range of value.
145            // For example, the precision is 3, the max of value is `999` and the min
146            // value is `-999`
147            MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
148            MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
149        ),
150        DataType::Decimal128(precision, _) => (
151            // Different precision for decimal128 can store different range of value.
152            // For example, the precision is 3, the max of value is `999` and the min
153            // value is `-999`
154            MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
155            MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
156        ),
157        _ => return None,
158    };
159    let lit_value_target_type = match lit_value {
160        ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
161        ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
162        ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
163        ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
164        ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
165        ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
166        ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
167        ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
168        ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
169        ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
170        ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
171        ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
172        ScalarValue::Decimal32(Some(v), _, scale) => {
173            let v = *v as i128;
174            let lit_scale_mul = 10_i128.pow(*scale as u32);
175            if mul >= lit_scale_mul {
176                // Example:
177                // lit is decimal(123,3,2)
178                // target type is decimal(5,3)
179                // the lit can be converted to the decimal(1230,5,3)
180                v.checked_mul(mul / lit_scale_mul)
181            } else if v % (lit_scale_mul / mul) == 0 {
182                // Example:
183                // lit is decimal(123000,10,3)
184                // target type is int32: the lit can be converted to INT32(123)
185                // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
186                Some(v / (lit_scale_mul / mul))
187            } else {
188                // can't convert the lit decimal to the target data type
189                None
190            }
191        }
192        ScalarValue::Decimal64(Some(v), _, scale) => {
193            let v = *v as i128;
194            let lit_scale_mul = 10_i128.pow(*scale as u32);
195            if mul >= lit_scale_mul {
196                // Example:
197                // lit is decimal(123,3,2)
198                // target type is decimal(5,3)
199                // the lit can be converted to the decimal(1230,5,3)
200                v.checked_mul(mul / lit_scale_mul)
201            } else if v % (lit_scale_mul / mul) == 0 {
202                // Example:
203                // lit is decimal(123000,10,3)
204                // target type is int32: the lit can be converted to INT32(123)
205                // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
206                Some(v / (lit_scale_mul / mul))
207            } else {
208                // can't convert the lit decimal to the target data type
209                None
210            }
211        }
212        ScalarValue::Decimal128(Some(v), _, scale) => {
213            let lit_scale_mul = 10_i128.pow(*scale as u32);
214            if mul >= lit_scale_mul {
215                // Example:
216                // lit is decimal(123,3,2)
217                // target type is decimal(5,3)
218                // the lit can be converted to the decimal(1230,5,3)
219                (*v).checked_mul(mul / lit_scale_mul)
220            } else if (*v) % (lit_scale_mul / mul) == 0 {
221                // Example:
222                // lit is decimal(123000,10,3)
223                // target type is int32: the lit can be converted to INT32(123)
224                // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
225                Some(*v / (lit_scale_mul / mul))
226            } else {
227                // can't convert the lit decimal to the target data type
228                None
229            }
230        }
231        _ => None,
232    };
233
234    match lit_value_target_type {
235        None => None,
236        Some(value) => {
237            if value >= target_min && value <= target_max {
238                // the value casted from lit to the target type is in the range of target type.
239                // return the target type of scalar value
240                let result_scalar = match target_type {
241                    DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
242                    DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
243                    DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
244                    DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
245                    DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
246                    DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
247                    DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
248                    DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
249                    DataType::Timestamp(TimeUnit::Second, tz) => {
250                        let value = cast_between_timestamp(
251                            &lit_data_type,
252                            &DataType::Timestamp(TimeUnit::Second, tz.clone()),
253                            value,
254                        );
255                        ScalarValue::TimestampSecond(value, tz.clone())
256                    }
257                    DataType::Timestamp(TimeUnit::Millisecond, tz) => {
258                        let value = cast_between_timestamp(
259                            &lit_data_type,
260                            &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
261                            value,
262                        );
263                        ScalarValue::TimestampMillisecond(value, tz.clone())
264                    }
265                    DataType::Timestamp(TimeUnit::Microsecond, tz) => {
266                        let value = cast_between_timestamp(
267                            &lit_data_type,
268                            &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
269                            value,
270                        );
271                        ScalarValue::TimestampMicrosecond(value, tz.clone())
272                    }
273                    DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
274                        let value = cast_between_timestamp(
275                            &lit_data_type,
276                            &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
277                            value,
278                        );
279                        ScalarValue::TimestampNanosecond(value, tz.clone())
280                    }
281                    DataType::Decimal32(p, s) => {
282                        ScalarValue::Decimal32(Some(value as i32), *p, *s)
283                    }
284                    DataType::Decimal64(p, s) => {
285                        ScalarValue::Decimal64(Some(value as i64), *p, *s)
286                    }
287                    DataType::Decimal128(p, s) => {
288                        ScalarValue::Decimal128(Some(value), *p, *s)
289                    }
290                    _ => {
291                        return None;
292                    }
293                };
294                Some(result_scalar)
295            } else {
296                None
297            }
298        }
299    }
300}
301
302fn try_cast_string_literal(
303    lit_value: &ScalarValue,
304    target_type: &DataType,
305) -> Option<ScalarValue> {
306    let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
307    let scalar_value = match target_type {
308        DataType::Utf8 => ScalarValue::Utf8(string_value),
309        DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
310        DataType::Utf8View => ScalarValue::Utf8View(string_value),
311        _ => return None,
312    };
313    Some(scalar_value)
314}
315
316/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary
317fn try_cast_dictionary(
318    lit_value: &ScalarValue,
319    target_type: &DataType,
320) -> Option<ScalarValue> {
321    let lit_value_type = lit_value.data_type();
322    let result_scalar = match (lit_value, target_type) {
323        // Unwrap dictionary when inner type matches target type
324        (ScalarValue::Dictionary(_, inner_value), _)
325            if inner_value.data_type() == *target_type =>
326        {
327            (**inner_value).clone()
328        }
329        // Wrap type when target type is dictionary
330        (_, DataType::Dictionary(index_type, inner_type))
331            if **inner_type == lit_value_type =>
332        {
333            ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
334        }
335        _ => {
336            return None;
337        }
338    };
339    Some(result_scalar)
340}
341
342/// Cast a timestamp value from one unit to another
343fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
344    let value = value as i64;
345    let from_scale = match from {
346        DataType::Timestamp(TimeUnit::Second, _) => 1,
347        DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
348        DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
349        DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
350        _ => return Some(value),
351    };
352
353    let to_scale = match to {
354        DataType::Timestamp(TimeUnit::Second, _) => 1,
355        DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
356        DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
357        DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
358        _ => return Some(value),
359    };
360
361    match from_scale.cmp(&to_scale) {
362        Ordering::Less => value.checked_mul(to_scale / from_scale),
363        Ordering::Greater => Some(value / (from_scale / to_scale)),
364        Ordering::Equal => Some(value),
365    }
366}
367
368fn try_cast_binary(
369    lit_value: &ScalarValue,
370    target_type: &DataType,
371) -> Option<ScalarValue> {
372    match (lit_value, target_type) {
373        (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
374            if v.len() == *n as usize =>
375        {
376            Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
377        }
378        _ => None,
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use arrow::compute::{cast_with_options, CastOptions};
386    use arrow::datatypes::{Field, Fields, TimeUnit};
387    use std::sync::Arc;
388
389    #[derive(Debug, Clone)]
390    enum ExpectedCast {
391        /// test successfully cast value and it is as specified
392        Value(ScalarValue),
393        /// test returned OK, but could not cast the value
394        NoValue,
395    }
396
397    /// Runs try_cast_literal_to_type with the specified inputs and
398    /// ensure it computes the expected output, and ensures the
399    /// casting is consistent with the Arrow kernels
400    fn expect_cast(
401        literal: ScalarValue,
402        target_type: DataType,
403        expected_result: ExpectedCast,
404    ) {
405        let actual_value = try_cast_literal_to_type(&literal, &target_type);
406
407        println!("expect_cast: ");
408        println!("  {literal:?} --> {target_type}");
409        println!("  expected_result: {expected_result:?}");
410        println!("  actual_result:   {actual_value:?}");
411
412        match expected_result {
413            ExpectedCast::Value(expected_value) => {
414                let actual_value =
415                    actual_value.expect("Expected cast value but got None");
416
417                assert_eq!(actual_value, expected_value);
418
419                // Verify that calling the arrow
420                // cast kernel yields the same results
421                // input array
422                let literal_array = literal
423                    .to_array_of_size(1)
424                    .expect("Failed to convert to array of size");
425                let expected_array = expected_value
426                    .to_array_of_size(1)
427                    .expect("Failed to convert to array of size");
428                let cast_array = cast_with_options(
429                    &literal_array,
430                    &target_type,
431                    &CastOptions::default(),
432                )
433                .expect("Expected to be cast array with arrow cast kernel");
434
435                assert_eq!(
436                    &expected_array, &cast_array,
437                    "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
438                );
439
440                // Verify that for timestamp types the timezones are the same
441                // (ScalarValue::cmp doesn't account for timezones);
442                if let (
443                    DataType::Timestamp(left_unit, left_tz),
444                    DataType::Timestamp(right_unit, right_tz),
445                ) = (actual_value.data_type(), expected_value.data_type())
446                {
447                    assert_eq!(left_unit, right_unit);
448                    assert_eq!(left_tz, right_tz);
449                }
450            }
451            ExpectedCast::NoValue => {
452                assert!(
453                    actual_value.is_none(),
454                    "Expected no cast value, but got {actual_value:?}"
455                );
456            }
457        }
458    }
459
460    #[test]
461    fn test_try_cast_to_type_nulls() {
462        // test that nulls can be cast to/from all integer types
463        let scalars = vec![
464            ScalarValue::Int8(None),
465            ScalarValue::Int16(None),
466            ScalarValue::Int32(None),
467            ScalarValue::Int64(None),
468            ScalarValue::UInt8(None),
469            ScalarValue::UInt16(None),
470            ScalarValue::UInt32(None),
471            ScalarValue::UInt64(None),
472            ScalarValue::Decimal128(None, 3, 0),
473            ScalarValue::Decimal128(None, 8, 2),
474            ScalarValue::Utf8(None),
475            ScalarValue::LargeUtf8(None),
476        ];
477
478        for s1 in &scalars {
479            for s2 in &scalars {
480                let expected_value = ExpectedCast::Value(s2.clone());
481
482                expect_cast(s1.clone(), s2.data_type(), expected_value);
483            }
484        }
485    }
486
487    #[test]
488    fn test_try_cast_to_type_int_in_range() {
489        // test values that can be cast to/from all integer types
490        let scalars = vec![
491            ScalarValue::Int8(Some(123)),
492            ScalarValue::Int16(Some(123)),
493            ScalarValue::Int32(Some(123)),
494            ScalarValue::Int64(Some(123)),
495            ScalarValue::UInt8(Some(123)),
496            ScalarValue::UInt16(Some(123)),
497            ScalarValue::UInt32(Some(123)),
498            ScalarValue::UInt64(Some(123)),
499            ScalarValue::Decimal128(Some(123), 3, 0),
500            ScalarValue::Decimal128(Some(12300), 8, 2),
501        ];
502
503        for s1 in &scalars {
504            for s2 in &scalars {
505                let expected_value = ExpectedCast::Value(s2.clone());
506
507                expect_cast(s1.clone(), s2.data_type(), expected_value);
508            }
509        }
510
511        let max_i32 = ScalarValue::Int32(Some(i32::MAX));
512        expect_cast(
513            max_i32,
514            DataType::UInt64,
515            ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
516        );
517
518        let min_i32 = ScalarValue::Int32(Some(i32::MIN));
519        expect_cast(
520            min_i32,
521            DataType::Int64,
522            ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
523        );
524
525        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
526        expect_cast(
527            max_i64,
528            DataType::UInt64,
529            ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
530        );
531    }
532
533    #[test]
534    fn test_try_cast_to_type_int_out_of_range() {
535        let min_i32 = ScalarValue::Int32(Some(i32::MIN));
536        let min_i64 = ScalarValue::Int64(Some(i64::MIN));
537        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
538        let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
539
540        expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
541
542        expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
543
544        expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
545
546        expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
547
548        expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
549
550        expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
551
552        // decimal out of range
553        expect_cast(
554            ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
555            DataType::Int64,
556            ExpectedCast::NoValue,
557        );
558
559        expect_cast(
560            ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
561            DataType::Int64,
562            ExpectedCast::NoValue,
563        );
564    }
565
566    #[test]
567    fn test_try_decimal_cast_in_range() {
568        expect_cast(
569            ScalarValue::Decimal128(Some(12300), 5, 2),
570            DataType::Decimal128(3, 0),
571            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
572        );
573
574        expect_cast(
575            ScalarValue::Decimal128(Some(12300), 5, 2),
576            DataType::Decimal128(8, 0),
577            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
578        );
579
580        expect_cast(
581            ScalarValue::Decimal128(Some(12300), 5, 2),
582            DataType::Decimal128(8, 5),
583            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
584        );
585    }
586
587    #[test]
588    fn test_try_decimal_cast_out_of_range() {
589        // decimal would lose precision
590        expect_cast(
591            ScalarValue::Decimal128(Some(12345), 5, 2),
592            DataType::Decimal128(3, 0),
593            ExpectedCast::NoValue,
594        );
595
596        // decimal would lose precision
597        expect_cast(
598            ScalarValue::Decimal128(Some(12300), 5, 2),
599            DataType::Decimal128(2, 0),
600            ExpectedCast::NoValue,
601        );
602    }
603
604    #[test]
605    fn test_try_cast_to_type_timestamps() {
606        for time_unit in [
607            TimeUnit::Second,
608            TimeUnit::Millisecond,
609            TimeUnit::Microsecond,
610            TimeUnit::Nanosecond,
611        ] {
612            let utc = Some("+00:00".into());
613            // No timezone, utc timezone
614            let (lit_tz_none, lit_tz_utc) = match time_unit {
615                TimeUnit::Second => (
616                    ScalarValue::TimestampSecond(Some(12345), None),
617                    ScalarValue::TimestampSecond(Some(12345), utc),
618                ),
619
620                TimeUnit::Millisecond => (
621                    ScalarValue::TimestampMillisecond(Some(12345), None),
622                    ScalarValue::TimestampMillisecond(Some(12345), utc),
623                ),
624
625                TimeUnit::Microsecond => (
626                    ScalarValue::TimestampMicrosecond(Some(12345), None),
627                    ScalarValue::TimestampMicrosecond(Some(12345), utc),
628                ),
629
630                TimeUnit::Nanosecond => (
631                    ScalarValue::TimestampNanosecond(Some(12345), None),
632                    ScalarValue::TimestampNanosecond(Some(12345), utc),
633                ),
634            };
635
636            // DataFusion ignores timezones for comparisons of ScalarValue
637            // so double check it here
638            assert_eq!(lit_tz_none, lit_tz_utc);
639
640            // e.g. DataType::Timestamp(_, None)
641            let dt_tz_none = lit_tz_none.data_type();
642
643            // e.g. DataType::Timestamp(_, Some(utc))
644            let dt_tz_utc = lit_tz_utc.data_type();
645
646            // None <--> None
647            expect_cast(
648                lit_tz_none.clone(),
649                dt_tz_none.clone(),
650                ExpectedCast::Value(lit_tz_none.clone()),
651            );
652
653            // None <--> Utc
654            expect_cast(
655                lit_tz_none.clone(),
656                dt_tz_utc.clone(),
657                ExpectedCast::Value(lit_tz_utc.clone()),
658            );
659
660            // Utc <--> None
661            expect_cast(
662                lit_tz_utc.clone(),
663                dt_tz_none.clone(),
664                ExpectedCast::Value(lit_tz_none.clone()),
665            );
666
667            // Utc <--> Utc
668            expect_cast(
669                lit_tz_utc.clone(),
670                dt_tz_utc.clone(),
671                ExpectedCast::Value(lit_tz_utc.clone()),
672            );
673
674            // timestamp to int64
675            expect_cast(
676                lit_tz_utc.clone(),
677                DataType::Int64,
678                ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
679            );
680
681            // int64 to timestamp
682            expect_cast(
683                ScalarValue::Int64(Some(12345)),
684                dt_tz_none.clone(),
685                ExpectedCast::Value(lit_tz_none.clone()),
686            );
687
688            // int64 to timestamp
689            expect_cast(
690                ScalarValue::Int64(Some(12345)),
691                dt_tz_utc.clone(),
692                ExpectedCast::Value(lit_tz_utc.clone()),
693            );
694
695            // timestamp to string (not supported yet)
696            expect_cast(
697                lit_tz_utc.clone(),
698                DataType::LargeUtf8,
699                ExpectedCast::NoValue,
700            );
701        }
702    }
703
704    #[test]
705    fn test_try_cast_to_type_unsupported() {
706        // int64 to list
707        expect_cast(
708            ScalarValue::Int64(Some(12345)),
709            DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
710            ExpectedCast::NoValue,
711        );
712    }
713
714    #[test]
715    fn test_try_cast_literal_to_timestamp() {
716        // same timestamp
717        let new_scalar = try_cast_literal_to_type(
718            &ScalarValue::TimestampNanosecond(Some(123456), None),
719            &DataType::Timestamp(TimeUnit::Nanosecond, None),
720        )
721        .unwrap();
722
723        assert_eq!(
724            new_scalar,
725            ScalarValue::TimestampNanosecond(Some(123456), None)
726        );
727
728        // TimestampNanosecond to TimestampMicrosecond
729        let new_scalar = try_cast_literal_to_type(
730            &ScalarValue::TimestampNanosecond(Some(123456), None),
731            &DataType::Timestamp(TimeUnit::Microsecond, None),
732        )
733        .unwrap();
734
735        assert_eq!(
736            new_scalar,
737            ScalarValue::TimestampMicrosecond(Some(123), None)
738        );
739
740        // TimestampNanosecond to TimestampMillisecond
741        let new_scalar = try_cast_literal_to_type(
742            &ScalarValue::TimestampNanosecond(Some(123456), None),
743            &DataType::Timestamp(TimeUnit::Millisecond, None),
744        )
745        .unwrap();
746
747        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
748
749        // TimestampNanosecond to TimestampSecond
750        let new_scalar = try_cast_literal_to_type(
751            &ScalarValue::TimestampNanosecond(Some(123456), None),
752            &DataType::Timestamp(TimeUnit::Second, None),
753        )
754        .unwrap();
755
756        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
757
758        // TimestampMicrosecond to TimestampNanosecond
759        let new_scalar = try_cast_literal_to_type(
760            &ScalarValue::TimestampMicrosecond(Some(123), None),
761            &DataType::Timestamp(TimeUnit::Nanosecond, None),
762        )
763        .unwrap();
764
765        assert_eq!(
766            new_scalar,
767            ScalarValue::TimestampNanosecond(Some(123000), None)
768        );
769
770        // TimestampMicrosecond to TimestampMillisecond
771        let new_scalar = try_cast_literal_to_type(
772            &ScalarValue::TimestampMicrosecond(Some(123), None),
773            &DataType::Timestamp(TimeUnit::Millisecond, None),
774        )
775        .unwrap();
776
777        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
778
779        // TimestampMicrosecond to TimestampSecond
780        let new_scalar = try_cast_literal_to_type(
781            &ScalarValue::TimestampMicrosecond(Some(123456789), None),
782            &DataType::Timestamp(TimeUnit::Second, None),
783        )
784        .unwrap();
785        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
786
787        // TimestampMillisecond to TimestampNanosecond
788        let new_scalar = try_cast_literal_to_type(
789            &ScalarValue::TimestampMillisecond(Some(123), None),
790            &DataType::Timestamp(TimeUnit::Nanosecond, None),
791        )
792        .unwrap();
793        assert_eq!(
794            new_scalar,
795            ScalarValue::TimestampNanosecond(Some(123000000), None)
796        );
797
798        // TimestampMillisecond to TimestampMicrosecond
799        let new_scalar = try_cast_literal_to_type(
800            &ScalarValue::TimestampMillisecond(Some(123), None),
801            &DataType::Timestamp(TimeUnit::Microsecond, None),
802        )
803        .unwrap();
804        assert_eq!(
805            new_scalar,
806            ScalarValue::TimestampMicrosecond(Some(123000), None)
807        );
808        // TimestampMillisecond to TimestampSecond
809        let new_scalar = try_cast_literal_to_type(
810            &ScalarValue::TimestampMillisecond(Some(123456789), None),
811            &DataType::Timestamp(TimeUnit::Second, None),
812        )
813        .unwrap();
814        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
815
816        // TimestampSecond to TimestampNanosecond
817        let new_scalar = try_cast_literal_to_type(
818            &ScalarValue::TimestampSecond(Some(123), None),
819            &DataType::Timestamp(TimeUnit::Nanosecond, None),
820        )
821        .unwrap();
822        assert_eq!(
823            new_scalar,
824            ScalarValue::TimestampNanosecond(Some(123000000000), None)
825        );
826
827        // TimestampSecond to TimestampMicrosecond
828        let new_scalar = try_cast_literal_to_type(
829            &ScalarValue::TimestampSecond(Some(123), None),
830            &DataType::Timestamp(TimeUnit::Microsecond, None),
831        )
832        .unwrap();
833        assert_eq!(
834            new_scalar,
835            ScalarValue::TimestampMicrosecond(Some(123000000), None)
836        );
837
838        // TimestampSecond to TimestampMillisecond
839        let new_scalar = try_cast_literal_to_type(
840            &ScalarValue::TimestampSecond(Some(123), None),
841            &DataType::Timestamp(TimeUnit::Millisecond, None),
842        )
843        .unwrap();
844        assert_eq!(
845            new_scalar,
846            ScalarValue::TimestampMillisecond(Some(123000), None)
847        );
848
849        // overflow
850        let new_scalar = try_cast_literal_to_type(
851            &ScalarValue::TimestampSecond(Some(i64::MAX), None),
852            &DataType::Timestamp(TimeUnit::Millisecond, None),
853        )
854        .unwrap();
855        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
856    }
857
858    #[test]
859    fn test_try_cast_to_string_type() {
860        let scalars = vec![
861            ScalarValue::from("string"),
862            ScalarValue::LargeUtf8(Some("string".to_owned())),
863        ];
864
865        for s1 in &scalars {
866            for s2 in &scalars {
867                let expected_value = ExpectedCast::Value(s2.clone());
868
869                expect_cast(s1.clone(), s2.data_type(), expected_value);
870            }
871        }
872    }
873
874    #[test]
875    fn test_try_cast_to_dictionary_type() {
876        fn dictionary_type(t: DataType) -> DataType {
877            DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
878        }
879        fn dictionary_value(value: ScalarValue) -> ScalarValue {
880            ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
881        }
882        let scalars = vec![
883            ScalarValue::from("string"),
884            ScalarValue::LargeUtf8(Some("string".to_owned())),
885        ];
886        for s in &scalars {
887            expect_cast(
888                s.clone(),
889                dictionary_type(s.data_type()),
890                ExpectedCast::Value(dictionary_value(s.clone())),
891            );
892            expect_cast(
893                dictionary_value(s.clone()),
894                s.data_type(),
895                ExpectedCast::Value(s.clone()),
896            )
897        }
898    }
899
900    #[test]
901    fn test_try_cast_to_fixed_size_binary() {
902        expect_cast(
903            ScalarValue::Binary(Some(vec![1, 2, 3])),
904            DataType::FixedSizeBinary(3),
905            ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
906        )
907    }
908
909    #[test]
910    fn test_numeric_boundary_values() {
911        // Test exact boundary values for signed integers
912        expect_cast(
913            ScalarValue::Int8(Some(i8::MAX)),
914            DataType::UInt8,
915            ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))),
916        );
917
918        expect_cast(
919            ScalarValue::Int8(Some(i8::MIN)),
920            DataType::UInt8,
921            ExpectedCast::NoValue,
922        );
923
924        expect_cast(
925            ScalarValue::UInt8(Some(u8::MAX)),
926            DataType::Int8,
927            ExpectedCast::NoValue,
928        );
929
930        // Test cross-type boundary scenarios
931        expect_cast(
932            ScalarValue::Int32(Some(i32::MAX)),
933            DataType::Int64,
934            ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))),
935        );
936
937        expect_cast(
938            ScalarValue::Int64(Some(i64::MIN)),
939            DataType::UInt64,
940            ExpectedCast::NoValue,
941        );
942
943        // Test unsigned to signed edge cases
944        expect_cast(
945            ScalarValue::UInt32(Some(u32::MAX)),
946            DataType::Int32,
947            ExpectedCast::NoValue,
948        );
949
950        expect_cast(
951            ScalarValue::UInt64(Some(u64::MAX)),
952            DataType::Int64,
953            ExpectedCast::NoValue,
954        );
955    }
956
957    #[test]
958    fn test_decimal_precision_limits() {
959        use arrow::datatypes::{
960            MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION,
961        };
962
963        // Test maximum precision values
964        expect_cast(
965            ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
966            DataType::Decimal128(5, 0),
967            ExpectedCast::Value(ScalarValue::Decimal128(
968                Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]),
969                5,
970                0,
971            )),
972        );
973
974        // Test minimum precision values
975        expect_cast(
976            ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
977            DataType::Decimal128(5, 0),
978            ExpectedCast::Value(ScalarValue::Decimal128(
979                Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]),
980                5,
981                0,
982            )),
983        );
984
985        // Test scale increase
986        expect_cast(
987            ScalarValue::Decimal128(Some(123), 3, 0),
988            DataType::Decimal128(5, 2),
989            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)),
990        );
991
992        // Test precision overflow (value too large for target precision)
993        expect_cast(
994            ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0),
995            DataType::Decimal128(3, 0),
996            ExpectedCast::NoValue,
997        );
998
999        // Test non-divisible decimal conversion (should fail)
1000        expect_cast(
1001            ScalarValue::Decimal128(Some(12345), 5, 3), // 12.345
1002            DataType::Int32,
1003            ExpectedCast::NoValue, // Can't convert 12.345 to integer without loss
1004        );
1005
1006        // Test edge case: scale reduction with precision loss
1007        expect_cast(
1008            ScalarValue::Decimal128(Some(12345), 5, 2), // 123.45
1009            DataType::Decimal128(3, 0),                 // Can only hold up to 999
1010            ExpectedCast::NoValue,
1011        );
1012    }
1013
1014    #[test]
1015    fn test_timestamp_overflow_scenarios() {
1016        // Test overflow in timestamp conversions
1017        let max_seconds = i64::MAX / 1_000_000_000; // Avoid overflow when converting to nanos
1018
1019        // This should work - within safe range
1020        expect_cast(
1021            ScalarValue::TimestampSecond(Some(max_seconds), None),
1022            DataType::Timestamp(TimeUnit::Nanosecond, None),
1023            ExpectedCast::Value(ScalarValue::TimestampNanosecond(
1024                Some(max_seconds * 1_000_000_000),
1025                None,
1026            )),
1027        );
1028
1029        // Test very large nanosecond value conversion to smaller units
1030        expect_cast(
1031            ScalarValue::TimestampNanosecond(Some(i64::MAX), None),
1032            DataType::Timestamp(TimeUnit::Second, None),
1033            ExpectedCast::Value(ScalarValue::TimestampSecond(
1034                Some(i64::MAX / 1_000_000_000),
1035                None,
1036            )),
1037        );
1038
1039        // Test precision loss in downscaling
1040        expect_cast(
1041            ScalarValue::TimestampNanosecond(Some(1), None),
1042            DataType::Timestamp(TimeUnit::Second, None),
1043            ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)),
1044        );
1045
1046        expect_cast(
1047            ScalarValue::TimestampMicrosecond(Some(999), None),
1048            DataType::Timestamp(TimeUnit::Millisecond, None),
1049            ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)),
1050        );
1051    }
1052
1053    #[test]
1054    fn test_string_view() {
1055        // Test Utf8View to other string types
1056        expect_cast(
1057            ScalarValue::Utf8View(Some("test".to_string())),
1058            DataType::Utf8,
1059            ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))),
1060        );
1061
1062        expect_cast(
1063            ScalarValue::Utf8View(Some("test".to_string())),
1064            DataType::LargeUtf8,
1065            ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))),
1066        );
1067
1068        // Test other string types to Utf8View
1069        expect_cast(
1070            ScalarValue::Utf8(Some("hello".to_string())),
1071            DataType::Utf8View,
1072            ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))),
1073        );
1074
1075        expect_cast(
1076            ScalarValue::LargeUtf8(Some("world".to_string())),
1077            DataType::Utf8View,
1078            ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))),
1079        );
1080
1081        // Test empty string
1082        expect_cast(
1083            ScalarValue::Utf8(Some("".to_string())),
1084            DataType::Utf8View,
1085            ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))),
1086        );
1087
1088        // Test large string
1089        let large_string = "x".repeat(1000);
1090        expect_cast(
1091            ScalarValue::LargeUtf8(Some(large_string.clone())),
1092            DataType::Utf8View,
1093            ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))),
1094        );
1095    }
1096
1097    #[test]
1098    fn test_binary_size_edge_cases() {
1099        // Test size mismatch - too small
1100        expect_cast(
1101            ScalarValue::Binary(Some(vec![1, 2])),
1102            DataType::FixedSizeBinary(3),
1103            ExpectedCast::NoValue,
1104        );
1105
1106        // Test size mismatch - too large
1107        expect_cast(
1108            ScalarValue::Binary(Some(vec![1, 2, 3, 4])),
1109            DataType::FixedSizeBinary(3),
1110            ExpectedCast::NoValue,
1111        );
1112
1113        // Test empty binary
1114        expect_cast(
1115            ScalarValue::Binary(Some(vec![])),
1116            DataType::FixedSizeBinary(0),
1117            ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))),
1118        );
1119
1120        // Test exact size match
1121        expect_cast(
1122            ScalarValue::Binary(Some(vec![1, 2, 3])),
1123            DataType::FixedSizeBinary(3),
1124            ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
1125        );
1126
1127        // Test single byte
1128        expect_cast(
1129            ScalarValue::Binary(Some(vec![42])),
1130            DataType::FixedSizeBinary(1),
1131            ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))),
1132        );
1133    }
1134
1135    #[test]
1136    fn test_dictionary_index_types() {
1137        // Test different dictionary index types
1138        let string_value = ScalarValue::Utf8(Some("test".to_string()));
1139
1140        // Int8 index dictionary
1141        let dict_int8 =
1142            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
1143        expect_cast(
1144            string_value.clone(),
1145            dict_int8,
1146            ExpectedCast::Value(ScalarValue::Dictionary(
1147                Box::new(DataType::Int8),
1148                Box::new(string_value.clone()),
1149            )),
1150        );
1151
1152        // Int16 index dictionary
1153        let dict_int16 =
1154            DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
1155        expect_cast(
1156            string_value.clone(),
1157            dict_int16,
1158            ExpectedCast::Value(ScalarValue::Dictionary(
1159                Box::new(DataType::Int16),
1160                Box::new(string_value.clone()),
1161            )),
1162        );
1163
1164        // Int64 index dictionary
1165        let dict_int64 =
1166            DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8));
1167        expect_cast(
1168            string_value.clone(),
1169            dict_int64,
1170            ExpectedCast::Value(ScalarValue::Dictionary(
1171                Box::new(DataType::Int64),
1172                Box::new(string_value.clone()),
1173            )),
1174        );
1175
1176        // Test dictionary unwrapping
1177        let dict_value = ScalarValue::Dictionary(
1178            Box::new(DataType::Int32),
1179            Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1180        );
1181        expect_cast(
1182            dict_value,
1183            DataType::LargeUtf8,
1184            ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1185        );
1186    }
1187
1188    #[test]
1189    fn test_type_support_functions() {
1190        // Test numeric type support
1191        assert!(is_supported_numeric_type(&DataType::Int8));
1192        assert!(is_supported_numeric_type(&DataType::UInt64));
1193        assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2)));
1194        assert!(is_supported_numeric_type(&DataType::Timestamp(
1195            TimeUnit::Nanosecond,
1196            None
1197        )));
1198        assert!(!is_supported_numeric_type(&DataType::Float32));
1199        assert!(!is_supported_numeric_type(&DataType::Float64));
1200
1201        // Test string type support
1202        assert!(is_supported_string_type(&DataType::Utf8));
1203        assert!(is_supported_string_type(&DataType::LargeUtf8));
1204        assert!(is_supported_string_type(&DataType::Utf8View));
1205        assert!(!is_supported_string_type(&DataType::Binary));
1206
1207        // Test binary type support
1208        assert!(is_supported_binary_type(&DataType::Binary));
1209        assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10)));
1210        assert!(!is_supported_binary_type(&DataType::Utf8));
1211
1212        // Test dictionary type support with nested types
1213        assert!(is_supported_dictionary_type(&DataType::Dictionary(
1214            Box::new(DataType::Int32),
1215            Box::new(DataType::Utf8)
1216        )));
1217        assert!(is_supported_dictionary_type(&DataType::Dictionary(
1218            Box::new(DataType::Int32),
1219            Box::new(DataType::Int64)
1220        )));
1221        assert!(!is_supported_dictionary_type(&DataType::Dictionary(
1222            Box::new(DataType::Int32),
1223            Box::new(DataType::List(Arc::new(Field::new(
1224                "item",
1225                DataType::Int32,
1226                true
1227            ))))
1228        )));
1229
1230        // Test overall type support
1231        assert!(is_supported_type(&DataType::Int32));
1232        assert!(is_supported_type(&DataType::Utf8));
1233        assert!(is_supported_type(&DataType::Binary));
1234        assert!(is_supported_type(&DataType::Dictionary(
1235            Box::new(DataType::Int32),
1236            Box::new(DataType::Utf8)
1237        )));
1238        assert!(!is_supported_type(&DataType::List(Arc::new(Field::new(
1239            "item",
1240            DataType::Int32,
1241            true
1242        )))));
1243        assert!(!is_supported_type(&DataType::Struct(Fields::empty())));
1244    }
1245
1246    #[test]
1247    fn test_error_conditions() {
1248        // Test unsupported source type
1249        expect_cast(
1250            ScalarValue::Float32(Some(1.5)),
1251            DataType::Int32,
1252            ExpectedCast::NoValue,
1253        );
1254
1255        // Test unsupported target type
1256        expect_cast(
1257            ScalarValue::Int32(Some(123)),
1258            DataType::Float64,
1259            ExpectedCast::NoValue,
1260        );
1261
1262        // Test both types unsupported
1263        expect_cast(
1264            ScalarValue::Float64(Some(1.5)),
1265            DataType::Float32,
1266            ExpectedCast::NoValue,
1267        );
1268
1269        // Test complex unsupported types
1270        let list_type =
1271            DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1272        expect_cast(
1273            ScalarValue::Int32(Some(123)),
1274            list_type,
1275            ExpectedCast::NoValue,
1276        );
1277
1278        // Test dictionary with unsupported inner type
1279        let bad_dict = DataType::Dictionary(
1280            Box::new(DataType::Int32),
1281            Box::new(DataType::List(Arc::new(Field::new(
1282                "item",
1283                DataType::Int32,
1284                true,
1285            )))),
1286        );
1287        expect_cast(
1288            ScalarValue::Int32(Some(123)),
1289            bad_dict,
1290            ExpectedCast::NoValue,
1291        );
1292    }
1293}