Skip to main content

hermes_core/query/
reranker.rs

1//! L2 reranker: rerank L1 candidates by exact dense vector distance on stored vectors
2//!
3//! Optimized for throughput:
4//! - Candidates grouped by segment for batched I/O
5//! - Flat indexes sorted for sequential mmap access (OS readahead)
6//! - Single SIMD batch-score call per segment (not per candidate)
7//! - Reusable buffers across segments (no per-candidate heap allocation)
8//! - unit_norm fast path: skip per-vector norm when vectors are pre-normalized
9
10use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16/// Precomputed query data for dense reranking (computed once, reused across segments).
17struct PrecompQuery<'a> {
18    query: &'a [f32],
19    inv_norm_q: f32,
20    query_f16: &'a [u16],
21}
22
23/// Batch SIMD scoring with precomputed query norm + f16 query.
24#[inline]
25#[allow(clippy::too_many_arguments)]
26fn score_batch_precomp(
27    pq: &PrecompQuery<'_>,
28    raw: &[u8],
29    quant: crate::dsl::DenseVectorQuantization,
30    dim: usize,
31    scores: &mut [f32],
32    unit_norm: bool,
33) {
34    let query = pq.query;
35    let inv_norm_q = pq.inv_norm_q;
36    let query_f16 = pq.query_f16;
37    use crate::dsl::DenseVectorQuantization;
38    use crate::structures::simd;
39    match (quant, unit_norm) {
40        (DenseVectorQuantization::F32, false) => {
41            let num_floats = scores.len() * dim;
42            let vectors: &[f32] =
43                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
44            simd::batch_cosine_scores_precomp(query, vectors, dim, scores, inv_norm_q);
45        }
46        (DenseVectorQuantization::F32, true) => {
47            let num_floats = scores.len() * dim;
48            let vectors: &[f32] =
49                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
50            simd::batch_dot_scores_precomp(query, vectors, dim, scores, inv_norm_q);
51        }
52        (DenseVectorQuantization::F16, false) => {
53            simd::batch_cosine_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
54        }
55        (DenseVectorQuantization::F16, true) => {
56            simd::batch_dot_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
57        }
58        (DenseVectorQuantization::UInt8, false) => {
59            simd::batch_cosine_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
60        }
61        (DenseVectorQuantization::UInt8, true) => {
62            simd::batch_dot_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
63        }
64    }
65}
66
67/// Configuration for L2 dense vector reranking
68#[derive(Debug, Clone)]
69pub struct RerankerConfig {
70    /// Dense vector field (must be stored)
71    pub field: Field,
72    /// Query vector
73    pub vector: Vec<f32>,
74    /// How to combine scores for multi-valued documents
75    pub combiner: MultiValueCombiner,
76    /// Whether stored vectors are pre-normalized to unit L2 norm.
77    /// When true, scoring uses dot-product only (skips per-vector norm — ~40% faster).
78    pub unit_norm: bool,
79}
80
81/// Score a single document against the query vector (used by tests).
82#[cfg(test)]
83use crate::structures::simd::cosine_similarity;
84#[cfg(test)]
85fn score_document(
86    doc: &crate::dsl::Document,
87    config: &RerankerConfig,
88) -> Option<(f32, Vec<ScoredPosition>)> {
89    let query_dim = config.vector.len();
90    let mut values: Vec<(u32, f32)> = doc
91        .get_all(config.field)
92        .filter_map(|fv| fv.as_dense_vector())
93        .enumerate()
94        .filter_map(|(ordinal, vec)| {
95            if vec.len() != query_dim {
96                return None;
97            }
98            let score = cosine_similarity(&config.vector, vec);
99            Some((ordinal as u32, score))
100        })
101        .collect();
102
103    if values.is_empty() {
104        return None;
105    }
106
107    let combined = config.combiner.combine(&values);
108
109    // Sort ordinals by score descending (best chunk first)
110    values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
111    let positions: Vec<ScoredPosition> = values
112        .into_iter()
113        .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
114        .collect();
115
116    Some((combined, positions))
117}
118
119/// Rerank L1 candidates by exact dense vector distance.
120///
121/// Groups candidates by segment for batched I/O, sorts flat indexes for
122/// sequential mmap access, and scores all vectors in a single SIMD batch
123/// per segment. Reuses buffers across segments to avoid per-candidate
124/// heap allocation.
125///
126/// When `unit_norm` is set in the config, scoring uses dot-product only
127/// (skips per-vector norm computation — ~40% less work).
128pub async fn rerank<D: crate::directories::Directory + 'static>(
129    searcher: &crate::index::Searcher<D>,
130    candidates: &[SearchResult],
131    config: &RerankerConfig,
132    final_limit: usize,
133) -> crate::error::Result<Vec<SearchResult>> {
134    if config.vector.is_empty() || candidates.is_empty() {
135        return Ok(Vec::new());
136    }
137
138    let t0 = std::time::Instant::now();
139    let field_id = config.field.0;
140    let query = &config.vector;
141    let query_dim = query.len();
142    let segments = searcher.segment_readers();
143    let seg_by_id = searcher.segment_map();
144
145    // Precompute query inverse-norm and f16 query once (reused across all segments)
146    use crate::structures::simd;
147    let norm_q_sq = simd::dot_product_f32(query, query, query_dim);
148    let inv_norm_q = if norm_q_sq < f32::EPSILON {
149        0.0
150    } else {
151        simd::fast_inv_sqrt(norm_q_sq)
152    };
153    let query_f16: Vec<u16> = query.iter().map(|&v| simd::f32_to_f16(v)).collect();
154    let pq = PrecompQuery {
155        query,
156        inv_norm_q,
157        query_f16: &query_f16,
158    };
159
160    // ── Phase 1: Group candidates by segment ──────────────────────────────
161    let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
162    let mut skipped = 0u32;
163
164    for (ci, candidate) in candidates.iter().enumerate() {
165        if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
166            segment_groups.entry(si).or_default().push(ci);
167        } else {
168            skipped += 1;
169        }
170    }
171
172    // ── Phase 2: Per-segment batched resolve + read + score ───────────────
173    // Flat buffer: (candidate_idx, ordinal, score) — one allocation for all candidates.
174    // Multi-value docs produce multiple entries per candidate_idx.
175    let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
176    let mut total_vectors = 0usize;
177    // Reusable buffers across segments (avoids per-candidate allocation)
178    let mut raw_buf: Vec<u8> = Vec::new();
179    let mut scores_buf: Vec<f32> = Vec::new();
180
181    for (si, candidate_indices) in &segment_groups {
182        let Some(lazy_flat) = segments[*si].flat_vectors().get(&field_id) else {
183            skipped += candidate_indices.len() as u32;
184            continue;
185        };
186        if lazy_flat.dim != query_dim {
187            skipped += candidate_indices.len() as u32;
188            continue;
189        }
190
191        let vbs = lazy_flat.vector_byte_size();
192        let quant = lazy_flat.quantization;
193
194        // Resolve flat indexes for all candidates in this segment
195        // Each entry: (candidate_idx, flat_vector_idx, ordinal)
196        let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
197        for &ci in candidate_indices {
198            let local_doc_id = candidates[ci].doc_id - segments[*si].doc_id_offset();
199            let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
200            if count == 0 {
201                skipped += 1;
202                continue;
203            }
204            for j in 0..count {
205                let (_, ordinal) = lazy_flat.get_doc_id(start + j);
206                resolved.push((ci, start + j, ordinal as u32));
207            }
208        }
209
210        if resolved.is_empty() {
211            continue;
212        }
213
214        let n = resolved.len();
215        total_vectors += n;
216
217        // Sort by flat_idx for sequential mmap access (better page locality)
218        resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
219
220        // Coalesced range read: single mmap read covering [min_idx..max_idx+1],
221        // then selectively copy needed vectors. One page fault instead of N.
222        let first_idx = resolved[0].1;
223        let last_idx = resolved[n - 1].1;
224        let span = last_idx - first_idx + 1;
225
226        raw_buf.resize(n * vbs, 0);
227
228        // Use coalesced read if span waste is reasonable (< 4× the needed count),
229        // otherwise fall back to individual reads for very sparse patterns
230        if span <= n * 4 {
231            let range_bytes = match lazy_flat.read_vectors_batch(first_idx, span).await {
232                Ok(b) => b,
233                Err(_) => continue,
234            };
235            let rb = range_bytes.as_slice();
236            for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
237                let rel = flat_idx - first_idx;
238                let src = &rb[rel * vbs..(rel + 1) * vbs];
239                raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
240            }
241        } else {
242            for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
243                let _ = lazy_flat
244                    .read_vector_raw_into(
245                        flat_idx,
246                        &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
247                    )
248                    .await;
249            }
250        }
251
252        // Single batch SIMD scoring for all vectors in this segment
253        scores_buf.resize(n, 0.0);
254        score_batch_precomp(
255            &pq,
256            &raw_buf[..n * vbs],
257            quant,
258            query_dim,
259            &mut scores_buf[..n],
260            config.unit_norm,
261        );
262
263        // Append (candidate_idx, ordinal, score) to flat buffer
264        all_scores.reserve(n);
265        for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
266            all_scores.push((ci, ordinal, scores_buf[buf_idx]));
267        }
268    }
269
270    let read_score_elapsed = t0.elapsed();
271
272    if total_vectors == 0 {
273        log::debug!(
274            "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
275            field_id,
276            candidates.len()
277        );
278        return Ok(Vec::new());
279    }
280
281    // ── Phase 3: Combine scores and build results ─────────────────────────
282    // Sort flat buffer by candidate_idx so contiguous runs belong to the same doc
283    all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
284
285    let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
286    let mut i = 0;
287    while i < all_scores.len() {
288        let ci = all_scores[i].0;
289        let run_start = i;
290        while i < all_scores.len() && all_scores[i].0 == ci {
291            i += 1;
292        }
293        let run = &mut all_scores[run_start..i];
294
295        // Build (ordinal, score) slice for combiner
296        let ordinal_pairs: Vec<(u32, f32)> = run.iter().map(|&(_, ord, s)| (ord, s)).collect();
297        let combined = config.combiner.combine(&ordinal_pairs);
298
299        // Sort positions by score descending (best chunk first)
300        run.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
301        let positions: Vec<ScoredPosition> = run
302            .iter()
303            .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
304            .collect();
305
306        scored.push(SearchResult {
307            doc_id: candidates[ci].doc_id,
308            score: combined,
309            segment_id: candidates[ci].segment_id,
310            positions: vec![(field_id, positions)],
311        });
312    }
313
314    scored.sort_by(|a, b| {
315        b.score
316            .partial_cmp(&a.score)
317            .unwrap_or(std::cmp::Ordering::Equal)
318    });
319    scored.truncate(final_limit);
320
321    log::debug!(
322        "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
323        field_id,
324        candidates.len(),
325        scored.len(),
326        skipped,
327        total_vectors,
328        config.unit_norm,
329        read_score_elapsed.as_secs_f64() * 1000.0,
330        t0.elapsed().as_secs_f64() * 1000.0,
331    );
332
333    Ok(scored)
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::dsl::{Document, Field};
340
341    fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
342        RerankerConfig {
343            field: Field(0),
344            vector,
345            combiner,
346            unit_norm: false,
347        }
348    }
349
350    #[test]
351    fn test_score_document_single_value() {
352        let mut doc = Document::new();
353        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
354
355        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
356        let (score, positions) = score_document(&doc, &config).unwrap();
357        // cosine([1,0,0], [1,0,0]) = 1.0
358        assert!((score - 1.0).abs() < 1e-6);
359        assert_eq!(positions.len(), 1);
360        assert_eq!(positions[0].position, 0); // ordinal 0
361    }
362
363    #[test]
364    fn test_score_document_orthogonal() {
365        let mut doc = Document::new();
366        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
367
368        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
369        let (score, _) = score_document(&doc, &config).unwrap();
370        // cosine([1,0,0], [0,1,0]) = 0.0
371        assert!(score.abs() < 1e-6);
372    }
373
374    #[test]
375    fn test_score_document_multi_value_max() {
376        let mut doc = Document::new();
377        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0 (same direction)
378        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0 (orthogonal)
379
380        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
381        let (score, positions) = score_document(&doc, &config).unwrap();
382        assert!((score - 1.0).abs() < 1e-6);
383        // Best chunk first
384        assert_eq!(positions.len(), 2);
385        assert_eq!(positions[0].position, 0); // ordinal 0 scored highest
386        assert!((positions[0].score - 1.0).abs() < 1e-6);
387    }
388
389    #[test]
390    fn test_score_document_multi_value_avg() {
391        let mut doc = Document::new();
392        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0
393        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0
394
395        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
396        let (score, _) = score_document(&doc, &config).unwrap();
397        // avg(1.0, 0.0) = 0.5
398        assert!((score - 0.5).abs() < 1e-6);
399    }
400
401    #[test]
402    fn test_score_document_missing_field() {
403        let mut doc = Document::new();
404        // Add to field 1, not field 0
405        doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
406
407        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
408        assert!(score_document(&doc, &config).is_none());
409    }
410
411    #[test]
412    fn test_score_document_wrong_field_type() {
413        let mut doc = Document::new();
414        doc.add_text(Field(0), "not a vector");
415
416        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
417        assert!(score_document(&doc, &config).is_none());
418    }
419
420    #[test]
421    fn test_score_document_dimension_mismatch() {
422        let mut doc = Document::new();
423        doc.add_dense_vector(Field(0), vec![1.0, 0.0]); // 2D
424
425        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); // 3D query
426        assert!(score_document(&doc, &config).is_none());
427    }
428
429    #[test]
430    fn test_score_document_empty_query_vector() {
431        let mut doc = Document::new();
432        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
433
434        let config = make_config(vec![], MultiValueCombiner::Max);
435        // Empty query can't match any stored vector (dimension mismatch)
436        assert!(score_document(&doc, &config).is_none());
437    }
438}