use crate::{beam_search, BeamSearchConfig, GraphIndex, DiskANN, DiskAnnError, DiskAnnParams};
use anndists::prelude::Distance;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::sync::Arc;
pub(crate) fn filtered_search(
graph: &dyn GraphIndex,
labels: &[Vec<u64>],
start_ids: &[u32],
query: &[f32],
k: usize,
beam_width: usize,
filter: &Filter,
) -> Vec<(u32, f32)> {
if matches!(filter, Filter::None) {
return beam_search(
start_ids,
beam_width,
k,
|id| graph.distance_to(query, id),
|id| graph.get_neighbors(id),
|_| true,
BeamSearchConfig::default(),
);
}
let expanded_beam = (beam_width * 4).max(k * 10);
beam_search(
start_ids,
beam_width,
k,
|id| graph.distance_to(query, id),
|id| graph.get_neighbors(id),
|id| {
let idx = id as usize;
if idx < labels.len() {
filter.matches(&labels[idx])
} else {
false
}
},
BeamSearchConfig {
expanded_beam: Some(expanded_beam),
max_iterations: Some(expanded_beam * 2),
early_term_factor: Some(1.5),
},
)
}
#[derive(Clone, Debug)]
pub enum Filter {
LabelEq { field: usize, value: u64 },
LabelIn { field: usize, values: HashSet<u64> },
LabelLt { field: usize, value: u64 },
LabelGt { field: usize, value: u64 },
LabelRange { field: usize, min: u64, max: u64 },
And(Vec<Filter>),
Or(Vec<Filter>),
None,
}
impl Filter {
pub fn label_eq(field: usize, value: u64) -> Self {
Filter::LabelEq { field, value }
}
pub fn label_in(field: usize, values: impl IntoIterator<Item = u64>) -> Self {
Filter::LabelIn {
field,
values: values.into_iter().collect(),
}
}
pub fn label_lt(field: usize, value: u64) -> Self {
Filter::LabelLt { field, value }
}
pub fn label_gt(field: usize, value: u64) -> Self {
Filter::LabelGt { field, value }
}
pub fn label_range(field: usize, min: u64, max: u64) -> Self {
Filter::LabelRange { field, min, max }
}
pub fn and(filters: Vec<Filter>) -> Self {
Filter::And(filters)
}
pub fn or(filters: Vec<Filter>) -> Self {
Filter::Or(filters)
}
pub fn matches(&self, labels: &[u64]) -> bool {
match self {
Filter::None => true,
Filter::LabelEq { field, value } => {
labels.get(*field).map_or(false, |v| v == value)
}
Filter::LabelIn { field, values } => {
labels.get(*field).map_or(false, |v| values.contains(v))
}
Filter::LabelLt { field, value } => {
labels.get(*field).map_or(false, |v| v < value)
}
Filter::LabelGt { field, value } => {
labels.get(*field).map_or(false, |v| v > value)
}
Filter::LabelRange { field, min, max } => {
labels.get(*field).map_or(false, |v| v >= min && v <= max)
}
Filter::And(filters) => filters.iter().all(|f| f.matches(labels)),
Filter::Or(filters) => filters.iter().any(|f| f.matches(labels)),
}
}
}
#[derive(Serialize, Deserialize, Debug)]
struct FilteredMetadata {
num_vectors: usize,
num_fields: usize,
}
pub struct FilteredDiskANN<D>
where
D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
{
index: DiskANN<D>,
labels: Vec<Vec<u64>>,
num_fields: usize,
#[allow(dead_code)]
labels_path: String,
}
impl<D> FilteredDiskANN<D>
where
D: Distance<f32> + Send + Sync + Copy + Clone + Default + 'static,
{
pub fn build(
vectors: &[Vec<f32>],
labels: &[Vec<u64>],
base_path: &str,
) -> Result<Self, DiskAnnError> {
Self::build_with_params(vectors, labels, base_path, DiskAnnParams::default())
}
pub fn build_with_params(
vectors: &[Vec<f32>],
labels: &[Vec<u64>],
base_path: &str,
params: DiskAnnParams,
) -> Result<Self, DiskAnnError> {
if vectors.len() != labels.len() {
return Err(DiskAnnError::IndexError(format!(
"vectors.len() ({}) != labels.len() ({})",
vectors.len(),
labels.len()
)));
}
let num_fields = labels.first().map(|l| l.len()).unwrap_or(0);
for (i, l) in labels.iter().enumerate() {
if l.len() != num_fields {
return Err(DiskAnnError::IndexError(format!(
"Label {} has {} fields, expected {}",
i,
l.len(),
num_fields
)));
}
}
let index_path = format!("{}.idx", base_path);
let index = DiskANN::<D>::build_index_with_params(
vectors,
D::default(),
&index_path,
params,
)?;
let labels_path = format!("{}.labels", base_path);
Self::save_labels(&labels_path, labels, num_fields)?;
Ok(Self {
index,
labels: labels.to_vec(),
num_fields,
labels_path,
})
}
pub fn open(base_path: &str) -> Result<Self, DiskAnnError> {
let index_path = format!("{}.idx", base_path);
let labels_path = format!("{}.labels", base_path);
let index = DiskANN::<D>::open_index_default_metric(&index_path)?;
let (labels, num_fields) = Self::load_labels(&labels_path)?;
if labels.len() != index.num_vectors {
return Err(DiskAnnError::IndexError(format!(
"Labels count ({}) != index vectors ({})",
labels.len(),
index.num_vectors
)));
}
Ok(Self {
index,
labels,
num_fields,
labels_path,
})
}
pub fn to_bytes(&self) -> Vec<u8> {
let index_bytes = self.index.to_bytes();
let labels_bytes = Self::serialize_labels(&self.labels, self.num_fields);
let mut out = Vec::with_capacity(8 + index_bytes.len() + labels_bytes.len());
out.extend_from_slice(&(index_bytes.len() as u64).to_le_bytes());
out.extend_from_slice(&index_bytes);
out.extend_from_slice(&labels_bytes);
out
}
pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
if bytes.len() < 8 {
return Err(DiskAnnError::IndexError("Buffer too small".into()));
}
let index_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
if bytes.len() < 8 + index_len {
return Err(DiskAnnError::IndexError("Buffer too small for index data".into()));
}
let index_bytes = bytes[8..8 + index_len].to_vec();
let labels_bytes = &bytes[8 + index_len..];
let index = DiskANN::<D>::from_bytes(index_bytes, dist)?;
let (labels, num_fields) = Self::deserialize_labels(labels_bytes)?;
if labels.len() != index.num_vectors {
return Err(DiskAnnError::IndexError(format!(
"Labels count ({}) != index vectors ({})",
labels.len(),
index.num_vectors
)));
}
Ok(Self {
index,
labels,
num_fields,
labels_path: String::new(),
})
}
pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
if bytes.len() < 8 {
return Err(DiskAnnError::IndexError("Buffer too small".into()));
}
let index_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
if bytes.len() < 8 + index_len {
return Err(DiskAnnError::IndexError("Buffer too small for index data".into()));
}
let index_bytes = bytes[8..8 + index_len].to_vec();
let labels_bytes = &bytes[8 + index_len..];
let index = DiskANN::<D>::from_bytes(index_bytes, dist)?;
let (labels, num_fields) = Self::deserialize_labels(labels_bytes)?;
if labels.len() != index.num_vectors {
return Err(DiskAnnError::IndexError(format!(
"Labels count ({}) != index vectors ({})",
labels.len(),
index.num_vectors
)));
}
Ok(Self {
index,
labels,
num_fields,
labels_path: String::new(),
})
}
fn serialize_labels(labels: &[Vec<u64>], num_fields: usize) -> Vec<u8> {
let meta = FilteredMetadata {
num_vectors: labels.len(),
num_fields,
};
let meta_bytes = bincode::serialize(&meta).unwrap();
let mut out = Vec::new();
out.extend_from_slice(&(meta_bytes.len() as u64).to_le_bytes());
out.extend_from_slice(&meta_bytes);
for label_vec in labels {
for &val in label_vec {
out.extend_from_slice(&val.to_le_bytes());
}
}
out
}
fn deserialize_labels(bytes: &[u8]) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
if bytes.len() < 8 {
return Err(DiskAnnError::IndexError("Labels buffer too small".into()));
}
let meta_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
if bytes.len() < 8 + meta_len {
return Err(DiskAnnError::IndexError("Labels buffer too small for metadata".into()));
}
let meta: FilteredMetadata = bincode::deserialize(&bytes[8..8 + meta_len])?;
let data = &bytes[8 + meta_len..];
let mut labels = Vec::with_capacity(meta.num_vectors);
let mut offset = 0;
for _ in 0..meta.num_vectors {
let mut label_vec = Vec::with_capacity(meta.num_fields);
for _ in 0..meta.num_fields {
if offset + 8 > data.len() {
return Err(DiskAnnError::IndexError("Labels data truncated".into()));
}
label_vec.push(u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()));
offset += 8;
}
labels.push(label_vec);
}
Ok((labels, meta.num_fields))
}
fn save_labels(path: &str, labels: &[Vec<u64>], num_fields: usize) -> Result<(), DiskAnnError> {
let file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(path)?;
let mut writer = BufWriter::new(file);
let meta = FilteredMetadata {
num_vectors: labels.len(),
num_fields,
};
let meta_bytes = bincode::serialize(&meta)?;
writer.write_all(&(meta_bytes.len() as u64).to_le_bytes())?;
writer.write_all(&meta_bytes)?;
for label_vec in labels {
for &val in label_vec {
writer.write_all(&val.to_le_bytes())?;
}
}
writer.flush()?;
Ok(())
}
fn load_labels(path: &str) -> Result<(Vec<Vec<u64>>, usize), DiskAnnError> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut len_buf = [0u8; 8];
reader.read_exact(&mut len_buf)?;
let meta_len = u64::from_le_bytes(len_buf) as usize;
let mut meta_bytes = vec![0u8; meta_len];
reader.read_exact(&mut meta_bytes)?;
let meta: FilteredMetadata = bincode::deserialize(&meta_bytes)?;
let mut labels = Vec::with_capacity(meta.num_vectors);
let mut val_buf = [0u8; 8];
for _ in 0..meta.num_vectors {
let mut label_vec = Vec::with_capacity(meta.num_fields);
for _ in 0..meta.num_fields {
reader.read_exact(&mut val_buf)?;
label_vec.push(u64::from_le_bytes(val_buf));
}
labels.push(label_vec);
}
Ok((labels, meta.num_fields))
}
}
impl<D> FilteredDiskANN<D>
where
D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
{
pub fn search_filtered(
&self,
query: &[f32],
k: usize,
beam_width: usize,
filter: &Filter,
) -> Vec<u32> {
self.search_filtered_with_dists(query, k, beam_width, filter)
.into_iter()
.map(|(id, _)| id)
.collect()
}
pub fn search_filtered_with_dists(
&self,
query: &[f32],
k: usize,
beam_width: usize,
filter: &Filter,
) -> Vec<(u32, f32)> {
filtered_search(
&self.index,
&self.labels,
&[self.index.medoid_id],
query,
k,
beam_width,
filter,
)
}
pub fn search_filtered_batch(
&self,
queries: &[Vec<f32>],
k: usize,
beam_width: usize,
filter: &Filter,
) -> Vec<Vec<u32>> {
queries
.par_iter()
.map(|q| self.search_filtered(q, k, beam_width, filter))
.collect()
}
pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
self.index.search(query, k, beam_width)
}
pub fn get_labels(&self, id: usize) -> Option<&[u64]> {
self.labels.get(id).map(|v| v.as_slice())
}
pub fn inner(&self) -> &DiskANN<D> {
&self.index
}
pub fn num_vectors(&self) -> usize {
self.index.num_vectors
}
pub fn num_fields(&self) -> usize {
self.num_fields
}
pub fn count_matching(&self, filter: &Filter) -> usize {
self.labels.iter().filter(|l| filter.matches(l)).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use anndists::dist::DistL2;
use std::fs;
#[test]
fn test_filter_eq() {
let filter = Filter::label_eq(0, 5);
assert!(filter.matches(&[5, 10]));
assert!(!filter.matches(&[4, 10]));
assert!(!filter.matches(&[]));
}
#[test]
fn test_filter_in() {
let filter = Filter::label_in(0, vec![1, 3, 5]);
assert!(filter.matches(&[1]));
assert!(filter.matches(&[3]));
assert!(filter.matches(&[5]));
assert!(!filter.matches(&[2]));
}
#[test]
fn test_filter_range() {
let filter = Filter::label_range(0, 10, 20);
assert!(filter.matches(&[10]));
assert!(filter.matches(&[15]));
assert!(filter.matches(&[20]));
assert!(!filter.matches(&[9]));
assert!(!filter.matches(&[21]));
}
#[test]
fn test_filter_and() {
let filter = Filter::and(vec![
Filter::label_eq(0, 5),
Filter::label_gt(1, 10),
]);
assert!(filter.matches(&[5, 15]));
assert!(!filter.matches(&[5, 5]));
assert!(!filter.matches(&[4, 15]));
}
#[test]
fn test_filter_or() {
let filter = Filter::or(vec![
Filter::label_eq(0, 5),
Filter::label_eq(0, 10),
]);
assert!(filter.matches(&[5]));
assert!(filter.matches(&[10]));
assert!(!filter.matches(&[7]));
}
#[test]
fn test_filtered_search_basic() {
let base_path = "test_filtered";
let _ = fs::remove_file(format!("{}.idx", base_path));
let _ = fs::remove_file(format!("{}.labels", base_path));
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| vec![i as f32, (i * 2) as f32])
.collect();
let labels: Vec<Vec<u64>> = (0..100)
.map(|i| vec![i % 5])
.collect();
let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
let results = index.search(&[50.0, 100.0], 5, 32);
assert_eq!(results.len(), 5);
let filter = Filter::label_eq(0, 0);
let results = index.search_filtered(&[50.0, 100.0], 5, 32, &filter);
for id in &results {
assert_eq!(labels[*id as usize][0], 0);
}
let _ = fs::remove_file(format!("{}.idx", base_path));
let _ = fs::remove_file(format!("{}.labels", base_path));
}
#[test]
fn test_filtered_search_selectivity() {
let base_path = "test_filtered_sel";
let _ = fs::remove_file(format!("{}.idx", base_path));
let _ = fs::remove_file(format!("{}.labels", base_path));
let vectors: Vec<Vec<f32>> = (0..1000)
.map(|i| vec![(i % 100) as f32, ((i / 100) * 10) as f32])
.collect();
let labels: Vec<Vec<u64>> = (0..1000)
.map(|i| vec![i % 10]) .collect();
let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
let filter = Filter::label_eq(0, 3);
assert_eq!(index.count_matching(&filter), 100);
let results = index.search_filtered(&[50.0, 50.0], 10, 64, &filter);
assert!(results.len() <= 10);
for id in &results {
assert_eq!(labels[*id as usize][0], 3);
}
let _ = fs::remove_file(format!("{}.idx", base_path));
let _ = fs::remove_file(format!("{}.labels", base_path));
}
#[test]
fn test_filtered_persistence() {
let base_path = "test_filtered_persist";
let _ = fs::remove_file(format!("{}.idx", base_path));
let _ = fs::remove_file(format!("{}.labels", base_path));
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| vec![i as f32, i as f32])
.collect();
let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3, i]).collect();
{
let _index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
}
let index = FilteredDiskANN::<DistL2>::open(base_path).unwrap();
assert_eq!(index.num_vectors(), 50);
assert_eq!(index.num_fields(), 2);
let filter = Filter::label_eq(0, 1);
let results = index.search_filtered(&[25.0, 25.0], 5, 32, &filter);
for id in &results {
assert_eq!(index.get_labels(*id as usize).unwrap()[0], 1);
}
let _ = fs::remove_file(format!("{}.idx", base_path));
let _ = fs::remove_file(format!("{}.labels", base_path));
}
#[test]
fn test_filtered_to_bytes_from_bytes() {
let base_path = "test_filtered_bytes_rt";
let _ = fs::remove_file(format!("{}.idx", base_path));
let _ = fs::remove_file(format!("{}.labels", base_path));
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| vec![i as f32, i as f32])
.collect();
let labels: Vec<Vec<u64>> = (0..50).map(|i| vec![i % 3]).collect();
let index = FilteredDiskANN::<DistL2>::build(&vectors, &labels, base_path).unwrap();
let bytes = index.to_bytes();
let index2 = FilteredDiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
assert_eq!(index2.num_vectors(), 50);
assert_eq!(index2.num_fields(), 1);
let filter = Filter::label_eq(0, 1);
let results = index2.search_filtered(&[25.0, 25.0], 5, 32, &filter);
for id in &results {
assert_eq!(index2.get_labels(*id as usize).unwrap()[0], 1);
}
let _ = fs::remove_file(format!("{}.idx", base_path));
let _ = fs::remove_file(format!("{}.labels", base_path));
}
}