ankurah_storage_postgres/
predicate.rs1use ankql::ast::{ComparisonOperator, Expr, Identifier, Literal, Predicate};
2use thiserror::Error;
3use tokio_postgres::types::ToSql;
4
5#[derive(Debug, Error, Clone)]
6pub enum SqlGenerationError {
7 #[error("Placeholder found in predicate - placeholders should be replaced before predicate processing")]
8 PlaceholderFound,
9 #[error("Unsupported expression type: {0}")]
10 UnsupportedExpression(&'static str),
11 #[error("Unsupported operator: {0}")]
12 UnsupportedOperator(&'static str),
13}
14
15pub enum SqlExpr {
16 Sql(String),
17 Argument(Box<dyn ToSql + Send + Sync>),
18}
19
20pub struct Sql(Vec<SqlExpr>);
21
22impl Default for Sql {
23 fn default() -> Self { Self::new() }
24}
25
26impl Sql {
27 pub fn new() -> Self { Self(Vec::new()) }
28
29 pub fn push(&mut self, expr: SqlExpr) { self.0.push(expr); }
30
31 pub fn arg(&mut self, arg: impl ToSql + Send + Sync + 'static) {
32 self.push(SqlExpr::Argument(Box::new(arg) as Box<dyn ToSql + Send + Sync>));
33 }
34
35 pub fn sql(&mut self, s: impl AsRef<str>) { self.push(SqlExpr::Sql(s.as_ref().to_owned())); }
36
37 pub fn collapse(self) -> (String, Vec<Box<dyn ToSql + Send + Sync>>) {
38 let mut counter = 1;
39 let mut sql = String::new();
40 let mut args = Vec::new();
41
42 for expr in self.0 {
43 match expr {
44 SqlExpr::Argument(arg) => {
45 sql += &format!("${}", counter);
46 args.push(arg);
47 counter += 1;
48 }
49 SqlExpr::Sql(s) => {
50 sql += &s;
51 }
52 }
53 }
54
55 (sql, args)
56 }
57
58 pub fn expr(&mut self, expr: &Expr) -> Result<(), SqlGenerationError> {
60 match expr {
61 Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
62 Expr::Literal(lit) => match lit {
63 Literal::String(s) => self.arg(s.to_owned()),
64 Literal::Integer(int) => self.arg(*int),
65 Literal::Float(float) => self.arg(*float),
66 Literal::Boolean(bool) => self.arg(*bool),
67 },
68 Expr::Identifier(id) => match id {
69 Identifier::Property(name) => self.sql(format!(r#""{}""#, name)),
70 Identifier::CollectionProperty(collection, name) => {
71 self.sql(format!(r#""{}"."{}""#, collection, name));
72 }
73 },
74 Expr::ExprList(exprs) => {
75 self.sql("(");
76 for (i, expr) in exprs.iter().enumerate() {
77 if i > 0 {
78 self.sql(", ");
79 }
80 match expr {
81 Expr::Placeholder => return Err(SqlGenerationError::PlaceholderFound),
82 Expr::Literal(lit) => match lit {
83 Literal::String(s) => self.arg(s.to_owned()),
84 Literal::Integer(int) => self.arg(*int),
85 Literal::Float(float) => self.arg(*float),
86 Literal::Boolean(bool) => self.arg(*bool),
87 },
88 _ => {
89 return Err(SqlGenerationError::UnsupportedExpression(
90 "Only literal expressions and placeholders are supported in IN lists",
91 ))
92 }
93 }
94 }
95 self.sql(")");
96 }
97 _ => return Err(SqlGenerationError::UnsupportedExpression("Only literal, identifier, and list expressions are supported")),
98 }
99 Ok(())
100 }
101
102 pub fn comparison_op(&mut self, op: &ComparisonOperator) -> Result<(), SqlGenerationError> { Ok(self.sql(comparison_op_to_sql(op)?)) }
103
104 pub fn predicate(&mut self, predicate: &Predicate) -> Result<(), SqlGenerationError> {
105 match predicate {
106 Predicate::Comparison { left, operator, right } => {
107 self.expr(left)?;
108 self.sql(" ");
109 self.comparison_op(operator);
110 self.sql(" ");
111 self.expr(right)?;
112 }
113 Predicate::And(left, right) => {
114 self.predicate(left)?;
115 self.sql(" AND ");
116 self.predicate(right)?;
117 }
118 Predicate::Or(left, right) => {
119 self.sql("(");
120 self.predicate(left)?;
121 self.sql(" OR ");
122 self.predicate(right)?;
123 self.sql(")");
124 }
125 Predicate::Not(pred) => {
126 self.sql("NOT (");
127 self.predicate(pred)?;
128 self.sql(")");
129 }
130 Predicate::IsNull(expr) => {
131 self.expr(expr)?;
132 self.sql(" IS NULL");
133 }
134 Predicate::True => {
135 self.sql("TRUE");
136 }
137 Predicate::False => {
138 self.sql("FALSE");
139 }
140 Predicate::Placeholder => {
141 return Err(SqlGenerationError::PlaceholderFound);
142 }
143 }
144 Ok(())
145 }
146}
147
148fn comparison_op_to_sql(op: &ComparisonOperator) -> Result<&'static str, SqlGenerationError> {
149 Ok(match op {
150 ComparisonOperator::Equal => "=",
151 ComparisonOperator::NotEqual => "<>",
152 ComparisonOperator::GreaterThan => ">",
153 ComparisonOperator::GreaterThanOrEqual => ">=",
154 ComparisonOperator::LessThan => "<",
155 ComparisonOperator::LessThanOrEqual => "<=",
156 ComparisonOperator::In => "IN",
157 ComparisonOperator::Between => return Err(SqlGenerationError::UnsupportedOperator("BETWEEN operator is not yet supported")),
158 })
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use ankql::parser::parse_selection;
165 use anyhow::Result;
166
167 fn assert_args<'a, 'b>(args: &Vec<Box<dyn ToSql + Send + Sync>>, expected: &Vec<Box<dyn ToSql + Send + Sync>>) {
168 assert_eq!(format!("{:?}", args), format!("{:?}", expected));
170 }
171
172 #[test]
173 fn test_simple_equality() -> Result<()> {
174 let predicate = parse_selection("name = 'Alice'").unwrap();
175 let mut sql = Sql::new();
176 sql.predicate(&predicate)?;
177
178 let (sql_string, args) = sql.collapse();
179 assert_eq!(sql_string, r#""name" = $1"#);
180 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
181 assert_args(&args, &expected);
182 Ok(())
183 }
184
185 #[test]
186 fn test_and_condition() -> Result<()> {
187 let predicate = parse_selection("name = 'Alice' AND age = 30").unwrap();
188 let mut sql = Sql::new();
189 sql.predicate(&predicate)?;
190 let (sql_string, args) = sql.collapse();
191
192 assert_eq!(sql_string, r#""name" = $1 AND "age" = $2"#);
193 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new(30)];
194 assert_args(&args, &expected);
195 Ok(())
196 }
197
198 #[test]
199 fn test_complex_condition() -> Result<()> {
200 let predicate = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= 30 AND age <= 40").unwrap();
201
202 let mut sql = Sql::new();
203 sql.predicate(&predicate)?;
204 let (sql_string, args) = sql.collapse();
205
206 assert_eq!(sql_string, r#"("name" = $1 OR "name" = $2) AND "age" >= $3 AND "age" <= $4"#);
207 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Charlie"), Box::new(30), Box::new(40)];
208 assert_args(&args, &expected);
209 Ok(())
210 }
211
212 #[test]
213 fn test_including_collection_identifier() -> Result<()> {
214 let predicate = parse_selection("person.name = 'Alice'").unwrap();
215
216 let mut sql = Sql::new();
217 sql.predicate(&predicate)?;
218 let (sql_string, args) = sql.collapse();
219
220 assert_eq!(sql_string, r#""person"."name" = $1"#);
221 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice")];
222 assert_args(&args, &expected);
223 Ok(())
224 }
225
226 #[test]
227 fn test_false_predicate() -> Result<()> {
228 let mut sql = Sql::new();
229 sql.predicate(&Predicate::False)?;
230 let (sql_string, args) = sql.collapse();
231
232 assert_eq!(sql_string, "FALSE");
233 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![];
234 assert_args(&args, &expected);
235 Ok(())
236 }
237
238 #[test]
239 fn test_in_operator() -> Result<()> {
240 let predicate = parse_selection("name IN ('Alice', 'Bob', 'Charlie')").unwrap();
241 let mut sql = Sql::new();
242 sql.predicate(&predicate)?;
243 let (sql_string, args) = sql.collapse();
244
245 assert_eq!(sql_string, r#""name" IN ($1, $2, $3)"#);
246 let expected: Vec<Box<dyn ToSql + Send + Sync>> = vec![Box::new("Alice"), Box::new("Bob"), Box::new("Charlie")];
247 assert_args(&args, &expected);
248 Ok(())
249 }
250
251 #[test]
252 fn test_placeholder_error() {
253 let mut sql = Sql::new();
254 let err = sql.predicate(&Predicate::Placeholder).expect_err("Expected an error");
255 assert!(matches!(err, SqlGenerationError::PlaceholderFound));
256 }
257}