Skip to main content

ankurah_storage_sqlite/
sql_builder.rs

1//! SQL builder for SQLite queries
2//!
3//! Converts AnkQL predicates to SQLite-compatible SQL WHERE clauses.
4
5use ankql::ast::{ComparisonOperator, Expr, Literal, OrderByItem, OrderDirection, Predicate, Selection};
6use ankurah_core::EntityId;
7use thiserror::Error;
8
9use crate::error::SqliteError;
10
11#[derive(Debug, Error, Clone)]
12pub enum SqlGenerationError {
13    #[error("Placeholder found in predicate - placeholders should be replaced before predicate processing")]
14    PlaceholderFound,
15    #[error("Unsupported expression type: {0}")]
16    UnsupportedExpression(&'static str),
17    #[error("Unsupported operator: {0}")]
18    UnsupportedOperator(&'static str),
19}
20
21impl From<SqlGenerationError> for SqliteError {
22    fn from(err: SqlGenerationError) -> Self { SqliteError::SqlGeneration(err.to_string()) }
23}
24
25/// Result of splitting a predicate for SQLite execution.
26#[derive(Debug, Clone)]
27pub struct SplitPredicate {
28    /// Predicate that can be pushed down to SQLite WHERE clause
29    pub sql_predicate: Predicate,
30    /// Predicate that must be evaluated in Rust after fetching (Predicate::True if nothing remains)
31    pub remaining_predicate: Predicate,
32}
33
34impl SplitPredicate {
35    /// Check if there's any remaining predicate that needs post-filtering
36    pub fn needs_post_filter(&self) -> bool { !matches!(self.remaining_predicate, Predicate::True) }
37}
38
39/// Split a predicate into parts that can be pushed down to SQLite vs evaluated post-fetch.
40pub fn split_predicate_for_sqlite(predicate: &Predicate) -> SplitPredicate {
41    let (sql_pred, remaining_pred) = split_predicate_recursive(predicate);
42    SplitPredicate { sql_predicate: sql_pred, remaining_predicate: remaining_pred }
43}
44
45fn split_predicate_recursive(predicate: &Predicate) -> (Predicate, Predicate) {
46    match predicate {
47        Predicate::Comparison { left, operator: _, right } => {
48            if can_pushdown_comparison(left, right) {
49                (predicate.clone(), Predicate::True)
50            } else {
51                (Predicate::True, predicate.clone())
52            }
53        }
54
55        Predicate::And(left, right) => {
56            let (left_sql, left_remaining) = split_predicate_recursive(left);
57            let (right_sql, right_remaining) = split_predicate_recursive(right);
58
59            let sql_pred = match (&left_sql, &right_sql) {
60                (Predicate::True, Predicate::True) => Predicate::True,
61                (Predicate::True, _) => right_sql,
62                (_, Predicate::True) => left_sql,
63                _ => Predicate::And(Box::new(left_sql), Box::new(right_sql)),
64            };
65
66            let remaining_pred = match (&left_remaining, &right_remaining) {
67                (Predicate::True, Predicate::True) => Predicate::True,
68                (Predicate::True, _) => right_remaining,
69                (_, Predicate::True) => left_remaining,
70                _ => Predicate::And(Box::new(left_remaining), Box::new(right_remaining)),
71            };
72
73            (sql_pred, remaining_pred)
74        }
75
76        Predicate::Or(left, right) => {
77            let (left_sql, left_remaining) = split_predicate_recursive(left);
78            let (right_sql, right_remaining) = split_predicate_recursive(right);
79
80            if matches!(left_remaining, Predicate::True) && matches!(right_remaining, Predicate::True) {
81                (predicate.clone(), Predicate::True)
82            } else {
83                let sql_pred = match (&left_sql, &right_sql) {
84                    (Predicate::True, Predicate::True) => Predicate::True,
85                    (Predicate::True, _) => right_sql,
86                    (_, Predicate::True) => left_sql,
87                    _ => Predicate::Or(Box::new(left_sql), Box::new(right_sql)),
88                };
89                (sql_pred, predicate.clone())
90            }
91        }
92
93        Predicate::Not(inner) => {
94            let (inner_sql, inner_remaining) = split_predicate_recursive(inner);
95            if matches!(inner_remaining, Predicate::True) {
96                (Predicate::Not(Box::new(inner_sql)), Predicate::True)
97            } else {
98                (Predicate::True, predicate.clone())
99            }
100        }
101
102        Predicate::IsNull(expr) => {
103            if can_pushdown_expr(expr) {
104                (predicate.clone(), Predicate::True)
105            } else {
106                (Predicate::True, predicate.clone())
107            }
108        }
109
110        Predicate::True => (Predicate::True, Predicate::True),
111        Predicate::False => (Predicate::False, Predicate::True),
112        Predicate::Placeholder => (Predicate::True, predicate.clone()),
113    }
114}
115
116fn can_pushdown_comparison(left: &Expr, right: &Expr) -> bool { can_pushdown_expr(left) && can_pushdown_expr(right) }
117
118fn can_pushdown_expr(expr: &Expr) -> bool {
119    match expr {
120        Expr::Literal(_) => true,
121        Expr::Path(path) => !path.steps.is_empty(),
122        Expr::ExprList(exprs) => exprs.iter().all(can_pushdown_expr),
123        Expr::Predicate(_) => false,
124        Expr::InfixExpr { .. } => false,
125        Expr::Placeholder => false,
126    }
127}
128
129/// SQL builder for SQLite queries
130pub struct SqlBuilder {
131    sql: String,
132    params: Vec<rusqlite::types::Value>,
133    fields: Vec<String>,
134    table_name: Option<String>,
135}
136
137impl Default for SqlBuilder {
138    fn default() -> Self { Self::new() }
139}
140
141impl SqlBuilder {
142    pub fn new() -> Self { Self { sql: String::new(), params: Vec::new(), fields: Vec::new(), table_name: None } }
143
144    pub fn with_fields<T: Into<String>>(fields: Vec<T>) -> Self {
145        Self { sql: String::new(), params: Vec::new(), fields: fields.into_iter().map(|f| f.into()).collect(), table_name: None }
146    }
147
148    pub fn table_name(&mut self, name: impl Into<String>) -> &mut Self {
149        self.table_name = Some(name.into());
150        self
151    }
152
153    fn push_sql(&mut self, s: &str) { self.sql.push_str(s); }
154
155    fn push_param(&mut self, value: rusqlite::types::Value) {
156        self.sql.push('?');
157        self.params.push(value);
158    }
159
160    pub fn build(self) -> Result<(String, Vec<rusqlite::types::Value>), SqlGenerationError> {
161        if self.fields.is_empty() || self.table_name.is_none() {
162            // Just return WHERE clause
163            return Ok((self.sql, self.params));
164        }
165
166        let fields_clause = self.fields.iter().map(|field| format!(r#""{}""#, field.replace('"', "\"\""))).collect::<Vec<_>>().join(", ");
167        let table = self.table_name.unwrap();
168        let sql = format!(r#"SELECT {} FROM "{}" WHERE {}"#, fields_clause, table.replace('"', "\"\""), self.sql);
169
170        Ok((sql, self.params))
171    }
172
173    #[allow(dead_code)]
174    pub fn build_where_clause(self) -> (String, Vec<rusqlite::types::Value>) { (self.sql, self.params) }
175
176    pub fn expr(&mut self, expr: &Expr) -> Result<(), SqlGenerationError> {
177        match expr {
178            Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
179            Expr::Literal(lit) => self.literal(lit),
180            Expr::Path(path) => {
181                if path.is_simple() {
182                    // Single-step path: regular column reference
183                    let escaped = path.first().replace('"', "\"\"");
184                    self.push_sql(&format!(r#""{}""#, escaped));
185                } else {
186                    // Multi-step path: JSONB traversal
187                    // SQLite's -> operator returns JSONB, but for comparisons we need to extract the value.
188                    // Use json_extract() with the full JSON path for reliable comparisons.
189                    let first = path.first().replace('"', "\"\"");
190                    // Build JSON path: $.step1.step2.step3
191                    let json_path = if path.steps.len() == 2 {
192                        format!("$.{}", path.steps[1].replace('\'', "''"))
193                    } else {
194                        format!("$.{}", path.steps.iter().skip(1).map(|s| s.replace('\'', "''")).collect::<Vec<_>>().join("."))
195                    };
196                    self.push_sql(&format!(r#"json_extract("{}", '{}')"#, first, json_path));
197                }
198            }
199            Expr::ExprList(exprs) => {
200                self.push_sql("(");
201                for (i, expr) in exprs.iter().enumerate() {
202                    if i > 0 {
203                        self.push_sql(", ");
204                    }
205                    self.expr(expr)?;
206                }
207                self.push_sql(")");
208            }
209            _ => return Err(SqlGenerationError::UnsupportedExpression("Only literal, path, and list expressions are supported")),
210        }
211        Ok(())
212    }
213
214    fn literal(&mut self, lit: &Literal) {
215        match lit {
216            Literal::String(s) => self.push_param(rusqlite::types::Value::Text(s.clone())),
217            Literal::I64(i) => self.push_param(rusqlite::types::Value::Integer(*i)),
218            Literal::F64(f) => self.push_param(rusqlite::types::Value::Real(*f)),
219            Literal::Bool(b) => self.push_param(rusqlite::types::Value::Integer(if *b { 1 } else { 0 })),
220            Literal::I16(i) => self.push_param(rusqlite::types::Value::Integer(*i as i64)),
221            Literal::I32(i) => self.push_param(rusqlite::types::Value::Integer(*i as i64)),
222            Literal::EntityId(ulid) => self.push_param(rusqlite::types::Value::Text(EntityId::from_ulid(*ulid).to_base64())),
223            Literal::Object(bytes) => self.push_param(rusqlite::types::Value::Blob(bytes.clone())),
224            Literal::Binary(bytes) => self.push_param(rusqlite::types::Value::Blob(bytes.clone())),
225            // For JSON literals, extract the raw SQL value since json_extract() returns SQL types.
226            // json.to_string() would produce "US" (with quotes) but we need just US.
227            Literal::Json(json) => match json {
228                serde_json::Value::String(s) => self.push_param(rusqlite::types::Value::Text(s.clone())),
229                serde_json::Value::Number(n) => {
230                    if let Some(i) = n.as_i64() {
231                        self.push_param(rusqlite::types::Value::Integer(i));
232                    } else if let Some(f) = n.as_f64() {
233                        self.push_param(rusqlite::types::Value::Real(f));
234                    } else {
235                        // Fallback: serialize as text
236                        self.push_param(rusqlite::types::Value::Text(n.to_string()));
237                    }
238                }
239                serde_json::Value::Bool(b) => self.push_param(rusqlite::types::Value::Integer(if *b { 1 } else { 0 })),
240                serde_json::Value::Null => self.push_param(rusqlite::types::Value::Null),
241                // For arrays and objects, serialize as JSON text
242                _ => self.push_param(rusqlite::types::Value::Text(json.to_string())),
243            },
244        }
245    }
246
247    pub fn comparison_op(&mut self, op: &ComparisonOperator) -> Result<(), SqlGenerationError> {
248        self.push_sql(comparison_op_to_sql(op)?);
249        Ok(())
250    }
251
252    pub fn predicate(&mut self, predicate: &Predicate) -> Result<(), SqlGenerationError> {
253        match predicate {
254            Predicate::Comparison { left, operator, right } => {
255                // Emit: left op right
256                // JSONB paths use json_extract() which returns SQL values, so direct comparison works
257                self.expr(left)?;
258                self.push_sql(" ");
259                self.comparison_op(operator)?;
260                self.push_sql(" ");
261                self.expr(right)?;
262            }
263            Predicate::And(left, right) => {
264                self.predicate(left)?;
265                self.push_sql(" AND ");
266                self.predicate(right)?;
267            }
268            Predicate::Or(left, right) => {
269                self.push_sql("(");
270                self.predicate(left)?;
271                self.push_sql(" OR ");
272                self.predicate(right)?;
273                self.push_sql(")");
274            }
275            Predicate::Not(pred) => {
276                self.push_sql("NOT (");
277                self.predicate(pred)?;
278                self.push_sql(")");
279            }
280            Predicate::IsNull(expr) => {
281                self.expr(expr)?;
282                self.push_sql(" IS NULL");
283            }
284            Predicate::True => {
285                self.push_sql("1=1");
286            }
287            Predicate::False => {
288                self.push_sql("1=0");
289            }
290            Predicate::Placeholder => {
291                return Err(SqlGenerationError::PlaceholderFound);
292            }
293        }
294        Ok(())
295    }
296
297    pub fn selection(&mut self, selection: &Selection) -> Result<(), SqlGenerationError> {
298        self.predicate(&selection.predicate)?;
299
300        if let Some(order_by_items) = &selection.order_by {
301            self.push_sql(" ORDER BY ");
302            for (i, order_by) in order_by_items.iter().enumerate() {
303                if i > 0 {
304                    self.push_sql(", ");
305                }
306                self.order_by_item(order_by)?;
307            }
308        }
309
310        if let Some(limit) = selection.limit {
311            self.push_sql(&format!(" LIMIT {}", limit));
312        }
313
314        Ok(())
315    }
316
317    pub fn order_by_item(&mut self, order_by: &OrderByItem) -> Result<(), SqlGenerationError> {
318        // Handle JSON paths the same way as in expr() - use -> operator for multi-step paths
319        if order_by.path.is_simple() {
320            // Single-step path: regular column reference
321            let escaped = order_by.path.first().replace('"', "\"\"");
322            self.push_sql(&format!(r#""{}""#, escaped));
323        } else {
324            // Multi-step path: JSONB traversal using -> operator
325            let first = order_by.path.first().replace('"', "\"\"");
326            self.push_sql(&format!(r#""{}""#, first));
327
328            for step in order_by.path.steps.iter().skip(1) {
329                let escaped = step.replace('\'', "''");
330                // Use -> to keep as JSONB (not ->> which extracts as text)
331                self.push_sql(&format!("->'{}'", escaped));
332            }
333        }
334
335        match order_by.direction {
336            OrderDirection::Asc => self.push_sql(" ASC"),
337            OrderDirection::Desc => self.push_sql(" DESC"),
338        }
339
340        Ok(())
341    }
342}
343
344fn comparison_op_to_sql(op: &ComparisonOperator) -> Result<&'static str, SqlGenerationError> {
345    Ok(match op {
346        ComparisonOperator::Equal => "=",
347        ComparisonOperator::NotEqual => "<>",
348        ComparisonOperator::GreaterThan => ">",
349        ComparisonOperator::GreaterThanOrEqual => ">=",
350        ComparisonOperator::LessThan => "<",
351        ComparisonOperator::LessThanOrEqual => "<=",
352        ComparisonOperator::In => "IN",
353        ComparisonOperator::Between => return Err(SqlGenerationError::UnsupportedOperator("BETWEEN operator is not yet supported")),
354    })
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use ankql::parser::parse_selection;
361
362    #[test]
363    fn test_simple_equality() {
364        let selection = parse_selection("name = 'Alice'").unwrap();
365        let mut sql = SqlBuilder::new();
366        sql.selection(&selection).unwrap();
367        let (sql_string, params) = sql.build_where_clause();
368
369        assert_eq!(sql_string, r#""name" = ?"#);
370        assert_eq!(params.len(), 1);
371    }
372
373    #[test]
374    fn test_and_condition() {
375        let selection = parse_selection("name = 'Alice' AND age = 30").unwrap();
376        let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
377        sql.table_name("users");
378        sql.selection(&selection).unwrap();
379        let (sql_string, params) = sql.build().unwrap();
380
381        assert_eq!(sql_string, r#"SELECT "id", "name", "age" FROM "users" WHERE "name" = ? AND "age" = ?"#);
382        assert_eq!(params.len(), 2);
383    }
384
385    #[test]
386    fn test_json_path() {
387        let selection = parse_selection("data.status = 'active'").unwrap();
388        let mut sql = SqlBuilder::new();
389        sql.selection(&selection).unwrap();
390        let (sql_string, _) = sql.build_where_clause();
391
392        // Uses json_extract() for reliable comparisons with BLOB JSONB columns
393        assert_eq!(sql_string, r#"json_extract("data", '$.status') = ?"#);
394    }
395
396    #[test]
397    fn test_json_nested_path() {
398        let selection = parse_selection("data.user.name = 'Alice'").unwrap();
399        let mut sql = SqlBuilder::new();
400        sql.selection(&selection).unwrap();
401        let (sql_string, _) = sql.build_where_clause();
402
403        // Uses json_extract() with nested path for reliable comparisons
404        assert_eq!(sql_string, r#"json_extract("data", '$.user.name') = ?"#);
405    }
406
407    #[test]
408    fn test_json_numeric_comparison() {
409        let selection = parse_selection("data.count > 10").unwrap();
410        let mut sql = SqlBuilder::new();
411        sql.selection(&selection).unwrap();
412        let (sql_string, _) = sql.build_where_clause();
413
414        // Numeric comparison with json_extract() - SQLite handles numeric comparison correctly
415        assert_eq!(sql_string, r#"json_extract("data", '$.count') > ?"#);
416    }
417
418    #[test]
419    fn test_in_operator() {
420        let selection = parse_selection("name IN ('Alice', 'Bob')").unwrap();
421        let mut sql = SqlBuilder::new();
422        sql.selection(&selection).unwrap();
423        let (sql_string, params) = sql.build_where_clause();
424
425        assert_eq!(sql_string, r#""name" IN (?, ?)"#);
426        assert_eq!(params.len(), 2);
427    }
428}