1use crate::{DiskANN, DiskAnnError, DiskAnnParams};
46use anndists::prelude::Distance;
47use rayon::prelude::*;
48use serde::{Deserialize, Serialize};
49use std::collections::{BinaryHeap, HashSet};
50use std::cmp::{Ordering, Reverse};
51use std::fs::{File, OpenOptions};
52use std::io::{BufReader, BufWriter, Read, Write};
53
54#[derive(Clone, Debug)]
56pub enum Filter {
57 LabelEq { field: usize, value: u64 },
59 LabelIn { field: usize, values: HashSet<u64> },
61 LabelLt { field: usize, value: u64 },
63 LabelGt { field: usize, value: u64 },
65 LabelRange { field: usize, min: u64, max: u64 },
67 And(Vec<Filter>),
69 Or(Vec<Filter>),
71 None,
73}
74
75impl Filter {
76 pub fn label_eq(field: usize, value: u64) -> Self {
78 Filter::LabelEq { field, value }
79 }
80
81 pub fn label_in(field: usize, values: impl IntoIterator<Item = u64>) -> Self {
83 Filter::LabelIn {
84 field,
85 values: values.into_iter().collect(),
86 }
87 }
88
89 pub fn label_lt(field: usize, value: u64) -> Self {
91 Filter::LabelLt { field, value }
92 }
93
94 pub fn label_gt(field: usize, value: u64) -> Self {
96 Filter::LabelGt { field, value }
97 }
98
99 pub fn label_range(field: usize, min: u64, max: u64) -> Self {
101 Filter::LabelRange { field, min, max }
102 }
103
104 pub fn and(filters: Vec<Filter>) -> Self {
106 Filter::And(filters)
107 }
108
109 pub fn or(filters: Vec<Filter>) -> Self {
111 Filter::Or(filters)
112 }
113
114 pub fn matches(&self, labels: &[u64]) -> bool {
116 match self {
117 Filter::None => true,
118 Filter::LabelEq { field, value } => {
119 labels.get(*field).map_or(false, |v| v == value)
120 }
121 Filter::LabelIn { field, values } => {
122 labels.get(*field).map_or(false, |v| values.contains(v))
123 }
124 Filter::LabelLt { field, value } => {
125 labels.get(*field).map_or(false, |v| v < value)
126 }
127 Filter::LabelGt { field, value } => {
128 labels.get(*field).map_or(false, |v| v > value)
129 }
130 Filter::LabelRange { field, min, max } => {
131 labels.get(*field).map_or(false, |v| v >= min && v <= max)
132 }
133 Filter::And(filters) => filters.iter().all(|f| f.matches(labels)),
134 Filter::Or(filters) => filters.iter().any(|f| f.matches(labels)),
135 }
136 }
137}
138
139#[derive(Serialize, Deserialize, Debug)]
141struct FilteredMetadata {
142 num_vectors: usize,
143 num_fields: usize,
144}
145
146#[derive(Clone, Copy)]
148struct Candidate {
149 dist: f32,
150 id: u32,
151}
152
153impl PartialEq for Candidate {
154 fn eq(&self, other: &Self) -> bool {
155 self.dist == other.dist && self.id == other.id
156 }
157}
158impl Eq for Candidate {}
159impl PartialOrd for Candidate {
160 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161 self.dist.partial_cmp(&other.dist)
162 }
163}
164impl Ord for Candidate {
165 fn cmp(&self, other: &Self) -> Ordering {
166 self.partial_cmp(other).unwrap_or(Ordering::Equal)
167 }
168}
169
170pub struct FilteredDiskANN<D>
172where
173 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
174{
175 index: DiskANN<D>,
177 labels: Vec<Vec<u64>>,
179 num_fields: usize,
181 #[allow(dead_code)]
183 labels_path: String,
184}
185
186impl<D> FilteredDiskANN<D>
187where
188 D: Distance<f32> + Send + Sync + Copy + Clone + Default + 'static,
189{
190 pub fn build(
192 vectors: &[Vec<f32>],
193 labels: &[Vec<u64>],
194 base_path: &str,
195 ) -> Result<Self, DiskAnnError> {
196 Self::build_with_params(vectors, labels, base_path, DiskAnnParams::default())
197 }
198
199 pub fn build_with_params(
201 vectors: &[Vec<f32>],
202 labels: &[Vec<u64>],
203 base_path: &str,
204 params: DiskAnnParams,
205 ) -> Result<Self, DiskAnnError> {
206 if vectors.len() != labels.len() {
207 return Err(DiskAnnError::IndexError(format!(
208 "vectors.len() ({}) != labels.len() ({})",
209 vectors.len(),
210 labels.len()
211 )));
212 }
213
214 let num_fields = labels.first().map(|l| l.len()).unwrap_or(0);
215 for (i, l) in labels.iter().enumerate() {
216 if l.len() != num_fields {
217 return Err(DiskAnnError::IndexError(format!(
218 "Label {} has {} fields, expected {}",
219 i,
220 l.len(),
221 num_fields
222 )));
223 }
224 }
225
226 let index_path = format!("{}.idx", base_path);
228 let index = DiskANN::<D>::build_index_with_params(
229 vectors,
230 D::default(),
231 &index_path,
232 params,
233 )?;
234
235 let labels_path = format!("{}.labels", base_path);
237 Self::save_labels(&labels_path, labels, num_fields)?;
238
239 Ok(Self {
240 index,
241 labels: labels.to_vec(),
242 num_fields,
243 labels_path,
244 })
245 }
246
247 pub fn open(base_path: &str) -> Result<Self, DiskAnnError> {
249 let index_path = format!("{}.idx", base_path);
250 let labels_path = format!("{}.labels", base_path);
251
252 let index = DiskANN::<D>::open_index_default_metric(&index_path)?;
253 let (labels, num_fields) = Self::load_labels(&labels_path)?;
254
255 if labels.len() != index.num_vectors {
256 return Err(DiskAnnError::IndexError(format!(
257 "Labels count ({}) != index vectors ({})",
258 labels.len(),
259 index.num_vectors
260 )));
261 }
262
263 Ok(Self {
264 index,
265 labels,
266 num_fields,
267 labels_path,
268 })
269 }
270
271 fn save_labels(path: &str, labels: &[Vec<u64>], num_fields: usize) -> Result<(), DiskAnnError> {
272 let file = OpenOptions::new()
273 .create(true)
274 .write(true)
275 .truncate(true)
276 .open(path)?;
277 let mut writer = BufWriter::new(file);
278
279 let meta = FilteredMetadata {
280 num_vectors: labels.len(),
281 num_fields,
282 };
283 let meta_bytes = bincode::serialize(&meta)?;
284 writer.write_all(&(meta_bytes.len() as u64).to_le_bytes())?;
285 writer.write_all(&meta_bytes)?;
286
287 for label_vec in labels {
289 for &val in label_vec {
290 writer.write_all(&val.to_le_bytes())?;
291 }
292 }
293
294 writer.flush()?;
295 Ok(())
296 }
297
298 fn load_labels(path: &str) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
299 let file = File::open(path)?;
300 let mut reader = BufReader::new(file);
301
302 let mut len_buf = [0u8; 8];
304 reader.read_exact(&mut len_buf)?;
305 let meta_len = u64::from_le_bytes(len_buf) as usize;
306
307 let mut meta_bytes = vec![0u8; meta_len];
308 reader.read_exact(&mut meta_bytes)?;
309 let meta: FilteredMetadata = bincode::deserialize(&meta_bytes)?;
310
311 let mut labels = Vec::with_capacity(meta.num_vectors);
313 let mut val_buf = [0u8; 8];
314
315 for _ in 0..meta.num_vectors {
316 let mut label_vec = Vec::with_capacity(meta.num_fields);
317 for _ in 0..meta.num_fields {
318 reader.read_exact(&mut val_buf)?;
319 label_vec.push(u64::from_le_bytes(val_buf));
320 }
321 labels.push(label_vec);
322 }
323
324 Ok((labels, meta.num_fields))
325 }
326}
327
328impl<D> FilteredDiskANN<D>
329where
330 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
331{
332 pub fn search_filtered(
337 &self,
338 query: &[f32],
339 k: usize,
340 beam_width: usize,
341 filter: &Filter,
342 ) -> Vec<u32> {
343 self.search_filtered_with_dists(query, k, beam_width, filter)
344 .into_iter()
345 .map(|(id, _)| id)
346 .collect()
347 }
348
349 pub fn search_filtered_with_dists(
351 &self,
352 query: &[f32],
353 k: usize,
354 beam_width: usize,
355 filter: &Filter,
356 ) -> Vec<(u32, f32)> {
357 if matches!(filter, Filter::None) {
359 return self.index.search_with_dists(query, k, beam_width);
360 }
361
362 let expanded_beam = (beam_width * 4).max(k * 10);
365
366 let mut visited = HashSet::new();
367 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
368 let mut working_set: BinaryHeap<Candidate> = BinaryHeap::new();
369 let mut results: Vec<(u32, f32)> = Vec::with_capacity(k);
370
371 let start_dist = self.distance_to(query, self.index.medoid_id as usize);
373 let start = Candidate {
374 dist: start_dist,
375 id: self.index.medoid_id,
376 };
377 frontier.push(Reverse(start));
378 working_set.push(start);
379 visited.insert(self.index.medoid_id);
380
381 if filter.matches(&self.labels[self.index.medoid_id as usize]) {
383 results.push((self.index.medoid_id, start_dist));
384 }
385
386 let mut iterations = 0;
388 let max_iterations = expanded_beam * 2;
389
390 while let Some(Reverse(best)) = frontier.peek().copied() {
391 iterations += 1;
392 if iterations > max_iterations {
393 break;
394 }
395
396 if results.len() >= k {
399 if let Some((_, worst_dist)) = results.last() {
400 if best.dist > *worst_dist * 1.5 {
401 break;
402 }
403 }
404 }
405
406 if working_set.len() >= expanded_beam {
407 if let Some(worst) = working_set.peek() {
408 if best.dist >= worst.dist {
409 break;
410 }
411 }
412 }
413
414 let Reverse(current) = frontier.pop().unwrap();
415
416 for &nb in self.get_neighbors(current.id) {
418 if nb == u32::MAX {
419 continue;
420 }
421 if !visited.insert(nb) {
422 continue;
423 }
424
425 let d = self.distance_to(query, nb as usize);
426 let cand = Candidate { dist: d, id: nb };
427
428 if working_set.len() < expanded_beam {
430 working_set.push(cand);
431 frontier.push(Reverse(cand));
432 } else if d < working_set.peek().unwrap().dist {
433 working_set.pop();
434 working_set.push(cand);
435 frontier.push(Reverse(cand));
436 }
437
438 if filter.matches(&self.labels[nb as usize]) {
440 let pos = results
442 .iter()
443 .position(|(_, dist)| d < *dist)
444 .unwrap_or(results.len());
445
446 if pos < k {
447 results.insert(pos, (nb, d));
448 if results.len() > k {
449 results.pop();
450 }
451 }
452 }
453 }
454 }
455
456 results
457 }
458
459 pub fn search_filtered_batch(
461 &self,
462 queries: &[Vec<f32>],
463 k: usize,
464 beam_width: usize,
465 filter: &Filter,
466 ) -> Vec<Vec<u32>> {
467 queries
468 .par_iter()
469 .map(|q| self.search_filtered(q, k, beam_width, filter))
470 .collect()
471 }
472
473 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
475 self.index.search(query, k, beam_width)
476 }
477
478 pub fn get_labels(&self, id: usize) -> Option<&[u64]> {
480 self.labels.get(id).map(|v| v.as_slice())
481 }
482
483 pub fn inner(&self) -> &DiskANN<D> {
485 &self.index
486 }
487
488 pub fn num_vectors(&self) -> usize {
490 self.index.num_vectors
491 }
492
493 pub fn num_fields(&self) -> usize {
495 self.num_fields
496 }
497
498 pub fn count_matching(&self, filter: &Filter) -> usize {
500 self.labels.iter().filter(|l| filter.matches(l)).count()
501 }
502
503 fn get_neighbors(&self, node_id: u32) -> &[u32] {
504 let offset = self.index.adjacency_offset
506 + (node_id as u64 * self.index.max_degree as u64 * 4);
507 let start = offset as usize;
508 let end = start + (self.index.max_degree * 4);
509 let bytes = &self.index.mmap[start..end];
510 bytemuck::cast_slice(bytes)
511 }
512
513 fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
514 let offset = self.index.vectors_offset + (idx as u64 * self.index.dim as u64 * 4);
515 let start = offset as usize;
516 let end = start + (self.index.dim * 4);
517 let bytes = &self.index.mmap[start..end];
518 let vector: &[f32] = bytemuck::cast_slice(bytes);
519 self.index.dist.eval(query, vector)
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use anndists::dist::DistL2;
527 use std::fs;
528
529 #[test]
530 fn test_filter_eq() {
531 let filter = Filter::label_eq(0, 5);
532 assert!(filter.matches(&[5, 10]));
533 assert!(!filter.matches(&[4, 10]));
534 assert!(!filter.matches(&[]));
535 }
536
537 #[test]
538 fn test_filter_in() {
539 let filter = Filter::label_in(0, vec![1, 3, 5]);
540 assert!(filter.matches(&[1]));
541 assert!(filter.matches(&[3]));
542 assert!(filter.matches(&[5]));
543 assert!(!filter.matches(&[2]));
544 }
545
546 #[test]
547 fn test_filter_range() {
548 let filter = Filter::label_range(0, 10, 20);
549 assert!(filter.matches(&[10]));
550 assert!(filter.matches(&[15]));
551 assert!(filter.matches(&[20]));
552 assert!(!filter.matches(&[9]));
553 assert!(!filter.matches(&[21]));
554 }
555
556 #[test]
557 fn test_filter_and() {
558 let filter = Filter::and(vec![
559 Filter::label_eq(0, 5),
560 Filter::label_gt(1, 10),
561 ]);
562 assert!(filter.matches(&[5, 15]));
563 assert!(!filter.matches(&[5, 5]));
564 assert!(!filter.matches(&[4, 15]));
565 }
566
567 #[test]
568 fn test_filter_or() {
569 let filter = Filter::or(vec![
570 Filter::label_eq(0, 5),
571 Filter::label_eq(0, 10),
572 ]);
573 assert!(filter.matches(&[5]));
574 assert!(filter.matches(&[10]));
575 assert!(!filter.matches(&[7]));
576 }
577
578 #[test]
579 fn test_filtered_search_basic() {
580 let base_path = "test_filtered";
581 let _ = fs::remove_file(format!("{}.idx", base_path));
582 let _ = fs::remove_file(format!("{}.labels", base_path));
583
584 let vectors: Vec<Vec<f32>> = (0..100)
586 .map(|i| vec![i as f32, (i * 2) as f32])
587 .collect();
588
589 let labels: Vec<Vec<u64>> = (0..100)
591 .map(|i| vec![i % 5])
592 .collect();
593
594 let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
595
596 let results = index.search(&[50.0, 100.0], 5, 32);
598 assert_eq!(results.len(), 5);
599
600 let filter = Filter::label_eq(0, 0);
602 let results = index.search_filtered(&[50.0, 100.0], 5, 32, &filter);
603
604 for id in &results {
606 assert_eq!(labels[*id as usize][0], 0);
607 }
608
609 let _ = fs::remove_file(format!("{}.idx", base_path));
610 let _ = fs::remove_file(format!("{}.labels", base_path));
611 }
612
613 #[test]
614 fn test_filtered_search_selectivity() {
615 let base_path = "test_filtered_sel";
616 let _ = fs::remove_file(format!("{}.idx", base_path));
617 let _ = fs::remove_file(format!("{}.labels", base_path));
618
619 let vectors: Vec<Vec<f32>> = (0..1000)
621 .map(|i| vec![(i % 100) as f32, ((i / 100) * 10) as f32])
622 .collect();
623
624 let labels: Vec<Vec<u64>> = (0..1000)
625 .map(|i| vec![i % 10]) .collect();
627
628 let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
629
630 let filter = Filter::label_eq(0, 3);
632 assert_eq!(index.count_matching(&filter), 100);
633
634 let results = index.search_filtered(&[50.0, 50.0], 10, 64, &filter);
636 assert!(results.len() <= 10);
637
638 for id in &results {
639 assert_eq!(labels[*id as usize][0], 3);
640 }
641
642 let _ = fs::remove_file(format!("{}.idx", base_path));
643 let _ = fs::remove_file(format!("{}.labels", base_path));
644 }
645
646 #[test]
647 fn test_filtered_persistence() {
648 let base_path = "test_filtered_persist";
649 let _ = fs::remove_file(format!("{}.idx", base_path));
650 let _ = fs::remove_file(format!("{}.labels", base_path));
651
652 let vectors: Vec<Vec<f32>> = (0..50)
653 .map(|i| vec![i as f32, i as f32])
654 .collect();
655 let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3, i]).collect();
656
657 {
658 let _index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
659 }
660
661 let index = FilteredDiskANN::<DistL2>::open(base_path).unwrap();
663 assert_eq!(index.num_vectors(), 50);
664 assert_eq!(index.num_fields(), 2);
665
666 let filter = Filter::label_eq(0, 1);
667 let results = index.search_filtered(&[25.0, 25.0], 5, 32, &filter);
668 for id in &results {
669 assert_eq!(index.get_labels(*id as usize).unwrap()[0], 1);
670 }
671
672 let _ = fs::remove_file(format!("{}.idx", base_path));
673 let _ = fs::remove_file(format!("{}.labels", base_path));
674 }
675}