1use arrow::datatypes::ArrowPrimitiveType;
2use std::ops::Bound;
3
4#[derive(Debug, Clone, PartialEq)]
8pub enum Literal {
9 Null,
10 Integer(i128),
11 Float(f64),
12 String(String),
13 Boolean(bool),
14 Struct(Vec<(String, Box<Literal>)>),
16 }
18
19macro_rules! impl_from_for_literal {
20 ($variant:ident, $($t:ty),*) => {
21 $(
22 impl From<$t> for Literal {
23 fn from(v: $t) -> Self {
24 Literal::$variant(v.into())
25 }
26 }
27 )*
28 };
29}
30
31impl_from_for_literal!(Integer, i8, i16, i32, i64, i128, u8, u16, u32, u64);
32impl_from_for_literal!(Float, f32, f64);
33
34impl From<&str> for Literal {
35 fn from(v: &str) -> Self {
36 Literal::String(v.to_string())
37 }
38}
39
40impl From<bool> for Literal {
41 fn from(v: bool) -> Self {
42 Literal::Boolean(v)
43 }
44}
45
46#[derive(Debug, Clone, PartialEq)]
48pub enum LiteralCastError {
49 TypeMismatch {
51 expected: &'static str,
52 got: &'static str,
53 },
54 OutOfRange { target: &'static str, value: i128 },
56 FloatOutOfRange { target: &'static str, value: f64 },
58}
59
60impl std::fmt::Display for LiteralCastError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 LiteralCastError::TypeMismatch { expected, got } => {
64 write!(f, "expected {}, got {}", expected, got)
65 }
66 LiteralCastError::OutOfRange { target, value } => {
67 write!(f, "value {} out of range for {}", value, target)
68 }
69 LiteralCastError::FloatOutOfRange { target, value } => {
70 write!(f, "value {} out of range for {}", value, target)
71 }
72 }
73 }
74}
75
76impl std::error::Error for LiteralCastError {}
77
78pub trait FromLiteral: Sized {
80 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError>;
81}
82
83macro_rules! impl_from_literal_int {
84 ($($ty:ty),* $(,)?) => {
85 $(
86 impl FromLiteral for $ty {
87 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
88 match lit {
89 Literal::Integer(i) => <$ty>::try_from(*i).map_err(|_| {
90 LiteralCastError::OutOfRange {
91 target: std::any::type_name::<$ty>(),
92 value: *i,
93 }
94 }),
95 Literal::Float(_) => Err(LiteralCastError::TypeMismatch {
96 expected: "integer",
97 got: "float",
98 }),
99 Literal::Boolean(_) => Err(LiteralCastError::TypeMismatch {
100 expected: "integer",
101 got: "boolean",
102 }),
103 Literal::String(_) => Err(LiteralCastError::TypeMismatch {
104 expected: "integer",
105 got: "string",
106 }),
107 Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
108 expected: "integer",
109 got: "struct",
110 }),
111 Literal::Null => Err(LiteralCastError::TypeMismatch {
112 expected: "integer",
113 got: "null",
114 }),
115 }
116 }
117 }
118 )*
119 };
120}
121
122impl_from_literal_int!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, usize);
123
124impl FromLiteral for f32 {
125 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
126 let value = match lit {
127 Literal::Float(f) => *f,
128 Literal::Integer(i) => *i as f64,
129 Literal::Boolean(_) => {
130 return Err(LiteralCastError::TypeMismatch {
131 expected: "float",
132 got: "boolean",
133 });
134 }
135 Literal::String(_) => {
136 return Err(LiteralCastError::TypeMismatch {
137 expected: "float",
138 got: "string",
139 });
140 }
141 Literal::Struct(_) => {
142 return Err(LiteralCastError::TypeMismatch {
143 expected: "float",
144 got: "struct",
145 });
146 }
147 Literal::Null => {
148 return Err(LiteralCastError::TypeMismatch {
149 expected: "float",
150 got: "null",
151 });
152 }
153 };
154 let cast = value as f32;
155 if value.is_finite() && !cast.is_finite() {
156 return Err(LiteralCastError::FloatOutOfRange {
157 target: "f32",
158 value,
159 });
160 }
161 Ok(cast)
162 }
163}
164
165impl FromLiteral for bool {
166 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
167 match lit {
168 Literal::Boolean(b) => Ok(*b),
169 Literal::Integer(i) => match *i {
170 0 => Ok(false),
171 1 => Ok(true),
172 value => Err(LiteralCastError::OutOfRange {
173 target: "bool",
174 value,
175 }),
176 },
177 Literal::Float(_) => Err(LiteralCastError::TypeMismatch {
178 expected: "bool",
179 got: "float",
180 }),
181 Literal::String(s) => {
182 let normalized = s.trim().to_ascii_lowercase();
183 match normalized.as_str() {
184 "true" | "t" | "1" => Ok(true),
185 "false" | "f" | "0" => Ok(false),
186 _ => Err(LiteralCastError::TypeMismatch {
187 expected: "bool",
188 got: "string",
189 }),
190 }
191 }
192 Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
193 expected: "bool",
194 got: "struct",
195 }),
196 Literal::Null => Err(LiteralCastError::TypeMismatch {
197 expected: "bool",
198 got: "null",
199 }),
200 }
201 }
202}
203
204impl FromLiteral for f64 {
205 fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
206 match lit {
207 Literal::Float(f) => Ok(*f),
208 Literal::Integer(i) => Ok(*i as f64),
209 Literal::Boolean(_) => Err(LiteralCastError::TypeMismatch {
210 expected: "float",
211 got: "boolean",
212 }),
213 Literal::String(_) => Err(LiteralCastError::TypeMismatch {
214 expected: "float",
215 got: "string",
216 }),
217 Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
218 expected: "float",
219 got: "struct",
220 }),
221 Literal::Null => Err(LiteralCastError::TypeMismatch {
222 expected: "float",
223 got: "null",
224 }),
225 }
226 }
227}
228
229fn literal_type_name(lit: &Literal) -> &'static str {
230 match lit {
231 Literal::Integer(_) => "integer",
232 Literal::Float(_) => "float",
233 Literal::String(_) => "string",
234 Literal::Boolean(_) => "boolean",
235 Literal::Null => "null",
236 Literal::Struct(_) => "struct",
237 }
238}
239
240pub fn literal_to_string(lit: &Literal) -> Result<String, LiteralCastError> {
242 match lit {
243 Literal::String(s) => Ok(s.clone()),
244 Literal::Null => Err(LiteralCastError::TypeMismatch {
245 expected: "string",
246 got: "null",
247 }),
248 _ => Err(LiteralCastError::TypeMismatch {
249 expected: "string",
250 got: literal_type_name(lit),
251 }),
252 }
253}
254
255pub fn literal_to_native<T>(lit: &Literal) -> Result<T, LiteralCastError>
257where
258 T: FromLiteral + Copy + 'static,
259{
260 T::from_literal(lit)
261}
262
263pub fn bound_to_native<T>(bound: &Bound<Literal>) -> Result<Bound<T::Native>, LiteralCastError>
270where
271 T: ArrowPrimitiveType,
272 T::Native: FromLiteral + Copy,
273{
274 Ok(match bound {
275 Bound::Unbounded => Bound::Unbounded,
276 Bound::Included(l) => Bound::Included(literal_to_native::<T::Native>(l)?),
277 Bound::Excluded(l) => Bound::Excluded(literal_to_native::<T::Native>(l)?),
278 })
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn boolean_literal_roundtrip() {
287 let lit = Literal::from(true);
288 assert_eq!(lit, Literal::Boolean(true));
289 assert!(literal_to_native::<bool>(&lit).unwrap());
290 assert!(!literal_to_native::<bool>(&Literal::Boolean(false)).unwrap());
291 }
292
293 #[test]
294 fn boolean_literal_rejects_integer_cast() {
295 let lit = Literal::Boolean(true);
296 let err = literal_to_native::<i32>(&lit).unwrap_err();
297 assert!(matches!(
298 err,
299 LiteralCastError::TypeMismatch {
300 expected: "integer",
301 got: "boolean",
302 }
303 ));
304 }
305}