Skip to main content

diskann_rs/
filtered.rs

1//! # Filtered DiskANN Search
2//!
3//! Enables searching with metadata predicates, e.g.:
4//! "Find 10 nearest neighbors WHERE category = 'electronics' AND price < 100"
5//!
6//! ## Architecture
7//!
8//! ```text
9//! ┌────────────────────────────────────────────────────────────┐
10//! │                    FilteredDiskANN                         │
11//! ├────────────────────────────────────────────────────────────┤
12//! │  ┌──────────────────┐    ┌─────────────────────────────┐   │
13//! │  │   DiskANN Index  │    │     Metadata Store          │   │
14//! │  │   (vectors +     │    │   (labels per vector)       │   │
15//! │  │    graph)        │    │   - numeric fields          │   │
16//! │  │                  │    │   - string labels           │   │
17//! │  └──────────────────┘    └─────────────────────────────┘   │
18//! └────────────────────────────────────────────────────────────┘
19//!
20//! Search: expand beam, but skip candidates that don't match filter
21//! ```
22//!
23//! ## Usage
24//!
25//! ```ignore
26//! use anndists::dist::DistL2;
27//! use diskann_rs::{FilteredDiskANN, Filter};
28//!
29//! // Build index with metadata
30//! let vectors = vec![vec![0.0; 128]; 1000];
31//! let labels: Vec<Vec<u64>> = (0..1000).map(|i| vec![i % 10]).collect(); // 10 categories
32//!
33//! let index = FilteredDiskANN::<DistL2>::build(
34//!     &vectors,
35//!     &labels,
36//!     "filtered.db"
37//! ).unwrap();
38//!
39//! // Search with filter: only category 5
40//! let query = vec![0.0f32; 128];
41//! let filter = Filter::label_eq(0, 5); // field 0 == 5
42//! let results = index.search_filtered(&query, 10, 128, &filter);
43//! ```
44
45use crate::{DiskANN, DiskAnnError, DiskAnnParams};
46use anndists::prelude::Distance;
47use rayon::prelude::*;
48use serde::{Deserialize, Serialize};
49use std::collections::{BinaryHeap, HashSet};
50use std::cmp::{Ordering, Reverse};
51use std::fs::{File, OpenOptions};
52use std::io::{BufReader, BufWriter, Read, Write};
53use std::sync::Arc;
54
55/// A single filter condition
56#[derive(Clone, Debug)]
57pub enum Filter {
58    /// Label at field index equals value
59    LabelEq { field: usize, value: u64 },
60    /// Label at field index is in set
61    LabelIn { field: usize, values: HashSet<u64> },
62    /// Label at field index less than value
63    LabelLt { field: usize, value: u64 },
64    /// Label at field index greater than value
65    LabelGt { field: usize, value: u64 },
66    /// Label at field index in range [min, max]
67    LabelRange { field: usize, min: u64, max: u64 },
68    /// Logical AND of filters
69    And(Vec<Filter>),
70    /// Logical OR of filters
71    Or(Vec<Filter>),
72    /// No filter (match all)
73    None,
74}
75
76impl Filter {
77    /// Create equality filter
78    pub fn label_eq(field: usize, value: u64) -> Self {
79        Filter::LabelEq { field, value }
80    }
81
82    /// Create "in set" filter
83    pub fn label_in(field: usize, values: impl IntoIterator<Item = u64>) -> Self {
84        Filter::LabelIn {
85            field,
86            values: values.into_iter().collect(),
87        }
88    }
89
90    /// Create less-than filter
91    pub fn label_lt(field: usize, value: u64) -> Self {
92        Filter::LabelLt { field, value }
93    }
94
95    /// Create greater-than filter
96    pub fn label_gt(field: usize, value: u64) -> Self {
97        Filter::LabelGt { field, value }
98    }
99
100    /// Create range filter [min, max] inclusive
101    pub fn label_range(field: usize, min: u64, max: u64) -> Self {
102        Filter::LabelRange { field, min, max }
103    }
104
105    /// Combine filters with AND
106    pub fn and(filters: Vec<Filter>) -> Self {
107        Filter::And(filters)
108    }
109
110    /// Combine filters with OR
111    pub fn or(filters: Vec<Filter>) -> Self {
112        Filter::Or(filters)
113    }
114
115    /// Check if a label vector matches this filter
116    pub fn matches(&self, labels: &[u64]) -> bool {
117        match self {
118            Filter::None => true,
119            Filter::LabelEq { field, value } => {
120                labels.get(*field).map_or(false, |v| v == value)
121            }
122            Filter::LabelIn { field, values } => {
123                labels.get(*field).map_or(false, |v| values.contains(v))
124            }
125            Filter::LabelLt { field, value } => {
126                labels.get(*field).map_or(false, |v| v < value)
127            }
128            Filter::LabelGt { field, value } => {
129                labels.get(*field).map_or(false, |v| v > value)
130            }
131            Filter::LabelRange { field, min, max } => {
132                labels.get(*field).map_or(false, |v| v >= min && v <= max)
133            }
134            Filter::And(filters) => filters.iter().all(|f| f.matches(labels)),
135            Filter::Or(filters) => filters.iter().any(|f| f.matches(labels)),
136        }
137    }
138}
139
140/// Metadata for filtered index
141#[derive(Serialize, Deserialize, Debug)]
142struct FilteredMetadata {
143    num_vectors: usize,
144    num_fields: usize,
145}
146
147/// Candidate for filtered search
148#[derive(Clone, Copy)]
149struct Candidate {
150    dist: f32,
151    id: u32,
152}
153
154impl PartialEq for Candidate {
155    fn eq(&self, other: &Self) -> bool {
156        self.dist == other.dist && self.id == other.id
157    }
158}
159impl Eq for Candidate {}
160impl PartialOrd for Candidate {
161    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162        self.dist.partial_cmp(&other.dist)
163    }
164}
165impl Ord for Candidate {
166    fn cmp(&self, other: &Self) -> Ordering {
167        self.partial_cmp(other).unwrap_or(Ordering::Equal)
168    }
169}
170
171/// DiskANN index with metadata filtering support
172pub struct FilteredDiskANN<D>
173where
174    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
175{
176    /// The underlying vector index
177    index: DiskANN<D>,
178    /// Labels for each vector: labels[vector_id] = [field0, field1, ...]
179    labels: Vec<Vec<u64>>,
180    /// Number of label fields per vector
181    num_fields: usize,
182    /// Path to labels file (kept for potential future persistence)
183    #[allow(dead_code)]
184    labels_path: String,
185}
186
187impl<D> FilteredDiskANN<D>
188where
189    D: Distance<f32> + Send + Sync + Copy + Clone + Default + 'static,
190{
191    /// Build a new filtered index
192    pub fn build(
193        vectors: &[Vec<f32>],
194        labels: &[Vec<u64>],
195        base_path: &str,
196    ) -> Result<Self, DiskAnnError> {
197        Self::build_with_params(vectors, labels, base_path, DiskAnnParams::default())
198    }
199
200    /// Build with custom parameters
201    pub fn build_with_params(
202        vectors: &[Vec<f32>],
203        labels: &[Vec<u64>],
204        base_path: &str,
205        params: DiskAnnParams,
206    ) -> Result<Self, DiskAnnError> {
207        if vectors.len() != labels.len() {
208            return Err(DiskAnnError::IndexError(format!(
209                "vectors.len() ({}) != labels.len() ({})",
210                vectors.len(),
211                labels.len()
212            )));
213        }
214
215        let num_fields = labels.first().map(|l| l.len()).unwrap_or(0);
216        for (i, l) in labels.iter().enumerate() {
217            if l.len() != num_fields {
218                return Err(DiskAnnError::IndexError(format!(
219                    "Label {} has {} fields, expected {}",
220                    i,
221                    l.len(),
222                    num_fields
223                )));
224            }
225        }
226
227        // Build the vector index
228        let index_path = format!("{}.idx", base_path);
229        let index = DiskANN::<D>::build_index_with_params(
230            vectors,
231            D::default(),
232            &index_path,
233            params,
234        )?;
235
236        // Save labels
237        let labels_path = format!("{}.labels", base_path);
238        Self::save_labels(&labels_path, labels, num_fields)?;
239
240        Ok(Self {
241            index,
242            labels: labels.to_vec(),
243            num_fields,
244            labels_path,
245        })
246    }
247
248    /// Open an existing filtered index
249    pub fn open(base_path: &str) -> Result<Self, DiskAnnError> {
250        let index_path = format!("{}.idx", base_path);
251        let labels_path = format!("{}.labels", base_path);
252
253        let index = DiskANN::<D>::open_index_default_metric(&index_path)?;
254        let (labels, num_fields) = Self::load_labels(&labels_path)?;
255
256        if labels.len() != index.num_vectors {
257            return Err(DiskAnnError::IndexError(format!(
258                "Labels count ({}) != index vectors ({})",
259                labels.len(),
260                index.num_vectors
261            )));
262        }
263
264        Ok(Self {
265            index,
266            labels,
267            num_fields,
268            labels_path,
269        })
270    }
271
272    /// Serialize the filtered index to bytes.
273    ///
274    /// Format: `[index_len:u64][index_bytes][labels_bytes]`
275    pub fn to_bytes(&self) -> Vec<u8> {
276        let index_bytes = self.index.to_bytes();
277        let labels_bytes = Self::serialize_labels(&self.labels, self.num_fields);
278        let mut out = Vec::with_capacity(8 + index_bytes.len() + labels_bytes.len());
279        out.extend_from_slice(&(index_bytes.len() as u64).to_le_bytes());
280        out.extend_from_slice(&index_bytes);
281        out.extend_from_slice(&labels_bytes);
282        out
283    }
284
285    /// Load a filtered index from bytes.
286    pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
287        if bytes.len() < 8 {
288            return Err(DiskAnnError::IndexError("Buffer too small".into()));
289        }
290        let index_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
291        if bytes.len() < 8 + index_len {
292            return Err(DiskAnnError::IndexError("Buffer too small for index data".into()));
293        }
294        let index_bytes = bytes[8..8 + index_len].to_vec();
295        let labels_bytes = &bytes[8 + index_len..];
296
297        let index = DiskANN::<D>::from_bytes(index_bytes, dist)?;
298        let (labels, num_fields) = Self::deserialize_labels(labels_bytes)?;
299
300        if labels.len() != index.num_vectors {
301            return Err(DiskAnnError::IndexError(format!(
302                "Labels count ({}) != index vectors ({})",
303                labels.len(),
304                index.num_vectors
305            )));
306        }
307
308        Ok(Self {
309            index,
310            labels,
311            num_fields,
312            labels_path: String::new(),
313        })
314    }
315
316    /// Load a filtered index from shared bytes.
317    pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
318        // We need to split the buffer, so we parse the header and use owned for both parts
319        if bytes.len() < 8 {
320            return Err(DiskAnnError::IndexError("Buffer too small".into()));
321        }
322        let index_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
323        if bytes.len() < 8 + index_len {
324            return Err(DiskAnnError::IndexError("Buffer too small for index data".into()));
325        }
326        let index_bytes = bytes[8..8 + index_len].to_vec();
327        let labels_bytes = &bytes[8 + index_len..];
328
329        let index = DiskANN::<D>::from_bytes(index_bytes, dist)?;
330        let (labels, num_fields) = Self::deserialize_labels(labels_bytes)?;
331
332        if labels.len() != index.num_vectors {
333            return Err(DiskAnnError::IndexError(format!(
334                "Labels count ({}) != index vectors ({})",
335                labels.len(),
336                index.num_vectors
337            )));
338        }
339
340        Ok(Self {
341            index,
342            labels,
343            num_fields,
344            labels_path: String::new(),
345        })
346    }
347
348    fn serialize_labels(labels: &[Vec<u64>], num_fields: usize) -> Vec<u8> {
349        let meta = FilteredMetadata {
350            num_vectors: labels.len(),
351            num_fields,
352        };
353        let meta_bytes = bincode::serialize(&meta).unwrap();
354        let mut out = Vec::new();
355        out.extend_from_slice(&(meta_bytes.len() as u64).to_le_bytes());
356        out.extend_from_slice(&meta_bytes);
357        for label_vec in labels {
358            for &val in label_vec {
359                out.extend_from_slice(&val.to_le_bytes());
360            }
361        }
362        out
363    }
364
365    fn deserialize_labels(bytes: &[u8]) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
366        if bytes.len() < 8 {
367            return Err(DiskAnnError::IndexError("Labels buffer too small".into()));
368        }
369        let meta_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
370        if bytes.len() < 8 + meta_len {
371            return Err(DiskAnnError::IndexError("Labels buffer too small for metadata".into()));
372        }
373        let meta: FilteredMetadata = bincode::deserialize(&bytes[8..8 + meta_len])?;
374
375        let data = &bytes[8 + meta_len..];
376        let mut labels = Vec::with_capacity(meta.num_vectors);
377        let mut offset = 0;
378        for _ in 0..meta.num_vectors {
379            let mut label_vec = Vec::with_capacity(meta.num_fields);
380            for _ in 0..meta.num_fields {
381                if offset + 8 > data.len() {
382                    return Err(DiskAnnError::IndexError("Labels data truncated".into()));
383                }
384                label_vec.push(u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()));
385                offset += 8;
386            }
387            labels.push(label_vec);
388        }
389
390        Ok((labels, meta.num_fields))
391    }
392
393    fn save_labels(path: &str, labels: &[Vec<u64>], num_fields: usize) -> Result<(), DiskAnnError> {
394        let file = OpenOptions::new()
395            .create(true)
396            .write(true)
397            .truncate(true)
398            .open(path)?;
399        let mut writer = BufWriter::new(file);
400
401        let meta = FilteredMetadata {
402            num_vectors: labels.len(),
403            num_fields,
404        };
405        let meta_bytes = bincode::serialize(&meta)?;
406        writer.write_all(&(meta_bytes.len() as u64).to_le_bytes())?;
407        writer.write_all(&meta_bytes)?;
408
409        // Write labels as flat array
410        for label_vec in labels {
411            for &val in label_vec {
412                writer.write_all(&val.to_le_bytes())?;
413            }
414        }
415
416        writer.flush()?;
417        Ok(())
418    }
419
420    fn load_labels(path: &str) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
421        let file = File::open(path)?;
422        let mut reader = BufReader::new(file);
423
424        // Read metadata
425        let mut len_buf = [0u8; 8];
426        reader.read_exact(&mut len_buf)?;
427        let meta_len = u64::from_le_bytes(len_buf) as usize;
428
429        let mut meta_bytes = vec![0u8; meta_len];
430        reader.read_exact(&mut meta_bytes)?;
431        let meta: FilteredMetadata = bincode::deserialize(&meta_bytes)?;
432
433        // Read labels
434        let mut labels = Vec::with_capacity(meta.num_vectors);
435        let mut val_buf = [0u8; 8];
436
437        for _ in 0..meta.num_vectors {
438            let mut label_vec = Vec::with_capacity(meta.num_fields);
439            for _ in 0..meta.num_fields {
440                reader.read_exact(&mut val_buf)?;
441                label_vec.push(u64::from_le_bytes(val_buf));
442            }
443            labels.push(label_vec);
444        }
445
446        Ok((labels, meta.num_fields))
447    }
448}
449
450impl<D> FilteredDiskANN<D>
451where
452    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
453{
454    /// Search with a filter predicate
455    ///
456    /// Uses an expanded beam search that skips non-matching candidates
457    /// but continues exploring the graph to find matches.
458    pub fn search_filtered(
459        &self,
460        query: &[f32],
461        k: usize,
462        beam_width: usize,
463        filter: &Filter,
464    ) -> Vec<u32> {
465        self.search_filtered_with_dists(query, k, beam_width, filter)
466            .into_iter()
467            .map(|(id, _)| id)
468            .collect()
469    }
470
471    /// Search with filter, returning distances
472    pub fn search_filtered_with_dists(
473        &self,
474        query: &[f32],
475        k: usize,
476        beam_width: usize,
477        filter: &Filter,
478    ) -> Vec<(u32, f32)> {
479        // For unfiltered search, use the fast path
480        if matches!(filter, Filter::None) {
481            return self.index.search_with_dists(query, k, beam_width);
482        }
483
484        // Filtered search: we need to explore more of the graph
485        // Use larger internal beam to find enough matching candidates
486        let expanded_beam = (beam_width * 4).max(k * 10);
487
488        let mut visited = HashSet::new();
489        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
490        let mut working_set: BinaryHeap<Candidate> = BinaryHeap::new();
491        let mut results: Vec<(u32, f32)> = Vec::with_capacity(k);
492
493        // Seed from medoid
494        let start_dist = self.distance_to(query, self.index.medoid_id as usize);
495        let start = Candidate {
496            dist: start_dist,
497            id: self.index.medoid_id,
498        };
499        frontier.push(Reverse(start));
500        working_set.push(start);
501        visited.insert(self.index.medoid_id);
502
503        // Check if medoid matches filter
504        if filter.matches(&self.labels[self.index.medoid_id as usize]) {
505            results.push((self.index.medoid_id, start_dist));
506        }
507
508        // Expand until we have k results or exhausted search
509        let mut iterations = 0;
510        let max_iterations = expanded_beam * 2;
511
512        while let Some(Reverse(best)) = frontier.peek().copied() {
513            iterations += 1;
514            if iterations > max_iterations {
515                break;
516            }
517
518            // Early termination if we have enough results and best candidate
519            // can't improve our worst result
520            if results.len() >= k {
521                if let Some((_, worst_dist)) = results.last() {
522                    if best.dist > *worst_dist * 1.5 {
523                        break;
524                    }
525                }
526            }
527
528            if working_set.len() >= expanded_beam {
529                if let Some(worst) = working_set.peek() {
530                    if best.dist >= worst.dist {
531                        break;
532                    }
533                }
534            }
535
536            let Reverse(current) = frontier.pop().unwrap();
537
538            // Explore neighbors
539            for &nb in self.get_neighbors(current.id) {
540                if nb == u32::MAX {
541                    continue;
542                }
543                if !visited.insert(nb) {
544                    continue;
545                }
546
547                let d = self.distance_to(query, nb as usize);
548                let cand = Candidate { dist: d, id: nb };
549
550                // Always add to working set for graph exploration
551                if working_set.len() < expanded_beam {
552                    working_set.push(cand);
553                    frontier.push(Reverse(cand));
554                } else if d < working_set.peek().unwrap().dist {
555                    working_set.pop();
556                    working_set.push(cand);
557                    frontier.push(Reverse(cand));
558                }
559
560                // Check filter for results
561                if filter.matches(&self.labels[nb as usize]) {
562                    // Insert into results maintaining sorted order
563                    let pos = results
564                        .iter()
565                        .position(|(_, dist)| d < *dist)
566                        .unwrap_or(results.len());
567
568                    if pos < k {
569                        results.insert(pos, (nb, d));
570                        if results.len() > k {
571                            results.pop();
572                        }
573                    }
574                }
575            }
576        }
577
578        results
579    }
580
581    /// Parallel batch filtered search
582    pub fn search_filtered_batch(
583        &self,
584        queries: &[Vec<f32>],
585        k: usize,
586        beam_width: usize,
587        filter: &Filter,
588    ) -> Vec<Vec<u32>> {
589        queries
590            .par_iter()
591            .map(|q| self.search_filtered(q, k, beam_width, filter))
592            .collect()
593    }
594
595    /// Unfiltered search (delegates to base index)
596    pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
597        self.index.search(query, k, beam_width)
598    }
599
600    /// Get labels for a vector
601    pub fn get_labels(&self, id: usize) -> Option<&[u64]> {
602        self.labels.get(id).map(|v| v.as_slice())
603    }
604
605    /// Get the underlying index
606    pub fn inner(&self) -> &DiskANN<D> {
607        &self.index
608    }
609
610    /// Number of vectors in the index
611    pub fn num_vectors(&self) -> usize {
612        self.index.num_vectors
613    }
614
615    /// Number of label fields per vector
616    pub fn num_fields(&self) -> usize {
617        self.num_fields
618    }
619
620    /// Count vectors matching a filter (useful for selectivity estimation)
621    pub fn count_matching(&self, filter: &Filter) -> usize {
622        self.labels.iter().filter(|l| filter.matches(l)).count()
623    }
624
625    fn get_neighbors(&self, node_id: u32) -> &[u32] {
626        // Access internal neighbors through the index
627        let offset = self.index.adjacency_offset
628            + (node_id as u64 * self.index.max_degree as u64 * 4);
629        let start = offset as usize;
630        let end = start + (self.index.max_degree * 4);
631        let bytes = &self.index.storage[start..end];
632        bytemuck::cast_slice(bytes)
633    }
634
635    fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
636        let offset = self.index.vectors_offset + (idx as u64 * self.index.dim as u64 * 4);
637        let start = offset as usize;
638        let end = start + (self.index.dim * 4);
639        let bytes = &self.index.storage[start..end];
640        let vector: &[f32] = bytemuck::cast_slice(bytes);
641        self.index.dist.eval(query, vector)
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use anndists::dist::DistL2;
649    use std::fs;
650
651    #[test]
652    fn test_filter_eq() {
653        let filter = Filter::label_eq(0, 5);
654        assert!(filter.matches(&[5, 10]));
655        assert!(!filter.matches(&[4, 10]));
656        assert!(!filter.matches(&[]));
657    }
658
659    #[test]
660    fn test_filter_in() {
661        let filter = Filter::label_in(0, vec![1, 3, 5]);
662        assert!(filter.matches(&[1]));
663        assert!(filter.matches(&[3]));
664        assert!(filter.matches(&[5]));
665        assert!(!filter.matches(&[2]));
666    }
667
668    #[test]
669    fn test_filter_range() {
670        let filter = Filter::label_range(0, 10, 20);
671        assert!(filter.matches(&[10]));
672        assert!(filter.matches(&[15]));
673        assert!(filter.matches(&[20]));
674        assert!(!filter.matches(&[9]));
675        assert!(!filter.matches(&[21]));
676    }
677
678    #[test]
679    fn test_filter_and() {
680        let filter = Filter::and(vec![
681            Filter::label_eq(0, 5),
682            Filter::label_gt(1, 10),
683        ]);
684        assert!(filter.matches(&[5, 15]));
685        assert!(!filter.matches(&[5, 5]));
686        assert!(!filter.matches(&[4, 15]));
687    }
688
689    #[test]
690    fn test_filter_or() {
691        let filter = Filter::or(vec![
692            Filter::label_eq(0, 5),
693            Filter::label_eq(0, 10),
694        ]);
695        assert!(filter.matches(&[5]));
696        assert!(filter.matches(&[10]));
697        assert!(!filter.matches(&[7]));
698    }
699
700    #[test]
701    fn test_filtered_search_basic() {
702        let base_path = "test_filtered";
703        let _ = fs::remove_file(format!("{}.idx", base_path));
704        let _ = fs::remove_file(format!("{}.labels", base_path));
705
706        // Create vectors with categories
707        let vectors: Vec<Vec<f32>> = (0..100)
708            .map(|i| vec![i as f32, (i * 2) as f32])
709            .collect();
710
711        // Labels: [category] where category = i % 5
712        let labels: Vec<Vec<u64>> = (0..100)
713            .map(|i| vec![i % 5])
714            .collect();
715
716        let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
717
718        // Search without filter
719        let results = index.search(&[50.0, 100.0], 5, 32);
720        assert_eq!(results.len(), 5);
721
722        // Search with filter: only category 0
723        let filter = Filter::label_eq(0, 0);
724        let results = index.search_filtered(&[50.0, 100.0], 5, 32, &filter);
725
726        // All results should be category 0
727        for id in &results {
728            assert_eq!(labels[*id as usize][0], 0);
729        }
730
731        let _ = fs::remove_file(format!("{}.idx", base_path));
732        let _ = fs::remove_file(format!("{}.labels", base_path));
733    }
734
735    #[test]
736    fn test_filtered_search_selectivity() {
737        let base_path = "test_filtered_sel";
738        let _ = fs::remove_file(format!("{}.idx", base_path));
739        let _ = fs::remove_file(format!("{}.labels", base_path));
740
741        // 1000 vectors, 10 categories
742        let vectors: Vec<Vec<f32>> = (0..1000)
743            .map(|i| vec![(i % 100) as f32, ((i / 100) * 10) as f32])
744            .collect();
745
746        let labels: Vec<Vec<u64>> = (0..1000)
747            .map(|i| vec![i % 10]) // ~100 per category
748            .collect();
749
750        let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
751
752        // Verify count
753        let filter = Filter::label_eq(0, 3);
754        assert_eq!(index.count_matching(&filter), 100);
755
756        // Search for category 3
757        let results = index.search_filtered(&[50.0, 50.0], 10, 64, &filter);
758        assert!(results.len() <= 10);
759
760        for id in &results {
761            assert_eq!(labels[*id as usize][0], 3);
762        }
763
764        let _ = fs::remove_file(format!("{}.idx", base_path));
765        let _ = fs::remove_file(format!("{}.labels", base_path));
766    }
767
768    #[test]
769    fn test_filtered_persistence() {
770        let base_path = "test_filtered_persist";
771        let _ = fs::remove_file(format!("{}.idx", base_path));
772        let _ = fs::remove_file(format!("{}.labels", base_path));
773
774        let vectors: Vec<Vec<f32>> = (0..50)
775            .map(|i| vec![i as f32, i as f32])
776            .collect();
777        let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3, i]).collect();
778
779        {
780            let _index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
781        }
782
783        // Reopen
784        let index = FilteredDiskANN::<DistL2>::open(base_path).unwrap();
785        assert_eq!(index.num_vectors(), 50);
786        assert_eq!(index.num_fields(), 2);
787
788        let filter = Filter::label_eq(0, 1);
789        let results = index.search_filtered(&[25.0, 25.0], 5, 32, &filter);
790        for id in &results {
791            assert_eq!(index.get_labels(*id as usize).unwrap()[0], 1);
792        }
793
794        let _ = fs::remove_file(format!("{}.idx", base_path));
795        let _ = fs::remove_file(format!("{}.labels", base_path));
796    }
797
798    #[test]
799    fn test_filtered_to_bytes_from_bytes() {
800        let base_path = "test_filtered_bytes_rt";
801        let _ = fs::remove_file(format!("{}.idx", base_path));
802        let _ = fs::remove_file(format!("{}.labels", base_path));
803
804        let vectors: Vec<Vec<f32>> = (0..50)
805            .map(|i| vec![i as f32, i as f32])
806            .collect();
807        let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3]).collect();
808
809        let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
810        let bytes = index.to_bytes();
811
812        let index2 = FilteredDiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
813        assert_eq!(index2.num_vectors(), 50);
814        assert_eq!(index2.num_fields(), 1);
815
816        let filter = Filter::label_eq(0, 1);
817        let results = index2.search_filtered(&[25.0, 25.0], 5, 32, &filter);
818        for id in &results {
819            assert_eq!(index2.get_labels(*id as usize).unwrap()[0], 1);
820        }
821
822        let _ = fs::remove_file(format!("{}.idx", base_path));
823        let _ = fs::remove_file(format!("{}.labels", base_path));
824    }
825}