use std::sync::atomic::{AtomicUsize, Ordering};
use ahash::{AHashMap, AHashSet};
use ordered_float::OrderedFloat;
use parking_lot::RwLock;
use rand::{Rng, SeedableRng};
use rayon::prelude::*;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use crate::distance;
use crate::error::{Error, Result};
use super::traits::{DistanceType, SearchResult, VectorIndex};
use super::hnsw::HNSWConfig;
#[derive(Debug)]
pub struct ParallelHNSWIndex {
config: HNSWConfig,
nodes: RwLock<Vec<ParallelHNSWNode>>,
id_to_idx: RwLock<AHashMap<String, usize>>,
entry_point: AtomicUsize,
max_level: AtomicUsize,
has_entry_point: std::sync::atomic::AtomicBool,
}
#[derive(Debug)]
struct ParallelHNSWNode {
id: String,
vector: Vec<f32>,
level: usize,
neighbors: Vec<RwLock<AHashSet<usize>>>,
}
pub struct ParallelHNSWBuilder {
config: HNSWConfig,
num_threads: usize,
batch_size: usize,
seed: Option<u64>,
}
impl Default for ParallelHNSWBuilder {
fn default() -> Self {
Self::new()
}
}
impl ParallelHNSWBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: HNSWConfig::new(0), num_threads: 0, batch_size: 256,
seed: None,
}
}
#[must_use]
pub fn with_config(mut self, config: HNSWConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub const fn with_threads(mut self, threads: usize) -> Self {
self.num_threads = threads;
self
}
#[must_use]
pub const fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
#[must_use]
pub const fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn build<I, S>(self, vectors: I) -> Result<ParallelHNSWIndex>
where
I: IntoIterator<Item = (S, Vec<f32>)>,
S: Into<String>,
{
let vectors: Vec<(String, Vec<f32>)> = vectors
.into_iter()
.map(|(id, v)| (id.into(), v))
.collect();
if vectors.is_empty() {
return Ok(ParallelHNSWIndex::new(self.config));
}
let dim = vectors[0].1.len();
for (id, v) in &vectors {
if v.len() != dim {
return Err(Error::InvalidQuery {
reason: format!(
"Dimension mismatch for '{}': expected {}, got {}",
id, dim, v.len()
),
});
}
}
let mut config = self.config.clone();
config.base.dimension = dim;
let index = ParallelHNSWIndex::new(config.clone());
let mut rng = match self.seed {
Some(s) => rand::rngs::SmallRng::seed_from_u64(s),
None => rand::rngs::SmallRng::from_entropy(),
};
let ml = config.ml;
let mut indexed_vectors: Vec<(usize, String, Vec<f32>, usize)> = vectors
.into_iter()
.enumerate()
.map(|(orig_idx, (id, vec))| {
let level = generate_level(&mut rng, ml);
(orig_idx, id, vec, level)
})
.collect();
indexed_vectors.sort_by(|a, b| b.3.cmp(&a.3));
if let Some((_, id, vector, level)) = indexed_vectors.first() {
index.insert_first(id.clone(), vector.clone(), *level);
}
let remaining: Vec<_> = indexed_vectors.into_iter().skip(1).collect();
if self.num_threads > 0 {
rayon::ThreadPoolBuilder::new()
.num_threads(self.num_threads)
.build()
.map_err(|e| Error::IndexBuild {
reason: format!("Failed to create thread pool: {}", e),
})?
.install(|| {
self.parallel_insert_batch(&index, remaining);
});
} else {
self.parallel_insert_batch(&index, remaining);
}
Ok(index)
}
fn parallel_insert_batch(&self, index: &ParallelHNSWIndex, vectors: Vec<(usize, String, Vec<f32>, usize)>) {
for batch in vectors.chunks(self.batch_size) {
batch.par_iter().for_each(|(_, id, vector, level)| {
index.insert_parallel(id.clone(), vector, *level);
});
}
}
}
fn generate_level(rng: &mut impl Rng, ml: f64) -> usize {
let mut level = 0;
while rng.gen::<f64>() < ml && level < 16 {
level += 1;
}
level
}
impl ParallelHNSWIndex {
#[must_use]
pub fn new(config: HNSWConfig) -> Self {
Self {
config,
nodes: RwLock::new(Vec::new()),
id_to_idx: RwLock::new(AHashMap::new()),
entry_point: AtomicUsize::new(0),
max_level: AtomicUsize::new(0),
has_entry_point: std::sync::atomic::AtomicBool::new(false),
}
}
fn insert_first(&self, id: String, vector: Vec<f32>, level: usize) {
let node = ParallelHNSWNode {
id: id.clone(),
vector,
level,
neighbors: (0..=level).map(|_| RwLock::new(AHashSet::new())).collect(),
};
let mut nodes = self.nodes.write();
let idx = nodes.len();
nodes.push(node);
self.id_to_idx.write().insert(id, idx);
self.entry_point.store(idx, Ordering::Release);
self.max_level.store(level, Ordering::Release);
self.has_entry_point.store(true, Ordering::Release);
}
fn insert_parallel(&self, id: String, vector: &[f32], level: usize) {
let new_idx = {
let mut nodes = self.nodes.write();
let idx = nodes.len();
let node = ParallelHNSWNode {
id: id.clone(),
vector: vector.to_vec(),
level,
neighbors: (0..=level).map(|_| RwLock::new(AHashSet::new())).collect(),
};
nodes.push(node);
idx
};
self.id_to_idx.write().insert(id, new_idx);
let entry_point = self.entry_point.load(Ordering::Acquire);
let current_max_level = self.max_level.load(Ordering::Acquire);
let mut curr_ep = vec![entry_point];
for lc in (level + 1..=current_max_level).rev() {
let nearest = self.search_layer_parallel(vector, curr_ep.clone(), 1, lc);
if !nearest.is_empty() {
curr_ep = vec![nearest[0].1];
}
}
for lc in (0..=level.min(current_max_level)).rev() {
let candidates = self.search_layer_parallel(
vector,
curr_ep.clone(),
self.config.ef_construction,
lc,
);
let m = self.get_max_connections(lc);
let neighbors: Vec<usize> = candidates.iter().take(m).map(|(_, idx)| *idx).collect();
{
let nodes = self.nodes.read();
if lc < nodes[new_idx].neighbors.len() {
let mut new_neighbors = nodes[new_idx].neighbors[lc].write();
for &n in &neighbors {
new_neighbors.insert(n);
}
}
}
{
let nodes = self.nodes.read();
for &neighbor_idx in &neighbors {
if lc < nodes[neighbor_idx].neighbors.len() {
let mut neighbor_set = nodes[neighbor_idx].neighbors[lc].write();
neighbor_set.insert(new_idx);
if neighbor_set.len() > m {
let neighbor_vec = &nodes[neighbor_idx].vector;
let mut scored: Vec<_> = neighbor_set
.iter()
.map(|&idx| {
let other_vec = if idx == new_idx {
vector
} else {
&nodes[idx].vector
};
(self.distance(neighbor_vec, other_vec), idx)
})
.collect();
scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
*neighbor_set = scored.into_iter().take(m).map(|(_, idx)| idx).collect();
}
}
}
}
if !candidates.is_empty() {
curr_ep = vec![candidates[0].1];
}
}
if level > current_max_level {
let _ = self.max_level.compare_exchange(
current_max_level,
level,
Ordering::AcqRel,
Ordering::Acquire,
);
let _ = self.entry_point.compare_exchange(
entry_point,
new_idx,
Ordering::AcqRel,
Ordering::Acquire,
);
}
}
fn search_layer_parallel(
&self,
query: &[f32],
entry_points: Vec<usize>,
ef: usize,
level: usize,
) -> Vec<(f32, usize)> {
let nodes = self.nodes.read();
let mut visited: AHashSet<usize> = entry_points.iter().copied().collect();
let mut candidates: BinaryHeap<Reverse<(OrderedFloat<f32>, usize)>> = BinaryHeap::new();
let mut results: BinaryHeap<(OrderedFloat<f32>, usize)> = BinaryHeap::new();
for &ep in &entry_points {
if ep < nodes.len() {
let dist = self.distance(query, &nodes[ep].vector);
candidates.push(Reverse((OrderedFloat(dist), ep)));
results.push((OrderedFloat(dist), ep));
}
}
while let Some(Reverse((OrderedFloat(c_dist), c_idx))) = candidates.pop() {
let f_dist = results.peek().map(|(d, _)| d.0).unwrap_or(f32::INFINITY);
if c_dist > f_dist && results.len() >= ef {
break;
}
if c_idx < nodes.len() && level < nodes[c_idx].neighbors.len() {
let neighbors = nodes[c_idx].neighbors[level].read();
for &neighbor_idx in neighbors.iter() {
if neighbor_idx < nodes.len() && visited.insert(neighbor_idx) {
let dist = self.distance(query, &nodes[neighbor_idx].vector);
let f_dist = results.peek().map(|(d, _)| d.0).unwrap_or(f32::INFINITY);
if dist < f_dist || results.len() < ef {
candidates.push(Reverse((OrderedFloat(dist), neighbor_idx)));
results.push((OrderedFloat(dist), neighbor_idx));
if results.len() > ef {
results.pop();
}
}
}
}
}
}
let mut result_vec: Vec<_> = results.into_iter().map(|(d, idx)| (d.0, idx)).collect();
result_vec.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
result_vec
}
#[inline]
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
distance::compute_distance_for_heap(a, b, self.config.base.distance_type)
}
fn get_max_connections(&self, level: usize) -> usize {
if level == 0 {
self.config.m_max0
} else {
self.config.m
}
}
#[must_use]
pub fn config(&self) -> &HNSWConfig {
&self.config
}
}
impl VectorIndex for ParallelHNSWIndex {
fn add(&mut self, id: String, vector: &[f32]) -> Result<()> {
if vector.len() != self.config.base.dimension {
return Err(Error::InvalidQuery {
reason: format!(
"Dimension mismatch: expected {}, got {}",
self.config.base.dimension,
vector.len()
),
});
}
if self.id_to_idx.read().contains_key(&id) {
return Err(Error::DuplicateRecord { record_id: id });
}
let mut rng = rand::rngs::SmallRng::from_entropy();
let level = generate_level(&mut rng, self.config.ml);
if !self.has_entry_point.load(Ordering::Acquire) {
self.insert_first(id, vector.to_vec(), level);
} else {
self.insert_parallel(id, vector, level);
}
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if query.len() != self.config.base.dimension {
return Err(Error::InvalidQuery {
reason: format!(
"Query dimension mismatch: expected {}, got {}",
self.config.base.dimension,
query.len()
),
});
}
let nodes = self.nodes.read();
if nodes.is_empty() {
return Ok(vec![]);
}
let entry_point = self.entry_point.load(Ordering::Acquire);
let max_level = self.max_level.load(Ordering::Acquire);
drop(nodes);
let mut curr_ep = vec![entry_point];
for lc in (1..=max_level).rev() {
let nearest = self.search_layer_parallel(query, curr_ep.clone(), 1, lc);
if !nearest.is_empty() {
curr_ep = vec![nearest[0].1];
}
}
let results = self.search_layer_parallel(query, curr_ep, self.config.ef_search, 0);
let nodes = self.nodes.read();
let k = k.min(results.len());
Ok(results
.into_iter()
.take(k)
.filter_map(|(dist, idx)| {
if idx < nodes.len() && !nodes[idx].id.is_empty() {
let actual_dist = match self.config.base.distance_type {
DistanceType::InnerProduct => -dist,
DistanceType::Cosine => 1.0 - dist,
DistanceType::L2 => dist,
};
Some(SearchResult::new(
nodes[idx].id.clone(),
actual_dist,
self.config.base.distance_type,
))
} else {
None
}
})
.collect())
}
fn remove(&mut self, id: &str) -> Result<bool> {
let idx = {
let id_map = self.id_to_idx.read();
id_map.get(id).copied()
};
if let Some(idx) = idx {
let nodes = self.nodes.read();
for node in nodes.iter() {
for neighbors in &node.neighbors {
neighbors.write().remove(&idx);
}
}
drop(nodes);
self.id_to_idx.write().remove(id);
let mut nodes = self.nodes.write();
nodes[idx].id = String::new();
nodes[idx].vector.clear();
Ok(true)
} else {
Ok(false)
}
}
fn contains(&self, id: &str) -> bool {
self.id_to_idx.read().contains_key(id)
}
fn len(&self) -> usize {
self.id_to_idx.read().len()
}
fn dimension(&self) -> usize {
self.config.base.dimension
}
fn distance_type(&self) -> DistanceType {
self.config.base.distance_type
}
fn clear(&mut self) {
self.nodes.write().clear();
self.id_to_idx.write().clear();
self.entry_point.store(0, Ordering::Release);
self.max_level.store(0, Ordering::Release);
self.has_entry_point.store(false, Ordering::Release);
}
fn memory_usage(&self) -> usize {
let nodes = self.nodes.read();
let node_size = self.config.base.dimension * 4 + 64;
let neighbor_size = self.config.m * 8 * 2;
nodes.len() * (node_size + neighbor_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<(String, Vec<f32>)> {
use rand::SeedableRng;
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
(0..n)
.map(|i| {
let v: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>() - 0.5).collect();
(format!("vec_{}", i), v)
})
.collect()
}
#[test]
fn test_parallel_build_basic() {
let vectors = random_vectors(100, 64, 42);
let config = HNSWConfig::new(64)
.with_m(8)
.with_ef_construction(50)
.with_ef_search(32);
let index = ParallelHNSWBuilder::new()
.with_config(config)
.with_seed(42)
.with_batch_size(16)
.build(vectors)
.unwrap();
assert_eq!(index.len(), 100);
assert_eq!(index.dimension(), 64);
}
#[test]
fn test_parallel_search() {
let vectors = random_vectors(100, 64, 42);
let query = vectors[0].1.clone();
let config = HNSWConfig::new(64)
.with_m(8)
.with_ef_construction(50)
.with_ef_search(32);
let index = ParallelHNSWBuilder::new()
.with_config(config)
.with_seed(42)
.build(vectors)
.unwrap();
let results = index.search(&query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
assert_eq!(results[0].id, "vec_0");
}
#[test]
fn test_parallel_vs_sequential_recall() {
let vectors = random_vectors(500, 64, 42);
let query = vectors[0].1.clone();
let config = HNSWConfig::new(64)
.with_m(16)
.with_ef_construction(100)
.with_ef_search(64);
let parallel_index = ParallelHNSWBuilder::new()
.with_config(config.clone())
.with_seed(42)
.build(vectors.clone())
.unwrap();
let results = parallel_index.search(&query, 10).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, "vec_0");
assert!(results[0].distance < 0.01);
}
#[test]
fn test_empty_build() {
let vectors: Vec<(String, Vec<f32>)> = vec![];
let config = HNSWConfig::new(64);
let index = ParallelHNSWBuilder::new()
.with_config(config)
.build(vectors)
.unwrap();
assert_eq!(index.len(), 0);
}
#[test]
fn test_single_vector() {
let vectors = vec![("single".to_string(), vec![1.0, 0.0, 0.0, 0.0])];
let config = HNSWConfig::new(4);
let index = ParallelHNSWBuilder::new()
.with_config(config)
.build(vectors)
.unwrap();
assert_eq!(index.len(), 1);
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "single");
}
#[test]
fn test_add_after_build() {
let vectors = random_vectors(50, 64, 42);
let config = HNSWConfig::new(64)
.with_m(8)
.with_ef_construction(50);
let mut index = ParallelHNSWBuilder::new()
.with_config(config)
.with_seed(42)
.build(vectors)
.unwrap();
assert_eq!(index.len(), 50);
index.add("new_1".to_string(), &vec![0.5; 64]).unwrap();
index.add("new_2".to_string(), &vec![-0.5; 64]).unwrap();
assert_eq!(index.len(), 52);
assert!(index.contains("new_1"));
assert!(index.contains("new_2"));
}
#[test]
fn test_remove_after_build() {
let vectors = random_vectors(50, 64, 42);
let config = HNSWConfig::new(64).with_m(8);
let mut index = ParallelHNSWBuilder::new()
.with_config(config)
.with_seed(42)
.build(vectors)
.unwrap();
assert_eq!(index.len(), 50);
assert!(index.contains("vec_0"));
index.remove("vec_0").unwrap();
assert_eq!(index.len(), 49);
assert!(!index.contains("vec_0"));
}
}