Skip to main content

hermes_core/query/
reranker.rs

1//! L2 reranker: rerank L1 candidates by exact dense vector distance on stored vectors
2//!
3//! Optimized for throughput:
4//! - Candidates grouped by segment for batched I/O
5//! - Flat indexes sorted for sequential mmap access (OS readahead)
6//! - Single SIMD batch-score call per segment (not per candidate)
7//! - Reusable buffers across segments (no per-candidate heap allocation)
8//! - unit_norm fast path: skip per-vector norm when vectors are pre-normalized
9
10use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16/// Precomputed query data for dense reranking (computed once, reused across segments).
17struct PrecompQuery<'a> {
18    query: &'a [f32],
19    inv_norm_q: f32,
20    query_f16: &'a [u16],
21}
22
23/// Batch SIMD scoring with precomputed query norm + f16 query.
24#[inline]
25#[allow(clippy::too_many_arguments)]
26fn score_batch_precomp(
27    pq: &PrecompQuery<'_>,
28    raw: &[u8],
29    quant: crate::dsl::DenseVectorQuantization,
30    dim: usize,
31    scores: &mut [f32],
32    unit_norm: bool,
33) {
34    let query = pq.query;
35    let inv_norm_q = pq.inv_norm_q;
36    let query_f16 = pq.query_f16;
37    use crate::dsl::DenseVectorQuantization;
38    use crate::structures::simd;
39    match (quant, unit_norm) {
40        (DenseVectorQuantization::F32, false) => {
41            let num_floats = scores.len() * dim;
42            let vectors: &[f32] =
43                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
44            simd::batch_cosine_scores_precomp(query, vectors, dim, scores, inv_norm_q);
45        }
46        (DenseVectorQuantization::F32, true) => {
47            let num_floats = scores.len() * dim;
48            let vectors: &[f32] =
49                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
50            simd::batch_dot_scores_precomp(query, vectors, dim, scores, inv_norm_q);
51        }
52        (DenseVectorQuantization::F16, false) => {
53            simd::batch_cosine_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
54        }
55        (DenseVectorQuantization::F16, true) => {
56            simd::batch_dot_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
57        }
58        (DenseVectorQuantization::UInt8, false) => {
59            simd::batch_cosine_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
60        }
61        (DenseVectorQuantization::UInt8, true) => {
62            simd::batch_dot_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
63        }
64    }
65}
66
67/// Configuration for L2 dense vector reranking
68#[derive(Debug, Clone)]
69pub struct RerankerConfig {
70    /// Dense vector field (must be stored)
71    pub field: Field,
72    /// Query vector
73    pub vector: Vec<f32>,
74    /// How to combine scores for multi-valued documents
75    pub combiner: MultiValueCombiner,
76    /// Whether stored vectors are pre-normalized to unit L2 norm.
77    /// When true, scoring uses dot-product only (skips per-vector norm — ~40% faster).
78    pub unit_norm: bool,
79    /// Matryoshka pre-filter: number of leading dimensions to use for cheap
80    /// approximate scoring before full-dimension exact reranking.
81    /// When set, scores all candidates on the first `matryoshka_dims` dimensions,
82    /// keeps the top `final_limit × 2` candidates, then does full-dimension
83    /// exact scoring on survivors only. Skips ~50-70% of full cosine computations.
84    /// Set to `None` to disable (default: score all candidates at full dimension).
85    pub matryoshka_dims: Option<usize>,
86}
87
88/// Score a single document against the query vector (used by tests).
89#[cfg(test)]
90use crate::structures::simd::cosine_similarity;
91#[cfg(test)]
92fn score_document(
93    doc: &crate::dsl::Document,
94    config: &RerankerConfig,
95) -> Option<(f32, Vec<ScoredPosition>)> {
96    let query_dim = config.vector.len();
97    let mut values: Vec<(u32, f32)> = doc
98        .get_all(config.field)
99        .filter_map(|fv| fv.as_dense_vector())
100        .enumerate()
101        .filter_map(|(ordinal, vec)| {
102            if vec.len() != query_dim {
103                return None;
104            }
105            let score = cosine_similarity(&config.vector, vec);
106            Some((ordinal as u32, score))
107        })
108        .collect();
109
110    if values.is_empty() {
111        return None;
112    }
113
114    let combined = config.combiner.combine(&values);
115
116    // Sort ordinals by score descending (best chunk first)
117    values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
118    let positions: Vec<ScoredPosition> = values
119        .into_iter()
120        .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
121        .collect();
122
123    Some((combined, positions))
124}
125
126/// Rerank L1 candidates by exact dense vector distance.
127///
128/// Groups candidates by segment for batched I/O, sorts flat indexes for
129/// sequential mmap access, and scores all vectors in a single SIMD batch
130/// per segment. Reuses buffers across segments to avoid per-candidate
131/// heap allocation.
132///
133/// When `unit_norm` is set in the config, scoring uses dot-product only
134/// (skips per-vector norm computation — ~40% less work).
135pub async fn rerank<D: crate::directories::Directory + 'static>(
136    searcher: &crate::index::Searcher<D>,
137    candidates: &[SearchResult],
138    config: &RerankerConfig,
139    final_limit: usize,
140) -> crate::error::Result<Vec<SearchResult>> {
141    if config.vector.is_empty() || candidates.is_empty() {
142        return Ok(Vec::new());
143    }
144
145    let t0 = std::time::Instant::now();
146    let field_id = config.field.0;
147    let query = &config.vector;
148    let query_dim = query.len();
149    let segments = searcher.segment_readers();
150    let seg_by_id = searcher.segment_map();
151
152    // Precompute query inverse-norm and f16 query once (reused across all segments)
153    use crate::structures::simd;
154    let norm_q_sq = simd::dot_product_f32(query, query, query_dim);
155    let inv_norm_q = if norm_q_sq < f32::EPSILON {
156        0.0
157    } else {
158        simd::fast_inv_sqrt(norm_q_sq)
159    };
160    let query_f16: Vec<u16> = query.iter().map(|&v| simd::f32_to_f16(v)).collect();
161    let pq = PrecompQuery {
162        query,
163        inv_norm_q,
164        query_f16: &query_f16,
165    };
166
167    // ── Phase 1: Group candidates by segment ──────────────────────────────
168    let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
169    let mut skipped = 0u32;
170
171    for (ci, candidate) in candidates.iter().enumerate() {
172        if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
173            segment_groups.entry(si).or_default().push(ci);
174        } else {
175            skipped += 1;
176        }
177    }
178
179    // ── Phase 2: Per-segment batched resolve + read + score ───────────────
180    // Flat buffer: (candidate_idx, ordinal, score) — one allocation for all candidates.
181    // Multi-value docs produce multiple entries per candidate_idx.
182    let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
183    let mut total_vectors = 0usize;
184    // Reusable buffers across segments (avoids per-candidate allocation)
185    let mut raw_buf: Vec<u8> = Vec::new();
186    let mut scores_buf: Vec<f32> = Vec::new();
187
188    for (si, candidate_indices) in &segment_groups {
189        let Some(lazy_flat) = segments[*si].flat_vectors().get(&field_id) else {
190            skipped += candidate_indices.len() as u32;
191            continue;
192        };
193        if lazy_flat.dim != query_dim {
194            skipped += candidate_indices.len() as u32;
195            continue;
196        }
197
198        let vbs = lazy_flat.vector_byte_size();
199        let quant = lazy_flat.quantization;
200
201        // Resolve flat indexes for all candidates in this segment
202        // Each entry: (candidate_idx, flat_vector_idx, ordinal)
203        let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
204        for &ci in candidate_indices {
205            let local_doc_id = candidates[ci].doc_id;
206            let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
207            if count == 0 {
208                skipped += 1;
209                continue;
210            }
211            for j in 0..count {
212                let (_, ordinal) = lazy_flat.get_doc_id(start + j);
213                resolved.push((ci, start + j, ordinal as u32));
214            }
215        }
216
217        if resolved.is_empty() {
218            continue;
219        }
220
221        let n = resolved.len();
222        total_vectors += n;
223
224        // Sort by flat_idx for sequential mmap access (better page locality)
225        resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
226
227        // Coalesced range read: single mmap read covering [min_idx..max_idx+1],
228        // then selectively copy needed vectors. One page fault instead of N.
229        let first_idx = resolved[0].1;
230        let last_idx = resolved[n - 1].1;
231        let span = last_idx - first_idx + 1;
232
233        raw_buf.resize(n * vbs, 0);
234
235        // Use coalesced read if span waste is reasonable (< 4× the needed count),
236        // otherwise fall back to individual reads for very sparse patterns
237        if span <= n * 4 {
238            let range_bytes = lazy_flat
239                .read_vectors_batch(first_idx, span)
240                .await
241                .expect("reranker: failed to read vector batch from flat storage");
242            let rb = range_bytes.as_slice();
243            for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
244                let rel = flat_idx - first_idx;
245                let src = &rb[rel * vbs..(rel + 1) * vbs];
246                raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
247            }
248        } else {
249            for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
250                lazy_flat
251                    .read_vector_raw_into(
252                        flat_idx,
253                        &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
254                    )
255                    .await
256                    .expect("reranker: failed to read individual vector from flat storage");
257            }
258        }
259
260        // Single batch SIMD scoring for all vectors in this segment
261        scores_buf.resize(n, 0.0);
262
263        // Matryoshka pre-filter: score on truncated dimensions first, then
264        // full-dimension scoring only on survivors.
265        if let Some(mdims) = config.matryoshka_dims
266            && mdims < query_dim
267            && n > final_limit * 2
268        {
269            // Phase 2a: truncated-dimension approximate SIMD scoring.
270            // Score directly from raw_buf — no intermediate buffer copy needed.
271            // Each vector's first trunc_vbs bytes are the leading dimensions
272            // (storage is dimension-contiguous within each vector).
273            let trunc_dim = mdims;
274            let trunc_pq = PrecompQuery {
275                query: &query[..trunc_dim],
276                inv_norm_q: {
277                    let nq =
278                        simd::dot_product_f32(&query[..trunc_dim], &query[..trunc_dim], trunc_dim);
279                    if nq < f32::EPSILON {
280                        0.0
281                    } else {
282                        simd::fast_inv_sqrt(nq)
283                    }
284                },
285                query_f16: &query_f16[..trunc_dim],
286            };
287            let trunc_vbs = trunc_dim * quant.element_size();
288            for i in 0..n {
289                let vec_start = i * vbs;
290                score_batch_precomp(
291                    &trunc_pq,
292                    &raw_buf[vec_start..vec_start + trunc_vbs],
293                    quant,
294                    trunc_dim,
295                    &mut scores_buf[i..i + 1],
296                    config.unit_norm,
297                );
298            }
299
300            // Phase 2b: diversity-aware selection — ensure at least final_limit
301            // unique documents survive the pre-filter. For saturating combiners
302            // (Max, WeightedTopK), cap per-doc vectors to avoid one multi-valued
303            // doc crowding out others.
304            let per_doc_cap: usize = match &config.combiner {
305                super::MultiValueCombiner::Max => 1,
306                super::MultiValueCombiner::WeightedTopK { k, .. } => *k,
307                _ => usize::MAX,
308            };
309
310            let mut ranked: Vec<(usize, f32)> = (0..n).map(|i| (i, scores_buf[i])).collect();
311            ranked.sort_unstable_by(|a, b| {
312                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
313            });
314
315            let mut survivors: Vec<(usize, f32)> = Vec::with_capacity(n.min(final_limit * 4));
316            let mut doc_vector_counts: FxHashMap<usize, usize> = FxHashMap::default();
317            let mut unique_docs = 0usize;
318
319            for &(orig_idx, score) in &ranked {
320                let ci = resolved[orig_idx].0;
321                let count = doc_vector_counts.entry(ci).or_insert(0);
322
323                if *count >= per_doc_cap {
324                    continue;
325                }
326                if *count == 0 {
327                    unique_docs += 1;
328                }
329                *count += 1;
330                survivors.push((orig_idx, score));
331
332                // Stop once we have enough unique docs AND enough entries
333                if unique_docs >= final_limit && survivors.len() >= final_limit * 2 {
334                    break;
335                }
336            }
337
338            // Phase 2c: full-dimension exact SIMD scoring on survivors only.
339            // Score directly from raw_buf at original offsets — zero-copy.
340            all_scores.reserve(survivors.len());
341            for &(orig_idx, _) in &survivors {
342                let vec_start = orig_idx * vbs;
343                let mut score = 0.0f32;
344                score_batch_precomp(
345                    &pq,
346                    &raw_buf[vec_start..vec_start + vbs],
347                    quant,
348                    query_dim,
349                    std::slice::from_mut(&mut score),
350                    config.unit_norm,
351                );
352                let (ci, _, ordinal) = resolved[orig_idx];
353                all_scores.push((ci, ordinal, score));
354            }
355
356            let filtered = n - survivors.len();
357            log::debug!(
358                "[reranker] matryoshka pre-filter: {}/{} dims, {}/{} vectors survived from {} unique docs (filtered {}, per_doc_cap={})",
359                trunc_dim,
360                query_dim,
361                survivors.len(),
362                n,
363                unique_docs,
364                filtered,
365                per_doc_cap
366            );
367        } else {
368            // No pre-filter: full-dimension SIMD scoring on all candidates
369            score_batch_precomp(
370                &pq,
371                &raw_buf[..n * vbs],
372                quant,
373                query_dim,
374                &mut scores_buf[..n],
375                config.unit_norm,
376            );
377
378            all_scores.reserve(n);
379            for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
380                all_scores.push((ci, ordinal, scores_buf[buf_idx]));
381            }
382        }
383    }
384
385    let read_score_elapsed = t0.elapsed();
386
387    if total_vectors == 0 {
388        log::debug!(
389            "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
390            field_id,
391            candidates.len()
392        );
393        return Ok(Vec::new());
394    }
395
396    // ── Phase 3: Combine scores and build results ─────────────────────────
397    // Sort flat buffer by candidate_idx so contiguous runs belong to the same doc
398    all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
399
400    let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
401    let mut i = 0;
402    while i < all_scores.len() {
403        let ci = all_scores[i].0;
404        let run_start = i;
405        while i < all_scores.len() && all_scores[i].0 == ci {
406            i += 1;
407        }
408        let run = &mut all_scores[run_start..i];
409
410        // Build (ordinal, score) slice for combiner
411        let ordinal_pairs: Vec<(u32, f32)> = run.iter().map(|&(_, ord, s)| (ord, s)).collect();
412        let combined = config.combiner.combine(&ordinal_pairs);
413
414        // Sort positions by score descending (best chunk first)
415        run.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
416        let positions: Vec<ScoredPosition> = run
417            .iter()
418            .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
419            .collect();
420
421        scored.push(SearchResult {
422            doc_id: candidates[ci].doc_id,
423            score: combined,
424            segment_id: candidates[ci].segment_id,
425            positions: vec![(field_id, positions)],
426        });
427    }
428
429    scored.sort_by(|a, b| {
430        b.score
431            .partial_cmp(&a.score)
432            .unwrap_or(std::cmp::Ordering::Equal)
433    });
434    scored.truncate(final_limit);
435
436    log::debug!(
437        "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
438        field_id,
439        candidates.len(),
440        scored.len(),
441        skipped,
442        total_vectors,
443        config.unit_norm,
444        read_score_elapsed.as_secs_f64() * 1000.0,
445        t0.elapsed().as_secs_f64() * 1000.0,
446    );
447
448    Ok(scored)
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use crate::dsl::{Document, Field};
455
456    fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
457        RerankerConfig {
458            field: Field(0),
459            vector,
460            combiner,
461            unit_norm: false,
462            matryoshka_dims: None,
463        }
464    }
465
466    #[test]
467    fn test_score_document_single_value() {
468        let mut doc = Document::new();
469        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
470
471        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
472        let (score, positions) = score_document(&doc, &config).unwrap();
473        // cosine([1,0,0], [1,0,0]) = 1.0
474        assert!((score - 1.0).abs() < 1e-6);
475        assert_eq!(positions.len(), 1);
476        assert_eq!(positions[0].position, 0); // ordinal 0
477    }
478
479    #[test]
480    fn test_score_document_orthogonal() {
481        let mut doc = Document::new();
482        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
483
484        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
485        let (score, _) = score_document(&doc, &config).unwrap();
486        // cosine([1,0,0], [0,1,0]) = 0.0
487        assert!(score.abs() < 1e-6);
488    }
489
490    #[test]
491    fn test_score_document_multi_value_max() {
492        let mut doc = Document::new();
493        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0 (same direction)
494        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0 (orthogonal)
495
496        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
497        let (score, positions) = score_document(&doc, &config).unwrap();
498        assert!((score - 1.0).abs() < 1e-6);
499        // Best chunk first
500        assert_eq!(positions.len(), 2);
501        assert_eq!(positions[0].position, 0); // ordinal 0 scored highest
502        assert!((positions[0].score - 1.0).abs() < 1e-6);
503    }
504
505    #[test]
506    fn test_score_document_multi_value_avg() {
507        let mut doc = Document::new();
508        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0
509        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0
510
511        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
512        let (score, _) = score_document(&doc, &config).unwrap();
513        // avg(1.0, 0.0) = 0.5
514        assert!((score - 0.5).abs() < 1e-6);
515    }
516
517    #[test]
518    fn test_score_document_missing_field() {
519        let mut doc = Document::new();
520        // Add to field 1, not field 0
521        doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
522
523        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
524        assert!(score_document(&doc, &config).is_none());
525    }
526
527    #[test]
528    fn test_score_document_wrong_field_type() {
529        let mut doc = Document::new();
530        doc.add_text(Field(0), "not a vector");
531
532        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
533        assert!(score_document(&doc, &config).is_none());
534    }
535
536    #[test]
537    fn test_score_document_dimension_mismatch() {
538        let mut doc = Document::new();
539        doc.add_dense_vector(Field(0), vec![1.0, 0.0]); // 2D
540
541        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); // 3D query
542        assert!(score_document(&doc, &config).is_none());
543    }
544
545    #[test]
546    fn test_score_document_empty_query_vector() {
547        let mut doc = Document::new();
548        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
549
550        let config = make_config(vec![], MultiValueCombiner::Max);
551        // Empty query can't match any stored vector (dimension mismatch)
552        assert!(score_document(&doc, &config).is_none());
553    }
554}