use crate::distance::cosine_distance_normalized;
use crate::distance::FloatOrd;
use crate::RetrieveError;
use std::collections::{BinaryHeap, HashMap};
#[derive(Clone, Debug)]
pub struct CuratorParams {
pub branching_factor: usize,
pub max_leaf_size: usize,
pub ef_search: usize,
pub beam_width: usize,
}
impl Default for CuratorParams {
fn default() -> Self {
Self {
branching_factor: 16,
max_leaf_size: 128,
ef_search: 256,
beam_width: 4,
}
}
}
#[derive(Debug)]
struct BloomFilter {
bits: Vec<u64>,
num_bits: usize,
}
impl BloomFilter {
fn new(num_bits: usize) -> Self {
let words = num_bits.div_ceil(64);
Self {
bits: vec![0u64; words],
num_bits,
}
}
fn insert(&mut self, hash: u64) {
let h1 = hash as usize % self.num_bits;
let h2 = (hash.wrapping_shr(32) as usize).wrapping_mul(0x9e3779b9) % self.num_bits;
self.bits[h1 / 64] |= 1u64 << (h1 % 64);
self.bits[h2 / 64] |= 1u64 << (h2 % 64);
}
fn may_contain(&self, hash: u64) -> bool {
let h1 = hash as usize % self.num_bits;
let h2 = (hash.wrapping_shr(32) as usize).wrapping_mul(0x9e3779b9) % self.num_bits;
(self.bits[h1 / 64] & (1u64 << (h1 % 64))) != 0
&& (self.bits[h2 / 64] & (1u64 << (h2 % 64))) != 0
}
}
#[derive(Debug)]
struct TreeNode {
centroid: Vec<f32>,
children: Vec<usize>,
vector_ids: Vec<u32>,
label_bloom: BloomFilter,
label_buffers: HashMap<u64, Vec<u32>>,
}
pub struct CuratorIndex {
dimension: usize,
params: CuratorParams,
built: bool,
vectors: Vec<f32>,
num_vectors: usize,
doc_ids: Vec<u32>,
staging_labels: Vec<Vec<String>>,
nodes: Vec<TreeNode>,
root: usize,
}
impl CuratorIndex {
pub fn new(dimension: usize, params: CuratorParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
Ok(Self {
dimension,
params,
built: false,
vectors: Vec::new(),
num_vectors: 0,
doc_ids: Vec::new(),
staging_labels: Vec::new(),
nodes: Vec::new(),
root: 0,
})
}
pub fn add(
&mut self,
doc_id: u32,
vector: Vec<f32>,
labels: Vec<String>,
) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add after build".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
self.vectors.extend(vector.iter().map(|x| x / norm));
} else {
self.vectors.extend_from_slice(&vector);
}
self.doc_ids.push(doc_id);
self.staging_labels.push(labels);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.nodes.clear();
let all_ids: Vec<u32> = (0..self.num_vectors as u32).collect();
self.root = self.build_tree(&all_ids, 0);
let labels = std::mem::take(&mut self.staging_labels);
for (internal_id, doc_labels) in labels.iter().enumerate() {
for label in doc_labels {
let label_hash = hash_label(label);
self.insert_label_entry(self.root, internal_id as u32, label_hash);
}
}
self.staging_labels = labels;
self.built = true;
Ok(())
}
pub fn search_filtered(
&self,
query: &[f32],
k: usize,
label: &str,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let query_normalized: Vec<f32> = if query_norm > 1e-10 {
query.iter().map(|x| x / query_norm).collect()
} else {
query.to_vec()
};
let label_hash = hash_label(label);
self.best_first_search(&query_normalized, k, label_hash)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let query_normalized: Vec<f32> = if query_norm > 1e-10 {
query.iter().map(|x| x / query_norm).collect()
} else {
query.to_vec()
};
self.best_first_unfiltered(&query_normalized, k)
}
pub fn len(&self) -> usize {
self.num_vectors
}
pub fn is_empty(&self) -> bool {
self.num_vectors == 0
}
fn build_tree(&mut self, ids: &[u32], depth: usize) -> usize {
let dim = self.dimension;
let mut centroid = vec![0.0f32; dim];
for &id in ids {
let v = self.get_vector(id as usize);
for (j, &val) in v.iter().enumerate() {
centroid[j] += val;
}
}
let n = ids.len() as f32;
for c in &mut centroid {
*c /= n;
}
let node_idx = self.nodes.len();
if ids.len() <= self.params.max_leaf_size || depth > 15 {
self.nodes.push(TreeNode {
centroid,
children: Vec::new(),
vector_ids: ids.to_vec(),
label_bloom: BloomFilter::new(256),
label_buffers: HashMap::new(),
});
return node_idx;
}
let k = self.params.branching_factor.min(ids.len());
let assignments = self.simple_kmeans(ids, k);
self.nodes.push(TreeNode {
centroid,
children: Vec::new(),
vector_ids: Vec::new(),
label_bloom: BloomFilter::new(256),
label_buffers: HashMap::new(),
});
let mut children = Vec::with_capacity(k);
for cluster in &assignments {
if cluster.is_empty() {
continue;
}
let child_idx = self.build_tree(cluster, depth + 1);
children.push(child_idx);
}
self.nodes[node_idx].children = children;
node_idx
}
fn simple_kmeans(&self, ids: &[u32], k: usize) -> Vec<Vec<u32>> {
let dim = self.dimension;
let n = ids.len();
let step = n / k;
let mut centroids: Vec<Vec<f32>> = (0..k)
.map(|i| {
let idx = ids[(i * step).min(n - 1)] as usize;
self.get_vector(idx).to_vec()
})
.collect();
let mut assignments = vec![Vec::new(); k];
for _ in 0..3 {
for a in &mut assignments {
a.clear();
}
for &id in ids {
let v = self.get_vector(id as usize);
let mut best_c = 0;
let mut best_d = f32::INFINITY;
for (ci, c) in centroids.iter().enumerate() {
let d = cosine_distance_normalized(v, c);
if d < best_d {
best_d = d;
best_c = ci;
}
}
assignments[best_c].push(id);
}
for (ci, cluster) in assignments.iter().enumerate() {
if cluster.is_empty() {
continue;
}
let mut new_centroid = vec![0.0f32; dim];
for &id in cluster {
let v = self.get_vector(id as usize);
for (j, &val) in v.iter().enumerate() {
new_centroid[j] += val;
}
}
let cn = cluster.len() as f32;
for c in &mut new_centroid {
*c /= cn;
}
centroids[ci] = new_centroid;
}
}
assignments
}
fn insert_label_entry(&mut self, node_idx: usize, vector_id: u32, label_hash: u64) {
self.nodes[node_idx].label_bloom.insert(label_hash);
if self.nodes[node_idx].children.is_empty() {
self.nodes[node_idx]
.label_buffers
.entry(label_hash)
.or_default()
.push(vector_id);
} else {
let children = self.nodes[node_idx].children.clone();
for &child_idx in &children {
if self.subtree_contains(child_idx, vector_id) {
self.insert_label_entry(child_idx, vector_id, label_hash);
return;
}
}
}
}
fn subtree_contains(&self, node_idx: usize, vector_id: u32) -> bool {
let node = &self.nodes[node_idx];
if node.children.is_empty() {
return node.vector_ids.contains(&vector_id);
}
for &child in &node.children {
if self.subtree_contains(child, vector_id) {
return true;
}
}
false
}
fn best_first_search(
&self,
query: &[f32],
k: usize,
label_hash: u64,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
let num_nodes = self.nodes.len();
thread_local! {
static VISITED: std::cell::RefCell<(Vec<u8>, u8)> =
const { std::cell::RefCell::new((Vec::new(), 1)) };
}
VISITED.with(|cell| {
let (marks, gen) = &mut *cell.borrow_mut();
if marks.len() < num_nodes {
marks.resize(num_nodes, 0);
}
if let Some(next) = gen.checked_add(1) {
*gen = next;
} else {
marks.fill(0);
*gen = 1;
}
let generation = *gen;
let mut visited_insert = |idx: usize| -> bool {
if idx < marks.len() && marks[idx] != generation {
marks[idx] = generation;
true
} else {
idx >= marks.len()
}
};
let mut heap: BinaryHeap<std::cmp::Reverse<(FloatOrd, usize)>> = BinaryHeap::new();
let mut results: Vec<(u32, f32)> = Vec::new();
let mut explored = 0usize;
let root_dist = cosine_distance_normalized(query, &self.nodes[self.root].centroid);
heap.push(std::cmp::Reverse((FloatOrd(root_dist), self.root)));
while let Some(std::cmp::Reverse((_, node_idx))) = heap.pop() {
if !visited_insert(node_idx) {
continue;
}
explored += 1;
if explored > self.params.ef_search {
break;
}
let node = &self.nodes[node_idx];
if !node.label_bloom.may_contain(label_hash) {
continue;
}
if node.children.is_empty() {
if let Some(buffer) = node.label_buffers.get(&label_hash) {
for &vid in buffer {
let v = self.get_vector(vid as usize);
let dist = cosine_distance_normalized(query, v);
let doc_id = self.doc_ids[vid as usize];
results.push((doc_id, dist));
}
}
} else {
for &child_idx in &node.children {
let child_dist =
cosine_distance_normalized(query, &self.nodes[child_idx].centroid);
heap.push(std::cmp::Reverse((FloatOrd(child_dist), child_idx)));
}
}
}
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
results.dedup_by_key(|c| c.0);
results.truncate(k);
Ok(results)
})
}
fn best_first_unfiltered(
&self,
query: &[f32],
k: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
let num_nodes = self.nodes.len();
thread_local! {
static VISITED_UF: std::cell::RefCell<(Vec<u8>, u8)> =
const { std::cell::RefCell::new((Vec::new(), 1)) };
}
VISITED_UF.with(|cell| {
let (marks, gen) = &mut *cell.borrow_mut();
if marks.len() < num_nodes {
marks.resize(num_nodes, 0);
}
if let Some(next) = gen.checked_add(1) {
*gen = next;
} else {
marks.fill(0);
*gen = 1;
}
let generation = *gen;
let mut visited_insert = |idx: usize| -> bool {
if idx < marks.len() && marks[idx] != generation {
marks[idx] = generation;
true
} else {
idx >= marks.len()
}
};
let mut heap: BinaryHeap<std::cmp::Reverse<(FloatOrd, usize)>> = BinaryHeap::new();
let mut results: Vec<(u32, f32)> = Vec::new();
let mut explored = 0usize;
let root_dist = cosine_distance_normalized(query, &self.nodes[self.root].centroid);
heap.push(std::cmp::Reverse((FloatOrd(root_dist), self.root)));
while let Some(std::cmp::Reverse((_, node_idx))) = heap.pop() {
if !visited_insert(node_idx) {
continue;
}
explored += 1;
if explored > self.params.ef_search {
break;
}
let node = &self.nodes[node_idx];
if node.children.is_empty() {
for &vid in &node.vector_ids {
let v = self.get_vector(vid as usize);
let dist = cosine_distance_normalized(query, v);
let doc_id = self.doc_ids[vid as usize];
results.push((doc_id, dist));
}
} else {
for &child_idx in &node.children {
let child_dist =
cosine_distance_normalized(query, &self.nodes[child_idx].centroid);
heap.push(std::cmp::Reverse((FloatOrd(child_dist), child_idx)));
}
}
}
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
results.dedup_by_key(|c| c.0);
results.truncate(k);
Ok(results)
})
}
#[inline]
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
&self.vectors[start..start + self.dimension]
}
}
fn hash_label(label: &str) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for byte in label.as_bytes() {
hash ^= *byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_vector(dim: usize, seed: u32) -> Vec<f32> {
(0..dim)
.map(|i| (seed as f32 * 0.1 + i as f32 * 0.01).sin())
.collect()
}
#[test]
fn build_and_filtered_search() {
let dim = 16;
let mut index = CuratorIndex::new(dim, CuratorParams::default()).unwrap();
for i in 0..50u32 {
let label = if i % 2 == 0 { "red" } else { "blue" };
index
.add(i, make_vector(dim, i), vec![label.into()])
.unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 0);
let results = index.search_filtered(&query, 5, "red").unwrap();
assert!(!results.is_empty());
for (doc_id, _) in &results {
assert_eq!(doc_id % 2, 0, "expected even doc_id (red), got {}", doc_id);
}
}
#[test]
fn unfiltered_search() {
let dim = 16;
let mut index = CuratorIndex::new(dim, CuratorParams::default()).unwrap();
for i in 0..30u32 {
index
.add(i, make_vector(dim, i), vec!["any".into()])
.unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 0);
let results = index.search(&query, 5).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 0);
}
#[test]
fn nonexistent_label_returns_empty() {
let dim = 16;
let mut index = CuratorIndex::new(dim, CuratorParams::default()).unwrap();
for i in 0..20u32 {
index
.add(i, make_vector(dim, i), vec!["exists".into()])
.unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 0);
let results = index.search_filtered(&query, 5, "nonexistent").unwrap();
assert!(results.is_empty());
}
#[test]
fn multi_label_vectors() {
let dim = 16;
let mut index = CuratorIndex::new(dim, CuratorParams::default()).unwrap();
for i in 0..20u32 {
let mut labels = vec!["all".to_string()];
if i % 3 == 0 {
labels.push("mod3".into());
}
index.add(i, make_vector(dim, i), labels).unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 0);
let results = index.search_filtered(&query, 10, "mod3").unwrap();
for (doc_id, _) in &results {
assert_eq!(doc_id % 3, 0, "expected mod3 doc_id, got {}", doc_id);
}
let results = index.search_filtered(&query, 5, "all").unwrap();
assert!(!results.is_empty());
}
#[test]
fn low_selectivity_recall() {
let dim = 16;
let mut index = CuratorIndex::new(
dim,
CuratorParams {
max_leaf_size: 32,
ef_search: 512,
..Default::default()
},
)
.unwrap();
for i in 0..100u32 {
let label = if i < 5 { "rare" } else { "common" };
index
.add(i, make_vector(dim, i), vec![label.into()])
.unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 2); let results = index.search_filtered(&query, 3, "rare").unwrap();
assert!(!results.is_empty());
for (doc_id, _) in &results {
assert!(*doc_id < 5, "expected rare doc_id < 5, got {}", doc_id);
}
}
#[test]
fn empty_index_errors() {
let mut index = CuratorIndex::new(8, CuratorParams::default()).unwrap();
assert!(index.build().is_err());
}
}