Skip to main content

lance_datafusion/
expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Utilities for working with datafusion expressions
5
6use std::sync::Arc;
7
8use arrow::compute::cast;
9use arrow_array::{ArrayRef, cast::AsArray};
10use arrow_schema::{DataType, TimeUnit};
11use datafusion_common::ScalarValue;
12
13const MS_PER_DAY: i64 = 86400000;
14
15// This is slightly tedious but when we convert expressions from SQL strings to logical
16// datafusion expressions there is no type coercion that happens.  In other words "x = 7"
17// will always yield "x = 7_u64" regardless of the type of the column "x".  As a result, we
18// need to do that literal coercion ourselves.
19pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option<ScalarValue> {
20    // A dictionary target coerces the value to the dictionary's value type and
21    // re-wraps it as a dictionary literal. Only an untyped `ScalarValue::Null`
22    // keeps its untyped form, matching the behavior for all other targets; a
23    // *typed* null (e.g. `Utf8(None)`) is coerced and wrapped like any other
24    // value so it produces a `Dictionary(..)` literal that matches the column.
25    if let DataType::Dictionary(key_type, value_type) = ty {
26        if matches!(value, ScalarValue::Null) {
27            return Some(value.clone());
28        }
29        let inner = safe_coerce_scalar(value, value_type)?;
30        return Some(ScalarValue::Dictionary(key_type.clone(), Box::new(inner)));
31    }
32    match value {
33        ScalarValue::Int8(val) => match ty {
34            DataType::Int8 => Some(value.clone()),
35            DataType::Int16 => val.map(|v| ScalarValue::Int16(Some(i16::from(v)))),
36            DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(i32::from(v)))),
37            DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
38            DataType::UInt8 => {
39                val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
40            }
41            DataType::UInt16 => {
42                val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
43            }
44            DataType::UInt32 => {
45                val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
46            }
47            DataType::UInt64 => {
48                val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
49            }
50            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
51            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
52            _ => None,
53        },
54        ScalarValue::Int16(val) => match ty {
55            DataType::Int8 => {
56                val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
57            }
58            DataType::Int16 => Some(value.clone()),
59            DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(i32::from(v)))),
60            DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
61            DataType::UInt8 => {
62                val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
63            }
64            DataType::UInt16 => {
65                val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
66            }
67            DataType::UInt32 => {
68                val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
69            }
70            DataType::UInt64 => {
71                val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
72            }
73            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
74            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
75            _ => None,
76        },
77        ScalarValue::Int32(val) => match ty {
78            DataType::Int8 => {
79                val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
80            }
81            DataType::Int16 => {
82                val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
83            }
84            DataType::Int32 => Some(value.clone()),
85            DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(i64::from(v)))),
86            DataType::UInt8 => {
87                val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
88            }
89            DataType::UInt16 => {
90                val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
91            }
92            DataType::UInt32 => {
93                val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
94            }
95            DataType::UInt64 => {
96                val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
97            }
98            // These conversions are inherently lossy as the full range of i32 cannot
99            // be represented in f32.  However, there is no f32::TryFrom(i32) and its not
100            // clear users would want that anyways
101            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
102            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
103            _ => None,
104        },
105        ScalarValue::Int64(val) => match ty {
106            DataType::Int8 => {
107                val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
108            }
109            DataType::Int16 => {
110                val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
111            }
112            DataType::Int32 => {
113                val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
114            }
115            DataType::Int64 => Some(value.clone()),
116            DataType::UInt8 => {
117                val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
118            }
119            DataType::UInt16 => {
120                val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
121            }
122            DataType::UInt32 => {
123                val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
124            }
125            DataType::UInt64 => {
126                val.and_then(|v| u64::try_from(v).map(|v| ScalarValue::UInt64(Some(v))).ok())
127            }
128            // See above warning about lossy float conversion
129            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
130            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
131            DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => value.cast_to(ty).ok(),
132            _ => None,
133        },
134        ScalarValue::UInt8(val) => match ty {
135            DataType::Int8 => {
136                val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
137            }
138            DataType::Int16 => val.map(|v| ScalarValue::Int16(Some(v.into()))),
139            DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(v.into()))),
140            DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
141            DataType::UInt8 => Some(value.clone()),
142            DataType::UInt16 => val.map(|v| ScalarValue::UInt16(Some(u16::from(v)))),
143            DataType::UInt32 => val.map(|v| ScalarValue::UInt32(Some(u32::from(v)))),
144            DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
145            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
146            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
147            _ => None,
148        },
149        ScalarValue::UInt16(val) => match ty {
150            DataType::Int8 => {
151                val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
152            }
153            DataType::Int16 => {
154                val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
155            }
156            DataType::Int32 => val.map(|v| ScalarValue::Int32(Some(v.into()))),
157            DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
158            DataType::UInt8 => {
159                val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
160            }
161            DataType::UInt16 => Some(value.clone()),
162            DataType::UInt32 => val.map(|v| ScalarValue::UInt32(Some(u32::from(v)))),
163            DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
164            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(f32::from(v)))),
165            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
166            _ => None,
167        },
168        ScalarValue::UInt32(val) => match ty {
169            DataType::Int8 => {
170                val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
171            }
172            DataType::Int16 => {
173                val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
174            }
175            DataType::Int32 => {
176                val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
177            }
178            DataType::Int64 => val.map(|v| ScalarValue::Int64(Some(v.into()))),
179            DataType::UInt8 => {
180                val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
181            }
182            DataType::UInt16 => {
183                val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
184            }
185            DataType::UInt32 => Some(value.clone()),
186            DataType::UInt64 => val.map(|v| ScalarValue::UInt64(Some(u64::from(v)))),
187            // See above warning about lossy float conversion
188            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
189            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
190            _ => None,
191        },
192        ScalarValue::UInt64(val) => match ty {
193            DataType::Int8 => {
194                val.and_then(|v| i8::try_from(v).map(|v| ScalarValue::Int8(Some(v))).ok())
195            }
196            DataType::Int16 => {
197                val.and_then(|v| i16::try_from(v).map(|v| ScalarValue::Int16(Some(v))).ok())
198            }
199            DataType::Int32 => {
200                val.and_then(|v| i32::try_from(v).map(|v| ScalarValue::Int32(Some(v))).ok())
201            }
202            DataType::Int64 => {
203                val.and_then(|v| i64::try_from(v).map(|v| ScalarValue::Int64(Some(v))).ok())
204            }
205            DataType::UInt8 => {
206                val.and_then(|v| u8::try_from(v).map(|v| ScalarValue::UInt8(Some(v))).ok())
207            }
208            DataType::UInt16 => {
209                val.and_then(|v| u16::try_from(v).map(|v| ScalarValue::UInt16(Some(v))).ok())
210            }
211            DataType::UInt32 => {
212                val.and_then(|v| u32::try_from(v).map(|v| ScalarValue::UInt32(Some(v))).ok())
213            }
214            DataType::UInt64 => Some(value.clone()),
215            // See above warning about lossy float conversion
216            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
217            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
218            _ => None,
219        },
220        ScalarValue::Float32(val) => match ty {
221            DataType::Float32 => Some(value.clone()),
222            DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(f64::from(v)))),
223            _ => None,
224        },
225        ScalarValue::Float64(val) => match ty {
226            DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
227            DataType::Float64 => Some(value.clone()),
228            _ => None,
229        },
230        ScalarValue::Utf8(val) => match ty {
231            DataType::Utf8 => Some(value.clone()),
232            DataType::LargeUtf8 => Some(ScalarValue::LargeUtf8(val.clone())),
233            DataType::Utf8View => Some(ScalarValue::Utf8View(val.clone())),
234            _ => None,
235        },
236        ScalarValue::LargeUtf8(val) => match ty {
237            DataType::Utf8 => Some(ScalarValue::Utf8(val.clone())),
238            DataType::LargeUtf8 => Some(value.clone()),
239            DataType::Utf8View => Some(ScalarValue::Utf8View(val.clone())),
240            _ => None,
241        },
242        ScalarValue::Utf8View(val) => match ty {
243            DataType::Utf8 => Some(ScalarValue::Utf8(val.clone())),
244            DataType::LargeUtf8 => Some(ScalarValue::LargeUtf8(val.clone())),
245            DataType::Utf8View => Some(value.clone()),
246            _ => None,
247        },
248        ScalarValue::Boolean(_) => match ty {
249            DataType::Boolean => Some(value.clone()),
250            _ => None,
251        },
252        ScalarValue::Null => Some(value.clone()),
253        ScalarValue::List(values) => {
254            let values = values.clone() as ArrayRef;
255            let new_values = cast(&values, ty).ok()?;
256            match ty {
257                DataType::List(_) => {
258                    Some(ScalarValue::List(Arc::new(new_values.as_list().clone())))
259                }
260                DataType::LargeList(_) => Some(ScalarValue::LargeList(Arc::new(
261                    new_values.as_list().clone(),
262                ))),
263                DataType::FixedSizeList(_, _) => Some(ScalarValue::FixedSizeList(Arc::new(
264                    new_values.as_fixed_size_list().clone(),
265                ))),
266                _ => None,
267            }
268        }
269        ScalarValue::TimestampSecond(seconds, _) => match ty {
270            DataType::Timestamp(TimeUnit::Second, _) => Some(value.clone()),
271            DataType::Timestamp(TimeUnit::Millisecond, tz) => seconds
272                .and_then(|v| v.checked_mul(1000))
273                .map(|val| ScalarValue::TimestampMillisecond(Some(val), tz.clone())),
274            DataType::Timestamp(TimeUnit::Microsecond, tz) => seconds
275                .and_then(|v| v.checked_mul(1000000))
276                .map(|val| ScalarValue::TimestampMicrosecond(Some(val), tz.clone())),
277            DataType::Timestamp(TimeUnit::Nanosecond, tz) => seconds
278                .and_then(|v| v.checked_mul(1000000000))
279                .map(|val| ScalarValue::TimestampNanosecond(Some(val), tz.clone())),
280            _ => None,
281        },
282        ScalarValue::TimestampMillisecond(millis, _) => match ty {
283            DataType::Timestamp(TimeUnit::Second, tz) => {
284                millis.map(|val| ScalarValue::TimestampSecond(Some(val / 1000), tz.clone()))
285            }
286            DataType::Timestamp(TimeUnit::Millisecond, _) => Some(value.clone()),
287            DataType::Timestamp(TimeUnit::Microsecond, tz) => millis
288                .and_then(|v| v.checked_mul(1000))
289                .map(|val| ScalarValue::TimestampMicrosecond(Some(val), tz.clone())),
290            DataType::Timestamp(TimeUnit::Nanosecond, tz) => millis
291                .and_then(|v| v.checked_mul(1000000))
292                .map(|val| ScalarValue::TimestampNanosecond(Some(val), tz.clone())),
293            _ => None,
294        },
295        ScalarValue::TimestampMicrosecond(micros, _) => match ty {
296            DataType::Timestamp(TimeUnit::Second, tz) => {
297                micros.map(|val| ScalarValue::TimestampSecond(Some(val / 1000000), tz.clone()))
298            }
299            DataType::Timestamp(TimeUnit::Millisecond, tz) => {
300                micros.map(|val| ScalarValue::TimestampMillisecond(Some(val / 1000), tz.clone()))
301            }
302            DataType::Timestamp(TimeUnit::Microsecond, _) => Some(value.clone()),
303            DataType::Timestamp(TimeUnit::Nanosecond, tz) => micros
304                .and_then(|v| v.checked_mul(1000))
305                .map(|val| ScalarValue::TimestampNanosecond(Some(val), tz.clone())),
306            _ => None,
307        },
308        ScalarValue::TimestampNanosecond(nanos, _) => {
309            match ty {
310                DataType::Timestamp(TimeUnit::Second, tz) => nanos
311                    .map(|val| ScalarValue::TimestampSecond(Some(val / 1000000000), tz.clone())),
312                DataType::Timestamp(TimeUnit::Millisecond, tz) => nanos
313                    .map(|val| ScalarValue::TimestampMillisecond(Some(val / 1000000), tz.clone())),
314                DataType::Timestamp(TimeUnit::Microsecond, tz) => {
315                    nanos.map(|val| ScalarValue::TimestampMicrosecond(Some(val / 1000), tz.clone()))
316                }
317                DataType::Timestamp(TimeUnit::Nanosecond, _) => Some(value.clone()),
318                _ => None,
319            }
320        }
321        ScalarValue::Date32(ticks) => match ty {
322            DataType::Date32 => Some(value.clone()),
323            DataType::Date64 => Some(ScalarValue::Date64(
324                ticks.map(|v| i64::from(v) * MS_PER_DAY),
325            )),
326            _ => None,
327        },
328        ScalarValue::Date64(ticks) => match ty {
329            DataType::Date32 => Some(ScalarValue::Date32(ticks.map(|v| (v / MS_PER_DAY) as i32))),
330            DataType::Date64 => Some(value.clone()),
331            _ => None,
332        },
333        ScalarValue::Time32Second(seconds) => {
334            match ty {
335                DataType::Time32(TimeUnit::Second) => Some(value.clone()),
336                DataType::Time32(TimeUnit::Millisecond) => {
337                    seconds.map(|val| ScalarValue::Time32Millisecond(Some(val * 1000)))
338                }
339                DataType::Time64(TimeUnit::Microsecond) => seconds
340                    .map(|val| ScalarValue::Time64Microsecond(Some(i64::from(val) * 1000000))),
341                DataType::Time64(TimeUnit::Nanosecond) => seconds
342                    .map(|val| ScalarValue::Time64Nanosecond(Some(i64::from(val) * 1000000000))),
343                _ => None,
344            }
345        }
346        ScalarValue::Time32Millisecond(millis) => match ty {
347            DataType::Time32(TimeUnit::Second) => {
348                millis.map(|val| ScalarValue::Time32Second(Some(val / 1000)))
349            }
350            DataType::Time32(TimeUnit::Millisecond) => Some(value.clone()),
351            DataType::Time64(TimeUnit::Microsecond) => {
352                millis.map(|val| ScalarValue::Time64Microsecond(Some(i64::from(val) * 1000)))
353            }
354            DataType::Time64(TimeUnit::Nanosecond) => {
355                millis.map(|val| ScalarValue::Time64Nanosecond(Some(i64::from(val) * 1000000)))
356            }
357            _ => None,
358        },
359        ScalarValue::Time64Microsecond(micros) => match ty {
360            DataType::Time32(TimeUnit::Second) => {
361                micros.map(|val| ScalarValue::Time32Second(Some((val / 1000000) as i32)))
362            }
363            DataType::Time32(TimeUnit::Millisecond) => {
364                micros.map(|val| ScalarValue::Time32Millisecond(Some((val / 1000) as i32)))
365            }
366            DataType::Time64(TimeUnit::Microsecond) => Some(value.clone()),
367            DataType::Time64(TimeUnit::Nanosecond) => {
368                micros.map(|val| ScalarValue::Time64Nanosecond(Some(val * 1000)))
369            }
370            _ => None,
371        },
372        ScalarValue::Time64Nanosecond(nanos) => match ty {
373            DataType::Time32(TimeUnit::Second) => {
374                nanos.map(|val| ScalarValue::Time32Second(Some((val / 1000000000) as i32)))
375            }
376            DataType::Time32(TimeUnit::Millisecond) => {
377                nanos.map(|val| ScalarValue::Time32Millisecond(Some((val / 1000000) as i32)))
378            }
379            DataType::Time64(TimeUnit::Microsecond) => {
380                nanos.map(|val| ScalarValue::Time64Microsecond(Some(val / 1000)))
381            }
382            DataType::Time64(TimeUnit::Nanosecond) => Some(value.clone()),
383            _ => None,
384        },
385        ScalarValue::LargeList(values) => {
386            let values = values.clone() as ArrayRef;
387            let new_values = cast(&values, ty).ok()?;
388            match ty {
389                DataType::List(_) => {
390                    Some(ScalarValue::List(Arc::new(new_values.as_list().clone())))
391                }
392                DataType::LargeList(_) => Some(ScalarValue::LargeList(Arc::new(
393                    new_values.as_list().clone(),
394                ))),
395                DataType::FixedSizeList(_, _) => Some(ScalarValue::FixedSizeList(Arc::new(
396                    new_values.as_fixed_size_list().clone(),
397                ))),
398                _ => None,
399            }
400        }
401        ScalarValue::FixedSizeList(values) => {
402            let values = values.clone() as ArrayRef;
403            let new_values = cast(&values, ty).ok()?;
404            match ty {
405                DataType::List(_) => {
406                    Some(ScalarValue::List(Arc::new(new_values.as_list().clone())))
407                }
408                DataType::LargeList(_) => Some(ScalarValue::LargeList(Arc::new(
409                    new_values.as_list().clone(),
410                ))),
411                DataType::FixedSizeList(_, _) => Some(ScalarValue::FixedSizeList(Arc::new(
412                    new_values.as_fixed_size_list().clone(),
413                ))),
414                _ => None,
415            }
416        }
417        ScalarValue::FixedSizeBinary(len, value) => match ty {
418            DataType::FixedSizeBinary(len2) => {
419                if len == len2 {
420                    Some(ScalarValue::FixedSizeBinary(*len, value.clone()))
421                } else {
422                    None
423                }
424            }
425            DataType::Binary => Some(ScalarValue::Binary(value.clone())),
426            _ => None,
427        },
428        ScalarValue::Binary(value) => match ty {
429            DataType::Binary => Some(ScalarValue::Binary(value.clone())),
430            DataType::LargeBinary => Some(ScalarValue::LargeBinary(value.clone())),
431            DataType::BinaryView => Some(ScalarValue::BinaryView(value.clone())),
432            DataType::FixedSizeBinary(len) => {
433                if let Some(value) = value {
434                    if value.len() == *len as usize {
435                        Some(ScalarValue::FixedSizeBinary(*len, Some(value.clone())))
436                    } else {
437                        None
438                    }
439                } else {
440                    None
441                }
442            }
443            _ => None,
444        },
445        ScalarValue::BinaryView(val) => match ty {
446            DataType::Binary => Some(ScalarValue::Binary(val.clone())),
447            DataType::LargeBinary => Some(ScalarValue::LargeBinary(val.clone())),
448            DataType::BinaryView => Some(value.clone()),
449            _ => None,
450        },
451        // A dictionary-encoded literal (e.g. produced by DataFusion's dictionary
452        // cast in the scalar-index path) coerces by unwrapping its underlying value.
453        ScalarValue::Dictionary(_, inner) => safe_coerce_scalar(inner, ty),
454        _ => None,
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_temporal_coerce() {
464        // Conversion from timestamps in one resolution to timestamps in another resolution is allowed
465        // s->s
466        assert_eq!(
467            safe_coerce_scalar(
468                &ScalarValue::TimestampSecond(Some(5), None),
469                &DataType::Timestamp(TimeUnit::Second, None),
470            ),
471            Some(ScalarValue::TimestampSecond(Some(5), None))
472        );
473        // s->ms
474        assert_eq!(
475            safe_coerce_scalar(
476                &ScalarValue::TimestampSecond(Some(5), None),
477                &DataType::Timestamp(TimeUnit::Millisecond, None),
478            ),
479            Some(ScalarValue::TimestampMillisecond(Some(5000), None))
480        );
481        // s->us
482        assert_eq!(
483            safe_coerce_scalar(
484                &ScalarValue::TimestampSecond(Some(5), None),
485                &DataType::Timestamp(TimeUnit::Microsecond, None),
486            ),
487            Some(ScalarValue::TimestampMicrosecond(Some(5000000), None))
488        );
489        // s->ns
490        assert_eq!(
491            safe_coerce_scalar(
492                &ScalarValue::TimestampSecond(Some(5), None),
493                &DataType::Timestamp(TimeUnit::Nanosecond, None),
494            ),
495            Some(ScalarValue::TimestampNanosecond(Some(5000000000), None))
496        );
497        // ms->s
498        assert_eq!(
499            safe_coerce_scalar(
500                &ScalarValue::TimestampMillisecond(Some(5000), None),
501                &DataType::Timestamp(TimeUnit::Second, None),
502            ),
503            Some(ScalarValue::TimestampSecond(Some(5), None))
504        );
505        // ms->ms
506        assert_eq!(
507            safe_coerce_scalar(
508                &ScalarValue::TimestampMillisecond(Some(5000), None),
509                &DataType::Timestamp(TimeUnit::Millisecond, None),
510            ),
511            Some(ScalarValue::TimestampMillisecond(Some(5000), None))
512        );
513        // ms->us
514        assert_eq!(
515            safe_coerce_scalar(
516                &ScalarValue::TimestampMillisecond(Some(5000), None),
517                &DataType::Timestamp(TimeUnit::Microsecond, None),
518            ),
519            Some(ScalarValue::TimestampMicrosecond(Some(5000000), None))
520        );
521        // ms->ns
522        assert_eq!(
523            safe_coerce_scalar(
524                &ScalarValue::TimestampMillisecond(Some(5000), None),
525                &DataType::Timestamp(TimeUnit::Nanosecond, None),
526            ),
527            Some(ScalarValue::TimestampNanosecond(Some(5000000000), None))
528        );
529        // us->s
530        assert_eq!(
531            safe_coerce_scalar(
532                &ScalarValue::TimestampMicrosecond(Some(5000000), None),
533                &DataType::Timestamp(TimeUnit::Second, None),
534            ),
535            Some(ScalarValue::TimestampSecond(Some(5), None))
536        );
537        // us->ms
538        assert_eq!(
539            safe_coerce_scalar(
540                &ScalarValue::TimestampMicrosecond(Some(5000000), None),
541                &DataType::Timestamp(TimeUnit::Millisecond, None),
542            ),
543            Some(ScalarValue::TimestampMillisecond(Some(5000), None))
544        );
545        // us->us
546        assert_eq!(
547            safe_coerce_scalar(
548                &ScalarValue::TimestampMicrosecond(Some(5000000), None),
549                &DataType::Timestamp(TimeUnit::Microsecond, None),
550            ),
551            Some(ScalarValue::TimestampMicrosecond(Some(5000000), None))
552        );
553        // us->ns
554        assert_eq!(
555            safe_coerce_scalar(
556                &ScalarValue::TimestampMicrosecond(Some(5000000), None),
557                &DataType::Timestamp(TimeUnit::Nanosecond, None),
558            ),
559            Some(ScalarValue::TimestampNanosecond(Some(5000000000), None))
560        );
561        // ns->s
562        assert_eq!(
563            safe_coerce_scalar(
564                &ScalarValue::TimestampNanosecond(Some(5000000000), None),
565                &DataType::Timestamp(TimeUnit::Second, None),
566            ),
567            Some(ScalarValue::TimestampSecond(Some(5), None))
568        );
569        // ns->ms
570        assert_eq!(
571            safe_coerce_scalar(
572                &ScalarValue::TimestampNanosecond(Some(5000000000), None),
573                &DataType::Timestamp(TimeUnit::Millisecond, None),
574            ),
575            Some(ScalarValue::TimestampMillisecond(Some(5000), None))
576        );
577        // ns->us
578        assert_eq!(
579            safe_coerce_scalar(
580                &ScalarValue::TimestampNanosecond(Some(5000000000), None),
581                &DataType::Timestamp(TimeUnit::Microsecond, None),
582            ),
583            Some(ScalarValue::TimestampMicrosecond(Some(5000000), None))
584        );
585        // ns->ns
586        assert_eq!(
587            safe_coerce_scalar(
588                &ScalarValue::TimestampNanosecond(Some(5000000000), None),
589                &DataType::Timestamp(TimeUnit::Nanosecond, None),
590            ),
591            Some(ScalarValue::TimestampNanosecond(Some(5000000000), None))
592        );
593        // Precision loss on coercion is allowed (truncation)
594        // ns->s
595        assert_eq!(
596            safe_coerce_scalar(
597                &ScalarValue::TimestampNanosecond(Some(5987654321), None),
598                &DataType::Timestamp(TimeUnit::Second, None),
599            ),
600            Some(ScalarValue::TimestampSecond(Some(5), None))
601        );
602        // Conversions from date-32 to date-64 is allowed
603        assert_eq!(
604            safe_coerce_scalar(&ScalarValue::Date32(Some(5)), &DataType::Date32,),
605            Some(ScalarValue::Date32(Some(5)))
606        );
607        assert_eq!(
608            safe_coerce_scalar(&ScalarValue::Date32(Some(5)), &DataType::Date64,),
609            Some(ScalarValue::Date64(Some(5 * MS_PER_DAY)))
610        );
611        assert_eq!(
612            safe_coerce_scalar(
613                &ScalarValue::Date64(Some(5 * MS_PER_DAY)),
614                &DataType::Date32,
615            ),
616            Some(ScalarValue::Date32(Some(5)))
617        );
618        assert_eq!(
619            safe_coerce_scalar(&ScalarValue::Date64(Some(5)), &DataType::Date64,),
620            Some(ScalarValue::Date64(Some(5)))
621        );
622        // Time-32 to time-64 (and within time-32 and time-64) is allowed
623        assert_eq!(
624            safe_coerce_scalar(
625                &ScalarValue::Time32Second(Some(5)),
626                &DataType::Time32(TimeUnit::Second),
627            ),
628            Some(ScalarValue::Time32Second(Some(5)))
629        );
630        assert_eq!(
631            safe_coerce_scalar(
632                &ScalarValue::Time32Second(Some(5)),
633                &DataType::Time32(TimeUnit::Millisecond),
634            ),
635            Some(ScalarValue::Time32Millisecond(Some(5000)))
636        );
637        assert_eq!(
638            safe_coerce_scalar(
639                &ScalarValue::Time32Second(Some(5)),
640                &DataType::Time64(TimeUnit::Microsecond),
641            ),
642            Some(ScalarValue::Time64Microsecond(Some(5000000)))
643        );
644        assert_eq!(
645            safe_coerce_scalar(
646                &ScalarValue::Time32Second(Some(5)),
647                &DataType::Time64(TimeUnit::Nanosecond),
648            ),
649            Some(ScalarValue::Time64Nanosecond(Some(5000000000)))
650        );
651        assert_eq!(
652            safe_coerce_scalar(
653                &ScalarValue::Time32Millisecond(Some(5000)),
654                &DataType::Time32(TimeUnit::Second),
655            ),
656            Some(ScalarValue::Time32Second(Some(5)))
657        );
658        assert_eq!(
659            safe_coerce_scalar(
660                &ScalarValue::Time32Millisecond(Some(5000)),
661                &DataType::Time32(TimeUnit::Millisecond),
662            ),
663            Some(ScalarValue::Time32Millisecond(Some(5000)))
664        );
665        assert_eq!(
666            safe_coerce_scalar(
667                &ScalarValue::Time32Millisecond(Some(5000)),
668                &DataType::Time64(TimeUnit::Microsecond),
669            ),
670            Some(ScalarValue::Time64Microsecond(Some(5000000)))
671        );
672        assert_eq!(
673            safe_coerce_scalar(
674                &ScalarValue::Time32Millisecond(Some(5000)),
675                &DataType::Time64(TimeUnit::Nanosecond),
676            ),
677            Some(ScalarValue::Time64Nanosecond(Some(5000000000)))
678        );
679        assert_eq!(
680            safe_coerce_scalar(
681                &ScalarValue::Time64Microsecond(Some(5000000)),
682                &DataType::Time32(TimeUnit::Second),
683            ),
684            Some(ScalarValue::Time32Second(Some(5)))
685        );
686        assert_eq!(
687            safe_coerce_scalar(
688                &ScalarValue::Time64Microsecond(Some(5000000)),
689                &DataType::Time32(TimeUnit::Millisecond),
690            ),
691            Some(ScalarValue::Time32Millisecond(Some(5000)))
692        );
693        assert_eq!(
694            safe_coerce_scalar(
695                &ScalarValue::Time64Microsecond(Some(5000000)),
696                &DataType::Time64(TimeUnit::Microsecond),
697            ),
698            Some(ScalarValue::Time64Microsecond(Some(5000000)))
699        );
700        assert_eq!(
701            safe_coerce_scalar(
702                &ScalarValue::Time64Microsecond(Some(5000000)),
703                &DataType::Time64(TimeUnit::Nanosecond),
704            ),
705            Some(ScalarValue::Time64Nanosecond(Some(5000000000)))
706        );
707        assert_eq!(
708            safe_coerce_scalar(
709                &ScalarValue::Time64Nanosecond(Some(5000000000)),
710                &DataType::Time32(TimeUnit::Second),
711            ),
712            Some(ScalarValue::Time32Second(Some(5)))
713        );
714        assert_eq!(
715            safe_coerce_scalar(
716                &ScalarValue::Time64Nanosecond(Some(5000000000)),
717                &DataType::Time32(TimeUnit::Millisecond),
718            ),
719            Some(ScalarValue::Time32Millisecond(Some(5000)))
720        );
721        assert_eq!(
722            safe_coerce_scalar(
723                &ScalarValue::Time64Nanosecond(Some(5000000000)),
724                &DataType::Time64(TimeUnit::Microsecond),
725            ),
726            Some(ScalarValue::Time64Microsecond(Some(5000000)))
727        );
728        assert_eq!(
729            safe_coerce_scalar(
730                &ScalarValue::Time64Nanosecond(Some(5000000000)),
731                &DataType::Time64(TimeUnit::Nanosecond),
732            ),
733            Some(ScalarValue::Time64Nanosecond(Some(5000000000)))
734        );
735    }
736
737    #[test]
738    fn test_string_view_coerce() {
739        // Utf8 <-> Utf8View
740        assert_eq!(
741            safe_coerce_scalar(&ScalarValue::Utf8(Some("hi".into())), &DataType::Utf8View),
742            Some(ScalarValue::Utf8View(Some("hi".into())))
743        );
744        assert_eq!(
745            safe_coerce_scalar(&ScalarValue::Utf8View(Some("hi".into())), &DataType::Utf8),
746            Some(ScalarValue::Utf8(Some("hi".into())))
747        );
748        assert_eq!(
749            safe_coerce_scalar(
750                &ScalarValue::Utf8View(Some("hi".into())),
751                &DataType::LargeUtf8
752            ),
753            Some(ScalarValue::LargeUtf8(Some("hi".into())))
754        );
755        assert_eq!(
756            safe_coerce_scalar(
757                &ScalarValue::LargeUtf8(Some("hi".into())),
758                &DataType::Utf8View
759            ),
760            Some(ScalarValue::Utf8View(Some("hi".into())))
761        );
762        // identity
763        assert_eq!(
764            safe_coerce_scalar(
765                &ScalarValue::Utf8View(Some("hi".into())),
766                &DataType::Utf8View
767            ),
768            Some(ScalarValue::Utf8View(Some("hi".into())))
769        );
770        // Binary <-> BinaryView
771        assert_eq!(
772            safe_coerce_scalar(
773                &ScalarValue::Binary(Some(vec![1, 2, 3])),
774                &DataType::BinaryView
775            ),
776            Some(ScalarValue::BinaryView(Some(vec![1, 2, 3])))
777        );
778        assert_eq!(
779            safe_coerce_scalar(
780                &ScalarValue::BinaryView(Some(vec![1, 2, 3])),
781                &DataType::Binary
782            ),
783            Some(ScalarValue::Binary(Some(vec![1, 2, 3])))
784        );
785        assert_eq!(
786            safe_coerce_scalar(
787                &ScalarValue::BinaryView(Some(vec![1, 2, 3])),
788                &DataType::BinaryView
789            ),
790            Some(ScalarValue::BinaryView(Some(vec![1, 2, 3])))
791        );
792    }
793
794    #[test]
795    fn test_dictionary_coerce() {
796        let dict_ty = DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
797
798        // A string literal coerces to a dictionary target by wrapping the
799        // coerced value in a dictionary scalar.
800        assert_eq!(
801            safe_coerce_scalar(&ScalarValue::Utf8(Some("com".to_string())), &dict_ty),
802            Some(ScalarValue::Dictionary(
803                Box::new(DataType::Int16),
804                Box::new(ScalarValue::Utf8(Some("com".to_string()))),
805            ))
806        );
807
808        // The inner value is coerced through to the dictionary value type, so a
809        // LargeUtf8 literal lands as a Utf8 value inside the dictionary.
810        assert_eq!(
811            safe_coerce_scalar(&ScalarValue::LargeUtf8(Some("com".to_string())), &dict_ty),
812            Some(ScalarValue::Dictionary(
813                Box::new(DataType::Int16),
814                Box::new(ScalarValue::Utf8(Some("com".to_string()))),
815            ))
816        );
817
818        // A dictionary literal round-trips back to its value type.
819        assert_eq!(
820            safe_coerce_scalar(
821                &ScalarValue::Dictionary(
822                    Box::new(DataType::Int16),
823                    Box::new(ScalarValue::Utf8(Some("com".to_string()))),
824                ),
825                &DataType::Utf8,
826            ),
827            Some(ScalarValue::Utf8(Some("com".to_string())))
828        );
829
830        // A dictionary literal coerces to a dictionary target, adopting the
831        // target's key type.
832        assert_eq!(
833            safe_coerce_scalar(
834                &ScalarValue::Dictionary(
835                    Box::new(DataType::Int32),
836                    Box::new(ScalarValue::Utf8(Some("com".to_string()))),
837                ),
838                &dict_ty,
839            ),
840            Some(ScalarValue::Dictionary(
841                Box::new(DataType::Int16),
842                Box::new(ScalarValue::Utf8(Some("com".to_string()))),
843            ))
844        );
845
846        // An untyped null keeps its untyped form for a dictionary target, just
847        // like for every other target type.
848        assert_eq!(
849            safe_coerce_scalar(&ScalarValue::Null, &dict_ty),
850            Some(ScalarValue::Null)
851        );
852
853        // A *typed* null (e.g. an API-built `Utf8(None)` literal, or an IN value
854        // already typed as Utf8) is still wrapped in the dictionary type so it
855        // matches the dictionary column. Returning a bare `Utf8(None)` here would
856        // leave `resolve_value` with a literal whose type does not line up with
857        // the column, breaking planning/evaluation the same way non-null strings
858        // used to break.
859        assert_eq!(
860            safe_coerce_scalar(&ScalarValue::Utf8(None), &dict_ty),
861            Some(ScalarValue::Dictionary(
862                Box::new(DataType::Int16),
863                Box::new(ScalarValue::Utf8(None)),
864            ))
865        );
866
867        // The inner null is coerced through to the dictionary value type as well,
868        // so a LargeUtf8 typed null lands as a Utf8 null inside the dictionary.
869        assert_eq!(
870            safe_coerce_scalar(&ScalarValue::LargeUtf8(None), &dict_ty),
871            Some(ScalarValue::Dictionary(
872                Box::new(DataType::Int16),
873                Box::new(ScalarValue::Utf8(None)),
874            ))
875        );
876
877        // A value that cannot be coerced to the dictionary value type fails.
878        assert_eq!(
879            safe_coerce_scalar(
880                &ScalarValue::Utf8(Some("com".to_string())),
881                &DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Int32)),
882            ),
883            None
884        );
885    }
886}