1use crate::hnsw::HnswIndex;
7use crate::Vector;
8use anyhow::Result;
9
10#[derive(Debug, Clone)]
12pub struct BatchOperationResult {
13 pub success_count: usize,
15 pub failure_count: usize,
17 pub results: Vec<Result<(), String>>,
19 pub duration_ms: u64,
21}
22
23#[derive(Debug, Clone)]
25pub struct BatchInsertConfig {
26 pub use_parallel: bool,
28 pub num_threads: usize,
30 pub batch_size: usize,
32 pub optimize_after: bool,
34}
35
36impl Default for BatchInsertConfig {
37 fn default() -> Self {
38 Self {
39 use_parallel: true,
40 num_threads: std::thread::available_parallelism()
41 .map(|n| n.get())
42 .unwrap_or(4),
43 batch_size: 1000,
44 optimize_after: true,
45 }
46 }
47}
48
49impl HnswIndex {
50 pub fn batch_insert(
66 &mut self,
67 vectors: Vec<(String, Vector)>,
68 config: BatchInsertConfig,
69 ) -> Result<BatchOperationResult> {
70 let start = std::time::Instant::now();
71 let total_count = vectors.len();
72 let mut results = Vec::with_capacity(total_count);
73 let mut success_count = 0;
74 let mut failure_count = 0;
75
76 if vectors.is_empty() {
77 return Ok(BatchOperationResult {
78 success_count: 0,
79 failure_count: 0,
80 results: vec![],
81 duration_ms: 0,
82 });
83 }
84
85 tracing::info!(
86 "Starting batch insert of {} vectors (parallel: {})",
87 total_count,
88 config.use_parallel
89 );
90
91 for chunk in vectors.chunks(config.batch_size) {
93 for (uri, vector) in chunk {
94 match self.add_vector(uri.clone(), vector.clone()) {
95 Ok(_) => {
96 success_count += 1;
97 results.push(Ok(()));
98 }
99 Err(e) => {
100 failure_count += 1;
101 results.push(Err(e.to_string()));
102 }
103 }
104 }
105 }
106
107 if config.optimize_after {
109 tracing::info!("Optimizing graph after batch insert");
110 self.optimize_graph_structure()?;
111 }
112
113 let duration_ms = start.elapsed().as_millis() as u64;
114
115 tracing::info!(
116 "Batch insert completed: {} successes, {} failures in {}ms",
117 success_count,
118 failure_count,
119 duration_ms
120 );
121
122 Ok(BatchOperationResult {
123 success_count,
124 failure_count,
125 results,
126 duration_ms,
127 })
128 }
129
130 pub fn batch_update(&mut self, updates: Vec<(String, Vector)>) -> Result<BatchOperationResult> {
140 let start = std::time::Instant::now();
141 let total_count = updates.len();
142 let mut results = Vec::with_capacity(total_count);
143 let mut success_count = 0;
144 let mut failure_count = 0;
145
146 tracing::info!("Starting batch update of {} vectors", total_count);
147
148 for (uri, vector) in updates {
149 match self.update_vector(&uri, vector) {
150 Ok(_) => {
151 success_count += 1;
152 results.push(Ok(()));
153 }
154 Err(e) => {
155 failure_count += 1;
156 results.push(Err(e.to_string()));
157 }
158 }
159 }
160
161 let duration_ms = start.elapsed().as_millis() as u64;
162
163 tracing::info!(
164 "Batch update completed: {} successes, {} failures in {}ms",
165 success_count,
166 failure_count,
167 duration_ms
168 );
169
170 Ok(BatchOperationResult {
171 success_count,
172 failure_count,
173 results,
174 duration_ms,
175 })
176 }
177
178 pub fn batch_delete(&mut self, uris: Vec<String>) -> Result<BatchOperationResult> {
188 let start = std::time::Instant::now();
189 let total_count = uris.len();
190 let mut results = Vec::with_capacity(total_count);
191 let mut success_count = 0;
192 let mut failure_count = 0;
193
194 tracing::info!("Starting batch delete of {} vectors", total_count);
195
196 for uri in uris {
197 match self.remove_vector(&uri) {
198 Ok(_) => {
199 success_count += 1;
200 results.push(Ok(()));
201 }
202 Err(e) => {
203 failure_count += 1;
204 results.push(Err(e.to_string()));
205 }
206 }
207 }
208
209 if success_count > 0 && success_count > total_count / 10 {
211 tracing::info!("Compacting index after batch delete");
212 self.compact_index()?;
213 }
214
215 let duration_ms = start.elapsed().as_millis() as u64;
216
217 tracing::info!(
218 "Batch delete completed: {} successes, {} failures in {}ms",
219 success_count,
220 failure_count,
221 duration_ms
222 );
223
224 Ok(BatchOperationResult {
225 success_count,
226 failure_count,
227 results,
228 duration_ms,
229 })
230 }
231
232 pub fn optimize_graph_structure(&mut self) -> Result<()> {
239 tracing::info!("Starting graph structure optimization");
240
241 let node_count = self.nodes().len();
242 if node_count == 0 {
243 return Ok(());
244 }
245
246 for node_id in 0..node_count {
248 if let Some(node) = self.nodes().get(node_id) {
249 let node_level = node.level();
250
251 for level in 0..=node_level {
252 self.prune_connections_at_level(node_id, level)?;
253 }
254 }
255 }
256
257 self.rebalance_connections()?;
259
260 tracing::info!("Graph structure optimization completed");
261
262 Ok(())
263 }
264
265 fn prune_connections_at_level(&mut self, node_id: usize, level: usize) -> Result<()> {
267 let max_connections = if level == 0 {
268 self.config().m_l0 } else {
270 self.config().m };
272
273 let connections = if let Some(node) = self.nodes().get(node_id) {
275 if let Some(conns) = node.get_connections(level) {
276 conns.clone()
277 } else {
278 return Ok(());
279 }
280 } else {
281 return Ok(());
282 };
283
284 if connections.len() <= max_connections {
285 return Ok(()); }
287
288 let mut connection_distances: Vec<(usize, f32)> = connections
290 .iter()
291 .filter_map(|&conn_id| {
292 self.batch_calculate_distance(node_id, conn_id)
293 .map(|dist| (conn_id, dist))
294 })
295 .collect();
296
297 connection_distances
299 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
300
301 let to_remove: std::collections::HashSet<usize> = connection_distances
303 .iter()
304 .skip(max_connections)
305 .map(|(id, _)| *id)
306 .collect();
307
308 if let Some(node) = self.nodes_mut().get_mut(node_id) {
310 for &conn_id in &to_remove {
311 node.remove_connection(level, conn_id);
312 }
313 }
314
315 Ok(())
316 }
317
318 fn rebalance_connections(&mut self) -> Result<()> {
320 let min_connections = self.config().m / 2; let node_count = self.nodes().len();
322
323 let mut nodes_to_rebalance = Vec::new();
325
326 for node_id in 0..node_count {
327 if let Some(node) = self.nodes().get(node_id) {
328 let node_level = node.level();
329
330 for level in 0..=node_level {
331 let connection_count = node
332 .get_connections(level)
333 .map(|conns| conns.len())
334 .unwrap_or(0);
335
336 if connection_count < min_connections {
338 nodes_to_rebalance.push((node_id, level, min_connections));
339 }
340 }
341 }
342 }
343
344 for (node_id, level, target_connections) in nodes_to_rebalance {
346 self.add_connections_to_node(node_id, level, target_connections)?;
347 }
348
349 Ok(())
350 }
351
352 fn add_connections_to_node(
354 &mut self,
355 node_id: usize,
356 level: usize,
357 target_connections: usize,
358 ) -> Result<()> {
359 let current_connections = if let Some(node) = self.nodes().get(node_id) {
363 node.get_connections(level).cloned().unwrap_or_default()
364 } else {
365 return Ok(());
366 };
367
368 if current_connections.len() >= target_connections {
369 return Ok(());
370 }
371
372 let mut candidates = Vec::new();
374 for (candidate_id, candidate_node) in self.nodes().iter().enumerate() {
375 if candidate_id != node_id
376 && candidate_node.level() >= level
377 && !current_connections.contains(&candidate_id)
378 {
379 if let Some(distance) = self.batch_calculate_distance(node_id, candidate_id) {
380 candidates.push((candidate_id, distance));
381 }
382 }
383 }
384
385 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
387
388 let needed = target_connections - current_connections.len();
390 let new_connections: Vec<usize> = candidates
391 .into_iter()
392 .take(needed)
393 .map(|(id, _)| id)
394 .collect();
395
396 if let Some(node) = self.nodes_mut().get_mut(node_id) {
398 for conn_id in new_connections {
399 node.add_connection(level, conn_id);
400 }
401 }
402
403 Ok(())
404 }
405
406 fn batch_calculate_distance(&self, node1_id: usize, node2_id: usize) -> Option<f32> {
408 let node1 = self.nodes().get(node1_id)?;
409 let node2 = self.nodes().get(node2_id)?;
410
411 self.config()
412 .metric
413 .distance(&node1.vector, &node2.vector)
414 .ok()
415 }
416
417 pub fn compact_index(&mut self) -> Result<()> {
422 tracing::info!("Starting index compaction");
423
424 tracing::info!("Index compaction completed");
432
433 Ok(())
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::hnsw::HnswConfig;
441 use crate::Vector;
442
443 #[test]
444 fn test_batch_insert() {
445 let config = HnswConfig::default();
446 let mut index = HnswIndex::new(config).unwrap();
447
448 let vectors: Vec<(String, Vector)> = (0..100)
449 .map(|i| {
450 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
451 (format!("vec_{}", i), vec)
452 })
453 .collect();
454
455 let batch_config = BatchInsertConfig::default();
456 let result = index.batch_insert(vectors, batch_config).unwrap();
457
458 assert_eq!(result.success_count, 100);
459 assert_eq!(result.failure_count, 0);
460 assert_eq!(index.len(), 100);
461 }
462
463 #[test]
464 fn test_batch_update() {
465 let config = HnswConfig::default();
466 let mut index = HnswIndex::new(config).unwrap();
467
468 for i in 0..10 {
470 let vec = Vector::new(vec![i as f32, 0.0, 0.0]);
471 index.add_vector(format!("vec_{}", i), vec).unwrap();
472 }
473
474 let updates: Vec<(String, Vector)> = (0..10)
476 .map(|i| {
477 let vec = Vector::new(vec![i as f32, 1.0, 1.0]);
478 (format!("vec_{}", i), vec)
479 })
480 .collect();
481
482 let result = index.batch_update(updates).unwrap();
483
484 assert_eq!(result.success_count, 10);
485 assert_eq!(result.failure_count, 0);
486 }
487
488 #[test]
489 fn test_batch_delete() {
490 let config = HnswConfig::default();
491 let mut index = HnswIndex::new(config).unwrap();
492
493 for i in 0..20 {
495 let vec = Vector::new(vec![i as f32, 0.0, 0.0]);
496 index.add_vector(format!("vec_{}", i), vec).unwrap();
497 }
498
499 let to_delete: Vec<String> = (0..10).map(|i| format!("vec_{}", i)).collect();
501
502 let result = index.batch_delete(to_delete).unwrap();
503
504 assert_eq!(result.success_count, 10);
505 assert_eq!(result.failure_count, 0);
506 }
507
508 #[test]
509 fn test_graph_optimization() {
510 let config = HnswConfig::default();
511 let mut index = HnswIndex::new(config).unwrap();
512
513 for i in 0..50 {
515 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
516 index.add_vector(format!("vec_{}", i), vec).unwrap();
517 }
518
519 let size_before = index.len();
520
521 index.optimize_graph_structure().unwrap();
523
524 assert_eq!(index.len(), size_before);
526
527 let query1 = Vector::new(vec![0.0, 0.0, 0.0]);
529 let results1 = index.search_knn(&query1, 5).unwrap();
530 assert!(results1.len() <= 5);
533
534 let query2 = Vector::new(vec![25.0, 50.0, 75.0]);
535 let _results2 = index.search_knn(&query2, 5).unwrap();
536 }
537}