1use crate::multivector::{codec::ResidualCodec, types::WarpSearchConfig, MultiVectorEmbedding};
10use crate::ChunkId;
11use std::collections::HashMap;
12
13pub struct CentroidSelector;
18
19impl CentroidSelector {
20 #[must_use]
34 pub fn select(
35 query: &MultiVectorEmbedding,
36 centroids: &[f32],
37 dim: usize,
38 config: &WarpSearchConfig,
39 ) -> Vec<Vec<(usize, f32)>> {
40 if dim == 0 || centroids.is_empty() {
41 return query.tokens().map(|_| vec![]).collect();
42 }
43 let num_centroids = centroids.len() / dim;
44
45 query
46 .tokens()
47 .map(|query_token| {
48 let mut scores: Vec<(usize, f32)> = (0..num_centroids)
50 .map(|c| {
51 let centroid = ¢roids[c * dim..(c + 1) * dim];
52 let score = Self::dot_product(query_token, centroid);
53 (c, score)
54 })
55 .collect();
56
57 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
59
60 scores
62 .into_iter()
63 .take(config.nprobe as usize)
64 .filter(|(_, score)| *score >= config.centroid_score_threshold)
65 .collect()
66 })
67 .collect()
68 }
69
70 #[must_use]
74 pub fn batch_scores(query_token: &[f32], centroids: &[f32], dim: usize) -> Vec<(usize, f32)> {
75 if dim == 0 || centroids.is_empty() {
76 return vec![];
77 }
78 let num_centroids = centroids.len() / dim;
79
80 let mut scores: Vec<(usize, f32)> = (0..num_centroids)
81 .map(|c| {
82 let centroid = ¢roids[c * dim..(c + 1) * dim];
83 let score = Self::dot_product(query_token, centroid);
84 (c, score)
85 })
86 .collect();
87
88 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
89 scores
90 }
91
92 fn dot_product(a: &[f32], b: &[f32]) -> f32 {
93 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
94 }
95}
96
97pub struct CandidateScorer;
102
103impl CandidateScorer {
104 #[must_use]
123 #[allow(clippy::too_many_arguments)]
124 pub fn score(
125 query_token: &[f32],
126 centroid_id: usize,
127 centroid_score: f32,
128 codec: &ResidualCodec,
129 sizes: &[usize],
130 offsets: &[usize],
131 chunk_ids: &[ChunkId],
132 token_indices: &[u16],
133 residuals: &[u8],
134 bytes_per_residual: usize,
135 ) -> Vec<(ChunkId, u16, f32)> {
136 let size = sizes.get(centroid_id).copied().unwrap_or(0);
137 if size == 0 {
138 return Vec::new();
139 }
140
141 let offset = offsets.get(centroid_id).copied().unwrap_or(0);
142
143 (0..size)
144 .map(|i| {
145 let idx = offset + i;
146 let chunk_id = chunk_ids[idx];
147 let token_idx = token_indices[idx];
148
149 let residual_start = idx * bytes_per_residual;
150 let residual_end = residual_start + bytes_per_residual;
151 let residual = &residuals[residual_start..residual_end];
152
153 let score =
154 codec.decompress_score(query_token, centroid_id, centroid_score, residual);
155
156 (chunk_id, token_idx, score)
157 })
158 .collect()
159 }
160
161 #[must_use]
163 pub fn score_single(
164 query_token: &[f32],
165 centroid_id: usize,
166 centroid_score: f32,
167 codec: &ResidualCodec,
168 residual: &[u8],
169 ) -> f32 {
170 codec.decompress_score(query_token, centroid_id, centroid_score, residual)
171 }
172}
173
174pub struct ScoreMerger;
181
182impl ScoreMerger {
183 #[must_use]
194 pub fn merge(token_scores: Vec<Vec<(ChunkId, u16, f32)>>, k: usize) -> Vec<(ChunkId, f32)> {
195 if token_scores.is_empty() {
196 return Vec::new();
197 }
198
199 let num_query_tokens = token_scores.len();
200
201 let mut doc_token_maxes: HashMap<ChunkId, Vec<f32>> = HashMap::new();
203
204 for (query_token_idx, scores) in token_scores.into_iter().enumerate() {
205 for (chunk_id, _doc_token_idx, score) in scores {
206 let maxes = doc_token_maxes
207 .entry(chunk_id)
208 .or_insert_with(|| vec![f32::NEG_INFINITY; num_query_tokens]);
209
210 if score > maxes[query_token_idx] {
211 maxes[query_token_idx] = score;
212 }
213 }
214 }
215
216 let mut doc_scores: Vec<(ChunkId, f32)> = doc_token_maxes
218 .into_iter()
219 .map(|(chunk_id, maxes)| {
220 let score: f32 = maxes.into_iter().filter(|&s| s > f32::NEG_INFINITY).sum();
221 (chunk_id, score)
222 })
223 .collect();
224
225 doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
227
228 doc_scores.truncate(k);
230 doc_scores
231 }
232
233 #[must_use]
237 pub fn merge_single_doc(token_max_scores: &[f32]) -> f32 {
238 token_max_scores.iter().filter(|&&s| s > f32::NEG_INFINITY).sum()
239 }
240}
241
242#[must_use]
247pub fn exact_maxsim(query: &MultiVectorEmbedding, doc: &MultiVectorEmbedding) -> f32 {
248 query
249 .tokens()
250 .map(|q| doc.tokens().map(|d| dot_product(q, d)).fold(f32::NEG_INFINITY, f32::max))
251 .filter(|&s| s > f32::NEG_INFINITY)
252 .sum()
253}
254
255#[inline]
257fn dot_product(a: &[f32], b: &[f32]) -> f32 {
258 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 fn generate_embedding(num_tokens: usize, dim: usize, seed: u64) -> MultiVectorEmbedding {
266 let mut embeddings = Vec::with_capacity(num_tokens * dim);
267 let mut rng = seed;
268
269 for _ in 0..(num_tokens * dim) {
270 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
271 let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
272 embeddings.push(val);
273 }
274
275 MultiVectorEmbedding::new(embeddings, num_tokens, dim)
276 }
277
278 fn chunk_id(n: u128) -> ChunkId {
279 ChunkId(uuid::Uuid::from_u128(n))
280 }
281
282 #[test]
285 fn test_centroid_selector_basic() {
286 let query = generate_embedding(2, 4, 42);
287
288 let centroids = vec![
290 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, ];
295
296 let config = WarpSearchConfig::with_k(10).nprobe(2).centroid_score_threshold(-1.0); let selected = CentroidSelector::select(&query, ¢roids, 4, &config);
299
300 assert_eq!(selected.len(), 2); assert!(selected[0].len() <= 2); }
303
304 #[test]
305 fn test_centroid_selector_threshold() {
306 let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
307
308 let centroids = vec![
309 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
314
315 let config = WarpSearchConfig::with_k(10).nprobe(4).centroid_score_threshold(0.4);
316
317 let selected = CentroidSelector::select(&query, ¢roids, 4, &config);
318
319 assert_eq!(selected.len(), 1);
321 assert!(selected[0].len() <= 2); }
323
324 #[test]
325 fn test_centroid_selector_sorted() {
326 let query = MultiVectorEmbedding::new(vec![0.5, 0.5, 0.0, 0.0], 1, 4);
327
328 let centroids = vec![
329 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
334
335 let config = WarpSearchConfig::with_k(10).nprobe(4).centroid_score_threshold(-1.0);
336
337 let selected = CentroidSelector::select(&query, ¢roids, 4, &config);
338
339 assert!(!selected[0].is_empty());
341 for i in 1..selected[0].len() {
342 assert!(selected[0][i - 1].1 >= selected[0][i].1);
343 }
344 }
345
346 #[test]
348 fn test_centroid_selector_dim_zero_no_panic() {
349 let query = MultiVectorEmbedding::from_tokens(&[]);
350 let centroids: Vec<f32> = vec![];
351 let config = WarpSearchConfig::with_k(10);
352
353 let selected = CentroidSelector::select(&query, ¢roids, 0, &config);
354 assert!(selected.is_empty());
355 }
356
357 #[test]
359 fn test_batch_scores_dim_zero_no_panic() {
360 let scores = CentroidSelector::batch_scores(&[], &[], 0);
361 assert!(scores.is_empty());
362 }
363
364 #[test]
365 fn test_batch_scores() {
366 let query_token = vec![1.0, 0.0, 0.0, 0.0];
367 let centroids = vec![
368 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ];
371
372 let scores = CentroidSelector::batch_scores(&query_token, ¢roids, 4);
373
374 assert_eq!(scores.len(), 2);
375 assert_eq!(scores[0].0, 0); assert!((scores[0].1 - 1.0).abs() < 1e-6);
377 }
378
379 #[test]
382 fn test_candidate_scorer_empty_centroid() {
383 let query_token = vec![1.0, 0.0, 0.0, 0.0];
384 let codec = create_test_codec();
385
386 let sizes = vec![0, 5, 3]; let offsets = vec![0, 0, 5];
388 let chunk_ids: Vec<ChunkId> = vec![];
389 let token_indices: Vec<u16> = vec![];
390 let residuals: Vec<u8> = vec![];
391
392 let results = CandidateScorer::score(
393 &query_token,
394 0, 0.5,
396 &codec,
397 &sizes,
398 &offsets,
399 &chunk_ids,
400 &token_indices,
401 &residuals,
402 2, );
404
405 assert!(results.is_empty());
406 }
407
408 fn create_test_codec() -> ResidualCodec {
409 let embeddings = vec![0.0f32; 200 * 4]; ResidualCodec::train(&embeddings, 4, 4, 2, 3).unwrap()
412 }
413
414 #[test]
417 fn test_score_merger_basic() {
418 let token_scores = vec![
419 vec![(chunk_id(1), 0, 0.9), (chunk_id(2), 0, 0.8), (chunk_id(1), 1, 0.7)],
420 vec![(chunk_id(1), 0, 0.6), (chunk_id(2), 0, 0.5), (chunk_id(3), 0, 0.4)],
421 ];
422
423 let results = ScoreMerger::merge(token_scores, 10);
424
425 assert_eq!(results.len(), 3);
430 assert_eq!(results[0].0, chunk_id(1));
431 assert!((results[0].1 - 1.5).abs() < 0.001);
432 }
433
434 #[test]
435 fn test_score_merger_empty() {
436 let token_scores: Vec<Vec<(ChunkId, u16, f32)>> = vec![];
437 let results = ScoreMerger::merge(token_scores, 10);
438 assert!(results.is_empty());
439 }
440
441 #[test]
442 fn test_score_merger_respects_k() {
443 let token_scores = vec![vec![
444 (chunk_id(1), 0, 0.9),
445 (chunk_id(2), 0, 0.8),
446 (chunk_id(3), 0, 0.7),
447 (chunk_id(4), 0, 0.6),
448 (chunk_id(5), 0, 0.5),
449 ]];
450
451 let results = ScoreMerger::merge(token_scores, 3);
452 assert_eq!(results.len(), 3);
453 }
454
455 #[test]
456 fn test_score_merger_sorted_descending() {
457 let token_scores =
458 vec![vec![(chunk_id(1), 0, 0.3), (chunk_id(2), 0, 0.9), (chunk_id(3), 0, 0.6)]];
459
460 let results = ScoreMerger::merge(token_scores, 10);
461
462 assert_eq!(results[0].0, chunk_id(2)); assert_eq!(results[1].0, chunk_id(3));
464 assert_eq!(results[2].0, chunk_id(1)); }
466
467 #[test]
468 fn test_merge_single_doc() {
469 let scores = vec![0.9, 0.6, f32::NEG_INFINITY, 0.3];
470 let total = ScoreMerger::merge_single_doc(&scores);
471
472 assert!((total - 1.8).abs() < 0.001); }
474
475 #[test]
478 fn test_exact_maxsim_identical() {
479 let emb = generate_embedding(3, 4, 42);
480 let score = exact_maxsim(&emb, &emb);
481
482 assert!(score > 0.0);
485 }
486
487 #[test]
488 fn test_exact_maxsim_orthogonal() {
489 let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
490 let doc = MultiVectorEmbedding::new(vec![0.0, 1.0, 0.0, 0.0], 1, 4);
491
492 let score = exact_maxsim(&query, &doc);
493 assert!((score - 0.0).abs() < 1e-6);
494 }
495
496 #[test]
497 fn test_exact_maxsim_aligned() {
498 let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
499 let doc = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
500
501 let score = exact_maxsim(&query, &doc);
502 assert!((score - 1.0).abs() < 1e-6);
503 }
504
505 use proptest::prelude::*;
508
509 proptest! {
510 #[test]
511 fn prop_maxsim_non_negative_for_unit_vectors(
512 num_q in 1usize..5,
513 num_d in 1usize..5
514 ) {
515 let query = generate_embedding(num_q, 4, 123);
517 let doc = generate_embedding(num_d, 4, 456);
518
519 let score = exact_maxsim(&query, &doc);
520
521 prop_assert!(score.is_finite());
524 }
525
526 #[test]
527 fn prop_merger_results_count_bounded_by_k(
528 k in 1usize..20,
529 num_docs in 1usize..50
530 ) {
531 let token_scores = vec![
532 (0..num_docs)
533 .map(|i| (chunk_id(i as u128), 0u16, i as f32 / 100.0))
534 .collect()
535 ];
536
537 let results = ScoreMerger::merge(token_scores, k);
538 prop_assert!(results.len() <= k);
539 prop_assert!(results.len() <= num_docs);
540 }
541
542 #[test]
543 fn prop_centroid_selector_respects_nprobe(
544 nprobe in 1u32..10
545 ) {
546 let query = generate_embedding(2, 4, 42);
547 let centroids = vec![0.5f32; 20 * 4]; let config = WarpSearchConfig::with_k(10)
550 .nprobe(nprobe)
551 .centroid_score_threshold(-10.0); let selected = CentroidSelector::select(&query, ¢roids, 4, &config);
554
555 for token_selection in selected {
556 prop_assert!(token_selection.len() <= nprobe as usize);
557 }
558 }
559 }
560}