1use serde::{Deserialize, Serialize};
7use std::fmt;
8
9use crate::error::{Error, Result};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum Operator {
15 And,
17 Or,
19 Not,
21}
22
23impl Operator {
24 pub fn as_str(&self) -> &'static str {
26 match self {
27 Operator::And => "and",
28 Operator::Or => "or",
29 Operator::Not => "not",
30 }
31 }
32}
33
34impl fmt::Display for Operator {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 write!(f, "{}", self.as_str())
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum Comparator {
44 Eq,
46 Ne,
48 Gt,
50 Gte,
52 Lt,
54 Lte,
56 Contain,
58 Like,
60 In,
62 Nin,
64}
65
66impl Comparator {
67 pub fn as_str(&self) -> &'static str {
69 match self {
70 Comparator::Eq => "eq",
71 Comparator::Ne => "ne",
72 Comparator::Gt => "gt",
73 Comparator::Gte => "gte",
74 Comparator::Lt => "lt",
75 Comparator::Lte => "lte",
76 Comparator::Contain => "contain",
77 Comparator::Like => "like",
78 Comparator::In => "in",
79 Comparator::Nin => "nin",
80 }
81 }
82}
83
84impl fmt::Display for Comparator {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 write!(f, "{}", self.as_str())
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum OperatorOrComparator {
93 Operator(Operator),
95 Comparator(Comparator),
97}
98
99impl From<Operator> for OperatorOrComparator {
100 fn from(op: Operator) -> Self {
101 OperatorOrComparator::Operator(op)
102 }
103}
104
105impl From<Comparator> for OperatorOrComparator {
106 fn from(comp: Comparator) -> Self {
107 OperatorOrComparator::Comparator(comp)
108 }
109}
110
111impl fmt::Display for OperatorOrComparator {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 match self {
114 OperatorOrComparator::Operator(op) => write!(f, "{}", op),
115 OperatorOrComparator::Comparator(comp) => write!(f, "{}", comp),
116 }
117 }
118}
119
120pub trait Visitor {
125 type Output;
127
128 fn allowed_comparators(&self) -> Option<&[Comparator]> {
130 None
131 }
132
133 fn allowed_operators(&self) -> Option<&[Operator]> {
135 None
136 }
137
138 fn validate_func(&self, func: OperatorOrComparator) -> Result<()> {
140 match func {
141 OperatorOrComparator::Operator(op) => {
142 if let Some(allowed) = self.allowed_operators()
143 && !allowed.contains(&op)
144 {
145 return Err(Error::Other(format!(
146 "Received disallowed operator {}. Allowed operators are {:?}",
147 op, allowed
148 )));
149 }
150 }
151 OperatorOrComparator::Comparator(comp) => {
152 if let Some(allowed) = self.allowed_comparators()
153 && !allowed.contains(&comp)
154 {
155 return Err(Error::Other(format!(
156 "Received disallowed comparator {}. Allowed comparators are {:?}",
157 comp, allowed
158 )));
159 }
160 }
161 }
162 Ok(())
163 }
164
165 fn visit_operation(&self, operation: &Operation) -> Result<Self::Output>;
167
168 fn visit_comparison(&self, comparison: &Comparison) -> Result<Self::Output>;
170
171 fn visit_structured_query(&self, structured_query: &StructuredQuery) -> Result<Self::Output>;
173}
174
175pub trait Expr: fmt::Debug {
180 fn expr_name(&self) -> &'static str;
182
183 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output>;
185}
186
187pub trait FilterDirective: Expr {}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct Comparison {
196 pub comparator: Comparator,
198 pub attribute: String,
200 pub value: serde_json::Value,
202}
203
204impl Comparison {
205 pub fn new(
207 comparator: Comparator,
208 attribute: impl Into<String>,
209 value: impl Into<serde_json::Value>,
210 ) -> Self {
211 Comparison {
212 comparator,
213 attribute: attribute.into(),
214 value: value.into(),
215 }
216 }
217}
218
219impl Expr for Comparison {
220 fn expr_name(&self) -> &'static str {
221 "comparison"
222 }
223
224 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
225 visitor.visit_comparison(self)
226 }
227}
228
229impl FilterDirective for Comparison {}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct Operation {
234 pub operator: Operator,
236 pub arguments: Vec<FilterDirectiveEnum>,
238}
239
240impl Operation {
241 pub fn new(operator: Operator, arguments: Vec<FilterDirectiveEnum>) -> Self {
243 Operation {
244 operator,
245 arguments,
246 }
247 }
248
249 pub fn and(arguments: Vec<FilterDirectiveEnum>) -> Self {
251 Self::new(Operator::And, arguments)
252 }
253
254 pub fn or(arguments: Vec<FilterDirectiveEnum>) -> Self {
256 Self::new(Operator::Or, arguments)
257 }
258
259 pub fn not(argument: FilterDirectiveEnum) -> Self {
261 Self::new(Operator::Not, vec![argument])
262 }
263}
264
265impl Expr for Operation {
266 fn expr_name(&self) -> &'static str {
267 "operation"
268 }
269
270 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
271 visitor.visit_operation(self)
272 }
273}
274
275impl FilterDirective for Operation {}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279#[serde(tag = "type", rename_all = "snake_case")]
280pub enum FilterDirectiveEnum {
281 Comparison(Comparison),
283 Operation(Operation),
285}
286
287impl FilterDirectiveEnum {
288 pub fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
290 match self {
291 FilterDirectiveEnum::Comparison(c) => visitor.visit_comparison(c),
292 FilterDirectiveEnum::Operation(o) => visitor.visit_operation(o),
293 }
294 }
295}
296
297impl From<Comparison> for FilterDirectiveEnum {
298 fn from(comparison: Comparison) -> Self {
299 FilterDirectiveEnum::Comparison(comparison)
300 }
301}
302
303impl From<Operation> for FilterDirectiveEnum {
304 fn from(operation: Operation) -> Self {
305 FilterDirectiveEnum::Operation(operation)
306 }
307}
308
309impl Expr for FilterDirectiveEnum {
310 fn expr_name(&self) -> &'static str {
311 match self {
312 FilterDirectiveEnum::Comparison(_) => "comparison",
313 FilterDirectiveEnum::Operation(_) => "operation",
314 }
315 }
316
317 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
318 match self {
319 FilterDirectiveEnum::Comparison(c) => visitor.visit_comparison(c),
320 FilterDirectiveEnum::Operation(o) => visitor.visit_operation(o),
321 }
322 }
323}
324
325impl FilterDirective for FilterDirectiveEnum {}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct StructuredQuery {
330 pub query: String,
332 pub filter: Option<FilterDirectiveEnum>,
334 pub limit: Option<usize>,
336}
337
338impl StructuredQuery {
339 pub fn new(
341 query: impl Into<String>,
342 filter: Option<FilterDirectiveEnum>,
343 limit: Option<usize>,
344 ) -> Self {
345 StructuredQuery {
346 query: query.into(),
347 filter,
348 limit,
349 }
350 }
351
352 pub fn query_only(query: impl Into<String>) -> Self {
354 Self::new(query, None, None)
355 }
356
357 pub fn with_filter(query: impl Into<String>, filter: impl Into<FilterDirectiveEnum>) -> Self {
359 Self::new(query, Some(filter.into()), None)
360 }
361}
362
363impl Expr for StructuredQuery {
364 fn expr_name(&self) -> &'static str {
365 "structured_query"
366 }
367
368 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
369 visitor.visit_structured_query(self)
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 fn to_snake_case(name: &str) -> String {
379 let mut snake_case = String::new();
380 for (i, char) in name.chars().enumerate() {
381 if char.is_uppercase() && i != 0 {
382 snake_case.push('_');
383 snake_case.push(char.to_ascii_lowercase());
384 } else {
385 snake_case.push(char.to_ascii_lowercase());
386 }
387 }
388 snake_case
389 }
390
391 #[test]
392 fn test_operator_display() {
393 assert_eq!(Operator::And.to_string(), "and");
394 assert_eq!(Operator::Or.to_string(), "or");
395 assert_eq!(Operator::Not.to_string(), "not");
396 }
397
398 #[test]
399 fn test_comparator_display() {
400 assert_eq!(Comparator::Eq.to_string(), "eq");
401 assert_eq!(Comparator::Ne.to_string(), "ne");
402 assert_eq!(Comparator::Gt.to_string(), "gt");
403 assert_eq!(Comparator::Gte.to_string(), "gte");
404 assert_eq!(Comparator::Lt.to_string(), "lt");
405 assert_eq!(Comparator::Lte.to_string(), "lte");
406 assert_eq!(Comparator::Contain.to_string(), "contain");
407 assert_eq!(Comparator::Like.to_string(), "like");
408 assert_eq!(Comparator::In.to_string(), "in");
409 assert_eq!(Comparator::Nin.to_string(), "nin");
410 }
411
412 #[test]
413 fn test_to_snake_case() {
414 assert_eq!(to_snake_case("Comparison"), "comparison");
415 assert_eq!(to_snake_case("Operation"), "operation");
416 assert_eq!(to_snake_case("StructuredQuery"), "structured_query");
417 assert_eq!(to_snake_case("FilterDirective"), "filter_directive");
418 }
419
420 #[test]
421 fn test_comparison_creation() {
422 let comparison = Comparison::new(Comparator::Eq, "field", "value");
423 assert_eq!(comparison.comparator, Comparator::Eq);
424 assert_eq!(comparison.attribute, "field");
425 assert_eq!(comparison.value, serde_json::json!("value"));
426 }
427
428 #[test]
429 fn test_operation_creation() {
430 let comparison = Comparison::new(Comparator::Gt, "age", 18);
431 let operation = Operation::and(vec![comparison.into()]);
432 assert_eq!(operation.operator, Operator::And);
433 assert_eq!(operation.arguments.len(), 1);
434 }
435
436 #[test]
437 fn test_structured_query_creation() {
438 let filter = Comparison::new(Comparator::Eq, "status", "active");
439 let query = StructuredQuery::with_filter("search term", filter);
440 assert_eq!(query.query, "search term");
441 assert!(query.filter.is_some());
442 assert!(query.limit.is_none());
443 }
444
445 struct TestVisitor {
446 allowed_operators: Vec<Operator>,
447 allowed_comparators: Vec<Comparator>,
448 }
449
450 impl TestVisitor {
451 fn new() -> Self {
452 TestVisitor {
453 allowed_operators: vec![Operator::And, Operator::Or],
454 allowed_comparators: vec![Comparator::Eq, Comparator::Ne],
455 }
456 }
457 }
458
459 impl Visitor for TestVisitor {
460 type Output = String;
461
462 fn allowed_operators(&self) -> Option<&[Operator]> {
463 Some(&self.allowed_operators)
464 }
465
466 fn allowed_comparators(&self) -> Option<&[Comparator]> {
467 Some(&self.allowed_comparators)
468 }
469
470 fn visit_operation(&self, operation: &Operation) -> Result<Self::Output> {
471 self.validate_func(operation.operator.into())?;
472 Ok(format!("operation:{}", operation.operator))
473 }
474
475 fn visit_comparison(&self, comparison: &Comparison) -> Result<Self::Output> {
476 self.validate_func(comparison.comparator.into())?;
477 Ok(format!(
478 "comparison:{}:{}",
479 comparison.attribute, comparison.comparator
480 ))
481 }
482
483 fn visit_structured_query(
484 &self,
485 structured_query: &StructuredQuery,
486 ) -> Result<Self::Output> {
487 Ok(format!("query:{}", structured_query.query))
488 }
489 }
490
491 #[test]
492 fn test_visitor_validation() {
493 let visitor = TestVisitor::new();
494
495 assert!(visitor.validate_func(Operator::And.into()).is_ok());
497 assert!(visitor.validate_func(Operator::Or.into()).is_ok());
498
499 assert!(visitor.validate_func(Operator::Not.into()).is_err());
501
502 assert!(visitor.validate_func(Comparator::Eq.into()).is_ok());
504 assert!(visitor.validate_func(Comparator::Ne.into()).is_ok());
505
506 assert!(visitor.validate_func(Comparator::Gt.into()).is_err());
508 }
509
510 #[test]
511 fn test_visitor_accept() {
512 let visitor = TestVisitor::new();
513
514 let comparison = Comparison::new(Comparator::Eq, "field", "value");
515 let result = comparison.accept(&visitor).unwrap();
516 assert_eq!(result, "comparison:field:eq");
517
518 let operation = Operation::and(vec![comparison.clone().into()]);
519 let result = operation.accept(&visitor).unwrap();
520 assert_eq!(result, "operation:and");
521 }
522
523 #[test]
524 fn test_serialization() {
525 let comparison = Comparison::new(Comparator::Eq, "field", "value");
526 let json = serde_json::to_string(&comparison).unwrap();
527 let deserialized: Comparison = serde_json::from_str(&json).unwrap();
528 assert_eq!(deserialized.comparator, comparison.comparator);
529 assert_eq!(deserialized.attribute, comparison.attribute);
530
531 let operation = Operation::and(vec![comparison.into()]);
532 let json = serde_json::to_string(&operation).unwrap();
533 let deserialized: Operation = serde_json::from_str(&json).unwrap();
534 assert_eq!(deserialized.operator, operation.operator);
535 }
536}