use triblespace_core::query::Variable;
use triblespace_core::inline::encodings::hash::Handle;
use triblespace_core::inline::Inline;
use crate::schemas::{EmbHandle, Embedding};
#[derive(Debug)]
struct HNSWNode {
vector: Vec<f32>,
level: u8,
neighbors: Vec<Vec<u32>>,
}
#[derive(Debug)]
struct HNSWIndexNode {
level: u8,
neighbors: Vec<Vec<u32>>,
}
pub struct HNSWBuilder {
dim: usize,
m: u16,
m0: u16,
ef_construction: u16,
level_mult: f32,
rng: u64,
nodes: Vec<HNSWNode>,
handles: Vec<Inline<Handle<Embedding>>>,
entry_point: Option<u32>,
max_level: u8,
}
impl HNSWBuilder {
pub fn new(dim: usize) -> Self {
assert!(dim > 0, "HNSWBuilder: dim must be > 0");
let m = 16u16;
Self {
dim,
m,
m0: m * 2,
ef_construction: 200,
level_mult: 1.0 / (m as f32).ln(),
rng: 0xC0FFEEu64,
nodes: Vec::new(),
handles: Vec::new(),
entry_point: None,
max_level: 0,
}
}
pub fn m(mut self, m: u16) -> Self {
assert!(m >= 2, "HNSWBuilder: M must be ≥ 2");
self.m = m;
self.m0 = m * 2;
self.level_mult = 1.0 / (m as f32).ln();
self
}
pub fn m0(mut self, m0: u16) -> Self {
assert!(m0 >= self.m, "HNSWBuilder: M0 must be ≥ M");
self.m0 = m0;
self
}
pub fn ef_construction(mut self, ef: u16) -> Self {
assert!(ef >= 1, "HNSWBuilder: ef_construction must be ≥ 1");
self.ef_construction = ef;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.rng = seed;
self
}
fn sample_level(&mut self) -> u8 {
self.rng = self.rng.wrapping_add(0x9E3779B97F4A7C15);
let mut z = self.rng;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^= z >> 31;
let u = ((z >> 11) as f64 / (1u64 << 53) as f64).max(f64::MIN_POSITIVE);
let l = (-u.ln() * self.level_mult as f64).floor() as i32;
l.clamp(0, u8::MAX as i32) as u8
}
pub fn insert(
&mut self,
handle: Inline<Handle<Embedding>>,
mut vec: Vec<f32>,
) -> Result<(), DimMismatch> {
if vec.len() != self.dim {
return Err(DimMismatch {
expected: self.dim,
got: vec.len(),
});
}
normalize(&mut vec);
let new_level = self.sample_level();
let new_idx = self.nodes.len() as u32;
let mut curr = self.entry_point;
if let Some(mut cnode) = curr {
for lvl in ((new_level + 1)..=self.max_level).rev() {
cnode = self.greedy_search_layer(&vec, cnode, lvl);
}
curr = Some(cnode);
}
self.nodes.push(HNSWNode {
vector: vec.clone(),
level: new_level,
neighbors: vec![Vec::new(); new_level as usize + 1],
});
self.handles.push(handle);
if let Some(start) = curr {
let mut entry = start;
for lvl in (0..=new_level.min(self.max_level)).rev() {
let cap = if lvl == 0 { self.m0 } else { self.m } as usize;
let candidates = self.search_layer(&vec, entry, self.ef_construction as usize, lvl);
let selected = Self::select_neighbours(&candidates, cap);
for &n in &selected {
self.nodes[new_idx as usize].neighbors[lvl as usize].push(n);
self.nodes[n as usize].neighbors[lvl as usize].push(new_idx);
}
self.prune_neighbours(new_idx, lvl, cap);
for &n in &selected {
self.prune_neighbours(n, lvl, cap);
}
if let Some((best, _)) = candidates
.iter()
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
{
entry = *best;
}
}
}
if new_level > self.max_level || self.entry_point.is_none() {
self.max_level = new_level;
self.entry_point = Some(new_idx);
}
Ok(())
}
pub fn build(self) -> crate::succinct::SuccinctHNSWIndex {
crate::succinct::SuccinctHNSWIndex::from_naive(&self.build_naive())
.expect("from_naive cannot fail on a valid HNSWIndex built by HNSWBuilder")
}
pub fn build_naive(self) -> HNSWIndex {
let nodes: Vec<HNSWIndexNode> = self
.nodes
.into_iter()
.map(|n| HNSWIndexNode {
level: n.level,
neighbors: n.neighbors,
})
.collect();
HNSWIndex {
dim: self.dim,
m: self.m,
m0: self.m0,
nodes,
handles: self.handles,
entry_point: self.entry_point,
max_level: self.max_level,
}
}
fn greedy_search_layer(&self, q: &[f32], entry: u32, layer: u8) -> u32 {
let mut curr = entry;
let mut curr_dist = cosine_dist(q, &self.nodes[curr as usize].vector);
loop {
let mut changed = false;
let node = &self.nodes[curr as usize];
let Some(neigh) = node.neighbors.get(layer as usize) else {
return curr;
};
for &n in neigh {
let d = cosine_dist(q, &self.nodes[n as usize].vector);
if d < curr_dist {
curr_dist = d;
curr = n;
changed = true;
}
}
if !changed {
return curr;
}
}
}
fn search_layer(&self, q: &[f32], entry: u32, ef: usize, layer: u8) -> Vec<(u32, f32)> {
use std::collections::BinaryHeap;
let mut visited: std::collections::HashSet<u32> = std::collections::HashSet::new();
visited.insert(entry);
let d0 = cosine_dist(q, &self.nodes[entry as usize].vector);
let mut candidates: BinaryHeap<MinDist> = BinaryHeap::new();
candidates.push(MinDist {
idx: entry,
dist: d0,
});
let mut results: BinaryHeap<MaxDist> = BinaryHeap::new();
results.push(MaxDist {
idx: entry,
dist: d0,
});
while let Some(c) = candidates.pop() {
let farthest = results.peek().map(|r| r.dist).unwrap_or(f32::INFINITY);
if c.dist > farthest && results.len() >= ef {
break;
}
let node = &self.nodes[c.idx as usize];
let Some(neigh) = node.neighbors.get(layer as usize) else {
continue;
};
for &n in neigh {
if !visited.insert(n) {
continue;
}
let d = cosine_dist(q, &self.nodes[n as usize].vector);
let farthest = results.peek().map(|r| r.dist).unwrap_or(f32::INFINITY);
if d < farthest || results.len() < ef {
candidates.push(MinDist { idx: n, dist: d });
results.push(MaxDist { idx: n, dist: d });
if results.len() > ef {
results.pop();
}
}
}
}
results.into_iter().map(|m| (m.idx, m.dist)).collect()
}
fn select_neighbours(candidates: &[(u32, f32)], cap: usize) -> Vec<u32> {
let mut sorted: Vec<&(u32, f32)> = candidates.iter().collect();
sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
sorted.into_iter().take(cap).map(|&(i, _)| i).collect()
}
fn prune_neighbours(&mut self, node: u32, layer: u8, cap: usize) {
let list_snapshot: Vec<u32> = self.nodes[node as usize].neighbors[layer as usize].clone();
if list_snapshot.len() <= cap {
let list = &mut self.nodes[node as usize].neighbors[layer as usize];
list.sort_unstable();
list.dedup();
return;
}
let q = self.nodes[node as usize].vector.clone();
let mut scored: Vec<(u32, f32)> = list_snapshot
.iter()
.map(|&n| (n, cosine_dist(&q, &self.nodes[n as usize].vector)))
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let list = &mut self.nodes[node as usize].neighbors[layer as usize];
list.clear();
list.extend(scored.into_iter().take(cap).map(|(i, _)| i));
list.sort_unstable();
list.dedup();
}
}
#[doc(hidden)]
pub struct HNSWIndex {
dim: usize,
m: u16,
m0: u16,
nodes: Vec<HNSWIndexNode>,
handles: Vec<Inline<Handle<Embedding>>>,
entry_point: Option<u32>,
max_level: u8,
}
impl std::fmt::Debug for HNSWIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HNSWIndex")
.field("n_nodes", &self.handles.len())
.field("dim", &self.dim)
.field("max_level", &self.max_level)
.finish()
}
}
impl HNSWIndex {
pub fn dim(&self) -> usize {
self.dim
}
pub fn doc_count(&self) -> usize {
self.handles.len()
}
pub fn m(&self) -> u16 {
self.m
}
pub fn m0(&self) -> u16 {
self.m0
}
pub fn max_level(&self) -> u8 {
self.max_level
}
pub fn node_level(&self, i: usize) -> Option<u8> {
self.nodes.get(i).map(|n| n.level)
}
pub fn node_neighbours(&self, i: usize, layer: u8) -> &[u32] {
self.nodes
.get(i)
.and_then(|n| n.neighbors.get(layer as usize))
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn handles(&self) -> &[Inline<Handle<Embedding>>] {
&self.handles
}
pub fn entry_point(&self) -> Option<u32> {
self.entry_point
}
pub fn attach<'a, B>(&'a self, store: &B) -> AttachedHNSWIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet + Clone,
{
AttachedHNSWIndex {
index: self,
cache: triblespace_core::blob::BlobCache::new(store.clone()),
ef_search: 200,
}
}
pub fn byte_size(&self) -> usize {
let n = self.nodes.len();
let entries_per_node = (self.max_level as usize) + 2;
let total_neighbours: usize = self
.nodes
.iter()
.map(|n| n.neighbors.iter().map(|l| l.len()).sum::<usize>())
.sum();
24 + n * 32 + n + n * entries_per_node * 4 + total_neighbours * 4
}
}
#[doc(hidden)]
pub struct AttachedHNSWIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
index: &'a HNSWIndex,
cache: triblespace_core::blob::BlobCache<B, Embedding, anybytes::View<[f32]>>,
ef_search: usize,
}
impl<'a, B> AttachedHNSWIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
pub fn index(&self) -> &HNSWIndex {
self.index
}
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
pub fn similar(
&self,
a: Variable<EmbHandle>,
b: Variable<EmbHandle>,
score_floor: f32,
) -> crate::constraint::Similar<'_, Self> {
crate::constraint::Similar::new(self, a, b, score_floor)
}
pub fn similar_to(
&self,
probe: Inline<EmbHandle>,
var: Variable<EmbHandle>,
score_floor: f32,
) -> crate::constraint::SimilarTo {
let candidates = self
.candidates_above(probe, score_floor)
.map(|v| v.into_iter().map(|h| h.raw).collect())
.unwrap_or_default();
crate::constraint::SimilarTo::from_candidates(var, candidates)
}
#[doc(hidden)]
pub fn candidates_above(
&self,
from_handle: Inline<EmbHandle>,
score_floor: f32,
) -> Result<Vec<Inline<EmbHandle>>, B::GetError<anybytes::view::ViewError>> {
let Some(entry) = self.index.entry_point else {
return Ok(Vec::new());
};
let from = self.cache.get(from_handle)?;
let query: Vec<f32> = from.as_ref().as_ref().to_vec();
if query.len() != self.index.dim {
return Ok(Vec::new());
}
let mut curr = entry;
for lvl in (1..=self.index.max_level).rev() {
curr = self.greedy_search_layer(&query, curr, lvl)?;
}
let candidates = self.search_layer(&query, curr, self.ef_search, 0)?;
Ok(candidates
.into_iter()
.filter(|(_, dist)| 1.0 - dist >= score_floor)
.map(|(i, _)| self.index.handles[i as usize])
.collect())
}
fn dist_to(
&self,
q: &[f32],
i: u32,
) -> Result<f32, B::GetError<anybytes::view::ViewError>> {
let handle = self.index.handles[i as usize];
let view = self.cache.get(handle)?;
Ok(cosine_dist(q, view.as_ref().as_ref()))
}
fn greedy_search_layer(
&self,
q: &[f32],
entry: u32,
layer: u8,
) -> Result<u32, B::GetError<anybytes::view::ViewError>> {
let mut curr = entry;
let mut curr_dist = self.dist_to(q, curr)?;
loop {
let mut changed = false;
let node = &self.index.nodes[curr as usize];
let Some(neigh) = node.neighbors.get(layer as usize) else {
return Ok(curr);
};
let neigh = neigh.clone();
for n in neigh {
let d = self.dist_to(q, n)?;
if d < curr_dist {
curr_dist = d;
curr = n;
changed = true;
}
}
if !changed {
return Ok(curr);
}
}
}
fn search_layer(
&self,
q: &[f32],
entry: u32,
ef: usize,
layer: u8,
) -> Result<Vec<(u32, f32)>, B::GetError<anybytes::view::ViewError>> {
use std::collections::BinaryHeap;
let mut visited: std::collections::HashSet<u32> = std::collections::HashSet::new();
visited.insert(entry);
let d0 = self.dist_to(q, entry)?;
let mut candidates: BinaryHeap<MinDist> = BinaryHeap::new();
candidates.push(MinDist {
idx: entry,
dist: d0,
});
let mut results: BinaryHeap<MaxDist> = BinaryHeap::new();
results.push(MaxDist {
idx: entry,
dist: d0,
});
while let Some(c) = candidates.pop() {
let farthest = results.peek().map(|r| r.dist).unwrap_or(f32::INFINITY);
if c.dist > farthest && results.len() >= ef {
break;
}
let neigh = {
let node = &self.index.nodes[c.idx as usize];
let Some(neigh) = node.neighbors.get(layer as usize) else {
continue;
};
neigh.clone()
};
for n in neigh {
if !visited.insert(n) {
continue;
}
let d = self.dist_to(q, n)?;
let farthest = results.peek().map(|r| r.dist).unwrap_or(f32::INFINITY);
if d < farthest || results.len() < ef {
candidates.push(MinDist { idx: n, dist: d });
results.push(MaxDist { idx: n, dist: d });
if results.len() > ef {
results.pop();
}
}
}
}
Ok(results.into_iter().map(|m| (m.idx, m.dist)).collect())
}
}
pub(crate) fn cosine_dist(a: &[f32], b: &[f32]) -> f32 {
1.0 - dot(a, b)
}
#[derive(Clone, Copy)]
struct MinDist {
idx: u32,
dist: f32,
}
impl PartialEq for MinDist {
fn eq(&self, o: &Self) -> bool {
self.dist == o.dist
}
}
impl Eq for MinDist {}
impl PartialOrd for MinDist {
fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(o))
}
}
impl Ord for MinDist {
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
o.dist
.partial_cmp(&self.dist)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Clone, Copy)]
struct MaxDist {
idx: u32,
dist: f32,
}
impl PartialEq for MaxDist {
fn eq(&self, o: &Self) -> bool {
self.dist == o.dist
}
}
impl Eq for MaxDist {}
impl PartialOrd for MaxDist {
fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(o))
}
}
impl Ord for MaxDist {
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
self.dist
.partial_cmp(&o.dist)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DimMismatch {
pub expected: usize,
pub got: usize,
}
impl std::fmt::Display for DimMismatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"embedding dimensionality mismatch: expected {}, got {}",
self.expected, self.got
)
}
}
impl std::error::Error for DimMismatch {}
pub(crate) fn normalize(v: &mut [f32]) {
let norm_sq: f32 = v.iter().map(|&x| x * x).sum();
if norm_sq > 0.0 {
let inv = 1.0 / norm_sq.sqrt();
for x in v.iter_mut() {
*x *= inv;
}
}
}
fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[doc(hidden)]
pub struct FlatBuilder {
dim: usize,
handles: Vec<Inline<Handle<Embedding>>>,
}
impl FlatBuilder {
pub fn new(dim: usize) -> Self {
assert!(dim > 0, "FlatBuilder: dim must be > 0");
Self {
dim,
handles: Vec::new(),
}
}
pub fn insert(&mut self, handle: Inline<Handle<Embedding>>) {
self.handles.push(handle);
}
pub fn build(self) -> FlatIndex {
FlatIndex {
dim: self.dim,
handles: self.handles,
}
}
pub fn len(&self) -> usize {
self.handles.len()
}
pub fn is_empty(&self) -> bool {
self.handles.is_empty()
}
pub fn dim(&self) -> usize {
self.dim
}
}
#[doc(hidden)]
#[derive(Debug, Clone)]
pub struct FlatIndex {
dim: usize,
handles: Vec<Inline<Handle<Embedding>>>,
}
impl FlatIndex {
pub fn dim(&self) -> usize {
self.dim
}
pub fn doc_count(&self) -> usize {
self.handles.len()
}
pub fn handles(&self) -> &[Inline<Handle<Embedding>>] {
&self.handles
}
pub fn attach<'a, B>(&'a self, store: &B) -> AttachedFlatIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet + Clone,
{
AttachedFlatIndex {
index: self,
cache: triblespace_core::blob::BlobCache::new(store.clone()),
}
}
}
#[doc(hidden)]
pub struct AttachedFlatIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
index: &'a FlatIndex,
cache: triblespace_core::blob::BlobCache<B, Embedding, anybytes::View<[f32]>>,
}
impl<'a, B> AttachedFlatIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
pub fn index(&self) -> &FlatIndex {
self.index
}
pub fn similar(
&self,
a: Variable<EmbHandle>,
b: Variable<EmbHandle>,
score_floor: f32,
) -> crate::constraint::Similar<'_, Self> {
crate::constraint::Similar::new(self, a, b, score_floor)
}
pub fn similar_to(
&self,
probe: Inline<EmbHandle>,
var: Variable<EmbHandle>,
score_floor: f32,
) -> crate::constraint::SimilarTo {
let candidates = self
.candidates_above(probe, score_floor)
.map(|v| v.into_iter().map(|h| h.raw).collect())
.unwrap_or_default();
crate::constraint::SimilarTo::from_candidates(var, candidates)
}
#[doc(hidden)]
pub fn candidates_above(
&self,
from_handle: Inline<EmbHandle>,
score_floor: f32,
) -> Result<Vec<Inline<EmbHandle>>, B::GetError<anybytes::view::ViewError>> {
let from = self.cache.get(from_handle)?;
let query = from.as_ref().as_ref();
if query.len() != self.index.dim {
return Ok(Vec::new());
}
let mut out = Vec::new();
for &handle in self.index.handles.iter() {
let view = self.cache.get(handle)?;
let score = dot(query, view.as_ref().as_ref());
if score >= score_floor {
out.push(handle);
}
}
Ok(out)
}
}
impl FlatIndex {
pub fn byte_size(&self) -> usize {
24 + self.handles.len() * 32
}
}
impl<'a, B> crate::constraint::SimilaritySearch for AttachedHNSWIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
fn neighbours_above(
&self,
from: Inline<Handle<Embedding>>,
score_floor: f32,
) -> Vec<Inline<Handle<Embedding>>> {
self.candidates_above(from, score_floor).unwrap_or_default()
}
fn cosine_between(
&self,
a: Inline<Handle<Embedding>>,
b: Inline<Handle<Embedding>>,
) -> Option<f32> {
let va = self.cache.get(a).ok()?;
let vb = self.cache.get(b).ok()?;
let a_slice: &[f32] = va.as_ref().as_ref();
let b_slice: &[f32] = vb.as_ref().as_ref();
if a_slice.len() != b_slice.len() {
return None;
}
Some(dot(a_slice, b_slice))
}
}
impl<'a, B> crate::constraint::SimilaritySearch for AttachedFlatIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
fn neighbours_above(
&self,
from: Inline<Handle<Embedding>>,
score_floor: f32,
) -> Vec<Inline<Handle<Embedding>>> {
self.candidates_above(from, score_floor).unwrap_or_default()
}
fn cosine_between(
&self,
a: Inline<Handle<Embedding>>,
b: Inline<Handle<Embedding>>,
) -> Option<f32> {
let va = self.cache.get(a).ok()?;
let vb = self.cache.get(b).ok()?;
let a_slice: &[f32] = va.as_ref().as_ref();
let b_slice: &[f32] = vb.as_ref().as_ref();
if a_slice.len() != b_slice.len() {
return None;
}
Some(dot(a_slice, b_slice))
}
}
#[cfg(test)]
mod tests {
use super::*;
use triblespace_core::blob::MemoryBlobStore;
use triblespace_core::repo::BlobStore;
fn put_emb(
store: &mut MemoryBlobStore,
vec: Vec<f32>,
) -> Inline<Handle<Embedding>> {
crate::schemas::put_embedding::<_>(store, vec).unwrap()
}
fn build_flat(
dim: usize,
vecs: &[Vec<f32>],
) -> (
FlatIndex,
MemoryBlobStore,
Vec<Inline<Handle<Embedding>>>,
) {
let mut store = MemoryBlobStore::new();
let mut b = FlatBuilder::new(dim);
let mut handles = Vec::with_capacity(vecs.len());
for v in vecs {
let h = put_emb(&mut store, v.clone());
b.insert(h);
handles.push(h);
}
(b.build(), store, handles)
}
fn reader_of(
store: &mut MemoryBlobStore,
) -> <MemoryBlobStore as BlobStore>::Reader {
store.reader().unwrap()
}
#[test]
fn flat_exact_match_includes_self_at_cos_one() {
let (idx, mut store, handles) = build_flat(
3,
&[
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
],
);
let hits = idx
.attach(&reader_of(&mut store))
.candidates_above(handles[0], 0.999)
.unwrap();
assert_eq!(hits, vec![handles[0]]);
}
#[test]
fn flat_threshold_selects_near_matches() {
let (idx, mut store, handles) = build_flat(
2,
&[
vec![1.0, 0.0],
vec![0.9, 0.1],
vec![0.0, 1.0],
],
);
let got: std::collections::HashSet<_> = idx
.attach(&reader_of(&mut store))
.candidates_above(handles[0], 0.8)
.unwrap()
.into_iter()
.collect();
assert!(got.contains(&handles[0]));
assert!(got.contains(&handles[1]));
assert!(!got.contains(&handles[2]));
}
#[test]
fn flat_parallel_inputs_dedupe_at_put() {
let (_idx, _store, handles) = build_flat(
2,
&[vec![3.0, 0.0], vec![100.0, 0.0]],
);
assert_eq!(handles[0], handles[1]);
}
#[test]
fn flat_empty_index_has_no_candidates() {
let mut store = MemoryBlobStore::new();
let idx = FlatBuilder::new(4).build();
let probe = put_emb(&mut store, vec![1.0, 0.0, 0.0, 0.0]);
let reader = store.reader().unwrap();
assert!(idx.attach(&reader).candidates_above(probe, 0.0).unwrap().is_empty());
}
fn sample_flat() -> (
FlatIndex,
MemoryBlobStore,
Vec<Inline<Handle<Embedding>>>,
) {
build_flat(
3,
&[
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.5, 0.5, 0.0],
],
)
}
#[test]
fn flat_byte_size_matches_formula() {
let (idx, _, _) = sample_flat();
assert_eq!(idx.byte_size(), 24 + idx.doc_count() * 32);
}
fn build_hnsw(
dim: usize,
seed: u64,
vecs: &[Vec<f32>],
) -> (
crate::succinct::SuccinctHNSWIndex,
MemoryBlobStore,
Vec<Inline<Handle<Embedding>>>,
) {
let mut store = MemoryBlobStore::new();
let mut b = HNSWBuilder::new(dim).with_seed(seed);
let mut handles = Vec::with_capacity(vecs.len());
for v in vecs {
let h = put_emb(&mut store, v.clone());
b.insert(h, v.clone()).unwrap();
handles.push(h);
}
(b.build(), store, handles)
}
#[test]
fn hnsw_empty_index_has_no_candidates() {
let mut store = MemoryBlobStore::new();
let idx = HNSWBuilder::new(4).build();
assert_eq!(idx.doc_count(), 0);
let probe = put_emb(&mut store, vec![1.0, 0.0, 0.0, 0.0]);
let reader = store.reader().unwrap();
assert!(idx
.attach(&reader)
.candidates_above(probe, 0.0)
.unwrap()
.is_empty());
}
#[test]
fn hnsw_single_doc_returns_itself() {
let (idx, mut store, handles) = build_hnsw(3, 42, &[vec![1.0, 0.0, 0.0]]);
let hits = idx
.attach(&reader_of(&mut store))
.candidates_above(handles[0], 0.999)
.unwrap();
assert_eq!(hits, vec![handles[0]]);
}
#[test]
fn hnsw_threshold_excludes_orthogonal() {
let (idx, mut store, handles) = build_hnsw(
2,
42,
&[vec![1.0, 0.0], vec![0.9, 0.1], vec![0.0, 1.0]],
);
let got: std::collections::HashSet<_> = idx
.attach(&reader_of(&mut store))
.candidates_above(handles[0], 0.8)
.unwrap()
.into_iter()
.collect();
assert!(got.contains(&handles[0]));
assert!(got.contains(&handles[1]));
assert!(!got.contains(&handles[2]));
}
#[test]
fn hnsw_threshold_recall_matches_flat_on_small_corpus() {
let mut rng = 0xBABE_u64;
let next = |r: &mut u64| {
*r = r.wrapping_add(0x9E3779B97F4A7C15);
let mut z = *r;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^ (z >> 31)
};
let dim = 16;
let vecs: Vec<Vec<f32>> = (0..200)
.map(|_| {
(0..dim)
.map(|_| (next(&mut rng) as i32 as f32) / (i32::MAX as f32))
.collect()
})
.collect();
let (flat, mut fstore, fhandles) = build_flat(dim, &vecs);
let (hnsw, mut hstore, hhandles) = build_hnsw(dim, 42, &vecs);
assert_eq!(fhandles, hhandles);
let freader = fstore.reader().unwrap();
let hreader = hstore.reader().unwrap();
let hnsw_view = hnsw.attach(&hreader).with_ef_search(50);
let flat_view = flat.attach(&freader);
let floor = 0.6;
let mut total_hits = 0usize;
let mut total_overlap = 0usize;
for probe in fhandles.iter().take(5) {
let truth: std::collections::HashSet<_> =
flat_view.candidates_above(*probe, floor).unwrap().into_iter().collect();
let got: std::collections::HashSet<_> =
hnsw_view.candidates_above(*probe, floor).unwrap().into_iter().collect();
total_hits += truth.len();
total_overlap += truth.intersection(&got).count();
}
assert!(total_hits > 0, "test fixture: floor excluded everything");
let recall = total_overlap as f32 / total_hits as f32;
assert!(recall >= 0.7, "HNSW recall {recall:.2} below 0.7 threshold");
}
#[test]
fn hnsw_deterministic_seed_reproduces_structure() {
let vecs: Vec<Vec<f32>> = (1u8..=20)
.map(|i| {
vec![
(i as f32) / 20.0,
((i as f32) * 2.0) % 1.0,
((i as f32) * 3.0) % 1.0,
]
})
.collect();
let (a, mut a_store, a_handles) = build_hnsw(3, 123, &vecs);
let (b, mut b_store, b_handles) = build_hnsw(3, 123, &vecs);
assert_eq!(a.doc_count(), b.doc_count());
assert_eq!(a.max_level(), b.max_level());
assert_eq!(a_handles, b_handles);
let ra = a
.attach(&a_store.reader().unwrap())
.candidates_above(a_handles[0], 0.5)
.unwrap();
let rb = b
.attach(&b_store.reader().unwrap())
.candidates_above(b_handles[0], 0.5)
.unwrap();
assert_eq!(ra, rb);
}
#[test]
fn hnsw_dim_mismatch_rejected_at_insert() {
let mut store = MemoryBlobStore::new();
let mut b = HNSWBuilder::new(3);
let h = put_emb(&mut store, vec![1.0, 0.0]);
let err = b.insert(h, vec![1.0, 0.0]).unwrap_err();
assert_eq!(err.expected, 3);
assert_eq!(err.got, 2);
}
fn sample_hnsw() -> (
HNSWIndex,
MemoryBlobStore,
Vec<Inline<Handle<Embedding>>>,
) {
let mut store = MemoryBlobStore::new();
let mut b = HNSWBuilder::new(3).with_seed(42);
let vecs = [
vec![1.0f32, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let mut handles = Vec::with_capacity(vecs.len());
for v in &vecs {
let h = put_emb(&mut store, v.clone());
b.insert(h, v.clone()).unwrap();
handles.push(h);
}
(b.build_naive(), store, handles)
}
#[test]
fn hnsw_byte_size_positive_and_growing() {
let (idx, _, _) = sample_hnsw();
let small = idx.byte_size();
assert!(small > 0);
let vecs = [
vec![1.0f32, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.5, 0.5, 0.0],
vec![0.0, 0.0, 1.0],
vec![0.2, 0.3, 0.5],
];
let mut store = MemoryBlobStore::new();
let mut b = HNSWBuilder::new(3).with_seed(19);
for v in &vecs {
let h = put_emb(&mut store, v.clone());
b.insert(h, v.clone()).unwrap();
}
let larger = b.build_naive().byte_size();
assert!(larger > small);
}
}