manifoldb_vector/store/
multi_vector_store.rs

1//! Multi-vector store implementation for ColBERT-style embeddings.
2
3use std::ops::Bound;
4
5use manifoldb_core::EntityId;
6use manifoldb_storage::{Cursor, StorageEngine, Transaction};
7
8use crate::encoding::{
9    decode_embedding_entity_id, encode_embedding_key, encode_embedding_prefix,
10    PREFIX_MULTI_VECTOR_SPACE,
11};
12use crate::error::VectorError;
13use crate::types::{EmbeddingName, MultiVectorEmbedding, MultiVectorEmbeddingSpace};
14
15/// Table name for multi-vector embedding space metadata.
16const TABLE_MULTI_VECTOR_SPACES: &str = "multi_vector_spaces";
17
18/// Table name for entity multi-vector embeddings.
19const TABLE_MULTI_VECTOR_EMBEDDINGS: &str = "multi_vector_embeddings";
20
21/// A store for multi-vector embeddings (ColBERT-style).
22///
23/// `MultiVectorStore` provides CRUD operations for multi-vector embeddings
24/// organized into named embedding spaces. Each multi-vector stores per-token
25/// embeddings for late interaction models like ColBERT.
26pub struct MultiVectorStore<E: StorageEngine> {
27    engine: E,
28}
29
30impl<E: StorageEngine> MultiVectorStore<E> {
31    /// Create a new multi-vector store with the given storage engine.
32    #[must_use]
33    pub const fn new(engine: E) -> Self {
34        Self { engine }
35    }
36
37    /// Create a new multi-vector embedding space.
38    ///
39    /// # Errors
40    ///
41    /// Returns an error if the space already exists or if the storage operation fails.
42    pub fn create_space(&self, space: &MultiVectorEmbeddingSpace) -> Result<(), VectorError> {
43        let mut tx = self.engine.begin_write()?;
44
45        let key = encode_multi_vector_space_key(space.name());
46
47        // Check if space already exists
48        if tx.get(TABLE_MULTI_VECTOR_SPACES, &key)?.is_some() {
49            return Err(VectorError::InvalidName(format!(
50                "multi-vector embedding space '{}' already exists",
51                space.name()
52            )));
53        }
54
55        // Store the space metadata
56        tx.put(TABLE_MULTI_VECTOR_SPACES, &key, &space.to_bytes()?)?;
57        tx.commit()?;
58
59        Ok(())
60    }
61
62    /// Get a multi-vector embedding space by name.
63    ///
64    /// # Errors
65    ///
66    /// Returns an error if the space doesn't exist or if the storage operation fails.
67    pub fn get_space(
68        &self,
69        name: &EmbeddingName,
70    ) -> Result<MultiVectorEmbeddingSpace, VectorError> {
71        let tx = self.engine.begin_read()?;
72        let key = encode_multi_vector_space_key(name);
73
74        let bytes = tx
75            .get(TABLE_MULTI_VECTOR_SPACES, &key)?
76            .ok_or_else(|| VectorError::SpaceNotFound(name.to_string()))?;
77
78        MultiVectorEmbeddingSpace::from_bytes(&bytes)
79    }
80
81    /// Delete a multi-vector embedding space and all its embeddings.
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the space doesn't exist or if the storage operation fails.
86    pub fn delete_space(&self, name: &EmbeddingName) -> Result<(), VectorError> {
87        let mut tx = self.engine.begin_write()?;
88
89        let space_key = encode_multi_vector_space_key(name);
90
91        // Check if space exists
92        if tx.get(TABLE_MULTI_VECTOR_SPACES, &space_key)?.is_none() {
93            return Err(VectorError::SpaceNotFound(name.to_string()));
94        }
95
96        // Delete all embeddings in this space
97        let prefix = encode_embedding_prefix(name);
98        let prefix_end = next_prefix(&prefix);
99
100        let mut keys_to_delete = Vec::new();
101        {
102            let cursor = tx.range(
103                TABLE_MULTI_VECTOR_EMBEDDINGS,
104                Bound::Included(prefix.as_slice()),
105                Bound::Excluded(prefix_end.as_slice()),
106            )?;
107
108            let mut cursor = cursor;
109            while let Some((key, _)) = cursor.next()? {
110                keys_to_delete.push(key);
111            }
112        }
113
114        for key in keys_to_delete {
115            tx.delete(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?;
116        }
117
118        // Delete the space metadata
119        tx.delete(TABLE_MULTI_VECTOR_SPACES, &space_key)?;
120
121        tx.commit()?;
122        Ok(())
123    }
124
125    /// List all multi-vector embedding spaces.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if the storage operation fails.
130    pub fn list_spaces(&self) -> Result<Vec<MultiVectorEmbeddingSpace>, VectorError> {
131        let tx = self.engine.begin_read()?;
132
133        let prefix = vec![PREFIX_MULTI_VECTOR_SPACE];
134        let prefix_end = next_prefix(&prefix);
135
136        let mut cursor = tx.range(
137            TABLE_MULTI_VECTOR_SPACES,
138            Bound::Included(prefix.as_slice()),
139            Bound::Excluded(prefix_end.as_slice()),
140        )?;
141
142        let mut spaces = Vec::new();
143        while let Some((_, value)) = cursor.next()? {
144            spaces.push(MultiVectorEmbeddingSpace::from_bytes(&value)?);
145        }
146
147        Ok(spaces)
148    }
149
150    /// Store a multi-vector embedding for an entity in a space.
151    ///
152    /// # Errors
153    ///
154    /// Returns an error if:
155    /// - The embedding space doesn't exist
156    /// - The token embedding dimension doesn't match the space dimension
157    /// - The storage operation fails
158    pub fn put(
159        &self,
160        entity_id: EntityId,
161        space_name: &EmbeddingName,
162        embedding: &MultiVectorEmbedding,
163    ) -> Result<(), VectorError> {
164        // Get space to validate dimension
165        let space = self.get_space(space_name)?;
166
167        if embedding.dimension() != space.dimension() {
168            return Err(VectorError::DimensionMismatch {
169                expected: space.dimension(),
170                actual: embedding.dimension(),
171            });
172        }
173
174        let mut tx = self.engine.begin_write()?;
175
176        let key = encode_embedding_key(space_name, entity_id);
177
178        // Encode multi-vector: dimension (u32) + data bytes
179        let mut bytes = Vec::new();
180        let dim = embedding.dimension() as u32;
181        bytes.extend_from_slice(&dim.to_le_bytes());
182        bytes.extend_from_slice(&embedding.to_bytes());
183
184        tx.put(TABLE_MULTI_VECTOR_EMBEDDINGS, &key, &bytes)?;
185
186        tx.commit()?;
187        Ok(())
188    }
189
190    /// Get a multi-vector embedding for an entity from a space.
191    ///
192    /// # Errors
193    ///
194    /// Returns an error if:
195    /// - The embedding space doesn't exist
196    /// - The embedding doesn't exist for this entity
197    /// - The storage operation fails
198    pub fn get(
199        &self,
200        entity_id: EntityId,
201        space_name: &EmbeddingName,
202    ) -> Result<MultiVectorEmbedding, VectorError> {
203        // Check space exists
204        let _ = self.get_space(space_name)?;
205
206        let tx = self.engine.begin_read()?;
207
208        let key = encode_embedding_key(space_name, entity_id);
209        let bytes = tx.get(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?.ok_or_else(|| {
210            VectorError::EmbeddingNotFound {
211                entity_id: entity_id.as_u64(),
212                space: space_name.to_string(),
213            }
214        })?;
215
216        // Decode: dimension (u32) + data bytes
217        if bytes.len() < 4 {
218            return Err(VectorError::Encoding("multi-vector data too short".to_string()));
219        }
220
221        let dim_bytes: [u8; 4] = bytes[..4]
222            .try_into()
223            .map_err(|_| VectorError::Encoding("failed to read dimension".to_string()))?;
224        let dimension = u32::from_le_bytes(dim_bytes) as usize;
225
226        MultiVectorEmbedding::from_bytes(&bytes[4..], dimension)
227    }
228
229    /// Delete a multi-vector embedding for an entity from a space.
230    ///
231    /// # Returns
232    ///
233    /// Returns `Ok(true)` if the embedding was deleted, `Ok(false)` if it didn't exist.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if the storage operation fails.
238    pub fn delete(
239        &self,
240        entity_id: EntityId,
241        space_name: &EmbeddingName,
242    ) -> Result<bool, VectorError> {
243        let mut tx = self.engine.begin_write()?;
244
245        let key = encode_embedding_key(space_name, entity_id);
246        let existed = tx.delete(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?;
247
248        tx.commit()?;
249        Ok(existed)
250    }
251
252    /// Check if a multi-vector embedding exists for an entity in a space.
253    ///
254    /// # Errors
255    ///
256    /// Returns an error if the storage operation fails.
257    pub fn exists(
258        &self,
259        entity_id: EntityId,
260        space_name: &EmbeddingName,
261    ) -> Result<bool, VectorError> {
262        let tx = self.engine.begin_read()?;
263        let key = encode_embedding_key(space_name, entity_id);
264        Ok(tx.get(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?.is_some())
265    }
266
267    /// List all entity IDs that have multi-vector embeddings in a space.
268    ///
269    /// # Errors
270    ///
271    /// Returns an error if the storage operation fails.
272    pub fn list_entities(&self, space_name: &EmbeddingName) -> Result<Vec<EntityId>, VectorError> {
273        let tx = self.engine.begin_read()?;
274
275        let prefix = encode_embedding_prefix(space_name);
276        let prefix_end = next_prefix(&prefix);
277
278        let mut cursor = tx.range(
279            TABLE_MULTI_VECTOR_EMBEDDINGS,
280            Bound::Included(prefix.as_slice()),
281            Bound::Excluded(prefix_end.as_slice()),
282        )?;
283
284        let mut entities = Vec::new();
285        while let Some((key, _)) = cursor.next()? {
286            if let Some(entity_id) = decode_embedding_entity_id(&key) {
287                entities.push(entity_id);
288            }
289        }
290
291        Ok(entities)
292    }
293
294    /// Count the number of multi-vector embeddings in a space.
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if the storage operation fails.
299    pub fn count(&self, space_name: &EmbeddingName) -> Result<usize, VectorError> {
300        let tx = self.engine.begin_read()?;
301
302        let prefix = encode_embedding_prefix(space_name);
303        let prefix_end = next_prefix(&prefix);
304
305        let mut cursor = tx.range(
306            TABLE_MULTI_VECTOR_EMBEDDINGS,
307            Bound::Included(prefix.as_slice()),
308            Bound::Excluded(prefix_end.as_slice()),
309        )?;
310
311        let mut count = 0;
312        while cursor.next()?.is_some() {
313            count += 1;
314        }
315
316        Ok(count)
317    }
318
319    /// Get multiple multi-vector embeddings at once.
320    ///
321    /// Returns a vector of `(EntityId, Option<MultiVectorEmbedding>)` tuples.
322    /// If an embedding doesn't exist for an entity, the option is `None`.
323    ///
324    /// # Errors
325    ///
326    /// Returns an error if the storage operation fails.
327    pub fn get_many(
328        &self,
329        entity_ids: &[EntityId],
330        space_name: &EmbeddingName,
331    ) -> Result<Vec<(EntityId, Option<MultiVectorEmbedding>)>, VectorError> {
332        // Get space to know dimension
333        let space = self.get_space(space_name)?;
334        let dimension = space.dimension();
335
336        let tx = self.engine.begin_read()?;
337
338        let mut results = Vec::with_capacity(entity_ids.len());
339
340        for &entity_id in entity_ids {
341            let key = encode_embedding_key(space_name, entity_id);
342            let embedding = tx
343                .get(TABLE_MULTI_VECTOR_EMBEDDINGS, &key)?
344                .map(|bytes| {
345                    if bytes.len() < 4 {
346                        return Err(VectorError::Encoding(
347                            "multi-vector data too short".to_string(),
348                        ));
349                    }
350                    MultiVectorEmbedding::from_bytes(&bytes[4..], dimension)
351                })
352                .transpose()?;
353
354            results.push((entity_id, embedding));
355        }
356
357        Ok(results)
358    }
359
360    /// Store multiple multi-vector embeddings at once.
361    ///
362    /// All embeddings must have token dimensions matching the space's dimension.
363    ///
364    /// # Errors
365    ///
366    /// Returns an error if:
367    /// - The embedding space doesn't exist
368    /// - Any embedding dimension doesn't match the space dimension
369    /// - The storage operation fails
370    pub fn put_many(
371        &self,
372        embeddings: &[(EntityId, MultiVectorEmbedding)],
373        space_name: &EmbeddingName,
374    ) -> Result<(), VectorError> {
375        if embeddings.is_empty() {
376            return Ok(());
377        }
378
379        // Get space to validate dimension
380        let space = self.get_space(space_name)?;
381
382        // Validate all dimensions first
383        for (entity_id, embedding) in embeddings {
384            if embedding.dimension() != space.dimension() {
385                return Err(VectorError::DimensionMismatch {
386                    expected: space.dimension(),
387                    actual: embedding.dimension(),
388                });
389            }
390            let _ = entity_id;
391        }
392
393        let mut tx = self.engine.begin_write()?;
394
395        for (entity_id, embedding) in embeddings {
396            let key = encode_embedding_key(space_name, *entity_id);
397
398            let mut bytes = Vec::new();
399            let dim = embedding.dimension() as u32;
400            bytes.extend_from_slice(&dim.to_le_bytes());
401            bytes.extend_from_slice(&embedding.to_bytes());
402
403            tx.put(TABLE_MULTI_VECTOR_EMBEDDINGS, &key, &bytes)?;
404        }
405
406        tx.commit()?;
407        Ok(())
408    }
409}
410
411/// Encode a multi-vector space key.
412fn encode_multi_vector_space_key(name: &EmbeddingName) -> Vec<u8> {
413    let name_bytes = name.as_str().as_bytes();
414    let mut key = Vec::with_capacity(1 + name_bytes.len());
415    key.push(PREFIX_MULTI_VECTOR_SPACE);
416    key.extend_from_slice(name_bytes);
417    key
418}
419
420/// Calculate the next prefix for range scanning.
421fn next_prefix(prefix: &[u8]) -> Vec<u8> {
422    let mut result = prefix.to_vec();
423
424    for byte in result.iter_mut().rev() {
425        if *byte < 0xFF {
426            *byte += 1;
427            return result;
428        }
429    }
430
431    result.push(0xFF);
432    result
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::distance::DistanceMetric;
439    use manifoldb_storage::backends::RedbEngine;
440    use std::sync::atomic::{AtomicUsize, Ordering};
441
442    static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0);
443
444    fn create_test_store() -> MultiVectorStore<RedbEngine> {
445        let engine = RedbEngine::in_memory().unwrap();
446        MultiVectorStore::new(engine)
447    }
448
449    fn unique_space_name() -> EmbeddingName {
450        let count = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
451        EmbeddingName::new(format!("multi_test_space_{}", count)).unwrap()
452    }
453
454    #[test]
455    fn create_and_get_space() {
456        let store = create_test_store();
457        let name = unique_space_name();
458        let space = MultiVectorEmbeddingSpace::new(name.clone(), 128, DistanceMetric::DotProduct);
459
460        store.create_space(&space).unwrap();
461
462        let retrieved = store.get_space(&name).unwrap();
463        assert_eq!(retrieved.dimension(), 128);
464        assert_eq!(retrieved.distance_metric(), DistanceMetric::DotProduct);
465    }
466
467    #[test]
468    fn create_duplicate_space_fails() {
469        let store = create_test_store();
470        let name = unique_space_name();
471        let space = MultiVectorEmbeddingSpace::new(name.clone(), 128, DistanceMetric::DotProduct);
472
473        store.create_space(&space).unwrap();
474        let result = store.create_space(&space);
475
476        assert!(result.is_err());
477    }
478
479    #[test]
480    fn put_and_get_multi_vector() {
481        let store = create_test_store();
482        let name = unique_space_name();
483        let space = MultiVectorEmbeddingSpace::new(name.clone(), 3, DistanceMetric::DotProduct);
484        store.create_space(&space).unwrap();
485
486        let embedding = MultiVectorEmbedding::new(vec![
487            vec![1.0, 2.0, 3.0],
488            vec![4.0, 5.0, 6.0],
489            vec![7.0, 8.0, 9.0],
490        ])
491        .unwrap();
492
493        store.put(EntityId::new(42), &name, &embedding).unwrap();
494
495        let retrieved = store.get(EntityId::new(42), &name).unwrap();
496        assert_eq!(retrieved.num_vectors(), 3);
497        assert_eq!(retrieved.dimension(), 3);
498        assert_eq!(retrieved.get_vector(0), Some([1.0, 2.0, 3.0].as_slice()));
499        assert_eq!(retrieved.get_vector(1), Some([4.0, 5.0, 6.0].as_slice()));
500        assert_eq!(retrieved.get_vector(2), Some([7.0, 8.0, 9.0].as_slice()));
501    }
502
503    #[test]
504    fn put_wrong_dimension_fails() {
505        let store = create_test_store();
506        let name = unique_space_name();
507        let space = MultiVectorEmbeddingSpace::new(name.clone(), 128, DistanceMetric::DotProduct);
508        store.create_space(&space).unwrap();
509
510        let embedding = MultiVectorEmbedding::new(vec![vec![1.0, 2.0, 3.0]]).unwrap();
511
512        let result = store.put(EntityId::new(1), &name, &embedding);
513        assert!(result.is_err());
514        match result.unwrap_err() {
515            VectorError::DimensionMismatch { expected, actual } => {
516                assert_eq!(expected, 128);
517                assert_eq!(actual, 3);
518            }
519            _ => panic!("unexpected error type"),
520        }
521    }
522
523    #[test]
524    fn delete_multi_vector() {
525        let store = create_test_store();
526        let name = unique_space_name();
527        let space = MultiVectorEmbeddingSpace::new(name.clone(), 3, DistanceMetric::DotProduct);
528        store.create_space(&space).unwrap();
529
530        let embedding = MultiVectorEmbedding::new(vec![vec![1.0, 2.0, 3.0]]).unwrap();
531        store.put(EntityId::new(1), &name, &embedding).unwrap();
532
533        assert!(store.exists(EntityId::new(1), &name).unwrap());
534        assert!(store.delete(EntityId::new(1), &name).unwrap());
535        assert!(!store.exists(EntityId::new(1), &name).unwrap());
536    }
537
538    #[test]
539    fn list_entities() {
540        let store = create_test_store();
541        let name = unique_space_name();
542        let space = MultiVectorEmbeddingSpace::new(name.clone(), 3, DistanceMetric::DotProduct);
543        store.create_space(&space).unwrap();
544
545        for i in 1..=5 {
546            let embedding =
547                MultiVectorEmbedding::new(vec![vec![i as f32, i as f32, i as f32]]).unwrap();
548            store.put(EntityId::new(i), &name, &embedding).unwrap();
549        }
550
551        let entities = store.list_entities(&name).unwrap();
552        assert_eq!(entities.len(), 5);
553    }
554
555    #[test]
556    fn count_embeddings() {
557        let store = create_test_store();
558        let name = unique_space_name();
559        let space = MultiVectorEmbeddingSpace::new(name.clone(), 3, DistanceMetric::DotProduct);
560        store.create_space(&space).unwrap();
561
562        assert_eq!(store.count(&name).unwrap(), 0);
563
564        for i in 1..=10 {
565            let embedding =
566                MultiVectorEmbedding::new(vec![vec![i as f32, i as f32, i as f32]]).unwrap();
567            store.put(EntityId::new(i), &name, &embedding).unwrap();
568        }
569
570        assert_eq!(store.count(&name).unwrap(), 10);
571    }
572
573    #[test]
574    fn variable_token_count() {
575        let store = create_test_store();
576        let name = unique_space_name();
577        let space = MultiVectorEmbeddingSpace::new(name.clone(), 4, DistanceMetric::DotProduct);
578        store.create_space(&space).unwrap();
579
580        // Documents with different numbers of tokens
581        let doc1 = MultiVectorEmbedding::new(vec![vec![1.0, 0.0, 0.0, 0.0]]).unwrap();
582
583        let doc2 =
584            MultiVectorEmbedding::new(vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]])
585                .unwrap();
586
587        let doc3 = MultiVectorEmbedding::new(vec![
588            vec![1.0, 0.0, 0.0, 0.0],
589            vec![0.0, 1.0, 0.0, 0.0],
590            vec![0.0, 0.0, 1.0, 0.0],
591            vec![0.0, 0.0, 0.0, 1.0],
592        ])
593        .unwrap();
594
595        store.put(EntityId::new(1), &name, &doc1).unwrap();
596        store.put(EntityId::new(2), &name, &doc2).unwrap();
597        store.put(EntityId::new(3), &name, &doc3).unwrap();
598
599        let retrieved1 = store.get(EntityId::new(1), &name).unwrap();
600        let retrieved2 = store.get(EntityId::new(2), &name).unwrap();
601        let retrieved3 = store.get(EntityId::new(3), &name).unwrap();
602
603        assert_eq!(retrieved1.num_vectors(), 1);
604        assert_eq!(retrieved2.num_vectors(), 2);
605        assert_eq!(retrieved3.num_vectors(), 4);
606    }
607}