Skip to main content

aprender_rag/multivector/
search.rs

1//! WARP search algorithm components
2//!
3//! This module implements the three phases of WARP search:
4//!
5//! 1. **Centroid Selection** - For each query token, find top-nprobe centroids
6//! 2. **Candidate Scoring** - Decompress and score tokens from selected centroids
7//! 3. **Score Merging** - Aggregate per-token scores into document scores via MaxSim
8
9use crate::multivector::{codec::ResidualCodec, types::WarpSearchConfig, MultiVectorEmbedding};
10use crate::ChunkId;
11use std::collections::HashMap;
12
13/// Phase 1: Select top centroids per query token.
14///
15/// For each query token, compute its similarity with all centroids and
16/// select the top-nprobe centroids above the score threshold.
17pub struct CentroidSelector;
18
19impl CentroidSelector {
20    /// Select top centroids for each query token.
21    ///
22    /// # Arguments
23    ///
24    /// * `query` - Query multi-vector embedding
25    /// * `centroids` - Flattened centroid vectors [num_centroids × dim]
26    /// * `dim` - Token embedding dimension
27    /// * `config` - Search configuration
28    ///
29    /// # Returns
30    ///
31    /// For each query token, a vector of (centroid_id, centroid_score) pairs
32    /// sorted by score descending.
33    #[must_use]
34    pub fn select(
35        query: &MultiVectorEmbedding,
36        centroids: &[f32],
37        dim: usize,
38        config: &WarpSearchConfig,
39    ) -> Vec<Vec<(usize, f32)>> {
40        if dim == 0 || centroids.is_empty() {
41            return query.tokens().map(|_| vec![]).collect();
42        }
43        let num_centroids = centroids.len() / dim;
44
45        query
46            .tokens()
47            .map(|query_token| {
48                // Compute scores with all centroids
49                let mut scores: Vec<(usize, f32)> = (0..num_centroids)
50                    .map(|c| {
51                        let centroid = &centroids[c * dim..(c + 1) * dim];
52                        let score = Self::dot_product(query_token, centroid);
53                        (c, score)
54                    })
55                    .collect();
56
57                // Sort by score descending
58                scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
59
60                // Take top nprobe, filtered by threshold
61                scores
62                    .into_iter()
63                    .take(config.nprobe as usize)
64                    .filter(|(_, score)| *score >= config.centroid_score_threshold)
65                    .collect()
66            })
67            .collect()
68    }
69
70    /// Batch compute centroid scores for a single query token.
71    ///
72    /// Returns scores for all centroids sorted by score descending.
73    #[must_use]
74    pub fn batch_scores(query_token: &[f32], centroids: &[f32], dim: usize) -> Vec<(usize, f32)> {
75        if dim == 0 || centroids.is_empty() {
76            return vec![];
77        }
78        let num_centroids = centroids.len() / dim;
79
80        let mut scores: Vec<(usize, f32)> = (0..num_centroids)
81            .map(|c| {
82                let centroid = &centroids[c * dim..(c + 1) * dim];
83                let score = Self::dot_product(query_token, centroid);
84                (c, score)
85            })
86            .collect();
87
88        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
89        scores
90    }
91
92    fn dot_product(a: &[f32], b: &[f32]) -> f32 {
93        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
94    }
95}
96
97/// Phase 2: Score candidates from a centroid.
98///
99/// For a single query token and centroid, decompress and score all
100/// document tokens assigned to that centroid.
101pub struct CandidateScorer;
102
103impl CandidateScorer {
104    /// Score candidates from a centroid for one query token.
105    ///
106    /// # Arguments
107    ///
108    /// * `query_token` - Query embedding for this token
109    /// * `centroid_id` - Selected centroid ID
110    /// * `centroid_score` - Precomputed q · c
111    /// * `codec` - Residual codec for decompression
112    /// * `sizes` - Number of tokens per centroid
113    /// * `offsets` - Cumulative offsets per centroid
114    /// * `chunk_ids` - Chunk IDs for all tokens
115    /// * `token_indices` - Token indices within chunks
116    /// * `residuals` - Packed residuals for all tokens
117    /// * `bytes_per_residual` - Bytes per packed residual
118    ///
119    /// # Returns
120    ///
121    /// Vector of (ChunkId, token_index, score) for all candidates.
122    #[must_use]
123    #[allow(clippy::too_many_arguments)]
124    pub fn score(
125        query_token: &[f32],
126        centroid_id: usize,
127        centroid_score: f32,
128        codec: &ResidualCodec,
129        sizes: &[usize],
130        offsets: &[usize],
131        chunk_ids: &[ChunkId],
132        token_indices: &[u16],
133        residuals: &[u8],
134        bytes_per_residual: usize,
135    ) -> Vec<(ChunkId, u16, f32)> {
136        let size = sizes.get(centroid_id).copied().unwrap_or(0);
137        if size == 0 {
138            return Vec::new();
139        }
140
141        let offset = offsets.get(centroid_id).copied().unwrap_or(0);
142
143        (0..size)
144            .map(|i| {
145                let idx = offset + i;
146                let chunk_id = chunk_ids[idx];
147                let token_idx = token_indices[idx];
148
149                let residual_start = idx * bytes_per_residual;
150                let residual_end = residual_start + bytes_per_residual;
151                let residual = &residuals[residual_start..residual_end];
152
153                let score =
154                    codec.decompress_score(query_token, centroid_id, centroid_score, residual);
155
156                (chunk_id, token_idx, score)
157            })
158            .collect()
159    }
160
161    /// Score a single candidate.
162    #[must_use]
163    pub fn score_single(
164        query_token: &[f32],
165        centroid_id: usize,
166        centroid_score: f32,
167        codec: &ResidualCodec,
168        residual: &[u8],
169    ) -> f32 {
170        codec.decompress_score(query_token, centroid_id, centroid_score, residual)
171    }
172}
173
174/// Phase 3: Merge per-token scores into document scores via MaxSim.
175///
176/// MaxSim computes: score(Q, D) = Σ_i max_j(q_i · d_j)
177///
178/// For each query token, find the maximum score with any document token,
179/// then sum across query tokens.
180pub struct ScoreMerger;
181
182impl ScoreMerger {
183    /// Merge per-token scores into document scores via MaxSim.
184    ///
185    /// # Arguments
186    ///
187    /// * `token_scores` - For each query token: (ChunkId, doc_token_idx, score)
188    /// * `k` - Number of top results to return
189    ///
190    /// # Returns
191    ///
192    /// Vector of (ChunkId, total_score) sorted by score descending.
193    #[must_use]
194    pub fn merge(token_scores: Vec<Vec<(ChunkId, u16, f32)>>, k: usize) -> Vec<(ChunkId, f32)> {
195        if token_scores.is_empty() {
196            return Vec::new();
197        }
198
199        let num_query_tokens = token_scores.len();
200
201        // For each document, track max score per query token
202        let mut doc_token_maxes: HashMap<ChunkId, Vec<f32>> = HashMap::new();
203
204        for (query_token_idx, scores) in token_scores.into_iter().enumerate() {
205            for (chunk_id, _doc_token_idx, score) in scores {
206                let maxes = doc_token_maxes
207                    .entry(chunk_id)
208                    .or_insert_with(|| vec![f32::NEG_INFINITY; num_query_tokens]);
209
210                if score > maxes[query_token_idx] {
211                    maxes[query_token_idx] = score;
212                }
213            }
214        }
215
216        // Sum max scores across query tokens
217        let mut doc_scores: Vec<(ChunkId, f32)> = doc_token_maxes
218            .into_iter()
219            .map(|(chunk_id, maxes)| {
220                let score: f32 = maxes.into_iter().filter(|&s| s > f32::NEG_INFINITY).sum();
221                (chunk_id, score)
222            })
223            .collect();
224
225        // Sort by score descending
226        doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
227
228        // Take top-k
229        doc_scores.truncate(k);
230        doc_scores
231    }
232
233    /// Merge scores for a single document across query tokens.
234    ///
235    /// This is useful when you have per-token scores already grouped by document.
236    #[must_use]
237    pub fn merge_single_doc(token_max_scores: &[f32]) -> f32 {
238        token_max_scores.iter().filter(|&&s| s > f32::NEG_INFINITY).sum()
239    }
240}
241
242/// Compute exact MaxSim score (for testing/comparison).
243///
244/// This computes the full MaxSim score without compression:
245/// score(Q, D) = Σ_i max_j(q_i · d_j)
246#[must_use]
247pub fn exact_maxsim(query: &MultiVectorEmbedding, doc: &MultiVectorEmbedding) -> f32 {
248    query
249        .tokens()
250        .map(|q| doc.tokens().map(|d| dot_product(q, d)).fold(f32::NEG_INFINITY, f32::max))
251        .filter(|&s| s > f32::NEG_INFINITY)
252        .sum()
253}
254
255/// Compute dot product between two vectors.
256#[inline]
257fn dot_product(a: &[f32], b: &[f32]) -> f32 {
258    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    fn generate_embedding(num_tokens: usize, dim: usize, seed: u64) -> MultiVectorEmbedding {
266        let mut embeddings = Vec::with_capacity(num_tokens * dim);
267        let mut rng = seed;
268
269        for _ in 0..(num_tokens * dim) {
270            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
271            let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
272            embeddings.push(val);
273        }
274
275        MultiVectorEmbedding::new(embeddings, num_tokens, dim)
276    }
277
278    fn chunk_id(n: u128) -> ChunkId {
279        ChunkId(uuid::Uuid::from_u128(n))
280    }
281
282    // ============ CentroidSelector Tests ============
283
284    #[test]
285    fn test_centroid_selector_basic() {
286        let query = generate_embedding(2, 4, 42);
287
288        // Create 4 centroids
289        let centroids = vec![
290            1.0, 0.0, 0.0, 0.0, // centroid 0
291            0.0, 1.0, 0.0, 0.0, // centroid 1
292            0.0, 0.0, 1.0, 0.0, // centroid 2
293            0.0, 0.0, 0.0, 1.0, // centroid 3
294        ];
295
296        let config = WarpSearchConfig::with_k(10).nprobe(2).centroid_score_threshold(-1.0); // Accept all
297
298        let selected = CentroidSelector::select(&query, &centroids, 4, &config);
299
300        assert_eq!(selected.len(), 2); // 2 query tokens
301        assert!(selected[0].len() <= 2); // nprobe = 2
302    }
303
304    #[test]
305    fn test_centroid_selector_threshold() {
306        let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
307
308        let centroids = vec![
309            1.0, 0.0, 0.0, 0.0, // centroid 0: score = 1.0
310            0.0, 1.0, 0.0, 0.0, // centroid 1: score = 0.0
311            0.5, 0.5, 0.0, 0.0, // centroid 2: score = 0.5
312            0.0, 0.0, 1.0, 0.0, // centroid 3: score = 0.0
313        ];
314
315        let config = WarpSearchConfig::with_k(10).nprobe(4).centroid_score_threshold(0.4);
316
317        let selected = CentroidSelector::select(&query, &centroids, 4, &config);
318
319        // Only centroids with score >= 0.4 should be selected
320        assert_eq!(selected.len(), 1);
321        assert!(selected[0].len() <= 2); // centroid 0 (1.0) and centroid 2 (0.5)
322    }
323
324    #[test]
325    fn test_centroid_selector_sorted() {
326        let query = MultiVectorEmbedding::new(vec![0.5, 0.5, 0.0, 0.0], 1, 4);
327
328        let centroids = vec![
329            1.0, 0.0, 0.0, 0.0, // centroid 0
330            0.0, 1.0, 0.0, 0.0, // centroid 1
331            0.5, 0.5, 0.0, 0.0, // centroid 2 (best match)
332            0.0, 0.0, 1.0, 0.0, // centroid 3
333        ];
334
335        let config = WarpSearchConfig::with_k(10).nprobe(4).centroid_score_threshold(-1.0);
336
337        let selected = CentroidSelector::select(&query, &centroids, 4, &config);
338
339        // Results should be sorted by score descending
340        assert!(!selected[0].is_empty());
341        for i in 1..selected[0].len() {
342            assert!(selected[0][i - 1].1 >= selected[0][i].1);
343        }
344    }
345
346    /// Regression test for paiml/trueno-rag#15: dim=0 must not divide by zero.
347    #[test]
348    fn test_centroid_selector_dim_zero_no_panic() {
349        let query = MultiVectorEmbedding::from_tokens(&[]);
350        let centroids: Vec<f32> = vec![];
351        let config = WarpSearchConfig::with_k(10);
352
353        let selected = CentroidSelector::select(&query, &centroids, 0, &config);
354        assert!(selected.is_empty());
355    }
356
357    /// Regression test for paiml/trueno-rag#15: batch_scores with dim=0.
358    #[test]
359    fn test_batch_scores_dim_zero_no_panic() {
360        let scores = CentroidSelector::batch_scores(&[], &[], 0);
361        assert!(scores.is_empty());
362    }
363
364    #[test]
365    fn test_batch_scores() {
366        let query_token = vec![1.0, 0.0, 0.0, 0.0];
367        let centroids = vec![
368            1.0, 0.0, 0.0, 0.0, // centroid 0
369            0.0, 1.0, 0.0, 0.0, // centroid 1
370        ];
371
372        let scores = CentroidSelector::batch_scores(&query_token, &centroids, 4);
373
374        assert_eq!(scores.len(), 2);
375        assert_eq!(scores[0].0, 0); // Best match is centroid 0
376        assert!((scores[0].1 - 1.0).abs() < 1e-6);
377    }
378
379    // ============ CandidateScorer Tests ============
380
381    #[test]
382    fn test_candidate_scorer_empty_centroid() {
383        let query_token = vec![1.0, 0.0, 0.0, 0.0];
384        let codec = create_test_codec();
385
386        let sizes = vec![0, 5, 3]; // centroid 0 is empty
387        let offsets = vec![0, 0, 5];
388        let chunk_ids: Vec<ChunkId> = vec![];
389        let token_indices: Vec<u16> = vec![];
390        let residuals: Vec<u8> = vec![];
391
392        let results = CandidateScorer::score(
393            &query_token,
394            0, // empty centroid
395            0.5,
396            &codec,
397            &sizes,
398            &offsets,
399            &chunk_ids,
400            &token_indices,
401            &residuals,
402            2, // bytes per residual
403        );
404
405        assert!(results.is_empty());
406    }
407
408    fn create_test_codec() -> ResidualCodec {
409        // Create a minimal test codec
410        let embeddings = vec![0.0f32; 200 * 4]; // 200 samples, dim=4
411        ResidualCodec::train(&embeddings, 4, 4, 2, 3).unwrap()
412    }
413
414    // ============ ScoreMerger Tests ============
415
416    #[test]
417    fn test_score_merger_basic() {
418        let token_scores = vec![
419            vec![(chunk_id(1), 0, 0.9), (chunk_id(2), 0, 0.8), (chunk_id(1), 1, 0.7)],
420            vec![(chunk_id(1), 0, 0.6), (chunk_id(2), 0, 0.5), (chunk_id(3), 0, 0.4)],
421        ];
422
423        let results = ScoreMerger::merge(token_scores, 10);
424
425        // chunk_id(1): max(0.9, 0.7) + max(0.6) = 0.9 + 0.6 = 1.5
426        // chunk_id(2): max(0.8) + max(0.5) = 0.8 + 0.5 = 1.3
427        // chunk_id(3): 0 + max(0.4) = 0.4
428
429        assert_eq!(results.len(), 3);
430        assert_eq!(results[0].0, chunk_id(1));
431        assert!((results[0].1 - 1.5).abs() < 0.001);
432    }
433
434    #[test]
435    fn test_score_merger_empty() {
436        let token_scores: Vec<Vec<(ChunkId, u16, f32)>> = vec![];
437        let results = ScoreMerger::merge(token_scores, 10);
438        assert!(results.is_empty());
439    }
440
441    #[test]
442    fn test_score_merger_respects_k() {
443        let token_scores = vec![vec![
444            (chunk_id(1), 0, 0.9),
445            (chunk_id(2), 0, 0.8),
446            (chunk_id(3), 0, 0.7),
447            (chunk_id(4), 0, 0.6),
448            (chunk_id(5), 0, 0.5),
449        ]];
450
451        let results = ScoreMerger::merge(token_scores, 3);
452        assert_eq!(results.len(), 3);
453    }
454
455    #[test]
456    fn test_score_merger_sorted_descending() {
457        let token_scores =
458            vec![vec![(chunk_id(1), 0, 0.3), (chunk_id(2), 0, 0.9), (chunk_id(3), 0, 0.6)]];
459
460        let results = ScoreMerger::merge(token_scores, 10);
461
462        assert_eq!(results[0].0, chunk_id(2)); // highest
463        assert_eq!(results[1].0, chunk_id(3));
464        assert_eq!(results[2].0, chunk_id(1)); // lowest
465    }
466
467    #[test]
468    fn test_merge_single_doc() {
469        let scores = vec![0.9, 0.6, f32::NEG_INFINITY, 0.3];
470        let total = ScoreMerger::merge_single_doc(&scores);
471
472        assert!((total - 1.8).abs() < 0.001); // 0.9 + 0.6 + 0.3
473    }
474
475    // ============ Exact MaxSim Tests ============
476
477    #[test]
478    fn test_exact_maxsim_identical() {
479        let emb = generate_embedding(3, 4, 42);
480        let score = exact_maxsim(&emb, &emb);
481
482        // Self-similarity: for normalized vectors, this should be num_tokens
483        // For non-normalized, just check it's positive
484        assert!(score > 0.0);
485    }
486
487    #[test]
488    fn test_exact_maxsim_orthogonal() {
489        let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
490        let doc = MultiVectorEmbedding::new(vec![0.0, 1.0, 0.0, 0.0], 1, 4);
491
492        let score = exact_maxsim(&query, &doc);
493        assert!((score - 0.0).abs() < 1e-6);
494    }
495
496    #[test]
497    fn test_exact_maxsim_aligned() {
498        let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
499        let doc = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
500
501        let score = exact_maxsim(&query, &doc);
502        assert!((score - 1.0).abs() < 1e-6);
503    }
504
505    // ============ Property-Based Tests ============
506
507    use proptest::prelude::*;
508
509    proptest! {
510        #[test]
511        fn prop_maxsim_non_negative_for_unit_vectors(
512            num_q in 1usize..5,
513            num_d in 1usize..5
514        ) {
515            // Generate unit vectors
516            let query = generate_embedding(num_q, 4, 123);
517            let doc = generate_embedding(num_d, 4, 456);
518
519            let score = exact_maxsim(&query, &doc);
520
521            // MaxSim can be negative for non-unit vectors, but the test
522            // just checks it doesn't panic
523            prop_assert!(score.is_finite());
524        }
525
526        #[test]
527        fn prop_merger_results_count_bounded_by_k(
528            k in 1usize..20,
529            num_docs in 1usize..50
530        ) {
531            let token_scores = vec![
532                (0..num_docs)
533                    .map(|i| (chunk_id(i as u128), 0u16, i as f32 / 100.0))
534                    .collect()
535            ];
536
537            let results = ScoreMerger::merge(token_scores, k);
538            prop_assert!(results.len() <= k);
539            prop_assert!(results.len() <= num_docs);
540        }
541
542        #[test]
543        fn prop_centroid_selector_respects_nprobe(
544            nprobe in 1u32..10
545        ) {
546            let query = generate_embedding(2, 4, 42);
547            let centroids = vec![0.5f32; 20 * 4]; // 20 centroids
548
549            let config = WarpSearchConfig::with_k(10)
550                .nprobe(nprobe)
551                .centroid_score_threshold(-10.0); // Accept all
552
553            let selected = CentroidSelector::select(&query, &centroids, 4, &config);
554
555            for token_selection in selected {
556                prop_assert!(token_selection.len() <= nprobe as usize);
557            }
558        }
559    }
560}