1use ankql::ast::{ComparisonOperator, Expr, Identifier, Literal, OrderByItem, OrderDirection, Predicate, Selection};
2use ankurah_core::{error::RetrievalError, EntityId};
3use thiserror::Error;
4use tokio_postgres::types::ToSql;
5
6#[derive(Debug, Error, Clone)]
7pub enum SqlGenerationError {
8 #[error("Placeholder found in predicate - placeholders should be replaced before predicate processing")]
9 PlaceholderFound,
10 #[error("Unsupported expression type: {0}")]
11 UnsupportedExpression(&'static str),
12 #[error("Unsupported operator: {0}")]
13 UnsupportedOperator(&'static str),
14 #[error("SqlBuilder requires both fields and table_name to be set for complete SELECT generation, or neither for WHERE-only mode")]
15 IncompleteConfiguration,
16}
17
18impl From<SqlGenerationError> for RetrievalError {
19 fn from(err: SqlGenerationError) -> Self { RetrievalError::StorageError(Box::new(err)) }
20}
21
22pub enum SqlExpr {
23 Sql(String),
24 Argument(Box<dyn ToSql + Send + Sync>),
25}
26
27pub struct SqlBuilder {
28 expressions: Vec<SqlExpr>,
29 fields: Vec<String>,
30 table_name: Option<String>,
31}
32
33impl Default for SqlBuilder {
34 fn default() -> Self { Self::new() }
35}
36
37impl SqlBuilder {
38 pub fn new() -> Self { Self { expressions: Vec::new(), fields: Vec::new(), table_name: None } }
39
40 pub fn with_fields<T: Into<String>>(fields: Vec<T>) -> Self {
41 Self { expressions: Vec::new(), fields: fields.into_iter().map(|f| f.into()).collect(), table_name: None }
42 }
43
44 pub fn table_name(&mut self, name: impl Into<String>) -> &mut Self {
45 self.table_name = Some(name.into());
46 self
47 }
48
49 pub fn push(&mut self, expr: SqlExpr) { self.expressions.push(expr); }
50
51 pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
52 self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
53 }
54
55 pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
56
57 pub fn build(self) -> Result<(String, Vec<Box<dyn ToSql + Send + Sync>>), SqlGenerationError> {
58 let mut counter = 1;
59 let mut where_clause = String::new();
60 let mut args = Vec::new();
61
62 for expr in self.expressions {
64 match expr {
65 SqlExpr::Argument(arg) => {
66 where_clause += &format!("${}", counter);
67 args.push(arg);
68 counter += 1;
69 }
70 SqlExpr::Sql(s) => {
71 where_clause += &s;
72 }
73 }
74 }
75
76 if self.fields.is_empty() || self.table_name.is_none() {
78 return Err(SqlGenerationError::IncompleteConfiguration);
79 }
80
81 let fields_clause = self.fields.iter().map(|field| format!(r#""{}""#, field.replace('"', "\"\""))).collect::<Vec<_>>().join(", ");
82 let table = self.table_name.unwrap();
83 let sql = format!(r#"SELECT {} FROM "{}" WHERE {}"#, fields_clause, table.replace('"', "\"\""), where_clause);
84
85 Ok((sql, args))
86 }
87
88 pub fn build_where_clause(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
89 let mut counter = 1;
90 let mut where_clause = String::new();
91 let mut args = Vec::new();
92
93 for expr in self.expressions {
95 match expr {
96 SqlExpr::Argument(arg) => {
97 where_clause += &format!("${}", counter);
98 args.push(arg);
99 counter += 1;
100 }
101 SqlExpr::Sql(s) => {
102 where_clause += &s;
103 }
104 }
105 }
106
107 (where_clause, args)
108 }
109
110 pub fn expr(&mut self, expr: &Expr) -> Result<(), SqlGenerationError> {
112 match expr {
113 Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
114 Expr::Literal(lit) => match lit {
115 Literal::String(s) => self.arg(s.to_owned()),
116 Literal::I64(int) => self.arg(*int),
117 Literal::F64(float) => self.arg(*float),
118 Literal::Bool(bool) => self.arg(*bool),
119 Literal::I16(i) => self.arg(*i),
120 Literal::I32(i) => self.arg(*i),
121 Literal::EntityId(ulid) => self.arg(EntityId::from_ulid(*ulid).to_base64()),
122 Literal::Object(bytes) => self.arg(bytes.clone()),
123 Literal::Binary(bytes) => self.arg(bytes.clone()),
124 },
125 Expr::Identifier(id) => match id {
126 Identifier::Property(name) => {
127 let escaped_name = name.replace('"', "\"\"");
129 self.sql(format!(r#""{}""#, escaped_name));
130 }
131 Identifier::CollectionProperty(collection, name) => {
132 let escaped_collection = collection.replace('"', "\"\"");
134 let escaped_name = name.replace('"', "\"\"");
135 self.sql(format!(r#""{}"."{}""#, escaped_collection, escaped_name));
136 }
137 },
138 Expr::ExprList(exprs) => {
139 self.sql("(");
140 for (i, expr) in exprs.iter().enumerate() {
141 if i > 0 {
142 self.sql(", ");
143 }
144 match expr {
145 Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
146 Expr::Literal(lit) => match lit {
147 Literal::String(s) => self.arg(s.to_owned()),
148 Literal::I64(int) => self.arg(*int),
149 Literal::F64(float) => self.arg(*float),
150 Literal::Bool(bool) => self.arg(*bool),
151 Literal::I16(i) => self.arg(*i),
152 Literal::I32(i) => self.arg(*i),
153 Literal::EntityId(ulid) => self.arg(EntityId::from_ulid(*ulid).to_base64()),
154 Literal::Object(bytes) => self.arg(bytes.clone()),
155 Literal::Binary(bytes) => self.arg(bytes.clone()),
156 },
157 _ => {
158 return Err(SqlGenerationError::UnsupportedExpression(
159 "Only literal expressions and placeholders are supported in IN lists",
160 ))
161 }
162 }
163 }
164 self.sql(")");
165 }
166 _ => return Err(SqlGenerationError::UnsupportedExpression("Only literal, identifier, and list expressions are supported")),
167 }
168 Ok(())
169 }
170
171 pub fn comparison_op(&mut self, op: &ComparisonOperator) -> Result<(), SqlGenerationError> {
172 self.sql(comparison_op_to_sql(op)?);
173 Ok(())
174 }
175
176 pub fn predicate(&mut self, predicate: &Predicate) -> Result<(), SqlGenerationError> {
177 match predicate {
178 Predicate::Comparison { left, operator, right } => {
179 self.expr(left)?;
180 self.sql(" ");
181 self.comparison_op(operator)?;
182 self.sql(" ");
183 self.expr(right)?;
184 }
185 Predicate::And(left, right) => {
186 self.predicate(left)?;
187 self.sql(" AND ");
188 self.predicate(right)?;
189 }
190 Predicate::Or(left, right) => {
191 self.sql("(");
192 self.predicate(left)?;
193 self.sql(" OR ");
194 self.predicate(right)?;
195 self.sql(")");
196 }
197 Predicate::Not(pred) => {
198 self.sql("NOT (");
199 self.predicate(pred)?;
200 self.sql(")");
201 }
202 Predicate::IsNull(expr) => {
203 self.expr(expr)?;
204 self.sql(" IS NULL");
205 }
206 Predicate::True => {
207 self.sql("TRUE");
208 }
209 Predicate::False => {
210 self.sql("FALSE");
211 }
212 Predicate::Placeholder => {
213 return Err(SqlGenerationError::PlaceholderFound);
214 }
215 }
216 Ok(())
217 }
218
219 pub fn selection(&mut self, selection: &Selection) -> Result<(), SqlGenerationError> {
220 self.predicate(&selection.predicate)?;
222
223 if let Some(order_by_items) = &selection.order_by {
225 self.sql(" ORDER BY ");
226 for (i, order_by) in order_by_items.iter().enumerate() {
227 if i > 0 {
228 self.sql(", ");
229 }
230 self.order_by_item(order_by)?;
231 }
232 }
233
234 if let Some(limit) = selection.limit {
236 self.sql(" LIMIT ");
237 self.arg(limit as i64); }
239
240 Ok(())
241 }
242
243 pub fn order_by_item(&mut self, order_by: &OrderByItem) -> Result<(), SqlGenerationError> {
244 match &order_by.identifier {
246 Identifier::Property(name) => {
247 let escaped_name = name.replace('"', "\"\"");
249 self.sql(format!(r#""{}""#, escaped_name));
250 }
251 Identifier::CollectionProperty(collection, name) => {
252 let escaped_collection = collection.replace('"', "\"\"");
254 let escaped_name = name.replace('"', "\"\"");
255 self.sql(format!(r#""{}"."{}""#, escaped_collection, escaped_name));
256 }
257 }
258
259 match order_by.direction {
261 OrderDirection::Asc => self.sql(" ASC"),
262 OrderDirection::Desc => self.sql(" DESC"),
263 }
264
265 Ok(())
266 }
267}
268
269fn comparison_op_to_sql(op: &ComparisonOperator) -> Result<&'static str, SqlGenerationError> {
270 Ok(match op {
271 ComparisonOperator::Equal => "=",
272 ComparisonOperator::NotEqual => "<>",
273 ComparisonOperator::GreaterThan => ">",
274 ComparisonOperator::GreaterThanOrEqual => ">=",
275 ComparisonOperator::LessThan => "<",
276 ComparisonOperator::LessThanOrEqual => "<=",
277 ComparisonOperator::In => "IN",
278 ComparisonOperator::Between => return Err(SqlGenerationError::UnsupportedOperator("BETWEEN operator is not yet supported")),
279 })
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use ankql::parser::parse_selection;
286 use anyhow::Result;
287
288 fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
289 assert_eq!(format!("{:?}", args), format!("{:?}", expected));
291 }
292
293 #[test]
294 fn test_simple_equality() -> Result<()> {
295 let selection = parse_selection("name = 'Alice'").unwrap();
296 let mut sql = SqlBuilder::new();
297 sql.selection(&selection)?;
298
299 let (sql_string, args) = sql.build_where_clause();
300 assert_eq!(sql_string, r#""name" = $1"#);
301 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
302 assert_args(&args, &expected);
303 Ok(())
304 }
305
306 #[test]
307 fn test_and_condition() -> Result<()> {
308 let selection = parse_selection("name = 'Alice' AND age = 30").unwrap();
309 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
310 sql.table_name("users");
311 sql.selection(&selection)?;
312 let (sql_string, args) = sql.build()?;
313
314 assert_eq!(sql_string, r#"SELECT "id", "name", "age" FROM "users" WHERE "name" = $1 AND "age" = $2"#);
315 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
316 assert_args(&args, &expected);
317 Ok(())
318 }
319
320 #[test]
321 fn test_complex_condition() -> Result<()> {
322 let selection = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
323
324 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
325 sql.table_name("users");
326 sql.selection(&selection)?;
327 let (sql_string, args) = sql.build()?;
328
329 assert_eq!(
330 sql_string,
331 r#"SELECT "id", "name", "age" FROM "users" WHERE ("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#
332 );
333 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
334 assert_args(&args, &expected);
335 Ok(())
336 }
337
338 #[test]
339 fn test_including_collection_identifier() -> Result<()> {
340 let selection = parse_selection("person.name = 'Alice'").unwrap();
341
342 let mut sql = SqlBuilder::with_fields(vec!["id", "name"]);
343 sql.table_name("people");
344 sql.selection(&selection)?;
345 let (sql_string, args) = sql.build()?;
346
347 assert_eq!(sql_string, r#"SELECT "id", "name" FROM "people" WHERE "person"."name" = $1"#);
348 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
349 assert_args(&args, &expected);
350 Ok(())
351 }
352
353 #[test]
354 fn test_false_predicate() -> Result<()> {
355 let mut sql = SqlBuilder::with_fields(vec!["id"]);
356 sql.table_name("test");
357 sql.predicate(&Predicate::False)?;
358 let (sql_string, args) = sql.build()?;
359
360 assert_eq!(sql_string, r#"SELECT "id" FROM "test" WHERE FALSE"#);
361 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
362 assert_args(&args, &expected);
363 Ok(())
364 }
365
366 #[test]
367 fn test_in_operator() -> Result<()> {
368 let selection = parse_selection("name IN ('Alice', 'Bob', 'Charlie')").unwrap();
369 let mut sql = SqlBuilder::with_fields(vec!["id", "name"]);
370 sql.table_name("users");
371 sql.selection(&selection)?;
372 let (sql_string, args) = sql.build()?;
373
374 assert_eq!(sql_string, r#"SELECT "id", "name" FROM "users" WHERE "name" IN ($1, $2, $3)"#);
375 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Bob"), Box::new("Charlie")];
376 assert_args(&args, &expected);
377 Ok(())
378 }
379
380 #[test]
381 fn test_placeholder_error() {
382 let mut sql = SqlBuilder::with_fields(vec!["id"]);
383 sql.table_name("test");
384 let err = sql.predicate(&Predicate::Placeholder).expect_err("Expected an error");
385 assert!(matches!(err, SqlGenerationError::PlaceholderFound));
386 }
387
388 #[test]
389 fn test_selection_with_order_by() -> Result<()> {
390 use ankql::ast::{Identifier, OrderByItem, OrderDirection, Selection};
391
392 let base_selection = ankql::parser::parse_selection("name = 'Alice'").unwrap();
393 let selection = Selection {
394 predicate: base_selection.predicate,
395 order_by: Some(vec![OrderByItem {
396 identifier: Identifier::Property("created_at".to_string()),
397 direction: OrderDirection::Desc,
398 }]),
399 limit: None,
400 };
401
402 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "created_at"]);
403 sql.table_name("users");
404 sql.selection(&selection)?;
405 let (sql_string, args) = sql.build()?;
406
407 assert_eq!(sql_string, r#"SELECT "id", "name", "created_at" FROM "users" WHERE "name" = $1 ORDER BY "created_at" DESC"#);
408 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
409 assert_args(&args, &expected);
410 Ok(())
411 }
412
413 #[test]
414 fn test_selection_with_limit() -> Result<()> {
415 let base_selection = ankql::parser::parse_selection("age > 18").unwrap();
416 let selection = Selection { predicate: base_selection.predicate, order_by: None, limit: Some(10) };
417
418 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
419 sql.table_name("users");
420 sql.selection(&selection)?;
421 let (sql_string, args) = sql.build()?;
422
423 assert_eq!(sql_string, r#"SELECT "id", "name", "age" FROM "users" WHERE "age" > $1 LIMIT $2"#);
424 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new(18i64), Box::new(10i64)];
425 assert_args(&args, &expected);
426 Ok(())
427 }
428
429 #[test]
430 fn test_selection_with_order_by_and_limit() -> Result<()> {
431 use ankql::ast::{Identifier, OrderByItem, OrderDirection, Selection};
432
433 let base_selection = ankql::parser::parse_selection("status = 'active'").unwrap();
434 let selection = Selection {
435 predicate: base_selection.predicate,
436 order_by: Some(vec![
437 OrderByItem { identifier: Identifier::Property("priority".to_string()), direction: OrderDirection::Desc },
438 OrderByItem { identifier: Identifier::Property("created_at".to_string()), direction: OrderDirection::Asc },
439 ]),
440 limit: Some(5),
441 };
442
443 let mut sql = SqlBuilder::with_fields(vec!["id", "status", "priority", "created_at"]);
444 sql.table_name("tasks");
445 sql.selection(&selection)?;
446 let (sql_string, args) = sql.build()?;
447
448 assert_eq!(
449 sql_string,
450 r#"SELECT "id", "status", "priority", "created_at" FROM "tasks" WHERE "status" = $1 ORDER BY "priority" DESC, "created_at" ASC LIMIT $2"#
451 );
452 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("active"), Box::new(5i64)];
453 assert_args(&args, &expected);
454 Ok(())
455 }
456}