diskann_rs/
filtered.rs

1//! # Filtered DiskANN Search
2//!
3//! Enables searching with metadata predicates, e.g.:
4//! "Find 10 nearest neighbors WHERE category = 'electronics' AND price < 100"
5//!
6//! ## Architecture
7//!
8//! ```text
9//! ┌────────────────────────────────────────────────────────────┐
10//! │                    FilteredDiskANN                         │
11//! ├────────────────────────────────────────────────────────────┤
12//! │  ┌──────────────────┐    ┌─────────────────────────────┐   │
13//! │  │   DiskANN Index  │    │     Metadata Store          │   │
14//! │  │   (vectors +     │    │   (labels per vector)       │   │
15//! │  │    graph)        │    │   - numeric fields          │   │
16//! │  │                  │    │   - string labels           │   │
17//! │  └──────────────────┘    └─────────────────────────────┘   │
18//! └────────────────────────────────────────────────────────────┘
19//!
20//! Search: expand beam, but skip candidates that don't match filter
21//! ```
22//!
23//! ## Usage
24//!
25//! ```ignore
26//! use anndists::dist::DistL2;
27//! use diskann_rs::{FilteredDiskANN, Filter};
28//!
29//! // Build index with metadata
30//! let vectors = vec![vec![0.0; 128]; 1000];
31//! let labels: Vec<Vec<u64>> = (0..1000).map(|i| vec![i % 10]).collect(); // 10 categories
32//!
33//! let index = FilteredDiskANN::<DistL2>::build(
34//!     &vectors,
35//!     &labels,
36//!     "filtered.db"
37//! ).unwrap();
38//!
39//! // Search with filter: only category 5
40//! let query = vec![0.0f32; 128];
41//! let filter = Filter::label_eq(0, 5); // field 0 == 5
42//! let results = index.search_filtered(&query, 10, 128, &filter);
43//! ```
44
45use crate::{DiskANN, DiskAnnError, DiskAnnParams};
46use anndists::prelude::Distance;
47use rayon::prelude::*;
48use serde::{Deserialize, Serialize};
49use std::collections::{BinaryHeap, HashSet};
50use std::cmp::{Ordering, Reverse};
51use std::fs::{File, OpenOptions};
52use std::io::{BufReader, BufWriter, Read, Write};
53
54/// A single filter condition
55#[derive(Clone, Debug)]
56pub enum Filter {
57    /// Label at field index equals value
58    LabelEq { field: usize, value: u64 },
59    /// Label at field index is in set
60    LabelIn { field: usize, values: HashSet<u64> },
61    /// Label at field index less than value
62    LabelLt { field: usize, value: u64 },
63    /// Label at field index greater than value
64    LabelGt { field: usize, value: u64 },
65    /// Label at field index in range [min, max]
66    LabelRange { field: usize, min: u64, max: u64 },
67    /// Logical AND of filters
68    And(Vec<Filter>),
69    /// Logical OR of filters
70    Or(Vec<Filter>),
71    /// No filter (match all)
72    None,
73}
74
75impl Filter {
76    /// Create equality filter
77    pub fn label_eq(field: usize, value: u64) -> Self {
78        Filter::LabelEq { field, value }
79    }
80
81    /// Create "in set" filter
82    pub fn label_in(field: usize, values: impl IntoIterator<Item = u64>) -> Self {
83        Filter::LabelIn {
84            field,
85            values: values.into_iter().collect(),
86        }
87    }
88
89    /// Create less-than filter
90    pub fn label_lt(field: usize, value: u64) -> Self {
91        Filter::LabelLt { field, value }
92    }
93
94    /// Create greater-than filter
95    pub fn label_gt(field: usize, value: u64) -> Self {
96        Filter::LabelGt { field, value }
97    }
98
99    /// Create range filter [min, max] inclusive
100    pub fn label_range(field: usize, min: u64, max: u64) -> Self {
101        Filter::LabelRange { field, min, max }
102    }
103
104    /// Combine filters with AND
105    pub fn and(filters: Vec<Filter>) -> Self {
106        Filter::And(filters)
107    }
108
109    /// Combine filters with OR
110    pub fn or(filters: Vec<Filter>) -> Self {
111        Filter::Or(filters)
112    }
113
114    /// Check if a label vector matches this filter
115    pub fn matches(&self, labels: &[u64]) -> bool {
116        match self {
117            Filter::None => true,
118            Filter::LabelEq { field, value } => {
119                labels.get(*field).map_or(false, |v| v == value)
120            }
121            Filter::LabelIn { field, values } => {
122                labels.get(*field).map_or(false, |v| values.contains(v))
123            }
124            Filter::LabelLt { field, value } => {
125                labels.get(*field).map_or(false, |v| v < value)
126            }
127            Filter::LabelGt { field, value } => {
128                labels.get(*field).map_or(false, |v| v > value)
129            }
130            Filter::LabelRange { field, min, max } => {
131                labels.get(*field).map_or(false, |v| v >= min && v <= max)
132            }
133            Filter::And(filters) => filters.iter().all(|f| f.matches(labels)),
134            Filter::Or(filters) => filters.iter().any(|f| f.matches(labels)),
135        }
136    }
137}
138
139/// Metadata for filtered index
140#[derive(Serialize, Deserialize, Debug)]
141struct FilteredMetadata {
142    num_vectors: usize,
143    num_fields: usize,
144}
145
146/// Candidate for filtered search
147#[derive(Clone, Copy)]
148struct Candidate {
149    dist: f32,
150    id: u32,
151}
152
153impl PartialEq for Candidate {
154    fn eq(&self, other: &Self) -> bool {
155        self.dist == other.dist && self.id == other.id
156    }
157}
158impl Eq for Candidate {}
159impl PartialOrd for Candidate {
160    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161        self.dist.partial_cmp(&other.dist)
162    }
163}
164impl Ord for Candidate {
165    fn cmp(&self, other: &Self) -> Ordering {
166        self.partial_cmp(other).unwrap_or(Ordering::Equal)
167    }
168}
169
170/// DiskANN index with metadata filtering support
171pub struct FilteredDiskANN<D>
172where
173    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
174{
175    /// The underlying vector index
176    index: DiskANN<D>,
177    /// Labels for each vector: labels[vector_id] = [field0, field1, ...]
178    labels: Vec<Vec<u64>>,
179    /// Number of label fields per vector
180    num_fields: usize,
181    /// Path to labels file (kept for potential future persistence)
182    #[allow(dead_code)]
183    labels_path: String,
184}
185
186impl<D> FilteredDiskANN<D>
187where
188    D: Distance<f32> + Send + Sync + Copy + Clone + Default + 'static,
189{
190    /// Build a new filtered index
191    pub fn build(
192        vectors: &[Vec<f32>],
193        labels: &[Vec<u64>],
194        base_path: &str,
195    ) -> Result<Self, DiskAnnError> {
196        Self::build_with_params(vectors, labels, base_path, DiskAnnParams::default())
197    }
198
199    /// Build with custom parameters
200    pub fn build_with_params(
201        vectors: &[Vec<f32>],
202        labels: &[Vec<u64>],
203        base_path: &str,
204        params: DiskAnnParams,
205    ) -> Result<Self, DiskAnnError> {
206        if vectors.len() != labels.len() {
207            return Err(DiskAnnError::IndexError(format!(
208                "vectors.len() ({}) != labels.len() ({})",
209                vectors.len(),
210                labels.len()
211            )));
212        }
213
214        let num_fields = labels.first().map(|l| l.len()).unwrap_or(0);
215        for (i, l) in labels.iter().enumerate() {
216            if l.len() != num_fields {
217                return Err(DiskAnnError::IndexError(format!(
218                    "Label {} has {} fields, expected {}",
219                    i,
220                    l.len(),
221                    num_fields
222                )));
223            }
224        }
225
226        // Build the vector index
227        let index_path = format!("{}.idx", base_path);
228        let index = DiskANN::<D>::build_index_with_params(
229            vectors,
230            D::default(),
231            &index_path,
232            params,
233        )?;
234
235        // Save labels
236        let labels_path = format!("{}.labels", base_path);
237        Self::save_labels(&labels_path, labels, num_fields)?;
238
239        Ok(Self {
240            index,
241            labels: labels.to_vec(),
242            num_fields,
243            labels_path,
244        })
245    }
246
247    /// Open an existing filtered index
248    pub fn open(base_path: &str) -> Result<Self, DiskAnnError> {
249        let index_path = format!("{}.idx", base_path);
250        let labels_path = format!("{}.labels", base_path);
251
252        let index = DiskANN::<D>::open_index_default_metric(&index_path)?;
253        let (labels, num_fields) = Self::load_labels(&labels_path)?;
254
255        if labels.len() != index.num_vectors {
256            return Err(DiskAnnError::IndexError(format!(
257                "Labels count ({}) != index vectors ({})",
258                labels.len(),
259                index.num_vectors
260            )));
261        }
262
263        Ok(Self {
264            index,
265            labels,
266            num_fields,
267            labels_path,
268        })
269    }
270
271    fn save_labels(path: &str, labels: &[Vec<u64>], num_fields: usize) -> Result<(), DiskAnnError> {
272        let file = OpenOptions::new()
273            .create(true)
274            .write(true)
275            .truncate(true)
276            .open(path)?;
277        let mut writer = BufWriter::new(file);
278
279        let meta = FilteredMetadata {
280            num_vectors: labels.len(),
281            num_fields,
282        };
283        let meta_bytes = bincode::serialize(&meta)?;
284        writer.write_all(&(meta_bytes.len() as u64).to_le_bytes())?;
285        writer.write_all(&meta_bytes)?;
286
287        // Write labels as flat array
288        for label_vec in labels {
289            for &val in label_vec {
290                writer.write_all(&val.to_le_bytes())?;
291            }
292        }
293
294        writer.flush()?;
295        Ok(())
296    }
297
298    fn load_labels(path: &str) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
299        let file = File::open(path)?;
300        let mut reader = BufReader::new(file);
301
302        // Read metadata
303        let mut len_buf = [0u8; 8];
304        reader.read_exact(&mut len_buf)?;
305        let meta_len = u64::from_le_bytes(len_buf) as usize;
306
307        let mut meta_bytes = vec![0u8; meta_len];
308        reader.read_exact(&mut meta_bytes)?;
309        let meta: FilteredMetadata = bincode::deserialize(&meta_bytes)?;
310
311        // Read labels
312        let mut labels = Vec::with_capacity(meta.num_vectors);
313        let mut val_buf = [0u8; 8];
314
315        for _ in 0..meta.num_vectors {
316            let mut label_vec = Vec::with_capacity(meta.num_fields);
317            for _ in 0..meta.num_fields {
318                reader.read_exact(&mut val_buf)?;
319                label_vec.push(u64::from_le_bytes(val_buf));
320            }
321            labels.push(label_vec);
322        }
323
324        Ok((labels, meta.num_fields))
325    }
326}
327
328impl<D> FilteredDiskANN<D>
329where
330    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
331{
332    /// Search with a filter predicate
333    ///
334    /// Uses an expanded beam search that skips non-matching candidates
335    /// but continues exploring the graph to find matches.
336    pub fn search_filtered(
337        &self,
338        query: &[f32],
339        k: usize,
340        beam_width: usize,
341        filter: &Filter,
342    ) -> Vec<u32> {
343        self.search_filtered_with_dists(query, k, beam_width, filter)
344            .into_iter()
345            .map(|(id, _)| id)
346            .collect()
347    }
348
349    /// Search with filter, returning distances
350    pub fn search_filtered_with_dists(
351        &self,
352        query: &[f32],
353        k: usize,
354        beam_width: usize,
355        filter: &Filter,
356    ) -> Vec<(u32, f32)> {
357        // For unfiltered search, use the fast path
358        if matches!(filter, Filter::None) {
359            return self.index.search_with_dists(query, k, beam_width);
360        }
361
362        // Filtered search: we need to explore more of the graph
363        // Use larger internal beam to find enough matching candidates
364        let expanded_beam = (beam_width * 4).max(k * 10);
365
366        let mut visited = HashSet::new();
367        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
368        let mut working_set: BinaryHeap<Candidate> = BinaryHeap::new();
369        let mut results: Vec<(u32, f32)> = Vec::with_capacity(k);
370
371        // Seed from medoid
372        let start_dist = self.distance_to(query, self.index.medoid_id as usize);
373        let start = Candidate {
374            dist: start_dist,
375            id: self.index.medoid_id,
376        };
377        frontier.push(Reverse(start));
378        working_set.push(start);
379        visited.insert(self.index.medoid_id);
380
381        // Check if medoid matches filter
382        if filter.matches(&self.labels[self.index.medoid_id as usize]) {
383            results.push((self.index.medoid_id, start_dist));
384        }
385
386        // Expand until we have k results or exhausted search
387        let mut iterations = 0;
388        let max_iterations = expanded_beam * 2;
389
390        while let Some(Reverse(best)) = frontier.peek().copied() {
391            iterations += 1;
392            if iterations > max_iterations {
393                break;
394            }
395
396            // Early termination if we have enough results and best candidate
397            // can't improve our worst result
398            if results.len() >= k {
399                if let Some((_, worst_dist)) = results.last() {
400                    if best.dist > *worst_dist * 1.5 {
401                        break;
402                    }
403                }
404            }
405
406            if working_set.len() >= expanded_beam {
407                if let Some(worst) = working_set.peek() {
408                    if best.dist >= worst.dist {
409                        break;
410                    }
411                }
412            }
413
414            let Reverse(current) = frontier.pop().unwrap();
415
416            // Explore neighbors
417            for &nb in self.get_neighbors(current.id) {
418                if nb == u32::MAX {
419                    continue;
420                }
421                if !visited.insert(nb) {
422                    continue;
423                }
424
425                let d = self.distance_to(query, nb as usize);
426                let cand = Candidate { dist: d, id: nb };
427
428                // Always add to working set for graph exploration
429                if working_set.len() < expanded_beam {
430                    working_set.push(cand);
431                    frontier.push(Reverse(cand));
432                } else if d < working_set.peek().unwrap().dist {
433                    working_set.pop();
434                    working_set.push(cand);
435                    frontier.push(Reverse(cand));
436                }
437
438                // Check filter for results
439                if filter.matches(&self.labels[nb as usize]) {
440                    // Insert into results maintaining sorted order
441                    let pos = results
442                        .iter()
443                        .position(|(_, dist)| d < *dist)
444                        .unwrap_or(results.len());
445
446                    if pos < k {
447                        results.insert(pos, (nb, d));
448                        if results.len() > k {
449                            results.pop();
450                        }
451                    }
452                }
453            }
454        }
455
456        results
457    }
458
459    /// Parallel batch filtered search
460    pub fn search_filtered_batch(
461        &self,
462        queries: &[Vec<f32>],
463        k: usize,
464        beam_width: usize,
465        filter: &Filter,
466    ) -> Vec<Vec<u32>> {
467        queries
468            .par_iter()
469            .map(|q| self.search_filtered(q, k, beam_width, filter))
470            .collect()
471    }
472
473    /// Unfiltered search (delegates to base index)
474    pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
475        self.index.search(query, k, beam_width)
476    }
477
478    /// Get labels for a vector
479    pub fn get_labels(&self, id: usize) -> Option<&[u64]> {
480        self.labels.get(id).map(|v| v.as_slice())
481    }
482
483    /// Get the underlying index
484    pub fn inner(&self) -> &DiskANN<D> {
485        &self.index
486    }
487
488    /// Number of vectors in the index
489    pub fn num_vectors(&self) -> usize {
490        self.index.num_vectors
491    }
492
493    /// Number of label fields per vector
494    pub fn num_fields(&self) -> usize {
495        self.num_fields
496    }
497
498    /// Count vectors matching a filter (useful for selectivity estimation)
499    pub fn count_matching(&self, filter: &Filter) -> usize {
500        self.labels.iter().filter(|l| filter.matches(l)).count()
501    }
502
503    fn get_neighbors(&self, node_id: u32) -> &[u32] {
504        // Access internal neighbors through the index
505        let offset = self.index.adjacency_offset
506            + (node_id as u64 * self.index.max_degree as u64 * 4);
507        let start = offset as usize;
508        let end = start + (self.index.max_degree * 4);
509        let bytes = &self.index.mmap[start..end];
510        bytemuck::cast_slice(bytes)
511    }
512
513    fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
514        let offset = self.index.vectors_offset + (idx as u64 * self.index.dim as u64 * 4);
515        let start = offset as usize;
516        let end = start + (self.index.dim * 4);
517        let bytes = &self.index.mmap[start..end];
518        let vector: &[f32] = bytemuck::cast_slice(bytes);
519        self.index.dist.eval(query, vector)
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use anndists::dist::DistL2;
527    use std::fs;
528
529    #[test]
530    fn test_filter_eq() {
531        let filter = Filter::label_eq(0, 5);
532        assert!(filter.matches(&[5, 10]));
533        assert!(!filter.matches(&[4, 10]));
534        assert!(!filter.matches(&[]));
535    }
536
537    #[test]
538    fn test_filter_in() {
539        let filter = Filter::label_in(0, vec![1, 3, 5]);
540        assert!(filter.matches(&[1]));
541        assert!(filter.matches(&[3]));
542        assert!(filter.matches(&[5]));
543        assert!(!filter.matches(&[2]));
544    }
545
546    #[test]
547    fn test_filter_range() {
548        let filter = Filter::label_range(0, 10, 20);
549        assert!(filter.matches(&[10]));
550        assert!(filter.matches(&[15]));
551        assert!(filter.matches(&[20]));
552        assert!(!filter.matches(&[9]));
553        assert!(!filter.matches(&[21]));
554    }
555
556    #[test]
557    fn test_filter_and() {
558        let filter = Filter::and(vec![
559            Filter::label_eq(0, 5),
560            Filter::label_gt(1, 10),
561        ]);
562        assert!(filter.matches(&[5, 15]));
563        assert!(!filter.matches(&[5, 5]));
564        assert!(!filter.matches(&[4, 15]));
565    }
566
567    #[test]
568    fn test_filter_or() {
569        let filter = Filter::or(vec![
570            Filter::label_eq(0, 5),
571            Filter::label_eq(0, 10),
572        ]);
573        assert!(filter.matches(&[5]));
574        assert!(filter.matches(&[10]));
575        assert!(!filter.matches(&[7]));
576    }
577
578    #[test]
579    fn test_filtered_search_basic() {
580        let base_path = "test_filtered";
581        let _ = fs::remove_file(format!("{}.idx", base_path));
582        let _ = fs::remove_file(format!("{}.labels", base_path));
583
584        // Create vectors with categories
585        let vectors: Vec<Vec<f32>> = (0..100)
586            .map(|i| vec![i as f32, (i * 2) as f32])
587            .collect();
588
589        // Labels: [category] where category = i % 5
590        let labels: Vec<Vec<u64>> = (0..100)
591            .map(|i| vec![i % 5])
592            .collect();
593
594        let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
595
596        // Search without filter
597        let results = index.search(&[50.0, 100.0], 5, 32);
598        assert_eq!(results.len(), 5);
599
600        // Search with filter: only category 0
601        let filter = Filter::label_eq(0, 0);
602        let results = index.search_filtered(&[50.0, 100.0], 5, 32, &filter);
603
604        // All results should be category 0
605        for id in &results {
606            assert_eq!(labels[*id as usize][0], 0);
607        }
608
609        let _ = fs::remove_file(format!("{}.idx", base_path));
610        let _ = fs::remove_file(format!("{}.labels", base_path));
611    }
612
613    #[test]
614    fn test_filtered_search_selectivity() {
615        let base_path = "test_filtered_sel";
616        let _ = fs::remove_file(format!("{}.idx", base_path));
617        let _ = fs::remove_file(format!("{}.labels", base_path));
618
619        // 1000 vectors, 10 categories
620        let vectors: Vec<Vec<f32>> = (0..1000)
621            .map(|i| vec![(i % 100) as f32, ((i / 100) * 10) as f32])
622            .collect();
623
624        let labels: Vec<Vec<u64>> = (0..1000)
625            .map(|i| vec![i % 10]) // ~100 per category
626            .collect();
627
628        let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
629
630        // Verify count
631        let filter = Filter::label_eq(0, 3);
632        assert_eq!(index.count_matching(&filter), 100);
633
634        // Search for category 3
635        let results = index.search_filtered(&[50.0, 50.0], 10, 64, &filter);
636        assert!(results.len() <= 10);
637
638        for id in &results {
639            assert_eq!(labels[*id as usize][0], 3);
640        }
641
642        let _ = fs::remove_file(format!("{}.idx", base_path));
643        let _ = fs::remove_file(format!("{}.labels", base_path));
644    }
645
646    #[test]
647    fn test_filtered_persistence() {
648        let base_path = "test_filtered_persist";
649        let _ = fs::remove_file(format!("{}.idx", base_path));
650        let _ = fs::remove_file(format!("{}.labels", base_path));
651
652        let vectors: Vec<Vec<f32>> = (0..50)
653            .map(|i| vec![i as f32, i as f32])
654            .collect();
655        let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3, i]).collect();
656
657        {
658            let _index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
659        }
660
661        // Reopen
662        let index = FilteredDiskANN::<DistL2>::open(base_path).unwrap();
663        assert_eq!(index.num_vectors(), 50);
664        assert_eq!(index.num_fields(), 2);
665
666        let filter = Filter::label_eq(0, 1);
667        let results = index.search_filtered(&[25.0, 25.0], 5, 32, &filter);
668        for id in &results {
669            assert_eq!(index.get_labels(*id as usize).unwrap()[0], 1);
670        }
671
672        let _ = fs::remove_file(format!("{}.idx", base_path));
673        let _ = fs::remove_file(format!("{}.labels", base_path));
674    }
675}