1use std::collections::HashMap;
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct VectorEntry {
15 pub id: u64,
17 pub vector: Vec<f32>,
19 pub metadata: HashMap<String, String>,
21}
22
23impl VectorEntry {
24 pub fn new(id: u64, vector: Vec<f32>) -> Self {
26 Self {
27 id,
28 vector,
29 metadata: HashMap::new(),
30 }
31 }
32
33 pub fn with_metadata(id: u64, vector: Vec<f32>, metadata: HashMap<String, String>) -> Self {
35 Self {
36 id,
37 vector,
38 metadata,
39 }
40 }
41}
42
43#[derive(Debug, Clone, PartialEq)]
47pub struct FlatIndex {
48 pub entries: Vec<VectorEntry>,
50 pub dims: usize,
52}
53
54impl FlatIndex {
55 pub fn new(dims: usize) -> Self {
57 Self {
58 entries: Vec::new(),
59 dims,
60 }
61 }
62
63 pub fn insert(&mut self, entry: VectorEntry) -> Result<(), MergeError> {
66 if entry.vector.len() != self.dims {
67 return Err(MergeError::DimensionMismatch {
68 expected: self.dims,
69 got: entry.vector.len(),
70 });
71 }
72 self.entries.push(entry);
73 Ok(())
74 }
75
76 pub fn len(&self) -> usize {
78 self.entries.len()
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.entries.is_empty()
84 }
85}
86
87#[derive(Debug, Clone, PartialEq)]
89pub struct MergeStats {
90 pub input_count: usize,
92 pub total_before: usize,
94 pub deduplicated: usize,
96 pub total_after: usize,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum MergeError {
103 DimensionMismatch { expected: usize, got: usize },
105 EmptyInput,
107 InvalidParts,
109}
110
111impl std::fmt::Display for MergeError {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 match self {
114 MergeError::DimensionMismatch { expected, got } => {
115 write!(f, "Dimension mismatch: expected {expected}, got {got}")
116 }
117 MergeError::EmptyInput => write!(f, "No input indices provided"),
118 MergeError::InvalidParts => {
119 write!(f, "Number of parts must be greater than zero")
120 }
121 }
122 }
123}
124
125impl std::error::Error for MergeError {}
126
127#[derive(Debug, Default)]
133pub struct IndexMerger {
134 indices: Vec<FlatIndex>,
135}
136
137impl IndexMerger {
138 pub fn new() -> Self {
140 Self {
141 indices: Vec::new(),
142 }
143 }
144
145 pub fn add_index(&mut self, idx: FlatIndex) {
147 self.indices.push(idx);
148 }
149
150 pub fn merge(&mut self) -> Result<FlatIndex, MergeError> {
158 if self.indices.is_empty() {
159 return Err(MergeError::EmptyInput);
160 }
161
162 let dims = self.indices[0].dims;
163
164 for idx in &self.indices {
166 if idx.dims != dims {
167 return Err(MergeError::DimensionMismatch {
168 expected: dims,
169 got: idx.dims,
170 });
171 }
172 }
173
174 let mut order: Vec<u64> = Vec::new();
178 let mut map: HashMap<u64, VectorEntry> = HashMap::new();
179
180 for idx in &self.indices {
181 for entry in &idx.entries {
182 if !map.contains_key(&entry.id) {
183 order.push(entry.id);
184 }
185 map.insert(entry.id, entry.clone());
186 }
187 }
188
189 let mut out = FlatIndex::new(dims);
190 for id in &order {
191 if let Some(entry) = map.remove(id) {
192 out.entries.push(entry);
193 }
194 }
195
196 Ok(out)
197 }
198
199 pub fn merge_with_filter<F>(&mut self, filter: F) -> Result<FlatIndex, MergeError>
202 where
203 F: Fn(&VectorEntry) -> bool,
204 {
205 let merged = self.merge()?;
206 let dims = merged.dims;
207 let mut out = FlatIndex::new(dims);
208 for entry in merged.entries {
209 if filter(&entry) {
210 out.entries.push(entry);
211 }
212 }
213 Ok(out)
214 }
215
216 pub fn merge_with_stats(&mut self) -> Result<(FlatIndex, MergeStats), MergeError> {
218 if self.indices.is_empty() {
219 return Err(MergeError::EmptyInput);
220 }
221
222 let input_count = self.indices.len();
223 let total_before: usize = self.indices.iter().map(|i| i.len()).sum();
224
225 let merged = self.merge()?;
226 let total_after = merged.len();
227 let deduplicated = total_before.saturating_sub(total_after);
228
229 let stats = MergeStats {
230 input_count,
231 total_before,
232 deduplicated,
233 total_after,
234 };
235 Ok((merged, stats))
236 }
237
238 pub fn split(idx: &FlatIndex, parts: usize) -> Vec<FlatIndex> {
245 if parts == 0 {
246 return vec![];
247 }
248 if idx.is_empty() {
249 return (0..parts).map(|_| FlatIndex::new(idx.dims)).collect();
250 }
251
252 let n = idx.entries.len();
253 let base = n / parts;
254 let remainder = n % parts;
255
256 let mut result = Vec::with_capacity(parts);
257 let mut offset = 0usize;
258
259 for i in 0..parts {
260 let chunk_size = base + if i < remainder { 1 } else { 0 };
261 let mut sub = FlatIndex::new(idx.dims);
262 sub.entries
263 .extend_from_slice(&idx.entries[offset..offset + chunk_size]);
264 offset += chunk_size;
265 result.push(sub);
266 }
267 result
268 }
269}
270
271#[cfg(test)]
276mod tests {
277 use super::*;
278
279 fn make_entry(id: u64, dims: usize, val: f32) -> VectorEntry {
280 VectorEntry::new(id, vec![val; dims])
281 }
282
283 fn make_index(dims: usize, ids: &[(u64, f32)]) -> FlatIndex {
284 let mut idx = FlatIndex::new(dims);
285 for (id, val) in ids {
286 idx.insert(make_entry(*id, dims, *val)).expect("insert ok");
287 }
288 idx
289 }
290
291 #[test]
294 fn test_flat_index_new_is_empty() {
295 let idx = FlatIndex::new(4);
296 assert!(idx.is_empty());
297 assert_eq!(idx.len(), 0);
298 assert_eq!(idx.dims, 4);
299 }
300
301 #[test]
302 fn test_flat_index_insert_valid() {
303 let mut idx = FlatIndex::new(3);
304 let entry = make_entry(1, 3, 0.5);
305 assert!(idx.insert(entry).is_ok());
306 assert_eq!(idx.len(), 1);
307 }
308
309 #[test]
310 fn test_flat_index_insert_dimension_mismatch() {
311 let mut idx = FlatIndex::new(3);
312 let entry = make_entry(1, 4, 0.5);
313 assert_eq!(
314 idx.insert(entry),
315 Err(MergeError::DimensionMismatch {
316 expected: 3,
317 got: 4
318 })
319 );
320 }
321
322 #[test]
323 fn test_flat_index_is_not_empty_after_insert() {
324 let mut idx = FlatIndex::new(2);
325 idx.insert(make_entry(1, 2, 1.0)).expect("ok");
326 assert!(!idx.is_empty());
327 }
328
329 #[test]
332 fn test_merge_empty_returns_error() {
333 let mut merger = IndexMerger::new();
334 assert_eq!(merger.merge(), Err(MergeError::EmptyInput));
335 }
336
337 #[test]
338 fn test_merge_single_index() {
339 let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
340 let mut merger = IndexMerger::new();
341 merger.add_index(idx);
342 let out = merger.merge().expect("merge ok");
343 assert_eq!(out.len(), 2);
344 }
345
346 #[test]
347 fn test_merge_two_disjoint_indices() {
348 let a = make_index(2, &[(1, 1.0), (2, 2.0)]);
349 let b = make_index(2, &[(3, 3.0), (4, 4.0)]);
350 let mut merger = IndexMerger::new();
351 merger.add_index(a);
352 merger.add_index(b);
353 let out = merger.merge().expect("merge ok");
354 assert_eq!(out.len(), 4);
355 }
356
357 #[test]
358 fn test_merge_deduplication_last_write_wins() {
359 let a = make_index(2, &[(1, 1.0)]);
361 let b = make_index(2, &[(1, 9.9)]);
362 let mut merger = IndexMerger::new();
363 merger.add_index(a);
364 merger.add_index(b);
365 let out = merger.merge().expect("merge ok");
366 assert_eq!(out.len(), 1);
367 assert!((out.entries[0].vector[0] - 9.9).abs() < 1e-6);
368 }
369
370 #[test]
371 fn test_merge_deduplication_count() {
372 let a = make_index(2, &[(1, 1.0), (2, 2.0)]);
373 let b = make_index(2, &[(2, 2.5), (3, 3.0)]);
374 let mut merger = IndexMerger::new();
375 merger.add_index(a);
376 merger.add_index(b);
377 let out = merger.merge().expect("merge ok");
378 assert_eq!(out.len(), 3);
380 }
381
382 #[test]
383 fn test_merge_dimension_mismatch_error() {
384 let a = make_index(2, &[(1, 1.0)]);
385 let b = make_index(3, &[(2, 2.0)]);
386 let mut merger = IndexMerger::new();
387 merger.add_index(a);
388 merger.add_index(b);
389 assert!(merger.merge().is_err());
390 }
391
392 #[test]
393 fn test_merge_preserves_metadata() {
394 let mut meta = HashMap::new();
395 meta.insert("key".to_string(), "val".to_string());
396 let entry = VectorEntry::with_metadata(42, vec![1.0, 2.0], meta.clone());
397 let mut idx = FlatIndex::new(2);
398 idx.insert(entry).expect("ok");
399 let mut merger = IndexMerger::new();
400 merger.add_index(idx);
401 let out = merger.merge().expect("ok");
402 assert_eq!(out.entries[0].metadata.get("key"), Some(&"val".to_string()));
403 }
404
405 #[test]
408 fn test_merge_with_filter_keeps_matching() {
409 let idx = make_index(2, &[(1, 1.0), (2, 2.0), (3, 3.0)]);
410 let mut merger = IndexMerger::new();
411 merger.add_index(idx);
412 let out = merger.merge_with_filter(|e| e.id % 2 == 1).expect("ok");
413 assert_eq!(out.len(), 2);
414 assert!(out.entries.iter().all(|e| e.id % 2 == 1));
415 }
416
417 #[test]
418 fn test_merge_with_filter_all_excluded() {
419 let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
420 let mut merger = IndexMerger::new();
421 merger.add_index(idx);
422 let out = merger.merge_with_filter(|_| false).expect("ok");
423 assert!(out.is_empty());
424 }
425
426 #[test]
427 fn test_merge_with_filter_all_included() {
428 let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
429 let mut merger = IndexMerger::new();
430 merger.add_index(idx);
431 let out = merger.merge_with_filter(|_| true).expect("ok");
432 assert_eq!(out.len(), 2);
433 }
434
435 #[test]
436 fn test_merge_with_filter_empty_input() {
437 let mut merger = IndexMerger::new();
438 assert_eq!(
439 merger.merge_with_filter(|_| true),
440 Err(MergeError::EmptyInput)
441 );
442 }
443
444 #[test]
447 fn test_merge_stats_no_dedup() {
448 let a = make_index(2, &[(1, 1.0), (2, 2.0)]);
449 let b = make_index(2, &[(3, 3.0)]);
450 let mut merger = IndexMerger::new();
451 merger.add_index(a);
452 merger.add_index(b);
453 let (out, stats) = merger.merge_with_stats().expect("ok");
454 assert_eq!(stats.input_count, 2);
455 assert_eq!(stats.total_before, 3);
456 assert_eq!(stats.deduplicated, 0);
457 assert_eq!(stats.total_after, 3);
458 assert_eq!(out.len(), 3);
459 }
460
461 #[test]
462 fn test_merge_stats_with_dedup() {
463 let a = make_index(2, &[(1, 1.0), (2, 2.0)]);
464 let b = make_index(2, &[(2, 9.0), (3, 3.0)]);
465 let mut merger = IndexMerger::new();
466 merger.add_index(a);
467 merger.add_index(b);
468 let (_out, stats) = merger.merge_with_stats().expect("ok");
469 assert_eq!(stats.total_before, 4);
470 assert_eq!(stats.deduplicated, 1);
471 assert_eq!(stats.total_after, 3);
472 }
473
474 #[test]
475 fn test_merge_stats_empty_input() {
476 let mut merger = IndexMerger::new();
477 assert_eq!(merger.merge_with_stats(), Err(MergeError::EmptyInput));
478 }
479
480 #[test]
483 fn test_split_even() {
484 let idx = make_index(2, &[(1, 1.0), (2, 2.0), (3, 3.0), (4, 4.0)]);
485 let parts = IndexMerger::split(&idx, 2);
486 assert_eq!(parts.len(), 2);
487 assert_eq!(parts[0].len(), 2);
488 assert_eq!(parts[1].len(), 2);
489 }
490
491 #[test]
492 fn test_split_uneven() {
493 let idx = make_index(2, &[(1, 1.0), (2, 2.0), (3, 3.0)]);
494 let parts = IndexMerger::split(&idx, 2);
495 assert_eq!(parts.len(), 2);
496 assert_eq!(parts[0].len(), 2);
498 assert_eq!(parts[1].len(), 1);
499 }
500
501 #[test]
502 fn test_split_into_one() {
503 let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
504 let parts = IndexMerger::split(&idx, 1);
505 assert_eq!(parts.len(), 1);
506 assert_eq!(parts[0].len(), 2);
507 }
508
509 #[test]
510 fn test_split_zero_parts() {
511 let idx = make_index(2, &[(1, 1.0)]);
512 let parts = IndexMerger::split(&idx, 0);
513 assert!(parts.is_empty());
514 }
515
516 #[test]
517 fn test_split_empty_index() {
518 let idx = FlatIndex::new(3);
519 let parts = IndexMerger::split(&idx, 3);
520 assert_eq!(parts.len(), 3);
521 assert!(parts.iter().all(|p| p.is_empty()));
522 }
523
524 #[test]
525 fn test_split_more_parts_than_entries() {
526 let idx = make_index(2, &[(1, 1.0), (2, 2.0)]);
527 let parts = IndexMerger::split(&idx, 5);
528 assert_eq!(parts.len(), 5);
529 let total: usize = parts.iter().map(|p| p.len()).sum();
530 assert_eq!(total, 2);
531 }
532
533 #[test]
534 fn test_split_preserves_dims() {
535 let idx = make_index(7, &[(1, 1.0), (2, 2.0), (3, 3.0)]);
536 let parts = IndexMerger::split(&idx, 2);
537 for p in &parts {
538 assert_eq!(p.dims, 7);
539 }
540 }
541
542 #[test]
543 fn test_split_total_count_preserved() {
544 let ids: Vec<(u64, f32)> = (1u64..=10).map(|i| (i, i as f32)).collect();
545 let idx = make_index(4, &ids);
546 let parts = IndexMerger::split(&idx, 3);
547 let total: usize = parts.iter().map(|p| p.len()).sum();
548 assert_eq!(total, 10);
549 }
550
551 #[test]
554 fn test_error_display_empty_input() {
555 let e = MergeError::EmptyInput;
556 assert!(e.to_string().contains("No input"));
557 }
558
559 #[test]
560 fn test_error_display_dimension_mismatch() {
561 let e = MergeError::DimensionMismatch {
562 expected: 4,
563 got: 3,
564 };
565 let s = e.to_string();
566 assert!(s.contains("4"));
567 assert!(s.contains("3"));
568 }
569
570 #[test]
571 fn test_error_display_invalid_parts() {
572 let e = MergeError::InvalidParts;
573 assert!(e.to_string().contains("zero"));
574 }
575
576 #[test]
577 fn test_error_is_std_error() {
578 let e: Box<dyn std::error::Error> = Box::new(MergeError::EmptyInput);
579 assert!(e.to_string().contains("No input"));
580 }
581
582 #[test]
585 fn test_vector_entry_new() {
586 let e = VectorEntry::new(7, vec![1.0, 2.0, 3.0]);
587 assert_eq!(e.id, 7);
588 assert_eq!(e.vector.len(), 3);
589 assert!(e.metadata.is_empty());
590 }
591
592 #[test]
593 fn test_vector_entry_with_metadata() {
594 let mut meta = HashMap::new();
595 meta.insert("source".into(), "test".into());
596 let e = VectorEntry::with_metadata(1, vec![0.0], meta);
597 assert_eq!(e.metadata.get("source"), Some(&"test".to_string()));
598 }
599
600 #[test]
601 fn test_index_merger_default() {
602 let _m: IndexMerger = IndexMerger::default();
603 }
604
605 #[test]
606 fn test_merge_three_indices() {
607 let a = make_index(2, &[(1, 1.0)]);
608 let b = make_index(2, &[(2, 2.0)]);
609 let c = make_index(2, &[(3, 3.0)]);
610 let mut merger = IndexMerger::new();
611 merger.add_index(a);
612 merger.add_index(b);
613 merger.add_index(c);
614 let out = merger.merge().expect("ok");
615 assert_eq!(out.len(), 3);
616 }
617
618 #[test]
619 fn test_merge_large_index() {
620 let pairs: Vec<(u64, f32)> = (1u64..=100).map(|i| (i, i as f32)).collect();
621 let idx = make_index(4, &pairs);
622 let mut merger = IndexMerger::new();
623 merger.add_index(idx);
624 let out = merger.merge().expect("ok");
625 assert_eq!(out.len(), 100);
626 }
627
628 #[test]
629 fn test_split_four_parts() {
630 let pairs: Vec<(u64, f32)> = (1u64..=8).map(|i| (i, i as f32)).collect();
631 let idx = make_index(2, &pairs);
632 let parts = IndexMerger::split(&idx, 4);
633 assert_eq!(parts.len(), 4);
634 assert!(parts.iter().all(|p| p.len() == 2));
635 }
636
637 #[test]
638 fn test_merge_filter_by_vector_value() {
639 let pairs: Vec<(u64, f32)> = (1u64..=10).map(|i| (i, i as f32)).collect();
640 let idx = make_index(2, &pairs);
641 let mut merger = IndexMerger::new();
642 merger.add_index(idx);
643 let out = merger
645 .merge_with_filter(|e| e.vector[0] >= 5.0)
646 .expect("ok");
647 assert_eq!(out.len(), 6); }
649
650 #[test]
651 fn test_flat_index_dims_preserved_through_merge() {
652 let idx = make_index(128, &[(1, 0.1), (2, 0.2)]);
653 let mut merger = IndexMerger::new();
654 merger.add_index(idx);
655 let out = merger.merge().expect("ok");
656 assert_eq!(out.dims, 128);
657 }
658
659 #[test]
660 fn test_stats_input_count_three() {
661 let mut merger = IndexMerger::new();
662 merger.add_index(make_index(2, &[(1, 1.0)]));
663 merger.add_index(make_index(2, &[(2, 2.0)]));
664 merger.add_index(make_index(2, &[(3, 3.0)]));
665 let (_, stats) = merger.merge_with_stats().expect("ok");
666 assert_eq!(stats.input_count, 3);
667 }
668
669 #[test]
670 fn test_split_single_entry_many_parts() {
671 let idx = make_index(2, &[(42, 1.0)]);
672 let parts = IndexMerger::split(&idx, 4);
673 let total: usize = parts.iter().map(|p| p.len()).sum();
674 assert_eq!(total, 1);
675 assert_eq!(parts.len(), 4);
676 }
677}