oxirs_vec/hnsw/
parallel_construction.rs1use super::{HnswConfig, HnswIndex};
7use crate::Vector;
8use anyhow::Result;
9use parking_lot::RwLock;
10use std::sync::Arc;
11use std::time::Instant;
12
13#[derive(Debug, Clone)]
15pub struct ParallelConstructionConfig {
16 pub num_threads: usize,
18 pub batch_size: usize,
20 pub parallel_connections: bool,
22 pub lock_granularity: usize,
24}
25
26impl Default for ParallelConstructionConfig {
27 fn default() -> Self {
28 Self {
29 num_threads: 0, batch_size: 1000,
31 parallel_connections: true,
32 lock_granularity: 64,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct ParallelConstructionStats {
40 pub total_time_ms: f64,
42 pub vectors_processed: usize,
44 pub threads_used: usize,
46 pub avg_insertion_time_us: f64,
48 pub throughput: f64,
50}
51
52pub struct ParallelHnswBuilder {
54 config: ParallelConstructionConfig,
55 hnsw_config: HnswConfig,
56}
57
58impl ParallelHnswBuilder {
59 pub fn new(hnsw_config: HnswConfig, parallel_config: ParallelConstructionConfig) -> Self {
61 Self {
62 config: parallel_config,
63 hnsw_config,
64 }
65 }
66
67 pub fn build(
69 &self,
70 vectors: Vec<(String, Vector)>,
71 ) -> Result<(HnswIndex, ParallelConstructionStats)> {
72 let start = Instant::now();
73 let num_threads = if self.config.num_threads == 0 {
74 num_cpus::get()
75 } else {
76 self.config.num_threads
77 };
78
79 tracing::info!(
80 "Building HNSW index with {} threads for {} vectors",
81 num_threads,
82 vectors.len()
83 );
84
85 let hnsw_index = HnswIndex::new(self.hnsw_config.clone())?;
87 let index = Arc::new(RwLock::new(hnsw_index));
88
89 let vectors_arc = Arc::new(vectors);
91 let batch_size = self.config.batch_size;
92
93 for batch_start in (0..vectors_arc.len()).step_by(batch_size) {
95 let batch_end = (batch_start + batch_size).min(vectors_arc.len());
96 let batch_vectors = &vectors_arc[batch_start..batch_end];
97
98 for (uri, vector) in batch_vectors {
100 let mut idx = index.write();
101 idx.add_vector(uri.clone(), vector.clone())?;
102 }
103 }
104
105 if self.config.parallel_connections {
107 self.build_connections_parallel(&index, num_threads)?;
108 }
109
110 let elapsed = start.elapsed();
111 let total_time_ms = elapsed.as_secs_f64() * 1000.0;
112
113 let stats = ParallelConstructionStats {
114 total_time_ms,
115 vectors_processed: vectors_arc.len(),
116 threads_used: num_threads,
117 avg_insertion_time_us: (total_time_ms * 1000.0) / vectors_arc.len() as f64,
118 throughput: vectors_arc.len() as f64 / elapsed.as_secs_f64(),
119 };
120
121 let final_index = Arc::try_unwrap(index)
123 .map_err(|_| anyhow::anyhow!("Failed to extract index from Arc"))?
124 .into_inner();
125
126 Ok((final_index, stats))
127 }
128
129 fn build_connections_parallel(
131 &self,
132 _index: &Arc<RwLock<HnswIndex>>,
133 num_threads: usize,
134 ) -> Result<()> {
135 tracing::debug!("Building connections with {} threads", num_threads);
143
144 Ok(())
145 }
146}
147
148pub struct ParallelHnswIndexBuilder {
150 hnsw_config: HnswConfig,
151 parallel_config: ParallelConstructionConfig,
152 vectors: Vec<(String, Vector)>,
153}
154
155impl ParallelHnswIndexBuilder {
156 pub fn new() -> Self {
158 Self {
159 hnsw_config: HnswConfig::default(),
160 parallel_config: ParallelConstructionConfig::default(),
161 vectors: Vec::new(),
162 }
163 }
164
165 pub fn with_hnsw_config(mut self, config: HnswConfig) -> Self {
167 self.hnsw_config = config;
168 self
169 }
170
171 pub fn with_parallel_config(mut self, config: ParallelConstructionConfig) -> Self {
173 self.parallel_config = config;
174 self
175 }
176
177 pub fn with_threads(mut self, num_threads: usize) -> Self {
179 self.parallel_config.num_threads = num_threads;
180 self
181 }
182
183 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
185 self.parallel_config.batch_size = batch_size;
186 self
187 }
188
189 pub fn add_vectors(mut self, vectors: Vec<(String, Vector)>) -> Self {
191 self.vectors = vectors;
192 self
193 }
194
195 pub fn build(self) -> Result<(HnswIndex, ParallelConstructionStats)> {
197 let builder = ParallelHnswBuilder::new(self.hnsw_config, self.parallel_config);
198 builder.build(self.vectors)
199 }
200}
201
202impl Default for ParallelHnswIndexBuilder {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 fn create_test_vectors(count: usize, dim: usize) -> Vec<(String, Vector)> {
213 (0..count)
214 .map(|i| {
215 let values = vec![i as f32 / count as f32; dim];
216 (format!("vec_{}", i), Vector::new(values))
217 })
218 .collect()
219 }
220
221 #[test]
222 fn test_parallel_construction_config() {
223 let config = ParallelConstructionConfig::default();
224 assert_eq!(config.num_threads, 0);
225 assert!(config.batch_size > 0);
226 }
227
228 #[test]
229 fn test_parallel_builder_creation() {
230 let hnsw_config = HnswConfig::default();
231 let parallel_config = ParallelConstructionConfig::default();
232 let _builder = ParallelHnswBuilder::new(hnsw_config, parallel_config);
233 }
234
235 #[test]
236 fn test_parallel_index_builder() {
237 let vectors = create_test_vectors(100, 64);
238
239 let result = ParallelHnswIndexBuilder::new()
240 .with_threads(2)
241 .with_batch_size(50)
242 .add_vectors(vectors)
243 .build();
244
245 assert!(result.is_ok());
246 let (index, stats) = result.unwrap();
247
248 assert_eq!(index.len(), 100);
249 assert_eq!(stats.vectors_processed, 100);
250 assert!(stats.throughput > 0.0);
251 }
252
253 #[test]
254 fn test_different_batch_sizes() {
255 let vectors = create_test_vectors(200, 32);
256
257 let result1 = ParallelHnswIndexBuilder::new()
259 .with_batch_size(10)
260 .add_vectors(vectors.clone())
261 .build();
262 assert!(result1.is_ok());
263
264 let result2 = ParallelHnswIndexBuilder::new()
266 .with_batch_size(200)
267 .add_vectors(vectors)
268 .build();
269 assert!(result2.is_ok());
270 }
271
272 #[test]
273 fn test_multi_threaded_build() {
274 let vectors = create_test_vectors(500, 128);
275
276 let result = ParallelHnswIndexBuilder::new()
277 .with_threads(4)
278 .add_vectors(vectors)
279 .build();
280
281 assert!(result.is_ok());
282 let (_index, stats) = result.unwrap();
283
284 assert_eq!(stats.vectors_processed, 500);
285 assert_eq!(stats.threads_used, 4);
286 }
287}