use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
use crate::core::LuciError;
use crate::mapping::QuantizationType;
use parking_lot::RwLock;
use rayon::ThreadPoolBuilder;
use rayon::current_num_threads;
use rayon::prelude::*;
use super::{DistanceMetric, normalize_in_place};
const ENTRY_SENTINEL: u32 = u32::MAX;
fn pack_entry(entry_point: u32, max_level: u32) -> u64 {
((entry_point as u64) << 32) | (max_level as u64)
}
fn unpack_entry(packed: u64) -> (u32, u32) {
((packed >> 32) as u32, packed as u32)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BuildThreads {
Ambient,
Fixed(usize),
}
pub const HNSW_FORMAT_MAGIC: [u8; 4] = *b"LHNS";
pub const HNSW_FORMAT_VERSION: u8 = 2;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HnswFormatVersion {
V1,
V2,
}
#[derive(Clone, Copy, Debug)]
pub struct HnswHeader {
pub version: HnswFormatVersion,
pub dims: usize,
pub m: usize,
pub metric: DistanceMetric,
pub num_vectors: usize,
pub entry_point: Option<u32>,
pub max_level: usize,
pub vectors_offset: usize,
}
pub(crate) fn take_bytes<'a>(
data: &'a [u8],
pos: &mut usize,
n: usize,
) -> Result<&'a [u8], LuciError> {
let start = *pos;
let end = start
.checked_add(n)
.filter(|&e| e <= data.len())
.ok_or_else(|| {
LuciError::IndexCorrupted(format!(
"vector index blob truncated: need {n} bytes at offset {start}, have {} total",
data.len()
))
})?;
*pos = end;
Ok(&data[start..end])
}
pub(crate) fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, LuciError> {
Ok(u32::from_le_bytes(
take_bytes(data, pos, 4)?.try_into().unwrap(),
))
}
pub(crate) fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, LuciError> {
Ok(u64::from_le_bytes(
take_bytes(data, pos, 8)?.try_into().unwrap(),
))
}
fn read_u8(data: &[u8], pos: &mut usize) -> Result<u8, LuciError> {
Ok(take_bytes(data, pos, 1)?[0])
}
fn read_f32(data: &[u8], pos: &mut usize) -> Result<f32, LuciError> {
Ok(f32::from_le_bytes(
take_bytes(data, pos, 4)?.try_into().unwrap(),
))
}
pub(crate) fn checked_len(
count: usize,
min_elem_bytes: usize,
data: &[u8],
pos: usize,
) -> Result<usize, LuciError> {
let remaining = data.len().saturating_sub(pos);
if min_elem_bytes != 0 && count > remaining / min_elem_bytes {
return Err(LuciError::IndexCorrupted(format!(
"vector index blob declares {count} elements (≥{min_elem_bytes} B each) \
but only {remaining} bytes remain"
)));
}
Ok(count)
}
pub fn read_header(data: &[u8]) -> Result<HnswHeader, LuciError> {
let (version, mut pos) = if data.len() >= 5 && data[0..4] == HNSW_FORMAT_MAGIC {
let v = data[4];
if v != HNSW_FORMAT_VERSION {
return Err(LuciError::SegmentFormatUnknown(format!(
"unknown HNSW format version: {v}",
)));
}
(HnswFormatVersion::V2, 5_usize)
} else {
(HnswFormatVersion::V1, 0_usize)
};
let dims = read_u32(data, &mut pos)? as usize;
let m = read_u32(data, &mut pos)? as usize;
let metric = DistanceMetric::from_byte(read_u8(data, &mut pos)?);
if version == HnswFormatVersion::V1 && metric == DistanceMetric::Cosine {
return Err(LuciError::SegmentFormatMigrationRequired(
"cosine HNSW segment was built with Luci ≤ 0.7.1 which stored \
raw vectors. The v0.7.2+ kernel requires unit-length vectors \
on disk. Re-index this collection."
.into(),
));
}
let num_vectors = read_u32(data, &mut pos)? as usize;
let ep = read_u32(data, &mut pos)?;
let entry_point = if ep == u32::MAX { None } else { Some(ep) };
let max_level = read_u32(data, &mut pos)? as usize;
Ok(HnswHeader {
version,
dims,
m,
metric,
num_vectors,
entry_point,
max_level,
vectors_offset: pos,
})
}
#[derive(Clone, Debug)]
pub struct HnswParams {
pub dims: usize,
pub m: usize, pub ef_construction: usize, pub metric: DistanceMetric,
pub quantization: QuantizationType,
}
impl Default for HnswParams {
fn default() -> Self {
Self {
dims: 128,
m: 16,
ef_construction: 100,
metric: DistanceMetric::Cosine,
quantization: QuantizationType::DEFAULT,
}
}
}
#[derive(Clone, Copy)]
struct Candidate {
id: u32,
dist: f32,
}
impl PartialEq for Candidate {
fn eq(&self, o: &Self) -> bool {
self.dist == o.dist
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
Some(self.cmp(o))
}
}
impl Ord for Candidate {
fn cmp(&self, o: &Self) -> Ordering {
o.dist.partial_cmp(&self.dist).unwrap_or(Ordering::Equal)
}
}
struct FurthestCandidate(Candidate);
impl PartialEq for FurthestCandidate {
fn eq(&self, o: &Self) -> bool {
self.0.dist == o.0.dist
}
}
impl Eq for FurthestCandidate {}
impl PartialOrd for FurthestCandidate {
fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
Some(self.cmp(o))
}
}
impl Ord for FurthestCandidate {
fn cmp(&self, o: &Self) -> Ordering {
self.0
.dist
.partial_cmp(&o.0.dist)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Clone)]
pub(crate) struct Node {
neighbors: Vec<Vec<u32>>,
level: usize,
}
pub struct HnswBuilder {
params: HnswParams,
vectors: Vec<Vec<f32>>,
nodes: Vec<RwLock<Node>>,
entry: AtomicU64,
ready: Vec<AtomicU64>,
connected_count: usize,
level_mult: f64,
rng_state: u64,
}
impl Clone for HnswBuilder {
fn clone(&self) -> Self {
Self {
params: self.params.clone(),
vectors: self.vectors.clone(),
nodes: self
.nodes
.iter()
.map(|n| RwLock::new(n.read().clone()))
.collect(),
entry: AtomicU64::new(self.entry.load(AtomicOrdering::Relaxed)),
ready: self
.ready
.iter()
.map(|w| AtomicU64::new(w.load(AtomicOrdering::Relaxed)))
.collect(),
connected_count: self.connected_count,
level_mult: self.level_mult,
rng_state: self.rng_state,
}
}
}
impl HnswBuilder {
pub fn new(params: HnswParams) -> Self {
match params.quantization {
QuantizationType::None | QuantizationType::Int8 => {}
unsupported @ (QuantizationType::Int4 | QuantizationType::Bbq) => {
panic!(
"HnswBuilder constructed with unimplemented quantization \
type {unsupported:?}; the mapping parser should have \
rejected this at index creation. This is an internal \
wiring bug, not user input."
);
}
}
let level_mult = 1.0 / (params.m as f64).ln();
Self {
params,
vectors: Vec::new(),
nodes: Vec::new(),
entry: AtomicU64::new(pack_entry(ENTRY_SENTINEL, 0)),
ready: Vec::new(),
connected_count: 0,
level_mult,
rng_state: 42,
}
}
pub fn with_capacity_for_merge(params: HnswParams, capacity: usize) -> Self {
match params.quantization {
QuantizationType::None | QuantizationType::Int8 => {}
unsupported @ (QuantizationType::Int4 | QuantizationType::Bbq) => {
panic!(
"HnswBuilder::with_capacity_for_merge constructed with \
unimplemented quantization type {unsupported:?}; the \
mapping parser should have rejected this at index \
creation. This is an internal wiring bug, not user \
input."
);
}
}
let level_mult = 1.0 / (params.m as f64).ln();
let mut vectors = Vec::with_capacity(capacity);
let mut nodes = Vec::with_capacity(capacity);
for _ in 0..capacity {
vectors.push(Vec::new());
nodes.push(RwLock::new(Node {
neighbors: Vec::new(),
level: 0,
}));
}
let mut ready = Vec::with_capacity(capacity.div_ceil(64));
for _ in 0..capacity.div_ceil(64) {
ready.push(AtomicU64::new(0));
}
Self {
params,
vectors,
nodes,
entry: AtomicU64::new(pack_entry(ENTRY_SENTINEL, 0)),
ready,
connected_count: 0,
level_mult,
rng_state: 42,
}
}
pub fn seed_from_graph<F>(&mut self, seed: &ParsedGraph, hnsw_bytes: &[u8], ord_map: F)
where
F: Fn(u32) -> u32,
{
let dims = self.params.dims;
debug_assert_eq!(seed.params.dims, dims);
debug_assert_eq!(seed.params.metric, self.params.metric);
debug_assert_eq!(seed.params.m, self.params.m);
for src_ord in 0..seed.num_vectors as u32 {
let merged_ord = ord_map(src_ord);
let start = seed.vector_data_offset + src_ord as usize * dims * 4;
let mut v = Vec::with_capacity(dims);
for d in 0..dims {
let off = start + d * 4;
v.push(f32::from_le_bytes(
hnsw_bytes[off..off + 4].try_into().unwrap(),
));
}
self.vectors[merged_ord as usize] = v;
let src_node = &seed.nodes[src_ord as usize];
let neighbors: Vec<Vec<u32>> = src_node
.neighbors
.iter()
.map(|layer| layer.iter().copied().map(&ord_map).collect())
.collect();
self.nodes[merged_ord as usize] = RwLock::new(Node {
neighbors,
level: src_node.level,
});
}
let ep = seed.entry_point.map(&ord_map).unwrap_or(ENTRY_SENTINEL);
self.entry.store(
pack_entry(ep, seed.max_level as u32),
AtomicOrdering::Relaxed,
);
}
fn next_rand(&mut self) -> f64 {
self.rng_state ^= self.rng_state << 13;
self.rng_state ^= self.rng_state >> 7;
self.rng_state ^= self.rng_state << 17;
(self.rng_state as f64) / (u64::MAX as f64)
}
fn random_level(&mut self) -> usize {
let r = self.next_rand().max(1e-10);
(-r.ln() * self.level_mult).floor() as usize
}
pub fn store_vector(&mut self, mut vector: Vec<f32>) -> Result<(), LuciError> {
debug_assert_eq!(vector.len(), self.params.dims);
if self.params.metric == DistanceMetric::Cosine {
normalize_in_place(&mut vector)?;
}
let ord = self.vectors.len();
let level = self.random_level();
self.vectors.push(vector);
let mut neighbors = Vec::with_capacity(level + 1);
for _ in 0..=level {
neighbors.push(Vec::new());
}
self.nodes.push(RwLock::new(Node { neighbors, level }));
if ord / 64 >= self.ready.len() {
self.ready.push(AtomicU64::new(0));
}
Ok(())
}
pub fn add_vector(&mut self, vector: Vec<f32>) -> Result<(), LuciError> {
let id = self.vectors.len() as u32;
self.store_vector(vector)?;
let level = self.nodes[id as usize].read().level;
self.connect_node(id, level);
self.set_ready(id);
self.connected_count = self.vectors.len();
Ok(())
}
pub fn add_vector_at_ordinal(
&mut self,
ord: u32,
mut vector: Vec<f32>,
) -> Result<(), LuciError> {
debug_assert_eq!(vector.len(), self.params.dims);
debug_assert!((ord as usize) < self.vectors.len());
debug_assert!(
self.vectors[ord as usize].is_empty(),
"add_vector_at_ordinal called on already-filled ordinal {ord}",
);
if self.params.metric == DistanceMetric::Cosine {
normalize_in_place(&mut vector)?;
}
let level = self.random_level();
self.vectors[ord as usize] = vector;
let mut neighbors = Vec::with_capacity(level + 1);
for _ in 0..=level {
neighbors.push(Vec::new());
}
self.nodes[ord as usize] = RwLock::new(Node { neighbors, level });
self.connect_node(ord, level);
Ok(())
}
fn connect_node(&mut self, id: u32, level: usize) {
let (ep0, max_l0) = unpack_entry(self.entry.load(AtomicOrdering::Relaxed));
if ep0 == ENTRY_SENTINEL {
self.entry
.store(pack_entry(id, level as u32), AtomicOrdering::Relaxed);
return;
}
let max_level = max_l0 as usize;
let mut current = ep0;
for lev in (level + 1..=max_level).rev() {
current = self.greedy_closest(current, id, lev);
}
let insert_from = level.min(max_level);
let mut ep_for_level = current;
for lev in (0..=insert_from).rev() {
let candidates = self.search_layer(id, ep_for_level, self.params.ef_construction, lev);
let neighbors_to_connect =
self.select_neighbors_heuristic(&candidates, self.m_max(lev));
{
let mut node = self.nodes[id as usize].write();
if lev < node.neighbors.len() {
node.neighbors[lev] = neighbors_to_connect.iter().map(|c| c.id).collect();
}
}
for &n in &neighbors_to_connect {
let mut node = self.nodes[n.id as usize].write();
if lev < node.neighbors.len() {
node.neighbors[lev].push(id);
if node.neighbors[lev].len() > self.m_max(lev) {
self.prune_connections_in(&mut node, n.id, lev);
}
}
}
if !candidates.is_empty() {
ep_for_level = candidates[0].id;
}
}
if level as u32 > max_l0 {
self.entry
.store(pack_entry(id, level as u32), AtomicOrdering::Relaxed);
}
}
fn dist(&self, a: u32, b: u32) -> f32 {
super::distance(
&self.vectors[a as usize],
&self.vectors[b as usize],
self.params.metric,
)
}
fn dist_to_vec(&self, a: u32, query: &[f32]) -> f32 {
super::distance(&self.vectors[a as usize], query, self.params.metric)
}
fn greedy_closest(&self, start: u32, target: u32, level: usize) -> u32 {
let mut current = start;
let mut current_dist = self.dist(current, target);
loop {
let mut changed = false;
let node = self.nodes[current as usize].read();
if level < node.neighbors.len() {
for &neighbor in &node.neighbors[level] {
let d = self.dist(neighbor, target);
if d < current_dist {
current = neighbor;
current_dist = d;
changed = true;
}
}
}
drop(node);
if !changed {
break;
}
}
current
}
fn search_layer(&self, query_id: u32, entry: u32, ef: usize, level: usize) -> Vec<Candidate> {
self.search_layer_vec(&self.vectors[query_id as usize], entry, ef, level, None)
}
fn search_layer_vec(
&self,
query: &[f32],
entry: u32,
ef: usize,
level: usize,
filter: Option<&roaring::RoaringBitmap>,
) -> Vec<Candidate> {
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new(); let mut results = BinaryHeap::new();
let d = self.dist_to_vec(entry, query);
visited.insert(entry);
candidates.push(Candidate { id: entry, dist: d });
if filter.is_none() || filter.unwrap().contains(entry) {
results.push(FurthestCandidate(Candidate { id: entry, dist: d }));
}
while let Some(c) = candidates.pop() {
let furthest_dist = results.peek().map(|f| f.0.dist).unwrap_or(f32::MAX);
if c.dist > furthest_dist && results.len() >= ef {
break;
}
let node = self.nodes[c.id as usize].read();
if level < node.neighbors.len() {
for &neighbor in &node.neighbors[level] {
if visited.insert(neighbor) {
let d = self.dist_to_vec(neighbor, query);
let furthest_dist = results.peek().map(|f| f.0.dist).unwrap_or(f32::MAX);
if d < furthest_dist || results.len() < ef {
candidates.push(Candidate {
id: neighbor,
dist: d,
});
if filter.is_none() || filter.unwrap().contains(neighbor) {
results.push(FurthestCandidate(Candidate {
id: neighbor,
dist: d,
}));
if results.len() > ef {
results.pop();
}
}
}
}
}
}
}
results.into_sorted_vec().into_iter().map(|f| f.0).collect()
}
fn m_max(&self, level: usize) -> usize {
if level == 0 {
self.params.m * 2
} else {
self.params.m
}
}
fn select_neighbors_heuristic(&self, candidates: &[Candidate], m: usize) -> Vec<Candidate> {
let mut working: Vec<Candidate> = candidates.to_vec();
working.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap_or(Ordering::Equal));
let mut selected: Vec<Candidate> = Vec::with_capacity(m);
for c in working {
if selected.len() == m {
break;
}
let diverse = selected.iter().all(|s| self.dist(c.id, s.id) >= c.dist);
if diverse {
selected.push(c);
}
}
selected
}
fn prune_connections_in(&self, node: &mut Node, node_id: u32, level: usize) {
let candidates: Vec<Candidate> = node.neighbors[level]
.iter()
.map(|&n| Candidate {
id: n,
dist: self.dist(node_id, n),
})
.collect();
let kept = self.select_neighbors_heuristic(&candidates, self.m_max(level));
node.neighbors[level] = kept.iter().map(|c| c.id).collect();
}
fn set_ready(&self, ord: u32) {
let word = (ord / 64) as usize;
let bit = ord % 64;
self.ready[word].fetch_or(1u64 << bit, AtomicOrdering::Release);
}
fn is_ready(&self, ord: u32) -> bool {
let word = (ord / 64) as usize;
let bit = ord % 64;
(self.ready[word].load(AtomicOrdering::Acquire) >> bit) & 1 == 1
}
pub fn connect_pending(&mut self, threads: BuildThreads) {
let start = self.connected_count;
let end = self.vectors.len();
if start >= end {
return;
}
match threads {
BuildThreads::Fixed(1) => self.connect_tail_sequential(start, end),
BuildThreads::Ambient => {
if current_num_threads() <= 1 {
self.connect_tail_sequential(start, end);
} else {
self.connect_tail_parallel(start, end);
}
}
BuildThreads::Fixed(n) => {
let pool = ThreadPoolBuilder::new()
.num_threads(n)
.build()
.expect("failed to build HNSW connect thread pool");
pool.install(|| self.connect_tail_parallel(start, end));
}
}
self.connected_count = end;
}
fn connect_tail_sequential(&mut self, start: usize, end: usize) {
for ord in start..end {
let level = self.nodes[ord].read().level;
self.connect_node(ord as u32, level);
self.set_ready(ord as u32);
}
}
fn connect_tail_parallel(&self, start: usize, end: usize) {
(start..end).into_par_iter().for_each(|ord| {
let level = self.nodes[ord].read().level;
self.connect_node_locked(ord as u32, level);
});
}
fn connect_node_locked(&self, ord: u32, level: usize) {
let (mut ep, mut max_level_u) = unpack_entry(self.entry.load(AtomicOrdering::Acquire));
if ep == ENTRY_SENTINEL {
self.set_ready(ord);
match self.entry.compare_exchange(
pack_entry(ENTRY_SENTINEL, 0),
pack_entry(ord, level as u32),
AtomicOrdering::AcqRel,
AtomicOrdering::Acquire,
) {
Ok(_) => return,
Err(actual) => {
let (a_ep, a_max) = unpack_entry(actual);
ep = a_ep;
max_level_u = a_max;
}
}
}
let max_level = max_level_u as usize;
for lev in (level + 1..=max_level).rev() {
ep = self.greedy_closest_ready(ep, ord, lev);
}
let insert_from = level.min(max_level);
let mut ep_for_level = ep;
for lev in (0..=insert_from).rev() {
let candidates =
self.search_layer_ready(ord, ep_for_level, self.params.ef_construction, lev);
let selected = self.select_neighbors_heuristic(&candidates, self.m_max(lev));
{
let mut node = self.nodes[ord as usize].write();
if lev < node.neighbors.len() {
node.neighbors[lev] = selected.iter().map(|c| c.id).collect();
}
}
for &n in &selected {
let mut node = self.nodes[n.id as usize].write();
if lev < node.neighbors.len() {
node.neighbors[lev].push(ord);
if node.neighbors[lev].len() > self.m_max(lev) {
self.prune_connections_in(&mut node, n.id, lev);
}
}
}
if !candidates.is_empty() {
ep_for_level = candidates[0].id;
}
}
self.set_ready(ord);
loop {
let cur = self.entry.load(AtomicOrdering::Acquire);
let (_, cur_max) = unpack_entry(cur);
if (level as u32) <= cur_max {
break;
}
if self
.entry
.compare_exchange(
cur,
pack_entry(ord, level as u32),
AtomicOrdering::AcqRel,
AtomicOrdering::Acquire,
)
.is_ok()
{
break;
}
}
}
fn greedy_closest_ready(&self, start: u32, target: u32, level: usize) -> u32 {
let mut current = start;
let mut current_dist = self.dist(current, target);
loop {
let neighbors: Vec<u32> = {
let node = self.nodes[current as usize].read();
if level < node.neighbors.len() {
node.neighbors[level].clone()
} else {
Vec::new()
}
};
let mut changed = false;
for neighbor in neighbors {
if !self.is_ready(neighbor) {
continue;
}
let d = self.dist(neighbor, target);
if d < current_dist {
current = neighbor;
current_dist = d;
changed = true;
}
}
if !changed {
break;
}
}
current
}
fn search_layer_ready(
&self,
query_id: u32,
entry: u32,
ef: usize,
level: usize,
) -> Vec<Candidate> {
let query = self.vectors[query_id as usize].as_slice();
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut results = BinaryHeap::new();
let d = self.dist_to_vec(entry, query);
visited.insert(entry);
candidates.push(Candidate { id: entry, dist: d });
results.push(FurthestCandidate(Candidate { id: entry, dist: d }));
while let Some(c) = candidates.pop() {
let furthest_dist = results.peek().map(|f| f.0.dist).unwrap_or(f32::MAX);
if c.dist > furthest_dist && results.len() >= ef {
break;
}
let neighbors: Vec<u32> = {
let node = self.nodes[c.id as usize].read();
if level < node.neighbors.len() {
node.neighbors[level].clone()
} else {
Vec::new()
}
};
for neighbor in neighbors {
if !self.is_ready(neighbor) {
continue;
}
if visited.insert(neighbor) {
let d = self.dist_to_vec(neighbor, query);
let furthest_dist = results.peek().map(|f| f.0.dist).unwrap_or(f32::MAX);
if d < furthest_dist || results.len() < ef {
candidates.push(Candidate {
id: neighbor,
dist: d,
});
results.push(FurthestCandidate(Candidate {
id: neighbor,
dist: d,
}));
if results.len() > ef {
results.pop();
}
}
}
}
}
results.into_sorted_vec().into_iter().map(|f| f.0).collect()
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn has_pending_tail(&self) -> bool {
self.connected_count < self.vectors.len()
}
pub fn params(&self) -> HnswParams {
self.params.clone()
}
pub fn from_index(index: HnswIndex) -> Self {
let level_mult = 1.0 / (index.params.m as f64).ln();
let ep = index.entry_point.unwrap_or(ENTRY_SENTINEL);
let entry = AtomicU64::new(pack_entry(ep, index.max_level as u32));
let n = index.vectors.len();
let mut ready = Vec::with_capacity(n.div_ceil(64));
for _ in 0..n.div_ceil(64) {
ready.push(AtomicU64::new(u64::MAX));
}
Self {
params: index.params,
vectors: index.vectors,
nodes: index.nodes.into_iter().map(RwLock::new).collect(),
entry,
ready,
connected_count: n,
level_mult,
rng_state: 42,
}
}
pub fn build(self) -> HnswIndex {
let quantized = match self.params.quantization {
QuantizationType::None => None,
QuantizationType::Int8 if !self.vectors.is_empty() => Some(
super::quantize::QuantizedVectors::quantize(&self.vectors, self.params.metric),
),
QuantizationType::Int8 => None,
unsupported @ (QuantizationType::Int4 | QuantizationType::Bbq) => {
panic!(
"HnswBuilder::build reached with unimplemented \
quantization {unsupported:?}; the constructor was \
supposed to reject this."
);
}
};
let (ep, max_level) = unpack_entry(self.entry.load(AtomicOrdering::Relaxed));
let entry_point = if ep == ENTRY_SENTINEL { None } else { Some(ep) };
HnswIndex {
params: self.params,
vectors: self.vectors,
nodes: self
.nodes
.into_iter()
.map(|lock| lock.into_inner())
.collect(),
entry_point,
max_level: max_level as usize,
quantized,
}
}
}
pub struct HnswIndex {
params: HnswParams,
vectors: Vec<Vec<f32>>,
nodes: Vec<Node>,
entry_point: Option<u32>,
max_level: usize,
quantized: Option<super::quantize::QuantizedVectors>,
}
impl HnswIndex {
pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Result<Vec<(u32, f32)>, LuciError> {
self.search_filtered(query, k, ef, None)
}
pub fn search_filtered(
&self,
query: &[f32],
k: usize,
ef: usize,
filter: Option<&roaring::RoaringBitmap>,
) -> Result<Vec<(u32, f32)>, LuciError> {
if self.entry_point.is_none() {
return Ok(Vec::new());
}
let query_owned: Vec<f32> = if self.params.metric == DistanceMetric::Cosine {
let mut q = query.to_vec();
normalize_in_place(&mut q)?;
q
} else {
query.to_vec()
};
let query = &query_owned[..];
if let Some(bm) = filter {
if (bm.len() as f64) < (self.vectors.len() as f64 * 0.01) {
return Ok(self.brute_force_search(query, k, bm));
}
}
let ep = self.entry_point.unwrap();
let ef_actual = ef.max(k);
let mut current = ep;
for lev in (1..=self.max_level).rev() {
current = self.greedy_closest_vec(current, query, lev);
}
let candidates = self.search_layer_0(query, current, ef_actual, filter);
let mut results: Vec<(u32, f32)> = if self.quantized.is_some() {
candidates
.into_iter()
.map(|c| (c.id, self.exact_dist(c.id, query)))
.collect()
} else {
candidates.into_iter().map(|c| (c.id, c.dist)).collect()
};
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results.truncate(k);
Ok(results)
}
fn greedy_closest_vec(&self, start: u32, query: &[f32], level: usize) -> u32 {
let mut current = start;
let mut current_dist = self.quantized_dist(current, query);
loop {
let mut changed = false;
if level < self.nodes[current as usize].neighbors.len() {
for &neighbor in &self.nodes[current as usize].neighbors[level] {
let d = self.quantized_dist(neighbor, query);
if d < current_dist {
current = neighbor;
current_dist = d;
changed = true;
}
}
}
if !changed {
break;
}
}
current
}
fn search_layer_0(
&self,
query: &[f32],
entry: u32,
ef: usize,
filter: Option<&roaring::RoaringBitmap>,
) -> Vec<Candidate> {
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut results = BinaryHeap::new();
let d = self.quantized_dist(entry, query);
visited.insert(entry);
candidates.push(Candidate { id: entry, dist: d });
if filter.is_none() || filter.unwrap().contains(entry) {
results.push(FurthestCandidate(Candidate { id: entry, dist: d }));
}
while let Some(c) = candidates.pop() {
let furthest_dist = results.peek().map(|f| f.0.dist).unwrap_or(f32::MAX);
if c.dist > furthest_dist && results.len() >= ef {
break;
}
if !self.nodes[c.id as usize].neighbors.is_empty() {
for &neighbor in &self.nodes[c.id as usize].neighbors[0] {
if visited.insert(neighbor) {
let d = self.quantized_dist(neighbor, query);
let furthest_dist = results.peek().map(|f| f.0.dist).unwrap_or(f32::MAX);
if d < furthest_dist || results.len() < ef {
candidates.push(Candidate {
id: neighbor,
dist: d,
});
if filter.is_none() || filter.unwrap().contains(neighbor) {
results.push(FurthestCandidate(Candidate {
id: neighbor,
dist: d,
}));
if results.len() > ef {
results.pop();
}
}
}
}
}
}
}
let mut result: Vec<Candidate> = results.into_iter().map(|f| f.0).collect();
result.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap_or(Ordering::Equal));
result
}
fn brute_force_search(
&self,
query: &[f32],
k: usize,
filter: &roaring::RoaringBitmap,
) -> Vec<(u32, f32)> {
let mut results: Vec<(u32, f32)> = filter
.iter()
.filter(|&id| (id as usize) < self.vectors.len())
.map(|id| (id, self.quantized_dist(id, query)))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results.truncate(k);
results
}
#[inline]
fn quantized_dist(&self, a: u32, query: &[f32]) -> f32 {
if let Some(ref qv) = self.quantized {
qv.asymmetric_distance(a as usize, query)
} else {
super::distance(&self.vectors[a as usize], query, self.params.metric)
}
}
fn exact_dist(&self, a: u32, query: &[f32]) -> f32 {
super::distance(&self.vectors[a as usize], query, self.params.metric)
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn dims(&self) -> usize {
self.params.dims
}
pub fn vector(&self, id: u32) -> &[f32] {
&self.vectors[id as usize]
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&HNSW_FORMAT_MAGIC);
buf.push(HNSW_FORMAT_VERSION);
buf.extend_from_slice(&(self.params.dims as u32).to_le_bytes());
buf.extend_from_slice(&(self.params.m as u32).to_le_bytes());
buf.push(self.params.metric as u8);
buf.extend_from_slice(&(self.vectors.len() as u32).to_le_bytes());
buf.extend_from_slice(&self.entry_point.unwrap_or(u32::MAX).to_le_bytes());
buf.extend_from_slice(&(self.max_level as u32).to_le_bytes());
for v in &self.vectors {
for &f in v {
buf.extend_from_slice(&f.to_le_bytes());
}
}
for node in &self.nodes {
buf.extend_from_slice(&(node.level as u32).to_le_bytes());
buf.extend_from_slice(&(node.neighbors.len() as u32).to_le_bytes());
for layer in &node.neighbors {
buf.extend_from_slice(&(layer.len() as u32).to_le_bytes());
for &n in layer {
buf.extend_from_slice(&n.to_le_bytes());
}
}
}
if let Some(ref qv) = self.quantized {
buf.push(1u8); let qbytes = qv.to_bytes();
buf.extend_from_slice(&(qbytes.len() as u32).to_le_bytes());
buf.extend_from_slice(&qbytes);
} else {
buf.push(0u8); }
buf
}
pub fn from_bytes(data: &[u8]) -> Result<Self, LuciError> {
let header = read_header(data)?;
let dims = header.dims;
let m = header.m;
let metric = header.metric;
let num_vectors = header.num_vectors;
let entry_point = header.entry_point;
let max_level = header.max_level;
let mut pos = header.vectors_offset;
checked_len(
num_vectors,
dims.saturating_mul(4).saturating_add(8),
data,
pos,
)?;
let mut vectors = Vec::with_capacity(num_vectors);
for _ in 0..num_vectors {
let mut v = Vec::with_capacity(dims);
for _ in 0..dims {
v.push(read_f32(data, &mut pos)?);
}
vectors.push(v);
}
let mut nodes = Vec::with_capacity(num_vectors);
for _ in 0..num_vectors {
let level = read_u32(data, &mut pos)? as usize;
let num_layers = read_u32(data, &mut pos)? as usize;
let mut neighbors = Vec::with_capacity(checked_len(num_layers, 4, data, pos)?);
for _ in 0..num_layers {
let num_neighbors = read_u32(data, &mut pos)? as usize;
let mut layer = Vec::with_capacity(checked_len(num_neighbors, 4, data, pos)?);
for _ in 0..num_neighbors {
layer.push(read_u32(data, &mut pos)?);
}
neighbors.push(layer);
}
nodes.push(Node { neighbors, level });
}
let quantized = if pos < data.len() && data[pos] == 1 {
pos += 1;
let qlen = read_u32(data, &mut pos)? as usize;
let qbytes = take_bytes(data, &mut pos, qlen)?;
Some(super::quantize::QuantizedVectors::from_bytes(qbytes))
} else {
None
};
let quantization = if quantized.is_some() {
QuantizationType::Int8
} else {
QuantizationType::None
};
if let Some(ep) = entry_point {
if ep as usize >= num_vectors {
return Err(LuciError::IndexCorrupted(format!(
"HNSW entry_point {ep} out of range (num_vectors {num_vectors})"
)));
}
}
for (node_id, node) in nodes.iter().enumerate() {
for layer in &node.neighbors {
for &nid in layer {
if nid as usize >= num_vectors {
return Err(LuciError::IndexCorrupted(format!(
"HNSW node {node_id} neighbour id {nid} out of range \
(num_vectors {num_vectors})"
)));
}
}
}
}
Ok(Self {
params: HnswParams {
dims,
m,
ef_construction: 100,
metric,
quantization,
},
vectors,
nodes,
entry_point,
max_level,
quantized,
})
}
}
pub struct ParsedGraph {
pub params: HnswParams,
pub num_vectors: usize,
pub entry_point: Option<u32>,
pub max_level: usize,
pub(crate) nodes: Vec<Node>,
pub vector_data_offset: usize,
pub quantized_offset: Option<usize>,
pub quantized_len: usize,
}
impl ParsedGraph {
pub fn parse(data: &[u8]) -> Result<Self, LuciError> {
let header = read_header(data)?;
let dims = header.dims;
let num_vectors = header.num_vectors;
let mut pos = header.vectors_offset;
let vector_data_offset = pos;
pos += num_vectors * dims * 4;
let mut nodes = Vec::with_capacity(num_vectors);
for _ in 0..num_vectors {
let level = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
pos += 4;
let num_layers = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
pos += 4;
let mut neighbors = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
let num_neighbors =
u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
pos += 4;
let mut layer = Vec::with_capacity(num_neighbors);
for _ in 0..num_neighbors {
layer.push(u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()));
pos += 4;
}
neighbors.push(layer);
}
nodes.push(Node { neighbors, level });
}
let (quantized_offset, quantized_len) = if pos < data.len() && data[pos] == 1 {
pos += 1;
let qlen = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
pos += 4;
(Some(pos), qlen)
} else {
(None, 0)
};
let quantization = if quantized_offset.is_some() {
QuantizationType::Int8
} else {
QuantizationType::None
};
Ok(Self {
params: HnswParams {
dims,
m: header.m,
ef_construction: 100,
metric: header.metric,
quantization,
},
num_vectors,
entry_point: header.entry_point,
max_level: header.max_level,
nodes,
vector_data_offset,
quantized_offset,
quantized_len,
})
}
}
pub struct HnswSearcher<'a> {
graph: &'a ParsedGraph,
data: &'a [u8],
quantized_cal: Option<QuantizedCalibration<'a>>,
}
struct QuantizedCalibration<'a> {
dims: usize,
_num_vectors: usize,
_metric: DistanceMetric,
mins: &'a [u8], scales: &'a [u8], norms: &'a [u8], data: &'a [u8], }
impl<'a> HnswSearcher<'a> {
pub fn new(data: &'a [u8], graph: &'a ParsedGraph) -> Self {
let quantized_cal = graph.quantized_offset.map(|qoff| {
let qdata = &data[qoff..qoff + graph.quantized_len];
let dims = graph.params.dims;
let num_vectors = graph.num_vectors;
let mut p = 0;
p += 4 + 4 + 1; let mins = &qdata[p..p + dims * 4];
p += dims * 4;
let scales = &qdata[p..p + dims * 4];
p += dims * 4;
let norms = &qdata[p..p + num_vectors * 4];
p += num_vectors * 4;
let vdata = &qdata[p..p + num_vectors * dims];
QuantizedCalibration {
dims,
_num_vectors: num_vectors,
_metric: graph.params.metric,
mins,
scales,
norms,
data: vdata,
}
});
Self {
graph,
data,
quantized_cal,
}
}
pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Result<Vec<(u32, f32)>, LuciError> {
self.search_filtered(query, k, ef, None)
}
pub fn search_filtered(
&self,
query: &[f32],
k: usize,
ef: usize,
filter: Option<&roaring::RoaringBitmap>,
) -> Result<Vec<(u32, f32)>, LuciError> {
if self.graph.entry_point.is_none() {
return Ok(Vec::new());
}
let query_owned: Vec<f32> = if self.graph.params.metric == DistanceMetric::Cosine {
let mut q = query.to_vec();
normalize_in_place(&mut q)?;
q
} else {
query.to_vec()
};
let query = &query_owned[..];
if let Some(bm) = filter {
if (bm.len() as f64) < (self.graph.num_vectors as f64 * 0.01) {
return Ok(self.brute_force_search(query, k, bm));
}
}
let ep = self.graph.entry_point.unwrap();
let ef_actual = ef.max(k);
let mut current = ep;
for lev in (1..=self.graph.max_level).rev() {
current = self.greedy_closest(current, query, lev);
}
let candidates = self.search_layer_0(query, current, ef_actual, filter);
let mut results: Vec<(u32, f32)> = candidates
.into_iter()
.map(|c| (c.id, self.exact_dist(c.id, query)))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results.truncate(k);
Ok(results)
}
#[inline]
fn read_f32(&self, byte_offset: usize) -> f32 {
f32::from_le_bytes(self.data[byte_offset..byte_offset + 4].try_into().unwrap())
}
fn exact_dist(&self, idx: u32, query: &[f32]) -> f32 {
let dims = self.graph.params.dims;
let base = self.graph.vector_data_offset + (idx as usize) * dims * 4;
let mut vec = Vec::with_capacity(dims);
for d in 0..dims {
vec.push(self.read_f32(base + d * 4));
}
super::distance(&vec, query, self.graph.params.metric)
}
#[inline]
fn approx_dist(&self, idx: u32, query: &[f32]) -> f32 {
if let Some(ref cal) = self.quantized_cal {
self.quantized_cosine_dist(cal, idx as usize, query)
} else {
self.exact_dist(idx, query)
}
}
fn quantized_cosine_dist(&self, cal: &QuantizedCalibration, idx: usize, query: &[f32]) -> f32 {
let dims = cal.dims;
let qvec = &cal.data[idx * dims..(idx + 1) * dims];
let mut dot = 0.0f32;
for d in 0..dims {
let min = f32::from_le_bytes(cal.mins[d * 4..d * 4 + 4].try_into().unwrap());
let scale = f32::from_le_bytes(cal.scales[d * 4..d * 4 + 4].try_into().unwrap());
let dequant = min + qvec[d] as f32 * scale;
dot += dequant * query[d];
}
let norm_offset = idx * 4;
let stored_norm =
f32::from_le_bytes(cal.norms[norm_offset..norm_offset + 4].try_into().unwrap());
if stored_norm == 0.0 {
1.0
} else {
1.0 - dot / stored_norm
}
}
fn greedy_closest(&self, start: u32, query: &[f32], level: usize) -> u32 {
let mut current = start;
let mut current_dist = self.approx_dist(current, query);
loop {
let mut changed = false;
if level < self.graph.nodes[current as usize].neighbors.len() {
for &neighbor in &self.graph.nodes[current as usize].neighbors[level] {
let d = self.approx_dist(neighbor, query);
if d < current_dist {
current = neighbor;
current_dist = d;
changed = true;
}
}
}
if !changed {
break;
}
}
current
}
fn search_layer_0(
&self,
query: &[f32],
entry: u32,
ef: usize,
filter: Option<&roaring::RoaringBitmap>,
) -> Vec<Candidate> {
let mut visited = HashSet::new();
let mut candidates = BinaryHeap::new();
let mut results = BinaryHeap::new();
let d = self.approx_dist(entry, query);
visited.insert(entry);
candidates.push(Candidate { id: entry, dist: d });
if filter.is_none() || filter.unwrap().contains(entry) {
results.push(FurthestCandidate(Candidate { id: entry, dist: d }));
}
while let Some(c) = candidates.pop() {
let furthest_dist = results.peek().map(|f| f.0.dist).unwrap_or(f32::MAX);
if c.dist > furthest_dist && results.len() >= ef {
break;
}
if !self.graph.nodes[c.id as usize].neighbors.is_empty() {
for &neighbor in &self.graph.nodes[c.id as usize].neighbors[0] {
if visited.insert(neighbor) {
let d = self.approx_dist(neighbor, query);
let furthest_dist = results.peek().map(|f| f.0.dist).unwrap_or(f32::MAX);
if d < furthest_dist || results.len() < ef {
candidates.push(Candidate {
id: neighbor,
dist: d,
});
if filter.is_none() || filter.unwrap().contains(neighbor) {
results.push(FurthestCandidate(Candidate {
id: neighbor,
dist: d,
}));
if results.len() > ef {
results.pop();
}
}
}
}
}
}
}
let mut result: Vec<Candidate> = results.into_iter().map(|f| f.0).collect();
result.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap_or(Ordering::Equal));
result
}
fn brute_force_search(
&self,
query: &[f32],
k: usize,
filter: &roaring::RoaringBitmap,
) -> Vec<(u32, f32)> {
let mut results: Vec<(u32, f32)> = filter
.iter()
.filter(|&id| (id as usize) < self.graph.num_vectors)
.map(|id| (id, self.exact_dist(id, query)))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results.truncate(k);
results
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_params(dims: usize) -> HnswParams {
HnswParams {
dims,
m: 8,
ef_construction: 50,
metric: DistanceMetric::L2,
quantization: QuantizationType::DEFAULT,
}
}
#[test]
fn build_and_search_small() {
let mut builder = HnswBuilder::new(make_params(2));
builder.add_vector(vec![0.0, 0.0]).unwrap(); builder.add_vector(vec![1.0, 0.0]).unwrap(); builder.add_vector(vec![0.0, 1.0]).unwrap(); builder.add_vector(vec![1.0, 1.0]).unwrap();
let index = builder.build();
let results = index.search(&[0.1, 0.1], 2, 10).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0); }
#[test]
fn search_returns_k_results() {
let mut builder = HnswBuilder::new(make_params(3));
for i in 0..20 {
builder.add_vector(vec![i as f32, 0.0, 0.0]).unwrap();
}
let index = builder.build();
let results = index.search(&[5.0, 0.0, 0.0], 5, 20).unwrap();
assert_eq!(results.len(), 5);
assert!(results.iter().any(|(id, _)| *id == 5));
}
#[test]
fn recall_on_random_data() {
let dims = 16;
let n = 200;
let mut builder = HnswBuilder::new(HnswParams {
dims,
m: 12,
ef_construction: 50,
metric: DistanceMetric::L2,
quantization: QuantizationType::DEFAULT,
});
let mut rng: u64 = 12345;
let mut vectors = Vec::new();
for _ in 0..n {
let mut v = Vec::with_capacity(dims);
for _ in 0..dims {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
v.push((rng as f32 / u64::MAX as f32) * 2.0 - 1.0);
}
vectors.push(v);
}
for v in &vectors {
builder.add_vector(v.clone()).unwrap();
}
let index = builder.build();
let results = index.search(&vectors[0], 1, 50).unwrap();
assert_eq!(results[0].0, 0);
let query = &vectors[42];
let hnsw_results = index.search(query, 10, 50).unwrap();
let mut brute: Vec<(u32, f32)> = (0..n as u32)
.map(|i| {
(
i,
super::super::distance(&vectors[i as usize], query, DistanceMetric::L2),
)
})
.collect();
brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let brute_top10: HashSet<u32> = brute[..10].iter().map(|x| x.0).collect();
let hnsw_top10: HashSet<u32> = hnsw_results.iter().map(|x| x.0).collect();
let recall = brute_top10.intersection(&hnsw_top10).count() as f64 / 10.0;
assert!(recall >= 0.8, "recall@10 = {recall}, expected >= 0.8");
}
fn gen_vectors(n: usize, dims: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = seed | 1;
let mut vectors = Vec::with_capacity(n);
for _ in 0..n {
let mut v = Vec::with_capacity(dims);
for _ in 0..dims {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
v.push((rng as f32 / u64::MAX as f32) * 2.0 - 1.0);
}
vectors.push(v);
}
vectors
}
fn mean_recall_at_10(index: &HnswIndex, vectors: &[Vec<f32>], queries: &[usize]) -> f64 {
let n = vectors.len();
let mut total = 0.0;
for &qi in queries {
let query = &vectors[qi];
let hnsw: HashSet<u32> = index
.search(query, 10, 64)
.unwrap()
.iter()
.map(|x| x.0)
.collect();
let mut brute: Vec<(u32, f32)> = (0..n as u32)
.map(|i| {
(
i,
super::super::distance(&vectors[i as usize], query, DistanceMetric::L2),
)
})
.collect();
brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let truth: HashSet<u32> = brute[..10].iter().map(|x| x.0).collect();
total += truth.intersection(&hnsw).count() as f64 / 10.0;
}
total / queries.len() as f64
}
#[test]
fn connect_pending_fixed1_matches_add_vector_bytes() {
let dims = 16;
let vectors = gen_vectors(300, dims, 0x00AB_CDEF);
let mut immediate = HnswBuilder::new(make_params(dims));
for v in &vectors {
immediate.add_vector(v.clone()).unwrap();
}
let bytes_immediate = immediate.build().to_bytes();
let mut deferred = HnswBuilder::new(make_params(dims));
for v in &vectors {
deferred.store_vector(v.clone()).unwrap();
}
deferred.connect_pending(BuildThreads::Fixed(1));
let bytes_deferred = deferred.build().to_bytes();
assert_eq!(
bytes_immediate, bytes_deferred,
"Fixed(1) deferred build must be byte-identical to immediate add_vector",
);
}
#[test]
fn parallel_build_recall_matches_sequential() {
let dims = 16;
let vectors = gen_vectors(800, dims, 0x1234_5678);
let queries: Vec<usize> = (0..40).map(|i| i * 17 % vectors.len()).collect();
let mut seq = HnswBuilder::new(make_params(dims));
for v in &vectors {
seq.add_vector(v.clone()).unwrap();
}
let recall_seq = mean_recall_at_10(&seq.build(), &vectors, &queries);
for threads in [
BuildThreads::Fixed(1),
BuildThreads::Fixed(8),
BuildThreads::Ambient,
] {
let mut par = HnswBuilder::new(make_params(dims));
for v in &vectors {
par.store_vector(v.clone()).unwrap();
}
par.connect_pending(threads);
let recall_par = mean_recall_at_10(&par.build(), &vectors, &queries);
assert!(
recall_par >= recall_seq - 0.10,
"{threads:?}: parallel recall {recall_par:.3} should track sequential \
{recall_seq:.3} (within 0.10)",
);
}
}
#[test]
fn parallel_build_no_deadlock_on_duplicates() {
let dims = 8;
let distinct = gen_vectors(10, dims, 0x0000_9999);
let mut builder = HnswBuilder::new(make_params(dims));
for i in 0..2000 {
builder
.store_vector(distinct[i % distinct.len()].clone())
.unwrap();
}
builder.connect_pending(BuildThreads::Fixed(16));
let index = builder.build();
let results = index.search(&distinct[0], 5, 32).unwrap();
assert!(
!results.is_empty(),
"duplicate-heavy graph should still answer queries",
);
}
#[test]
fn filtered_search() {
let mut builder = HnswBuilder::new(make_params(2));
for i in 0..10 {
builder.add_vector(vec![i as f32, 0.0]).unwrap();
}
let index = builder.build();
let mut filter = roaring::RoaringBitmap::new();
for i in (0..10).step_by(2) {
filter.insert(i);
}
let results = index
.search_filtered(&[3.0, 0.0], 3, 20, Some(&filter))
.unwrap();
for (id, _) in &results {
assert!(id % 2 == 0, "filtered result should be even, got {id}");
}
assert!(results[0].0 == 2 || results[0].0 == 4);
}
#[test]
fn serialization_round_trip() {
let mut builder = HnswBuilder::new(make_params(3));
builder.add_vector(vec![1.0, 2.0, 3.0]).unwrap();
builder.add_vector(vec![4.0, 5.0, 6.0]).unwrap();
builder.add_vector(vec![7.0, 8.0, 9.0]).unwrap();
let index = builder.build();
let bytes = index.to_bytes();
let restored = HnswIndex::from_bytes(&bytes).unwrap();
assert_eq!(restored.len(), 3);
assert_eq!(restored.dims(), 3);
let r1 = index.search(&[1.0, 2.0, 3.0], 1, 10).unwrap();
let r2 = restored.search(&[1.0, 2.0, 3.0], 1, 10).unwrap();
assert_eq!(r1[0].0, r2[0].0);
}
#[test]
fn from_bytes_rejects_corrupt_blob() {
let mut builder = HnswBuilder::new(make_params(3));
builder.add_vector(vec![1.0, 2.0, 3.0]).unwrap();
builder.add_vector(vec![4.0, 5.0, 6.0]).unwrap();
builder.add_vector(vec![7.0, 8.0, 9.0]).unwrap();
let valid = builder.build().to_bytes();
assert!(
HnswIndex::from_bytes(&valid).is_ok(),
"valid blob must load"
);
for cut in [
1usize,
10,
20,
valid.len() / 2,
valid.len().saturating_sub(6),
] {
assert!(
HnswIndex::from_bytes(&valid[..cut]).is_err(),
"truncated-to-{cut} blob must be rejected, not panic"
);
}
let mut bad_ep = valid.clone();
bad_ep[18..22].copy_from_slice(&9999u32.to_le_bytes());
assert!(
matches!(
HnswIndex::from_bytes(&bad_ep),
Err(LuciError::IndexCorrupted(_))
),
"out-of-range entry_point must be IndexCorrupted"
);
let metric_byte = valid[13];
let mut blob = Vec::new();
blob.extend_from_slice(&HNSW_FORMAT_MAGIC);
blob.push(HNSW_FORMAT_VERSION);
blob.extend_from_slice(&1u32.to_le_bytes()); blob.extend_from_slice(&1u32.to_le_bytes()); blob.push(metric_byte);
blob.extend_from_slice(&2u32.to_le_bytes()); blob.extend_from_slice(&0u32.to_le_bytes()); blob.extend_from_slice(&0u32.to_le_bytes()); blob.extend_from_slice(&0.0f32.to_le_bytes()); blob.extend_from_slice(&1.0f32.to_le_bytes()); blob.extend_from_slice(&0u32.to_le_bytes()); blob.extend_from_slice(&1u32.to_le_bytes()); blob.extend_from_slice(&1u32.to_le_bytes()); blob.extend_from_slice(&99u32.to_le_bytes()); blob.extend_from_slice(&0u32.to_le_bytes());
blob.extend_from_slice(&1u32.to_le_bytes());
blob.extend_from_slice(&1u32.to_le_bytes());
blob.extend_from_slice(&0u32.to_le_bytes()); blob.push(0); assert!(
matches!(
HnswIndex::from_bytes(&blob),
Err(LuciError::IndexCorrupted(_))
),
"out-of-range neighbour id must be IndexCorrupted"
);
}
#[test]
fn from_bytes_honors_quantization_none() {
let mut builder = HnswBuilder::new(HnswParams {
dims: 4,
m: 16,
ef_construction: 100,
metric: DistanceMetric::L2,
quantization: QuantizationType::None,
});
builder.add_vector(vec![1.0, 2.0, 3.0, 4.0]).unwrap();
builder.add_vector(vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let index = builder.build();
assert!(
index.quantized.is_none(),
"build must produce quantized = None when mapping says None",
);
let bytes = index.to_bytes();
let restored = HnswIndex::from_bytes(&bytes).unwrap();
assert!(
restored.quantized.is_none(),
"from_bytes must honor on-disk `no quantized data` flag; \
auto-synthesising int8 overrides the user's mapping",
);
assert_eq!(restored.params.quantization, QuantizationType::None);
}
#[test]
fn from_bytes_preserves_quantization_int8() {
let mut builder = HnswBuilder::new(HnswParams {
dims: 4,
m: 16,
ef_construction: 100,
metric: DistanceMetric::L2,
quantization: QuantizationType::Int8,
});
builder.add_vector(vec![1.0, 2.0, 3.0, 4.0]).unwrap();
builder.add_vector(vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let index = builder.build();
assert!(
index.quantized.is_some(),
"build must produce quantized = Some when mapping says Int8",
);
let bytes = index.to_bytes();
let restored = HnswIndex::from_bytes(&bytes).unwrap();
assert!(restored.quantized.is_some());
assert_eq!(restored.params.quantization, QuantizationType::Int8);
}
#[test]
fn empty_index() {
let builder = HnswBuilder::new(make_params(2));
let index = builder.build();
let results = index.search(&[0.0, 0.0], 5, 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn single_vector() {
let mut builder = HnswBuilder::new(make_params(2));
builder.add_vector(vec![1.0, 1.0]).unwrap();
let index = builder.build();
let results = index.search(&[0.0, 0.0], 1, 10).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 0);
}
#[test]
fn cosine_metric() {
let params = HnswParams {
dims: 3,
m: 8,
ef_construction: 50,
metric: DistanceMetric::Cosine,
quantization: QuantizationType::DEFAULT,
};
let mut builder = HnswBuilder::new(params);
builder.add_vector(vec![1.0, 0.0, 0.0]).unwrap(); builder.add_vector(vec![0.0, 1.0, 0.0]).unwrap(); builder.add_vector(vec![0.9, 0.1, 0.0]).unwrap();
let index = builder.build();
let results = index.search(&[1.0, 0.0, 0.0], 2, 10).unwrap();
assert_eq!(results[0].0, 0);
assert_eq!(results[1].0, 2);
}
#[test]
#[should_panic(expected = "unimplemented quantization")]
fn hnsw_builder_panics_on_int4_quantization() {
HnswBuilder::new(HnswParams {
dims: 4,
m: 8,
ef_construction: 50,
metric: DistanceMetric::Cosine,
quantization: QuantizationType::Int4,
});
}
#[test]
#[should_panic(expected = "unimplemented quantization")]
fn hnsw_builder_panics_on_bbq_quantization() {
HnswBuilder::new(HnswParams {
dims: 4,
m: 8,
ef_construction: 50,
metric: DistanceMetric::Cosine,
quantization: QuantizationType::Bbq,
});
}
#[test]
fn builder_normalizes_input_on_cosine() {
let mut builder = HnswBuilder::new(HnswParams {
dims: 3,
m: 8,
ef_construction: 50,
metric: DistanceMetric::Cosine,
quantization: QuantizationType::None,
});
builder.add_vector(vec![3.0, 0.0, 4.0]).unwrap(); let v = &builder.vectors[0];
let norm_sq: f32 = v.iter().map(|x| x * x).sum();
assert!(
(norm_sq - 1.0).abs() < 1e-5,
"stored vector must be unit-length, got norm_sq = {norm_sq}",
);
assert!((v[0] - 0.6).abs() < 1e-5);
assert!((v[2] - 0.8).abs() < 1e-5);
}
#[test]
fn builder_rejects_zero_vector_with_cosine() {
let mut builder = HnswBuilder::new(HnswParams {
dims: 3,
m: 8,
ef_construction: 50,
metric: DistanceMetric::Cosine,
quantization: QuantizationType::None,
});
let err = builder.add_vector(vec![0.0, 0.0, 0.0]).unwrap_err();
assert!(
matches!(err, LuciError::InvalidQuery(_)),
"expected InvalidQuery, got {err:?}",
);
}
#[test]
fn builder_accepts_zero_vector_with_dot_product() {
let mut builder = HnswBuilder::new(HnswParams {
dims: 3,
m: 8,
ef_construction: 50,
metric: DistanceMetric::DotProduct,
quantization: QuantizationType::None,
});
builder.add_vector(vec![0.0, 0.0, 0.0]).unwrap();
assert_eq!(builder.vectors[0], vec![0.0, 0.0, 0.0]);
}
#[test]
fn bulk_aborts_on_zero_vector() {
let mut builder = HnswBuilder::new(HnswParams {
dims: 3,
m: 8,
ef_construction: 50,
metric: DistanceMetric::Cosine,
quantization: QuantizationType::None,
});
builder.add_vector(vec![1.0, 0.0, 0.0]).unwrap();
let err = builder.add_vector(vec![0.0, 0.0, 0.0]).unwrap_err();
assert!(matches!(err, LuciError::InvalidQuery(_)));
assert_eq!(builder.vectors.len(), 1);
}
#[test]
fn hnsw_builder_accepts_none_quantization() {
let mut builder = HnswBuilder::new(HnswParams {
dims: 3,
m: 8,
ef_construction: 50,
metric: DistanceMetric::Cosine,
quantization: QuantizationType::None,
});
builder.add_vector(vec![1.0, 0.0, 0.0]).unwrap();
let index = builder.build();
assert!(index.quantized.is_none());
}
#[test]
fn m_max_is_2m_at_layer_0() {
let builder = HnswBuilder::new(HnswParams {
dims: 2,
m: 8,
ef_construction: 50,
metric: DistanceMetric::L2,
quantization: QuantizationType::None,
});
assert_eq!(builder.m_max(0), 16);
assert_eq!(builder.m_max(1), 8);
assert_eq!(builder.m_max(5), 8);
}
#[test]
fn select_neighbors_heuristic_rejects_clustered_candidate() {
let mut builder = HnswBuilder::new(HnswParams {
dims: 2,
m: 2,
ef_construction: 10,
metric: DistanceMetric::L2,
quantization: QuantizationType::None,
});
builder.add_vector(vec![0.0, 0.0]).unwrap(); builder.add_vector(vec![1.0, 0.0]).unwrap(); builder.add_vector(vec![1.0, 0.05]).unwrap(); builder.add_vector(vec![0.0, 1.05]).unwrap();
let candidates = vec![
Candidate { id: 1, dist: 1.0 }, Candidate {
id: 2,
dist: 1.00125,
}, Candidate { id: 3, dist: 1.05 }, ];
let selected = builder.select_neighbors_heuristic(&candidates, 2);
let ids: HashSet<u32> = selected.iter().map(|c| c.id).collect();
assert_eq!(selected.len(), 2);
assert!(
ids.contains(&1),
"Expected A (id 1) selected, got {:?}",
ids
);
assert!(
ids.contains(&3),
"Expected C (id 3) selected (diverse direction), got {:?}",
ids
);
assert!(
!ids.contains(&2),
"Expected B (id 2) rejected (too close to A), got {:?}",
ids
);
}
#[test]
fn select_neighbors_heuristic_satisfies_diversity_invariant() {
let dims = 4;
let n = 30;
let mut builder = HnswBuilder::new(HnswParams {
dims,
m: 8,
ef_construction: 50,
metric: DistanceMetric::L2,
quantization: QuantizationType::None,
});
let mut rng: u64 = 12345;
for _ in 0..n {
let mut v = Vec::with_capacity(dims);
for _ in 0..dims {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
v.push((rng as f32 / u64::MAX as f32) * 2.0 - 1.0);
}
builder.add_vector(v).unwrap();
}
let candidates: Vec<Candidate> = (1..n as u32)
.map(|id| Candidate {
id,
dist: builder.dist(id, 0),
})
.collect();
let selected = builder.select_neighbors_heuristic(&candidates, 8);
assert!(!selected.is_empty(), "heuristic returned empty selection");
for (i, s_i) in selected.iter().enumerate() {
for s_k in &selected[..i] {
let d_ik = builder.dist(s_i.id, s_k.id);
assert!(
d_ik >= s_i.dist,
"Diversity invariant violated: selected[{i}] (id={}, dist_to_query={}) \
has dist to earlier-selected (id={}) of {}; expected >= {}",
s_i.id,
s_i.dist,
s_k.id,
d_ik,
s_i.dist
);
}
}
}
#[test]
fn add_vector_at_ordinal_fills_reserved_slot() {
let params = make_params(3);
let mut builder = HnswBuilder::with_capacity_for_merge(params, 5);
builder
.add_vector_at_ordinal(2, vec![0.0, 0.0, 1.0])
.unwrap();
builder
.add_vector_at_ordinal(0, vec![1.0, 0.0, 0.0])
.unwrap();
builder
.add_vector_at_ordinal(4, vec![0.5, 0.5, 0.0])
.unwrap();
builder
.add_vector_at_ordinal(1, vec![0.0, 1.0, 0.0])
.unwrap();
builder
.add_vector_at_ordinal(3, vec![0.7, 0.0, 0.7])
.unwrap();
assert_eq!(builder.vectors[0], vec![1.0, 0.0, 0.0]);
assert_eq!(builder.vectors[2], vec![0.0, 0.0, 1.0]);
assert_eq!(builder.vectors[4], vec![0.5, 0.5, 0.0]);
let index = builder.build();
let results = index.search(&[1.0, 0.0, 0.0], 1, 10).unwrap();
assert_eq!(results[0].0, 0, "x-axis vector should match query closest");
}
#[test]
fn seed_from_graph_round_trips_topology() {
let dims = 3;
let mut src = HnswBuilder::new(make_params(dims));
src.add_vector(vec![1.0, 0.0, 0.0]).unwrap();
src.add_vector(vec![0.0, 1.0, 0.0]).unwrap();
src.add_vector(vec![0.0, 0.0, 1.0]).unwrap();
src.add_vector(vec![0.7, 0.0, 0.7]).unwrap();
let src_index = src.build();
let src_bytes = src_index.to_bytes();
let graph = ParsedGraph::parse(&src_bytes).unwrap();
let mut merged = HnswBuilder::with_capacity_for_merge(make_params(dims), 4);
merged.seed_from_graph(&graph, &src_bytes, |src_ord| src_ord);
let (merged_ep, merged_max) = unpack_entry(merged.entry.load(AtomicOrdering::Relaxed));
assert_eq!(merged_ep, graph.entry_point.unwrap_or(ENTRY_SENTINEL));
assert_eq!(merged_max as usize, graph.max_level);
for ord in 0..4 {
assert_eq!(
merged.vectors[ord].len(),
dims,
"ordinal {ord} should be seeded with a {dims}-dim vector",
);
}
let merged_index = merged.build();
let results = merged_index.search(&[1.0, 0.0, 0.0], 1, 10).unwrap();
assert_eq!(results[0].0, 0);
}
}