1use crate::error::ParseError;
2use crate::selection::sql::generate_selection_sql;
3use serde::{Deserialize, Serialize};
4use ulid::Ulid;
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub enum Expr {
8 Literal(Literal),
9 Identifier(Identifier),
10 Predicate(Predicate),
11 InfixExpr { left: Box<Expr>, operator: InfixOperator, right: Box<Expr> },
12 ExprList(Vec<Expr>), Placeholder,
14}
15
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum Literal {
18 I16(i16),
19 I32(i32),
20 I64(i64),
21 F64(f64),
22 Bool(bool),
23 String(String),
24 EntityId(Ulid),
25 Object(Vec<u8>),
26 Binary(Vec<u8>),
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub enum Identifier {
31 Property(String),
32 CollectionProperty(String, String),
33}
34
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub struct Selection {
37 pub predicate: Predicate,
38 pub order_by: Option<Vec<OrderByItem>>,
39 pub limit: Option<u64>,
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub struct OrderByItem {
44 pub identifier: Identifier,
45 pub direction: OrderDirection,
46}
47
48impl std::fmt::Display for OrderByItem {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 let field = match &self.identifier {
51 Identifier::Property(prop) => prop.clone(),
52 Identifier::CollectionProperty(coll, prop) => format!("{}.{}", coll, prop),
53 };
54 write!(
55 f,
56 "{} {}",
57 field,
58 match self.direction {
59 OrderDirection::Asc => "ASC",
60 OrderDirection::Desc => "DESC",
61 }
62 )
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub enum OrderDirection {
68 Asc,
69 Desc,
70}
71
72impl std::fmt::Display for Selection {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 write!(f, "{}", self.predicate)?;
75 if let Some(order_by) = &self.order_by {
76 write!(f, " ORDER BY ")?;
77 for (i, item) in order_by.iter().enumerate() {
78 if i > 0 {
79 write!(f, ", ")?;
80 }
81 write!(f, "{}", item)?;
82 }
83 }
84 if let Some(limit) = self.limit {
85 write!(f, " LIMIT {}", limit)?;
86 }
87 Ok(())
88 }
89}
90
91impl From<Predicate> for Selection {
93 fn from(predicate: Predicate) -> Self { Selection { predicate, order_by: None, limit: None } }
94}
95
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub enum Predicate {
98 Comparison { left: Box<Expr>, operator: ComparisonOperator, right: Box<Expr> },
99 IsNull(Box<Expr>),
100 And(Box<Predicate>, Box<Predicate>),
101 Or(Box<Predicate>, Box<Predicate>),
102 Not(Box<Predicate>),
103 True,
104 False,
105 Placeholder,
106}
107
108impl std::fmt::Display for Predicate {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match generate_selection_sql(self, None) {
111 Ok(sql) => write!(f, "{}", sql),
112 Err(e) => write!(f, "SQL Error: {}", e),
113 }
114 }
115}
116
117impl Selection {
118 pub fn assume_null(&self, columns: &[String]) -> Self {
119 Self { predicate: self.predicate.assume_null(columns), order_by: self.order_by.clone(), limit: self.limit }
120 }
121}
122
123impl Predicate {
124 pub fn walk<T, F>(&self, accumulator: T, visitor: &mut F) -> T
126 where F: FnMut(T, &Predicate) -> T {
127 let accumulator = visitor(accumulator, self);
128 match self {
129 Predicate::And(left, right) | Predicate::Or(left, right) => {
130 let accumulator = left.walk(accumulator, visitor);
131 right.walk(accumulator, visitor)
132 }
133 Predicate::Not(inner) => inner.walk(accumulator, visitor),
134 _ => accumulator,
135 }
136 }
137
138 pub fn assume_null(&self, columns: &[String]) -> Self {
140 match self {
141 Predicate::Comparison { left, operator, right } => {
142 let has_null_identifier = match (&**left, &**right) {
144 (Expr::Identifier(id), _) | (_, Expr::Identifier(id)) => match id {
145 Identifier::Property(name) => columns.contains(name),
146 Identifier::CollectionProperty(_, name) => columns.contains(name),
147 },
148 _ => false,
149 };
150
151 if has_null_identifier {
152 match operator {
153 ComparisonOperator::Equal => Predicate::False,
155 ComparisonOperator::NotEqual => Predicate::False,
157 ComparisonOperator::GreaterThan => Predicate::False,
159 ComparisonOperator::GreaterThanOrEqual => Predicate::False,
161 ComparisonOperator::LessThan => Predicate::False,
163 ComparisonOperator::LessThanOrEqual => Predicate::False,
165 ComparisonOperator::In => Predicate::False,
167 ComparisonOperator::Between => Predicate::False,
169 }
170 } else {
171 Predicate::Comparison { left: left.clone(), operator: operator.clone(), right: right.clone() }
173 }
174 }
175 Predicate::IsNull(expr) => {
176 match &**expr {
179 Expr::Identifier(id) => {
180 let is_null = match id {
181 Identifier::Property(name) => columns.contains(name),
182 Identifier::CollectionProperty(_, name) => columns.contains(name),
183 };
184 if is_null {
185 Predicate::True
186 } else {
187 Predicate::IsNull(expr.clone())
188 }
189 }
190 _ => Predicate::IsNull(expr.clone()),
191 }
192 }
193 Predicate::And(left, right) => {
194 let left = left.assume_null(columns);
195 let right = right.assume_null(columns);
196
197 match (&left, &right) {
199 (Predicate::False, _) | (_, Predicate::False) => Predicate::False,
201 (Predicate::True, Predicate::True) => Predicate::True,
203 (Predicate::True, p) | (p, Predicate::True) => p.clone(),
205 _ => Predicate::And(Box::new(left), Box::new(right)),
206 }
207 }
208 Predicate::Or(left, right) => {
209 let left = left.assume_null(columns);
210 let right = right.assume_null(columns);
211
212 match (&left, &right) {
214 (Predicate::True, _) | (_, Predicate::True) => Predicate::True,
216 (Predicate::False, Predicate::False) => Predicate::False,
218 (Predicate::False, p) | (p, Predicate::False) => p.clone(),
220 _ => Predicate::Or(Box::new(left), Box::new(right)),
222 }
223 }
224 Predicate::Not(pred) => {
225 let inner = pred.assume_null(columns);
226 match inner {
227 Predicate::True => Predicate::False,
228 Predicate::False => Predicate::True,
229 _ => Predicate::Not(Box::new(inner)),
230 }
231 }
232 Predicate::True => Predicate::True,
234 Predicate::False => Predicate::False,
235 Predicate::Placeholder => Predicate::Placeholder,
236 }
237 }
238
239 pub fn populate<I, V, E>(self, values: I) -> Result<Predicate, ParseError>
241 where
242 I: IntoIterator<Item = V>,
243 V: TryInto<Expr, Error = E>,
244 E: Into<ParseError>,
245 {
246 let mut values_iter = values.into_iter();
247 let result = self.populate_recursive(&mut values_iter)?;
248
249 if values_iter.next().is_some() {
251 return Err(ParseError::InvalidPredicate("Too many values provided for placeholders".to_string()));
252 }
253
254 Ok(result)
255 }
256
257 fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Predicate, ParseError>
258 where
259 I: Iterator<Item = V>,
260 V: TryInto<Expr, Error = E>,
261 E: Into<ParseError>,
262 {
263 match self {
264 Predicate::Comparison { left, operator, right } => Ok(Predicate::Comparison {
265 left: Box::new(left.populate_recursive(values)?),
266 operator,
267 right: Box::new(right.populate_recursive(values)?),
268 }),
269 Predicate::And(left, right) => {
270 Ok(Predicate::And(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
271 }
272 Predicate::Or(left, right) => {
273 Ok(Predicate::Or(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
274 }
275 Predicate::Not(pred) => Ok(Predicate::Not(Box::new(pred.populate_recursive(values)?))),
276 Predicate::IsNull(expr) => Ok(Predicate::IsNull(Box::new(expr.populate_recursive(values)?))),
277 Predicate::True => Ok(Predicate::True),
278 Predicate::False => Ok(Predicate::False),
279 Predicate::Placeholder => Err(ParseError::InvalidPredicate("Placeholder must be transformed before population".to_string())),
281 }
282 }
283}
284
285impl Expr {
286 fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Expr, ParseError>
287 where
288 I: Iterator<Item = V>,
289 V: TryInto<Expr, Error = E>,
290 E: Into<ParseError>,
291 {
292 match self {
293 Expr::Placeholder => match values.next() {
294 Some(value) => Ok(value.try_into().map_err(|e| e.into())?),
295 None => Err(ParseError::InvalidPredicate("Not enough values provided for placeholders".to_string())),
296 },
297 Expr::Literal(lit) => Ok(Expr::Literal(lit)),
298 Expr::Identifier(id) => Ok(Expr::Identifier(id)),
299 Expr::Predicate(pred) => Ok(Expr::Predicate(pred.populate_recursive(values)?)),
300 Expr::InfixExpr { left, operator, right } => Ok(Expr::InfixExpr {
301 left: Box::new(left.populate_recursive(values)?),
302 operator,
303 right: Box::new(right.populate_recursive(values)?),
304 }),
305 Expr::ExprList(exprs) => {
306 let mut populated_exprs = Vec::new();
307 for expr in exprs {
308 populated_exprs.push(expr.populate_recursive(values)?);
309 }
310 Ok(Expr::ExprList(populated_exprs))
311 }
312 }
313 }
314}
315
316#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
317pub enum ComparisonOperator {
318 Equal, NotEqual, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, In, Between, }
327
328#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
329pub enum InfixOperator {
330 Add,
331 Subtract,
332 Multiply,
333 Divide,
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::parser::parse_selection;
340
341 fn nullify_columns(input: &str, null_columns: &[&str]) -> Result<String, ParseError> {
342 let selection = parse_selection(input)?;
343 let result = selection.predicate.assume_null(&null_columns.iter().map(|s| s.to_string()).collect::<Vec<_>>());
344 generate_selection_sql(&result, None).map_err(|_| ParseError::InvalidPredicate("SQL generation failed".to_string()))
345 }
346
347 #[test]
348 fn test_single_comparison_null_handling() {
349 assert_eq!(nullify_columns("status = 'active'", &["status"]).unwrap(), "FALSE");
350 assert_eq!(nullify_columns("age > 30", &["age"]).unwrap(), "FALSE");
351 assert_eq!(nullify_columns("count >= 100", &["count"]).unwrap(), "FALSE");
352 assert_eq!(nullify_columns("name < 'Z'", &["name"]).unwrap(), "FALSE");
353 assert_eq!(nullify_columns("score <= 90", &["score"]).unwrap(), "FALSE");
354 assert_eq!(nullify_columns("status IS NULL", &["status"]).unwrap(), "TRUE");
355 assert_eq!(nullify_columns("role = 'admin'", &["other"]).unwrap(), r#""role" = 'admin'"#);
356 }
357
358 #[test]
359 fn nested_predicate_null_handling() {
360 let input = "alpha = 1 AND (beta = 2 OR charlie = 3)";
361 assert_eq!(nullify_columns(input, &["charlie"]).unwrap(), r#""alpha" = 1 AND "beta" = 2"#);
362 assert_eq!(nullify_columns(input, &["beta", "charlie"]).unwrap(), r#"FALSE"#);
363 assert_eq!(nullify_columns(input, &["alpha"]).unwrap(), r#"FALSE"#);
364 assert_eq!(nullify_columns(input, &["other"]).unwrap(), r#""alpha" = 1 AND ("beta" = 2 OR "charlie" = 3)"#);
365 }
366
367 #[test]
368 fn test_populate_single_placeholder() {
369 let selection = parse_selection("name = ?").unwrap();
370 let populated = selection.predicate.populate(vec!["Alice"]).unwrap();
371
372 let expected = Predicate::Comparison {
373 left: Box::new(Expr::Identifier(Identifier::Property("name".to_string()))),
374 operator: ComparisonOperator::Equal,
375 right: Box::new(Expr::Literal(Literal::String("Alice".to_string()))),
376 };
377
378 assert_eq!(populated, expected);
379 }
380
381 #[test]
382 fn test_populate_multiple_placeholders() {
383 let selection = parse_selection("age > ? AND name = ?").unwrap();
384 let values: Vec<Expr> = vec![25i64.into(), "Bob".into()];
385 let populated = selection.predicate.populate(values).unwrap();
386
387 let expected = Predicate::And(
388 Box::new(Predicate::Comparison {
389 left: Box::new(Expr::Identifier(Identifier::Property("age".to_string()))),
390 operator: ComparisonOperator::GreaterThan,
391 right: Box::new(Expr::Literal(Literal::I64(25))),
392 }),
393 Box::new(Predicate::Comparison {
394 left: Box::new(Expr::Identifier(Identifier::Property("name".to_string()))),
395 operator: ComparisonOperator::Equal,
396 right: Box::new(Expr::Literal(Literal::String("Bob".to_string()))),
397 }),
398 );
399
400 assert_eq!(populated, expected);
401 }
402
403 #[test]
404 fn test_populate_in_clause() {
405 let selection = parse_selection("status IN (?, ?, ?)").unwrap();
406 let populated = selection.predicate.populate(vec!["active", "pending", "review"]).unwrap();
407
408 let expected = Predicate::Comparison {
409 left: Box::new(Expr::Identifier(Identifier::Property("status".to_string()))),
410 operator: ComparisonOperator::In,
411 right: Box::new(Expr::ExprList(vec![
412 Expr::Literal(Literal::String("active".to_string())),
413 Expr::Literal(Literal::String("pending".to_string())),
414 Expr::Literal(Literal::String("review".to_string())),
415 ])),
416 };
417
418 assert_eq!(populated, expected);
419 }
420
421 #[test]
422 fn test_populate_mixed_types() {
423 let selection = parse_selection("active = ? AND score > ? AND name = ?").unwrap();
424 let values: Vec<Expr> = vec![true.into(), 95.5f64.into(), "Charlie".into()];
425 let populated = selection.predicate.populate(values).unwrap();
426
427 if let Predicate::And(left, right) = populated {
429 if let Predicate::And(inner_left, inner_right) = *left {
430 if let Predicate::Comparison { right: val, .. } = *inner_left {
432 assert_eq!(*val, Expr::Literal(Literal::Bool(true)));
433 }
434 if let Predicate::Comparison { right: val, .. } = *inner_right {
436 assert_eq!(*val, Expr::Literal(Literal::F64(95.5)));
437 }
438 }
439 if let Predicate::Comparison { right: val, .. } = *right {
441 assert_eq!(*val, Expr::Literal(Literal::String("Charlie".to_string())));
442 }
443 }
444 }
445
446 #[test]
447 fn test_populate_too_few_values() {
448 let selection = parse_selection("name = ? AND age = ?").unwrap();
449 let result = selection.predicate.populate(vec!["Alice"]);
450
451 assert!(result.is_err());
452 assert!(result.unwrap_err().to_string().contains("Not enough values"));
453 }
454
455 #[test]
456 fn test_populate_too_many_values() {
457 let selection = parse_selection("name = ?").unwrap();
458 let result = selection.predicate.populate(vec!["Alice", "Bob"]);
459
460 assert!(result.is_err());
461 assert!(result.unwrap_err().to_string().contains("Too many values"));
462 }
463
464 #[test]
465 fn test_populate_no_placeholders() {
466 let selection = parse_selection("name = 'Alice'").unwrap();
467 let populated = selection.clone().predicate.populate(Vec::<String>::new()).unwrap();
468
469 assert_eq!(populated, selection.predicate);
471 }
472}
473
474impl From<String> for Expr {
476 fn from(s: String) -> Expr { Expr::Literal(Literal::String(s)) }
477}
478
479impl From<&str> for Expr {
480 fn from(s: &str) -> Expr { Expr::Literal(Literal::String(s.to_string())) }
481}
482
483impl From<i64> for Expr {
484 fn from(i: i64) -> Expr { Expr::Literal(Literal::I64(i)) }
485}
486
487impl From<f64> for Expr {
488 fn from(f: f64) -> Expr { Expr::Literal(Literal::F64(f)) }
489}
490
491impl From<bool> for Expr {
492 fn from(b: bool) -> Expr { Expr::Literal(Literal::Bool(b)) }
493}
494
495impl From<Literal> for Expr {
496 fn from(lit: Literal) -> Expr { Expr::Literal(lit) }
497}
498
499impl<T> From<Vec<T>> for Expr
501where T: Into<Expr>
502{
503 fn from(vec: Vec<T>) -> Self { Expr::ExprList(vec.into_iter().map(|item| item.into()).collect()) }
504}
505
506impl<T, const N: usize> From<[T; N]> for Expr
507where T: Into<Expr>
508{
509 fn from(arr: [T; N]) -> Self { Expr::ExprList(arr.into_iter().map(|item| item.into()).collect()) }
510}
511
512impl<T> From<&[T]> for Expr
513where T: Into<Expr> + Clone
514{
515 fn from(slice: &[T]) -> Self { Expr::ExprList(slice.iter().map(|item| item.clone().into()).collect()) }
516}
517
518impl<T, const N: usize> From<&[T; N]> for Expr
519where T: Into<Expr> + Clone
520{
521 fn from(arr: &[T; N]) -> Self { Expr::ExprList(arr.iter().map(|item| item.clone().into()).collect()) }
522}