1use crate::dsl::Field;
4
5use super::{MultiValueCombiner, ScoredPosition, SearchResult};
6
7#[inline]
10fn score_batch(
11 query: &[f32],
12 raw: &[u8],
13 quant: crate::dsl::DenseVectorQuantization,
14 dim: usize,
15 scores: &mut [f32],
16) {
17 use crate::dsl::DenseVectorQuantization;
18 match quant {
19 DenseVectorQuantization::F32 => {
20 let num_floats = scores.len() * dim;
21 let vectors: &[f32] =
22 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
23 crate::structures::simd::batch_cosine_scores(query, vectors, dim, scores);
24 }
25 DenseVectorQuantization::F16 => {
26 crate::structures::simd::batch_cosine_scores_f16(query, raw, dim, scores);
27 }
28 DenseVectorQuantization::UInt8 => {
29 crate::structures::simd::batch_cosine_scores_u8(query, raw, dim, scores);
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct RerankerConfig {
37 pub field: Field,
39 pub vector: Vec<f32>,
41 pub combiner: MultiValueCombiner,
43}
44
45#[cfg(test)]
47use crate::structures::simd::cosine_similarity;
48#[cfg(test)]
49fn score_document(
50 doc: &crate::dsl::Document,
51 config: &RerankerConfig,
52) -> Option<(f32, Vec<ScoredPosition>)> {
53 let query_dim = config.vector.len();
54 let mut values: Vec<(u32, f32)> = doc
55 .get_all(config.field)
56 .filter_map(|fv| fv.as_dense_vector())
57 .enumerate()
58 .filter_map(|(ordinal, vec)| {
59 if vec.len() != query_dim {
60 return None;
61 }
62 let score = cosine_similarity(&config.vector, vec);
63 Some((ordinal as u32, score))
64 })
65 .collect();
66
67 if values.is_empty() {
68 return None;
69 }
70
71 let combined = config.combiner.combine(&values);
72
73 values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
75 let positions: Vec<ScoredPosition> = values
76 .into_iter()
77 .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
78 .collect();
79
80 Some((combined, positions))
81}
82
83pub async fn rerank<D: crate::directories::Directory + 'static>(
95 searcher: &crate::index::Searcher<D>,
96 candidates: &[SearchResult],
97 config: &RerankerConfig,
98 final_limit: usize,
99) -> crate::error::Result<Vec<SearchResult>> {
100 if config.vector.is_empty() || candidates.is_empty() {
101 return Ok(Vec::new());
102 }
103
104 let t0 = std::time::Instant::now();
105 let field_id = config.field.0;
106 let query = &config.vector;
107 let query_dim = query.len();
108 let segments = searcher.segment_readers();
109 let seg_by_id = searcher.segment_map();
110
111 let mut ordinal_scores: Vec<Vec<(u32, f32)>> = vec![Vec::new(); candidates.len()];
114 let mut skipped = 0u32;
115 let mut total_vectors = 0usize;
116
117 for (ci, candidate) in candidates.iter().enumerate() {
118 let Some(&si) = seg_by_id.get(&candidate.segment_id) else {
119 skipped += 1;
120 continue;
121 };
122
123 let local_doc_id = candidate.doc_id - segments[si].doc_id_offset();
124 let Some(lazy_flat) = segments[si].flat_vectors().get(&field_id) else {
125 skipped += 1;
126 continue;
127 };
128
129 if lazy_flat.dim != query_dim {
130 skipped += 1;
131 continue;
132 }
133
134 let (start, entries) = lazy_flat.flat_indexes_for_doc(local_doc_id);
135 if entries.is_empty() {
136 skipped += 1;
137 continue;
138 }
139
140 let count = entries.len();
141 total_vectors += count;
142
143 let batch = match lazy_flat.read_vectors_batch(start, count).await {
145 Ok(b) => b,
146 Err(_) => {
147 skipped += 1;
148 continue;
149 }
150 };
151
152 let raw = batch.as_slice();
153
154 let mut scores = vec![0f32; count];
156 score_batch(query, raw, lazy_flat.quantization, query_dim, &mut scores);
157
158 for (j, &(_doc_id, ordinal)) in entries.iter().enumerate() {
159 ordinal_scores[ci].push((ordinal as u32, scores[j]));
160 }
161 }
162
163 let read_score_elapsed = t0.elapsed();
164
165 if total_vectors == 0 {
166 log::debug!(
167 "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
168 field_id,
169 candidates.len()
170 );
171 return Ok(Vec::new());
172 }
173
174 let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len());
176 for (ci, ordinals) in ordinal_scores.into_iter().enumerate() {
177 if ordinals.is_empty() {
178 continue;
179 }
180 let combined = config.combiner.combine(&ordinals);
181 let mut positions: Vec<ScoredPosition> = ordinals
182 .into_iter()
183 .map(|(ord, score)| ScoredPosition::new(ord, score))
184 .collect();
185 positions.sort_by(|a, b| {
186 b.score
187 .partial_cmp(&a.score)
188 .unwrap_or(std::cmp::Ordering::Equal)
189 });
190 scored.push(SearchResult {
191 doc_id: candidates[ci].doc_id,
192 score: combined,
193 segment_id: candidates[ci].segment_id,
194 positions: vec![(field_id, positions)],
195 });
196 }
197
198 scored.sort_by(|a, b| {
199 b.score
200 .partial_cmp(&a.score)
201 .unwrap_or(std::cmp::Ordering::Equal)
202 });
203 scored.truncate(final_limit);
204
205 log::debug!(
206 "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors): read+score={:.1}ms total={:.1}ms",
207 field_id,
208 candidates.len(),
209 scored.len(),
210 skipped,
211 total_vectors,
212 read_score_elapsed.as_secs_f64() * 1000.0,
213 t0.elapsed().as_secs_f64() * 1000.0,
214 );
215
216 Ok(scored)
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::dsl::{Document, Field};
223
224 fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
225 RerankerConfig {
226 field: Field(0),
227 vector,
228 combiner,
229 }
230 }
231
232 #[test]
233 fn test_score_document_single_value() {
234 let mut doc = Document::new();
235 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
236
237 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
238 let (score, positions) = score_document(&doc, &config).unwrap();
239 assert!((score - 1.0).abs() < 1e-6);
241 assert_eq!(positions.len(), 1);
242 assert_eq!(positions[0].position, 0); }
244
245 #[test]
246 fn test_score_document_orthogonal() {
247 let mut doc = Document::new();
248 doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
249
250 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
251 let (score, _) = score_document(&doc, &config).unwrap();
252 assert!(score.abs() < 1e-6);
254 }
255
256 #[test]
257 fn test_score_document_multi_value_max() {
258 let mut doc = Document::new();
259 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
263 let (score, positions) = score_document(&doc, &config).unwrap();
264 assert!((score - 1.0).abs() < 1e-6);
265 assert_eq!(positions.len(), 2);
267 assert_eq!(positions[0].position, 0); assert!((positions[0].score - 1.0).abs() < 1e-6);
269 }
270
271 #[test]
272 fn test_score_document_multi_value_avg() {
273 let mut doc = Document::new();
274 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
278 let (score, _) = score_document(&doc, &config).unwrap();
279 assert!((score - 0.5).abs() < 1e-6);
281 }
282
283 #[test]
284 fn test_score_document_missing_field() {
285 let mut doc = Document::new();
286 doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
288
289 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
290 assert!(score_document(&doc, &config).is_none());
291 }
292
293 #[test]
294 fn test_score_document_wrong_field_type() {
295 let mut doc = Document::new();
296 doc.add_text(Field(0), "not a vector");
297
298 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
299 assert!(score_document(&doc, &config).is_none());
300 }
301
302 #[test]
303 fn test_score_document_dimension_mismatch() {
304 let mut doc = Document::new();
305 doc.add_dense_vector(Field(0), vec![1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); assert!(score_document(&doc, &config).is_none());
309 }
310
311 #[test]
312 fn test_score_document_empty_query_vector() {
313 let mut doc = Document::new();
314 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
315
316 let config = make_config(vec![], MultiValueCombiner::Max);
317 assert!(score_document(&doc, &config).is_none());
319 }
320}