Skip to main content

hermes_core/query/
reranker.rs

1//! L2 reranker: rerank L1 candidates by exact dense vector distance on stored vectors
2//!
3//! Optimized for throughput:
4//! - Candidates grouped by segment for batched I/O
5//! - Flat indexes sorted for sequential mmap access (OS readahead)
6//! - Single SIMD batch-score call per segment (not per candidate)
7//! - Reusable buffers across segments (no per-candidate heap allocation)
8//! - unit_norm fast path: skip per-vector norm when vectors are pre-normalized
9
10use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16/// Precomputed query data for dense reranking (computed once, reused across segments).
17struct PrecompQuery<'a> {
18    query: &'a [f32],
19    inv_norm_q: f32,
20    query_f16: &'a [u16],
21}
22
23/// Batch SIMD scoring with precomputed query norm + f16 query.
24#[inline]
25#[allow(clippy::too_many_arguments)]
26fn score_batch_precomp(
27    pq: &PrecompQuery<'_>,
28    raw: &[u8],
29    quant: crate::dsl::DenseVectorQuantization,
30    dim: usize,
31    scores: &mut [f32],
32    unit_norm: bool,
33) {
34    let query = pq.query;
35    let inv_norm_q = pq.inv_norm_q;
36    let query_f16 = pq.query_f16;
37    use crate::dsl::DenseVectorQuantization;
38    use crate::structures::simd;
39    match (quant, unit_norm) {
40        (DenseVectorQuantization::F32, false) => {
41            let num_floats = scores.len() * dim;
42            // Safety: Vec<u8> from the global allocator is guaranteed to be at least
43            // 8-byte aligned on 64-bit platforms (aligned to max_align_t). Assert at
44            // runtime to guard against custom allocators with weaker guarantees.
45            assert!(
46                (raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()),
47                "f32 vector data not 4-byte aligned"
48            );
49            let vectors: &[f32] =
50                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
51            simd::batch_cosine_scores_precomp(query, vectors, dim, scores, inv_norm_q);
52        }
53        (DenseVectorQuantization::F32, true) => {
54            let num_floats = scores.len() * dim;
55            assert!(
56                (raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()),
57                "f32 vector data not 4-byte aligned"
58            );
59            let vectors: &[f32] =
60                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
61            simd::batch_dot_scores_precomp(query, vectors, dim, scores, inv_norm_q);
62        }
63        (DenseVectorQuantization::F16, false) => {
64            simd::batch_cosine_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
65        }
66        (DenseVectorQuantization::F16, true) => {
67            simd::batch_dot_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
68        }
69        (DenseVectorQuantization::UInt8, false) => {
70            simd::batch_cosine_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
71        }
72        (DenseVectorQuantization::UInt8, true) => {
73            simd::batch_dot_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
74        }
75    }
76}
77
78/// Configuration for L2 dense vector reranking
79#[derive(Debug, Clone)]
80pub struct RerankerConfig {
81    /// Dense vector field (must be stored)
82    pub field: Field,
83    /// Query vector
84    pub vector: Vec<f32>,
85    /// How to combine scores for multi-valued documents
86    pub combiner: MultiValueCombiner,
87    /// Whether stored vectors are pre-normalized to unit L2 norm.
88    /// When true, scoring uses dot-product only (skips per-vector norm — ~40% faster).
89    pub unit_norm: bool,
90    /// Matryoshka pre-filter: number of leading dimensions to use for cheap
91    /// approximate scoring before full-dimension exact reranking.
92    /// When set, scores all candidates on the first `matryoshka_dims` dimensions,
93    /// keeps the top `final_limit × 2` candidates, then does full-dimension
94    /// exact scoring on survivors only. Skips ~50-70% of full cosine computations.
95    /// Set to `None` to disable (default: score all candidates at full dimension).
96    pub matryoshka_dims: Option<usize>,
97}
98
99/// Score a single document against the query vector (used by tests).
100#[cfg(test)]
101use crate::structures::simd::cosine_similarity;
102#[cfg(test)]
103fn score_document(
104    doc: &crate::dsl::Document,
105    config: &RerankerConfig,
106) -> Option<(f32, Vec<ScoredPosition>)> {
107    let query_dim = config.vector.len();
108    let mut values: Vec<(u32, f32)> = doc
109        .get_all(config.field)
110        .filter_map(|fv| fv.as_dense_vector())
111        .enumerate()
112        .filter_map(|(ordinal, vec)| {
113            if vec.len() != query_dim {
114                return None;
115            }
116            let score = cosine_similarity(&config.vector, vec);
117            Some((ordinal as u32, score))
118        })
119        .collect();
120
121    if values.is_empty() {
122        return None;
123    }
124
125    let combined = config.combiner.combine(&values);
126
127    // Sort ordinals by score descending (best chunk first)
128    values.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
129    let positions: Vec<ScoredPosition> = values
130        .into_iter()
131        .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
132        .collect();
133
134    Some((combined, positions))
135}
136
137/// Rerank L1 candidates by exact dense vector distance.
138///
139/// Groups candidates by segment for batched I/O, sorts flat indexes for
140/// sequential mmap access, and scores all vectors in a single SIMD batch
141/// per segment. Reuses buffers across segments to avoid per-candidate
142/// heap allocation.
143///
144/// When `unit_norm` is set in the config, scoring uses dot-product only
145/// (skips per-vector norm computation — ~40% less work).
146pub async fn rerank<D: crate::directories::Directory + 'static>(
147    searcher: &crate::index::Searcher<D>,
148    candidates: &[SearchResult],
149    config: &RerankerConfig,
150    final_limit: usize,
151) -> crate::error::Result<Vec<SearchResult>> {
152    if config.vector.is_empty() || candidates.is_empty() {
153        return Ok(Vec::new());
154    }
155
156    let t0 = std::time::Instant::now();
157    let field_id = config.field.0;
158    let query = &config.vector;
159    let query_dim = query.len();
160    let segments = searcher.segment_readers();
161    let seg_by_id = searcher.segment_map();
162
163    // Precompute query inverse-norm and f16 query once (reused across all segments)
164    use crate::structures::simd;
165    let norm_q_sq = simd::dot_product_f32(query, query, query_dim);
166    let inv_norm_q = if norm_q_sq < f32::EPSILON {
167        0.0
168    } else {
169        simd::fast_inv_sqrt(norm_q_sq)
170    };
171    let query_f16: Vec<u16> = query.iter().map(|&v| simd::f32_to_f16(v)).collect();
172    let pq = PrecompQuery {
173        query,
174        inv_norm_q,
175        query_f16: &query_f16,
176    };
177
178    // ── Phase 1: Group candidates by segment ──────────────────────────────
179    let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
180    let mut skipped = 0u32;
181
182    for (ci, candidate) in candidates.iter().enumerate() {
183        if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
184            segment_groups.entry(si).or_default().push(ci);
185        } else {
186            skipped += 1;
187        }
188    }
189
190    // ── Phase 2: Per-segment batched resolve + read + score (concurrent) ──
191    // Each segment runs independently: resolve flat indexes, read vectors,
192    // and score — all overlapping I/O across segments via join_all.
193    let query_ref = pq.query;
194    let inv_norm_q_val = pq.inv_norm_q;
195    let query_f16_ref = pq.query_f16;
196
197    let segment_futs: Vec<_> = segment_groups
198        .into_iter()
199        .map(|(si, candidate_indices)| {
200            #[allow(clippy::redundant_locals)]
201            let segments = &segments;
202            #[allow(clippy::redundant_locals)]
203            let candidates = candidates;
204            #[allow(clippy::redundant_locals)]
205            let query_ref = query_ref;
206            #[allow(clippy::redundant_locals)]
207            let query_f16_ref = query_f16_ref;
208            #[allow(clippy::redundant_locals)]
209            let config = config;
210            async move {
211                let mut scores: Vec<(usize, u32, f32)> = Vec::new();
212                let mut vectors = 0usize;
213                let mut seg_skipped = 0u32;
214
215                let Some(lazy_flat) = segments[si].flat_vectors().get(&field_id) else {
216                    return Ok::<_, crate::error::Error>((
217                        scores,
218                        vectors,
219                        candidate_indices.len() as u32,
220                    ));
221                };
222                if lazy_flat.dim != query_dim {
223                    return Ok((scores, vectors, candidate_indices.len() as u32));
224                }
225
226                let vbs = lazy_flat.vector_byte_size();
227                let quant = lazy_flat.quantization;
228
229                // Resolve flat indexes for all candidates in this segment
230                let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
231                for &ci in &candidate_indices {
232                    let local_doc_id = candidates[ci].doc_id;
233                    let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
234                    if count == 0 {
235                        seg_skipped += 1;
236                        continue;
237                    }
238                    for j in 0..count {
239                        let (_, ordinal) = lazy_flat.get_doc_id(start + j);
240                        resolved.push((ci, start + j, ordinal as u32));
241                    }
242                }
243
244                if resolved.is_empty() {
245                    return Ok((scores, vectors, seg_skipped));
246                }
247
248                let n = resolved.len();
249                vectors = n;
250
251                // Sort by flat_idx for sequential mmap access
252                resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
253
254                let first_idx = resolved[0].1;
255                let last_idx = resolved[n - 1].1;
256                let span = last_idx - first_idx + 1;
257
258                let mut raw_buf: Vec<u8> = vec![0u8; n * vbs];
259
260                if span <= n * 4 {
261                    let range_bytes = lazy_flat
262                        .read_vectors_batch(first_idx, span)
263                        .await
264                        .map_err(crate::error::Error::Io)?;
265                    let rb = range_bytes.as_slice();
266                    for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
267                        let rel = flat_idx - first_idx;
268                        let src = &rb[rel * vbs..(rel + 1) * vbs];
269                        raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
270                    }
271                } else {
272                    for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
273                        lazy_flat
274                            .read_vector_raw_into(
275                                flat_idx,
276                                &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
277                            )
278                            .await
279                            .map_err(crate::error::Error::Io)?;
280                    }
281                }
282
283                // Reconstruct PrecompQuery from captured components
284                let pq = PrecompQuery {
285                    query: query_ref,
286                    inv_norm_q: inv_norm_q_val,
287                    query_f16: query_f16_ref,
288                };
289
290                let mut scores_buf: Vec<f32> = vec![0.0; n];
291
292                // Matryoshka pre-filter
293                if let Some(mdims) = config.matryoshka_dims
294                    && mdims < query_dim
295                    && n > final_limit * 2
296                {
297                    let trunc_dim = mdims;
298                    let trunc_pq = PrecompQuery {
299                        query: &query_ref[..trunc_dim],
300                        inv_norm_q: {
301                            let nq = simd::dot_product_f32(
302                                &query_ref[..trunc_dim],
303                                &query_ref[..trunc_dim],
304                                trunc_dim,
305                            );
306                            if nq < f32::EPSILON {
307                                0.0
308                            } else {
309                                simd::fast_inv_sqrt(nq)
310                            }
311                        },
312                        query_f16: &query_f16_ref[..trunc_dim],
313                    };
314                    let trunc_vbs = trunc_dim * quant.element_size();
315                    for i in 0..n {
316                        let vec_start = i * vbs;
317                        score_batch_precomp(
318                            &trunc_pq,
319                            &raw_buf[vec_start..vec_start + trunc_vbs],
320                            quant,
321                            trunc_dim,
322                            &mut scores_buf[i..i + 1],
323                            config.unit_norm,
324                        );
325                    }
326
327                    let per_doc_cap: usize = match &config.combiner {
328                        super::MultiValueCombiner::Max => 1,
329                        super::MultiValueCombiner::WeightedTopK { k, .. } => *k,
330                        _ => usize::MAX,
331                    };
332
333                    let mut ranked: Vec<(usize, f32)> =
334                        (0..n).map(|i| (i, scores_buf[i])).collect();
335                    ranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
336
337                    let mut survivors: Vec<(usize, f32)> =
338                        Vec::with_capacity(n.min(final_limit * 4));
339                    let mut doc_vector_counts: FxHashMap<usize, usize> = FxHashMap::default();
340                    let mut unique_docs = 0usize;
341
342                    for &(orig_idx, score) in &ranked {
343                        let ci = resolved[orig_idx].0;
344                        let count = doc_vector_counts.entry(ci).or_insert(0);
345
346                        if *count >= per_doc_cap {
347                            continue;
348                        }
349                        if *count == 0 {
350                            unique_docs += 1;
351                        }
352                        *count += 1;
353                        survivors.push((orig_idx, score));
354
355                        if unique_docs >= final_limit && survivors.len() >= final_limit * 2 {
356                            break;
357                        }
358                    }
359
360                    scores.reserve(survivors.len());
361                    for &(orig_idx, _) in &survivors {
362                        let vec_start = orig_idx * vbs;
363                        let mut score = 0.0f32;
364                        score_batch_precomp(
365                            &pq,
366                            &raw_buf[vec_start..vec_start + vbs],
367                            quant,
368                            query_dim,
369                            std::slice::from_mut(&mut score),
370                            config.unit_norm,
371                        );
372                        let (ci, _, ordinal) = resolved[orig_idx];
373                        scores.push((ci, ordinal, score));
374                    }
375
376                    let filtered = n - survivors.len();
377                    log::debug!(
378                        "[reranker] matryoshka pre-filter: {}/{} dims, {}/{} vectors survived from {} unique docs (filtered {}, per_doc_cap={})",
379                        trunc_dim,
380                        query_dim,
381                        survivors.len(),
382                        n,
383                        unique_docs,
384                        filtered,
385                        per_doc_cap
386                    );
387                } else {
388                    score_batch_precomp(
389                        &pq,
390                        &raw_buf[..n * vbs],
391                        quant,
392                        query_dim,
393                        &mut scores_buf[..n],
394                        config.unit_norm,
395                    );
396
397                    scores.reserve(n);
398                    for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
399                        scores.push((ci, ordinal, scores_buf[buf_idx]));
400                    }
401                }
402
403                Ok((scores, vectors, seg_skipped))
404            }
405        })
406        .collect();
407
408    let results = futures::future::join_all(segment_futs).await;
409
410    let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
411    let mut total_vectors = 0usize;
412    for result in results {
413        let (scores, vectors, seg_skipped) = result?;
414        all_scores.extend(scores);
415        total_vectors += vectors;
416        skipped += seg_skipped;
417    }
418
419    let read_score_elapsed = t0.elapsed();
420
421    if total_vectors == 0 {
422        log::debug!(
423            "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
424            field_id,
425            candidates.len()
426        );
427        return Ok(Vec::new());
428    }
429
430    // ── Phase 3: Combine scores and build results ─────────────────────────
431    // Sort flat buffer by candidate_idx so contiguous runs belong to the same doc
432    all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
433
434    let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
435    let mut ordinal_pairs: Vec<(u32, f32)> = Vec::new();
436    let mut i = 0;
437    while i < all_scores.len() {
438        let ci = all_scores[i].0;
439        let run_start = i;
440        while i < all_scores.len() && all_scores[i].0 == ci {
441            i += 1;
442        }
443        let run = &mut all_scores[run_start..i];
444
445        // Build (ordinal, score) slice for combiner (reuses hoisted buffer)
446        ordinal_pairs.clear();
447        ordinal_pairs.extend(run.iter().map(|&(_, ord, s)| (ord, s)));
448        let combined = config.combiner.combine(&ordinal_pairs);
449
450        // Sort positions by score descending (best chunk first)
451        run.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
452        let positions: Vec<ScoredPosition> = run
453            .iter()
454            .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
455            .collect();
456
457        scored.push(SearchResult {
458            doc_id: candidates[ci].doc_id,
459            score: combined,
460            segment_id: candidates[ci].segment_id,
461            positions: vec![(field_id, positions)],
462        });
463    }
464
465    scored.sort_unstable_by(|a, b| {
466        b.score
467            .partial_cmp(&a.score)
468            .unwrap_or(std::cmp::Ordering::Equal)
469    });
470    scored.truncate(final_limit);
471
472    log::debug!(
473        "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
474        field_id,
475        candidates.len(),
476        scored.len(),
477        skipped,
478        total_vectors,
479        config.unit_norm,
480        read_score_elapsed.as_secs_f64() * 1000.0,
481        t0.elapsed().as_secs_f64() * 1000.0,
482    );
483
484    Ok(scored)
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use crate::dsl::{Document, Field};
491
492    fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
493        RerankerConfig {
494            field: Field(0),
495            vector,
496            combiner,
497            unit_norm: false,
498            matryoshka_dims: None,
499        }
500    }
501
502    #[test]
503    fn test_score_document_single_value() {
504        let mut doc = Document::new();
505        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
506
507        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
508        let (score, positions) = score_document(&doc, &config).unwrap();
509        // cosine([1,0,0], [1,0,0]) = 1.0
510        assert!((score - 1.0).abs() < 1e-6);
511        assert_eq!(positions.len(), 1);
512        assert_eq!(positions[0].position, 0); // ordinal 0
513    }
514
515    #[test]
516    fn test_score_document_orthogonal() {
517        let mut doc = Document::new();
518        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
519
520        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
521        let (score, _) = score_document(&doc, &config).unwrap();
522        // cosine([1,0,0], [0,1,0]) = 0.0
523        assert!(score.abs() < 1e-6);
524    }
525
526    #[test]
527    fn test_score_document_multi_value_max() {
528        let mut doc = Document::new();
529        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0 (same direction)
530        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0 (orthogonal)
531
532        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
533        let (score, positions) = score_document(&doc, &config).unwrap();
534        assert!((score - 1.0).abs() < 1e-6);
535        // Best chunk first
536        assert_eq!(positions.len(), 2);
537        assert_eq!(positions[0].position, 0); // ordinal 0 scored highest
538        assert!((positions[0].score - 1.0).abs() < 1e-6);
539    }
540
541    #[test]
542    fn test_score_document_multi_value_avg() {
543        let mut doc = Document::new();
544        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0
545        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0
546
547        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
548        let (score, _) = score_document(&doc, &config).unwrap();
549        // avg(1.0, 0.0) = 0.5
550        assert!((score - 0.5).abs() < 1e-6);
551    }
552
553    #[test]
554    fn test_score_document_missing_field() {
555        let mut doc = Document::new();
556        // Add to field 1, not field 0
557        doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
558
559        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
560        assert!(score_document(&doc, &config).is_none());
561    }
562
563    #[test]
564    fn test_score_document_wrong_field_type() {
565        let mut doc = Document::new();
566        doc.add_text(Field(0), "not a vector");
567
568        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
569        assert!(score_document(&doc, &config).is_none());
570    }
571
572    #[test]
573    fn test_score_document_dimension_mismatch() {
574        let mut doc = Document::new();
575        doc.add_dense_vector(Field(0), vec![1.0, 0.0]); // 2D
576
577        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); // 3D query
578        assert!(score_document(&doc, &config).is_none());
579    }
580
581    #[test]
582    fn test_score_document_empty_query_vector() {
583        let mut doc = Document::new();
584        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
585
586        let config = make_config(vec![], MultiValueCombiner::Max);
587        // Empty query can't match any stored vector (dimension mismatch)
588        assert!(score_document(&doc, &config).is_none());
589    }
590}