Skip to main content

liminal/routing/
evaluate.rs

1use std::cmp::Ordering;
2
3use crate::routing::{ComparisonOp, FieldPath, FieldValue, Predicate};
4
5/// Borrowed view of a message field value used by routing predicate evaluation.
6#[derive(Clone, Copy, Debug, PartialEq)]
7pub enum FieldValueRef<'a> {
8    /// Borrowed UTF-8 text value.
9    Text(&'a str),
10    /// Signed integer value.
11    Integer(i64),
12    /// Floating-point value.
13    Float(f64),
14    /// Boolean value.
15    Boolean(bool),
16    /// Explicit null value.
17    Null,
18}
19
20/// Provides borrowed access to fields from a message being routed.
21pub trait FieldAccessor: std::fmt::Debug {
22    /// Returns a borrowed field value for `path`, or `None` when absent.
23    fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>>;
24}
25
26/// Evaluates a predicate against borrowed message fields.
27#[must_use]
28pub fn evaluate(predicate: &Predicate, accessor: &dyn FieldAccessor) -> bool {
29    match predicate {
30        Predicate::Comparison { field, op, value } => accessor
31            .field(field)
32            .is_some_and(|field_value| compare_values(field_value, *op, value)),
33        Predicate::And(children) => children.iter().all(|child| evaluate(child, accessor)),
34        Predicate::Or(children) => children.iter().any(|child| evaluate(child, accessor)),
35        Predicate::Not(child) => !evaluate(child, accessor),
36        Predicate::Range {
37            field,
38            lower,
39            upper,
40        } => accessor.field(field).is_some_and(|field_value| {
41            compare_values(field_value, ComparisonOp::Gte, lower)
42                && compare_values(field_value, ComparisonOp::Lte, upper)
43        }),
44        Predicate::Exists { field } => accessor.field(field).is_some(),
45    }
46}
47
48pub(crate) fn compare_values(
49    field_value: FieldValueRef<'_>,
50    op: ComparisonOp,
51    literal: &FieldValue,
52) -> bool {
53    match (field_value, literal) {
54        (FieldValueRef::Text(left), FieldValue::Text(right)) => {
55            compare_ordering(left.cmp(right.as_str()), op)
56        }
57        (FieldValueRef::Integer(left), FieldValue::Integer(right)) => {
58            compare_ordering(left.cmp(right), op)
59        }
60        (FieldValueRef::Float(left), FieldValue::Float(right)) => left
61            .partial_cmp(right)
62            .is_some_and(|ordering| compare_ordering(ordering, op)),
63        (FieldValueRef::Boolean(left), FieldValue::Boolean(right)) => {
64            compare_equality(left == *right, op)
65        }
66        (FieldValueRef::Null, FieldValue::Null) => compare_equality(true, op),
67        _ => false,
68    }
69}
70
71const fn compare_ordering(ordering: Ordering, op: ComparisonOp) -> bool {
72    match op {
73        ComparisonOp::Eq => ordering.is_eq(),
74        ComparisonOp::Ne => !ordering.is_eq(),
75        ComparisonOp::Gt => ordering.is_gt(),
76        ComparisonOp::Lt => ordering.is_lt(),
77        ComparisonOp::Gte => ordering.is_ge(),
78        ComparisonOp::Lte => ordering.is_le(),
79    }
80}
81
82const fn compare_equality(equal: bool, op: ComparisonOp) -> bool {
83    match op {
84        ComparisonOp::Eq => equal,
85        ComparisonOp::Ne => !equal,
86        ComparisonOp::Gt | ComparisonOp::Lt | ComparisonOp::Gte | ComparisonOp::Lte => false,
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use std::cell::Cell;
93
94    use super::{FieldAccessor, FieldValueRef, evaluate};
95    use crate::routing::{ComparisonOp, FieldPath, FieldValue, Predicate};
96
97    #[derive(Debug)]
98    struct StaticAccessor<'a> {
99        field: &'a str,
100        value: FieldValueRef<'a>,
101    }
102
103    impl<'a> StaticAccessor<'a> {
104        const fn new(field: &'a str, value: FieldValueRef<'a>) -> Self {
105            Self { field, value }
106        }
107    }
108
109    impl FieldAccessor for StaticAccessor<'_> {
110        fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
111            path.segments().eq([self.field]).then_some(self.value)
112        }
113    }
114
115    #[derive(Debug)]
116    struct CountingAccessor {
117        count: Cell<usize>,
118    }
119
120    impl CountingAccessor {
121        const fn new() -> Self {
122            Self {
123                count: Cell::new(0),
124            }
125        }
126
127        fn count(&self) -> usize {
128            self.count.get()
129        }
130    }
131
132    impl FieldAccessor for CountingAccessor {
133        fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
134            self.count.set(self.count.get() + 1);
135
136            if path.segments().eq(["truth"]) {
137                Some(FieldValueRef::Boolean(true))
138            } else if path.segments().eq(["falsehood"]) {
139                Some(FieldValueRef::Boolean(false))
140            } else if path.segments().eq(["third"]) {
141                Some(FieldValueRef::Boolean(true))
142            } else {
143                None
144            }
145        }
146    }
147
148    fn integer_comparison(op: ComparisonOp, value: i64) -> Predicate {
149        Predicate::Comparison {
150            field: FieldPath::new("amount"),
151            op,
152            value: FieldValue::Integer(value),
153        }
154    }
155
156    fn boolean_comparison(field: &str, value: bool) -> Predicate {
157        Predicate::Comparison {
158            field: FieldPath::new(field),
159            op: ComparisonOp::Eq,
160            value: FieldValue::Boolean(value),
161        }
162    }
163
164    fn amount_range() -> Predicate {
165        Predicate::Range {
166            field: FieldPath::new("amount"),
167            lower: FieldValue::Integer(100),
168            upper: FieldValue::Integer(200),
169        }
170    }
171
172    #[test]
173    fn integer_greater_than_comparison_matches() {
174        let predicate = integer_comparison(ComparisonOp::Gt, 1_000);
175        let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(1_500));
176
177        assert!(evaluate(&predicate, &accessor));
178    }
179
180    #[test]
181    fn integer_greater_than_comparison_rejects_lower_value() {
182        let predicate = integer_comparison(ComparisonOp::Gt, 1_000);
183        let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(500));
184
185        assert!(!evaluate(&predicate, &accessor));
186    }
187
188    #[test]
189    fn exists_returns_true_for_present_field() {
190        let predicate = Predicate::Exists {
191            field: FieldPath::new("region"),
192        };
193        let accessor = StaticAccessor::new("region", FieldValueRef::Text("eu"));
194
195        assert!(evaluate(&predicate, &accessor));
196    }
197
198    #[test]
199    fn exists_returns_false_for_missing_field() {
200        let predicate = Predicate::Exists {
201            field: FieldPath::new("missing"),
202        };
203        let accessor = StaticAccessor::new("region", FieldValueRef::Text("eu"));
204
205        assert!(!evaluate(&predicate, &accessor));
206    }
207
208    #[test]
209    fn comparison_returns_false_for_missing_field() {
210        let predicate = Predicate::Comparison {
211            field: FieldPath::new("missing"),
212            op: ComparisonOp::Eq,
213            value: FieldValue::Text(String::from("x")),
214        };
215        let accessor = StaticAccessor::new("region", FieldValueRef::Text("eu"));
216
217        assert!(!evaluate(&predicate, &accessor));
218    }
219
220    #[test]
221    fn and_short_circuits_at_first_false() {
222        let predicate = Predicate::And(vec![
223            boolean_comparison("truth", true),
224            boolean_comparison("falsehood", true),
225            boolean_comparison("third", true),
226        ]);
227        let accessor = CountingAccessor::new();
228
229        assert!(!evaluate(&predicate, &accessor));
230        assert_eq!(accessor.count(), 2);
231    }
232
233    #[test]
234    fn or_short_circuits_at_first_true() {
235        let predicate = Predicate::Or(vec![
236            boolean_comparison("falsehood", true),
237            boolean_comparison("truth", true),
238            boolean_comparison("third", true),
239        ]);
240        let accessor = CountingAccessor::new();
241
242        assert!(evaluate(&predicate, &accessor));
243        assert_eq!(accessor.count(), 2);
244    }
245
246    #[test]
247    fn not_negates_child_predicate() {
248        let true_predicate = Predicate::Not(Box::new(boolean_comparison("truth", true)));
249        let false_predicate = Predicate::Not(Box::new(boolean_comparison("falsehood", true)));
250        let accessor = CountingAccessor::new();
251
252        assert!(!evaluate(&true_predicate, &accessor));
253        assert!(evaluate(&false_predicate, &accessor));
254    }
255
256    #[test]
257    fn empty_boolean_combinators_have_vacuous_values() {
258        assert!(evaluate(
259            &Predicate::And(Vec::new()),
260            &StaticAccessor::new("amount", FieldValueRef::Integer(1))
261        ));
262        assert!(!evaluate(
263            &Predicate::Or(Vec::new()),
264            &StaticAccessor::new("amount", FieldValueRef::Integer(1))
265        ));
266    }
267
268    #[test]
269    fn range_includes_middle_and_bounds() {
270        for value in [150, 100, 200] {
271            let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(value));
272
273            assert!(evaluate(&amount_range(), &accessor));
274        }
275    }
276
277    #[test]
278    fn range_rejects_value_below_lower_bound() {
279        let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(50));
280
281        assert!(!evaluate(&amount_range(), &accessor));
282    }
283
284    #[test]
285    fn range_rejects_missing_field() {
286        let accessor = StaticAccessor::new("region", FieldValueRef::Text("eu"));
287
288        assert!(!evaluate(&amount_range(), &accessor));
289    }
290
291    #[test]
292    fn range_rejects_type_mismatch() {
293        let accessor = StaticAccessor::new("amount", FieldValueRef::Text("150"));
294
295        assert!(!evaluate(&amount_range(), &accessor));
296    }
297}