use arrow::datatypes::ArrowPrimitiveType;
use std::ops::Bound;
#[derive(Debug, Clone, PartialEq)]
pub enum Literal {
Null,
Integer(i128),
Float(f64),
String(String),
Boolean(bool),
Struct(Vec<(String, Box<Literal>)>),
}
macro_rules! impl_from_for_literal {
($variant:ident, $($t:ty),*) => {
$(
impl From<$t> for Literal {
fn from(v: $t) -> Self {
Literal::$variant(v.into())
}
}
)*
};
}
impl_from_for_literal!(Integer, i8, i16, i32, i64, i128, u8, u16, u32, u64);
impl_from_for_literal!(Float, f32, f64);
impl From<&str> for Literal {
fn from(v: &str) -> Self {
Literal::String(v.to_string())
}
}
impl From<bool> for Literal {
fn from(v: bool) -> Self {
Literal::Boolean(v)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum LiteralCastError {
TypeMismatch {
expected: &'static str,
got: &'static str,
},
OutOfRange { target: &'static str, value: i128 },
FloatOutOfRange { target: &'static str, value: f64 },
}
impl std::fmt::Display for LiteralCastError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LiteralCastError::TypeMismatch { expected, got } => {
write!(f, "expected {}, got {}", expected, got)
}
LiteralCastError::OutOfRange { target, value } => {
write!(f, "value {} out of range for {}", value, target)
}
LiteralCastError::FloatOutOfRange { target, value } => {
write!(f, "value {} out of range for {}", value, target)
}
}
}
}
impl std::error::Error for LiteralCastError {}
pub trait FromLiteral: Sized {
fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError>;
}
macro_rules! impl_from_literal_int {
($($ty:ty),* $(,)?) => {
$(
impl FromLiteral for $ty {
fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
match lit {
Literal::Integer(i) => <$ty>::try_from(*i).map_err(|_| {
LiteralCastError::OutOfRange {
target: std::any::type_name::<$ty>(),
value: *i,
}
}),
Literal::Float(_) => Err(LiteralCastError::TypeMismatch {
expected: "integer",
got: "float",
}),
Literal::Boolean(_) => Err(LiteralCastError::TypeMismatch {
expected: "integer",
got: "boolean",
}),
Literal::String(_) => Err(LiteralCastError::TypeMismatch {
expected: "integer",
got: "string",
}),
Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
expected: "integer",
got: "struct",
}),
Literal::Null => Err(LiteralCastError::TypeMismatch {
expected: "integer",
got: "null",
}),
}
}
}
)*
};
}
impl_from_literal_int!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, usize);
impl FromLiteral for f32 {
fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
let value = match lit {
Literal::Float(f) => *f,
Literal::Integer(i) => *i as f64,
Literal::Boolean(_) => {
return Err(LiteralCastError::TypeMismatch {
expected: "float",
got: "boolean",
});
}
Literal::String(_) => {
return Err(LiteralCastError::TypeMismatch {
expected: "float",
got: "string",
});
}
Literal::Struct(_) => {
return Err(LiteralCastError::TypeMismatch {
expected: "float",
got: "struct",
});
}
Literal::Null => {
return Err(LiteralCastError::TypeMismatch {
expected: "float",
got: "null",
});
}
};
let cast = value as f32;
if value.is_finite() && !cast.is_finite() {
return Err(LiteralCastError::FloatOutOfRange {
target: "f32",
value,
});
}
Ok(cast)
}
}
impl FromLiteral for bool {
fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
match lit {
Literal::Boolean(b) => Ok(*b),
Literal::Integer(i) => match *i {
0 => Ok(false),
1 => Ok(true),
value => Err(LiteralCastError::OutOfRange {
target: "bool",
value,
}),
},
Literal::Float(_) => Err(LiteralCastError::TypeMismatch {
expected: "bool",
got: "float",
}),
Literal::String(s) => {
let normalized = s.trim().to_ascii_lowercase();
match normalized.as_str() {
"true" | "t" | "1" => Ok(true),
"false" | "f" | "0" => Ok(false),
_ => Err(LiteralCastError::TypeMismatch {
expected: "bool",
got: "string",
}),
}
}
Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
expected: "bool",
got: "struct",
}),
Literal::Null => Err(LiteralCastError::TypeMismatch {
expected: "bool",
got: "null",
}),
}
}
}
impl FromLiteral for f64 {
fn from_literal(lit: &Literal) -> Result<Self, LiteralCastError> {
match lit {
Literal::Float(f) => Ok(*f),
Literal::Integer(i) => Ok(*i as f64),
Literal::Boolean(_) => Err(LiteralCastError::TypeMismatch {
expected: "float",
got: "boolean",
}),
Literal::String(_) => Err(LiteralCastError::TypeMismatch {
expected: "float",
got: "string",
}),
Literal::Struct(_) => Err(LiteralCastError::TypeMismatch {
expected: "float",
got: "struct",
}),
Literal::Null => Err(LiteralCastError::TypeMismatch {
expected: "float",
got: "null",
}),
}
}
}
fn literal_type_name(lit: &Literal) -> &'static str {
match lit {
Literal::Integer(_) => "integer",
Literal::Float(_) => "float",
Literal::String(_) => "string",
Literal::Boolean(_) => "boolean",
Literal::Null => "null",
Literal::Struct(_) => "struct",
}
}
pub fn literal_to_string(lit: &Literal) -> Result<String, LiteralCastError> {
match lit {
Literal::String(s) => Ok(s.clone()),
Literal::Null => Err(LiteralCastError::TypeMismatch {
expected: "string",
got: "null",
}),
_ => Err(LiteralCastError::TypeMismatch {
expected: "string",
got: literal_type_name(lit),
}),
}
}
pub fn literal_to_native<T>(lit: &Literal) -> Result<T, LiteralCastError>
where
T: FromLiteral + Copy + 'static,
{
T::from_literal(lit)
}
pub fn bound_to_native<T>(bound: &Bound<Literal>) -> Result<Bound<T::Native>, LiteralCastError>
where
T: ArrowPrimitiveType,
T::Native: FromLiteral + Copy,
{
Ok(match bound {
Bound::Unbounded => Bound::Unbounded,
Bound::Included(l) => Bound::Included(literal_to_native::<T::Native>(l)?),
Bound::Excluded(l) => Bound::Excluded(literal_to_native::<T::Native>(l)?),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn boolean_literal_roundtrip() {
let lit = Literal::from(true);
assert_eq!(lit, Literal::Boolean(true));
assert!(literal_to_native::<bool>(&lit).unwrap());
assert!(!literal_to_native::<bool>(&Literal::Boolean(false)).unwrap());
}
#[test]
fn boolean_literal_rejects_integer_cast() {
let lit = Literal::Boolean(true);
let err = literal_to_native::<i32>(&lit).unwrap_err();
assert!(matches!(
err,
LiteralCastError::TypeMismatch {
expected: "integer",
got: "boolean",
}
));
}
}