use std::sync::Arc;
use bytes::Bytes;
use crate::error::{Error, RepoError};
use crate::id::NodeId;
use crate::objects::{Dtype, Embedding, Node};
use crate::prolly::Cursor;
use crate::repo::readonly::{ReadonlyRepo, decode_from_store};
use crate::store::Blockstore;
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub struct VectorHit {
pub node_id: NodeId,
pub score: f32,
}
impl VectorHit {
#[must_use]
pub const fn new(node_id: NodeId, score: f32) -> Self {
Self { node_id, score }
}
}
pub trait VectorIndex: Send + Sync {
fn model(&self) -> &str;
fn dim(&self) -> u32;
fn search(&self, query: &[f32], k: usize) -> Result<Vec<VectorHit>, Error>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub struct BruteForceVectorIndex {
model: String,
dim: u32,
ids: Vec<NodeId>,
data: Vec<f32>,
}
impl BruteForceVectorIndex {
#[must_use]
pub fn empty(model: impl Into<String>, dim: u32) -> Self {
Self {
model: model.into(),
dim,
ids: Vec::new(),
data: Vec::new(),
}
}
#[must_use]
pub fn model(&self) -> &str {
&self.model
}
#[must_use]
pub const fn dim(&self) -> u32 {
self.dim
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.ids.is_empty()
}
pub fn points_iter(&self) -> impl Iterator<Item = (NodeId, &[f32])> + '_ {
let row_len = self.dim as usize;
self.ids.iter().enumerate().map(move |(i, id)| {
let slice = if row_len == 0 {
&[][..]
} else {
&self.data[i * row_len..(i + 1) * row_len]
};
(*id, slice)
})
}
pub fn try_insert(&mut self, node_id: NodeId, embed: &Embedding) -> bool {
if embed.model != self.model {
return false;
}
if embed.dim != self.dim {
return false;
}
let Some(vec_f32) = decode_to_f32(embed) else {
return false;
};
let normalised = normalise(vec_f32);
self.ids.push(node_id);
self.data.extend_from_slice(&normalised);
true
}
pub fn build_from_repo(repo: &ReadonlyRepo, model: &str) -> Result<Self, Error> {
let bs: Arc<dyn Blockstore> = repo.blockstore().clone();
let Some(commit) = repo.head_commit() else {
return Err(RepoError::Uninitialized.into());
};
let mut idx: Option<Self> = None;
let debug = std::env::var("MNEM_DEBUG_VEC").is_ok();
let mut dbg_total = 0usize;
let mut dbg_has_embed = 0usize;
let mut dbg_inserted = 0usize;
let cursor = Cursor::new(&*bs, &commit.nodes)?;
for entry in cursor {
let (_k, node_cid) = entry?;
let node: Node = decode_from_store(&*bs, &node_cid)?;
dbg_total += 1;
let Some(embed) = repo.embedding_for(&node_cid, model)? else {
continue;
};
dbg_has_embed += 1;
if debug && dbg_has_embed <= 3 {
eprintln!(
"[mnem-debug-vec] node embed.model={:?} want={:?} dim={}",
embed.model, model, embed.dim,
);
}
embed.validate()?;
let ok = match idx.as_mut() {
Some(existing) => existing.try_insert(node.id, &embed),
None => {
let mut fresh = Self::empty(model, embed.dim);
let ok = fresh.try_insert(node.id, &embed);
idx = Some(fresh);
ok
}
};
if ok {
dbg_inserted += 1;
}
}
if debug {
eprintln!(
"[mnem-debug-vec] total={dbg_total} has_embed={dbg_has_embed} \
inserted={dbg_inserted} idx_dim={}",
idx.as_ref().map_or(0, |i| i.dim)
);
}
Ok(idx.unwrap_or_else(|| Self::empty(model, 0)))
}
}
impl VectorIndex for BruteForceVectorIndex {
fn model(&self) -> &str {
&self.model
}
fn dim(&self) -> u32 {
self.dim
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<VectorHit>, Error> {
if self.dim == 0 && self.ids.is_empty() {
return Ok(Vec::new());
}
if query.len() != self.dim as usize {
return Err(RepoError::VectorDimMismatch {
index_dim: self.dim,
query_dim: query.len(),
}
.into());
}
if k == 0 || self.ids.is_empty() {
return Ok(Vec::new());
}
let q_norm = normalise(query.to_vec());
let row_len = self.dim as usize;
let mut hits: Vec<VectorHit> = Vec::with_capacity(self.ids.len());
for (i, id) in self.ids.iter().enumerate() {
let row = &self.data[i * row_len..(i + 1) * row_len];
let score = dot(&q_norm, row);
hits.push(VectorHit {
node_id: *id,
score,
});
}
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.node_id.cmp(&b.node_id))
});
hits.truncate(k);
Ok(hits)
}
fn len(&self) -> usize {
self.ids.len()
}
}
fn decode_to_f32(embed: &Embedding) -> Option<Vec<f32>> {
let dim = embed.dim as usize;
let bytes: &Bytes = &embed.vector;
if bytes.len() != dim * embed.dtype.byte_width() {
return None;
}
match embed.dtype {
Dtype::F32 => {
let mut out = Vec::with_capacity(dim);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Some(out)
}
Dtype::F64 => {
let mut out = Vec::with_capacity(dim);
for chunk in bytes.chunks_exact(8) {
let raw = f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]);
out.push(raw as f32);
}
Some(out)
}
Dtype::F16 => {
let mut out = Vec::with_capacity(dim);
for chunk in bytes.chunks_exact(2) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
out.push(f16_bits_to_f32(bits));
}
Some(out)
}
Dtype::I8 => {
let mut out = Vec::with_capacity(dim);
for &b in bytes {
out.push(f32::from(i8::from_ne_bytes([b])));
}
Some(out)
}
}
}
fn f16_bits_to_f32(bits: u16) -> f32 {
let sign = u32::from(bits >> 15) << 31;
let exp = u32::from((bits >> 10) & 0x1F);
let mant = u32::from(bits & 0x3FF);
let out_bits = if exp == 0 {
if mant == 0 {
sign
} else {
let mut m = mant;
let mut e: u32 = 127 - 15 + 1;
while (m & 0x400) == 0 {
m <<= 1;
e = e.saturating_sub(1);
}
m &= 0x3FF;
sign | (e << 23) | (m << 13)
}
} else if exp == 31 {
sign | 0x7F80_0000 | (mant << 13)
} else {
let e = exp + (127 - 15);
sign | (e << 23) | (mant << 13)
};
f32::from_bits(out_bits)
}
fn normalise(mut v: Vec<f32>) -> Vec<f32> {
let norm = dot(&v, &v).sqrt();
if norm > 0.0 && norm.is_finite() {
for x in &mut v {
*x /= norm;
}
}
v
}
fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut acc = 0.0f32;
for i in 0..a.len() {
acc += a[i] * b[i];
}
acc
}
#[cfg(test)]
mod tests {
use super::*;
use crate::objects::{Dtype, Embedding, Node};
use crate::repo::ReadonlyRepo;
use crate::store::{MemoryBlockstore, MemoryOpHeadsStore, OpHeadsStore};
use std::sync::Arc;
fn stores() -> (Arc<dyn Blockstore>, Arc<dyn OpHeadsStore>) {
(
Arc::new(MemoryBlockstore::new()),
Arc::new(MemoryOpHeadsStore::new()),
)
}
fn f32_embed(model: &str, v: &[f32]) -> Embedding {
let mut bytes = Vec::with_capacity(v.len() * 4);
for x in v {
bytes.extend_from_slice(&x.to_le_bytes());
}
Embedding {
model: model.to_string(),
dtype: Dtype::F32,
dim: v.len() as u32,
vector: Bytes::from(bytes),
}
}
#[test]
fn normalise_unit_vector_is_unchanged() {
let v = normalise(vec![1.0, 0.0, 0.0]);
assert!((dot(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn normalise_scales_to_unit_length() {
let v = normalise(vec![3.0, 4.0]);
assert!((dot(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn normalise_zero_vector_stays_zero() {
let v = normalise(vec![0.0, 0.0, 0.0]);
assert_eq!(v, vec![0.0, 0.0, 0.0]);
}
#[test]
fn f16_round_trip_for_common_values() {
assert!((f16_bits_to_f32(0x3C00) - 1.0).abs() < 1e-6);
assert!((f16_bits_to_f32(0xBC00) + 1.0).abs() < 1e-6);
assert_eq!(f16_bits_to_f32(0x0000), 0.0);
assert_eq!(f16_bits_to_f32(0x8000), -0.0);
assert!(f16_bits_to_f32(0x7C00).is_infinite());
}
#[test]
fn empty_index_returns_no_hits() {
let idx = BruteForceVectorIndex::empty("m", 4);
let hits = idx.search(&[0.0, 0.0, 0.0, 0.0], 5).unwrap();
assert!(hits.is_empty());
assert_eq!(idx.len(), 0);
assert!(idx.is_empty());
}
#[test]
fn k_zero_returns_no_hits() {
let mut idx = BruteForceVectorIndex::empty("m", 3);
idx.try_insert(
NodeId::from_bytes_raw([1u8; 16]),
&f32_embed("m", &[1.0, 0.0, 0.0]),
);
let hits = idx.search(&[1.0, 0.0, 0.0], 0).unwrap();
assert!(hits.is_empty());
}
#[test]
fn dim_mismatch_errors_with_both_sides() {
let idx = BruteForceVectorIndex::empty("m", 4);
let err = idx.search(&[0.0, 0.0, 0.0], 3).unwrap_err();
match err {
Error::Repo(RepoError::VectorDimMismatch {
index_dim,
query_dim,
}) => {
assert_eq!(index_dim, 4);
assert_eq!(query_dim, 3);
}
e => panic!("expected VectorDimMismatch, got {e:?}"),
}
}
#[test]
fn wrong_model_is_silently_skipped_on_insert() {
let mut idx = BruteForceVectorIndex::empty("mA", 3);
let inserted = idx.try_insert(
NodeId::from_bytes_raw([1u8; 16]),
&f32_embed("mB", &[1.0, 0.0, 0.0]),
);
assert!(!inserted);
assert!(idx.is_empty());
}
#[test]
fn wrong_dim_is_silently_skipped_on_insert() {
let mut idx = BruteForceVectorIndex::empty("m", 3);
let inserted = idx.try_insert(
NodeId::from_bytes_raw([1u8; 16]),
&f32_embed("m", &[1.0, 0.0]),
);
assert!(!inserted);
}
#[test]
fn nearest_neighbour_wins() {
let mut idx = BruteForceVectorIndex::empty("m", 3);
idx.try_insert(
NodeId::from_bytes_raw([1u8; 16]),
&f32_embed("m", &[1.0, 0.0, 0.0]),
);
idx.try_insert(
NodeId::from_bytes_raw([2u8; 16]),
&f32_embed("m", &[0.0, 1.0, 0.0]),
);
idx.try_insert(
NodeId::from_bytes_raw([3u8; 16]),
&f32_embed("m", &[0.0, 0.0, 1.0]),
);
let hits = idx.search(&[0.9, 0.1, 0.0], 3).unwrap();
assert_eq!(hits[0].node_id, NodeId::from_bytes_raw([1u8; 16]));
assert_eq!(hits[1].node_id, NodeId::from_bytes_raw([2u8; 16]));
assert_eq!(hits[2].node_id, NodeId::from_bytes_raw([3u8; 16]));
assert!((hits[2].score).abs() < 1e-6);
}
#[test]
fn scale_invariance_cosine_similarity() {
let mut idx = BruteForceVectorIndex::empty("m", 3);
idx.try_insert(
NodeId::from_bytes_raw([1u8; 16]),
&f32_embed("m", &[10.0, 0.0, 0.0]),
);
let hits = idx.search(&[0.5, 0.0, 0.0], 1).unwrap();
assert!((hits[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn k_truncates_results() {
let mut idx = BruteForceVectorIndex::empty("m", 2);
for i in 0..20u8 {
idx.try_insert(
NodeId::from_bytes_raw([i; 16]),
&f32_embed("m", &[f32::from(i), 1.0]),
);
}
let hits = idx.search(&[1.0, 1.0], 5).unwrap();
assert_eq!(hits.len(), 5);
}
#[test]
fn ties_broken_by_node_id_ascending() {
let mut idx = BruteForceVectorIndex::empty("m", 2);
let hi = NodeId::from_bytes_raw([0xFFu8; 16]);
let lo = NodeId::from_bytes_raw([0x01u8; 16]);
idx.try_insert(hi, &f32_embed("m", &[1.0, 0.0]));
idx.try_insert(lo, &f32_embed("m", &[1.0, 0.0]));
let hits = idx.search(&[1.0, 0.0], 2).unwrap();
assert_eq!(hits[0].node_id, lo);
assert_eq!(hits[1].node_id, hi);
}
#[test]
fn f64_embeddings_are_indexed() {
let mut bytes = Vec::new();
for x in &[1.0f64, 0.0, 0.0] {
bytes.extend_from_slice(&x.to_le_bytes());
}
let embed = Embedding {
model: "m".into(),
dtype: Dtype::F64,
dim: 3,
vector: Bytes::from(bytes),
};
let mut idx = BruteForceVectorIndex::empty("m", 3);
assert!(idx.try_insert(NodeId::from_bytes_raw([1u8; 16]), &embed));
let hits = idx.search(&[1.0, 0.0, 0.0], 1).unwrap();
assert!((hits[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn i8_embeddings_are_indexed() {
let bytes: Vec<u8> = vec![127, 0, 0].into_iter().map(|v: i8| v as u8).collect();
let embed = Embedding {
model: "m".into(),
dtype: Dtype::I8,
dim: 3,
vector: Bytes::from(bytes),
};
let mut idx = BruteForceVectorIndex::empty("m", 3);
assert!(idx.try_insert(NodeId::from_bytes_raw([1u8; 16]), &embed));
let hits = idx.search(&[1.0, 0.0, 0.0], 1).unwrap();
assert!((hits[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn f16_embeddings_are_indexed() {
let bytes: Vec<u8> = vec![0x00, 0x3C, 0x00, 0x00];
let embed = Embedding {
model: "m".into(),
dtype: Dtype::F16,
dim: 2,
vector: Bytes::from(bytes),
};
let mut idx = BruteForceVectorIndex::empty("m", 2);
assert!(idx.try_insert(NodeId::from_bytes_raw([1u8; 16]), &embed));
let hits = idx.search(&[1.0, 0.0], 1).unwrap();
assert!((hits[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn build_from_repo_indexes_only_matching_model() {
let (bs, ohs) = stores();
let repo = ReadonlyRepo::init(bs, ohs).unwrap();
let mut tx = repo.start_transaction();
let mut add = |id: [u8; 16], model: &str, v: &[f32]| {
let node = Node::new(NodeId::from_bytes_raw(id), "Doc");
let cid = tx.add_node(&node).unwrap();
let emb = f32_embed(model, v);
tx.set_embedding(cid, emb.model.clone(), emb).unwrap();
};
add([1u8; 16], "mA", &[1.0, 0.0]);
add([2u8; 16], "mA", &[0.0, 1.0]);
add([3u8; 16], "mB", &[1.0, 0.0]);
tx.add_node(&Node::new(NodeId::from_bytes_raw([4u8; 16]), "Doc")) .unwrap();
let repo = tx.commit("t", "seed").unwrap();
let idx = BruteForceVectorIndex::build_from_repo(&repo, "mA").unwrap();
assert_eq!(idx.len(), 2);
assert_eq!(idx.dim(), 2);
assert_eq!(idx.model(), "mA");
let hits = idx.search(&[1.0, 0.0], 2).unwrap();
assert_eq!(hits[0].node_id, NodeId::from_bytes_raw([1u8; 16]));
}
#[test]
fn build_for_absent_model_returns_empty_index() {
let (bs, ohs) = stores();
let repo = ReadonlyRepo::init(bs, ohs).unwrap();
let mut tx = repo.start_transaction();
let cid = tx
.add_node(&Node::new(NodeId::from_bytes_raw([1u8; 16]), "Doc"))
.unwrap();
let emb = f32_embed("mA", &[1.0, 0.0]);
tx.set_embedding(cid, emb.model.clone(), emb).unwrap();
let repo = tx.commit("t", "seed").unwrap();
let idx = BruteForceVectorIndex::build_from_repo(&repo, "unknown").unwrap();
assert!(idx.is_empty());
assert_eq!(idx.model(), "unknown");
}
#[test]
fn build_on_empty_repo_errors() {
let (bs, ohs) = stores();
let repo = ReadonlyRepo::init(bs, ohs).unwrap();
let err = BruteForceVectorIndex::build_from_repo(&repo, "mA").unwrap_err();
match err {
Error::Repo(RepoError::Uninitialized) => {}
e => panic!("expected Uninitialized, got {e:?}"),
}
}
#[test]
fn determinism_same_repo_same_results() {
let build = || {
let (bs, ohs) = stores();
let repo = ReadonlyRepo::init(bs, ohs).unwrap();
let mut tx = repo.start_transaction();
for i in 0..5u8 {
let cid = tx
.add_node(&Node::new(NodeId::from_bytes_raw([i; 16]), "Doc"))
.unwrap();
let emb = f32_embed("m", &[f32::from(i), 1.0]);
tx.set_embedding(cid, emb.model.clone(), emb).unwrap();
}
let repo = tx.commit("t", "seed").unwrap();
let idx = BruteForceVectorIndex::build_from_repo(&repo, "m").unwrap();
idx.search(&[2.0, 1.0], 3).unwrap()
};
let a = build();
let b = build();
assert_eq!(a, b, "same inputs -> byte-identical hit list");
}
#[test]
fn index_reads_embedding_from_sidecar() {
let (bs, ohs) = stores();
let repo = ReadonlyRepo::init(bs, ohs).unwrap();
let mut tx = repo.start_transaction();
let node = Node::new(NodeId::from_bytes_raw([1u8; 16]), "Doc");
let node_cid = tx.add_node(&node).unwrap();
let emb = f32_embed("mA", &[1.0, 0.0, 0.0]);
tx.set_embedding(node_cid, "mA".into(), emb).unwrap();
let repo = tx.commit("t", "seed via sidecar").unwrap();
let idx = BruteForceVectorIndex::build_from_repo(&repo, "mA").unwrap();
assert_eq!(idx.len(), 1, "sidecar embedding must surface in the index");
assert_eq!(idx.dim(), 3);
let hits = idx.search(&[1.0, 0.0, 0.0], 1).unwrap();
assert_eq!(hits[0].node_id, NodeId::from_bytes_raw([1u8; 16]));
assert!((hits[0].score - 1.0).abs() < 1e-5);
}
}