1use std::ops::Bound;
4
5use manifoldb_core::EntityId;
6use manifoldb_storage::{Cursor, StorageEngine, Transaction};
7
8use crate::encoding::{
9 decode_embedding_entity_id, encode_embedding_key, encode_embedding_prefix,
10 PREFIX_MULTI_VECTOR_SPACE,
11};
12use crate::error::VectorError;
13use crate::types::{EmbeddingName, MultiVectorEmbedding, MultiVectorEmbeddingSpace};
14
15const TABLE_MULTI_VECTOR_SPACES: &str = "multi_vector_spaces";
17
18const TABLE_MULTI_VECTOR_EMBEDDINGS: &str = "multi_vector_embeddings";
20
21pub struct MultiVectorStore<E: StorageEngine> {
27 engine: E,
28}
29
30impl<E: StorageEngine> MultiVectorStore<E> {
31 #[must_use]
33 pub const fn new(engine: E) -> Self {
34 Self { engine }
35 }
36
37 pub fn create_space(&self, space: &MultiVectorEmbeddingSpace) -> Result<(), VectorError> {
43 let mut tx = self.engine.begin_write()?;
44
45 let key = encode_multi_vector_space_key(space.name());
46
47 if tx.get(TABLE_MULTI_VECTOR_SPACES, &key)?.is_some() {
49 return Err(VectorError::InvalidName(format!(
50 "multi-vector embedding space '{}' already exists",
51 space.name()
52 )));
53 }
54
55 tx.put(TABLE_MULTI_VECTOR_SPACES, &key, &space.to_bytes()?)?;
57 tx.commit()?;
58
59 Ok(())
60 }
61
62 pub fn get_space(
68 &self,
69 name: &EmbeddingName,
70 ) -> Result<MultiVectorEmbeddingSpace, VectorError> {
71 let tx = self.engine.begin_read()?;
72 let key = encode_multi_vector_space_key(name);
73
74 let bytes = tx
75 .get(TABLE_MULTI_VECTOR_SPACES, &key)?
76 .ok_or_else(|| VectorError::SpaceNotFound(name.to_string()))?;
77
78 MultiVectorEmbeddingSpace::from_bytes(&bytes)
79 }
80
81 pub fn delete_space(&self, name: &EmbeddingName) -> Result<(), VectorError> {
87 let mut tx = self.engine.begin_write()?;
88
89 let space_key = encode_multi_vector_space_key(name);
90
91 if tx.get(TABLE_MULTI_VECTOR_SPACES, &space_key)?.is_none() {
93 return Err(VectorError::SpaceNotFound(name.to_string()));
94 }
95
96 let prefix = encode_embedding_prefix(name);
98 let prefix_end = next_prefix(&prefix);
99
100 let mut keys_to_delete = Vec::new();
101 {
102 let cursor = tx.range(
103 TABLE_MULTI_VECTOR_EMBEDDINGS,
104 Bound::Included(prefix.as_slice()),
105 Bound::Excluded(prefix_end.as_slice()),
106 )?;
107
108 let mut cursor = cursor;
109 while let Some((key, _)) = cursor.next()? {
110 keys_to_delete.push(key);
111 }
112 }
113
114 for key in keys_to_delete {
115 tx.delete(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?;
116 }
117
118 tx.delete(TABLE_MULTI_VECTOR_SPACES, &space_key)?;
120
121 tx.commit()?;
122 Ok(())
123 }
124
125 pub fn list_spaces(&self) -> Result<Vec<MultiVectorEmbeddingSpace>, VectorError> {
131 let tx = self.engine.begin_read()?;
132
133 let prefix = vec![PREFIX_MULTI_VECTOR_SPACE];
134 let prefix_end = next_prefix(&prefix);
135
136 let mut cursor = tx.range(
137 TABLE_MULTI_VECTOR_SPACES,
138 Bound::Included(prefix.as_slice()),
139 Bound::Excluded(prefix_end.as_slice()),
140 )?;
141
142 let mut spaces = Vec::new();
143 while let Some((_, value)) = cursor.next()? {
144 spaces.push(MultiVectorEmbeddingSpace::from_bytes(&value)?);
145 }
146
147 Ok(spaces)
148 }
149
150 pub fn put(
159 &self,
160 entity_id: EntityId,
161 space_name: &EmbeddingName,
162 embedding: &MultiVectorEmbedding,
163 ) -> Result<(), VectorError> {
164 let space = self.get_space(space_name)?;
166
167 if embedding.dimension() != space.dimension() {
168 return Err(VectorError::DimensionMismatch {
169 expected: space.dimension(),
170 actual: embedding.dimension(),
171 });
172 }
173
174 let mut tx = self.engine.begin_write()?;
175
176 let key = encode_embedding_key(space_name, entity_id);
177
178 let mut bytes = Vec::new();
180 let dim = embedding.dimension() as u32;
181 bytes.extend_from_slice(&dim.to_le_bytes());
182 bytes.extend_from_slice(&embedding.to_bytes());
183
184 tx.put(TABLE_MULTI_VECTOR_EMBEDDINGS, &key, &bytes)?;
185
186 tx.commit()?;
187 Ok(())
188 }
189
190 pub fn get(
199 &self,
200 entity_id: EntityId,
201 space_name: &EmbeddingName,
202 ) -> Result<MultiVectorEmbedding, VectorError> {
203 let _ = self.get_space(space_name)?;
205
206 let tx = self.engine.begin_read()?;
207
208 let key = encode_embedding_key(space_name, entity_id);
209 let bytes = tx.get(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?.ok_or_else(|| {
210 VectorError::EmbeddingNotFound {
211 entity_id: entity_id.as_u64(),
212 space: space_name.to_string(),
213 }
214 })?;
215
216 if bytes.len() < 4 {
218 return Err(VectorError::Encoding("multi-vector data too short".to_string()));
219 }
220
221 let dim_bytes: [u8; 4] = bytes[..4]
222 .try_into()
223 .map_err(|_| VectorError::Encoding("failed to read dimension".to_string()))?;
224 let dimension = u32::from_le_bytes(dim_bytes) as usize;
225
226 MultiVectorEmbedding::from_bytes(&bytes[4..], dimension)
227 }
228
229 pub fn delete(
239 &self,
240 entity_id: EntityId,
241 space_name: &EmbeddingName,
242 ) -> Result<bool, VectorError> {
243 let mut tx = self.engine.begin_write()?;
244
245 let key = encode_embedding_key(space_name, entity_id);
246 let existed = tx.delete(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?;
247
248 tx.commit()?;
249 Ok(existed)
250 }
251
252 pub fn exists(
258 &self,
259 entity_id: EntityId,
260 space_name: &EmbeddingName,
261 ) -> Result<bool, VectorError> {
262 let tx = self.engine.begin_read()?;
263 let key = encode_embedding_key(space_name, entity_id);
264 Ok(tx.get(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?.is_some())
265 }
266
267 pub fn list_entities(&self, space_name: &EmbeddingName) -> Result<Vec<EntityId>, VectorError> {
273 let tx = self.engine.begin_read()?;
274
275 let prefix = encode_embedding_prefix(space_name);
276 let prefix_end = next_prefix(&prefix);
277
278 let mut cursor = tx.range(
279 TABLE_MULTI_VECTOR_EMBEDDINGS,
280 Bound::Included(prefix.as_slice()),
281 Bound::Excluded(prefix_end.as_slice()),
282 )?;
283
284 let mut entities = Vec::new();
285 while let Some((key, _)) = cursor.next()? {
286 if let Some(entity_id) = decode_embedding_entity_id(&key) {
287 entities.push(entity_id);
288 }
289 }
290
291 Ok(entities)
292 }
293
294 pub fn count(&self, space_name: &EmbeddingName) -> Result<usize, VectorError> {
300 let tx = self.engine.begin_read()?;
301
302 let prefix = encode_embedding_prefix(space_name);
303 let prefix_end = next_prefix(&prefix);
304
305 let mut cursor = tx.range(
306 TABLE_MULTI_VECTOR_EMBEDDINGS,
307 Bound::Included(prefix.as_slice()),
308 Bound::Excluded(prefix_end.as_slice()),
309 )?;
310
311 let mut count = 0;
312 while cursor.next()?.is_some() {
313 count += 1;
314 }
315
316 Ok(count)
317 }
318
319 pub fn get_many(
328 &self,
329 entity_ids: &[EntityId],
330 space_name: &EmbeddingName,
331 ) -> Result<Vec<(EntityId, Option<MultiVectorEmbedding>)>, VectorError> {
332 let space = self.get_space(space_name)?;
334 let dimension = space.dimension();
335
336 let tx = self.engine.begin_read()?;
337
338 let mut results = Vec::with_capacity(entity_ids.len());
339
340 for &entity_id in entity_ids {
341 let key = encode_embedding_key(space_name, entity_id);
342 let embedding = tx
343 .get(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?
344 .map(|bytes| {
345 if bytes.len() < 4 {
346 return Err(VectorError::Encoding(
347 "multi-vector data too short".to_string(),
348 ));
349 }
350 MultiVectorEmbedding::from_bytes(&bytes[4..], dimension)
351 })
352 .transpose()?;
353
354 results.push((entity_id, embedding));
355 }
356
357 Ok(results)
358 }
359
360 pub fn put_many(
371 &self,
372 embeddings: &[(EntityId, MultiVectorEmbedding)],
373 space_name: &EmbeddingName,
374 ) -> Result<(), VectorError> {
375 if embeddings.is_empty() {
376 return Ok(());
377 }
378
379 let space = self.get_space(space_name)?;
381
382 for (entity_id, embedding) in embeddings {
384 if embedding.dimension() != space.dimension() {
385 return Err(VectorError::DimensionMismatch {
386 expected: space.dimension(),
387 actual: embedding.dimension(),
388 });
389 }
390 let _ = entity_id;
391 }
392
393 let mut tx = self.engine.begin_write()?;
394
395 for (entity_id, embedding) in embeddings {
396 let key = encode_embedding_key(space_name, *entity_id);
397
398 let mut bytes = Vec::new();
399 let dim = embedding.dimension() as u32;
400 bytes.extend_from_slice(&dim.to_le_bytes());
401 bytes.extend_from_slice(&embedding.to_bytes());
402
403 tx.put(TABLE_MULTI_VECTOR_EMBEDDINGS, &key, &bytes)?;
404 }
405
406 tx.commit()?;
407 Ok(())
408 }
409}
410
411fn encode_multi_vector_space_key(name: &EmbeddingName) -> Vec<u8> {
413 let name_bytes = name.as_str().as_bytes();
414 let mut key = Vec::with_capacity(1 + name_bytes.len());
415 key.push(PREFIX_MULTI_VECTOR_SPACE);
416 key.extend_from_slice(name_bytes);
417 key
418}
419
420fn next_prefix(prefix: &[u8]) -> Vec<u8> {
422 let mut result = prefix.to_vec();
423
424 for byte in result.iter_mut().rev() {
425 if *byte < 0xFF {
426 *byte += 1;
427 return result;
428 }
429 }
430
431 result.push(0xFF);
432 result
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::distance::DistanceMetric;
439 use manifoldb_storage::backends::RedbEngine;
440 use std::sync::atomic::{AtomicUsize, Ordering};
441
442 static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0);
443
444 fn create_test_store() -> MultiVectorStore<RedbEngine> {
445 let engine = RedbEngine::in_memory().unwrap();
446 MultiVectorStore::new(engine)
447 }
448
449 fn unique_space_name() -> EmbeddingName {
450 let count = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
451 EmbeddingName::new(format!("multi_test_space_{}", count)).unwrap()
452 }
453
454 #[test]
455 fn create_and_get_space() {
456 let store = create_test_store();
457 let name = unique_space_name();
458 let space = MultiVectorEmbeddingSpace::new(name.clone(), 128, DistanceMetric::DotProduct);
459
460 store.create_space(&space).unwrap();
461
462 let retrieved = store.get_space(&name).unwrap();
463 assert_eq!(retrieved.dimension(), 128);
464 assert_eq!(retrieved.distance_metric(), DistanceMetric::DotProduct);
465 }
466
467 #[test]
468 fn create_duplicate_space_fails() {
469 let store = create_test_store();
470 let name = unique_space_name();
471 let space = MultiVectorEmbeddingSpace::new(name.clone(), 128, DistanceMetric::DotProduct);
472
473 store.create_space(&space).unwrap();
474 let result = store.create_space(&space);
475
476 assert!(result.is_err());
477 }
478
479 #[test]
480 fn put_and_get_multi_vector() {
481 let store = create_test_store();
482 let name = unique_space_name();
483 let space = MultiVectorEmbeddingSpace::new(name.clone(), 3, DistanceMetric::DotProduct);
484 store.create_space(&space).unwrap();
485
486 let embedding = MultiVectorEmbedding::new(vec![
487 vec![1.0, 2.0, 3.0],
488 vec![4.0, 5.0, 6.0],
489 vec![7.0, 8.0, 9.0],
490 ])
491 .unwrap();
492
493 store.put(EntityId::new(42), &name, &embedding).unwrap();
494
495 let retrieved = store.get(EntityId::new(42), &name).unwrap();
496 assert_eq!(retrieved.num_vectors(), 3);
497 assert_eq!(retrieved.dimension(), 3);
498 assert_eq!(retrieved.get_vector(0), Some([1.0, 2.0, 3.0].as_slice()));
499 assert_eq!(retrieved.get_vector(1), Some([4.0, 5.0, 6.0].as_slice()));
500 assert_eq!(retrieved.get_vector(2), Some([7.0, 8.0, 9.0].as_slice()));
501 }
502
503 #[test]
504 fn put_wrong_dimension_fails() {
505 let store = create_test_store();
506 let name = unique_space_name();
507 let space = MultiVectorEmbeddingSpace::new(name.clone(), 128, DistanceMetric::DotProduct);
508 store.create_space(&space).unwrap();
509
510 let embedding = MultiVectorEmbedding::new(vec![vec![1.0, 2.0, 3.0]]).unwrap();
511
512 let result = store.put(EntityId::new(1), &name, &embedding);
513 assert!(result.is_err());
514 match result.unwrap_err() {
515 VectorError::DimensionMismatch { expected, actual } => {
516 assert_eq!(expected, 128);
517 assert_eq!(actual, 3);
518 }
519 _ => panic!("unexpected error type"),
520 }
521 }
522
523 #[test]
524 fn delete_multi_vector() {
525 let store = create_test_store();
526 let name = unique_space_name();
527 let space = MultiVectorEmbeddingSpace::new(name.clone(), 3, DistanceMetric::DotProduct);
528 store.create_space(&space).unwrap();
529
530 let embedding = MultiVectorEmbedding::new(vec![vec![1.0, 2.0, 3.0]]).unwrap();
531 store.put(EntityId::new(1), &name, &embedding).unwrap();
532
533 assert!(store.exists(EntityId::new(1), &name).unwrap());
534 assert!(store.delete(EntityId::new(1), &name).unwrap());
535 assert!(!store.exists(EntityId::new(1), &name).unwrap());
536 }
537
538 #[test]
539 fn list_entities() {
540 let store = create_test_store();
541 let name = unique_space_name();
542 let space = MultiVectorEmbeddingSpace::new(name.clone(), 3, DistanceMetric::DotProduct);
543 store.create_space(&space).unwrap();
544
545 for i in 1..=5 {
546 let embedding =
547 MultiVectorEmbedding::new(vec![vec![i as f32, i as f32, i as f32]]).unwrap();
548 store.put(EntityId::new(i), &name, &embedding).unwrap();
549 }
550
551 let entities = store.list_entities(&name).unwrap();
552 assert_eq!(entities.len(), 5);
553 }
554
555 #[test]
556 fn count_embeddings() {
557 let store = create_test_store();
558 let name = unique_space_name();
559 let space = MultiVectorEmbeddingSpace::new(name.clone(), 3, DistanceMetric::DotProduct);
560 store.create_space(&space).unwrap();
561
562 assert_eq!(store.count(&name).unwrap(), 0);
563
564 for i in 1..=10 {
565 let embedding =
566 MultiVectorEmbedding::new(vec![vec![i as f32, i as f32, i as f32]]).unwrap();
567 store.put(EntityId::new(i), &name, &embedding).unwrap();
568 }
569
570 assert_eq!(store.count(&name).unwrap(), 10);
571 }
572
573 #[test]
574 fn variable_token_count() {
575 let store = create_test_store();
576 let name = unique_space_name();
577 let space = MultiVectorEmbeddingSpace::new(name.clone(), 4, DistanceMetric::DotProduct);
578 store.create_space(&space).unwrap();
579
580 let doc1 = MultiVectorEmbedding::new(vec![vec![1.0, 0.0, 0.0, 0.0]]).unwrap();
582
583 let doc2 =
584 MultiVectorEmbedding::new(vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]])
585 .unwrap();
586
587 let doc3 = MultiVectorEmbedding::new(vec![
588 vec![1.0, 0.0, 0.0, 0.0],
589 vec![0.0, 1.0, 0.0, 0.0],
590 vec![0.0, 0.0, 1.0, 0.0],
591 vec![0.0, 0.0, 0.0, 1.0],
592 ])
593 .unwrap();
594
595 store.put(EntityId::new(1), &name, &doc1).unwrap();
596 store.put(EntityId::new(2), &name, &doc2).unwrap();
597 store.put(EntityId::new(3), &name, &doc3).unwrap();
598
599 let retrieved1 = store.get(EntityId::new(1), &name).unwrap();
600 let retrieved2 = store.get(EntityId::new(2), &name).unwrap();
601 let retrieved3 = store.get(EntityId::new(3), &name).unwrap();
602
603 assert_eq!(retrieved1.num_vectors(), 1);
604 assert_eq!(retrieved2.num_vectors(), 2);
605 assert_eq!(retrieved3.num_vectors(), 4);
606 }
607}