ankurah_storage_postgres/
sql_builder.rs

1use ankql::ast::{ComparisonOperator, Expr, Identifier, Literal, OrderByItem, OrderDirection, Predicate, Selection};
2use ankurah_core::{error::RetrievalError, EntityId};
3use thiserror::Error;
4use tokio_postgres::types::ToSql;
5
6#[derive(Debug, Error, Clone)]
7pub enum SqlGenerationError {
8    #[error("Placeholder found in predicate - placeholders should be replaced before predicate processing")]
9    PlaceholderFound,
10    #[error("Unsupported expression type: {0}")]
11    UnsupportedExpression(&'static str),
12    #[error("Unsupported operator: {0}")]
13    UnsupportedOperator(&'static str),
14    #[error("SqlBuilder requires both fields and table_name to be set for complete SELECT generation, or neither for WHERE-only mode")]
15    IncompleteConfiguration,
16}
17
18impl From<SqlGenerationError> for RetrievalError {
19    fn from(err: SqlGenerationError) -> Self { RetrievalError::StorageError(Box::new(err)) }
20}
21
22pub enum SqlExpr {
23    Sql(String),
24    Argument(Box<dyn ToSql + Send + Sync>),
25}
26
27pub struct SqlBuilder {
28    expressions: Vec<SqlExpr>,
29    fields: Vec<String>,
30    table_name: Option<String>,
31}
32
33impl Default for SqlBuilder {
34    fn default() -> Self { Self::new() }
35}
36
37impl SqlBuilder {
38    pub fn new() -> Self { Self { expressions: Vec::new(), fields: Vec::new(), table_name: None } }
39
40    pub fn with_fields<T: Into<String>>(fields: Vec<T>) -> Self {
41        Self { expressions: Vec::new(), fields: fields.into_iter().map(|f| f.into()).collect(), table_name: None }
42    }
43
44    pub fn table_name(&mut self, name: impl Into<String>) -> &mut Self {
45        self.table_name = Some(name.into());
46        self
47    }
48
49    pub fn push(&mut self, expr: SqlExpr) { self.expressions.push(expr); }
50
51    pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
52        self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
53    }
54
55    pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
56
57    pub fn build(self) -> Result<(String, Vec<Box<dyn ToSql + Send + Sync>>), SqlGenerationError> {
58        let mut counter = 1;
59        let mut where_clause = String::new();
60        let mut args = Vec::new();
61
62        // Build WHERE clause from expressions
63        for expr in self.expressions {
64            match expr {
65                SqlExpr::Argument(arg) => {
66                    where_clause += &format!("${}", counter);
67                    args.push(arg);
68                    counter += 1;
69                }
70                SqlExpr::Sql(s) => {
71                    where_clause += &s;
72                }
73            }
74        }
75
76        // Build complete SELECT statement - fields and table are required
77        if self.fields.is_empty() || self.table_name.is_none() {
78            return Err(SqlGenerationError::IncompleteConfiguration);
79        }
80
81        let fields_clause = self.fields.iter().map(|field| format!(r#""{}""#, field.replace('"', "\"\""))).collect::<Vec<_>>().join(", ");
82        let table = self.table_name.unwrap();
83        let sql = format!(r#"SELECT {} FROM "{}" WHERE {}"#, fields_clause, table.replace('"', "\"\""), where_clause);
84
85        Ok((sql, args))
86    }
87
88    pub fn build_where_clause(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
89        let mut counter = 1;
90        let mut where_clause = String::new();
91        let mut args = Vec::new();
92
93        // Build WHERE clause from expressions
94        for expr in self.expressions {
95            match expr {
96                SqlExpr::Argument(arg) => {
97                    where_clause += &format!("${}", counter);
98                    args.push(arg);
99                    counter += 1;
100                }
101                SqlExpr::Sql(s) => {
102                    where_clause += &s;
103                }
104            }
105        }
106
107        (where_clause, args)
108    }
109
110    // --- AST flattening ---
111    pub fn expr(&mut self, expr: &Expr) -> Result<(), SqlGenerationError> {
112        match expr {
113            Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
114            Expr::Literal(lit) => match lit {
115                Literal::String(s) => self.arg(s.to_owned()),
116                Literal::I64(int) => self.arg(*int),
117                Literal::F64(float) => self.arg(*float),
118                Literal::Bool(bool) => self.arg(*bool),
119                Literal::I16(i) => self.arg(*i),
120                Literal::I32(i) => self.arg(*i),
121                Literal::EntityId(ulid) => self.arg(EntityId::from_ulid(*ulid).to_base64()),
122                Literal::Object(bytes) => self.arg(bytes.clone()),
123                Literal::Binary(bytes) => self.arg(bytes.clone()),
124            },
125            Expr::Identifier(id) => match id {
126                Identifier::Property(name) => {
127                    // Escape any existing quotes in the property name by doubling them
128                    let escaped_name = name.replace('"', "\"\"");
129                    self.sql(format!(r#""{}""#, escaped_name));
130                }
131                Identifier::CollectionProperty(collection, name) => {
132                    // Escape quotes in both collection and property names
133                    let escaped_collection = collection.replace('"', "\"\"");
134                    let escaped_name = name.replace('"', "\"\"");
135                    self.sql(format!(r#""{}"."{}""#, escaped_collection, escaped_name));
136                }
137            },
138            Expr::ExprList(exprs) => {
139                self.sql("(");
140                for (i, expr) in exprs.iter().enumerate() {
141                    if i > 0 {
142                        self.sql(", ");
143                    }
144                    match expr {
145                        Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
146                        Expr::Literal(lit) => match lit {
147                            Literal::String(s) => self.arg(s.to_owned()),
148                            Literal::I64(int) => self.arg(*int),
149                            Literal::F64(float) => self.arg(*float),
150                            Literal::Bool(bool) => self.arg(*bool),
151                            Literal::I16(i) => self.arg(*i),
152                            Literal::I32(i) => self.arg(*i),
153                            Literal::EntityId(ulid) => self.arg(EntityId::from_ulid(*ulid).to_base64()),
154                            Literal::Object(bytes) => self.arg(bytes.clone()),
155                            Literal::Binary(bytes) => self.arg(bytes.clone()),
156                        },
157                        _ => {
158                            return Err(SqlGenerationError::UnsupportedExpression(
159                                "Only literal expressions and placeholders are supported in IN lists",
160                            ))
161                        }
162                    }
163                }
164                self.sql(")");
165            }
166            _ => return Err(SqlGenerationError::UnsupportedExpression("Only literal, identifier, and list expressions are supported")),
167        }
168        Ok(())
169    }
170
171    pub fn comparison_op(&mut self, op: &ComparisonOperator) -> Result<(), SqlGenerationError> {
172        self.sql(comparison_op_to_sql(op)?);
173        Ok(())
174    }
175
176    pub fn predicate(&mut self, predicate: &Predicate) -> Result<(), SqlGenerationError> {
177        match predicate {
178            Predicate::Comparison { left, operator, right } => {
179                self.expr(left)?;
180                self.sql(" ");
181                self.comparison_op(operator)?;
182                self.sql(" ");
183                self.expr(right)?;
184            }
185            Predicate::And(left, right) => {
186                self.predicate(left)?;
187                self.sql(" AND ");
188                self.predicate(right)?;
189            }
190            Predicate::Or(left, right) => {
191                self.sql("(");
192                self.predicate(left)?;
193                self.sql(" OR ");
194                self.predicate(right)?;
195                self.sql(")");
196            }
197            Predicate::Not(pred) => {
198                self.sql("NOT (");
199                self.predicate(pred)?;
200                self.sql(")");
201            }
202            Predicate::IsNull(expr) => {
203                self.expr(expr)?;
204                self.sql(" IS NULL");
205            }
206            Predicate::True => {
207                self.sql("TRUE");
208            }
209            Predicate::False => {
210                self.sql("FALSE");
211            }
212            Predicate::Placeholder => {
213                return Err(SqlGenerationError::PlaceholderFound);
214            }
215        }
216        Ok(())
217    }
218
219    pub fn selection(&mut self, selection: &Selection) -> Result<(), SqlGenerationError> {
220        // Add the predicate (WHERE clause)
221        self.predicate(&selection.predicate)?;
222
223        // Add ORDER BY clause if present
224        if let Some(order_by_items) = &selection.order_by {
225            self.sql(" ORDER BY ");
226            for (i, order_by) in order_by_items.iter().enumerate() {
227                if i > 0 {
228                    self.sql(", ");
229                }
230                self.order_by_item(order_by)?;
231            }
232        }
233
234        // Add LIMIT clause if present
235        if let Some(limit) = selection.limit {
236            self.sql(" LIMIT ");
237            self.arg(limit as i64); // PostgreSQL expects i64 for LIMIT
238        }
239
240        Ok(())
241    }
242
243    pub fn order_by_item(&mut self, order_by: &OrderByItem) -> Result<(), SqlGenerationError> {
244        // Generate the identifier
245        match &order_by.identifier {
246            Identifier::Property(name) => {
247                // Escape any existing quotes in the property name by doubling them
248                let escaped_name = name.replace('"', "\"\"");
249                self.sql(format!(r#""{}""#, escaped_name));
250            }
251            Identifier::CollectionProperty(collection, name) => {
252                // Escape quotes in both collection and property names
253                let escaped_collection = collection.replace('"', "\"\"");
254                let escaped_name = name.replace('"', "\"\"");
255                self.sql(format!(r#""{}"."{}""#, escaped_collection, escaped_name));
256            }
257        }
258
259        // Add the direction
260        match order_by.direction {
261            OrderDirection::Asc => self.sql(" ASC"),
262            OrderDirection::Desc => self.sql(" DESC"),
263        }
264
265        Ok(())
266    }
267}
268
269fn comparison_op_to_sql(op: &ComparisonOperator) -> Result<&'static str, SqlGenerationError> {
270    Ok(match op {
271        ComparisonOperator::Equal => "=",
272        ComparisonOperator::NotEqual => "<>",
273        ComparisonOperator::GreaterThan => ">",
274        ComparisonOperator::GreaterThanOrEqual => ">=",
275        ComparisonOperator::LessThan => "<",
276        ComparisonOperator::LessThanOrEqual => "<=",
277        ComparisonOperator::In => "IN",
278        ComparisonOperator::Between => return Err(SqlGenerationError::UnsupportedOperator("BETWEEN operator is not yet supported")),
279    })
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use ankql::parser::parse_selection;
286    use anyhow::Result;
287
288    fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
289        // TODO: Maybe actually encoding these and comparing bytes?
290        assert_eq!(format!("{:?}", args), format!("{:?}", expected));
291    }
292
293    #[test]
294    fn test_simple_equality() -> Result<()> {
295        let selection = parse_selection("name = 'Alice'").unwrap();
296        let mut sql = SqlBuilder::new();
297        sql.selection(&selection)?;
298
299        let (sql_string, args) = sql.build_where_clause();
300        assert_eq!(sql_string, r#""name" = $1"#);
301        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
302        assert_args(&args, &expected);
303        Ok(())
304    }
305
306    #[test]
307    fn test_and_condition() -> Result<()> {
308        let selection = parse_selection("name = 'Alice' AND age = 30").unwrap();
309        let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
310        sql.table_name("users");
311        sql.selection(&selection)?;
312        let (sql_string, args) = sql.build()?;
313
314        assert_eq!(sql_string, r#"SELECT "id", "name", "age" FROM "users" WHERE "name" = $1 AND "age" = $2"#);
315        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
316        assert_args(&args, &expected);
317        Ok(())
318    }
319
320    #[test]
321    fn test_complex_condition() -> Result<()> {
322        let selection = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
323
324        let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
325        sql.table_name("users");
326        sql.selection(&selection)?;
327        let (sql_string, args) = sql.build()?;
328
329        assert_eq!(
330            sql_string,
331            r#"SELECT "id", "name", "age" FROM "users" WHERE ("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#
332        );
333        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
334        assert_args(&args, &expected);
335        Ok(())
336    }
337
338    #[test]
339    fn test_including_collection_identifier() -> Result<()> {
340        let selection = parse_selection("person.name = 'Alice'").unwrap();
341
342        let mut sql = SqlBuilder::with_fields(vec!["id", "name"]);
343        sql.table_name("people");
344        sql.selection(&selection)?;
345        let (sql_string, args) = sql.build()?;
346
347        assert_eq!(sql_string, r#"SELECT "id", "name" FROM "people" WHERE "person"."name" = $1"#);
348        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
349        assert_args(&args, &expected);
350        Ok(())
351    }
352
353    #[test]
354    fn test_false_predicate() -> Result<()> {
355        let mut sql = SqlBuilder::with_fields(vec!["id"]);
356        sql.table_name("test");
357        sql.predicate(&Predicate::False)?;
358        let (sql_string, args) = sql.build()?;
359
360        assert_eq!(sql_string, r#"SELECT "id" FROM "test" WHERE FALSE"#);
361        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
362        assert_args(&args, &expected);
363        Ok(())
364    }
365
366    #[test]
367    fn test_in_operator() -> Result<()> {
368        let selection = parse_selection("name IN ('Alice', 'Bob', 'Charlie')").unwrap();
369        let mut sql = SqlBuilder::with_fields(vec!["id", "name"]);
370        sql.table_name("users");
371        sql.selection(&selection)?;
372        let (sql_string, args) = sql.build()?;
373
374        assert_eq!(sql_string, r#"SELECT "id", "name" FROM "users" WHERE "name" IN ($1, $2, $3)"#);
375        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Bob"), Box::new("Charlie")];
376        assert_args(&args, &expected);
377        Ok(())
378    }
379
380    #[test]
381    fn test_placeholder_error() {
382        let mut sql = SqlBuilder::with_fields(vec!["id"]);
383        sql.table_name("test");
384        let err = sql.predicate(&Predicate::Placeholder).expect_err("Expected an error");
385        assert!(matches!(err, SqlGenerationError::PlaceholderFound));
386    }
387
388    #[test]
389    fn test_selection_with_order_by() -> Result<()> {
390        use ankql::ast::{Identifier, OrderByItem, OrderDirection, Selection};
391
392        let base_selection = ankql::parser::parse_selection("name = 'Alice'").unwrap();
393        let selection = Selection {
394            predicate: base_selection.predicate,
395            order_by: Some(vec![OrderByItem {
396                identifier: Identifier::Property("created_at".to_string()),
397                direction: OrderDirection::Desc,
398            }]),
399            limit: None,
400        };
401
402        let mut sql = SqlBuilder::with_fields(vec!["id", "name", "created_at"]);
403        sql.table_name("users");
404        sql.selection(&selection)?;
405        let (sql_string, args) = sql.build()?;
406
407        assert_eq!(sql_string, r#"SELECT "id", "name", "created_at" FROM "users" WHERE "name" = $1 ORDER BY "created_at" DESC"#);
408        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
409        assert_args(&args, &expected);
410        Ok(())
411    }
412
413    #[test]
414    fn test_selection_with_limit() -> Result<()> {
415        let base_selection = ankql::parser::parse_selection("age > 18").unwrap();
416        let selection = Selection { predicate: base_selection.predicate, order_by: None, limit: Some(10) };
417
418        let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
419        sql.table_name("users");
420        sql.selection(&selection)?;
421        let (sql_string, args) = sql.build()?;
422
423        assert_eq!(sql_string, r#"SELECT "id", "name", "age" FROM "users" WHERE "age" > $1 LIMIT $2"#);
424        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new(18i64), Box::new(10i64)];
425        assert_args(&args, &expected);
426        Ok(())
427    }
428
429    #[test]
430    fn test_selection_with_order_by_and_limit() -> Result<()> {
431        use ankql::ast::{Identifier, OrderByItem, OrderDirection, Selection};
432
433        let base_selection = ankql::parser::parse_selection("status = 'active'").unwrap();
434        let selection = Selection {
435            predicate: base_selection.predicate,
436            order_by: Some(vec![
437                OrderByItem { identifier: Identifier::Property("priority".to_string()), direction: OrderDirection::Desc },
438                OrderByItem { identifier: Identifier::Property("created_at".to_string()), direction: OrderDirection::Asc },
439            ]),
440            limit: Some(5),
441        };
442
443        let mut sql = SqlBuilder::with_fields(vec!["id", "status", "priority", "created_at"]);
444        sql.table_name("tasks");
445        sql.selection(&selection)?;
446        let (sql_string, args) = sql.build()?;
447
448        assert_eq!(
449            sql_string,
450            r#"SELECT "id", "status", "priority", "created_at" FROM "tasks" WHERE "status" = $1 ORDER BY "priority" DESC, "created_at" ASC LIMIT $2"#
451        );
452        let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("active"), Box::new(5i64)];
453        assert_args(&args, &expected);
454        Ok(())
455    }
456}