1use std::cmp::Ordering;
2
3use crate::routing::{ComparisonOp, FieldPath, FieldValue, Predicate};
4
5#[derive(Clone, Copy, Debug, PartialEq)]
7pub enum FieldValueRef<'a> {
8 Text(&'a str),
10 Integer(i64),
12 Float(f64),
14 Boolean(bool),
16 Null,
18}
19
20pub trait FieldAccessor: std::fmt::Debug {
22 fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>>;
24}
25
26#[must_use]
28pub fn evaluate(predicate: &Predicate, accessor: &dyn FieldAccessor) -> bool {
29 match predicate {
30 Predicate::Comparison { field, op, value } => accessor
31 .field(field)
32 .is_some_and(|field_value| compare_values(field_value, *op, value)),
33 Predicate::And(children) => children.iter().all(|child| evaluate(child, accessor)),
34 Predicate::Or(children) => children.iter().any(|child| evaluate(child, accessor)),
35 Predicate::Not(child) => !evaluate(child, accessor),
36 Predicate::Range {
37 field,
38 lower,
39 upper,
40 } => accessor.field(field).is_some_and(|field_value| {
41 compare_values(field_value, ComparisonOp::Gte, lower)
42 && compare_values(field_value, ComparisonOp::Lte, upper)
43 }),
44 Predicate::Exists { field } => accessor.field(field).is_some(),
45 }
46}
47
48pub(crate) fn compare_values(
49 field_value: FieldValueRef<'_>,
50 op: ComparisonOp,
51 literal: &FieldValue,
52) -> bool {
53 match (field_value, literal) {
54 (FieldValueRef::Text(left), FieldValue::Text(right)) => {
55 compare_ordering(left.cmp(right.as_str()), op)
56 }
57 (FieldValueRef::Integer(left), FieldValue::Integer(right)) => {
58 compare_ordering(left.cmp(right), op)
59 }
60 (FieldValueRef::Float(left), FieldValue::Float(right)) => left
61 .partial_cmp(right)
62 .is_some_and(|ordering| compare_ordering(ordering, op)),
63 (FieldValueRef::Boolean(left), FieldValue::Boolean(right)) => {
64 compare_equality(left == *right, op)
65 }
66 (FieldValueRef::Null, FieldValue::Null) => compare_equality(true, op),
67 _ => false,
68 }
69}
70
71const fn compare_ordering(ordering: Ordering, op: ComparisonOp) -> bool {
72 match op {
73 ComparisonOp::Eq => ordering.is_eq(),
74 ComparisonOp::Ne => !ordering.is_eq(),
75 ComparisonOp::Gt => ordering.is_gt(),
76 ComparisonOp::Lt => ordering.is_lt(),
77 ComparisonOp::Gte => ordering.is_ge(),
78 ComparisonOp::Lte => ordering.is_le(),
79 }
80}
81
82const fn compare_equality(equal: bool, op: ComparisonOp) -> bool {
83 match op {
84 ComparisonOp::Eq => equal,
85 ComparisonOp::Ne => !equal,
86 ComparisonOp::Gt | ComparisonOp::Lt | ComparisonOp::Gte | ComparisonOp::Lte => false,
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use std::cell::Cell;
93
94 use super::{FieldAccessor, FieldValueRef, evaluate};
95 use crate::routing::{ComparisonOp, FieldPath, FieldValue, Predicate};
96
97 #[derive(Debug)]
98 struct StaticAccessor<'a> {
99 field: &'a str,
100 value: FieldValueRef<'a>,
101 }
102
103 impl<'a> StaticAccessor<'a> {
104 const fn new(field: &'a str, value: FieldValueRef<'a>) -> Self {
105 Self { field, value }
106 }
107 }
108
109 impl FieldAccessor for StaticAccessor<'_> {
110 fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
111 path.segments().eq([self.field]).then_some(self.value)
112 }
113 }
114
115 #[derive(Debug)]
116 struct CountingAccessor {
117 count: Cell<usize>,
118 }
119
120 impl CountingAccessor {
121 const fn new() -> Self {
122 Self {
123 count: Cell::new(0),
124 }
125 }
126
127 fn count(&self) -> usize {
128 self.count.get()
129 }
130 }
131
132 impl FieldAccessor for CountingAccessor {
133 fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
134 self.count.set(self.count.get() + 1);
135
136 if path.segments().eq(["truth"]) {
137 Some(FieldValueRef::Boolean(true))
138 } else if path.segments().eq(["falsehood"]) {
139 Some(FieldValueRef::Boolean(false))
140 } else if path.segments().eq(["third"]) {
141 Some(FieldValueRef::Boolean(true))
142 } else {
143 None
144 }
145 }
146 }
147
148 fn integer_comparison(op: ComparisonOp, value: i64) -> Predicate {
149 Predicate::Comparison {
150 field: FieldPath::new("amount"),
151 op,
152 value: FieldValue::Integer(value),
153 }
154 }
155
156 fn boolean_comparison(field: &str, value: bool) -> Predicate {
157 Predicate::Comparison {
158 field: FieldPath::new(field),
159 op: ComparisonOp::Eq,
160 value: FieldValue::Boolean(value),
161 }
162 }
163
164 fn amount_range() -> Predicate {
165 Predicate::Range {
166 field: FieldPath::new("amount"),
167 lower: FieldValue::Integer(100),
168 upper: FieldValue::Integer(200),
169 }
170 }
171
172 #[test]
173 fn integer_greater_than_comparison_matches() {
174 let predicate = integer_comparison(ComparisonOp::Gt, 1_000);
175 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(1_500));
176
177 assert!(evaluate(&predicate, &accessor));
178 }
179
180 #[test]
181 fn integer_greater_than_comparison_rejects_lower_value() {
182 let predicate = integer_comparison(ComparisonOp::Gt, 1_000);
183 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(500));
184
185 assert!(!evaluate(&predicate, &accessor));
186 }
187
188 #[test]
189 fn exists_returns_true_for_present_field() {
190 let predicate = Predicate::Exists {
191 field: FieldPath::new("region"),
192 };
193 let accessor = StaticAccessor::new("region", FieldValueRef::Text("eu"));
194
195 assert!(evaluate(&predicate, &accessor));
196 }
197
198 #[test]
199 fn exists_returns_false_for_missing_field() {
200 let predicate = Predicate::Exists {
201 field: FieldPath::new("missing"),
202 };
203 let accessor = StaticAccessor::new("region", FieldValueRef::Text("eu"));
204
205 assert!(!evaluate(&predicate, &accessor));
206 }
207
208 #[test]
209 fn comparison_returns_false_for_missing_field() {
210 let predicate = Predicate::Comparison {
211 field: FieldPath::new("missing"),
212 op: ComparisonOp::Eq,
213 value: FieldValue::Text(String::from("x")),
214 };
215 let accessor = StaticAccessor::new("region", FieldValueRef::Text("eu"));
216
217 assert!(!evaluate(&predicate, &accessor));
218 }
219
220 #[test]
221 fn and_short_circuits_at_first_false() {
222 let predicate = Predicate::And(vec![
223 boolean_comparison("truth", true),
224 boolean_comparison("falsehood", true),
225 boolean_comparison("third", true),
226 ]);
227 let accessor = CountingAccessor::new();
228
229 assert!(!evaluate(&predicate, &accessor));
230 assert_eq!(accessor.count(), 2);
231 }
232
233 #[test]
234 fn or_short_circuits_at_first_true() {
235 let predicate = Predicate::Or(vec![
236 boolean_comparison("falsehood", true),
237 boolean_comparison("truth", true),
238 boolean_comparison("third", true),
239 ]);
240 let accessor = CountingAccessor::new();
241
242 assert!(evaluate(&predicate, &accessor));
243 assert_eq!(accessor.count(), 2);
244 }
245
246 #[test]
247 fn not_negates_child_predicate() {
248 let true_predicate = Predicate::Not(Box::new(boolean_comparison("truth", true)));
249 let false_predicate = Predicate::Not(Box::new(boolean_comparison("falsehood", true)));
250 let accessor = CountingAccessor::new();
251
252 assert!(!evaluate(&true_predicate, &accessor));
253 assert!(evaluate(&false_predicate, &accessor));
254 }
255
256 #[test]
257 fn empty_boolean_combinators_have_vacuous_values() {
258 assert!(evaluate(
259 &Predicate::And(Vec::new()),
260 &StaticAccessor::new("amount", FieldValueRef::Integer(1))
261 ));
262 assert!(!evaluate(
263 &Predicate::Or(Vec::new()),
264 &StaticAccessor::new("amount", FieldValueRef::Integer(1))
265 ));
266 }
267
268 #[test]
269 fn range_includes_middle_and_bounds() {
270 for value in [150, 100, 200] {
271 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(value));
272
273 assert!(evaluate(&amount_range(), &accessor));
274 }
275 }
276
277 #[test]
278 fn range_rejects_value_below_lower_bound() {
279 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(50));
280
281 assert!(!evaluate(&amount_range(), &accessor));
282 }
283
284 #[test]
285 fn range_rejects_missing_field() {
286 let accessor = StaticAccessor::new("region", FieldValueRef::Text("eu"));
287
288 assert!(!evaluate(&amount_range(), &accessor));
289 }
290
291 #[test]
292 fn range_rejects_type_mismatch() {
293 let accessor = StaticAccessor::new("amount", FieldValueRef::Text("150"));
294
295 assert!(!evaluate(&amount_range(), &accessor));
296 }
297}