1use ankql::ast::{ComparisonOperator, Expr, Literal, OrderByItem, OrderDirection, Predicate, Selection};
6use ankurah_core::EntityId;
7use thiserror::Error;
8
9use crate::error::SqliteError;
10
11#[derive(Debug, Error, Clone)]
12pub enum SqlGenerationError {
13 #[error("Placeholder found in predicate - placeholders should be replaced before predicate processing")]
14 PlaceholderFound,
15 #[error("Unsupported expression type: {0}")]
16 UnsupportedExpression(&'static str),
17 #[error("Unsupported operator: {0}")]
18 UnsupportedOperator(&'static str),
19}
20
21impl From<SqlGenerationError> for SqliteError {
22 fn from(err: SqlGenerationError) -> Self { SqliteError::SqlGeneration(err.to_string()) }
23}
24
25#[derive(Debug, Clone)]
27pub struct SplitPredicate {
28 pub sql_predicate: Predicate,
30 pub remaining_predicate: Predicate,
32}
33
34impl SplitPredicate {
35 pub fn needs_post_filter(&self) -> bool { !matches!(self.remaining_predicate, Predicate::True) }
37}
38
39pub fn split_predicate_for_sqlite(predicate: &Predicate) -> SplitPredicate {
41 let (sql_pred, remaining_pred) = split_predicate_recursive(predicate);
42 SplitPredicate { sql_predicate: sql_pred, remaining_predicate: remaining_pred }
43}
44
45fn split_predicate_recursive(predicate: &Predicate) -> (Predicate, Predicate) {
46 match predicate {
47 Predicate::Comparison { left, operator: _, right } => {
48 if can_pushdown_comparison(left, right) {
49 (predicate.clone(), Predicate::True)
50 } else {
51 (Predicate::True, predicate.clone())
52 }
53 }
54
55 Predicate::And(left, right) => {
56 let (left_sql, left_remaining) = split_predicate_recursive(left);
57 let (right_sql, right_remaining) = split_predicate_recursive(right);
58
59 let sql_pred = match (&left_sql, &right_sql) {
60 (Predicate::True, Predicate::True) => Predicate::True,
61 (Predicate::True, _) => right_sql,
62 (_, Predicate::True) => left_sql,
63 _ => Predicate::And(Box::new(left_sql), Box::new(right_sql)),
64 };
65
66 let remaining_pred = match (&left_remaining, &right_remaining) {
67 (Predicate::True, Predicate::True) => Predicate::True,
68 (Predicate::True, _) => right_remaining,
69 (_, Predicate::True) => left_remaining,
70 _ => Predicate::And(Box::new(left_remaining), Box::new(right_remaining)),
71 };
72
73 (sql_pred, remaining_pred)
74 }
75
76 Predicate::Or(left, right) => {
77 let (left_sql, left_remaining) = split_predicate_recursive(left);
78 let (right_sql, right_remaining) = split_predicate_recursive(right);
79
80 if matches!(left_remaining, Predicate::True) && matches!(right_remaining, Predicate::True) {
81 (predicate.clone(), Predicate::True)
82 } else {
83 let sql_pred = match (&left_sql, &right_sql) {
84 (Predicate::True, Predicate::True) => Predicate::True,
85 (Predicate::True, _) => right_sql,
86 (_, Predicate::True) => left_sql,
87 _ => Predicate::Or(Box::new(left_sql), Box::new(right_sql)),
88 };
89 (sql_pred, predicate.clone())
90 }
91 }
92
93 Predicate::Not(inner) => {
94 let (inner_sql, inner_remaining) = split_predicate_recursive(inner);
95 if matches!(inner_remaining, Predicate::True) {
96 (Predicate::Not(Box::new(inner_sql)), Predicate::True)
97 } else {
98 (Predicate::True, predicate.clone())
99 }
100 }
101
102 Predicate::IsNull(expr) => {
103 if can_pushdown_expr(expr) {
104 (predicate.clone(), Predicate::True)
105 } else {
106 (Predicate::True, predicate.clone())
107 }
108 }
109
110 Predicate::True => (Predicate::True, Predicate::True),
111 Predicate::False => (Predicate::False, Predicate::True),
112 Predicate::Placeholder => (Predicate::True, predicate.clone()),
113 }
114}
115
116fn can_pushdown_comparison(left: &Expr, right: &Expr) -> bool { can_pushdown_expr(left) && can_pushdown_expr(right) }
117
118fn can_pushdown_expr(expr: &Expr) -> bool {
119 match expr {
120 Expr::Literal(_) => true,
121 Expr::Path(path) => !path.steps.is_empty(),
122 Expr::ExprList(exprs) => exprs.iter().all(can_pushdown_expr),
123 Expr::Predicate(_) => false,
124 Expr::InfixExpr { .. } => false,
125 Expr::Placeholder => false,
126 }
127}
128
129pub struct SqlBuilder {
131 sql: String,
132 params: Vec<rusqlite::types::Value>,
133 fields: Vec<String>,
134 table_name: Option<String>,
135}
136
137impl Default for SqlBuilder {
138 fn default() -> Self { Self::new() }
139}
140
141impl SqlBuilder {
142 pub fn new() -> Self { Self { sql: String::new(), params: Vec::new(), fields: Vec::new(), table_name: None } }
143
144 pub fn with_fields<T: Into<String>>(fields: Vec<T>) -> Self {
145 Self { sql: String::new(), params: Vec::new(), fields: fields.into_iter().map(|f| f.into()).collect(), table_name: None }
146 }
147
148 pub fn table_name(&mut self, name: impl Into<String>) -> &mut Self {
149 self.table_name = Some(name.into());
150 self
151 }
152
153 fn push_sql(&mut self, s: &str) { self.sql.push_str(s); }
154
155 fn push_param(&mut self, value: rusqlite::types::Value) {
156 self.sql.push('?');
157 self.params.push(value);
158 }
159
160 pub fn build(self) -> Result<(String, Vec<rusqlite::types::Value>), SqlGenerationError> {
161 if self.fields.is_empty() || self.table_name.is_none() {
162 return Ok((self.sql, self.params));
164 }
165
166 let fields_clause = self.fields.iter().map(|field| format!(r#""{}""#, field.replace('"', "\"\""))).collect::<Vec<_>>().join(", ");
167 let table = self.table_name.unwrap();
168 let sql = format!(r#"SELECT {} FROM "{}" WHERE {}"#, fields_clause, table.replace('"', "\"\""), self.sql);
169
170 Ok((sql, self.params))
171 }
172
173 #[allow(dead_code)]
174 pub fn build_where_clause(self) -> (String, Vec<rusqlite::types::Value>) { (self.sql, self.params) }
175
176 pub fn expr(&mut self, expr: &Expr) -> Result<(), SqlGenerationError> {
177 match expr {
178 Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
179 Expr::Literal(lit) => self.literal(lit),
180 Expr::Path(path) => {
181 if path.is_simple() {
182 let escaped = path.first().replace('"', "\"\"");
184 self.push_sql(&format!(r#""{}""#, escaped));
185 } else {
186 let first = path.first().replace('"', "\"\"");
190 let json_path = if path.steps.len() == 2 {
192 format!("$.{}", path.steps[1].replace('\'', "''"))
193 } else {
194 format!("$.{}", path.steps.iter().skip(1).map(|s| s.replace('\'', "''")).collect::<Vec<_>>().join("."))
195 };
196 self.push_sql(&format!(r#"json_extract("{}", '{}')"#, first, json_path));
197 }
198 }
199 Expr::ExprList(exprs) => {
200 self.push_sql("(");
201 for (i, expr) in exprs.iter().enumerate() {
202 if i > 0 {
203 self.push_sql(", ");
204 }
205 self.expr(expr)?;
206 }
207 self.push_sql(")");
208 }
209 _ => return Err(SqlGenerationError::UnsupportedExpression("Only literal, path, and list expressions are supported")),
210 }
211 Ok(())
212 }
213
214 fn literal(&mut self, lit: &Literal) {
215 match lit {
216 Literal::String(s) => self.push_param(rusqlite::types::Value::Text(s.clone())),
217 Literal::I64(i) => self.push_param(rusqlite::types::Value::Integer(*i)),
218 Literal::F64(f) => self.push_param(rusqlite::types::Value::Real(*f)),
219 Literal::Bool(b) => self.push_param(rusqlite::types::Value::Integer(if *b { 1 } else { 0 })),
220 Literal::I16(i) => self.push_param(rusqlite::types::Value::Integer(*i as i64)),
221 Literal::I32(i) => self.push_param(rusqlite::types::Value::Integer(*i as i64)),
222 Literal::EntityId(ulid) => self.push_param(rusqlite::types::Value::Text(EntityId::from_ulid(*ulid).to_base64())),
223 Literal::Object(bytes) => self.push_param(rusqlite::types::Value::Blob(bytes.clone())),
224 Literal::Binary(bytes) => self.push_param(rusqlite::types::Value::Blob(bytes.clone())),
225 Literal::Json(json) => match json {
228 serde_json::Value::String(s) => self.push_param(rusqlite::types::Value::Text(s.clone())),
229 serde_json::Value::Number(n) => {
230 if let Some(i) = n.as_i64() {
231 self.push_param(rusqlite::types::Value::Integer(i));
232 } else if let Some(f) = n.as_f64() {
233 self.push_param(rusqlite::types::Value::Real(f));
234 } else {
235 self.push_param(rusqlite::types::Value::Text(n.to_string()));
237 }
238 }
239 serde_json::Value::Bool(b) => self.push_param(rusqlite::types::Value::Integer(if *b { 1 } else { 0 })),
240 serde_json::Value::Null => self.push_param(rusqlite::types::Value::Null),
241 _ => self.push_param(rusqlite::types::Value::Text(json.to_string())),
243 },
244 }
245 }
246
247 pub fn comparison_op(&mut self, op: &ComparisonOperator) -> Result<(), SqlGenerationError> {
248 self.push_sql(comparison_op_to_sql(op)?);
249 Ok(())
250 }
251
252 pub fn predicate(&mut self, predicate: &Predicate) -> Result<(), SqlGenerationError> {
253 match predicate {
254 Predicate::Comparison { left, operator, right } => {
255 self.expr(left)?;
258 self.push_sql(" ");
259 self.comparison_op(operator)?;
260 self.push_sql(" ");
261 self.expr(right)?;
262 }
263 Predicate::And(left, right) => {
264 self.predicate(left)?;
265 self.push_sql(" AND ");
266 self.predicate(right)?;
267 }
268 Predicate::Or(left, right) => {
269 self.push_sql("(");
270 self.predicate(left)?;
271 self.push_sql(" OR ");
272 self.predicate(right)?;
273 self.push_sql(")");
274 }
275 Predicate::Not(pred) => {
276 self.push_sql("NOT (");
277 self.predicate(pred)?;
278 self.push_sql(")");
279 }
280 Predicate::IsNull(expr) => {
281 self.expr(expr)?;
282 self.push_sql(" IS NULL");
283 }
284 Predicate::True => {
285 self.push_sql("1=1");
286 }
287 Predicate::False => {
288 self.push_sql("1=0");
289 }
290 Predicate::Placeholder => {
291 return Err(SqlGenerationError::PlaceholderFound);
292 }
293 }
294 Ok(())
295 }
296
297 pub fn selection(&mut self, selection: &Selection) -> Result<(), SqlGenerationError> {
298 self.predicate(&selection.predicate)?;
299
300 if let Some(order_by_items) = &selection.order_by {
301 self.push_sql(" ORDER BY ");
302 for (i, order_by) in order_by_items.iter().enumerate() {
303 if i > 0 {
304 self.push_sql(", ");
305 }
306 self.order_by_item(order_by)?;
307 }
308 }
309
310 if let Some(limit) = selection.limit {
311 self.push_sql(&format!(" LIMIT {}", limit));
312 }
313
314 Ok(())
315 }
316
317 pub fn order_by_item(&mut self, order_by: &OrderByItem) -> Result<(), SqlGenerationError> {
318 if order_by.path.is_simple() {
320 let escaped = order_by.path.first().replace('"', "\"\"");
322 self.push_sql(&format!(r#""{}""#, escaped));
323 } else {
324 let first = order_by.path.first().replace('"', "\"\"");
326 self.push_sql(&format!(r#""{}""#, first));
327
328 for step in order_by.path.steps.iter().skip(1) {
329 let escaped = step.replace('\'', "''");
330 self.push_sql(&format!("->'{}'", escaped));
332 }
333 }
334
335 match order_by.direction {
336 OrderDirection::Asc => self.push_sql(" ASC"),
337 OrderDirection::Desc => self.push_sql(" DESC"),
338 }
339
340 Ok(())
341 }
342}
343
344fn comparison_op_to_sql(op: &ComparisonOperator) -> Result<&'static str, SqlGenerationError> {
345 Ok(match op {
346 ComparisonOperator::Equal => "=",
347 ComparisonOperator::NotEqual => "<>",
348 ComparisonOperator::GreaterThan => ">",
349 ComparisonOperator::GreaterThanOrEqual => ">=",
350 ComparisonOperator::LessThan => "<",
351 ComparisonOperator::LessThanOrEqual => "<=",
352 ComparisonOperator::In => "IN",
353 ComparisonOperator::Between => return Err(SqlGenerationError::UnsupportedOperator("BETWEEN operator is not yet supported")),
354 })
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use ankql::parser::parse_selection;
361
362 #[test]
363 fn test_simple_equality() {
364 let selection = parse_selection("name = 'Alice'").unwrap();
365 let mut sql = SqlBuilder::new();
366 sql.selection(&selection).unwrap();
367 let (sql_string, params) = sql.build_where_clause();
368
369 assert_eq!(sql_string, r#""name" = ?"#);
370 assert_eq!(params.len(), 1);
371 }
372
373 #[test]
374 fn test_and_condition() {
375 let selection = parse_selection("name = 'Alice' AND age = 30").unwrap();
376 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
377 sql.table_name("users");
378 sql.selection(&selection).unwrap();
379 let (sql_string, params) = sql.build().unwrap();
380
381 assert_eq!(sql_string, r#"SELECT "id", "name", "age" FROM "users" WHERE "name" = ? AND "age" = ?"#);
382 assert_eq!(params.len(), 2);
383 }
384
385 #[test]
386 fn test_json_path() {
387 let selection = parse_selection("data.status = 'active'").unwrap();
388 let mut sql = SqlBuilder::new();
389 sql.selection(&selection).unwrap();
390 let (sql_string, _) = sql.build_where_clause();
391
392 assert_eq!(sql_string, r#"json_extract("data", '$.status') = ?"#);
394 }
395
396 #[test]
397 fn test_json_nested_path() {
398 let selection = parse_selection("data.user.name = 'Alice'").unwrap();
399 let mut sql = SqlBuilder::new();
400 sql.selection(&selection).unwrap();
401 let (sql_string, _) = sql.build_where_clause();
402
403 assert_eq!(sql_string, r#"json_extract("data", '$.user.name') = ?"#);
405 }
406
407 #[test]
408 fn test_json_numeric_comparison() {
409 let selection = parse_selection("data.count > 10").unwrap();
410 let mut sql = SqlBuilder::new();
411 sql.selection(&selection).unwrap();
412 let (sql_string, _) = sql.build_where_clause();
413
414 assert_eq!(sql_string, r#"json_extract("data", '$.count') > ?"#);
416 }
417
418 #[test]
419 fn test_in_operator() {
420 let selection = parse_selection("name IN ('Alice', 'Bob')").unwrap();
421 let mut sql = SqlBuilder::new();
422 sql.selection(&selection).unwrap();
423 let (sql_string, params) = sql.build_where_clause();
424
425 assert_eq!(sql_string, r#""name" IN (?, ?)"#);
426 assert_eq!(params.len(), 2);
427 }
428}