Skip to main content

aprender_rag/multivector/
index.rs

1//! WARP index with IVF structure
2//!
3//! This module implements the WARP index which organizes compressed token
4//! embeddings by centroid for cache-efficient search. The index supports:
5//!
6//! - Training from sample embeddings
7//! - Incremental insertion of documents
8//! - Building (compacting) for efficient search
9//! - MaxSim-based multi-vector search
10
11use crate::multivector::{
12    codec::ResidualCodec,
13    search::{CandidateScorer, CentroidSelector, ScoreMerger},
14    types::{MultiVectorEmbedding, WarpIndexConfig, WarpSearchConfig},
15};
16use crate::{Chunk, ChunkId, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20/// WARP index for efficient multi-vector retrieval.
21///
22/// The index organizes token embeddings by centroid assignment (IVF structure)
23/// for cache-efficient access during search. Each token embedding is stored as:
24/// - Centroid assignment
25/// - Quantized residual (2-4 bits per dimension)
26///
27/// # Lifecycle
28///
29/// 1. Create index with `new(config)`
30/// 2. Train codec with `train(samples)`
31/// 3. Insert documents with `insert(chunk, embedding)`
32/// 4. Build index with `build()` (compacts for efficient search)
33/// 5. Search with `search(query, config)`
34///
35/// # Memory Layout
36///
37/// After `build()`, data is organized by centroid:
38/// ```text
39/// Centroid 0: [chunk_ids...] [token_indices...] [residuals...]
40/// Centroid 1: [chunk_ids...] [token_indices...] [residuals...]
41/// ...
42/// ```
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct WarpIndex {
45    /// Index configuration
46    config: WarpIndexConfig,
47    /// Trained residual codec (None until trained)
48    codec: Option<ResidualCodec>,
49    /// Number of embeddings per centroid
50    sizes: Vec<usize>,
51    /// Cumulative offset for each centroid's data
52    offsets: Vec<usize>,
53    /// Chunk IDs, sorted by centroid assignment
54    chunk_ids: Vec<ChunkId>,
55    /// Token indices within each chunk
56    token_indices: Vec<u16>,
57    /// Packed residuals, sorted by centroid
58    residuals: Vec<u8>,
59    /// Original chunks for result retrieval
60    #[serde(skip)]
61    chunks: HashMap<ChunkId, Chunk>,
62    /// Pending embeddings before build
63    #[serde(skip)]
64    pending: Vec<(ChunkId, MultiVectorEmbedding)>,
65    /// Whether the index has been built
66    is_built: bool,
67}
68
69impl WarpIndex {
70    /// Create a new WARP index with the given configuration.
71    #[must_use]
72    pub fn new(config: WarpIndexConfig) -> Self {
73        Self {
74            config,
75            codec: None,
76            sizes: Vec::new(),
77            offsets: Vec::new(),
78            chunk_ids: Vec::new(),
79            token_indices: Vec::new(),
80            residuals: Vec::new(),
81            chunks: HashMap::new(),
82            pending: Vec::new(),
83            is_built: false,
84        }
85    }
86
87    /// Get the index configuration.
88    #[must_use]
89    pub fn config(&self) -> &WarpIndexConfig {
90        &self.config
91    }
92
93    /// Get the trained codec (if any).
94    #[must_use]
95    pub fn codec(&self) -> Option<&ResidualCodec> {
96        self.codec.as_ref()
97    }
98
99    /// Check if the codec has been trained.
100    #[must_use]
101    pub fn is_trained(&self) -> bool {
102        self.codec.is_some()
103    }
104
105    /// Check if the index has been built.
106    #[must_use]
107    pub fn is_built(&self) -> bool {
108        self.is_built
109    }
110
111    /// Get the number of indexed chunks.
112    #[must_use]
113    pub fn num_chunks(&self) -> usize {
114        self.chunks.len()
115    }
116
117    /// Get the number of indexed tokens.
118    #[must_use]
119    pub fn num_tokens(&self) -> usize {
120        self.chunk_ids.len()
121    }
122
123    /// Check if the index is empty.
124    #[must_use]
125    pub fn is_empty(&self) -> bool {
126        self.chunks.is_empty()
127    }
128
129    /// Get a chunk by ID.
130    #[must_use]
131    pub fn get_chunk(&self, id: &ChunkId) -> Option<&Chunk> {
132        self.chunks.get(id)
133    }
134
135    /// Get memory usage in bytes (approximate).
136    #[must_use]
137    pub fn memory_usage(&self) -> usize {
138        let codec_size = self
139            .codec
140            .as_ref()
141            .map(|c| {
142                c.centroids().len() * 4 // centroids
143                    + c.dim() * ((1 << c.nbits()) - 1) * 4 // cutoffs
144                    + c.dim() * (1 << c.nbits()) * 4 // weights
145            })
146            .unwrap_or(0);
147
148        let index_size = self.chunk_ids.len() * size_of::<ChunkId>()
149            + self.token_indices.len() * size_of::<u16>()
150            + self.residuals.len()
151            + self.sizes.len() * size_of::<usize>()
152            + self.offsets.len() * size_of::<usize>();
153
154        codec_size + index_size
155    }
156
157    /// Train the codec from sample embeddings.
158    ///
159    /// # Arguments
160    ///
161    /// * `samples` - Sample multi-vector embeddings for training
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if:
166    /// - Not enough samples for training
167    /// - Configuration is invalid
168    pub fn train(&mut self, samples: &[MultiVectorEmbedding]) -> Result<()> {
169        // Collect all token embeddings
170        let total_tokens: usize = samples.iter().map(|s| s.num_tokens()).sum();
171        let min_samples = self.config.effective_min_training_samples();
172
173        if total_tokens < min_samples {
174            return Err(crate::Error::InvalidInput(format!(
175                "Insufficient training tokens: {total_tokens} < {min_samples} required"
176            )));
177        }
178
179        // Flatten all embeddings
180        let mut all_embeddings = Vec::with_capacity(total_tokens * self.config.token_dim);
181        for sample in samples {
182            all_embeddings.extend_from_slice(sample.as_slice());
183        }
184
185        // Train codec
186        let codec = ResidualCodec::train(
187            &all_embeddings,
188            self.config.token_dim,
189            self.config.num_centroids,
190            self.config.nbits,
191            self.config.kmeans_iterations,
192        )?;
193
194        self.codec = Some(codec);
195        Ok(())
196    }
197
198    /// Insert a chunk with its token embeddings.
199    ///
200    /// The chunk will be stored in pending state until `build()` is called.
201    ///
202    /// # Errors
203    ///
204    /// Returns an error if:
205    /// - Codec has not been trained
206    /// - Index has already been built (call `rebuild()` first)
207    pub fn insert(&mut self, chunk: Chunk, embedding: MultiVectorEmbedding) -> Result<()> {
208        if self.codec.is_none() {
209            return Err(crate::Error::InvalidInput(
210                "Codec not trained - call train() first".to_string(),
211            ));
212        }
213
214        if self.is_built {
215            return Err(crate::Error::InvalidInput(
216                "Index already built - cannot insert".to_string(),
217            ));
218        }
219
220        // Contract: embedding-algebra-v1.yaml precondition (pv codegen)
221        contract_pre_embedding_lookup!(embedding.as_slice());
222
223        let chunk_id = chunk.id;
224        self.chunks.insert(chunk_id, chunk);
225        self.pending.push((chunk_id, embedding));
226
227        Ok(())
228    }
229
230    /// Build the index for efficient search.
231    ///
232    /// This compacts all pending embeddings into a centroid-organized
233    /// IVF structure optimized for cache-efficient search.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if the codec has not been trained.
238    pub fn build(&mut self) -> Result<()> {
239        let codec = self.codec.as_ref().ok_or_else(|| {
240            crate::Error::InvalidInput("Codec not trained - call train() first".to_string())
241        })?;
242
243        // Assign each token to its nearest centroid
244        let mut centroid_assignments: Vec<Vec<(ChunkId, u16, Vec<u8>)>> =
245            vec![Vec::new(); self.config.num_centroids];
246
247        for (chunk_id, embedding) in &self.pending {
248            for (token_idx, token) in embedding.tokens().enumerate() {
249                let (centroid_id, residual) = codec.compress(token);
250                centroid_assignments[centroid_id].push((*chunk_id, token_idx as u16, residual));
251            }
252        }
253
254        // Build compacted arrays
255        let bytes_per_residual = self.config.packed_residual_size();
256
257        self.sizes = centroid_assignments.iter().map(|v| v.len()).collect();
258        self.offsets = self
259            .sizes
260            .iter()
261            .scan(0, |acc, &size| {
262                let offset = *acc;
263                *acc += size;
264                Some(offset)
265            })
266            .collect();
267
268        let total_tokens: usize = self.sizes.iter().sum();
269        self.chunk_ids = Vec::with_capacity(total_tokens);
270        self.token_indices = Vec::with_capacity(total_tokens);
271        self.residuals = Vec::with_capacity(total_tokens * bytes_per_residual);
272
273        for assignments in centroid_assignments {
274            for (chunk_id, token_idx, residual) in assignments {
275                self.chunk_ids.push(chunk_id);
276                self.token_indices.push(token_idx);
277                self.residuals.extend(residual);
278            }
279        }
280
281        self.pending.clear();
282        self.is_built = true;
283
284        Ok(())
285    }
286
287    /// Clear the built index to allow new insertions.
288    ///
289    /// Chunks are preserved, but the IVF structure is cleared.
290    /// Call `build()` again after inserting new chunks.
291    pub fn clear_index(&mut self) {
292        self.sizes.clear();
293        self.offsets.clear();
294        self.chunk_ids.clear();
295        self.token_indices.clear();
296        self.residuals.clear();
297        self.is_built = false;
298    }
299
300    /// Search for relevant chunks using MaxSim scoring.
301    ///
302    /// # Arguments
303    ///
304    /// * `query` - Query multi-vector embedding
305    /// * `search_config` - Search parameters
306    ///
307    /// # Returns
308    ///
309    /// Vector of (ChunkId, score) pairs sorted by score descending.
310    ///
311    /// # Errors
312    ///
313    /// Returns an error if the index has not been built.
314    pub fn search(
315        &self,
316        query: &MultiVectorEmbedding,
317        search_config: &WarpSearchConfig,
318    ) -> Result<Vec<(ChunkId, f32)>> {
319        let codec = self
320            .codec
321            .as_ref()
322            .ok_or_else(|| crate::Error::InvalidInput("Codec not trained".to_string()))?;
323
324        if !self.is_built {
325            return Err(crate::Error::InvalidInput(
326                "Index not built - call build() first".to_string(),
327            ));
328        }
329
330        // Phase 1: Select centroids per query token
331        let selected_centroids = CentroidSelector::select(
332            query,
333            codec.centroids(),
334            self.config.token_dim,
335            search_config,
336        );
337
338        // Apply bound: limit total centroids examined
339        let mut total_centroids = 0;
340        let max_tokens = search_config.t_prime.unwrap_or(usize::MAX);
341        let bounded_centroids: Vec<Vec<(usize, f32)>> = selected_centroids
342            .into_iter()
343            .take(max_tokens)
344            .map(|centroids| {
345                let take =
346                    (search_config.bound.saturating_sub(total_centroids)).min(centroids.len());
347                total_centroids += take;
348                centroids.into_iter().take(take).collect()
349            })
350            .collect();
351
352        // Phase 2: Score candidates from selected centroids
353        let bytes_per_residual = self.config.packed_residual_size();
354
355        let token_scores: Vec<Vec<(ChunkId, u16, f32)>> = bounded_centroids
356            .into_iter()
357            .enumerate()
358            .map(|(query_token_idx, centroids)| {
359                let query_token = query.token(query_token_idx);
360
361                centroids
362                    .into_iter()
363                    .flat_map(|(centroid_id, centroid_score)| {
364                        CandidateScorer::score(
365                            query_token,
366                            centroid_id,
367                            centroid_score,
368                            codec,
369                            &self.sizes,
370                            &self.offsets,
371                            &self.chunk_ids,
372                            &self.token_indices,
373                            &self.residuals,
374                            bytes_per_residual,
375                        )
376                    })
377                    .collect()
378            })
379            .collect();
380
381        // Phase 3: Merge via MaxSim
382        Ok(ScoreMerger::merge(token_scores, search_config.k))
383    }
384
385    /// Get centroid size (number of tokens assigned).
386    #[must_use]
387    pub fn centroid_size(&self, centroid_id: usize) -> usize {
388        self.sizes.get(centroid_id).copied().unwrap_or(0)
389    }
390
391    /// Get centroid offset in the compacted arrays.
392    #[must_use]
393    pub fn centroid_offset(&self, centroid_id: usize) -> usize {
394        self.offsets.get(centroid_id).copied().unwrap_or(0)
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::DocumentId;
402
403    fn create_test_chunk(content: &str) -> Chunk {
404        Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
405    }
406
407    fn generate_embedding(num_tokens: usize, dim: usize, seed: u64) -> MultiVectorEmbedding {
408        let mut embeddings = Vec::with_capacity(num_tokens * dim);
409        let mut rng = seed;
410
411        for _ in 0..(num_tokens * dim) {
412            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
413            let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
414            embeddings.push(val);
415        }
416
417        MultiVectorEmbedding::new(embeddings, num_tokens, dim)
418    }
419
420    // ============ Basic Index Tests ============
421
422    #[test]
423    fn test_index_new() {
424        let config = WarpIndexConfig::new(2, 16, 32);
425        let index = WarpIndex::new(config);
426
427        assert!(!index.is_trained());
428        assert!(!index.is_built());
429        assert!(index.is_empty());
430    }
431
432    #[test]
433    fn test_index_config() {
434        let config = WarpIndexConfig::new(4, 32, 64);
435        let index = WarpIndex::new(config);
436
437        assert_eq!(index.config().nbits, 4);
438        assert_eq!(index.config().num_centroids, 32);
439        assert_eq!(index.config().token_dim, 64);
440    }
441
442    // ============ Training Tests ============
443
444    #[test]
445    fn test_index_train() {
446        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
447        let mut index = WarpIndex::new(config);
448
449        // Generate training samples
450        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
451
452        index.train(&samples).unwrap();
453
454        assert!(index.is_trained());
455        assert!(index.codec().is_some());
456    }
457
458    #[test]
459    fn test_index_train_insufficient_samples() {
460        let config = WarpIndexConfig::new(2, 100, 16); // 100 centroids needs 1000+ samples
461        let mut index = WarpIndex::new(config);
462
463        let samples: Vec<_> = (0..5).map(|i| generate_embedding(10, 16, i)).collect();
464
465        let result = index.train(&samples);
466        assert!(result.is_err());
467    }
468
469    // ============ Insert Tests ============
470
471    #[test]
472    fn test_index_insert() {
473        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
474        let mut index = WarpIndex::new(config);
475
476        // Train first
477        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
478        index.train(&samples).unwrap();
479
480        // Insert chunk
481        let chunk = create_test_chunk("test content");
482        let embedding = generate_embedding(5, 16, 999);
483        index.insert(chunk, embedding).unwrap();
484
485        assert_eq!(index.num_chunks(), 1);
486    }
487
488    #[test]
489    fn test_index_insert_without_training() {
490        let config = WarpIndexConfig::new(2, 8, 16);
491        let mut index = WarpIndex::new(config);
492
493        let chunk = create_test_chunk("test");
494        let embedding = generate_embedding(5, 16, 0);
495
496        let result = index.insert(chunk, embedding);
497        assert!(result.is_err());
498    }
499
500    // ============ Build Tests ============
501
502    #[test]
503    fn test_index_build() {
504        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
505        let mut index = WarpIndex::new(config);
506
507        // Train
508        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
509        index.train(&samples).unwrap();
510
511        // Insert
512        for i in 0..10 {
513            let chunk = create_test_chunk(&format!("document {}", i));
514            let embedding = generate_embedding(5, 16, 1000 + i);
515            index.insert(chunk, embedding).unwrap();
516        }
517
518        // Build
519        index.build().unwrap();
520
521        assert!(index.is_built());
522        assert_eq!(index.num_chunks(), 10);
523        assert_eq!(index.num_tokens(), 50); // 10 chunks × 5 tokens
524    }
525
526    #[test]
527    fn test_index_cannot_insert_after_build() {
528        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
529        let mut index = WarpIndex::new(config);
530
531        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
532        index.train(&samples).unwrap();
533
534        let chunk = create_test_chunk("test");
535        let embedding = generate_embedding(5, 16, 0);
536        index.insert(chunk, embedding).unwrap();
537
538        index.build().unwrap();
539
540        // Try to insert after build
541        let chunk2 = create_test_chunk("test2");
542        let embedding2 = generate_embedding(5, 16, 1);
543        let result = index.insert(chunk2, embedding2);
544
545        assert!(result.is_err());
546    }
547
548    // ============ Search Tests ============
549
550    #[test]
551    fn test_index_search() {
552        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
553        let mut index = WarpIndex::new(config);
554
555        // Train
556        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
557        index.train(&samples).unwrap();
558
559        // Insert
560        for i in 0..20 {
561            let chunk = create_test_chunk(&format!("document {}", i));
562            let embedding = generate_embedding(5, 16, 1000 + i);
563            index.insert(chunk, embedding).unwrap();
564        }
565
566        // Build
567        index.build().unwrap();
568
569        // Search
570        let query = generate_embedding(3, 16, 9999);
571        let search_config = WarpSearchConfig::with_k(5);
572        let results = index.search(&query, &search_config).unwrap();
573
574        assert!(results.len() <= 5);
575        assert!(!results.is_empty());
576
577        // Results should be sorted by score descending
578        for i in 1..results.len() {
579            assert!(results[i - 1].1 >= results[i].1);
580        }
581    }
582
583    #[test]
584    fn test_index_search_without_build() {
585        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
586        let mut index = WarpIndex::new(config);
587
588        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
589        index.train(&samples).unwrap();
590
591        let query = generate_embedding(3, 16, 0);
592        let search_config = WarpSearchConfig::with_k(5);
593        let result = index.search(&query, &search_config);
594
595        assert!(result.is_err());
596    }
597
598    // ============ Memory & Stats Tests ============
599
600    #[test]
601    fn test_index_memory_usage() {
602        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
603        let mut index = WarpIndex::new(config);
604
605        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
606        index.train(&samples).unwrap();
607
608        for i in 0..10 {
609            let chunk = create_test_chunk(&format!("doc {}", i));
610            let embedding = generate_embedding(5, 16, 1000 + i);
611            index.insert(chunk, embedding).unwrap();
612        }
613
614        index.build().unwrap();
615
616        let memory = index.memory_usage();
617        assert!(memory > 0);
618    }
619
620    #[test]
621    fn test_index_centroid_stats() {
622        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
623        let mut index = WarpIndex::new(config);
624
625        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
626        index.train(&samples).unwrap();
627
628        for i in 0..10 {
629            let chunk = create_test_chunk(&format!("doc {}", i));
630            let embedding = generate_embedding(5, 16, 1000 + i);
631            index.insert(chunk, embedding).unwrap();
632        }
633
634        index.build().unwrap();
635
636        // Total tokens across centroids should equal num_tokens
637        let total: usize = (0..8).map(|c| index.centroid_size(c)).sum();
638        assert_eq!(total, index.num_tokens());
639    }
640
641    // ============ Clear & Rebuild Tests ============
642
643    #[test]
644    fn test_index_clear_and_rebuild() {
645        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
646        let mut index = WarpIndex::new(config);
647
648        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
649        index.train(&samples).unwrap();
650
651        let chunk = create_test_chunk("test");
652        let embedding = generate_embedding(5, 16, 0);
653        index.insert(chunk, embedding).unwrap();
654        index.build().unwrap();
655
656        assert!(index.is_built());
657
658        index.clear_index();
659
660        assert!(!index.is_built());
661        assert_eq!(index.num_tokens(), 0);
662        // Chunks are preserved
663        assert_eq!(index.num_chunks(), 1);
664    }
665
666    // ============ Get Chunk Tests ============
667
668    #[test]
669    fn test_index_get_chunk() {
670        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
671        let mut index = WarpIndex::new(config);
672
673        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
674        index.train(&samples).unwrap();
675
676        let chunk = create_test_chunk("test content");
677        let chunk_id = chunk.id;
678        let embedding = generate_embedding(5, 16, 0);
679        index.insert(chunk, embedding).unwrap();
680
681        let retrieved = index.get_chunk(&chunk_id);
682        assert!(retrieved.is_some());
683        assert_eq!(retrieved.unwrap().content, "test content");
684    }
685
686    // ============ Property-Based Tests ============
687
688    use proptest::prelude::*;
689
690    proptest! {
691        #[test]
692        fn prop_search_returns_at_most_k(k in 1usize..20) {
693            let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
694            let mut index = WarpIndex::new(config);
695
696            let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
697            index.train(&samples).unwrap();
698
699            for i in 0..30 {
700                let chunk = create_test_chunk(&format!("doc {}", i));
701                let embedding = generate_embedding(5, 16, 1000 + i as u64);
702                index.insert(chunk, embedding).unwrap();
703            }
704
705            index.build().unwrap();
706
707            let query = generate_embedding(3, 16, 9999);
708            let search_config = WarpSearchConfig::with_k(k);
709            let results = index.search(&query, &search_config).unwrap();
710
711            prop_assert!(results.len() <= k);
712        }
713
714        #[test]
715        fn prop_search_results_sorted_descending(seed in 0u64..1000) {
716            let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
717            let mut index = WarpIndex::new(config);
718
719            let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
720            index.train(&samples).unwrap();
721
722            for i in 0..20 {
723                let chunk = create_test_chunk(&format!("doc {}", i));
724                let embedding = generate_embedding(5, 16, seed + i as u64);
725                index.insert(chunk, embedding).unwrap();
726            }
727
728            index.build().unwrap();
729
730            let query = generate_embedding(3, 16, seed + 1000);
731            let search_config = WarpSearchConfig::with_k(10);
732            let results = index.search(&query, &search_config).unwrap();
733
734            for i in 1..results.len() {
735                prop_assert!(results[i - 1].1 >= results[i].1);
736            }
737        }
738    }
739}