1use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16struct PrecompQuery<'a> {
18 query: &'a [f32],
19 inv_norm_q: f32,
20 query_f16: &'a [u16],
21}
22
23#[inline]
25#[allow(clippy::too_many_arguments)]
26fn score_batch_precomp(
27 pq: &PrecompQuery<'_>,
28 raw: &[u8],
29 quant: crate::dsl::DenseVectorQuantization,
30 dim: usize,
31 scores: &mut [f32],
32 unit_norm: bool,
33) {
34 let query = pq.query;
35 let inv_norm_q = pq.inv_norm_q;
36 let query_f16 = pq.query_f16;
37 use crate::dsl::DenseVectorQuantization;
38 use crate::structures::simd;
39 match (quant, unit_norm) {
40 (DenseVectorQuantization::F32, false) => {
41 let num_floats = scores.len() * dim;
42 let vectors: &[f32] =
43 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
44 simd::batch_cosine_scores_precomp(query, vectors, dim, scores, inv_norm_q);
45 }
46 (DenseVectorQuantization::F32, true) => {
47 let num_floats = scores.len() * dim;
48 let vectors: &[f32] =
49 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
50 simd::batch_dot_scores_precomp(query, vectors, dim, scores, inv_norm_q);
51 }
52 (DenseVectorQuantization::F16, false) => {
53 simd::batch_cosine_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
54 }
55 (DenseVectorQuantization::F16, true) => {
56 simd::batch_dot_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
57 }
58 (DenseVectorQuantization::UInt8, false) => {
59 simd::batch_cosine_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
60 }
61 (DenseVectorQuantization::UInt8, true) => {
62 simd::batch_dot_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct RerankerConfig {
70 pub field: Field,
72 pub vector: Vec<f32>,
74 pub combiner: MultiValueCombiner,
76 pub unit_norm: bool,
79 pub matryoshka_dims: Option<usize>,
86}
87
88#[cfg(test)]
90use crate::structures::simd::cosine_similarity;
91#[cfg(test)]
92fn score_document(
93 doc: &crate::dsl::Document,
94 config: &RerankerConfig,
95) -> Option<(f32, Vec<ScoredPosition>)> {
96 let query_dim = config.vector.len();
97 let mut values: Vec<(u32, f32)> = doc
98 .get_all(config.field)
99 .filter_map(|fv| fv.as_dense_vector())
100 .enumerate()
101 .filter_map(|(ordinal, vec)| {
102 if vec.len() != query_dim {
103 return None;
104 }
105 let score = cosine_similarity(&config.vector, vec);
106 Some((ordinal as u32, score))
107 })
108 .collect();
109
110 if values.is_empty() {
111 return None;
112 }
113
114 let combined = config.combiner.combine(&values);
115
116 values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
118 let positions: Vec<ScoredPosition> = values
119 .into_iter()
120 .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
121 .collect();
122
123 Some((combined, positions))
124}
125
126pub async fn rerank<D: crate::directories::Directory + 'static>(
136 searcher: &crate::index::Searcher<D>,
137 candidates: &[SearchResult],
138 config: &RerankerConfig,
139 final_limit: usize,
140) -> crate::error::Result<Vec<SearchResult>> {
141 if config.vector.is_empty() || candidates.is_empty() {
142 return Ok(Vec::new());
143 }
144
145 let t0 = std::time::Instant::now();
146 let field_id = config.field.0;
147 let query = &config.vector;
148 let query_dim = query.len();
149 let segments = searcher.segment_readers();
150 let seg_by_id = searcher.segment_map();
151
152 use crate::structures::simd;
154 let norm_q_sq = simd::dot_product_f32(query, query, query_dim);
155 let inv_norm_q = if norm_q_sq < f32::EPSILON {
156 0.0
157 } else {
158 simd::fast_inv_sqrt(norm_q_sq)
159 };
160 let query_f16: Vec<u16> = query.iter().map(|&v| simd::f32_to_f16(v)).collect();
161 let pq = PrecompQuery {
162 query,
163 inv_norm_q,
164 query_f16: &query_f16,
165 };
166
167 let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
169 let mut skipped = 0u32;
170
171 for (ci, candidate) in candidates.iter().enumerate() {
172 if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
173 segment_groups.entry(si).or_default().push(ci);
174 } else {
175 skipped += 1;
176 }
177 }
178
179 let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
183 let mut total_vectors = 0usize;
184 let mut raw_buf: Vec<u8> = Vec::new();
186 let mut scores_buf: Vec<f32> = Vec::new();
187
188 for (si, candidate_indices) in &segment_groups {
189 let Some(lazy_flat) = segments[*si].flat_vectors().get(&field_id) else {
190 skipped += candidate_indices.len() as u32;
191 continue;
192 };
193 if lazy_flat.dim != query_dim {
194 skipped += candidate_indices.len() as u32;
195 continue;
196 }
197
198 let vbs = lazy_flat.vector_byte_size();
199 let quant = lazy_flat.quantization;
200
201 let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
204 for &ci in candidate_indices {
205 let local_doc_id = candidates[ci].doc_id;
206 let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
207 if count == 0 {
208 skipped += 1;
209 continue;
210 }
211 for j in 0..count {
212 let (_, ordinal) = lazy_flat.get_doc_id(start + j);
213 resolved.push((ci, start + j, ordinal as u32));
214 }
215 }
216
217 if resolved.is_empty() {
218 continue;
219 }
220
221 let n = resolved.len();
222 total_vectors += n;
223
224 resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
226
227 let first_idx = resolved[0].1;
230 let last_idx = resolved[n - 1].1;
231 let span = last_idx - first_idx + 1;
232
233 raw_buf.resize(n * vbs, 0);
234
235 if span <= n * 4 {
238 let range_bytes = lazy_flat
239 .read_vectors_batch(first_idx, span)
240 .await
241 .expect("reranker: failed to read vector batch from flat storage");
242 let rb = range_bytes.as_slice();
243 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
244 let rel = flat_idx - first_idx;
245 let src = &rb[rel * vbs..(rel + 1) * vbs];
246 raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
247 }
248 } else {
249 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
250 lazy_flat
251 .read_vector_raw_into(
252 flat_idx,
253 &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
254 )
255 .await
256 .expect("reranker: failed to read individual vector from flat storage");
257 }
258 }
259
260 scores_buf.resize(n, 0.0);
262
263 if let Some(mdims) = config.matryoshka_dims
266 && mdims < query_dim
267 && n > final_limit * 2
268 {
269 let trunc_dim = mdims;
274 let trunc_pq = PrecompQuery {
275 query: &query[..trunc_dim],
276 inv_norm_q: {
277 let nq =
278 simd::dot_product_f32(&query[..trunc_dim], &query[..trunc_dim], trunc_dim);
279 if nq < f32::EPSILON {
280 0.0
281 } else {
282 simd::fast_inv_sqrt(nq)
283 }
284 },
285 query_f16: &query_f16[..trunc_dim],
286 };
287 let trunc_vbs = trunc_dim * quant.element_size();
288 for i in 0..n {
289 let vec_start = i * vbs;
290 score_batch_precomp(
291 &trunc_pq,
292 &raw_buf[vec_start..vec_start + trunc_vbs],
293 quant,
294 trunc_dim,
295 &mut scores_buf[i..i + 1],
296 config.unit_norm,
297 );
298 }
299
300 let keep = (final_limit * 2).min(n);
302 let mut ranked: Vec<(usize, f32)> = (0..n).map(|i| (i, scores_buf[i])).collect();
303 ranked.sort_unstable_by(|a, b| {
304 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
305 });
306 ranked.truncate(keep);
307
308 all_scores.reserve(ranked.len());
311 for &(orig_idx, _) in &ranked {
312 let vec_start = orig_idx * vbs;
313 let mut score = 0.0f32;
314 score_batch_precomp(
315 &pq,
316 &raw_buf[vec_start..vec_start + vbs],
317 quant,
318 query_dim,
319 std::slice::from_mut(&mut score),
320 config.unit_norm,
321 );
322 let (ci, _, ordinal) = resolved[orig_idx];
323 all_scores.push((ci, ordinal, score));
324 }
325
326 let filtered = n - ranked.len();
327 log::debug!(
328 "[reranker] matryoshka pre-filter: {}/{} dims, {}/{} vectors survived (filtered {})",
329 trunc_dim,
330 query_dim,
331 ranked.len(),
332 n,
333 filtered
334 );
335 } else {
336 score_batch_precomp(
338 &pq,
339 &raw_buf[..n * vbs],
340 quant,
341 query_dim,
342 &mut scores_buf[..n],
343 config.unit_norm,
344 );
345
346 all_scores.reserve(n);
347 for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
348 all_scores.push((ci, ordinal, scores_buf[buf_idx]));
349 }
350 }
351 }
352
353 let read_score_elapsed = t0.elapsed();
354
355 if total_vectors == 0 {
356 log::debug!(
357 "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
358 field_id,
359 candidates.len()
360 );
361 return Ok(Vec::new());
362 }
363
364 all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
367
368 let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
369 let mut i = 0;
370 while i < all_scores.len() {
371 let ci = all_scores[i].0;
372 let run_start = i;
373 while i < all_scores.len() && all_scores[i].0 == ci {
374 i += 1;
375 }
376 let run = &mut all_scores[run_start..i];
377
378 let ordinal_pairs: Vec<(u32, f32)> = run.iter().map(|&(_, ord, s)| (ord, s)).collect();
380 let combined = config.combiner.combine(&ordinal_pairs);
381
382 run.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
384 let positions: Vec<ScoredPosition> = run
385 .iter()
386 .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
387 .collect();
388
389 scored.push(SearchResult {
390 doc_id: candidates[ci].doc_id,
391 score: combined,
392 segment_id: candidates[ci].segment_id,
393 positions: vec![(field_id, positions)],
394 });
395 }
396
397 scored.sort_by(|a, b| {
398 b.score
399 .partial_cmp(&a.score)
400 .unwrap_or(std::cmp::Ordering::Equal)
401 });
402 scored.truncate(final_limit);
403
404 log::debug!(
405 "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
406 field_id,
407 candidates.len(),
408 scored.len(),
409 skipped,
410 total_vectors,
411 config.unit_norm,
412 read_score_elapsed.as_secs_f64() * 1000.0,
413 t0.elapsed().as_secs_f64() * 1000.0,
414 );
415
416 Ok(scored)
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::dsl::{Document, Field};
423
424 fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
425 RerankerConfig {
426 field: Field(0),
427 vector,
428 combiner,
429 unit_norm: false,
430 matryoshka_dims: None,
431 }
432 }
433
434 #[test]
435 fn test_score_document_single_value() {
436 let mut doc = Document::new();
437 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
438
439 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
440 let (score, positions) = score_document(&doc, &config).unwrap();
441 assert!((score - 1.0).abs() < 1e-6);
443 assert_eq!(positions.len(), 1);
444 assert_eq!(positions[0].position, 0); }
446
447 #[test]
448 fn test_score_document_orthogonal() {
449 let mut doc = Document::new();
450 doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
451
452 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
453 let (score, _) = score_document(&doc, &config).unwrap();
454 assert!(score.abs() < 1e-6);
456 }
457
458 #[test]
459 fn test_score_document_multi_value_max() {
460 let mut doc = Document::new();
461 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
465 let (score, positions) = score_document(&doc, &config).unwrap();
466 assert!((score - 1.0).abs() < 1e-6);
467 assert_eq!(positions.len(), 2);
469 assert_eq!(positions[0].position, 0); assert!((positions[0].score - 1.0).abs() < 1e-6);
471 }
472
473 #[test]
474 fn test_score_document_multi_value_avg() {
475 let mut doc = Document::new();
476 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
480 let (score, _) = score_document(&doc, &config).unwrap();
481 assert!((score - 0.5).abs() < 1e-6);
483 }
484
485 #[test]
486 fn test_score_document_missing_field() {
487 let mut doc = Document::new();
488 doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
490
491 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
492 assert!(score_document(&doc, &config).is_none());
493 }
494
495 #[test]
496 fn test_score_document_wrong_field_type() {
497 let mut doc = Document::new();
498 doc.add_text(Field(0), "not a vector");
499
500 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
501 assert!(score_document(&doc, &config).is_none());
502 }
503
504 #[test]
505 fn test_score_document_dimension_mismatch() {
506 let mut doc = Document::new();
507 doc.add_dense_vector(Field(0), vec![1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); assert!(score_document(&doc, &config).is_none());
511 }
512
513 #[test]
514 fn test_score_document_empty_query_vector() {
515 let mut doc = Document::new();
516 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
517
518 let config = make_config(vec![], MultiValueCombiner::Max);
519 assert!(score_document(&doc, &config).is_none());
521 }
522}