hermes_core/query/
reranker.rs1use crate::dsl::Field;
4use crate::structures::simd::squared_euclidean_distance;
5
6use super::{MultiValueCombiner, SearchResult};
7
8#[derive(Debug, Clone)]
10pub struct RerankerConfig {
11 pub field: Field,
13 pub vector: Vec<f32>,
15 pub combiner: MultiValueCombiner,
17}
18
19fn score_document(doc: &crate::dsl::Document, config: &RerankerConfig) -> Option<f32> {
24 let query_dim = config.vector.len();
25 let values: Vec<(u32, f32)> = doc
26 .get_all(config.field)
27 .enumerate()
28 .filter_map(|(ordinal, fv)| {
29 let vec = fv.as_dense_vector()?;
30 if vec.len() != query_dim {
31 return None;
32 }
33 let dist = squared_euclidean_distance(&config.vector, vec);
34 let score = 1.0 / (1.0 + dist);
35 Some((ordinal as u32, score))
36 })
37 .collect();
38
39 if values.is_empty() {
40 return None;
41 }
42
43 Some(config.combiner.combine(&values))
44}
45
46pub async fn rerank<D: crate::directories::Directory + 'static>(
54 searcher: &crate::index::Searcher<D>,
55 candidates: &[SearchResult],
56 config: &RerankerConfig,
57 final_limit: usize,
58) -> crate::error::Result<Vec<SearchResult>> {
59 if config.vector.is_empty() {
60 return Ok(Vec::new());
61 }
62
63 let doc_futures: Vec<_> = candidates.iter().map(|c| searcher.doc(c.doc_id)).collect();
65 let docs = futures::future::join_all(doc_futures).await;
66
67 let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len());
68 let mut skipped = 0u32;
69
70 for (candidate, doc_result) in candidates.iter().zip(docs) {
71 match doc_result {
72 Ok(Some(doc)) => {
73 if let Some(score) = score_document(&doc, config) {
74 scored.push(SearchResult {
75 doc_id: candidate.doc_id,
76 score,
77 positions: Vec::new(),
78 });
79 } else {
80 skipped += 1;
81 }
82 }
83 _ => {
84 skipped += 1;
85 }
86 }
87 }
88
89 if skipped > 0 {
90 log::debug!(
91 "[reranker] skipped {skipped}/{} candidates (missing/incompatible vector field)",
92 candidates.len()
93 );
94 }
95
96 scored.sort_by(|a, b| {
97 b.score
98 .partial_cmp(&a.score)
99 .unwrap_or(std::cmp::Ordering::Equal)
100 });
101 scored.truncate(final_limit);
102
103 Ok(scored)
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::dsl::{Document, Field};
110
111 fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
112 RerankerConfig {
113 field: Field(0),
114 vector,
115 combiner,
116 }
117 }
118
119 #[test]
120 fn test_score_document_single_value() {
121 let mut doc = Document::new();
122 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
123
124 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
125 let score = score_document(&doc, &config).unwrap();
126 assert!((score - 1.0).abs() < 1e-6);
128 }
129
130 #[test]
131 fn test_score_document_distance_correctness() {
132 let mut doc = Document::new();
133 doc.add_dense_vector(Field(0), vec![3.0, 0.0, 0.0]);
134
135 let config = make_config(vec![0.0, 0.0, 0.0], MultiValueCombiner::Max);
136 let score = score_document(&doc, &config).unwrap();
137 assert!((score - 0.1).abs() < 1e-6);
139 }
140
141 #[test]
142 fn test_score_document_multi_value_max() {
143 let mut doc = Document::new();
144 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![3.0, 0.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
148 let score = score_document(&doc, &config).unwrap();
149 assert!((score - 1.0).abs() < 1e-6);
150 }
151
152 #[test]
153 fn test_score_document_multi_value_avg() {
154 let mut doc = Document::new();
155 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![3.0, 0.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
159 let score = score_document(&doc, &config).unwrap();
160 assert!((score - 0.6).abs() < 1e-6);
162 }
163
164 #[test]
165 fn test_score_document_missing_field() {
166 let mut doc = Document::new();
167 doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
169
170 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
171 assert!(score_document(&doc, &config).is_none());
172 }
173
174 #[test]
175 fn test_score_document_wrong_field_type() {
176 let mut doc = Document::new();
177 doc.add_text(Field(0), "not a vector");
178
179 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
180 assert!(score_document(&doc, &config).is_none());
181 }
182
183 #[test]
184 fn test_score_document_dimension_mismatch() {
185 let mut doc = Document::new();
186 doc.add_dense_vector(Field(0), vec![1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); assert!(score_document(&doc, &config).is_none());
190 }
191
192 #[test]
193 fn test_score_document_empty_query_vector() {
194 let mut doc = Document::new();
195 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
196
197 let config = make_config(vec![], MultiValueCombiner::Max);
198 assert!(score_document(&doc, &config).is_none());
200 }
201}