ankurah_storage_postgres/
predicate.rs

1use ankql::ast::{ComparisonOperator, Expr, Identifier, Literal, Predicate};
2use tokio_postgres::types::ToSql;
3
4pub enum SqlExpr {
5    Sql(String),
6    Argument(Box<dyn ToSql + Send + Sync>),
7}
8
9pub struct Sql(Vec<SqlExpr>);
10
11impl Sql {
12    pub fn new() -> Self { Self(Vec::new()) }
13
14    pub fn push(&mut self, expr: SqlExpr) { self.0.push(expr); }
15
16    pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
17        self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
18    }
19
20    pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
21
22    pub fn collapse(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
23        let mut counter = 1;
24        let mut sql = String::new();
25        let mut args = Vec::new();
26
27        for expr in self.0 {
28            match expr {
29                SqlExpr::Argument(arg) => {
30                    sql += &format!("${}", counter);
31                    args.push(arg);
32                    counter += 1;
33                }
34                SqlExpr::Sql(s) => {
35                    sql += &s;
36                }
37            }
38        }
39
40        (sql, args)
41    }
42
43    // --- AST flattening ---
44    pub fn expr(&mut self, expr: &Expr) {
45        match expr {
46            Expr::Literal(lit) => match lit {
47                Literal::String(s) => self.arg(s.to_owned()),
48                Literal::Integer(int) => self.arg(*int),
49                Literal::Float(float) => self.arg(*float),
50                Literal::Boolean(bool) => self.arg(*bool),
51            },
52            Expr::Identifier(id) => match id {
53                Identifier::Property(name) => self.sql(format!(r#""{}""#, name)),
54                Identifier::CollectionProperty(collection, name) => {
55                    self.sql(format!(r#""{}"."{}""#, collection, name));
56                }
57            },
58            _ => unimplemented!("Only literal and identifier expressions are supported"),
59        }
60    }
61
62    pub fn comparison_op(&mut self, op: &ComparisonOperator) { self.sql(comparison_op_to_sql(op)); }
63
64    pub fn predicate(&mut self, predicate: &Predicate) {
65        match predicate {
66            Predicate::Comparison { left, operator, right } => {
67                self.expr(left);
68                self.sql(" ");
69                self.comparison_op(operator);
70                self.sql(" ");
71                self.expr(right);
72            }
73            Predicate::And(left, right) => {
74                self.predicate(left);
75                self.sql(" AND ");
76                self.predicate(right);
77            }
78            Predicate::Or(left, right) => {
79                self.sql("(");
80                self.predicate(left);
81                self.sql(" OR ");
82                self.predicate(right);
83                self.sql(")");
84            }
85            Predicate::Not(pred) => {
86                self.sql("NOT (");
87                self.predicate(pred);
88                self.sql("NOT )");
89            }
90            Predicate::IsNull(expr) => {
91                self.expr(expr);
92                self.sql(" IS NULL");
93            }
94            Predicate::True => {
95                self.sql("TRUE");
96            }
97            Predicate::False => {
98                self.sql("FALSE");
99            }
100        }
101    }
102}
103
104fn comparison_op_to_sql(op: &ComparisonOperator) -> &'static str {
105    match op {
106        ComparisonOperator::Equal => "=",
107        ComparisonOperator::NotEqual => "<>",
108        ComparisonOperator::GreaterThan => ">",
109        ComparisonOperator::GreaterThanOrEqual => ">=",
110        ComparisonOperator::LessThan => "<",
111        ComparisonOperator::LessThanOrEqual => "<=",
112        _ => unimplemented!("Only basic comparison operators are supported"),
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use ankql::parser::parse_selection;
120
121    fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
122        // TODO: Maybe actually encoding these and comparing bytes?
123        assert_eq!(format!("{:?}", args), format!("{:?}", expected));
124    }
125
126    #[test]
127    fn test_simple_equality() {
128        let predicate = parse_selection("name = 'Alice'").unwrap();
129        let mut sql = Sql::new();
130        sql.predicate(&predicate);
131
132        let (sql_string, args) = sql.collapse();
133        assert_eq!(sql_string, r#""name" = $1"#);
134        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
135        assert_args(&args, &expected);
136    }
137
138    #[test]
139    fn test_and_condition() {
140        let predicate = parse_selection("name = 'Alice' AND age = 30").unwrap();
141        let mut sql = Sql::new();
142        sql.predicate(&predicate);
143        let (sql_string, args) = sql.collapse();
144
145        assert_eq!(sql_string, r#""name" = $1 AND "age" = $2"#);
146        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
147        assert_args(&args, &expected);
148    }
149
150    #[test]
151    fn test_complex_condition() {
152        let predicate = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
153
154        let mut sql = Sql::new();
155        sql.predicate(&predicate);
156        let (sql_string, args) = sql.collapse();
157
158        assert_eq!(sql_string, r#"("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#);
159        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
160        assert_args(&args, &expected);
161    }
162
163    #[test]
164    fn test_including_collection_identifier() {
165        let predicate = parse_selection("person.name = 'Alice'").unwrap();
166
167        let mut sql = Sql::new();
168        sql.predicate(&predicate);
169        let (sql_string, args) = sql.collapse();
170
171        assert_eq!(sql_string, r#""person"."name" = $1"#);
172        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
173        assert_args(&args, &expected);
174    }
175
176    #[test]
177    fn test_false_predicate() {
178        let mut sql = Sql::new();
179        sql.predicate(&Predicate::False);
180        let (sql_string, args) = sql.collapse();
181
182        assert_eq!(sql_string, "FALSE");
183        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
184        assert_args(&args, &expected);
185    }
186}