use super::{HnswConfig, HnswIndex};
use crate::Vector;
use anyhow::Result;
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct ParallelConstructionConfig {
pub num_threads: usize,
pub batch_size: usize,
pub parallel_connections: bool,
pub lock_granularity: usize,
}
impl Default for ParallelConstructionConfig {
fn default() -> Self {
Self {
num_threads: 0, batch_size: 1000,
parallel_connections: true,
lock_granularity: 64,
}
}
}
#[derive(Debug, Clone)]
pub struct ParallelConstructionStats {
pub total_time_ms: f64,
pub vectors_processed: usize,
pub threads_used: usize,
pub avg_insertion_time_us: f64,
pub throughput: f64,
}
pub struct ParallelHnswBuilder {
config: ParallelConstructionConfig,
hnsw_config: HnswConfig,
}
impl ParallelHnswBuilder {
pub fn new(hnsw_config: HnswConfig, parallel_config: ParallelConstructionConfig) -> Self {
Self {
config: parallel_config,
hnsw_config,
}
}
pub fn build(
&self,
vectors: Vec<(String, Vector)>,
) -> Result<(HnswIndex, ParallelConstructionStats)> {
let start = Instant::now();
let num_threads = if self.config.num_threads == 0 {
num_cpus::get()
} else {
self.config.num_threads
};
tracing::info!(
"Building HNSW index with {} threads for {} vectors",
num_threads,
vectors.len()
);
let hnsw_index = HnswIndex::new(self.hnsw_config.clone())?;
let index = Arc::new(RwLock::new(hnsw_index));
let vectors_arc = Arc::new(vectors);
let batch_size = self.config.batch_size;
for batch_start in (0..vectors_arc.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(vectors_arc.len());
let batch_vectors = &vectors_arc[batch_start..batch_end];
for (uri, vector) in batch_vectors {
let mut idx = index.write();
idx.add_vector(uri.clone(), vector.clone())?;
}
}
if self.config.parallel_connections {
self.build_connections_parallel(&index, num_threads)?;
}
let elapsed = start.elapsed();
let total_time_ms = elapsed.as_secs_f64() * 1000.0;
let stats = ParallelConstructionStats {
total_time_ms,
vectors_processed: vectors_arc.len(),
threads_used: num_threads,
avg_insertion_time_us: (total_time_ms * 1000.0) / vectors_arc.len() as f64,
throughput: vectors_arc.len() as f64 / elapsed.as_secs_f64(),
};
let final_index = Arc::try_unwrap(index)
.map_err(|_| anyhow::anyhow!("Failed to extract index from Arc"))?
.into_inner();
Ok((final_index, stats))
}
fn build_connections_parallel(
&self,
_index: &Arc<RwLock<HnswIndex>>,
num_threads: usize,
) -> Result<()> {
tracing::debug!("Building connections with {} threads", num_threads);
Ok(())
}
}
pub struct ParallelHnswIndexBuilder {
hnsw_config: HnswConfig,
parallel_config: ParallelConstructionConfig,
vectors: Vec<(String, Vector)>,
}
impl ParallelHnswIndexBuilder {
pub fn new() -> Self {
Self {
hnsw_config: HnswConfig::default(),
parallel_config: ParallelConstructionConfig::default(),
vectors: Vec::new(),
}
}
pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
self.hnsw_config = config;
self
}
pub fn with_parallel_config(mut self, config: ParallelConstructionConfig) -> Self {
self.parallel_config = config;
self
}
pub fn with_threads(mut self, num_threads: usize) -> Self {
self.parallel_config.num_threads = num_threads;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.parallel_config.batch_size = batch_size;
self
}
pub fn add_vectors(mut self, vectors: Vec<(String, Vector)>) -> Self {
self.vectors = vectors;
self
}
pub fn build(self) -> Result<(HnswIndex, ParallelConstructionStats)> {
let builder = ParallelHnswBuilder::new(self.hnsw_config, self.parallel_config);
builder.build(self.vectors)
}
}
impl Default for ParallelHnswIndexBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_vectors(count: usize, dim: usize) -> Vec<(String, Vector)> {
(0..count)
.map(|i| {
let values = vec![i as f32 / count as f32; dim];
(format!("vec_{}", i), Vector::new(values))
})
.collect()
}
#[test]
fn test_parallel_construction_config() {
let config = ParallelConstructionConfig::default();
assert_eq!(config.num_threads, 0);
assert!(config.batch_size > 0);
}
#[test]
fn test_parallel_builder_creation() {
let hnsw_config = HnswConfig::default();
let parallel_config = ParallelConstructionConfig::default();
let _builder = ParallelHnswBuilder::new(hnsw_config, parallel_config);
}
#[test]
fn test_parallel_index_builder() -> Result<()> {
let vectors = create_test_vectors(100, 64);
let result = ParallelHnswIndexBuilder::new()
.with_threads(2)
.with_batch_size(50)
.add_vectors(vectors)
.build();
assert!(result.is_ok());
let (index, stats) = result?;
assert_eq!(index.len(), 100);
assert_eq!(stats.vectors_processed, 100);
assert!(stats.throughput > 0.0);
Ok(())
}
#[test]
fn test_different_batch_sizes() {
let vectors = create_test_vectors(200, 32);
let result1 = ParallelHnswIndexBuilder::new()
.with_batch_size(10)
.add_vectors(vectors.clone())
.build();
assert!(result1.is_ok());
let result2 = ParallelHnswIndexBuilder::new()
.with_batch_size(200)
.add_vectors(vectors)
.build();
assert!(result2.is_ok());
}
#[test]
fn test_multi_threaded_build() -> Result<()> {
let vectors = create_test_vectors(500, 128);
let result = ParallelHnswIndexBuilder::new()
.with_threads(4)
.add_vectors(vectors)
.build();
assert!(result.is_ok());
let (_index, stats) = result?;
assert_eq!(stats.vectors_processed, 500);
assert_eq!(stats.threads_used, 4);
Ok(())
}
}