1use crate::error::ParseError;
2use crate::selection::sql::generate_selection_sql;
3use serde::{Deserialize, Serialize};
4use ulid::Ulid;
5
6mod json_as_bytes {
9 use serde::{Deserialize, Deserializer, Serialize, Serializer};
10
11 pub fn serialize<S>(value: &serde_json::Value, serializer: S) -> Result<S::Ok, S::Error>
12 where S: Serializer {
13 let bytes = serde_json::to_vec(value).map_err(serde::ser::Error::custom)?;
14 bytes.serialize(serializer)
15 }
16
17 pub fn deserialize<'de, D>(deserializer: D) -> Result<serde_json::Value, D::Error>
18 where D: Deserializer<'de> {
19 let bytes: Vec<u8> = Vec::deserialize(deserializer)?;
20 serde_json::from_slice(&bytes).map_err(serde::de::Error::custom)
21 }
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub enum Expr {
26 Literal(Literal),
27 Path(PathExpr),
28 Predicate(Predicate),
29 InfixExpr { left: Box<Expr>, operator: InfixOperator, right: Box<Expr> },
30 ExprList(Vec<Expr>), Placeholder,
32}
33
34#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub enum Literal {
36 I16(i16),
37 I32(i32),
38 I64(i64),
39 F64(f64),
40 Bool(bool),
41 String(String),
42 EntityId(Ulid),
43 Object(Vec<u8>),
44 Binary(Vec<u8>),
45 #[serde(with = "json_as_bytes")]
49 Json(serde_json::Value),
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct PathExpr {
55 pub steps: Vec<String>,
56}
57
58impl PathExpr {
59 pub fn simple(name: impl Into<String>) -> Self { Self { steps: vec![name.into()] } }
61
62 pub fn is_simple(&self) -> bool { self.steps.len() == 1 }
64
65 pub fn first(&self) -> &str { &self.steps[0] }
67
68 pub fn property(&self) -> &str { self.steps.last().expect("PathExpr must have at least one step") }
70}
71
72impl std::fmt::Display for PathExpr {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.steps.join(".")) }
74}
75
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77pub struct Selection {
78 pub predicate: Predicate,
79 pub order_by: Option<Vec<OrderByItem>>,
80 pub limit: Option<u64>,
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub struct OrderByItem {
85 pub path: PathExpr,
86 pub direction: OrderDirection,
87}
88
89impl std::fmt::Display for OrderByItem {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(
92 f,
93 "{} {}",
94 self.path,
95 match self.direction {
96 OrderDirection::Asc => "ASC",
97 OrderDirection::Desc => "DESC",
98 }
99 )
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub enum OrderDirection {
105 Asc,
106 Desc,
107}
108
109impl std::fmt::Display for Selection {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 write!(f, "{}", self.predicate)?;
112 if let Some(order_by) = &self.order_by {
113 write!(f, " ORDER BY ")?;
114 for (i, item) in order_by.iter().enumerate() {
115 if i > 0 {
116 write!(f, ", ")?;
117 }
118 write!(f, "{}", item)?;
119 }
120 }
121 if let Some(limit) = self.limit {
122 write!(f, " LIMIT {}", limit)?;
123 }
124 Ok(())
125 }
126}
127
128impl From<Predicate> for Selection {
130 fn from(predicate: Predicate) -> Self { Selection { predicate, order_by: None, limit: None } }
131}
132
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134pub enum Predicate {
135 Comparison { left: Box<Expr>, operator: ComparisonOperator, right: Box<Expr> },
136 IsNull(Box<Expr>),
137 And(Box<Predicate>, Box<Predicate>),
138 Or(Box<Predicate>, Box<Predicate>),
139 Not(Box<Predicate>),
140 True,
141 False,
142 Placeholder,
143}
144
145impl std::fmt::Display for Predicate {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 match generate_selection_sql(self, None) {
148 Ok(sql) => write!(f, "{}", sql),
149 Err(e) => write!(f, "SQL Error: {}", e),
150 }
151 }
152}
153
154impl Selection {
155 pub fn assume_null(&self, columns: &[String]) -> Self {
158 let order_by = self.order_by.as_ref().map(|items| {
159 items
160 .iter()
161 .filter(|item| {
162 let col_name = item.path.property();
164 !columns.contains(&col_name.to_string())
165 })
166 .cloned()
167 .collect::<Vec<_>>()
168 });
169 let order_by = order_by.and_then(|v| if v.is_empty() { None } else { Some(v) });
171
172 Self { predicate: self.predicate.assume_null(columns), order_by, limit: self.limit }
173 }
174
175 pub fn referenced_columns(&self) -> Vec<String> {
179 let mut columns = self.predicate.referenced_columns();
180 if let Some(order_by) = &self.order_by {
181 for item in order_by {
182 let col = item.path.first().to_string();
184 if !columns.contains(&col) {
185 columns.push(col);
186 }
187 }
188 }
189 columns
190 }
191}
192
193impl Predicate {
194 pub fn walk<T, F>(&self, accumulator: T, visitor: &mut F) -> T
196 where F: FnMut(T, &Predicate) -> T {
197 let accumulator = visitor(accumulator, self);
198 match self {
199 Predicate::And(left, right) | Predicate::Or(left, right) => {
200 let accumulator = left.walk(accumulator, visitor);
201 right.walk(accumulator, visitor)
202 }
203 Predicate::Not(inner) => inner.walk(accumulator, visitor),
204 _ => accumulator,
205 }
206 }
207
208 pub fn referenced_columns(&self) -> Vec<String> {
212 self.walk(Vec::new(), &mut |mut cols, pred| {
213 match pred {
214 Predicate::Comparison { left, right, .. } => {
215 for expr in [&**left, &**right] {
216 if let Expr::Path(path) = expr {
217 let col = path.first().to_string();
220 if !cols.contains(&col) {
221 cols.push(col);
222 }
223 }
224 }
225 }
226 Predicate::IsNull(expr) => {
227 if let Expr::Path(path) = &**expr {
228 let col = path.first().to_string();
229 if !cols.contains(&col) {
230 cols.push(col);
231 }
232 }
233 }
234 _ => {}
235 }
236 cols
237 })
238 }
239
240 pub fn assume_null(&self, columns: &[String]) -> Self {
242 match self {
243 Predicate::Comparison { left, operator, right } => {
244 let has_null_path = match (&**left, &**right) {
246 (Expr::Path(path), _) | (_, Expr::Path(path)) => columns.contains(&path.property().to_string()),
247 _ => false,
248 };
249
250 if has_null_path {
251 match operator {
252 ComparisonOperator::Equal => Predicate::False,
254 ComparisonOperator::NotEqual => Predicate::False,
256 ComparisonOperator::GreaterThan => Predicate::False,
258 ComparisonOperator::GreaterThanOrEqual => Predicate::False,
260 ComparisonOperator::LessThan => Predicate::False,
262 ComparisonOperator::LessThanOrEqual => Predicate::False,
264 ComparisonOperator::In => Predicate::False,
266 ComparisonOperator::Between => Predicate::False,
268 }
269 } else {
270 Predicate::Comparison { left: left.clone(), operator: operator.clone(), right: right.clone() }
272 }
273 }
274 Predicate::IsNull(expr) => {
275 match &**expr {
278 Expr::Path(path) => {
279 let is_null = columns.contains(&path.property().to_string());
280 if is_null {
281 Predicate::True
282 } else {
283 Predicate::IsNull(expr.clone())
284 }
285 }
286 _ => Predicate::IsNull(expr.clone()),
287 }
288 }
289 Predicate::And(left, right) => {
290 let left = left.assume_null(columns);
291 let right = right.assume_null(columns);
292
293 match (&left, &right) {
295 (Predicate::False, _) | (_, Predicate::False) => Predicate::False,
297 (Predicate::True, Predicate::True) => Predicate::True,
299 (Predicate::True, p) | (p, Predicate::True) => p.clone(),
301 _ => Predicate::And(Box::new(left), Box::new(right)),
302 }
303 }
304 Predicate::Or(left, right) => {
305 let left = left.assume_null(columns);
306 let right = right.assume_null(columns);
307
308 match (&left, &right) {
310 (Predicate::True, _) | (_, Predicate::True) => Predicate::True,
312 (Predicate::False, Predicate::False) => Predicate::False,
314 (Predicate::False, p) | (p, Predicate::False) => p.clone(),
316 _ => Predicate::Or(Box::new(left), Box::new(right)),
318 }
319 }
320 Predicate::Not(pred) => {
321 let inner = pred.assume_null(columns);
322 match inner {
323 Predicate::True => Predicate::False,
324 Predicate::False => Predicate::True,
325 _ => Predicate::Not(Box::new(inner)),
326 }
327 }
328 Predicate::True => Predicate::True,
330 Predicate::False => Predicate::False,
331 Predicate::Placeholder => Predicate::Placeholder,
332 }
333 }
334
335 pub fn populate<I, V, E>(self, values: I) -> Result<Predicate, ParseError>
337 where
338 I: IntoIterator<Item = V>,
339 V: TryInto<Expr, Error = E>,
340 E: Into<ParseError>,
341 {
342 let mut values_iter = values.into_iter();
343 let result = self.populate_recursive(&mut values_iter)?;
344
345 if values_iter.next().is_some() {
347 return Err(ParseError::InvalidPredicate("Too many values provided for placeholders".to_string()));
348 }
349
350 Ok(result)
351 }
352
353 fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Predicate, ParseError>
354 where
355 I: Iterator<Item = V>,
356 V: TryInto<Expr, Error = E>,
357 E: Into<ParseError>,
358 {
359 match self {
360 Predicate::Comparison { left, operator, right } => Ok(Predicate::Comparison {
361 left: Box::new(left.populate_recursive(values)?),
362 operator,
363 right: Box::new(right.populate_recursive(values)?),
364 }),
365 Predicate::And(left, right) => {
366 Ok(Predicate::And(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
367 }
368 Predicate::Or(left, right) => {
369 Ok(Predicate::Or(Box::new(left.populate_recursive(values)?), Box::new(right.populate_recursive(values)?)))
370 }
371 Predicate::Not(pred) => Ok(Predicate::Not(Box::new(pred.populate_recursive(values)?))),
372 Predicate::IsNull(expr) => Ok(Predicate::IsNull(Box::new(expr.populate_recursive(values)?))),
373 Predicate::True => Ok(Predicate::True),
374 Predicate::False => Ok(Predicate::False),
375 Predicate::Placeholder => Err(ParseError::InvalidPredicate("Placeholder must be transformed before population".to_string())),
377 }
378 }
379}
380
381impl Expr {
382 fn populate_recursive<I, V, E>(self, values: &mut I) -> Result<Expr, ParseError>
383 where
384 I: Iterator<Item = V>,
385 V: TryInto<Expr, Error = E>,
386 E: Into<ParseError>,
387 {
388 match self {
389 Expr::Placeholder => match values.next() {
390 Some(value) => Ok(value.try_into().map_err(|e| e.into())?),
391 None => Err(ParseError::InvalidPredicate("Not enough values provided for placeholders".to_string())),
392 },
393 Expr::Literal(lit) => Ok(Expr::Literal(lit)),
394 Expr::Path(path) => Ok(Expr::Path(path)),
395 Expr::Predicate(pred) => Ok(Expr::Predicate(pred.populate_recursive(values)?)),
396 Expr::InfixExpr { left, operator, right } => Ok(Expr::InfixExpr {
397 left: Box::new(left.populate_recursive(values)?),
398 operator,
399 right: Box::new(right.populate_recursive(values)?),
400 }),
401 Expr::ExprList(exprs) => {
402 let mut populated_exprs = Vec::new();
403 for expr in exprs {
404 populated_exprs.push(expr.populate_recursive(values)?);
405 }
406 Ok(Expr::ExprList(populated_exprs))
407 }
408 }
409 }
410}
411
412#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
413pub enum ComparisonOperator {
414 Equal, NotEqual, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, In, Between, }
423
424#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
425pub enum InfixOperator {
426 Add,
427 Subtract,
428 Multiply,
429 Divide,
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::parser::parse_selection;
436
437 fn nullify_columns(input: &str, null_columns: &[&str]) -> Result<String, ParseError> {
438 let selection = parse_selection(input)?;
439 let result = selection.predicate.assume_null(&null_columns.iter().map(|s| s.to_string()).collect::<Vec<_>>());
440 generate_selection_sql(&result, None).map_err(|_| ParseError::InvalidPredicate("SQL generation failed".to_string()))
441 }
442
443 #[test]
444 fn test_single_comparison_null_handling() {
445 assert_eq!(nullify_columns("status = 'active'", &["status"]).unwrap(), "FALSE");
446 assert_eq!(nullify_columns("age > 30", &["age"]).unwrap(), "FALSE");
447 assert_eq!(nullify_columns("count >= 100", &["count"]).unwrap(), "FALSE");
448 assert_eq!(nullify_columns("name < 'Z'", &["name"]).unwrap(), "FALSE");
449 assert_eq!(nullify_columns("score <= 90", &["score"]).unwrap(), "FALSE");
450 assert_eq!(nullify_columns("status IS NULL", &["status"]).unwrap(), "TRUE");
451 assert_eq!(nullify_columns("role = 'admin'", &["other"]).unwrap(), r#""role" = 'admin'"#);
452 }
453
454 #[test]
455 fn nested_predicate_null_handling() {
456 let input = "alpha = 1 AND (beta = 2 OR charlie = 3)";
457 assert_eq!(nullify_columns(input, &["charlie"]).unwrap(), r#""alpha" = 1 AND "beta" = 2"#);
458 assert_eq!(nullify_columns(input, &["beta", "charlie"]).unwrap(), r#"FALSE"#);
459 assert_eq!(nullify_columns(input, &["alpha"]).unwrap(), r#"FALSE"#);
460 assert_eq!(nullify_columns(input, &["other"]).unwrap(), r#""alpha" = 1 AND ("beta" = 2 OR "charlie" = 3)"#);
461 }
462
463 #[test]
464 fn test_populate_single_placeholder() {
465 let selection = parse_selection("name = ?").unwrap();
466 let populated = selection.predicate.populate(vec!["Alice"]).unwrap();
467
468 let expected = Predicate::Comparison {
469 left: Box::new(Expr::Path(PathExpr::simple("name".to_string()))),
470 operator: ComparisonOperator::Equal,
471 right: Box::new(Expr::Literal(Literal::String("Alice".to_string()))),
472 };
473
474 assert_eq!(populated, expected);
475 }
476
477 #[test]
478 fn test_populate_multiple_placeholders() {
479 let selection = parse_selection("age > ? AND name = ?").unwrap();
480 let values: Vec<Expr> = vec![25i64.into(), "Bob".into()];
481 let populated = selection.predicate.populate(values).unwrap();
482
483 let expected = Predicate::And(
484 Box::new(Predicate::Comparison {
485 left: Box::new(Expr::Path(PathExpr::simple("age".to_string()))),
486 operator: ComparisonOperator::GreaterThan,
487 right: Box::new(Expr::Literal(Literal::I64(25))),
488 }),
489 Box::new(Predicate::Comparison {
490 left: Box::new(Expr::Path(PathExpr::simple("name".to_string()))),
491 operator: ComparisonOperator::Equal,
492 right: Box::new(Expr::Literal(Literal::String("Bob".to_string()))),
493 }),
494 );
495
496 assert_eq!(populated, expected);
497 }
498
499 #[test]
500 fn test_populate_in_clause() {
501 let selection = parse_selection("status IN (?, ?, ?)").unwrap();
502 let populated = selection.predicate.populate(vec!["active", "pending", "review"]).unwrap();
503
504 let expected = Predicate::Comparison {
505 left: Box::new(Expr::Path(PathExpr::simple("status".to_string()))),
506 operator: ComparisonOperator::In,
507 right: Box::new(Expr::ExprList(vec![
508 Expr::Literal(Literal::String("active".to_string())),
509 Expr::Literal(Literal::String("pending".to_string())),
510 Expr::Literal(Literal::String("review".to_string())),
511 ])),
512 };
513
514 assert_eq!(populated, expected);
515 }
516
517 #[test]
518 fn test_populate_mixed_types() {
519 let selection = parse_selection("active = ? AND score > ? AND name = ?").unwrap();
520 let values: Vec<Expr> = vec![true.into(), 95.5f64.into(), "Charlie".into()];
521 let populated = selection.predicate.populate(values).unwrap();
522
523 if let Predicate::And(left, right) = populated {
525 if let Predicate::And(inner_left, inner_right) = *left {
526 if let Predicate::Comparison { right: val, .. } = *inner_left {
528 assert_eq!(*val, Expr::Literal(Literal::Bool(true)));
529 }
530 if let Predicate::Comparison { right: val, .. } = *inner_right {
532 assert_eq!(*val, Expr::Literal(Literal::F64(95.5)));
533 }
534 }
535 if let Predicate::Comparison { right: val, .. } = *right {
537 assert_eq!(*val, Expr::Literal(Literal::String("Charlie".to_string())));
538 }
539 }
540 }
541
542 #[test]
543 fn test_populate_too_few_values() {
544 let selection = parse_selection("name = ? AND age = ?").unwrap();
545 let result = selection.predicate.populate(vec!["Alice"]);
546
547 assert!(result.is_err());
548 assert!(result.unwrap_err().to_string().contains("Not enough values"));
549 }
550
551 #[test]
552 fn test_populate_too_many_values() {
553 let selection = parse_selection("name = ?").unwrap();
554 let result = selection.predicate.populate(vec!["Alice", "Bob"]);
555
556 assert!(result.is_err());
557 assert!(result.unwrap_err().to_string().contains("Too many values"));
558 }
559
560 #[test]
561 fn test_populate_no_placeholders() {
562 let selection = parse_selection("name = 'Alice'").unwrap();
563 let populated = selection.clone().predicate.populate(Vec::<String>::new()).unwrap();
564
565 assert_eq!(populated, selection.predicate);
567 }
568}
569
570impl From<String> for Expr {
572 fn from(s: String) -> Expr { Expr::Literal(Literal::String(s)) }
573}
574
575impl From<&str> for Expr {
576 fn from(s: &str) -> Expr { Expr::Literal(Literal::String(s.to_string())) }
577}
578
579impl From<i64> for Expr {
580 fn from(i: i64) -> Expr { Expr::Literal(Literal::I64(i)) }
581}
582
583impl From<f64> for Expr {
584 fn from(f: f64) -> Expr { Expr::Literal(Literal::F64(f)) }
585}
586
587impl From<bool> for Expr {
588 fn from(b: bool) -> Expr { Expr::Literal(Literal::Bool(b)) }
589}
590
591impl From<Literal> for Expr {
592 fn from(lit: Literal) -> Expr { Expr::Literal(lit) }
593}
594
595impl<T> From<Vec<T>> for Expr
597where T: Into<Expr>
598{
599 fn from(vec: Vec<T>) -> Self { Expr::ExprList(vec.into_iter().map(|item| item.into()).collect()) }
600}
601
602impl<T, const N: usize> From<[T; N]> for Expr
603where T: Into<Expr>
604{
605 fn from(arr: [T; N]) -> Self { Expr::ExprList(arr.into_iter().map(|item| item.into()).collect()) }
606}
607
608impl<T> From<&[T]> for Expr
609where T: Into<Expr> + Clone
610{
611 fn from(slice: &[T]) -> Self { Expr::ExprList(slice.iter().map(|item| item.clone().into()).collect()) }
612}
613
614impl<T, const N: usize> From<&[T; N]> for Expr
615where T: Into<Expr> + Clone
616{
617 fn from(arr: &[T; N]) -> Self { Expr::ExprList(arr.iter().map(|item| item.clone().into()).collect()) }
618}