1use crate::dsl::Field;
4use crate::structures::simd::cosine_similarity;
5
6use super::{MultiValueCombiner, ScoredPosition, SearchResult};
7
8#[derive(Debug, Clone)]
10pub struct RerankerConfig {
11 pub field: Field,
13 pub vector: Vec<f32>,
15 pub combiner: MultiValueCombiner,
17}
18
19fn score_document(
27 doc: &crate::dsl::Document,
28 config: &RerankerConfig,
29) -> Option<(f32, Vec<ScoredPosition>)> {
30 let query_dim = config.vector.len();
31 let mut values: Vec<(u32, f32)> = doc
32 .get_all(config.field)
33 .filter_map(|fv| fv.as_dense_vector())
34 .enumerate()
35 .filter_map(|(ordinal, vec)| {
36 if vec.len() != query_dim {
37 return None;
38 }
39 let score = cosine_similarity(&config.vector, vec);
40 Some((ordinal as u32, score))
41 })
42 .collect();
43
44 if values.is_empty() {
45 return None;
46 }
47
48 let combined = config.combiner.combine(&values);
49
50 values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
52 let positions: Vec<ScoredPosition> = values
53 .into_iter()
54 .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
55 .collect();
56
57 Some((combined, positions))
58}
59
60pub async fn rerank<D: crate::directories::Directory + 'static>(
68 searcher: &crate::index::Searcher<D>,
69 candidates: &[SearchResult],
70 config: &RerankerConfig,
71 final_limit: usize,
72) -> crate::error::Result<Vec<SearchResult>> {
73 if config.vector.is_empty() {
74 return Ok(Vec::new());
75 }
76
77 let doc_futures: Vec<_> = candidates.iter().map(|c| searcher.doc(c.doc_id)).collect();
79 let docs = futures::future::join_all(doc_futures).await;
80
81 let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len());
82 let mut skipped = 0u32;
83
84 let field_id = config.field.0;
85
86 for (candidate, doc_result) in candidates.iter().zip(docs) {
87 match doc_result {
88 Ok(Some(doc)) => {
89 if let Some((score, ordinal_positions)) = score_document(&doc, config) {
90 scored.push(SearchResult {
91 doc_id: candidate.doc_id,
92 score,
93 positions: vec![(field_id, ordinal_positions)],
94 });
95 } else {
96 skipped += 1;
97 }
98 }
99 _ => {
100 skipped += 1;
101 }
102 }
103 }
104
105 if skipped > 0 {
106 log::debug!(
107 "[reranker] skipped {skipped}/{} candidates (missing/incompatible vector field)",
108 candidates.len()
109 );
110 }
111
112 scored.sort_by(|a, b| {
113 b.score
114 .partial_cmp(&a.score)
115 .unwrap_or(std::cmp::Ordering::Equal)
116 });
117 scored.truncate(final_limit);
118
119 Ok(scored)
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::dsl::{Document, Field};
126
127 fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
128 RerankerConfig {
129 field: Field(0),
130 vector,
131 combiner,
132 }
133 }
134
135 #[test]
136 fn test_score_document_single_value() {
137 let mut doc = Document::new();
138 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
139
140 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
141 let (score, positions) = score_document(&doc, &config).unwrap();
142 assert!((score - 1.0).abs() < 1e-6);
144 assert_eq!(positions.len(), 1);
145 assert_eq!(positions[0].position, 0); }
147
148 #[test]
149 fn test_score_document_orthogonal() {
150 let mut doc = Document::new();
151 doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
152
153 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
154 let (score, _) = score_document(&doc, &config).unwrap();
155 assert!(score.abs() < 1e-6);
157 }
158
159 #[test]
160 fn test_score_document_multi_value_max() {
161 let mut doc = Document::new();
162 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
166 let (score, positions) = score_document(&doc, &config).unwrap();
167 assert!((score - 1.0).abs() < 1e-6);
168 assert_eq!(positions.len(), 2);
170 assert_eq!(positions[0].position, 0); assert!((positions[0].score - 1.0).abs() < 1e-6);
172 }
173
174 #[test]
175 fn test_score_document_multi_value_avg() {
176 let mut doc = Document::new();
177 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
181 let (score, _) = score_document(&doc, &config).unwrap();
182 assert!((score - 0.5).abs() < 1e-6);
184 }
185
186 #[test]
187 fn test_score_document_missing_field() {
188 let mut doc = Document::new();
189 doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
191
192 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
193 assert!(score_document(&doc, &config).is_none());
194 }
195
196 #[test]
197 fn test_score_document_wrong_field_type() {
198 let mut doc = Document::new();
199 doc.add_text(Field(0), "not a vector");
200
201 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
202 assert!(score_document(&doc, &config).is_none());
203 }
204
205 #[test]
206 fn test_score_document_dimension_mismatch() {
207 let mut doc = Document::new();
208 doc.add_dense_vector(Field(0), vec![1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); assert!(score_document(&doc, &config).is_none());
212 }
213
214 #[test]
215 fn test_score_document_empty_query_vector() {
216 let mut doc = Document::new();
217 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
218
219 let config = make_config(vec![], MultiValueCombiner::Max);
220 assert!(score_document(&doc, &config).is_none());
222 }
223}