1use crate::distance::DistanceMetric;
8use crate::hnsw::HnswIndex;
9
10const DEFAULT_CAPACITY: usize = 10_000;
12
13pub struct StagingBuffer {
15 pending: Vec<(usize, Vec<f32>)>, capacity: usize,
17}
18
19impl Default for StagingBuffer {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl StagingBuffer {
26 pub fn new() -> Self {
27 Self {
28 pending: Vec::new(),
29 capacity: DEFAULT_CAPACITY,
30 }
31 }
32
33 pub fn add(&mut self, external_id: usize, vector: Vec<f32>) -> bool {
35 self.pending.push((external_id, vector));
36 self.pending.len() >= self.capacity
37 }
38
39 pub fn pending_count(&self) -> usize {
41 self.pending.len()
42 }
43
44 pub fn search_merged(
46 &self,
47 live: &HnswIndex,
48 query: &[f32],
49 k: usize,
50 ef: usize,
51 metric: DistanceMetric,
52 ) -> Vec<(usize, f32)> {
53 let mut results = if !live.is_empty() {
55 live.search(query, k, ef)
56 } else {
57 Vec::new()
58 };
59
60 for (ext_id, vec) in &self.pending {
62 let dist = metric.distance(query, vec);
63 results.push((*ext_id, dist));
64 }
65
66 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
68 results.truncate(k);
69 results
70 }
71
72 pub fn flush(&mut self, index: &mut HnswIndex) {
74 for (_ext_id, vec) in self.pending.drain(..) {
75 index.insert(&vec);
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::hnsw::HnswConfig;
84
85 #[test]
86 fn auto_flush_at_capacity() {
87 let mut buf = StagingBuffer {
88 pending: Vec::new(),
89 capacity: 3,
90 };
91 assert!(!buf.add(0, vec![1.0]));
92 assert!(!buf.add(1, vec![2.0]));
93 assert!(buf.add(2, vec![3.0])); assert_eq!(buf.pending_count(), 3);
95 }
96
97 #[test]
98 fn search_merged_includes_pending() {
99 let live = HnswIndex::new(2, HnswConfig::default());
100 let mut buf = StagingBuffer::new();
101
102 buf.add(0, vec![1.0, 0.0]);
103 buf.add(1, vec![0.0, 1.0]);
104
105 let results = buf.search_merged(&live, &[0.9, 0.0], 2, 50, DistanceMetric::L2);
106 assert_eq!(results.len(), 2);
107 assert_eq!(results[0].0, 0);
109 }
110
111 #[test]
112 fn flush_moves_to_live() {
113 let mut index = HnswIndex::new(2, HnswConfig::default());
114 let mut buf = StagingBuffer::new();
115
116 buf.add(0, vec![1.0, 0.0]);
117 buf.add(1, vec![0.0, 1.0]);
118 assert_eq!(buf.pending_count(), 2);
119
120 buf.flush(&mut index);
121 assert_eq!(buf.pending_count(), 0);
122 assert_eq!(index.len(), 2);
123
124 let results = index.search(&[1.0, 0.0], 2, 50);
126 assert_eq!(results.len(), 2);
127 }
128}