ankurah_storage_postgres/
predicate.rs1use ankql::ast::{ComparisonOperator, Expr, Identifier, Literal, Predicate};
2use tokio_postgres::types::ToSql;
3
4pub enum SqlExpr {
5 Sql(String),
6 Argument(Box<dyn ToSql + Send + Sync>),
7}
8
9pub struct Sql(Vec<SqlExpr>);
10
11impl Default for Sql {
12 fn default() -> Self { Self::new() }
13}
14
15impl Sql {
16 pub fn new() -> Self { Self(Vec::new()) }
17
18 pub fn push(&mut self, expr: SqlExpr) { self.0.push(expr); }
19
20 pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
21 self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
22 }
23
24 pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
25
26 pub fn collapse(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
27 let mut counter = 1;
28 let mut sql = String::new();
29 let mut args = Vec::new();
30
31 for expr in self.0 {
32 match expr {
33 SqlExpr::Argument(arg) => {
34 sql += &format!("${}", counter);
35 args.push(arg);
36 counter += 1;
37 }
38 SqlExpr::Sql(s) => {
39 sql += &s;
40 }
41 }
42 }
43
44 (sql, args)
45 }
46
47 pub fn expr(&mut self, expr: &Expr) {
49 match expr {
50 Expr::Literal(lit) => match lit {
51 Literal::String(s) => self.arg(s.to_owned()),
52 Literal::Integer(int) => self.arg(*int),
53 Literal::Float(float) => self.arg(*float),
54 Literal::Boolean(bool) => self.arg(*bool),
55 },
56 Expr::Identifier(id) => match id {
57 Identifier::Property(name) => self.sql(format!(r#""{}""#, name)),
58 Identifier::CollectionProperty(collection, name) => {
59 self.sql(format!(r#""{}"."{}""#, collection, name));
60 }
61 },
62 Expr::ExprList(exprs) => {
63 self.sql("(");
64 for (i, expr) in exprs.iter().enumerate() {
65 if i > 0 {
66 self.sql(", ");
67 }
68 match expr {
69 Expr::Literal(lit) => match lit {
70 Literal::String(s) => self.arg(s.to_owned()),
71 Literal::Integer(int) => self.arg(*int),
72 Literal::Float(float) => self.arg(*float),
73 Literal::Boolean(bool) => self.arg(*bool),
74 },
75 _ => unimplemented!("Only literal expressions are supported in IN lists"),
76 }
77 }
78 self.sql(")");
79 }
80 _ => unimplemented!("Only literal, identifier, and list expressions are supported"),
81 }
82 }
83
84 pub fn comparison_op(&mut self, op: &ComparisonOperator) { self.sql(comparison_op_to_sql(op)); }
85
86 pub fn predicate(&mut self, predicate: &Predicate) {
87 match predicate {
88 Predicate::Comparison { left, operator, right } => {
89 self.expr(left);
90 self.sql(" ");
91 self.comparison_op(operator);
92 self.sql(" ");
93 self.expr(right);
94 }
95 Predicate::And(left, right) => {
96 self.predicate(left);
97 self.sql(" AND ");
98 self.predicate(right);
99 }
100 Predicate::Or(left, right) => {
101 self.sql("(");
102 self.predicate(left);
103 self.sql(" OR ");
104 self.predicate(right);
105 self.sql(")");
106 }
107 Predicate::Not(pred) => {
108 self.sql("NOT (");
109 self.predicate(pred);
110 self.sql("NOT )");
111 }
112 Predicate::IsNull(expr) => {
113 self.expr(expr);
114 self.sql(" IS NULL");
115 }
116 Predicate::True => {
117 self.sql("TRUE");
118 }
119 Predicate::False => {
120 self.sql("FALSE");
121 }
122 }
123 }
124}
125
126fn comparison_op_to_sql(op: &ComparisonOperator) -> &'static str {
127 match op {
128 ComparisonOperator::Equal => "=",
129 ComparisonOperator::NotEqual => "<>",
130 ComparisonOperator::GreaterThan => ">",
131 ComparisonOperator::GreaterThanOrEqual => ">=",
132 ComparisonOperator::LessThan => "<",
133 ComparisonOperator::LessThanOrEqual => "<=",
134 ComparisonOperator::In => "IN",
135 ComparisonOperator::Between => unimplemented!("BETWEEN operator is not yet supported"),
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use ankql::parser::parse_selection;
143
144 fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
145 assert_eq!(format!("{:?}", args), format!("{:?}", expected));
147 }
148
149 #[test]
150 fn test_simple_equality() {
151 let predicate = parse_selection("name = 'Alice'").unwrap();
152 let mut sql = Sql::new();
153 sql.predicate(&predicate);
154
155 let (sql_string, args) = sql.collapse();
156 assert_eq!(sql_string, r#""name" = $1"#);
157 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
158 assert_args(&args, &expected);
159 }
160
161 #[test]
162 fn test_and_condition() {
163 let predicate = parse_selection("name = 'Alice' AND age = 30").unwrap();
164 let mut sql = Sql::new();
165 sql.predicate(&predicate);
166 let (sql_string, args) = sql.collapse();
167
168 assert_eq!(sql_string, r#""name" = $1 AND "age" = $2"#);
169 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
170 assert_args(&args, &expected);
171 }
172
173 #[test]
174 fn test_complex_condition() {
175 let predicate = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
176
177 let mut sql = Sql::new();
178 sql.predicate(&predicate);
179 let (sql_string, args) = sql.collapse();
180
181 assert_eq!(sql_string, r#"("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#);
182 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
183 assert_args(&args, &expected);
184 }
185
186 #[test]
187 fn test_including_collection_identifier() {
188 let predicate = parse_selection("person.name = 'Alice'").unwrap();
189
190 let mut sql = Sql::new();
191 sql.predicate(&predicate);
192 let (sql_string, args) = sql.collapse();
193
194 assert_eq!(sql_string, r#""person"."name" = $1"#);
195 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
196 assert_args(&args, &expected);
197 }
198
199 #[test]
200 fn test_false_predicate() {
201 let mut sql = Sql::new();
202 sql.predicate(&Predicate::False);
203 let (sql_string, args) = sql.collapse();
204
205 assert_eq!(sql_string, "FALSE");
206 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
207 assert_args(&args, &expected);
208 }
209
210 #[test]
211 fn test_in_operator() {
212 let predicate = parse_selection("name IN ('Alice', 'Bob', 'Charlie')").unwrap();
213 let mut sql = Sql::new();
214 sql.predicate(&predicate);
215 let (sql_string, args) = sql.collapse();
216
217 assert_eq!(sql_string, r#""name" IN ($1, $2, $3)"#);
218 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Bob"), Box::new("Charlie")];
219 assert_args(&args, &expected);
220 }
221}