Skip to main content

nodedb_vector/delta/
compaction.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! SPFresh LIRE-style topology-aware local patching.
4//!
5//! Rather than rebuilding the entire HNSW on every update, `LirePatcher`
6//! stitches fresh delta vectors into the main graph one node at a time using
7//! the graph's own `insert` routine (which runs heuristic neighbor selection
8//! and bidirectional edge maintenance internally).  Tombstones are forwarded
9//! to the HNSW as soft-deletes so search skips them immediately; physical
10//! removal is deferred to a later background compaction sweep via
11//! `HnswIndex::compact`.
12//!
13//! ## Partition-quality drift estimate
14//!
15//! After inserting each fresh node we inspect its assigned neighbors.  For
16//! each neighbor `n`, we count how many of *n*'s own neighbors were already
17//! neighbors of any other newly-inserted node in this batch.  A high overlap
18//! means the delta nodes clustered into an already-dense region — minimal
19//! drift.  A low overlap means the new node landed in a sparse region that
20//! may require broader re-wiring.
21//!
22//! This is an O(patch_size × M) approximation of LIRE's full local-rebuild
23//! signal (SOSP 2023 §4.2).  When the average overlap fraction across all
24//! patched nodes falls below `drift_threshold`, we record that subgraph in
25//! `PatchStats::drift_subgraphs` so the caller can schedule a deeper
26//! re-pruning pass at lower priority.
27
28use crate::delta::index::DeltaIndex;
29use crate::error::VectorError;
30use crate::hnsw::HnswIndex;
31
32/// SPFresh LIRE-style topology-aware local patcher.
33///
34/// Holds mutable references to both the main HNSW and the delta buffer so
35/// both can be updated atomically within a single flush.
36pub struct LirePatcher<'a> {
37    /// Main HNSW graph that fresh vectors will be patched into.
38    pub main: &'a mut HnswIndex,
39    /// Delta buffer whose fresh vectors (and tombstones) will be drained.
40    pub delta: &'a mut DeltaIndex,
41    /// Drift threshold in `[0.0, 1.0]`.  When the average neighbor-overlap
42    /// fraction for a batch falls below this value the subgraph is flagged
43    /// for deeper re-pruning.  Default: `0.3`.
44    pub drift_threshold: f32,
45}
46
47/// Statistics returned by a single `LirePatcher::patch` call.
48#[derive(Debug, Default, Clone)]
49pub struct PatchStats {
50    /// Number of fresh vectors successfully patched into the main HNSW.
51    pub patched: usize,
52    /// Number of vectors tombstoned in the main HNSW during this flush.
53    pub tombstoned_marked: usize,
54    /// Number of subgraphs flagged for deeper re-pruning due to drift.
55    pub drift_subgraphs: usize,
56}
57
58impl<'a> LirePatcher<'a> {
59    /// Create a patcher with the default drift threshold (`0.3`).
60    pub fn new(main: &'a mut HnswIndex, delta: &'a mut DeltaIndex) -> Self {
61        Self {
62            main,
63            delta,
64            drift_threshold: 0.3,
65        }
66    }
67
68    /// Flush the delta buffer into the main HNSW.
69    ///
70    /// ## Steps
71    ///
72    /// 1. Drain tombstones → call `HnswIndex::delete` on each.
73    /// 2. Drain fresh vectors → call `HnswIndex::insert` on each.
74    /// 3. After each insert, estimate local topology drift for the newly
75    ///    assigned node and accumulate the overlap fraction.
76    /// 4. If average overlap fraction < `drift_threshold`, increment
77    ///    `drift_subgraphs`.
78    ///
79    /// `k_neighbors` and `ef_construction` are accepted for API completeness
80    /// and forward-compatibility with future Vamana-style patchers; the
81    /// current HNSW implementation derives its own neighbor count from
82    /// `HnswParams` stored on the index, so these values are informational.
83    pub fn patch(
84        &mut self,
85        _k_neighbors: usize,
86        _ef_construction: usize,
87    ) -> Result<PatchStats, VectorError> {
88        let mut stats = PatchStats::default();
89
90        // --- Step 1: Forward tombstones to the main HNSW ---
91        let tombstone_ids = self.delta.drain_tombstones();
92        for id in tombstone_ids {
93            if self.main.delete(id) {
94                stats.tombstoned_marked += 1;
95            }
96        }
97
98        // --- Step 2 + 3: Insert fresh vectors and estimate drift ---
99        let fresh = self.delta.drain_fresh();
100
101        // Collect the node IDs that will be assigned to freshly inserted nodes
102        // so we can measure neighborhood overlap after each insert.
103        // The HNSW appends nodes sequentially, so the new id = len() before insert.
104        let mut overlap_fractions: Vec<f32> = Vec::with_capacity(fresh.len());
105        // Track the set of recently-patched node ids for overlap estimation.
106        let mut patched_ids: std::collections::HashSet<u32> =
107            std::collections::HashSet::with_capacity(fresh.len());
108
109        for (user_id, vector) in fresh {
110            // Skip tombstoned fresh inserts — they were deleted before we
111            // could patch them.
112            if self.delta.is_tombstoned(user_id) {
113                continue;
114            }
115
116            // The HNSW uses its own internal monotonic IDs (insertion order).
117            // We record what the next id will be before the insert.
118            let new_internal_id = self.main.len() as u32;
119
120            self.main.insert(vector)?;
121            stats.patched += 1;
122
123            // --- Drift estimation (LIRE approximation) ---
124            // Inspect neighbors assigned to the new node at layer 0.
125            let neighbors_l0 = self.main.hnsw_neighbors_layer0(new_internal_id);
126
127            let overlap_fraction = if neighbors_l0.is_empty() {
128                // First node or isolated — perfect connectivity by definition.
129                1.0f32
130            } else {
131                // Count how many neighbors are themselves in the current
132                // patched-ids set (i.e., recently inserted into this batch).
133                let overlap = neighbors_l0
134                    .iter()
135                    .filter(|&&nid| patched_ids.contains(&nid))
136                    .count();
137                overlap as f32 / neighbors_l0.len() as f32
138            };
139
140            overlap_fractions.push(overlap_fraction);
141            patched_ids.insert(new_internal_id);
142        }
143
144        // --- Step 4: Flag drift subgraphs ---
145        if !overlap_fractions.is_empty() {
146            let avg_overlap =
147                overlap_fractions.iter().sum::<f32>() / overlap_fractions.len() as f32;
148            if avg_overlap < self.drift_threshold {
149                stats.drift_subgraphs += 1;
150            }
151        }
152
153        Ok(stats)
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Additive accessor on HnswIndex required by LirePatcher.
159//
160// `mark_deleted` does not exist on `HnswIndex`; the equivalent is `delete`.
161// We add a thin helper that exposes layer-0 neighbors so the drift estimator
162// can read them without re-exposing internal fields.
163// ---------------------------------------------------------------------------
164impl HnswIndex {
165    /// Return the layer-0 neighbor list of `node_id`, or an empty slice if
166    /// the node does not exist or has no layer-0 neighbors.
167    ///
168    /// Used by `LirePatcher` for local topology-drift estimation.
169    pub fn hnsw_neighbors_layer0(&self, node_id: u32) -> Vec<u32> {
170        self.neighbors_at(node_id, 0).to_vec()
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::hnsw::HnswIndex;
178    use nodedb_types::hnsw::HnswParams;
179
180    fn small_params() -> HnswParams {
181        HnswParams {
182            m: 4,
183            m0: 8,
184            ef_construction: 20,
185            ..HnswParams::default()
186        }
187    }
188
189    #[test]
190    fn patch_grows_hnsw_and_drains_delta() {
191        let mut main = HnswIndex::with_seed(3, small_params(), 1);
192        // Pre-populate main with 10 vectors so fresh nodes have neighbors.
193        for i in 0u32..10 {
194            let v = vec![i as f32, 0.0, 0.0];
195            main.insert(v).expect("pre-populate insert failed");
196        }
197        assert_eq!(main.len(), 10);
198
199        let mut delta = DeltaIndex::new(3, 32);
200        for i in 10u32..15 {
201            let v = vec![i as f32, 1.0, 0.0];
202            delta.insert(i, v);
203        }
204        assert_eq!(delta.fresh_len(), 5);
205
206        let mut patcher = LirePatcher::new(&mut main, &mut delta);
207        let stats = patcher.patch(8, 20).expect("patch failed");
208
209        assert_eq!(stats.patched, 5);
210        assert_eq!(delta.fresh_len(), 0);
211        assert_eq!(main.len(), 15);
212    }
213
214    #[test]
215    fn tombstone_forwarded_to_hnsw() {
216        let mut main = HnswIndex::with_seed(3, small_params(), 2);
217        for i in 0u32..5 {
218            let v = vec![i as f32, 0.0, 0.0];
219            main.insert(v).expect("insert failed");
220        }
221        assert!(!main.is_deleted(2));
222
223        let mut delta = DeltaIndex::new(3, 16);
224        delta.tombstone(2);
225
226        let mut patcher = LirePatcher::new(&mut main, &mut delta);
227        let stats = patcher.patch(4, 20).expect("patch failed");
228
229        assert_eq!(stats.tombstoned_marked, 1);
230        assert!(main.is_deleted(2));
231    }
232
233    #[test]
234    fn patch_empty_delta_is_noop() {
235        let mut main = HnswIndex::with_seed(3, small_params(), 3);
236        for i in 0u32..3 {
237            main.insert(vec![i as f32, 0.0, 0.0])
238                .expect("insert failed");
239        }
240        let initial_len = main.len();
241
242        let mut delta = DeltaIndex::new(3, 16);
243        let mut patcher = LirePatcher::new(&mut main, &mut delta);
244        let stats = patcher.patch(4, 20).expect("patch failed");
245
246        assert_eq!(stats.patched, 0);
247        assert_eq!(stats.tombstoned_marked, 0);
248        assert_eq!(main.len(), initial_len);
249    }
250}