1use std::collections::HashMap;
31use std::ops::Bound;
32
33use manifoldb_core::{CollectionId, EntityId};
34use manifoldb_storage::{Cursor, StorageEngine, Transaction};
35
36use crate::encoding::{encode_collection_vector_key, encode_entity_vector_prefix, hash_name};
37use crate::error::VectorError;
38use crate::types::VectorData;
39
40const VECTOR_FORMAT_VERSION: u8 = 1;
42
43const VECTOR_TYPE_DENSE: u8 = 0;
45const VECTOR_TYPE_SPARSE: u8 = 1;
46const VECTOR_TYPE_MULTI: u8 = 2;
47const VECTOR_TYPE_BINARY: u8 = 3;
48
49pub struct CollectionVectorStore<E: StorageEngine> {
55 engine: E,
56}
57
58impl<E: StorageEngine> CollectionVectorStore<E> {
59 #[must_use]
61 pub const fn new(engine: E) -> Self {
62 Self { engine }
63 }
64
65 pub fn put_vector(
71 &self,
72 collection_id: CollectionId,
73 entity_id: EntityId,
74 vector_name: &str,
75 data: &VectorData,
76 ) -> Result<(), VectorError> {
77 let mut tx = self.engine.begin_write()?;
78 self.put_vector_tx(&mut tx, collection_id, entity_id, vector_name, data)?;
79 tx.commit()?;
80 Ok(())
81 }
82
83 pub fn put_vector_tx<T: Transaction>(
92 &self,
93 tx: &mut T,
94 collection_id: CollectionId,
95 entity_id: EntityId,
96 vector_name: &str,
97 data: &VectorData,
98 ) -> Result<(), VectorError> {
99 let key = encode_collection_vector_key(collection_id, entity_id, vector_name);
100 let value = encode_vector_value(data, vector_name);
101 tx.put(TABLE_COLLECTION_VECTORS, &key, &value)?;
102 Ok(())
103 }
104
105 pub fn get_vector(
111 &self,
112 collection_id: CollectionId,
113 entity_id: EntityId,
114 vector_name: &str,
115 ) -> Result<Option<VectorData>, VectorError> {
116 let tx = self.engine.begin_read()?;
117 let key = encode_collection_vector_key(collection_id, entity_id, vector_name);
118
119 match tx.get(TABLE_COLLECTION_VECTORS, &key)? {
120 Some(bytes) => Ok(Some(decode_vector_value(&bytes)?.0)),
121 None => Ok(None),
122 }
123 }
124
125 pub fn get_all_vectors(
133 &self,
134 collection_id: CollectionId,
135 entity_id: EntityId,
136 ) -> Result<HashMap<String, VectorData>, VectorError> {
137 let tx = self.engine.begin_read()?;
138 let prefix = encode_entity_vector_prefix(collection_id, entity_id);
139 let prefix_end = next_prefix(&prefix);
140
141 let mut cursor = tx.range(
142 TABLE_COLLECTION_VECTORS,
143 Bound::Included(prefix.as_slice()),
144 Bound::Excluded(prefix_end.as_slice()),
145 )?;
146
147 let mut vectors = HashMap::new();
148 while let Some((_, value)) = cursor.next()? {
149 let (data, vector_name) = decode_vector_value(&value)?;
150 vectors.insert(vector_name, data);
151 }
152
153 Ok(vectors)
154 }
155
156 pub fn delete_vector(
166 &self,
167 collection_id: CollectionId,
168 entity_id: EntityId,
169 vector_name: &str,
170 ) -> Result<bool, VectorError> {
171 let mut tx = self.engine.begin_write()?;
172 let key = encode_collection_vector_key(collection_id, entity_id, vector_name);
173 let deleted = tx.delete(TABLE_COLLECTION_VECTORS, &key)?;
174 tx.commit()?;
175 Ok(deleted)
176 }
177
178 pub fn delete_all_vectors(
190 &self,
191 collection_id: CollectionId,
192 entity_id: EntityId,
193 ) -> Result<usize, VectorError> {
194 let mut tx = self.engine.begin_write()?;
195 let count = self.delete_all_vectors_tx(&mut tx, collection_id, entity_id)?;
196 tx.commit()?;
197 Ok(count)
198 }
199
200 pub fn delete_all_vectors_tx<T: Transaction>(
213 &self,
214 tx: &mut T,
215 collection_id: CollectionId,
216 entity_id: EntityId,
217 ) -> Result<usize, VectorError> {
218 let prefix = encode_entity_vector_prefix(collection_id, entity_id);
219 let prefix_end = next_prefix(&prefix);
220
221 let mut keys_to_delete = Vec::new();
223 {
224 let mut cursor = tx.range(
225 TABLE_COLLECTION_VECTORS,
226 Bound::Included(prefix.as_slice()),
227 Bound::Excluded(prefix_end.as_slice()),
228 )?;
229
230 while let Some((key, _)) = cursor.next()? {
231 keys_to_delete.push(key);
232 }
233 }
234
235 let count = keys_to_delete.len();
237 for key in keys_to_delete {
238 tx.delete(TABLE_COLLECTION_VECTORS, &key)?;
239 }
240
241 Ok(count)
242 }
243
244 pub fn put_vectors_batch(
253 &self,
254 collection_id: CollectionId,
255 vectors: &[(EntityId, &str, &VectorData)],
256 ) -> Result<(), VectorError> {
257 if vectors.is_empty() {
258 return Ok(());
259 }
260
261 let mut tx = self.engine.begin_write()?;
262 for (entity_id, vector_name, data) in vectors {
263 self.put_vector_tx(&mut tx, collection_id, *entity_id, vector_name, data)?;
264 }
265 tx.commit()?;
266 Ok(())
267 }
268
269 pub fn exists(
275 &self,
276 collection_id: CollectionId,
277 entity_id: EntityId,
278 vector_name: &str,
279 ) -> Result<bool, VectorError> {
280 let tx = self.engine.begin_read()?;
281 let key = encode_collection_vector_key(collection_id, entity_id, vector_name);
282 Ok(tx.get(TABLE_COLLECTION_VECTORS, &key)?.is_some())
283 }
284
285 pub fn count_entity_vectors(
291 &self,
292 collection_id: CollectionId,
293 entity_id: EntityId,
294 ) -> Result<usize, VectorError> {
295 let tx = self.engine.begin_read()?;
296 let prefix = encode_entity_vector_prefix(collection_id, entity_id);
297 let prefix_end = next_prefix(&prefix);
298
299 let mut cursor = tx.range(
300 TABLE_COLLECTION_VECTORS,
301 Bound::Included(prefix.as_slice()),
302 Bound::Excluded(prefix_end.as_slice()),
303 )?;
304
305 let mut count = 0;
306 while cursor.next()?.is_some() {
307 count += 1;
308 }
309
310 Ok(count)
311 }
312
313 pub fn list_entities_with_vector(
322 &self,
323 collection_id: CollectionId,
324 vector_name: &str,
325 ) -> Result<Vec<EntityId>, VectorError> {
326 use crate::encoding::{decode_collection_vector_key, encode_collection_vector_prefix};
327
328 let tx = self.engine.begin_read()?;
329 let prefix = encode_collection_vector_prefix(collection_id);
330 let prefix_end = next_prefix(&prefix);
331
332 let target_hash = hash_name(vector_name);
333
334 let mut cursor = tx.range(
335 TABLE_COLLECTION_VECTORS,
336 Bound::Included(prefix.as_slice()),
337 Bound::Excluded(prefix_end.as_slice()),
338 )?;
339
340 let mut entities = Vec::new();
341 while let Some((key, _)) = cursor.next()? {
342 if let Some(decoded) = decode_collection_vector_key(&key) {
343 if decoded.vector_name_hash == target_hash {
344 entities.push(decoded.entity_id);
345 }
346 }
347 }
348
349 Ok(entities)
350 }
351}
352
353pub const TABLE_COLLECTION_VECTORS: &str = "collection_vectors";
357
358pub fn encode_vector_value(data: &VectorData, vector_name: &str) -> Vec<u8> {
369 let timestamp = std::time::SystemTime::now()
370 .duration_since(std::time::UNIX_EPOCH)
371 .map(|d| d.as_secs())
372 .unwrap_or(0);
373
374 let name_bytes = vector_name.as_bytes();
375 let name_len = name_bytes.len().min(u16::MAX as usize);
376
377 let mut bytes = Vec::new();
378 bytes.push(VECTOR_FORMAT_VERSION);
379 bytes.push(data.type_discriminant());
380 bytes.extend_from_slice(×tamp.to_be_bytes());
381 bytes.extend_from_slice(&(name_len as u16).to_be_bytes());
382 bytes.extend_from_slice(&name_bytes[..name_len]);
383
384 match data {
385 VectorData::Dense(v) => {
386 bytes.extend_from_slice(&(v.len() as u32).to_be_bytes());
387 for &val in v {
388 bytes.extend_from_slice(&val.to_le_bytes());
389 }
390 }
391 VectorData::Sparse(v) => {
392 bytes.extend_from_slice(&(v.len() as u32).to_be_bytes());
393 for &(idx, val) in v {
394 bytes.extend_from_slice(&idx.to_be_bytes());
395 bytes.extend_from_slice(&val.to_le_bytes());
396 }
397 }
398 VectorData::Multi(v) => {
399 let num_vectors = v.len() as u32;
400 let dim = v.first().map(|inner| inner.len() as u32).unwrap_or(0);
401 bytes.extend_from_slice(&num_vectors.to_be_bytes());
402 bytes.extend_from_slice(&dim.to_be_bytes());
403 for inner in v {
404 for &val in inner {
405 bytes.extend_from_slice(&val.to_le_bytes());
406 }
407 }
408 }
409 VectorData::Binary(v) => {
410 bytes.extend_from_slice(&(v.len() as u32).to_be_bytes());
411 bytes.extend_from_slice(v);
412 }
413 }
414
415 bytes
416}
417
418pub fn decode_vector_value(bytes: &[u8]) -> Result<(VectorData, String), VectorError> {
426 if bytes.len() < 12 {
427 return Err(VectorError::Encoding("truncated vector value".to_string()));
428 }
429
430 let version = bytes[0];
431 if version != VECTOR_FORMAT_VERSION {
432 return Err(VectorError::Encoding(format!(
433 "unsupported vector format version: {}",
434 version
435 )));
436 }
437
438 let vec_type = bytes[1];
439 let name_len = u16::from_be_bytes([bytes[10], bytes[11]]) as usize;
441
442 if bytes.len() < 12 + name_len + 4 {
443 return Err(VectorError::Encoding("truncated vector value (name)".to_string()));
444 }
445
446 let vector_name = String::from_utf8(bytes[12..12 + name_len].to_vec())
447 .map_err(|e| VectorError::Encoding(format!("invalid vector name: {}", e)))?;
448
449 let data_offset = 12 + name_len;
450 let data_len = u32::from_be_bytes([
451 bytes[data_offset],
452 bytes[data_offset + 1],
453 bytes[data_offset + 2],
454 bytes[data_offset + 3],
455 ]) as usize;
456
457 let payload_offset = data_offset + 4;
458
459 let data = match vec_type {
460 VECTOR_TYPE_DENSE => {
461 let expected_len = payload_offset + data_len * 4;
462 if bytes.len() != expected_len {
463 return Err(VectorError::Encoding("dense vector length mismatch".to_string()));
464 }
465 let mut v = Vec::with_capacity(data_len);
466 for i in 0..data_len {
467 let offset = payload_offset + i * 4;
468 let val = f32::from_le_bytes([
469 bytes[offset],
470 bytes[offset + 1],
471 bytes[offset + 2],
472 bytes[offset + 3],
473 ]);
474 v.push(val);
475 }
476 VectorData::Dense(v)
477 }
478 VECTOR_TYPE_SPARSE => {
479 let expected_len = payload_offset + data_len * 8;
480 if bytes.len() != expected_len {
481 return Err(VectorError::Encoding("sparse vector length mismatch".to_string()));
482 }
483 let mut v = Vec::with_capacity(data_len);
484 for i in 0..data_len {
485 let offset = payload_offset + i * 8;
486 let idx = u32::from_be_bytes([
487 bytes[offset],
488 bytes[offset + 1],
489 bytes[offset + 2],
490 bytes[offset + 3],
491 ]);
492 let val = f32::from_le_bytes([
493 bytes[offset + 4],
494 bytes[offset + 5],
495 bytes[offset + 6],
496 bytes[offset + 7],
497 ]);
498 v.push((idx, val));
499 }
500 VectorData::Sparse(v)
501 }
502 VECTOR_TYPE_MULTI => {
503 if bytes.len() < payload_offset + 4 {
504 return Err(VectorError::Encoding("truncated multi-vector".to_string()));
505 }
506 let num_vectors = data_len;
507 let dim = u32::from_be_bytes([
508 bytes[payload_offset],
509 bytes[payload_offset + 1],
510 bytes[payload_offset + 2],
511 bytes[payload_offset + 3],
512 ]) as usize;
513 let expected_len = payload_offset + 4 + num_vectors * dim * 4;
514 if bytes.len() != expected_len {
515 return Err(VectorError::Encoding("multi-vector length mismatch".to_string()));
516 }
517 let mut v = Vec::with_capacity(num_vectors);
518 for i in 0..num_vectors {
519 let mut inner = Vec::with_capacity(dim);
520 for j in 0..dim {
521 let offset = payload_offset + 4 + (i * dim + j) * 4;
522 let val = f32::from_le_bytes([
523 bytes[offset],
524 bytes[offset + 1],
525 bytes[offset + 2],
526 bytes[offset + 3],
527 ]);
528 inner.push(val);
529 }
530 v.push(inner);
531 }
532 VectorData::Multi(v)
533 }
534 VECTOR_TYPE_BINARY => {
535 let expected_len = payload_offset + data_len;
536 if bytes.len() != expected_len {
537 return Err(VectorError::Encoding("binary vector length mismatch".to_string()));
538 }
539 VectorData::Binary(bytes[payload_offset..].to_vec())
540 }
541 _ => {
542 return Err(VectorError::Encoding(format!("unknown vector type: {}", vec_type)));
543 }
544 };
545
546 Ok((data, vector_name))
547}
548
549fn next_prefix(prefix: &[u8]) -> Vec<u8> {
551 let mut result = prefix.to_vec();
552 for byte in result.iter_mut().rev() {
553 if *byte < 0xFF {
554 *byte += 1;
555 return result;
556 }
557 }
558 result.push(0xFF);
559 result
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use manifoldb_storage::backends::RedbEngine;
566
567 fn create_test_store() -> CollectionVectorStore<RedbEngine> {
568 let engine = RedbEngine::in_memory().unwrap();
569 CollectionVectorStore::new(engine)
570 }
571
572 #[test]
573 fn test_put_and_get_dense_vector() {
574 let store = create_test_store();
575 let collection_id = CollectionId::new(1);
576 let entity_id = EntityId::new(42);
577 let data = VectorData::Dense(vec![1.0, 2.0, 3.0]);
578
579 store.put_vector(collection_id, entity_id, "text", &data).unwrap();
580
581 let retrieved = store.get_vector(collection_id, entity_id, "text").unwrap().unwrap();
582
583 assert_eq!(retrieved.as_dense(), Some([1.0, 2.0, 3.0].as_slice()));
584 }
585
586 #[test]
587 fn test_put_and_get_sparse_vector() {
588 let store = create_test_store();
589 let collection_id = CollectionId::new(1);
590 let entity_id = EntityId::new(42);
591 let data = VectorData::Sparse(vec![(0, 1.0), (5, 2.0), (10, 3.0)]);
592
593 store.put_vector(collection_id, entity_id, "sparse", &data).unwrap();
594
595 let retrieved = store.get_vector(collection_id, entity_id, "sparse").unwrap().unwrap();
596
597 assert_eq!(retrieved.as_sparse(), Some([(0, 1.0), (5, 2.0), (10, 3.0)].as_slice()));
598 }
599
600 #[test]
601 fn test_put_and_get_multi_vector() {
602 let store = create_test_store();
603 let collection_id = CollectionId::new(1);
604 let entity_id = EntityId::new(42);
605 let data = VectorData::Multi(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
606
607 store.put_vector(collection_id, entity_id, "multi", &data).unwrap();
608
609 let retrieved = store.get_vector(collection_id, entity_id, "multi").unwrap().unwrap();
610
611 assert!(retrieved.is_multi());
612 }
613
614 #[test]
615 fn test_put_and_get_binary_vector() {
616 let store = create_test_store();
617 let collection_id = CollectionId::new(1);
618 let entity_id = EntityId::new(42);
619 let data = VectorData::Binary(vec![0xFF, 0x00, 0xAB]);
620
621 store.put_vector(collection_id, entity_id, "binary", &data).unwrap();
622
623 let retrieved = store.get_vector(collection_id, entity_id, "binary").unwrap().unwrap();
624
625 assert_eq!(retrieved.as_binary(), Some([0xFF, 0x00, 0xAB].as_slice()));
626 }
627
628 #[test]
629 fn test_get_nonexistent_vector() {
630 let store = create_test_store();
631 let collection_id = CollectionId::new(1);
632 let entity_id = EntityId::new(42);
633
634 let result = store.get_vector(collection_id, entity_id, "nonexistent").unwrap();
635
636 assert!(result.is_none());
637 }
638
639 #[test]
640 fn test_get_all_vectors() {
641 let store = create_test_store();
642 let collection_id = CollectionId::new(1);
643 let entity_id = EntityId::new(42);
644
645 store
646 .put_vector(collection_id, entity_id, "text", &VectorData::Dense(vec![1.0, 2.0]))
647 .unwrap();
648 store
649 .put_vector(collection_id, entity_id, "image", &VectorData::Dense(vec![3.0, 4.0]))
650 .unwrap();
651
652 let vectors = store.get_all_vectors(collection_id, entity_id).unwrap();
653
654 assert_eq!(vectors.len(), 2);
655 assert!(vectors.contains_key("text"));
656 assert!(vectors.contains_key("image"));
657 }
658
659 #[test]
660 fn test_delete_vector() {
661 let store = create_test_store();
662 let collection_id = CollectionId::new(1);
663 let entity_id = EntityId::new(42);
664
665 store.put_vector(collection_id, entity_id, "text", &VectorData::Dense(vec![1.0])).unwrap();
666
667 assert!(store.exists(collection_id, entity_id, "text").unwrap());
668
669 let deleted = store.delete_vector(collection_id, entity_id, "text").unwrap();
670 assert!(deleted);
671
672 assert!(!store.exists(collection_id, entity_id, "text").unwrap());
673
674 let deleted = store.delete_vector(collection_id, entity_id, "text").unwrap();
676 assert!(!deleted);
677 }
678
679 #[test]
680 fn test_delete_all_vectors() {
681 let store = create_test_store();
682 let collection_id = CollectionId::new(1);
683 let entity_id = EntityId::new(42);
684
685 store.put_vector(collection_id, entity_id, "text", &VectorData::Dense(vec![1.0])).unwrap();
686 store.put_vector(collection_id, entity_id, "image", &VectorData::Dense(vec![2.0])).unwrap();
687 store
688 .put_vector(collection_id, entity_id, "summary", &VectorData::Dense(vec![3.0]))
689 .unwrap();
690
691 let count = store.delete_all_vectors(collection_id, entity_id).unwrap();
692 assert_eq!(count, 3);
693
694 let vectors = store.get_all_vectors(collection_id, entity_id).unwrap();
695 assert!(vectors.is_empty());
696 }
697
698 #[test]
699 fn test_put_vectors_batch() {
700 let store = create_test_store();
701 let collection_id = CollectionId::new(1);
702
703 let text_data = VectorData::Dense(vec![1.0, 2.0]);
704 let image_data = VectorData::Dense(vec![3.0, 4.0]);
705
706 let vectors: Vec<(EntityId, &str, &VectorData)> = vec![
707 (EntityId::new(1), "text", &text_data),
708 (EntityId::new(1), "image", &image_data),
709 (EntityId::new(2), "text", &text_data),
710 ];
711
712 store.put_vectors_batch(collection_id, &vectors).unwrap();
713
714 assert!(store.exists(collection_id, EntityId::new(1), "text").unwrap());
715 assert!(store.exists(collection_id, EntityId::new(1), "image").unwrap());
716 assert!(store.exists(collection_id, EntityId::new(2), "text").unwrap());
717 }
718
719 #[test]
720 fn test_count_entity_vectors() {
721 let store = create_test_store();
722 let collection_id = CollectionId::new(1);
723 let entity_id = EntityId::new(42);
724
725 assert_eq!(store.count_entity_vectors(collection_id, entity_id).unwrap(), 0);
726
727 store.put_vector(collection_id, entity_id, "text", &VectorData::Dense(vec![1.0])).unwrap();
728 store.put_vector(collection_id, entity_id, "image", &VectorData::Dense(vec![2.0])).unwrap();
729
730 assert_eq!(store.count_entity_vectors(collection_id, entity_id).unwrap(), 2);
731 }
732
733 #[test]
734 fn test_list_entities_with_vector() {
735 let store = create_test_store();
736 let collection_id = CollectionId::new(1);
737
738 store
740 .put_vector(collection_id, EntityId::new(1), "text", &VectorData::Dense(vec![1.0]))
741 .unwrap();
742 store
743 .put_vector(collection_id, EntityId::new(2), "text", &VectorData::Dense(vec![2.0]))
744 .unwrap();
745 store
746 .put_vector(collection_id, EntityId::new(2), "image", &VectorData::Dense(vec![3.0]))
747 .unwrap();
748 store
749 .put_vector(collection_id, EntityId::new(3), "image", &VectorData::Dense(vec![4.0]))
750 .unwrap();
751
752 let text_entities = store.list_entities_with_vector(collection_id, "text").unwrap();
753 assert_eq!(text_entities.len(), 2);
754 assert!(text_entities.contains(&EntityId::new(1)));
755 assert!(text_entities.contains(&EntityId::new(2)));
756
757 let image_entities = store.list_entities_with_vector(collection_id, "image").unwrap();
758 assert_eq!(image_entities.len(), 2);
759 assert!(image_entities.contains(&EntityId::new(2)));
760 assert!(image_entities.contains(&EntityId::new(3)));
761 }
762
763 #[test]
764 fn test_update_vector() {
765 let store = create_test_store();
766 let collection_id = CollectionId::new(1);
767 let entity_id = EntityId::new(42);
768
769 store
771 .put_vector(collection_id, entity_id, "text", &VectorData::Dense(vec![1.0, 2.0]))
772 .unwrap();
773
774 store
776 .put_vector(collection_id, entity_id, "text", &VectorData::Dense(vec![3.0, 4.0, 5.0]))
777 .unwrap();
778
779 let retrieved = store.get_vector(collection_id, entity_id, "text").unwrap().unwrap();
780
781 assert_eq!(retrieved.as_dense(), Some([3.0, 4.0, 5.0].as_slice()));
782 }
783
784 #[test]
785 fn test_isolation_between_collections() {
786 let store = create_test_store();
787 let collection1 = CollectionId::new(1);
788 let collection2 = CollectionId::new(2);
789 let entity_id = EntityId::new(42);
790
791 store.put_vector(collection1, entity_id, "text", &VectorData::Dense(vec![1.0])).unwrap();
792 store.put_vector(collection2, entity_id, "text", &VectorData::Dense(vec![2.0])).unwrap();
793
794 let v1 = store.get_vector(collection1, entity_id, "text").unwrap().unwrap();
795 let v2 = store.get_vector(collection2, entity_id, "text").unwrap().unwrap();
796
797 assert_eq!(v1.as_dense(), Some([1.0].as_slice()));
798 assert_eq!(v2.as_dense(), Some([2.0].as_slice()));
799 }
800
801 #[test]
802 fn test_encode_decode_roundtrip() {
803 let data = VectorData::Dense(vec![1.0, 2.0, 3.0]);
804 let encoded = encode_vector_value(&data, "test_vector");
805 let (decoded, name) = decode_vector_value(&encoded).unwrap();
806
807 assert_eq!(decoded, data);
808 assert_eq!(name, "test_vector");
809 }
810}