use crate::simd;
use crate::RetrieveError;
pub struct RpForestIndex {
pub(crate) vectors: Vec<f32>,
pub(crate) dimension: usize,
pub(crate) num_vectors: usize,
params: RpForestParams,
built: bool,
pub(crate) trees: Vec<RPTree>,
}
#[derive(Clone, Debug)]
pub struct RpForestParams {
pub num_trees: usize,
pub tree_params: RPTreeParams,
}
impl Default for RpForestParams {
fn default() -> Self {
Self {
num_trees: 10,
tree_params: RPTreeParams::default(),
}
}
}
pub(crate) struct RPTree {
root: Option<TreeNode>,
}
enum TreeNode {
Leaf {
indices: Vec<u32>,
},
Internal {
hyperplane: Vec<f32>, #[allow(dead_code)]
threshold: f32, left: Box<TreeNode>,
right: Box<TreeNode>,
},
}
#[derive(Clone, Debug)]
pub struct RPTreeParams {
pub max_leaf_size: usize,
}
impl Default for RPTreeParams {
fn default() -> Self {
Self { max_leaf_size: 10 }
}
}
impl RpForestIndex {
pub fn new(dimension: usize, params: RpForestParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
Ok(Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
params,
built: false,
trees: Vec::new(),
})
}
pub fn add(&mut self, _doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"Cannot add vectors after index is built".to_string(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
self.vectors.extend_from_slice(&vector);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.trees = Vec::new();
for _ in 0..self.params.num_trees {
let tree = self.build_tree()?;
self.trees.push(tree);
}
self.built = true;
Ok(())
}
fn build_tree(&self) -> Result<RPTree, RetrieveError> {
let indices: Vec<u32> = (0..self.num_vectors as u32).collect();
let root = self.build_tree_recursive(&indices)?;
Ok(RPTree { root })
}
fn build_tree_recursive(&self, indices: &[u32]) -> Result<Option<TreeNode>, RetrieveError> {
if indices.is_empty() {
return Ok(None);
}
if indices.len() <= self.params.tree_params.max_leaf_size {
return Ok(Some(TreeNode::Leaf {
indices: indices.to_vec(),
}));
}
use rand::Rng;
let mut rng = rand::rng();
let mut hyperplane = Vec::with_capacity(self.dimension);
let mut norm = 0.0f32;
for _ in 0..self.dimension {
let val = rng.random::<f32>() * 2.0 - 1.0;
norm += val * val;
hyperplane.push(val);
}
let norm = norm.sqrt();
if norm > 0.0 {
for val in &mut hyperplane {
*val /= norm;
}
}
let mut left_indices = Vec::new();
let mut right_indices = Vec::new();
for &idx in indices {
let vec = self.get_vector(idx as usize);
let projection = simd::dot(vec, &hyperplane);
if projection < 0.0 {
left_indices.push(idx);
} else {
right_indices.push(idx);
}
}
if left_indices.is_empty() || right_indices.is_empty() {
return Ok(Some(TreeNode::Leaf {
indices: indices.to_vec(),
}));
}
let left = self.build_tree_recursive(&left_indices)?;
let right = self.build_tree_recursive(&right_indices)?;
Ok(Some(TreeNode::Internal {
hyperplane,
threshold: 0.0,
left: Box::new(left.unwrap_or(TreeNode::Leaf {
indices: Vec::new(),
})),
right: Box::new(right.unwrap_or(TreeNode::Leaf {
indices: Vec::new(),
})),
}))
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"Index must be built before search".to_string(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let mut candidate_set = std::collections::HashSet::new();
for tree in &self.trees {
if let Some(ref root) = tree.root {
let candidates = self.search_tree(root, query);
for idx in candidates {
candidate_set.insert(idx);
}
}
}
let mut results: Vec<(u32, f32)> = candidate_set
.iter()
.map(|&idx| {
let vec = self.get_vector(idx as usize);
let dist = 1.0 - simd::dot(query, vec);
(idx, dist)
})
.collect();
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1)); Ok(results.into_iter().take(k).collect())
}
fn search_tree(&self, node: &TreeNode, query: &[f32]) -> Vec<u32> {
match node {
TreeNode::Leaf { indices } => indices.clone(),
TreeNode::Internal {
hyperplane,
threshold: _,
left,
right,
} => {
let projection = simd::dot(query, hyperplane);
if projection < 0.0 {
self.search_tree(left, query)
} else {
self.search_tree(right, query)
}
}
}
}
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
let end = start + self.dimension;
&self.vectors[start..end]
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_index(n: usize, dim: usize) -> RpForestIndex {
let params = RpForestParams {
num_trees: 5,
tree_params: RPTreeParams { max_leaf_size: 10 },
};
let mut index = RpForestIndex::new(dim, params).unwrap();
for i in 0..n {
let mut v = vec![0.0f32; dim];
v[i % dim] = 1.0;
index.add(i as u32, v).unwrap();
}
index.build().unwrap();
index
}
#[test]
fn test_basic_search_returns_results() {
let index = build_index(50, 4);
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
}
#[test]
fn test_search_returns_at_most_k() {
let index = build_index(50, 4);
let query = vec![1.0, 0.0, 0.0, 0.0];
for k in [1, 3, 5, 10] {
let results = index.search(&query, k).unwrap();
assert!(results.len() <= k);
}
}
#[test]
fn test_results_sorted_by_distance() {
let index = build_index(50, 4);
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 10).unwrap();
for w in results.windows(2) {
assert!(w[0].1 <= w[1].1, "results not sorted: {:?}", results);
}
}
#[test]
fn test_ids_in_bounds() {
let n = 50usize;
let index = build_index(n, 4);
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 10).unwrap();
for (id, _) in results {
assert!((id as usize) < n, "id {} out of bounds (n={})", id, n);
}
}
#[test]
fn test_multiple_trees_improve_coverage() {
let dim = 8;
let n = 100;
let params = RpForestParams {
num_trees: 20,
tree_params: RPTreeParams { max_leaf_size: 5 },
};
let mut index = RpForestIndex::new(dim, params).unwrap();
for i in 0..n {
let mut v = vec![0.0f32; dim];
if i < n / 2 {
v[0] = 1.0;
v[1] = 0.01 * (i as f32);
} else {
v[1] = 1.0;
v[0] = 0.01 * (i as f32);
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut v {
*x /= norm;
}
index.add(i as u32, v).unwrap();
}
index.build().unwrap();
let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 10).unwrap();
let correct = results
.iter()
.filter(|(id, _)| *id < (n / 2) as u32)
.count();
assert!(
correct >= 5,
"only {correct}/10 results from correct cluster — hyperplane independence may be broken"
);
}
#[test]
fn test_build_errors_on_empty_index() {
let mut index = RpForestIndex::new(4, RpForestParams::default()).unwrap();
assert!(index.build().is_err());
}
#[test]
fn test_add_after_build_errors() {
let mut index = RpForestIndex::new(4, RpForestParams::default()).unwrap();
index.add(0, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
index.build().unwrap();
assert!(index.add(1, vec![0.0, 1.0, 0.0, 0.0]).is_err());
}
#[test]
fn test_dimension_mismatch_errors() {
let mut index = RpForestIndex::new(4, RpForestParams::default()).unwrap();
assert!(index.add(0, vec![1.0, 0.0]).is_err());
}
#[test]
fn test_degenerate_split_does_not_recurse_infinitely() {
let params = RpForestParams {
num_trees: 3,
tree_params: RPTreeParams { max_leaf_size: 2 },
};
let mut index = RpForestIndex::new(4, params).unwrap();
for i in 0..20u32 {
index.add(i, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
}
index.build().unwrap();
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 5).unwrap();
assert!(!results.is_empty());
}
}