1use std::ops::Bound;
8
9use arrow::array::{
10 ArrayRef, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int8Array,
11 Int16Array, Int32Array, Int64Array, LargeStringArray, StringArray, StructArray, UInt8Array,
12 UInt16Array, UInt32Array, UInt64Array,
13};
14use arrow::datatypes::{ArrowPrimitiveType, DataType};
15
16use llkv_result::Error;
17use time::{Date, Month};
18
19use crate::decimal::DecimalValue;
20use crate::interval::IntervalValue;
21
22#[derive(Debug, Clone, PartialEq)]
26pub enum Literal {
27 Null,
28 Int128(i128),
29 Float64(f64),
30 Decimal128(DecimalValue),
32 String(String),
33 Boolean(bool),
34 Date32(i32),
36 Struct(Vec<(String, Box<Literal>)>),
38 Interval(IntervalValue),
40 }
42
43macro_rules! impl_from_for_literal {
44 ($variant:ident, $($t:ty),*) => {
45 $(
46 impl From<$t> for Literal {
47 fn from(v: $t) -> Self {
48 Literal::$variant(v.into())
49 }
50 }
51 )*
52 };
53}
54
55impl_from_for_literal!(Int128, i8, i16, i32, i64, i128, u8, u16, u32, u64);
56impl_from_for_literal!(Float64, f32, f64);
57impl_from_for_literal!(String, String);
58impl_from_for_literal!(Boolean, bool);
59impl_from_for_literal!(Decimal128, DecimalValue);
60impl_from_for_literal!(Interval, IntervalValue);
61
62impl From<&str> for Literal {
63 fn from(v: &str) -> Self {
64 Literal::String(v.to_string())
65 }
66}
67
68impl From<Vec<(String, Literal)>> for Literal {
69 fn from(fields: Vec<(String, Literal)>) -> Self {
70 let boxed_fields = fields
71 .into_iter()
72 .map(|(name, lit)| (name, Box::new(lit)))
73 .collect();
74 Literal::Struct(boxed_fields)
75 }
76}
77
78impl Literal {
79 pub fn format_display(&self) -> String {
81 match self {
82 Literal::Int128(i) => i.to_string(),
83 Literal::Float64(f) => f.to_string(),
84 Literal::Decimal128(d) => d.to_string(),
85 Literal::Boolean(b) => b.to_string(),
86 Literal::String(s) => format!("\"{}\"", escape_string(s)),
87 Literal::Date32(days) => format!("DATE '{}'", format_date32(*days)),
88 Literal::Interval(interval) => format!(
89 "INTERVAL {{ months: {}, days: {}, nanos: {} }}",
90 interval.months, interval.days, interval.nanos
91 ),
92 Literal::Null => "NULL".to_string(),
93 Literal::Struct(fields) => {
94 let field_strs: Vec<_> = fields
95 .iter()
96 .map(|(name, lit)| format!("{}: {}", name, lit.format_display()))
97 .collect();
98 format!("{{{}}}", field_strs.join(", "))
99 }
100 }
101 }
102}
103
104fn format_date32(days: i32) -> String {
105 let julian = match epoch_julian_day().checked_add(days) {
106 Some(value) => value,
107 None => return days.to_string(),
108 };
109
110 match Date::from_julian_day(julian) {
111 Ok(date) => {
112 let (year, month, day) = date.to_calendar_date();
113 let month_number = month as u8;
114 format!("{:04}-{:02}-{:02}", year, month_number, day)
115 }
116 Err(_) => days.to_string(),
117 }
118}
119
120fn epoch_julian_day() -> i32 {
121 Date::from_calendar_date(1970, Month::January, 1)
122 .expect("1970-01-01 is a valid date")
123 .to_julian_day()
124}
125
126fn escape_string(value: &str) -> String {
127 value.chars().flat_map(|c| c.escape_default()).collect()
128}
129
130#[derive(Debug, Clone, PartialEq)]
132pub enum LiteralCastError {
133 TypeMismatch {
135 expected: &'static str,
136 got: &'static str,
137 },
138 OutOfRange { target: &'static str, value: i128 },
140 FloatOutOfRange { target: &'static str, value: f64 },
142}
143
144impl std::fmt::Display for LiteralCastError {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 LiteralCastError::TypeMismatch { expected, got } => {
148 write!(f, "expected {}, got {}", expected, got)
149 }
150 LiteralCastError::OutOfRange { target, value } => {
151 write!(f, "value {} out of range for {}", value, target)
152 }
153 LiteralCastError::FloatOutOfRange { target, value } => {
154 write!(f, "value {} out of range for {}", value, target)
155 }
156 }
157 }
158}
159
160impl std::error::Error for LiteralCastError {}
161
162pub trait LiteralExt {
164 fn type_name(&self) -> &'static str;
165 fn to_string_owned(&self) -> Result<String, LiteralCastError>;
166 fn to_native<T>(&self) -> Result<T, LiteralCastError>
167 where
168 T: FromLiteral + Copy + 'static;
169 fn from_array_ref(array: &ArrayRef, index: usize) -> llkv_result::Result<Literal>;
170 fn bound_to_native<T>(bound: &Bound<Literal>) -> Result<Bound<T::Native>, LiteralCastError>
171 where
172 T: ArrowPrimitiveType,
173 T::Native: FromLiteral + Copy;
174}
175
176impl LiteralExt for Literal {
177 fn type_name(&self) -> &'static str {
178 match self {
179 Literal::Int128(_) => "integer",
180 Literal::Float64(_) => "float",
181 Literal::Decimal128(_) => "decimal",
182 Literal::String(_) => "string",
183 Literal::Boolean(_) => "boolean",
184 Literal::Date32(_) => "date",
185 Literal::Null => "null",
186 Literal::Struct(_) => "struct",
187 Literal::Interval(_) => "interval",
188 }
189 }
190
191 fn to_string_owned(&self) -> Result<String, LiteralCastError> {
192 match self {
193 Literal::String(s) => Ok(s.clone()),
194 Literal::Null => Err(LiteralCastError::TypeMismatch {
195 expected: "string",
196 got: "null",
197 }),
198 Literal::Date32(_) => Err(LiteralCastError::TypeMismatch {
199 expected: "string",
200 got: "date",
201 }),
202 other => Err(LiteralCastError::TypeMismatch {
203 expected: "string",
204 got: other.type_name(),
205 }),
206 }
207 }
208
209 fn to_native<T>(&self) -> Result<T, LiteralCastError>
210 where
211 T: FromLiteral + Copy + 'static,
212 {
213 T::from_literal(self)
214 }
215
216 fn from_array_ref(array: &ArrayRef, index: usize) -> llkv_result::Result<Literal> {
217 if array.is_null(index) {
218 return Ok(Literal::Null);
219 }
220
221 match array.data_type() {
222 DataType::Int8 => {
223 let arr = array.as_any().downcast_ref::<Int8Array>().unwrap();
224 Ok(Literal::Int128(arr.value(index) as i128))
225 }
226 DataType::Int16 => {
227 let arr = array.as_any().downcast_ref::<Int16Array>().unwrap();
228 Ok(Literal::Int128(arr.value(index) as i128))
229 }
230 DataType::Int32 => {
231 let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
232 Ok(Literal::Int128(arr.value(index) as i128))
233 }
234 DataType::Int64 => {
235 let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
236 Ok(Literal::Int128(arr.value(index) as i128))
237 }
238 DataType::UInt8 => {
239 let arr = array.as_any().downcast_ref::<UInt8Array>().unwrap();
240 Ok(Literal::Int128(arr.value(index) as i128))
241 }
242 DataType::UInt16 => {
243 let arr = array.as_any().downcast_ref::<UInt16Array>().unwrap();
244 Ok(Literal::Int128(arr.value(index) as i128))
245 }
246 DataType::UInt32 => {
247 let arr = array.as_any().downcast_ref::<UInt32Array>().unwrap();
248 Ok(Literal::Int128(arr.value(index) as i128))
249 }
250 DataType::UInt64 => {
251 let arr = array.as_any().downcast_ref::<UInt64Array>().unwrap();
252 Ok(Literal::Int128(arr.value(index) as i128))
253 }
254 DataType::Float32 => {
255 let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
256 Ok(Literal::Float64(arr.value(index) as f64))
257 }
258 DataType::Float64 => {
259 let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
260 Ok(Literal::Float64(arr.value(index)))
261 }
262 DataType::Utf8 => {
263 let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
264 Ok(Literal::String(arr.value(index).to_string()))
265 }
266 DataType::LargeUtf8 => {
267 let arr = array.as_any().downcast_ref::<LargeStringArray>().unwrap();
268 Ok(Literal::String(arr.value(index).to_string()))
269 }
270 DataType::Boolean => {
271 let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
272 Ok(Literal::Boolean(arr.value(index)))
273 }
274 DataType::Date32 => {
275 let arr = array.as_any().downcast_ref::<Date32Array>().unwrap();
276 Ok(Literal::Date32(arr.value(index)))
277 }
278 DataType::Decimal128(_, scale) => {
279 let arr = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
280 let val = arr.value(index);
281 let decimal = DecimalValue::new(val, *scale).map_err(|err| {
282 Error::InvalidArgumentError(format!(
283 "invalid decimal value for literal conversion: {err}"
284 ))
285 })?;
286 Ok(Literal::Decimal128(decimal))
287 }
288 DataType::Struct(fields) => {
289 let struct_array =
290 array
291 .as_any()
292 .downcast_ref::<StructArray>()
293 .ok_or_else(|| {
294 Error::InvalidArgumentError("failed to downcast struct array".into())
295 })?;
296 let mut members = Vec::with_capacity(fields.len());
297 for (idx, field) in fields.iter().enumerate() {
298 let child = struct_array.column(idx);
299 let literal = Literal::from_array_ref(child, index)?;
300 members.push((field.name().clone(), Box::new(literal)));
301 }
302 Ok(Literal::Struct(members))
303 }
304 other => Err(Error::InvalidArgumentError(format!(
305 "unsupported type for literal conversion: {other:?}"
306 ))),
307 }
308 }
309
310 fn bound_to_native<T>(bound: &Bound<Literal>) -> Result<Bound<T::Native>, LiteralCastError>
311 where
312 T: ArrowPrimitiveType,
313 T::Native: FromLiteral + Copy,
314 {
315 Ok(match bound {
316 Bound::Unbounded => Bound::Unbounded,
317 Bound::Included(l) => Bound::Included(T::Native::from_literal(l)?),
318 Bound::Excluded(l) => Bound::Excluded(T::Native::from_literal(l)?),
319 })
320 }
321}
322
323pub trait FromLiteral: Sized {
325 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError>;
326}
327
328macro_rules! impl_from_literal_int {
329 ($($ty:ty),* $(,)?) => {
330 $(
331 impl FromLiteral for $ty {
332 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
333 match lit {
334 Literal::Int128(i) => <$ty>::try_from(*i).map_err(|_| {
335 LiteralCastError::OutOfRange {
336 target: std::any::type_name::<$ty>(),
337 value: *i,
338 }
339 }),
340 Literal::Float64(_) => Err(LiteralCastError::TypeMismatch {
341 expected: "integer",
342 got: "float",
343 }),
344 Literal::Boolean(_) => Err(LiteralCastError::TypeMismatch {
345 expected: "integer",
346 got: "boolean",
347 }),
348 Literal::String(_) => Err(LiteralCastError::TypeMismatch {
349 expected: "integer",
350 got: "string",
351 }),
352 Literal::Date32(_) => Err(LiteralCastError::TypeMismatch {
353 expected: "integer",
354 got: "date",
355 }),
356 Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
357 expected: "integer",
358 got: "struct",
359 }),
360 Literal::Interval(_) => Err(LiteralCastError::TypeMismatch {
361 expected: "integer",
362 got: "interval",
363 }),
364 Literal::Decimal128(decimal) => {
365 if decimal.scale() == 0 {
366 let raw = decimal.raw_value();
367 <$ty>::try_from(raw).map_err(|_| LiteralCastError::OutOfRange {
368 target: std::any::type_name::<$ty>(),
369 value: raw,
370 })
371 } else {
372 Err(LiteralCastError::TypeMismatch {
373 expected: "integer",
374 got: "decimal",
375 })
376 }
377 }
378 Literal::Null => Err(LiteralCastError::TypeMismatch {
379 expected: "integer",
380 got: "null",
381 }),
382 }
383 }
384 }
385 )*
386 };
387}
388
389impl_from_literal_int!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, usize);
390
391impl FromLiteral for f32 {
392 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
393 let value = match lit {
394 Literal::Float64(f) => *f,
395 Literal::Int128(i) => *i as f64,
396 Literal::Decimal128(d) => d.to_f64(),
397 Literal::Boolean(_) => {
398 return Err(LiteralCastError::TypeMismatch {
399 expected: "float",
400 got: "boolean",
401 });
402 }
403 Literal::String(_) => {
404 return Err(LiteralCastError::TypeMismatch {
405 expected: "float",
406 got: "string",
407 });
408 }
409 Literal::Struct(_) => {
410 return Err(LiteralCastError::TypeMismatch {
411 expected: "float",
412 got: "struct",
413 });
414 }
415 Literal::Interval(_) => {
416 return Err(LiteralCastError::TypeMismatch {
417 expected: "float",
418 got: "interval",
419 });
420 }
421 Literal::Null => {
422 return Err(LiteralCastError::TypeMismatch {
423 expected: "float",
424 got: "null",
425 });
426 }
427 Literal::Date32(_) => {
428 return Err(LiteralCastError::TypeMismatch {
429 expected: "float",
430 got: "date",
431 });
432 }
433 };
434
435 let casted = value as f32;
436 if casted.is_finite() {
437 Ok(casted)
438 } else {
439 Err(LiteralCastError::FloatOutOfRange {
440 target: "f32",
441 value,
442 })
443 }
444 }
445}
446
447impl FromLiteral for f64 {
448 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
449 match lit {
450 Literal::Float64(f) => Ok(*f),
451 Literal::Int128(i) => Ok(*i as f64),
452 Literal::Decimal128(d) => Ok(d.to_f64()),
453 Literal::Boolean(_) => Err(LiteralCastError::TypeMismatch {
454 expected: "float",
455 got: "boolean",
456 }),
457 Literal::String(_) => Err(LiteralCastError::TypeMismatch {
458 expected: "float",
459 got: "string",
460 }),
461 Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
462 expected: "float",
463 got: "struct",
464 }),
465 Literal::Interval(_) => Err(LiteralCastError::TypeMismatch {
466 expected: "float",
467 got: "interval",
468 }),
469 Literal::Null => Err(LiteralCastError::TypeMismatch {
470 expected: "float",
471 got: "null",
472 }),
473 Literal::Date32(_) => Err(LiteralCastError::TypeMismatch {
474 expected: "float",
475 got: "date",
476 }),
477 }
478 }
479}
480
481impl FromLiteral for bool {
482 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
483 match lit {
484 Literal::Boolean(b) => Ok(*b),
485 Literal::Int128(i) => match *i {
486 0 => Ok(false),
487 1 => Ok(true),
488 value => Err(LiteralCastError::OutOfRange {
489 target: "bool",
490 value,
491 }),
492 },
493 Literal::Float64(_) => Err(LiteralCastError::TypeMismatch {
494 expected: "bool",
495 got: "float",
496 }),
497 Literal::String(s) => {
498 let normalized = s.trim().to_ascii_lowercase();
499 match normalized.as_str() {
500 "true" | "t" | "1" => Ok(true),
501 "false" | "f" | "0" => Ok(false),
502 _ => Err(LiteralCastError::TypeMismatch {
503 expected: "bool",
504 got: "string",
505 }),
506 }
507 }
508 Literal::Date32(_) => Err(LiteralCastError::TypeMismatch {
509 expected: "bool",
510 got: "date",
511 }),
512 Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
513 expected: "bool",
514 got: "struct",
515 }),
516 Literal::Interval(_) => Err(LiteralCastError::TypeMismatch {
517 expected: "bool",
518 got: "interval",
519 }),
520 Literal::Decimal128(_) => Err(LiteralCastError::TypeMismatch {
521 expected: "bool",
522 got: "decimal",
523 }),
524 Literal::Null => Err(LiteralCastError::TypeMismatch {
525 expected: "bool",
526 got: "null",
527 }),
528 }
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn boolean_literal_roundtrip() {
538 let lit = Literal::from(true);
539 assert_eq!(lit, Literal::Boolean(true));
540 assert!(lit.to_native::<bool>().unwrap());
541 assert!(!Literal::Boolean(false).to_native::<bool>().unwrap());
542 }
543
544 #[test]
545 fn boolean_literal_rejects_integer_cast() {
546 let lit = Literal::Boolean(true);
547 let err = lit.to_native::<i32>().unwrap_err();
548 assert!(matches!(
549 err,
550 LiteralCastError::TypeMismatch {
551 expected: "integer",
552 got: "boolean",
553 }
554 ));
555 }
556}