ankurah_storage_postgres/
predicate.rs

1use ankql::ast::{ComparisonOperator, Expr, Identifier, Literal, Predicate};
2use thiserror::Error;
3use tokio_postgres::types::ToSql;
4
5#[derive(Debug, Error, Clone)]
6pub enum SqlGenerationError {
7    #[error("Placeholder found in predicate - placeholders should be replaced before predicate processing")]
8    PlaceholderFound,
9    #[error("Unsupported expression type: {0}")]
10    UnsupportedExpression(&'static str),
11    #[error("Unsupported operator: {0}")]
12    UnsupportedOperator(&'static str),
13}
14
15pub enum SqlExpr {
16    Sql(String),
17    Argument(Box<dyn ToSql + Send + Sync>),
18}
19
20pub struct Sql(Vec<SqlExpr>);
21
22impl Default for Sql {
23    fn default() -> Self { Self::new() }
24}
25
26impl Sql {
27    pub fn new() -> Self { Self(Vec::new()) }
28
29    pub fn push(&mut self, expr: SqlExpr) { self.0.push(expr); }
30
31    pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
32        self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
33    }
34
35    pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
36
37    pub fn collapse(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
38        let mut counter = 1;
39        let mut sql = String::new();
40        let mut args = Vec::new();
41
42        for expr in self.0 {
43            match expr {
44                SqlExpr::Argument(arg) => {
45                    sql += &format!("${}", counter);
46                    args.push(arg);
47                    counter += 1;
48                }
49                SqlExpr::Sql(s) => {
50                    sql += &s;
51                }
52            }
53        }
54
55        (sql, args)
56    }
57
58    // --- AST flattening ---
59    pub fn expr(&mut self, expr: &Expr) -> Result<(), SqlGenerationError> {
60        match expr {
61            Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
62            Expr::Literal(lit) => match lit {
63                Literal::String(s) => self.arg(s.to_owned()),
64                Literal::Integer(int) => self.arg(*int),
65                Literal::Float(float) => self.arg(*float),
66                Literal::Boolean(bool) => self.arg(*bool),
67            },
68            Expr::Identifier(id) => match id {
69                Identifier::Property(name) => self.sql(format!(r#""{}""#, name)),
70                Identifier::CollectionProperty(collection, name) => {
71                    self.sql(format!(r#""{}"."{}""#, collection, name));
72                }
73            },
74            Expr::ExprList(exprs) => {
75                self.sql("(");
76                for (i, expr) in exprs.iter().enumerate() {
77                    if i > 0 {
78                        self.sql(", ");
79                    }
80                    match expr {
81                        Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
82                        Expr::Literal(lit) => match lit {
83                            Literal::String(s) => self.arg(s.to_owned()),
84                            Literal::Integer(int) => self.arg(*int),
85                            Literal::Float(float) => self.arg(*float),
86                            Literal::Boolean(bool) => self.arg(*bool),
87                        },
88                        _ => {
89                            return Err(SqlGenerationError::UnsupportedExpression(
90                                "Only literal expressions and placeholders are supported in IN lists",
91                            ))
92                        }
93                    }
94                }
95                self.sql(")");
96            }
97            _ => return Err(SqlGenerationError::UnsupportedExpression("Only literal, identifier, and list expressions are supported")),
98        }
99        Ok(())
100    }
101
102    pub fn comparison_op(&mut self, op: &ComparisonOperator) -> Result<(), SqlGenerationError> { Ok(self.sql(comparison_op_to_sql(op)?)) }
103
104    pub fn predicate(&mut self, predicate: &Predicate) -> Result<(), SqlGenerationError> {
105        match predicate {
106            Predicate::Comparison { left, operator, right } => {
107                self.expr(left)?;
108                self.sql(" ");
109                self.comparison_op(operator);
110                self.sql(" ");
111                self.expr(right)?;
112            }
113            Predicate::And(left, right) => {
114                self.predicate(left)?;
115                self.sql(" AND ");
116                self.predicate(right)?;
117            }
118            Predicate::Or(left, right) => {
119                self.sql("(");
120                self.predicate(left)?;
121                self.sql(" OR ");
122                self.predicate(right)?;
123                self.sql(")");
124            }
125            Predicate::Not(pred) => {
126                self.sql("NOT (");
127                self.predicate(pred)?;
128                self.sql(")");
129            }
130            Predicate::IsNull(expr) => {
131                self.expr(expr)?;
132                self.sql(" IS NULL");
133            }
134            Predicate::True => {
135                self.sql("TRUE");
136            }
137            Predicate::False => {
138                self.sql("FALSE");
139            }
140            Predicate::Placeholder => {
141                return Err(SqlGenerationError::PlaceholderFound);
142            }
143        }
144        Ok(())
145    }
146}
147
148fn comparison_op_to_sql(op: &ComparisonOperator) -> Result<&'static str, SqlGenerationError> {
149    Ok(match op {
150        ComparisonOperator::Equal => "=",
151        ComparisonOperator::NotEqual => "<>",
152        ComparisonOperator::GreaterThan => ">",
153        ComparisonOperator::GreaterThanOrEqual => ">=",
154        ComparisonOperator::LessThan => "<",
155        ComparisonOperator::LessThanOrEqual => "<=",
156        ComparisonOperator::In => "IN",
157        ComparisonOperator::Between => return Err(SqlGenerationError::UnsupportedOperator("BETWEEN operator is not yet supported")),
158    })
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use ankql::parser::parse_selection;
165    use anyhow::Result;
166
167    fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
168        // TODO: Maybe actually encoding these and comparing bytes?
169        assert_eq!(format!("{:?}", args), format!("{:?}", expected));
170    }
171
172    #[test]
173    fn test_simple_equality() -> Result<()> {
174        let predicate = parse_selection("name = 'Alice'").unwrap();
175        let mut sql = Sql::new();
176        sql.predicate(&predicate)?;
177
178        let (sql_string, args) = sql.collapse();
179        assert_eq!(sql_string, r#""name" = $1"#);
180        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
181        assert_args(&args, &expected);
182        Ok(())
183    }
184
185    #[test]
186    fn test_and_condition() -> Result<()> {
187        let predicate = parse_selection("name = 'Alice' AND age = 30").unwrap();
188        let mut sql = Sql::new();
189        sql.predicate(&predicate)?;
190        let (sql_string, args) = sql.collapse();
191
192        assert_eq!(sql_string, r#""name" = $1 AND "age" = $2"#);
193        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
194        assert_args(&args, &expected);
195        Ok(())
196    }
197
198    #[test]
199    fn test_complex_condition() -> Result<()> {
200        let predicate = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
201
202        let mut sql = Sql::new();
203        sql.predicate(&predicate)?;
204        let (sql_string, args) = sql.collapse();
205
206        assert_eq!(sql_string, r#"("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#);
207        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
208        assert_args(&args, &expected);
209        Ok(())
210    }
211
212    #[test]
213    fn test_including_collection_identifier() -> Result<()> {
214        let predicate = parse_selection("person.name = 'Alice'").unwrap();
215
216        let mut sql = Sql::new();
217        sql.predicate(&predicate)?;
218        let (sql_string, args) = sql.collapse();
219
220        assert_eq!(sql_string, r#""person"."name" = $1"#);
221        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
222        assert_args(&args, &expected);
223        Ok(())
224    }
225
226    #[test]
227    fn test_false_predicate() -> Result<()> {
228        let mut sql = Sql::new();
229        sql.predicate(&Predicate::False)?;
230        let (sql_string, args) = sql.collapse();
231
232        assert_eq!(sql_string, "FALSE");
233        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
234        assert_args(&args, &expected);
235        Ok(())
236    }
237
238    #[test]
239    fn test_in_operator() -> Result<()> {
240        let predicate = parse_selection("name IN ('Alice', 'Bob', 'Charlie')").unwrap();
241        let mut sql = Sql::new();
242        sql.predicate(&predicate)?;
243        let (sql_string, args) = sql.collapse();
244
245        assert_eq!(sql_string, r#""name" IN ($1, $2, $3)"#);
246        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Bob"), Box::new("Charlie")];
247        assert_args(&args, &expected);
248        Ok(())
249    }
250
251    #[test]
252    fn test_placeholder_error() {
253        let mut sql = Sql::new();
254        let err = sql.predicate(&Predicate::Placeholder).expect_err("Expected an error");
255        assert!(matches!(err, SqlGenerationError::PlaceholderFound));
256    }
257}