1use crate::{DiskANN, DiskAnnError, DiskAnnParams};
46use anndists::prelude::Distance;
47use rayon::prelude::*;
48use serde::{Deserialize, Serialize};
49use std::collections::{BinaryHeap, HashSet};
50use std::cmp::{Ordering, Reverse};
51use std::fs::{File, OpenOptions};
52use std::io::{BufReader, BufWriter, Read, Write};
53use std::sync::Arc;
54
55#[derive(Clone, Debug)]
57pub enum Filter {
58 LabelEq { field: usize, value: u64 },
60 LabelIn { field: usize, values: HashSet<u64> },
62 LabelLt { field: usize, value: u64 },
64 LabelGt { field: usize, value: u64 },
66 LabelRange { field: usize, min: u64, max: u64 },
68 And(Vec<Filter>),
70 Or(Vec<Filter>),
72 None,
74}
75
76impl Filter {
77 pub fn label_eq(field: usize, value: u64) -> Self {
79 Filter::LabelEq { field, value }
80 }
81
82 pub fn label_in(field: usize, values: impl IntoIterator<Item = u64>) -> Self {
84 Filter::LabelIn {
85 field,
86 values: values.into_iter().collect(),
87 }
88 }
89
90 pub fn label_lt(field: usize, value: u64) -> Self {
92 Filter::LabelLt { field, value }
93 }
94
95 pub fn label_gt(field: usize, value: u64) -> Self {
97 Filter::LabelGt { field, value }
98 }
99
100 pub fn label_range(field: usize, min: u64, max: u64) -> Self {
102 Filter::LabelRange { field, min, max }
103 }
104
105 pub fn and(filters: Vec<Filter>) -> Self {
107 Filter::And(filters)
108 }
109
110 pub fn or(filters: Vec<Filter>) -> Self {
112 Filter::Or(filters)
113 }
114
115 pub fn matches(&self, labels: &[u64]) -> bool {
117 match self {
118 Filter::None => true,
119 Filter::LabelEq { field, value } => {
120 labels.get(*field).map_or(false, |v| v == value)
121 }
122 Filter::LabelIn { field, values } => {
123 labels.get(*field).map_or(false, |v| values.contains(v))
124 }
125 Filter::LabelLt { field, value } => {
126 labels.get(*field).map_or(false, |v| v < value)
127 }
128 Filter::LabelGt { field, value } => {
129 labels.get(*field).map_or(false, |v| v > value)
130 }
131 Filter::LabelRange { field, min, max } => {
132 labels.get(*field).map_or(false, |v| v >= min && v <= max)
133 }
134 Filter::And(filters) => filters.iter().all(|f| f.matches(labels)),
135 Filter::Or(filters) => filters.iter().any(|f| f.matches(labels)),
136 }
137 }
138}
139
140#[derive(Serialize, Deserialize, Debug)]
142struct FilteredMetadata {
143 num_vectors: usize,
144 num_fields: usize,
145}
146
147#[derive(Clone, Copy)]
149struct Candidate {
150 dist: f32,
151 id: u32,
152}
153
154impl PartialEq for Candidate {
155 fn eq(&self, other: &Self) -> bool {
156 self.dist == other.dist && self.id == other.id
157 }
158}
159impl Eq for Candidate {}
160impl PartialOrd for Candidate {
161 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162 self.dist.partial_cmp(&other.dist)
163 }
164}
165impl Ord for Candidate {
166 fn cmp(&self, other: &Self) -> Ordering {
167 self.partial_cmp(other).unwrap_or(Ordering::Equal)
168 }
169}
170
171pub struct FilteredDiskANN<D>
173where
174 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
175{
176 index: DiskANN<D>,
178 labels: Vec<Vec<u64>>,
180 num_fields: usize,
182 #[allow(dead_code)]
184 labels_path: String,
185}
186
187impl<D> FilteredDiskANN<D>
188where
189 D: Distance<f32> + Send + Sync + Copy + Clone + Default + 'static,
190{
191 pub fn build(
193 vectors: &[Vec<f32>],
194 labels: &[Vec<u64>],
195 base_path: &str,
196 ) -> Result<Self, DiskAnnError> {
197 Self::build_with_params(vectors, labels, base_path, DiskAnnParams::default())
198 }
199
200 pub fn build_with_params(
202 vectors: &[Vec<f32>],
203 labels: &[Vec<u64>],
204 base_path: &str,
205 params: DiskAnnParams,
206 ) -> Result<Self, DiskAnnError> {
207 if vectors.len() != labels.len() {
208 return Err(DiskAnnError::IndexError(format!(
209 "vectors.len() ({}) != labels.len() ({})",
210 vectors.len(),
211 labels.len()
212 )));
213 }
214
215 let num_fields = labels.first().map(|l| l.len()).unwrap_or(0);
216 for (i, l) in labels.iter().enumerate() {
217 if l.len() != num_fields {
218 return Err(DiskAnnError::IndexError(format!(
219 "Label {} has {} fields, expected {}",
220 i,
221 l.len(),
222 num_fields
223 )));
224 }
225 }
226
227 let index_path = format!("{}.idx", base_path);
229 let index = DiskANN::<D>::build_index_with_params(
230 vectors,
231 D::default(),
232 &index_path,
233 params,
234 )?;
235
236 let labels_path = format!("{}.labels", base_path);
238 Self::save_labels(&labels_path, labels, num_fields)?;
239
240 Ok(Self {
241 index,
242 labels: labels.to_vec(),
243 num_fields,
244 labels_path,
245 })
246 }
247
248 pub fn open(base_path: &str) -> Result<Self, DiskAnnError> {
250 let index_path = format!("{}.idx", base_path);
251 let labels_path = format!("{}.labels", base_path);
252
253 let index = DiskANN::<D>::open_index_default_metric(&index_path)?;
254 let (labels, num_fields) = Self::load_labels(&labels_path)?;
255
256 if labels.len() != index.num_vectors {
257 return Err(DiskAnnError::IndexError(format!(
258 "Labels count ({}) != index vectors ({})",
259 labels.len(),
260 index.num_vectors
261 )));
262 }
263
264 Ok(Self {
265 index,
266 labels,
267 num_fields,
268 labels_path,
269 })
270 }
271
272 pub fn to_bytes(&self) -> Vec<u8> {
276 let index_bytes = self.index.to_bytes();
277 let labels_bytes = Self::serialize_labels(&self.labels, self.num_fields);
278 let mut out = Vec::with_capacity(8 + index_bytes.len() + labels_bytes.len());
279 out.extend_from_slice(&(index_bytes.len() as u64).to_le_bytes());
280 out.extend_from_slice(&index_bytes);
281 out.extend_from_slice(&labels_bytes);
282 out
283 }
284
285 pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
287 if bytes.len() < 8 {
288 return Err(DiskAnnError::IndexError("Buffer too small".into()));
289 }
290 let index_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
291 if bytes.len() < 8 + index_len {
292 return Err(DiskAnnError::IndexError("Buffer too small for index data".into()));
293 }
294 let index_bytes = bytes[8..8 + index_len].to_vec();
295 let labels_bytes = &bytes[8 + index_len..];
296
297 let index = DiskANN::<D>::from_bytes(index_bytes, dist)?;
298 let (labels, num_fields) = Self::deserialize_labels(labels_bytes)?;
299
300 if labels.len() != index.num_vectors {
301 return Err(DiskAnnError::IndexError(format!(
302 "Labels count ({}) != index vectors ({})",
303 labels.len(),
304 index.num_vectors
305 )));
306 }
307
308 Ok(Self {
309 index,
310 labels,
311 num_fields,
312 labels_path: String::new(),
313 })
314 }
315
316 pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
318 if bytes.len() < 8 {
320 return Err(DiskAnnError::IndexError("Buffer too small".into()));
321 }
322 let index_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
323 if bytes.len() < 8 + index_len {
324 return Err(DiskAnnError::IndexError("Buffer too small for index data".into()));
325 }
326 let index_bytes = bytes[8..8 + index_len].to_vec();
327 let labels_bytes = &bytes[8 + index_len..];
328
329 let index = DiskANN::<D>::from_bytes(index_bytes, dist)?;
330 let (labels, num_fields) = Self::deserialize_labels(labels_bytes)?;
331
332 if labels.len() != index.num_vectors {
333 return Err(DiskAnnError::IndexError(format!(
334 "Labels count ({}) != index vectors ({})",
335 labels.len(),
336 index.num_vectors
337 )));
338 }
339
340 Ok(Self {
341 index,
342 labels,
343 num_fields,
344 labels_path: String::new(),
345 })
346 }
347
348 fn serialize_labels(labels: &[Vec<u64>], num_fields: usize) -> Vec<u8> {
349 let meta = FilteredMetadata {
350 num_vectors: labels.len(),
351 num_fields,
352 };
353 let meta_bytes = bincode::serialize(&meta).unwrap();
354 let mut out = Vec::new();
355 out.extend_from_slice(&(meta_bytes.len() as u64).to_le_bytes());
356 out.extend_from_slice(&meta_bytes);
357 for label_vec in labels {
358 for &val in label_vec {
359 out.extend_from_slice(&val.to_le_bytes());
360 }
361 }
362 out
363 }
364
365 fn deserialize_labels(bytes: &[u8]) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
366 if bytes.len() < 8 {
367 return Err(DiskAnnError::IndexError("Labels buffer too small".into()));
368 }
369 let meta_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
370 if bytes.len() < 8 + meta_len {
371 return Err(DiskAnnError::IndexError("Labels buffer too small for metadata".into()));
372 }
373 let meta: FilteredMetadata = bincode::deserialize(&bytes[8..8 + meta_len])?;
374
375 let data = &bytes[8 + meta_len..];
376 let mut labels = Vec::with_capacity(meta.num_vectors);
377 let mut offset = 0;
378 for _ in 0..meta.num_vectors {
379 let mut label_vec = Vec::with_capacity(meta.num_fields);
380 for _ in 0..meta.num_fields {
381 if offset + 8 > data.len() {
382 return Err(DiskAnnError::IndexError("Labels data truncated".into()));
383 }
384 label_vec.push(u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()));
385 offset += 8;
386 }
387 labels.push(label_vec);
388 }
389
390 Ok((labels, meta.num_fields))
391 }
392
393 fn save_labels(path: &str, labels: &[Vec<u64>], num_fields: usize) -> Result<(), DiskAnnError> {
394 let file = OpenOptions::new()
395 .create(true)
396 .write(true)
397 .truncate(true)
398 .open(path)?;
399 let mut writer = BufWriter::new(file);
400
401 let meta = FilteredMetadata {
402 num_vectors: labels.len(),
403 num_fields,
404 };
405 let meta_bytes = bincode::serialize(&meta)?;
406 writer.write_all(&(meta_bytes.len() as u64).to_le_bytes())?;
407 writer.write_all(&meta_bytes)?;
408
409 for label_vec in labels {
411 for &val in label_vec {
412 writer.write_all(&val.to_le_bytes())?;
413 }
414 }
415
416 writer.flush()?;
417 Ok(())
418 }
419
420 fn load_labels(path: &str) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
421 let file = File::open(path)?;
422 let mut reader = BufReader::new(file);
423
424 let mut len_buf = [0u8; 8];
426 reader.read_exact(&mut len_buf)?;
427 let meta_len = u64::from_le_bytes(len_buf) as usize;
428
429 let mut meta_bytes = vec![0u8; meta_len];
430 reader.read_exact(&mut meta_bytes)?;
431 let meta: FilteredMetadata = bincode::deserialize(&meta_bytes)?;
432
433 let mut labels = Vec::with_capacity(meta.num_vectors);
435 let mut val_buf = [0u8; 8];
436
437 for _ in 0..meta.num_vectors {
438 let mut label_vec = Vec::with_capacity(meta.num_fields);
439 for _ in 0..meta.num_fields {
440 reader.read_exact(&mut val_buf)?;
441 label_vec.push(u64::from_le_bytes(val_buf));
442 }
443 labels.push(label_vec);
444 }
445
446 Ok((labels, meta.num_fields))
447 }
448}
449
450impl<D> FilteredDiskANN<D>
451where
452 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
453{
454 pub fn search_filtered(
459 &self,
460 query: &[f32],
461 k: usize,
462 beam_width: usize,
463 filter: &Filter,
464 ) -> Vec<u32> {
465 self.search_filtered_with_dists(query, k, beam_width, filter)
466 .into_iter()
467 .map(|(id, _)| id)
468 .collect()
469 }
470
471 pub fn search_filtered_with_dists(
473 &self,
474 query: &[f32],
475 k: usize,
476 beam_width: usize,
477 filter: &Filter,
478 ) -> Vec<(u32, f32)> {
479 if matches!(filter, Filter::None) {
481 return self.index.search_with_dists(query, k, beam_width);
482 }
483
484 let expanded_beam = (beam_width * 4).max(k * 10);
487
488 let mut visited = HashSet::new();
489 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
490 let mut working_set: BinaryHeap<Candidate> = BinaryHeap::new();
491 let mut results: Vec<(u32, f32)> = Vec::with_capacity(k);
492
493 let start_dist = self.distance_to(query, self.index.medoid_id as usize);
495 let start = Candidate {
496 dist: start_dist,
497 id: self.index.medoid_id,
498 };
499 frontier.push(Reverse(start));
500 working_set.push(start);
501 visited.insert(self.index.medoid_id);
502
503 if filter.matches(&self.labels[self.index.medoid_id as usize]) {
505 results.push((self.index.medoid_id, start_dist));
506 }
507
508 let mut iterations = 0;
510 let max_iterations = expanded_beam * 2;
511
512 while let Some(Reverse(best)) = frontier.peek().copied() {
513 iterations += 1;
514 if iterations > max_iterations {
515 break;
516 }
517
518 if results.len() >= k {
521 if let Some((_, worst_dist)) = results.last() {
522 if best.dist > *worst_dist * 1.5 {
523 break;
524 }
525 }
526 }
527
528 if working_set.len() >= expanded_beam {
529 if let Some(worst) = working_set.peek() {
530 if best.dist >= worst.dist {
531 break;
532 }
533 }
534 }
535
536 let Reverse(current) = frontier.pop().unwrap();
537
538 for &nb in self.get_neighbors(current.id) {
540 if nb == u32::MAX {
541 continue;
542 }
543 if !visited.insert(nb) {
544 continue;
545 }
546
547 let d = self.distance_to(query, nb as usize);
548 let cand = Candidate { dist: d, id: nb };
549
550 if working_set.len() < expanded_beam {
552 working_set.push(cand);
553 frontier.push(Reverse(cand));
554 } else if d < working_set.peek().unwrap().dist {
555 working_set.pop();
556 working_set.push(cand);
557 frontier.push(Reverse(cand));
558 }
559
560 if filter.matches(&self.labels[nb as usize]) {
562 let pos = results
564 .iter()
565 .position(|(_, dist)| d < *dist)
566 .unwrap_or(results.len());
567
568 if pos < k {
569 results.insert(pos, (nb, d));
570 if results.len() > k {
571 results.pop();
572 }
573 }
574 }
575 }
576 }
577
578 results
579 }
580
581 pub fn search_filtered_batch(
583 &self,
584 queries: &[Vec<f32>],
585 k: usize,
586 beam_width: usize,
587 filter: &Filter,
588 ) -> Vec<Vec<u32>> {
589 queries
590 .par_iter()
591 .map(|q| self.search_filtered(q, k, beam_width, filter))
592 .collect()
593 }
594
595 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
597 self.index.search(query, k, beam_width)
598 }
599
600 pub fn get_labels(&self, id: usize) -> Option<&[u64]> {
602 self.labels.get(id).map(|v| v.as_slice())
603 }
604
605 pub fn inner(&self) -> &DiskANN<D> {
607 &self.index
608 }
609
610 pub fn num_vectors(&self) -> usize {
612 self.index.num_vectors
613 }
614
615 pub fn num_fields(&self) -> usize {
617 self.num_fields
618 }
619
620 pub fn count_matching(&self, filter: &Filter) -> usize {
622 self.labels.iter().filter(|l| filter.matches(l)).count()
623 }
624
625 fn get_neighbors(&self, node_id: u32) -> &[u32] {
626 let offset = self.index.adjacency_offset
628 + (node_id as u64 * self.index.max_degree as u64 * 4);
629 let start = offset as usize;
630 let end = start + (self.index.max_degree * 4);
631 let bytes = &self.index.storage[start..end];
632 bytemuck::cast_slice(bytes)
633 }
634
635 fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
636 let offset = self.index.vectors_offset + (idx as u64 * self.index.dim as u64 * 4);
637 let start = offset as usize;
638 let end = start + (self.index.dim * 4);
639 let bytes = &self.index.storage[start..end];
640 let vector: &[f32] = bytemuck::cast_slice(bytes);
641 self.index.dist.eval(query, vector)
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use anndists::dist::DistL2;
649 use std::fs;
650
651 #[test]
652 fn test_filter_eq() {
653 let filter = Filter::label_eq(0, 5);
654 assert!(filter.matches(&[5, 10]));
655 assert!(!filter.matches(&[4, 10]));
656 assert!(!filter.matches(&[]));
657 }
658
659 #[test]
660 fn test_filter_in() {
661 let filter = Filter::label_in(0, vec![1, 3, 5]);
662 assert!(filter.matches(&[1]));
663 assert!(filter.matches(&[3]));
664 assert!(filter.matches(&[5]));
665 assert!(!filter.matches(&[2]));
666 }
667
668 #[test]
669 fn test_filter_range() {
670 let filter = Filter::label_range(0, 10, 20);
671 assert!(filter.matches(&[10]));
672 assert!(filter.matches(&[15]));
673 assert!(filter.matches(&[20]));
674 assert!(!filter.matches(&[9]));
675 assert!(!filter.matches(&[21]));
676 }
677
678 #[test]
679 fn test_filter_and() {
680 let filter = Filter::and(vec![
681 Filter::label_eq(0, 5),
682 Filter::label_gt(1, 10),
683 ]);
684 assert!(filter.matches(&[5, 15]));
685 assert!(!filter.matches(&[5, 5]));
686 assert!(!filter.matches(&[4, 15]));
687 }
688
689 #[test]
690 fn test_filter_or() {
691 let filter = Filter::or(vec![
692 Filter::label_eq(0, 5),
693 Filter::label_eq(0, 10),
694 ]);
695 assert!(filter.matches(&[5]));
696 assert!(filter.matches(&[10]));
697 assert!(!filter.matches(&[7]));
698 }
699
700 #[test]
701 fn test_filtered_search_basic() {
702 let base_path = "test_filtered";
703 let _ = fs::remove_file(format!("{}.idx", base_path));
704 let _ = fs::remove_file(format!("{}.labels", base_path));
705
706 let vectors: Vec<Vec<f32>> = (0..100)
708 .map(|i| vec![i as f32, (i * 2) as f32])
709 .collect();
710
711 let labels: Vec<Vec<u64>> = (0..100)
713 .map(|i| vec![i % 5])
714 .collect();
715
716 let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
717
718 let results = index.search(&[50.0, 100.0], 5, 32);
720 assert_eq!(results.len(), 5);
721
722 let filter = Filter::label_eq(0, 0);
724 let results = index.search_filtered(&[50.0, 100.0], 5, 32, &filter);
725
726 for id in &results {
728 assert_eq!(labels[*id as usize][0], 0);
729 }
730
731 let _ = fs::remove_file(format!("{}.idx", base_path));
732 let _ = fs::remove_file(format!("{}.labels", base_path));
733 }
734
735 #[test]
736 fn test_filtered_search_selectivity() {
737 let base_path = "test_filtered_sel";
738 let _ = fs::remove_file(format!("{}.idx", base_path));
739 let _ = fs::remove_file(format!("{}.labels", base_path));
740
741 let vectors: Vec<Vec<f32>> = (0..1000)
743 .map(|i| vec![(i % 100) as f32, ((i / 100) * 10) as f32])
744 .collect();
745
746 let labels: Vec<Vec<u64>> = (0..1000)
747 .map(|i| vec![i % 10]) .collect();
749
750 let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
751
752 let filter = Filter::label_eq(0, 3);
754 assert_eq!(index.count_matching(&filter), 100);
755
756 let results = index.search_filtered(&[50.0, 50.0], 10, 64, &filter);
758 assert!(results.len() <= 10);
759
760 for id in &results {
761 assert_eq!(labels[*id as usize][0], 3);
762 }
763
764 let _ = fs::remove_file(format!("{}.idx", base_path));
765 let _ = fs::remove_file(format!("{}.labels", base_path));
766 }
767
768 #[test]
769 fn test_filtered_persistence() {
770 let base_path = "test_filtered_persist";
771 let _ = fs::remove_file(format!("{}.idx", base_path));
772 let _ = fs::remove_file(format!("{}.labels", base_path));
773
774 let vectors: Vec<Vec<f32>> = (0..50)
775 .map(|i| vec![i as f32, i as f32])
776 .collect();
777 let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3, i]).collect();
778
779 {
780 let _index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
781 }
782
783 let index = FilteredDiskANN::<DistL2>::open(base_path).unwrap();
785 assert_eq!(index.num_vectors(), 50);
786 assert_eq!(index.num_fields(), 2);
787
788 let filter = Filter::label_eq(0, 1);
789 let results = index.search_filtered(&[25.0, 25.0], 5, 32, &filter);
790 for id in &results {
791 assert_eq!(index.get_labels(*id as usize).unwrap()[0], 1);
792 }
793
794 let _ = fs::remove_file(format!("{}.idx", base_path));
795 let _ = fs::remove_file(format!("{}.labels", base_path));
796 }
797
798 #[test]
799 fn test_filtered_to_bytes_from_bytes() {
800 let base_path = "test_filtered_bytes_rt";
801 let _ = fs::remove_file(format!("{}.idx", base_path));
802 let _ = fs::remove_file(format!("{}.labels", base_path));
803
804 let vectors: Vec<Vec<f32>> = (0..50)
805 .map(|i| vec![i as f32, i as f32])
806 .collect();
807 let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3]).collect();
808
809 let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
810 let bytes = index.to_bytes();
811
812 let index2 = FilteredDiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
813 assert_eq!(index2.num_vectors(), 50);
814 assert_eq!(index2.num_fields(), 1);
815
816 let filter = Filter::label_eq(0, 1);
817 let results = index2.search_filtered(&[25.0, 25.0], 5, 32, &filter);
818 for id in &results {
819 assert_eq!(index2.get_labels(*id as usize).unwrap()[0], 1);
820 }
821
822 let _ = fs::remove_file(format!("{}.idx", base_path));
823 let _ = fs::remove_file(format!("{}.labels", base_path));
824 }
825}