1use std::cmp::Ordering;
40use std::collections::{BinaryHeap, HashMap};
41use std::ops::Bound;
42
43use manifoldb_core::PointId;
44use manifoldb_storage::{Cursor, StorageEngine, Transaction};
45
46use crate::encoding::{
47 encode_inverted_meta_collection_prefix, encode_inverted_meta_key,
48 encode_point_tokens_collection_prefix, encode_point_tokens_key, encode_point_tokens_prefix,
49 encode_posting_collection_prefix, encode_posting_key, encode_posting_prefix,
50};
51use crate::error::VectorError;
52
53const TABLE_POSTINGS: &str = "inverted_postings";
55
56const TABLE_META: &str = "inverted_meta";
58
59const TABLE_POINT_TOKENS: &str = "inverted_point_tokens";
61
62#[derive(Debug, Clone, Copy, PartialEq)]
64pub struct PostingEntry {
65 pub point_id: PointId,
67 pub weight: f32,
69}
70
71impl PostingEntry {
72 #[must_use]
74 pub const fn new(point_id: PointId, weight: f32) -> Self {
75 Self { point_id, weight }
76 }
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct PostingList {
82 entries: Vec<PostingEntry>,
84 max_weight: f32,
86}
87
88impl PostingList {
89 #[must_use]
91 pub const fn new() -> Self {
92 Self { entries: Vec::new(), max_weight: 0.0 }
93 }
94
95 #[must_use]
97 pub fn from_entries(mut entries: Vec<PostingEntry>) -> Self {
98 entries.sort_by_key(|e| e.point_id.as_u64());
99 let max_weight = entries.iter().map(|e| e.weight).fold(0.0f32, f32::max);
100 Self { entries, max_weight }
101 }
102
103 #[must_use]
105 pub fn entries(&self) -> &[PostingEntry] {
106 &self.entries
107 }
108
109 #[must_use]
111 pub fn max_weight(&self) -> f32 {
112 self.max_weight
113 }
114
115 #[must_use]
117 pub fn len(&self) -> usize {
118 self.entries.len()
119 }
120
121 #[must_use]
123 pub fn is_empty(&self) -> bool {
124 self.entries.is_empty()
125 }
126
127 pub fn add(&mut self, entry: PostingEntry) {
129 match self.entries.binary_search_by_key(&entry.point_id.as_u64(), |e| e.point_id.as_u64()) {
130 Ok(idx) => {
131 self.entries[idx] = entry;
133 }
134 Err(idx) => {
135 self.entries.insert(idx, entry);
137 }
138 }
139 self.max_weight = self.max_weight.max(entry.weight);
140 }
141
142 pub fn remove(&mut self, point_id: PointId) -> bool {
144 match self.entries.binary_search_by_key(&point_id.as_u64(), |e| e.point_id.as_u64()) {
145 Ok(idx) => {
146 let removed = self.entries.remove(idx);
147 if (removed.weight - self.max_weight).abs() < f32::EPSILON {
149 self.max_weight = self.entries.iter().map(|e| e.weight).fold(0.0f32, f32::max);
150 }
151 true
152 }
153 Err(_) => false,
154 }
155 }
156
157 #[must_use]
161 pub fn to_bytes(&self) -> Vec<u8> {
162 let mut bytes = Vec::with_capacity(8 + self.entries.len() * 12);
163 bytes.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
164 bytes.extend_from_slice(&self.max_weight.to_le_bytes());
165 for entry in &self.entries {
166 bytes.extend_from_slice(&entry.point_id.as_u64().to_le_bytes());
167 bytes.extend_from_slice(&entry.weight.to_le_bytes());
168 }
169 bytes
170 }
171
172 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
174 if bytes.len() < 8 {
175 return Err(VectorError::Encoding("posting list too short".to_string()));
176 }
177
178 let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
179 let max_weight = f32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
180
181 let expected_len = 8 + count * 12;
182 if bytes.len() != expected_len {
183 return Err(VectorError::Encoding(format!(
184 "posting list length mismatch: expected {}, got {}",
185 expected_len,
186 bytes.len()
187 )));
188 }
189
190 let mut entries = Vec::with_capacity(count);
191 for i in 0..count {
192 let offset = 8 + i * 12;
193 let point_id = u64::from_le_bytes([
194 bytes[offset],
195 bytes[offset + 1],
196 bytes[offset + 2],
197 bytes[offset + 3],
198 bytes[offset + 4],
199 bytes[offset + 5],
200 bytes[offset + 6],
201 bytes[offset + 7],
202 ]);
203 let weight = f32::from_le_bytes([
204 bytes[offset + 8],
205 bytes[offset + 9],
206 bytes[offset + 10],
207 bytes[offset + 11],
208 ]);
209 entries.push(PostingEntry::new(PointId::new(point_id), weight));
210 }
211
212 Ok(Self { entries, max_weight })
213 }
214}
215
216#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
218pub struct InvertedIndexMeta {
219 pub doc_count: u64,
221 pub total_tokens: u64,
223 pub avg_doc_length: f32,
225}
226
227impl InvertedIndexMeta {
228 #[must_use]
230 pub const fn new() -> Self {
231 Self { doc_count: 0, total_tokens: 0, avg_doc_length: 0.0 }
232 }
233
234 pub fn add_document(&mut self, token_count: usize) {
236 self.doc_count += 1;
237 self.total_tokens += token_count as u64;
238 self.avg_doc_length = self.total_tokens as f32 / self.doc_count as f32;
239 }
240
241 pub fn remove_document(&mut self, token_count: usize) {
243 if self.doc_count > 0 {
244 self.doc_count -= 1;
245 self.total_tokens = self.total_tokens.saturating_sub(token_count as u64);
246 if self.doc_count > 0 {
247 self.avg_doc_length = self.total_tokens as f32 / self.doc_count as f32;
248 } else {
249 self.avg_doc_length = 0.0;
250 }
251 }
252 }
253
254 #[must_use]
256 pub fn to_bytes(&self) -> Vec<u8> {
257 bincode::serde::encode_to_vec(self, bincode::config::standard()).unwrap_or_default()
258 }
259
260 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
262 bincode::serde::decode_from_slice(bytes, bincode::config::standard())
263 .map(|(v, _)| v)
264 .map_err(|e| VectorError::Encoding(format!("failed to deserialize index meta: {}", e)))
265 }
266}
267
268impl Default for InvertedIndexMeta {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
276pub enum ScoringFunction {
277 DotProduct,
279 Bm25 {
281 k1_times_10: u8,
283 b_times_100: u8,
285 },
286}
287
288impl Default for ScoringFunction {
289 fn default() -> Self {
290 Self::DotProduct
291 }
292}
293
294impl ScoringFunction {
295 #[must_use]
297 pub const fn bm25() -> Self {
298 Self::Bm25 { k1_times_10: 12, b_times_100: 75 }
299 }
300
301 #[must_use]
303 pub fn bm25_custom(k1: f32, b: f32) -> Self {
304 Self::Bm25 {
305 k1_times_10: (k1 * 10.0).clamp(0.0, 255.0) as u8,
306 b_times_100: (b * 100.0).clamp(0.0, 255.0) as u8,
307 }
308 }
309}
310
311#[derive(Debug, Clone, Copy)]
313pub struct SearchResult {
314 pub point_id: PointId,
316 pub score: f32,
318}
319
320impl SearchResult {
321 #[must_use]
323 pub const fn new(point_id: PointId, score: f32) -> Self {
324 Self { point_id, score }
325 }
326}
327
328impl PartialEq for SearchResult {
329 fn eq(&self, other: &Self) -> bool {
330 self.point_id == other.point_id
331 }
332}
333
334impl Eq for SearchResult {}
335
336impl PartialOrd for SearchResult {
337 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
338 Some(self.cmp(other))
339 }
340}
341
342impl Ord for SearchResult {
343 fn cmp(&self, other: &Self) -> Ordering {
344 other.score.partial_cmp(&self.score).unwrap_or(Ordering::Equal)
346 }
347}
348
349pub struct InvertedIndex<E: StorageEngine> {
351 engine: E,
352}
353
354impl<E: StorageEngine> InvertedIndex<E> {
355 #[must_use]
357 pub const fn new(engine: E) -> Self {
358 Self { engine }
359 }
360
361 #[must_use]
363 pub fn engine(&self) -> &E {
364 &self.engine
365 }
366
367 pub fn insert(
380 &self,
381 collection: &str,
382 vector_name: &str,
383 point_id: PointId,
384 vector: &[(u32, f32)],
385 ) -> Result<(), VectorError> {
386 if vector.is_empty() {
387 return Ok(());
388 }
389
390 let mut tx = self.engine.begin_write()?;
391
392 let meta_key = encode_inverted_meta_key(collection, vector_name);
394 let mut meta = tx
395 .get(TABLE_META, &meta_key)?
396 .map(|bytes| InvertedIndexMeta::from_bytes(&bytes))
397 .transpose()?
398 .unwrap_or_default();
399
400 let token_ids: Vec<u32> = vector.iter().map(|(idx, _)| *idx).collect();
402 let point_tokens_key = encode_point_tokens_key(collection, vector_name, point_id);
403 tx.put(TABLE_POINT_TOKENS, &point_tokens_key, &encode_token_ids(&token_ids))?;
404
405 for &(token_id, weight) in vector {
407 let posting_key = encode_posting_key(collection, vector_name, token_id);
408
409 let mut posting_list = tx
411 .get(TABLE_POSTINGS, &posting_key)?
412 .map(|bytes| PostingList::from_bytes(&bytes))
413 .transpose()?
414 .unwrap_or_default();
415
416 posting_list.add(PostingEntry::new(point_id, weight));
417 tx.put(TABLE_POSTINGS, &posting_key, &posting_list.to_bytes())?;
418 }
419
420 meta.add_document(vector.len());
422 tx.put(TABLE_META, &meta_key, &meta.to_bytes())?;
423
424 tx.commit()?;
425 Ok(())
426 }
427
428 pub fn delete(
434 &self,
435 collection: &str,
436 vector_name: &str,
437 point_id: PointId,
438 ) -> Result<bool, VectorError> {
439 let mut tx = self.engine.begin_write()?;
440
441 let point_tokens_key = encode_point_tokens_key(collection, vector_name, point_id);
443 let token_ids = match tx.get(TABLE_POINT_TOKENS, &point_tokens_key)? {
444 Some(bytes) => decode_token_ids(&bytes)?,
445 None => return Ok(false),
446 };
447
448 let meta_key = encode_inverted_meta_key(collection, vector_name);
450 let mut meta = tx
451 .get(TABLE_META, &meta_key)?
452 .map(|bytes| InvertedIndexMeta::from_bytes(&bytes))
453 .transpose()?
454 .unwrap_or_default();
455
456 for token_id in &token_ids {
458 let posting_key = encode_posting_key(collection, vector_name, *token_id);
459
460 if let Some(bytes) = tx.get(TABLE_POSTINGS, &posting_key)? {
461 let mut posting_list = PostingList::from_bytes(&bytes)?;
462 posting_list.remove(point_id);
463
464 if posting_list.is_empty() {
465 tx.delete(TABLE_POSTINGS, &posting_key)?;
466 } else {
467 tx.put(TABLE_POSTINGS, &posting_key, &posting_list.to_bytes())?;
468 }
469 }
470 }
471
472 tx.delete(TABLE_POINT_TOKENS, &point_tokens_key)?;
474
475 meta.remove_document(token_ids.len());
477 tx.put(TABLE_META, &meta_key, &meta.to_bytes())?;
478
479 tx.commit()?;
480 Ok(true)
481 }
482
483 pub fn update(
485 &self,
486 collection: &str,
487 vector_name: &str,
488 point_id: PointId,
489 vector: &[(u32, f32)],
490 ) -> Result<(), VectorError> {
491 self.delete(collection, vector_name, point_id)?;
493 self.insert(collection, vector_name, point_id, vector)
495 }
496
497 pub fn delete_collection(&self, collection: &str) -> Result<(), VectorError> {
499 let mut tx = self.engine.begin_write()?;
500
501 delete_by_prefix(&mut tx, TABLE_POSTINGS, &encode_posting_collection_prefix(collection))?;
503
504 delete_by_prefix(
506 &mut tx,
507 TABLE_POINT_TOKENS,
508 &encode_point_tokens_collection_prefix(collection),
509 )?;
510
511 delete_by_prefix(&mut tx, TABLE_META, &encode_inverted_meta_collection_prefix(collection))?;
513
514 tx.commit()?;
515 Ok(())
516 }
517
518 pub fn delete_vector(&self, collection: &str, vector_name: &str) -> Result<(), VectorError> {
520 let mut tx = self.engine.begin_write()?;
521
522 delete_by_prefix(&mut tx, TABLE_POSTINGS, &encode_posting_prefix(collection, vector_name))?;
524
525 delete_by_prefix(
527 &mut tx,
528 TABLE_POINT_TOKENS,
529 &encode_point_tokens_prefix(collection, vector_name),
530 )?;
531
532 let meta_key = encode_inverted_meta_key(collection, vector_name);
534 tx.delete(TABLE_META, &meta_key)?;
535
536 tx.commit()?;
537 Ok(())
538 }
539
540 pub fn get_meta(
546 &self,
547 collection: &str,
548 vector_name: &str,
549 ) -> Result<InvertedIndexMeta, VectorError> {
550 let tx = self.engine.begin_read()?;
551 let meta_key = encode_inverted_meta_key(collection, vector_name);
552 tx.get(TABLE_META, &meta_key)?
553 .map(|bytes| InvertedIndexMeta::from_bytes(&bytes))
554 .transpose()?
555 .ok_or_else(|| {
556 VectorError::SpaceNotFound(format!("index '{}/{}'", collection, vector_name))
557 })
558 }
559
560 pub fn get_posting_list(
562 &self,
563 collection: &str,
564 vector_name: &str,
565 token_id: u32,
566 ) -> Result<Option<PostingList>, VectorError> {
567 let tx = self.engine.begin_read()?;
568 let posting_key = encode_posting_key(collection, vector_name, token_id);
569 tx.get(TABLE_POSTINGS, &posting_key)?
570 .map(|bytes| PostingList::from_bytes(&bytes))
571 .transpose()
572 }
573
574 pub fn search_daat(
591 &self,
592 collection: &str,
593 vector_name: &str,
594 query: &[(u32, f32)],
595 top_k: usize,
596 scoring: ScoringFunction,
597 ) -> Result<Vec<SearchResult>, VectorError> {
598 if query.is_empty() || top_k == 0 {
599 return Ok(Vec::new());
600 }
601
602 let tx = self.engine.begin_read()?;
603
604 let meta = if matches!(scoring, ScoringFunction::Bm25 { .. }) {
606 let meta_key = encode_inverted_meta_key(collection, vector_name);
607 tx.get(TABLE_META, &meta_key)?
608 .map(|bytes| InvertedIndexMeta::from_bytes(&bytes))
609 .transpose()?
610 } else {
611 None
612 };
613
614 let mut posting_lists: Vec<(u32, f32, PostingList)> = Vec::with_capacity(query.len());
616 for &(token_id, query_weight) in query {
617 let posting_key = encode_posting_key(collection, vector_name, token_id);
618 if let Some(bytes) = tx.get(TABLE_POSTINGS, &posting_key)? {
619 let posting_list = PostingList::from_bytes(&bytes)?;
620 if !posting_list.is_empty() {
621 posting_lists.push((token_id, query_weight, posting_list));
622 }
623 }
624 }
625
626 if posting_lists.is_empty() {
627 return Ok(Vec::new());
628 }
629
630 let mut scores: HashMap<u64, f32> = HashMap::new();
632
633 for (token_id, query_weight, posting_list) in &posting_lists {
634 for entry in posting_list.entries() {
635 let doc_id = entry.point_id.as_u64();
636 let term_score = match scoring {
637 ScoringFunction::DotProduct => query_weight * entry.weight,
638 ScoringFunction::Bm25 { k1_times_10, b_times_100 } => {
639 let k1 = k1_times_10 as f32 / 10.0;
640 let b = b_times_100 as f32 / 100.0;
641 compute_bm25_term_score(
642 *query_weight,
643 entry.weight,
644 meta.as_ref(),
645 *token_id,
646 posting_list.len(),
647 k1,
648 b,
649 )
650 }
651 };
652 *scores.entry(doc_id).or_insert(0.0) += term_score;
653 }
654 }
655
656 let mut results: Vec<SearchResult> = scores
658 .into_iter()
659 .map(|(doc_id, score)| SearchResult::new(PointId::new(doc_id), score))
660 .collect();
661
662 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
664 results.truncate(top_k);
665
666 Ok(results)
667 }
668
669 pub fn search_wand(
681 &self,
682 collection: &str,
683 vector_name: &str,
684 query: &[(u32, f32)],
685 top_k: usize,
686 ) -> Result<Vec<SearchResult>, VectorError> {
687 if query.is_empty() || top_k == 0 {
688 return Ok(Vec::new());
689 }
690
691 let tx = self.engine.begin_read()?;
692
693 let mut cursors: Vec<WandCursor> = Vec::with_capacity(query.len());
695 for &(token_id, query_weight) in query {
696 let posting_key = encode_posting_key(collection, vector_name, token_id);
697 if let Some(bytes) = tx.get(TABLE_POSTINGS, &posting_key)? {
698 let posting_list = PostingList::from_bytes(&bytes)?;
699 if !posting_list.is_empty() {
700 let upper_bound = query_weight * posting_list.max_weight();
701 cursors.push(WandCursor::new(posting_list, query_weight, upper_bound));
702 }
703 }
704 }
705
706 if cursors.is_empty() {
707 return Ok(Vec::new());
708 }
709
710 let mut heap: BinaryHeap<SearchResult> = BinaryHeap::with_capacity(top_k + 1);
712 let mut threshold = 0.0f32;
713
714 loop {
715 cursors.sort_by_key(|c| c.current_doc_id());
717
718 let first_valid = cursors.iter().position(|c| !c.exhausted());
720 if first_valid.is_none() {
721 break;
722 }
723
724 let mut upper_sum = 0.0f32;
726 let mut pivot_idx = None;
727
728 for (i, cursor) in cursors.iter().enumerate() {
729 if cursor.exhausted() {
730 continue;
731 }
732 upper_sum += cursor.upper_bound;
733 if upper_sum >= threshold {
734 pivot_idx = Some(i);
735 break;
736 }
737 }
738
739 let pivot_idx = match pivot_idx {
740 Some(idx) => idx,
741 None => break, };
743
744 let pivot_doc_id = cursors[pivot_idx].current_doc_id();
745
746 let all_aligned = cursors[..pivot_idx]
748 .iter()
749 .filter(|c| !c.exhausted())
750 .all(|c| c.current_doc_id() == pivot_doc_id);
751
752 if all_aligned || pivot_idx == 0 {
753 let mut score = 0.0f32;
755 for cursor in &cursors {
756 if !cursor.exhausted() && cursor.current_doc_id() == pivot_doc_id {
757 if let Some(entry) = cursor.current_entry() {
758 score += cursor.query_weight * entry.weight;
759 }
760 }
761 }
762
763 if score > threshold || heap.len() < top_k {
764 heap.push(SearchResult::new(PointId::new(pivot_doc_id), score));
765 if heap.len() > top_k {
766 heap.pop();
767 }
768 if heap.len() == top_k {
769 threshold = heap.peek().map_or(0.0, |r| r.score);
770 }
771 }
772
773 for cursor in &mut cursors {
775 if !cursor.exhausted() && cursor.current_doc_id() == pivot_doc_id {
776 cursor.advance();
777 }
778 }
779 } else {
780 for cursor in &mut cursors[..pivot_idx] {
782 if !cursor.exhausted() {
783 cursor.advance_to(pivot_doc_id);
784 }
785 }
786 }
787
788 cursors.retain(|c| !c.exhausted());
790 if cursors.is_empty() {
791 break;
792 }
793 }
794
795 let mut results: Vec<SearchResult> = heap.into_vec();
797 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
798 Ok(results)
799 }
800
801 pub fn search_maxscore(
805 &self,
806 collection: &str,
807 vector_name: &str,
808 query: &[(u32, f32)],
809 top_k: usize,
810 ) -> Result<Vec<SearchResult>, VectorError> {
811 self.search_wand(collection, vector_name, query, top_k)
814 }
815}
816
817struct WandCursor {
819 posting_list: PostingList,
820 position: usize,
821 query_weight: f32,
822 upper_bound: f32,
823}
824
825impl WandCursor {
826 fn new(posting_list: PostingList, query_weight: f32, upper_bound: f32) -> Self {
827 Self { posting_list, position: 0, query_weight, upper_bound }
828 }
829
830 fn exhausted(&self) -> bool {
831 self.position >= self.posting_list.len()
832 }
833
834 fn current_doc_id(&self) -> u64 {
835 if self.exhausted() {
836 u64::MAX
837 } else {
838 self.posting_list.entries()[self.position].point_id.as_u64()
839 }
840 }
841
842 fn current_entry(&self) -> Option<&PostingEntry> {
843 if self.exhausted() {
844 None
845 } else {
846 Some(&self.posting_list.entries()[self.position])
847 }
848 }
849
850 fn advance(&mut self) {
851 if !self.exhausted() {
852 self.position += 1;
853 }
854 }
855
856 fn advance_to(&mut self, doc_id: u64) {
857 while !self.exhausted() && self.current_doc_id() < doc_id {
858 self.position += 1;
859 }
860 }
861}
862
863fn compute_bm25_term_score(
865 query_weight: f32,
866 doc_weight: f32,
867 meta: Option<&InvertedIndexMeta>,
868 _token_id: u32,
869 df: usize,
870 k1: f32,
871 b: f32,
872) -> f32 {
873 let meta = match meta {
874 Some(m) => m,
875 None => return query_weight * doc_weight, };
877
878 if meta.doc_count == 0 {
879 return 0.0;
880 }
881
882 let n = meta.doc_count as f32;
884 let df = df as f32;
885 let idf = ((n - df + 0.5) / (df + 0.5)).ln_1p();
886
887 let tf = doc_weight;
890 let avg_dl = meta.avg_doc_length.max(1.0);
891 let dl = doc_weight;
893
894 let tf_component = (tf * (k1 + 1.0)) / (tf + k1 * (1.0 - b + b * (dl / avg_dl)));
895
896 query_weight * idf * tf_component
897}
898
899fn encode_token_ids(token_ids: &[u32]) -> Vec<u8> {
901 let mut bytes = Vec::with_capacity(4 + token_ids.len() * 4);
902 bytes.extend_from_slice(&(token_ids.len() as u32).to_le_bytes());
903 for &token_id in token_ids {
904 bytes.extend_from_slice(&token_id.to_le_bytes());
905 }
906 bytes
907}
908
909fn decode_token_ids(bytes: &[u8]) -> Result<Vec<u32>, VectorError> {
911 if bytes.len() < 4 {
912 return Err(VectorError::Encoding("token ids too short".to_string()));
913 }
914
915 let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
916 let expected_len = 4 + count * 4;
917
918 if bytes.len() != expected_len {
919 return Err(VectorError::Encoding(format!(
920 "token ids length mismatch: expected {}, got {}",
921 expected_len,
922 bytes.len()
923 )));
924 }
925
926 let mut token_ids = Vec::with_capacity(count);
927 for i in 0..count {
928 let offset = 4 + i * 4;
929 let token_id = u32::from_le_bytes([
930 bytes[offset],
931 bytes[offset + 1],
932 bytes[offset + 2],
933 bytes[offset + 3],
934 ]);
935 token_ids.push(token_id);
936 }
937
938 Ok(token_ids)
939}
940
941fn next_prefix(prefix: &[u8]) -> Vec<u8> {
943 let mut result = prefix.to_vec();
944
945 for byte in result.iter_mut().rev() {
946 if *byte < 0xFF {
947 *byte += 1;
948 return result;
949 }
950 }
951
952 result.push(0xFF);
953 result
954}
955
956fn delete_by_prefix<T: Transaction>(
958 tx: &mut T,
959 table: &str,
960 prefix: &[u8],
961) -> Result<(), VectorError> {
962 let prefix_end = next_prefix(prefix);
963
964 let mut keys_to_delete = Vec::new();
965 {
966 let mut cursor =
967 tx.range(table, Bound::Included(prefix), Bound::Excluded(prefix_end.as_slice()))?;
968
969 while let Some((key, _)) = cursor.next()? {
970 keys_to_delete.push(key);
971 }
972 }
973
974 for key in keys_to_delete {
975 tx.delete(table, &key)?;
976 }
977
978 Ok(())
979}
980
981#[cfg(test)]
982mod tests {
983 use super::*;
984 use manifoldb_storage::backends::RedbEngine;
985 use std::sync::atomic::{AtomicUsize, Ordering};
986
987 static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0);
988
989 fn create_test_index() -> InvertedIndex<RedbEngine> {
990 let engine = RedbEngine::in_memory().unwrap();
991 InvertedIndex::new(engine)
992 }
993
994 fn unique_name(prefix: &str) -> String {
995 let count = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
996 format!("{}_{}", prefix, count)
997 }
998
999 #[test]
1000 fn posting_list_roundtrip() {
1001 let mut list = PostingList::new();
1002 list.add(PostingEntry::new(PointId::new(1), 0.5));
1003 list.add(PostingEntry::new(PointId::new(3), 0.3));
1004 list.add(PostingEntry::new(PointId::new(2), 0.8));
1005
1006 let bytes = list.to_bytes();
1007 let restored = PostingList::from_bytes(&bytes).unwrap();
1008
1009 assert_eq!(restored.len(), 3);
1010 assert!((restored.max_weight() - 0.8).abs() < 1e-6);
1011
1012 assert_eq!(restored.entries()[0].point_id, PointId::new(1));
1014 assert_eq!(restored.entries()[1].point_id, PointId::new(2));
1015 assert_eq!(restored.entries()[2].point_id, PointId::new(3));
1016 }
1017
1018 #[test]
1019 fn posting_list_remove() {
1020 let mut list = PostingList::new();
1021 list.add(PostingEntry::new(PointId::new(1), 0.5));
1022 list.add(PostingEntry::new(PointId::new(2), 0.8));
1023 list.add(PostingEntry::new(PointId::new(3), 0.3));
1024
1025 assert!(list.remove(PointId::new(2)));
1026 assert_eq!(list.len(), 2);
1027 assert!((list.max_weight() - 0.5).abs() < 1e-6);
1028
1029 assert!(!list.remove(PointId::new(2))); }
1031
1032 #[test]
1033 fn insert_and_search() {
1034 let index = create_test_index();
1035 let collection = unique_name("collection");
1036 let vector = "keywords";
1037
1038 index.insert(&collection, vector, PointId::new(1), &[(100, 0.5), (200, 0.3)]).unwrap();
1040 index.insert(&collection, vector, PointId::new(2), &[(100, 0.8), (300, 0.2)]).unwrap();
1041 index.insert(&collection, vector, PointId::new(3), &[(200, 0.6), (300, 0.4)]).unwrap();
1042
1043 let query = vec![(100, 1.0), (200, 0.5)];
1045 let results = index
1046 .search_daat(&collection, vector, &query, 10, ScoringFunction::DotProduct)
1047 .unwrap();
1048
1049 assert!(!results.is_empty());
1050 assert_eq!(results[0].point_id, PointId::new(2)); }
1056
1057 #[test]
1058 fn delete_document() {
1059 let index = create_test_index();
1060 let collection = unique_name("collection");
1061 let vector = "keywords";
1062
1063 index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1064 index.insert(&collection, vector, PointId::new(2), &[(100, 0.8)]).unwrap();
1065
1066 let results = index
1068 .search_daat(&collection, vector, &[(100, 1.0)], 10, ScoringFunction::DotProduct)
1069 .unwrap();
1070 assert_eq!(results.len(), 2);
1071
1072 assert!(index.delete(&collection, vector, PointId::new(1)).unwrap());
1074
1075 let results = index
1077 .search_daat(&collection, vector, &[(100, 1.0)], 10, ScoringFunction::DotProduct)
1078 .unwrap();
1079 assert_eq!(results.len(), 1);
1080 assert_eq!(results[0].point_id, PointId::new(2));
1081 }
1082
1083 #[test]
1084 fn update_document() {
1085 let index = create_test_index();
1086 let collection = unique_name("collection");
1087 let vector = "keywords";
1088
1089 index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1090
1091 index.update(&collection, vector, PointId::new(1), &[(200, 0.9)]).unwrap();
1093
1094 let results = index
1096 .search_daat(&collection, vector, &[(100, 1.0)], 10, ScoringFunction::DotProduct)
1097 .unwrap();
1098 assert!(results.is_empty());
1099
1100 let results = index
1102 .search_daat(&collection, vector, &[(200, 1.0)], 10, ScoringFunction::DotProduct)
1103 .unwrap();
1104 assert_eq!(results.len(), 1);
1105 }
1106
1107 #[test]
1108 fn wand_search() {
1109 let index = create_test_index();
1110 let collection = unique_name("collection");
1111 let vector = "keywords";
1112
1113 for i in 0..100 {
1115 let weight = (i as f32 + 1.0) / 100.0;
1116 index
1117 .insert(&collection, vector, PointId::new(i), &[(100, weight), (200, weight * 0.5)])
1118 .unwrap();
1119 }
1120
1121 let results = index.search_wand(&collection, vector, &[(100, 1.0), (200, 0.5)], 5).unwrap();
1123
1124 assert_eq!(results.len(), 5);
1125 for i in 0..4 {
1127 assert!(
1128 results[i].score >= results[i + 1].score,
1129 "Results should be sorted by score: {} >= {}",
1130 results[i].score,
1131 results[i + 1].score
1132 );
1133 }
1134
1135 let daat_results = index
1137 .search_daat(
1138 &collection,
1139 vector,
1140 &[(100, 1.0), (200, 0.5)],
1141 5,
1142 ScoringFunction::DotProduct,
1143 )
1144 .unwrap();
1145
1146 assert_eq!(results.len(), daat_results.len());
1148 for (wand_r, daat_r) in results.iter().zip(daat_results.iter()) {
1149 assert_eq!(wand_r.point_id, daat_r.point_id);
1150 assert!((wand_r.score - daat_r.score).abs() < 1e-5);
1151 }
1152 }
1153
1154 #[test]
1155 fn metadata_tracking() {
1156 let index = create_test_index();
1157 let collection = unique_name("collection");
1158 let vector = "keywords";
1159
1160 index.insert(&collection, vector, PointId::new(1), &[(100, 0.5), (200, 0.3)]).unwrap();
1161 index.insert(&collection, vector, PointId::new(2), &[(100, 0.8)]).unwrap();
1162
1163 let meta = index.get_meta(&collection, vector).unwrap();
1164 assert_eq!(meta.doc_count, 2);
1165 assert_eq!(meta.total_tokens, 3);
1166 assert!((meta.avg_doc_length - 1.5).abs() < 0.01);
1167
1168 index.delete(&collection, vector, PointId::new(1)).unwrap();
1169
1170 let meta = index.get_meta(&collection, vector).unwrap();
1171 assert_eq!(meta.doc_count, 1);
1172 assert_eq!(meta.total_tokens, 1);
1173 }
1174
1175 #[test]
1176 fn bm25_scoring() {
1177 let index = create_test_index();
1178 let collection = unique_name("collection");
1179 let vector = "keywords";
1180
1181 index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1182 index.insert(&collection, vector, PointId::new(2), &[(100, 0.8)]).unwrap();
1183
1184 let results = index
1185 .search_daat(&collection, vector, &[(100, 1.0)], 10, ScoringFunction::bm25())
1186 .unwrap();
1187
1188 assert_eq!(results.len(), 2);
1189 assert_eq!(results[0].point_id, PointId::new(2));
1191 }
1192
1193 #[test]
1194 fn empty_query() {
1195 let index = create_test_index();
1196 let collection = unique_name("collection");
1197 let vector = "keywords";
1198
1199 index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1200
1201 let results =
1202 index.search_daat(&collection, vector, &[], 10, ScoringFunction::DotProduct).unwrap();
1203 assert!(results.is_empty());
1204
1205 let results = index.search_wand(&collection, vector, &[], 10).unwrap();
1206 assert!(results.is_empty());
1207 }
1208
1209 #[test]
1210 fn no_matching_tokens() {
1211 let index = create_test_index();
1212 let collection = unique_name("collection");
1213 let vector = "keywords";
1214
1215 index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1216
1217 let results = index
1219 .search_daat(&collection, vector, &[(999, 1.0)], 10, ScoringFunction::DotProduct)
1220 .unwrap();
1221 assert!(results.is_empty());
1222 }
1223
1224 #[test]
1225 fn delete_vector_index() {
1226 let index = create_test_index();
1227 let collection = unique_name("collection");
1228
1229 index.insert(&collection, "v1", PointId::new(1), &[(100, 0.5)]).unwrap();
1230 index.insert(&collection, "v2", PointId::new(1), &[(100, 0.8)]).unwrap();
1231
1232 index.delete_vector(&collection, "v1").unwrap();
1234
1235 assert!(index.get_meta(&collection, "v1").is_err());
1237
1238 assert!(index.get_meta(&collection, "v2").is_ok());
1240 }
1241}