#![allow(clippy::doc_markdown)]
use std::sync::Arc;
use instant_distance::{Builder, HnswMap, Point as IdPoint, Search};
use mnem_core::codec::from_canonical_bytes;
use mnem_core::error::{Error, RepoError};
use mnem_core::id::NodeId;
use mnem_core::index::vector::{VectorHit, VectorIndex};
use mnem_core::objects::{Dtype, Embedding, Node};
use mnem_core::prolly::Cursor;
use mnem_core::repo::ReadonlyRepo;
use mnem_core::store::Blockstore;
#[derive(Clone, Debug)]
pub struct HnswConfig {
pub ef_construction: usize,
pub ef_search: usize,
pub seed: u64,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
ef_construction: 200,
ef_search: 100,
seed: 0x6DEF_1EE7_5CE8_7D55,
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct Point {
pub(crate) vec: Vec<f32>,
}
impl IdPoint for Point {
fn distance(&self, other: &Self) -> f32 {
debug_assert_eq!(self.vec.len(), other.vec.len());
let mut acc = 0.0_f32;
for (x, y) in self.vec.iter().zip(other.vec.iter()) {
let d = x - y;
acc += d * d;
}
acc
}
}
pub struct HnswVectorIndex {
model: String,
dim: u32,
pub(crate) ids: Vec<NodeId>,
pub(crate) points: Vec<Point>,
inner: HnswMap<Point, usize>,
ef_search: usize,
}
impl HnswVectorIndex {
pub fn points_iter(&self) -> impl Iterator<Item = (NodeId, &[f32])> + '_ {
self.ids
.iter()
.zip(self.points.iter())
.map(|(id, p)| (*id, p.vec.as_slice()))
}
#[must_use]
pub fn points_len(&self) -> usize {
self.ids.len()
}
}
impl std::fmt::Debug for HnswVectorIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HnswVectorIndex")
.field("model", &self.model)
.field("dim", &self.dim)
.field("len", &self.ids.len())
.finish()
}
}
impl HnswVectorIndex {
pub fn build_from_repo(repo: &ReadonlyRepo, model: &str) -> Result<Self, Error> {
Self::build_from_repo_with(repo, model, HnswConfig::default())
}
pub fn build_from_repo_with(
repo: &ReadonlyRepo,
model: &str,
cfg: HnswConfig,
) -> Result<Self, Error> {
let bs: Arc<dyn Blockstore> = repo.blockstore().clone();
let Some(commit) = repo.head_commit() else {
return Err(RepoError::Uninitialized.into());
};
let mut ids: Vec<NodeId> = Vec::new();
let mut points: Vec<Point> = Vec::new();
let mut dim: Option<u32> = None;
let cursor = Cursor::new(&*bs, &commit.nodes)?;
for entry in cursor {
let (_k, node_cid) = entry?;
let bytes = bs
.get(&node_cid)
.map_err(Error::from)?
.ok_or_else(|| Error::from(RepoError::NotFound))?;
let node: Node = from_canonical_bytes(&bytes).map_err(Error::from)?;
let Some(embed) = repo.embedding_for(&node_cid, model)? else {
continue;
};
embed.validate()?;
if let Some(d) = dim {
if embed.dim != d {
continue;
}
} else {
dim = Some(embed.dim);
}
let Some(vec_f32) = decode_to_f32(&embed) else {
continue;
};
let normalised = normalise(vec_f32);
ids.push(node.id);
points.push(Point { vec: normalised });
}
let dim = dim.unwrap_or(0);
if points.is_empty() {
return Ok(Self {
model: model.into(),
dim,
ids: Vec::new(),
points: Vec::new(),
inner: Builder::default().build(Vec::<Point>::new(), Vec::<usize>::new()),
ef_search: cfg.ef_search,
});
}
let values: Vec<usize> = (0..points.len()).collect();
let points_retained = points.clone();
let inner = Builder::default()
.ef_construction(cfg.ef_construction)
.seed(cfg.seed)
.build(points, values);
Ok(Self {
model: model.into(),
dim,
ids,
points: points_retained,
inner,
ef_search: cfg.ef_search,
})
}
#[doc(hidden)]
#[must_use]
pub fn from_parts_for_test(
model: &str,
dim: u32,
ids: Vec<NodeId>,
normalised_vecs: Vec<Vec<f32>>,
cfg: &HnswConfig,
) -> Self {
assert_eq!(ids.len(), normalised_vecs.len(), "ids/vecs length mismatch");
let points: Vec<Point> = normalised_vecs
.into_iter()
.map(|v| Point { vec: v })
.collect();
if points.is_empty() {
return Self {
model: model.into(),
dim,
ids,
points,
inner: Builder::default().build(Vec::<Point>::new(), Vec::<usize>::new()),
ef_search: cfg.ef_search,
};
}
let values: Vec<usize> = (0..points.len()).collect();
let points_retained = points.clone();
let inner = Builder::default()
.ef_construction(cfg.ef_construction)
.seed(cfg.seed)
.build(points, values);
Self {
model: model.into(),
dim,
ids,
points: points_retained,
inner,
ef_search: cfg.ef_search,
}
}
}
impl VectorIndex for HnswVectorIndex {
fn model(&self) -> &str {
&self.model
}
fn dim(&self) -> u32 {
self.dim
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<VectorHit>, Error> {
if self.dim == 0 && self.ids.is_empty() {
return Ok(Vec::new());
}
if query.len() != self.dim as usize {
return Err(RepoError::VectorDimMismatch {
index_dim: self.dim,
query_dim: query.len(),
}
.into());
}
if k == 0 {
return Ok(Vec::new());
}
let q = Point {
vec: normalise(query.to_vec()),
};
let mut searcher = Search::default();
let fetch = std::cmp::max(k, self.ef_search);
let mut hits: Vec<VectorHit> = Vec::with_capacity(k);
for item in self.inner.search(&q, &mut searcher).take(fetch) {
let ord = *item.value;
let node_id = self.ids[ord];
let score = 1.0 - item.distance * 0.5;
hits.push(VectorHit::new(node_id, score));
}
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.node_id.cmp(&b.node_id))
});
hits.truncate(k);
Ok(hits)
}
fn len(&self) -> usize {
self.ids.len()
}
}
fn decode_to_f32(embed: &Embedding) -> Option<Vec<f32>> {
let dim = embed.dim as usize;
let bytes = &embed.vector;
if bytes.len() != dim * embed.dtype.byte_width() {
return None;
}
match embed.dtype {
Dtype::F32 => {
let mut out = Vec::with_capacity(dim);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Some(out)
}
Dtype::F64 => {
let mut out = Vec::with_capacity(dim);
for chunk in bytes.chunks_exact(8) {
out.push(f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]) as f32);
}
Some(out)
}
_ => None,
}
}
fn normalise(mut v: Vec<f32>) -> Vec<f32> {
let mut sq = 0.0_f32;
for x in &v {
sq += x * x;
}
if sq > 0.0 {
let inv = sq.sqrt().recip();
for x in &mut v {
*x *= inv;
}
}
v
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_build_returns_len_zero_index() {
let cfg = HnswConfig::default();
let built = Builder::default()
.ef_construction(cfg.ef_construction)
.seed(cfg.seed)
.build(Vec::<Point>::new(), Vec::<usize>::new());
let idx = HnswVectorIndex {
model: "m".into(),
dim: 0,
ids: Vec::new(),
points: Vec::new(),
inner: built,
ef_search: cfg.ef_search,
};
assert!(idx.is_empty());
let hits = idx.search(&[0.0_f32; 3], 5).unwrap();
assert!(hits.is_empty());
}
#[test]
fn dim_mismatch_errors() {
use mnem_core::error::RepoError;
let points = vec![
Point {
vec: normalise(vec![1.0, 0.0, 0.0]),
},
Point {
vec: normalise(vec![0.0, 1.0, 0.0]),
},
];
let values = vec![0_usize, 1];
let points_retained = points.clone();
let inner = Builder::default().build(points, values);
let idx = HnswVectorIndex {
model: "m".into(),
dim: 3,
ids: vec![NodeId::new_v7(), NodeId::new_v7()],
points: points_retained,
inner,
ef_search: 10,
};
let err = idx.search(&[1.0, 0.0], 1).unwrap_err();
assert!(matches!(
err,
Error::Repo(RepoError::VectorDimMismatch {
index_dim: 3,
query_dim: 2,
})
));
}
#[test]
fn identical_query_is_top_hit() {
let id_a = NodeId::new_v7();
let id_b = NodeId::new_v7();
let points = vec![
Point {
vec: normalise(vec![1.0, 0.0, 0.0]),
},
Point {
vec: normalise(vec![0.0, 1.0, 0.0]),
},
];
let points_retained = points.clone();
let inner = Builder::default().build(points, vec![0_usize, 1]);
let idx = HnswVectorIndex {
model: "m".into(),
dim: 3,
ids: vec![id_a, id_b],
points: points_retained,
inner,
ef_search: 10,
};
let hits = idx.search(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(hits[0].node_id, id_a, "exact match should rank #1");
assert!(
(hits[0].score - 1.0).abs() < 1e-5,
"expected cos == 1, got {}",
hits[0].score
);
}
#[test]
fn score_is_cosine_not_euclidean() {
let id_a = NodeId::new_v7();
let id_b = NodeId::new_v7();
let points = vec![
Point {
vec: normalise(vec![1.0, 0.0]),
},
Point {
vec: normalise(vec![0.0, 1.0]),
},
];
let points_retained = points.clone();
let inner = Builder::default().build(points, vec![0_usize, 1]);
let idx = HnswVectorIndex {
model: "m".into(),
dim: 2,
ids: vec![id_a, id_b],
points: points_retained,
inner,
ef_search: 10,
};
let hits = idx.search(&[1.0, 0.0], 2).unwrap();
let orth = hits.iter().find(|h| h.node_id == id_b).unwrap();
assert!(
orth.score.abs() < 1e-5,
"expected orthogonal cos ~= 0; got {}",
orth.score
);
}
fn f32_embed(model: &str, v: &[f32]) -> Embedding {
let mut bytes = Vec::with_capacity(v.len() * 4);
for x in v {
bytes.extend_from_slice(&x.to_le_bytes());
}
Embedding {
model: model.to_string(),
dtype: Dtype::F32,
dim: u32::try_from(v.len()).expect("test vec fits in u32"),
vector: bytes::Bytes::from(bytes),
}
}
fn stores() -> (
Arc<dyn mnem_core::store::Blockstore>,
Arc<dyn mnem_core::store::OpHeadsStore>,
) {
(
Arc::new(mnem_core::store::MemoryBlockstore::new()),
Arc::new(mnem_core::store::MemoryOpHeadsStore::new()),
)
}
#[test]
fn build_from_repo_reads_sidecar_embeddings() {
let (bs, ohs) = stores();
let repo = ReadonlyRepo::init(bs, ohs).unwrap();
let mut tx = repo.start_transaction();
let id_a = NodeId::from_bytes_raw([1u8; 16]);
let id_b = NodeId::from_bytes_raw([2u8; 16]);
let cid_a = tx.add_node(&Node::new(id_a, "Doc")).unwrap();
let cid_b = tx.add_node(&Node::new(id_b, "Doc")).unwrap();
tx.set_embedding(cid_a, "mA".into(), f32_embed("mA", &[1.0, 0.0]))
.unwrap();
tx.set_embedding(cid_b, "mA".into(), f32_embed("mA", &[0.0, 1.0]))
.unwrap();
let id_c = NodeId::from_bytes_raw([3u8; 16]);
let cid_c = tx.add_node(&Node::new(id_c, "Doc")).unwrap();
tx.set_embedding(cid_c, "mB".into(), f32_embed("mB", &[1.0, 0.0]))
.unwrap();
tx.add_node(&Node::new(NodeId::from_bytes_raw([4u8; 16]), "Doc"))
.unwrap();
let repo = tx.commit("t", "seed").unwrap();
let idx = HnswVectorIndex::build_from_repo(&repo, "mA").unwrap();
assert_eq!(idx.len(), 2, "only the two mA nodes should index");
assert_eq!(idx.dim(), 2);
let hits = idx.search(&[1.0, 0.0], 2).unwrap();
assert_eq!(hits[0].node_id, id_a, "exact-match node should rank #1");
assert!(
(hits[0].score - 1.0).abs() < 1e-5,
"expected cos == 1, got {}",
hits[0].score
);
}
}