llkv_types/
literal.rs

1//! Untyped literal values plus helpers for converting them into native types.
2//!
3//! Literals capture query parameters before a table knows the concrete Arrow
4//! type of each column. Conversion helpers here defer type checking until the
5//! caller can perform schema-aware coercion.
6
7use std::ops::Bound;
8
9use arrow::array::{
10    ArrayRef, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int8Array,
11    Int16Array, Int32Array, Int64Array, LargeStringArray, StringArray, StructArray, UInt8Array,
12    UInt16Array, UInt32Array, UInt64Array,
13};
14use arrow::datatypes::{ArrowPrimitiveType, DataType};
15
16use llkv_result::Error;
17use time::{Date, Month};
18
19use crate::decimal::DecimalValue;
20use crate::interval::IntervalValue;
21
22/// A literal value that has not yet been coerced into a specific native
23/// type. This allows for type inference to be deferred until the column
24/// type is known.
25#[derive(Debug, Clone, PartialEq)]
26pub enum Literal {
27    Null,
28    Int128(i128),
29    Float64(f64),
30    /// Decimal literal stored as scaled integer with fixed precision.
31    Decimal128(DecimalValue),
32    String(String),
33    Boolean(bool),
34    /// Date literal stored as days since the Unix epoch (1970-01-01).
35    Date32(i32),
36    /// Struct literal with field names and nested literals
37    Struct(Vec<(String, Box<Literal>)>),
38    /// Interval literal with mixed calendar and sub-day precision.
39    Interval(IntervalValue),
40    // Other types like Bytes, etc. can be added here.
41}
42
43macro_rules! impl_from_for_literal {
44    ($variant:ident, $($t:ty),*) => {
45        $(
46            impl From<$t> for Literal {
47                fn from(v: $t) -> Self {
48                    Literal::$variant(v.into())
49                }
50            }
51        )*
52    };
53}
54
55impl_from_for_literal!(Int128, i8, i16, i32, i64, i128, u8, u16, u32, u64);
56impl_from_for_literal!(Float64, f32, f64);
57impl_from_for_literal!(String, String);
58impl_from_for_literal!(Boolean, bool);
59impl_from_for_literal!(Decimal128, DecimalValue);
60impl_from_for_literal!(Interval, IntervalValue);
61
62impl From<&str> for Literal {
63    fn from(v: &str) -> Self {
64        Literal::String(v.to_string())
65    }
66}
67
68impl From<Vec<(String, Literal)>> for Literal {
69    fn from(fields: Vec<(String, Literal)>) -> Self {
70        let boxed_fields = fields
71            .into_iter()
72            .map(|(name, lit)| (name, Box::new(lit)))
73            .collect();
74        Literal::Struct(boxed_fields)
75    }
76}
77
78impl Literal {
79    /// Human-friendly rendering used in plan/debug output.
80    pub fn format_display(&self) -> String {
81        match self {
82            Literal::Int128(i) => i.to_string(),
83            Literal::Float64(f) => f.to_string(),
84            Literal::Decimal128(d) => d.to_string(),
85            Literal::Boolean(b) => b.to_string(),
86            Literal::String(s) => format!("\"{}\"", escape_string(s)),
87            Literal::Date32(days) => format!("DATE '{}'", format_date32(*days)),
88            Literal::Interval(interval) => format!(
89                "INTERVAL {{ months: {}, days: {}, nanos: {} }}",
90                interval.months, interval.days, interval.nanos
91            ),
92            Literal::Null => "NULL".to_string(),
93            Literal::Struct(fields) => {
94                let field_strs: Vec<_> = fields
95                    .iter()
96                    .map(|(name, lit)| format!("{}: {}", name, lit.format_display()))
97                    .collect();
98                format!("{{{}}}", field_strs.join(", "))
99            }
100        }
101    }
102}
103
104fn format_date32(days: i32) -> String {
105    let julian = match epoch_julian_day().checked_add(days) {
106        Some(value) => value,
107        None => return days.to_string(),
108    };
109
110    match Date::from_julian_day(julian) {
111        Ok(date) => {
112            let (year, month, day) = date.to_calendar_date();
113            let month_number = month as u8;
114            format!("{:04}-{:02}-{:02}", year, month_number, day)
115        }
116        Err(_) => days.to_string(),
117    }
118}
119
120fn epoch_julian_day() -> i32 {
121    Date::from_calendar_date(1970, Month::January, 1)
122        .expect("1970-01-01 is a valid date")
123        .to_julian_day()
124}
125
126fn escape_string(value: &str) -> String {
127    value.chars().flat_map(|c| c.escape_default()).collect()
128}
129
130/// Error converting a `Literal` into a concrete native type.
131#[derive(Debug, Clone, PartialEq)]
132pub enum LiteralCastError {
133    /// Tried to coerce a non-integer literal into an integer native type.
134    TypeMismatch {
135        expected: &'static str,
136        got: &'static str,
137    },
138    /// Integer value does not fit in the destination type.
139    OutOfRange { target: &'static str, value: i128 },
140    /// Float value does not fit in the destination type.
141    FloatOutOfRange { target: &'static str, value: f64 },
142}
143
144impl std::fmt::Display for LiteralCastError {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            LiteralCastError::TypeMismatch { expected, got } => {
148                write!(f, "expected {}, got {}", expected, got)
149            }
150            LiteralCastError::OutOfRange { target, value } => {
151                write!(f, "value {} out of range for {}", value, target)
152            }
153            LiteralCastError::FloatOutOfRange { target, value } => {
154                write!(f, "value {} out of range for {}", value, target)
155            }
156        }
157    }
158}
159
160impl std::error::Error for LiteralCastError {}
161
162/// Extension methods for working with `Literal`.
163pub trait LiteralExt {
164    fn type_name(&self) -> &'static str;
165    fn to_string_owned(&self) -> Result<String, LiteralCastError>;
166    fn to_native<T>(&self) -> Result<T, LiteralCastError>
167    where
168        T: FromLiteral + Copy + 'static;
169    fn from_array_ref(array: &ArrayRef, index: usize) -> llkv_result::Result<Literal>;
170    fn bound_to_native<T>(bound: &Bound<Literal>) -> Result<Bound<T::Native>, LiteralCastError>
171    where
172        T: ArrowPrimitiveType,
173        T::Native: FromLiteral + Copy;
174}
175
176impl LiteralExt for Literal {
177    fn type_name(&self) -> &'static str {
178        match self {
179            Literal::Int128(_) => "integer",
180            Literal::Float64(_) => "float",
181            Literal::Decimal128(_) => "decimal",
182            Literal::String(_) => "string",
183            Literal::Boolean(_) => "boolean",
184            Literal::Date32(_) => "date",
185            Literal::Null => "null",
186            Literal::Struct(_) => "struct",
187            Literal::Interval(_) => "interval",
188        }
189    }
190
191    fn to_string_owned(&self) -> Result<String, LiteralCastError> {
192        match self {
193            Literal::String(s) => Ok(s.clone()),
194            Literal::Null => Err(LiteralCastError::TypeMismatch {
195                expected: "string",
196                got: "null",
197            }),
198            Literal::Date32(_) => Err(LiteralCastError::TypeMismatch {
199                expected: "string",
200                got: "date",
201            }),
202            other => Err(LiteralCastError::TypeMismatch {
203                expected: "string",
204                got: other.type_name(),
205            }),
206        }
207    }
208
209    fn to_native<T>(&self) -> Result<T, LiteralCastError>
210    where
211        T: FromLiteral + Copy + 'static,
212    {
213        T::from_literal(self)
214    }
215
216    fn from_array_ref(array: &ArrayRef, index: usize) -> llkv_result::Result<Literal> {
217        if array.is_null(index) {
218            return Ok(Literal::Null);
219        }
220
221        match array.data_type() {
222            DataType::Int8 => {
223                let arr = array.as_any().downcast_ref::<Int8Array>().unwrap();
224                Ok(Literal::Int128(arr.value(index) as i128))
225            }
226            DataType::Int16 => {
227                let arr = array.as_any().downcast_ref::<Int16Array>().unwrap();
228                Ok(Literal::Int128(arr.value(index) as i128))
229            }
230            DataType::Int32 => {
231                let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
232                Ok(Literal::Int128(arr.value(index) as i128))
233            }
234            DataType::Int64 => {
235                let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
236                Ok(Literal::Int128(arr.value(index) as i128))
237            }
238            DataType::UInt8 => {
239                let arr = array.as_any().downcast_ref::<UInt8Array>().unwrap();
240                Ok(Literal::Int128(arr.value(index) as i128))
241            }
242            DataType::UInt16 => {
243                let arr = array.as_any().downcast_ref::<UInt16Array>().unwrap();
244                Ok(Literal::Int128(arr.value(index) as i128))
245            }
246            DataType::UInt32 => {
247                let arr = array.as_any().downcast_ref::<UInt32Array>().unwrap();
248                Ok(Literal::Int128(arr.value(index) as i128))
249            }
250            DataType::UInt64 => {
251                let arr = array.as_any().downcast_ref::<UInt64Array>().unwrap();
252                Ok(Literal::Int128(arr.value(index) as i128))
253            }
254            DataType::Float32 => {
255                let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
256                Ok(Literal::Float64(arr.value(index) as f64))
257            }
258            DataType::Float64 => {
259                let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
260                Ok(Literal::Float64(arr.value(index)))
261            }
262            DataType::Utf8 => {
263                let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
264                Ok(Literal::String(arr.value(index).to_string()))
265            }
266            DataType::LargeUtf8 => {
267                let arr = array.as_any().downcast_ref::<LargeStringArray>().unwrap();
268                Ok(Literal::String(arr.value(index).to_string()))
269            }
270            DataType::Boolean => {
271                let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
272                Ok(Literal::Boolean(arr.value(index)))
273            }
274            DataType::Date32 => {
275                let arr = array.as_any().downcast_ref::<Date32Array>().unwrap();
276                Ok(Literal::Date32(arr.value(index)))
277            }
278            DataType::Decimal128(_, scale) => {
279                let arr = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
280                let val = arr.value(index);
281                let decimal = DecimalValue::new(val, *scale).map_err(|err| {
282                    Error::InvalidArgumentError(format!(
283                        "invalid decimal value for literal conversion: {err}"
284                    ))
285                })?;
286                Ok(Literal::Decimal128(decimal))
287            }
288            DataType::Struct(fields) => {
289                let struct_array =
290                    array
291                        .as_any()
292                        .downcast_ref::<StructArray>()
293                        .ok_or_else(|| {
294                            Error::InvalidArgumentError("failed to downcast struct array".into())
295                        })?;
296                let mut members = Vec::with_capacity(fields.len());
297                for (idx, field) in fields.iter().enumerate() {
298                    let child = struct_array.column(idx);
299                    let literal = Literal::from_array_ref(child, index)?;
300                    members.push((field.name().clone(), Box::new(literal)));
301                }
302                Ok(Literal::Struct(members))
303            }
304            other => Err(Error::InvalidArgumentError(format!(
305                "unsupported type for literal conversion: {other:?}"
306            ))),
307        }
308    }
309
310    fn bound_to_native<T>(bound: &Bound<Literal>) -> Result<Bound<T::Native>, LiteralCastError>
311    where
312        T: ArrowPrimitiveType,
313        T::Native: FromLiteral + Copy,
314    {
315        Ok(match bound {
316            Bound::Unbounded => Bound::Unbounded,
317            Bound::Included(l) => Bound::Included(T::Native::from_literal(l)?),
318            Bound::Excluded(l) => Bound::Excluded(T::Native::from_literal(l)?),
319        })
320    }
321}
322
323/// Helper trait implemented for primitive types that can be produced from a `Literal`.
324pub trait FromLiteral: Sized {
325    fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError>;
326}
327
328macro_rules! impl_from_literal_int {
329    ($($ty:ty),* $(,)?) => {
330        $(
331            impl FromLiteral for $ty {
332                fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
333                    match lit {
334                        Literal::Int128(i) => <$ty>::try_from(*i).map_err(|_| {
335                            LiteralCastError::OutOfRange {
336                                target: std::any::type_name::<$ty>(),
337                                value: *i,
338                            }
339                        }),
340                        Literal::Float64(_) => Err(LiteralCastError::TypeMismatch {
341                            expected: "integer",
342                            got: "float",
343                        }),
344                        Literal::Boolean(_) => Err(LiteralCastError::TypeMismatch {
345                            expected: "integer",
346                            got: "boolean",
347                        }),
348                        Literal::String(_) => Err(LiteralCastError::TypeMismatch {
349                            expected: "integer",
350                            got: "string",
351                        }),
352                        Literal::Date32(_) => Err(LiteralCastError::TypeMismatch {
353                            expected: "integer",
354                            got: "date",
355                        }),
356                        Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
357                            expected: "integer",
358                            got: "struct",
359                        }),
360                        Literal::Interval(_) => Err(LiteralCastError::TypeMismatch {
361                            expected: "integer",
362                            got: "interval",
363                        }),
364                        Literal::Decimal128(decimal) => {
365                            if decimal.scale() == 0 {
366                                let raw = decimal.raw_value();
367                                <$ty>::try_from(raw).map_err(|_| LiteralCastError::OutOfRange {
368                                    target: std::any::type_name::<$ty>(),
369                                    value: raw,
370                                })
371                            } else {
372                                Err(LiteralCastError::TypeMismatch {
373                                    expected: "integer",
374                                    got: "decimal",
375                                })
376                            }
377                        }
378                        Literal::Null => Err(LiteralCastError::TypeMismatch {
379                            expected: "integer",
380                            got: "null",
381                        }),
382                    }
383                }
384            }
385        )*
386    };
387}
388
389impl_from_literal_int!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, usize);
390
391impl FromLiteral for f32 {
392    fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
393        let value = match lit {
394            Literal::Float64(f) => *f,
395            Literal::Int128(i) => *i as f64,
396            Literal::Decimal128(d) => d.to_f64(),
397            Literal::Boolean(_) => {
398                return Err(LiteralCastError::TypeMismatch {
399                    expected: "float",
400                    got: "boolean",
401                });
402            }
403            Literal::String(_) => {
404                return Err(LiteralCastError::TypeMismatch {
405                    expected: "float",
406                    got: "string",
407                });
408            }
409            Literal::Struct(_) => {
410                return Err(LiteralCastError::TypeMismatch {
411                    expected: "float",
412                    got: "struct",
413                });
414            }
415            Literal::Interval(_) => {
416                return Err(LiteralCastError::TypeMismatch {
417                    expected: "float",
418                    got: "interval",
419                });
420            }
421            Literal::Null => {
422                return Err(LiteralCastError::TypeMismatch {
423                    expected: "float",
424                    got: "null",
425                });
426            }
427            Literal::Date32(_) => {
428                return Err(LiteralCastError::TypeMismatch {
429                    expected: "float",
430                    got: "date",
431                });
432            }
433        };
434
435        let casted = value as f32;
436        if casted.is_finite() {
437            Ok(casted)
438        } else {
439            Err(LiteralCastError::FloatOutOfRange {
440                target: "f32",
441                value,
442            })
443        }
444    }
445}
446
447impl FromLiteral for f64 {
448    fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
449        match lit {
450            Literal::Float64(f) => Ok(*f),
451            Literal::Int128(i) => Ok(*i as f64),
452            Literal::Decimal128(d) => Ok(d.to_f64()),
453            Literal::Boolean(_) => Err(LiteralCastError::TypeMismatch {
454                expected: "float",
455                got: "boolean",
456            }),
457            Literal::String(_) => Err(LiteralCastError::TypeMismatch {
458                expected: "float",
459                got: "string",
460            }),
461            Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
462                expected: "float",
463                got: "struct",
464            }),
465            Literal::Interval(_) => Err(LiteralCastError::TypeMismatch {
466                expected: "float",
467                got: "interval",
468            }),
469            Literal::Null => Err(LiteralCastError::TypeMismatch {
470                expected: "float",
471                got: "null",
472            }),
473            Literal::Date32(_) => Err(LiteralCastError::TypeMismatch {
474                expected: "float",
475                got: "date",
476            }),
477        }
478    }
479}
480
481impl FromLiteral for bool {
482    fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
483        match lit {
484            Literal::Boolean(b) => Ok(*b),
485            Literal::Int128(i) => match *i {
486                0 => Ok(false),
487                1 => Ok(true),
488                value => Err(LiteralCastError::OutOfRange {
489                    target: "bool",
490                    value,
491                }),
492            },
493            Literal::Float64(_) => Err(LiteralCastError::TypeMismatch {
494                expected: "bool",
495                got: "float",
496            }),
497            Literal::String(s) => {
498                let normalized = s.trim().to_ascii_lowercase();
499                match normalized.as_str() {
500                    "true" | "t" | "1" => Ok(true),
501                    "false" | "f" | "0" => Ok(false),
502                    _ => Err(LiteralCastError::TypeMismatch {
503                        expected: "bool",
504                        got: "string",
505                    }),
506                }
507            }
508            Literal::Date32(_) => Err(LiteralCastError::TypeMismatch {
509                expected: "bool",
510                got: "date",
511            }),
512            Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
513                expected: "bool",
514                got: "struct",
515            }),
516            Literal::Interval(_) => Err(LiteralCastError::TypeMismatch {
517                expected: "bool",
518                got: "interval",
519            }),
520            Literal::Decimal128(_) => Err(LiteralCastError::TypeMismatch {
521                expected: "bool",
522                got: "decimal",
523            }),
524            Literal::Null => Err(LiteralCastError::TypeMismatch {
525                expected: "bool",
526                got: "null",
527            }),
528        }
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn boolean_literal_roundtrip() {
538        let lit = Literal::from(true);
539        assert_eq!(lit, Literal::Boolean(true));
540        assert!(lit.to_native::<bool>().unwrap());
541        assert!(!Literal::Boolean(false).to_native::<bool>().unwrap());
542    }
543
544    #[test]
545    fn boolean_literal_rejects_integer_cast() {
546        let lit = Literal::Boolean(true);
547        let err = lit.to_native::<i32>().unwrap_err();
548        assert!(matches!(
549            err,
550            LiteralCastError::TypeMismatch {
551                expected: "integer",
552                got: "boolean",
553            }
554        ));
555    }
556}