Skip to main content

aprender_rag/multivector/
types.rs

1//! Core data structures for multi-vector retrieval
2//!
3//! This module defines the fundamental types used in WARP-based multi-vector
4//! retrieval, including embeddings, index configuration, and search parameters.
5
6use serde::{Deserialize, Serialize};
7
8/// A document or query represented as multiple token embeddings.
9///
10/// In ColBERT-style retrieval, each document and query is represented not by
11/// a single embedding vector, but by multiple vectors—one per token. This
12/// enables fine-grained "late interaction" scoring via MaxSim.
13///
14/// # Memory Layout
15///
16/// Embeddings are stored in a flattened contiguous array for cache efficiency:
17/// `[token_0_dim_0, token_0_dim_1, ..., token_1_dim_0, token_1_dim_1, ...]`
18///
19/// # Example
20///
21/// ```
22/// use aprender_rag::multivector::MultiVectorEmbedding;
23///
24/// // Create a 3-token embedding with 128 dimensions per token
25/// let embeddings = vec![0.0f32; 3 * 128];
26/// let mv = MultiVectorEmbedding::new(embeddings, 3, 128);
27///
28/// assert_eq!(mv.num_tokens(), 3);
29/// assert_eq!(mv.dim(), 128);
30/// assert_eq!(mv.token(0).len(), 128);
31/// ```
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct MultiVectorEmbedding {
34    /// Flattened embeddings: [num_tokens * dim]
35    embeddings: Vec<f32>,
36    /// Number of token embeddings
37    num_tokens: usize,
38    /// Dimension per token embedding
39    dim: usize,
40}
41
42impl MultiVectorEmbedding {
43    /// Create a new multi-vector embedding.
44    ///
45    /// # Panics
46    ///
47    /// Panics if `embeddings.len() != num_tokens * dim`.
48    #[must_use]
49    pub fn new(embeddings: Vec<f32>, num_tokens: usize, dim: usize) -> Self {
50        assert_eq!(
51            embeddings.len(),
52            num_tokens * dim,
53            "Embedding size mismatch: expected {} ({}×{}), got {}",
54            num_tokens * dim,
55            num_tokens,
56            dim,
57            embeddings.len()
58        );
59        // Contract: embedding-algebra-v1.yaml precondition (pv codegen)
60        contract_pre_embedding_lookup!(embeddings);
61        Self { embeddings, num_tokens, dim }
62    }
63
64    /// Create from a vector of token embeddings.
65    #[must_use]
66    pub fn from_tokens(tokens: &[Vec<f32>]) -> Self {
67        if tokens.is_empty() {
68            return Self { embeddings: Vec::new(), num_tokens: 0, dim: 0 };
69        }
70
71        let dim = tokens[0].len();
72        let num_tokens = tokens.len();
73        let mut embeddings = Vec::with_capacity(num_tokens * dim);
74
75        for token in tokens {
76            assert_eq!(token.len(), dim, "All tokens must have the same dimension");
77            embeddings.extend_from_slice(token);
78        }
79
80        Self { embeddings, num_tokens, dim }
81    }
82
83    /// Get the number of token embeddings.
84    #[must_use]
85    pub fn num_tokens(&self) -> usize {
86        self.num_tokens
87    }
88
89    /// Get the dimension of each token embedding.
90    #[must_use]
91    pub fn dim(&self) -> usize {
92        self.dim
93    }
94
95    /// Get the i-th token embedding as a slice.
96    ///
97    /// # Panics
98    ///
99    /// Panics if `i >= num_tokens`.
100    #[must_use]
101    pub fn token(&self, i: usize) -> &[f32] {
102        assert!(i < self.num_tokens, "Token index out of bounds");
103        let start = i * self.dim;
104        &self.embeddings[start..start + self.dim]
105    }
106
107    /// Iterate over token embeddings.
108    ///
109    /// Returns an empty iterator if `dim == 0` (poka-yoke: prevents
110    /// `chunks_exact(0)` panic from uninitialized embedding config).
111    pub fn tokens(&self) -> impl Iterator<Item = &[f32]> {
112        if self.dim == 0 {
113            // chunks_exact(0) panics — return empty iterator instead
114            [].chunks_exact(1)
115        } else {
116            self.embeddings.chunks_exact(self.dim)
117        }
118    }
119
120    /// Get the raw flattened embeddings.
121    #[must_use]
122    pub fn as_slice(&self) -> &[f32] {
123        &self.embeddings
124    }
125
126    /// Get the raw flattened embeddings mutably.
127    pub fn as_mut_slice(&mut self) -> &mut [f32] {
128        &mut self.embeddings
129    }
130
131    /// Memory size in bytes (uncompressed).
132    #[must_use]
133    pub fn size_bytes(&self) -> usize {
134        self.embeddings.len() * size_of::<f32>()
135    }
136
137    /// Check if the embedding is empty (no tokens).
138    #[must_use]
139    pub fn is_empty(&self) -> bool {
140        self.num_tokens == 0
141    }
142}
143
144/// Configuration for WARP index construction.
145///
146/// These parameters control the compression quality and index structure.
147/// The default values provide a good balance of memory efficiency and
148/// retrieval quality for most use cases.
149///
150/// # Parameter Guidance
151///
152/// | Corpus Size  | nbits | num_centroids |
153/// |--------------|-------|---------------|
154/// | < 100K docs  | 4     | 256           |
155/// | 100K - 1M    | 2     | 1024          |
156/// | > 1M docs    | 2     | 4096          |
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct WarpIndexConfig {
159    /// Bits per dimension for residual quantization (2 or 4).
160    ///
161    /// - 2-bit: 16× compression, ~3-5% MRR loss
162    /// - 4-bit: 8× compression, ~1-2% MRR loss
163    pub nbits: u8,
164
165    /// Number of centroids for IVF clustering.
166    ///
167    /// More centroids provide finer-grained partitioning but require
168    /// more memory for centroid storage. Typical values: 256-4096.
169    pub num_centroids: usize,
170
171    /// Token embedding dimension (e.g., 128 for ColBERT).
172    pub token_dim: usize,
173
174    /// Minimum training samples for codec training.
175    ///
176    /// Should be at least 10 × num_centroids for stable clustering.
177    /// If None, defaults to 10 × num_centroids.
178    pub min_training_samples: Option<usize>,
179
180    /// K-means iterations for centroid training.
181    pub kmeans_iterations: usize,
182}
183
184impl Default for WarpIndexConfig {
185    fn default() -> Self {
186        Self {
187            nbits: 2,
188            num_centroids: 1024,
189            token_dim: 128,
190            min_training_samples: None,
191            kmeans_iterations: 20,
192        }
193    }
194}
195
196impl WarpIndexConfig {
197    /// Create a new configuration with the specified parameters.
198    #[must_use]
199    pub fn new(nbits: u8, num_centroids: usize, token_dim: usize) -> Self {
200        Self { nbits, num_centroids, token_dim, ..Default::default() }
201    }
202
203    /// Set the minimum training samples.
204    #[must_use]
205    pub fn with_min_training_samples(mut self, samples: usize) -> Self {
206        self.min_training_samples = Some(samples);
207        self
208    }
209
210    /// Set the k-means iterations.
211    #[must_use]
212    pub fn with_kmeans_iterations(mut self, iterations: usize) -> Self {
213        self.kmeans_iterations = iterations;
214        self
215    }
216
217    /// Get the effective minimum training samples.
218    #[must_use]
219    pub fn effective_min_training_samples(&self) -> usize {
220        self.min_training_samples.unwrap_or(10 * self.num_centroids)
221    }
222
223    /// Calculate packed residual size in bytes.
224    #[must_use]
225    pub fn packed_residual_size(&self) -> usize {
226        (self.token_dim * self.nbits as usize + 7) / 8
227    }
228
229    /// Validate the configuration.
230    pub fn validate(&self) -> Result<(), &'static str> {
231        if self.nbits != 2 && self.nbits != 4 {
232            return Err("nbits must be 2 or 4");
233        }
234        if self.num_centroids == 0 {
235            return Err("num_centroids must be > 0");
236        }
237        if self.token_dim == 0 {
238            return Err("token_dim must be > 0");
239        }
240        if self.kmeans_iterations == 0 {
241            return Err("kmeans_iterations must be > 0");
242        }
243        Ok(())
244    }
245}
246
247/// Configuration for WARP search.
248///
249/// These parameters control the trade-off between search speed and
250/// recall quality. The defaults are tuned for high recall (>95%).
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct WarpSearchConfig {
253    /// Number of results to return.
254    pub k: usize,
255
256    /// Centroids to probe per query token.
257    ///
258    /// Higher values increase recall but also latency.
259    /// Default: 4 (provides ~95% recall on most datasets).
260    pub nprobe: u32,
261
262    /// Maximum total centroids examined across all tokens.
263    ///
264    /// Acts as an upper bound on computation. Default: 128.
265    pub bound: usize,
266
267    /// Early termination: skip tokens after this many.
268    ///
269    /// For very long queries, processing all tokens may be wasteful.
270    /// Setting this limits which tokens contribute to scoring.
271    pub t_prime: Option<usize>,
272
273    /// Skip tokens with centroid score below threshold.
274    ///
275    /// Tokens that don't match any centroid well are unlikely to
276    /// contribute meaningful scores. Default: 0.4.
277    pub centroid_score_threshold: f32,
278}
279
280impl Default for WarpSearchConfig {
281    fn default() -> Self {
282        Self { k: 10, nprobe: 4, bound: 128, t_prime: None, centroid_score_threshold: 0.4 }
283    }
284}
285
286impl WarpSearchConfig {
287    /// Create a search config with the specified k.
288    #[must_use]
289    pub fn with_k(k: usize) -> Self {
290        Self { k, ..Default::default() }
291    }
292
293    /// Set nprobe (centroids per token).
294    #[must_use]
295    pub fn nprobe(mut self, nprobe: u32) -> Self {
296        self.nprobe = nprobe;
297        self
298    }
299
300    /// Set the centroid bound.
301    #[must_use]
302    pub fn bound(mut self, bound: usize) -> Self {
303        self.bound = bound;
304        self
305    }
306
307    /// Set early termination threshold.
308    #[must_use]
309    pub fn t_prime(mut self, t_prime: usize) -> Self {
310        self.t_prime = Some(t_prime);
311        self
312    }
313
314    /// Set centroid score threshold.
315    #[must_use]
316    pub fn centroid_score_threshold(mut self, threshold: f32) -> Self {
317        self.centroid_score_threshold = threshold;
318        self
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    // ============ MultiVectorEmbedding Tests ============
327
328    #[test]
329    fn test_multivector_new() {
330        let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
331        let mv = MultiVectorEmbedding::new(embeddings, 2, 3);
332
333        assert_eq!(mv.num_tokens(), 2);
334        assert_eq!(mv.dim(), 3);
335        assert_eq!(mv.token(0), &[1.0, 2.0, 3.0]);
336        assert_eq!(mv.token(1), &[4.0, 5.0, 6.0]);
337    }
338
339    #[test]
340    #[should_panic(expected = "Embedding size mismatch")]
341    fn test_multivector_size_mismatch() {
342        let embeddings = vec![1.0, 2.0, 3.0];
343        let _ = MultiVectorEmbedding::new(embeddings, 2, 3); // Should panic
344    }
345
346    #[test]
347    fn test_multivector_from_tokens() {
348        let tokens = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
349        let mv = MultiVectorEmbedding::from_tokens(&tokens);
350
351        assert_eq!(mv.num_tokens(), 3);
352        assert_eq!(mv.dim(), 2);
353    }
354
355    #[test]
356    fn test_multivector_from_tokens_empty() {
357        let tokens: Vec<Vec<f32>> = vec![];
358        let mv = MultiVectorEmbedding::from_tokens(&tokens);
359
360        assert_eq!(mv.num_tokens(), 0);
361        assert!(mv.is_empty());
362    }
363
364    /// Regression test for paiml/trueno-rag#15: division by zero when dim is 0.
365    /// `from_tokens(&[])` produces dim=0; `tokens()` must not panic.
366    #[test]
367    fn test_multivector_dim_zero_tokens_no_panic() {
368        let mv = MultiVectorEmbedding::from_tokens(&[]);
369        assert_eq!(mv.dim(), 0);
370        assert_eq!(mv.tokens().count(), 0); // must not panic
371    }
372
373    /// Regression: `new(vec![], 0, 0)` is valid (empty embedding) and
374    /// iterating tokens must return an empty iterator, not div-by-zero.
375    #[test]
376    fn test_multivector_new_zero_dim_zero_tokens() {
377        let mv = MultiVectorEmbedding::new(vec![], 0, 0);
378        assert_eq!(mv.tokens().count(), 0);
379        assert!(mv.is_empty());
380    }
381
382    #[test]
383    fn test_multivector_tokens_iterator() {
384        let embeddings = vec![1.0, 2.0, 3.0, 4.0];
385        let mv = MultiVectorEmbedding::new(embeddings, 2, 2);
386
387        let tokens: Vec<&[f32]> = mv.tokens().collect();
388        assert_eq!(tokens.len(), 2);
389        assert_eq!(tokens[0], &[1.0, 2.0]);
390        assert_eq!(tokens[1], &[3.0, 4.0]);
391    }
392
393    #[test]
394    fn test_multivector_size_bytes() {
395        let embeddings = vec![0.0; 100];
396        let mv = MultiVectorEmbedding::new(embeddings, 10, 10);
397
398        assert_eq!(mv.size_bytes(), 100 * 4); // 100 f32s × 4 bytes
399    }
400
401    #[test]
402    fn test_multivector_as_slice() {
403        let embeddings = vec![1.0, 2.0, 3.0];
404        let mv = MultiVectorEmbedding::new(embeddings.clone(), 1, 3);
405
406        assert_eq!(mv.as_slice(), &[1.0, 2.0, 3.0]);
407    }
408
409    #[test]
410    fn test_multivector_serialization() {
411        let mv = MultiVectorEmbedding::new(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
412        let json = serde_json::to_string(&mv).unwrap();
413        let deserialized: MultiVectorEmbedding = serde_json::from_str(&json).unwrap();
414
415        assert_eq!(mv.num_tokens(), deserialized.num_tokens());
416        assert_eq!(mv.dim(), deserialized.dim());
417        assert_eq!(mv.as_slice(), deserialized.as_slice());
418    }
419
420    // ============ WarpIndexConfig Tests ============
421
422    #[test]
423    fn test_index_config_default() {
424        let config = WarpIndexConfig::default();
425
426        assert_eq!(config.nbits, 2);
427        assert_eq!(config.num_centroids, 1024);
428        assert_eq!(config.token_dim, 128);
429        assert_eq!(config.kmeans_iterations, 20);
430    }
431
432    #[test]
433    fn test_index_config_new() {
434        let config = WarpIndexConfig::new(4, 256, 64);
435
436        assert_eq!(config.nbits, 4);
437        assert_eq!(config.num_centroids, 256);
438        assert_eq!(config.token_dim, 64);
439    }
440
441    #[test]
442    fn test_index_config_builders() {
443        let config = WarpIndexConfig::new(2, 512, 128)
444            .with_min_training_samples(5000)
445            .with_kmeans_iterations(30);
446
447        assert_eq!(config.min_training_samples, Some(5000));
448        assert_eq!(config.kmeans_iterations, 30);
449    }
450
451    #[test]
452    fn test_index_config_effective_min_samples() {
453        let config = WarpIndexConfig::new(2, 100, 128);
454        assert_eq!(config.effective_min_training_samples(), 1000); // 10 × 100
455
456        let config = config.with_min_training_samples(500);
457        assert_eq!(config.effective_min_training_samples(), 500);
458    }
459
460    #[test]
461    fn test_index_config_packed_size() {
462        // 128 dims × 2 bits = 256 bits = 32 bytes
463        let config = WarpIndexConfig::new(2, 1024, 128);
464        assert_eq!(config.packed_residual_size(), 32);
465
466        // 128 dims × 4 bits = 512 bits = 64 bytes
467        let config = WarpIndexConfig::new(4, 1024, 128);
468        assert_eq!(config.packed_residual_size(), 64);
469    }
470
471    #[test]
472    fn test_index_config_validate() {
473        let config = WarpIndexConfig::default();
474        assert!(config.validate().is_ok());
475
476        let bad_nbits = WarpIndexConfig { nbits: 3, ..Default::default() };
477        assert!(bad_nbits.validate().is_err());
478
479        let bad_centroids = WarpIndexConfig { num_centroids: 0, ..Default::default() };
480        assert!(bad_centroids.validate().is_err());
481    }
482
483    #[test]
484    fn test_index_config_serialization() {
485        let config = WarpIndexConfig::new(4, 512, 64);
486        let json = serde_json::to_string(&config).unwrap();
487        let deserialized: WarpIndexConfig = serde_json::from_str(&json).unwrap();
488
489        assert_eq!(config.nbits, deserialized.nbits);
490        assert_eq!(config.num_centroids, deserialized.num_centroids);
491        assert_eq!(config.token_dim, deserialized.token_dim);
492    }
493
494    // ============ WarpSearchConfig Tests ============
495
496    #[test]
497    fn test_search_config_default() {
498        let config = WarpSearchConfig::default();
499
500        assert_eq!(config.k, 10);
501        assert_eq!(config.nprobe, 4);
502        assert_eq!(config.bound, 128);
503        assert!(config.t_prime.is_none());
504        assert!((config.centroid_score_threshold - 0.4).abs() < 0.001);
505    }
506
507    #[test]
508    fn test_search_config_with_k() {
509        let config = WarpSearchConfig::with_k(20);
510        assert_eq!(config.k, 20);
511    }
512
513    #[test]
514    fn test_search_config_builders() {
515        let config = WarpSearchConfig::with_k(5)
516            .nprobe(8)
517            .bound(256)
518            .t_prime(10)
519            .centroid_score_threshold(0.5);
520
521        assert_eq!(config.k, 5);
522        assert_eq!(config.nprobe, 8);
523        assert_eq!(config.bound, 256);
524        assert_eq!(config.t_prime, Some(10));
525        assert!((config.centroid_score_threshold - 0.5).abs() < 0.001);
526    }
527
528    #[test]
529    fn test_search_config_serialization() {
530        let config = WarpSearchConfig::with_k(15).nprobe(6);
531        let json = serde_json::to_string(&config).unwrap();
532        let deserialized: WarpSearchConfig = serde_json::from_str(&json).unwrap();
533
534        assert_eq!(config.k, deserialized.k);
535        assert_eq!(config.nprobe, deserialized.nprobe);
536    }
537
538    // ============ Property-Based Tests ============
539
540    use proptest::prelude::*;
541
542    proptest! {
543        #[test]
544        fn prop_multivector_tokens_count_matches(
545            num_tokens in 1usize..20,
546            dim in 1usize..64
547        ) {
548            let embeddings = vec![0.0f32; num_tokens * dim];
549            let mv = MultiVectorEmbedding::new(embeddings, num_tokens, dim);
550
551            prop_assert_eq!(mv.num_tokens(), num_tokens);
552            prop_assert_eq!(mv.dim(), dim);
553            prop_assert_eq!(mv.tokens().count(), num_tokens);
554        }
555
556        #[test]
557        fn prop_multivector_token_slices_correct_size(
558            num_tokens in 1usize..10,
559            dim in 1usize..32
560        ) {
561            let embeddings = vec![0.0f32; num_tokens * dim];
562            let mv = MultiVectorEmbedding::new(embeddings, num_tokens, dim);
563
564            for i in 0..num_tokens {
565                prop_assert_eq!(mv.token(i).len(), dim);
566            }
567        }
568
569        #[test]
570        fn prop_index_config_packed_size_formula(
571            nbits in prop::sample::select(vec![2u8, 4]),
572            dim in 1usize..256
573        ) {
574            let config = WarpIndexConfig::new(nbits, 1024, dim);
575            let expected = (dim * nbits as usize + 7) / 8;
576            prop_assert_eq!(config.packed_residual_size(), expected);
577        }
578    }
579}