use super::simd;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DenseScorer {
Dot,
Cosine,
}
impl DenseScorer {
#[must_use]
pub fn score(&self, query: &[f32], doc: &[f32]) -> f32 {
match self {
Self::Dot => simd::dot(query, doc),
Self::Cosine => simd::cosine(query, doc),
}
}
}
pub trait Scorer {
fn score(&self, query: &[f32], doc: &[f32]) -> f32;
fn rank<I: Clone>(&self, query: &[f32], docs: &[(I, &[f32])]) -> Vec<(I, f32)> {
let mut results: Vec<(I, f32)> = docs
.iter()
.map(|(id, doc)| (id.clone(), self.score(query, doc)))
.collect();
super::sort_scored_desc(&mut results);
results
}
}
impl Scorer for DenseScorer {
fn score(&self, query: &[f32], doc: &[f32]) -> f32 {
DenseScorer::score(self, query, doc)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LateInteractionScorer {
MaxSimDot,
MaxSimCosine,
}
impl LateInteractionScorer {
#[must_use]
pub fn score(&self, query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
match self {
Self::MaxSimDot => simd::maxsim(query_tokens, doc_tokens),
Self::MaxSimCosine => simd::maxsim_cosine(query_tokens, doc_tokens),
}
}
#[must_use]
pub fn score_weighted(
&self,
query_tokens: &[&[f32]],
doc_tokens: &[&[f32]],
weights: &[f32],
) -> f32 {
match self {
Self::MaxSimDot => simd::maxsim_weighted(query_tokens, doc_tokens, weights),
Self::MaxSimCosine => simd::maxsim_cosine_weighted(query_tokens, doc_tokens, weights),
}
}
}
pub trait TokenScorer {
fn score_tokens(&self, query: &[&[f32]], doc: &[&[f32]]) -> f32;
fn score_vecs(&self, query: &[Vec<f32>], doc: &[Vec<f32>]) -> f32 {
let q = super::simd::as_slices(query);
let d = super::simd::as_slices(doc);
self.score_tokens(&q, &d)
}
fn maxsim_tokens<I: Clone>(
&self,
query: &[&[f32]],
docs: &[(I, Vec<&[f32]>)],
) -> Vec<(I, f32)> {
let mut results: Vec<(I, f32)> = docs
.iter()
.map(|(id, doc_tokens)| (id.clone(), self.score_tokens(query, doc_tokens)))
.collect();
super::sort_scored_desc(&mut results);
results
}
fn maxsim_vecs<I: Clone>(
&self,
query: &[Vec<f32>],
docs: &[(I, Vec<Vec<f32>>)],
) -> Vec<(I, f32)> {
let q = super::simd::as_slices(query);
let mut results: Vec<(I, f32)> = docs
.iter()
.map(|(id, doc_tokens)| {
let d = super::simd::as_slices(doc_tokens);
(id.clone(), self.score_tokens(&q, &d))
})
.collect();
super::sort_scored_desc(&mut results);
results
}
}
impl TokenScorer for LateInteractionScorer {
fn score_tokens(&self, query: &[&[f32]], doc: &[&[f32]]) -> f32 {
self.score(query, doc)
}
}
#[inline]
#[must_use]
pub fn blend(score_a: f32, score_b: f32, alpha: f32) -> f32 {
(1.0 - alpha).mul_add(score_b, alpha * score_a)
}
#[must_use]
pub fn normalize_scores(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return Vec::new();
}
let (min, max) = scores
.iter()
.fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), &s| {
(lo.min(s), hi.max(s))
});
let range = max - min;
if range < 1e-9 {
return vec![0.5; scores.len()];
}
scores.iter().map(|&s| (s - min) / range).collect()
}
pub trait Pooler {
fn pool(&self, tokens: &[Vec<f32>], target_count: usize) -> Vec<Vec<f32>>;
fn pool_by_factor(&self, tokens: &[Vec<f32>], factor: usize) -> Vec<Vec<f32>> {
if tokens.is_empty() || factor <= 1 {
return tokens.to_vec();
}
self.pool(tokens, (tokens.len() / factor).max(1))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SequentialPooler;
impl Pooler for SequentialPooler {
fn pool(&self, tokens: &[Vec<f32>], target_count: usize) -> Vec<Vec<f32>> {
if tokens.is_empty() || target_count >= tokens.len() {
return tokens.to_vec();
}
let window = tokens.len().div_ceil(target_count);
super::colbert::pool_tokens_sequential(tokens, window).unwrap_or_else(|_| tokens.to_vec())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ClusteringPooler;
impl Pooler for ClusteringPooler {
fn pool(&self, tokens: &[Vec<f32>], target_count: usize) -> Vec<Vec<f32>> {
if tokens.is_empty() || target_count >= tokens.len() {
return tokens.to_vec();
}
let factor = tokens.len().div_ceil(target_count);
super::colbert::pool_tokens(tokens, factor).unwrap_or_else(|_| tokens.to_vec())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AdaptivePooler;
impl Pooler for AdaptivePooler {
fn pool(&self, tokens: &[Vec<f32>], target_count: usize) -> Vec<Vec<f32>> {
if tokens.is_empty() || target_count >= tokens.len() {
return tokens.to_vec();
}
let factor = tokens.len().div_ceil(target_count);
super::colbert::pool_tokens_adaptive(tokens, factor).unwrap_or_else(|_| tokens.to_vec())
}
}
pub struct FnPooler<F> {
pool_fn: F,
}
impl<F> FnPooler<F>
where
F: Fn(&[Vec<f32>], usize) -> Vec<Vec<f32>>,
{
pub const fn new(pool_fn: F) -> Self {
Self { pool_fn }
}
}
impl<F> Pooler for FnPooler<F>
where
F: Fn(&[Vec<f32>], usize) -> Vec<Vec<f32>>,
{
fn pool(&self, tokens: &[Vec<f32>], target_count: usize) -> Vec<Vec<f32>> {
(self.pool_fn)(tokens, target_count)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dense_dot() {
let scorer = DenseScorer::Dot;
assert!((scorer.score(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-5);
assert!((scorer.score(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-5);
}
#[test]
fn test_dense_cosine() {
let scorer = DenseScorer::Cosine;
assert!((scorer.score(&[2.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-5);
assert!((scorer.score(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-5);
}
#[test]
fn test_dense_rank() {
let scorer = DenseScorer::Cosine;
let query = &[1.0f32, 0.0][..];
let docs: Vec<(&str, &[f32])> = vec![("d1", &[0.0, 1.0][..]), ("d2", &[1.0, 0.0][..])];
let ranked = scorer.rank(query, &docs);
assert_eq!(ranked[0].0, "d2");
}
#[test]
fn test_late_interaction_maxsim() {
let scorer = LateInteractionScorer::MaxSimDot;
let q1: &[f32] = &[1.0, 0.0];
let d1: &[f32] = &[1.0, 0.0];
let d2: &[f32] = &[0.0, 1.0];
let query = vec![q1];
let doc = vec![d1, d2];
assert!((scorer.score_tokens(&query, &doc) - 1.0).abs() < 1e-5);
}
#[test]
fn test_blend() {
assert!((blend(1.0, 0.0, 1.0) - 1.0).abs() < 1e-5); assert!((blend(1.0, 0.0, 0.0) - 0.0).abs() < 1e-5); assert!((blend(1.0, 0.0, 0.5) - 0.5).abs() < 1e-5); }
#[test]
fn test_normalize_scores() {
let scores = vec![0.0, 0.5, 1.0];
let normalized = normalize_scores(&scores);
assert!((normalized[0] - 0.0).abs() < 1e-5);
assert!((normalized[1] - 0.5).abs() < 1e-5);
assert!((normalized[2] - 1.0).abs() < 1e-5);
}
#[test]
fn test_normalize_scores_equal() {
let scores = vec![0.5, 0.5, 0.5];
let normalized = normalize_scores(&scores);
assert!(normalized.iter().all(|&s| (s - 0.5).abs() < 1e-5));
}
#[test]
fn test_normalize_scores_empty() {
let scores: Vec<f32> = vec![];
let normalized = normalize_scores(&scores);
assert!(normalized.is_empty());
}
#[test]
fn test_token_scorer_score_vecs() {
let scorer = LateInteractionScorer::MaxSimDot;
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let doc = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
let score = scorer.score_vecs(&query, &doc);
assert!(score > 1.5); }
#[test]
fn test_token_scorer_maxsim_vecs() {
let scorer = LateInteractionScorer::MaxSimDot;
let query = vec![vec![1.0, 0.0]];
let docs = vec![
("d1", vec![vec![0.0, 1.0]]), ("d2", vec![vec![1.0, 0.0]]), ];
let ranked = scorer.maxsim_vecs(&query, &docs);
assert_eq!(ranked[0].0, "d2"); }
#[test]
fn test_fn_pooler_custom() {
let first_only = FnPooler::new(|tokens: &[Vec<f32>], _target| {
if tokens.is_empty() {
vec![]
} else {
vec![tokens[0].clone()]
}
});
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
let pooled = first_only.pool(&tokens, 1);
assert_eq!(pooled.len(), 1);
assert_eq!(pooled[0], vec![1.0, 0.0]);
}
#[test]
fn test_fn_pooler_mean() {
let mean_pool = FnPooler::new(|tokens: &[Vec<f32>], _target| {
if tokens.is_empty() {
return vec![];
}
let dim = tokens[0].len();
let mut mean = vec![0.0; dim];
for tok in tokens {
for (i, &v) in tok.iter().enumerate() {
mean[i] += v;
}
}
let n = tokens.len() as f32;
for v in &mut mean {
*v /= n;
}
vec![mean]
});
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let pooled = mean_pool.pool(&tokens, 1);
assert_eq!(pooled.len(), 1);
assert!((pooled[0][0] - 0.5).abs() < 1e-5);
assert!((pooled[0][1] - 0.5).abs() < 1e-5);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn arb_vec(len: usize) -> impl Strategy<Value = Vec<f32>> {
proptest::collection::vec(-10.0f32..10.0, len)
}
proptest! {
#[test]
fn scorer_cosine_commutative(a in arb_vec(32), b in arb_vec(32)) {
let scorer = DenseScorer::Cosine;
let ab = scorer.score(&a, &b);
let ba = scorer.score(&b, &a);
prop_assert!((ab - ba).abs() < 1e-5);
}
#[test]
fn scorer_dot_commutative(a in arb_vec(32), b in arb_vec(32)) {
let scorer = DenseScorer::Dot;
let ab = scorer.score(&a, &b);
let ba = scorer.score(&b, &a);
prop_assert!((ab - ba).abs() < 1e-5);
}
#[test]
fn scorer_maxsim_preserves_count(n in 1usize..10, dim in 2usize..8) {
let scorer = DenseScorer::Cosine;
let query: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let docs: Vec<(u32, Vec<f32>)> = (0..n as u32)
.map(|i| (i, (0..dim).map(|j| (i as usize + j) as f32 * 0.1).collect()))
.collect();
let doc_refs: Vec<(u32, &[f32])> = docs.iter()
.map(|(id, v)| (*id, v.as_slice()))
.collect();
let ranked = scorer.rank(&query, &doc_refs);
prop_assert_eq!(ranked.len(), n);
}
#[test]
fn blend_alpha_one(a in -100.0f32..100.0, b in -100.0f32..100.0) {
let blended = blend(a, b, 1.0);
prop_assert!((blended - a).abs() < 1e-5);
}
#[test]
fn pool_by_factor_uses_division(n_tokens in 10usize..50, factor in 2usize..10) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| vec![i as f32; 4])
.collect();
let pooler = ClusteringPooler;
let pooled = pooler.pool_by_factor(&tokens, factor);
let expected_count = (n_tokens / factor).max(1);
prop_assert!(
pooled.len() <= expected_count + 1, "pool_by_factor should divide: {} tokens / {} factor = {} expected, got {}",
n_tokens, factor, expected_count, pooled.len()
);
prop_assert!(
pooled.len() < n_tokens * factor,
"Should not multiply: {} tokens * {} factor would be {}, got {}",
n_tokens, factor, n_tokens * factor, pooled.len()
);
}
#[test]
fn blend_alpha_zero(a in -100.0f32..100.0, b in -100.0f32..100.0) {
let blended = blend(a, b, 0.0);
prop_assert!((blended - b).abs() < 1e-5);
}
#[test]
fn normalize_bounded(scores in proptest::collection::vec(-100.0f32..100.0, 2..20)) {
let normalized = normalize_scores(&scores);
for &s in &normalized {
prop_assert!((-0.01..=1.01).contains(&s), "Score {} out of bounds", s);
}
}
#[test]
fn normalize_preserves_order(scores in proptest::collection::vec(-100.0f32..100.0, 2..10)) {
let normalized = normalize_scores(&scores);
let range = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b))
- scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let eps = range * 1e-5;
for i in 0..scores.len() {
for j in 0..scores.len() {
if (scores[i] - scores[j]).abs() < eps.max(1e-5) {
continue;
}
let orig_cmp = scores[i].total_cmp(&scores[j]);
let norm_cmp = normalized[i].total_cmp(&normalized[j]);
prop_assert_eq!(orig_cmp, norm_cmp, "Order changed at indices ({}, {})", i, j);
}
}
}
#[test]
fn blend_is_linear(a in -10.0f32..10.0, b in -10.0f32..10.0, alpha in 0.0f32..1.0) {
let blended = blend(a, b, alpha);
let expected = alpha * a + (1.0 - alpha) * b;
prop_assert!((blended - expected).abs() < 1e-5, "blend({}, {}, {}) = {}, expected {}", a, b, alpha, blended, expected);
}
#[test]
fn scorer_maxsim_is_sorted(n in 2usize..10, dim in 2usize..8) {
let scorer = DenseScorer::Cosine;
let query: Vec<f32> = (0..dim).map(|i| (i + 1) as f32).collect();
let docs: Vec<(u32, Vec<f32>)> = (0..n as u32)
.map(|i| (i, (0..dim).map(|j| ((i as usize * dim + j) % 10) as f32).collect()))
.collect();
let doc_refs: Vec<(u32, &[f32])> = docs.iter()
.map(|(id, v)| (*id, v.as_slice()))
.collect();
let ranked = scorer.rank(&query, &doc_refs);
for w in ranked.windows(2) {
prop_assert!(w[0].1 >= w[1].1, "Not sorted: {} < {}", w[0].1, w[1].1);
}
}
#[test]
fn late_interaction_nonnegative(
q_tokens in 1usize..4,
d_tokens in 1usize..4,
dim in 2usize..8
) {
let query: Vec<Vec<f32>> = (0..q_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) % 5) as f32 * 0.1 + 0.1).collect())
.collect();
let doc: Vec<Vec<f32>> = (0..d_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j + 3) % 5) as f32 * 0.1 + 0.1).collect())
.collect();
let query_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let doc_refs: Vec<&[f32]> = doc.iter().map(Vec::as_slice).collect();
let scorer = LateInteractionScorer::MaxSimDot;
let score = scorer.score(&query_refs, &doc_refs);
prop_assert!(score >= 0.0, "`MaxSim` score {} should be non-negative", score);
}
#[test]
fn late_interaction_empty_doc(dim in 2usize..8) {
let query: Vec<Vec<f32>> = vec![vec![1.0; dim], vec![0.5; dim]];
let query_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let doc_refs: Vec<&[f32]> = vec![];
let scorer = LateInteractionScorer::MaxSimDot;
let score = scorer.score(&query_refs, &doc_refs);
prop_assert!((score - 0.0).abs() < 1e-9, "Empty doc should return 0, got {}", score);
}
#[test]
fn scorer_cosine_bounded_normalized(dim in 2usize..16) {
let a: Vec<f32> = (0..dim).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
let b: Vec<f32> = (0..dim).map(|i| if i == 1 { 1.0 } else { 0.0 }).collect();
let scorer = DenseScorer::Cosine;
let score = scorer.score(&a, &b);
prop_assert!((-1.01..=1.01).contains(&score), "Cosine {} out of bounds", score);
}
#[test]
fn pooler_never_increases_count(n_tokens in 2usize..16, dim in 2usize..8, target in 1usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let seq = SequentialPooler.pool(&tokens, target);
let cluster = ClusteringPooler.pool(&tokens, target);
let adaptive = AdaptivePooler.pool(&tokens, target);
prop_assert!(seq.len() <= n_tokens, "Sequential increased count: {} -> {}", n_tokens, seq.len());
prop_assert!(cluster.len() <= n_tokens, "Clustering increased count: {} -> {}", n_tokens, cluster.len());
prop_assert!(adaptive.len() <= n_tokens, "Adaptive increased count: {} -> {}", n_tokens, adaptive.len());
}
#[test]
fn pooler_preserves_dimension(n_tokens in 2usize..16, dim in 2usize..16, factor in 2usize..4) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let seq = SequentialPooler.pool_by_factor(&tokens, factor);
let cluster = ClusteringPooler.pool_by_factor(&tokens, factor);
let adaptive = AdaptivePooler.pool_by_factor(&tokens, factor);
prop_assert!(seq.iter().all(|t| t.len() == dim), "Sequential changed dim");
prop_assert!(cluster.iter().all(|t| t.len() == dim), "Clustering changed dim");
prop_assert!(adaptive.iter().all(|t| t.len() == dim), "Adaptive changed dim");
}
#[test]
fn pooler_empty_input(target in 1usize..10) {
let empty: Vec<Vec<f32>> = vec![];
prop_assert!(SequentialPooler.pool(&empty, target).is_empty());
prop_assert!(ClusteringPooler.pool(&empty, target).is_empty());
prop_assert!(AdaptivePooler.pool(&empty, target).is_empty());
}
#[test]
fn pooler_factor_one_identity(n_tokens in 1usize..8, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
.collect();
let seq = SequentialPooler.pool_by_factor(&tokens, 1);
let cluster = ClusteringPooler.pool_by_factor(&tokens, 1);
let adaptive = AdaptivePooler.pool_by_factor(&tokens, 1);
prop_assert_eq!(seq.len(), n_tokens);
prop_assert_eq!(cluster.len(), n_tokens);
prop_assert_eq!(adaptive.len(), n_tokens);
}
#[test]
fn token_scorer_maxsim_is_sorted(n_docs in 2usize..6, n_q in 1usize..3, dim in 2usize..8) {
let query: Vec<Vec<f32>> = (0..n_q)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| {
let toks: Vec<Vec<f32>> = (0..3)
.map(|t| (0..dim).map(|j| ((i as usize * 3 + t + j) as f32 * 0.1).cos()).collect())
.collect();
(i, toks)
})
.collect();
let query_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let doc_refs: Vec<(u32, Vec<&[f32]>)> = docs.iter()
.map(|(id, toks)| (*id, toks.iter().map(Vec::as_slice).collect()))
.collect();
let scorer = LateInteractionScorer::MaxSimDot;
let ranked = scorer.maxsim_tokens(&query_refs, &doc_refs);
for window in ranked.windows(2) {
prop_assert!(
window[0].1 >= window[1].1 - 1e-6,
"Not sorted: {} >= {}",
window[0].1,
window[1].1
);
}
}
#[test]
fn token_scorer_maxsim_preserves_count(n_docs in 1usize..6, n_q in 1usize..3, dim in 2usize..8) {
let query: Vec<Vec<f32>> = (0..n_q)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| {
let toks: Vec<Vec<f32>> = (0..2)
.map(|t| (0..dim).map(|j| ((i as usize * 2 + t + j) as f32 * 0.1).cos()).collect())
.collect();
(i, toks)
})
.collect();
let query_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let doc_refs: Vec<(u32, Vec<&[f32]>)> = docs.iter()
.map(|(id, toks)| (*id, toks.iter().map(Vec::as_slice).collect()))
.collect();
let scorer = LateInteractionScorer::MaxSimDot;
let ranked = scorer.maxsim_tokens(&query_refs, &doc_refs);
prop_assert_eq!(ranked.len(), n_docs);
}
}
}