1use crate::{beam_search, BeamSearchConfig, GraphIndex, DiskANN, DiskAnnError, DiskAnnParams};
46use anndists::prelude::Distance;
47use rayon::prelude::*;
48use serde::{Deserialize, Serialize};
49use std::collections::HashSet;
50use std::fs::{File, OpenOptions};
51use std::io::{BufReader, BufWriter, Read, Write};
52use std::sync::Arc;
53
54pub(crate) fn filtered_search(
59 graph: &dyn GraphIndex,
60 labels: &[Vec<u64>],
61 start_ids: &[u32],
62 query: &[f32],
63 k: usize,
64 beam_width: usize,
65 filter: &Filter,
66) -> Vec<(u32, f32)> {
67 if matches!(filter, Filter::None) {
68 return beam_search(
69 start_ids,
70 beam_width,
71 k,
72 |id| graph.distance_to(query, id),
73 |id| graph.get_neighbors(id),
74 |_| true,
75 BeamSearchConfig::default(),
76 );
77 }
78
79 let expanded_beam = (beam_width * 4).max(k * 10);
80
81 beam_search(
82 start_ids,
83 beam_width,
84 k,
85 |id| graph.distance_to(query, id),
86 |id| graph.get_neighbors(id),
87 |id| {
88 let idx = id as usize;
89 if idx < labels.len() {
90 filter.matches(&labels[idx])
91 } else {
92 false
93 }
94 },
95 BeamSearchConfig {
96 expanded_beam: Some(expanded_beam),
97 max_iterations: Some(expanded_beam * 2),
98 early_term_factor: Some(1.5),
99 },
100 )
101}
102
103#[derive(Clone, Debug)]
105pub enum Filter {
106 LabelEq { field: usize, value: u64 },
108 LabelIn { field: usize, values: HashSet<u64> },
110 LabelLt { field: usize, value: u64 },
112 LabelGt { field: usize, value: u64 },
114 LabelRange { field: usize, min: u64, max: u64 },
116 And(Vec<Filter>),
118 Or(Vec<Filter>),
120 None,
122}
123
124impl Filter {
125 pub fn label_eq(field: usize, value: u64) -> Self {
127 Filter::LabelEq { field, value }
128 }
129
130 pub fn label_in(field: usize, values: impl IntoIterator<Item = u64>) -> Self {
132 Filter::LabelIn {
133 field,
134 values: values.into_iter().collect(),
135 }
136 }
137
138 pub fn label_lt(field: usize, value: u64) -> Self {
140 Filter::LabelLt { field, value }
141 }
142
143 pub fn label_gt(field: usize, value: u64) -> Self {
145 Filter::LabelGt { field, value }
146 }
147
148 pub fn label_range(field: usize, min: u64, max: u64) -> Self {
150 Filter::LabelRange { field, min, max }
151 }
152
153 pub fn and(filters: Vec<Filter>) -> Self {
155 Filter::And(filters)
156 }
157
158 pub fn or(filters: Vec<Filter>) -> Self {
160 Filter::Or(filters)
161 }
162
163 pub fn matches(&self, labels: &[u64]) -> bool {
165 match self {
166 Filter::None => true,
167 Filter::LabelEq { field, value } => {
168 labels.get(*field).map_or(false, |v| v == value)
169 }
170 Filter::LabelIn { field, values } => {
171 labels.get(*field).map_or(false, |v| values.contains(v))
172 }
173 Filter::LabelLt { field, value } => {
174 labels.get(*field).map_or(false, |v| v < value)
175 }
176 Filter::LabelGt { field, value } => {
177 labels.get(*field).map_or(false, |v| v > value)
178 }
179 Filter::LabelRange { field, min, max } => {
180 labels.get(*field).map_or(false, |v| v >= min && v <= max)
181 }
182 Filter::And(filters) => filters.iter().all(|f| f.matches(labels)),
183 Filter::Or(filters) => filters.iter().any(|f| f.matches(labels)),
184 }
185 }
186}
187
188#[derive(Serialize, Deserialize, Debug)]
190struct FilteredMetadata {
191 num_vectors: usize,
192 num_fields: usize,
193}
194
195pub struct FilteredDiskANN<D>
197where
198 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
199{
200 index: DiskANN<D>,
202 labels: Vec<Vec<u64>>,
204 num_fields: usize,
206 #[allow(dead_code)]
208 labels_path: String,
209}
210
211impl<D> FilteredDiskANN<D>
212where
213 D: Distance<f32> + Send + Sync + Copy + Clone + Default + 'static,
214{
215 pub fn build(
217 vectors: &[Vec<f32>],
218 labels: &[Vec<u64>],
219 base_path: &str,
220 ) -> Result<Self, DiskAnnError> {
221 Self::build_with_params(vectors, labels, base_path, DiskAnnParams::default())
222 }
223
224 pub fn build_with_params(
226 vectors: &[Vec<f32>],
227 labels: &[Vec<u64>],
228 base_path: &str,
229 params: DiskAnnParams,
230 ) -> Result<Self, DiskAnnError> {
231 if vectors.len() != labels.len() {
232 return Err(DiskAnnError::IndexError(format!(
233 "vectors.len() ({}) != labels.len() ({})",
234 vectors.len(),
235 labels.len()
236 )));
237 }
238
239 let num_fields = labels.first().map(|l| l.len()).unwrap_or(0);
240 for (i, l) in labels.iter().enumerate() {
241 if l.len() != num_fields {
242 return Err(DiskAnnError::IndexError(format!(
243 "Label {} has {} fields, expected {}",
244 i,
245 l.len(),
246 num_fields
247 )));
248 }
249 }
250
251 let index_path = format!("{}.idx", base_path);
253 let index = DiskANN::<D>::build_index_with_params(
254 vectors,
255 D::default(),
256 &index_path,
257 params,
258 )?;
259
260 let labels_path = format!("{}.labels", base_path);
262 Self::save_labels(&labels_path, labels, num_fields)?;
263
264 Ok(Self {
265 index,
266 labels: labels.to_vec(),
267 num_fields,
268 labels_path,
269 })
270 }
271
272 pub fn open(base_path: &str) -> Result<Self, DiskAnnError> {
274 let index_path = format!("{}.idx", base_path);
275 let labels_path = format!("{}.labels", base_path);
276
277 let index = DiskANN::<D>::open_index_default_metric(&index_path)?;
278 let (labels, num_fields) = Self::load_labels(&labels_path)?;
279
280 if labels.len() != index.num_vectors {
281 return Err(DiskAnnError::IndexError(format!(
282 "Labels count ({}) != index vectors ({})",
283 labels.len(),
284 index.num_vectors
285 )));
286 }
287
288 Ok(Self {
289 index,
290 labels,
291 num_fields,
292 labels_path,
293 })
294 }
295
296 pub fn to_bytes(&self) -> Vec<u8> {
300 let index_bytes = self.index.to_bytes();
301 let labels_bytes = Self::serialize_labels(&self.labels, self.num_fields);
302 let mut out = Vec::with_capacity(8 + index_bytes.len() + labels_bytes.len());
303 out.extend_from_slice(&(index_bytes.len() as u64).to_le_bytes());
304 out.extend_from_slice(&index_bytes);
305 out.extend_from_slice(&labels_bytes);
306 out
307 }
308
309 pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
311 if bytes.len() < 8 {
312 return Err(DiskAnnError::IndexError("Buffer too small".into()));
313 }
314 let index_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
315 if bytes.len() < 8 + index_len {
316 return Err(DiskAnnError::IndexError("Buffer too small for index data".into()));
317 }
318 let index_bytes = bytes[8..8 + index_len].to_vec();
319 let labels_bytes = &bytes[8 + index_len..];
320
321 let index = DiskANN::<D>::from_bytes(index_bytes, dist)?;
322 let (labels, num_fields) = Self::deserialize_labels(labels_bytes)?;
323
324 if labels.len() != index.num_vectors {
325 return Err(DiskAnnError::IndexError(format!(
326 "Labels count ({}) != index vectors ({})",
327 labels.len(),
328 index.num_vectors
329 )));
330 }
331
332 Ok(Self {
333 index,
334 labels,
335 num_fields,
336 labels_path: String::new(),
337 })
338 }
339
340 pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
342 if bytes.len() < 8 {
344 return Err(DiskAnnError::IndexError("Buffer too small".into()));
345 }
346 let index_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
347 if bytes.len() < 8 + index_len {
348 return Err(DiskAnnError::IndexError("Buffer too small for index data".into()));
349 }
350 let index_bytes = bytes[8..8 + index_len].to_vec();
351 let labels_bytes = &bytes[8 + index_len..];
352
353 let index = DiskANN::<D>::from_bytes(index_bytes, dist)?;
354 let (labels, num_fields) = Self::deserialize_labels(labels_bytes)?;
355
356 if labels.len() != index.num_vectors {
357 return Err(DiskAnnError::IndexError(format!(
358 "Labels count ({}) != index vectors ({})",
359 labels.len(),
360 index.num_vectors
361 )));
362 }
363
364 Ok(Self {
365 index,
366 labels,
367 num_fields,
368 labels_path: String::new(),
369 })
370 }
371
372 fn serialize_labels(labels: &[Vec<u64>], num_fields: usize) -> Vec<u8> {
373 let meta = FilteredMetadata {
374 num_vectors: labels.len(),
375 num_fields,
376 };
377 let meta_bytes = bincode::serialize(&meta).unwrap();
378 let mut out = Vec::new();
379 out.extend_from_slice(&(meta_bytes.len() as u64).to_le_bytes());
380 out.extend_from_slice(&meta_bytes);
381 for label_vec in labels {
382 for &val in label_vec {
383 out.extend_from_slice(&val.to_le_bytes());
384 }
385 }
386 out
387 }
388
389 fn deserialize_labels(bytes: &[u8]) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
390 if bytes.len() < 8 {
391 return Err(DiskAnnError::IndexError("Labels buffer too small".into()));
392 }
393 let meta_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
394 if bytes.len() < 8 + meta_len {
395 return Err(DiskAnnError::IndexError("Labels buffer too small for metadata".into()));
396 }
397 let meta: FilteredMetadata = bincode::deserialize(&bytes[8..8 + meta_len])?;
398
399 let data = &bytes[8 + meta_len..];
400 let mut labels = Vec::with_capacity(meta.num_vectors);
401 let mut offset = 0;
402 for _ in 0..meta.num_vectors {
403 let mut label_vec = Vec::with_capacity(meta.num_fields);
404 for _ in 0..meta.num_fields {
405 if offset + 8 > data.len() {
406 return Err(DiskAnnError::IndexError("Labels data truncated".into()));
407 }
408 label_vec.push(u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()));
409 offset += 8;
410 }
411 labels.push(label_vec);
412 }
413
414 Ok((labels, meta.num_fields))
415 }
416
417 fn save_labels(path: &str, labels: &[Vec<u64>], num_fields: usize) -> Result<(), DiskAnnError> {
418 let file = OpenOptions::new()
419 .create(true)
420 .write(true)
421 .truncate(true)
422 .open(path)?;
423 let mut writer = BufWriter::new(file);
424
425 let meta = FilteredMetadata {
426 num_vectors: labels.len(),
427 num_fields,
428 };
429 let meta_bytes = bincode::serialize(&meta)?;
430 writer.write_all(&(meta_bytes.len() as u64).to_le_bytes())?;
431 writer.write_all(&meta_bytes)?;
432
433 for label_vec in labels {
435 for &val in label_vec {
436 writer.write_all(&val.to_le_bytes())?;
437 }
438 }
439
440 writer.flush()?;
441 Ok(())
442 }
443
444 fn load_labels(path: &str) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
445 let file = File::open(path)?;
446 let mut reader = BufReader::new(file);
447
448 let mut len_buf = [0u8; 8];
450 reader.read_exact(&mut len_buf)?;
451 let meta_len = u64::from_le_bytes(len_buf) as usize;
452
453 let mut meta_bytes = vec![0u8; meta_len];
454 reader.read_exact(&mut meta_bytes)?;
455 let meta: FilteredMetadata = bincode::deserialize(&meta_bytes)?;
456
457 let mut labels = Vec::with_capacity(meta.num_vectors);
459 let mut val_buf = [0u8; 8];
460
461 for _ in 0..meta.num_vectors {
462 let mut label_vec = Vec::with_capacity(meta.num_fields);
463 for _ in 0..meta.num_fields {
464 reader.read_exact(&mut val_buf)?;
465 label_vec.push(u64::from_le_bytes(val_buf));
466 }
467 labels.push(label_vec);
468 }
469
470 Ok((labels, meta.num_fields))
471 }
472}
473
474impl<D> FilteredDiskANN<D>
475where
476 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
477{
478 pub fn search_filtered(
483 &self,
484 query: &[f32],
485 k: usize,
486 beam_width: usize,
487 filter: &Filter,
488 ) -> Vec<u32> {
489 self.search_filtered_with_dists(query, k, beam_width, filter)
490 .into_iter()
491 .map(|(id, _)| id)
492 .collect()
493 }
494
495 pub fn search_filtered_with_dists(
497 &self,
498 query: &[f32],
499 k: usize,
500 beam_width: usize,
501 filter: &Filter,
502 ) -> Vec<(u32, f32)> {
503 filtered_search(
504 &self.index,
505 &self.labels,
506 &[self.index.medoid_id],
507 query,
508 k,
509 beam_width,
510 filter,
511 )
512 }
513
514 pub fn search_filtered_batch(
516 &self,
517 queries: &[Vec<f32>],
518 k: usize,
519 beam_width: usize,
520 filter: &Filter,
521 ) -> Vec<Vec<u32>> {
522 queries
523 .par_iter()
524 .map(|q| self.search_filtered(q, k, beam_width, filter))
525 .collect()
526 }
527
528 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
530 self.index.search(query, k, beam_width)
531 }
532
533 pub fn get_labels(&self, id: usize) -> Option<&[u64]> {
535 self.labels.get(id).map(|v| v.as_slice())
536 }
537
538 pub fn inner(&self) -> &DiskANN<D> {
540 &self.index
541 }
542
543 pub fn num_vectors(&self) -> usize {
545 self.index.num_vectors
546 }
547
548 pub fn num_fields(&self) -> usize {
550 self.num_fields
551 }
552
553 pub fn count_matching(&self, filter: &Filter) -> usize {
555 self.labels.iter().filter(|l| filter.matches(l)).count()
556 }
557
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use anndists::dist::DistL2;
564 use std::fs;
565
566 #[test]
567 fn test_filter_eq() {
568 let filter = Filter::label_eq(0, 5);
569 assert!(filter.matches(&[5, 10]));
570 assert!(!filter.matches(&[4, 10]));
571 assert!(!filter.matches(&[]));
572 }
573
574 #[test]
575 fn test_filter_in() {
576 let filter = Filter::label_in(0, vec![1, 3, 5]);
577 assert!(filter.matches(&[1]));
578 assert!(filter.matches(&[3]));
579 assert!(filter.matches(&[5]));
580 assert!(!filter.matches(&[2]));
581 }
582
583 #[test]
584 fn test_filter_range() {
585 let filter = Filter::label_range(0, 10, 20);
586 assert!(filter.matches(&[10]));
587 assert!(filter.matches(&[15]));
588 assert!(filter.matches(&[20]));
589 assert!(!filter.matches(&[9]));
590 assert!(!filter.matches(&[21]));
591 }
592
593 #[test]
594 fn test_filter_and() {
595 let filter = Filter::and(vec![
596 Filter::label_eq(0, 5),
597 Filter::label_gt(1, 10),
598 ]);
599 assert!(filter.matches(&[5, 15]));
600 assert!(!filter.matches(&[5, 5]));
601 assert!(!filter.matches(&[4, 15]));
602 }
603
604 #[test]
605 fn test_filter_or() {
606 let filter = Filter::or(vec![
607 Filter::label_eq(0, 5),
608 Filter::label_eq(0, 10),
609 ]);
610 assert!(filter.matches(&[5]));
611 assert!(filter.matches(&[10]));
612 assert!(!filter.matches(&[7]));
613 }
614
615 #[test]
616 fn test_filtered_search_basic() {
617 let base_path = "test_filtered";
618 let _ = fs::remove_file(format!("{}.idx", base_path));
619 let _ = fs::remove_file(format!("{}.labels", base_path));
620
621 let vectors: Vec<Vec<f32>> = (0..100)
623 .map(|i| vec![i as f32, (i * 2) as f32])
624 .collect();
625
626 let labels: Vec<Vec<u64>> = (0..100)
628 .map(|i| vec![i % 5])
629 .collect();
630
631 let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
632
633 let results = index.search(&[50.0, 100.0], 5, 32);
635 assert_eq!(results.len(), 5);
636
637 let filter = Filter::label_eq(0, 0);
639 let results = index.search_filtered(&[50.0, 100.0], 5, 32, &filter);
640
641 for id in &results {
643 assert_eq!(labels[*id as usize][0], 0);
644 }
645
646 let _ = fs::remove_file(format!("{}.idx", base_path));
647 let _ = fs::remove_file(format!("{}.labels", base_path));
648 }
649
650 #[test]
651 fn test_filtered_search_selectivity() {
652 let base_path = "test_filtered_sel";
653 let _ = fs::remove_file(format!("{}.idx", base_path));
654 let _ = fs::remove_file(format!("{}.labels", base_path));
655
656 let vectors: Vec<Vec<f32>> = (0..1000)
658 .map(|i| vec![(i % 100) as f32, ((i / 100) * 10) as f32])
659 .collect();
660
661 let labels: Vec<Vec<u64>> = (0..1000)
662 .map(|i| vec![i % 10]) .collect();
664
665 let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
666
667 let filter = Filter::label_eq(0, 3);
669 assert_eq!(index.count_matching(&filter), 100);
670
671 let results = index.search_filtered(&[50.0, 50.0], 10, 64, &filter);
673 assert!(results.len() <= 10);
674
675 for id in &results {
676 assert_eq!(labels[*id as usize][0], 3);
677 }
678
679 let _ = fs::remove_file(format!("{}.idx", base_path));
680 let _ = fs::remove_file(format!("{}.labels", base_path));
681 }
682
683 #[test]
684 fn test_filtered_persistence() {
685 let base_path = "test_filtered_persist";
686 let _ = fs::remove_file(format!("{}.idx", base_path));
687 let _ = fs::remove_file(format!("{}.labels", base_path));
688
689 let vectors: Vec<Vec<f32>> = (0..50)
690 .map(|i| vec![i as f32, i as f32])
691 .collect();
692 let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3, i]).collect();
693
694 {
695 let _index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
696 }
697
698 let index = FilteredDiskANN::<DistL2>::open(base_path).unwrap();
700 assert_eq!(index.num_vectors(), 50);
701 assert_eq!(index.num_fields(), 2);
702
703 let filter = Filter::label_eq(0, 1);
704 let results = index.search_filtered(&[25.0, 25.0], 5, 32, &filter);
705 for id in &results {
706 assert_eq!(index.get_labels(*id as usize).unwrap()[0], 1);
707 }
708
709 let _ = fs::remove_file(format!("{}.idx", base_path));
710 let _ = fs::remove_file(format!("{}.labels", base_path));
711 }
712
713 #[test]
714 fn test_filtered_to_bytes_from_bytes() {
715 let base_path = "test_filtered_bytes_rt";
716 let _ = fs::remove_file(format!("{}.idx", base_path));
717 let _ = fs::remove_file(format!("{}.labels", base_path));
718
719 let vectors: Vec<Vec<f32>> = (0..50)
720 .map(|i| vec![i as f32, i as f32])
721 .collect();
722 let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3]).collect();
723
724 let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
725 let bytes = index.to_bytes();
726
727 let index2 = FilteredDiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
728 assert_eq!(index2.num_vectors(), 50);
729 assert_eq!(index2.num_fields(), 1);
730
731 let filter = Filter::label_eq(0, 1);
732 let results = index2.search_filtered(&[25.0, 25.0], 5, 32, &filter);
733 for id in &results {
734 assert_eq!(index2.get_labels(*id as usize).unwrap()[0], 1);
735 }
736
737 let _ = fs::remove_file(format!("{}.idx", base_path));
738 let _ = fs::remove_file(format!("{}.labels", base_path));
739 }
740}