use crate::distance::cosine_distance_normalized;
use crate::RetrieveError;
use smallvec::SmallVec;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Clone, Debug, PartialEq)]
pub enum AttrValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
}
#[derive(Clone, Debug)]
pub enum Predicate {
Eq(String, AttrValue),
Lt(String, AttrValue),
Le(String, AttrValue),
Gt(String, AttrValue),
Ge(String, AttrValue),
In(String, Vec<AttrValue>),
}
#[derive(Clone, Debug)]
pub enum Filter {
Clause(Predicate),
And(Vec<Filter>),
Or(Vec<Filter>),
}
#[derive(Clone, Debug)]
pub struct FilteredGraphParams {
pub max_degree: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub alpha: f32,
}
impl Default for FilteredGraphParams {
fn default() -> Self {
Self {
max_degree: 32,
ef_construction: 200,
ef_search: 100,
alpha: 1.2,
}
}
}
pub struct FilteredGraphIndex {
dimension: usize,
params: FilteredGraphParams,
built: bool,
vectors: Vec<f32>,
num_vectors: usize,
doc_ids: Vec<u32>,
staging_attrs: Vec<HashMap<String, AttrValue>>,
neighbors: Vec<SmallVec<[u32; 16]>>,
medoid: u32,
inverted: HashMap<String, Vec<(AttrValue, Vec<u32>)>>,
}
impl FilteredGraphIndex {
pub fn new(dimension: usize, params: FilteredGraphParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
Ok(Self {
dimension,
params,
built: false,
vectors: Vec::new(),
num_vectors: 0,
doc_ids: Vec::new(),
staging_attrs: Vec::new(),
neighbors: Vec::new(),
medoid: 0,
inverted: HashMap::new(),
})
}
pub fn add(
&mut self,
doc_id: u32,
vector: Vec<f32>,
attrs: HashMap<String, AttrValue>,
) -> Result<(), RetrieveError> {
self.add_slice(doc_id, &vector, attrs)
}
pub fn add_slice(
&mut self,
doc_id: u32,
vector: &[f32],
attrs: HashMap<String, AttrValue>,
) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add vectors after build".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
self.vectors.extend(vector.iter().map(|x| x / norm));
} else {
self.vectors.extend_from_slice(vector);
}
self.doc_ids.push(doc_id);
self.staging_attrs.push(attrs);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.medoid = self.compute_medoid();
self.build_knn_graph();
self.rng_refine();
self.ensure_connectivity();
self.build_inverted_indexes();
self.built = true;
Ok(())
}
pub fn search_filtered(
&self,
query: &[f32],
k: usize,
filter: &Filter,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let qn = normalize(query);
let matching = self.evaluate_filter(filter);
if matching.is_empty() {
return Ok(Vec::new());
}
let selectivity = matching.len() as f32 / self.num_vectors as f32;
let mut results: Vec<(u32, f32)> = if selectivity >= 0.10 {
let ef = self.params.ef_search.max(k * 4);
let candidates = self.beam_search(&qn, ef);
candidates
.into_iter()
.filter(|(id, _)| matching.contains(id))
.take(k)
.map(|(id, d)| (self.doc_ids[id as usize], d))
.collect()
} else {
let mut scored: Vec<(u32, f32)> = matching
.iter()
.map(|&id| {
let d = cosine_distance_normalized(&qn, self.get_vector(id as usize));
(id, d)
})
.collect();
scored.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
scored.truncate(k);
scored
.into_iter()
.map(|(id, d)| (self.doc_ids[id as usize], d))
.collect()
};
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
Ok(results)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let qn = normalize(query);
let ef = self.params.ef_search.max(k);
let candidates = self.beam_search(&qn, ef);
Ok(candidates
.into_iter()
.take(k)
.map(|(id, d)| (self.doc_ids[id as usize], d))
.collect())
}
pub fn search_with_ef(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let qn = normalize(query);
let ef = ef_search.max(k);
let candidates = self.beam_search(&qn, ef);
Ok(candidates
.into_iter()
.take(k)
.map(|(id, d)| (self.doc_ids[id as usize], d))
.collect())
}
pub fn len(&self) -> usize {
self.num_vectors
}
pub fn is_empty(&self) -> bool {
self.num_vectors == 0
}
fn evaluate_filter(&self, filter: &Filter) -> HashSet<u32> {
match filter {
Filter::Clause(pred) => self.evaluate_predicate(pred),
Filter::And(children) => {
if children.is_empty() {
return (0..self.num_vectors as u32).collect();
}
let mut result = self.evaluate_filter(&children[0]);
for child in &children[1..] {
if result.is_empty() {
break;
}
let child_set = self.evaluate_filter(child);
result.retain(|id| child_set.contains(id));
}
result
}
Filter::Or(children) => {
let mut result = HashSet::new();
for child in children {
result.extend(self.evaluate_filter(child));
}
result
}
}
}
fn evaluate_predicate(&self, pred: &Predicate) -> HashSet<u32> {
match pred {
Predicate::Eq(attr, val) => self.inverted_eq(attr, val),
Predicate::Lt(attr, val) => self.inverted_range(attr, None, Some(val), false, false),
Predicate::Le(attr, val) => self.inverted_range(attr, None, Some(val), false, true),
Predicate::Gt(attr, val) => self.inverted_range(attr, Some(val), None, false, false),
Predicate::Ge(attr, val) => self.inverted_range(attr, Some(val), None, true, false),
Predicate::In(attr, vals) => {
let mut result = HashSet::new();
for v in vals {
result.extend(self.inverted_eq(attr, v));
}
result
}
}
}
fn inverted_eq(&self, attr: &str, val: &AttrValue) -> HashSet<u32> {
let Some(entries) = self.inverted.get(attr) else {
return HashSet::new();
};
let pos = entries
.partition_point(|(v, _)| compare_attr(v, val).is_none_or(|o| o == Ordering::Less));
let mut result = HashSet::new();
for (v, ids) in &entries[pos..] {
if compare_attr(v, val) != Some(Ordering::Equal) {
break;
}
result.extend(ids.iter().copied());
}
result
}
fn inverted_range(
&self,
attr: &str,
lo: Option<&AttrValue>,
hi: Option<&AttrValue>,
lo_inclusive: bool,
hi_inclusive: bool,
) -> HashSet<u32> {
let Some(entries) = self.inverted.get(attr) else {
return HashSet::new();
};
let mut result = HashSet::new();
for (v, ids) in entries {
let lo_ok = match lo {
None => true,
Some(bound) => match compare_attr(v, bound) {
Some(Ordering::Greater) => true,
Some(Ordering::Equal) => lo_inclusive,
_ => false,
},
};
let hi_ok = match hi {
None => true,
Some(bound) => match compare_attr(v, bound) {
Some(Ordering::Less) => true,
Some(Ordering::Equal) => hi_inclusive,
_ => false,
},
};
if lo_ok && hi_ok {
result.extend(ids.iter().copied());
}
}
result
}
fn compute_medoid(&self) -> u32 {
let dim = self.dimension;
let n = self.num_vectors;
let mut centroid = vec![0.0f32; dim];
for i in 0..n {
let v = self.get_vector(i);
for (j, &val) in v.iter().enumerate() {
centroid[j] += val;
}
}
for c in &mut centroid {
*c /= n as f32;
}
let mut best = 0u32;
let mut best_d = f32::INFINITY;
for i in 0..n {
let d = cosine_distance_normalized(¢roid, self.get_vector(i));
if d < best_d {
best_d = d;
best = i as u32;
}
}
best
}
fn build_knn_graph(&mut self) {
let n = self.num_vectors;
if n <= 1000 {
self.build_knn_graph_bruteforce();
} else {
self.build_knn_graph_nndescent();
}
}
fn build_knn_graph_bruteforce(&mut self) {
let n = self.num_vectors;
let k = self.params.max_degree.min(n.saturating_sub(1));
self.neighbors = vec![SmallVec::new(); n];
for i in 0..n {
let vi = self.get_vector(i);
let mut dists: Vec<(u32, f32)> = (0..n)
.filter(|&j| j != i)
.map(|j| (j as u32, cosine_distance_normalized(vi, self.get_vector(j))))
.collect();
dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
dists.truncate(k);
self.neighbors[i] = dists.iter().map(|(id, _)| *id).collect();
}
}
fn build_knn_graph_nndescent(&mut self) {
let (n, k, dim) = (self.num_vectors, self.params.max_degree, self.dimension);
let vecs = &self.vectors;
self.neighbors = crate::graph_utils::build_knn_graph_nndescent(n, k, |i, j| {
cosine_distance_normalized(&vecs[i * dim..(i + 1) * dim], &vecs[j * dim..(j + 1) * dim])
});
}
fn rng_refine(&mut self) {
let n = self.num_vectors;
let ef = self.params.ef_construction;
for i in 0..n {
let vi = self.get_vector(i).to_vec();
let candidates = self.beam_search(&vi, ef);
let selected = self.rng_prune(&vi, &candidates);
let old = std::mem::replace(
&mut self.neighbors[i],
selected.iter().map(|&(id, _)| id).collect(),
);
let max_deg = self.params.max_degree;
for &(nb_id, _) in &selected {
let nid = nb_id as usize;
if !self.neighbors[nid].contains(&(i as u32)) {
if self.neighbors[nid].len() < max_deg {
self.neighbors[nid].push(i as u32);
} else {
let nv = self.get_vector(nid).to_vec();
let rev_cands: Vec<(u32, f32)> = self.neighbors[nid]
.iter()
.chain(std::iter::once(&(i as u32)))
.map(|&id| {
let d =
cosine_distance_normalized(&nv, self.get_vector(id as usize));
(id, d)
})
.collect();
let pruned = self.rng_prune(&nv, &rev_cands);
self.neighbors[nid] = pruned.iter().map(|&(id, _)| id).collect();
}
}
}
drop(old);
}
}
fn rng_prune(&self, _query_vec: &[f32], candidates: &[(u32, f32)]) -> Vec<(u32, f32)> {
let mut sorted: Vec<(u32, f32)> = candidates.to_vec();
sorted.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
sorted.dedup_by_key(|c| c.0);
let max_deg = self.params.max_degree;
let alpha = self.params.alpha;
let mut selected: Vec<(u32, f32)> = Vec::with_capacity(max_deg);
'outer: for &(cand_id, cand_dist) in &sorted {
if selected.len() >= max_deg {
break;
}
let cand_vec = self.get_vector(cand_id as usize);
for &(sel_id, _) in &selected {
let sel_vec = self.get_vector(sel_id as usize);
let inter = cosine_distance_normalized(sel_vec, cand_vec);
if alpha * cand_dist >= inter {
continue 'outer;
}
}
selected.push((cand_id, cand_dist));
}
selected
}
fn ensure_connectivity(&mut self) {
let (dim, vecs) = (self.dimension, &self.vectors);
crate::graph_utils::ensure_connectivity(&mut self.neighbors, self.medoid, |i, j| {
cosine_distance_normalized(&vecs[i * dim..(i + 1) * dim], &vecs[j * dim..(j + 1) * dim])
});
}
fn build_inverted_indexes(&mut self) {
let mut raw: HashMap<String, Vec<(AttrValue, u32)>> = HashMap::new();
for (internal_id, attrs) in self.staging_attrs.iter().enumerate() {
for (attr, val) in attrs {
raw.entry(attr.clone())
.or_default()
.push((val.clone(), internal_id as u32));
}
}
for (attr, mut pairs) in raw {
pairs.sort_unstable_by(|(a, _), (b, _)| compare_attr(a, b).unwrap_or(Ordering::Equal));
let mut entries: Vec<(AttrValue, Vec<u32>)> = Vec::new();
for (val, id) in pairs {
if let Some(last) = entries.last_mut() {
if last.0 == val {
last.1.push(id);
continue;
}
}
entries.push((val, vec![id]));
}
for (_, ids) in &mut entries {
ids.sort_unstable();
}
self.inverted.insert(attr, entries);
}
self.staging_attrs = Vec::new();
}
fn beam_search(&self, query: &[f32], ef: usize) -> Vec<(u32, f32)> {
let n = self.num_vectors;
if n == 0 {
return Vec::new();
}
let mut visited: HashSet<u32> = HashSet::new();
let mut frontier: BinaryHeap<std::cmp::Reverse<(FloatOrd, u32)>> = BinaryHeap::new();
let mut candidates: Vec<(u32, f32)> = Vec::new();
let entry = self.medoid;
let entry_dist = cosine_distance_normalized(query, self.get_vector(entry as usize));
visited.insert(entry);
frontier.push(std::cmp::Reverse((FloatOrd(entry_dist), entry)));
candidates.push((entry, entry_dist));
while let Some(std::cmp::Reverse((FloatOrd(current_dist), current_id))) = frontier.pop() {
if candidates.len() >= ef {
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
if current_dist > candidates[ef - 1].1 * 1.5 {
break;
}
}
for &nb in &self.neighbors[current_id as usize] {
if visited.insert(nb) {
let d = cosine_distance_normalized(query, self.get_vector(nb as usize));
candidates.push((nb, d));
frontier.push(std::cmp::Reverse((FloatOrd(d), nb)));
}
}
if visited.len() > ef * 10 {
break;
}
}
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
candidates.dedup_by_key(|c| c.0);
candidates
}
#[inline]
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
&self.vectors[start..start + self.dimension]
}
}
fn compare_attr(a: &AttrValue, b: &AttrValue) -> Option<Ordering> {
match (a, b) {
(AttrValue::Int(x), AttrValue::Int(y)) => Some(x.cmp(y)),
(AttrValue::Float(x), AttrValue::Float(y)) => x.partial_cmp(y),
(AttrValue::String(x), AttrValue::String(y)) => Some(x.cmp(y)),
(AttrValue::Bool(x), AttrValue::Bool(y)) => Some(x.cmp(y)),
_ => None,
}
}
fn normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
use crate::distance::FloatOrd;
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn lcg(seed: &mut u64) -> f32 {
*seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
((*seed >> 33) as f32 / (1u64 << 31) as f32) - 1.0
}
fn make_vectors(n: usize, dim: usize, seed: u64) -> Vec<f32> {
let mut s = seed;
(0..n * dim).map(|_| lcg(&mut s)).collect()
}
fn build_index(n: usize, dim: usize, seed: u64) -> FilteredGraphIndex {
let data = make_vectors(n, dim, seed);
let params = FilteredGraphParams {
max_degree: 16,
ef_construction: 40,
ef_search: 40,
alpha: 1.2,
};
let mut idx = FilteredGraphIndex::new(dim, params).unwrap();
for i in 0..n {
let start = i * dim;
idx.add_slice(i as u32, &data[start..start + dim], HashMap::new())
.unwrap();
}
idx.build().unwrap();
idx
}
#[test]
fn build_and_search_unfiltered() {
let dim = 16;
let n = 50;
let idx = build_index(n, dim, 42);
let data = make_vectors(n, dim, 42);
let query = &data[0..dim];
let results = idx.search(query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.iter().any(|(id, _)| *id == 0));
}
#[test]
fn eq_filter() {
let dim = 8;
let n = 30;
let data = make_vectors(n, dim, 7);
let params = FilteredGraphParams {
max_degree: 12,
ef_construction: 30,
ef_search: 30,
alpha: 1.2,
};
let mut idx = FilteredGraphIndex::new(dim, params).unwrap();
for i in 0..n {
let start = i * dim;
let label = if i % 2 == 0 { "even" } else { "odd" };
let mut attrs = HashMap::new();
attrs.insert("parity".to_string(), AttrValue::String(label.to_string()));
idx.add_slice(i as u32, &data[start..start + dim], attrs)
.unwrap();
}
idx.build().unwrap();
let filter = Filter::Clause(Predicate::Eq(
"parity".to_string(),
AttrValue::String("even".to_string()),
));
let results = idx.search_filtered(&data[0..dim], 20, &filter).unwrap();
for (doc_id, _) in &results {
assert_eq!(doc_id % 2, 0, "doc_id {doc_id} is not even");
}
assert!(!results.is_empty());
}
#[test]
fn range_filter() {
let dim = 8;
let n = 40;
let data = make_vectors(n, dim, 13);
let params = FilteredGraphParams {
max_degree: 12,
ef_construction: 30,
ef_search: 30,
alpha: 1.2,
};
let mut idx = FilteredGraphIndex::new(dim, params).unwrap();
for i in 0..n {
let start = i * dim;
let mut attrs = HashMap::new();
attrs.insert("score".to_string(), AttrValue::Int(i as i64));
idx.add_slice(i as u32, &data[start..start + dim], attrs)
.unwrap();
}
idx.build().unwrap();
let filter = Filter::And(vec![
Filter::Clause(Predicate::Gt("score".to_string(), AttrValue::Int(9))),
Filter::Clause(Predicate::Lt("score".to_string(), AttrValue::Int(21))),
]);
let results = idx.search_filtered(&data[0..dim], 20, &filter).unwrap();
for (doc_id, _) in &results {
let score = *doc_id as i64; assert!(score > 9 && score < 21, "score {score} outside (9, 21)");
}
assert!(!results.is_empty());
}
#[test]
fn and_filter() {
let dim = 8;
let n = 40;
let data = make_vectors(n, dim, 17);
let params = FilteredGraphParams {
max_degree: 12,
ef_construction: 30,
ef_search: 30,
alpha: 1.2,
};
let mut idx = FilteredGraphIndex::new(dim, params).unwrap();
for i in 0..n {
let start = i * dim;
let mut attrs = HashMap::new();
let parity = if i % 2 == 0 { "even" } else { "odd" };
let tier = if i < 20 { "low" } else { "high" };
attrs.insert("parity".to_string(), AttrValue::String(parity.to_string()));
attrs.insert("tier".to_string(), AttrValue::String(tier.to_string()));
idx.add_slice(i as u32, &data[start..start + dim], attrs)
.unwrap();
}
idx.build().unwrap();
let filter = Filter::And(vec![
Filter::Clause(Predicate::Eq(
"parity".to_string(),
AttrValue::String("even".to_string()),
)),
Filter::Clause(Predicate::Eq(
"tier".to_string(),
AttrValue::String("high".to_string()),
)),
]);
let results = idx.search_filtered(&data[0..dim], 20, &filter).unwrap();
for (doc_id, _) in &results {
assert_eq!(doc_id % 2, 0, "doc_id {doc_id} should be even");
assert!(*doc_id >= 20, "doc_id {doc_id} should be >= 20");
}
}
#[test]
fn or_filter() {
let dim = 8;
let n = 30;
let data = make_vectors(n, dim, 23);
let params = FilteredGraphParams {
max_degree: 12,
ef_construction: 30,
ef_search: 30,
alpha: 1.2,
};
let mut idx = FilteredGraphIndex::new(dim, params).unwrap();
for i in 0..n {
let start = i * dim;
let mut attrs = HashMap::new();
let label = match i % 3 {
0 => "alpha",
1 => "beta",
_ => "gamma",
};
attrs.insert("group".to_string(), AttrValue::String(label.to_string()));
idx.add_slice(i as u32, &data[start..start + dim], attrs)
.unwrap();
}
idx.build().unwrap();
let filter = Filter::Or(vec![
Filter::Clause(Predicate::Eq(
"group".to_string(),
AttrValue::String("alpha".to_string()),
)),
Filter::Clause(Predicate::Eq(
"group".to_string(),
AttrValue::String("beta".to_string()),
)),
]);
let results = idx.search_filtered(&data[0..dim], 20, &filter).unwrap();
for (doc_id, _) in &results {
assert_ne!(
doc_id % 3,
2,
"doc_id {doc_id} should not be in group gamma"
);
}
assert!(!results.is_empty());
}
#[test]
fn no_match_returns_empty() {
let dim = 8;
let n = 20;
let data = make_vectors(n, dim, 99);
let params = FilteredGraphParams {
max_degree: 8,
ef_construction: 20,
ef_search: 20,
alpha: 1.2,
};
let mut idx = FilteredGraphIndex::new(dim, params).unwrap();
for i in 0..n {
let start = i * dim;
let mut attrs = HashMap::new();
attrs.insert("tag".to_string(), AttrValue::String("present".to_string()));
idx.add_slice(i as u32, &data[start..start + dim], attrs)
.unwrap();
}
idx.build().unwrap();
let filter = Filter::Clause(Predicate::Eq(
"tag".to_string(),
AttrValue::String("absent".to_string()),
));
let results = idx.search_filtered(&data[0..dim], 5, &filter).unwrap();
assert!(results.is_empty());
}
#[test]
fn empty_index_errors() {
let mut idx = FilteredGraphIndex::new(8, FilteredGraphParams::default()).unwrap();
assert!(idx.build().is_err());
}
}