use crate::distance::FlatVectors;
use crate::error::Result;
use crate::graph::VamanaGraph;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RebuildPolicy {
AlwaysRebuild,
ReweightOnly,
Periodic {
k: usize,
},
}
impl RebuildPolicy {
fn rebuilds_at(self, step: usize) -> bool {
match self {
RebuildPolicy::AlwaysRebuild => true,
RebuildPolicy::ReweightOnly => false,
RebuildPolicy::Periodic { k } => k > 0 && step % k == 0,
}
}
}
pub struct DriftingIndex {
graph: VamanaGraph,
policy: RebuildPolicy,
n: usize,
max_degree: usize,
build_beam: usize,
alpha: f32,
step: usize,
rebuilds: usize,
}
impl DriftingIndex {
pub fn build(
vectors: &FlatVectors,
policy: RebuildPolicy,
max_degree: usize,
build_beam: usize,
alpha: f32,
) -> Result<Self> {
let n = vectors.len();
let graph = build_graph(vectors, n, max_degree, build_beam, alpha)?;
Ok(Self {
graph,
policy,
n,
max_degree,
build_beam,
alpha,
step: 0,
rebuilds: 0,
})
}
pub fn on_metric_update(&mut self, vectors: &FlatVectors) -> Result<bool> {
self.step += 1;
if !self.policy.rebuilds_at(self.step) {
return Ok(false);
}
debug_assert_eq!(
vectors.len(),
self.n,
"reuse model assumes fixed membership; point count changed"
);
self.graph = build_graph(
vectors,
self.n,
self.max_degree,
self.build_beam,
self.alpha,
)?;
self.rebuilds += 1;
Ok(true)
}
pub fn search(
&self,
vectors: &FlatVectors,
query: &[f32],
beam_width: usize,
) -> (Vec<u32>, usize) {
self.graph.greedy_search(vectors, query, beam_width)
}
pub fn force_rebuild(&mut self, vectors: &FlatVectors) -> Result<()> {
debug_assert_eq!(vectors.len(), self.n, "force_rebuild: point count changed");
self.graph = build_graph(
vectors,
self.n,
self.max_degree,
self.build_beam,
self.alpha,
)?;
self.rebuilds += 1;
Ok(())
}
pub fn policy(&self) -> RebuildPolicy {
self.policy
}
pub fn step(&self) -> usize {
self.step
}
pub fn rebuilds(&self) -> usize {
self.rebuilds
}
pub fn graph(&self) -> &VamanaGraph {
&self.graph
}
}
fn build_graph(
vectors: &FlatVectors,
n: usize,
max_degree: usize,
build_beam: usize,
alpha: f32,
) -> Result<VamanaGraph> {
let mut graph = VamanaGraph::new(n, max_degree, build_beam, alpha);
graph.build(vectors)?;
Ok(graph)
}
fn brute_force_topk(vectors: &FlatVectors, q: usize, k: usize) -> Vec<u32> {
let qv = vectors.get(q);
let mut scored: Vec<(f32, u32)> = (0..vectors.len())
.filter(|&i| i != q)
.map(|i| (crate::distance::l2_squared(vectors.get(i), qv), i as u32))
.collect();
scored.sort_by(|a, b| a.0.total_cmp(&b.0));
scored.into_iter().take(k).map(|(_, i)| i).collect()
}
pub struct RecallTrigger {
index: DriftingIndex,
probe_queries: Vec<u32>,
k: usize,
floor: f32,
search_beam: usize,
}
impl RecallTrigger {
#[allow(clippy::too_many_arguments)]
pub fn build(
vectors: &FlatVectors,
probe_queries: Vec<u32>,
k: usize,
floor: f32,
search_beam: usize,
max_degree: usize,
build_beam: usize,
alpha: f32,
) -> Result<Self> {
let index = DriftingIndex::build(
vectors,
RebuildPolicy::ReweightOnly,
max_degree,
build_beam,
alpha,
)?;
Ok(Self {
index,
probe_queries,
k,
floor,
search_beam,
})
}
pub fn probe_recall(&self, vectors: &FlatVectors) -> f32 {
if self.probe_queries.is_empty() {
return 1.0;
}
let mut sum = 0.0f32;
for &q in &self.probe_queries {
let qi = q as usize;
let truth = brute_force_topk(vectors, qi, self.k);
let qv = vectors.get(qi);
let (cands, _) = self.index.search(vectors, qv, self.search_beam);
let mut scored: Vec<(f32, u32)> = cands
.iter()
.map(|&c| (crate::distance::l2_squared(vectors.get(c as usize), qv), c))
.collect();
scored.sort_by(|a, b| a.0.total_cmp(&b.0));
let hits = scored
.into_iter()
.filter(|&(_, c)| c as usize != qi)
.take(self.k)
.filter(|(_, c)| truth.contains(c))
.count();
sum += hits as f32 / self.k.max(1) as f32;
}
sum / self.probe_queries.len() as f32
}
pub fn on_metric_update(&mut self, vectors: &FlatVectors) -> Result<bool> {
if self.probe_recall(vectors) < self.floor {
self.index.force_rebuild(vectors)?;
Ok(true)
} else {
Ok(false)
}
}
pub fn search(
&self,
vectors: &FlatVectors,
query: &[f32],
beam_width: usize,
) -> (Vec<u32>, usize) {
self.index.search(vectors, query, beam_width)
}
pub fn rebuilds(&self) -> usize {
self.index.rebuilds()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture(n: usize, dim: usize) -> FlatVectors {
let mut f = FlatVectors::with_capacity(dim, n);
for i in 0..n {
let v: Vec<f32> = (0..dim)
.map(|d| ((i * 31 + d * 7) % 97) as f32 / 97.0)
.collect();
f.push(&v);
}
f
}
#[test]
fn reweight_only_never_rebuilds() {
let v = fixture(64, 8);
let mut idx = DriftingIndex::build(&v, RebuildPolicy::ReweightOnly, 16, 32, 1.2).unwrap();
for _ in 0..10 {
assert!(!idx.on_metric_update(&v).unwrap());
}
assert_eq!(idx.rebuilds(), 0);
assert_eq!(idx.step(), 10);
}
#[test]
fn always_rebuild_rebuilds_every_step() {
let v = fixture(64, 8);
let mut idx = DriftingIndex::build(&v, RebuildPolicy::AlwaysRebuild, 16, 32, 1.2).unwrap();
for _ in 0..10 {
assert!(idx.on_metric_update(&v).unwrap());
}
assert_eq!(idx.rebuilds(), 10);
}
#[test]
fn periodic_rebuilds_on_cadence() {
let v = fixture(64, 8);
let mut idx =
DriftingIndex::build(&v, RebuildPolicy::Periodic { k: 4 }, 16, 32, 1.2).unwrap();
let did: Vec<bool> = (0..12).map(|_| idx.on_metric_update(&v).unwrap()).collect();
assert_eq!(
did,
vec![false, false, false, true, false, false, false, true, false, false, false, true]
);
assert_eq!(idx.rebuilds(), 3);
}
#[test]
fn periodic_k0_is_reweight_only() {
let v = fixture(32, 8);
let mut idx =
DriftingIndex::build(&v, RebuildPolicy::Periodic { k: 0 }, 16, 32, 1.2).unwrap();
for _ in 0..5 {
assert!(!idx.on_metric_update(&v).unwrap());
}
assert_eq!(idx.rebuilds(), 0);
}
#[test]
fn force_rebuild_counts_but_does_not_advance_step() {
let v = fixture(64, 8);
let mut idx = DriftingIndex::build(&v, RebuildPolicy::ReweightOnly, 16, 32, 1.2).unwrap();
idx.on_metric_update(&v).unwrap(); idx.force_rebuild(&v).unwrap(); idx.force_rebuild(&v).unwrap();
assert_eq!(
idx.step(),
1,
"force_rebuild must not advance the update step"
);
assert_eq!(
idx.rebuilds(),
2,
"force_rebuild must count toward rebuilds"
);
}
fn fixture_b(n: usize, dim: usize) -> FlatVectors {
let mut f = FlatVectors::with_capacity(dim, n);
for i in 0..n {
let v: Vec<f32> = (0..dim)
.map(|d| (((n - i) * 53 + d * 17) % 89) as f32 / 89.0)
.collect();
f.push(&v);
}
f
}
#[test]
fn recall_trigger_holds_under_no_drift() {
let v = fixture(128, 8);
let probes: Vec<u32> = (0..16).collect();
let mut t = RecallTrigger::build(&v, probes, 5, 0.9, 32, 16, 32, 1.2).unwrap();
assert!(t.probe_recall(&v) >= 0.9);
assert!(!t.on_metric_update(&v).unwrap());
assert_eq!(t.rebuilds(), 0);
}
#[test]
fn recall_trigger_fires_then_recovers_under_drift() {
let v = fixture(128, 8);
let probes: Vec<u32> = (0..16).collect();
let mut t = RecallTrigger::build(&v, probes, 5, 0.9, 32, 16, 32, 1.2).unwrap();
let vb = fixture_b(128, 8);
assert!(
t.probe_recall(&vb) < 0.9,
"drift should drop probe recall below floor"
);
assert!(
t.on_metric_update(&vb).unwrap(),
"trigger must fire on the drift"
);
assert_eq!(t.rebuilds(), 1);
assert!(!t.on_metric_update(&vb).unwrap());
assert_eq!(t.rebuilds(), 1);
}
#[test]
fn search_returns_self_as_nearest() {
let v = fixture(128, 8);
let idx = DriftingIndex::build(&v, RebuildPolicy::ReweightOnly, 16, 32, 1.2).unwrap();
let q = v.get(5).to_vec();
let (cands, visited) = idx.search(&v, &q, 16);
assert!(visited > 0);
assert!(cands.contains(&5), "self should be retrieved: {cands:?}");
}
}