Skip to main content

nodedb_query/
scan_filter.rs

1//! Post-scan filter evaluation.
2//!
3//! `ScanFilter` represents a single filter predicate. `compare_json_values`
4//! provides total ordering for JSON values used in sort and range comparisons.
5//!
6//! Shared between Origin (Control Plane + Data Plane) and Lite.
7
8use crate::json_ops::{coerced_eq, compare_json_optional as compare_json_values};
9
10/// A single filter predicate for document scan evaluation.
11///
12/// Supports simple comparison operators (eq, ne, gt, gte, lt, lte, contains,
13/// is_null, is_not_null) and disjunctive groups via the `"or"` operator.
14///
15/// OR representation: `{"op": "or", "clauses": [[filter1, filter2], [filter3]]}`
16/// means `(filter1 AND filter2) OR filter3`. Each clause is an AND-group;
17/// the document matches if ANY clause group fully matches.
18#[derive(Clone, serde::Serialize, serde::Deserialize, Default)]
19pub struct ScanFilter {
20    #[serde(default)]
21    pub field: String,
22    pub op: String,
23    #[serde(default)]
24    pub value: serde_json::Value,
25    /// Disjunctive clause groups for OR predicates.
26    /// Each inner Vec is an AND-group. The document matches if ANY group matches.
27    #[serde(default)]
28    pub clauses: Vec<Vec<ScanFilter>>,
29}
30
31impl ScanFilter {
32    /// Evaluate this filter against a JSON document.
33    pub fn matches(&self, doc: &serde_json::Value) -> bool {
34        if self.op == "match_all" {
35            return true;
36        }
37
38        if self.op == "exists" || self.op == "not_exists" {
39            return true;
40        }
41
42        if self.op == "or" {
43            return self
44                .clauses
45                .iter()
46                .any(|clause| clause.iter().all(|f| f.matches(doc)));
47        }
48
49        let field_val = match doc.get(&self.field) {
50            Some(v) => v,
51            None => return self.op == "is_null",
52        };
53
54        match self.op.as_str() {
55            "eq" => coerced_eq(field_val, &self.value),
56            "ne" | "neq" => !coerced_eq(field_val, &self.value),
57            "gt" => {
58                compare_json_values(Some(field_val), Some(&self.value))
59                    == std::cmp::Ordering::Greater
60            }
61            "gte" | "ge" => {
62                let cmp = compare_json_values(Some(field_val), Some(&self.value));
63                cmp == std::cmp::Ordering::Greater || cmp == std::cmp::Ordering::Equal
64            }
65            "lt" => {
66                compare_json_values(Some(field_val), Some(&self.value)) == std::cmp::Ordering::Less
67            }
68            "lte" | "le" => {
69                let cmp = compare_json_values(Some(field_val), Some(&self.value));
70                cmp == std::cmp::Ordering::Less || cmp == std::cmp::Ordering::Equal
71            }
72            "contains" => {
73                if let (Some(s), Some(pattern)) = (field_val.as_str(), self.value.as_str()) {
74                    s.contains(pattern)
75                } else {
76                    false
77                }
78            }
79            "like" => {
80                if let (Some(s), Some(pattern)) = (field_val.as_str(), self.value.as_str()) {
81                    sql_like_match(s, pattern, false)
82                } else {
83                    false
84                }
85            }
86            "not_like" => {
87                if let (Some(s), Some(pattern)) = (field_val.as_str(), self.value.as_str()) {
88                    !sql_like_match(s, pattern, false)
89                } else {
90                    false
91                }
92            }
93            "ilike" => {
94                if let (Some(s), Some(pattern)) = (field_val.as_str(), self.value.as_str()) {
95                    sql_like_match(s, pattern, true)
96                } else {
97                    false
98                }
99            }
100            "not_ilike" => {
101                if let (Some(s), Some(pattern)) = (field_val.as_str(), self.value.as_str()) {
102                    !sql_like_match(s, pattern, true)
103                } else {
104                    false
105                }
106            }
107            "in" => {
108                if let Some(arr) = self.value.as_array() {
109                    arr.iter().any(|v| field_val == v)
110                } else {
111                    false
112                }
113            }
114            "not_in" => {
115                if let Some(arr) = self.value.as_array() {
116                    !arr.iter().any(|v| field_val == v)
117                } else {
118                    true
119                }
120            }
121            "is_null" => field_val.is_null(),
122            "is_not_null" => !field_val.is_null(),
123            _ => false,
124        }
125    }
126}
127
128/// SQL LIKE pattern matching.
129///
130/// Supports `%` (zero or more characters) and `_` (exactly one character).
131/// When `case_insensitive` is true, both input and pattern are lowercased (ILIKE).
132pub fn sql_like_match(input: &str, pattern: &str, case_insensitive: bool) -> bool {
133    let (input, pattern) = if case_insensitive {
134        (input.to_lowercase(), pattern.to_lowercase())
135    } else {
136        (input.to_string(), pattern.to_string())
137    };
138
139    let input = input.as_bytes();
140    let pattern = pattern.as_bytes();
141
142    let (mut i, mut j) = (0usize, 0usize);
143    let (mut star_j, mut star_i) = (usize::MAX, 0usize);
144
145    while i < input.len() {
146        if j < pattern.len() && (pattern[j] == b'_' || pattern[j] == input[i]) {
147            i += 1;
148            j += 1;
149        } else if j < pattern.len() && pattern[j] == b'%' {
150            star_j = j;
151            star_i = i;
152            j += 1;
153        } else if star_j != usize::MAX {
154            star_i += 1;
155            i = star_i;
156            j = star_j + 1;
157        } else {
158            return false;
159        }
160    }
161
162    while j < pattern.len() && pattern[j] == b'%' {
163        j += 1;
164    }
165
166    j == pattern.len()
167}
168
169/// Compute an aggregate function over a group of JSON documents.
170///
171/// Supported operations: count, sum, avg, min, max, count_distinct,
172/// stddev, variance, array_agg, string_agg, percentile_cont.
173pub fn compute_aggregate(op: &str, field: &str, docs: &[serde_json::Value]) -> serde_json::Value {
174    match op {
175        "count" => serde_json::json!(docs.len()),
176
177        "sum" => {
178            let total: f64 = docs
179                .iter()
180                .filter_map(|d| d.get(field).and_then(|v| v.as_f64()))
181                .sum();
182            serde_json::json!(total)
183        }
184
185        "avg" => {
186            let values: Vec<f64> = docs
187                .iter()
188                .filter_map(|d| d.get(field).and_then(|v| v.as_f64()))
189                .collect();
190            if values.is_empty() {
191                serde_json::Value::Null
192            } else {
193                let avg = values.iter().sum::<f64>() / values.len() as f64;
194                serde_json::json!(avg)
195            }
196        }
197
198        "min" => {
199            let min = docs
200                .iter()
201                .filter_map(|d| d.get(field))
202                .min_by(|a, b| compare_json_values(Some(a), Some(b)));
203            match min {
204                Some(v) => v.clone(),
205                None => serde_json::Value::Null,
206            }
207        }
208
209        "max" => {
210            let max = docs
211                .iter()
212                .filter_map(|d| d.get(field))
213                .max_by(|a, b| compare_json_values(Some(a), Some(b)));
214            match max {
215                Some(v) => v.clone(),
216                None => serde_json::Value::Null,
217            }
218        }
219
220        "count_distinct" => {
221            let mut seen = std::collections::HashSet::new();
222            for d in docs {
223                if let Some(v) = d.get(field) {
224                    seen.insert(v.to_string());
225                }
226            }
227            serde_json::json!(seen.len())
228        }
229
230        "stddev" | "stddev_pop" => {
231            let values: Vec<f64> = docs
232                .iter()
233                .filter_map(|d| d.get(field).and_then(|v| v.as_f64()))
234                .collect();
235            if values.len() < 2 {
236                return serde_json::Value::Null;
237            }
238            let mean = values.iter().sum::<f64>() / values.len() as f64;
239            let variance =
240                values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
241            serde_json::json!(variance.sqrt())
242        }
243
244        "stddev_samp" => {
245            let values: Vec<f64> = docs
246                .iter()
247                .filter_map(|d| d.get(field).and_then(|v| v.as_f64()))
248                .collect();
249            if values.len() < 2 {
250                return serde_json::Value::Null;
251            }
252            let mean = values.iter().sum::<f64>() / values.len() as f64;
253            let variance =
254                values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
255            serde_json::json!(variance.sqrt())
256        }
257
258        "variance" | "var_pop" => {
259            let values: Vec<f64> = docs
260                .iter()
261                .filter_map(|d| d.get(field).and_then(|v| v.as_f64()))
262                .collect();
263            if values.len() < 2 {
264                return serde_json::Value::Null;
265            }
266            let mean = values.iter().sum::<f64>() / values.len() as f64;
267            let variance =
268                values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
269            serde_json::json!(variance)
270        }
271
272        "var_samp" => {
273            let values: Vec<f64> = docs
274                .iter()
275                .filter_map(|d| d.get(field).and_then(|v| v.as_f64()))
276                .collect();
277            if values.len() < 2 {
278                return serde_json::Value::Null;
279            }
280            let mean = values.iter().sum::<f64>() / values.len() as f64;
281            let variance =
282                values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
283            serde_json::json!(variance)
284        }
285
286        "array_agg" => {
287            let values: Vec<serde_json::Value> =
288                docs.iter().filter_map(|d| d.get(field).cloned()).collect();
289            serde_json::Value::Array(values)
290        }
291
292        "string_agg" | "group_concat" => {
293            let values: Vec<String> = docs
294                .iter()
295                .filter_map(|d| d.get(field).and_then(|v| v.as_str()).map(String::from))
296                .collect();
297            serde_json::Value::String(values.join(","))
298        }
299
300        "percentile_cont" => {
301            let (pct, actual_field) = if let Some(idx) = field.find(':') {
302                let p: f64 = field[..idx].parse().unwrap_or(0.5);
303                (p, &field[idx + 1..])
304            } else {
305                (0.5, field)
306            };
307            let mut values: Vec<f64> = docs
308                .iter()
309                .filter_map(|d| d.get(actual_field).and_then(|v| v.as_f64()))
310                .collect();
311            if values.is_empty() {
312                return serde_json::Value::Null;
313            }
314            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
315            let idx = (pct * (values.len() - 1) as f64).clamp(0.0, (values.len() - 1) as f64);
316            let lower = idx.floor() as usize;
317            let upper = idx.ceil() as usize;
318            let frac = idx - lower as f64;
319            let result = values[lower] * (1.0 - frac) + values[upper] * frac;
320            serde_json::json!(result)
321        }
322
323        _ => serde_json::Value::Null,
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use serde_json::json;
331
332    #[test]
333    fn filter_eq_coercion() {
334        let doc = json!({"age": 25});
335        let filter = ScanFilter {
336            field: "age".into(),
337            op: "eq".into(),
338            value: json!("25"),
339            clauses: vec![],
340        };
341        assert!(filter.matches(&doc));
342    }
343
344    #[test]
345    fn filter_gt_coercion() {
346        let doc = json!({"score": "90"});
347        let filter = ScanFilter {
348            field: "score".into(),
349            op: "gt".into(),
350            value: json!(80),
351            clauses: vec![],
352        };
353        assert!(filter.matches(&doc));
354    }
355
356    #[test]
357    fn like_basic() {
358        assert!(sql_like_match("hello world", "%world", false));
359        assert!(sql_like_match("hello world", "hello%", false));
360        assert!(!sql_like_match("hello world", "xyz%", false));
361    }
362
363    #[test]
364    fn ilike_case_insensitive() {
365        assert!(sql_like_match("Hello", "hello", true));
366        assert!(sql_like_match("WORLD", "%world%", true));
367    }
368
369    #[test]
370    fn aggregate_count() {
371        let docs = vec![json!({"x": 1}), json!({"x": 2}), json!({"x": 3})];
372        assert_eq!(compute_aggregate("count", "x", &docs), json!(3));
373    }
374
375    #[test]
376    fn aggregate_sum() {
377        let docs = vec![json!({"v": 10}), json!({"v": 20}), json!({"v": 30})];
378        assert_eq!(compute_aggregate("sum", "v", &docs), json!(60.0));
379    }
380
381    #[test]
382    fn aggregate_min_max() {
383        let docs = vec![json!({"v": 5}), json!({"v": 1}), json!({"v": 9})];
384        assert_eq!(compute_aggregate("min", "v", &docs), json!(1));
385        assert_eq!(compute_aggregate("max", "v", &docs), json!(9));
386    }
387}