ankurah_storage_postgres/
predicate.rs

1use ankql::ast::{ComparisonOperator, Expr, Identifier, Literal, Predicate};
2use tokio_postgres::types::ToSql;
3
4pub enum SqlExpr {
5    Sql(String),
6    Argument(Box<dyn ToSql + Send + Sync>),
7}
8
9pub struct Sql(Vec<SqlExpr>);
10
11impl Default for Sql {
12    fn default() -> Self { Self::new() }
13}
14
15impl Sql {
16    pub fn new() -> Self { Self(Vec::new()) }
17
18    pub fn push(&mut self, expr: SqlExpr) { self.0.push(expr); }
19
20    pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
21        self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
22    }
23
24    pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
25
26    pub fn collapse(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
27        let mut counter = 1;
28        let mut sql = String::new();
29        let mut args = Vec::new();
30
31        for expr in self.0 {
32            match expr {
33                SqlExpr::Argument(arg) => {
34                    sql += &format!("${}", counter);
35                    args.push(arg);
36                    counter += 1;
37                }
38                SqlExpr::Sql(s) => {
39                    sql += &s;
40                }
41            }
42        }
43
44        (sql, args)
45    }
46
47    // --- AST flattening ---
48    pub fn expr(&mut self, expr: &Expr) {
49        match expr {
50            Expr::Literal(lit) => match lit {
51                Literal::String(s) => self.arg(s.to_owned()),
52                Literal::Integer(int) => self.arg(*int),
53                Literal::Float(float) => self.arg(*float),
54                Literal::Boolean(bool) => self.arg(*bool),
55            },
56            Expr::Identifier(id) => match id {
57                Identifier::Property(name) => self.sql(format!(r#""{}""#, name)),
58                Identifier::CollectionProperty(collection, name) => {
59                    self.sql(format!(r#""{}"."{}""#, collection, name));
60                }
61            },
62            Expr::ExprList(exprs) => {
63                self.sql("(");
64                for (i, expr) in exprs.iter().enumerate() {
65                    if i > 0 {
66                        self.sql(", ");
67                    }
68                    match expr {
69                        Expr::Literal(lit) => match lit {
70                            Literal::String(s) => self.arg(s.to_owned()),
71                            Literal::Integer(int) => self.arg(*int),
72                            Literal::Float(float) => self.arg(*float),
73                            Literal::Boolean(bool) => self.arg(*bool),
74                        },
75                        _ => unimplemented!("Only literal expressions are supported in IN lists"),
76                    }
77                }
78                self.sql(")");
79            }
80            _ => unimplemented!("Only literal, identifier, and list expressions are supported"),
81        }
82    }
83
84    pub fn comparison_op(&mut self, op: &ComparisonOperator) { self.sql(comparison_op_to_sql(op)); }
85
86    pub fn predicate(&mut self, predicate: &Predicate) {
87        match predicate {
88            Predicate::Comparison { left, operator, right } => {
89                self.expr(left);
90                self.sql(" ");
91                self.comparison_op(operator);
92                self.sql(" ");
93                self.expr(right);
94            }
95            Predicate::And(left, right) => {
96                self.predicate(left);
97                self.sql(" AND ");
98                self.predicate(right);
99            }
100            Predicate::Or(left, right) => {
101                self.sql("(");
102                self.predicate(left);
103                self.sql(" OR ");
104                self.predicate(right);
105                self.sql(")");
106            }
107            Predicate::Not(pred) => {
108                self.sql("NOT (");
109                self.predicate(pred);
110                self.sql("NOT )");
111            }
112            Predicate::IsNull(expr) => {
113                self.expr(expr);
114                self.sql(" IS NULL");
115            }
116            Predicate::True => {
117                self.sql("TRUE");
118            }
119            Predicate::False => {
120                self.sql("FALSE");
121            }
122        }
123    }
124}
125
126fn comparison_op_to_sql(op: &ComparisonOperator) -> &'static str {
127    match op {
128        ComparisonOperator::Equal => "=",
129        ComparisonOperator::NotEqual => "<>",
130        ComparisonOperator::GreaterThan => ">",
131        ComparisonOperator::GreaterThanOrEqual => ">=",
132        ComparisonOperator::LessThan => "<",
133        ComparisonOperator::LessThanOrEqual => "<=",
134        ComparisonOperator::In => "IN",
135        ComparisonOperator::Between => unimplemented!("BETWEEN operator is not yet supported"),
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use ankql::parser::parse_selection;
143
144    fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
145        // TODO: Maybe actually encoding these and comparing bytes?
146        assert_eq!(format!("{:?}", args), format!("{:?}", expected));
147    }
148
149    #[test]
150    fn test_simple_equality() {
151        let predicate = parse_selection("name = 'Alice'").unwrap();
152        let mut sql = Sql::new();
153        sql.predicate(&predicate);
154
155        let (sql_string, args) = sql.collapse();
156        assert_eq!(sql_string, r#""name" = $1"#);
157        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
158        assert_args(&args, &expected);
159    }
160
161    #[test]
162    fn test_and_condition() {
163        let predicate = parse_selection("name = 'Alice' AND age = 30").unwrap();
164        let mut sql = Sql::new();
165        sql.predicate(&predicate);
166        let (sql_string, args) = sql.collapse();
167
168        assert_eq!(sql_string, r#""name" = $1 AND "age" = $2"#);
169        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
170        assert_args(&args, &expected);
171    }
172
173    #[test]
174    fn test_complex_condition() {
175        let predicate = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
176
177        let mut sql = Sql::new();
178        sql.predicate(&predicate);
179        let (sql_string, args) = sql.collapse();
180
181        assert_eq!(sql_string, r#"("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#);
182        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
183        assert_args(&args, &expected);
184    }
185
186    #[test]
187    fn test_including_collection_identifier() {
188        let predicate = parse_selection("person.name = 'Alice'").unwrap();
189
190        let mut sql = Sql::new();
191        sql.predicate(&predicate);
192        let (sql_string, args) = sql.collapse();
193
194        assert_eq!(sql_string, r#""person"."name" = $1"#);
195        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
196        assert_args(&args, &expected);
197    }
198
199    #[test]
200    fn test_false_predicate() {
201        let mut sql = Sql::new();
202        sql.predicate(&Predicate::False);
203        let (sql_string, args) = sql.collapse();
204
205        assert_eq!(sql_string, "FALSE");
206        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
207        assert_args(&args, &expected);
208    }
209
210    #[test]
211    fn test_in_operator() {
212        let predicate = parse_selection("name IN ('Alice', 'Bob', 'Charlie')").unwrap();
213        let mut sql = Sql::new();
214        sql.predicate(&predicate);
215        let (sql_string, args) = sql.collapse();
216
217        assert_eq!(sql_string, r#""name" IN ($1, $2, $3)"#);
218        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Bob"), Box::new("Charlie")];
219        assert_args(&args, &expected);
220    }
221}