1use ankql::ast::{ComparisonOperator, Expr, Literal, OrderByItem, OrderDirection, Predicate, Selection};
2use ankurah_core::{error::RetrievalError, EntityId};
3use thiserror::Error;
4use tokio_postgres::types::ToSql;
5
6#[derive(Debug, Error, Clone)]
7pub enum SqlGenerationError {
8 #[error("Placeholder found in predicate - placeholders should be replaced before predicate processing")]
9 PlaceholderFound,
10 #[error("Unsupported expression type: {0}")]
11 UnsupportedExpression(&'static str),
12 #[error("Unsupported operator: {0}")]
13 UnsupportedOperator(&'static str),
14 #[error("SqlBuilder requires both fields and table_name to be set for complete SELECT generation, or neither for WHERE-only mode")]
15 IncompleteConfiguration,
16}
17
18#[derive(Debug, Clone)]
25pub struct SplitPredicate {
26 pub sql_predicate: Predicate,
28 pub remaining_predicate: Predicate,
30}
31
32impl SplitPredicate {
33 pub fn needs_post_filter(&self) -> bool { !matches!(self.remaining_predicate, Predicate::True) }
35}
36
37pub fn split_predicate_for_postgres(predicate: &Predicate) -> SplitPredicate {
48 let (sql_pred, remaining_pred) = split_predicate_recursive(predicate);
53
54 SplitPredicate { sql_predicate: sql_pred, remaining_predicate: remaining_pred }
55}
56
57fn split_predicate_recursive(predicate: &Predicate) -> (Predicate, Predicate) {
59 match predicate {
60 Predicate::Comparison { left, operator: _, right } => {
62 if can_pushdown_comparison(left, right) {
63 (predicate.clone(), Predicate::True)
64 } else {
65 (Predicate::True, predicate.clone())
67 }
68 }
69
70 Predicate::And(left, right) => {
72 let (left_sql, left_remaining) = split_predicate_recursive(left);
73 let (right_sql, right_remaining) = split_predicate_recursive(right);
74
75 let sql_pred = match (&left_sql, &right_sql) {
76 (Predicate::True, Predicate::True) => Predicate::True,
77 (Predicate::True, _) => right_sql,
78 (_, Predicate::True) => left_sql,
79 _ => Predicate::And(Box::new(left_sql), Box::new(right_sql)),
80 };
81
82 let remaining_pred = match (&left_remaining, &right_remaining) {
83 (Predicate::True, Predicate::True) => Predicate::True,
84 (Predicate::True, _) => right_remaining,
85 (_, Predicate::True) => left_remaining,
86 _ => Predicate::And(Box::new(left_remaining), Box::new(right_remaining)),
87 };
88
89 (sql_pred, remaining_pred)
90 }
91
92 Predicate::Or(left, right) => {
95 let (left_sql, left_remaining) = split_predicate_recursive(left);
96 let (right_sql, right_remaining) = split_predicate_recursive(right);
97
98 if matches!(left_remaining, Predicate::True) && matches!(right_remaining, Predicate::True) {
100 (predicate.clone(), Predicate::True)
101 } else {
102 let sql_pred = match (&left_sql, &right_sql) {
105 (Predicate::True, Predicate::True) => Predicate::True,
106 (Predicate::True, _) => right_sql,
107 (_, Predicate::True) => left_sql,
108 _ => Predicate::Or(Box::new(left_sql), Box::new(right_sql)),
109 };
110 (sql_pred, predicate.clone())
111 }
112 }
113
114 Predicate::Not(inner) => {
116 let (inner_sql, inner_remaining) = split_predicate_recursive(inner);
117 if matches!(inner_remaining, Predicate::True) {
118 (Predicate::Not(Box::new(inner_sql)), Predicate::True)
119 } else {
120 (Predicate::True, predicate.clone())
122 }
123 }
124
125 Predicate::IsNull(expr) => {
127 if can_pushdown_expr(expr) {
128 (predicate.clone(), Predicate::True)
129 } else {
130 (Predicate::True, predicate.clone())
131 }
132 }
133
134 Predicate::True => (Predicate::True, Predicate::True),
135 Predicate::False => (Predicate::False, Predicate::True),
136 Predicate::Placeholder => (Predicate::True, predicate.clone()), }
138}
139
140fn can_pushdown_comparison(left: &Expr, right: &Expr) -> bool { can_pushdown_expr(left) && can_pushdown_expr(right) }
142
143fn can_pushdown_expr(expr: &Expr) -> bool {
165 match expr {
166 Expr::Literal(_) => true,
167 Expr::Path(path) => {
168 !path.steps.is_empty()
176 }
177 Expr::ExprList(exprs) => exprs.iter().all(can_pushdown_expr),
178 Expr::Predicate(_) => false, Expr::InfixExpr { .. } => false, Expr::Placeholder => false, }
182}
183
184impl From<SqlGenerationError> for RetrievalError {
185 fn from(err: SqlGenerationError) -> Self { RetrievalError::StorageError(Box::new(err)) }
186}
187
188pub enum SqlExpr {
189 Sql(String),
190 Argument(Box<dyn ToSql + Send + Sync>),
191}
192
193pub struct SqlBuilder {
194 expressions: Vec<SqlExpr>,
195 fields: Vec<String>,
196 table_name: Option<String>,
197}
198
199impl Default for SqlBuilder {
200 fn default() -> Self { Self::new() }
201}
202
203impl SqlBuilder {
204 pub fn new() -> Self { Self { expressions: Vec::new(), fields: Vec::new(), table_name: None } }
205
206 pub fn with_fields<T: Into<String>>(fields: Vec<T>) -> Self {
207 Self { expressions: Vec::new(), fields: fields.into_iter().map(|f| f.into()).collect(), table_name: None }
208 }
209
210 pub fn table_name(&mut self, name: impl Into<String>) -> &mut Self {
211 self.table_name = Some(name.into());
212 self
213 }
214
215 pub fn push(&mut self, expr: SqlExpr) { self.expressions.push(expr); }
216
217 pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
218 self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
219 }
220
221 pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
222
223 pub fn build(self) -> Result<(String, Vec<Box<dyn ToSql + Send + Sync>>), SqlGenerationError> {
224 let mut counter = 1;
225 let mut where_clause = String::new();
226 let mut args = Vec::new();
227
228 for expr in self.expressions {
230 match expr {
231 SqlExpr::Argument(arg) => {
232 where_clause += &format!("${}", counter);
233 args.push(arg);
234 counter += 1;
235 }
236 SqlExpr::Sql(s) => {
237 where_clause += &s;
238 }
239 }
240 }
241
242 if self.fields.is_empty() || self.table_name.is_none() {
244 return Err(SqlGenerationError::IncompleteConfiguration);
245 }
246
247 let fields_clause = self.fields.iter().map(|field| format!(r#""{}""#, field.replace('"', "\"\""))).collect::<Vec<_>>().join(", ");
248 let table = self.table_name.unwrap();
249 let sql = format!(r#"SELECT {} FROM "{}" WHERE {}"#, fields_clause, table.replace('"', "\"\""), where_clause);
250
251 Ok((sql, args))
252 }
253
254 pub fn build_where_clause(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
255 let mut counter = 1;
256 let mut where_clause = String::new();
257 let mut args = Vec::new();
258
259 for expr in self.expressions {
261 match expr {
262 SqlExpr::Argument(arg) => {
263 where_clause += &format!("${}", counter);
264 args.push(arg);
265 counter += 1;
266 }
267 SqlExpr::Sql(s) => {
268 where_clause += &s;
269 }
270 }
271 }
272
273 (where_clause, args)
274 }
275
276 pub fn expr(&mut self, expr: &Expr) -> Result<(), SqlGenerationError> {
278 match expr {
279 Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
280 Expr::Literal(lit) => match lit {
281 Literal::String(s) => self.arg(s.to_owned()),
282 Literal::I64(int) => self.arg(*int),
283 Literal::F64(float) => self.arg(*float),
284 Literal::Bool(bool) => self.arg(*bool),
285 Literal::I16(i) => self.arg(*i),
286 Literal::I32(i) => self.arg(*i),
287 Literal::EntityId(ulid) => self.arg(EntityId::from_ulid(*ulid).to_base64()),
288 Literal::Object(bytes) => self.arg(bytes.clone()),
289 Literal::Binary(bytes) => self.arg(bytes.clone()),
290 },
291 Expr::Path(path) => {
292 if path.is_simple() {
293 let escaped = path.first().replace('"', "\"\"");
295 self.sql(format!(r#""{}""#, escaped));
296 } else {
297 let first = path.first().replace('"', "\"\"");
301 self.sql(format!(r#""{}""#, first));
302
303 for step in path.steps.iter().skip(1) {
304 let escaped = step.replace('\'', "''");
305 self.sql(format!("->'{}'", escaped));
307 }
308 }
309 }
310 Expr::ExprList(exprs) => {
311 self.sql("(");
312 for (i, expr) in exprs.iter().enumerate() {
313 if i > 0 {
314 self.sql(", ");
315 }
316 match expr {
317 Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
318 Expr::Literal(lit) => match lit {
319 Literal::String(s) => self.arg(s.to_owned()),
320 Literal::I64(int) => self.arg(*int),
321 Literal::F64(float) => self.arg(*float),
322 Literal::Bool(bool) => self.arg(*bool),
323 Literal::I16(i) => self.arg(*i),
324 Literal::I32(i) => self.arg(*i),
325 Literal::EntityId(ulid) => self.arg(EntityId::from_ulid(*ulid).to_base64()),
326 Literal::Object(bytes) => self.arg(bytes.clone()),
327 Literal::Binary(bytes) => self.arg(bytes.clone()),
328 },
329 _ => {
330 return Err(SqlGenerationError::UnsupportedExpression(
331 "Only literal expressions and placeholders are supported in IN lists",
332 ))
333 }
334 }
335 }
336 self.sql(")");
337 }
338 _ => return Err(SqlGenerationError::UnsupportedExpression("Only literal, identifier, and list expressions are supported")),
339 }
340 Ok(())
341 }
342
343 pub fn expr_as_jsonb(&mut self, expr: &Expr) -> Result<(), SqlGenerationError> {
347 match expr {
348 Expr::Literal(lit) => {
349 match lit {
352 Literal::String(s) => {
353 let json_escaped = s.replace('\\', "\\\\").replace('"', "\\\"");
356 let sql_escaped = format!("\"{}\"", json_escaped).replace('\'', "''");
357 self.sql(format!("'{}'::jsonb", sql_escaped));
358 }
359 Literal::I64(n) => self.sql(format!("'{}'::jsonb", n)),
360 Literal::F64(n) => self.sql(format!("'{}'::jsonb", n)),
361 Literal::Bool(b) => self.sql(format!("'{}'::jsonb", b)),
362 Literal::I16(n) => self.sql(format!("'{}'::jsonb", n)),
363 Literal::I32(n) => self.sql(format!("'{}'::jsonb", n)),
364 Literal::EntityId(_) | Literal::Object(_) | Literal::Binary(_) => {
366 self.expr(expr)?;
368 }
369 }
370 Ok(())
371 }
372 _ => self.expr(expr),
374 }
375 }
376
377 pub fn comparison_op(&mut self, op: &ComparisonOperator) -> Result<(), SqlGenerationError> {
378 self.sql(comparison_op_to_sql(op)?);
379 Ok(())
380 }
381
382 pub fn predicate(&mut self, predicate: &Predicate) -> Result<(), SqlGenerationError> {
383 match predicate {
384 Predicate::Comparison { left, operator, right } => {
385 let left_is_jsonb = matches!(left.as_ref(), Expr::Path(p) if !p.is_simple());
387 let right_is_jsonb = matches!(right.as_ref(), Expr::Path(p) if !p.is_simple());
388
389 self.expr(left)?;
390 self.sql(" ");
391 self.comparison_op(operator)?;
392 self.sql(" ");
393
394 if left_is_jsonb && matches!(right.as_ref(), Expr::Literal(_)) {
395 self.expr_as_jsonb(right)?;
397 } else if right_is_jsonb && matches!(left.as_ref(), Expr::Literal(_)) {
398 self.expr_as_jsonb(right)?;
400 } else {
401 self.expr(right)?;
402 }
403 }
404 Predicate::And(left, right) => {
405 self.predicate(left)?;
406 self.sql(" AND ");
407 self.predicate(right)?;
408 }
409 Predicate::Or(left, right) => {
410 self.sql("(");
411 self.predicate(left)?;
412 self.sql(" OR ");
413 self.predicate(right)?;
414 self.sql(")");
415 }
416 Predicate::Not(pred) => {
417 self.sql("NOT (");
418 self.predicate(pred)?;
419 self.sql(")");
420 }
421 Predicate::IsNull(expr) => {
422 self.expr(expr)?;
423 self.sql(" IS NULL");
424 }
425 Predicate::True => {
426 self.sql("TRUE");
427 }
428 Predicate::False => {
429 self.sql("FALSE");
430 }
431 Predicate::Placeholder => {
432 return Err(SqlGenerationError::PlaceholderFound);
433 }
434 }
435 Ok(())
436 }
437
438 pub fn selection(&mut self, selection: &Selection) -> Result<(), SqlGenerationError> {
439 self.predicate(&selection.predicate)?;
441
442 if let Some(order_by_items) = &selection.order_by {
444 self.sql(" ORDER BY ");
445 for (i, order_by) in order_by_items.iter().enumerate() {
446 if i > 0 {
447 self.sql(", ");
448 }
449 self.order_by_item(order_by)?;
450 }
451 }
452
453 if let Some(limit) = selection.limit {
455 self.sql(" LIMIT ");
456 self.arg(limit as i64); }
458
459 Ok(())
460 }
461
462 pub fn order_by_item(&mut self, order_by: &OrderByItem) -> Result<(), SqlGenerationError> {
463 for (i, step) in order_by.path.steps.iter().enumerate() {
465 if i > 0 {
466 self.sql(".");
467 }
468 let escaped_step = step.replace('"', "\"\"");
470 self.sql(format!(r#""{}""#, escaped_step));
471 }
472
473 match order_by.direction {
475 OrderDirection::Asc => self.sql(" ASC"),
476 OrderDirection::Desc => self.sql(" DESC"),
477 }
478
479 Ok(())
480 }
481}
482
483fn comparison_op_to_sql(op: &ComparisonOperator) -> Result<&'static str, SqlGenerationError> {
484 Ok(match op {
485 ComparisonOperator::Equal => "=",
486 ComparisonOperator::NotEqual => "<>",
487 ComparisonOperator::GreaterThan => ">",
488 ComparisonOperator::GreaterThanOrEqual => ">=",
489 ComparisonOperator::LessThan => "<",
490 ComparisonOperator::LessThanOrEqual => "<=",
491 ComparisonOperator::In => "IN",
492 ComparisonOperator::Between => return Err(SqlGenerationError::UnsupportedOperator("BETWEEN operator is not yet supported")),
493 })
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use ankql::parser::parse_selection;
500 use anyhow::Result;
501
502 fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
503 assert_eq!(format!("{:?}", args), format!("{:?}", expected));
505 }
506
507 #[test]
508 fn test_simple_equality() -> Result<()> {
509 let selection = parse_selection("name = 'Alice'").unwrap();
510 let mut sql = SqlBuilder::new();
511 sql.selection(&selection)?;
512
513 let (sql_string, args) = sql.build_where_clause();
514 assert_eq!(sql_string, r#""name" = $1"#);
515 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
516 assert_args(&args, &expected);
517 Ok(())
518 }
519
520 #[test]
521 fn test_and_condition() -> Result<()> {
522 let selection = parse_selection("name = 'Alice' AND age = 30").unwrap();
523 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
524 sql.table_name("users");
525 sql.selection(&selection)?;
526 let (sql_string, args) = sql.build()?;
527
528 assert_eq!(sql_string, r#"SELECT "id", "name", "age" FROM "users" WHERE "name" = $1 AND "age" = $2"#);
529 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
530 assert_args(&args, &expected);
531 Ok(())
532 }
533
534 #[test]
535 fn test_complex_condition() -> Result<()> {
536 let selection = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
537
538 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
539 sql.table_name("users");
540 sql.selection(&selection)?;
541 let (sql_string, args) = sql.build()?;
542
543 assert_eq!(
544 sql_string,
545 r#"SELECT "id", "name", "age" FROM "users" WHERE ("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#
546 );
547 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
548 assert_args(&args, &expected);
549 Ok(())
550 }
551
552 #[test]
553 fn test_including_collection_identifier() -> Result<()> {
554 let selection = parse_selection("person.name = 'Alice'").unwrap();
559
560 let mut sql = SqlBuilder::with_fields(vec!["id", "name"]);
561 sql.table_name("people");
562 sql.selection(&selection)?;
563 let (sql_string, args) = sql.build()?;
564
565 assert_eq!(sql_string, r#"SELECT "id", "name" FROM "people" WHERE "person"->'name' = '"Alice"'::jsonb"#);
567 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
569 assert_args(&args, &expected);
570 Ok(())
571 }
572
573 #[test]
574 fn test_false_predicate() -> Result<()> {
575 let mut sql = SqlBuilder::with_fields(vec!["id"]);
576 sql.table_name("test");
577 sql.predicate(&Predicate::False)?;
578 let (sql_string, args) = sql.build()?;
579
580 assert_eq!(sql_string, r#"SELECT "id" FROM "test" WHERE FALSE"#);
581 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
582 assert_args(&args, &expected);
583 Ok(())
584 }
585
586 #[test]
587 fn test_in_operator() -> Result<()> {
588 let selection = parse_selection("name IN ('Alice', 'Bob', 'Charlie')").unwrap();
589 let mut sql = SqlBuilder::with_fields(vec!["id", "name"]);
590 sql.table_name("users");
591 sql.selection(&selection)?;
592 let (sql_string, args) = sql.build()?;
593
594 assert_eq!(sql_string, r#"SELECT "id", "name" FROM "users" WHERE "name" IN ($1, $2, $3)"#);
595 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Bob"), Box::new("Charlie")];
596 assert_args(&args, &expected);
597 Ok(())
598 }
599
600 #[test]
601 fn test_placeholder_error() {
602 let mut sql = SqlBuilder::with_fields(vec!["id"]);
603 sql.table_name("test");
604 let err = sql.predicate(&Predicate::Placeholder).expect_err("Expected an error");
605 assert!(matches!(err, SqlGenerationError::PlaceholderFound));
606 }
607
608 #[test]
609 fn test_selection_with_order_by() -> Result<()> {
610 use ankql::ast::{OrderByItem, OrderDirection, PathExpr, Selection};
611
612 let base_selection = ankql::parser::parse_selection("name = 'Alice'").unwrap();
613 let selection = Selection {
614 predicate: base_selection.predicate,
615 order_by: Some(vec![OrderByItem { path: PathExpr::simple("created_at"), direction: OrderDirection::Desc }]),
616 limit: None,
617 };
618
619 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "created_at"]);
620 sql.table_name("users");
621 sql.selection(&selection)?;
622 let (sql_string, args) = sql.build()?;
623
624 assert_eq!(sql_string, r#"SELECT "id", "name", "created_at" FROM "users" WHERE "name" = $1 ORDER BY "created_at" DESC"#);
625 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
626 assert_args(&args, &expected);
627 Ok(())
628 }
629
630 #[test]
631 fn test_selection_with_limit() -> Result<()> {
632 let base_selection = ankql::parser::parse_selection("age > 18").unwrap();
633 let selection = Selection { predicate: base_selection.predicate, order_by: None, limit: Some(10) };
634
635 let mut sql = SqlBuilder::with_fields(vec!["id", "name", "age"]);
636 sql.table_name("users");
637 sql.selection(&selection)?;
638 let (sql_string, args) = sql.build()?;
639
640 assert_eq!(sql_string, r#"SELECT "id", "name", "age" FROM "users" WHERE "age" > $1 LIMIT $2"#);
641 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new(18i64), Box::new(10i64)];
642 assert_args(&args, &expected);
643 Ok(())
644 }
645
646 #[test]
647 fn test_selection_with_order_by_and_limit() -> Result<()> {
648 use ankql::ast::{OrderByItem, OrderDirection, PathExpr, Selection};
649
650 let base_selection = ankql::parser::parse_selection("status = 'active'").unwrap();
651 let selection = Selection {
652 predicate: base_selection.predicate,
653 order_by: Some(vec![
654 OrderByItem { path: PathExpr::simple("priority"), direction: OrderDirection::Desc },
655 OrderByItem { path: PathExpr::simple("created_at"), direction: OrderDirection::Asc },
656 ]),
657 limit: Some(5),
658 };
659
660 let mut sql = SqlBuilder::with_fields(vec!["id", "status", "priority", "created_at"]);
661 sql.table_name("tasks");
662 sql.selection(&selection)?;
663 let (sql_string, args) = sql.build()?;
664
665 assert_eq!(
666 sql_string,
667 r#"SELECT "id", "status", "priority", "created_at" FROM "tasks" WHERE "status" = $1 ORDER BY "priority" DESC, "created_at" ASC LIMIT $2"#
668 );
669 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("active"), Box::new(5i64)];
670 assert_args(&args, &expected);
671 Ok(())
672 }
673
674 mod jsonb_sql_tests {
684 use super::*;
685 use ankql::ast::PathExpr;
686
687 #[test]
688 fn test_two_step_json_path() -> Result<()> {
689 let selection = parse_selection("licensing.territory = 'US'").unwrap();
691 let mut sql = SqlBuilder::new();
692 sql.selection(&selection)?;
693 let (sql_string, _) = sql.build_where_clause();
694
695 assert_eq!(sql_string, r#""licensing"->'territory' = '"US"'::jsonb"#);
697 Ok(())
698 }
699
700 #[test]
701 fn test_three_step_json_path() -> Result<()> {
702 let selection = parse_selection("licensing.rights.holder = 'Label'").unwrap();
704 let mut sql = SqlBuilder::new();
705 sql.selection(&selection)?;
706 let (sql_string, _) = sql.build_where_clause();
707
708 assert_eq!(sql_string, r#""licensing"->'rights'->'holder' = '"Label"'::jsonb"#);
709 Ok(())
710 }
711
712 #[test]
713 fn test_four_step_json_path() -> Result<()> {
714 let selection = parse_selection("a.b.c.d = 'value'").unwrap();
716 let mut sql = SqlBuilder::new();
717 sql.selection(&selection)?;
718 let (sql_string, _) = sql.build_where_clause();
719
720 assert_eq!(sql_string, r#""a"->'b'->'c'->'d' = '"value"'::jsonb"#);
721 Ok(())
722 }
723
724 #[test]
725 fn test_json_path_with_numeric_comparison() -> Result<()> {
726 let selection = parse_selection("data.count > 10").unwrap();
731 let mut sql = SqlBuilder::new();
732 sql.selection(&selection)?;
733 let (sql_string, _) = sql.build_where_clause();
734
735 assert_eq!(sql_string, r#""data"->'count' > '10'::jsonb"#);
736 Ok(())
737 }
738
739 #[test]
740 fn test_mixed_simple_and_json_paths() -> Result<()> {
741 let selection = parse_selection("name = 'test' AND data.status = 'active'").unwrap();
744 let mut sql = SqlBuilder::new();
745 sql.selection(&selection)?;
746 let (sql_string, _) = sql.build_where_clause();
747
748 assert_eq!(sql_string, r#""name" = $1 AND "data"->'status' = '"active"'::jsonb"#);
749 Ok(())
750 }
751
752 #[test]
753 fn test_json_path_escaping() -> Result<()> {
754 let mut sql = SqlBuilder::new();
757 let path = PathExpr { steps: vec!["data".to_string(), "it's".to_string()] };
758 sql.expr(&Expr::Path(path))?;
759 let (sql_string, _) = sql.build_where_clause();
760
761 assert_eq!(sql_string, r#""data"->'it''s'"#);
763 Ok(())
764 }
765
766 #[test]
767 fn test_json_path_with_boolean() -> Result<()> {
768 let selection = parse_selection("data.active = true").unwrap();
769 let mut sql = SqlBuilder::new();
770 sql.selection(&selection)?;
771 let (sql_string, _) = sql.build_where_clause();
772
773 assert_eq!(sql_string, r#""data"->'active' = 'true'::jsonb"#);
774 Ok(())
775 }
776
777 #[test]
778 fn test_json_path_with_float() -> Result<()> {
779 let selection = parse_selection("data.score >= 95").unwrap();
781 let mut sql = SqlBuilder::new();
782 sql.selection(&selection)?;
783 let (sql_string, _) = sql.build_where_clause();
784
785 assert_eq!(sql_string, r#""data"->'score' >= '95'::jsonb"#);
786 Ok(())
787 }
788 }
789
790 mod predicate_split_tests {
795 use super::*;
796
797 #[test]
798 fn test_simple_predicate_fully_pushable() {
799 let selection = parse_selection("name = 'Alice'").unwrap();
800 let split = split_predicate_for_postgres(&selection.predicate);
801
802 assert!(!split.needs_post_filter());
804 assert!(matches!(split.remaining_predicate, Predicate::True));
805 }
806
807 #[test]
808 fn test_json_path_predicate_pushable() {
809 let selection = parse_selection("licensing.territory = 'US'").unwrap();
814 let split = split_predicate_for_postgres(&selection.predicate);
815
816 assert!(!split.needs_post_filter());
818 }
819
820 #[test]
821 fn test_and_with_all_pushable() {
822 let selection = parse_selection("name = 'test' AND licensing.status = 'active'").unwrap();
823 let split = split_predicate_for_postgres(&selection.predicate);
824
825 assert!(!split.needs_post_filter());
827 }
828
829 #[test]
830 fn test_or_with_all_pushable() {
831 let selection = parse_selection("name = 'a' OR name = 'b'").unwrap();
832 let split = split_predicate_for_postgres(&selection.predicate);
833
834 assert!(!split.needs_post_filter());
836 }
837
838 #[test]
839 fn test_complex_nested_predicate() {
840 let selection = parse_selection("(name = 'test' OR data.type = 'special') AND status = 'active'").unwrap();
841 let split = split_predicate_for_postgres(&selection.predicate);
842
843 assert!(!split.needs_post_filter());
845 }
846
847 #[test]
848 fn test_not_predicate_pushable() {
849 let selection = parse_selection("NOT (status = 'deleted')").unwrap();
850 let split = split_predicate_for_postgres(&selection.predicate);
851
852 assert!(!split.needs_post_filter());
853 }
854
855 #[test]
856 fn test_is_null_pushable() {
857 let selection = parse_selection("name IS NULL").unwrap();
858 let split = split_predicate_for_postgres(&selection.predicate);
859
860 assert!(!split.needs_post_filter());
861 }
862
863 }
873}