Skip to main content

hermes_core/query/
reranker.rs

1//! L2 reranker: rerank L1 candidates by exact dense vector distance on stored vectors
2
3use crate::dsl::Field;
4
5use super::{MultiValueCombiner, ScoredPosition, SearchResult};
6
7/// Batch SIMD cosine scoring — dispatches to native-precision scorer by quantization type.
8/// Scores all vectors in one pass without per-vector dequantization overhead.
9#[inline]
10fn score_batch(
11    query: &[f32],
12    raw: &[u8],
13    quant: crate::dsl::DenseVectorQuantization,
14    dim: usize,
15    scores: &mut [f32],
16) {
17    use crate::dsl::DenseVectorQuantization;
18    match quant {
19        DenseVectorQuantization::F32 => {
20            let num_floats = scores.len() * dim;
21            let vectors: &[f32] =
22                unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
23            crate::structures::simd::batch_cosine_scores(query, vectors, dim, scores);
24        }
25        DenseVectorQuantization::F16 => {
26            crate::structures::simd::batch_cosine_scores_f16(query, raw, dim, scores);
27        }
28        DenseVectorQuantization::UInt8 => {
29            crate::structures::simd::batch_cosine_scores_u8(query, raw, dim, scores);
30        }
31    }
32}
33
34/// Configuration for L2 dense vector reranking
35#[derive(Debug, Clone)]
36pub struct RerankerConfig {
37    /// Dense vector field (must be stored)
38    pub field: Field,
39    /// Query vector
40    pub vector: Vec<f32>,
41    /// How to combine scores for multi-valued documents
42    pub combiner: MultiValueCombiner,
43}
44
45/// Score a single document against the query vector (used by tests).
46#[cfg(test)]
47use crate::structures::simd::cosine_similarity;
48#[cfg(test)]
49fn score_document(
50    doc: &crate::dsl::Document,
51    config: &RerankerConfig,
52) -> Option<(f32, Vec<ScoredPosition>)> {
53    let query_dim = config.vector.len();
54    let mut values: Vec<(u32, f32)> = doc
55        .get_all(config.field)
56        .filter_map(|fv| fv.as_dense_vector())
57        .enumerate()
58        .filter_map(|(ordinal, vec)| {
59            if vec.len() != query_dim {
60                return None;
61            }
62            let score = cosine_similarity(&config.vector, vec);
63            Some((ordinal as u32, score))
64        })
65        .collect();
66
67    if values.is_empty() {
68        return None;
69    }
70
71    let combined = config.combiner.combine(&values);
72
73    // Sort ordinals by score descending (best chunk first)
74    values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
75    let positions: Vec<ScoredPosition> = values
76        .into_iter()
77        .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
78        .collect();
79
80    Some((combined, positions))
81}
82
83/// Rerank L1 candidates by exact dense vector distance.
84///
85/// Reads vectors directly from flat vector data (mmap) instead of loading
86/// full documents from the store. This avoids store block decompression
87/// and document deserialization — typically 10-50× faster.
88///
89/// For each candidate, resolves the segment and flat vector index via
90/// binary search, reads the raw vector, dequantizes to f32, and scores
91/// with SIMD cosine similarity.
92///
93/// Documents missing the vector field are skipped.
94pub async fn rerank<D: crate::directories::Directory + 'static>(
95    searcher: &crate::index::Searcher<D>,
96    candidates: &[SearchResult],
97    config: &RerankerConfig,
98    final_limit: usize,
99) -> crate::error::Result<Vec<SearchResult>> {
100    if config.vector.is_empty() || candidates.is_empty() {
101        return Ok(Vec::new());
102    }
103
104    let t0 = std::time::Instant::now();
105    let field_id = config.field.0;
106    let query = &config.vector;
107    let query_dim = query.len();
108    let segments = searcher.segment_readers();
109    let seg_by_id = searcher.segment_map();
110
111    // For each candidate, batch-read all its ordinals in one mmap call
112    // (vectors for the same doc are contiguous in the flat store)
113    let mut ordinal_scores: Vec<Vec<(u32, f32)>> = vec![Vec::new(); candidates.len()];
114    let mut skipped = 0u32;
115    let mut total_vectors = 0usize;
116
117    for (ci, candidate) in candidates.iter().enumerate() {
118        let Some(&si) = seg_by_id.get(&candidate.segment_id) else {
119            skipped += 1;
120            continue;
121        };
122
123        let local_doc_id = candidate.doc_id - segments[si].doc_id_offset();
124        let Some(lazy_flat) = segments[si].flat_vectors().get(&field_id) else {
125            skipped += 1;
126            continue;
127        };
128
129        if lazy_flat.dim != query_dim {
130            skipped += 1;
131            continue;
132        }
133
134        let (start, entries) = lazy_flat.flat_indexes_for_doc(local_doc_id);
135        if entries.is_empty() {
136            skipped += 1;
137            continue;
138        }
139
140        let count = entries.len();
141        total_vectors += count;
142
143        // One batch read for all ordinals of this doc (contiguous in flat store)
144        let batch = match lazy_flat.read_vectors_batch(start, count).await {
145            Ok(b) => b,
146            Err(_) => {
147                skipped += 1;
148                continue;
149            }
150        };
151
152        let raw = batch.as_slice();
153
154        // Batch SIMD scoring: scores all vectors in one pass without per-vector dequantization
155        let mut scores = vec![0f32; count];
156        score_batch(query, raw, lazy_flat.quantization, query_dim, &mut scores);
157
158        for (j, &(_doc_id, ordinal)) in entries.iter().enumerate() {
159            ordinal_scores[ci].push((ordinal as u32, scores[j]));
160        }
161    }
162
163    let read_score_elapsed = t0.elapsed();
164
165    if total_vectors == 0 {
166        log::debug!(
167            "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
168            field_id,
169            candidates.len()
170        );
171        return Ok(Vec::new());
172    }
173
174    // Combine per-candidate ordinal scores and build results
175    let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len());
176    for (ci, ordinals) in ordinal_scores.into_iter().enumerate() {
177        if ordinals.is_empty() {
178            continue;
179        }
180        let combined = config.combiner.combine(&ordinals);
181        let mut positions: Vec<ScoredPosition> = ordinals
182            .into_iter()
183            .map(|(ord, score)| ScoredPosition::new(ord, score))
184            .collect();
185        positions.sort_by(|a, b| {
186            b.score
187                .partial_cmp(&a.score)
188                .unwrap_or(std::cmp::Ordering::Equal)
189        });
190        scored.push(SearchResult {
191            doc_id: candidates[ci].doc_id,
192            score: combined,
193            segment_id: candidates[ci].segment_id,
194            positions: vec![(field_id, positions)],
195        });
196    }
197
198    scored.sort_by(|a, b| {
199        b.score
200            .partial_cmp(&a.score)
201            .unwrap_or(std::cmp::Ordering::Equal)
202    });
203    scored.truncate(final_limit);
204
205    log::debug!(
206        "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors): read+score={:.1}ms total={:.1}ms",
207        field_id,
208        candidates.len(),
209        scored.len(),
210        skipped,
211        total_vectors,
212        read_score_elapsed.as_secs_f64() * 1000.0,
213        t0.elapsed().as_secs_f64() * 1000.0,
214    );
215
216    Ok(scored)
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::dsl::{Document, Field};
223
224    fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
225        RerankerConfig {
226            field: Field(0),
227            vector,
228            combiner,
229        }
230    }
231
232    #[test]
233    fn test_score_document_single_value() {
234        let mut doc = Document::new();
235        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
236
237        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
238        let (score, positions) = score_document(&doc, &config).unwrap();
239        // cosine([1,0,0], [1,0,0]) = 1.0
240        assert!((score - 1.0).abs() < 1e-6);
241        assert_eq!(positions.len(), 1);
242        assert_eq!(positions[0].position, 0); // ordinal 0
243    }
244
245    #[test]
246    fn test_score_document_orthogonal() {
247        let mut doc = Document::new();
248        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
249
250        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
251        let (score, _) = score_document(&doc, &config).unwrap();
252        // cosine([1,0,0], [0,1,0]) = 0.0
253        assert!(score.abs() < 1e-6);
254    }
255
256    #[test]
257    fn test_score_document_multi_value_max() {
258        let mut doc = Document::new();
259        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0 (same direction)
260        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0 (orthogonal)
261
262        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
263        let (score, positions) = score_document(&doc, &config).unwrap();
264        assert!((score - 1.0).abs() < 1e-6);
265        // Best chunk first
266        assert_eq!(positions.len(), 2);
267        assert_eq!(positions[0].position, 0); // ordinal 0 scored highest
268        assert!((positions[0].score - 1.0).abs() < 1e-6);
269    }
270
271    #[test]
272    fn test_score_document_multi_value_avg() {
273        let mut doc = Document::new();
274        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0
275        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0
276
277        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
278        let (score, _) = score_document(&doc, &config).unwrap();
279        // avg(1.0, 0.0) = 0.5
280        assert!((score - 0.5).abs() < 1e-6);
281    }
282
283    #[test]
284    fn test_score_document_missing_field() {
285        let mut doc = Document::new();
286        // Add to field 1, not field 0
287        doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
288
289        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
290        assert!(score_document(&doc, &config).is_none());
291    }
292
293    #[test]
294    fn test_score_document_wrong_field_type() {
295        let mut doc = Document::new();
296        doc.add_text(Field(0), "not a vector");
297
298        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
299        assert!(score_document(&doc, &config).is_none());
300    }
301
302    #[test]
303    fn test_score_document_dimension_mismatch() {
304        let mut doc = Document::new();
305        doc.add_dense_vector(Field(0), vec![1.0, 0.0]); // 2D
306
307        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); // 3D query
308        assert!(score_document(&doc, &config).is_none());
309    }
310
311    #[test]
312    fn test_score_document_empty_query_vector() {
313        let mut doc = Document::new();
314        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
315
316        let config = make_config(vec![], MultiValueCombiner::Max);
317        // Empty query can't match any stored vector (dimension mismatch)
318        assert!(score_document(&doc, &config).is_none());
319    }
320}