Skip to main content

oxirs_vec/
filtered_search.rs

1//! Filtered search capabilities for vector indices
2//!
3//! This module provides advanced filtering capabilities for vector search,
4//! allowing searches to be constrained by metadata predicates, value ranges,
5//! and complex logical conditions.
6
7use crate::{Vector, VectorId};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Metadata filter for search operations
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum MetadataFilter {
14    /// Exact match on a metadata field
15    Equals { field: String, value: FilterValue },
16    /// Field value is not equal to the given value
17    NotEquals { field: String, value: FilterValue },
18    /// Field value is greater than the given value
19    GreaterThan { field: String, value: FilterValue },
20    /// Field value is greater than or equal to the given value
21    GreaterThanOrEqual { field: String, value: FilterValue },
22    /// Field value is less than the given value
23    LessThan { field: String, value: FilterValue },
24    /// Field value is less than or equal to the given value
25    LessThanOrEqual { field: String, value: FilterValue },
26    /// Field value is in the given set
27    In {
28        field: String,
29        values: Vec<FilterValue>,
30    },
31    /// Field value is not in the given set
32    NotIn {
33        field: String,
34        values: Vec<FilterValue>,
35    },
36    /// Field value contains the given substring
37    Contains { field: String, substring: String },
38    /// Field value matches the given regex pattern
39    Regex { field: String, pattern: String },
40    /// Field exists (has any value)
41    Exists { field: String },
42    /// Field does not exist or is null
43    NotExists { field: String },
44    /// Logical AND of multiple filters
45    And(Vec<MetadataFilter>),
46    /// Logical OR of multiple filters
47    Or(Vec<MetadataFilter>),
48    /// Logical NOT of a filter
49    Not(Box<MetadataFilter>),
50}
51
52/// Value type for filter predicates
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub enum FilterValue {
55    String(String),
56    Integer(i64),
57    Float(f64),
58    Boolean(bool),
59    Null,
60}
61
62impl FilterValue {
63    /// Compare two filter values
64    fn compare(&self, other: &FilterValue) -> std::cmp::Ordering {
65        match (self, other) {
66            (FilterValue::String(a), FilterValue::String(b)) => a.cmp(b),
67            (FilterValue::Integer(a), FilterValue::Integer(b)) => a.cmp(b),
68            (FilterValue::Float(a), FilterValue::Float(b)) => {
69                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
70            }
71            (FilterValue::Boolean(a), FilterValue::Boolean(b)) => a.cmp(b),
72            _ => std::cmp::Ordering::Equal,
73        }
74    }
75}
76
77impl MetadataFilter {
78    /// Evaluate the filter against a metadata map
79    pub fn evaluate(&self, metadata: &HashMap<String, String>) -> bool {
80        match self {
81            MetadataFilter::Equals { field, value } => {
82                if let Some(field_value) = metadata.get(field) {
83                    let parsed_value = Self::parse_value(field_value);
84                    &parsed_value == value
85                } else {
86                    false
87                }
88            }
89            MetadataFilter::NotEquals { field, value } => {
90                if let Some(field_value) = metadata.get(field) {
91                    let parsed_value = Self::parse_value(field_value);
92                    &parsed_value != value
93                } else {
94                    true
95                }
96            }
97            MetadataFilter::GreaterThan { field, value } => {
98                if let Some(field_value) = metadata.get(field) {
99                    let parsed_value = Self::parse_value(field_value);
100                    parsed_value.compare(value) == std::cmp::Ordering::Greater
101                } else {
102                    false
103                }
104            }
105            MetadataFilter::GreaterThanOrEqual { field, value } => {
106                if let Some(field_value) = metadata.get(field) {
107                    let parsed_value = Self::parse_value(field_value);
108                    matches!(
109                        parsed_value.compare(value),
110                        std::cmp::Ordering::Greater | std::cmp::Ordering::Equal
111                    )
112                } else {
113                    false
114                }
115            }
116            MetadataFilter::LessThan { field, value } => {
117                if let Some(field_value) = metadata.get(field) {
118                    let parsed_value = Self::parse_value(field_value);
119                    parsed_value.compare(value) == std::cmp::Ordering::Less
120                } else {
121                    false
122                }
123            }
124            MetadataFilter::LessThanOrEqual { field, value } => {
125                if let Some(field_value) = metadata.get(field) {
126                    let parsed_value = Self::parse_value(field_value);
127                    matches!(
128                        parsed_value.compare(value),
129                        std::cmp::Ordering::Less | std::cmp::Ordering::Equal
130                    )
131                } else {
132                    false
133                }
134            }
135            MetadataFilter::In { field, values } => {
136                if let Some(field_value) = metadata.get(field) {
137                    let parsed_value = Self::parse_value(field_value);
138                    values.contains(&parsed_value)
139                } else {
140                    false
141                }
142            }
143            MetadataFilter::NotIn { field, values } => {
144                if let Some(field_value) = metadata.get(field) {
145                    let parsed_value = Self::parse_value(field_value);
146                    !values.contains(&parsed_value)
147                } else {
148                    true
149                }
150            }
151            MetadataFilter::Contains { field, substring } => {
152                if let Some(field_value) = metadata.get(field) {
153                    field_value.contains(substring)
154                } else {
155                    false
156                }
157            }
158            MetadataFilter::Regex { field, pattern } => {
159                if let Some(field_value) = metadata.get(field) {
160                    if let Ok(regex) = regex::Regex::new(pattern) {
161                        regex.is_match(field_value)
162                    } else {
163                        false
164                    }
165                } else {
166                    false
167                }
168            }
169            MetadataFilter::Exists { field } => metadata.contains_key(field),
170            MetadataFilter::NotExists { field } => !metadata.contains_key(field),
171            MetadataFilter::And(filters) => filters.iter().all(|f| f.evaluate(metadata)),
172            MetadataFilter::Or(filters) => filters.iter().any(|f| f.evaluate(metadata)),
173            MetadataFilter::Not(filter) => !filter.evaluate(metadata),
174        }
175    }
176
177    /// Parse a string value into a FilterValue
178    fn parse_value(s: &str) -> FilterValue {
179        // Try to parse as integer
180        if let Ok(i) = s.parse::<i64>() {
181            return FilterValue::Integer(i);
182        }
183
184        // Try to parse as float
185        if let Ok(f) = s.parse::<f64>() {
186            return FilterValue::Float(f);
187        }
188
189        // Try to parse as boolean
190        if let Ok(b) = s.parse::<bool>() {
191            return FilterValue::Boolean(b);
192        }
193
194        // Check for null
195        if s == "null" || s.is_empty() {
196            return FilterValue::Null;
197        }
198
199        // Default to string
200        FilterValue::String(s.to_string())
201    }
202}
203
204/// Search filter combining distance and metadata constraints
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct SearchFilter {
207    /// Maximum distance threshold
208    pub max_distance: Option<f32>,
209    /// Minimum distance threshold
210    pub min_distance: Option<f32>,
211    /// Metadata filter predicates
212    pub metadata_filter: Option<MetadataFilter>,
213    /// Vector dimension constraints
214    pub dimension_constraints: Option<Vec<DimensionConstraint>>,
215}
216
217/// Constraint on specific vector dimensions
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct DimensionConstraint {
220    /// Dimension index
221    pub dimension: usize,
222    /// Minimum value for this dimension
223    pub min_value: Option<f32>,
224    /// Maximum value for this dimension
225    pub max_value: Option<f32>,
226}
227
228impl DimensionConstraint {
229    /// Check if a vector satisfies this dimension constraint
230    pub fn satisfies(&self, vector: &Vector) -> bool {
231        let values = vector.as_f32();
232
233        if self.dimension >= values.len() {
234            return false;
235        }
236
237        let value = values[self.dimension];
238
239        if let Some(min) = self.min_value {
240            if value < min {
241                return false;
242            }
243        }
244
245        if let Some(max) = self.max_value {
246            if value > max {
247                return false;
248            }
249        }
250
251        true
252    }
253}
254
255impl SearchFilter {
256    /// Create a new empty search filter
257    pub fn new() -> Self {
258        Self {
259            max_distance: None,
260            min_distance: None,
261            metadata_filter: None,
262            dimension_constraints: None,
263        }
264    }
265
266    /// Set maximum distance threshold
267    pub fn with_max_distance(mut self, max_distance: f32) -> Self {
268        self.max_distance = Some(max_distance);
269        self
270    }
271
272    /// Set minimum distance threshold
273    pub fn with_min_distance(mut self, min_distance: f32) -> Self {
274        self.min_distance = Some(min_distance);
275        self
276    }
277
278    /// Set metadata filter
279    pub fn with_metadata_filter(mut self, filter: MetadataFilter) -> Self {
280        self.metadata_filter = Some(filter);
281        self
282    }
283
284    /// Set dimension constraints
285    pub fn with_dimension_constraints(mut self, constraints: Vec<DimensionConstraint>) -> Self {
286        self.dimension_constraints = Some(constraints);
287        self
288    }
289
290    /// Check if a search result satisfies this filter
291    pub fn satisfies(
292        &self,
293        distance: f32,
294        vector: &Vector,
295        metadata: &HashMap<String, String>,
296    ) -> bool {
297        // Check distance constraints
298        if let Some(max) = self.max_distance {
299            if distance > max {
300                return false;
301            }
302        }
303
304        if let Some(min) = self.min_distance {
305            if distance < min {
306                return false;
307            }
308        }
309
310        // Check metadata filter
311        if let Some(ref filter) = self.metadata_filter {
312            if !filter.evaluate(metadata) {
313                return false;
314            }
315        }
316
317        // Check dimension constraints
318        if let Some(ref constraints) = self.dimension_constraints {
319            for constraint in constraints {
320                if !constraint.satisfies(vector) {
321                    return false;
322                }
323            }
324        }
325
326        true
327    }
328
329    /// Filter a list of search results
330    pub fn filter_results(
331        &self,
332        results: Vec<(VectorId, f32, Vector, HashMap<String, String>)>,
333    ) -> Vec<(VectorId, f32)> {
334        results
335            .into_iter()
336            .filter(|(_, distance, vector, metadata)| self.satisfies(*distance, vector, metadata))
337            .map(|(id, distance, _, _)| (id, distance))
338            .collect()
339    }
340}
341
342impl Default for SearchFilter {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348/// Builder for complex filter expressions
349pub struct FilterBuilder {
350    filters: Vec<MetadataFilter>,
351}
352
353impl FilterBuilder {
354    pub fn new() -> Self {
355        Self {
356            filters: Vec::new(),
357        }
358    }
359
360    pub fn equals(mut self, field: impl Into<String>, value: FilterValue) -> Self {
361        self.filters.push(MetadataFilter::Equals {
362            field: field.into(),
363            value,
364        });
365        self
366    }
367
368    pub fn not_equals(mut self, field: impl Into<String>, value: FilterValue) -> Self {
369        self.filters.push(MetadataFilter::NotEquals {
370            field: field.into(),
371            value,
372        });
373        self
374    }
375
376    pub fn greater_than(mut self, field: impl Into<String>, value: FilterValue) -> Self {
377        self.filters.push(MetadataFilter::GreaterThan {
378            field: field.into(),
379            value,
380        });
381        self
382    }
383
384    pub fn less_than(mut self, field: impl Into<String>, value: FilterValue) -> Self {
385        self.filters.push(MetadataFilter::LessThan {
386            field: field.into(),
387            value,
388        });
389        self
390    }
391
392    pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
393        self.filters.push(MetadataFilter::Contains {
394            field: field.into(),
395            substring: substring.into(),
396        });
397        self
398    }
399
400    pub fn regex(mut self, field: impl Into<String>, pattern: impl Into<String>) -> Self {
401        self.filters.push(MetadataFilter::Regex {
402            field: field.into(),
403            pattern: pattern.into(),
404        });
405        self
406    }
407
408    pub fn exists(mut self, field: impl Into<String>) -> Self {
409        self.filters.push(MetadataFilter::Exists {
410            field: field.into(),
411        });
412        self
413    }
414
415    pub fn build_and(self) -> MetadataFilter {
416        if self.filters.len() == 1 {
417            self.filters
418                .into_iter()
419                .next()
420                .expect("filters validated to have exactly one element")
421        } else {
422            MetadataFilter::And(self.filters)
423        }
424    }
425
426    pub fn build_or(self) -> MetadataFilter {
427        if self.filters.len() == 1 {
428            self.filters
429                .into_iter()
430                .next()
431                .expect("filters validated to have exactly one element")
432        } else {
433            MetadataFilter::Or(self.filters)
434        }
435    }
436}
437
438impl Default for FilterBuilder {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_equals_filter() {
450        let filter = MetadataFilter::Equals {
451            field: "category".to_string(),
452            value: FilterValue::String("news".to_string()),
453        };
454
455        let mut metadata = HashMap::new();
456        metadata.insert("category".to_string(), "news".to_string());
457
458        assert!(filter.evaluate(&metadata));
459
460        metadata.insert("category".to_string(), "sports".to_string());
461        assert!(!filter.evaluate(&metadata));
462    }
463
464    #[test]
465    fn test_greater_than_filter() {
466        let filter = MetadataFilter::GreaterThan {
467            field: "score".to_string(),
468            value: FilterValue::Integer(50),
469        };
470
471        let mut metadata = HashMap::new();
472        metadata.insert("score".to_string(), "75".to_string());
473        assert!(filter.evaluate(&metadata));
474
475        metadata.insert("score".to_string(), "25".to_string());
476        assert!(!filter.evaluate(&metadata));
477    }
478
479    #[test]
480    fn test_and_filter() {
481        let filter = MetadataFilter::And(vec![
482            MetadataFilter::Equals {
483                field: "status".to_string(),
484                value: FilterValue::String("active".to_string()),
485            },
486            MetadataFilter::GreaterThan {
487                field: "priority".to_string(),
488                value: FilterValue::Integer(5),
489            },
490        ]);
491
492        let mut metadata = HashMap::new();
493        metadata.insert("status".to_string(), "active".to_string());
494        metadata.insert("priority".to_string(), "8".to_string());
495        assert!(filter.evaluate(&metadata));
496
497        metadata.insert("priority".to_string(), "3".to_string());
498        assert!(!filter.evaluate(&metadata));
499    }
500
501    #[test]
502    fn test_or_filter() {
503        let filter = MetadataFilter::Or(vec![
504            MetadataFilter::Equals {
505                field: "type".to_string(),
506                value: FilterValue::String("urgent".to_string()),
507            },
508            MetadataFilter::Equals {
509                field: "type".to_string(),
510                value: FilterValue::String("critical".to_string()),
511            },
512        ]);
513
514        let mut metadata = HashMap::new();
515        metadata.insert("type".to_string(), "urgent".to_string());
516        assert!(filter.evaluate(&metadata));
517
518        metadata.insert("type".to_string(), "critical".to_string());
519        assert!(filter.evaluate(&metadata));
520
521        metadata.insert("type".to_string(), "normal".to_string());
522        assert!(!filter.evaluate(&metadata));
523    }
524
525    #[test]
526    fn test_contains_filter() {
527        let filter = MetadataFilter::Contains {
528            field: "description".to_string(),
529            substring: "important".to_string(),
530        };
531
532        let mut metadata = HashMap::new();
533        metadata.insert(
534            "description".to_string(),
535            "This is an important message".to_string(),
536        );
537        assert!(filter.evaluate(&metadata));
538
539        metadata.insert("description".to_string(), "Regular message".to_string());
540        assert!(!filter.evaluate(&metadata));
541    }
542
543    #[test]
544    fn test_filter_builder() {
545        let filter = FilterBuilder::new()
546            .equals("category", FilterValue::String("tech".to_string()))
547            .greater_than("score", FilterValue::Integer(70))
548            .build_and();
549
550        let mut metadata = HashMap::new();
551        metadata.insert("category".to_string(), "tech".to_string());
552        metadata.insert("score".to_string(), "85".to_string());
553        assert!(filter.evaluate(&metadata));
554    }
555
556    #[test]
557    fn test_dimension_constraint() {
558        let constraint = DimensionConstraint {
559            dimension: 0,
560            min_value: Some(0.0),
561            max_value: Some(1.0),
562        };
563
564        let vec1 = Vector::new(vec![0.5, 0.3, 0.7]);
565        assert!(constraint.satisfies(&vec1));
566
567        let vec2 = Vector::new(vec![1.5, 0.3, 0.7]);
568        assert!(!constraint.satisfies(&vec2));
569    }
570
571    #[test]
572    fn test_search_filter() {
573        let filter = SearchFilter::new()
574            .with_max_distance(0.5)
575            .with_metadata_filter(MetadataFilter::Equals {
576                field: "category".to_string(),
577                value: FilterValue::String("approved".to_string()),
578            });
579
580        let mut metadata = HashMap::new();
581        metadata.insert("category".to_string(), "approved".to_string());
582
583        let vector = Vector::new(vec![1.0, 2.0, 3.0]);
584
585        assert!(filter.satisfies(0.3, &vector, &metadata));
586        assert!(!filter.satisfies(0.7, &vector, &metadata)); // distance too high
587
588        metadata.insert("category".to_string(), "pending".to_string());
589        assert!(!filter.satisfies(0.3, &vector, &metadata)); // metadata doesn't match
590    }
591}