Skip to main content

ankql/
ast.rs

1use crate::error::ParseError;
2use crate::selection::sql::generate_selection_sql;
3use serde::{Deserialize, Serialize};
4use ulid::Ulid;
5
6/// Custom serialization for serde_json::Value that stores as bytes.
7/// This is needed because bincode doesn't support deserialize_any.
8mod json_as_bytes {
9    use serde::{Deserialize, Deserializer, Serialize, Serializer};
10
11    pub fn serialize<S>(value: &serde_json::Value, serializer: S) -> Result<S::Ok, S::Error>
12    where S: Serializer {
13        let bytes = serde_json::to_vec(value).map_err(serde::ser::Error::custom)?;
14        bytes.serialize(serializer)
15    }
16
17    pub fn deserialize<'de, D>(deserializer: D) -> Result<serde_json::Value, D::Error>
18    where D: Deserializer<'de> {
19        let bytes: Vec<u8> = Vec::deserialize(deserializer)?;
20        serde_json::from_slice(&bytes).map_err(serde::de::Error::custom)
21    }
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub enum Expr {
26    Literal(Literal),
27    Path(PathExpr),
28    Predicate(Predicate),
29    InfixExpr { left: Box<Expr>, operator: InfixOperator, right: Box<Expr> },
30    ExprList(Vec<Expr>), // New variant for handling lists like (1,2,3) in IN clauses
31    Placeholder,
32}
33
34#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub enum Literal {
36    I16(i16),
37    I32(i32),
38    I64(i64),
39    F64(f64),
40    Bool(bool),
41    String(String),
42    EntityId(Ulid),
43    Object(Vec<u8>),
44    Binary(Vec<u8>),
45    /// JSON value - used for comparisons against JSON property subfields.
46    /// Not parsed from query syntax; created by AST preparation pass.
47    /// Serialized as bytes for bincode compatibility.
48    #[serde(with = "json_as_bytes")]
49    Json(serde_json::Value),
50}
51
52/// A dot-separated path like `name` or `licensing.territory`
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct PathExpr {
55    pub steps: Vec<String>,
56}
57
58impl PathExpr {
59    /// Create a single-step path
60    pub fn simple(name: impl Into<String>) -> Self { Self { steps: vec![name.into()] } }
61
62    /// Check if this is a single-step path
63    pub fn is_simple(&self) -> bool { self.steps.len() == 1 }
64
65    /// Get the first step (always exists)
66    pub fn first(&self) -> &str { &self.steps[0] }
67
68    /// Get the property name (last step)
69    pub fn property(&self) -> &str { self.steps.last().expect("PathExpr must have at least one step") }
70}
71
72impl std::fmt::Display for PathExpr {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.steps.join(".")) }
74}
75
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77pub struct Selection {
78    pub predicate: Predicate,
79    pub order_by: Option<Vec<OrderByItem>>,
80    pub limit: Option<u64>,
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub struct OrderByItem {
85    pub path: PathExpr,
86    pub direction: OrderDirection,
87}
88
89impl std::fmt::Display for OrderByItem {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(
92            f,
93            "{} {}",
94            self.path,
95            match self.direction {
96                OrderDirection::Asc => "ASC",
97                OrderDirection::Desc => "DESC",
98            }
99        )
100    }
101}
102
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub enum OrderDirection {
105    Asc,
106    Desc,
107}
108
109impl std::fmt::Display for Selection {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        write!(f, "{}", self.predicate)?;
112        if let Some(order_by) = &self.order_by {
113            write!(f, " ORDER BY ")?;
114            for (i, item) in order_by.iter().enumerate() {
115                if i > 0 {
116                    write!(f, ", ")?;
117                }
118                write!(f, "{}", item)?;
119            }
120        }
121        if let Some(limit) = self.limit {
122            write!(f, " LIMIT {}", limit)?;
123        }
124        Ok(())
125    }
126}
127
128// Backward compatibility
129impl From<Predicate> for Selection {
130    fn from(predicate: Predicate) -> Self { Selection { predicate, order_by: None, limit: None } }
131}
132
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134pub enum Predicate {
135    Comparison { left: Box<Expr>, operator: ComparisonOperator, right: Box<Expr> },
136    IsNull(Box<Expr>),
137    And(Box<Predicate>, Box<Predicate>),
138    Or(Box<Predicate>, Box<Predicate>),
139    Not(Box<Predicate>),
140    True,
141    False,
142    Placeholder,
143}
144
145impl std::fmt::Display for Predicate {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        match generate_selection_sql(self, None) {
148            Ok(sql) => write!(f, "{}", sql),
149            Err(e) => write!(f, "SQL Error: {}", e),
150        }
151    }
152}
153
154impl Selection {
155    /// Transform the selection to assume the given columns are NULL.
156    /// This filters out ORDER BY items that reference missing columns.
157    pub fn assume_null(&self, columns: &[String]) -> Self {
158        let order_by = self.order_by.as_ref().map(|items| {
159            items
160                .iter()
161                .filter(|item| {
162                    // Use the property name (last step) for column matching
163                    let col_name = item.path.property();
164                    !columns.contains(&col_name.to_string())
165                })
166                .cloned()
167                .collect::<Vec<_>>()
168        });
169        // If all ORDER BY items were filtered out, set to None
170        let order_by = order_by.and_then(|v| if v.is_empty() { None } else { Some(v) });
171
172        Self { predicate: self.predicate.assume_null(columns), order_by, limit: self.limit }
173    }
174
175    /// Collect all column names referenced in this selection (WHERE + ORDER BY).
176    /// For JSON paths like `licensing.territory`, returns the column name (`licensing`),
177    /// not the JSON path step (`territory`).
178    pub fn referenced_columns(&self) -> Vec<String> {
179        let mut columns = self.predicate.referenced_columns();
180        if let Some(order_by) = &self.order_by {
181            for item in order_by {
182                // Use first step for column reference (actual PostgreSQL column name)
183                let col = item.path.first().to_string();
184                if !columns.contains(&col) {
185                    columns.push(col);
186                }
187            }
188        }
189        columns
190    }
191}
192
193impl Predicate {
194    /// Recursively walk a predicate tree and accumulate results using a closure
195    pub fn walk<T, F>(&self, accumulator: T, visitor: &mut F) -> T
196    where F: FnMut(T, &Predicate) -> T {
197        let accumulator = visitor(accumulator, self);
198        match self {
199            Predicate::And(left, right) | Predicate::Or(left, right) => {
200                let accumulator = left.walk(accumulator, visitor);
201                right.walk(accumulator, visitor)
202            }
203            Predicate::Not(inner) => inner.walk(accumulator, visitor),
204            _ => accumulator,
205        }
206    }
207
208    /// Collect all column names referenced in this predicate.
209    /// For JSON paths like `licensing.territory`, returns the column name (`licensing`),
210    /// not the JSON path step (`territory`).
211    pub fn referenced_columns(&self) -> Vec<String> {
212        self.walk(Vec::new(), &mut |mut cols, pred| {
213            match pred {
214                Predicate::Comparison { left, right, .. } => {
215                    for expr in [&**left, &**right] {
216                        if let Expr::Path(path) = expr {
217                            // Use first step for column reference (actual PostgreSQL column name)
218                            // For `licensing.territory`, the column is `licensing`
219                            let col = path.first().to_string();
220                            if !cols.contains(&col) {
221                                cols.push(col);
222                            }
223                        }
224                    }
225                }
226                Predicate::IsNull(expr) => {
227                    if let Expr::Path(path) = &**expr {
228                        let col = path.first().to_string();
229                        if !cols.contains(&col) {
230                            cols.push(col);
231                        }
232                    }
233                }
234                _ => {}
235            }
236            cols
237        })
238    }
239
240    /// Clones the predicate tree and evaluates comparisons involving missing columns as if they were NULL
241    pub fn assume_null(&self, columns: &[String]) -> Self {
242        match self {
243            Predicate::Comparison { left, operator, right } => {
244                // Check if either side is a path whose property is in our null list
245                let has_null_path = match (&**left, &**right) {
246                    (Expr::Path(path), _) | (_, Expr::Path(path)) => columns.contains(&path.property().to_string()),
247                    _ => false,
248                };
249
250                if has_null_path {
251                    match operator {
252                        // NULL = anything is false
253                        ComparisonOperator::Equal => Predicate::False,
254                        // NULL != anything is false (NULL comparisons always return NULL in SQL)
255                        ComparisonOperator::NotEqual => Predicate::False,
256                        // NULL > anything is false
257                        ComparisonOperator::GreaterThan => Predicate::False,
258                        // NULL >= anything is false
259                        ComparisonOperator::GreaterThanOrEqual => Predicate::False,
260                        // NULL < anything is false
261                        ComparisonOperator::LessThan => Predicate::False,
262                        // NULL <= anything is false
263                        ComparisonOperator::LessThanOrEqual => Predicate::False,
264                        // NULL IN (...) is false
265                        ComparisonOperator::In => Predicate::False,
266                        // NULL BETWEEN ... is false
267                        ComparisonOperator::Between => Predicate::False,
268                    }
269                } else {
270                    // No NULL paths, keep the comparison as is
271                    Predicate::Comparison { left: left.clone(), operator: operator.clone(), right: right.clone() }
272                }
273            }
274            Predicate::IsNull(expr) => {
275                // If we're explicitly checking for NULL and the path's property is in our null list,
276                // then this evaluates to true
277                match &**expr {
278                    Expr::Path(path) => {
279                        let is_null = columns.contains(&path.property().to_string());
280                        if is_null {
281                            Predicate::True
282                        } else {
283                            Predicate::IsNull(expr.clone())
284                        }
285                    }
286                    _ => Predicate::IsNull(expr.clone()),
287                }
288            }
289            Predicate::And(left, right) => {
290                let left = left.assume_null(columns);
291                let right = right.assume_null(columns);
292
293                // Optimize
294                match (&left, &right) {
295                    // if either side is false, the whole thing is false
296                    (Predicate::False, _) | (_, Predicate::False) => Predicate::False,
297                    // if both sides are true, the whole thing is true
298                    (Predicate::True, Predicate::True) => Predicate::True,
299                    // if one side is true, the whole thing is the other side
300                    (Predicate::True, p) | (p, Predicate::True) => p.clone(),
301                    _ => Predicate::And(Box::new(left), Box::new(right)),
302                }
303            }
304            Predicate::Or(left, right) => {
305                let left = left.assume_null(columns);
306                let right = right.assume_null(columns);
307
308                // Optimize
309                match (&left, &right) {
310                    // if either side is true, the whole thing is true
311                    (Predicate::True, _) | (_, Predicate::True) => Predicate::True,
312                    // if both sides are false, the whole thing is false
313                    (Predicate::False, Predicate::False) => Predicate::False,
314                    // if one side is false, the whole thing is the other side
315                    (Predicate::False, p) | (p, Predicate::False) => p.clone(),
316                    // otherwise, keep the original
317                    _ => Predicate::Or(Box::new(left), Box::new(right)),
318                }
319            }
320            Predicate::Not(pred) => {
321                let inner = pred.assume_null(columns);
322                match inner {
323                    Predicate::True => Predicate::False,
324                    Predicate::False => Predicate::True,
325                    _ => Predicate::Not(Box::new(inner)),
326                }
327            }
328            // These are constants, just clone them
329            Predicate::True => Predicate::True,
330            Predicate::False => Predicate::False,
331            Predicate::Placeholder => Predicate::Placeholder,
332        }
333    }
334
335    /// Populate placeholders in the predicate with actual values
336    pub fn populate<I, V, E>(self, values: I) -> Result<Predicate, ParseError>
337    where
338        I: IntoIterator<Item = V>,
339        V: TryInto<Expr, Error = E>,
340        E: Into<ParseError>,
341    {
342        let mut values_iter = values.into_iter();
343        let result = self.populate_recursive(&mut values_iter)?;
344
345        // Check if there are any unused values
346        if values_iter.next().is_some() {
347            return Err(ParseError::InvalidPredicate("Too many values provided for placeholders".to_string()));
348        }
349
350        Ok(result)
351    }
352
353    fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Predicate, ParseError>
354    where
355        I: Iterator<Item = V>,
356        V: TryInto<Expr, Error = E>,
357        E: Into<ParseError>,
358    {
359        match self {
360            Predicate::Comparison { left, operator, right } => Ok(Predicate::Comparison {
361                left: Box::new(left.populate_recursive(values)?),
362                operator,
363                right: Box::new(right.populate_recursive(values)?),
364            }),
365            Predicate::And(left, right) => {
366                Ok(Predicate::And(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
367            }
368            Predicate::Or(left, right) => {
369                Ok(Predicate::Or(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
370            }
371            Predicate::Not(pred) => Ok(Predicate::Not(Box::new(pred.populate_recursive(values)?))),
372            Predicate::IsNull(expr) => Ok(Predicate::IsNull(Box::new(expr.populate_recursive(values)?))),
373            Predicate::True => Ok(Predicate::True),
374            Predicate::False => Ok(Predicate::False),
375            // Placeholder should be transformed to a comparison before population
376            Predicate::Placeholder => Err(ParseError::InvalidPredicate("Placeholder must be transformed before population".to_string())),
377        }
378    }
379}
380
381impl Expr {
382    fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Expr, ParseError>
383    where
384        I: Iterator<Item = V>,
385        V: TryInto<Expr, Error = E>,
386        E: Into<ParseError>,
387    {
388        match self {
389            Expr::Placeholder => match values.next() {
390                Some(value) => Ok(value.try_into().map_err(|e| e.into())?),
391                None => Err(ParseError::InvalidPredicate("Not enough values provided for placeholders".to_string())),
392            },
393            Expr::Literal(lit) => Ok(Expr::Literal(lit)),
394            Expr::Path(path) => Ok(Expr::Path(path)),
395            Expr::Predicate(pred) => Ok(Expr::Predicate(pred.populate_recursive(values)?)),
396            Expr::InfixExpr { left, operator, right } => Ok(Expr::InfixExpr {
397                left: Box::new(left.populate_recursive(values)?),
398                operator,
399                right: Box::new(right.populate_recursive(values)?),
400            }),
401            Expr::ExprList(exprs) => {
402                let mut populated_exprs = Vec::new();
403                for expr in exprs {
404                    populated_exprs.push(expr.populate_recursive(values)?);
405                }
406                Ok(Expr::ExprList(populated_exprs))
407            }
408        }
409    }
410}
411
412#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
413pub enum ComparisonOperator {
414    Equal,              // =
415    NotEqual,           // <> or !=
416    GreaterThan,        // >
417    GreaterThanOrEqual, // >=
418    LessThan,           // <
419    LessThanOrEqual,    // <=
420    In,                 // IN
421    Between,            // BETWEEN
422}
423
424#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
425pub enum InfixOperator {
426    Add,
427    Subtract,
428    Multiply,
429    Divide,
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::parser::parse_selection;
436
437    fn nullify_columns(input: &str, null_columns: &[&str]) -> Result<String, ParseError> {
438        let selection = parse_selection(input)?;
439        let result = selection.predicate.assume_null(&null_columns.iter().map(|s| s.to_string()).collect::<Vec<_>>());
440        generate_selection_sql(&result, None).map_err(|_| ParseError::InvalidPredicate("SQL generation failed".to_string()))
441    }
442
443    #[test]
444    fn test_single_comparison_null_handling() {
445        assert_eq!(nullify_columns("status = 'active'", &["status"]).unwrap(), "FALSE");
446        assert_eq!(nullify_columns("age > 30", &["age"]).unwrap(), "FALSE");
447        assert_eq!(nullify_columns("count >= 100", &["count"]).unwrap(), "FALSE");
448        assert_eq!(nullify_columns("name < 'Z'", &["name"]).unwrap(), "FALSE");
449        assert_eq!(nullify_columns("score <= 90", &["score"]).unwrap(), "FALSE");
450        assert_eq!(nullify_columns("status IS NULL", &["status"]).unwrap(), "TRUE");
451        assert_eq!(nullify_columns("role = 'admin'", &["other"]).unwrap(), r#""role" = 'admin'"#);
452    }
453
454    #[test]
455    fn nested_predicate_null_handling() {
456        let input = "alpha = 1 AND (beta = 2 OR charlie = 3)";
457        assert_eq!(nullify_columns(input, &["charlie"]).unwrap(), r#""alpha" = 1 AND "beta" = 2"#);
458        assert_eq!(nullify_columns(input, &["beta", "charlie"]).unwrap(), r#"FALSE"#);
459        assert_eq!(nullify_columns(input, &["alpha"]).unwrap(), r#"FALSE"#);
460        assert_eq!(nullify_columns(input, &["other"]).unwrap(), r#""alpha" = 1 AND ("beta" = 2 OR "charlie" = 3)"#);
461    }
462
463    #[test]
464    fn test_populate_single_placeholder() {
465        let selection = parse_selection("name = ?").unwrap();
466        let populated = selection.predicate.populate(vec!["Alice"]).unwrap();
467
468        let expected = Predicate::Comparison {
469            left: Box::new(Expr::Path(PathExpr::simple("name".to_string()))),
470            operator: ComparisonOperator::Equal,
471            right: Box::new(Expr::Literal(Literal::String("Alice".to_string()))),
472        };
473
474        assert_eq!(populated, expected);
475    }
476
477    #[test]
478    fn test_populate_multiple_placeholders() {
479        let selection = parse_selection("age > ? AND name = ?").unwrap();
480        let values: Vec<Expr> = vec![25i64.into(), "Bob".into()];
481        let populated = selection.predicate.populate(values).unwrap();
482
483        let expected = Predicate::And(
484            Box::new(Predicate::Comparison {
485                left: Box::new(Expr::Path(PathExpr::simple("age".to_string()))),
486                operator: ComparisonOperator::GreaterThan,
487                right: Box::new(Expr::Literal(Literal::I64(25))),
488            }),
489            Box::new(Predicate::Comparison {
490                left: Box::new(Expr::Path(PathExpr::simple("name".to_string()))),
491                operator: ComparisonOperator::Equal,
492                right: Box::new(Expr::Literal(Literal::String("Bob".to_string()))),
493            }),
494        );
495
496        assert_eq!(populated, expected);
497    }
498
499    #[test]
500    fn test_populate_in_clause() {
501        let selection = parse_selection("status IN (?, ?, ?)").unwrap();
502        let populated = selection.predicate.populate(vec!["active", "pending", "review"]).unwrap();
503
504        let expected = Predicate::Comparison {
505            left: Box::new(Expr::Path(PathExpr::simple("status".to_string()))),
506            operator: ComparisonOperator::In,
507            right: Box::new(Expr::ExprList(vec![
508                Expr::Literal(Literal::String("active".to_string())),
509                Expr::Literal(Literal::String("pending".to_string())),
510                Expr::Literal(Literal::String("review".to_string())),
511            ])),
512        };
513
514        assert_eq!(populated, expected);
515    }
516
517    #[test]
518    fn test_populate_mixed_types() {
519        let selection = parse_selection("active = ? AND score > ? AND name = ?").unwrap();
520        let values: Vec<Expr> = vec![true.into(), 95.5f64.into(), "Charlie".into()];
521        let populated = selection.predicate.populate(values).unwrap();
522
523        // Verify the structure is correct
524        if let Predicate::And(left, right) = populated {
525            if let Predicate::And(inner_left, inner_right) = *left {
526                // Check boolean value
527                if let Predicate::Comparison { right: val, .. } = *inner_left {
528                    assert_eq!(*val, Expr::Literal(Literal::Bool(true)));
529                }
530                // Check float value
531                if let Predicate::Comparison { right: val, .. } = *inner_right {
532                    assert_eq!(*val, Expr::Literal(Literal::F64(95.5)));
533                }
534            }
535            // Check string value
536            if let Predicate::Comparison { right: val, .. } = *right {
537                assert_eq!(*val, Expr::Literal(Literal::String("Charlie".to_string())));
538            }
539        }
540    }
541
542    #[test]
543    fn test_populate_too_few_values() {
544        let selection = parse_selection("name = ? AND age = ?").unwrap();
545        let result = selection.predicate.populate(vec!["Alice"]);
546
547        assert!(result.is_err());
548        assert!(result.unwrap_err().to_string().contains("Not enough values"));
549    }
550
551    #[test]
552    fn test_populate_too_many_values() {
553        let selection = parse_selection("name = ?").unwrap();
554        let result = selection.predicate.populate(vec!["Alice", "Bob"]);
555
556        assert!(result.is_err());
557        assert!(result.unwrap_err().to_string().contains("Too many values"));
558    }
559
560    #[test]
561    fn test_populate_no_placeholders() {
562        let selection = parse_selection("name = 'Alice'").unwrap();
563        let populated = selection.clone().predicate.populate(Vec::<String>::new()).unwrap();
564
565        // Should be unchanged
566        assert_eq!(populated, selection.predicate);
567    }
568}
569
570// From implementations for single values that wrap them in Expr::Literal
571impl From<String> for Expr {
572    fn from(s: String) -> Expr { Expr::Literal(Literal::String(s)) }
573}
574
575impl From<&str> for Expr {
576    fn from(s: &str) -> Expr { Expr::Literal(Literal::String(s.to_string())) }
577}
578
579impl From<i64> for Expr {
580    fn from(i: i64) -> Expr { Expr::Literal(Literal::I64(i)) }
581}
582
583impl From<f64> for Expr {
584    fn from(f: f64) -> Expr { Expr::Literal(Literal::F64(f)) }
585}
586
587impl From<bool> for Expr {
588    fn from(b: bool) -> Expr { Expr::Literal(Literal::Bool(b)) }
589}
590
591impl From<Literal> for Expr {
592    fn from(lit: Literal) -> Expr { Expr::Literal(lit) }
593}
594
595// These create Expr::ExprList for use in IN clauses
596impl<T> From<Vec<T>> for Expr
597where T: Into<Expr>
598{
599    fn from(vec: Vec<T>) -> Self { Expr::ExprList(vec.into_iter().map(|item| item.into()).collect()) }
600}
601
602impl<T, const N: usize> From<[T; N]> for Expr
603where T: Into<Expr>
604{
605    fn from(arr: [T; N]) -> Self { Expr::ExprList(arr.into_iter().map(|item| item.into()).collect()) }
606}
607
608impl<T> From<&[T]> for Expr
609where T: Into<Expr> + Clone
610{
611    fn from(slice: &[T]) -> Self { Expr::ExprList(slice.iter().map(|item| item.clone().into()).collect()) }
612}
613
614impl<T, const N: usize> From<&[T; N]> for Expr
615where T: Into<Expr> + Clone
616{
617    fn from(arr: &[T; N]) -> Self { Expr::ExprList(arr.iter().map(|item| item.clone().into()).collect()) }
618}