Skip to main content

hermes_core/query/
reranker.rs

1//! L2 reranker: rerank L1 candidates by exact dense vector distance on stored vectors
2//!
3//! Optimized for throughput:
4//! - Candidates grouped by segment for batched I/O
5//! - Flat indexes sorted for sequential mmap access (OS readahead)
6//! - Single SIMD batch-score call per segment (not per candidate)
7//! - Reusable buffers across segments (no per-candidate heap allocation)
8//! - unit_norm fast path: skip per-vector norm when vectors are pre-normalized
9
10use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16/// Batch SIMD scoring — dispatches to cosine or dot-product scorer by quantization + unit_norm.
17#[inline]
18fn score_batch(
19    query: &[f32],
20    raw: &[u8],
21    quant: crate::dsl::DenseVectorQuantization,
22    dim: usize,
23    scores: &mut [f32],
24    unit_norm: bool,
25) {
26    use crate::dsl::DenseVectorQuantization;
27    use crate::structures::simd;
28    match (quant, unit_norm) {
29        (DenseVectorQuantization::F32, false) => {
30            let num_floats = scores.len() * dim;
31            let vectors: &[f32] =
32                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
33            simd::batch_cosine_scores(query, vectors, dim, scores);
34        }
35        (DenseVectorQuantization::F32, true) => {
36            let num_floats = scores.len() * dim;
37            let vectors: &[f32] =
38                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
39            simd::batch_dot_scores(query, vectors, dim, scores);
40        }
41        (DenseVectorQuantization::F16, false) => {
42            simd::batch_cosine_scores_f16(query, raw, dim, scores);
43        }
44        (DenseVectorQuantization::F16, true) => {
45            simd::batch_dot_scores_f16(query, raw, dim, scores);
46        }
47        (DenseVectorQuantization::UInt8, false) => {
48            simd::batch_cosine_scores_u8(query, raw, dim, scores);
49        }
50        (DenseVectorQuantization::UInt8, true) => {
51            simd::batch_dot_scores_u8(query, raw, dim, scores);
52        }
53    }
54}
55
56/// Configuration for L2 dense vector reranking
57#[derive(Debug, Clone)]
58pub struct RerankerConfig {
59    /// Dense vector field (must be stored)
60    pub field: Field,
61    /// Query vector
62    pub vector: Vec<f32>,
63    /// How to combine scores for multi-valued documents
64    pub combiner: MultiValueCombiner,
65    /// Whether stored vectors are pre-normalized to unit L2 norm.
66    /// When true, scoring uses dot-product only (skips per-vector norm — ~40% faster).
67    pub unit_norm: bool,
68}
69
70/// Score a single document against the query vector (used by tests).
71#[cfg(test)]
72use crate::structures::simd::cosine_similarity;
73#[cfg(test)]
74fn score_document(
75    doc: &crate::dsl::Document,
76    config: &RerankerConfig,
77) -> Option<(f32, Vec<ScoredPosition>)> {
78    let query_dim = config.vector.len();
79    let mut values: Vec<(u32, f32)> = doc
80        .get_all(config.field)
81        .filter_map(|fv| fv.as_dense_vector())
82        .enumerate()
83        .filter_map(|(ordinal, vec)| {
84            if vec.len() != query_dim {
85                return None;
86            }
87            let score = cosine_similarity(&config.vector, vec);
88            Some((ordinal as u32, score))
89        })
90        .collect();
91
92    if values.is_empty() {
93        return None;
94    }
95
96    let combined = config.combiner.combine(&values);
97
98    // Sort ordinals by score descending (best chunk first)
99    values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
100    let positions: Vec<ScoredPosition> = values
101        .into_iter()
102        .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
103        .collect();
104
105    Some((combined, positions))
106}
107
108/// Rerank L1 candidates by exact dense vector distance.
109///
110/// Groups candidates by segment for batched I/O, sorts flat indexes for
111/// sequential mmap access, and scores all vectors in a single SIMD batch
112/// per segment. Reuses buffers across segments to avoid per-candidate
113/// heap allocation.
114///
115/// When `unit_norm` is set in the config, scoring uses dot-product only
116/// (skips per-vector norm computation — ~40% less work).
117pub async fn rerank<D: crate::directories::Directory + 'static>(
118    searcher: &crate::index::Searcher<D>,
119    candidates: &[SearchResult],
120    config: &RerankerConfig,
121    final_limit: usize,
122) -> crate::error::Result<Vec<SearchResult>> {
123    if config.vector.is_empty() || candidates.is_empty() {
124        return Ok(Vec::new());
125    }
126
127    let t0 = std::time::Instant::now();
128    let field_id = config.field.0;
129    let query = &config.vector;
130    let query_dim = query.len();
131    let segments = searcher.segment_readers();
132    let seg_by_id = searcher.segment_map();
133
134    // ── Phase 1: Group candidates by segment ──────────────────────────────
135    let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
136    let mut skipped = 0u32;
137
138    for (ci, candidate) in candidates.iter().enumerate() {
139        if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
140            segment_groups.entry(si).or_default().push(ci);
141        } else {
142            skipped += 1;
143        }
144    }
145
146    // ── Phase 2: Per-segment batched resolve + read + score ───────────────
147    // Flat buffer: (candidate_idx, ordinal, score) — one allocation for all candidates.
148    // Multi-value docs produce multiple entries per candidate_idx.
149    let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
150    let mut total_vectors = 0usize;
151    // Reusable buffers across segments (avoids per-candidate allocation)
152    let mut raw_buf: Vec<u8> = Vec::new();
153    let mut scores_buf: Vec<f32> = Vec::new();
154
155    for (si, candidate_indices) in &segment_groups {
156        let Some(lazy_flat) = segments[*si].flat_vectors().get(&field_id) else {
157            skipped += candidate_indices.len() as u32;
158            continue;
159        };
160        if lazy_flat.dim != query_dim {
161            skipped += candidate_indices.len() as u32;
162            continue;
163        }
164
165        let vbs = lazy_flat.vector_byte_size();
166        let quant = lazy_flat.quantization;
167
168        // Resolve flat indexes for all candidates in this segment
169        // Each entry: (candidate_idx, flat_vector_idx, ordinal)
170        let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
171        for &ci in candidate_indices {
172            let local_doc_id = candidates[ci].doc_id - segments[*si].doc_id_offset();
173            let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
174            if count == 0 {
175                skipped += 1;
176                continue;
177            }
178            for j in 0..count {
179                let (_, ordinal) = lazy_flat.get_doc_id(start + j);
180                resolved.push((ci, start + j, ordinal as u32));
181            }
182        }
183
184        if resolved.is_empty() {
185            continue;
186        }
187
188        let n = resolved.len();
189        total_vectors += n;
190
191        // Sort by flat_idx for sequential mmap access (better page locality)
192        resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
193
194        // Coalesced range read: single mmap read covering [min_idx..max_idx+1],
195        // then selectively copy needed vectors. One page fault instead of N.
196        let first_idx = resolved[0].1;
197        let last_idx = resolved[n - 1].1;
198        let span = last_idx - first_idx + 1;
199
200        raw_buf.resize(n * vbs, 0);
201
202        // Use coalesced read if span waste is reasonable (< 4× the needed count),
203        // otherwise fall back to individual reads for very sparse patterns
204        if span <= n * 4 {
205            let range_bytes = match lazy_flat.read_vectors_batch(first_idx, span).await {
206                Ok(b) => b,
207                Err(_) => continue,
208            };
209            let rb = range_bytes.as_slice();
210            for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
211                let rel = flat_idx - first_idx;
212                let src = &rb[rel * vbs..(rel + 1) * vbs];
213                raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
214            }
215        } else {
216            for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
217                let _ = lazy_flat
218                    .read_vector_raw_into(
219                        flat_idx,
220                        &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
221                    )
222                    .await;
223            }
224        }
225
226        // Single batch SIMD scoring for all vectors in this segment
227        scores_buf.resize(n, 0.0);
228        score_batch(
229            query,
230            &raw_buf[..n * vbs],
231            quant,
232            query_dim,
233            &mut scores_buf[..n],
234            config.unit_norm,
235        );
236
237        // Append (candidate_idx, ordinal, score) to flat buffer
238        all_scores.reserve(n);
239        for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
240            all_scores.push((ci, ordinal, scores_buf[buf_idx]));
241        }
242    }
243
244    let read_score_elapsed = t0.elapsed();
245
246    if total_vectors == 0 {
247        log::debug!(
248            "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
249            field_id,
250            candidates.len()
251        );
252        return Ok(Vec::new());
253    }
254
255    // ── Phase 3: Combine scores and build results ─────────────────────────
256    // Sort flat buffer by candidate_idx so contiguous runs belong to the same doc
257    all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
258
259    let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
260    let mut i = 0;
261    while i < all_scores.len() {
262        let ci = all_scores[i].0;
263        let run_start = i;
264        while i < all_scores.len() && all_scores[i].0 == ci {
265            i += 1;
266        }
267        let run = &mut all_scores[run_start..i];
268
269        // Build (ordinal, score) slice for combiner
270        let ordinal_pairs: Vec<(u32, f32)> = run.iter().map(|&(_, ord, s)| (ord, s)).collect();
271        let combined = config.combiner.combine(&ordinal_pairs);
272
273        // Sort positions by score descending (best chunk first)
274        run.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
275        let positions: Vec<ScoredPosition> = run
276            .iter()
277            .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
278            .collect();
279
280        scored.push(SearchResult {
281            doc_id: candidates[ci].doc_id,
282            score: combined,
283            segment_id: candidates[ci].segment_id,
284            positions: vec![(field_id, positions)],
285        });
286    }
287
288    scored.sort_by(|a, b| {
289        b.score
290            .partial_cmp(&a.score)
291            .unwrap_or(std::cmp::Ordering::Equal)
292    });
293    scored.truncate(final_limit);
294
295    log::debug!(
296        "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
297        field_id,
298        candidates.len(),
299        scored.len(),
300        skipped,
301        total_vectors,
302        config.unit_norm,
303        read_score_elapsed.as_secs_f64() * 1000.0,
304        t0.elapsed().as_secs_f64() * 1000.0,
305    );
306
307    Ok(scored)
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::dsl::{Document, Field};
314
315    fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
316        RerankerConfig {
317            field: Field(0),
318            vector,
319            combiner,
320            unit_norm: false,
321        }
322    }
323
324    #[test]
325    fn test_score_document_single_value() {
326        let mut doc = Document::new();
327        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
328
329        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
330        let (score, positions) = score_document(&doc, &config).unwrap();
331        // cosine([1,0,0], [1,0,0]) = 1.0
332        assert!((score - 1.0).abs() < 1e-6);
333        assert_eq!(positions.len(), 1);
334        assert_eq!(positions[0].position, 0); // ordinal 0
335    }
336
337    #[test]
338    fn test_score_document_orthogonal() {
339        let mut doc = Document::new();
340        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
341
342        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
343        let (score, _) = score_document(&doc, &config).unwrap();
344        // cosine([1,0,0], [0,1,0]) = 0.0
345        assert!(score.abs() < 1e-6);
346    }
347
348    #[test]
349    fn test_score_document_multi_value_max() {
350        let mut doc = Document::new();
351        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0 (same direction)
352        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0 (orthogonal)
353
354        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
355        let (score, positions) = score_document(&doc, &config).unwrap();
356        assert!((score - 1.0).abs() < 1e-6);
357        // Best chunk first
358        assert_eq!(positions.len(), 2);
359        assert_eq!(positions[0].position, 0); // ordinal 0 scored highest
360        assert!((positions[0].score - 1.0).abs() < 1e-6);
361    }
362
363    #[test]
364    fn test_score_document_multi_value_avg() {
365        let mut doc = Document::new();
366        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0
367        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0
368
369        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
370        let (score, _) = score_document(&doc, &config).unwrap();
371        // avg(1.0, 0.0) = 0.5
372        assert!((score - 0.5).abs() < 1e-6);
373    }
374
375    #[test]
376    fn test_score_document_missing_field() {
377        let mut doc = Document::new();
378        // Add to field 1, not field 0
379        doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
380
381        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
382        assert!(score_document(&doc, &config).is_none());
383    }
384
385    #[test]
386    fn test_score_document_wrong_field_type() {
387        let mut doc = Document::new();
388        doc.add_text(Field(0), "not a vector");
389
390        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
391        assert!(score_document(&doc, &config).is_none());
392    }
393
394    #[test]
395    fn test_score_document_dimension_mismatch() {
396        let mut doc = Document::new();
397        doc.add_dense_vector(Field(0), vec![1.0, 0.0]); // 2D
398
399        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); // 3D query
400        assert!(score_document(&doc, &config).is_none());
401    }
402
403    #[test]
404    fn test_score_document_empty_query_vector() {
405        let mut doc = Document::new();
406        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
407
408        let config = make_config(vec![], MultiValueCombiner::Max);
409        // Empty query can't match any stored vector (dimension mismatch)
410        assert!(score_document(&doc, &config).is_none());
411    }
412}