diskann_label_filter/parser/
ast.rs1use std::fmt;
7
8use serde_json::Value;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum ASTExpr {
13 And(Vec<ASTExpr>),
15 Or(Vec<ASTExpr>),
17 Not(Box<ASTExpr>),
19 Compare { field: String, op: CompareOp },
21}
22
23#[derive(Debug, Clone, PartialEq)]
25pub enum CompareOp {
26 Eq(Value), Ne(Value), Lt(f64), Lte(f64), Gt(f64), Gte(f64), }
39
40impl fmt::Display for CompareOp {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 CompareOp::Eq(_) => write!(f, "=="),
44 CompareOp::Ne(_) => write!(f, "!="),
45 CompareOp::Lt(_) => write!(f, "<"),
46 CompareOp::Lte(_) => write!(f, "<="),
47 CompareOp::Gt(_) => write!(f, ">"),
48 CompareOp::Gte(_) => write!(f, ">="),
49 }
50 }
51}
52
53pub trait ASTVisitor {
55 type Output;
56
57 fn visit(&mut self, expr: &ASTExpr) -> Self::Output {
59 match expr {
60 ASTExpr::And(exprs) => self.visit_and(exprs),
61 ASTExpr::Or(exprs) => self.visit_or(exprs),
62 ASTExpr::Not(expr) => self.visit_not(expr),
63 ASTExpr::Compare { field, op } => self.visit_compare(field, op),
64 }
65 }
66
67 fn visit_and(&mut self, exprs: &[ASTExpr]) -> Self::Output;
69
70 fn visit_or(&mut self, exprs: &[ASTExpr]) -> Self::Output;
72
73 fn visit_not(&mut self, expr: &ASTExpr) -> Self::Output;
75
76 fn visit_compare(&mut self, field: &str, op: &CompareOp) -> Self::Output;
78}
79
80impl ASTExpr {
82 pub fn accept<V: ASTVisitor>(&self, visitor: &mut V) -> V::Output {
84 visitor.visit(self)
85 }
86}
87
88pub struct PrintVisitor {
90 indent_level: usize,
91 indent_str: String,
92}
93
94impl Default for PrintVisitor {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl PrintVisitor {
101 pub fn new() -> Self {
103 Self {
104 indent_level: 0,
105 indent_str: " ".to_string(),
106 }
107 }
108
109 pub fn with_indent(indent_str: &str) -> Self {
111 Self {
112 indent_level: 0,
113 indent_str: indent_str.to_string(),
114 }
115 }
116
117 fn indent(&self) -> String {
118 self.indent_str.repeat(self.indent_level)
119 }
120
121 fn value_to_string(value: &Value) -> String {
122 match value {
123 Value::String(s) => format!("\"{}\"", s.replace('\"', "\\\"")),
124 Value::Array(arr) => {
125 let items: Vec<String> = arr.iter().map(Self::value_to_string).collect();
126 format!("[{}]", items.join(", "))
127 }
128 _ => value.to_string(),
129 }
130 }
131}
132
133impl ASTVisitor for PrintVisitor {
134 type Output = String;
135
136 fn visit_and(&mut self, exprs: &[ASTExpr]) -> Self::Output {
137 if exprs.is_empty() {
138 return "true".to_string();
139 }
140
141 if exprs.len() == 1 {
142 return self.visit(&exprs[0]);
143 }
144
145 let current_indent = self.indent();
146 self.indent_level += 1;
147
148 let inner: Vec<String> = exprs
149 .iter()
150 .map(|expr| format!("\n{}{}", self.indent(), self.visit(expr)))
151 .collect();
152
153 self.indent_level -= 1;
154
155 format!("AND({}\n{})", inner.join(","), current_indent)
156 }
157
158 fn visit_or(&mut self, exprs: &[ASTExpr]) -> Self::Output {
159 if exprs.is_empty() {
160 return "false".to_string();
161 }
162
163 if exprs.len() == 1 {
164 return self.visit(&exprs[0]);
165 }
166
167 let current_indent = self.indent();
168 self.indent_level += 1;
169
170 let inner: Vec<String> = exprs
171 .iter()
172 .map(|expr| format!("\n{}{}", self.indent(), self.visit(expr)))
173 .collect();
174
175 self.indent_level -= 1;
176
177 format!("OR({}\n{})", inner.join(","), current_indent)
178 }
179
180 fn visit_not(&mut self, expr: &ASTExpr) -> Self::Output {
181 format!("NOT({})", self.visit(expr))
182 }
183
184 fn visit_compare(&mut self, field: &str, op: &CompareOp) -> Self::Output {
185 let value_str = match op {
186 CompareOp::Eq(value) => Self::value_to_string(value),
187 CompareOp::Ne(value) => Self::value_to_string(value),
188 CompareOp::Lt(num) => num.to_string(),
189 CompareOp::Lte(num) => num.to_string(),
190 CompareOp::Gt(num) => num.to_string(),
191 CompareOp::Gte(num) => num.to_string(),
192 };
193
194 format!("{}{}{}", field, op, value_str)
195 }
196}
197
198impl fmt::Display for ASTExpr {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 let mut visitor = PrintVisitor::new();
202 write!(f, "{}", self.accept(&mut visitor))
203 }
204}
205
206impl ASTExpr {
208 pub fn to_string_with_indent(&self, indent: &str) -> String {
210 let mut visitor = PrintVisitor::with_indent(indent);
211 self.accept(&mut visitor)
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use serde_json::json;
218
219 use super::*;
220
221 #[test]
222 fn test_ast_visitor() {
223 let expr = ASTExpr::Compare {
225 field: "age".to_string(),
226 op: CompareOp::Gt(30.0),
227 };
228
229 assert_eq!(expr.to_string(), "age>30");
230
231 let and_expr = ASTExpr::And(vec![
233 ASTExpr::Compare {
234 field: "age".to_string(),
235 op: CompareOp::Gt(30.0),
236 },
237 ASTExpr::Compare {
238 field: "name".to_string(),
239 op: CompareOp::Eq(json!("John")),
240 },
241 ]);
242
243 let expected_and = "AND(\n age>30,\n name==\"John\"\n)";
244 assert_eq!(and_expr.to_string(), expected_and);
245
246 let or_expr = ASTExpr::Or(vec![
248 ASTExpr::Compare {
249 field: "age".to_string(),
250 op: CompareOp::Gt(30.0),
251 },
252 ASTExpr::Compare {
253 field: "name".to_string(),
254 op: CompareOp::Eq(json!("John")),
255 },
256 ]);
257
258 let expected_or = "OR(\n age>30,\n name==\"John\"\n)";
259 assert_eq!(or_expr.to_string(), expected_or);
260
261 let not_expr = ASTExpr::Not(Box::new(ASTExpr::Compare {
263 field: "age".to_string(),
264 op: CompareOp::Lt(18.0),
265 }));
266
267 assert_eq!(not_expr.to_string(), "NOT(age<18)");
268
269 let nested_expr = ASTExpr::And(vec![
271 ASTExpr::Or(vec![
272 ASTExpr::Compare {
273 field: "age".to_string(),
274 op: CompareOp::Gt(30.0),
275 },
276 ASTExpr::Compare {
277 field: "age".to_string(),
278 op: CompareOp::Lt(20.0),
279 },
280 ]),
281 ASTExpr::Not(Box::new(ASTExpr::Compare {
282 field: "name".to_string(),
283 op: CompareOp::Eq(json!("Admin")),
284 })),
285 ]);
286
287 let expected_nested = "AND(\n OR(\n age>30,\n age<20\n ),\n NOT(name==\"Admin\")\n)";
288 assert_eq!(nested_expr.to_string(), expected_nested);
289 }
290}