use super::{simd, RerankConfig, RerankError, Result};
use std::collections::HashMap;
use std::hash::Hash;
pub fn refine<I: Clone + Eq + Hash>(
candidates: &[(I, f32)],
query: &[f32],
docs: &[(I, Vec<f32>)],
head_dims: usize,
) -> Result<Vec<(I, f32)>> {
try_refine(candidates, query, docs, head_dims, RerankConfig::default())
}
pub fn refine_with_alpha<I: Clone + Eq + Hash>(
candidates: &[(I, f32)],
query: &[f32],
docs: &[(I, Vec<f32>)],
head_dims: usize,
alpha: f32,
) -> Result<Vec<(I, f32)>> {
try_refine(
candidates,
query,
docs,
head_dims,
RerankConfig::default().with_alpha(alpha),
)
}
pub fn refine_tail_only<I: Clone + Eq + Hash>(
candidates: &[(I, f32)],
query: &[f32],
docs: &[(I, Vec<f32>)],
head_dims: usize,
) -> Result<Vec<(I, f32)>> {
try_refine(
candidates,
query,
docs,
head_dims,
RerankConfig::refinement_only(),
)
}
pub fn try_refine<I: Clone + Eq + Hash>(
candidates: &[(I, f32)],
query: &[f32],
docs: &[(I, Vec<f32>)],
head_dims: usize,
config: RerankConfig,
) -> Result<Vec<(I, f32)>> {
if head_dims >= query.len() {
return Err(RerankError::InvalidHeadDims {
head_dims,
query_len: query.len(),
});
}
let doc_map: HashMap<&I, &[f32]> = docs
.iter()
.filter(|(_, emb)| emb.len() > head_dims)
.map(|(id, emb)| (id, emb.as_slice()))
.collect();
let query_tail = &query[head_dims..];
let alpha = config.alpha;
let mut results: Vec<(I, f32)> = candidates
.iter()
.filter_map(|(id, orig)| {
let doc_emb = doc_map.get(id)?;
let doc_tail = &doc_emb[head_dims..];
let tail_sim = simd::cosine(query_tail, doc_tail);
let blended = (1.0 - alpha).mul_add(tail_sim, alpha * orig);
Some((id.clone(), blended))
})
.collect();
super::sort_scored_desc(&mut results);
if let Some(k) = config.top_k {
results.truncate(k);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_refinement() {
let candidates = vec![("d1", 0.9), ("d2", 0.1)];
let query = vec![0.0, 0.0, 1.0, 0.0];
let docs = vec![
("d1", vec![0.0, 0.0, 0.0, 1.0]),
("d2", vec![0.0, 0.0, 1.0, 0.0]),
];
let refined = refine_tail_only(&candidates, &query, &docs, 2).unwrap();
assert_eq!(refined[0].0, "d2");
}
#[test]
fn test_refinement_with_alpha() {
let candidates = vec![("d1", 1.0), ("d2", 0.0)];
let query = vec![0.0, 0.0, 1.0, 0.0];
let docs = vec![
("d1", vec![0.0, 0.0, 0.0, 1.0]),
("d2", vec![0.0, 0.0, 1.0, 0.0]),
];
let refined = refine_with_alpha(&candidates, &query, &docs, 2, 1.0).unwrap();
assert_eq!(refined[0].0, "d1");
let refined = refine_with_alpha(&candidates, &query, &docs, 2, 0.0).unwrap();
assert_eq!(refined[0].0, "d2");
}
#[test]
fn test_try_refine_error() {
let candidates = vec![("d1", 0.9)];
let query = vec![1.0, 2.0];
let docs = vec![("d1", vec![1.0, 2.0])];
let result = try_refine(&candidates, &query, &docs, 2, RerankConfig::default());
assert!(matches!(result, Err(RerankError::InvalidHeadDims { .. })));
}
#[test]
fn test_missing_candidate() {
let candidates = vec![("d1", 0.9), ("d2", 0.8)];
let query = vec![0.0, 0.0, 1.0, 0.0];
let docs = vec![("d1", vec![0.0, 0.0, 1.0, 0.0])];
let refined = refine(&candidates, &query, &docs, 2).unwrap();
assert_eq!(refined.len(), 1);
}
#[test]
fn test_short_doc_embedding() {
let candidates = vec![("d1", 0.9), ("d2", 0.8)];
let query = vec![0.0, 0.0, 1.0, 0.0];
let docs = vec![("d1", vec![0.0, 0.0, 1.0, 0.0]), ("d2", vec![0.0])];
let refined = refine(&candidates, &query, &docs, 2).unwrap();
assert_eq!(refined.len(), 1);
assert_eq!(refined[0].0, "d1");
}
#[test]
fn test_head_dims_too_large() {
let candidates = vec![("d1", 0.9)];
let query = vec![1.0, 2.0];
let docs = vec![("d1", vec![1.0, 2.0])];
let result = refine(&candidates, &query, &docs, 2);
assert!(result.is_err());
let _ = refine(&candidates, &query, &docs, 2);
}
#[test]
fn test_nan_score_handling() {
let candidates = vec![("d1", f32::NAN), ("d2", 0.5), ("d3", 0.9)];
let query = vec![0.0, 0.0, 1.0, 0.0];
let docs = vec![
("d1", vec![0.0, 0.0, 1.0, 0.0]),
("d2", vec![0.0, 0.0, 1.0, 0.0]),
("d3", vec![0.0, 0.0, 1.0, 0.0]),
];
let refined = refine(&candidates, &query, &docs, 2).unwrap();
assert_eq!(refined.len(), 3);
assert!(refined[0].1.is_nan());
}
#[test]
fn test_top_k() {
let candidates = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let query = vec![0.0, 0.0, 1.0, 0.0];
let docs = vec![
("d1", vec![0.0, 0.0, 1.0, 0.0]),
("d2", vec![0.0, 0.0, 1.0, 0.0]),
("d3", vec![0.0, 0.0, 1.0, 0.0]),
];
let refined = try_refine(
&candidates,
&query,
&docs,
2,
RerankConfig::default().with_top_k(2),
)
.unwrap();
assert_eq!(refined.len(), 2);
}
#[test]
fn test_tail_dims_provide_discrimination() {
let candidates = vec![
("doc_a", 0.5), ("doc_b", 0.5), ];
let query = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ];
let docs = vec![
(
"doc_a",
vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
),
(
"doc_b",
vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ],
),
];
let refined = refine_tail_only(&candidates, &query, &docs, 4).unwrap();
assert_eq!(
refined[0].0, "doc_b",
"Tail should discriminate: doc_b matches query tail"
);
assert!(
refined[0].1 > refined[1].1,
"doc_b score {} should be > doc_a score {}",
refined[0].1,
refined[1].1
);
let score_diff = refined[0].1 - refined[1].1;
assert!(
score_diff > 0.9,
"Score difference {} should be large (orthogonal vs identical)",
score_diff
);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn output_bounded_by_candidates(
n_candidates in 1usize..10,
n_docs in 0usize..10,
dim in 4usize..16
) {
let head_dims = dim / 2;
let candidates: Vec<(u32, f32)> = (0..n_candidates as u32)
.map(|i| (i, 1.0 - i as f32 * 0.1))
.collect();
let query: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let docs: Vec<(u32, Vec<f32>)> = (0..n_docs as u32)
.map(|i| (i, (0..dim).map(|j| (i + j as u32) as f32 * 0.1).collect()))
.collect();
let refined = refine(&candidates, &query, &docs, head_dims).unwrap();
prop_assert!(refined.len() <= candidates.len());
}
#[test]
fn alpha_one_preserves_order(dim in 4usize..8) {
let head_dims = dim / 2;
let candidates = vec![(1u32, 0.9), (2u32, 0.8), (3u32, 0.7)];
let query: Vec<f32> = (0..dim).map(|_| 0.5).collect();
let docs: Vec<(u32, Vec<f32>)> = vec![
(1, (0..dim).map(|_| 0.1).collect()),
(2, (0..dim).map(|_| 0.2).collect()),
(3, (0..dim).map(|_| 0.3).collect()),
];
let refined = refine_with_alpha(&candidates, &query, &docs, head_dims, 1.0).unwrap();
prop_assert_eq!(refined.len(), 3);
prop_assert_eq!(refined[0].0, 1);
prop_assert_eq!(refined[1].0, 2);
prop_assert_eq!(refined[2].0, 3);
}
#[test]
fn results_sorted_descending(n in 2usize..8, dim in 4usize..8) {
let head_dims = dim / 2;
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, (i as f32) * 0.1))
.collect();
let query: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let docs: Vec<(u32, Vec<f32>)> = (0..n as u32)
.map(|i| (i, (0..dim).map(|j| (i + j as u32) as f32 * 0.05).collect()))
.collect();
let refined = refine(&candidates, &query, &docs, head_dims).unwrap();
for window in refined.windows(2) {
prop_assert!(
window[0].1.total_cmp(&window[1].1) != std::cmp::Ordering::Less,
"Results not sorted: {:?}",
refined
);
}
}
#[test]
fn alpha_interpolation(dim in 8usize..16) {
let head_dims = dim / 2;
let candidates = vec![(1u32, 0.8), (2u32, 0.5)];
let query: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
let docs: Vec<(u32, Vec<f32>)> = vec![
(1, (0..dim).map(|i| (i as f32 * 0.2).cos()).collect()),
(2, (0..dim).map(|i| (i as f32 * 0.3).sin()).collect()),
];
let r_0 = refine_with_alpha(&candidates, &query, &docs, head_dims, 0.0).unwrap();
let r_half = refine_with_alpha(&candidates, &query, &docs, head_dims, 0.5).unwrap();
let r_1 = refine_with_alpha(&candidates, &query, &docs, head_dims, 1.0).unwrap();
for id in [1u32, 2] {
let s_0 = r_0.iter().find(|(i, _)| *i == id).map(|(_, s)| *s);
let s_half = r_half.iter().find(|(i, _)| *i == id).map(|(_, s)| *s);
let s_1 = r_1.iter().find(|(i, _)| *i == id).map(|(_, s)| *s);
if let (Some(s0), Some(sh), Some(s1)) = (s_0, s_half, s_1) {
let min_score = s0.min(s1);
let max_score = s0.max(s1);
prop_assert!(
sh >= min_score - 0.01 && sh <= max_score + 0.01,
"alpha=0.5 score {} not between {} and {}",
sh, min_score, max_score
);
}
}
}
#[test]
fn smaller_head_uses_more_dims(dim in 16usize..32) {
let candidates = vec![(1u32, 0.5)];
let query: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
let doc: Vec<f32> = (0..dim)
.map(|i| if i < dim / 4 { 0.1 } else { 0.9 })
.collect();
let docs = vec![(1u32, doc)];
let r_small_head = refine_tail_only(&candidates, &query, &docs, dim / 4).unwrap();
let r_big_head = refine_tail_only(&candidates, &query, &docs, dim * 3 / 4).unwrap();
prop_assert!(r_small_head[0].1.is_finite());
prop_assert!(r_big_head[0].1.is_finite());
}
#[test]
fn short_docs_filtered(dim in 8usize..16) {
let head_dims = dim / 2;
let candidates = vec![(1u32, 0.9), (2u32, 0.8)];
let query: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let docs = vec![
(1u32, (0..dim).map(|i| i as f32 * 0.1).collect()),
(2u32, (0..head_dims - 1).map(|i| i as f32 * 0.1).collect()),
];
let refined = refine(&candidates, &query, &docs, head_dims).unwrap();
prop_assert_eq!(refined.len(), 1);
prop_assert_eq!(refined[0].0, 1);
}
#[test]
fn try_refine_validates_head_dims(dim in 2usize..8) {
let query: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let candidates = vec![(1u32, 0.5)];
let docs = vec![(1u32, query.clone())];
let result = try_refine(&candidates, &query, &docs, dim, RerankConfig::default());
prop_assert!(result.is_err());
let result = try_refine(&candidates, &query, &docs, dim - 1, RerankConfig::default());
prop_assert!(result.is_ok());
}
#[test]
fn config_top_k_limits_output(n in 5usize..10, k in 1usize..4) {
let dim = 8;
let head_dims = 4;
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 1.0 - i as f32 * 0.05))
.collect();
let query: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let docs: Vec<(u32, Vec<f32>)> = (0..n as u32)
.map(|i| (i, (0..dim).map(|j| (i + j as u32) as f32 * 0.05).collect()))
.collect();
let config = RerankConfig::default().with_top_k(k);
let refined = try_refine(&candidates, &query, &docs, head_dims, config).unwrap();
prop_assert!(refined.len() <= k, "Expected at most {} results, got {}", k, refined.len());
}
}
}