use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::core::{DocId, FieldId, LuciError, SegmentId};
use crate::mapping::Mapping;
use super::DistanceMetric;
use super::hnsw::{
BuildThreads, HnswBuilder, HnswIndex, HnswParams, checked_len, read_u32, read_u64, take_bytes,
};
const HNSW_M: usize = 16;
const HNSW_EF_CONSTRUCTION: usize = 100;
const VECTOR_INDEX_MAGIC: [u8; 4] = *b"VIDX";
const VECTOR_INDEX_VERSION: u8 = 1;
struct FieldGlobalHnsw {
builder: HnswBuilder,
resolver: Vec<(SegmentId, u32)>,
cached: Option<Arc<HnswIndex>>,
}
impl FieldGlobalHnsw {
fn new(params: HnswParams) -> Self {
Self {
builder: HnswBuilder::new(params),
resolver: Vec::new(),
cached: None,
}
}
fn invalidate_cache(&mut self) {
self.cached = None;
}
fn get_or_build_index(&mut self) -> Arc<HnswIndex> {
debug_assert!(
!self.builder.has_pending_tail(),
"get_or_build_index called with an unlinked pending tail; \
connect_pending was not run before persist",
);
if self.cached.is_none() {
self.cached = Some(Arc::new(self.builder.clone().build()));
}
Arc::clone(self.cached.as_ref().unwrap())
}
}
pub struct GlobalHnsw {
per_field: Mutex<HashMap<FieldId, FieldGlobalHnsw>>,
}
#[derive(Clone, Copy, Debug)]
pub struct GlobalHit {
pub segment_id: SegmentId,
pub doc_id: DocId,
pub distance: f32,
}
impl GlobalHnsw {
pub fn new(schema: &Mapping) -> Self {
let mut per_field = HashMap::new();
for mapping in schema.fields() {
let Some(dims) = mapping.field_type.vector_dims() else {
continue;
};
if dims == 0 {
continue;
}
let field_id = schema.field_id(&mapping.name).unwrap_or_else(|| {
panic!(
"schema.fields() returned mapping for {:?} but \
field_id() couldn't find it; schema is internally \
inconsistent",
mapping.name
);
});
let quantization = mapping
.field_type
.vector_quantization()
.expect("dense_vector mapping must carry quantization");
per_field.insert(
field_id,
FieldGlobalHnsw::new(HnswParams {
dims,
m: HNSW_M,
ef_construction: HNSW_EF_CONSTRUCTION,
metric: DistanceMetric::Cosine,
quantization,
}),
);
}
Self {
per_field: Mutex::new(per_field),
}
}
pub fn add_vector(
&self,
field_id: FieldId,
segment_id: SegmentId,
local_doc_id: u32,
vector: Vec<f32>,
) -> Result<u32, LuciError> {
let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
let field = guard.get_mut(&field_id).ok_or_else(|| {
LuciError::InvalidQuery(format!(
"GlobalHnsw::add_vector called for field {field_id:?} which is \
not a dense_vector field in the schema; this is an internal \
wiring bug",
))
})?;
let ord = field.builder.len() as u32;
field.builder.add_vector(vector)?;
field.resolver.push((segment_id, local_doc_id));
field.invalidate_cache();
debug_assert_eq!(
field.resolver.len(),
field.builder.len(),
"resolver and builder lengths must agree after add_vector",
);
Ok(ord)
}
pub fn store_vector(
&self,
field_id: FieldId,
segment_id: SegmentId,
local_doc_id: u32,
vector: Vec<f32>,
) -> Result<u32, LuciError> {
let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
let field = guard.get_mut(&field_id).ok_or_else(|| {
LuciError::InvalidQuery(format!(
"GlobalHnsw::store_vector called for field {field_id:?} which is \
not a dense_vector field in the schema; this is an internal \
wiring bug",
))
})?;
let ord = field.builder.len() as u32;
field.builder.store_vector(vector)?;
field.resolver.push((segment_id, local_doc_id));
field.invalidate_cache();
debug_assert_eq!(
field.resolver.len(),
field.builder.len(),
"resolver and builder lengths must agree after store_vector",
);
Ok(ord)
}
pub fn connect_pending(&self, field_id: FieldId, threads: BuildThreads) {
let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
if let Some(field) = guard.get_mut(&field_id) {
field.builder.connect_pending(threads);
field.invalidate_cache();
}
}
pub fn search(
&self,
field_id: FieldId,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Option<(Vec<GlobalHit>, DistanceMetric)>, LuciError> {
let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
let field = match guard.get_mut(&field_id) {
Some(f) => f,
None => return Ok(None),
};
let metric = field.builder.params().metric;
let resolver = field.resolver.clone();
let index = field.get_or_build_index();
drop(guard);
let raw = index.search(query, k, ef)?;
let hits = raw
.into_iter()
.map(|(global_ord, dist)| {
let (seg, doc) = resolver[global_ord as usize];
GlobalHit {
segment_id: seg,
doc_id: DocId::new(doc),
distance: dist,
}
})
.collect();
Ok(Some((hits, metric)))
}
pub fn len(&self, field_id: FieldId) -> Option<usize> {
let guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
guard.get(&field_id).map(|f| f.builder.len())
}
pub fn is_empty(&self) -> bool {
let guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
guard.values().all(|f| f.builder.is_empty())
}
pub fn rewrite_after_merge(&self, merge_map: &HashMap<(SegmentId, u32), (SegmentId, u32)>) {
if merge_map.is_empty() {
return;
}
let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
for field in guard.values_mut() {
let mut changed = false;
for entry in &mut field.resolver {
if let Some(&(new_seg, new_doc)) = merge_map.get(entry) {
*entry = (new_seg, new_doc);
changed = true;
}
}
if changed {
field.invalidate_cache();
}
}
}
pub fn field_ids(&self) -> Vec<FieldId> {
let guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
guard.keys().copied().collect()
}
pub fn non_empty_field_ids(&self) -> Vec<FieldId> {
let guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
let mut ids: Vec<FieldId> = guard
.iter()
.filter(|(_, f)| !f.builder.is_empty())
.map(|(fid, _)| *fid)
.collect();
ids.sort();
ids
}
pub fn field_to_bytes(&self, field_id: FieldId) -> Option<Vec<u8>> {
let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
let field = guard.get_mut(&field_id)?;
let mut buf = Vec::new();
buf.extend_from_slice(&VECTOR_INDEX_MAGIC);
buf.push(VECTOR_INDEX_VERSION);
let index = field.get_or_build_index();
let hnsw_bytes = index.to_bytes();
buf.extend_from_slice(&(hnsw_bytes.len() as u32).to_le_bytes());
buf.extend_from_slice(&hnsw_bytes);
buf.extend_from_slice(&(field.resolver.len() as u32).to_le_bytes());
for (seg, doc) in &field.resolver {
buf.extend_from_slice(&seg.as_u64().to_le_bytes());
buf.extend_from_slice(&doc.to_le_bytes());
}
Some(buf)
}
pub fn load_field(&self, field_id: FieldId, data: &[u8]) -> Result<(), LuciError> {
if data.len() < 5 {
return Err(LuciError::IndexCorrupted(format!(
"vector index blob for field {field_id:?} too short: {} bytes",
data.len()
)));
}
if data[0..4] != VECTOR_INDEX_MAGIC {
return Err(LuciError::IndexCorrupted(format!(
"vector index blob for field {field_id:?} missing magic prefix"
)));
}
if data[4] != VECTOR_INDEX_VERSION {
return Err(LuciError::SegmentFormatUnknown(format!(
"unknown vector index blob version {} for field {field_id:?}",
data[4]
)));
}
let mut pos = 5;
let hnsw_len = read_u32(data, &mut pos)? as usize;
let hnsw_bytes = take_bytes(data, &mut pos, hnsw_len)?;
let index = HnswIndex::from_bytes(hnsw_bytes)?;
let builder = HnswBuilder::from_index(index);
let resolver_len = read_u32(data, &mut pos)? as usize;
let mut resolver = Vec::with_capacity(checked_len(resolver_len, 12, data, pos)?);
for _ in 0..resolver_len {
let seg = SegmentId::new(read_u64(data, &mut pos)?);
let doc = read_u32(data, &mut pos)?;
resolver.push((seg, doc));
}
if resolver.len() != builder.len() {
return Err(LuciError::IndexCorrupted(format!(
"vector index resolver/graph mismatch for field {field_id:?}: \
graph has {} vectors, resolver has {}",
builder.len(),
resolver.len()
)));
}
let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
let field = guard.get_mut(&field_id).ok_or_else(|| {
LuciError::InvalidQuery(format!(
"GlobalHnsw::load_field called for field {field_id:?} which is \
not a dense_vector field in the current schema"
))
})?;
field.builder = builder;
field.resolver = resolver;
field.cached = None;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::mapping::{FieldType, Mapping};
use super::*;
fn vector_schema(name: &str, dims: usize) -> Mapping {
Mapping::builder()
.field(name, FieldType::dense_vector(dims))
.build()
}
#[test]
fn new_finds_dense_vector_fields() {
let schema = vector_schema("embedding", 4);
let g = GlobalHnsw::new(&schema);
let ids = g.field_ids();
assert_eq!(ids.len(), 1);
let field_id = schema.field_id("embedding").unwrap();
assert_eq!(g.len(field_id), Some(0));
}
#[test]
fn new_with_no_vector_fields_is_empty() {
let schema = Mapping::builder().build();
let g = GlobalHnsw::new(&schema);
assert!(g.is_empty());
}
#[test]
fn add_vector_returns_increasing_ordinals() {
let schema = vector_schema("embedding", 3);
let field_id = schema.field_id("embedding").unwrap();
let g = GlobalHnsw::new(&schema);
let seg = SegmentId::new(1);
let ord0 = g.add_vector(field_id, seg, 0, vec![1.0, 0.0, 0.0]).unwrap();
let ord1 = g.add_vector(field_id, seg, 1, vec![0.0, 1.0, 0.0]).unwrap();
let ord2 = g.add_vector(field_id, seg, 2, vec![0.0, 0.0, 1.0]).unwrap();
assert_eq!(ord0, 0);
assert_eq!(ord1, 1);
assert_eq!(ord2, 2);
assert_eq!(g.len(field_id), Some(3));
}
#[test]
fn add_vector_for_unknown_field_errors() {
let schema = vector_schema("embedding", 3);
let g = GlobalHnsw::new(&schema);
let seg = SegmentId::new(1);
let result = g.add_vector(FieldId(999), seg, 0, vec![1.0, 0.0, 0.0]);
assert!(matches!(result, Err(LuciError::InvalidQuery(_))));
}
#[test]
fn cosine_zero_vector_rejected() {
let schema = vector_schema("embedding", 3);
let field_id = schema.field_id("embedding").unwrap();
let g = GlobalHnsw::new(&schema);
let seg = SegmentId::new(1);
let result = g.add_vector(field_id, seg, 0, vec![0.0, 0.0, 0.0]);
assert!(matches!(result, Err(LuciError::InvalidQuery(_))));
}
#[test]
fn search_returns_hits_in_segment_local_doc_space() {
let schema = vector_schema("embedding", 3);
let field_id = schema.field_id("embedding").unwrap();
let g = GlobalHnsw::new(&schema);
let seg1 = SegmentId::new(1);
let seg2 = SegmentId::new(2);
g.add_vector(field_id, seg1, 0, vec![1.0, 0.0, 0.0])
.unwrap();
g.add_vector(field_id, seg1, 1, vec![0.0, 1.0, 0.0])
.unwrap();
g.add_vector(field_id, seg2, 0, vec![0.9, 0.1, 0.0])
.unwrap();
let (hits, metric) = g
.search(field_id, &[1.0, 0.0, 0.0], 3, 16)
.unwrap()
.unwrap();
assert_eq!(metric, DistanceMetric::Cosine);
assert_eq!(hits.len(), 3);
assert_eq!(hits[0].segment_id, seg1);
assert_eq!(hits[0].doc_id, DocId::new(0));
}
#[test]
fn roundtrip_field_to_bytes_load_field() {
let schema = vector_schema("embedding", 3);
let field_id = schema.field_id("embedding").unwrap();
let g = GlobalHnsw::new(&schema);
let seg1 = SegmentId::new(1);
g.add_vector(field_id, seg1, 0, vec![1.0, 0.0, 0.0])
.unwrap();
g.add_vector(field_id, seg1, 1, vec![0.0, 1.0, 0.0])
.unwrap();
g.add_vector(field_id, seg1, 2, vec![0.0, 0.0, 1.0])
.unwrap();
let bytes = g.field_to_bytes(field_id).unwrap();
let g2 = GlobalHnsw::new(&schema);
g2.load_field(field_id, &bytes).unwrap();
assert_eq!(g2.len(field_id), Some(3));
let (hits, _) = g2
.search(field_id, &[1.0, 0.0, 0.0], 1, 16)
.unwrap()
.unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].segment_id, seg1);
assert_eq!(hits[0].doc_id, DocId::new(0));
}
#[test]
fn load_field_rejects_corrupt_blob() {
let schema = vector_schema("embedding", 3);
let field_id = schema.field_id("embedding").unwrap();
let g = GlobalHnsw::new(&schema);
let seg = SegmentId::new(1);
g.add_vector(field_id, seg, 0, vec![1.0, 0.0, 0.0]).unwrap();
g.add_vector(field_id, seg, 1, vec![0.0, 1.0, 0.0]).unwrap();
let valid = g.field_to_bytes(field_id).unwrap();
assert!(
GlobalHnsw::new(&schema)
.load_field(field_id, &valid)
.is_ok(),
"valid blob must load"
);
for cut in [5usize, 6, 9, valid.len() / 2, valid.len() - 1] {
assert!(
GlobalHnsw::new(&schema)
.load_field(field_id, &valid[..cut])
.is_err(),
"truncated-to-{cut} blob must be rejected, not panic"
);
}
let hnsw_len = u32::from_le_bytes(valid[5..9].try_into().unwrap()) as usize;
let resolver_len_off = 5 + 4 + hnsw_len;
let mut bad_resolver = valid.clone();
bad_resolver[resolver_len_off..resolver_len_off + 4]
.copy_from_slice(&u32::MAX.to_le_bytes());
assert!(
matches!(
GlobalHnsw::new(&schema).load_field(field_id, &bad_resolver),
Err(LuciError::IndexCorrupted(_))
),
"corrupt resolver length must be IndexCorrupted, not OOM/panic"
);
}
#[test]
fn non_empty_field_ids_omits_empty_fields() {
let schema = Mapping::builder()
.field("a", FieldType::dense_vector(2))
.field("b", FieldType::dense_vector(2))
.build();
let a = schema.field_id("a").unwrap();
let b = schema.field_id("b").unwrap();
let g = GlobalHnsw::new(&schema);
g.add_vector(a, SegmentId::new(1), 0, vec![1.0, 0.0])
.unwrap();
let ids = g.non_empty_field_ids();
assert_eq!(ids, vec![a]);
assert_eq!(g.len(b), Some(0));
}
#[test]
fn rewrite_after_merge_remaps_resolver() {
let schema = vector_schema("embedding", 3);
let field_id = schema.field_id("embedding").unwrap();
let g = GlobalHnsw::new(&schema);
let s1 = SegmentId::new(1);
let s2 = SegmentId::new(2);
let s3 = SegmentId::new(3);
g.add_vector(field_id, s1, 0, vec![1.0, 0.0, 0.0]).unwrap();
g.add_vector(field_id, s2, 0, vec![0.0, 1.0, 0.0]).unwrap();
let mut merge_map = HashMap::new();
merge_map.insert((s1, 0), (s3, 0));
merge_map.insert((s2, 0), (s3, 1));
g.rewrite_after_merge(&merge_map);
let (hits, _) = g
.search(field_id, &[1.0, 0.0, 0.0], 2, 16)
.unwrap()
.unwrap();
assert_eq!(hits.len(), 2);
for hit in &hits {
assert_eq!(hit.segment_id, s3);
}
}
}