ankql/
ast.rs

1use crate::error::ParseError;
2use crate::selection::sql::generate_selection_sql;
3use serde::{Deserialize, Serialize};
4use ulid::Ulid;
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub enum Expr {
8    Literal(Literal),
9    Identifier(Identifier),
10    Predicate(Predicate),
11    InfixExpr { left: Box<Expr>, operator: InfixOperator, right: Box<Expr> },
12    ExprList(Vec<Expr>), // New variant for handling lists like (1,2,3) in IN clauses
13    Placeholder,
14}
15
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum Literal {
18    I16(i16),
19    I32(i32),
20    I64(i64),
21    F64(f64),
22    Bool(bool),
23    String(String),
24    EntityId(Ulid),
25    Object(Vec<u8>),
26    Binary(Vec<u8>),
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub enum Identifier {
31    Property(String),
32    CollectionProperty(String, String),
33}
34
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub struct Selection {
37    pub predicate: Predicate,
38    pub order_by: Option<Vec<OrderByItem>>,
39    pub limit: Option<u64>,
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub struct OrderByItem {
44    pub identifier: Identifier,
45    pub direction: OrderDirection,
46}
47
48impl std::fmt::Display for OrderByItem {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        let field = match &self.identifier {
51            Identifier::Property(prop) => prop.clone(),
52            Identifier::CollectionProperty(coll, prop) => format!("{}.{}", coll, prop),
53        };
54        write!(
55            f,
56            "{} {}",
57            field,
58            match self.direction {
59                OrderDirection::Asc => "ASC",
60                OrderDirection::Desc => "DESC",
61            }
62        )
63    }
64}
65
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub enum OrderDirection {
68    Asc,
69    Desc,
70}
71
72impl std::fmt::Display for Selection {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "{}", self.predicate)?;
75        if let Some(order_by) = &self.order_by {
76            write!(f, " ORDER BY ")?;
77            for (i, item) in order_by.iter().enumerate() {
78                if i > 0 {
79                    write!(f, ", ")?;
80                }
81                write!(f, "{}", item)?;
82            }
83        }
84        if let Some(limit) = self.limit {
85            write!(f, " LIMIT {}", limit)?;
86        }
87        Ok(())
88    }
89}
90
91// Backward compatibility
92impl From<Predicate> for Selection {
93    fn from(predicate: Predicate) -> Self { Selection { predicate, order_by: None, limit: None } }
94}
95
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub enum Predicate {
98    Comparison { left: Box<Expr>, operator: ComparisonOperator, right: Box<Expr> },
99    IsNull(Box<Expr>),
100    And(Box<Predicate>, Box<Predicate>),
101    Or(Box<Predicate>, Box<Predicate>),
102    Not(Box<Predicate>),
103    True,
104    False,
105    Placeholder,
106}
107
108impl std::fmt::Display for Predicate {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        match generate_selection_sql(self, None) {
111            Ok(sql) => write!(f, "{}", sql),
112            Err(e) => write!(f, "SQL Error: {}", e),
113        }
114    }
115}
116
117impl Selection {
118    pub fn assume_null(&self, columns: &[String]) -> Self {
119        Self { predicate: self.predicate.assume_null(columns), order_by: self.order_by.clone(), limit: self.limit }
120    }
121}
122
123impl Predicate {
124    /// Recursively walk a predicate tree and accumulate results using a closure
125    pub fn walk<T, F>(&self, accumulator: T, visitor: &mut F) -> T
126    where F: FnMut(T, &Predicate) -> T {
127        let accumulator = visitor(accumulator, self);
128        match self {
129            Predicate::And(left, right) | Predicate::Or(left, right) => {
130                let accumulator = left.walk(accumulator, visitor);
131                right.walk(accumulator, visitor)
132            }
133            Predicate::Not(inner) => inner.walk(accumulator, visitor),
134            _ => accumulator,
135        }
136    }
137
138    /// Clones the predicate tree and evaluates comparisons involving missing columns as if they were NULL
139    pub fn assume_null(&self, columns: &[String]) -> Self {
140        match self {
141            Predicate::Comparison { left, operator, right } => {
142                // Check if either side is an identifier that's in our null list
143                let has_null_identifier = match (&**left, &**right) {
144                    (Expr::Identifier(id), _) | (_, Expr::Identifier(id)) => match id {
145                        Identifier::Property(name) => columns.contains(name),
146                        Identifier::CollectionProperty(_, name) => columns.contains(name),
147                    },
148                    _ => false,
149                };
150
151                if has_null_identifier {
152                    match operator {
153                        // NULL = anything is false
154                        ComparisonOperator::Equal => Predicate::False,
155                        // NULL != anything is false (NULL comparisons always return NULL in SQL)
156                        ComparisonOperator::NotEqual => Predicate::False,
157                        // NULL > anything is false
158                        ComparisonOperator::GreaterThan => Predicate::False,
159                        // NULL >= anything is false
160                        ComparisonOperator::GreaterThanOrEqual => Predicate::False,
161                        // NULL < anything is false
162                        ComparisonOperator::LessThan => Predicate::False,
163                        // NULL <= anything is false
164                        ComparisonOperator::LessThanOrEqual => Predicate::False,
165                        // NULL IN (...) is false
166                        ComparisonOperator::In => Predicate::False,
167                        // NULL BETWEEN ... is false
168                        ComparisonOperator::Between => Predicate::False,
169                    }
170                } else {
171                    // No NULL identifiers, keep the comparison as is
172                    Predicate::Comparison { left: left.clone(), operator: operator.clone(), right: right.clone() }
173                }
174            }
175            Predicate::IsNull(expr) => {
176                // If we're explicitly checking for NULL and the identifier is in our null list,
177                // then this evaluates to true
178                match &**expr {
179                    Expr::Identifier(id) => {
180                        let is_null = match id {
181                            Identifier::Property(name) => columns.contains(name),
182                            Identifier::CollectionProperty(_, name) => columns.contains(name),
183                        };
184                        if is_null {
185                            Predicate::True
186                        } else {
187                            Predicate::IsNull(expr.clone())
188                        }
189                    }
190                    _ => Predicate::IsNull(expr.clone()),
191                }
192            }
193            Predicate::And(left, right) => {
194                let left = left.assume_null(columns);
195                let right = right.assume_null(columns);
196
197                // Optimize
198                match (&left, &right) {
199                    // if either side is false, the whole thing is false
200                    (Predicate::False, _) | (_, Predicate::False) => Predicate::False,
201                    // if both sides are true, the whole thing is true
202                    (Predicate::True, Predicate::True) => Predicate::True,
203                    // if one side is true, the whole thing is the other side
204                    (Predicate::True, p) | (p, Predicate::True) => p.clone(),
205                    _ => Predicate::And(Box::new(left), Box::new(right)),
206                }
207            }
208            Predicate::Or(left, right) => {
209                let left = left.assume_null(columns);
210                let right = right.assume_null(columns);
211
212                // Optimize
213                match (&left, &right) {
214                    // if either side is true, the whole thing is true
215                    (Predicate::True, _) | (_, Predicate::True) => Predicate::True,
216                    // if both sides are false, the whole thing is false
217                    (Predicate::False, Predicate::False) => Predicate::False,
218                    // if one side is false, the whole thing is the other side
219                    (Predicate::False, p) | (p, Predicate::False) => p.clone(),
220                    // otherwise, keep the original
221                    _ => Predicate::Or(Box::new(left), Box::new(right)),
222                }
223            }
224            Predicate::Not(pred) => {
225                let inner = pred.assume_null(columns);
226                match inner {
227                    Predicate::True => Predicate::False,
228                    Predicate::False => Predicate::True,
229                    _ => Predicate::Not(Box::new(inner)),
230                }
231            }
232            // These are constants, just clone them
233            Predicate::True => Predicate::True,
234            Predicate::False => Predicate::False,
235            Predicate::Placeholder => Predicate::Placeholder,
236        }
237    }
238
239    /// Populate placeholders in the predicate with actual values
240    pub fn populate<I, V, E>(self, values: I) -> Result<Predicate, ParseError>
241    where
242        I: IntoIterator<Item = V>,
243        V: TryInto<Expr, Error = E>,
244        E: Into<ParseError>,
245    {
246        let mut values_iter = values.into_iter();
247        let result = self.populate_recursive(&mut values_iter)?;
248
249        // Check if there are any unused values
250        if values_iter.next().is_some() {
251            return Err(ParseError::InvalidPredicate("Too many values provided for placeholders".to_string()));
252        }
253
254        Ok(result)
255    }
256
257    fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Predicate, ParseError>
258    where
259        I: Iterator<Item = V>,
260        V: TryInto<Expr, Error = E>,
261        E: Into<ParseError>,
262    {
263        match self {
264            Predicate::Comparison { left, operator, right } => Ok(Predicate::Comparison {
265                left: Box::new(left.populate_recursive(values)?),
266                operator,
267                right: Box::new(right.populate_recursive(values)?),
268            }),
269            Predicate::And(left, right) => {
270                Ok(Predicate::And(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
271            }
272            Predicate::Or(left, right) => {
273                Ok(Predicate::Or(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
274            }
275            Predicate::Not(pred) => Ok(Predicate::Not(Box::new(pred.populate_recursive(values)?))),
276            Predicate::IsNull(expr) => Ok(Predicate::IsNull(Box::new(expr.populate_recursive(values)?))),
277            Predicate::True => Ok(Predicate::True),
278            Predicate::False => Ok(Predicate::False),
279            // Placeholder should be transformed to a comparison before population
280            Predicate::Placeholder => Err(ParseError::InvalidPredicate("Placeholder must be transformed before population".to_string())),
281        }
282    }
283}
284
285impl Expr {
286    fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Expr, ParseError>
287    where
288        I: Iterator<Item = V>,
289        V: TryInto<Expr, Error = E>,
290        E: Into<ParseError>,
291    {
292        match self {
293            Expr::Placeholder => match values.next() {
294                Some(value) => Ok(value.try_into().map_err(|e| e.into())?),
295                None => Err(ParseError::InvalidPredicate("Not enough values provided for placeholders".to_string())),
296            },
297            Expr::Literal(lit) => Ok(Expr::Literal(lit)),
298            Expr::Identifier(id) => Ok(Expr::Identifier(id)),
299            Expr::Predicate(pred) => Ok(Expr::Predicate(pred.populate_recursive(values)?)),
300            Expr::InfixExpr { left, operator, right } => Ok(Expr::InfixExpr {
301                left: Box::new(left.populate_recursive(values)?),
302                operator,
303                right: Box::new(right.populate_recursive(values)?),
304            }),
305            Expr::ExprList(exprs) => {
306                let mut populated_exprs = Vec::new();
307                for expr in exprs {
308                    populated_exprs.push(expr.populate_recursive(values)?);
309                }
310                Ok(Expr::ExprList(populated_exprs))
311            }
312        }
313    }
314}
315
316#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
317pub enum ComparisonOperator {
318    Equal,              // =
319    NotEqual,           // <> or !=
320    GreaterThan,        // >
321    GreaterThanOrEqual, // >=
322    LessThan,           // <
323    LessThanOrEqual,    // <=
324    In,                 // IN
325    Between,            // BETWEEN
326}
327
328#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
329pub enum InfixOperator {
330    Add,
331    Subtract,
332    Multiply,
333    Divide,
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::parser::parse_selection;
340
341    fn nullify_columns(input: &str, null_columns: &[&str]) -> Result<String, ParseError> {
342        let selection = parse_selection(input)?;
343        let result = selection.predicate.assume_null(&null_columns.iter().map(|s| s.to_string()).collect::<Vec<_>>());
344        generate_selection_sql(&result, None).map_err(|_| ParseError::InvalidPredicate("SQL generation failed".to_string()))
345    }
346
347    #[test]
348    fn test_single_comparison_null_handling() {
349        assert_eq!(nullify_columns("status = 'active'", &["status"]).unwrap(), "FALSE");
350        assert_eq!(nullify_columns("age > 30", &["age"]).unwrap(), "FALSE");
351        assert_eq!(nullify_columns("count >= 100", &["count"]).unwrap(), "FALSE");
352        assert_eq!(nullify_columns("name < 'Z'", &["name"]).unwrap(), "FALSE");
353        assert_eq!(nullify_columns("score <= 90", &["score"]).unwrap(), "FALSE");
354        assert_eq!(nullify_columns("status IS NULL", &["status"]).unwrap(), "TRUE");
355        assert_eq!(nullify_columns("role = 'admin'", &["other"]).unwrap(), r#""role" = 'admin'"#);
356    }
357
358    #[test]
359    fn nested_predicate_null_handling() {
360        let input = "alpha = 1 AND (beta = 2 OR charlie = 3)";
361        assert_eq!(nullify_columns(input, &["charlie"]).unwrap(), r#""alpha" = 1 AND "beta" = 2"#);
362        assert_eq!(nullify_columns(input, &["beta", "charlie"]).unwrap(), r#"FALSE"#);
363        assert_eq!(nullify_columns(input, &["alpha"]).unwrap(), r#"FALSE"#);
364        assert_eq!(nullify_columns(input, &["other"]).unwrap(), r#""alpha" = 1 AND ("beta" = 2 OR "charlie" = 3)"#);
365    }
366
367    #[test]
368    fn test_populate_single_placeholder() {
369        let selection = parse_selection("name = ?").unwrap();
370        let populated = selection.predicate.populate(vec!["Alice"]).unwrap();
371
372        let expected = Predicate::Comparison {
373            left: Box::new(Expr::Identifier(Identifier::Property("name".to_string()))),
374            operator: ComparisonOperator::Equal,
375            right: Box::new(Expr::Literal(Literal::String("Alice".to_string()))),
376        };
377
378        assert_eq!(populated, expected);
379    }
380
381    #[test]
382    fn test_populate_multiple_placeholders() {
383        let selection = parse_selection("age > ? AND name = ?").unwrap();
384        let values: Vec<Expr> = vec![25i64.into(), "Bob".into()];
385        let populated = selection.predicate.populate(values).unwrap();
386
387        let expected = Predicate::And(
388            Box::new(Predicate::Comparison {
389                left: Box::new(Expr::Identifier(Identifier::Property("age".to_string()))),
390                operator: ComparisonOperator::GreaterThan,
391                right: Box::new(Expr::Literal(Literal::I64(25))),
392            }),
393            Box::new(Predicate::Comparison {
394                left: Box::new(Expr::Identifier(Identifier::Property("name".to_string()))),
395                operator: ComparisonOperator::Equal,
396                right: Box::new(Expr::Literal(Literal::String("Bob".to_string()))),
397            }),
398        );
399
400        assert_eq!(populated, expected);
401    }
402
403    #[test]
404    fn test_populate_in_clause() {
405        let selection = parse_selection("status IN (?, ?, ?)").unwrap();
406        let populated = selection.predicate.populate(vec!["active", "pending", "review"]).unwrap();
407
408        let expected = Predicate::Comparison {
409            left: Box::new(Expr::Identifier(Identifier::Property("status".to_string()))),
410            operator: ComparisonOperator::In,
411            right: Box::new(Expr::ExprList(vec![
412                Expr::Literal(Literal::String("active".to_string())),
413                Expr::Literal(Literal::String("pending".to_string())),
414                Expr::Literal(Literal::String("review".to_string())),
415            ])),
416        };
417
418        assert_eq!(populated, expected);
419    }
420
421    #[test]
422    fn test_populate_mixed_types() {
423        let selection = parse_selection("active = ? AND score > ? AND name = ?").unwrap();
424        let values: Vec<Expr> = vec![true.into(), 95.5f64.into(), "Charlie".into()];
425        let populated = selection.predicate.populate(values).unwrap();
426
427        // Verify the structure is correct
428        if let Predicate::And(left, right) = populated {
429            if let Predicate::And(inner_left, inner_right) = *left {
430                // Check boolean value
431                if let Predicate::Comparison { right: val, .. } = *inner_left {
432                    assert_eq!(*val, Expr::Literal(Literal::Bool(true)));
433                }
434                // Check float value
435                if let Predicate::Comparison { right: val, .. } = *inner_right {
436                    assert_eq!(*val, Expr::Literal(Literal::F64(95.5)));
437                }
438            }
439            // Check string value
440            if let Predicate::Comparison { right: val, .. } = *right {
441                assert_eq!(*val, Expr::Literal(Literal::String("Charlie".to_string())));
442            }
443        }
444    }
445
446    #[test]
447    fn test_populate_too_few_values() {
448        let selection = parse_selection("name = ? AND age = ?").unwrap();
449        let result = selection.predicate.populate(vec!["Alice"]);
450
451        assert!(result.is_err());
452        assert!(result.unwrap_err().to_string().contains("Not enough values"));
453    }
454
455    #[test]
456    fn test_populate_too_many_values() {
457        let selection = parse_selection("name = ?").unwrap();
458        let result = selection.predicate.populate(vec!["Alice", "Bob"]);
459
460        assert!(result.is_err());
461        assert!(result.unwrap_err().to_string().contains("Too many values"));
462    }
463
464    #[test]
465    fn test_populate_no_placeholders() {
466        let selection = parse_selection("name = 'Alice'").unwrap();
467        let populated = selection.clone().predicate.populate(Vec::<String>::new()).unwrap();
468
469        // Should be unchanged
470        assert_eq!(populated, selection.predicate);
471    }
472}
473
474// From implementations for single values that wrap them in Expr::Literal
475impl From<String> for Expr {
476    fn from(s: String) -> Expr { Expr::Literal(Literal::String(s)) }
477}
478
479impl From<&str> for Expr {
480    fn from(s: &str) -> Expr { Expr::Literal(Literal::String(s.to_string())) }
481}
482
483impl From<i64> for Expr {
484    fn from(i: i64) -> Expr { Expr::Literal(Literal::I64(i)) }
485}
486
487impl From<f64> for Expr {
488    fn from(f: f64) -> Expr { Expr::Literal(Literal::F64(f)) }
489}
490
491impl From<bool> for Expr {
492    fn from(b: bool) -> Expr { Expr::Literal(Literal::Bool(b)) }
493}
494
495impl From<Literal> for Expr {
496    fn from(lit: Literal) -> Expr { Expr::Literal(lit) }
497}
498
499// These create Expr::ExprList for use in IN clauses
500impl<T> From<Vec<T>> for Expr
501where T: Into<Expr>
502{
503    fn from(vec: Vec<T>) -> Self { Expr::ExprList(vec.into_iter().map(|item| item.into()).collect()) }
504}
505
506impl<T, const N: usize> From<[T; N]> for Expr
507where T: Into<Expr>
508{
509    fn from(arr: [T; N]) -> Self { Expr::ExprList(arr.into_iter().map(|item| item.into()).collect()) }
510}
511
512impl<T> From<&[T]> for Expr
513where T: Into<Expr> + Clone
514{
515    fn from(slice: &[T]) -> Self { Expr::ExprList(slice.iter().map(|item| item.clone().into()).collect()) }
516}
517
518impl<T, const N: usize> From<&[T; N]> for Expr
519where T: Into<Expr> + Clone
520{
521    fn from(arr: &[T; N]) -> Self { Expr::ExprList(arr.iter().map(|item| item.clone().into()).collect()) }
522}