use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use turbovec::codebook::codebook;
use turbovec::rotation::make_rotation_matrix;
use crate::distance::Distance;
use crate::index::{HnswParams, IndexError, NodeId, SearchResult};
pub trait CodecDistance {
fn distance(&self, a: NodeId, b: NodeId) -> f32;
}
#[derive(Clone, Debug)]
struct TurboHnswNode {
id: NodeId,
levels: Vec<Vec<usize>>,
deleted: bool,
}
impl TurboHnswNode {
fn level(&self) -> usize {
self.levels.len().saturating_sub(1)
}
}
pub struct TurboHnswIndex<const BITS: u8> {
distance: Distance,
dim: u16,
params: HnswParams,
rotation: Vec<f32>,
boundaries: Vec<f32>,
centroids: Vec<f32>,
packed: Vec<u8>,
scales: Vec<f32>,
nodes: Vec<TurboHnswNode>,
id_to_idx: HashMap<NodeId, usize>,
entry: Option<usize>,
rng_state: u64,
ml: f64,
}
#[derive(Clone, Copy, Debug)]
struct Candidate {
idx: usize,
score: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Clone, Copy, Debug)]
struct MaxCandidate {
idx: usize,
score: f32,
}
impl PartialEq for MaxCandidate {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for MaxCandidate {}
impl PartialOrd for MaxCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxCandidate {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
}
}
impl<const BITS: u8> TurboHnswIndex<BITS> {
pub fn new(distance: Distance, dim: u16, params: HnswParams) -> Result<Self, IndexError> {
if !(2..=4).contains(&BITS) {
return Err(IndexError::Empty);
}
if dim == 0 {
return Err(IndexError::Empty);
}
if !dim.is_multiple_of(8) {
return Err(IndexError::DimensionMismatch {
expected: ((dim / 8) + 1) * 8,
got: dim,
});
}
let dim_usize = usize::from(dim);
let bits_usize = usize::from(BITS);
let rotation = make_rotation_matrix(dim_usize);
let (boundaries, centroids) = codebook(bits_usize, dim_usize);
let ml = if params.m > 1 {
1.0 / f64::from(u32::try_from(params.m).unwrap_or(u32::MAX)).ln()
} else {
1.0
};
Ok(Self {
distance,
dim,
params,
rotation,
boundaries,
centroids,
packed: Vec::new(),
scales: Vec::new(),
nodes: Vec::new(),
id_to_idx: HashMap::new(),
entry: None,
rng_state: params.seed,
ml,
})
}
#[must_use]
pub fn len(&self) -> usize {
self.nodes.iter().filter(|n| !n.deleted).count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn dim(&self) -> u16 {
self.dim
}
#[must_use]
pub fn distance_metric(&self) -> Distance {
self.distance
}
#[must_use]
pub fn bits(&self) -> u8 {
BITS
}
#[must_use]
pub fn contains(&self, id: NodeId) -> bool {
self.id_to_idx
.get(&id)
.is_some_and(|&idx| !self.nodes[idx].deleted)
}
pub fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
if vector.is_empty() {
return Err(IndexError::Empty);
}
let got = u16::try_from(vector.len()).unwrap_or(u16::MAX);
if got != self.dim {
return Err(IndexError::DimensionMismatch {
expected: self.dim,
got,
});
}
if self.id_to_idx.contains_key(&id) {
return Err(IndexError::Duplicate(id));
}
let prepared = match self.distance {
Distance::Cosine | Distance::Euclidean => l2_normalise(&vector),
Distance::DotProduct => vector,
};
for v in &prepared {
if !v.is_finite() || v.abs() >= 1e16_f32 {
return Err(IndexError::Empty);
}
}
let dim_usize = usize::from(self.dim);
let bytes_per_vec = self.bytes_per_vec();
let (packed, scale) = self.encode_one(&prepared);
debug_assert_eq!(packed.len(), bytes_per_vec);
let _ = dim_usize;
self.packed.extend_from_slice(&packed);
self.scales.push(scale);
let level = self.random_level();
let mut levels: Vec<Vec<usize>> = Vec::with_capacity(level + 1);
for _ in 0..=level {
levels.push(Vec::new());
}
let new_idx = self.nodes.len();
self.nodes.push(TurboHnswNode {
id,
levels,
deleted: false,
});
self.id_to_idx.insert(id, new_idx);
let Some(entry) = self.entry else {
self.entry = Some(new_idx);
return Ok(());
};
let entry_level = self.nodes[entry].level();
let q_rot = self.rotate(&prepared);
let mut current = entry;
if entry_level > level {
for lc in (level + 1..=entry_level).rev() {
current = self.greedy_search_layer(&q_rot, current, lc, new_idx);
}
}
let start_layer = level.min(entry_level);
let mut entry_points = vec![current];
for lc in (0..=start_layer).rev() {
let neighbours = self.search_layer(
&q_rot,
&entry_points,
lc,
self.params.ef_construction,
Some(new_idx),
);
let m = if lc == 0 {
self.params.m0
} else {
self.params.m
};
let selected = Self::select_neighbours(&neighbours, m);
for &nb in &selected {
self.nodes[new_idx].levels[lc].push(nb);
self.nodes[nb].levels[lc].push(new_idx);
let cap = if lc == 0 {
self.params.m0
} else {
self.params.m
};
if self.nodes[nb].levels[lc].len() > cap {
self.shrink_connections(nb, lc, cap);
}
}
entry_points = selected;
if entry_points.is_empty() {
entry_points = vec![current];
}
}
if level > entry_level {
self.entry = Some(new_idx);
}
Ok(())
}
pub fn delete(&mut self, id: NodeId) -> bool {
let Some(&idx) = self.id_to_idx.get(&id) else {
return false;
};
if self.nodes[idx].deleted {
return false;
}
self.nodes[idx].deleted = true;
true
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: Option<usize>,
) -> Result<Vec<SearchResult>, IndexError> {
if query.is_empty() {
return Ok(Vec::new());
}
if self.nodes.is_empty() {
return Ok(Vec::new());
}
let got = u16::try_from(query.len()).unwrap_or(u16::MAX);
if got != self.dim {
return Err(IndexError::DimensionMismatch {
expected: self.dim,
got,
});
}
let prepared = match self.distance {
Distance::Cosine | Distance::Euclidean => l2_normalise(query),
Distance::DotProduct => query.to_vec(),
};
let q_rot = self.rotate(&prepared);
let mut entry = self.entry.unwrap_or(0);
let entry_level = self.nodes[entry].level();
let ef = ef.unwrap_or(self.params.ef_search).max(k);
for lc in (1..=entry_level).rev() {
entry = self.greedy_search_layer(&q_rot, entry, lc, usize::MAX);
}
let candidates = self.search_layer(&q_rot, &[entry], 0, ef, None);
let mut sorted = candidates;
sorted.sort_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(sorted
.into_iter()
.filter(|c| !self.nodes[c.idx].deleted)
.take(k)
.map(|c| SearchResult {
id: self.nodes[c.idx].id,
score: c.score,
})
.collect())
}
fn bytes_per_vec(&self) -> usize {
usize::from(self.dim)
}
fn encode_one(&self, vector: &[f32]) -> (Vec<u8>, f32) {
let dim = usize::from(self.dim);
let mut norm_sq = 0.0_f32;
for &x in vector {
norm_sq += x * x;
}
let norm = norm_sq.sqrt();
let inv_norm = if norm > 1e-10 { 1.0 / norm } else { 0.0 };
let mut unit = vec![0.0_f32; dim];
for (d, slot) in unit.iter_mut().enumerate().take(dim) {
*slot = vector[d] * inv_norm;
}
let u_rot = self.rotate(&unit);
let mut packed = vec![0_u8; dim];
let mut inner = 0.0_f32;
for (j, &uj) in u_rot.iter().enumerate().take(dim) {
let mut code = 0_u8;
for &b in &self.boundaries {
if uj > b {
code += 1;
}
}
inner += uj * self.centroids[usize::from(code)];
packed[j] = code;
}
let inner = inner.max(1e-10_f32);
let scale = norm / inner;
(packed, scale)
}
fn codes(&self, slot: usize) -> &[u8] {
let dim = usize::from(self.dim);
let row_start = slot * dim;
&self.packed[row_start..row_start + dim]
}
fn rotate(&self, q: &[f32]) -> Vec<f32> {
let dim = usize::from(self.dim);
let mut out = vec![0.0_f32; dim];
for (d, slot) in out.iter_mut().enumerate().take(dim) {
let row = &self.rotation[d * dim..(d + 1) * dim];
let mut sum = 0.0_f32;
for (e, &qe) in q.iter().enumerate().take(dim) {
sum += row[e] * qe;
}
*slot = sum;
}
out
}
fn similarity_query(&self, q_rot: &[f32], slot: usize) -> f32 {
let dim = usize::from(self.dim);
let codes = self.codes(slot);
let centroids = self.centroids.as_slice();
let mut acc = 0.0_f32;
for d in 0..dim {
acc += q_rot[d] * centroids[codes[d] as usize];
}
acc * self.scales[slot]
}
fn similarity_pair(&self, a: usize, b: usize) -> f32 {
let dim = usize::from(self.dim);
let ca = self.codes(a);
let cb = self.codes(b);
let centroids = self.centroids.as_slice();
let mut acc = 0.0_f32;
for d in 0..dim {
acc += centroids[ca[d] as usize] * centroids[cb[d] as usize];
}
acc * self.scales[a] * self.scales[b]
}
fn similarity_to_distance(&self, similarity: f32) -> f32 {
match self.distance {
Distance::DotProduct => -similarity,
Distance::Cosine => 1.0 - similarity,
Distance::Euclidean => (2.0 - 2.0 * similarity).max(0.0).sqrt(),
}
}
fn distance_query(&self, q_rot: &[f32], slot: usize) -> f32 {
self.similarity_to_distance(self.similarity_query(q_rot, slot))
}
fn distance_pair(&self, a: usize, b: usize) -> f32 {
self.similarity_to_distance(self.similarity_pair(a, b))
}
fn greedy_search_layer(
&self,
q_rot: &[f32],
entry: usize,
lc: usize,
skip_idx: usize,
) -> usize {
let mut current = entry;
let mut current_score = self.distance_query(q_rot, current);
loop {
let mut improved = false;
let next = if lc < self.nodes[current].levels.len() {
let neighbours = self.nodes[current].levels[lc].as_slice();
let mut best = (current, current_score);
for &nb in neighbours {
if nb == skip_idx {
continue;
}
let s = self.distance_query(q_rot, nb);
if s < best.1 {
best = (nb, s);
improved = true;
}
}
best
} else {
(current, current_score)
};
current = next.0;
current_score = next.1;
if !improved {
break;
}
}
current
}
fn search_layer(
&self,
q_rot: &[f32],
entry_points: &[usize],
lc: usize,
ef: usize,
skip_idx: Option<usize>,
) -> Vec<MaxCandidate> {
let mut visited: HashSet<usize> = HashSet::new();
let mut frontier: BinaryHeap<Candidate> = BinaryHeap::new();
let mut top: BinaryHeap<MaxCandidate> = BinaryHeap::new();
for &ep in entry_points {
if Some(ep) == skip_idx {
continue;
}
if visited.insert(ep) {
let s = self.distance_query(q_rot, ep);
frontier.push(Candidate { idx: ep, score: s });
top.push(MaxCandidate { idx: ep, score: s });
}
}
while let Some(c) = frontier.pop() {
if top.len() >= ef {
if let Some(worst) = top.peek() {
if c.score > worst.score {
break;
}
}
}
if lc < self.nodes[c.idx].levels.len() {
let neighbours = self.nodes[c.idx].levels[lc].as_slice();
for &nb in neighbours {
if Some(nb) == skip_idx {
continue;
}
if !visited.insert(nb) {
continue;
}
let s = self.distance_query(q_rot, nb);
let admit = match top.peek() {
Some(worst) => s < worst.score || top.len() < ef,
None => true,
};
if admit {
frontier.push(Candidate { idx: nb, score: s });
top.push(MaxCandidate { idx: nb, score: s });
if top.len() > ef {
top.pop();
}
}
}
}
}
top.into_vec()
}
fn select_neighbours(candidates: &[MaxCandidate], m: usize) -> Vec<usize> {
let mut sorted: Vec<MaxCandidate> = candidates.to_vec();
sorted.sort_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted.into_iter().take(m).map(|c| c.idx).collect()
}
fn shrink_connections(&mut self, idx: usize, lc: usize, cap: usize) {
let neighbours = std::mem::take(&mut self.nodes[idx].levels[lc]);
let mut scored: Vec<(usize, f32)> = neighbours
.into_iter()
.map(|nb| {
let s = self.distance_pair(idx, nb);
(nb, s)
})
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(cap);
self.nodes[idx].levels[lc] = scored.into_iter().map(|(nb, _)| nb).collect();
}
fn rand_unit(&mut self) -> f64 {
let mut x = self.rng_state;
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
self.rng_state = x;
let r = x.wrapping_mul(0x2545_F491_4F6C_DD1D);
let bits = (r >> 11) & ((1_u64 << 53) - 1);
#[allow(
clippy::cast_precision_loss,
reason = "bits is in [0, 2^53), exactly representable as f64"
)]
let f = (bits as f64) / ((1_u64 << 53) as f64);
f
}
fn random_level(&mut self) -> usize {
let r = self.rand_unit().max(f64::MIN_POSITIVE);
let level = (-r.ln() * self.ml).floor();
let clamped = level.clamp(0.0, 16.0);
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "clamped to [0, 16]"
)]
let lvl = clamped as usize;
lvl
}
}
impl<const BITS: u8> CodecDistance for TurboHnswIndex<BITS> {
fn distance(&self, a: NodeId, b: NodeId) -> f32 {
let Some(&sa) = self.id_to_idx.get(&a) else {
return f32::INFINITY;
};
let Some(&sb) = self.id_to_idx.get(&b) else {
return f32::INFINITY;
};
self.distance_pair(sa, sb)
}
}
fn l2_normalise(v: &[f32]) -> Vec<f32> {
let n2: f32 = v.iter().map(|x| x * x).sum();
let n = n2.sqrt();
if n <= 0.0 {
return v.to_vec();
}
v.iter().map(|x| x / n).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
let mut x = if seed == 0 { 0xDEAD_BEEF } else { seed };
let mut v = Vec::with_capacity(dim);
for _ in 0..dim {
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
let bits = (x >> 11) & ((1_u64 << 53) - 1);
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
reason = "test fixture: PRNG narrowed to f32"
)]
let r = (((bits as f64) / ((1_u64 << 53) as f64)) * 2.0 - 1.0) as f32;
v.push(r);
}
v
}
#[test]
fn insert_and_search_returns_self_first_4bit() {
let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
.expect("4-bit ctor");
let target = rand_vec(42, 64);
idx.insert(0, target.clone()).unwrap();
for i in 1..50_u64 {
idx.insert(i, rand_vec(i.wrapping_mul(1_000_003) + 1, 64))
.unwrap();
}
let res = idx.search(&target, 3, None).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].id, 0);
}
#[test]
fn delete_excludes_from_search() {
let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
.expect("4-bit ctor");
for i in 0..30_u64 {
idx.insert(i, rand_vec(i + 1, 64)).unwrap();
}
let q = rand_vec(1, 64);
let before = idx.search(&q, 5, None).unwrap();
let target = before[0].id;
assert!(idx.delete(target));
let after = idx.search(&q, 5, None).unwrap();
assert!(after.iter().all(|r| r.id != target));
}
#[test]
fn duplicate_id_rejected() {
let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
.expect("4-bit ctor");
idx.insert(7, rand_vec(7, 64)).unwrap();
assert!(matches!(
idx.insert(7, rand_vec(8, 64)),
Err(IndexError::Duplicate(7))
));
}
#[test]
fn dimension_mismatch_rejected() {
let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
.expect("4-bit ctor");
assert!(matches!(
idx.insert(0, vec![0.1; 32]),
Err(IndexError::DimensionMismatch { .. })
));
}
#[test]
fn empty_index_search_is_empty() {
let idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
.expect("4-bit ctor");
let res = idx.search(&rand_vec(0, 64), 5, None).unwrap();
assert!(res.is_empty());
}
#[test]
fn ctor_rejects_misaligned_dim() {
let r = TurboHnswIndex::<4>::new(Distance::Cosine, 7, HnswParams::default());
assert!(matches!(
r,
Err(IndexError::DimensionMismatch {
expected: 8,
got: 7
})
));
}
}