1use std::collections::BTreeMap;
15
16use crate::routing::evaluate::compare_values;
17use crate::routing::{ComparisonOp, FieldAccessor, FieldPath, FieldValue, Predicate};
18
19#[derive(Clone, Debug, PartialEq)]
24pub struct CompiledFunction {
25 plan: Plan,
26}
27
28impl CompiledFunction {
29 #[must_use]
34 pub fn evaluate(&self, accessor: &dyn FieldAccessor) -> bool {
35 eval_plan(&self.plan, accessor)
36 }
37}
38
39#[must_use]
43pub fn compile(predicate: &Predicate) -> CompiledFunction {
44 CompiledFunction {
45 plan: compile_plan(predicate),
46 }
47}
48
49#[derive(Clone, Debug, PartialEq)]
54enum Plan {
55 Comparison {
56 field: FieldPath,
57 op: ComparisonOp,
58 value: FieldValue,
59 },
60 Range {
61 field: FieldPath,
62 lower: FieldValue,
63 upper: FieldValue,
64 },
65 Exists {
66 field: FieldPath,
67 },
68 All(Vec<Self>),
69 Any(Vec<Self>),
70 Not(Box<Self>),
71}
72
73fn compile_plan(predicate: &Predicate) -> Plan {
74 match predicate {
75 Predicate::Comparison { field, op, value } => Plan::Comparison {
76 field: field.clone(),
77 op: *op,
78 value: value.clone(),
79 },
80 Predicate::Range {
81 field,
82 lower,
83 upper,
84 } => Plan::Range {
85 field: field.clone(),
86 lower: lower.clone(),
87 upper: upper.clone(),
88 },
89 Predicate::Exists { field } => Plan::Exists {
90 field: field.clone(),
91 },
92 Predicate::And(children) => Plan::All(optimize_clauses(children)),
93 Predicate::Or(children) => Plan::Any(optimize_clauses(children)),
94 Predicate::Not(child) => Plan::Not(Box::new(compile_plan(child))),
95 }
96}
97
98fn optimize_clauses(children: &[Predicate]) -> Vec<Plan> {
106 let mut compiled: Vec<Plan> = children.iter().map(compile_plan).collect();
107 let frequency = field_frequency(&compiled);
108 compiled.sort_by(|left, right| {
109 cost(left)
110 .cmp(&cost(right))
111 .then_with(|| frequency_rank(right, &frequency).cmp(&frequency_rank(left, &frequency)))
112 });
113 compiled
114}
115
116fn cost(plan: &Plan) -> u32 {
118 match plan {
119 Plan::Exists { .. } => 1,
120 Plan::Comparison { .. } => 2,
121 Plan::Range { .. } => 3,
122 Plan::Not(child) => cost(child),
123 Plan::All(children) | Plan::Any(children) => {
124 children.iter().map(cost).sum::<u32>().saturating_add(1)
125 }
126 }
127}
128
129fn primary_field(plan: &Plan) -> Option<&FieldPath> {
131 match plan {
132 Plan::Comparison { field, .. } | Plan::Range { field, .. } | Plan::Exists { field } => {
133 Some(field)
134 }
135 Plan::Not(child) => primary_field(child),
136 Plan::All(_) | Plan::Any(_) => None,
137 }
138}
139
140fn field_key(field: &FieldPath) -> String {
141 field.segments().collect::<Vec<_>>().join(".")
142}
143
144fn field_frequency(clauses: &[Plan]) -> BTreeMap<String, u32> {
145 let mut counts = BTreeMap::new();
146 for clause in clauses {
147 if let Some(field) = primary_field(clause) {
148 *counts.entry(field_key(field)).or_insert(0) += 1;
149 }
150 }
151 counts
152}
153
154fn frequency_rank(plan: &Plan, frequency: &BTreeMap<String, u32>) -> u32 {
155 primary_field(plan)
156 .and_then(|field| frequency.get(&field_key(field)).copied())
157 .unwrap_or(0)
158}
159
160fn eval_plan(plan: &Plan, accessor: &dyn FieldAccessor) -> bool {
161 match plan {
162 Plan::Comparison { field, op, value } => accessor
163 .field(field)
164 .is_some_and(|field_value| compare_values(field_value, *op, value)),
165 Plan::Range {
166 field,
167 lower,
168 upper,
169 } => accessor.field(field).is_some_and(|field_value| {
170 compare_values(field_value, ComparisonOp::Gte, lower)
171 && compare_values(field_value, ComparisonOp::Lte, upper)
172 }),
173 Plan::Exists { field } => accessor.field(field).is_some(),
174 Plan::All(children) => children.iter().all(|child| eval_plan(child, accessor)),
175 Plan::Any(children) => children.iter().any(|child| eval_plan(child, accessor)),
176 Plan::Not(child) => !eval_plan(child, accessor),
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::{Plan, compile};
183 use crate::routing::evaluate::evaluate;
184 use crate::routing::{
185 ComparisonOp, FieldAccessor, FieldPath, FieldValue, FieldValueRef, Predicate,
186 };
187
188 #[derive(Debug)]
190 struct MapAccessor {
191 entries: Vec<(&'static str, FieldValueRef<'static>)>,
192 }
193
194 impl MapAccessor {
195 const fn new(entries: Vec<(&'static str, FieldValueRef<'static>)>) -> Self {
196 Self { entries }
197 }
198 }
199
200 impl FieldAccessor for MapAccessor {
201 fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
202 let key = path.segments().collect::<Vec<_>>().join(".");
203 self.entries
204 .iter()
205 .find(|(name, _)| *name == key)
206 .map(|(_, value)| *value)
207 }
208 }
209
210 fn comparison(field: &str, op: ComparisonOp, value: FieldValue) -> Predicate {
211 Predicate::Comparison {
212 field: FieldPath::new(field),
213 op,
214 value,
215 }
216 }
217
218 fn exists(field: &str) -> Predicate {
219 Predicate::Exists {
220 field: FieldPath::new(field),
221 }
222 }
223
224 fn predicate_corpus() -> Vec<Predicate> {
225 let amount_gt = comparison("amount", ComparisonOp::Gt, FieldValue::Integer(1_000));
226 let amount_le = comparison("amount", ComparisonOp::Lte, FieldValue::Integer(100));
227 let region_eq = comparison(
228 "region",
229 ComparisonOp::Eq,
230 FieldValue::Text(String::from("eu")),
231 );
232 let flag_eq = comparison("flag", ComparisonOp::Eq, FieldValue::Boolean(true));
233 let amount_range = Predicate::Range {
234 field: FieldPath::new("amount"),
235 lower: FieldValue::Integer(100),
236 upper: FieldValue::Integer(1_000),
237 };
238
239 vec![
240 amount_gt.clone(),
241 region_eq.clone(),
242 amount_range.clone(),
243 exists("region"),
244 exists("missing"),
245 Predicate::And(vec![exists("region"), amount_gt.clone()]),
246 Predicate::Or(vec![amount_le, region_eq.clone()]),
247 Predicate::Not(Box::new(amount_gt)),
248 Predicate::Not(Box::new(Predicate::Not(Box::new(Predicate::Not(
249 Box::new(flag_eq.clone()),
250 ))))),
251 Predicate::And(Vec::new()),
252 Predicate::Or(Vec::new()),
253 Predicate::And(vec![
254 exists("amount"),
255 Predicate::Or(vec![region_eq, flag_eq]),
256 amount_range,
257 ]),
258 comparison(
259 "amount",
260 ComparisonOp::Eq,
261 FieldValue::Text(String::from("x")),
262 ),
263 comparison("amount", ComparisonOp::Eq, FieldValue::Null),
264 ]
265 }
266
267 fn accessor_corpus() -> Vec<MapAccessor> {
268 vec![
269 MapAccessor::new(vec![("amount", FieldValueRef::Integer(1_500))]),
270 MapAccessor::new(vec![("amount", FieldValueRef::Integer(50))]),
271 MapAccessor::new(vec![("amount", FieldValueRef::Integer(500))]),
272 MapAccessor::new(vec![("region", FieldValueRef::Text("eu"))]),
273 MapAccessor::new(vec![("region", FieldValueRef::Text("us"))]),
274 MapAccessor::new(vec![("flag", FieldValueRef::Boolean(true))]),
275 MapAccessor::new(vec![("amount", FieldValueRef::Text("1500"))]),
276 MapAccessor::new(Vec::new()),
277 MapAccessor::new(vec![
278 ("amount", FieldValueRef::Integer(750)),
279 ("region", FieldValueRef::Text("eu")),
280 ("flag", FieldValueRef::Boolean(false)),
281 ]),
282 MapAccessor::new(vec![
283 ("amount", FieldValueRef::Integer(2_000)),
284 ("region", FieldValueRef::Text("us")),
285 ("flag", FieldValueRef::Boolean(true)),
286 ]),
287 ]
288 }
289
290 #[test]
291 fn compiled_matches_direct_evaluation_for_all_combinations() {
292 let predicates = predicate_corpus();
293 let accessors = accessor_corpus();
294 let mut combinations = 0_usize;
295
296 for predicate in &predicates {
297 let compiled = compile(predicate);
298 for accessor in &accessors {
299 assert_eq!(
300 compiled.evaluate(accessor),
301 evaluate(predicate, accessor),
302 "compiled diverged for {predicate:?}"
303 );
304 combinations += 1;
305 }
306 }
307
308 assert!(
309 combinations >= 100,
310 "expected >=100 combinations, ran {combinations}"
311 );
312 }
313
314 #[test]
315 fn compile_borrows_predicate_unchanged() {
316 let predicate = Predicate::And(vec![
317 comparison("amount", ComparisonOp::Gt, FieldValue::Integer(10)),
318 exists("region"),
319 ]);
320 let snapshot = predicate.clone();
321
322 let _ = compile(&predicate);
323
324 assert_eq!(predicate, snapshot);
325 }
326
327 #[test]
328 fn and_places_existence_check_before_comparison() {
329 let predicate = Predicate::And(vec![
330 comparison("amount", ComparisonOp::Gt, FieldValue::Integer(10)),
331 exists("region"),
332 ]);
333
334 let compiled = compile(&predicate);
335
336 assert!(matches!(
337 &compiled.plan,
338 Plan::All(clauses) if matches!(clauses.as_slice(), [Plan::Exists { .. }, ..])
339 ));
340 }
341
342 #[test]
343 fn and_extracts_more_frequent_field_first_among_equal_cost_clauses() {
344 let predicate = Predicate::And(vec![
345 comparison(
346 "region",
347 ComparisonOp::Eq,
348 FieldValue::Text(String::from("eu")),
349 ),
350 comparison("amount", ComparisonOp::Gt, FieldValue::Integer(10)),
351 comparison("amount", ComparisonOp::Lt, FieldValue::Integer(100)),
352 ]);
353
354 let compiled = compile(&predicate);
355
356 assert!(matches!(
357 &compiled.plan,
358 Plan::All(clauses)
359 if matches!(
360 clauses.first(),
361 Some(Plan::Comparison { field, .. }) if field.segments().eq(["amount"])
362 )
363 ));
364 }
365
366 #[test]
367 fn reordering_preserves_result_for_existence_and_comparison() {
368 let ordered = Predicate::And(vec![
369 comparison("amount", ComparisonOp::Gt, FieldValue::Integer(10)),
370 exists("region"),
371 ]);
372 let compiled = compile(&ordered);
373
374 let present = MapAccessor::new(vec![
375 ("amount", FieldValueRef::Integer(50)),
376 ("region", FieldValueRef::Text("eu")),
377 ]);
378 let missing_region = MapAccessor::new(vec![("amount", FieldValueRef::Integer(50))]);
379
380 assert!(compiled.evaluate(&present));
381 assert!(!compiled.evaluate(&missing_region));
382 }
383}