ankurah_storage_postgres/
predicate.rs1use ankql::ast::{ComparisonOperator, Expr, Identifier, Literal, Predicate};
2use tokio_postgres::types::ToSql;
3
4pub enum SqlExpr {
5 Sql(String),
6 Argument(Box<dyn ToSql + Send + Sync>),
7}
8
9pub struct Sql(Vec<SqlExpr>);
10
11impl Sql {
12 pub fn new() -> Self { Self(Vec::new()) }
13
14 pub fn push(&mut self, expr: SqlExpr) { self.0.push(expr); }
15
16 pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
17 self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
18 }
19
20 pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
21
22 pub fn collapse(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
23 let mut counter = 1;
24 let mut sql = String::new();
25 let mut args = Vec::new();
26
27 for expr in self.0 {
28 match expr {
29 SqlExpr::Argument(arg) => {
30 sql += &format!("${}", counter);
31 args.push(arg);
32 counter += 1;
33 }
34 SqlExpr::Sql(s) => {
35 sql += &s;
36 }
37 }
38 }
39
40 (sql, args)
41 }
42
43 pub fn expr(&mut self, expr: &Expr) {
45 match expr {
46 Expr::Literal(lit) => match lit {
47 Literal::String(s) => self.arg(s.to_owned()),
48 Literal::Integer(int) => self.arg(*int),
49 Literal::Float(float) => self.arg(*float),
50 Literal::Boolean(bool) => self.arg(*bool),
51 },
52 Expr::Identifier(id) => match id {
53 Identifier::Property(name) => self.sql(format!(r#""{}""#, name)),
54 Identifier::CollectionProperty(collection, name) => {
55 self.sql(format!(r#""{}"."{}""#, collection, name));
56 }
57 },
58 _ => unimplemented!("Only literal and identifier expressions are supported"),
59 }
60 }
61
62 pub fn comparison_op(&mut self, op: &ComparisonOperator) { self.sql(comparison_op_to_sql(op)); }
63
64 pub fn predicate(&mut self, predicate: &Predicate) {
65 match predicate {
66 Predicate::Comparison { left, operator, right } => {
67 self.expr(left);
68 self.sql(" ");
69 self.comparison_op(operator);
70 self.sql(" ");
71 self.expr(right);
72 }
73 Predicate::And(left, right) => {
74 self.predicate(left);
75 self.sql(" AND ");
76 self.predicate(right);
77 }
78 Predicate::Or(left, right) => {
79 self.sql("(");
80 self.predicate(left);
81 self.sql(" OR ");
82 self.predicate(right);
83 self.sql(")");
84 }
85 Predicate::Not(pred) => {
86 self.sql("NOT (");
87 self.predicate(pred);
88 self.sql("NOT )");
89 }
90 Predicate::IsNull(expr) => {
91 self.expr(expr);
92 self.sql(" IS NULL");
93 }
94 Predicate::True => {
95 self.sql("TRUE");
96 }
97 Predicate::False => {
98 self.sql("FALSE");
99 }
100 }
101 }
102}
103
104fn comparison_op_to_sql(op: &ComparisonOperator) -> &'static str {
105 match op {
106 ComparisonOperator::Equal => "=",
107 ComparisonOperator::NotEqual => "<>",
108 ComparisonOperator::GreaterThan => ">",
109 ComparisonOperator::GreaterThanOrEqual => ">=",
110 ComparisonOperator::LessThan => "<",
111 ComparisonOperator::LessThanOrEqual => "<=",
112 _ => unimplemented!("Only basic comparison operators are supported"),
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use ankql::parser::parse_selection;
120
121 fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
122 assert_eq!(format!("{:?}", args), format!("{:?}", expected));
124 }
125
126 #[test]
127 fn test_simple_equality() {
128 let predicate = parse_selection("name = 'Alice'").unwrap();
129 let mut sql = Sql::new();
130 sql.predicate(&predicate);
131
132 let (sql_string, args) = sql.collapse();
133 assert_eq!(sql_string, r#""name" = $1"#);
134 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
135 assert_args(&args, &expected);
136 }
137
138 #[test]
139 fn test_and_condition() {
140 let predicate = parse_selection("name = 'Alice' AND age = 30").unwrap();
141 let mut sql = Sql::new();
142 sql.predicate(&predicate);
143 let (sql_string, args) = sql.collapse();
144
145 assert_eq!(sql_string, r#""name" = $1 AND "age" = $2"#);
146 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
147 assert_args(&args, &expected);
148 }
149
150 #[test]
151 fn test_complex_condition() {
152 let predicate = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
153
154 let mut sql = Sql::new();
155 sql.predicate(&predicate);
156 let (sql_string, args) = sql.collapse();
157
158 assert_eq!(sql_string, r#"("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#);
159 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
160 assert_args(&args, &expected);
161 }
162
163 #[test]
164 fn test_including_collection_identifier() {
165 let predicate = parse_selection("person.name = 'Alice'").unwrap();
166
167 let mut sql = Sql::new();
168 sql.predicate(&predicate);
169 let (sql_string, args) = sql.collapse();
170
171 assert_eq!(sql_string, r#""person"."name" = $1"#);
172 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
173 assert_args(&args, &expected);
174 }
175
176 #[test]
177 fn test_false_predicate() {
178 let mut sql = Sql::new();
179 sql.predicate(&Predicate::False);
180 let (sql_string, args) = sql.collapse();
181
182 assert_eq!(sql_string, "FALSE");
183 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
184 assert_args(&args, &expected);
185 }
186}