datafusion_expr_common/
casts.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities for casting scalar literals to different data types
19//!
20//! This module contains functions for casting ScalarValue literals
21//! to different data types, originally extracted from the optimizer's
22//! unwrap_cast module to be shared between logical and physical layers.
23
24use std::cmp::Ordering;
25
26use arrow::datatypes::{
27    DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
28    MIN_DECIMAL128_FOR_EACH_PRECISION,
29};
30use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
31use datafusion_common::ScalarValue;
32
33/// Convert a literal value from one data type to another
34pub fn try_cast_literal_to_type(
35    lit_value: &ScalarValue,
36    target_type: &DataType,
37) -> Option<ScalarValue> {
38    let lit_data_type = lit_value.data_type();
39    if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
40        return None;
41    }
42    if lit_value.is_null() {
43        // null value can be cast to any type of null value
44        return ScalarValue::try_from(target_type).ok();
45    }
46    try_cast_numeric_literal(lit_value, target_type)
47        .or_else(|| try_cast_string_literal(lit_value, target_type))
48        .or_else(|| try_cast_dictionary(lit_value, target_type))
49        .or_else(|| try_cast_binary(lit_value, target_type))
50}
51
52/// Returns true if unwrap_cast_in_comparison supports this data type
53pub fn is_supported_type(data_type: &DataType) -> bool {
54    is_supported_numeric_type(data_type)
55        || is_supported_string_type(data_type)
56        || is_supported_dictionary_type(data_type)
57        || is_supported_binary_type(data_type)
58}
59
60/// Returns true if unwrap_cast_in_comparison support this numeric type
61fn is_supported_numeric_type(data_type: &DataType) -> bool {
62    matches!(
63        data_type,
64        DataType::UInt8
65            | DataType::UInt16
66            | DataType::UInt32
67            | DataType::UInt64
68            | DataType::Int8
69            | DataType::Int16
70            | DataType::Int32
71            | DataType::Int64
72            | DataType::Decimal128(_, _)
73            | DataType::Timestamp(_, _)
74    )
75}
76
77/// Returns true if unwrap_cast_in_comparison supports casting this value as a string
78fn is_supported_string_type(data_type: &DataType) -> bool {
79    matches!(
80        data_type,
81        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
82    )
83}
84
85/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary
86fn is_supported_dictionary_type(data_type: &DataType) -> bool {
87    matches!(data_type,
88                    DataType::Dictionary(_, inner) if is_supported_type(inner))
89}
90
91fn is_supported_binary_type(data_type: &DataType) -> bool {
92    matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
93}
94
95/// Convert a numeric value from one numeric data type to another
96fn try_cast_numeric_literal(
97    lit_value: &ScalarValue,
98    target_type: &DataType,
99) -> Option<ScalarValue> {
100    let lit_data_type = lit_value.data_type();
101    if !is_supported_numeric_type(&lit_data_type)
102        || !is_supported_numeric_type(target_type)
103    {
104        return None;
105    }
106
107    let mul = match target_type {
108        DataType::UInt8
109        | DataType::UInt16
110        | DataType::UInt32
111        | DataType::UInt64
112        | DataType::Int8
113        | DataType::Int16
114        | DataType::Int32
115        | DataType::Int64 => 1_i128,
116        DataType::Timestamp(_, _) => 1_i128,
117        DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
118        _ => return None,
119    };
120    let (target_min, target_max) = match target_type {
121        DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
122        DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
123        DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
124        DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
125        DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
126        DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
127        DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
128        DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
129        DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
130        DataType::Decimal128(precision, _) => (
131            // Different precision for decimal128 can store different range of value.
132            // For example, the precision is 3, the max of value is `999` and the min
133            // value is `-999`
134            MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
135            MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
136        ),
137        _ => return None,
138    };
139    let lit_value_target_type = match lit_value {
140        ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
141        ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
142        ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
143        ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
144        ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
145        ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
146        ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
147        ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
148        ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
149        ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
150        ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
151        ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
152        ScalarValue::Decimal128(Some(v), _, scale) => {
153            let lit_scale_mul = 10_i128.pow(*scale as u32);
154            if mul >= lit_scale_mul {
155                // Example:
156                // lit is decimal(123,3,2)
157                // target type is decimal(5,3)
158                // the lit can be converted to the decimal(1230,5,3)
159                (*v).checked_mul(mul / lit_scale_mul)
160            } else if (*v) % (lit_scale_mul / mul) == 0 {
161                // Example:
162                // lit is decimal(123000,10,3)
163                // target type is int32: the lit can be converted to INT32(123)
164                // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
165                Some(*v / (lit_scale_mul / mul))
166            } else {
167                // can't convert the lit decimal to the target data type
168                None
169            }
170        }
171        _ => None,
172    };
173
174    match lit_value_target_type {
175        None => None,
176        Some(value) => {
177            if value >= target_min && value <= target_max {
178                // the value casted from lit to the target type is in the range of target type.
179                // return the target type of scalar value
180                let result_scalar = match target_type {
181                    DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
182                    DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
183                    DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
184                    DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
185                    DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
186                    DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
187                    DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
188                    DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
189                    DataType::Timestamp(TimeUnit::Second, tz) => {
190                        let value = cast_between_timestamp(
191                            &lit_data_type,
192                            &DataType::Timestamp(TimeUnit::Second, tz.clone()),
193                            value,
194                        );
195                        ScalarValue::TimestampSecond(value, tz.clone())
196                    }
197                    DataType::Timestamp(TimeUnit::Millisecond, tz) => {
198                        let value = cast_between_timestamp(
199                            &lit_data_type,
200                            &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
201                            value,
202                        );
203                        ScalarValue::TimestampMillisecond(value, tz.clone())
204                    }
205                    DataType::Timestamp(TimeUnit::Microsecond, tz) => {
206                        let value = cast_between_timestamp(
207                            &lit_data_type,
208                            &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
209                            value,
210                        );
211                        ScalarValue::TimestampMicrosecond(value, tz.clone())
212                    }
213                    DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
214                        let value = cast_between_timestamp(
215                            &lit_data_type,
216                            &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
217                            value,
218                        );
219                        ScalarValue::TimestampNanosecond(value, tz.clone())
220                    }
221                    DataType::Decimal128(p, s) => {
222                        ScalarValue::Decimal128(Some(value), *p, *s)
223                    }
224                    _ => {
225                        return None;
226                    }
227                };
228                Some(result_scalar)
229            } else {
230                None
231            }
232        }
233    }
234}
235
236fn try_cast_string_literal(
237    lit_value: &ScalarValue,
238    target_type: &DataType,
239) -> Option<ScalarValue> {
240    let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
241    let scalar_value = match target_type {
242        DataType::Utf8 => ScalarValue::Utf8(string_value),
243        DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
244        DataType::Utf8View => ScalarValue::Utf8View(string_value),
245        _ => return None,
246    };
247    Some(scalar_value)
248}
249
250/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary
251fn try_cast_dictionary(
252    lit_value: &ScalarValue,
253    target_type: &DataType,
254) -> Option<ScalarValue> {
255    let lit_value_type = lit_value.data_type();
256    let result_scalar = match (lit_value, target_type) {
257        // Unwrap dictionary when inner type matches target type
258        (ScalarValue::Dictionary(_, inner_value), _)
259            if inner_value.data_type() == *target_type =>
260        {
261            (**inner_value).clone()
262        }
263        // Wrap type when target type is dictionary
264        (_, DataType::Dictionary(index_type, inner_type))
265            if **inner_type == lit_value_type =>
266        {
267            ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
268        }
269        _ => {
270            return None;
271        }
272    };
273    Some(result_scalar)
274}
275
276/// Cast a timestamp value from one unit to another
277fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
278    let value = value as i64;
279    let from_scale = match from {
280        DataType::Timestamp(TimeUnit::Second, _) => 1,
281        DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
282        DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
283        DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
284        _ => return Some(value),
285    };
286
287    let to_scale = match to {
288        DataType::Timestamp(TimeUnit::Second, _) => 1,
289        DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
290        DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
291        DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
292        _ => return Some(value),
293    };
294
295    match from_scale.cmp(&to_scale) {
296        Ordering::Less => value.checked_mul(to_scale / from_scale),
297        Ordering::Greater => Some(value / (from_scale / to_scale)),
298        Ordering::Equal => Some(value),
299    }
300}
301
302fn try_cast_binary(
303    lit_value: &ScalarValue,
304    target_type: &DataType,
305) -> Option<ScalarValue> {
306    match (lit_value, target_type) {
307        (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
308            if v.len() == *n as usize =>
309        {
310            Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
311        }
312        _ => None,
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use arrow::compute::{cast_with_options, CastOptions};
320    use arrow::datatypes::{Field, Fields, TimeUnit};
321    use std::sync::Arc;
322
323    #[derive(Debug, Clone)]
324    enum ExpectedCast {
325        /// test successfully cast value and it is as specified
326        Value(ScalarValue),
327        /// test returned OK, but could not cast the value
328        NoValue,
329    }
330
331    /// Runs try_cast_literal_to_type with the specified inputs and
332    /// ensure it computes the expected output, and ensures the
333    /// casting is consistent with the Arrow kernels
334    fn expect_cast(
335        literal: ScalarValue,
336        target_type: DataType,
337        expected_result: ExpectedCast,
338    ) {
339        let actual_value = try_cast_literal_to_type(&literal, &target_type);
340
341        println!("expect_cast: ");
342        println!("  {literal:?} --> {target_type:?}");
343        println!("  expected_result: {expected_result:?}");
344        println!("  actual_result:   {actual_value:?}");
345
346        match expected_result {
347            ExpectedCast::Value(expected_value) => {
348                let actual_value =
349                    actual_value.expect("Expected cast value but got None");
350
351                assert_eq!(actual_value, expected_value);
352
353                // Verify that calling the arrow
354                // cast kernel yields the same results
355                // input array
356                let literal_array = literal
357                    .to_array_of_size(1)
358                    .expect("Failed to convert to array of size");
359                let expected_array = expected_value
360                    .to_array_of_size(1)
361                    .expect("Failed to convert to array of size");
362                let cast_array = cast_with_options(
363                    &literal_array,
364                    &target_type,
365                    &CastOptions::default(),
366                )
367                .expect("Expected to be cast array with arrow cast kernel");
368
369                assert_eq!(
370                    &expected_array, &cast_array,
371                    "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
372                );
373
374                // Verify that for timestamp types the timezones are the same
375                // (ScalarValue::cmp doesn't account for timezones);
376                if let (
377                    DataType::Timestamp(left_unit, left_tz),
378                    DataType::Timestamp(right_unit, right_tz),
379                ) = (actual_value.data_type(), expected_value.data_type())
380                {
381                    assert_eq!(left_unit, right_unit);
382                    assert_eq!(left_tz, right_tz);
383                }
384            }
385            ExpectedCast::NoValue => {
386                assert!(
387                    actual_value.is_none(),
388                    "Expected no cast value, but got {actual_value:?}"
389                );
390            }
391        }
392    }
393
394    #[test]
395    fn test_try_cast_to_type_nulls() {
396        // test that nulls can be cast to/from all integer types
397        let scalars = vec![
398            ScalarValue::Int8(None),
399            ScalarValue::Int16(None),
400            ScalarValue::Int32(None),
401            ScalarValue::Int64(None),
402            ScalarValue::UInt8(None),
403            ScalarValue::UInt16(None),
404            ScalarValue::UInt32(None),
405            ScalarValue::UInt64(None),
406            ScalarValue::Decimal128(None, 3, 0),
407            ScalarValue::Decimal128(None, 8, 2),
408            ScalarValue::Utf8(None),
409            ScalarValue::LargeUtf8(None),
410        ];
411
412        for s1 in &scalars {
413            for s2 in &scalars {
414                let expected_value = ExpectedCast::Value(s2.clone());
415
416                expect_cast(s1.clone(), s2.data_type(), expected_value);
417            }
418        }
419    }
420
421    #[test]
422    fn test_try_cast_to_type_int_in_range() {
423        // test values that can be cast to/from all integer types
424        let scalars = vec![
425            ScalarValue::Int8(Some(123)),
426            ScalarValue::Int16(Some(123)),
427            ScalarValue::Int32(Some(123)),
428            ScalarValue::Int64(Some(123)),
429            ScalarValue::UInt8(Some(123)),
430            ScalarValue::UInt16(Some(123)),
431            ScalarValue::UInt32(Some(123)),
432            ScalarValue::UInt64(Some(123)),
433            ScalarValue::Decimal128(Some(123), 3, 0),
434            ScalarValue::Decimal128(Some(12300), 8, 2),
435        ];
436
437        for s1 in &scalars {
438            for s2 in &scalars {
439                let expected_value = ExpectedCast::Value(s2.clone());
440
441                expect_cast(s1.clone(), s2.data_type(), expected_value);
442            }
443        }
444
445        let max_i32 = ScalarValue::Int32(Some(i32::MAX));
446        expect_cast(
447            max_i32,
448            DataType::UInt64,
449            ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
450        );
451
452        let min_i32 = ScalarValue::Int32(Some(i32::MIN));
453        expect_cast(
454            min_i32,
455            DataType::Int64,
456            ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
457        );
458
459        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
460        expect_cast(
461            max_i64,
462            DataType::UInt64,
463            ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
464        );
465    }
466
467    #[test]
468    fn test_try_cast_to_type_int_out_of_range() {
469        let min_i32 = ScalarValue::Int32(Some(i32::MIN));
470        let min_i64 = ScalarValue::Int64(Some(i64::MIN));
471        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
472        let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
473
474        expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
475
476        expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
477
478        expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
479
480        expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
481
482        expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
483
484        expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
485
486        // decimal out of range
487        expect_cast(
488            ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
489            DataType::Int64,
490            ExpectedCast::NoValue,
491        );
492
493        expect_cast(
494            ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
495            DataType::Int64,
496            ExpectedCast::NoValue,
497        );
498    }
499
500    #[test]
501    fn test_try_decimal_cast_in_range() {
502        expect_cast(
503            ScalarValue::Decimal128(Some(12300), 5, 2),
504            DataType::Decimal128(3, 0),
505            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
506        );
507
508        expect_cast(
509            ScalarValue::Decimal128(Some(12300), 5, 2),
510            DataType::Decimal128(8, 0),
511            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
512        );
513
514        expect_cast(
515            ScalarValue::Decimal128(Some(12300), 5, 2),
516            DataType::Decimal128(8, 5),
517            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
518        );
519    }
520
521    #[test]
522    fn test_try_decimal_cast_out_of_range() {
523        // decimal would lose precision
524        expect_cast(
525            ScalarValue::Decimal128(Some(12345), 5, 2),
526            DataType::Decimal128(3, 0),
527            ExpectedCast::NoValue,
528        );
529
530        // decimal would lose precision
531        expect_cast(
532            ScalarValue::Decimal128(Some(12300), 5, 2),
533            DataType::Decimal128(2, 0),
534            ExpectedCast::NoValue,
535        );
536    }
537
538    #[test]
539    fn test_try_cast_to_type_timestamps() {
540        for time_unit in [
541            TimeUnit::Second,
542            TimeUnit::Millisecond,
543            TimeUnit::Microsecond,
544            TimeUnit::Nanosecond,
545        ] {
546            let utc = Some("+00:00".into());
547            // No timezone, utc timezone
548            let (lit_tz_none, lit_tz_utc) = match time_unit {
549                TimeUnit::Second => (
550                    ScalarValue::TimestampSecond(Some(12345), None),
551                    ScalarValue::TimestampSecond(Some(12345), utc),
552                ),
553
554                TimeUnit::Millisecond => (
555                    ScalarValue::TimestampMillisecond(Some(12345), None),
556                    ScalarValue::TimestampMillisecond(Some(12345), utc),
557                ),
558
559                TimeUnit::Microsecond => (
560                    ScalarValue::TimestampMicrosecond(Some(12345), None),
561                    ScalarValue::TimestampMicrosecond(Some(12345), utc),
562                ),
563
564                TimeUnit::Nanosecond => (
565                    ScalarValue::TimestampNanosecond(Some(12345), None),
566                    ScalarValue::TimestampNanosecond(Some(12345), utc),
567                ),
568            };
569
570            // DataFusion ignores timezones for comparisons of ScalarValue
571            // so double check it here
572            assert_eq!(lit_tz_none, lit_tz_utc);
573
574            // e.g. DataType::Timestamp(_, None)
575            let dt_tz_none = lit_tz_none.data_type();
576
577            // e.g. DataType::Timestamp(_, Some(utc))
578            let dt_tz_utc = lit_tz_utc.data_type();
579
580            // None <--> None
581            expect_cast(
582                lit_tz_none.clone(),
583                dt_tz_none.clone(),
584                ExpectedCast::Value(lit_tz_none.clone()),
585            );
586
587            // None <--> Utc
588            expect_cast(
589                lit_tz_none.clone(),
590                dt_tz_utc.clone(),
591                ExpectedCast::Value(lit_tz_utc.clone()),
592            );
593
594            // Utc <--> None
595            expect_cast(
596                lit_tz_utc.clone(),
597                dt_tz_none.clone(),
598                ExpectedCast::Value(lit_tz_none.clone()),
599            );
600
601            // Utc <--> Utc
602            expect_cast(
603                lit_tz_utc.clone(),
604                dt_tz_utc.clone(),
605                ExpectedCast::Value(lit_tz_utc.clone()),
606            );
607
608            // timestamp to int64
609            expect_cast(
610                lit_tz_utc.clone(),
611                DataType::Int64,
612                ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
613            );
614
615            // int64 to timestamp
616            expect_cast(
617                ScalarValue::Int64(Some(12345)),
618                dt_tz_none.clone(),
619                ExpectedCast::Value(lit_tz_none.clone()),
620            );
621
622            // int64 to timestamp
623            expect_cast(
624                ScalarValue::Int64(Some(12345)),
625                dt_tz_utc.clone(),
626                ExpectedCast::Value(lit_tz_utc.clone()),
627            );
628
629            // timestamp to string (not supported yet)
630            expect_cast(
631                lit_tz_utc.clone(),
632                DataType::LargeUtf8,
633                ExpectedCast::NoValue,
634            );
635        }
636    }
637
638    #[test]
639    fn test_try_cast_to_type_unsupported() {
640        // int64 to list
641        expect_cast(
642            ScalarValue::Int64(Some(12345)),
643            DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
644            ExpectedCast::NoValue,
645        );
646    }
647
648    #[test]
649    fn test_try_cast_literal_to_timestamp() {
650        // same timestamp
651        let new_scalar = try_cast_literal_to_type(
652            &ScalarValue::TimestampNanosecond(Some(123456), None),
653            &DataType::Timestamp(TimeUnit::Nanosecond, None),
654        )
655        .unwrap();
656
657        assert_eq!(
658            new_scalar,
659            ScalarValue::TimestampNanosecond(Some(123456), None)
660        );
661
662        // TimestampNanosecond to TimestampMicrosecond
663        let new_scalar = try_cast_literal_to_type(
664            &ScalarValue::TimestampNanosecond(Some(123456), None),
665            &DataType::Timestamp(TimeUnit::Microsecond, None),
666        )
667        .unwrap();
668
669        assert_eq!(
670            new_scalar,
671            ScalarValue::TimestampMicrosecond(Some(123), None)
672        );
673
674        // TimestampNanosecond to TimestampMillisecond
675        let new_scalar = try_cast_literal_to_type(
676            &ScalarValue::TimestampNanosecond(Some(123456), None),
677            &DataType::Timestamp(TimeUnit::Millisecond, None),
678        )
679        .unwrap();
680
681        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
682
683        // TimestampNanosecond to TimestampSecond
684        let new_scalar = try_cast_literal_to_type(
685            &ScalarValue::TimestampNanosecond(Some(123456), None),
686            &DataType::Timestamp(TimeUnit::Second, None),
687        )
688        .unwrap();
689
690        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
691
692        // TimestampMicrosecond to TimestampNanosecond
693        let new_scalar = try_cast_literal_to_type(
694            &ScalarValue::TimestampMicrosecond(Some(123), None),
695            &DataType::Timestamp(TimeUnit::Nanosecond, None),
696        )
697        .unwrap();
698
699        assert_eq!(
700            new_scalar,
701            ScalarValue::TimestampNanosecond(Some(123000), None)
702        );
703
704        // TimestampMicrosecond to TimestampMillisecond
705        let new_scalar = try_cast_literal_to_type(
706            &ScalarValue::TimestampMicrosecond(Some(123), None),
707            &DataType::Timestamp(TimeUnit::Millisecond, None),
708        )
709        .unwrap();
710
711        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
712
713        // TimestampMicrosecond to TimestampSecond
714        let new_scalar = try_cast_literal_to_type(
715            &ScalarValue::TimestampMicrosecond(Some(123456789), None),
716            &DataType::Timestamp(TimeUnit::Second, None),
717        )
718        .unwrap();
719        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
720
721        // TimestampMillisecond to TimestampNanosecond
722        let new_scalar = try_cast_literal_to_type(
723            &ScalarValue::TimestampMillisecond(Some(123), None),
724            &DataType::Timestamp(TimeUnit::Nanosecond, None),
725        )
726        .unwrap();
727        assert_eq!(
728            new_scalar,
729            ScalarValue::TimestampNanosecond(Some(123000000), None)
730        );
731
732        // TimestampMillisecond to TimestampMicrosecond
733        let new_scalar = try_cast_literal_to_type(
734            &ScalarValue::TimestampMillisecond(Some(123), None),
735            &DataType::Timestamp(TimeUnit::Microsecond, None),
736        )
737        .unwrap();
738        assert_eq!(
739            new_scalar,
740            ScalarValue::TimestampMicrosecond(Some(123000), None)
741        );
742        // TimestampMillisecond to TimestampSecond
743        let new_scalar = try_cast_literal_to_type(
744            &ScalarValue::TimestampMillisecond(Some(123456789), None),
745            &DataType::Timestamp(TimeUnit::Second, None),
746        )
747        .unwrap();
748        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
749
750        // TimestampSecond to TimestampNanosecond
751        let new_scalar = try_cast_literal_to_type(
752            &ScalarValue::TimestampSecond(Some(123), None),
753            &DataType::Timestamp(TimeUnit::Nanosecond, None),
754        )
755        .unwrap();
756        assert_eq!(
757            new_scalar,
758            ScalarValue::TimestampNanosecond(Some(123000000000), None)
759        );
760
761        // TimestampSecond to TimestampMicrosecond
762        let new_scalar = try_cast_literal_to_type(
763            &ScalarValue::TimestampSecond(Some(123), None),
764            &DataType::Timestamp(TimeUnit::Microsecond, None),
765        )
766        .unwrap();
767        assert_eq!(
768            new_scalar,
769            ScalarValue::TimestampMicrosecond(Some(123000000), None)
770        );
771
772        // TimestampSecond to TimestampMillisecond
773        let new_scalar = try_cast_literal_to_type(
774            &ScalarValue::TimestampSecond(Some(123), None),
775            &DataType::Timestamp(TimeUnit::Millisecond, None),
776        )
777        .unwrap();
778        assert_eq!(
779            new_scalar,
780            ScalarValue::TimestampMillisecond(Some(123000), None)
781        );
782
783        // overflow
784        let new_scalar = try_cast_literal_to_type(
785            &ScalarValue::TimestampSecond(Some(i64::MAX), None),
786            &DataType::Timestamp(TimeUnit::Millisecond, None),
787        )
788        .unwrap();
789        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
790    }
791
792    #[test]
793    fn test_try_cast_to_string_type() {
794        let scalars = vec![
795            ScalarValue::from("string"),
796            ScalarValue::LargeUtf8(Some("string".to_owned())),
797        ];
798
799        for s1 in &scalars {
800            for s2 in &scalars {
801                let expected_value = ExpectedCast::Value(s2.clone());
802
803                expect_cast(s1.clone(), s2.data_type(), expected_value);
804            }
805        }
806    }
807
808    #[test]
809    fn test_try_cast_to_dictionary_type() {
810        fn dictionary_type(t: DataType) -> DataType {
811            DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
812        }
813        fn dictionary_value(value: ScalarValue) -> ScalarValue {
814            ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
815        }
816        let scalars = vec![
817            ScalarValue::from("string"),
818            ScalarValue::LargeUtf8(Some("string".to_owned())),
819        ];
820        for s in &scalars {
821            expect_cast(
822                s.clone(),
823                dictionary_type(s.data_type()),
824                ExpectedCast::Value(dictionary_value(s.clone())),
825            );
826            expect_cast(
827                dictionary_value(s.clone()),
828                s.data_type(),
829                ExpectedCast::Value(s.clone()),
830            )
831        }
832    }
833
834    #[test]
835    fn test_try_cast_to_fixed_size_binary() {
836        expect_cast(
837            ScalarValue::Binary(Some(vec![1, 2, 3])),
838            DataType::FixedSizeBinary(3),
839            ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
840        )
841    }
842
843    #[test]
844    fn test_numeric_boundary_values() {
845        // Test exact boundary values for signed integers
846        expect_cast(
847            ScalarValue::Int8(Some(i8::MAX)),
848            DataType::UInt8,
849            ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))),
850        );
851
852        expect_cast(
853            ScalarValue::Int8(Some(i8::MIN)),
854            DataType::UInt8,
855            ExpectedCast::NoValue,
856        );
857
858        expect_cast(
859            ScalarValue::UInt8(Some(u8::MAX)),
860            DataType::Int8,
861            ExpectedCast::NoValue,
862        );
863
864        // Test cross-type boundary scenarios
865        expect_cast(
866            ScalarValue::Int32(Some(i32::MAX)),
867            DataType::Int64,
868            ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))),
869        );
870
871        expect_cast(
872            ScalarValue::Int64(Some(i64::MIN)),
873            DataType::UInt64,
874            ExpectedCast::NoValue,
875        );
876
877        // Test unsigned to signed edge cases
878        expect_cast(
879            ScalarValue::UInt32(Some(u32::MAX)),
880            DataType::Int32,
881            ExpectedCast::NoValue,
882        );
883
884        expect_cast(
885            ScalarValue::UInt64(Some(u64::MAX)),
886            DataType::Int64,
887            ExpectedCast::NoValue,
888        );
889    }
890
891    #[test]
892    fn test_decimal_precision_limits() {
893        use arrow::datatypes::{
894            MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION,
895        };
896
897        // Test maximum precision values
898        expect_cast(
899            ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
900            DataType::Decimal128(5, 0),
901            ExpectedCast::Value(ScalarValue::Decimal128(
902                Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]),
903                5,
904                0,
905            )),
906        );
907
908        // Test minimum precision values
909        expect_cast(
910            ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
911            DataType::Decimal128(5, 0),
912            ExpectedCast::Value(ScalarValue::Decimal128(
913                Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]),
914                5,
915                0,
916            )),
917        );
918
919        // Test scale increase
920        expect_cast(
921            ScalarValue::Decimal128(Some(123), 3, 0),
922            DataType::Decimal128(5, 2),
923            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)),
924        );
925
926        // Test precision overflow (value too large for target precision)
927        expect_cast(
928            ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0),
929            DataType::Decimal128(3, 0),
930            ExpectedCast::NoValue,
931        );
932
933        // Test non-divisible decimal conversion (should fail)
934        expect_cast(
935            ScalarValue::Decimal128(Some(12345), 5, 3), // 12.345
936            DataType::Int32,
937            ExpectedCast::NoValue, // Can't convert 12.345 to integer without loss
938        );
939
940        // Test edge case: scale reduction with precision loss
941        expect_cast(
942            ScalarValue::Decimal128(Some(12345), 5, 2), // 123.45
943            DataType::Decimal128(3, 0),                 // Can only hold up to 999
944            ExpectedCast::NoValue,
945        );
946    }
947
948    #[test]
949    fn test_timestamp_overflow_scenarios() {
950        // Test overflow in timestamp conversions
951        let max_seconds = i64::MAX / 1_000_000_000; // Avoid overflow when converting to nanos
952
953        // This should work - within safe range
954        expect_cast(
955            ScalarValue::TimestampSecond(Some(max_seconds), None),
956            DataType::Timestamp(TimeUnit::Nanosecond, None),
957            ExpectedCast::Value(ScalarValue::TimestampNanosecond(
958                Some(max_seconds * 1_000_000_000),
959                None,
960            )),
961        );
962
963        // Test very large nanosecond value conversion to smaller units
964        expect_cast(
965            ScalarValue::TimestampNanosecond(Some(i64::MAX), None),
966            DataType::Timestamp(TimeUnit::Second, None),
967            ExpectedCast::Value(ScalarValue::TimestampSecond(
968                Some(i64::MAX / 1_000_000_000),
969                None,
970            )),
971        );
972
973        // Test precision loss in downscaling
974        expect_cast(
975            ScalarValue::TimestampNanosecond(Some(1), None),
976            DataType::Timestamp(TimeUnit::Second, None),
977            ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)),
978        );
979
980        expect_cast(
981            ScalarValue::TimestampMicrosecond(Some(999), None),
982            DataType::Timestamp(TimeUnit::Millisecond, None),
983            ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)),
984        );
985    }
986
987    #[test]
988    fn test_string_view() {
989        // Test Utf8View to other string types
990        expect_cast(
991            ScalarValue::Utf8View(Some("test".to_string())),
992            DataType::Utf8,
993            ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))),
994        );
995
996        expect_cast(
997            ScalarValue::Utf8View(Some("test".to_string())),
998            DataType::LargeUtf8,
999            ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))),
1000        );
1001
1002        // Test other string types to Utf8View
1003        expect_cast(
1004            ScalarValue::Utf8(Some("hello".to_string())),
1005            DataType::Utf8View,
1006            ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))),
1007        );
1008
1009        expect_cast(
1010            ScalarValue::LargeUtf8(Some("world".to_string())),
1011            DataType::Utf8View,
1012            ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))),
1013        );
1014
1015        // Test empty string
1016        expect_cast(
1017            ScalarValue::Utf8(Some("".to_string())),
1018            DataType::Utf8View,
1019            ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))),
1020        );
1021
1022        // Test large string
1023        let large_string = "x".repeat(1000);
1024        expect_cast(
1025            ScalarValue::LargeUtf8(Some(large_string.clone())),
1026            DataType::Utf8View,
1027            ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))),
1028        );
1029    }
1030
1031    #[test]
1032    fn test_binary_size_edge_cases() {
1033        // Test size mismatch - too small
1034        expect_cast(
1035            ScalarValue::Binary(Some(vec![1, 2])),
1036            DataType::FixedSizeBinary(3),
1037            ExpectedCast::NoValue,
1038        );
1039
1040        // Test size mismatch - too large
1041        expect_cast(
1042            ScalarValue::Binary(Some(vec![1, 2, 3, 4])),
1043            DataType::FixedSizeBinary(3),
1044            ExpectedCast::NoValue,
1045        );
1046
1047        // Test empty binary
1048        expect_cast(
1049            ScalarValue::Binary(Some(vec![])),
1050            DataType::FixedSizeBinary(0),
1051            ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))),
1052        );
1053
1054        // Test exact size match
1055        expect_cast(
1056            ScalarValue::Binary(Some(vec![1, 2, 3])),
1057            DataType::FixedSizeBinary(3),
1058            ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
1059        );
1060
1061        // Test single byte
1062        expect_cast(
1063            ScalarValue::Binary(Some(vec![42])),
1064            DataType::FixedSizeBinary(1),
1065            ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))),
1066        );
1067    }
1068
1069    #[test]
1070    fn test_dictionary_index_types() {
1071        // Test different dictionary index types
1072        let string_value = ScalarValue::Utf8(Some("test".to_string()));
1073
1074        // Int8 index dictionary
1075        let dict_int8 =
1076            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
1077        expect_cast(
1078            string_value.clone(),
1079            dict_int8,
1080            ExpectedCast::Value(ScalarValue::Dictionary(
1081                Box::new(DataType::Int8),
1082                Box::new(string_value.clone()),
1083            )),
1084        );
1085
1086        // Int16 index dictionary
1087        let dict_int16 =
1088            DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
1089        expect_cast(
1090            string_value.clone(),
1091            dict_int16,
1092            ExpectedCast::Value(ScalarValue::Dictionary(
1093                Box::new(DataType::Int16),
1094                Box::new(string_value.clone()),
1095            )),
1096        );
1097
1098        // Int64 index dictionary
1099        let dict_int64 =
1100            DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8));
1101        expect_cast(
1102            string_value.clone(),
1103            dict_int64,
1104            ExpectedCast::Value(ScalarValue::Dictionary(
1105                Box::new(DataType::Int64),
1106                Box::new(string_value.clone()),
1107            )),
1108        );
1109
1110        // Test dictionary unwrapping
1111        let dict_value = ScalarValue::Dictionary(
1112            Box::new(DataType::Int32),
1113            Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1114        );
1115        expect_cast(
1116            dict_value,
1117            DataType::LargeUtf8,
1118            ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
1119        );
1120    }
1121
1122    #[test]
1123    fn test_type_support_functions() {
1124        // Test numeric type support
1125        assert!(is_supported_numeric_type(&DataType::Int8));
1126        assert!(is_supported_numeric_type(&DataType::UInt64));
1127        assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2)));
1128        assert!(is_supported_numeric_type(&DataType::Timestamp(
1129            TimeUnit::Nanosecond,
1130            None
1131        )));
1132        assert!(!is_supported_numeric_type(&DataType::Float32));
1133        assert!(!is_supported_numeric_type(&DataType::Float64));
1134
1135        // Test string type support
1136        assert!(is_supported_string_type(&DataType::Utf8));
1137        assert!(is_supported_string_type(&DataType::LargeUtf8));
1138        assert!(is_supported_string_type(&DataType::Utf8View));
1139        assert!(!is_supported_string_type(&DataType::Binary));
1140
1141        // Test binary type support
1142        assert!(is_supported_binary_type(&DataType::Binary));
1143        assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10)));
1144        assert!(!is_supported_binary_type(&DataType::Utf8));
1145
1146        // Test dictionary type support with nested types
1147        assert!(is_supported_dictionary_type(&DataType::Dictionary(
1148            Box::new(DataType::Int32),
1149            Box::new(DataType::Utf8)
1150        )));
1151        assert!(is_supported_dictionary_type(&DataType::Dictionary(
1152            Box::new(DataType::Int32),
1153            Box::new(DataType::Int64)
1154        )));
1155        assert!(!is_supported_dictionary_type(&DataType::Dictionary(
1156            Box::new(DataType::Int32),
1157            Box::new(DataType::List(Arc::new(Field::new(
1158                "item",
1159                DataType::Int32,
1160                true
1161            ))))
1162        )));
1163
1164        // Test overall type support
1165        assert!(is_supported_type(&DataType::Int32));
1166        assert!(is_supported_type(&DataType::Utf8));
1167        assert!(is_supported_type(&DataType::Binary));
1168        assert!(is_supported_type(&DataType::Dictionary(
1169            Box::new(DataType::Int32),
1170            Box::new(DataType::Utf8)
1171        )));
1172        assert!(!is_supported_type(&DataType::List(Arc::new(Field::new(
1173            "item",
1174            DataType::Int32,
1175            true
1176        )))));
1177        assert!(!is_supported_type(&DataType::Struct(Fields::empty())));
1178    }
1179
1180    #[test]
1181    fn test_error_conditions() {
1182        // Test unsupported source type
1183        expect_cast(
1184            ScalarValue::Float32(Some(1.5)),
1185            DataType::Int32,
1186            ExpectedCast::NoValue,
1187        );
1188
1189        // Test unsupported target type
1190        expect_cast(
1191            ScalarValue::Int32(Some(123)),
1192            DataType::Float64,
1193            ExpectedCast::NoValue,
1194        );
1195
1196        // Test both types unsupported
1197        expect_cast(
1198            ScalarValue::Float64(Some(1.5)),
1199            DataType::Float32,
1200            ExpectedCast::NoValue,
1201        );
1202
1203        // Test complex unsupported types
1204        let list_type =
1205            DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1206        expect_cast(
1207            ScalarValue::Int32(Some(123)),
1208            list_type,
1209            ExpectedCast::NoValue,
1210        );
1211
1212        // Test dictionary with unsupported inner type
1213        let bad_dict = DataType::Dictionary(
1214            Box::new(DataType::Int32),
1215            Box::new(DataType::List(Arc::new(Field::new(
1216                "item",
1217                DataType::Int32,
1218                true,
1219            )))),
1220        );
1221        expect_cast(
1222            ScalarValue::Int32(Some(123)),
1223            bad_dict,
1224            ExpectedCast::NoValue,
1225        );
1226    }
1227}