Skip to main content

nodedb_vector/multivec/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Variable-length multi-vector document storage with Meta Token mode.
4//!
5//! `MultiVectorStore` holds a collection of `MultiVectorDoc` entries where
6//! each document carries one or more embedding vectors.  In `MetaToken` mode
7//! every document has exactly `k` vectors (MetaEmbed learnable summary tokens);
8//! in `PerToken` mode the count is unconstrained (naive ColBERT).
9
10use std::collections::HashMap;
11
12/// A document represented by one or more embedding vectors.
13#[derive(Debug, Clone)]
14pub struct MultiVectorDoc {
15    pub doc_id: u32,
16    /// Variable-length list of vectors.  For Meta Token mode this has fixed length K.
17    pub vectors: Vec<Vec<f32>>,
18}
19
20/// Controls how many vectors a document may contain and what they represent.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum MultiVecMode {
24    /// Per-token: one vector per token (naive ColBERT, expensive).
25    PerToken,
26    /// Meta Token: K learnable summary vectors per document (MetaEmbed).
27    MetaToken { k: u8 },
28}
29
30/// In-memory store for multi-vector documents.
31pub struct MultiVectorStore {
32    pub dim: usize,
33    pub mode: MultiVecMode,
34    docs: HashMap<u32, MultiVectorDoc>,
35}
36
37/// Errors produced by `MultiVectorStore`.
38#[derive(Debug, thiserror::Error)]
39#[non_exhaustive]
40pub enum MultivecError {
41    #[error("dim mismatch: expected {expected}, got {actual}")]
42    DimMismatch { expected: usize, actual: usize },
43    #[error("meta-token count mismatch: expected k={expected}, got {actual}")]
44    MetaTokenCountMismatch { expected: u8, actual: usize },
45}
46
47impl MultiVectorStore {
48    /// Create a new store with the given embedding dimension and mode.
49    pub fn new(dim: usize, mode: MultiVecMode) -> Self {
50        Self {
51            dim,
52            mode,
53            docs: HashMap::new(),
54        }
55    }
56
57    /// Insert a document, validating dimensions and (for MetaToken mode) vector count.
58    pub fn insert(&mut self, doc: MultiVectorDoc) -> Result<(), MultivecError> {
59        // Validate each vector's dimension.
60        for v in &doc.vectors {
61            if v.len() != self.dim {
62                return Err(MultivecError::DimMismatch {
63                    expected: self.dim,
64                    actual: v.len(),
65                });
66            }
67        }
68
69        // In MetaToken mode the count must equal k exactly.
70        if let MultiVecMode::MetaToken { k } = self.mode
71            && doc.vectors.len() != k as usize
72        {
73            return Err(MultivecError::MetaTokenCountMismatch {
74                expected: k,
75                actual: doc.vectors.len(),
76            });
77        }
78
79        self.docs.insert(doc.doc_id, doc);
80        Ok(())
81    }
82
83    /// Look up a document by ID.
84    pub fn get(&self, doc_id: u32) -> Option<&MultiVectorDoc> {
85        self.docs.get(&doc_id)
86    }
87
88    /// Number of documents currently stored.
89    pub fn len(&self) -> usize {
90        self.docs.len()
91    }
92
93    /// Returns `true` if the store holds no documents.
94    pub fn is_empty(&self) -> bool {
95        self.docs.is_empty()
96    }
97
98    /// Iterate over all stored documents (order unspecified).
99    pub fn iter(&self) -> impl Iterator<Item = &MultiVectorDoc> {
100        self.docs.values()
101    }
102
103    /// Returns `Some(k)` for `MetaToken` mode; `None` for `PerToken`.
104    pub fn k(&self) -> Option<u8> {
105        match self.mode {
106            MultiVecMode::MetaToken { k } => Some(k),
107            MultiVecMode::PerToken => None,
108        }
109    }
110}
111
112// ---------------------------------------------------------------------------
113// Tests
114// ---------------------------------------------------------------------------
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    fn make_doc(doc_id: u32, n_vecs: usize, dim: usize) -> MultiVectorDoc {
121        MultiVectorDoc {
122            doc_id,
123            vectors: (0..n_vecs).map(|_| vec![0.0f32; dim]).collect(),
124        }
125    }
126
127    #[test]
128    fn insert_per_token_valid() {
129        let mut store = MultiVectorStore::new(4, MultiVecMode::PerToken);
130        let doc = make_doc(1, 5, 4);
131        assert!(store.insert(doc).is_ok());
132        assert_eq!(store.len(), 1);
133    }
134
135    #[test]
136    fn insert_dim_mismatch() {
137        let mut store = MultiVectorStore::new(4, MultiVecMode::PerToken);
138        let doc = MultiVectorDoc {
139            doc_id: 2,
140            vectors: vec![vec![0.0f32; 3]], // wrong dim
141        };
142        let err = store.insert(doc).unwrap_err();
143        assert!(matches!(
144            err,
145            MultivecError::DimMismatch {
146                expected: 4,
147                actual: 3
148            }
149        ));
150    }
151
152    #[test]
153    fn insert_meta_token_valid() {
154        let mut store = MultiVectorStore::new(8, MultiVecMode::MetaToken { k: 4 });
155        let doc = make_doc(10, 4, 8);
156        assert!(store.insert(doc).is_ok());
157        assert_eq!(store.k(), Some(4));
158    }
159
160    #[test]
161    fn insert_meta_token_count_mismatch() {
162        let mut store = MultiVectorStore::new(8, MultiVecMode::MetaToken { k: 4 });
163        let doc = make_doc(10, 3, 8); // 3 vectors but k=4
164        let err = store.insert(doc).unwrap_err();
165        assert!(matches!(
166            err,
167            MultivecError::MetaTokenCountMismatch {
168                expected: 4,
169                actual: 3
170            }
171        ));
172    }
173
174    #[test]
175    fn get_returns_inserted_doc() {
176        let mut store = MultiVectorStore::new(2, MultiVecMode::PerToken);
177        store.insert(make_doc(42, 2, 2)).unwrap();
178        let doc = store.get(42).expect("doc should be present");
179        assert_eq!(doc.doc_id, 42);
180    }
181
182    #[test]
183    fn iter_yields_all_docs() {
184        let mut store = MultiVectorStore::new(2, MultiVecMode::PerToken);
185        for id in 0..5u32 {
186            store.insert(make_doc(id, 1, 2)).unwrap();
187        }
188        assert_eq!(store.iter().count(), 5);
189    }
190
191    #[test]
192    fn k_none_for_per_token() {
193        let store = MultiVectorStore::new(4, MultiVecMode::PerToken);
194        assert_eq!(store.k(), None);
195    }
196}