oxify_vector/
filter.rs

1//! Metadata filtering for vector search
2//!
3//! Provides pre-filtering and post-filtering capabilities to constrain
4//! search results based on metadata attributes.
5//!
6//! ## Example
7//!
8//! ```rust
9//! use oxify_vector::filter::{Filter, FilterCondition, FilterValue};
10//!
11//! // Filter by document type
12//! let filter = Filter::new()
13//!     .eq("type", "article")
14//!     .gte("year", 2020);
15//!
16//! // Filter with OR conditions
17//! let filter = Filter::any(vec![
18//!     Filter::new().eq("category", "tech"),
19//!     Filter::new().eq("category", "science"),
20//! ]);
21//! ```
22
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25
26/// Value types for metadata filtering
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28pub enum FilterValue {
29    /// String value
30    String(String),
31    /// Integer value
32    Int(i64),
33    /// Float value
34    Float(f64),
35    /// Boolean value
36    Bool(bool),
37    /// List of strings (for IN operator)
38    StringList(Vec<String>),
39    /// List of integers (for IN operator)
40    IntList(Vec<i64>),
41}
42
43impl From<&str> for FilterValue {
44    fn from(s: &str) -> Self {
45        FilterValue::String(s.to_string())
46    }
47}
48
49impl From<String> for FilterValue {
50    fn from(s: String) -> Self {
51        FilterValue::String(s)
52    }
53}
54
55impl From<i64> for FilterValue {
56    fn from(v: i64) -> Self {
57        FilterValue::Int(v)
58    }
59}
60
61impl From<i32> for FilterValue {
62    fn from(v: i32) -> Self {
63        FilterValue::Int(v as i64)
64    }
65}
66
67impl From<f64> for FilterValue {
68    fn from(v: f64) -> Self {
69        FilterValue::Float(v)
70    }
71}
72
73impl From<bool> for FilterValue {
74    fn from(v: bool) -> Self {
75        FilterValue::Bool(v)
76    }
77}
78
79impl From<Vec<String>> for FilterValue {
80    fn from(v: Vec<String>) -> Self {
81        FilterValue::StringList(v)
82    }
83}
84
85impl From<Vec<&str>> for FilterValue {
86    fn from(v: Vec<&str>) -> Self {
87        FilterValue::StringList(v.into_iter().map(|s| s.to_string()).collect())
88    }
89}
90
91impl From<Vec<i64>> for FilterValue {
92    fn from(v: Vec<i64>) -> Self {
93        FilterValue::IntList(v)
94    }
95}
96
97/// Filter condition operators
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub enum FilterCondition {
100    /// Equal to
101    Eq(String, FilterValue),
102    /// Not equal to
103    Ne(String, FilterValue),
104    /// Greater than
105    Gt(String, FilterValue),
106    /// Greater than or equal to
107    Gte(String, FilterValue),
108    /// Less than
109    Lt(String, FilterValue),
110    /// Less than or equal to
111    Lte(String, FilterValue),
112    /// Value is in list
113    In(String, FilterValue),
114    /// Value is not in list
115    NotIn(String, FilterValue),
116    /// Field contains substring
117    Contains(String, String),
118    /// Field starts with prefix
119    StartsWith(String, String),
120    /// All conditions must match (AND)
121    All(Vec<FilterCondition>),
122    /// Any condition must match (OR)
123    Any(Vec<FilterCondition>),
124    /// Negate condition (NOT)
125    Not(Box<FilterCondition>),
126}
127
128/// Metadata filter builder
129#[derive(Debug, Clone, Default, Serialize, Deserialize)]
130pub struct Filter {
131    conditions: Vec<FilterCondition>,
132}
133
134impl Filter {
135    /// Create a new empty filter
136    pub fn new() -> Self {
137        Self {
138            conditions: Vec::new(),
139        }
140    }
141
142    /// Create a filter that matches ALL conditions (AND)
143    pub fn all(filters: Vec<Filter>) -> Self {
144        let conditions: Vec<FilterCondition> =
145            filters.into_iter().flat_map(|f| f.conditions).collect();
146        Self { conditions }
147    }
148
149    /// Create a filter that matches ANY condition (OR)
150    pub fn any(filters: Vec<Filter>) -> Self {
151        let inner: Vec<FilterCondition> = filters
152            .into_iter()
153            .map(|f| {
154                if f.conditions.len() == 1 {
155                    f.conditions.into_iter().next().unwrap()
156                } else {
157                    FilterCondition::All(f.conditions)
158                }
159            })
160            .collect();
161        Self {
162            conditions: vec![FilterCondition::Any(inner)],
163        }
164    }
165
166    /// Add equality condition
167    pub fn eq<K: Into<String>, V: Into<FilterValue>>(mut self, key: K, value: V) -> Self {
168        self.conditions
169            .push(FilterCondition::Eq(key.into(), value.into()));
170        self
171    }
172
173    /// Add not-equal condition
174    pub fn ne<K: Into<String>, V: Into<FilterValue>>(mut self, key: K, value: V) -> Self {
175        self.conditions
176            .push(FilterCondition::Ne(key.into(), value.into()));
177        self
178    }
179
180    /// Add greater-than condition
181    pub fn gt<K: Into<String>, V: Into<FilterValue>>(mut self, key: K, value: V) -> Self {
182        self.conditions
183            .push(FilterCondition::Gt(key.into(), value.into()));
184        self
185    }
186
187    /// Add greater-than-or-equal condition
188    pub fn gte<K: Into<String>, V: Into<FilterValue>>(mut self, key: K, value: V) -> Self {
189        self.conditions
190            .push(FilterCondition::Gte(key.into(), value.into()));
191        self
192    }
193
194    /// Add less-than condition
195    pub fn lt<K: Into<String>, V: Into<FilterValue>>(mut self, key: K, value: V) -> Self {
196        self.conditions
197            .push(FilterCondition::Lt(key.into(), value.into()));
198        self
199    }
200
201    /// Add less-than-or-equal condition
202    pub fn lte<K: Into<String>, V: Into<FilterValue>>(mut self, key: K, value: V) -> Self {
203        self.conditions
204            .push(FilterCondition::Lte(key.into(), value.into()));
205        self
206    }
207
208    /// Add IN condition (value must be in list)
209    pub fn in_list<K: Into<String>, V: Into<FilterValue>>(mut self, key: K, values: V) -> Self {
210        self.conditions
211            .push(FilterCondition::In(key.into(), values.into()));
212        self
213    }
214
215    /// Add NOT IN condition (value must not be in list)
216    pub fn not_in<K: Into<String>, V: Into<FilterValue>>(mut self, key: K, values: V) -> Self {
217        self.conditions
218            .push(FilterCondition::NotIn(key.into(), values.into()));
219        self
220    }
221
222    /// Add contains condition (string contains substring)
223    pub fn contains<K: Into<String>, V: Into<String>>(mut self, key: K, substring: V) -> Self {
224        self.conditions
225            .push(FilterCondition::Contains(key.into(), substring.into()));
226        self
227    }
228
229    /// Add starts_with condition
230    pub fn starts_with<K: Into<String>, V: Into<String>>(mut self, key: K, prefix: V) -> Self {
231        self.conditions
232            .push(FilterCondition::StartsWith(key.into(), prefix.into()));
233        self
234    }
235
236    /// Get the conditions
237    pub fn conditions(&self) -> &[FilterCondition] {
238        &self.conditions
239    }
240
241    /// Check if filter is empty
242    pub fn is_empty(&self) -> bool {
243        self.conditions.is_empty()
244    }
245
246    /// Evaluate filter against metadata
247    pub fn matches(&self, metadata: &Metadata) -> bool {
248        self.conditions
249            .iter()
250            .all(|c| evaluate_condition(c, metadata))
251    }
252}
253
254/// Metadata storage for a single entity
255pub type Metadata = HashMap<String, FilterValue>;
256
257/// Evaluate a single condition against metadata
258fn evaluate_condition(condition: &FilterCondition, metadata: &Metadata) -> bool {
259    match condition {
260        FilterCondition::Eq(key, expected) => metadata.get(key) == Some(expected),
261        FilterCondition::Ne(key, expected) => metadata.get(key) != Some(expected),
262        FilterCondition::Gt(key, expected) => metadata
263            .get(key)
264            .is_some_and(|v| compare_values(v, expected) == Some(std::cmp::Ordering::Greater)),
265        FilterCondition::Gte(key, expected) => metadata.get(key).is_some_and(|v| {
266            matches!(
267                compare_values(v, expected),
268                Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
269            )
270        }),
271        FilterCondition::Lt(key, expected) => metadata
272            .get(key)
273            .is_some_and(|v| compare_values(v, expected) == Some(std::cmp::Ordering::Less)),
274        FilterCondition::Lte(key, expected) => metadata.get(key).is_some_and(|v| {
275            matches!(
276                compare_values(v, expected),
277                Some(std::cmp::Ordering::Less | std::cmp::Ordering::Equal)
278            )
279        }),
280        FilterCondition::In(key, values) => {
281            metadata.get(key).is_some_and(|v| value_in_list(v, values))
282        }
283        FilterCondition::NotIn(key, values) => {
284            metadata.get(key).is_none_or(|v| !value_in_list(v, values))
285        }
286        FilterCondition::Contains(key, substring) => metadata.get(key).is_some_and(|v| {
287            if let FilterValue::String(s) = v {
288                s.contains(substring)
289            } else {
290                false
291            }
292        }),
293        FilterCondition::StartsWith(key, prefix) => metadata.get(key).is_some_and(|v| {
294            if let FilterValue::String(s) = v {
295                s.starts_with(prefix)
296            } else {
297                false
298            }
299        }),
300        FilterCondition::All(conditions) => {
301            conditions.iter().all(|c| evaluate_condition(c, metadata))
302        }
303        FilterCondition::Any(conditions) => {
304            conditions.iter().any(|c| evaluate_condition(c, metadata))
305        }
306        FilterCondition::Not(condition) => !evaluate_condition(condition, metadata),
307    }
308}
309
310/// Compare two filter values
311fn compare_values(a: &FilterValue, b: &FilterValue) -> Option<std::cmp::Ordering> {
312    match (a, b) {
313        (FilterValue::Int(a), FilterValue::Int(b)) => Some(a.cmp(b)),
314        (FilterValue::Float(a), FilterValue::Float(b)) => a.partial_cmp(b),
315        (FilterValue::Int(a), FilterValue::Float(b)) => (*a as f64).partial_cmp(b),
316        (FilterValue::Float(a), FilterValue::Int(b)) => a.partial_cmp(&(*b as f64)),
317        (FilterValue::String(a), FilterValue::String(b)) => Some(a.cmp(b)),
318        _ => None,
319    }
320}
321
322/// Check if value is in list
323fn value_in_list(value: &FilterValue, list: &FilterValue) -> bool {
324    match (value, list) {
325        (FilterValue::String(v), FilterValue::StringList(l)) => l.contains(v),
326        (FilterValue::Int(v), FilterValue::IntList(l)) => l.contains(v),
327        _ => false,
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    fn test_metadata() -> Metadata {
336        let mut m = HashMap::new();
337        m.insert(
338            "type".to_string(),
339            FilterValue::String("article".to_string()),
340        );
341        m.insert("year".to_string(), FilterValue::Int(2023));
342        m.insert("score".to_string(), FilterValue::Float(0.95));
343        m.insert("published".to_string(), FilterValue::Bool(true));
344        m.insert(
345            "category".to_string(),
346            FilterValue::String("tech".to_string()),
347        );
348        m
349    }
350
351    #[test]
352    fn test_eq_filter() {
353        let metadata = test_metadata();
354
355        let filter = Filter::new().eq("type", "article");
356        assert!(filter.matches(&metadata));
357
358        let filter = Filter::new().eq("type", "book");
359        assert!(!filter.matches(&metadata));
360    }
361
362    #[test]
363    fn test_ne_filter() {
364        let metadata = test_metadata();
365
366        let filter = Filter::new().ne("type", "book");
367        assert!(filter.matches(&metadata));
368
369        let filter = Filter::new().ne("type", "article");
370        assert!(!filter.matches(&metadata));
371    }
372
373    #[test]
374    fn test_gt_filter() {
375        let metadata = test_metadata();
376
377        let filter = Filter::new().gt("year", 2020i64);
378        assert!(filter.matches(&metadata));
379
380        let filter = Filter::new().gt("year", 2023i64);
381        assert!(!filter.matches(&metadata));
382    }
383
384    #[test]
385    fn test_gte_filter() {
386        let metadata = test_metadata();
387
388        let filter = Filter::new().gte("year", 2023i64);
389        assert!(filter.matches(&metadata));
390
391        let filter = Filter::new().gte("year", 2024i64);
392        assert!(!filter.matches(&metadata));
393    }
394
395    #[test]
396    fn test_lt_filter() {
397        let metadata = test_metadata();
398
399        let filter = Filter::new().lt("year", 2026i64);
400        assert!(filter.matches(&metadata));
401
402        let filter = Filter::new().lt("year", 2023i64);
403        assert!(!filter.matches(&metadata));
404    }
405
406    #[test]
407    fn test_lte_filter() {
408        let metadata = test_metadata();
409
410        let filter = Filter::new().lte("year", 2023i64);
411        assert!(filter.matches(&metadata));
412
413        let filter = Filter::new().lte("year", 2022i64);
414        assert!(!filter.matches(&metadata));
415    }
416
417    #[test]
418    fn test_in_filter() {
419        let metadata = test_metadata();
420
421        let filter = Filter::new().in_list("type", vec!["article", "book"]);
422        assert!(filter.matches(&metadata));
423
424        let filter = Filter::new().in_list("type", vec!["book", "journal"]);
425        assert!(!filter.matches(&metadata));
426    }
427
428    #[test]
429    fn test_not_in_filter() {
430        let metadata = test_metadata();
431
432        let filter = Filter::new().not_in("type", vec!["book", "journal"]);
433        assert!(filter.matches(&metadata));
434
435        let filter = Filter::new().not_in("type", vec!["article", "journal"]);
436        assert!(!filter.matches(&metadata));
437    }
438
439    #[test]
440    fn test_contains_filter() {
441        let metadata = test_metadata();
442
443        let filter = Filter::new().contains("type", "art");
444        assert!(filter.matches(&metadata));
445
446        let filter = Filter::new().contains("type", "xyz");
447        assert!(!filter.matches(&metadata));
448    }
449
450    #[test]
451    fn test_starts_with_filter() {
452        let metadata = test_metadata();
453
454        let filter = Filter::new().starts_with("type", "art");
455        assert!(filter.matches(&metadata));
456
457        let filter = Filter::new().starts_with("type", "icle");
458        assert!(!filter.matches(&metadata));
459    }
460
461    #[test]
462    fn test_combined_filters() {
463        let metadata = test_metadata();
464
465        // All conditions must match (AND)
466        let filter = Filter::new().eq("type", "article").gte("year", 2020i64);
467        assert!(filter.matches(&metadata));
468
469        let filter = Filter::new().eq("type", "article").gt("year", 2026i64);
470        assert!(!filter.matches(&metadata));
471    }
472
473    #[test]
474    fn test_or_filters() {
475        let metadata = test_metadata();
476
477        // Any condition can match (OR)
478        let filter = Filter::any(vec![
479            Filter::new().eq("category", "tech"),
480            Filter::new().eq("category", "science"),
481        ]);
482        assert!(filter.matches(&metadata));
483
484        let filter = Filter::any(vec![
485            Filter::new().eq("category", "sports"),
486            Filter::new().eq("category", "music"),
487        ]);
488        assert!(!filter.matches(&metadata));
489    }
490
491    #[test]
492    fn test_missing_field() {
493        let metadata = test_metadata();
494
495        // Missing field returns false for eq
496        let filter = Filter::new().eq("nonexistent", "value");
497        assert!(!filter.matches(&metadata));
498
499        // Missing field returns true for ne
500        let filter = Filter::new().ne("nonexistent", "value");
501        assert!(filter.matches(&metadata));
502    }
503
504    #[test]
505    fn test_float_comparison() {
506        let metadata = test_metadata();
507
508        let filter = Filter::new().gt("score", 0.9);
509        assert!(filter.matches(&metadata));
510
511        let filter = Filter::new().lt("score", 1.0);
512        assert!(filter.matches(&metadata));
513    }
514
515    #[test]
516    fn test_filter_is_empty() {
517        let filter = Filter::new();
518        assert!(filter.is_empty());
519
520        let filter = Filter::new().eq("key", "value");
521        assert!(!filter.is_empty());
522    }
523}