nodedb_vector/multivec/
storage.rs1use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct MultiVectorDoc {
15 pub doc_id: u32,
16 pub vectors: Vec<Vec<f32>>,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum MultiVecMode {
24 PerToken,
26 MetaToken { k: u8 },
28}
29
30pub struct MultiVectorStore {
32 pub dim: usize,
33 pub mode: MultiVecMode,
34 docs: HashMap<u32, MultiVectorDoc>,
35}
36
37#[derive(Debug, thiserror::Error)]
39#[non_exhaustive]
40pub enum MultivecError {
41 #[error("dim mismatch: expected {expected}, got {actual}")]
42 DimMismatch { expected: usize, actual: usize },
43 #[error("meta-token count mismatch: expected k={expected}, got {actual}")]
44 MetaTokenCountMismatch { expected: u8, actual: usize },
45}
46
47impl MultiVectorStore {
48 pub fn new(dim: usize, mode: MultiVecMode) -> Self {
50 Self {
51 dim,
52 mode,
53 docs: HashMap::new(),
54 }
55 }
56
57 pub fn insert(&mut self, doc: MultiVectorDoc) -> Result<(), MultivecError> {
59 for v in &doc.vectors {
61 if v.len() != self.dim {
62 return Err(MultivecError::DimMismatch {
63 expected: self.dim,
64 actual: v.len(),
65 });
66 }
67 }
68
69 if let MultiVecMode::MetaToken { k } = self.mode
71 && doc.vectors.len() != k as usize
72 {
73 return Err(MultivecError::MetaTokenCountMismatch {
74 expected: k,
75 actual: doc.vectors.len(),
76 });
77 }
78
79 self.docs.insert(doc.doc_id, doc);
80 Ok(())
81 }
82
83 pub fn get(&self, doc_id: u32) -> Option<&MultiVectorDoc> {
85 self.docs.get(&doc_id)
86 }
87
88 pub fn len(&self) -> usize {
90 self.docs.len()
91 }
92
93 pub fn is_empty(&self) -> bool {
95 self.docs.is_empty()
96 }
97
98 pub fn iter(&self) -> impl Iterator<Item = &MultiVectorDoc> {
100 self.docs.values()
101 }
102
103 pub fn k(&self) -> Option<u8> {
105 match self.mode {
106 MultiVecMode::MetaToken { k } => Some(k),
107 MultiVecMode::PerToken => None,
108 }
109 }
110}
111
112#[cfg(test)]
117mod tests {
118 use super::*;
119
120 fn make_doc(doc_id: u32, n_vecs: usize, dim: usize) -> MultiVectorDoc {
121 MultiVectorDoc {
122 doc_id,
123 vectors: (0..n_vecs).map(|_| vec![0.0f32; dim]).collect(),
124 }
125 }
126
127 #[test]
128 fn insert_per_token_valid() {
129 let mut store = MultiVectorStore::new(4, MultiVecMode::PerToken);
130 let doc = make_doc(1, 5, 4);
131 assert!(store.insert(doc).is_ok());
132 assert_eq!(store.len(), 1);
133 }
134
135 #[test]
136 fn insert_dim_mismatch() {
137 let mut store = MultiVectorStore::new(4, MultiVecMode::PerToken);
138 let doc = MultiVectorDoc {
139 doc_id: 2,
140 vectors: vec![vec![0.0f32; 3]], };
142 let err = store.insert(doc).unwrap_err();
143 assert!(matches!(
144 err,
145 MultivecError::DimMismatch {
146 expected: 4,
147 actual: 3
148 }
149 ));
150 }
151
152 #[test]
153 fn insert_meta_token_valid() {
154 let mut store = MultiVectorStore::new(8, MultiVecMode::MetaToken { k: 4 });
155 let doc = make_doc(10, 4, 8);
156 assert!(store.insert(doc).is_ok());
157 assert_eq!(store.k(), Some(4));
158 }
159
160 #[test]
161 fn insert_meta_token_count_mismatch() {
162 let mut store = MultiVectorStore::new(8, MultiVecMode::MetaToken { k: 4 });
163 let doc = make_doc(10, 3, 8); let err = store.insert(doc).unwrap_err();
165 assert!(matches!(
166 err,
167 MultivecError::MetaTokenCountMismatch {
168 expected: 4,
169 actual: 3
170 }
171 ));
172 }
173
174 #[test]
175 fn get_returns_inserted_doc() {
176 let mut store = MultiVectorStore::new(2, MultiVecMode::PerToken);
177 store.insert(make_doc(42, 2, 2)).unwrap();
178 let doc = store.get(42).expect("doc should be present");
179 assert_eq!(doc.doc_id, 42);
180 }
181
182 #[test]
183 fn iter_yields_all_docs() {
184 let mut store = MultiVectorStore::new(2, MultiVecMode::PerToken);
185 for id in 0..5u32 {
186 store.insert(make_doc(id, 1, 2)).unwrap();
187 }
188 assert_eq!(store.iter().count(), 5);
189 }
190
191 #[test]
192 fn k_none_for_per_token() {
193 let store = MultiVectorStore::new(4, MultiVecMode::PerToken);
194 assert_eq!(store.k(), None);
195 }
196}