manifoldb_vector/store/
inverted_index.rs

1//! Inverted index for sparse vector similarity search.
2//!
3//! This module provides an inverted index implementation for SPLADE-style sparse vectors,
4//! enabling efficient top-k retrieval using algorithms like WAND and DAAT.
5//!
6//! # Storage Tables
7//!
8//! - `inverted_postings`: Posting lists mapping token_id → [(point_id, weight), ...]
9//! - `inverted_meta`: Index metadata (doc count, statistics)
10//! - `inverted_point_tokens`: Reverse mapping point_id → [token_ids] for deletion
11//!
12//! # Search Algorithms
13//!
14//! - **DAAT (Document-at-a-time)**: Exact scoring by traversing all posting lists
15//! - **WAND (Weak AND)**: Top-k retrieval with early termination
16//!
17//! # Scoring Functions
18//!
19//! - **Dot product**: Standard for SPLADE vectors
20//! - **BM25-style**: Optional term frequency normalization
21//!
22//! # Example
23//!
24//! ```ignore
25//! use manifoldb_vector::store::InvertedIndex;
26//! use manifoldb_core::PointId;
27//!
28//! let index = InvertedIndex::new(engine);
29//!
30//! // Index a sparse vector
31//! let vector = vec![(100, 0.5), (200, 0.3), (300, 0.2)];
32//! index.insert("documents", "keywords", PointId::new(1), &vector)?;
33//!
34//! // Search for similar vectors
35//! let query = vec![(100, 1.0), (200, 0.8)];
36//! let results = index.search_wand("documents", "keywords", &query, 10)?;
37//! ```
38
39use std::cmp::Ordering;
40use std::collections::{BinaryHeap, HashMap};
41use std::ops::Bound;
42
43use manifoldb_core::PointId;
44use manifoldb_storage::{Cursor, StorageEngine, Transaction};
45
46use crate::encoding::{
47    encode_inverted_meta_collection_prefix, encode_inverted_meta_key,
48    encode_point_tokens_collection_prefix, encode_point_tokens_key, encode_point_tokens_prefix,
49    encode_posting_collection_prefix, encode_posting_key, encode_posting_prefix,
50};
51use crate::error::VectorError;
52
53/// Table name for posting lists.
54const TABLE_POSTINGS: &str = "inverted_postings";
55
56/// Table name for index metadata.
57const TABLE_META: &str = "inverted_meta";
58
59/// Table name for point-to-tokens reverse mapping.
60const TABLE_POINT_TOKENS: &str = "inverted_point_tokens";
61
62/// A posting list entry: (point_id, weight).
63#[derive(Debug, Clone, Copy, PartialEq)]
64pub struct PostingEntry {
65    /// The point ID.
66    pub point_id: PointId,
67    /// The weight (term frequency or TF-IDF weight).
68    pub weight: f32,
69}
70
71impl PostingEntry {
72    /// Create a new posting entry.
73    #[must_use]
74    pub const fn new(point_id: PointId, weight: f32) -> Self {
75        Self { point_id, weight }
76    }
77}
78
79/// A posting list for a single token.
80#[derive(Debug, Clone, Default)]
81pub struct PostingList {
82    /// Entries sorted by point_id for efficient merging.
83    entries: Vec<PostingEntry>,
84    /// Maximum weight in this posting list (for WAND upper bound).
85    max_weight: f32,
86}
87
88impl PostingList {
89    /// Create an empty posting list.
90    #[must_use]
91    pub const fn new() -> Self {
92        Self { entries: Vec::new(), max_weight: 0.0 }
93    }
94
95    /// Create a posting list from entries.
96    #[must_use]
97    pub fn from_entries(mut entries: Vec<PostingEntry>) -> Self {
98        entries.sort_by_key(|e| e.point_id.as_u64());
99        let max_weight = entries.iter().map(|e| e.weight).fold(0.0f32, f32::max);
100        Self { entries, max_weight }
101    }
102
103    /// Get the entries.
104    #[must_use]
105    pub fn entries(&self) -> &[PostingEntry] {
106        &self.entries
107    }
108
109    /// Get the maximum weight.
110    #[must_use]
111    pub fn max_weight(&self) -> f32 {
112        self.max_weight
113    }
114
115    /// Get the number of entries.
116    #[must_use]
117    pub fn len(&self) -> usize {
118        self.entries.len()
119    }
120
121    /// Check if empty.
122    #[must_use]
123    pub fn is_empty(&self) -> bool {
124        self.entries.is_empty()
125    }
126
127    /// Add an entry, maintaining sort order.
128    pub fn add(&mut self, entry: PostingEntry) {
129        match self.entries.binary_search_by_key(&entry.point_id.as_u64(), |e| e.point_id.as_u64()) {
130            Ok(idx) => {
131                // Update existing entry
132                self.entries[idx] = entry;
133            }
134            Err(idx) => {
135                // Insert new entry
136                self.entries.insert(idx, entry);
137            }
138        }
139        self.max_weight = self.max_weight.max(entry.weight);
140    }
141
142    /// Remove an entry by point_id.
143    pub fn remove(&mut self, point_id: PointId) -> bool {
144        match self.entries.binary_search_by_key(&point_id.as_u64(), |e| e.point_id.as_u64()) {
145            Ok(idx) => {
146                let removed = self.entries.remove(idx);
147                // Recalculate max_weight if we removed the max
148                if (removed.weight - self.max_weight).abs() < f32::EPSILON {
149                    self.max_weight = self.entries.iter().map(|e| e.weight).fold(0.0f32, f32::max);
150                }
151                true
152            }
153            Err(_) => false,
154        }
155    }
156
157    /// Serialize to bytes.
158    ///
159    /// Format: [count: u32][max_weight: f32][(point_id: u64, weight: f32), ...]
160    #[must_use]
161    pub fn to_bytes(&self) -> Vec<u8> {
162        let mut bytes = Vec::with_capacity(8 + self.entries.len() * 12);
163        bytes.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
164        bytes.extend_from_slice(&self.max_weight.to_le_bytes());
165        for entry in &self.entries {
166            bytes.extend_from_slice(&entry.point_id.as_u64().to_le_bytes());
167            bytes.extend_from_slice(&entry.weight.to_le_bytes());
168        }
169        bytes
170    }
171
172    /// Deserialize from bytes.
173    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
174        if bytes.len() < 8 {
175            return Err(VectorError::Encoding("posting list too short".to_string()));
176        }
177
178        let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
179        let max_weight = f32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
180
181        let expected_len = 8 + count * 12;
182        if bytes.len() != expected_len {
183            return Err(VectorError::Encoding(format!(
184                "posting list length mismatch: expected {}, got {}",
185                expected_len,
186                bytes.len()
187            )));
188        }
189
190        let mut entries = Vec::with_capacity(count);
191        for i in 0..count {
192            let offset = 8 + i * 12;
193            let point_id = u64::from_le_bytes([
194                bytes[offset],
195                bytes[offset + 1],
196                bytes[offset + 2],
197                bytes[offset + 3],
198                bytes[offset + 4],
199                bytes[offset + 5],
200                bytes[offset + 6],
201                bytes[offset + 7],
202            ]);
203            let weight = f32::from_le_bytes([
204                bytes[offset + 8],
205                bytes[offset + 9],
206                bytes[offset + 10],
207                bytes[offset + 11],
208            ]);
209            entries.push(PostingEntry::new(PointId::new(point_id), weight));
210        }
211
212        Ok(Self { entries, max_weight })
213    }
214}
215
216/// Inverted index metadata for a vector.
217#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
218pub struct InvertedIndexMeta {
219    /// Number of documents indexed.
220    pub doc_count: u64,
221    /// Total number of tokens across all documents.
222    pub total_tokens: u64,
223    /// Average document length (number of non-zero tokens).
224    pub avg_doc_length: f32,
225}
226
227impl InvertedIndexMeta {
228    /// Create new metadata.
229    #[must_use]
230    pub const fn new() -> Self {
231        Self { doc_count: 0, total_tokens: 0, avg_doc_length: 0.0 }
232    }
233
234    /// Update statistics after adding a document.
235    pub fn add_document(&mut self, token_count: usize) {
236        self.doc_count += 1;
237        self.total_tokens += token_count as u64;
238        self.avg_doc_length = self.total_tokens as f32 / self.doc_count as f32;
239    }
240
241    /// Update statistics after removing a document.
242    pub fn remove_document(&mut self, token_count: usize) {
243        if self.doc_count > 0 {
244            self.doc_count -= 1;
245            self.total_tokens = self.total_tokens.saturating_sub(token_count as u64);
246            if self.doc_count > 0 {
247                self.avg_doc_length = self.total_tokens as f32 / self.doc_count as f32;
248            } else {
249                self.avg_doc_length = 0.0;
250            }
251        }
252    }
253
254    /// Serialize to bytes.
255    #[must_use]
256    pub fn to_bytes(&self) -> Vec<u8> {
257        bincode::serde::encode_to_vec(self, bincode::config::standard()).unwrap_or_default()
258    }
259
260    /// Deserialize from bytes.
261    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
262        bincode::serde::decode_from_slice(bytes, bincode::config::standard())
263            .map(|(v, _)| v)
264            .map_err(|e| VectorError::Encoding(format!("failed to deserialize index meta: {}", e)))
265    }
266}
267
268impl Default for InvertedIndexMeta {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274/// Scoring function for sparse vector similarity.
275#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
276pub enum ScoringFunction {
277    /// Dot product: sum of query_weight * doc_weight for matching tokens.
278    DotProduct,
279    /// BM25-style scoring with length normalization.
280    Bm25 {
281        /// BM25 k1 parameter (default: 1.2).
282        k1_times_10: u8,
283        /// BM25 b parameter (default: 0.75).
284        b_times_100: u8,
285    },
286}
287
288impl Default for ScoringFunction {
289    fn default() -> Self {
290        Self::DotProduct
291    }
292}
293
294impl ScoringFunction {
295    /// Create BM25 scoring with default parameters.
296    #[must_use]
297    pub const fn bm25() -> Self {
298        Self::Bm25 { k1_times_10: 12, b_times_100: 75 }
299    }
300
301    /// Create BM25 scoring with custom parameters.
302    #[must_use]
303    pub fn bm25_custom(k1: f32, b: f32) -> Self {
304        Self::Bm25 {
305            k1_times_10: (k1 * 10.0).clamp(0.0, 255.0) as u8,
306            b_times_100: (b * 100.0).clamp(0.0, 255.0) as u8,
307        }
308    }
309}
310
311/// A search result with point ID and score.
312#[derive(Debug, Clone, Copy)]
313pub struct SearchResult {
314    /// The point ID.
315    pub point_id: PointId,
316    /// The similarity score.
317    pub score: f32,
318}
319
320impl SearchResult {
321    /// Create a new search result.
322    #[must_use]
323    pub const fn new(point_id: PointId, score: f32) -> Self {
324        Self { point_id, score }
325    }
326}
327
328impl PartialEq for SearchResult {
329    fn eq(&self, other: &Self) -> bool {
330        self.point_id == other.point_id
331    }
332}
333
334impl Eq for SearchResult {}
335
336impl PartialOrd for SearchResult {
337    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
338        Some(self.cmp(other))
339    }
340}
341
342impl Ord for SearchResult {
343    fn cmp(&self, other: &Self) -> Ordering {
344        // Reverse order for min-heap (we want to pop smallest scores)
345        other.score.partial_cmp(&self.score).unwrap_or(Ordering::Equal)
346    }
347}
348
349/// Inverted index for sparse vector similarity search.
350pub struct InvertedIndex<E: StorageEngine> {
351    engine: E,
352}
353
354impl<E: StorageEngine> InvertedIndex<E> {
355    /// Create a new inverted index with the given storage engine.
356    #[must_use]
357    pub const fn new(engine: E) -> Self {
358        Self { engine }
359    }
360
361    /// Get a reference to the storage engine.
362    #[must_use]
363    pub fn engine(&self) -> &E {
364        &self.engine
365    }
366
367    // ========================================================================
368    // Index operations
369    // ========================================================================
370
371    /// Insert a sparse vector into the index.
372    ///
373    /// # Arguments
374    ///
375    /// * `collection` - Collection name
376    /// * `vector_name` - Vector name within the collection
377    /// * `point_id` - Point ID
378    /// * `vector` - Sparse vector as (token_id, weight) pairs (must be sorted)
379    pub fn insert(
380        &self,
381        collection: &str,
382        vector_name: &str,
383        point_id: PointId,
384        vector: &[(u32, f32)],
385    ) -> Result<(), VectorError> {
386        if vector.is_empty() {
387            return Ok(());
388        }
389
390        let mut tx = self.engine.begin_write()?;
391
392        // Load or create metadata
393        let meta_key = encode_inverted_meta_key(collection, vector_name);
394        let mut meta = tx
395            .get(TABLE_META, &meta_key)?
396            .map(|bytes| InvertedIndexMeta::from_bytes(&bytes))
397            .transpose()?
398            .unwrap_or_default();
399
400        // Store token IDs for this point (for deletion)
401        let token_ids: Vec<u32> = vector.iter().map(|(idx, _)| *idx).collect();
402        let point_tokens_key = encode_point_tokens_key(collection, vector_name, point_id);
403        tx.put(TABLE_POINT_TOKENS, &point_tokens_key, &encode_token_ids(&token_ids))?;
404
405        // Add to posting lists
406        for &(token_id, weight) in vector {
407            let posting_key = encode_posting_key(collection, vector_name, token_id);
408
409            // Load existing posting list or create new
410            let mut posting_list = tx
411                .get(TABLE_POSTINGS, &posting_key)?
412                .map(|bytes| PostingList::from_bytes(&bytes))
413                .transpose()?
414                .unwrap_or_default();
415
416            posting_list.add(PostingEntry::new(point_id, weight));
417            tx.put(TABLE_POSTINGS, &posting_key, &posting_list.to_bytes())?;
418        }
419
420        // Update metadata
421        meta.add_document(vector.len());
422        tx.put(TABLE_META, &meta_key, &meta.to_bytes())?;
423
424        tx.commit()?;
425        Ok(())
426    }
427
428    /// Delete a sparse vector from the index.
429    ///
430    /// # Returns
431    ///
432    /// Returns `Ok(true)` if the vector was deleted, `Ok(false)` if it wasn't indexed.
433    pub fn delete(
434        &self,
435        collection: &str,
436        vector_name: &str,
437        point_id: PointId,
438    ) -> Result<bool, VectorError> {
439        let mut tx = self.engine.begin_write()?;
440
441        // Get token IDs for this point
442        let point_tokens_key = encode_point_tokens_key(collection, vector_name, point_id);
443        let token_ids = match tx.get(TABLE_POINT_TOKENS, &point_tokens_key)? {
444            Some(bytes) => decode_token_ids(&bytes)?,
445            None => return Ok(false),
446        };
447
448        // Load metadata
449        let meta_key = encode_inverted_meta_key(collection, vector_name);
450        let mut meta = tx
451            .get(TABLE_META, &meta_key)?
452            .map(|bytes| InvertedIndexMeta::from_bytes(&bytes))
453            .transpose()?
454            .unwrap_or_default();
455
456        // Remove from posting lists
457        for token_id in &token_ids {
458            let posting_key = encode_posting_key(collection, vector_name, *token_id);
459
460            if let Some(bytes) = tx.get(TABLE_POSTINGS, &posting_key)? {
461                let mut posting_list = PostingList::from_bytes(&bytes)?;
462                posting_list.remove(point_id);
463
464                if posting_list.is_empty() {
465                    tx.delete(TABLE_POSTINGS, &posting_key)?;
466                } else {
467                    tx.put(TABLE_POSTINGS, &posting_key, &posting_list.to_bytes())?;
468                }
469            }
470        }
471
472        // Delete point tokens mapping
473        tx.delete(TABLE_POINT_TOKENS, &point_tokens_key)?;
474
475        // Update metadata
476        meta.remove_document(token_ids.len());
477        tx.put(TABLE_META, &meta_key, &meta.to_bytes())?;
478
479        tx.commit()?;
480        Ok(true)
481    }
482
483    /// Update a sparse vector in the index (delete + insert).
484    pub fn update(
485        &self,
486        collection: &str,
487        vector_name: &str,
488        point_id: PointId,
489        vector: &[(u32, f32)],
490    ) -> Result<(), VectorError> {
491        // Delete existing
492        self.delete(collection, vector_name, point_id)?;
493        // Insert new
494        self.insert(collection, vector_name, point_id, vector)
495    }
496
497    /// Delete all index data for a collection.
498    pub fn delete_collection(&self, collection: &str) -> Result<(), VectorError> {
499        let mut tx = self.engine.begin_write()?;
500
501        // Delete all posting lists
502        delete_by_prefix(&mut tx, TABLE_POSTINGS, &encode_posting_collection_prefix(collection))?;
503
504        // Delete all point tokens
505        delete_by_prefix(
506            &mut tx,
507            TABLE_POINT_TOKENS,
508            &encode_point_tokens_collection_prefix(collection),
509        )?;
510
511        // Delete all metadata
512        delete_by_prefix(&mut tx, TABLE_META, &encode_inverted_meta_collection_prefix(collection))?;
513
514        tx.commit()?;
515        Ok(())
516    }
517
518    /// Delete all index data for a specific vector in a collection.
519    pub fn delete_vector(&self, collection: &str, vector_name: &str) -> Result<(), VectorError> {
520        let mut tx = self.engine.begin_write()?;
521
522        // Delete all posting lists for this vector
523        delete_by_prefix(&mut tx, TABLE_POSTINGS, &encode_posting_prefix(collection, vector_name))?;
524
525        // Delete all point tokens for this vector
526        delete_by_prefix(
527            &mut tx,
528            TABLE_POINT_TOKENS,
529            &encode_point_tokens_prefix(collection, vector_name),
530        )?;
531
532        // Delete metadata
533        let meta_key = encode_inverted_meta_key(collection, vector_name);
534        tx.delete(TABLE_META, &meta_key)?;
535
536        tx.commit()?;
537        Ok(())
538    }
539
540    // ========================================================================
541    // Query operations
542    // ========================================================================
543
544    /// Get index metadata.
545    pub fn get_meta(
546        &self,
547        collection: &str,
548        vector_name: &str,
549    ) -> Result<InvertedIndexMeta, VectorError> {
550        let tx = self.engine.begin_read()?;
551        let meta_key = encode_inverted_meta_key(collection, vector_name);
552        tx.get(TABLE_META, &meta_key)?
553            .map(|bytes| InvertedIndexMeta::from_bytes(&bytes))
554            .transpose()?
555            .ok_or_else(|| {
556                VectorError::SpaceNotFound(format!("index '{}/{}'", collection, vector_name))
557            })
558    }
559
560    /// Get a posting list for a specific token.
561    pub fn get_posting_list(
562        &self,
563        collection: &str,
564        vector_name: &str,
565        token_id: u32,
566    ) -> Result<Option<PostingList>, VectorError> {
567        let tx = self.engine.begin_read()?;
568        let posting_key = encode_posting_key(collection, vector_name, token_id);
569        tx.get(TABLE_POSTINGS, &posting_key)?
570            .map(|bytes| PostingList::from_bytes(&bytes))
571            .transpose()
572    }
573
574    // ========================================================================
575    // Search algorithms
576    // ========================================================================
577
578    /// Search using DAAT (Document-at-a-time) algorithm.
579    ///
580    /// This is exact scoring that traverses all posting lists to compute
581    /// the full similarity score for each candidate document.
582    ///
583    /// # Arguments
584    ///
585    /// * `collection` - Collection name
586    /// * `vector_name` - Vector name
587    /// * `query` - Query vector as (token_id, weight) pairs
588    /// * `top_k` - Number of results to return
589    /// * `scoring` - Scoring function to use
590    pub fn search_daat(
591        &self,
592        collection: &str,
593        vector_name: &str,
594        query: &[(u32, f32)],
595        top_k: usize,
596        scoring: ScoringFunction,
597    ) -> Result<Vec<SearchResult>, VectorError> {
598        if query.is_empty() || top_k == 0 {
599            return Ok(Vec::new());
600        }
601
602        let tx = self.engine.begin_read()?;
603
604        // Load metadata for BM25
605        let meta = if matches!(scoring, ScoringFunction::Bm25 { .. }) {
606            let meta_key = encode_inverted_meta_key(collection, vector_name);
607            tx.get(TABLE_META, &meta_key)?
608                .map(|bytes| InvertedIndexMeta::from_bytes(&bytes))
609                .transpose()?
610        } else {
611            None
612        };
613
614        // Load posting lists for all query tokens
615        let mut posting_lists: Vec<(u32, f32, PostingList)> = Vec::with_capacity(query.len());
616        for &(token_id, query_weight) in query {
617            let posting_key = encode_posting_key(collection, vector_name, token_id);
618            if let Some(bytes) = tx.get(TABLE_POSTINGS, &posting_key)? {
619                let posting_list = PostingList::from_bytes(&bytes)?;
620                if !posting_list.is_empty() {
621                    posting_lists.push((token_id, query_weight, posting_list));
622                }
623            }
624        }
625
626        if posting_lists.is_empty() {
627            return Ok(Vec::new());
628        }
629
630        // Accumulate scores for all documents
631        let mut scores: HashMap<u64, f32> = HashMap::new();
632
633        for (token_id, query_weight, posting_list) in &posting_lists {
634            for entry in posting_list.entries() {
635                let doc_id = entry.point_id.as_u64();
636                let term_score = match scoring {
637                    ScoringFunction::DotProduct => query_weight * entry.weight,
638                    ScoringFunction::Bm25 { k1_times_10, b_times_100 } => {
639                        let k1 = k1_times_10 as f32 / 10.0;
640                        let b = b_times_100 as f32 / 100.0;
641                        compute_bm25_term_score(
642                            *query_weight,
643                            entry.weight,
644                            meta.as_ref(),
645                            *token_id,
646                            posting_list.len(),
647                            k1,
648                            b,
649                        )
650                    }
651                };
652                *scores.entry(doc_id).or_insert(0.0) += term_score;
653            }
654        }
655
656        // Get top-k results
657        let mut results: Vec<SearchResult> = scores
658            .into_iter()
659            .map(|(doc_id, score)| SearchResult::new(PointId::new(doc_id), score))
660            .collect();
661
662        // Sort by score descending
663        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
664        results.truncate(top_k);
665
666        Ok(results)
667    }
668
669    /// Search using WAND (Weak AND) algorithm.
670    ///
671    /// This is an optimized top-k search that uses upper bound scores
672    /// to skip documents that cannot make it into the result set.
673    ///
674    /// # Arguments
675    ///
676    /// * `collection` - Collection name
677    /// * `vector_name` - Vector name
678    /// * `query` - Query vector as (token_id, weight) pairs
679    /// * `top_k` - Number of results to return
680    pub fn search_wand(
681        &self,
682        collection: &str,
683        vector_name: &str,
684        query: &[(u32, f32)],
685        top_k: usize,
686    ) -> Result<Vec<SearchResult>, VectorError> {
687        if query.is_empty() || top_k == 0 {
688            return Ok(Vec::new());
689        }
690
691        let tx = self.engine.begin_read()?;
692
693        // Load posting lists for all query tokens with their upper bounds
694        let mut cursors: Vec<WandCursor> = Vec::with_capacity(query.len());
695        for &(token_id, query_weight) in query {
696            let posting_key = encode_posting_key(collection, vector_name, token_id);
697            if let Some(bytes) = tx.get(TABLE_POSTINGS, &posting_key)? {
698                let posting_list = PostingList::from_bytes(&bytes)?;
699                if !posting_list.is_empty() {
700                    let upper_bound = query_weight * posting_list.max_weight();
701                    cursors.push(WandCursor::new(posting_list, query_weight, upper_bound));
702                }
703            }
704        }
705
706        if cursors.is_empty() {
707            return Ok(Vec::new());
708        }
709
710        // WAND algorithm
711        let mut heap: BinaryHeap<SearchResult> = BinaryHeap::with_capacity(top_k + 1);
712        let mut threshold = 0.0f32;
713
714        loop {
715            // Sort cursors by current document ID
716            cursors.sort_by_key(|c| c.current_doc_id());
717
718            // Skip exhausted cursors and find first valid one
719            let first_valid = cursors.iter().position(|c| !c.exhausted());
720            if first_valid.is_none() {
721                break;
722            }
723
724            // Find pivot: smallest index where sum of upper bounds >= threshold
725            let mut upper_sum = 0.0f32;
726            let mut pivot_idx = None;
727
728            for (i, cursor) in cursors.iter().enumerate() {
729                if cursor.exhausted() {
730                    continue;
731                }
732                upper_sum += cursor.upper_bound;
733                if upper_sum >= threshold {
734                    pivot_idx = Some(i);
735                    break;
736                }
737            }
738
739            let pivot_idx = match pivot_idx {
740                Some(idx) => idx,
741                None => break, // No more candidates can beat threshold
742            };
743
744            let pivot_doc_id = cursors[pivot_idx].current_doc_id();
745
746            // Check if all non-exhausted cursors before pivot point to the same document
747            let all_aligned = cursors[..pivot_idx]
748                .iter()
749                .filter(|c| !c.exhausted())
750                .all(|c| c.current_doc_id() == pivot_doc_id);
751
752            if all_aligned || pivot_idx == 0 {
753                // Score this document - include all cursors pointing to this doc
754                let mut score = 0.0f32;
755                for cursor in &cursors {
756                    if !cursor.exhausted() && cursor.current_doc_id() == pivot_doc_id {
757                        if let Some(entry) = cursor.current_entry() {
758                            score += cursor.query_weight * entry.weight;
759                        }
760                    }
761                }
762
763                if score > threshold || heap.len() < top_k {
764                    heap.push(SearchResult::new(PointId::new(pivot_doc_id), score));
765                    if heap.len() > top_k {
766                        heap.pop();
767                    }
768                    if heap.len() == top_k {
769                        threshold = heap.peek().map_or(0.0, |r| r.score);
770                    }
771                }
772
773                // Advance all cursors past this document
774                for cursor in &mut cursors {
775                    if !cursor.exhausted() && cursor.current_doc_id() == pivot_doc_id {
776                        cursor.advance();
777                    }
778                }
779            } else {
780                // Advance cursors before pivot to pivot_doc_id
781                for cursor in &mut cursors[..pivot_idx] {
782                    if !cursor.exhausted() {
783                        cursor.advance_to(pivot_doc_id);
784                    }
785                }
786            }
787
788            // Remove exhausted cursors
789            cursors.retain(|c| !c.exhausted());
790            if cursors.is_empty() {
791                break;
792            }
793        }
794
795        // Extract results from heap and sort by score descending
796        let mut results: Vec<SearchResult> = heap.into_vec();
797        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
798        Ok(results)
799    }
800
801    /// Search using MaxScore algorithm (optimized WAND variant).
802    ///
803    /// Similar to WAND but more aggressive at skipping low-scoring documents.
804    pub fn search_maxscore(
805        &self,
806        collection: &str,
807        vector_name: &str,
808        query: &[(u32, f32)],
809        top_k: usize,
810    ) -> Result<Vec<SearchResult>, VectorError> {
811        // For now, delegate to WAND. MaxScore can be implemented later
812        // as an optimization with block-max indices.
813        self.search_wand(collection, vector_name, query, top_k)
814    }
815}
816
817/// Cursor for WAND algorithm traversal.
818struct WandCursor {
819    posting_list: PostingList,
820    position: usize,
821    query_weight: f32,
822    upper_bound: f32,
823}
824
825impl WandCursor {
826    fn new(posting_list: PostingList, query_weight: f32, upper_bound: f32) -> Self {
827        Self { posting_list, position: 0, query_weight, upper_bound }
828    }
829
830    fn exhausted(&self) -> bool {
831        self.position >= self.posting_list.len()
832    }
833
834    fn current_doc_id(&self) -> u64 {
835        if self.exhausted() {
836            u64::MAX
837        } else {
838            self.posting_list.entries()[self.position].point_id.as_u64()
839        }
840    }
841
842    fn current_entry(&self) -> Option<&PostingEntry> {
843        if self.exhausted() {
844            None
845        } else {
846            Some(&self.posting_list.entries()[self.position])
847        }
848    }
849
850    fn advance(&mut self) {
851        if !self.exhausted() {
852            self.position += 1;
853        }
854    }
855
856    fn advance_to(&mut self, doc_id: u64) {
857        while !self.exhausted() && self.current_doc_id() < doc_id {
858            self.position += 1;
859        }
860    }
861}
862
863/// Compute BM25 term score.
864fn compute_bm25_term_score(
865    query_weight: f32,
866    doc_weight: f32,
867    meta: Option<&InvertedIndexMeta>,
868    _token_id: u32,
869    df: usize,
870    k1: f32,
871    b: f32,
872) -> f32 {
873    let meta = match meta {
874        Some(m) => m,
875        None => return query_weight * doc_weight, // Fallback to dot product
876    };
877
878    if meta.doc_count == 0 {
879        return 0.0;
880    }
881
882    // IDF component: log((N - df + 0.5) / (df + 0.5))
883    let n = meta.doc_count as f32;
884    let df = df as f32;
885    let idf = ((n - df + 0.5) / (df + 0.5)).ln_1p();
886
887    // TF component with length normalization
888    // For sparse vectors, we use the weight as a proxy for term frequency
889    let tf = doc_weight;
890    let avg_dl = meta.avg_doc_length.max(1.0);
891    // Assume document length is proportional to the weight
892    let dl = doc_weight;
893
894    let tf_component = (tf * (k1 + 1.0)) / (tf + k1 * (1.0 - b + b * (dl / avg_dl)));
895
896    query_weight * idf * tf_component
897}
898
899/// Encode token IDs to bytes.
900fn encode_token_ids(token_ids: &[u32]) -> Vec<u8> {
901    let mut bytes = Vec::with_capacity(4 + token_ids.len() * 4);
902    bytes.extend_from_slice(&(token_ids.len() as u32).to_le_bytes());
903    for &token_id in token_ids {
904        bytes.extend_from_slice(&token_id.to_le_bytes());
905    }
906    bytes
907}
908
909/// Decode token IDs from bytes.
910fn decode_token_ids(bytes: &[u8]) -> Result<Vec<u32>, VectorError> {
911    if bytes.len() < 4 {
912        return Err(VectorError::Encoding("token ids too short".to_string()));
913    }
914
915    let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
916    let expected_len = 4 + count * 4;
917
918    if bytes.len() != expected_len {
919        return Err(VectorError::Encoding(format!(
920            "token ids length mismatch: expected {}, got {}",
921            expected_len,
922            bytes.len()
923        )));
924    }
925
926    let mut token_ids = Vec::with_capacity(count);
927    for i in 0..count {
928        let offset = 4 + i * 4;
929        let token_id = u32::from_le_bytes([
930            bytes[offset],
931            bytes[offset + 1],
932            bytes[offset + 2],
933            bytes[offset + 3],
934        ]);
935        token_ids.push(token_id);
936    }
937
938    Ok(token_ids)
939}
940
941/// Calculate the next prefix for range scanning.
942fn next_prefix(prefix: &[u8]) -> Vec<u8> {
943    let mut result = prefix.to_vec();
944
945    for byte in result.iter_mut().rev() {
946        if *byte < 0xFF {
947            *byte += 1;
948            return result;
949        }
950    }
951
952    result.push(0xFF);
953    result
954}
955
956/// Delete all keys matching a prefix.
957fn delete_by_prefix<T: Transaction>(
958    tx: &mut T,
959    table: &str,
960    prefix: &[u8],
961) -> Result<(), VectorError> {
962    let prefix_end = next_prefix(prefix);
963
964    let mut keys_to_delete = Vec::new();
965    {
966        let mut cursor =
967            tx.range(table, Bound::Included(prefix), Bound::Excluded(prefix_end.as_slice()))?;
968
969        while let Some((key, _)) = cursor.next()? {
970            keys_to_delete.push(key);
971        }
972    }
973
974    for key in keys_to_delete {
975        tx.delete(table, &key)?;
976    }
977
978    Ok(())
979}
980
981#[cfg(test)]
982mod tests {
983    use super::*;
984    use manifoldb_storage::backends::RedbEngine;
985    use std::sync::atomic::{AtomicUsize, Ordering};
986
987    static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0);
988
989    fn create_test_index() -> InvertedIndex<RedbEngine> {
990        let engine = RedbEngine::in_memory().unwrap();
991        InvertedIndex::new(engine)
992    }
993
994    fn unique_name(prefix: &str) -> String {
995        let count = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
996        format!("{}_{}", prefix, count)
997    }
998
999    #[test]
1000    fn posting_list_roundtrip() {
1001        let mut list = PostingList::new();
1002        list.add(PostingEntry::new(PointId::new(1), 0.5));
1003        list.add(PostingEntry::new(PointId::new(3), 0.3));
1004        list.add(PostingEntry::new(PointId::new(2), 0.8));
1005
1006        let bytes = list.to_bytes();
1007        let restored = PostingList::from_bytes(&bytes).unwrap();
1008
1009        assert_eq!(restored.len(), 3);
1010        assert!((restored.max_weight() - 0.8).abs() < 1e-6);
1011
1012        // Should be sorted by point_id
1013        assert_eq!(restored.entries()[0].point_id, PointId::new(1));
1014        assert_eq!(restored.entries()[1].point_id, PointId::new(2));
1015        assert_eq!(restored.entries()[2].point_id, PointId::new(3));
1016    }
1017
1018    #[test]
1019    fn posting_list_remove() {
1020        let mut list = PostingList::new();
1021        list.add(PostingEntry::new(PointId::new(1), 0.5));
1022        list.add(PostingEntry::new(PointId::new(2), 0.8));
1023        list.add(PostingEntry::new(PointId::new(3), 0.3));
1024
1025        assert!(list.remove(PointId::new(2)));
1026        assert_eq!(list.len(), 2);
1027        assert!((list.max_weight() - 0.5).abs() < 1e-6);
1028
1029        assert!(!list.remove(PointId::new(2))); // Already removed
1030    }
1031
1032    #[test]
1033    fn insert_and_search() {
1034        let index = create_test_index();
1035        let collection = unique_name("collection");
1036        let vector = "keywords";
1037
1038        // Insert some documents
1039        index.insert(&collection, vector, PointId::new(1), &[(100, 0.5), (200, 0.3)]).unwrap();
1040        index.insert(&collection, vector, PointId::new(2), &[(100, 0.8), (300, 0.2)]).unwrap();
1041        index.insert(&collection, vector, PointId::new(3), &[(200, 0.6), (300, 0.4)]).unwrap();
1042
1043        // Search with DAAT
1044        let query = vec![(100, 1.0), (200, 0.5)];
1045        let results = index
1046            .search_daat(&collection, vector, &query, 10, ScoringFunction::DotProduct)
1047            .unwrap();
1048
1049        assert!(!results.is_empty());
1050        // Point 1 has score: 0.5*1.0 + 0.3*0.5 = 0.65
1051        // Point 2 has score: 0.8*1.0 = 0.8
1052        // Point 3 has score: 0.6*0.5 = 0.3
1053
1054        assert_eq!(results[0].point_id, PointId::new(2)); // Highest score
1055    }
1056
1057    #[test]
1058    fn delete_document() {
1059        let index = create_test_index();
1060        let collection = unique_name("collection");
1061        let vector = "keywords";
1062
1063        index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1064        index.insert(&collection, vector, PointId::new(2), &[(100, 0.8)]).unwrap();
1065
1066        // Verify both exist
1067        let results = index
1068            .search_daat(&collection, vector, &[(100, 1.0)], 10, ScoringFunction::DotProduct)
1069            .unwrap();
1070        assert_eq!(results.len(), 2);
1071
1072        // Delete one
1073        assert!(index.delete(&collection, vector, PointId::new(1)).unwrap());
1074
1075        // Verify only one remains
1076        let results = index
1077            .search_daat(&collection, vector, &[(100, 1.0)], 10, ScoringFunction::DotProduct)
1078            .unwrap();
1079        assert_eq!(results.len(), 1);
1080        assert_eq!(results[0].point_id, PointId::new(2));
1081    }
1082
1083    #[test]
1084    fn update_document() {
1085        let index = create_test_index();
1086        let collection = unique_name("collection");
1087        let vector = "keywords";
1088
1089        index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1090
1091        // Update with new vector
1092        index.update(&collection, vector, PointId::new(1), &[(200, 0.9)]).unwrap();
1093
1094        // Old token should not match
1095        let results = index
1096            .search_daat(&collection, vector, &[(100, 1.0)], 10, ScoringFunction::DotProduct)
1097            .unwrap();
1098        assert!(results.is_empty());
1099
1100        // New token should match
1101        let results = index
1102            .search_daat(&collection, vector, &[(200, 1.0)], 10, ScoringFunction::DotProduct)
1103            .unwrap();
1104        assert_eq!(results.len(), 1);
1105    }
1106
1107    #[test]
1108    fn wand_search() {
1109        let index = create_test_index();
1110        let collection = unique_name("collection");
1111        let vector = "keywords";
1112
1113        // Insert many documents with 2 tokens to make WAND more effective
1114        for i in 0..100 {
1115            let weight = (i as f32 + 1.0) / 100.0;
1116            index
1117                .insert(&collection, vector, PointId::new(i), &[(100, weight), (200, weight * 0.5)])
1118                .unwrap();
1119        }
1120
1121        // WAND search for top 5
1122        let results = index.search_wand(&collection, vector, &[(100, 1.0), (200, 0.5)], 5).unwrap();
1123
1124        assert_eq!(results.len(), 5);
1125        // Results should be sorted by score descending
1126        for i in 0..4 {
1127            assert!(
1128                results[i].score >= results[i + 1].score,
1129                "Results should be sorted by score: {} >= {}",
1130                results[i].score,
1131                results[i + 1].score
1132            );
1133        }
1134
1135        // Compare with DAAT to ensure correctness
1136        let daat_results = index
1137            .search_daat(
1138                &collection,
1139                vector,
1140                &[(100, 1.0), (200, 0.5)],
1141                5,
1142                ScoringFunction::DotProduct,
1143            )
1144            .unwrap();
1145
1146        // Both should return the same result set (same point IDs in same order)
1147        assert_eq!(results.len(), daat_results.len());
1148        for (wand_r, daat_r) in results.iter().zip(daat_results.iter()) {
1149            assert_eq!(wand_r.point_id, daat_r.point_id);
1150            assert!((wand_r.score - daat_r.score).abs() < 1e-5);
1151        }
1152    }
1153
1154    #[test]
1155    fn metadata_tracking() {
1156        let index = create_test_index();
1157        let collection = unique_name("collection");
1158        let vector = "keywords";
1159
1160        index.insert(&collection, vector, PointId::new(1), &[(100, 0.5), (200, 0.3)]).unwrap();
1161        index.insert(&collection, vector, PointId::new(2), &[(100, 0.8)]).unwrap();
1162
1163        let meta = index.get_meta(&collection, vector).unwrap();
1164        assert_eq!(meta.doc_count, 2);
1165        assert_eq!(meta.total_tokens, 3);
1166        assert!((meta.avg_doc_length - 1.5).abs() < 0.01);
1167
1168        index.delete(&collection, vector, PointId::new(1)).unwrap();
1169
1170        let meta = index.get_meta(&collection, vector).unwrap();
1171        assert_eq!(meta.doc_count, 1);
1172        assert_eq!(meta.total_tokens, 1);
1173    }
1174
1175    #[test]
1176    fn bm25_scoring() {
1177        let index = create_test_index();
1178        let collection = unique_name("collection");
1179        let vector = "keywords";
1180
1181        index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1182        index.insert(&collection, vector, PointId::new(2), &[(100, 0.8)]).unwrap();
1183
1184        let results = index
1185            .search_daat(&collection, vector, &[(100, 1.0)], 10, ScoringFunction::bm25())
1186            .unwrap();
1187
1188        assert_eq!(results.len(), 2);
1189        // BM25 should still rank doc 2 higher due to higher weight
1190        assert_eq!(results[0].point_id, PointId::new(2));
1191    }
1192
1193    #[test]
1194    fn empty_query() {
1195        let index = create_test_index();
1196        let collection = unique_name("collection");
1197        let vector = "keywords";
1198
1199        index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1200
1201        let results =
1202            index.search_daat(&collection, vector, &[], 10, ScoringFunction::DotProduct).unwrap();
1203        assert!(results.is_empty());
1204
1205        let results = index.search_wand(&collection, vector, &[], 10).unwrap();
1206        assert!(results.is_empty());
1207    }
1208
1209    #[test]
1210    fn no_matching_tokens() {
1211        let index = create_test_index();
1212        let collection = unique_name("collection");
1213        let vector = "keywords";
1214
1215        index.insert(&collection, vector, PointId::new(1), &[(100, 0.5)]).unwrap();
1216
1217        // Query for a token that doesn't exist
1218        let results = index
1219            .search_daat(&collection, vector, &[(999, 1.0)], 10, ScoringFunction::DotProduct)
1220            .unwrap();
1221        assert!(results.is_empty());
1222    }
1223
1224    #[test]
1225    fn delete_vector_index() {
1226        let index = create_test_index();
1227        let collection = unique_name("collection");
1228
1229        index.insert(&collection, "v1", PointId::new(1), &[(100, 0.5)]).unwrap();
1230        index.insert(&collection, "v2", PointId::new(1), &[(100, 0.8)]).unwrap();
1231
1232        // Delete v1 index
1233        index.delete_vector(&collection, "v1").unwrap();
1234
1235        // v1 should be gone
1236        assert!(index.get_meta(&collection, "v1").is_err());
1237
1238        // v2 should still exist
1239        assert!(index.get_meta(&collection, "v2").is_ok());
1240    }
1241}