1use std::sync::Arc;
49
50use bytes::Bytes;
51
52use crate::error::{Error, RepoError};
53use crate::id::NodeId;
54use crate::objects::{Dtype, Embedding, Node};
55use crate::prolly::Cursor;
56use crate::repo::readonly::{ReadonlyRepo, decode_from_store};
57use crate::store::Blockstore;
58
59#[derive(Clone, Debug, PartialEq)]
69#[non_exhaustive]
70pub struct VectorHit {
71 pub node_id: NodeId,
73 pub score: f32,
75}
76
77impl VectorHit {
78 #[must_use]
81 pub const fn new(node_id: NodeId, score: f32) -> Self {
82 Self { node_id, score }
83 }
84}
85
86pub trait VectorIndex: Send + Sync {
92 fn model(&self) -> &str;
94
95 fn dim(&self) -> u32;
97
98 fn search(&self, query: &[f32], k: usize) -> Result<Vec<VectorHit>, Error>;
104
105 fn len(&self) -> usize;
107
108 fn is_empty(&self) -> bool {
110 self.len() == 0
111 }
112}
113
114#[derive(Debug, Clone)]
125pub struct BruteForceVectorIndex {
126 model: String,
127 dim: u32,
128 ids: Vec<NodeId>,
129 data: Vec<f32>,
131}
132
133impl BruteForceVectorIndex {
134 #[must_use]
140 pub fn empty(model: impl Into<String>, dim: u32) -> Self {
141 Self {
142 model: model.into(),
143 dim,
144 ids: Vec::new(),
145 data: Vec::new(),
146 }
147 }
148
149 #[must_use]
155 pub fn model(&self) -> &str {
156 &self.model
157 }
158
159 #[must_use]
162 pub const fn dim(&self) -> u32 {
163 self.dim
164 }
165
166 #[must_use]
168 pub fn is_empty(&self) -> bool {
169 self.ids.is_empty()
170 }
171
172 pub fn points_iter(&self) -> impl Iterator<Item = (NodeId, &[f32])> + '_ {
183 let row_len = self.dim as usize;
184 self.ids.iter().enumerate().map(move |(i, id)| {
185 let slice = if row_len == 0 {
188 &[][..]
189 } else {
190 &self.data[i * row_len..(i + 1) * row_len]
191 };
192 (*id, slice)
193 })
194 }
195
196 pub fn try_insert(&mut self, node_id: NodeId, embed: &Embedding) -> bool {
203 if embed.model != self.model {
204 return false;
205 }
206 if embed.dim != self.dim {
207 return false;
208 }
209 let Some(vec_f32) = decode_to_f32(embed) else {
210 return false;
211 };
212 let normalised = normalise(vec_f32);
213 self.ids.push(node_id);
214 self.data.extend_from_slice(&normalised);
215 true
216 }
217
218 pub fn build_from_repo(repo: &ReadonlyRepo, model: &str) -> Result<Self, Error> {
240 let bs: Arc<dyn Blockstore> = repo.blockstore().clone();
241 let Some(commit) = repo.head_commit() else {
242 return Err(RepoError::Uninitialized.into());
243 };
244
245 let mut idx: Option<Self> = None;
250 let debug = std::env::var("MNEM_DEBUG_VEC").is_ok();
251 let mut dbg_total = 0usize;
252 let mut dbg_has_embed = 0usize;
253 let mut dbg_inserted = 0usize;
254 let cursor = Cursor::new(&*bs, &commit.nodes)?;
255 for entry in cursor {
256 let (_k, node_cid) = entry?;
257 let node: Node = decode_from_store(&*bs, &node_cid)?;
258 dbg_total += 1;
259
260 let Some(embed) = repo.embedding_for(&node_cid, model)? else {
264 continue;
265 };
266 dbg_has_embed += 1;
267 if debug && dbg_has_embed <= 3 {
268 eprintln!(
269 "[mnem-debug-vec] node embed.model={:?} want={:?} dim={}",
270 embed.model, model, embed.dim,
271 );
272 }
273 embed.validate()?;
274 let ok = match idx.as_mut() {
275 Some(existing) => existing.try_insert(node.id, &embed),
276 None => {
277 let mut fresh = Self::empty(model, embed.dim);
278 let ok = fresh.try_insert(node.id, &embed);
279 idx = Some(fresh);
280 ok
281 }
282 };
283 if ok {
284 dbg_inserted += 1;
285 }
286 }
287 if debug {
288 eprintln!(
289 "[mnem-debug-vec] total={dbg_total} has_embed={dbg_has_embed} \
290 inserted={dbg_inserted} idx_dim={}",
291 idx.as_ref().map_or(0, |i| i.dim)
292 );
293 }
294 Ok(idx.unwrap_or_else(|| Self::empty(model, 0)))
297 }
298}
299
300impl VectorIndex for BruteForceVectorIndex {
301 fn model(&self) -> &str {
302 &self.model
303 }
304
305 fn dim(&self) -> u32 {
306 self.dim
307 }
308
309 fn search(&self, query: &[f32], k: usize) -> Result<Vec<VectorHit>, Error> {
310 if self.dim == 0 && self.ids.is_empty() {
318 return Ok(Vec::new());
319 }
320 if query.len() != self.dim as usize {
321 return Err(RepoError::VectorDimMismatch {
322 index_dim: self.dim,
323 query_dim: query.len(),
324 }
325 .into());
326 }
327 if k == 0 || self.ids.is_empty() {
328 return Ok(Vec::new());
329 }
330
331 let q_norm = normalise(query.to_vec());
332 let row_len = self.dim as usize;
333 let mut hits: Vec<VectorHit> = Vec::with_capacity(self.ids.len());
334 for (i, id) in self.ids.iter().enumerate() {
335 let row = &self.data[i * row_len..(i + 1) * row_len];
336 let score = dot(&q_norm, row);
337 hits.push(VectorHit {
338 node_id: *id,
339 score,
340 });
341 }
342 hits.sort_by(|a, b| {
344 b.score
345 .partial_cmp(&a.score)
346 .unwrap_or(std::cmp::Ordering::Equal)
347 .then_with(|| a.node_id.cmp(&b.node_id))
348 });
349 hits.truncate(k);
350 Ok(hits)
351 }
352
353 fn len(&self) -> usize {
354 self.ids.len()
355 }
356}
357
358fn decode_to_f32(embed: &Embedding) -> Option<Vec<f32>> {
366 let dim = embed.dim as usize;
367 let bytes: &Bytes = &embed.vector;
368 if bytes.len() != dim * embed.dtype.byte_width() {
369 return None;
370 }
371 match embed.dtype {
372 Dtype::F32 => {
373 let mut out = Vec::with_capacity(dim);
374 for chunk in bytes.chunks_exact(4) {
375 out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
376 }
377 Some(out)
378 }
379 Dtype::F64 => {
380 let mut out = Vec::with_capacity(dim);
381 for chunk in bytes.chunks_exact(8) {
382 let raw = f64::from_le_bytes([
383 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
384 ]);
385 out.push(raw as f32);
386 }
387 Some(out)
388 }
389 Dtype::F16 => {
390 let mut out = Vec::with_capacity(dim);
393 for chunk in bytes.chunks_exact(2) {
394 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
395 out.push(f16_bits_to_f32(bits));
396 }
397 Some(out)
398 }
399 Dtype::I8 => {
400 let mut out = Vec::with_capacity(dim);
404 for &b in bytes {
405 out.push(f32::from(i8::from_ne_bytes([b])));
406 }
407 Some(out)
408 }
409 }
410}
411
412fn f16_bits_to_f32(bits: u16) -> f32 {
415 let sign = u32::from(bits >> 15) << 31;
416 let exp = u32::from((bits >> 10) & 0x1F);
417 let mant = u32::from(bits & 0x3FF);
418 let out_bits = if exp == 0 {
419 if mant == 0 {
420 sign
421 } else {
422 let mut m = mant;
426 let mut e: u32 = 127 - 15 + 1;
427 while (m & 0x400) == 0 {
428 m <<= 1;
429 e = e.saturating_sub(1);
430 }
431 m &= 0x3FF;
432 sign | (e << 23) | (m << 13)
433 }
434 } else if exp == 31 {
435 sign | 0x7F80_0000 | (mant << 13)
437 } else {
438 let e = exp + (127 - 15);
439 sign | (e << 23) | (mant << 13)
440 };
441 f32::from_bits(out_bits)
442}
443
444fn normalise(mut v: Vec<f32>) -> Vec<f32> {
447 let norm = dot(&v, &v).sqrt();
448 if norm > 0.0 && norm.is_finite() {
449 for x in &mut v {
450 *x /= norm;
451 }
452 }
453 v
454}
455
456fn dot(a: &[f32], b: &[f32]) -> f32 {
459 debug_assert_eq!(a.len(), b.len());
460 let mut acc = 0.0f32;
461 for i in 0..a.len() {
462 acc += a[i] * b[i];
463 }
464 acc
465}
466
467#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::objects::{Dtype, Embedding, Node};
475 use crate::repo::ReadonlyRepo;
476 use crate::store::{MemoryBlockstore, MemoryOpHeadsStore, OpHeadsStore};
477 use std::sync::Arc;
478
479 fn stores() -> (Arc<dyn Blockstore>, Arc<dyn OpHeadsStore>) {
480 (
481 Arc::new(MemoryBlockstore::new()),
482 Arc::new(MemoryOpHeadsStore::new()),
483 )
484 }
485
486 fn f32_embed(model: &str, v: &[f32]) -> Embedding {
487 let mut bytes = Vec::with_capacity(v.len() * 4);
488 for x in v {
489 bytes.extend_from_slice(&x.to_le_bytes());
490 }
491 Embedding {
492 model: model.to_string(),
493 dtype: Dtype::F32,
494 dim: v.len() as u32,
495 vector: Bytes::from(bytes),
496 }
497 }
498
499 #[test]
502 fn normalise_unit_vector_is_unchanged() {
503 let v = normalise(vec![1.0, 0.0, 0.0]);
504 assert!((dot(&v, &v) - 1.0).abs() < 1e-6);
505 }
506
507 #[test]
508 fn normalise_scales_to_unit_length() {
509 let v = normalise(vec![3.0, 4.0]);
510 assert!((dot(&v, &v) - 1.0).abs() < 1e-6);
511 }
512
513 #[test]
514 fn normalise_zero_vector_stays_zero() {
515 let v = normalise(vec![0.0, 0.0, 0.0]);
516 assert_eq!(v, vec![0.0, 0.0, 0.0]);
517 }
518
519 #[test]
520 fn f16_round_trip_for_common_values() {
521 assert!((f16_bits_to_f32(0x3C00) - 1.0).abs() < 1e-6);
523 assert!((f16_bits_to_f32(0xBC00) + 1.0).abs() < 1e-6);
525 assert_eq!(f16_bits_to_f32(0x0000), 0.0);
527 assert_eq!(f16_bits_to_f32(0x8000), -0.0);
528 assert!(f16_bits_to_f32(0x7C00).is_infinite());
530 }
531
532 #[test]
535 fn empty_index_returns_no_hits() {
536 let idx = BruteForceVectorIndex::empty("m", 4);
537 let hits = idx.search(&[0.0, 0.0, 0.0, 0.0], 5).unwrap();
538 assert!(hits.is_empty());
539 assert_eq!(idx.len(), 0);
540 assert!(idx.is_empty());
541 }
542
543 #[test]
544 fn k_zero_returns_no_hits() {
545 let mut idx = BruteForceVectorIndex::empty("m", 3);
546 idx.try_insert(
547 NodeId::from_bytes_raw([1u8; 16]),
548 &f32_embed("m", &[1.0, 0.0, 0.0]),
549 );
550 let hits = idx.search(&[1.0, 0.0, 0.0], 0).unwrap();
551 assert!(hits.is_empty());
552 }
553
554 #[test]
557 fn dim_mismatch_errors_with_both_sides() {
558 let idx = BruteForceVectorIndex::empty("m", 4);
559 let err = idx.search(&[0.0, 0.0, 0.0], 3).unwrap_err();
560 match err {
561 Error::Repo(RepoError::VectorDimMismatch {
562 index_dim,
563 query_dim,
564 }) => {
565 assert_eq!(index_dim, 4);
566 assert_eq!(query_dim, 3);
567 }
568 e => panic!("expected VectorDimMismatch, got {e:?}"),
569 }
570 }
571
572 #[test]
575 fn wrong_model_is_silently_skipped_on_insert() {
576 let mut idx = BruteForceVectorIndex::empty("mA", 3);
577 let inserted = idx.try_insert(
578 NodeId::from_bytes_raw([1u8; 16]),
579 &f32_embed("mB", &[1.0, 0.0, 0.0]),
580 );
581 assert!(!inserted);
582 assert!(idx.is_empty());
583 }
584
585 #[test]
586 fn wrong_dim_is_silently_skipped_on_insert() {
587 let mut idx = BruteForceVectorIndex::empty("m", 3);
588 let inserted = idx.try_insert(
589 NodeId::from_bytes_raw([1u8; 16]),
590 &f32_embed("m", &[1.0, 0.0]),
591 );
592 assert!(!inserted);
593 }
594
595 #[test]
598 fn nearest_neighbour_wins() {
599 let mut idx = BruteForceVectorIndex::empty("m", 3);
600 idx.try_insert(
601 NodeId::from_bytes_raw([1u8; 16]),
602 &f32_embed("m", &[1.0, 0.0, 0.0]),
603 );
604 idx.try_insert(
605 NodeId::from_bytes_raw([2u8; 16]),
606 &f32_embed("m", &[0.0, 1.0, 0.0]),
607 );
608 idx.try_insert(
609 NodeId::from_bytes_raw([3u8; 16]),
610 &f32_embed("m", &[0.0, 0.0, 1.0]),
611 );
612 let hits = idx.search(&[0.9, 0.1, 0.0], 3).unwrap();
613 assert_eq!(hits[0].node_id, NodeId::from_bytes_raw([1u8; 16]));
614 assert_eq!(hits[1].node_id, NodeId::from_bytes_raw([2u8; 16]));
616 assert_eq!(hits[2].node_id, NodeId::from_bytes_raw([3u8; 16]));
618 assert!((hits[2].score).abs() < 1e-6);
619 }
620
621 #[test]
622 fn scale_invariance_cosine_similarity() {
623 let mut idx = BruteForceVectorIndex::empty("m", 3);
626 idx.try_insert(
627 NodeId::from_bytes_raw([1u8; 16]),
628 &f32_embed("m", &[10.0, 0.0, 0.0]),
629 );
630 let hits = idx.search(&[0.5, 0.0, 0.0], 1).unwrap();
631 assert!((hits[0].score - 1.0).abs() < 1e-5);
632 }
633
634 #[test]
635 fn k_truncates_results() {
636 let mut idx = BruteForceVectorIndex::empty("m", 2);
637 for i in 0..20u8 {
638 idx.try_insert(
639 NodeId::from_bytes_raw([i; 16]),
640 &f32_embed("m", &[f32::from(i), 1.0]),
641 );
642 }
643 let hits = idx.search(&[1.0, 1.0], 5).unwrap();
644 assert_eq!(hits.len(), 5);
645 }
646
647 #[test]
648 fn ties_broken_by_node_id_ascending() {
649 let mut idx = BruteForceVectorIndex::empty("m", 2);
650 let hi = NodeId::from_bytes_raw([0xFFu8; 16]);
651 let lo = NodeId::from_bytes_raw([0x01u8; 16]);
652 idx.try_insert(hi, &f32_embed("m", &[1.0, 0.0]));
653 idx.try_insert(lo, &f32_embed("m", &[1.0, 0.0]));
654 let hits = idx.search(&[1.0, 0.0], 2).unwrap();
655 assert_eq!(hits[0].node_id, lo);
656 assert_eq!(hits[1].node_id, hi);
657 }
658
659 #[test]
662 fn f64_embeddings_are_indexed() {
663 let mut bytes = Vec::new();
664 for x in &[1.0f64, 0.0, 0.0] {
665 bytes.extend_from_slice(&x.to_le_bytes());
666 }
667 let embed = Embedding {
668 model: "m".into(),
669 dtype: Dtype::F64,
670 dim: 3,
671 vector: Bytes::from(bytes),
672 };
673 let mut idx = BruteForceVectorIndex::empty("m", 3);
674 assert!(idx.try_insert(NodeId::from_bytes_raw([1u8; 16]), &embed));
675 let hits = idx.search(&[1.0, 0.0, 0.0], 1).unwrap();
676 assert!((hits[0].score - 1.0).abs() < 1e-5);
677 }
678
679 #[test]
680 fn i8_embeddings_are_indexed() {
681 let bytes: Vec<u8> = vec![127, 0, 0].into_iter().map(|v: i8| v as u8).collect();
682 let embed = Embedding {
683 model: "m".into(),
684 dtype: Dtype::I8,
685 dim: 3,
686 vector: Bytes::from(bytes),
687 };
688 let mut idx = BruteForceVectorIndex::empty("m", 3);
689 assert!(idx.try_insert(NodeId::from_bytes_raw([1u8; 16]), &embed));
690 let hits = idx.search(&[1.0, 0.0, 0.0], 1).unwrap();
691 assert!((hits[0].score - 1.0).abs() < 1e-5);
693 }
694
695 #[test]
696 fn f16_embeddings_are_indexed() {
697 let bytes: Vec<u8> = vec![0x00, 0x3C, 0x00, 0x00];
699 let embed = Embedding {
700 model: "m".into(),
701 dtype: Dtype::F16,
702 dim: 2,
703 vector: Bytes::from(bytes),
704 };
705 let mut idx = BruteForceVectorIndex::empty("m", 2);
706 assert!(idx.try_insert(NodeId::from_bytes_raw([1u8; 16]), &embed));
707 let hits = idx.search(&[1.0, 0.0], 1).unwrap();
708 assert!((hits[0].score - 1.0).abs() < 1e-5);
709 }
710
711 #[test]
714 fn build_from_repo_indexes_only_matching_model() {
715 let (bs, ohs) = stores();
716 let repo = ReadonlyRepo::init(bs, ohs).unwrap();
717 let mut tx = repo.start_transaction();
718
719 let mut add = |id: [u8; 16], model: &str, v: &[f32]| {
720 let node = Node::new(NodeId::from_bytes_raw(id), "Doc");
721 let cid = tx.add_node(&node).unwrap();
722 let emb = f32_embed(model, v);
723 tx.set_embedding(cid, emb.model.clone(), emb).unwrap();
724 };
725 add([1u8; 16], "mA", &[1.0, 0.0]);
726 add([2u8; 16], "mA", &[0.0, 1.0]);
727 add([3u8; 16], "mB", &[1.0, 0.0]);
728 tx.add_node(&Node::new(NodeId::from_bytes_raw([4u8; 16]), "Doc")) .unwrap();
730 let repo = tx.commit("t", "seed").unwrap();
731
732 let idx = BruteForceVectorIndex::build_from_repo(&repo, "mA").unwrap();
733 assert_eq!(idx.len(), 2);
734 assert_eq!(idx.dim(), 2);
735 assert_eq!(idx.model(), "mA");
736
737 let hits = idx.search(&[1.0, 0.0], 2).unwrap();
738 assert_eq!(hits[0].node_id, NodeId::from_bytes_raw([1u8; 16]));
739 }
740
741 #[test]
742 fn build_for_absent_model_returns_empty_index() {
743 let (bs, ohs) = stores();
744 let repo = ReadonlyRepo::init(bs, ohs).unwrap();
745 let mut tx = repo.start_transaction();
746 let cid = tx
747 .add_node(&Node::new(NodeId::from_bytes_raw([1u8; 16]), "Doc"))
748 .unwrap();
749 let emb = f32_embed("mA", &[1.0, 0.0]);
750 tx.set_embedding(cid, emb.model.clone(), emb).unwrap();
751 let repo = tx.commit("t", "seed").unwrap();
752
753 let idx = BruteForceVectorIndex::build_from_repo(&repo, "unknown").unwrap();
754 assert!(idx.is_empty());
755 assert_eq!(idx.model(), "unknown");
756 }
757
758 #[test]
759 fn build_on_empty_repo_errors() {
760 let (bs, ohs) = stores();
761 let repo = ReadonlyRepo::init(bs, ohs).unwrap();
762 let err = BruteForceVectorIndex::build_from_repo(&repo, "mA").unwrap_err();
763 match err {
764 Error::Repo(RepoError::Uninitialized) => {}
765 e => panic!("expected Uninitialized, got {e:?}"),
766 }
767 }
768
769 #[test]
770 fn determinism_same_repo_same_results() {
771 let build = || {
772 let (bs, ohs) = stores();
773 let repo = ReadonlyRepo::init(bs, ohs).unwrap();
774 let mut tx = repo.start_transaction();
775 for i in 0..5u8 {
776 let cid = tx
777 .add_node(&Node::new(NodeId::from_bytes_raw([i; 16]), "Doc"))
778 .unwrap();
779 let emb = f32_embed("m", &[f32::from(i), 1.0]);
780 tx.set_embedding(cid, emb.model.clone(), emb).unwrap();
781 }
782 let repo = tx.commit("t", "seed").unwrap();
783 let idx = BruteForceVectorIndex::build_from_repo(&repo, "m").unwrap();
784 idx.search(&[2.0, 1.0], 3).unwrap()
785 };
786 let a = build();
787 let b = build();
788 assert_eq!(a, b, "same inputs -> byte-identical hit list");
789 }
790
791 #[test]
799 fn index_reads_embedding_from_sidecar() {
800 let (bs, ohs) = stores();
801 let repo = ReadonlyRepo::init(bs, ohs).unwrap();
802 let mut tx = repo.start_transaction();
803
804 let node = Node::new(NodeId::from_bytes_raw([1u8; 16]), "Doc");
808 let node_cid = tx.add_node(&node).unwrap();
809 let emb = f32_embed("mA", &[1.0, 0.0, 0.0]);
810 tx.set_embedding(node_cid, "mA".into(), emb).unwrap();
811 let repo = tx.commit("t", "seed via sidecar").unwrap();
812
813 let idx = BruteForceVectorIndex::build_from_repo(&repo, "mA").unwrap();
814 assert_eq!(idx.len(), 1, "sidecar embedding must surface in the index");
815 assert_eq!(idx.dim(), 3);
816 let hits = idx.search(&[1.0, 0.0, 0.0], 1).unwrap();
817 assert_eq!(hits[0].node_id, NodeId::from_bytes_raw([1u8; 16]));
818 assert!((hits[0].score - 1.0).abs() < 1e-5);
819 }
820}