1use rustc_hash::FxHashMap;
11
12use crate::dsl::Field;
13
14use super::{MultiValueCombiner, ScoredPosition, SearchResult};
15
16struct PrecompQuery<'a> {
18 query: &'a [f32],
19 inv_norm_q: f32,
20 query_f16: &'a [u16],
21}
22
23#[inline]
25#[allow(clippy::too_many_arguments)]
26fn score_batch_precomp(
27 pq: &PrecompQuery<'_>,
28 raw: &[u8],
29 quant: crate::dsl::DenseVectorQuantization,
30 dim: usize,
31 scores: &mut [f32],
32 unit_norm: bool,
33) {
34 let query = pq.query;
35 let inv_norm_q = pq.inv_norm_q;
36 let query_f16 = pq.query_f16;
37 use crate::dsl::DenseVectorQuantization;
38 use crate::structures::simd;
39 match (quant, unit_norm) {
40 (DenseVectorQuantization::F32, false) => {
41 let num_floats = scores.len() * dim;
42 let vectors: &[f32] =
43 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
44 simd::batch_cosine_scores_precomp(query, vectors, dim, scores, inv_norm_q);
45 }
46 (DenseVectorQuantization::F32, true) => {
47 let num_floats = scores.len() * dim;
48 let vectors: &[f32] =
49 unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats) };
50 simd::batch_dot_scores_precomp(query, vectors, dim, scores, inv_norm_q);
51 }
52 (DenseVectorQuantization::F16, false) => {
53 simd::batch_cosine_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
54 }
55 (DenseVectorQuantization::F16, true) => {
56 simd::batch_dot_scores_f16_precomp(query_f16, raw, dim, scores, inv_norm_q);
57 }
58 (DenseVectorQuantization::UInt8, false) => {
59 simd::batch_cosine_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
60 }
61 (DenseVectorQuantization::UInt8, true) => {
62 simd::batch_dot_scores_u8_precomp(query, raw, dim, scores, inv_norm_q);
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct RerankerConfig {
70 pub field: Field,
72 pub vector: Vec<f32>,
74 pub combiner: MultiValueCombiner,
76 pub unit_norm: bool,
79}
80
81#[cfg(test)]
83use crate::structures::simd::cosine_similarity;
84#[cfg(test)]
85fn score_document(
86 doc: &crate::dsl::Document,
87 config: &RerankerConfig,
88) -> Option<(f32, Vec<ScoredPosition>)> {
89 let query_dim = config.vector.len();
90 let mut values: Vec<(u32, f32)> = doc
91 .get_all(config.field)
92 .filter_map(|fv| fv.as_dense_vector())
93 .enumerate()
94 .filter_map(|(ordinal, vec)| {
95 if vec.len() != query_dim {
96 return None;
97 }
98 let score = cosine_similarity(&config.vector, vec);
99 Some((ordinal as u32, score))
100 })
101 .collect();
102
103 if values.is_empty() {
104 return None;
105 }
106
107 let combined = config.combiner.combine(&values);
108
109 values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
111 let positions: Vec<ScoredPosition> = values
112 .into_iter()
113 .map(|(ordinal, score)| ScoredPosition::new(ordinal, score))
114 .collect();
115
116 Some((combined, positions))
117}
118
119pub async fn rerank<D: crate::directories::Directory + 'static>(
129 searcher: &crate::index::Searcher<D>,
130 candidates: &[SearchResult],
131 config: &RerankerConfig,
132 final_limit: usize,
133) -> crate::error::Result<Vec<SearchResult>> {
134 if config.vector.is_empty() || candidates.is_empty() {
135 return Ok(Vec::new());
136 }
137
138 let t0 = std::time::Instant::now();
139 let field_id = config.field.0;
140 let query = &config.vector;
141 let query_dim = query.len();
142 let segments = searcher.segment_readers();
143 let seg_by_id = searcher.segment_map();
144
145 use crate::structures::simd;
147 let norm_q_sq = simd::dot_product_f32(query, query, query_dim);
148 let inv_norm_q = if norm_q_sq < f32::EPSILON {
149 0.0
150 } else {
151 simd::fast_inv_sqrt(norm_q_sq)
152 };
153 let query_f16: Vec<u16> = query.iter().map(|&v| simd::f32_to_f16(v)).collect();
154 let pq = PrecompQuery {
155 query,
156 inv_norm_q,
157 query_f16: &query_f16,
158 };
159
160 let mut segment_groups: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
162 let mut skipped = 0u32;
163
164 for (ci, candidate) in candidates.iter().enumerate() {
165 if let Some(&si) = seg_by_id.get(&candidate.segment_id) {
166 segment_groups.entry(si).or_default().push(ci);
167 } else {
168 skipped += 1;
169 }
170 }
171
172 let mut all_scores: Vec<(usize, u32, f32)> = Vec::new();
176 let mut total_vectors = 0usize;
177 let mut raw_buf: Vec<u8> = Vec::new();
179 let mut scores_buf: Vec<f32> = Vec::new();
180
181 for (si, candidate_indices) in &segment_groups {
182 let Some(lazy_flat) = segments[*si].flat_vectors().get(&field_id) else {
183 skipped += candidate_indices.len() as u32;
184 continue;
185 };
186 if lazy_flat.dim != query_dim {
187 skipped += candidate_indices.len() as u32;
188 continue;
189 }
190
191 let vbs = lazy_flat.vector_byte_size();
192 let quant = lazy_flat.quantization;
193
194 let mut resolved: Vec<(usize, usize, u32)> = Vec::new();
197 for &ci in candidate_indices {
198 let local_doc_id = candidates[ci].doc_id - segments[*si].doc_id_offset();
199 let (start, count) = lazy_flat.flat_indexes_for_doc_range(local_doc_id);
200 if count == 0 {
201 skipped += 1;
202 continue;
203 }
204 for j in 0..count {
205 let (_, ordinal) = lazy_flat.get_doc_id(start + j);
206 resolved.push((ci, start + j, ordinal as u32));
207 }
208 }
209
210 if resolved.is_empty() {
211 continue;
212 }
213
214 let n = resolved.len();
215 total_vectors += n;
216
217 resolved.sort_unstable_by_key(|&(_, flat_idx, _)| flat_idx);
219
220 let first_idx = resolved[0].1;
223 let last_idx = resolved[n - 1].1;
224 let span = last_idx - first_idx + 1;
225
226 raw_buf.resize(n * vbs, 0);
227
228 if span <= n * 4 {
231 let range_bytes = match lazy_flat.read_vectors_batch(first_idx, span).await {
232 Ok(b) => b,
233 Err(_) => continue,
234 };
235 let rb = range_bytes.as_slice();
236 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
237 let rel = flat_idx - first_idx;
238 let src = &rb[rel * vbs..(rel + 1) * vbs];
239 raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs].copy_from_slice(src);
240 }
241 } else {
242 for (buf_idx, &(_, flat_idx, _)) in resolved.iter().enumerate() {
243 let _ = lazy_flat
244 .read_vector_raw_into(
245 flat_idx,
246 &mut raw_buf[buf_idx * vbs..(buf_idx + 1) * vbs],
247 )
248 .await;
249 }
250 }
251
252 scores_buf.resize(n, 0.0);
254 score_batch_precomp(
255 &pq,
256 &raw_buf[..n * vbs],
257 quant,
258 query_dim,
259 &mut scores_buf[..n],
260 config.unit_norm,
261 );
262
263 all_scores.reserve(n);
265 for (buf_idx, &(ci, _, ordinal)) in resolved.iter().enumerate() {
266 all_scores.push((ci, ordinal, scores_buf[buf_idx]));
267 }
268 }
269
270 let read_score_elapsed = t0.elapsed();
271
272 if total_vectors == 0 {
273 log::debug!(
274 "[reranker] field {}: {} candidates, all skipped (no flat vectors)",
275 field_id,
276 candidates.len()
277 );
278 return Ok(Vec::new());
279 }
280
281 all_scores.sort_unstable_by_key(|&(ci, _, _)| ci);
284
285 let mut scored: Vec<SearchResult> = Vec::with_capacity(candidates.len().min(final_limit * 2));
286 let mut i = 0;
287 while i < all_scores.len() {
288 let ci = all_scores[i].0;
289 let run_start = i;
290 while i < all_scores.len() && all_scores[i].0 == ci {
291 i += 1;
292 }
293 let run = &mut all_scores[run_start..i];
294
295 let ordinal_pairs: Vec<(u32, f32)> = run.iter().map(|&(_, ord, s)| (ord, s)).collect();
297 let combined = config.combiner.combine(&ordinal_pairs);
298
299 run.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
301 let positions: Vec<ScoredPosition> = run
302 .iter()
303 .map(|&(_, ord, score)| ScoredPosition::new(ord, score))
304 .collect();
305
306 scored.push(SearchResult {
307 doc_id: candidates[ci].doc_id,
308 score: combined,
309 segment_id: candidates[ci].segment_id,
310 positions: vec![(field_id, positions)],
311 });
312 }
313
314 scored.sort_by(|a, b| {
315 b.score
316 .partial_cmp(&a.score)
317 .unwrap_or(std::cmp::Ordering::Equal)
318 });
319 scored.truncate(final_limit);
320
321 log::debug!(
322 "[reranker] field {}: {} candidates -> {} results (skipped {}, {} vectors, unit_norm={}): read+score={:.1}ms total={:.1}ms",
323 field_id,
324 candidates.len(),
325 scored.len(),
326 skipped,
327 total_vectors,
328 config.unit_norm,
329 read_score_elapsed.as_secs_f64() * 1000.0,
330 t0.elapsed().as_secs_f64() * 1000.0,
331 );
332
333 Ok(scored)
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::dsl::{Document, Field};
340
341 fn make_config(vector: Vec<f32>, combiner: MultiValueCombiner) -> RerankerConfig {
342 RerankerConfig {
343 field: Field(0),
344 vector,
345 combiner,
346 unit_norm: false,
347 }
348 }
349
350 #[test]
351 fn test_score_document_single_value() {
352 let mut doc = Document::new();
353 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
354
355 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
356 let (score, positions) = score_document(&doc, &config).unwrap();
357 assert!((score - 1.0).abs() < 1e-6);
359 assert_eq!(positions.len(), 1);
360 assert_eq!(positions[0].position, 0); }
362
363 #[test]
364 fn test_score_document_orthogonal() {
365 let mut doc = Document::new();
366 doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]);
367
368 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
369 let (score, _) = score_document(&doc, &config).unwrap();
370 assert!(score.abs() < 1e-6);
372 }
373
374 #[test]
375 fn test_score_document_multi_value_max() {
376 let mut doc = Document::new();
377 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
381 let (score, positions) = score_document(&doc, &config).unwrap();
382 assert!((score - 1.0).abs() < 1e-6);
383 assert_eq!(positions.len(), 2);
385 assert_eq!(positions[0].position, 0); assert!((positions[0].score - 1.0).abs() < 1e-6);
387 }
388
389 #[test]
390 fn test_score_document_multi_value_avg() {
391 let mut doc = Document::new();
392 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]); doc.add_dense_vector(Field(0), vec![0.0, 1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Avg);
396 let (score, _) = score_document(&doc, &config).unwrap();
397 assert!((score - 0.5).abs() < 1e-6);
399 }
400
401 #[test]
402 fn test_score_document_missing_field() {
403 let mut doc = Document::new();
404 doc.add_dense_vector(Field(1), vec![1.0, 0.0, 0.0]);
406
407 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
408 assert!(score_document(&doc, &config).is_none());
409 }
410
411 #[test]
412 fn test_score_document_wrong_field_type() {
413 let mut doc = Document::new();
414 doc.add_text(Field(0), "not a vector");
415
416 let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max);
417 assert!(score_document(&doc, &config).is_none());
418 }
419
420 #[test]
421 fn test_score_document_dimension_mismatch() {
422 let mut doc = Document::new();
423 doc.add_dense_vector(Field(0), vec![1.0, 0.0]); let config = make_config(vec![1.0, 0.0, 0.0], MultiValueCombiner::Max); assert!(score_document(&doc, &config).is_none());
427 }
428
429 #[test]
430 fn test_score_document_empty_query_vector() {
431 let mut doc = Document::new();
432 doc.add_dense_vector(Field(0), vec![1.0, 0.0, 0.0]);
433
434 let config = make_config(vec![], MultiValueCombiner::Max);
435 assert!(score_document(&doc, &config).is_none());
437 }
438}