1use crate::{Vector, VectorId};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum MetadataFilter {
14 Equals { field: String, value: FilterValue },
16 NotEquals { field: String, value: FilterValue },
18 GreaterThan { field: String, value: FilterValue },
20 GreaterThanOrEqual { field: String, value: FilterValue },
22 LessThan { field: String, value: FilterValue },
24 LessThanOrEqual { field: String, value: FilterValue },
26 In {
28 field: String,
29 values: Vec<FilterValue>,
30 },
31 NotIn {
33 field: String,
34 values: Vec<FilterValue>,
35 },
36 Contains { field: String, substring: String },
38 Regex { field: String, pattern: String },
40 Exists { field: String },
42 NotExists { field: String },
44 And(Vec<MetadataFilter>),
46 Or(Vec<MetadataFilter>),
48 Not(Box<MetadataFilter>),
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub enum FilterValue {
55 String(String),
56 Integer(i64),
57 Float(f64),
58 Boolean(bool),
59 Null,
60}
61
62impl FilterValue {
63 fn compare(&self, other: &FilterValue) -> std::cmp::Ordering {
65 match (self, other) {
66 (FilterValue::String(a), FilterValue::String(b)) => a.cmp(b),
67 (FilterValue::Integer(a), FilterValue::Integer(b)) => a.cmp(b),
68 (FilterValue::Float(a), FilterValue::Float(b)) => {
69 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
70 }
71 (FilterValue::Boolean(a), FilterValue::Boolean(b)) => a.cmp(b),
72 _ => std::cmp::Ordering::Equal,
73 }
74 }
75}
76
77impl MetadataFilter {
78 pub fn evaluate(&self, metadata: &HashMap<String, String>) -> bool {
80 match self {
81 MetadataFilter::Equals { field, value } => {
82 if let Some(field_value) = metadata.get(field) {
83 let parsed_value = Self::parse_value(field_value);
84 &parsed_value == value
85 } else {
86 false
87 }
88 }
89 MetadataFilter::NotEquals { field, value } => {
90 if let Some(field_value) = metadata.get(field) {
91 let parsed_value = Self::parse_value(field_value);
92 &parsed_value != value
93 } else {
94 true
95 }
96 }
97 MetadataFilter::GreaterThan { field, value } => {
98 if let Some(field_value) = metadata.get(field) {
99 let parsed_value = Self::parse_value(field_value);
100 parsed_value.compare(value) == std::cmp::Ordering::Greater
101 } else {
102 false
103 }
104 }
105 MetadataFilter::GreaterThanOrEqual { field, value } => {
106 if let Some(field_value) = metadata.get(field) {
107 let parsed_value = Self::parse_value(field_value);
108 matches!(
109 parsed_value.compare(value),
110 std::cmp::Ordering::Greater | std::cmp::Ordering::Equal
111 )
112 } else {
113 false
114 }
115 }
116 MetadataFilter::LessThan { field, value } => {
117 if let Some(field_value) = metadata.get(field) {
118 let parsed_value = Self::parse_value(field_value);
119 parsed_value.compare(value) == std::cmp::Ordering::Less
120 } else {
121 false
122 }
123 }
124 MetadataFilter::LessThanOrEqual { field, value } => {
125 if let Some(field_value) = metadata.get(field) {
126 let parsed_value = Self::parse_value(field_value);
127 matches!(
128 parsed_value.compare(value),
129 std::cmp::Ordering::Less | std::cmp::Ordering::Equal
130 )
131 } else {
132 false
133 }
134 }
135 MetadataFilter::In { field, values } => {
136 if let Some(field_value) = metadata.get(field) {
137 let parsed_value = Self::parse_value(field_value);
138 values.contains(&parsed_value)
139 } else {
140 false
141 }
142 }
143 MetadataFilter::NotIn { field, values } => {
144 if let Some(field_value) = metadata.get(field) {
145 let parsed_value = Self::parse_value(field_value);
146 !values.contains(&parsed_value)
147 } else {
148 true
149 }
150 }
151 MetadataFilter::Contains { field, substring } => {
152 if let Some(field_value) = metadata.get(field) {
153 field_value.contains(substring)
154 } else {
155 false
156 }
157 }
158 MetadataFilter::Regex { field, pattern } => {
159 if let Some(field_value) = metadata.get(field) {
160 if let Ok(regex) = regex::Regex::new(pattern) {
161 regex.is_match(field_value)
162 } else {
163 false
164 }
165 } else {
166 false
167 }
168 }
169 MetadataFilter::Exists { field } => metadata.contains_key(field),
170 MetadataFilter::NotExists { field } => !metadata.contains_key(field),
171 MetadataFilter::And(filters) => filters.iter().all(|f| f.evaluate(metadata)),
172 MetadataFilter::Or(filters) => filters.iter().any(|f| f.evaluate(metadata)),
173 MetadataFilter::Not(filter) => !filter.evaluate(metadata),
174 }
175 }
176
177 fn parse_value(s: &str) -> FilterValue {
179 if let Ok(i) = s.parse::<i64>() {
181 return FilterValue::Integer(i);
182 }
183
184 if let Ok(f) = s.parse::<f64>() {
186 return FilterValue::Float(f);
187 }
188
189 if let Ok(b) = s.parse::<bool>() {
191 return FilterValue::Boolean(b);
192 }
193
194 if s == "null" || s.is_empty() {
196 return FilterValue::Null;
197 }
198
199 FilterValue::String(s.to_string())
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct SearchFilter {
207 pub max_distance: Option<f32>,
209 pub min_distance: Option<f32>,
211 pub metadata_filter: Option<MetadataFilter>,
213 pub dimension_constraints: Option<Vec<DimensionConstraint>>,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct DimensionConstraint {
220 pub dimension: usize,
222 pub min_value: Option<f32>,
224 pub max_value: Option<f32>,
226}
227
228impl DimensionConstraint {
229 pub fn satisfies(&self, vector: &Vector) -> bool {
231 let values = vector.as_f32();
232
233 if self.dimension >= values.len() {
234 return false;
235 }
236
237 let value = values[self.dimension];
238
239 if let Some(min) = self.min_value {
240 if value < min {
241 return false;
242 }
243 }
244
245 if let Some(max) = self.max_value {
246 if value > max {
247 return false;
248 }
249 }
250
251 true
252 }
253}
254
255impl SearchFilter {
256 pub fn new() -> Self {
258 Self {
259 max_distance: None,
260 min_distance: None,
261 metadata_filter: None,
262 dimension_constraints: None,
263 }
264 }
265
266 pub fn with_max_distance(mut self, max_distance: f32) -> Self {
268 self.max_distance = Some(max_distance);
269 self
270 }
271
272 pub fn with_min_distance(mut self, min_distance: f32) -> Self {
274 self.min_distance = Some(min_distance);
275 self
276 }
277
278 pub fn with_metadata_filter(mut self, filter: MetadataFilter) -> Self {
280 self.metadata_filter = Some(filter);
281 self
282 }
283
284 pub fn with_dimension_constraints(mut self, constraints: Vec<DimensionConstraint>) -> Self {
286 self.dimension_constraints = Some(constraints);
287 self
288 }
289
290 pub fn satisfies(
292 &self,
293 distance: f32,
294 vector: &Vector,
295 metadata: &HashMap<String, String>,
296 ) -> bool {
297 if let Some(max) = self.max_distance {
299 if distance > max {
300 return false;
301 }
302 }
303
304 if let Some(min) = self.min_distance {
305 if distance < min {
306 return false;
307 }
308 }
309
310 if let Some(ref filter) = self.metadata_filter {
312 if !filter.evaluate(metadata) {
313 return false;
314 }
315 }
316
317 if let Some(ref constraints) = self.dimension_constraints {
319 for constraint in constraints {
320 if !constraint.satisfies(vector) {
321 return false;
322 }
323 }
324 }
325
326 true
327 }
328
329 pub fn filter_results(
331 &self,
332 results: Vec<(VectorId, f32, Vector, HashMap<String, String>)>,
333 ) -> Vec<(VectorId, f32)> {
334 results
335 .into_iter()
336 .filter(|(_, distance, vector, metadata)| self.satisfies(*distance, vector, metadata))
337 .map(|(id, distance, _, _)| (id, distance))
338 .collect()
339 }
340}
341
342impl Default for SearchFilter {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348pub struct FilterBuilder {
350 filters: Vec<MetadataFilter>,
351}
352
353impl FilterBuilder {
354 pub fn new() -> Self {
355 Self {
356 filters: Vec::new(),
357 }
358 }
359
360 pub fn equals(mut self, field: impl Into<String>, value: FilterValue) -> Self {
361 self.filters.push(MetadataFilter::Equals {
362 field: field.into(),
363 value,
364 });
365 self
366 }
367
368 pub fn not_equals(mut self, field: impl Into<String>, value: FilterValue) -> Self {
369 self.filters.push(MetadataFilter::NotEquals {
370 field: field.into(),
371 value,
372 });
373 self
374 }
375
376 pub fn greater_than(mut self, field: impl Into<String>, value: FilterValue) -> Self {
377 self.filters.push(MetadataFilter::GreaterThan {
378 field: field.into(),
379 value,
380 });
381 self
382 }
383
384 pub fn less_than(mut self, field: impl Into<String>, value: FilterValue) -> Self {
385 self.filters.push(MetadataFilter::LessThan {
386 field: field.into(),
387 value,
388 });
389 self
390 }
391
392 pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
393 self.filters.push(MetadataFilter::Contains {
394 field: field.into(),
395 substring: substring.into(),
396 });
397 self
398 }
399
400 pub fn regex(mut self, field: impl Into<String>, pattern: impl Into<String>) -> Self {
401 self.filters.push(MetadataFilter::Regex {
402 field: field.into(),
403 pattern: pattern.into(),
404 });
405 self
406 }
407
408 pub fn exists(mut self, field: impl Into<String>) -> Self {
409 self.filters.push(MetadataFilter::Exists {
410 field: field.into(),
411 });
412 self
413 }
414
415 pub fn build_and(self) -> MetadataFilter {
416 if self.filters.len() == 1 {
417 self.filters
418 .into_iter()
419 .next()
420 .expect("filters validated to have exactly one element")
421 } else {
422 MetadataFilter::And(self.filters)
423 }
424 }
425
426 pub fn build_or(self) -> MetadataFilter {
427 if self.filters.len() == 1 {
428 self.filters
429 .into_iter()
430 .next()
431 .expect("filters validated to have exactly one element")
432 } else {
433 MetadataFilter::Or(self.filters)
434 }
435 }
436}
437
438impl Default for FilterBuilder {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_equals_filter() {
450 let filter = MetadataFilter::Equals {
451 field: "category".to_string(),
452 value: FilterValue::String("news".to_string()),
453 };
454
455 let mut metadata = HashMap::new();
456 metadata.insert("category".to_string(), "news".to_string());
457
458 assert!(filter.evaluate(&metadata));
459
460 metadata.insert("category".to_string(), "sports".to_string());
461 assert!(!filter.evaluate(&metadata));
462 }
463
464 #[test]
465 fn test_greater_than_filter() {
466 let filter = MetadataFilter::GreaterThan {
467 field: "score".to_string(),
468 value: FilterValue::Integer(50),
469 };
470
471 let mut metadata = HashMap::new();
472 metadata.insert("score".to_string(), "75".to_string());
473 assert!(filter.evaluate(&metadata));
474
475 metadata.insert("score".to_string(), "25".to_string());
476 assert!(!filter.evaluate(&metadata));
477 }
478
479 #[test]
480 fn test_and_filter() {
481 let filter = MetadataFilter::And(vec![
482 MetadataFilter::Equals {
483 field: "status".to_string(),
484 value: FilterValue::String("active".to_string()),
485 },
486 MetadataFilter::GreaterThan {
487 field: "priority".to_string(),
488 value: FilterValue::Integer(5),
489 },
490 ]);
491
492 let mut metadata = HashMap::new();
493 metadata.insert("status".to_string(), "active".to_string());
494 metadata.insert("priority".to_string(), "8".to_string());
495 assert!(filter.evaluate(&metadata));
496
497 metadata.insert("priority".to_string(), "3".to_string());
498 assert!(!filter.evaluate(&metadata));
499 }
500
501 #[test]
502 fn test_or_filter() {
503 let filter = MetadataFilter::Or(vec![
504 MetadataFilter::Equals {
505 field: "type".to_string(),
506 value: FilterValue::String("urgent".to_string()),
507 },
508 MetadataFilter::Equals {
509 field: "type".to_string(),
510 value: FilterValue::String("critical".to_string()),
511 },
512 ]);
513
514 let mut metadata = HashMap::new();
515 metadata.insert("type".to_string(), "urgent".to_string());
516 assert!(filter.evaluate(&metadata));
517
518 metadata.insert("type".to_string(), "critical".to_string());
519 assert!(filter.evaluate(&metadata));
520
521 metadata.insert("type".to_string(), "normal".to_string());
522 assert!(!filter.evaluate(&metadata));
523 }
524
525 #[test]
526 fn test_contains_filter() {
527 let filter = MetadataFilter::Contains {
528 field: "description".to_string(),
529 substring: "important".to_string(),
530 };
531
532 let mut metadata = HashMap::new();
533 metadata.insert(
534 "description".to_string(),
535 "This is an important message".to_string(),
536 );
537 assert!(filter.evaluate(&metadata));
538
539 metadata.insert("description".to_string(), "Regular message".to_string());
540 assert!(!filter.evaluate(&metadata));
541 }
542
543 #[test]
544 fn test_filter_builder() {
545 let filter = FilterBuilder::new()
546 .equals("category", FilterValue::String("tech".to_string()))
547 .greater_than("score", FilterValue::Integer(70))
548 .build_and();
549
550 let mut metadata = HashMap::new();
551 metadata.insert("category".to_string(), "tech".to_string());
552 metadata.insert("score".to_string(), "85".to_string());
553 assert!(filter.evaluate(&metadata));
554 }
555
556 #[test]
557 fn test_dimension_constraint() {
558 let constraint = DimensionConstraint {
559 dimension: 0,
560 min_value: Some(0.0),
561 max_value: Some(1.0),
562 };
563
564 let vec1 = Vector::new(vec![0.5, 0.3, 0.7]);
565 assert!(constraint.satisfies(&vec1));
566
567 let vec2 = Vector::new(vec![1.5, 0.3, 0.7]);
568 assert!(!constraint.satisfies(&vec2));
569 }
570
571 #[test]
572 fn test_search_filter() {
573 let filter = SearchFilter::new()
574 .with_max_distance(0.5)
575 .with_metadata_filter(MetadataFilter::Equals {
576 field: "category".to_string(),
577 value: FilterValue::String("approved".to_string()),
578 });
579
580 let mut metadata = HashMap::new();
581 metadata.insert("category".to_string(), "approved".to_string());
582
583 let vector = Vector::new(vec![1.0, 2.0, 3.0]);
584
585 assert!(filter.satisfies(0.3, &vector, &metadata));
586 assert!(!filter.satisfies(0.7, &vector, &metadata)); metadata.insert("category".to_string(), "pending".to_string());
589 assert!(!filter.satisfies(0.3, &vector, &metadata)); }
591}