Skip to main content

hermes_core/query/
reranker.rs

1//! L2 reranker: rerank L1 candidates by exact dense vector distance on stored vectors
2
3use crate::dsl::Field;
4use crate::structures::simd::cosine_similarity;
5
6use super::{MultiValueCombiner, ScoredPosition, SearchResult};
7
8/// Configuration for L2 dense vector reranking
9#[derive(Debug, Clone)]
10pub struct RerankerConfig {
11    /// Dense vector field (must be stored)
12    pub field: Field,
13    /// Query vector
14    pub vector: Vec<f32>,
15    /// How to combine scores for multi-valued documents
16    pub combiner: MultiValueCombiner,
17}
18
19/// Score a single document against the query vector.
20///
21/// Returns `None` if the document has no values for the given field,
22/// or if stored vectors have a different dimension than the query.
23///
24/// Returns `(combined_score, ordinal_scores)` where ordinal_scores contains
25/// per-vector scores sorted by score descending (best chunk first).
26fn score_document(
27    doc: &crate::dsl::Document,
28    config: &RerankerConfig,
29) -> Option<(f32, Vec<ScoredPosition>)> {
30    let query_dim = config.vector.len();
31    let mut values: Vec<(u32, f32)> = doc
32        .get_all(config.field)
33        .filter_map(|fv| fv.as_dense_vector())
34        .enumerate()
35        .filter_map(|(ordinal, vec)| {
36            if vec.len() != query_dim {
37                return None;
38            }
39            let score = cosine_similarity(&config.vector, vec);
40            Some((ordinal as u32, score))
41        })
42        .collect();
43
44    if values.is_empty() {
45        return None;
46    }
47
48    let combined = config.combiner.combine(&values);
49
50    // Sort ordinals by score descending (best chunk first)
51    values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
52    let positions: Vec<ScoredPosition> = values
53        .into_iter()
54        .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
55        .collect();
56
57    Some((combined, positions))
58}
59
60/// Rerank L1 candidates by exact dense vector distance.
61///
62/// For each candidate, loads the stored document, extracts the dense vector field,
63/// computes squared Euclidean distance, and converts to a similarity score via
64/// `1 / (1 + dist)`. Multi-valued fields are combined using `config.combiner`.
65///
66/// Documents missing the vector field are skipped.
67pub async fn rerank<D: crate::directories::Directory + 'static>(
68    searcher: &crate::index::Searcher<D>,
69    candidates: &[SearchResult],
70    config: &RerankerConfig,
71    final_limit: usize,
72) -> crate::error::Result<Vec<SearchResult>> {
73    if config.vector.is_empty() {
74        return Ok(Vec::new());
75    }
76
77    // Load all candidate documents in parallel
78    let doc_futures: Vec<_> = candidates.iter().map(|c| searcher.doc(c.doc_id)).collect();
79    let docs = futures::future::join_all(doc_futures).await;
80
81    let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len());
82    let mut skipped = 0u32;
83
84    let field_id = config.field.0;
85
86    for (candidate, doc_result) in candidates.iter().zip(docs) {
87        match doc_result {
88            Ok(Some(doc)) => {
89                if let Some((score, ordinal_positions)) = score_document(&doc, config) {
90                    scored.push(SearchResult {
91                        doc_id: candidate.doc_id,
92                        score,
93                        positions: vec![(field_id, ordinal_positions)],
94                    });
95                } else {
96                    skipped += 1;
97                }
98            }
99            _ => {
100                skipped += 1;
101            }
102        }
103    }
104
105    if skipped > 0 {
106        log::debug!(
107            "[reranker] skipped {skipped}/{} candidates (missing/incompatible vector field)",
108            candidates.len()
109        );
110    }
111
112    scored.sort_by(|a, b| {
113        b.score
114            .partial_cmp(&a.score)
115            .unwrap_or(std::cmp::Ordering::Equal)
116    });
117    scored.truncate(final_limit);
118
119    Ok(scored)
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::dsl::{Document, Field};
126
127    fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
128        RerankerConfig {
129            field: Field(0),
130            vector,
131            combiner,
132        }
133    }
134
135    #[test]
136    fn test_score_document_single_value() {
137        let mut doc = Document::new();
138        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
139
140        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
141        let (score, positions) = score_document(&doc, &config).unwrap();
142        // cosine([1,0,0], [1,0,0]) = 1.0
143        assert!((score - 1.0).abs() < 1e-6);
144        assert_eq!(positions.len(), 1);
145        assert_eq!(positions[0].position, 0); // ordinal 0
146    }
147
148    #[test]
149    fn test_score_document_orthogonal() {
150        let mut doc = Document::new();
151        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
152
153        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
154        let (score, _) = score_document(&doc, &config).unwrap();
155        // cosine([1,0,0], [0,1,0]) = 0.0
156        assert!(score.abs() < 1e-6);
157    }
158
159    #[test]
160    fn test_score_document_multi_value_max() {
161        let mut doc = Document::new();
162        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0 (same direction)
163        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0 (orthogonal)
164
165        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
166        let (score, positions) = score_document(&doc, &config).unwrap();
167        assert!((score - 1.0).abs() < 1e-6);
168        // Best chunk first
169        assert_eq!(positions.len(), 2);
170        assert_eq!(positions[0].position, 0); // ordinal 0 scored highest
171        assert!((positions[0].score - 1.0).abs() < 1e-6);
172    }
173
174    #[test]
175    fn test_score_document_multi_value_avg() {
176        let mut doc = Document::new();
177        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // cos=1.0
178        doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); // cos=0.0
179
180        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
181        let (score, _) = score_document(&doc, &config).unwrap();
182        // avg(1.0, 0.0) = 0.5
183        assert!((score - 0.5).abs() < 1e-6);
184    }
185
186    #[test]
187    fn test_score_document_missing_field() {
188        let mut doc = Document::new();
189        // Add to field 1, not field 0
190        doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
191
192        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
193        assert!(score_document(&doc, &config).is_none());
194    }
195
196    #[test]
197    fn test_score_document_wrong_field_type() {
198        let mut doc = Document::new();
199        doc.add_text(Field(0), "not a vector");
200
201        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
202        assert!(score_document(&doc, &config).is_none());
203    }
204
205    #[test]
206    fn test_score_document_dimension_mismatch() {
207        let mut doc = Document::new();
208        doc.add_dense_vector(Field(0), vec![1.0, 0.0]); // 2D
209
210        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); // 3D query
211        assert!(score_document(&doc, &config).is_none());
212    }
213
214    #[test]
215    fn test_score_document_empty_query_vector() {
216        let mut doc = Document::new();
217        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
218
219        let config = make_config(vec![], MultiValueCombiner::Max);
220        // Empty query can't match any stored vector (dimension mismatch)
221        assert!(score_document(&doc, &config).is_none());
222    }
223}