1use nodedb_types::vector_ann::VectorAnnOptions;
4use nodedb_types::vector_distance::DistanceMetric;
5
6use super::gating::codec_name_for_quant;
7use super::sidecar::CodecSidecar;
8use super::types::{Candidate, Ranked, RerankError};
9
10pub fn rerank<'v, F>(
29 candidates: Vec<Candidate>,
30 query: &[f32],
31 metric: DistanceMetric,
32 k: usize,
33 opts: &VectorAnnOptions,
34 sidecar: Option<&CodecSidecar>,
35 mut fetch_vector: F,
36) -> Result<Vec<Ranked>, RerankError>
37where
38 F: FnMut(u32) -> Option<&'v [f32]>,
39{
40 if k == 0 {
41 return Err(RerankError::BadInput("k must be > 0".into()));
42 }
43 if query.is_empty() {
44 return Err(RerankError::BadInput("query is empty".into()));
45 }
46
47 let requested_codec = opts.quantization.and_then(codec_name_for_quant);
49
50 if opts.query_dim.is_some() && requested_codec.is_some() {
52 return Err(RerankError::BadInput(
53 "rerank: query_dim (Matryoshka truncation) is not yet supported in combination \
54 with quantization codecs — use one or the other"
55 .into(),
56 ));
57 }
58
59 if candidates.is_empty() {
60 return Ok(Vec::new());
61 }
62
63 if let Some(requested) = requested_codec {
65 let sc = sidecar.ok_or_else(|| {
66 RerankError::BadInput(
67 "rerank: opts.quantization requested but no codec sidecar provided".into(),
68 )
69 })?;
70
71 let actual = sc.codec_name();
72 if actual != requested {
73 return Err(RerankError::BadInput(format!(
74 "rerank: requested codec {requested:?} does not match sidecar codec {actual:?}"
75 )));
76 }
77
78 let prepared = sc.prepare_query(query)?;
79
80 let mut scored: Vec<Ranked> = Vec::with_capacity(candidates.len());
81 for c in candidates {
82 match sc.distance_prepared(&prepared, c.id)? {
83 None => {
84 return Err(RerankError::BadInput(format!(
85 "rerank: candidate id {} not present in sidecar (index/sidecar drift)",
86 c.id
87 )));
88 }
89 Some(d) => {
90 scored.push(Ranked {
91 id: c.id,
92 distance: d,
93 });
94 }
95 }
96 }
97
98 scored.sort_unstable_by(|a, b| {
99 a.distance
100 .partial_cmp(&b.distance)
101 .unwrap_or(std::cmp::Ordering::Equal)
102 });
103 scored.truncate(k);
104
105 let _ = &mut fetch_vector;
107 return Ok(scored);
108 }
109
110 let effective_dim: usize = match opts.query_dim {
112 Some(d) => {
113 let d = d as usize;
114 if d == 0 || d > query.len() {
115 return Err(RerankError::BadInput(format!(
116 "query_dim={d} is out of range; query has {} dimensions \
117 (must be 0 < query_dim <= query.len())",
118 query.len(),
119 )));
120 }
121 d
122 }
123 None => query.len(),
124 };
125
126 let query_slice = crate::matryoshka::truncate(query, effective_dim);
128
129 let mut scored: Vec<Ranked> = Vec::with_capacity(candidates.len());
130 let query_dim = query.len();
131
132 for c in candidates {
133 let vec = fetch_vector(c.id).ok_or_else(|| {
134 RerankError::BadInput(format!(
135 "rerank: fetch_vector returned None for id {}",
136 c.id
137 ))
138 })?;
139 if vec.len() != query_dim {
140 return Err(RerankError::BadInput(format!(
141 "candidate id={} has dim {} but query has dim {}",
142 c.id,
143 vec.len(),
144 query_dim,
145 )));
146 }
147 let vec_slice = crate::matryoshka::truncate(vec, effective_dim);
148 let d = crate::distance::distance(query_slice, vec_slice, metric);
149 scored.push(Ranked {
150 id: c.id,
151 distance: d,
152 });
153 }
154
155 scored.sort_unstable_by(|a, b| {
156 a.distance
157 .partial_cmp(&b.distance)
158 .unwrap_or(std::cmp::Ordering::Equal)
159 });
160 scored.truncate(k);
161
162 Ok(scored)
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use std::collections::HashMap;
169 use std::sync::Arc;
170
171 use nodedb_types::vector_ann::{VectorAnnOptions, VectorQuantization};
172
173 use crate::rerank::codec::{CodecName, PreparedQuery, RerankCodec};
174 use crate::rerank::sidecar::CodecSidecar;
175 use crate::rerank::types::RerankError;
176
177 fn opts() -> VectorAnnOptions {
180 VectorAnnOptions::default()
181 }
182
183 fn opts_with_dim(d: u32) -> VectorAnnOptions {
184 VectorAnnOptions {
185 query_dim: Some(d),
186 ..Default::default()
187 }
188 }
189
190 fn opts_with_quant(q: VectorQuantization) -> VectorAnnOptions {
191 VectorAnnOptions {
192 quantization: Some(q),
193 ..Default::default()
194 }
195 }
196
197 fn make(id: u32) -> Candidate {
198 Candidate {
199 id,
200 index_distance: 0.0,
201 }
202 }
203
204 fn store(pairs: &[(u32, Vec<f32>)]) -> HashMap<u32, Vec<f32>> {
205 pairs.iter().cloned().collect()
206 }
207
208 fn fetch<'a>(store: &'a HashMap<u32, Vec<f32>>) -> impl FnMut(u32) -> Option<&'a [f32]> {
209 move |id| store.get(&id).map(|v| v.as_slice())
210 }
211
212 struct StubCodec {
215 name: CodecName,
216 }
217
218 impl RerankCodec for StubCodec {
219 fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
220 Ok(v.iter().flat_map(|x| x.to_le_bytes()).collect())
221 }
222
223 fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
224 Ok(PreparedQuery::Raw(q.to_vec()))
225 }
226
227 fn distance_prepared(
228 &self,
229 prepared: &PreparedQuery,
230 encoded: &[u8],
231 ) -> Result<f32, RerankError> {
232 let query = match prepared {
233 PreparedQuery::Raw(v) => v,
234 _ => {
235 return Err(RerankError::BadInput(
236 "StubCodec expects Raw prepared query".into(),
237 ));
238 }
239 };
240 let floats: Vec<f32> = encoded
241 .chunks_exact(4)
242 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
243 .collect();
244 if query.len() != floats.len() {
245 return Err(RerankError::BadInput("dimension mismatch".into()));
246 }
247 let d = query
248 .iter()
249 .zip(floats.iter())
250 .map(|(a, b)| (a - b) * (a - b))
251 .sum::<f32>()
252 .sqrt();
253 Ok(d)
254 }
255
256 fn name(&self) -> CodecName {
257 self.name
258 }
259
260 fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
261 Err(RerankError::BadInput(
262 "StubCodec does not support serialization".into(),
263 ))
264 }
265 }
266
267 fn make_sidecar(name: CodecName) -> CodecSidecar {
268 CodecSidecar::new(Arc::new(StubCodec { name }))
269 }
270
271 #[test]
274 fn happy_path_top2() {
275 let s = store(&[
276 (1, vec![1.0, 0.0]),
277 (2, vec![0.1, 0.0]),
278 (3, vec![0.5, 0.0]),
279 (4, vec![2.0, 0.0]),
280 (5, vec![0.3, 0.0]),
281 ]);
282 let candidates = vec![make(1), make(2), make(3), make(4), make(5)];
283 let query = [0.0, 0.0];
284 let result = rerank(
285 candidates,
286 &query,
287 DistanceMetric::L2,
288 2,
289 &opts(),
290 None,
291 fetch(&s),
292 )
293 .unwrap();
294 assert_eq!(result.len(), 2);
295 assert_eq!(result[0].id, 2);
297 assert_eq!(result[1].id, 5);
298 }
299
300 #[test]
301 fn empty_candidates_returns_empty() {
302 let s: HashMap<u32, Vec<f32>> = HashMap::new();
303 let result = rerank(
304 vec![],
305 &[1.0, 2.0],
306 DistanceMetric::L2,
307 3,
308 &opts(),
309 None,
310 fetch(&s),
311 )
312 .unwrap();
313 assert!(result.is_empty());
314 }
315
316 #[test]
317 fn dim_mismatch_returns_bad_input() {
318 let s = store(&[(7, vec![1.0, 2.0, 3.0])]);
319 let err = rerank(
320 vec![make(7)],
321 &[1.0, 2.0],
322 DistanceMetric::L2,
323 1,
324 &opts(),
325 None,
326 fetch(&s),
327 )
328 .unwrap_err();
329 let msg = err.to_string();
330 assert!(msg.contains("7"), "expected id in message: {msg}");
331 assert!(
332 msg.contains("3"),
333 "expected candidate dim in message: {msg}"
334 );
335 assert!(msg.contains("2"), "expected query dim in message: {msg}");
336 }
337
338 #[test]
339 fn k_zero_returns_bad_input() {
340 let s = store(&[(1, vec![1.0])]);
341 let err = rerank(
342 vec![make(1)],
343 &[0.0],
344 DistanceMetric::L2,
345 0,
346 &opts(),
347 None,
348 fetch(&s),
349 )
350 .unwrap_err();
351 assert!(err.to_string().contains("k must be > 0"));
352 }
353
354 #[test]
355 fn k_exceeds_candidates_returns_all() {
356 let s = store(&[
357 (1, vec![1.0, 0.0]),
358 (2, vec![2.0, 0.0]),
359 (3, vec![3.0, 0.0]),
360 ]);
361 let candidates = vec![make(1), make(2), make(3)];
362 let result = rerank(
363 candidates,
364 &[0.0, 0.0],
365 DistanceMetric::L2,
366 10,
367 &opts(),
368 None,
369 fetch(&s),
370 )
371 .unwrap();
372 assert_eq!(result.len(), 3);
373 }
374
375 #[test]
378 fn query_dim_truncated_ranking_differs_from_full() {
379 let s = store(&[(1, vec![0.1, 0.1]), (2, vec![0.0, 9.0])]);
380 let query = [0.0_f32, 1.0];
381
382 let full = rerank(
383 vec![make(1), make(2)],
384 &query,
385 DistanceMetric::L2,
386 1,
387 &opts(),
388 None,
389 fetch(&s),
390 )
391 .unwrap();
392 assert_eq!(full[0].id, 1, "full-dim should rank id=1 first");
393
394 let trunc = rerank(
395 vec![make(1), make(2)],
396 &query,
397 DistanceMetric::L2,
398 1,
399 &opts_with_dim(1),
400 None,
401 fetch(&s),
402 )
403 .unwrap();
404 assert_eq!(trunc[0].id, 2, "truncated-dim=1 should rank id=2 first");
405 }
406
407 #[test]
408 fn query_dim_zero_returns_bad_input() {
409 let s = store(&[(1, vec![1.0, 2.0])]);
410 let err = rerank(
411 vec![make(1)],
412 &[0.0, 0.0],
413 DistanceMetric::L2,
414 1,
415 &opts_with_dim(0),
416 None,
417 fetch(&s),
418 )
419 .unwrap_err();
420 let msg = err.to_string();
421 assert!(
422 msg.contains("query_dim=0"),
423 "error should name query_dim=0: {msg}"
424 );
425 }
426
427 #[test]
428 fn query_dim_exceeds_query_len_returns_bad_input() {
429 let s = store(&[(1, vec![1.0, 2.0])]);
430 let err = rerank(
431 vec![make(1)],
432 &[0.0, 0.0],
433 DistanceMetric::L2,
434 1,
435 &opts_with_dim(5),
436 None,
437 fetch(&s),
438 )
439 .unwrap_err();
440 let msg = err.to_string();
441 assert!(
442 msg.contains("query_dim=5"),
443 "error should name query_dim=5: {msg}"
444 );
445 assert!(
446 msg.contains('2'),
447 "error should mention query len (2): {msg}"
448 );
449 }
450
451 #[test]
452 fn query_dim_equal_to_query_len_matches_full_dim() {
453 let s = store(&[
454 (1, vec![1.0, 0.0, 0.0]),
455 (2, vec![0.5, 0.0, 0.0]),
456 (3, vec![3.0, 0.0, 0.0]),
457 ]);
458 let query = [0.0_f32, 0.0, 0.0];
459
460 let full = rerank(
461 vec![make(1), make(2), make(3)],
462 &query,
463 DistanceMetric::L2,
464 3,
465 &opts(),
466 None,
467 fetch(&s),
468 )
469 .unwrap();
470 let trunc = rerank(
471 vec![make(1), make(2), make(3)],
472 &query,
473 DistanceMetric::L2,
474 3,
475 &opts_with_dim(3),
476 None,
477 fetch(&s),
478 )
479 .unwrap();
480
481 let full_ids: Vec<u32> = full.iter().map(|r| r.id).collect();
482 let trunc_ids: Vec<u32> = trunc.iter().map(|r| r.id).collect();
483 assert_eq!(
484 full_ids, trunc_ids,
485 "query_dim == query.len() should produce identical ranking"
486 );
487 }
488
489 #[test]
490 fn fetch_returns_none_is_bad_input() {
491 let s: HashMap<u32, Vec<f32>> = HashMap::new();
492 let err = rerank(
493 vec![make(99)],
494 &[0.0, 0.0],
495 DistanceMetric::L2,
496 1,
497 &opts(),
498 None,
499 fetch(&s),
500 )
501 .unwrap_err();
502 let msg = err.to_string();
503 assert!(
504 msg.contains("99"),
505 "error should name the missing id (99): {msg}"
506 );
507 assert!(
508 matches!(err, RerankError::BadInput(_)),
509 "expected BadInput, got: {err}"
510 );
511 }
512
513 #[test]
517 fn codec_path_uses_sidecar() {
518 let mut sc = make_sidecar(CodecName::Binary);
520 sc.encode_and_insert(1, &[1.0, 0.0]).unwrap();
522 sc.encode_and_insert(2, &[0.0, 2.0]).unwrap();
523 sc.encode_and_insert(3, &[3.0, 0.0]).unwrap();
524
525 let candidates = vec![make(1), make(2), make(3)];
526 let opts = opts_with_quant(VectorQuantization::Binary);
527
528 let result = rerank(
531 candidates,
532 &[0.0, 0.0],
533 DistanceMetric::L2,
534 3,
535 &opts,
536 Some(&sc),
537 |_id| panic!("fetch_vector must not be called in codec path"),
538 )
539 .unwrap();
540
541 assert_eq!(result.len(), 3);
542 assert_eq!(result[0].id, 1);
544 assert_eq!(result[1].id, 2);
545 assert_eq!(result[2].id, 3);
546 assert!((result[0].distance - 1.0).abs() < 1e-5);
547 }
548
549 #[test]
551 fn codec_requested_but_no_sidecar_returns_bad_input() {
552 let s: HashMap<u32, Vec<f32>> = HashMap::new();
553 let opts = opts_with_quant(VectorQuantization::Binary);
554 let err = rerank(
555 vec![make(1)],
556 &[0.0, 0.0],
557 DistanceMetric::L2,
558 1,
559 &opts,
560 None,
561 fetch(&s),
562 )
563 .unwrap_err();
564 let msg = err.to_string();
565 assert!(
566 msg.contains("no codec sidecar provided"),
567 "expected sidecar-missing message: {msg}"
568 );
569 assert!(matches!(err, RerankError::BadInput(_)));
570 }
571
572 #[test]
574 fn codec_name_mismatch_returns_bad_input() {
575 let mut sc = make_sidecar(CodecName::Binary); sc.encode_and_insert(1, &[1.0, 0.0]).unwrap();
577
578 let opts = opts_with_quant(VectorQuantization::Sq8); let s: HashMap<u32, Vec<f32>> = HashMap::new();
580 let err = rerank(
581 vec![make(1)],
582 &[0.0, 0.0],
583 DistanceMetric::L2,
584 1,
585 &opts,
586 Some(&sc),
587 fetch(&s),
588 )
589 .unwrap_err();
590 let msg = err.to_string();
591 assert!(
592 msg.contains("Sq8") || msg.contains("sq8"),
593 "expected requested codec in message: {msg}"
594 );
595 assert!(
596 msg.contains("Binary") || msg.contains("binary"),
597 "expected actual codec in message: {msg}"
598 );
599 assert!(matches!(err, RerankError::BadInput(_)));
600 }
601
602 #[test]
604 fn codec_with_query_dim_returns_bad_input() {
605 let mut sc = make_sidecar(CodecName::Binary);
606 sc.encode_and_insert(1, &[1.0, 0.0]).unwrap();
607
608 let opts = VectorAnnOptions {
609 query_dim: Some(1),
610 quantization: Some(VectorQuantization::Binary),
611 ..Default::default()
612 };
613 let s: HashMap<u32, Vec<f32>> = HashMap::new();
614 let err = rerank(
615 vec![make(1)],
616 &[0.0, 0.0],
617 DistanceMetric::L2,
618 1,
619 &opts,
620 Some(&sc),
621 fetch(&s),
622 )
623 .unwrap_err();
624 let msg = err.to_string();
625 assert!(
626 msg.contains("query_dim") && msg.contains("quantization"),
627 "expected both terms in message: {msg}"
628 );
629 assert!(matches!(err, RerankError::BadInput(_)));
630 }
631
632 #[test]
634 fn fp32_path_with_some_sidecar_argument() {
635 let sc = make_sidecar(CodecName::Binary);
636 let s = store(&[(1, vec![1.0, 0.0]), (2, vec![0.1, 0.0])]);
638 let result = rerank(
640 vec![make(1), make(2)],
641 &[0.0, 0.0],
642 DistanceMetric::L2,
643 2,
644 &opts(), Some(&sc),
646 fetch(&s),
647 )
648 .unwrap();
649 assert_eq!(result[0].id, 2);
651 assert_eq!(result[1].id, 1);
652 }
653
654 #[test]
656 fn codec_path_missing_id_in_sidecar_returns_bad_input() {
657 let sc = make_sidecar(CodecName::Binary);
658 let opts = opts_with_quant(VectorQuantization::Binary);
660 let s: HashMap<u32, Vec<f32>> = HashMap::new();
661 let err = rerank(
662 vec![make(99)],
663 &[0.0, 0.0],
664 DistanceMetric::L2,
665 1,
666 &opts,
667 Some(&sc),
668 fetch(&s),
669 )
670 .unwrap_err();
671 let msg = err.to_string();
672 assert!(msg.contains("99"), "expected id 99 in message: {msg}");
673 assert!(
674 msg.contains("sidecar drift") || msg.contains("not present in sidecar"),
675 "expected drift message: {msg}"
676 );
677 assert!(matches!(err, RerankError::BadInput(_)));
678 }
679}