use crate::distance;
use crate::lid::{estimate_lid_for_hnsw, LidEstimate, LidStats};
use crate::RetrieveError;
use rand::prelude::*;
use std::collections::{BinaryHeap, HashSet};
#[derive(Debug, Clone)]
pub struct DualBranchConfig {
pub m: usize,
pub m_high_lid: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub lid_k: usize,
pub lid_threshold_sigma: f32,
pub skip_bridge_probability: f32,
pub max_skip_length: usize,
pub seed: Option<u64>,
}
impl Default for DualBranchConfig {
fn default() -> Self {
Self {
m: 16,
m_high_lid: 24, ef_construction: 200,
ef_search: 50,
lid_k: 20,
lid_threshold_sigma: 1.5, skip_bridge_probability: 0.1,
max_skip_length: 3,
seed: None,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SkipBridge {
pub from: u32,
pub to: u32,
pub skip_length: u32,
}
#[derive(Debug)]
pub struct DualBranchHNSW {
vectors: Vec<f32>,
dimension: usize,
num_vectors: usize,
neighbors: Vec<Vec<u32>>,
skip_bridges: Vec<SkipBridge>,
skip_adjacency: Vec<Vec<usize>>,
lid_estimates: Vec<LidEstimate>,
lid_stats: Option<LidStats>,
config: DualBranchConfig,
entry_point: Option<u32>,
built: bool,
}
impl DualBranchHNSW {
pub fn new(dimension: usize, config: DualBranchConfig) -> Self {
Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
neighbors: Vec::new(),
skip_bridges: Vec::new(),
skip_adjacency: Vec::new(),
lid_estimates: Vec::new(),
lid_stats: None,
config,
entry_point: None,
built: false,
}
}
pub fn add_vectors(&mut self, vectors: &[f32]) -> Result<(), RetrieveError> {
if vectors.len() % self.dimension != 0 {
return Err(RetrieveError::DimensionMismatch {
query_dim: vectors.len(),
doc_dim: self.dimension,
});
}
self.vectors.extend_from_slice(vectors);
let new_count = vectors.len() / self.dimension;
self.num_vectors += new_count;
self.built = false;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let mut rng: Box<dyn RngCore> = match self.config.seed {
Some(s) => Box::new(StdRng::seed_from_u64(s)),
None => Box::new(rand::rng()),
};
self.neighbors = vec![Vec::new(); self.num_vectors];
self.lid_estimates = vec![
LidEstimate {
lid: 0.0,
k: 0,
max_dist: 0.0
};
self.num_vectors
];
for i in 0..self.num_vectors {
self.insert_node(i as u32, &mut rng)?;
}
self.compute_all_lid();
self.enhance_high_lid_nodes(&mut rng)?;
self.add_skip_bridges(&mut rng)?;
self.built = true;
Ok(())
}
fn insert_node(&mut self, node_id: u32, _rng: &mut dyn RngCore) -> Result<(), RetrieveError> {
let query: Vec<f32> = self.get_vector(node_id as usize).to_vec();
if self.entry_point.is_none() {
self.entry_point = Some(node_id);
return Ok(());
}
let entry = self.entry_point.ok_or(RetrieveError::EmptyIndex)?;
let candidates = self.search_layer(&query, entry, self.config.ef_construction);
let m = self.config.m;
let selected: Vec<u32> = candidates.iter().take(m).map(|&(id, _)| id).collect();
let mut updates: Vec<(u32, u32)> = Vec::new(); let mut prune_list: Vec<usize> = Vec::new();
for &neighbor in &selected {
if neighbor != node_id {
updates.push((node_id, neighbor));
if !self.neighbors[neighbor as usize].contains(&node_id) {
updates.push((neighbor, node_id));
if self.neighbors[neighbor as usize].len() + 1 > m * 2 {
prune_list.push(neighbor as usize);
}
}
}
}
for (from, to) in updates {
if !self.neighbors[from as usize].contains(&to) {
self.neighbors[from as usize].push(to);
}
}
for node in prune_list {
self.prune_neighbors(node, m);
}
let entry_dist = distance::l2_distance(&query, self.get_vector(entry as usize));
if !candidates.is_empty() && candidates[0].1 < entry_dist {
self.entry_point = Some(candidates[0].0);
}
Ok(())
}
fn search_layer(&self, query: &[f32], entry: u32, ef: usize) -> Vec<(u32, f32)> {
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut results = BinaryHeap::new();
let entry_dist = distance::l2_distance(query, self.get_vector(entry as usize));
visited.insert(entry);
candidates.push(MinCandidate {
id: entry,
dist: entry_dist,
});
results.push(MaxCandidate {
id: entry,
dist: entry_dist,
});
while let Some(MinCandidate {
id: current,
dist: current_dist,
}) = candidates.pop()
{
if let Some(worst) = results.peek() {
if current_dist > worst.dist && results.len() >= ef {
break;
}
}
for &neighbor in &self.neighbors[current as usize] {
if visited.insert(neighbor) {
let dist = distance::l2_distance(query, self.get_vector(neighbor as usize));
let should_add =
results.len() < ef || results.peek().map(|w| dist < w.dist).unwrap_or(true);
if should_add {
candidates.push(MinCandidate { id: neighbor, dist });
results.push(MaxCandidate { id: neighbor, dist });
if results.len() > ef {
results.pop();
}
}
}
}
}
let mut result_vec: Vec<(u32, f32)> = results.into_iter().map(|c| (c.id, c.dist)).collect();
result_vec.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
result_vec
}
fn prune_neighbors(&mut self, node: usize, m: usize) {
let node_vec = self.get_vector(node);
let mut neighbors_with_dist: Vec<(u32, f32)> = self.neighbors[node]
.iter()
.map(|&n| {
(
n,
distance::l2_distance(node_vec, self.get_vector(n as usize)),
)
})
.collect();
neighbors_with_dist.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
self.neighbors[node] = neighbors_with_dist
.into_iter()
.take(m)
.map(|(id, _)| id)
.collect();
}
fn compute_all_lid(&mut self) {
for i in 0..self.num_vectors {
let node_vec = self.get_vector(i);
let neighbor_distances: Vec<f32> = self.neighbors[i]
.iter()
.map(|&n| distance::l2_distance(node_vec, self.get_vector(n as usize)))
.collect();
if !neighbor_distances.is_empty() {
self.lid_estimates[i] =
estimate_lid_for_hnsw(&neighbor_distances, self.config.lid_k);
}
}
self.lid_stats = Some(LidStats::from_estimates(&self.lid_estimates));
}
fn enhance_high_lid_nodes(&mut self, _rng: &mut dyn RngCore) -> Result<(), RetrieveError> {
let stats = self
.lid_stats
.as_ref()
.ok_or(RetrieveError::InvalidParameter(
"LID stats not computed".into(),
))?;
let threshold = stats.median + self.config.lid_threshold_sigma * stats.std_dev;
for i in 0..self.num_vectors {
if self.lid_estimates[i].lid > threshold {
let query = self.get_vector(i);
let entry = self.entry_point.unwrap_or(0);
let candidates = self.search_layer(query, entry, self.config.m_high_lid * 2);
let current_neighbors: HashSet<u32> = self.neighbors[i].iter().copied().collect();
let mut added = 0;
for (neighbor, _) in candidates {
if neighbor as usize != i
&& !current_neighbors.contains(&neighbor)
&& added < self.config.m_high_lid - self.neighbors[i].len()
{
self.neighbors[i].push(neighbor);
self.neighbors[neighbor as usize].push(i as u32);
added += 1;
}
}
}
}
Ok(())
}
fn add_skip_bridges(&mut self, rng: &mut dyn RngCore) -> Result<(), RetrieveError> {
self.skip_bridges.clear();
self.skip_adjacency = vec![Vec::new(); self.num_vectors];
let stats = self
.lid_stats
.as_ref()
.ok_or(RetrieveError::InvalidParameter(
"LID stats not computed".into(),
))?;
let threshold = stats.median + self.config.lid_threshold_sigma * stats.std_dev;
for i in 0..self.num_vectors {
if self.lid_estimates[i].lid > threshold {
for _ in 0..3 {
if rng.random::<f32>() > self.config.skip_bridge_probability {
continue;
}
if let Some(target) =
self.random_walk(i as u32, self.config.max_skip_length, rng)
{
if target as usize != i && !self.neighbors[i].contains(&target) {
let bridge_idx = self.skip_bridges.len();
self.skip_bridges.push(SkipBridge {
from: i as u32,
to: target,
skip_length: self.config.max_skip_length as u32,
});
self.skip_adjacency[i].push(bridge_idx);
}
}
}
}
}
Ok(())
}
fn random_walk(&self, start: u32, hops: usize, rng: &mut dyn RngCore) -> Option<u32> {
let mut current = start;
for _ in 0..hops {
let neighbors = &self.neighbors[current as usize];
if neighbors.is_empty() {
return None;
}
current = neighbors[rng.random_range(0..neighbors.len())];
}
Some(current)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter("index not built".into()));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let entry = self.entry_point.ok_or(RetrieveError::EmptyIndex)?;
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut results = BinaryHeap::new();
let ef = self.config.ef_search.max(k);
let entry_dist = distance::l2_distance(query, self.get_vector(entry as usize));
visited.insert(entry);
candidates.push(MinCandidate {
id: entry,
dist: entry_dist,
});
results.push(MaxCandidate {
id: entry,
dist: entry_dist,
});
while let Some(MinCandidate {
id: current,
dist: current_dist,
}) = candidates.pop()
{
if let Some(worst) = results.peek() {
if current_dist > worst.dist && results.len() >= ef {
break;
}
}
for &neighbor in &self.neighbors[current as usize] {
if visited.insert(neighbor) {
let dist = distance::l2_distance(query, self.get_vector(neighbor as usize));
let should_add =
results.len() < ef || results.peek().map(|w| dist < w.dist).unwrap_or(true);
if should_add {
candidates.push(MinCandidate { id: neighbor, dist });
results.push(MaxCandidate { id: neighbor, dist });
if results.len() > ef {
results.pop();
}
}
}
}
for &bridge_idx in &self.skip_adjacency[current as usize] {
let bridge = &self.skip_bridges[bridge_idx];
let target = bridge.to;
if visited.insert(target) {
let dist = distance::l2_distance(query, self.get_vector(target as usize));
let should_add =
results.len() < ef || results.peek().map(|w| dist < w.dist).unwrap_or(true);
if should_add {
candidates.push(MinCandidate { id: target, dist });
results.push(MaxCandidate { id: target, dist });
if results.len() > ef {
results.pop();
}
}
}
}
}
let mut result_vec: Vec<(u32, f32)> = results.into_iter().map(|c| (c.id, c.dist)).collect();
result_vec.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
result_vec.truncate(k);
Ok(result_vec)
}
pub fn stats(&self) -> DualBranchStats {
let high_lid_count = if let Some(stats) = &self.lid_stats {
let threshold = stats.median + self.config.lid_threshold_sigma * stats.std_dev;
self.lid_estimates
.iter()
.filter(|e| e.lid > threshold)
.count()
} else {
0
};
DualBranchStats {
num_vectors: self.num_vectors,
num_edges: self.neighbors.iter().map(|n| n.len()).sum::<usize>() / 2,
num_skip_bridges: self.skip_bridges.len(),
high_lid_nodes: high_lid_count,
lid_stats: self.lid_stats.clone(),
}
}
#[inline]
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
&self.vectors[start..start + self.dimension]
}
}
#[derive(Debug, Clone)]
pub struct DualBranchStats {
pub num_vectors: usize,
pub num_edges: usize,
pub num_skip_bridges: usize,
pub high_lid_nodes: usize,
pub lid_stats: Option<LidStats>,
}
#[derive(Debug, Clone, Copy)]
struct MinCandidate {
id: u32,
dist: f32,
}
impl PartialEq for MinCandidate {
fn eq(&self, other: &Self) -> bool {
self.dist == other.dist
}
}
impl Eq for MinCandidate {}
impl PartialOrd for MinCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.dist.total_cmp(&self.dist)
}
}
#[derive(Debug, Clone, Copy)]
struct MaxCandidate {
id: u32,
dist: f32,
}
impl PartialEq for MaxCandidate {
fn eq(&self, other: &Self) -> bool {
self.dist == other.dist
}
}
impl Eq for MaxCandidate {}
impl PartialOrd for MaxCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.dist.total_cmp(&other.dist)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn create_clustered_data(n_clusters: usize, points_per_cluster: usize, dim: usize) -> Vec<f32> {
let mut rng = StdRng::seed_from_u64(42);
let mut data = Vec::new();
for c in 0..n_clusters {
let center: Vec<f32> = (0..dim)
.map(|_| (c as f32) * 10.0 + rng.random::<f32>())
.collect();
for _ in 0..points_per_cluster {
for c_val in center.iter().take(dim) {
data.push(c_val + rng.random::<f32>() * 0.5);
}
}
}
for _ in 0..5 {
for _ in 0..dim {
data.push(rng.random::<f32>() * 100.0);
}
}
data
}
#[test]
fn test_dual_branch_build() {
let dim = 8;
let data = create_clustered_data(3, 50, dim);
let config = DualBranchConfig {
m: 8,
m_high_lid: 12,
ef_construction: 50,
ef_search: 20,
seed: Some(42),
..Default::default()
};
let mut index = DualBranchHNSW::new(dim, config);
index.add_vectors(&data).unwrap();
index.build().unwrap();
let stats = index.stats();
println!("Stats: {:?}", stats);
assert!(stats.num_edges > 0);
assert!(stats.num_vectors > 0);
}
#[test]
fn test_dual_branch_search() {
let dim = 8;
let data = create_clustered_data(3, 50, dim);
let _n = data.len() / dim;
let config = DualBranchConfig {
m: 8,
m_high_lid: 12,
ef_construction: 50,
ef_search: 30,
seed: Some(42),
..Default::default()
};
let mut index = DualBranchHNSW::new(dim, config);
index.add_vectors(&data).unwrap();
index.build().unwrap();
let query = &data[0..dim];
let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert!(
results[0].1 < 1.0,
"First result distance: {}",
results[0].1
);
}
#[test]
fn test_dual_branch_skip_bridges() {
let dim = 8;
let data = create_clustered_data(5, 30, dim);
let config = DualBranchConfig {
m: 6,
m_high_lid: 10,
ef_construction: 50,
ef_search: 20,
skip_bridge_probability: 0.5, seed: Some(42),
..Default::default()
};
let mut index = DualBranchHNSW::new(dim, config);
index.add_vectors(&data).unwrap();
index.build().unwrap();
let stats = index.stats();
println!("Skip bridges: {}", stats.num_skip_bridges);
println!("High-LID nodes: {}", stats.high_lid_nodes);
assert!(stats.num_vectors > 0);
}
#[test]
fn test_dual_branch_lid_detection() {
let dim = 4;
let mut data = Vec::new();
for i in 0..50 {
data.extend_from_slice(&[0.1 * i as f32, 0.1, 0.1, 0.1]);
}
data.extend_from_slice(&[100.0, 100.0, 100.0, 100.0]);
data.extend_from_slice(&[-100.0, -100.0, -100.0, -100.0]);
let config = DualBranchConfig {
m: 6,
m_high_lid: 12,
ef_construction: 30,
seed: Some(42),
..Default::default()
};
let mut index = DualBranchHNSW::new(dim, config);
index.add_vectors(&data).unwrap();
index.build().unwrap();
let stats = index.stats();
println!("LID stats: {:?}", stats.lid_stats);
assert!(stats.high_lid_nodes > 0, "Should detect high-LID outliers");
}
#[test]
fn test_dual_branch_recall() {
let dim = 16;
let data = create_clustered_data(5, 100, dim);
let n = data.len() / dim;
let config = DualBranchConfig {
m: 12,
m_high_lid: 18,
ef_construction: 100,
ef_search: 50,
seed: Some(42),
..Default::default()
};
let mut index = DualBranchHNSW::new(dim, config);
index.add_vectors(&data).unwrap();
index.build().unwrap();
let mut rng = StdRng::seed_from_u64(123);
let mut correct = 0;
let num_queries = 20;
let k = 10;
for _ in 0..num_queries {
let query_idx = rng.random_range(0..n);
let query = &data[query_idx * dim..(query_idx + 1) * dim];
let results = index.search(query, k).unwrap();
let mut gt: Vec<(usize, f32)> = (0..n)
.map(|i| {
let v = &data[i * dim..(i + 1) * dim];
(i, distance::l2_distance(query, v))
})
.collect();
gt.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
let gt_set: HashSet<u32> = gt.iter().take(k).map(|&(id, _)| id as u32).collect();
let result_set: HashSet<u32> = results.iter().map(|&(id, _)| id).collect();
correct += gt_set.intersection(&result_set).count();
}
let recall = correct as f32 / (num_queries * k) as f32;
println!("Recall@{}: {:.2}%", k, recall * 100.0);
assert!(recall > 0.1, "Recall too low: {}", recall);
}
}