Skip to main content

hermes_core/query/
reranker.rs

1//! L2 reranker: rerank L1 candidates by exact dense vector distance on stored vectors
2
3use crate::dsl::Field;
4use crate::structures::simd::squared_euclidean_distance;
5
6use super::{MultiValueCombiner, SearchResult};
7
8/// Configuration for L2 dense vector reranking
9#[derive(Debug, Clone)]
10pub struct RerankerConfig {
11    /// Dense vector field (must be stored)
12    pub field: Field,
13    /// Query vector
14    pub vector: Vec<f32>,
15    /// How to combine scores for multi-valued documents
16    pub combiner: MultiValueCombiner,
17}
18
19/// Score a single document against the query vector.
20///
21/// Returns `None` if the document has no values for the given field,
22/// or if stored vectors have a different dimension than the query.
23fn score_document(doc: &crate::dsl::Document, config: &RerankerConfig) -> Option<f32> {
24    let query_dim = config.vector.len();
25    let values: Vec<(u32, f32)> = doc
26        .get_all(config.field)
27        .enumerate()
28        .filter_map(|(ordinal, fv)| {
29            let vec = fv.as_dense_vector()?;
30            if vec.len() != query_dim {
31                return None;
32            }
33            let dist = squared_euclidean_distance(&config.vector, vec);
34            let score = 1.0 / (1.0 + dist);
35            Some((ordinal as u32, score))
36        })
37        .collect();
38
39    if values.is_empty() {
40        return None;
41    }
42
43    Some(config.combiner.combine(&values))
44}
45
46/// Rerank L1 candidates by exact dense vector distance.
47///
48/// For each candidate, loads the stored document, extracts the dense vector field,
49/// computes squared Euclidean distance, and converts to a similarity score via
50/// `1 / (1 + dist)`. Multi-valued fields are combined using `config.combiner`.
51///
52/// Documents missing the vector field are skipped.
53pub async fn rerank<D: crate::directories::Directory + 'static>(
54    searcher: &crate::index::Searcher<D>,
55    candidates: &[SearchResult],
56    config: &RerankerConfig,
57    final_limit: usize,
58) -> crate::error::Result<Vec<SearchResult>> {
59    if config.vector.is_empty() {
60        return Ok(Vec::new());
61    }
62
63    // Load all candidate documents in parallel
64    let doc_futures: Vec<_> = candidates.iter().map(|c| searcher.doc(c.doc_id)).collect();
65    let docs = futures::future::join_all(doc_futures).await;
66
67    let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len());
68    let mut skipped = 0u32;
69
70    for (candidate, doc_result) in candidates.iter().zip(docs) {
71        match doc_result {
72            Ok(Some(doc)) => {
73                if let Some(score) = score_document(&doc, config) {
74                    scored.push(SearchResult {
75                        doc_id: candidate.doc_id,
76                        score,
77                        positions: Vec::new(),
78                    });
79                } else {
80                    skipped += 1;
81                }
82            }
83            _ => {
84                skipped += 1;
85            }
86        }
87    }
88
89    if skipped > 0 {
90        log::debug!(
91            "[reranker] skipped {skipped}/{} candidates (missing/incompatible vector field)",
92            candidates.len()
93        );
94    }
95
96    scored.sort_by(|a, b| {
97        b.score
98            .partial_cmp(&a.score)
99            .unwrap_or(std::cmp::Ordering::Equal)
100    });
101    scored.truncate(final_limit);
102
103    Ok(scored)
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::dsl::{Document, Field};
110
111    fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
112        RerankerConfig {
113            field: Field(0),
114            vector,
115            combiner,
116        }
117    }
118
119    #[test]
120    fn test_score_document_single_value() {
121        let mut doc = Document::new();
122        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
123
124        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
125        let score = score_document(&doc, &config).unwrap();
126        // Distance = 0, score = 1 / (1 + 0) = 1.0
127        assert!((score - 1.0).abs() < 1e-6);
128    }
129
130    #[test]
131    fn test_score_document_distance_correctness() {
132        let mut doc = Document::new();
133        doc.add_dense_vector(Field(0), vec![3.0, 0.0, 0.0]);
134
135        let config = make_config(vec![0.0, 0.0, 0.0], MultiValueCombiner::Max);
136        let score = score_document(&doc, &config).unwrap();
137        // Distance = 9.0, score = 1 / (1 + 9) = 0.1
138        assert!((score - 0.1).abs() < 1e-6);
139    }
140
141    #[test]
142    fn test_score_document_multi_value_max() {
143        let mut doc = Document::new();
144        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // dist=0, score=1.0
145        doc.add_dense_vector(Field(0), vec![3.0, 0.0, 0.0]); // dist=4, score=0.2
146
147        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
148        let score = score_document(&doc, &config).unwrap();
149        assert!((score - 1.0).abs() < 1e-6);
150    }
151
152    #[test]
153    fn test_score_document_multi_value_avg() {
154        let mut doc = Document::new();
155        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); // dist=0, score=1.0
156        doc.add_dense_vector(Field(0), vec![3.0, 0.0, 0.0]); // dist=4, score=0.2
157
158        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
159        let score = score_document(&doc, &config).unwrap();
160        // avg(1.0, 0.2) = 0.6
161        assert!((score - 0.6).abs() < 1e-6);
162    }
163
164    #[test]
165    fn test_score_document_missing_field() {
166        let mut doc = Document::new();
167        // Add to field 1, not field 0
168        doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
169
170        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
171        assert!(score_document(&doc, &config).is_none());
172    }
173
174    #[test]
175    fn test_score_document_wrong_field_type() {
176        let mut doc = Document::new();
177        doc.add_text(Field(0), "not a vector");
178
179        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
180        assert!(score_document(&doc, &config).is_none());
181    }
182
183    #[test]
184    fn test_score_document_dimension_mismatch() {
185        let mut doc = Document::new();
186        doc.add_dense_vector(Field(0), vec![1.0, 0.0]); // 2D
187
188        let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); // 3D query
189        assert!(score_document(&doc, &config).is_none());
190    }
191
192    #[test]
193    fn test_score_document_empty_query_vector() {
194        let mut doc = Document::new();
195        doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
196
197        let config = make_config(vec![], MultiValueCombiner::Max);
198        // Empty query can't match any stored vector (dimension mismatch)
199        assert!(score_document(&doc, &config).is_none());
200    }
201}