1use std::ops::Bound;
4
5use manifoldb_core::EntityId;
6use manifoldb_storage::{Cursor, StorageEngine, Transaction};
7
8use crate::encoding::{
9 decode_sparse_embedding_entity_id, encode_sparse_embedding_key, encode_sparse_embedding_prefix,
10 encode_sparse_embedding_space_key, PREFIX_SPARSE_EMBEDDING_SPACE,
11};
12use crate::error::VectorError;
13use crate::types::{EmbeddingName, SparseEmbedding, SparseEmbeddingSpace};
14
15const TABLE_SPARSE_EMBEDDING_SPACES: &str = "sparse_vector_spaces";
17
18const TABLE_SPARSE_EMBEDDINGS: &str = "sparse_vector_embeddings";
20
21pub struct SparseVectorStore<E: StorageEngine> {
27 engine: E,
28}
29
30impl<E: StorageEngine> SparseVectorStore<E> {
31 #[must_use]
33 pub const fn new(engine: E) -> Self {
34 Self { engine }
35 }
36
37 pub fn create_space(&self, space: &SparseEmbeddingSpace) -> Result<(), VectorError> {
43 let mut tx = self.engine.begin_write()?;
44
45 let key = encode_sparse_embedding_space_key(space.name());
46
47 if tx.get(TABLE_SPARSE_EMBEDDING_SPACES, &key)?.is_some() {
49 return Err(VectorError::InvalidName(format!(
50 "sparse embedding space '{}' already exists",
51 space.name()
52 )));
53 }
54
55 tx.put(TABLE_SPARSE_EMBEDDING_SPACES, &key, &space.to_bytes()?)?;
57 tx.commit()?;
58
59 Ok(())
60 }
61
62 pub fn get_space(&self, name: &EmbeddingName) -> Result<SparseEmbeddingSpace, VectorError> {
68 let tx = self.engine.begin_read()?;
69 let key = encode_sparse_embedding_space_key(name);
70
71 let bytes = tx
72 .get(TABLE_SPARSE_EMBEDDING_SPACES, &key)?
73 .ok_or_else(|| VectorError::SpaceNotFound(name.to_string()))?;
74
75 SparseEmbeddingSpace::from_bytes(&bytes)
76 }
77
78 pub fn delete_space(&self, name: &EmbeddingName) -> Result<(), VectorError> {
84 let mut tx = self.engine.begin_write()?;
85
86 let space_key = encode_sparse_embedding_space_key(name);
87
88 if tx.get(TABLE_SPARSE_EMBEDDING_SPACES, &space_key)?.is_none() {
90 return Err(VectorError::SpaceNotFound(name.to_string()));
91 }
92
93 let prefix = encode_sparse_embedding_prefix(name);
95 let prefix_end = next_prefix(&prefix);
96
97 let mut keys_to_delete = Vec::new();
98 {
99 let cursor = tx.range(
100 TABLE_SPARSE_EMBEDDINGS,
101 Bound::Included(prefix.as_slice()),
102 Bound::Excluded(prefix_end.as_slice()),
103 )?;
104
105 let mut cursor = cursor;
106 while let Some((key, _)) = cursor.next()? {
107 keys_to_delete.push(key);
108 }
109 }
110
111 for key in keys_to_delete {
112 tx.delete(TABLE_SPARSE_EMBEDDINGS, &key)?;
113 }
114
115 tx.delete(TABLE_SPARSE_EMBEDDING_SPACES, &space_key)?;
117
118 tx.commit()?;
119 Ok(())
120 }
121
122 pub fn list_spaces(&self) -> Result<Vec<SparseEmbeddingSpace>, VectorError> {
128 let tx = self.engine.begin_read()?;
129
130 let prefix = vec![PREFIX_SPARSE_EMBEDDING_SPACE];
131 let prefix_end = next_prefix(&prefix);
132
133 let mut cursor = tx.range(
134 TABLE_SPARSE_EMBEDDING_SPACES,
135 Bound::Included(prefix.as_slice()),
136 Bound::Excluded(prefix_end.as_slice()),
137 )?;
138
139 let mut spaces = Vec::new();
140 while let Some((_, value)) = cursor.next()? {
141 spaces.push(SparseEmbeddingSpace::from_bytes(&value)?);
142 }
143
144 Ok(spaces)
145 }
146
147 pub fn put(
156 &self,
157 entity_id: EntityId,
158 space_name: &EmbeddingName,
159 embedding: &SparseEmbedding,
160 ) -> Result<(), VectorError> {
161 let space = self.get_space(space_name)?;
163
164 for &(idx, _) in embedding.as_pairs() {
166 if idx >= space.max_dimension() {
167 return Err(VectorError::Encoding(format!(
168 "sparse vector index {} exceeds max dimension {}",
169 idx,
170 space.max_dimension()
171 )));
172 }
173 }
174
175 let mut tx = self.engine.begin_write()?;
176
177 let key = encode_sparse_embedding_key(space_name, entity_id);
178 tx.put(TABLE_SPARSE_EMBEDDINGS, &key, &embedding.to_bytes())?;
179
180 tx.commit()?;
181 Ok(())
182 }
183
184 pub fn get(
193 &self,
194 entity_id: EntityId,
195 space_name: &EmbeddingName,
196 ) -> Result<SparseEmbedding, VectorError> {
197 let _ = self.get_space(space_name)?;
199
200 let tx = self.engine.begin_read()?;
201
202 let key = encode_sparse_embedding_key(space_name, entity_id);
203 let bytes = tx.get(TABLE_SPARSE_EMBEDDINGS, &key)?.ok_or_else(|| {
204 VectorError::EmbeddingNotFound {
205 entity_id: entity_id.as_u64(),
206 space: space_name.to_string(),
207 }
208 })?;
209
210 SparseEmbedding::from_bytes(&bytes)
211 }
212
213 pub fn delete(
223 &self,
224 entity_id: EntityId,
225 space_name: &EmbeddingName,
226 ) -> Result<bool, VectorError> {
227 let mut tx = self.engine.begin_write()?;
228
229 let key = encode_sparse_embedding_key(space_name, entity_id);
230 let existed = tx.delete(TABLE_SPARSE_EMBEDDINGS, &key)?;
231
232 tx.commit()?;
233 Ok(existed)
234 }
235
236 pub fn exists(
242 &self,
243 entity_id: EntityId,
244 space_name: &EmbeddingName,
245 ) -> Result<bool, VectorError> {
246 let tx = self.engine.begin_read()?;
247 let key = encode_sparse_embedding_key(space_name, entity_id);
248 Ok(tx.get(TABLE_SPARSE_EMBEDDINGS, &key)?.is_some())
249 }
250
251 pub fn list_entities(&self, space_name: &EmbeddingName) -> Result<Vec<EntityId>, VectorError> {
257 let tx = self.engine.begin_read()?;
258
259 let prefix = encode_sparse_embedding_prefix(space_name);
260 let prefix_end = next_prefix(&prefix);
261
262 let mut cursor = tx.range(
263 TABLE_SPARSE_EMBEDDINGS,
264 Bound::Included(prefix.as_slice()),
265 Bound::Excluded(prefix_end.as_slice()),
266 )?;
267
268 let mut entities = Vec::new();
269 while let Some((key, _)) = cursor.next()? {
270 if let Some(entity_id) = decode_sparse_embedding_entity_id(&key) {
271 entities.push(entity_id);
272 }
273 }
274
275 Ok(entities)
276 }
277
278 pub fn count(&self, space_name: &EmbeddingName) -> Result<usize, VectorError> {
284 let tx = self.engine.begin_read()?;
285
286 let prefix = encode_sparse_embedding_prefix(space_name);
287 let prefix_end = next_prefix(&prefix);
288
289 let mut cursor = tx.range(
290 TABLE_SPARSE_EMBEDDINGS,
291 Bound::Included(prefix.as_slice()),
292 Bound::Excluded(prefix_end.as_slice()),
293 )?;
294
295 let mut count = 0;
296 while cursor.next()?.is_some() {
297 count += 1;
298 }
299
300 Ok(count)
301 }
302
303 pub fn get_many(
311 &self,
312 entity_ids: &[EntityId],
313 space_name: &EmbeddingName,
314 ) -> Result<Vec<(EntityId, Option<SparseEmbedding>)>, VectorError> {
315 let tx = self.engine.begin_read()?;
316
317 let mut results = Vec::with_capacity(entity_ids.len());
318
319 for &entity_id in entity_ids {
320 let key = encode_sparse_embedding_key(space_name, entity_id);
321 let embedding = tx
322 .get(TABLE_SPARSE_EMBEDDINGS, &key)?
323 .map(|bytes| SparseEmbedding::from_bytes(&bytes))
324 .transpose()?;
325
326 results.push((entity_id, embedding));
327 }
328
329 Ok(results)
330 }
331
332 pub fn put_many(
339 &self,
340 embeddings: &[(EntityId, SparseEmbedding)],
341 space_name: &EmbeddingName,
342 ) -> Result<(), VectorError> {
343 if embeddings.is_empty() {
344 return Ok(());
345 }
346
347 let space = self.get_space(space_name)?;
349
350 for (entity_id, embedding) in embeddings {
352 for &(idx, _) in embedding.as_pairs() {
353 if idx >= space.max_dimension() {
354 return Err(VectorError::Encoding(format!(
355 "sparse vector index {} exceeds max dimension {} for entity {}",
356 idx,
357 space.max_dimension(),
358 entity_id.as_u64()
359 )));
360 }
361 }
362 }
363
364 let mut tx = self.engine.begin_write()?;
365
366 for (entity_id, embedding) in embeddings {
367 let key = encode_sparse_embedding_key(space_name, *entity_id);
368 tx.put(TABLE_SPARSE_EMBEDDINGS, &key, &embedding.to_bytes())?;
369 }
370
371 tx.commit()?;
372 Ok(())
373 }
374
375 pub fn delete_entity(&self, entity_id: EntityId) -> Result<usize, VectorError> {
381 let spaces = self.list_spaces()?;
382
383 let mut deleted = 0;
384 for space in spaces {
385 if self.delete(entity_id, space.name())? {
386 deleted += 1;
387 }
388 }
389
390 Ok(deleted)
391 }
392}
393
394fn next_prefix(prefix: &[u8]) -> Vec<u8> {
396 let mut result = prefix.to_vec();
397
398 for byte in result.iter_mut().rev() {
399 if *byte < 0xFF {
400 *byte += 1;
401 return result;
402 }
403 }
404
405 result.push(0xFF);
406 result
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use crate::distance::sparse::SparseDistanceMetric;
413 use manifoldb_storage::backends::RedbEngine;
414 use std::sync::atomic::{AtomicUsize, Ordering};
415
416 static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0);
417
418 fn create_test_store() -> SparseVectorStore<RedbEngine> {
419 let engine = RedbEngine::in_memory().unwrap();
420 SparseVectorStore::new(engine)
421 }
422
423 fn unique_space_name() -> EmbeddingName {
424 let count = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
425 EmbeddingName::new(format!("sparse_test_space_{}", count)).unwrap()
426 }
427
428 #[test]
429 fn create_and_get_space() {
430 let store = create_test_store();
431 let name = unique_space_name();
432 let space =
433 SparseEmbeddingSpace::new(name.clone(), 30522, SparseDistanceMetric::DotProduct);
434
435 store.create_space(&space).unwrap();
436
437 let retrieved = store.get_space(&name).unwrap();
438 assert_eq!(retrieved.max_dimension(), 30522);
439 assert_eq!(retrieved.distance_metric(), SparseDistanceMetric::DotProduct);
440 }
441
442 #[test]
443 fn create_duplicate_space_fails() {
444 let store = create_test_store();
445 let name = unique_space_name();
446 let space = SparseEmbeddingSpace::new(name.clone(), 10000, SparseDistanceMetric::Cosine);
447
448 store.create_space(&space).unwrap();
449 let result = store.create_space(&space);
450
451 assert!(result.is_err());
452 }
453
454 #[test]
455 fn get_nonexistent_space_fails() {
456 let store = create_test_store();
457 let name = EmbeddingName::new("nonexistent").unwrap();
458
459 let result = store.get_space(&name);
460 assert!(result.is_err());
461 }
462
463 #[test]
464 fn put_and_get_embedding() {
465 let store = create_test_store();
466 let name = unique_space_name();
467 let space = SparseEmbeddingSpace::new(name.clone(), 1000, SparseDistanceMetric::DotProduct);
468 store.create_space(&space).unwrap();
469
470 let embedding = SparseEmbedding::new(vec![(10, 0.5), (50, 0.3), (100, 0.2)]).unwrap();
471 store.put(EntityId::new(42), &name, &embedding).unwrap();
472
473 let retrieved = store.get(EntityId::new(42), &name).unwrap();
474 assert_eq!(retrieved.nnz(), 3);
475 assert!((retrieved.get(10) - 0.5).abs() < 1e-6);
476 assert!((retrieved.get(50) - 0.3).abs() < 1e-6);
477 assert!((retrieved.get(100) - 0.2).abs() < 1e-6);
478 }
479
480 #[test]
481 fn put_index_exceeds_dimension_fails() {
482 let store = create_test_store();
483 let name = unique_space_name();
484 let space = SparseEmbeddingSpace::new(name.clone(), 100, SparseDistanceMetric::DotProduct);
485 store.create_space(&space).unwrap();
486
487 let embedding = SparseEmbedding::new(vec![(0, 0.5), (100, 0.3)]).unwrap(); let result = store.put(EntityId::new(1), &name, &embedding);
490 assert!(result.is_err());
491 }
492
493 #[test]
494 fn get_nonexistent_embedding_fails() {
495 let store = create_test_store();
496 let name = unique_space_name();
497 let space = SparseEmbeddingSpace::new(name.clone(), 1000, SparseDistanceMetric::DotProduct);
498 store.create_space(&space).unwrap();
499
500 let result = store.get(EntityId::new(999), &name);
501 assert!(result.is_err());
502 }
503
504 #[test]
505 fn delete_embedding() {
506 let store = create_test_store();
507 let name = unique_space_name();
508 let space = SparseEmbeddingSpace::new(name.clone(), 1000, SparseDistanceMetric::DotProduct);
509 store.create_space(&space).unwrap();
510
511 let embedding = SparseEmbedding::new(vec![(10, 0.5)]).unwrap();
512 store.put(EntityId::new(1), &name, &embedding).unwrap();
513
514 assert!(store.exists(EntityId::new(1), &name).unwrap());
515 assert!(store.delete(EntityId::new(1), &name).unwrap());
516 assert!(!store.exists(EntityId::new(1), &name).unwrap());
517
518 assert!(!store.delete(EntityId::new(1), &name).unwrap());
520 }
521
522 #[test]
523 fn list_entities() {
524 let store = create_test_store();
525 let name = unique_space_name();
526 let space = SparseEmbeddingSpace::new(name.clone(), 1000, SparseDistanceMetric::DotProduct);
527 store.create_space(&space).unwrap();
528
529 for i in 1..=5 {
530 let embedding = SparseEmbedding::new(vec![(i as u32, i as f32)]).unwrap();
531 store.put(EntityId::new(i), &name, &embedding).unwrap();
532 }
533
534 let entities = store.list_entities(&name).unwrap();
535 assert_eq!(entities.len(), 5);
536
537 let ids: Vec<u64> = entities.iter().map(|e| e.as_u64()).collect();
538 for i in 1..=5 {
539 assert!(ids.contains(&i));
540 }
541 }
542
543 #[test]
544 fn count_embeddings() {
545 let store = create_test_store();
546 let name = unique_space_name();
547 let space = SparseEmbeddingSpace::new(name.clone(), 1000, SparseDistanceMetric::DotProduct);
548 store.create_space(&space).unwrap();
549
550 assert_eq!(store.count(&name).unwrap(), 0);
551
552 for i in 1..=10 {
553 let embedding = SparseEmbedding::new(vec![(i as u32, 1.0)]).unwrap();
554 store.put(EntityId::new(i), &name, &embedding).unwrap();
555 }
556
557 assert_eq!(store.count(&name).unwrap(), 10);
558 }
559
560 #[test]
561 fn delete_space_removes_embeddings() {
562 let store = create_test_store();
563 let name = unique_space_name();
564 let space = SparseEmbeddingSpace::new(name.clone(), 1000, SparseDistanceMetric::DotProduct);
565 store.create_space(&space).unwrap();
566
567 let embedding = SparseEmbedding::new(vec![(10, 0.5)]).unwrap();
568 store.put(EntityId::new(1), &name, &embedding).unwrap();
569
570 store.delete_space(&name).unwrap();
571
572 assert!(store.get_space(&name).is_err());
573 }
574
575 #[test]
576 fn put_many_and_get_many() {
577 let store = create_test_store();
578 let name = unique_space_name();
579 let space = SparseEmbeddingSpace::new(name.clone(), 1000, SparseDistanceMetric::DotProduct);
580 store.create_space(&space).unwrap();
581
582 let embeddings: Vec<_> = (1..=5)
583 .map(|i| {
584 (EntityId::new(i), SparseEmbedding::new(vec![(i as u32 * 10, i as f32)]).unwrap())
585 })
586 .collect();
587
588 store.put_many(&embeddings, &name).unwrap();
589
590 assert_eq!(store.count(&name).unwrap(), 5);
591
592 let results = store
593 .get_many(&[EntityId::new(1), EntityId::new(2), EntityId::new(999)], &name)
594 .unwrap();
595
596 assert_eq!(results.len(), 3);
597 assert!(results[0].1.is_some());
598 assert!(results[1].1.is_some());
599 assert!(results[2].1.is_none());
600 }
601}