oxirs_vec/
filtered_search.rs

1//! Filtered search capabilities for vector indices
2//!
3//! This module provides advanced filtering capabilities for vector search,
4//! allowing searches to be constrained by metadata predicates, value ranges,
5//! and complex logical conditions.
6
7use crate::{Vector, VectorId};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Metadata filter for search operations
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum MetadataFilter {
14    /// Exact match on a metadata field
15    Equals { field: String, value: FilterValue },
16    /// Field value is not equal to the given value
17    NotEquals { field: String, value: FilterValue },
18    /// Field value is greater than the given value
19    GreaterThan { field: String, value: FilterValue },
20    /// Field value is greater than or equal to the given value
21    GreaterThanOrEqual { field: String, value: FilterValue },
22    /// Field value is less than the given value
23    LessThan { field: String, value: FilterValue },
24    /// Field value is less than or equal to the given value
25    LessThanOrEqual { field: String, value: FilterValue },
26    /// Field value is in the given set
27    In {
28        field: String,
29        values: Vec<FilterValue>,
30    },
31    /// Field value is not in the given set
32    NotIn {
33        field: String,
34        values: Vec<FilterValue>,
35    },
36    /// Field value contains the given substring
37    Contains { field: String, substring: String },
38    /// Field value matches the given regex pattern
39    Regex { field: String, pattern: String },
40    /// Field exists (has any value)
41    Exists { field: String },
42    /// Field does not exist or is null
43    NotExists { field: String },
44    /// Logical AND of multiple filters
45    And(Vec<MetadataFilter>),
46    /// Logical OR of multiple filters
47    Or(Vec<MetadataFilter>),
48    /// Logical NOT of a filter
49    Not(Box<MetadataFilter>),
50}
51
52/// Value type for filter predicates
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub enum FilterValue {
55    String(String),
56    Integer(i64),
57    Float(f64),
58    Boolean(bool),
59    Null,
60}
61
62impl FilterValue {
63    /// Compare two filter values
64    fn compare(&self, other: &FilterValue) -> std::cmp::Ordering {
65        match (self, other) {
66            (FilterValue::String(a), FilterValue::String(b)) => a.cmp(b),
67            (FilterValue::Integer(a), FilterValue::Integer(b)) => a.cmp(b),
68            (FilterValue::Float(a), FilterValue::Float(b)) => {
69                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
70            }
71            (FilterValue::Boolean(a), FilterValue::Boolean(b)) => a.cmp(b),
72            _ => std::cmp::Ordering::Equal,
73        }
74    }
75}
76
77impl MetadataFilter {
78    /// Evaluate the filter against a metadata map
79    pub fn evaluate(&self, metadata: &HashMap<String, String>) -> bool {
80        match self {
81            MetadataFilter::Equals { field, value } => {
82                if let Some(field_value) = metadata.get(field) {
83                    let parsed_value = Self::parse_value(field_value);
84                    &parsed_value == value
85                } else {
86                    false
87                }
88            }
89            MetadataFilter::NotEquals { field, value } => {
90                if let Some(field_value) = metadata.get(field) {
91                    let parsed_value = Self::parse_value(field_value);
92                    &parsed_value != value
93                } else {
94                    true
95                }
96            }
97            MetadataFilter::GreaterThan { field, value } => {
98                if let Some(field_value) = metadata.get(field) {
99                    let parsed_value = Self::parse_value(field_value);
100                    parsed_value.compare(value) == std::cmp::Ordering::Greater
101                } else {
102                    false
103                }
104            }
105            MetadataFilter::GreaterThanOrEqual { field, value } => {
106                if let Some(field_value) = metadata.get(field) {
107                    let parsed_value = Self::parse_value(field_value);
108                    matches!(
109                        parsed_value.compare(value),
110                        std::cmp::Ordering::Greater | std::cmp::Ordering::Equal
111                    )
112                } else {
113                    false
114                }
115            }
116            MetadataFilter::LessThan { field, value } => {
117                if let Some(field_value) = metadata.get(field) {
118                    let parsed_value = Self::parse_value(field_value);
119                    parsed_value.compare(value) == std::cmp::Ordering::Less
120                } else {
121                    false
122                }
123            }
124            MetadataFilter::LessThanOrEqual { field, value } => {
125                if let Some(field_value) = metadata.get(field) {
126                    let parsed_value = Self::parse_value(field_value);
127                    matches!(
128                        parsed_value.compare(value),
129                        std::cmp::Ordering::Less | std::cmp::Ordering::Equal
130                    )
131                } else {
132                    false
133                }
134            }
135            MetadataFilter::In { field, values } => {
136                if let Some(field_value) = metadata.get(field) {
137                    let parsed_value = Self::parse_value(field_value);
138                    values.contains(&parsed_value)
139                } else {
140                    false
141                }
142            }
143            MetadataFilter::NotIn { field, values } => {
144                if let Some(field_value) = metadata.get(field) {
145                    let parsed_value = Self::parse_value(field_value);
146                    !values.contains(&parsed_value)
147                } else {
148                    true
149                }
150            }
151            MetadataFilter::Contains { field, substring } => {
152                if let Some(field_value) = metadata.get(field) {
153                    field_value.contains(substring)
154                } else {
155                    false
156                }
157            }
158            MetadataFilter::Regex { field, pattern } => {
159                if let Some(field_value) = metadata.get(field) {
160                    if let Ok(regex) = regex::Regex::new(pattern) {
161                        regex.is_match(field_value)
162                    } else {
163                        false
164                    }
165                } else {
166                    false
167                }
168            }
169            MetadataFilter::Exists { field } => metadata.contains_key(field),
170            MetadataFilter::NotExists { field } => !metadata.contains_key(field),
171            MetadataFilter::And(filters) => filters.iter().all(|f| f.evaluate(metadata)),
172            MetadataFilter::Or(filters) => filters.iter().any(|f| f.evaluate(metadata)),
173            MetadataFilter::Not(filter) => !filter.evaluate(metadata),
174        }
175    }
176
177    /// Parse a string value into a FilterValue
178    fn parse_value(s: &str) -> FilterValue {
179        // Try to parse as integer
180        if let Ok(i) = s.parse::<i64>() {
181            return FilterValue::Integer(i);
182        }
183
184        // Try to parse as float
185        if let Ok(f) = s.parse::<f64>() {
186            return FilterValue::Float(f);
187        }
188
189        // Try to parse as boolean
190        if let Ok(b) = s.parse::<bool>() {
191            return FilterValue::Boolean(b);
192        }
193
194        // Check for null
195        if s == "null" || s.is_empty() {
196            return FilterValue::Null;
197        }
198
199        // Default to string
200        FilterValue::String(s.to_string())
201    }
202}
203
204/// Search filter combining distance and metadata constraints
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct SearchFilter {
207    /// Maximum distance threshold
208    pub max_distance: Option<f32>,
209    /// Minimum distance threshold
210    pub min_distance: Option<f32>,
211    /// Metadata filter predicates
212    pub metadata_filter: Option<MetadataFilter>,
213    /// Vector dimension constraints
214    pub dimension_constraints: Option<Vec<DimensionConstraint>>,
215}
216
217/// Constraint on specific vector dimensions
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct DimensionConstraint {
220    /// Dimension index
221    pub dimension: usize,
222    /// Minimum value for this dimension
223    pub min_value: Option<f32>,
224    /// Maximum value for this dimension
225    pub max_value: Option<f32>,
226}
227
228impl DimensionConstraint {
229    /// Check if a vector satisfies this dimension constraint
230    pub fn satisfies(&self, vector: &Vector) -> bool {
231        let values = vector.as_f32();
232
233        if self.dimension >= values.len() {
234            return false;
235        }
236
237        let value = values[self.dimension];
238
239        if let Some(min) = self.min_value {
240            if value < min {
241                return false;
242            }
243        }
244
245        if let Some(max) = self.max_value {
246            if value > max {
247                return false;
248            }
249        }
250
251        true
252    }
253}
254
255impl SearchFilter {
256    /// Create a new empty search filter
257    pub fn new() -> Self {
258        Self {
259            max_distance: None,
260            min_distance: None,
261            metadata_filter: None,
262            dimension_constraints: None,
263        }
264    }
265
266    /// Set maximum distance threshold
267    pub fn with_max_distance(mut self, max_distance: f32) -> Self {
268        self.max_distance = Some(max_distance);
269        self
270    }
271
272    /// Set minimum distance threshold
273    pub fn with_min_distance(mut self, min_distance: f32) -> Self {
274        self.min_distance = Some(min_distance);
275        self
276    }
277
278    /// Set metadata filter
279    pub fn with_metadata_filter(mut self, filter: MetadataFilter) -> Self {
280        self.metadata_filter = Some(filter);
281        self
282    }
283
284    /// Set dimension constraints
285    pub fn with_dimension_constraints(mut self, constraints: Vec<DimensionConstraint>) -> Self {
286        self.dimension_constraints = Some(constraints);
287        self
288    }
289
290    /// Check if a search result satisfies this filter
291    pub fn satisfies(
292        &self,
293        distance: f32,
294        vector: &Vector,
295        metadata: &HashMap<String, String>,
296    ) -> bool {
297        // Check distance constraints
298        if let Some(max) = self.max_distance {
299            if distance > max {
300                return false;
301            }
302        }
303
304        if let Some(min) = self.min_distance {
305            if distance < min {
306                return false;
307            }
308        }
309
310        // Check metadata filter
311        if let Some(ref filter) = self.metadata_filter {
312            if !filter.evaluate(metadata) {
313                return false;
314            }
315        }
316
317        // Check dimension constraints
318        if let Some(ref constraints) = self.dimension_constraints {
319            for constraint in constraints {
320                if !constraint.satisfies(vector) {
321                    return false;
322                }
323            }
324        }
325
326        true
327    }
328
329    /// Filter a list of search results
330    pub fn filter_results(
331        &self,
332        results: Vec<(VectorId, f32, Vector, HashMap<String, String>)>,
333    ) -> Vec<(VectorId, f32)> {
334        results
335            .into_iter()
336            .filter(|(_, distance, vector, metadata)| self.satisfies(*distance, vector, metadata))
337            .map(|(id, distance, _, _)| (id, distance))
338            .collect()
339    }
340}
341
342impl Default for SearchFilter {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348/// Builder for complex filter expressions
349pub struct FilterBuilder {
350    filters: Vec<MetadataFilter>,
351}
352
353impl FilterBuilder {
354    pub fn new() -> Self {
355        Self {
356            filters: Vec::new(),
357        }
358    }
359
360    pub fn equals(mut self, field: impl Into<String>, value: FilterValue) -> Self {
361        self.filters.push(MetadataFilter::Equals {
362            field: field.into(),
363            value,
364        });
365        self
366    }
367
368    pub fn not_equals(mut self, field: impl Into<String>, value: FilterValue) -> Self {
369        self.filters.push(MetadataFilter::NotEquals {
370            field: field.into(),
371            value,
372        });
373        self
374    }
375
376    pub fn greater_than(mut self, field: impl Into<String>, value: FilterValue) -> Self {
377        self.filters.push(MetadataFilter::GreaterThan {
378            field: field.into(),
379            value,
380        });
381        self
382    }
383
384    pub fn less_than(mut self, field: impl Into<String>, value: FilterValue) -> Self {
385        self.filters.push(MetadataFilter::LessThan {
386            field: field.into(),
387            value,
388        });
389        self
390    }
391
392    pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
393        self.filters.push(MetadataFilter::Contains {
394            field: field.into(),
395            substring: substring.into(),
396        });
397        self
398    }
399
400    pub fn regex(mut self, field: impl Into<String>, pattern: impl Into<String>) -> Self {
401        self.filters.push(MetadataFilter::Regex {
402            field: field.into(),
403            pattern: pattern.into(),
404        });
405        self
406    }
407
408    pub fn exists(mut self, field: impl Into<String>) -> Self {
409        self.filters.push(MetadataFilter::Exists {
410            field: field.into(),
411        });
412        self
413    }
414
415    pub fn build_and(self) -> MetadataFilter {
416        if self.filters.len() == 1 {
417            self.filters.into_iter().next().unwrap()
418        } else {
419            MetadataFilter::And(self.filters)
420        }
421    }
422
423    pub fn build_or(self) -> MetadataFilter {
424        if self.filters.len() == 1 {
425            self.filters.into_iter().next().unwrap()
426        } else {
427            MetadataFilter::Or(self.filters)
428        }
429    }
430}
431
432impl Default for FilterBuilder {
433    fn default() -> Self {
434        Self::new()
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_equals_filter() {
444        let filter = MetadataFilter::Equals {
445            field: "category".to_string(),
446            value: FilterValue::String("news".to_string()),
447        };
448
449        let mut metadata = HashMap::new();
450        metadata.insert("category".to_string(), "news".to_string());
451
452        assert!(filter.evaluate(&metadata));
453
454        metadata.insert("category".to_string(), "sports".to_string());
455        assert!(!filter.evaluate(&metadata));
456    }
457
458    #[test]
459    fn test_greater_than_filter() {
460        let filter = MetadataFilter::GreaterThan {
461            field: "score".to_string(),
462            value: FilterValue::Integer(50),
463        };
464
465        let mut metadata = HashMap::new();
466        metadata.insert("score".to_string(), "75".to_string());
467        assert!(filter.evaluate(&metadata));
468
469        metadata.insert("score".to_string(), "25".to_string());
470        assert!(!filter.evaluate(&metadata));
471    }
472
473    #[test]
474    fn test_and_filter() {
475        let filter = MetadataFilter::And(vec![
476            MetadataFilter::Equals {
477                field: "status".to_string(),
478                value: FilterValue::String("active".to_string()),
479            },
480            MetadataFilter::GreaterThan {
481                field: "priority".to_string(),
482                value: FilterValue::Integer(5),
483            },
484        ]);
485
486        let mut metadata = HashMap::new();
487        metadata.insert("status".to_string(), "active".to_string());
488        metadata.insert("priority".to_string(), "8".to_string());
489        assert!(filter.evaluate(&metadata));
490
491        metadata.insert("priority".to_string(), "3".to_string());
492        assert!(!filter.evaluate(&metadata));
493    }
494
495    #[test]
496    fn test_or_filter() {
497        let filter = MetadataFilter::Or(vec![
498            MetadataFilter::Equals {
499                field: "type".to_string(),
500                value: FilterValue::String("urgent".to_string()),
501            },
502            MetadataFilter::Equals {
503                field: "type".to_string(),
504                value: FilterValue::String("critical".to_string()),
505            },
506        ]);
507
508        let mut metadata = HashMap::new();
509        metadata.insert("type".to_string(), "urgent".to_string());
510        assert!(filter.evaluate(&metadata));
511
512        metadata.insert("type".to_string(), "critical".to_string());
513        assert!(filter.evaluate(&metadata));
514
515        metadata.insert("type".to_string(), "normal".to_string());
516        assert!(!filter.evaluate(&metadata));
517    }
518
519    #[test]
520    fn test_contains_filter() {
521        let filter = MetadataFilter::Contains {
522            field: "description".to_string(),
523            substring: "important".to_string(),
524        };
525
526        let mut metadata = HashMap::new();
527        metadata.insert(
528            "description".to_string(),
529            "This is an important message".to_string(),
530        );
531        assert!(filter.evaluate(&metadata));
532
533        metadata.insert("description".to_string(), "Regular message".to_string());
534        assert!(!filter.evaluate(&metadata));
535    }
536
537    #[test]
538    fn test_filter_builder() {
539        let filter = FilterBuilder::new()
540            .equals("category", FilterValue::String("tech".to_string()))
541            .greater_than("score", FilterValue::Integer(70))
542            .build_and();
543
544        let mut metadata = HashMap::new();
545        metadata.insert("category".to_string(), "tech".to_string());
546        metadata.insert("score".to_string(), "85".to_string());
547        assert!(filter.evaluate(&metadata));
548    }
549
550    #[test]
551    fn test_dimension_constraint() {
552        let constraint = DimensionConstraint {
553            dimension: 0,
554            min_value: Some(0.0),
555            max_value: Some(1.0),
556        };
557
558        let vec1 = Vector::new(vec![0.5, 0.3, 0.7]);
559        assert!(constraint.satisfies(&vec1));
560
561        let vec2 = Vector::new(vec![1.5, 0.3, 0.7]);
562        assert!(!constraint.satisfies(&vec2));
563    }
564
565    #[test]
566    fn test_search_filter() {
567        let filter = SearchFilter::new()
568            .with_max_distance(0.5)
569            .with_metadata_filter(MetadataFilter::Equals {
570                field: "category".to_string(),
571                value: FilterValue::String("approved".to_string()),
572            });
573
574        let mut metadata = HashMap::new();
575        metadata.insert("category".to_string(), "approved".to_string());
576
577        let vector = Vector::new(vec![1.0, 2.0, 3.0]);
578
579        assert!(filter.satisfies(0.3, &vector, &metadata));
580        assert!(!filter.satisfies(0.7, &vector, &metadata)); // distance too high
581
582        metadata.insert("category".to_string(), "pending".to_string());
583        assert!(!filter.satisfies(0.3, &vector, &metadata)); // metadata doesn't match
584    }
585}