agent_chain_core/
structured_query.rs

1//! Internal representation of a structured query language.
2//!
3//! This module provides types for building structured queries that can be
4//! translated to different query languages using the visitor pattern.
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9use crate::error::{Error, Result};
10
11/// Enumerator of the logical operators.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14pub enum Operator {
15    /// Logical AND operator.
16    And,
17    /// Logical OR operator.
18    Or,
19    /// Logical NOT operator.
20    Not,
21}
22
23impl Operator {
24    /// Returns the string representation of the operator.
25    pub fn as_str(&self) -> &'static str {
26        match self {
27            Operator::And => "and",
28            Operator::Or => "or",
29            Operator::Not => "not",
30        }
31    }
32}
33
34impl fmt::Display for Operator {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        write!(f, "{}", self.as_str())
37    }
38}
39
40/// Enumerator of the comparison operators.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum Comparator {
44    /// Equal to.
45    Eq,
46    /// Not equal to.
47    Ne,
48    /// Greater than.
49    Gt,
50    /// Greater than or equal to.
51    Gte,
52    /// Less than.
53    Lt,
54    /// Less than or equal to.
55    Lte,
56    /// Contains.
57    Contain,
58    /// Like (pattern matching).
59    Like,
60    /// In a set of values.
61    In,
62    /// Not in a set of values.
63    Nin,
64}
65
66impl Comparator {
67    /// Returns the string representation of the comparator.
68    pub fn as_str(&self) -> &'static str {
69        match self {
70            Comparator::Eq => "eq",
71            Comparator::Ne => "ne",
72            Comparator::Gt => "gt",
73            Comparator::Gte => "gte",
74            Comparator::Lt => "lt",
75            Comparator::Lte => "lte",
76            Comparator::Contain => "contain",
77            Comparator::Like => "like",
78            Comparator::In => "in",
79            Comparator::Nin => "nin",
80        }
81    }
82}
83
84impl fmt::Display for Comparator {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        write!(f, "{}", self.as_str())
87    }
88}
89
90/// Either an Operator or a Comparator for validation purposes.
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum OperatorOrComparator {
93    /// An operator variant.
94    Operator(Operator),
95    /// A comparator variant.
96    Comparator(Comparator),
97}
98
99impl From<Operator> for OperatorOrComparator {
100    fn from(op: Operator) -> Self {
101        OperatorOrComparator::Operator(op)
102    }
103}
104
105impl From<Comparator> for OperatorOrComparator {
106    fn from(comp: Comparator) -> Self {
107        OperatorOrComparator::Comparator(comp)
108    }
109}
110
111impl fmt::Display for OperatorOrComparator {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            OperatorOrComparator::Operator(op) => write!(f, "{}", op),
115            OperatorOrComparator::Comparator(comp) => write!(f, "{}", comp),
116        }
117    }
118}
119
120/// Defines interface for IR translation using a visitor pattern.
121///
122/// Implementations of this trait translate structured query expressions
123/// into target-specific query formats.
124pub trait Visitor {
125    /// The output type produced by visiting expressions.
126    type Output;
127
128    /// Allowed comparators for this visitor, if restricted.
129    fn allowed_comparators(&self) -> Option<&[Comparator]> {
130        None
131    }
132
133    /// Allowed operators for this visitor, if restricted.
134    fn allowed_operators(&self) -> Option<&[Operator]> {
135        None
136    }
137
138    /// Validates that a function (operator or comparator) is allowed.
139    fn validate_func(&self, func: OperatorOrComparator) -> Result<()> {
140        match func {
141            OperatorOrComparator::Operator(op) => {
142                if let Some(allowed) = self.allowed_operators()
143                    && !allowed.contains(&op)
144                {
145                    return Err(Error::Other(format!(
146                        "Received disallowed operator {}. Allowed operators are {:?}",
147                        op, allowed
148                    )));
149                }
150            }
151            OperatorOrComparator::Comparator(comp) => {
152                if let Some(allowed) = self.allowed_comparators()
153                    && !allowed.contains(&comp)
154                {
155                    return Err(Error::Other(format!(
156                        "Received disallowed comparator {}. Allowed comparators are {:?}",
157                        comp, allowed
158                    )));
159                }
160            }
161        }
162        Ok(())
163    }
164
165    /// Translate an Operation.
166    fn visit_operation(&self, operation: &Operation) -> Result<Self::Output>;
167
168    /// Translate a Comparison.
169    fn visit_comparison(&self, comparison: &Comparison) -> Result<Self::Output>;
170
171    /// Translate a StructuredQuery.
172    fn visit_structured_query(&self, structured_query: &StructuredQuery) -> Result<Self::Output>;
173}
174
175/// Base trait for all expressions.
176///
177/// All expression types implement this trait and can accept visitors
178/// for translation to target-specific formats.
179pub trait Expr: fmt::Debug {
180    /// Returns the name of this expression type in snake_case.
181    fn expr_name(&self) -> &'static str;
182
183    /// Accept a visitor and return the result of visiting this expression.
184    fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output>;
185}
186
187/// A filtering expression.
188///
189/// This is a marker trait for expressions that can be used as filters
190/// in a structured query.
191pub trait FilterDirective: Expr {}
192
193/// Comparison to a value.
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct Comparison {
196    /// The comparator to use.
197    pub comparator: Comparator,
198    /// The attribute to compare.
199    pub attribute: String,
200    /// The value to compare to.
201    pub value: serde_json::Value,
202}
203
204impl Comparison {
205    /// Create a new Comparison.
206    pub fn new(
207        comparator: Comparator,
208        attribute: impl Into<String>,
209        value: impl Into<serde_json::Value>,
210    ) -> Self {
211        Comparison {
212            comparator,
213            attribute: attribute.into(),
214            value: value.into(),
215        }
216    }
217}
218
219impl Expr for Comparison {
220    fn expr_name(&self) -> &'static str {
221        "comparison"
222    }
223
224    fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
225        visitor.visit_comparison(self)
226    }
227}
228
229impl FilterDirective for Comparison {}
230
231/// Logical operation over other directives.
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct Operation {
234    /// The operator to use.
235    pub operator: Operator,
236    /// The arguments to the operator.
237    pub arguments: Vec<FilterDirectiveEnum>,
238}
239
240impl Operation {
241    /// Create a new Operation.
242    pub fn new(operator: Operator, arguments: Vec<FilterDirectiveEnum>) -> Self {
243        Operation {
244            operator,
245            arguments,
246        }
247    }
248
249    /// Create an AND operation.
250    pub fn and(arguments: Vec<FilterDirectiveEnum>) -> Self {
251        Self::new(Operator::And, arguments)
252    }
253
254    /// Create an OR operation.
255    pub fn or(arguments: Vec<FilterDirectiveEnum>) -> Self {
256        Self::new(Operator::Or, arguments)
257    }
258
259    /// Create a NOT operation.
260    pub fn not(argument: FilterDirectiveEnum) -> Self {
261        Self::new(Operator::Not, vec![argument])
262    }
263}
264
265impl Expr for Operation {
266    fn expr_name(&self) -> &'static str {
267        "operation"
268    }
269
270    fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
271        visitor.visit_operation(self)
272    }
273}
274
275impl FilterDirective for Operation {}
276
277/// Enum wrapper for filter directives to allow recursive structures.
278#[derive(Debug, Clone, Serialize, Deserialize)]
279#[serde(tag = "type", rename_all = "snake_case")]
280pub enum FilterDirectiveEnum {
281    /// A comparison directive.
282    Comparison(Comparison),
283    /// An operation directive.
284    Operation(Operation),
285}
286
287impl FilterDirectiveEnum {
288    /// Accept a visitor based on the variant.
289    pub fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
290        match self {
291            FilterDirectiveEnum::Comparison(c) => visitor.visit_comparison(c),
292            FilterDirectiveEnum::Operation(o) => visitor.visit_operation(o),
293        }
294    }
295}
296
297impl From<Comparison> for FilterDirectiveEnum {
298    fn from(comparison: Comparison) -> Self {
299        FilterDirectiveEnum::Comparison(comparison)
300    }
301}
302
303impl From<Operation> for FilterDirectiveEnum {
304    fn from(operation: Operation) -> Self {
305        FilterDirectiveEnum::Operation(operation)
306    }
307}
308
309impl Expr for FilterDirectiveEnum {
310    fn expr_name(&self) -> &'static str {
311        match self {
312            FilterDirectiveEnum::Comparison(_) => "comparison",
313            FilterDirectiveEnum::Operation(_) => "operation",
314        }
315    }
316
317    fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
318        match self {
319            FilterDirectiveEnum::Comparison(c) => visitor.visit_comparison(c),
320            FilterDirectiveEnum::Operation(o) => visitor.visit_operation(o),
321        }
322    }
323}
324
325impl FilterDirective for FilterDirectiveEnum {}
326
327/// Structured query.
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct StructuredQuery {
330    /// Query string.
331    pub query: String,
332    /// Filtering expression.
333    pub filter: Option<FilterDirectiveEnum>,
334    /// Limit on the number of results.
335    pub limit: Option<usize>,
336}
337
338impl StructuredQuery {
339    /// Create a new StructuredQuery.
340    pub fn new(
341        query: impl Into<String>,
342        filter: Option<FilterDirectiveEnum>,
343        limit: Option<usize>,
344    ) -> Self {
345        StructuredQuery {
346            query: query.into(),
347            filter,
348            limit,
349        }
350    }
351
352    /// Create a StructuredQuery with only a query string.
353    pub fn query_only(query: impl Into<String>) -> Self {
354        Self::new(query, None, None)
355    }
356
357    /// Create a StructuredQuery with a query and filter.
358    pub fn with_filter(query: impl Into<String>, filter: impl Into<FilterDirectiveEnum>) -> Self {
359        Self::new(query, Some(filter.into()), None)
360    }
361}
362
363impl Expr for StructuredQuery {
364    fn expr_name(&self) -> &'static str {
365        "structured_query"
366    }
367
368    fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output> {
369        visitor.visit_structured_query(self)
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    /// Convert a name from PascalCase to snake_case.
378    fn to_snake_case(name: &str) -> String {
379        let mut snake_case = String::new();
380        for (i, char) in name.chars().enumerate() {
381            if char.is_uppercase() && i != 0 {
382                snake_case.push('_');
383                snake_case.push(char.to_ascii_lowercase());
384            } else {
385                snake_case.push(char.to_ascii_lowercase());
386            }
387        }
388        snake_case
389    }
390
391    #[test]
392    fn test_operator_display() {
393        assert_eq!(Operator::And.to_string(), "and");
394        assert_eq!(Operator::Or.to_string(), "or");
395        assert_eq!(Operator::Not.to_string(), "not");
396    }
397
398    #[test]
399    fn test_comparator_display() {
400        assert_eq!(Comparator::Eq.to_string(), "eq");
401        assert_eq!(Comparator::Ne.to_string(), "ne");
402        assert_eq!(Comparator::Gt.to_string(), "gt");
403        assert_eq!(Comparator::Gte.to_string(), "gte");
404        assert_eq!(Comparator::Lt.to_string(), "lt");
405        assert_eq!(Comparator::Lte.to_string(), "lte");
406        assert_eq!(Comparator::Contain.to_string(), "contain");
407        assert_eq!(Comparator::Like.to_string(), "like");
408        assert_eq!(Comparator::In.to_string(), "in");
409        assert_eq!(Comparator::Nin.to_string(), "nin");
410    }
411
412    #[test]
413    fn test_to_snake_case() {
414        assert_eq!(to_snake_case("Comparison"), "comparison");
415        assert_eq!(to_snake_case("Operation"), "operation");
416        assert_eq!(to_snake_case("StructuredQuery"), "structured_query");
417        assert_eq!(to_snake_case("FilterDirective"), "filter_directive");
418    }
419
420    #[test]
421    fn test_comparison_creation() {
422        let comparison = Comparison::new(Comparator::Eq, "field", "value");
423        assert_eq!(comparison.comparator, Comparator::Eq);
424        assert_eq!(comparison.attribute, "field");
425        assert_eq!(comparison.value, serde_json::json!("value"));
426    }
427
428    #[test]
429    fn test_operation_creation() {
430        let comparison = Comparison::new(Comparator::Gt, "age", 18);
431        let operation = Operation::and(vec![comparison.into()]);
432        assert_eq!(operation.operator, Operator::And);
433        assert_eq!(operation.arguments.len(), 1);
434    }
435
436    #[test]
437    fn test_structured_query_creation() {
438        let filter = Comparison::new(Comparator::Eq, "status", "active");
439        let query = StructuredQuery::with_filter("search term", filter);
440        assert_eq!(query.query, "search term");
441        assert!(query.filter.is_some());
442        assert!(query.limit.is_none());
443    }
444
445    struct TestVisitor {
446        allowed_operators: Vec<Operator>,
447        allowed_comparators: Vec<Comparator>,
448    }
449
450    impl TestVisitor {
451        fn new() -> Self {
452            TestVisitor {
453                allowed_operators: vec![Operator::And, Operator::Or],
454                allowed_comparators: vec![Comparator::Eq, Comparator::Ne],
455            }
456        }
457    }
458
459    impl Visitor for TestVisitor {
460        type Output = String;
461
462        fn allowed_operators(&self) -> Option<&[Operator]> {
463            Some(&self.allowed_operators)
464        }
465
466        fn allowed_comparators(&self) -> Option<&[Comparator]> {
467            Some(&self.allowed_comparators)
468        }
469
470        fn visit_operation(&self, operation: &Operation) -> Result<Self::Output> {
471            self.validate_func(operation.operator.into())?;
472            Ok(format!("operation:{}", operation.operator))
473        }
474
475        fn visit_comparison(&self, comparison: &Comparison) -> Result<Self::Output> {
476            self.validate_func(comparison.comparator.into())?;
477            Ok(format!(
478                "comparison:{}:{}",
479                comparison.attribute, comparison.comparator
480            ))
481        }
482
483        fn visit_structured_query(
484            &self,
485            structured_query: &StructuredQuery,
486        ) -> Result<Self::Output> {
487            Ok(format!("query:{}", structured_query.query))
488        }
489    }
490
491    #[test]
492    fn test_visitor_validation() {
493        let visitor = TestVisitor::new();
494
495        // Valid operator
496        assert!(visitor.validate_func(Operator::And.into()).is_ok());
497        assert!(visitor.validate_func(Operator::Or.into()).is_ok());
498
499        // Invalid operator
500        assert!(visitor.validate_func(Operator::Not.into()).is_err());
501
502        // Valid comparator
503        assert!(visitor.validate_func(Comparator::Eq.into()).is_ok());
504        assert!(visitor.validate_func(Comparator::Ne.into()).is_ok());
505
506        // Invalid comparator
507        assert!(visitor.validate_func(Comparator::Gt.into()).is_err());
508    }
509
510    #[test]
511    fn test_visitor_accept() {
512        let visitor = TestVisitor::new();
513
514        let comparison = Comparison::new(Comparator::Eq, "field", "value");
515        let result = comparison.accept(&visitor).unwrap();
516        assert_eq!(result, "comparison:field:eq");
517
518        let operation = Operation::and(vec![comparison.clone().into()]);
519        let result = operation.accept(&visitor).unwrap();
520        assert_eq!(result, "operation:and");
521    }
522
523    #[test]
524    fn test_serialization() {
525        let comparison = Comparison::new(Comparator::Eq, "field", "value");
526        let json = serde_json::to_string(&comparison).unwrap();
527        let deserialized: Comparison = serde_json::from_str(&json).unwrap();
528        assert_eq!(deserialized.comparator, comparison.comparator);
529        assert_eq!(deserialized.attribute, comparison.attribute);
530
531        let operation = Operation::and(vec![comparison.into()]);
532        let json = serde_json::to_string(&operation).unwrap();
533        let deserialized: Operation = serde_json::from_str(&json).unwrap();
534        assert_eq!(deserialized.operator, operation.operator);
535    }
536}