1use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16struct PrecompQuery<'a> {
18 query: &'a [f32],
19 inv_norm_q: f32,
20 query_f16: &'a [u16],
21}
22
23#[inline]
25#[allow(clippy::too_many_arguments)]
26fn score_batch_precomp(
27 pq: &PrecompQuery<'_>,
28 raw: &[u8],
29 quant: crate::dsl::DenseVectorQuantization,
30 dim: usize,
31 scores: &mut [f32],
32 unit_norm: bool,
33) {
34 let query = pq.query;
35 let inv_norm_q = pq.inv_norm_q;
36 let query_f16 = pq.query_f16;
37 use crate::dsl::DenseVectorQuantization;
38 use crate::structures::simd;
39 match (quant, unit_norm) {
40 (DenseVectorQuantization::F32, false) => {
41 let num_floats = scores.len() * dim;
42 assert!(
46 (raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()),
47 "f32 vector data not 4-byte aligned"
48 );
49 let vectors: &[f32] =
50 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
51 simd::batch_cosine_scores_precomp(query, vectors, dim, scores, inv_norm_q);
52 }
53 (DenseVectorQuantization::F32, true) => {
54 let num_floats = scores.len() * dim;
55 assert!(
56 (raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()),
57 "f32 vector data not 4-byte aligned"
58 );
59 let vectors: &[f32] =
60 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
61 simd::batch_dot_scores_precomp(query, vectors, dim, scores, inv_norm_q);
62 }
63 (DenseVectorQuantization::F16, false) => {
64 simd::batch_cosine_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
65 }
66 (DenseVectorQuantization::F16, true) => {
67 simd::batch_dot_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
68 }
69 (DenseVectorQuantization::UInt8, false) => {
70 simd::batch_cosine_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
71 }
72 (DenseVectorQuantization::UInt8, true) => {
73 simd::batch_dot_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct RerankerConfig {
81 pub field: Field,
83 pub vector: Vec<f32>,
85 pub combiner: MultiValueCombiner,
87 pub unit_norm: bool,
90 pub matryoshka_dims: Option<usize>,
97}
98
99#[cfg(test)]
101use crate::structures::simd::cosine_similarity;
102#[cfg(test)]
103fn score_document(
104 doc: &crate::dsl::Document,
105 config: &RerankerConfig,
106) -> Option<(f32, Vec<ScoredPosition>)> {
107 let query_dim = config.vector.len();
108 let mut values: Vec<(u32, f32)> = doc
109 .get_all(config.field)
110 .filter_map(|fv| fv.as_dense_vector())
111 .enumerate()
112 .filter_map(|(ordinal, vec)| {
113 if vec.len() != query_dim {
114 return None;
115 }
116 let score = cosine_similarity(&config.vector, vec);
117 Some((ordinal as u32, score))
118 })
119 .collect();
120
121 if values.is_empty() {
122 return None;
123 }
124
125 let combined = config.combiner.combine(&values);
126
127 values.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
129 let positions: Vec<ScoredPosition> = values
130 .into_iter()
131 .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
132 .collect();
133
134 Some((combined, positions))
135}
136
137pub async fn rerank<D: crate::directories::Directory + 'static>(
147 searcher: &crate::index::Searcher<D>,
148 candidates: &[SearchResult],
149 config: &RerankerConfig,
150 final_limit: usize,
151) -> crate::error::Result<Vec<SearchResult>> {
152 if config.vector.is_empty() || candidates.is_empty() {
153 return Ok(Vec::new());
154 }
155
156 let t0 = std::time::Instant::now();
157 let field_id = config.field.0;
158 let query = &config.vector;
159 let query_dim = query.len();
160 let segments = searcher.segment_readers();
161 let seg_by_id = searcher.segment_map();
162
163 use crate::structures::simd;
165 let norm_q_sq = simd::dot_product_f32(query, query, query_dim);
166 let inv_norm_q = if norm_q_sq < f32::EPSILON {
167 0.0
168 } else {
169 simd::fast_inv_sqrt(norm_q_sq)
170 };
171 let query_f16: Vec<u16> = query.iter().map(|&v| simd::f32_to_f16(v)).collect();
172 let pq = PrecompQuery {
173 query,
174 inv_norm_q,
175 query_f16: &query_f16,
176 };
177
178 let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
180 let mut skipped = 0u32;
181
182 for (ci, candidate) in candidates.iter().enumerate() {
183 if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
184 segment_groups.entry(si).or_default().push(ci);
185 } else {
186 skipped += 1;
187 }
188 }
189
190 let query_ref = pq.query;
194 let inv_norm_q_val = pq.inv_norm_q;
195 let query_f16_ref = pq.query_f16;
196
197 let segment_futs: Vec<_> = segment_groups
198 .into_iter()
199 .map(|(si, candidate_indices)| {
200 #[allow(clippy::redundant_locals)]
201 let segments = &segments;
202 #[allow(clippy::redundant_locals)]
203 let candidates = candidates;
204 #[allow(clippy::redundant_locals)]
205 let query_ref = query_ref;
206 #[allow(clippy::redundant_locals)]
207 let query_f16_ref = query_f16_ref;
208 #[allow(clippy::redundant_locals)]
209 let config = config;
210 async move {
211 let mut scores: Vec<(usize, u32, f32)> = Vec::new();
212 let mut vectors = 0usize;
213 let mut seg_skipped = 0u32;
214
215 let Some(lazy_flat) = segments[si].flat_vectors().get(&field_id) else {
216 return Ok::<_, crate::error::Error>((
217 scores,
218 vectors,
219 candidate_indices.len() as u32,
220 ));
221 };
222 if lazy_flat.dim != query_dim {
223 return Ok((scores, vectors, candidate_indices.len() as u32));
224 }
225
226 let vbs = lazy_flat.vector_byte_size();
227 let quant = lazy_flat.quantization;
228
229 let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
231 for &ci in &candidate_indices {
232 let local_doc_id = candidates[ci].doc_id;
233 let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
234 if count == 0 {
235 seg_skipped += 1;
236 continue;
237 }
238 for j in 0..count {
239 let (_, ordinal) = lazy_flat.get_doc_id(start + j);
240 resolved.push((ci, start + j, ordinal as u32));
241 }
242 }
243
244 if resolved.is_empty() {
245 return Ok((scores, vectors, seg_skipped));
246 }
247
248 let n = resolved.len();
249 vectors = n;
250
251 resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
253
254 let first_idx = resolved[0].1;
255 let last_idx = resolved[n - 1].1;
256 let span = last_idx - first_idx + 1;
257
258 let mut raw_buf: Vec<u8> = vec![0u8; n * vbs];
259
260 if span <= n * 4 {
261 let range_bytes = lazy_flat
262 .read_vectors_batch(first_idx, span)
263 .await
264 .map_err(crate::error::Error::Io)?;
265 let rb = range_bytes.as_slice();
266 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
267 let rel = flat_idx - first_idx;
268 let src = &rb[rel * vbs..(rel + 1) * vbs];
269 raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
270 }
271 } else {
272 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
273 lazy_flat
274 .read_vector_raw_into(
275 flat_idx,
276 &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
277 )
278 .await
279 .map_err(crate::error::Error::Io)?;
280 }
281 }
282
283 let pq = PrecompQuery {
285 query: query_ref,
286 inv_norm_q: inv_norm_q_val,
287 query_f16: query_f16_ref,
288 };
289
290 let mut scores_buf: Vec<f32> = vec![0.0; n];
291
292 if let Some(mdims) = config.matryoshka_dims
294 && mdims < query_dim
295 && n > final_limit * 2
296 {
297 let trunc_dim = mdims;
298 let trunc_pq = PrecompQuery {
299 query: &query_ref[..trunc_dim],
300 inv_norm_q: {
301 let nq = simd::dot_product_f32(
302 &query_ref[..trunc_dim],
303 &query_ref[..trunc_dim],
304 trunc_dim,
305 );
306 if nq < f32::EPSILON {
307 0.0
308 } else {
309 simd::fast_inv_sqrt(nq)
310 }
311 },
312 query_f16: &query_f16_ref[..trunc_dim],
313 };
314 let trunc_vbs = trunc_dim * quant.element_size();
315 for i in 0..n {
316 let vec_start = i * vbs;
317 score_batch_precomp(
318 &trunc_pq,
319 &raw_buf[vec_start..vec_start + trunc_vbs],
320 quant,
321 trunc_dim,
322 &mut scores_buf[i..i + 1],
323 config.unit_norm,
324 );
325 }
326
327 let per_doc_cap: usize = match &config.combiner {
328 super::MultiValueCombiner::Max => 1,
329 super::MultiValueCombiner::WeightedTopK { k, .. } => *k,
330 _ => usize::MAX,
331 };
332
333 let mut ranked: Vec<(usize, f32)> =
334 (0..n).map(|i| (i, scores_buf[i])).collect();
335 ranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
336
337 let mut survivors: Vec<(usize, f32)> =
338 Vec::with_capacity(n.min(final_limit * 4));
339 let mut doc_vector_counts: FxHashMap<usize, usize> = FxHashMap::default();
340 let mut unique_docs = 0usize;
341
342 for &(orig_idx, score) in &ranked {
343 let ci = resolved[orig_idx].0;
344 let count = doc_vector_counts.entry(ci).or_insert(0);
345
346 if *count >= per_doc_cap {
347 continue;
348 }
349 if *count == 0 {
350 unique_docs += 1;
351 }
352 *count += 1;
353 survivors.push((orig_idx, score));
354
355 if unique_docs >= final_limit && survivors.len() >= final_limit * 2 {
356 break;
357 }
358 }
359
360 scores.reserve(survivors.len());
361 for &(orig_idx, _) in &survivors {
362 let vec_start = orig_idx * vbs;
363 let mut score = 0.0f32;
364 score_batch_precomp(
365 &pq,
366 &raw_buf[vec_start..vec_start + vbs],
367 quant,
368 query_dim,
369 std::slice::from_mut(&mut score),
370 config.unit_norm,
371 );
372 let (ci, _, ordinal) = resolved[orig_idx];
373 scores.push((ci, ordinal, score));
374 }
375
376 let filtered = n - survivors.len();
377 log::debug!(
378 "[reranker] matryoshka pre-filter: {}/{} dims, {}/{} vectors survived from {} unique docs (filtered {}, per_doc_cap={})",
379 trunc_dim,
380 query_dim,
381 survivors.len(),
382 n,
383 unique_docs,
384 filtered,
385 per_doc_cap
386 );
387 } else {
388 score_batch_precomp(
389 &pq,
390 &raw_buf[..n * vbs],
391 quant,
392 query_dim,
393 &mut scores_buf[..n],
394 config.unit_norm,
395 );
396
397 scores.reserve(n);
398 for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
399 scores.push((ci, ordinal, scores_buf[buf_idx]));
400 }
401 }
402
403 Ok((scores, vectors, seg_skipped))
404 }
405 })
406 .collect();
407
408 let results = futures::future::join_all(segment_futs).await;
409
410 let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
411 let mut total_vectors = 0usize;
412 for result in results {
413 let (scores, vectors, seg_skipped) = result?;
414 all_scores.extend(scores);
415 total_vectors += vectors;
416 skipped += seg_skipped;
417 }
418
419 let read_score_elapsed = t0.elapsed();
420
421 if total_vectors == 0 {
422 log::debug!(
423 "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
424 field_id,
425 candidates.len()
426 );
427 return Ok(Vec::new());
428 }
429
430 all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
433
434 let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
435 let mut ordinal_pairs: Vec<(u32, f32)> = Vec::new();
436 let mut i = 0;
437 while i < all_scores.len() {
438 let ci = all_scores[i].0;
439 let run_start = i;
440 while i < all_scores.len() && all_scores[i].0 == ci {
441 i += 1;
442 }
443 let run = &mut all_scores[run_start..i];
444
445 ordinal_pairs.clear();
447 ordinal_pairs.extend(run.iter().map(|&(_, ord, s)| (ord, s)));
448 let combined = config.combiner.combine(&ordinal_pairs);
449
450 run.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
452 let positions: Vec<ScoredPosition> = run
453 .iter()
454 .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
455 .collect();
456
457 scored.push(SearchResult {
458 doc_id: candidates[ci].doc_id,
459 score: combined,
460 segment_id: candidates[ci].segment_id,
461 positions: vec![(field_id, positions)],
462 });
463 }
464
465 scored.sort_unstable_by(|a, b| {
466 b.score
467 .partial_cmp(&a.score)
468 .unwrap_or(std::cmp::Ordering::Equal)
469 });
470 scored.truncate(final_limit);
471
472 log::debug!(
473 "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
474 field_id,
475 candidates.len(),
476 scored.len(),
477 skipped,
478 total_vectors,
479 config.unit_norm,
480 read_score_elapsed.as_secs_f64() * 1000.0,
481 t0.elapsed().as_secs_f64() * 1000.0,
482 );
483
484 Ok(scored)
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::dsl::{Document, Field};
491
492 fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
493 RerankerConfig {
494 field: Field(0),
495 vector,
496 combiner,
497 unit_norm: false,
498 matryoshka_dims: None,
499 }
500 }
501
502 #[test]
503 fn test_score_document_single_value() {
504 let mut doc = Document::new();
505 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
506
507 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
508 let (score, positions) = score_document(&doc, &config).unwrap();
509 assert!((score - 1.0).abs() < 1e-6);
511 assert_eq!(positions.len(), 1);
512 assert_eq!(positions[0].position, 0); }
514
515 #[test]
516 fn test_score_document_orthogonal() {
517 let mut doc = Document::new();
518 doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
519
520 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
521 let (score, _) = score_document(&doc, &config).unwrap();
522 assert!(score.abs() < 1e-6);
524 }
525
526 #[test]
527 fn test_score_document_multi_value_max() {
528 let mut doc = Document::new();
529 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
533 let (score, positions) = score_document(&doc, &config).unwrap();
534 assert!((score - 1.0).abs() < 1e-6);
535 assert_eq!(positions.len(), 2);
537 assert_eq!(positions[0].position, 0); assert!((positions[0].score - 1.0).abs() < 1e-6);
539 }
540
541 #[test]
542 fn test_score_document_multi_value_avg() {
543 let mut doc = Document::new();
544 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
548 let (score, _) = score_document(&doc, &config).unwrap();
549 assert!((score - 0.5).abs() < 1e-6);
551 }
552
553 #[test]
554 fn test_score_document_missing_field() {
555 let mut doc = Document::new();
556 doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
558
559 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
560 assert!(score_document(&doc, &config).is_none());
561 }
562
563 #[test]
564 fn test_score_document_wrong_field_type() {
565 let mut doc = Document::new();
566 doc.add_text(Field(0), "not a vector");
567
568 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
569 assert!(score_document(&doc, &config).is_none());
570 }
571
572 #[test]
573 fn test_score_document_dimension_mismatch() {
574 let mut doc = Document::new();
575 doc.add_dense_vector(Field(0), vec![1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); assert!(score_document(&doc, &config).is_none());
579 }
580
581 #[test]
582 fn test_score_document_empty_query_vector() {
583 let mut doc = Document::new();
584 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
585
586 let config = make_config(vec![], MultiValueCombiner::Max);
587 assert!(score_document(&doc, &config).is_none());
589 }
590}