1use crate::{Vector, VectorId};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum MetadataFilter {
14 Equals { field: String, value: FilterValue },
16 NotEquals { field: String, value: FilterValue },
18 GreaterThan { field: String, value: FilterValue },
20 GreaterThanOrEqual { field: String, value: FilterValue },
22 LessThan { field: String, value: FilterValue },
24 LessThanOrEqual { field: String, value: FilterValue },
26 In {
28 field: String,
29 values: Vec<FilterValue>,
30 },
31 NotIn {
33 field: String,
34 values: Vec<FilterValue>,
35 },
36 Contains { field: String, substring: String },
38 Regex { field: String, pattern: String },
40 Exists { field: String },
42 NotExists { field: String },
44 And(Vec<MetadataFilter>),
46 Or(Vec<MetadataFilter>),
48 Not(Box<MetadataFilter>),
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub enum FilterValue {
55 String(String),
56 Integer(i64),
57 Float(f64),
58 Boolean(bool),
59 Null,
60}
61
62impl FilterValue {
63 fn compare(&self, other: &FilterValue) -> std::cmp::Ordering {
65 match (self, other) {
66 (FilterValue::String(a), FilterValue::String(b)) => a.cmp(b),
67 (FilterValue::Integer(a), FilterValue::Integer(b)) => a.cmp(b),
68 (FilterValue::Float(a), FilterValue::Float(b)) => {
69 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
70 }
71 (FilterValue::Boolean(a), FilterValue::Boolean(b)) => a.cmp(b),
72 _ => std::cmp::Ordering::Equal,
73 }
74 }
75}
76
77impl MetadataFilter {
78 pub fn evaluate(&self, metadata: &HashMap<String, String>) -> bool {
80 match self {
81 MetadataFilter::Equals { field, value } => {
82 if let Some(field_value) = metadata.get(field) {
83 let parsed_value = Self::parse_value(field_value);
84 &parsed_value == value
85 } else {
86 false
87 }
88 }
89 MetadataFilter::NotEquals { field, value } => {
90 if let Some(field_value) = metadata.get(field) {
91 let parsed_value = Self::parse_value(field_value);
92 &parsed_value != value
93 } else {
94 true
95 }
96 }
97 MetadataFilter::GreaterThan { field, value } => {
98 if let Some(field_value) = metadata.get(field) {
99 let parsed_value = Self::parse_value(field_value);
100 parsed_value.compare(value) == std::cmp::Ordering::Greater
101 } else {
102 false
103 }
104 }
105 MetadataFilter::GreaterThanOrEqual { field, value } => {
106 if let Some(field_value) = metadata.get(field) {
107 let parsed_value = Self::parse_value(field_value);
108 matches!(
109 parsed_value.compare(value),
110 std::cmp::Ordering::Greater | std::cmp::Ordering::Equal
111 )
112 } else {
113 false
114 }
115 }
116 MetadataFilter::LessThan { field, value } => {
117 if let Some(field_value) = metadata.get(field) {
118 let parsed_value = Self::parse_value(field_value);
119 parsed_value.compare(value) == std::cmp::Ordering::Less
120 } else {
121 false
122 }
123 }
124 MetadataFilter::LessThanOrEqual { field, value } => {
125 if let Some(field_value) = metadata.get(field) {
126 let parsed_value = Self::parse_value(field_value);
127 matches!(
128 parsed_value.compare(value),
129 std::cmp::Ordering::Less | std::cmp::Ordering::Equal
130 )
131 } else {
132 false
133 }
134 }
135 MetadataFilter::In { field, values } => {
136 if let Some(field_value) = metadata.get(field) {
137 let parsed_value = Self::parse_value(field_value);
138 values.contains(&parsed_value)
139 } else {
140 false
141 }
142 }
143 MetadataFilter::NotIn { field, values } => {
144 if let Some(field_value) = metadata.get(field) {
145 let parsed_value = Self::parse_value(field_value);
146 !values.contains(&parsed_value)
147 } else {
148 true
149 }
150 }
151 MetadataFilter::Contains { field, substring } => {
152 if let Some(field_value) = metadata.get(field) {
153 field_value.contains(substring)
154 } else {
155 false
156 }
157 }
158 MetadataFilter::Regex { field, pattern } => {
159 if let Some(field_value) = metadata.get(field) {
160 if let Ok(regex) = regex::Regex::new(pattern) {
161 regex.is_match(field_value)
162 } else {
163 false
164 }
165 } else {
166 false
167 }
168 }
169 MetadataFilter::Exists { field } => metadata.contains_key(field),
170 MetadataFilter::NotExists { field } => !metadata.contains_key(field),
171 MetadataFilter::And(filters) => filters.iter().all(|f| f.evaluate(metadata)),
172 MetadataFilter::Or(filters) => filters.iter().any(|f| f.evaluate(metadata)),
173 MetadataFilter::Not(filter) => !filter.evaluate(metadata),
174 }
175 }
176
177 fn parse_value(s: &str) -> FilterValue {
179 if let Ok(i) = s.parse::<i64>() {
181 return FilterValue::Integer(i);
182 }
183
184 if let Ok(f) = s.parse::<f64>() {
186 return FilterValue::Float(f);
187 }
188
189 if let Ok(b) = s.parse::<bool>() {
191 return FilterValue::Boolean(b);
192 }
193
194 if s == "null" || s.is_empty() {
196 return FilterValue::Null;
197 }
198
199 FilterValue::String(s.to_string())
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct SearchFilter {
207 pub max_distance: Option<f32>,
209 pub min_distance: Option<f32>,
211 pub metadata_filter: Option<MetadataFilter>,
213 pub dimension_constraints: Option<Vec<DimensionConstraint>>,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct DimensionConstraint {
220 pub dimension: usize,
222 pub min_value: Option<f32>,
224 pub max_value: Option<f32>,
226}
227
228impl DimensionConstraint {
229 pub fn satisfies(&self, vector: &Vector) -> bool {
231 let values = vector.as_f32();
232
233 if self.dimension >= values.len() {
234 return false;
235 }
236
237 let value = values[self.dimension];
238
239 if let Some(min) = self.min_value {
240 if value < min {
241 return false;
242 }
243 }
244
245 if let Some(max) = self.max_value {
246 if value > max {
247 return false;
248 }
249 }
250
251 true
252 }
253}
254
255impl SearchFilter {
256 pub fn new() -> Self {
258 Self {
259 max_distance: None,
260 min_distance: None,
261 metadata_filter: None,
262 dimension_constraints: None,
263 }
264 }
265
266 pub fn with_max_distance(mut self, max_distance: f32) -> Self {
268 self.max_distance = Some(max_distance);
269 self
270 }
271
272 pub fn with_min_distance(mut self, min_distance: f32) -> Self {
274 self.min_distance = Some(min_distance);
275 self
276 }
277
278 pub fn with_metadata_filter(mut self, filter: MetadataFilter) -> Self {
280 self.metadata_filter = Some(filter);
281 self
282 }
283
284 pub fn with_dimension_constraints(mut self, constraints: Vec<DimensionConstraint>) -> Self {
286 self.dimension_constraints = Some(constraints);
287 self
288 }
289
290 pub fn satisfies(
292 &self,
293 distance: f32,
294 vector: &Vector,
295 metadata: &HashMap<String, String>,
296 ) -> bool {
297 if let Some(max) = self.max_distance {
299 if distance > max {
300 return false;
301 }
302 }
303
304 if let Some(min) = self.min_distance {
305 if distance < min {
306 return false;
307 }
308 }
309
310 if let Some(ref filter) = self.metadata_filter {
312 if !filter.evaluate(metadata) {
313 return false;
314 }
315 }
316
317 if let Some(ref constraints) = self.dimension_constraints {
319 for constraint in constraints {
320 if !constraint.satisfies(vector) {
321 return false;
322 }
323 }
324 }
325
326 true
327 }
328
329 pub fn filter_results(
331 &self,
332 results: Vec<(VectorId, f32, Vector, HashMap<String, String>)>,
333 ) -> Vec<(VectorId, f32)> {
334 results
335 .into_iter()
336 .filter(|(_, distance, vector, metadata)| self.satisfies(*distance, vector, metadata))
337 .map(|(id, distance, _, _)| (id, distance))
338 .collect()
339 }
340}
341
342impl Default for SearchFilter {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348pub struct FilterBuilder {
350 filters: Vec<MetadataFilter>,
351}
352
353impl FilterBuilder {
354 pub fn new() -> Self {
355 Self {
356 filters: Vec::new(),
357 }
358 }
359
360 pub fn equals(mut self, field: impl Into<String>, value: FilterValue) -> Self {
361 self.filters.push(MetadataFilter::Equals {
362 field: field.into(),
363 value,
364 });
365 self
366 }
367
368 pub fn not_equals(mut self, field: impl Into<String>, value: FilterValue) -> Self {
369 self.filters.push(MetadataFilter::NotEquals {
370 field: field.into(),
371 value,
372 });
373 self
374 }
375
376 pub fn greater_than(mut self, field: impl Into<String>, value: FilterValue) -> Self {
377 self.filters.push(MetadataFilter::GreaterThan {
378 field: field.into(),
379 value,
380 });
381 self
382 }
383
384 pub fn less_than(mut self, field: impl Into<String>, value: FilterValue) -> Self {
385 self.filters.push(MetadataFilter::LessThan {
386 field: field.into(),
387 value,
388 });
389 self
390 }
391
392 pub fn contains(mut self, field: impl Into<String>, substring: impl Into<String>) -> Self {
393 self.filters.push(MetadataFilter::Contains {
394 field: field.into(),
395 substring: substring.into(),
396 });
397 self
398 }
399
400 pub fn regex(mut self, field: impl Into<String>, pattern: impl Into<String>) -> Self {
401 self.filters.push(MetadataFilter::Regex {
402 field: field.into(),
403 pattern: pattern.into(),
404 });
405 self
406 }
407
408 pub fn exists(mut self, field: impl Into<String>) -> Self {
409 self.filters.push(MetadataFilter::Exists {
410 field: field.into(),
411 });
412 self
413 }
414
415 pub fn build_and(self) -> MetadataFilter {
416 if self.filters.len() == 1 {
417 self.filters.into_iter().next().unwrap()
418 } else {
419 MetadataFilter::And(self.filters)
420 }
421 }
422
423 pub fn build_or(self) -> MetadataFilter {
424 if self.filters.len() == 1 {
425 self.filters.into_iter().next().unwrap()
426 } else {
427 MetadataFilter::Or(self.filters)
428 }
429 }
430}
431
432impl Default for FilterBuilder {
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_equals_filter() {
444 let filter = MetadataFilter::Equals {
445 field: "category".to_string(),
446 value: FilterValue::String("news".to_string()),
447 };
448
449 let mut metadata = HashMap::new();
450 metadata.insert("category".to_string(), "news".to_string());
451
452 assert!(filter.evaluate(&metadata));
453
454 metadata.insert("category".to_string(), "sports".to_string());
455 assert!(!filter.evaluate(&metadata));
456 }
457
458 #[test]
459 fn test_greater_than_filter() {
460 let filter = MetadataFilter::GreaterThan {
461 field: "score".to_string(),
462 value: FilterValue::Integer(50),
463 };
464
465 let mut metadata = HashMap::new();
466 metadata.insert("score".to_string(), "75".to_string());
467 assert!(filter.evaluate(&metadata));
468
469 metadata.insert("score".to_string(), "25".to_string());
470 assert!(!filter.evaluate(&metadata));
471 }
472
473 #[test]
474 fn test_and_filter() {
475 let filter = MetadataFilter::And(vec![
476 MetadataFilter::Equals {
477 field: "status".to_string(),
478 value: FilterValue::String("active".to_string()),
479 },
480 MetadataFilter::GreaterThan {
481 field: "priority".to_string(),
482 value: FilterValue::Integer(5),
483 },
484 ]);
485
486 let mut metadata = HashMap::new();
487 metadata.insert("status".to_string(), "active".to_string());
488 metadata.insert("priority".to_string(), "8".to_string());
489 assert!(filter.evaluate(&metadata));
490
491 metadata.insert("priority".to_string(), "3".to_string());
492 assert!(!filter.evaluate(&metadata));
493 }
494
495 #[test]
496 fn test_or_filter() {
497 let filter = MetadataFilter::Or(vec![
498 MetadataFilter::Equals {
499 field: "type".to_string(),
500 value: FilterValue::String("urgent".to_string()),
501 },
502 MetadataFilter::Equals {
503 field: "type".to_string(),
504 value: FilterValue::String("critical".to_string()),
505 },
506 ]);
507
508 let mut metadata = HashMap::new();
509 metadata.insert("type".to_string(), "urgent".to_string());
510 assert!(filter.evaluate(&metadata));
511
512 metadata.insert("type".to_string(), "critical".to_string());
513 assert!(filter.evaluate(&metadata));
514
515 metadata.insert("type".to_string(), "normal".to_string());
516 assert!(!filter.evaluate(&metadata));
517 }
518
519 #[test]
520 fn test_contains_filter() {
521 let filter = MetadataFilter::Contains {
522 field: "description".to_string(),
523 substring: "important".to_string(),
524 };
525
526 let mut metadata = HashMap::new();
527 metadata.insert(
528 "description".to_string(),
529 "This is an important message".to_string(),
530 );
531 assert!(filter.evaluate(&metadata));
532
533 metadata.insert("description".to_string(), "Regular message".to_string());
534 assert!(!filter.evaluate(&metadata));
535 }
536
537 #[test]
538 fn test_filter_builder() {
539 let filter = FilterBuilder::new()
540 .equals("category", FilterValue::String("tech".to_string()))
541 .greater_than("score", FilterValue::Integer(70))
542 .build_and();
543
544 let mut metadata = HashMap::new();
545 metadata.insert("category".to_string(), "tech".to_string());
546 metadata.insert("score".to_string(), "85".to_string());
547 assert!(filter.evaluate(&metadata));
548 }
549
550 #[test]
551 fn test_dimension_constraint() {
552 let constraint = DimensionConstraint {
553 dimension: 0,
554 min_value: Some(0.0),
555 max_value: Some(1.0),
556 };
557
558 let vec1 = Vector::new(vec![0.5, 0.3, 0.7]);
559 assert!(constraint.satisfies(&vec1));
560
561 let vec2 = Vector::new(vec![1.5, 0.3, 0.7]);
562 assert!(!constraint.satisfies(&vec2));
563 }
564
565 #[test]
566 fn test_search_filter() {
567 let filter = SearchFilter::new()
568 .with_max_distance(0.5)
569 .with_metadata_filter(MetadataFilter::Equals {
570 field: "category".to_string(),
571 value: FilterValue::String("approved".to_string()),
572 });
573
574 let mut metadata = HashMap::new();
575 metadata.insert("category".to_string(), "approved".to_string());
576
577 let vector = Vector::new(vec![1.0, 2.0, 3.0]);
578
579 assert!(filter.satisfies(0.3, &vector, &metadata));
580 assert!(!filter.satisfies(0.7, &vector, &metadata)); metadata.insert("category".to_string(), "pending".to_string());
583 assert!(!filter.satisfies(0.3, &vector, &metadata)); }
585}