1use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16#[inline]
18fn score_batch(
19 query: &[f32],
20 raw: &[u8],
21 quant: crate::dsl::DenseVectorQuantization,
22 dim: usize,
23 scores: &mut [f32],
24 unit_norm: bool,
25) {
26 use crate::dsl::DenseVectorQuantization;
27 use crate::structures::simd;
28 match (quant, unit_norm) {
29 (DenseVectorQuantization::F32, false) => {
30 let num_floats = scores.len() * dim;
31 let vectors: &[f32] =
32 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
33 simd::batch_cosine_scores(query, vectors, dim, scores);
34 }
35 (DenseVectorQuantization::F32, true) => {
36 let num_floats = scores.len() * dim;
37 let vectors: &[f32] =
38 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
39 simd::batch_dot_scores(query, vectors, dim, scores);
40 }
41 (DenseVectorQuantization::F16, false) => {
42 simd::batch_cosine_scores_f16(query, raw, dim, scores);
43 }
44 (DenseVectorQuantization::F16, true) => {
45 simd::batch_dot_scores_f16(query, raw, dim, scores);
46 }
47 (DenseVectorQuantization::UInt8, false) => {
48 simd::batch_cosine_scores_u8(query, raw, dim, scores);
49 }
50 (DenseVectorQuantization::UInt8, true) => {
51 simd::batch_dot_scores_u8(query, raw, dim, scores);
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct RerankerConfig {
59 pub field: Field,
61 pub vector: Vec<f32>,
63 pub combiner: MultiValueCombiner,
65 pub unit_norm: bool,
68}
69
70#[cfg(test)]
72use crate::structures::simd::cosine_similarity;
73#[cfg(test)]
74fn score_document(
75 doc: &crate::dsl::Document,
76 config: &RerankerConfig,
77) -> Option<(f32, Vec<ScoredPosition>)> {
78 let query_dim = config.vector.len();
79 let mut values: Vec<(u32, f32)> = doc
80 .get_all(config.field)
81 .filter_map(|fv| fv.as_dense_vector())
82 .enumerate()
83 .filter_map(|(ordinal, vec)| {
84 if vec.len() != query_dim {
85 return None;
86 }
87 let score = cosine_similarity(&config.vector, vec);
88 Some((ordinal as u32, score))
89 })
90 .collect();
91
92 if values.is_empty() {
93 return None;
94 }
95
96 let combined = config.combiner.combine(&values);
97
98 values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
100 let positions: Vec<ScoredPosition> = values
101 .into_iter()
102 .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
103 .collect();
104
105 Some((combined, positions))
106}
107
108pub async fn rerank<D: crate::directories::Directory + 'static>(
118 searcher: &crate::index::Searcher<D>,
119 candidates: &[SearchResult],
120 config: &RerankerConfig,
121 final_limit: usize,
122) -> crate::error::Result<Vec<SearchResult>> {
123 if config.vector.is_empty() || candidates.is_empty() {
124 return Ok(Vec::new());
125 }
126
127 let t0 = std::time::Instant::now();
128 let field_id = config.field.0;
129 let query = &config.vector;
130 let query_dim = query.len();
131 let segments = searcher.segment_readers();
132 let seg_by_id = searcher.segment_map();
133
134 let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
136 let mut skipped = 0u32;
137
138 for (ci, candidate) in candidates.iter().enumerate() {
139 if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
140 segment_groups.entry(si).or_default().push(ci);
141 } else {
142 skipped += 1;
143 }
144 }
145
146 let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
150 let mut total_vectors = 0usize;
151 let mut raw_buf: Vec<u8> = Vec::new();
153 let mut scores_buf: Vec<f32> = Vec::new();
154
155 for (si, candidate_indices) in &segment_groups {
156 let Some(lazy_flat) = segments[*si].flat_vectors().get(&field_id) else {
157 skipped += candidate_indices.len() as u32;
158 continue;
159 };
160 if lazy_flat.dim != query_dim {
161 skipped += candidate_indices.len() as u32;
162 continue;
163 }
164
165 let vbs = lazy_flat.vector_byte_size();
166 let quant = lazy_flat.quantization;
167
168 let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
171 for &ci in candidate_indices {
172 let local_doc_id = candidates[ci].doc_id - segments[*si].doc_id_offset();
173 let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
174 if count == 0 {
175 skipped += 1;
176 continue;
177 }
178 for j in 0..count {
179 let (_, ordinal) = lazy_flat.get_doc_id(start + j);
180 resolved.push((ci, start + j, ordinal as u32));
181 }
182 }
183
184 if resolved.is_empty() {
185 continue;
186 }
187
188 let n = resolved.len();
189 total_vectors += n;
190
191 resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
193
194 let first_idx = resolved[0].1;
197 let last_idx = resolved[n - 1].1;
198 let span = last_idx - first_idx + 1;
199
200 raw_buf.resize(n * vbs, 0);
201
202 if span <= n * 4 {
205 let range_bytes = match lazy_flat.read_vectors_batch(first_idx, span).await {
206 Ok(b) => b,
207 Err(_) => continue,
208 };
209 let rb = range_bytes.as_slice();
210 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
211 let rel = flat_idx - first_idx;
212 let src = &rb[rel * vbs..(rel + 1) * vbs];
213 raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
214 }
215 } else {
216 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
217 let _ = lazy_flat
218 .read_vector_raw_into(
219 flat_idx,
220 &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
221 )
222 .await;
223 }
224 }
225
226 scores_buf.resize(n, 0.0);
228 score_batch(
229 query,
230 &raw_buf[..n * vbs],
231 quant,
232 query_dim,
233 &mut scores_buf[..n],
234 config.unit_norm,
235 );
236
237 all_scores.reserve(n);
239 for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
240 all_scores.push((ci, ordinal, scores_buf[buf_idx]));
241 }
242 }
243
244 let read_score_elapsed = t0.elapsed();
245
246 if total_vectors == 0 {
247 log::debug!(
248 "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
249 field_id,
250 candidates.len()
251 );
252 return Ok(Vec::new());
253 }
254
255 all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
258
259 let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
260 let mut i = 0;
261 while i < all_scores.len() {
262 let ci = all_scores[i].0;
263 let run_start = i;
264 while i < all_scores.len() && all_scores[i].0 == ci {
265 i += 1;
266 }
267 let run = &mut all_scores[run_start..i];
268
269 let ordinal_pairs: Vec<(u32, f32)> = run.iter().map(|&(_, ord, s)| (ord, s)).collect();
271 let combined = config.combiner.combine(&ordinal_pairs);
272
273 run.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
275 let positions: Vec<ScoredPosition> = run
276 .iter()
277 .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
278 .collect();
279
280 scored.push(SearchResult {
281 doc_id: candidates[ci].doc_id,
282 score: combined,
283 segment_id: candidates[ci].segment_id,
284 positions: vec![(field_id, positions)],
285 });
286 }
287
288 scored.sort_by(|a, b| {
289 b.score
290 .partial_cmp(&a.score)
291 .unwrap_or(std::cmp::Ordering::Equal)
292 });
293 scored.truncate(final_limit);
294
295 log::debug!(
296 "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
297 field_id,
298 candidates.len(),
299 scored.len(),
300 skipped,
301 total_vectors,
302 config.unit_norm,
303 read_score_elapsed.as_secs_f64() * 1000.0,
304 t0.elapsed().as_secs_f64() * 1000.0,
305 );
306
307 Ok(scored)
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::dsl::{Document, Field};
314
315 fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
316 RerankerConfig {
317 field: Field(0),
318 vector,
319 combiner,
320 unit_norm: false,
321 }
322 }
323
324 #[test]
325 fn test_score_document_single_value() {
326 let mut doc = Document::new();
327 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
328
329 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
330 let (score, positions) = score_document(&doc, &config).unwrap();
331 assert!((score - 1.0).abs() < 1e-6);
333 assert_eq!(positions.len(), 1);
334 assert_eq!(positions[0].position, 0); }
336
337 #[test]
338 fn test_score_document_orthogonal() {
339 let mut doc = Document::new();
340 doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
341
342 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
343 let (score, _) = score_document(&doc, &config).unwrap();
344 assert!(score.abs() < 1e-6);
346 }
347
348 #[test]
349 fn test_score_document_multi_value_max() {
350 let mut doc = Document::new();
351 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
355 let (score, positions) = score_document(&doc, &config).unwrap();
356 assert!((score - 1.0).abs() < 1e-6);
357 assert_eq!(positions.len(), 2);
359 assert_eq!(positions[0].position, 0); assert!((positions[0].score - 1.0).abs() < 1e-6);
361 }
362
363 #[test]
364 fn test_score_document_multi_value_avg() {
365 let mut doc = Document::new();
366 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
370 let (score, _) = score_document(&doc, &config).unwrap();
371 assert!((score - 0.5).abs() < 1e-6);
373 }
374
375 #[test]
376 fn test_score_document_missing_field() {
377 let mut doc = Document::new();
378 doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
380
381 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
382 assert!(score_document(&doc, &config).is_none());
383 }
384
385 #[test]
386 fn test_score_document_wrong_field_type() {
387 let mut doc = Document::new();
388 doc.add_text(Field(0), "not a vector");
389
390 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
391 assert!(score_document(&doc, &config).is_none());
392 }
393
394 #[test]
395 fn test_score_document_dimension_mismatch() {
396 let mut doc = Document::new();
397 doc.add_dense_vector(Field(0), vec![1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); assert!(score_document(&doc, &config).is_none());
401 }
402
403 #[test]
404 fn test_score_document_empty_query_vector() {
405 let mut doc = Document::new();
406 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
407
408 let config = make_config(vec![], MultiValueCombiner::Max);
409 assert!(score_document(&doc, &config).is_none());
411 }
412}