use crate::error::ParseError;
use crate::selection::sql::generate_selection_sql;
use serde::{Deserialize, Serialize};
use ulid::Ulid;
mod json_as_bytes {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(value: &serde_json::Value, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer {
let bytes = serde_json::to_vec(value).map_err(serde::ser::Error::custom)?;
bytes.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<serde_json::Value, D::Error>
where D: Deserializer<'de> {
let bytes: Vec<u8> = Vec::deserialize(deserializer)?;
serde_json::from_slice(&bytes).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Expr {
Literal(Literal),
Path(PathExpr),
Predicate(Predicate),
InfixExpr { left: Box<Expr>, operator: InfixOperator, right: Box<Expr> },
ExprList(Vec<Expr>), Placeholder,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Literal {
I16(i16),
I32(i32),
I64(i64),
F64(f64),
Bool(bool),
String(String),
EntityId(Ulid),
Object(Vec<u8>),
Binary(Vec<u8>),
#[serde(with = "json_as_bytes")]
Json(serde_json::Value),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PathExpr {
pub steps: Vec<String>,
}
impl PathExpr {
pub fn simple(name: impl Into<String>) -> Self { Self { steps: vec![name.into()] } }
pub fn is_simple(&self) -> bool { self.steps.len() == 1 }
pub fn first(&self) -> &str { &self.steps[0] }
pub fn property(&self) -> &str { self.steps.last().expect("PathExpr must have at least one step") }
}
impl std::fmt::Display for PathExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.steps.join(".")) }
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Selection {
pub predicate: Predicate,
pub order_by: Option<Vec<OrderByItem>>,
pub limit: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OrderByItem {
pub path: PathExpr,
pub direction: OrderDirection,
}
impl std::fmt::Display for OrderByItem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} {}",
self.path,
match self.direction {
OrderDirection::Asc => "ASC",
OrderDirection::Desc => "DESC",
}
)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum OrderDirection {
Asc,
Desc,
}
impl std::fmt::Display for Selection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.predicate)?;
if let Some(order_by) = &self.order_by {
write!(f, " ORDER BY ")?;
for (i, item) in order_by.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", item)?;
}
}
if let Some(limit) = self.limit {
write!(f, " LIMIT {}", limit)?;
}
Ok(())
}
}
impl From<Predicate> for Selection {
fn from(predicate: Predicate) -> Self { Selection { predicate, order_by: None, limit: None } }
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Predicate {
Comparison { left: Box<Expr>, operator: ComparisonOperator, right: Box<Expr> },
IsNull(Box<Expr>),
And(Box<Predicate>, Box<Predicate>),
Or(Box<Predicate>, Box<Predicate>),
Not(Box<Predicate>),
True,
False,
Placeholder,
}
impl std::fmt::Display for Predicate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match generate_selection_sql(self, None) {
Ok(sql) => write!(f, "{}", sql),
Err(e) => write!(f, "SQL Error: {}", e),
}
}
}
impl Selection {
pub fn assume_null(&self, columns: &[String]) -> Self {
let order_by = self.order_by.as_ref().map(|items| {
items
.iter()
.filter(|item| {
let col_name = item.path.property();
!columns.contains(&col_name.to_string())
})
.cloned()
.collect::<Vec<_>>()
});
let order_by = order_by.and_then(|v| if v.is_empty() { None } else { Some(v) });
Self { predicate: self.predicate.assume_null(columns), order_by, limit: self.limit }
}
pub fn referenced_columns(&self) -> Vec<String> {
let mut columns = self.predicate.referenced_columns();
if let Some(order_by) = &self.order_by {
for item in order_by {
let col = item.path.first().to_string();
if !columns.contains(&col) {
columns.push(col);
}
}
}
columns
}
}
impl Predicate {
pub fn walk<T, F>(&self, accumulator: T, visitor: &mut F) -> T
where F: FnMut(T, &Predicate) -> T {
let accumulator = visitor(accumulator, self);
match self {
Predicate::And(left, right) | Predicate::Or(left, right) => {
let accumulator = left.walk(accumulator, visitor);
right.walk(accumulator, visitor)
}
Predicate::Not(inner) => inner.walk(accumulator, visitor),
_ => accumulator,
}
}
pub fn referenced_columns(&self) -> Vec<String> {
self.walk(Vec::new(), &mut |mut cols, pred| {
match pred {
Predicate::Comparison { left, right, .. } => {
for expr in [&**left, &**right] {
if let Expr::Path(path) = expr {
let col = path.first().to_string();
if !cols.contains(&col) {
cols.push(col);
}
}
}
}
Predicate::IsNull(expr) => {
if let Expr::Path(path) = &**expr {
let col = path.first().to_string();
if !cols.contains(&col) {
cols.push(col);
}
}
}
_ => {}
}
cols
})
}
pub fn assume_null(&self, columns: &[String]) -> Self {
match self {
Predicate::Comparison { left, operator, right } => {
let has_null_path = match (&**left, &**right) {
(Expr::Path(path), _) | (_, Expr::Path(path)) => columns.contains(&path.property().to_string()),
_ => false,
};
if has_null_path {
match operator {
ComparisonOperator::Equal => Predicate::False,
ComparisonOperator::NotEqual => Predicate::False,
ComparisonOperator::GreaterThan => Predicate::False,
ComparisonOperator::GreaterThanOrEqual => Predicate::False,
ComparisonOperator::LessThan => Predicate::False,
ComparisonOperator::LessThanOrEqual => Predicate::False,
ComparisonOperator::In => Predicate::False,
ComparisonOperator::Between => Predicate::False,
}
} else {
Predicate::Comparison { left: left.clone(), operator: operator.clone(), right: right.clone() }
}
}
Predicate::IsNull(expr) => {
match &**expr {
Expr::Path(path) => {
let is_null = columns.contains(&path.property().to_string());
if is_null {
Predicate::True
} else {
Predicate::IsNull(expr.clone())
}
}
_ => Predicate::IsNull(expr.clone()),
}
}
Predicate::And(left, right) => {
let left = left.assume_null(columns);
let right = right.assume_null(columns);
match (&left, &right) {
(Predicate::False, _) | (_, Predicate::False) => Predicate::False,
(Predicate::True, Predicate::True) => Predicate::True,
(Predicate::True, p) | (p, Predicate::True) => p.clone(),
_ => Predicate::And(Box::new(left), Box::new(right)),
}
}
Predicate::Or(left, right) => {
let left = left.assume_null(columns);
let right = right.assume_null(columns);
match (&left, &right) {
(Predicate::True, _) | (_, Predicate::True) => Predicate::True,
(Predicate::False, Predicate::False) => Predicate::False,
(Predicate::False, p) | (p, Predicate::False) => p.clone(),
_ => Predicate::Or(Box::new(left), Box::new(right)),
}
}
Predicate::Not(pred) => {
let inner = pred.assume_null(columns);
match inner {
Predicate::True => Predicate::False,
Predicate::False => Predicate::True,
_ => Predicate::Not(Box::new(inner)),
}
}
Predicate::True => Predicate::True,
Predicate::False => Predicate::False,
Predicate::Placeholder => Predicate::Placeholder,
}
}
pub fn populate<I, V, E>(self, values: I) -> Result<Predicate, ParseError>
where
I: IntoIterator<Item = V>,
V: TryInto<Expr, Error = E>,
E: Into<ParseError>,
{
let mut values_iter = values.into_iter();
let result = self.populate_recursive(&mut values_iter)?;
if values_iter.next().is_some() {
return Err(ParseError::InvalidPredicate("Too many values provided for placeholders".to_string()));
}
Ok(result)
}
fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Predicate, ParseError>
where
I: Iterator<Item = V>,
V: TryInto<Expr, Error = E>,
E: Into<ParseError>,
{
match self {
Predicate::Comparison { left, operator, right } => Ok(Predicate::Comparison {
left: Box::new(left.populate_recursive(values)?),
operator,
right: Box::new(right.populate_recursive(values)?),
}),
Predicate::And(left, right) => {
Ok(Predicate::And(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
}
Predicate::Or(left, right) => {
Ok(Predicate::Or(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
}
Predicate::Not(pred) => Ok(Predicate::Not(Box::new(pred.populate_recursive(values)?))),
Predicate::IsNull(expr) => Ok(Predicate::IsNull(Box::new(expr.populate_recursive(values)?))),
Predicate::True => Ok(Predicate::True),
Predicate::False => Ok(Predicate::False),
Predicate::Placeholder => Err(ParseError::InvalidPredicate("Placeholder must be transformed before population".to_string())),
}
}
}
impl Expr {
fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Expr, ParseError>
where
I: Iterator<Item = V>,
V: TryInto<Expr, Error = E>,
E: Into<ParseError>,
{
match self {
Expr::Placeholder => match values.next() {
Some(value) => Ok(value.try_into().map_err(|e| e.into())?),
None => Err(ParseError::InvalidPredicate("Not enough values provided for placeholders".to_string())),
},
Expr::Literal(lit) => Ok(Expr::Literal(lit)),
Expr::Path(path) => Ok(Expr::Path(path)),
Expr::Predicate(pred) => Ok(Expr::Predicate(pred.populate_recursive(values)?)),
Expr::InfixExpr { left, operator, right } => Ok(Expr::InfixExpr {
left: Box::new(left.populate_recursive(values)?),
operator,
right: Box::new(right.populate_recursive(values)?),
}),
Expr::ExprList(exprs) => {
let mut populated_exprs = Vec::new();
for expr in exprs {
populated_exprs.push(expr.populate_recursive(values)?);
}
Ok(Expr::ExprList(populated_exprs))
}
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ComparisonOperator {
Equal, NotEqual, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, In, Between, }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum InfixOperator {
Add,
Subtract,
Multiply,
Divide,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_selection;
fn nullify_columns(input: &str, null_columns: &[&str]) -> Result<String, ParseError> {
let selection = parse_selection(input)?;
let result = selection.predicate.assume_null(&null_columns.iter().map(|s| s.to_string()).collect::<Vec<_>>());
generate_selection_sql(&result, None).map_err(|_| ParseError::InvalidPredicate("SQL generation failed".to_string()))
}
#[test]
fn test_single_comparison_null_handling() {
assert_eq!(nullify_columns("status = 'active'", &["status"]).unwrap(), "FALSE");
assert_eq!(nullify_columns("age > 30", &["age"]).unwrap(), "FALSE");
assert_eq!(nullify_columns("count >= 100", &["count"]).unwrap(), "FALSE");
assert_eq!(nullify_columns("name < 'Z'", &["name"]).unwrap(), "FALSE");
assert_eq!(nullify_columns("score <= 90", &["score"]).unwrap(), "FALSE");
assert_eq!(nullify_columns("status IS NULL", &["status"]).unwrap(), "TRUE");
assert_eq!(nullify_columns("role = 'admin'", &["other"]).unwrap(), r#""role" = 'admin'"#);
}
#[test]
fn nested_predicate_null_handling() {
let input = "alpha = 1 AND (beta = 2 OR charlie = 3)";
assert_eq!(nullify_columns(input, &["charlie"]).unwrap(), r#""alpha" = 1 AND "beta" = 2"#);
assert_eq!(nullify_columns(input, &["beta", "charlie"]).unwrap(), r#"FALSE"#);
assert_eq!(nullify_columns(input, &["alpha"]).unwrap(), r#"FALSE"#);
assert_eq!(nullify_columns(input, &["other"]).unwrap(), r#""alpha" = 1 AND ("beta" = 2 OR "charlie" = 3)"#);
}
#[test]
fn test_populate_single_placeholder() {
let selection = parse_selection("name = ?").unwrap();
let populated = selection.predicate.populate(vec!["Alice"]).unwrap();
let expected = Predicate::Comparison {
left: Box::new(Expr::Path(PathExpr::simple("name".to_string()))),
operator: ComparisonOperator::Equal,
right: Box::new(Expr::Literal(Literal::String("Alice".to_string()))),
};
assert_eq!(populated, expected);
}
#[test]
fn test_populate_multiple_placeholders() {
let selection = parse_selection("age > ? AND name = ?").unwrap();
let values: Vec<Expr> = vec![25i64.into(), "Bob".into()];
let populated = selection.predicate.populate(values).unwrap();
let expected = Predicate::And(
Box::new(Predicate::Comparison {
left: Box::new(Expr::Path(PathExpr::simple("age".to_string()))),
operator: ComparisonOperator::GreaterThan,
right: Box::new(Expr::Literal(Literal::I64(25))),
}),
Box::new(Predicate::Comparison {
left: Box::new(Expr::Path(PathExpr::simple("name".to_string()))),
operator: ComparisonOperator::Equal,
right: Box::new(Expr::Literal(Literal::String("Bob".to_string()))),
}),
);
assert_eq!(populated, expected);
}
#[test]
fn test_populate_in_clause() {
let selection = parse_selection("status IN (?, ?, ?)").unwrap();
let populated = selection.predicate.populate(vec!["active", "pending", "review"]).unwrap();
let expected = Predicate::Comparison {
left: Box::new(Expr::Path(PathExpr::simple("status".to_string()))),
operator: ComparisonOperator::In,
right: Box::new(Expr::ExprList(vec![
Expr::Literal(Literal::String("active".to_string())),
Expr::Literal(Literal::String("pending".to_string())),
Expr::Literal(Literal::String("review".to_string())),
])),
};
assert_eq!(populated, expected);
}
#[test]
fn test_populate_mixed_types() {
let selection = parse_selection("active = ? AND score > ? AND name = ?").unwrap();
let values: Vec<Expr> = vec![true.into(), 95.5f64.into(), "Charlie".into()];
let populated = selection.predicate.populate(values).unwrap();
if let Predicate::And(left, right) = populated {
if let Predicate::And(inner_left, inner_right) = *left {
if let Predicate::Comparison { right: val, .. } = *inner_left {
assert_eq!(*val, Expr::Literal(Literal::Bool(true)));
}
if let Predicate::Comparison { right: val, .. } = *inner_right {
assert_eq!(*val, Expr::Literal(Literal::F64(95.5)));
}
}
if let Predicate::Comparison { right: val, .. } = *right {
assert_eq!(*val, Expr::Literal(Literal::String("Charlie".to_string())));
}
}
}
#[test]
fn test_populate_too_few_values() {
let selection = parse_selection("name = ? AND age = ?").unwrap();
let result = selection.predicate.populate(vec!["Alice"]);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Not enough values"));
}
#[test]
fn test_populate_too_many_values() {
let selection = parse_selection("name = ?").unwrap();
let result = selection.predicate.populate(vec!["Alice", "Bob"]);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Too many values"));
}
#[test]
fn test_populate_no_placeholders() {
let selection = parse_selection("name = 'Alice'").unwrap();
let populated = selection.clone().predicate.populate(Vec::<String>::new()).unwrap();
assert_eq!(populated, selection.predicate);
}
}
impl From<String> for Expr {
fn from(s: String) -> Expr { Expr::Literal(Literal::String(s)) }
}
impl From<&str> for Expr {
fn from(s: &str) -> Expr { Expr::Literal(Literal::String(s.to_string())) }
}
impl From<i64> for Expr {
fn from(i: i64) -> Expr { Expr::Literal(Literal::I64(i)) }
}
impl From<f64> for Expr {
fn from(f: f64) -> Expr { Expr::Literal(Literal::F64(f)) }
}
impl From<bool> for Expr {
fn from(b: bool) -> Expr { Expr::Literal(Literal::Bool(b)) }
}
impl From<Literal> for Expr {
fn from(lit: Literal) -> Expr { Expr::Literal(lit) }
}
impl<T> From<Vec<T>> for Expr
where T: Into<Expr>
{
fn from(vec: Vec<T>) -> Self { Expr::ExprList(vec.into_iter().map(|item| item.into()).collect()) }
}
impl<T, const N: usize> From<[T; N]> for Expr
where T: Into<Expr>
{
fn from(arr: [T; N]) -> Self { Expr::ExprList(arr.into_iter().map(|item| item.into()).collect()) }
}
impl<T> From<&[T]> for Expr
where T: Into<Expr> + Clone
{
fn from(slice: &[T]) -> Self { Expr::ExprList(slice.iter().map(|item| item.clone().into()).collect()) }
}
impl<T, const N: usize> From<&[T; N]> for Expr
where T: Into<Expr> + Clone
{
fn from(arr: &[T; N]) -> Self { Expr::ExprList(arr.iter().map(|item| item.clone().into()).collect()) }
}