use crate::delta::index::DeltaIndex;
use crate::error::VectorError;
use crate::hnsw::HnswIndex;
pub struct LirePatcher<'a> {
pub main: &'a mut HnswIndex,
pub delta: &'a mut DeltaIndex,
pub drift_threshold: f32,
}
#[derive(Debug, Default, Clone)]
pub struct PatchStats {
pub patched: usize,
pub tombstoned_marked: usize,
pub drift_subgraphs: usize,
}
impl<'a> LirePatcher<'a> {
pub fn new(main: &'a mut HnswIndex, delta: &'a mut DeltaIndex) -> Self {
Self {
main,
delta,
drift_threshold: 0.3,
}
}
pub fn patch(
&mut self,
_k_neighbors: usize,
_ef_construction: usize,
) -> Result<PatchStats, VectorError> {
let mut stats = PatchStats::default();
let tombstone_ids = self.delta.drain_tombstones();
for id in tombstone_ids {
if self.main.delete(id) {
stats.tombstoned_marked += 1;
}
}
let fresh = self.delta.drain_fresh();
let mut overlap_fractions: Vec<f32> = Vec::with_capacity(fresh.len());
let mut patched_ids: std::collections::HashSet<u32> =
std::collections::HashSet::with_capacity(fresh.len());
for (user_id, vector) in fresh {
if self.delta.is_tombstoned(user_id) {
continue;
}
let new_internal_id = self.main.len() as u32;
self.main.insert(vector)?;
stats.patched += 1;
let neighbors_l0 = self.main.hnsw_neighbors_layer0(new_internal_id);
let overlap_fraction = if neighbors_l0.is_empty() {
1.0f32
} else {
let overlap = neighbors_l0
.iter()
.filter(|&&nid| patched_ids.contains(&nid))
.count();
overlap as f32 / neighbors_l0.len() as f32
};
overlap_fractions.push(overlap_fraction);
patched_ids.insert(new_internal_id);
}
if !overlap_fractions.is_empty() {
let avg_overlap =
overlap_fractions.iter().sum::<f32>() / overlap_fractions.len() as f32;
if avg_overlap < self.drift_threshold {
stats.drift_subgraphs += 1;
}
}
Ok(stats)
}
}
impl HnswIndex {
pub fn hnsw_neighbors_layer0(&self, node_id: u32) -> Vec<u32> {
self.neighbors_at(node_id, 0).to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::HnswIndex;
use nodedb_types::hnsw::HnswParams;
fn small_params() -> HnswParams {
HnswParams {
m: 4,
m0: 8,
ef_construction: 20,
..HnswParams::default()
}
}
#[test]
fn patch_grows_hnsw_and_drains_delta() {
let mut main = HnswIndex::with_seed(3, small_params(), 1);
for i in 0u32..10 {
let v = vec![i as f32, 0.0, 0.0];
main.insert(v).expect("pre-populate insert failed");
}
assert_eq!(main.len(), 10);
let mut delta = DeltaIndex::new(3, 32);
for i in 10u32..15 {
let v = vec![i as f32, 1.0, 0.0];
delta.insert(i, v);
}
assert_eq!(delta.fresh_len(), 5);
let mut patcher = LirePatcher::new(&mut main, &mut delta);
let stats = patcher.patch(8, 20).expect("patch failed");
assert_eq!(stats.patched, 5);
assert_eq!(delta.fresh_len(), 0);
assert_eq!(main.len(), 15);
}
#[test]
fn tombstone_forwarded_to_hnsw() {
let mut main = HnswIndex::with_seed(3, small_params(), 2);
for i in 0u32..5 {
let v = vec![i as f32, 0.0, 0.0];
main.insert(v).expect("insert failed");
}
assert!(!main.is_deleted(2));
let mut delta = DeltaIndex::new(3, 16);
delta.tombstone(2);
let mut patcher = LirePatcher::new(&mut main, &mut delta);
let stats = patcher.patch(4, 20).expect("patch failed");
assert_eq!(stats.tombstoned_marked, 1);
assert!(main.is_deleted(2));
}
#[test]
fn patch_empty_delta_is_noop() {
let mut main = HnswIndex::with_seed(3, small_params(), 3);
for i in 0u32..3 {
main.insert(vec![i as f32, 0.0, 0.0])
.expect("insert failed");
}
let initial_len = main.len();
let mut delta = DeltaIndex::new(3, 16);
let mut patcher = LirePatcher::new(&mut main, &mut delta);
let stats = patcher.patch(4, 20).expect("patch failed");
assert_eq!(stats.patched, 0);
assert_eq!(stats.tombstoned_marked, 0);
assert_eq!(main.len(), initial_len);
}
}