Skip to main content

diskann_label_filter/parser/
ast.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::fmt;
7
8use serde_json::Value;
9
10/// AST for query filters https://en.wikipedia.org/wiki/Abstract_syntax_tree
11#[derive(Debug, Clone, PartialEq)]
12pub enum ASTExpr {
13    /// Logical AND: all sub-expressions must be true
14    And(Vec<ASTExpr>),
15    /// Logical OR: at least one sub-expression must be true
16    Or(Vec<ASTExpr>),
17    /// Logical NOT: negates the sub-expression
18    Not(Box<ASTExpr>),
19    /// Comparison on a field (supports dot notation)
20    Compare { field: String, op: CompareOp },
21}
22
23/// Supported comparison operators with type-safe values
24#[derive(Debug, Clone, PartialEq)]
25pub enum CompareOp {
26    /// Equal comparison, can be used with any value type
27    Eq(Value), // $eq
28    /// Not equal comparison, can be used with any value type
29    Ne(Value), // $ne
30    /// Less than comparison, only valid for numeric values
31    Lt(f64), // $lt
32    /// Less than or equal comparison, only valid for numeric values
33    Lte(f64), // $lte
34    /// Greater than comparison, only valid for numeric values
35    Gt(f64), // $gt
36    /// Greater than or equal comparison, only valid for numeric values
37    Gte(f64), // $gte
38}
39
40impl fmt::Display for CompareOp {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            CompareOp::Eq(_) => write!(f, "=="),
44            CompareOp::Ne(_) => write!(f, "!="),
45            CompareOp::Lt(_) => write!(f, "<"),
46            CompareOp::Lte(_) => write!(f, "<="),
47            CompareOp::Gt(_) => write!(f, ">"),
48            CompareOp::Gte(_) => write!(f, ">="),
49        }
50    }
51}
52
53/// Trait for visiting AST expressions
54pub trait ASTVisitor {
55    type Output;
56
57    /// Visit an AST expression
58    fn visit(&mut self, expr: &ASTExpr) -> Self::Output {
59        match expr {
60            ASTExpr::And(exprs) => self.visit_and(exprs),
61            ASTExpr::Or(exprs) => self.visit_or(exprs),
62            ASTExpr::Not(expr) => self.visit_not(expr),
63            ASTExpr::Compare { field, op } => self.visit_compare(field, op),
64        }
65    }
66
67    /// Visit an AND expression
68    fn visit_and(&mut self, exprs: &[ASTExpr]) -> Self::Output;
69
70    /// Visit an OR expression
71    fn visit_or(&mut self, exprs: &[ASTExpr]) -> Self::Output;
72
73    /// Visit a NOT expression
74    fn visit_not(&mut self, expr: &ASTExpr) -> Self::Output;
75
76    /// Visit a comparison expression
77    fn visit_compare(&mut self, field: &str, op: &CompareOp) -> Self::Output;
78}
79
80/// Implementation of the visitor pattern for ASTExpr
81impl ASTExpr {
82    /// Accept a visitor and return its output
83    pub fn accept<V: ASTVisitor>(&self, visitor: &mut V) -> V::Output {
84        visitor.visit(self)
85    }
86}
87
88/// A visitor that converts AST expressions to a human-readable string
89pub struct PrintVisitor {
90    indent_level: usize,
91    indent_str: String,
92}
93
94impl Default for PrintVisitor {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl PrintVisitor {
101    /// Create a new PrintVisitor with default settings
102    pub fn new() -> Self {
103        Self {
104            indent_level: 0,
105            indent_str: " ".to_string(),
106        }
107    }
108
109    /// Create a new PrintVisitor with custom indentation
110    pub fn with_indent(indent_str: &str) -> Self {
111        Self {
112            indent_level: 0,
113            indent_str: indent_str.to_string(),
114        }
115    }
116
117    fn indent(&self) -> String {
118        self.indent_str.repeat(self.indent_level)
119    }
120
121    fn value_to_string(value: &Value) -> String {
122        match value {
123            Value::String(s) => format!("\"{}\"", s.replace('\"', "\\\"")),
124            Value::Array(arr) => {
125                let items: Vec<String> = arr.iter().map(Self::value_to_string).collect();
126                format!("[{}]", items.join(", "))
127            }
128            _ => value.to_string(),
129        }
130    }
131}
132
133impl ASTVisitor for PrintVisitor {
134    type Output = String;
135
136    fn visit_and(&mut self, exprs: &[ASTExpr]) -> Self::Output {
137        if exprs.is_empty() {
138            return "true".to_string();
139        }
140
141        if exprs.len() == 1 {
142            return self.visit(&exprs[0]);
143        }
144
145        let current_indent = self.indent();
146        self.indent_level += 1;
147
148        let inner: Vec<String> = exprs
149            .iter()
150            .map(|expr| format!("\n{}{}", self.indent(), self.visit(expr)))
151            .collect();
152
153        self.indent_level -= 1;
154
155        format!("AND({}\n{})", inner.join(","), current_indent)
156    }
157
158    fn visit_or(&mut self, exprs: &[ASTExpr]) -> Self::Output {
159        if exprs.is_empty() {
160            return "false".to_string();
161        }
162
163        if exprs.len() == 1 {
164            return self.visit(&exprs[0]);
165        }
166
167        let current_indent = self.indent();
168        self.indent_level += 1;
169
170        let inner: Vec<String> = exprs
171            .iter()
172            .map(|expr| format!("\n{}{}", self.indent(), self.visit(expr)))
173            .collect();
174
175        self.indent_level -= 1;
176
177        format!("OR({}\n{})", inner.join(","), current_indent)
178    }
179
180    fn visit_not(&mut self, expr: &ASTExpr) -> Self::Output {
181        format!("NOT({})", self.visit(expr))
182    }
183
184    fn visit_compare(&mut self, field: &str, op: &CompareOp) -> Self::Output {
185        let value_str = match op {
186            CompareOp::Eq(value) => Self::value_to_string(value),
187            CompareOp::Ne(value) => Self::value_to_string(value),
188            CompareOp::Lt(num) => num.to_string(),
189            CompareOp::Lte(num) => num.to_string(),
190            CompareOp::Gt(num) => num.to_string(),
191            CompareOp::Gte(num) => num.to_string(),
192        };
193
194        format!("{}{}{}", field, op, value_str)
195    }
196}
197
198/// Display implementation for ASTExpr to easily print as string
199impl fmt::Display for ASTExpr {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        let mut visitor = PrintVisitor::new();
202        write!(f, "{}", self.accept(&mut visitor))
203    }
204}
205
206/// Extension methods for ASTExpr for custom string representation
207impl ASTExpr {
208    /// Convert the AST expression to a human-readable string with custom indentation
209    pub fn to_string_with_indent(&self, indent: &str) -> String {
210        let mut visitor = PrintVisitor::with_indent(indent);
211        self.accept(&mut visitor)
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use serde_json::json;
218
219    use super::*;
220
221    #[test]
222    fn test_ast_visitor() {
223        // Test simple comparison
224        let expr = ASTExpr::Compare {
225            field: "age".to_string(),
226            op: CompareOp::Gt(30.0),
227        };
228
229        assert_eq!(expr.to_string(), "age>30");
230
231        // Test AND expression
232        let and_expr = ASTExpr::And(vec![
233            ASTExpr::Compare {
234                field: "age".to_string(),
235                op: CompareOp::Gt(30.0),
236            },
237            ASTExpr::Compare {
238                field: "name".to_string(),
239                op: CompareOp::Eq(json!("John")),
240            },
241        ]);
242
243        let expected_and = "AND(\n age>30,\n name==\"John\"\n)";
244        assert_eq!(and_expr.to_string(), expected_and);
245
246        // Test OR expression
247        let or_expr = ASTExpr::Or(vec![
248            ASTExpr::Compare {
249                field: "age".to_string(),
250                op: CompareOp::Gt(30.0),
251            },
252            ASTExpr::Compare {
253                field: "name".to_string(),
254                op: CompareOp::Eq(json!("John")),
255            },
256        ]);
257
258        let expected_or = "OR(\n age>30,\n name==\"John\"\n)";
259        assert_eq!(or_expr.to_string(), expected_or);
260
261        // Test NOT expression
262        let not_expr = ASTExpr::Not(Box::new(ASTExpr::Compare {
263            field: "age".to_string(),
264            op: CompareOp::Lt(18.0),
265        }));
266
267        assert_eq!(not_expr.to_string(), "NOT(age<18)");
268
269        // Test nested expressions
270        let nested_expr = ASTExpr::And(vec![
271            ASTExpr::Or(vec![
272                ASTExpr::Compare {
273                    field: "age".to_string(),
274                    op: CompareOp::Gt(30.0),
275                },
276                ASTExpr::Compare {
277                    field: "age".to_string(),
278                    op: CompareOp::Lt(20.0),
279                },
280            ]),
281            ASTExpr::Not(Box::new(ASTExpr::Compare {
282                field: "name".to_string(),
283                op: CompareOp::Eq(json!("Admin")),
284            })),
285        ]);
286
287        let expected_nested = "AND(\n OR(\n  age>30,\n  age<20\n ),\n NOT(name==\"Admin\")\n)";
288        assert_eq!(nested_expr.to_string(), expected_nested);
289    }
290}