1use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16struct PrecompQuery<'a> {
18 query: &'a [f32],
19 inv_norm_q: f32,
20 query_f16: &'a [u16],
21}
22
23#[inline]
25#[allow(clippy::too_many_arguments)]
26fn score_batch_precomp(
27 pq: &PrecompQuery<'_>,
28 raw: &[u8],
29 quant: crate::dsl::DenseVectorQuantization,
30 dim: usize,
31 scores: &mut [f32],
32 unit_norm: bool,
33) {
34 let query = pq.query;
35 let inv_norm_q = pq.inv_norm_q;
36 let query_f16 = pq.query_f16;
37 use crate::dsl::DenseVectorQuantization;
38 use crate::structures::simd;
39 match (quant, unit_norm) {
40 (DenseVectorQuantization::F32, false) => {
41 let num_floats = scores.len() * dim;
42 assert!(
46 (raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()),
47 "f32 vector data not 4-byte aligned"
48 );
49 let vectors: &[f32] =
50 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
51 simd::batch_cosine_scores_precomp(query, vectors, dim, scores, inv_norm_q);
52 }
53 (DenseVectorQuantization::F32, true) => {
54 let num_floats = scores.len() * dim;
55 assert!(
56 (raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()),
57 "f32 vector data not 4-byte aligned"
58 );
59 let vectors: &[f32] =
60 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
61 simd::batch_dot_scores_precomp(query, vectors, dim, scores, inv_norm_q);
62 }
63 (DenseVectorQuantization::F16, false) => {
64 simd::batch_cosine_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
65 }
66 (DenseVectorQuantization::F16, true) => {
67 simd::batch_dot_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
68 }
69 (DenseVectorQuantization::UInt8, false) => {
70 simd::batch_cosine_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
71 }
72 (DenseVectorQuantization::UInt8, true) => {
73 simd::batch_dot_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct RerankerConfig {
81 pub field: Field,
83 pub vector: Vec<f32>,
85 pub combiner: MultiValueCombiner,
87 pub unit_norm: bool,
90 pub matryoshka_dims: Option<usize>,
97}
98
99#[cfg(test)]
101use crate::structures::simd::cosine_similarity;
102#[cfg(test)]
103fn score_document(
104 doc: &crate::dsl::Document,
105 config: &RerankerConfig,
106) -> Option<(f32, Vec<ScoredPosition>)> {
107 let query_dim = config.vector.len();
108 let mut values: Vec<(u32, f32)> = doc
109 .get_all(config.field)
110 .filter_map(|fv| fv.as_dense_vector())
111 .enumerate()
112 .filter_map(|(ordinal, vec)| {
113 if vec.len() != query_dim {
114 return None;
115 }
116 let score = cosine_similarity(&config.vector, vec);
117 Some((ordinal as u32, score))
118 })
119 .collect();
120
121 if values.is_empty() {
122 return None;
123 }
124
125 let combined = config.combiner.combine(&values);
126
127 values.sort_by(|a, b| b.1.total_cmp(&a.1));
129 let positions: Vec<ScoredPosition> = values
130 .into_iter()
131 .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
132 .collect();
133
134 Some((combined, positions))
135}
136
137pub async fn rerank<D: crate::directories::Directory + 'static>(
147 searcher: &crate::index::Searcher<D>,
148 candidates: &[SearchResult],
149 config: &RerankerConfig,
150 final_limit: usize,
151) -> crate::error::Result<Vec<SearchResult>> {
152 if config.vector.is_empty() || candidates.is_empty() {
153 return Ok(Vec::new());
154 }
155
156 let t0 = std::time::Instant::now();
157 let field_id = config.field.0;
158 let query = &config.vector;
159 let query_dim = query.len();
160 let segments = searcher.segment_readers();
161 let seg_by_id = searcher.segment_map();
162
163 use crate::structures::simd;
165 let norm_q_sq = simd::dot_product_f32(query, query, query_dim);
166 let inv_norm_q = if norm_q_sq < f32::EPSILON {
167 0.0
168 } else {
169 simd::fast_inv_sqrt(norm_q_sq)
170 };
171 let query_f16: Vec<u16> = query.iter().map(|&v| simd::f32_to_f16(v)).collect();
172 let pq = PrecompQuery {
173 query,
174 inv_norm_q,
175 query_f16: &query_f16,
176 };
177
178 let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
180 let mut skipped = 0u32;
181
182 for (ci, candidate) in candidates.iter().enumerate() {
183 if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
184 segment_groups.entry(si).or_default().push(ci);
185 } else {
186 skipped += 1;
187 }
188 }
189
190 let query_ref = pq.query;
194 let inv_norm_q_val = pq.inv_norm_q;
195 let query_f16_ref = pq.query_f16;
196
197 let segment_futs: Vec<_> = segment_groups
198 .into_iter()
199 .map(|(si, candidate_indices)| {
200 #[allow(clippy::redundant_locals)]
201 let segments = &segments;
202 #[allow(clippy::redundant_locals)]
203 let candidates = candidates;
204 #[allow(clippy::redundant_locals)]
205 let query_ref = query_ref;
206 #[allow(clippy::redundant_locals)]
207 let query_f16_ref = query_f16_ref;
208 #[allow(clippy::redundant_locals)]
209 let config = config;
210 async move {
211 let mut scores: Vec<(usize, u32, f32)> = Vec::new();
212 let mut vectors = 0usize;
213 let mut seg_skipped = 0u32;
214
215 let Some(lazy_flat) = segments[si].flat_vectors().get(&field_id) else {
216 return Ok::<_, crate::error::Error>((
217 scores,
218 vectors,
219 candidate_indices.len() as u32,
220 ));
221 };
222 if lazy_flat.dim != query_dim {
223 return Ok((scores, vectors, candidate_indices.len() as u32));
224 }
225
226 let vbs = lazy_flat.vector_byte_size();
227 let quant = lazy_flat.quantization;
228
229 let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
231 for &ci in &candidate_indices {
232 let local_doc_id = candidates[ci].doc_id;
233 let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
234 if count == 0 {
235 seg_skipped += 1;
236 continue;
237 }
238 for j in 0..count {
239 let (_, ordinal) = lazy_flat.get_doc_id(start + j);
240 resolved.push((ci, start + j, ordinal as u32));
241 }
242 }
243
244 if resolved.is_empty() {
245 return Ok((scores, vectors, seg_skipped));
246 }
247
248 let n = resolved.len();
249 vectors = n;
250
251 resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
253
254 let first_idx = resolved[0].1;
255 let last_idx = resolved[n - 1].1;
256 let span = last_idx - first_idx + 1;
257
258 let mut raw_buf: Vec<u8> = vec![0u8; n * vbs];
259
260 if span <= n * 4 {
261 let range_bytes = lazy_flat
262 .read_vectors_batch(first_idx, span)
263 .await
264 .map_err(crate::error::Error::Io)?;
265 let rb = range_bytes.as_slice();
266 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
267 let rel = flat_idx - first_idx;
268 let src = &rb[rel * vbs..(rel + 1) * vbs];
269 raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
270 }
271 } else {
272 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
273 lazy_flat
274 .read_vector_raw_into(
275 flat_idx,
276 &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
277 )
278 .await
279 .map_err(crate::error::Error::Io)?;
280 }
281 }
282
283 let pq = PrecompQuery {
285 query: query_ref,
286 inv_norm_q: inv_norm_q_val,
287 query_f16: query_f16_ref,
288 };
289
290 let mut scores_buf: Vec<f32> = vec![0.0; n];
291
292 if let Some(mdims) = config.matryoshka_dims
294 && mdims < query_dim
295 && n > final_limit * 2
296 {
297 let trunc_dim = mdims;
298 let trunc_pq = PrecompQuery {
299 query: &query_ref[..trunc_dim],
300 inv_norm_q: {
301 let nq = simd::dot_product_f32(
302 &query_ref[..trunc_dim],
303 &query_ref[..trunc_dim],
304 trunc_dim,
305 );
306 if nq < f32::EPSILON {
307 0.0
308 } else {
309 simd::fast_inv_sqrt(nq)
310 }
311 },
312 query_f16: &query_f16_ref[..trunc_dim],
313 };
314 let trunc_vbs = trunc_dim * quant.element_size();
315 for i in 0..n {
316 let vec_start = i * vbs;
317 score_batch_precomp(
318 &trunc_pq,
319 &raw_buf[vec_start..vec_start + trunc_vbs],
320 quant,
321 trunc_dim,
322 &mut scores_buf[i..i + 1],
323 config.unit_norm,
324 );
325 }
326
327 let per_doc_cap: usize = match &config.combiner {
328 super::MultiValueCombiner::Max => 1,
329 super::MultiValueCombiner::WeightedTopK { k, .. } => *k,
330 _ => usize::MAX,
331 };
332
333 let mut ranked: Vec<(usize, f32)> =
334 (0..n).map(|i| (i, scores_buf[i])).collect();
335 ranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
336
337 let mut survivors: Vec<(usize, f32)> =
338 Vec::with_capacity(n.min(final_limit * 4));
339 let mut doc_vector_counts: FxHashMap<usize, usize> = FxHashMap::default();
340 let mut unique_docs = 0usize;
341
342 for &(orig_idx, score) in &ranked {
343 let ci = resolved[orig_idx].0;
344 let count = doc_vector_counts.entry(ci).or_insert(0);
345
346 if *count >= per_doc_cap {
347 continue;
348 }
349 if *count == 0 {
350 unique_docs += 1;
351 }
352 *count += 1;
353 survivors.push((orig_idx, score));
354
355 if unique_docs >= final_limit && survivors.len() >= final_limit * 2 {
356 break;
357 }
358 }
359
360 scores.reserve(survivors.len());
361 for &(orig_idx, _) in &survivors {
362 let vec_start = orig_idx * vbs;
363 let mut score = 0.0f32;
364 score_batch_precomp(
365 &pq,
366 &raw_buf[vec_start..vec_start + vbs],
367 quant,
368 query_dim,
369 std::slice::from_mut(&mut score),
370 config.unit_norm,
371 );
372 let (ci, _, ordinal) = resolved[orig_idx];
373 scores.push((ci, ordinal, score));
374 }
375
376 let filtered = n - survivors.len();
377 log::debug!(
378 "[reranker] matryoshka pre-filter: {}/{} dims, {}/{} vectors survived from {} unique docs (filtered {}, per_doc_cap={})",
379 trunc_dim,
380 query_dim,
381 survivors.len(),
382 n,
383 unique_docs,
384 filtered,
385 per_doc_cap
386 );
387 } else {
388 score_batch_precomp(
389 &pq,
390 &raw_buf[..n * vbs],
391 quant,
392 query_dim,
393 &mut scores_buf[..n],
394 config.unit_norm,
395 );
396
397 scores.reserve(n);
398 for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
399 scores.push((ci, ordinal, scores_buf[buf_idx]));
400 }
401 }
402
403 Ok((scores, vectors, seg_skipped))
404 }
405 })
406 .collect();
407
408 let results = futures::future::join_all(segment_futs).await;
409
410 let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
411 let mut total_vectors = 0usize;
412 for result in results {
413 let (scores, vectors, seg_skipped) = result?;
414 all_scores.extend(scores);
415 total_vectors += vectors;
416 skipped += seg_skipped;
417 }
418
419 let read_score_elapsed = t0.elapsed();
420
421 if total_vectors == 0 {
422 log::debug!(
423 "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
424 field_id,
425 candidates.len()
426 );
427 return Ok(Vec::new());
428 }
429
430 all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
433
434 let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
435 let mut i = 0;
436 while i < all_scores.len() {
437 let ci = all_scores[i].0;
438 let run_start = i;
439 while i < all_scores.len() && all_scores[i].0 == ci {
440 i += 1;
441 }
442 let run = &mut all_scores[run_start..i];
443
444 let ordinal_pairs: Vec<(u32, f32)> = run.iter().map(|&(_, ord, s)| (ord, s)).collect();
446 let combined = config.combiner.combine(&ordinal_pairs);
447
448 run.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
450 let positions: Vec<ScoredPosition> = run
451 .iter()
452 .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
453 .collect();
454
455 scored.push(SearchResult {
456 doc_id: candidates[ci].doc_id,
457 score: combined,
458 segment_id: candidates[ci].segment_id,
459 positions: vec![(field_id, positions)],
460 });
461 }
462
463 scored.sort_by(|a, b| {
464 b.score
465 .partial_cmp(&a.score)
466 .unwrap_or(std::cmp::Ordering::Equal)
467 });
468 scored.truncate(final_limit);
469
470 log::debug!(
471 "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
472 field_id,
473 candidates.len(),
474 scored.len(),
475 skipped,
476 total_vectors,
477 config.unit_norm,
478 read_score_elapsed.as_secs_f64() * 1000.0,
479 t0.elapsed().as_secs_f64() * 1000.0,
480 );
481
482 Ok(scored)
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::dsl::{Document, Field};
489
490 fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
491 RerankerConfig {
492 field: Field(0),
493 vector,
494 combiner,
495 unit_norm: false,
496 matryoshka_dims: None,
497 }
498 }
499
500 #[test]
501 fn test_score_document_single_value() {
502 let mut doc = Document::new();
503 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
504
505 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
506 let (score, positions) = score_document(&doc, &config).unwrap();
507 assert!((score - 1.0).abs() < 1e-6);
509 assert_eq!(positions.len(), 1);
510 assert_eq!(positions[0].position, 0); }
512
513 #[test]
514 fn test_score_document_orthogonal() {
515 let mut doc = Document::new();
516 doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
517
518 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
519 let (score, _) = score_document(&doc, &config).unwrap();
520 assert!(score.abs() < 1e-6);
522 }
523
524 #[test]
525 fn test_score_document_multi_value_max() {
526 let mut doc = Document::new();
527 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
531 let (score, positions) = score_document(&doc, &config).unwrap();
532 assert!((score - 1.0).abs() < 1e-6);
533 assert_eq!(positions.len(), 2);
535 assert_eq!(positions[0].position, 0); assert!((positions[0].score - 1.0).abs() < 1e-6);
537 }
538
539 #[test]
540 fn test_score_document_multi_value_avg() {
541 let mut doc = Document::new();
542 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
546 let (score, _) = score_document(&doc, &config).unwrap();
547 assert!((score - 0.5).abs() < 1e-6);
549 }
550
551 #[test]
552 fn test_score_document_missing_field() {
553 let mut doc = Document::new();
554 doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
556
557 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
558 assert!(score_document(&doc, &config).is_none());
559 }
560
561 #[test]
562 fn test_score_document_wrong_field_type() {
563 let mut doc = Document::new();
564 doc.add_text(Field(0), "not a vector");
565
566 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
567 assert!(score_document(&doc, &config).is_none());
568 }
569
570 #[test]
571 fn test_score_document_dimension_mismatch() {
572 let mut doc = Document::new();
573 doc.add_dense_vector(Field(0), vec![1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); assert!(score_document(&doc, &config).is_none());
577 }
578
579 #[test]
580 fn test_score_document_empty_query_vector() {
581 let mut doc = Document::new();
582 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
583
584 let config = make_config(vec![], MultiValueCombiner::Max);
585 assert!(score_document(&doc, &config).is_none());
587 }
588}