use super::{simd, RerankConfig};
pub fn pool_tokens(tokens: &[Vec<f32>], pool_factor: usize) -> super::Result<Vec<Vec<f32>>> {
if tokens.is_empty() {
return Ok(tokens.to_vec());
}
if pool_factor == 0 {
return Err(super::RerankError::InvalidPoolFactor { pool_factor: 0 });
}
if pool_factor == 1 {
return Ok(tokens.to_vec());
}
let n = tokens.len();
let target_count = (n / pool_factor).max(1);
if n <= target_count {
return Ok(tokens.to_vec());
}
#[cfg(feature = "hierarchical")]
{
Ok(pool_tokens_hierarchical(tokens, target_count))
}
#[cfg(not(feature = "hierarchical"))]
{
Ok(pool_tokens_greedy(tokens, target_count))
}
}
#[cfg(not(feature = "hierarchical"))]
fn pool_tokens_greedy(tokens: &[Vec<f32>], target_count: usize) -> Vec<Vec<f32>> {
let n = tokens.len();
let mut clusters: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
while clusters.len() > target_count {
let mut best_i = 0;
let mut best_j = 1;
let mut best_sim = f32::NEG_INFINITY;
for i in 0..clusters.len() {
for j in (i + 1)..clusters.len() {
let sim = cluster_similarity(tokens, &clusters[i], &clusters[j]);
if sim > best_sim {
best_sim = sim;
best_i = i;
best_j = j;
}
}
}
let merged = clusters.remove(best_j);
clusters[best_i].extend(merged);
}
clusters
.iter()
.map(|indices| mean_pool(tokens, indices))
.collect()
}
#[cfg(feature = "hierarchical")]
fn pool_tokens_hierarchical(tokens: &[Vec<f32>], target_count: usize) -> Vec<Vec<f32>> {
use kodama::{linkage, Method};
let n = tokens.len();
let mut condensed = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let sim = simd::cosine(&tokens[i], &tokens[j]);
let sim_safe = if sim.is_nan() { -1.0 } else { sim };
#[allow(clippy::cast_lossless)]
let dist = f64::from((1.0 - sim_safe).clamp(0.0, 2.0));
condensed.push(dist);
}
}
let dendrogram = linkage(&mut condensed, n, Method::Ward);
let labels = cut_dendrogram(&dendrogram, n, target_count);
let num_clusters = labels.iter().max().map_or(0, |&m| m + 1);
let mut clusters: Vec<Vec<usize>> = vec![vec![]; num_clusters];
for (i, &label) in labels.iter().enumerate() {
clusters[label].push(i);
}
clusters
.iter()
.filter(|c| !c.is_empty())
.map(|indices| mean_pool(tokens, indices))
.collect()
}
#[cfg(feature = "hierarchical")]
fn cut_dendrogram(
dendrogram: &kodama::Dendrogram<f64>,
n: usize,
target_count: usize,
) -> Vec<usize> {
let steps_to_take = n.saturating_sub(target_count);
#[allow(clippy::items_after_statements)]
fn find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]]; x = parent[x];
}
x
}
let mut parent: Vec<usize> = (0..2 * n).collect();
for (step_idx, step) in dendrogram.steps().iter().enumerate() {
if step_idx >= steps_to_take {
break;
}
let new_cluster = n + step_idx;
parent[step.cluster1] = new_cluster;
parent[step.cluster2] = new_cluster;
}
let mut label_map = std::collections::HashMap::new();
let mut next_label = 0;
let mut labels = vec![0; n];
for (i, label_slot) in labels.iter_mut().enumerate() {
let root = find(&mut parent, i);
let label = *label_map.entry(root).or_insert_with(|| {
let l = next_label;
next_label += 1;
l
});
*label_slot = label;
}
labels
}
pub fn pool_tokens_sequential(
tokens: &[Vec<f32>],
window_size: usize,
) -> super::Result<Vec<Vec<f32>>> {
if tokens.is_empty() {
return Ok(tokens.to_vec());
}
if window_size == 0 {
return Err(super::RerankError::InvalidWindowSize { window_size: 0 });
}
if window_size == 1 {
return Ok(tokens.to_vec());
}
Ok(tokens
.chunks(window_size)
.map(|chunk| {
let dim = chunk[0].len();
let mut pooled = vec![0.0; dim];
for token in chunk {
for (k, v) in pooled.iter_mut().enumerate() {
*v += token[k];
}
}
#[allow(clippy::cast_precision_loss)]
let n = chunk.len() as f32;
for v in &mut pooled {
*v /= n;
}
pooled
})
.collect())
}
pub fn pool_tokens_with_protected(
tokens: &[Vec<f32>],
pool_factor: usize,
protected_count: usize,
) -> super::Result<Vec<Vec<f32>>> {
if tokens.is_empty() {
return Ok(tokens.to_vec());
}
if pool_factor == 0 {
return Err(super::RerankError::InvalidPoolFactor { pool_factor: 0 });
}
if pool_factor == 1 {
return Ok(tokens.to_vec());
}
let protected_count = protected_count.min(tokens.len());
let protected = &tokens[..protected_count];
let poolable = &tokens[protected_count..];
let mut result = protected.to_vec();
result.extend(pool_tokens(poolable, pool_factor)?);
Ok(result)
}
pub fn pool_tokens_adaptive(
tokens: &[Vec<f32>],
pool_factor: usize,
) -> super::Result<Vec<Vec<f32>>> {
if tokens.is_empty() {
return Ok(tokens.to_vec());
}
if pool_factor == 0 {
return Err(super::RerankError::InvalidPoolFactor { pool_factor: 0 });
}
if pool_factor == 1 {
return Ok(tokens.to_vec());
}
if pool_factor >= 4 {
pool_tokens_sequential(tokens, pool_factor)
} else {
pool_tokens(tokens, pool_factor)
}
}
#[cfg(not(feature = "hierarchical"))]
fn cluster_similarity(tokens: &[Vec<f32>], c1: &[usize], c2: &[usize]) -> f32 {
let centroid1 = mean_pool(tokens, c1);
let centroid2 = mean_pool(tokens, c2);
simd::cosine(¢roid1, ¢roid2)
}
fn mean_pool(tokens: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
if indices.is_empty() {
return vec![];
}
let dim = tokens[indices[0]].len();
let mut pooled = vec![0.0; dim];
for &idx in indices {
for (k, v) in pooled.iter_mut().enumerate() {
*v += tokens[idx][k];
}
}
#[allow(clippy::cast_precision_loss)]
let n = indices.len() as f32;
for v in &mut pooled {
*v /= n;
}
pooled
}
pub fn rank<I: Clone>(query: &[Vec<f32>], docs: &[(I, Vec<Vec<f32>>)]) -> Vec<(I, f32)> {
maxsim_with_top_k(query, docs, None)
}
#[must_use]
pub fn maxsim_with_top_k<I: Clone>(
query: &[Vec<f32>],
docs: &[(I, Vec<Vec<f32>>)],
top_k: Option<usize>,
) -> Vec<(I, f32)> {
let query_refs = super::simd::as_slices(query);
let mut results: Vec<(I, f32)> = docs
.iter()
.map(|(id, doc_tokens)| {
let doc_refs = super::simd::as_slices(doc_tokens);
let score = simd::maxsim(&query_refs, &doc_refs);
(id.clone(), score)
})
.collect();
super::sort_scored_desc(&mut results);
if let Some(k) = top_k {
results.truncate(k);
}
results
}
#[must_use]
pub fn refine<I: Clone + Eq + std::hash::Hash>(
candidates: &[(I, f32)],
query: &[Vec<f32>],
docs: &[(I, Vec<Vec<f32>>)],
alpha: f32,
) -> Vec<(I, f32)> {
refine_with_config(
candidates,
query,
docs,
RerankConfig::default().with_alpha(alpha),
)
}
#[must_use]
pub fn refine_with_config<I: Clone + Eq + std::hash::Hash>(
candidates: &[(I, f32)],
query: &[Vec<f32>],
docs: &[(I, Vec<Vec<f32>>)],
config: RerankConfig,
) -> Vec<(I, f32)> {
use std::collections::HashMap;
let doc_map: HashMap<&I, &Vec<Vec<f32>>> = docs.iter().map(|(id, toks)| (id, toks)).collect();
let query_refs = super::simd::as_slices(query);
let alpha = config.alpha;
let mut results: Vec<(I, f32)> = candidates
.iter()
.filter_map(|(id, orig_score)| {
let doc_tokens = doc_map.get(id)?;
let doc_refs = super::simd::as_slices(doc_tokens);
let maxsim_score = simd::maxsim(&query_refs, &doc_refs);
let blended = (1.0 - alpha).mul_add(maxsim_score, alpha * orig_score);
Some((id.clone(), blended))
})
.collect();
super::sort_scored_desc(&mut results);
if let Some(k) = config.top_k {
results.truncate(k);
}
results
}
#[must_use]
pub fn alignments(query: &[Vec<f32>], doc: &[Vec<f32>]) -> Vec<(usize, usize, f32)> {
simd::maxsim_alignments_vecs(query, doc)
}
#[must_use]
pub fn highlight(query: &[Vec<f32>], doc: &[Vec<f32>], threshold: f32) -> Vec<usize> {
simd::highlight_matches_vecs(query, doc, threshold)
}
#[derive(Debug, Clone)]
pub struct TokenIndex<I> {
entries: Vec<(I, Vec<Vec<f32>>)>,
}
impl<I> TokenIndex<I> {
#[must_use]
pub fn new(entries: Vec<(I, Vec<Vec<f32>>)>) -> Self {
Self { entries }
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &(I, Vec<Vec<f32>>)> {
self.entries.iter()
}
#[must_use]
pub fn entries(&self) -> &[(I, Vec<Vec<f32>>)] {
&self.entries
}
#[must_use]
pub fn into_entries(self) -> Vec<(I, Vec<Vec<f32>>)> {
self.entries
}
}
impl<I: Clone> TokenIndex<I> {
#[must_use]
pub fn score_all(&self, query: &[Vec<f32>]) -> Vec<(I, f32)> {
let query_refs = super::simd::as_slices(query);
self.entries
.iter()
.map(|(id, doc_tokens)| {
let doc_refs = super::simd::as_slices(doc_tokens);
(id.clone(), simd::maxsim(&query_refs, &doc_refs))
})
.collect()
}
#[must_use]
pub fn score_all_cosine(&self, query: &[Vec<f32>]) -> Vec<(I, f32)> {
let query_refs = super::simd::as_slices(query);
self.entries
.iter()
.map(|(id, doc_tokens)| {
let doc_refs = super::simd::as_slices(doc_tokens);
(id.clone(), simd::maxsim_cosine(&query_refs, &doc_refs))
})
.collect()
}
#[must_use]
pub fn rank(&self, query: &[Vec<f32>]) -> Vec<(I, f32)> {
let mut results = self.score_all(query);
super::sort_scored_desc(&mut results);
results
}
#[must_use]
pub fn top_k(&self, query: &[Vec<f32>], k: usize) -> Vec<(I, f32)> {
let mut results = self.score_all(query);
super::sort_scored_desc(&mut results);
results.truncate(k);
results
}
#[must_use]
pub fn top_k_cosine(&self, query: &[Vec<f32>], k: usize) -> Vec<(I, f32)> {
let mut results = self.score_all_cosine(query);
super::sort_scored_desc(&mut results);
results.truncate(k);
results
}
}
impl<I: Clone + Eq + std::hash::Hash> TokenIndex<I> {
#[must_use]
pub fn get(&self, id: &I) -> Option<&Vec<Vec<f32>>> {
self.entries
.iter()
.find(|(entry_id, _)| entry_id == id)
.map(|(_, tokens)| tokens)
}
#[must_use]
pub fn contains(&self, id: &I) -> bool {
self.get(id).is_some()
}
}
impl<I> Default for TokenIndex<I> {
fn default() -> Self {
Self {
entries: Vec::new(),
}
}
}
impl<I> FromIterator<(I, Vec<Vec<f32>>)> for TokenIndex<I> {
fn from_iter<T: IntoIterator<Item = (I, Vec<Vec<f32>>)>>(iter: T) -> Self {
Self {
entries: iter.into_iter().collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rank() {
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let docs = vec![
("d1", vec![vec![1.0, 0.0], vec![0.0, 1.0]]),
("d2", vec![vec![0.5, 0.5]]),
];
let ranked = rank(&query, &docs);
assert_eq!(ranked[0].0, "d1");
}
#[test]
fn test_maxsim_with_top_k() {
let query = vec![vec![1.0, 0.0]];
let docs = vec![
("d1", vec![vec![1.0, 0.0]]),
("d2", vec![vec![0.9, 0.1]]),
("d3", vec![vec![0.8, 0.2]]),
];
let ranked = maxsim_with_top_k(&query, &docs, Some(2));
assert_eq!(ranked.len(), 2);
}
#[test]
fn test_maxsim_empty_query() {
let query: Vec<Vec<f32>> = vec![];
let docs = vec![("d1", vec![vec![1.0, 0.0]])];
let ranked = rank(&query, &docs);
assert_eq!(ranked[0].0, "d1");
assert_eq!(ranked[0].1, 0.0);
}
#[test]
fn test_maxsim_empty_docs() {
let query = vec![vec![1.0, 0.0]];
let docs: Vec<(&str, Vec<Vec<f32>>)> = vec![("d1", vec![])];
let ranked = rank(&query, &docs);
assert_eq!(ranked[0].0, "d1");
assert_eq!(ranked[0].1, 0.0);
}
#[test]
fn test_refine() {
let candidates = vec![("d1", 0.5), ("d2", 0.9)];
let query = vec![vec![1.0, 0.0]];
let docs = vec![("d1", vec![vec![1.0, 0.0]]), ("d2", vec![vec![0.0, 1.0]])];
let refined = refine(&candidates, &query, &docs, 0.0);
assert_eq!(refined[0].0, "d1");
let refined = refine(&candidates, &query, &docs, 1.0);
assert_eq!(refined[0].0, "d2");
}
#[test]
fn test_refine_with_config_top_k() {
let candidates = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let query = vec![vec![1.0, 0.0]];
let docs = vec![
("d1", vec![vec![1.0, 0.0]]),
("d2", vec![vec![1.0, 0.0]]),
("d3", vec![vec![1.0, 0.0]]),
];
let refined = refine_with_config(
&candidates,
&query,
&docs,
RerankConfig::default().with_top_k(2),
);
assert_eq!(refined.len(), 2);
}
#[test]
fn test_refine_missing_doc() {
let candidates = vec![("d1", 0.9), ("d2", 0.8)];
let query = vec![vec![1.0, 0.0]];
let docs = vec![("d1", vec![vec![1.0, 0.0]])];
let refined = refine(&candidates, &query, &docs, 0.5);
assert_eq!(refined.len(), 1);
assert_eq!(refined[0].0, "d1");
}
#[test]
fn test_nan_score_handling() {
let candidates = vec![("d1", f32::NAN), ("d2", 0.5)];
let query = vec![vec![1.0, 0.0]];
let docs = vec![("d1", vec![vec![1.0, 0.0]]), ("d2", vec![vec![1.0, 0.0]])];
let refined = refine(&candidates, &query, &docs, 0.5);
assert_eq!(refined.len(), 2);
assert!(refined[0].1.is_nan());
}
#[test]
fn test_pool_tokens_empty() {
let tokens: Vec<Vec<f32>> = vec![];
assert!(pool_tokens(&tokens, 2).unwrap().is_empty());
}
#[test]
fn test_pool_tokens_factor_one() {
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let pooled = pool_tokens(&tokens, 1).unwrap();
assert_eq!(pooled.len(), tokens.len());
}
#[test]
fn test_pool_tokens_reduces_count() {
let tokens = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.9, 0.1],
];
let pooled = pool_tokens(&tokens, 2).unwrap();
assert!(pooled.len() <= 2);
assert!(!pooled.is_empty());
}
#[test]
fn test_pool_tokens_sequential() {
let tokens = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![0.5, 0.5],
vec![0.3, 0.7],
];
let pooled = pool_tokens_sequential(&tokens, 2).unwrap();
assert_eq!(pooled.len(), 2);
assert!((pooled[0][0] - 0.5).abs() < 1e-5);
assert!((pooled[0][1] - 0.5).abs() < 1e-5);
}
#[test]
fn test_pool_tokens_with_protected() {
let tokens = vec![
vec![0.0, 0.0, 1.0],
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
];
let pooled = pool_tokens_with_protected(&tokens, 2, 1).unwrap();
assert_eq!(pooled[0], vec![0.0, 0.0, 1.0]);
assert!(pooled.len() >= 2);
}
#[test]
fn test_pool_tokens_adaptive_low_factor() {
let tokens = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.9, 0.1, 0.0],
];
let pooled = pool_tokens_adaptive(&tokens, 2).unwrap();
assert_eq!(pooled.len(), 2);
}
#[test]
fn test_pool_tokens_adaptive_high_factor() {
let tokens: Vec<Vec<f32>> = (0..8)
.map(|i| vec![(i as f32) * 0.1, 0.0, 0.0, 0.0])
.collect();
let pooled = pool_tokens_adaptive(&tokens, 4).unwrap();
assert_eq!(pooled.len(), 2); }
#[test]
fn test_pool_tokens_adaptive_factor_one() {
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let pooled = pool_tokens_adaptive(&tokens, 1).unwrap();
assert_eq!(pooled.len(), 2); }
#[test]
fn test_pool_tokens_adaptive_empty() {
let tokens: Vec<Vec<f32>> = vec![];
let pooled = pool_tokens_adaptive(&tokens, 2).unwrap();
assert!(pooled.is_empty());
}
#[test]
fn test_pool_factor_zero_returns_error() {
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
assert!(pool_tokens(&tokens, 0).is_err());
}
#[test]
fn test_pool_tokens_sequential_window_zero_returns_error() {
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
assert!(pool_tokens_sequential(&tokens, 0).is_err());
}
#[test]
fn test_pool_tokens_adaptive_factor_zero_returns_error() {
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
assert!(pool_tokens_adaptive(&tokens, 0).is_err());
}
#[test]
fn test_pool_tokens_with_protected_factor_zero_returns_error() {
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
assert!(pool_tokens_with_protected(&tokens, 0, 0).is_err());
}
#[test]
fn test_pool_tokens_factor_larger_than_count() {
let tokens = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let pooled = pool_tokens(&tokens, 10).unwrap();
assert!(!pooled.is_empty(), "Should return at least one token");
assert!(pooled.len() <= 3, "Should not exceed original count");
}
#[test]
fn test_pool_tokens_sequential_factor_larger_than_count() {
let tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let pooled = pool_tokens_sequential(&tokens, 10).unwrap();
assert_eq!(pooled.len(), 1);
}
#[test]
fn test_pooling_methods_produce_same_dimensions() {
let tokens: Vec<Vec<f32>> = (0..8)
.map(|i| {
(0..16)
.map(|j| ((i * 16 + j) as f32 * 0.01).sin())
.collect()
})
.collect();
let greedy = pool_tokens(&tokens, 2).unwrap();
let sequential = pool_tokens_sequential(&tokens, 2).unwrap();
let adaptive = pool_tokens_adaptive(&tokens, 2).unwrap();
assert!(greedy.iter().all(|v| v.len() == 16));
assert!(sequential.iter().all(|v| v.len() == 16));
assert!(adaptive.iter().all(|v| v.len() == 16));
}
#[test]
fn test_maxsim_with_pooled_tokens() {
let query = [vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
let doc = vec![
vec![0.9, 0.1, 0.0, 0.0],
vec![0.8, 0.2, 0.0, 0.0],
vec![0.1, 0.9, 0.0, 0.0],
vec![0.2, 0.8, 0.0, 0.0],
];
let q_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = doc.iter().map(Vec::as_slice).collect();
let score_original = super::simd::maxsim(&q_refs, &d_refs);
let pooled = pool_tokens(&doc, 2).unwrap();
let p_refs: Vec<&[f32]> = pooled.iter().map(Vec::as_slice).collect();
let score_pooled = super::simd::maxsim(&q_refs, &p_refs);
assert!(score_pooled > 0.0);
assert!(score_pooled.is_finite());
assert!(
score_pooled >= score_original * 0.5,
"pooled {score_pooled} vs original {score_original}"
);
}
#[cfg(not(feature = "hierarchical"))]
#[test]
fn pool_greedy_exact_count() {
let tokens = vec![
vec![1.0, 0.0],
vec![0.9, 0.1],
vec![0.0, 1.0],
vec![0.1, 0.9],
];
let pooled = pool_tokens_greedy(&tokens, 2);
assert_eq!(pooled.len(), 2);
}
#[test]
fn pool_sequential_exact_count() {
let tokens: Vec<Vec<f32>> = (0..8).map(|i| vec![i as f32]).collect();
let pooled = pool_tokens_sequential(&tokens, 2).unwrap();
assert_eq!(pooled.len(), 4);
}
#[test]
fn refine_alpha_zero_ignores_original() {
let query = vec![vec![1.0, 0.0]];
let candidates = vec![("d1", 100.0), ("d2", 0.0)];
let docs = vec![
("d1", vec![vec![0.0, 1.0]]), ("d2", vec![vec![1.0, 0.0]]), ];
let config = super::RerankConfig::default().with_alpha(0.0);
let refined = refine_with_config(&candidates, &query, &docs, config);
assert_eq!(refined[0].0, "d2", "alpha=0 should rank by maxsim only");
}
#[test]
fn refine_alpha_one_ignores_maxsim() {
let query = vec![vec![1.0, 0.0]];
let candidates = vec![("d1", 1.0), ("d2", 0.5)];
let docs = vec![
("d1", vec![vec![0.0, 1.0]]), ("d2", vec![vec![1.0, 0.0]]), ];
let config = super::RerankConfig::default().with_alpha(1.0);
let refined = refine_with_config(&candidates, &query, &docs, config);
assert_eq!(refined[0].0, "d1", "alpha=1 should rank by original only");
}
#[cfg(feature = "hierarchical")]
#[test]
fn hierarchical_returns_target_count() {
let tokens: Vec<Vec<f32>> = (0..8).map(|i| vec![(i as f32 * 0.1).sin(); 16]).collect();
let pooled = pool_tokens_hierarchical(&tokens, 4);
assert_eq!(pooled.len(), 4);
}
#[cfg(feature = "hierarchical")]
#[test]
fn cut_dendrogram_produces_correct_clusters() {
use kodama::{linkage, Method};
let tokens = [
vec![1.0, 0.0, 0.0, 0.0], vec![0.9, 0.1, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0], vec![0.1, 0.9, 0.0, 0.0], vec![0.0, 0.0, 1.0, 0.0], vec![0.0, 0.0, 0.9, 0.1], ];
let n = tokens.len();
let mut condensed = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let sim = super::simd::cosine(&tokens[i], &tokens[j]);
let dist = f64::from((1.0 - sim).clamp(0.0, 2.0));
condensed.push(dist);
}
}
let dendrogram = linkage(&mut condensed, n, Method::Ward);
let labels = cut_dendrogram(&dendrogram, n, 3);
let unique_labels: std::collections::HashSet<_> = labels.iter().collect();
assert_eq!(
unique_labels.len(),
3,
"Expected 3 clusters, got labels: {:?}",
labels
);
assert_eq!(labels[0], labels[1], "Tokens 0,1 should be in same cluster");
assert_eq!(labels[2], labels[3], "Tokens 2,3 should be in same cluster");
assert_eq!(labels[4], labels[5], "Tokens 4,5 should be in same cluster");
assert_ne!(labels[0], labels[2], "Groups A,B should differ");
assert_ne!(labels[2], labels[4], "Groups B,C should differ");
}
#[cfg(feature = "hierarchical")]
#[test]
fn hierarchical_clusters_similar_tokens() {
let tokens = vec![
vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0], ];
let pooled = pool_tokens_hierarchical(&tokens, 3);
assert_eq!(pooled.len(), 3);
let target = vec![0.0, 1.0, 0.0, 0.0];
let max_sim = pooled
.iter()
.map(|p| super::simd::cosine(p, &target))
.fold(f32::NEG_INFINITY, f32::max);
assert!(
max_sim > 0.99,
"Expected pooled vector near [0,1,0,0], best sim: {}",
max_sim
);
}
#[test]
fn token_index_new_and_len() {
let index: TokenIndex<&str> = TokenIndex::new(vec![
("doc1", vec![vec![1.0, 0.0]]),
("doc2", vec![vec![0.0, 1.0]]),
]);
assert_eq!(index.len(), 2);
assert!(!index.is_empty());
}
#[test]
fn token_index_empty() {
let index: TokenIndex<&str> = TokenIndex::default();
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn token_index_score_all() {
let index = TokenIndex::new(vec![
("doc1", vec![vec![1.0, 0.0]]),
("doc2", vec![vec![0.0, 1.0]]),
]);
let query = vec![vec![1.0, 0.0]];
let scores = index.score_all(&query);
assert_eq!(scores.len(), 2);
let doc1_score = scores.iter().find(|(id, _)| *id == "doc1").unwrap().1;
let doc2_score = scores.iter().find(|(id, _)| *id == "doc2").unwrap().1;
assert!((doc1_score - 1.0).abs() < 1e-5);
assert!(doc2_score.abs() < 1e-5);
}
#[test]
fn token_index_rank() {
let index = TokenIndex::new(vec![
("doc1", vec![vec![1.0, 0.0], vec![0.0, 1.0]]),
("doc2", vec![vec![0.5, 0.5]]),
("doc3", vec![vec![0.9, 0.1]]),
]);
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let ranked = index.rank(&query);
assert_eq!(ranked.len(), 3);
assert_eq!(ranked[0].0, "doc1");
}
#[test]
fn token_index_top_k() {
let index = TokenIndex::new(vec![
("doc1", vec![vec![1.0, 0.0]]),
("doc2", vec![vec![0.9, 0.1]]),
("doc3", vec![vec![0.8, 0.2]]),
]);
let query = vec![vec![1.0, 0.0]];
let top2 = index.top_k(&query, 2);
assert_eq!(top2.len(), 2);
assert_eq!(top2[0].0, "doc1");
assert_eq!(top2[1].0, "doc2");
}
#[test]
fn token_index_top_k_larger_than_size() {
let index = TokenIndex::new(vec![("doc1", vec![vec![1.0, 0.0]])]);
let query = vec![vec![1.0, 0.0]];
let top10 = index.top_k(&query, 10);
assert_eq!(top10.len(), 1);
}
#[test]
fn token_index_get() {
let index = TokenIndex::new(vec![
("doc1", vec![vec![1.0, 0.0]]),
("doc2", vec![vec![0.0, 1.0]]),
]);
assert!(index.get(&"doc1").is_some());
assert!(index.get(&"doc2").is_some());
assert!(index.get(&"doc3").is_none());
}
#[test]
fn token_index_contains() {
let index = TokenIndex::new(vec![("doc1", vec![vec![1.0, 0.0]])]);
assert!(index.contains(&"doc1"));
assert!(!index.contains(&"doc2"));
}
#[test]
fn token_index_from_iter() {
let entries = vec![
("doc1", vec![vec![1.0, 0.0]]),
("doc2", vec![vec![0.0, 1.0]]),
];
let index: TokenIndex<&str> = entries.into_iter().collect();
assert_eq!(index.len(), 2);
}
#[test]
fn token_index_iter() {
let index = TokenIndex::new(vec![
("doc1", vec![vec![1.0, 0.0]]),
("doc2", vec![vec![0.0, 1.0]]),
]);
let ids: Vec<_> = index.iter().map(|(id, _)| *id).collect();
assert_eq!(ids.len(), 2);
assert!(ids.contains(&"doc1"));
assert!(ids.contains(&"doc2"));
}
#[test]
fn token_index_into_entries() {
let index = TokenIndex::new(vec![("doc1", vec![vec![1.0, 0.0]])]);
let entries = index.into_entries();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].0, "doc1");
}
#[test]
fn token_index_entries_returns_slice() {
let index = TokenIndex::new(vec![
("doc1", vec![vec![1.0, 0.0]]),
("doc2", vec![vec![0.0, 1.0]]),
]);
let entries = index.entries();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].0, "doc1");
assert_eq!(entries[1].0, "doc2");
assert_eq!(entries[0].1.len(), 1);
assert_eq!(entries[0].1[0], vec![1.0, 0.0]);
}
#[test]
fn token_index_score_all_cosine() {
let index = TokenIndex::new(vec![
("doc1", vec![vec![2.0, 0.0]]), ("doc2", vec![vec![0.0, 2.0]]), ]);
let query = vec![vec![1.0, 0.0]];
let scores = index.score_all_cosine(&query);
let doc1_score = scores.iter().find(|(id, _)| *id == "doc1").unwrap().1;
assert!((doc1_score - 1.0).abs() < 1e-5);
}
#[test]
fn token_index_matches_maxsim_function() {
let docs = vec![
("d1", vec![vec![1.0, 0.0], vec![0.0, 1.0]]),
("d2", vec![vec![0.5, 0.5]]),
("d3", vec![vec![0.9, 0.1]]),
];
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let index = TokenIndex::new(docs.clone());
let index_ranked = index.rank(&query);
let func_ranked = rank(&query, &docs);
assert_eq!(index_ranked.len(), func_ranked.len());
for (a, b) in index_ranked.iter().zip(func_ranked.iter()) {
assert_eq!(a.0, b.0);
assert!((a.1 - b.1).abs() < 1e-6);
}
}
}
#[cfg(test)]
mod proptests {
use super::*;
use crate::rerank::RerankError;
use proptest::prelude::*;
proptest! {
#[test]
fn maxsim_preserves_doc_count(n_docs in 1usize..5, n_query_tok in 1usize..4, dim in 2usize..8) {
let query = (0..n_query_tok)
.map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
.collect::<Vec<Vec<f32>>>();
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| {
let toks = (0..2)
.map(|t| (0..dim).map(|j| (i as usize + t + j) as f32 * 0.1).collect())
.collect();
(i, toks)
})
.collect();
let ranked = rank(&query, &docs);
prop_assert_eq!(ranked.len(), n_docs);
}
#[test]
fn maxsim_sorted_descending(n_docs in 2usize..6, dim in 2usize..6) {
let query = vec![(0..dim).map(|i| i as f32 * 0.1).collect::<Vec<f32>>()];
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| {
let toks = vec![(0..dim).map(|j| (i as usize + j) as f32 * 0.1).collect()];
(i, toks)
})
.collect();
let ranked = rank(&query, &docs);
for window in ranked.windows(2) {
prop_assert!(window[0].1 >= window[1].1);
}
}
#[test]
fn refine_output_bounded(n_cand in 1usize..5, n_docs in 0usize..5, dim in 2usize..6) {
let candidates: Vec<(u32, f32)> = (0..n_cand as u32)
.map(|i| (i, 1.0 - i as f32 * 0.1))
.collect();
let query = vec![(0..dim).map(|i| i as f32 * 0.1).collect::<Vec<f32>>()];
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| {
let toks = vec![(0..dim).map(|j| (i as usize + j) as f32 * 0.1).collect()];
(i, toks)
})
.collect();
let refined = refine(&candidates, &query, &docs, 0.5);
prop_assert!(refined.len() <= candidates.len());
prop_assert!(refined.len() <= docs.len());
}
#[test]
fn pool_reduces_count(n_tokens in 4usize..20, dim in 2usize..8, pool_factor in 2usize..4) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i + j) as f32 * 0.1).sin()).collect())
.collect();
let pooled = pool_tokens(&tokens, pool_factor).unwrap();
let expected_max = (n_tokens / pool_factor).max(1);
prop_assert!(pooled.len() <= expected_max + 1);
prop_assert!(!pooled.is_empty());
}
#[test]
fn sequential_pool_exact_count(n_tokens in 2usize..20, dim in 2usize..8, window in 2usize..4) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
.collect();
let pooled = pool_tokens_sequential(&tokens, window).unwrap();
let expected = n_tokens.div_ceil(window);
prop_assert_eq!(pooled.len(), expected);
}
#[test]
fn pool_preserves_dimension(n_tokens in 2usize..10, dim in 2usize..16, pool_factor in 2usize..4) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
.collect();
let pooled = pool_tokens(&tokens, pool_factor).unwrap();
for tok in &pooled {
prop_assert_eq!(tok.len(), dim, "Dimension mismatch: expected {}, got {}", dim, tok.len());
}
}
#[test]
fn sequential_pool_preserves_dimension(n_tokens in 2usize..10, dim in 2usize..16, window in 2usize..4) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
.collect();
let pooled = pool_tokens_sequential(&tokens, window).unwrap();
for tok in &pooled {
prop_assert_eq!(tok.len(), dim, "Dimension mismatch: expected {}, got {}", dim, tok.len());
}
}
#[test]
fn pool_factor_one_identity(n_tokens in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
.collect();
let pooled = pool_tokens(&tokens, 1).unwrap();
prop_assert_eq!(pooled.len(), n_tokens, "Factor 1 should preserve count");
}
#[test]
fn protected_tokens_preserved(n_tokens in 3usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
.collect();
let pooled = pool_tokens_with_protected(&tokens, 2, 2).unwrap();
prop_assert!(pooled.len() >= 2, "Should have at least protected tokens");
prop_assert_eq!(&pooled[0], &tokens[0], "First protected token should be preserved");
prop_assert_eq!(&pooled[1], &tokens[1], "Second protected token should be preserved");
}
#[test]
fn pool_maintains_score_quality(dim in 8usize..16) {
let query: Vec<Vec<f32>> = (0..4)
.map(|i| (0..dim).map(|j| if i == j % 4 { 1.0 } else { 0.1 }).collect())
.collect();
let doc: Vec<Vec<f32>> = (0..8)
.map(|i| (0..dim).map(|j| if i % 4 == j % 4 { 1.0 } else { 0.1 }).collect())
.collect();
let query_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let doc_refs: Vec<&[f32]> = doc.iter().map(Vec::as_slice).collect();
let original_score = super::simd::maxsim(&query_refs, &doc_refs);
let pooled = pool_tokens(&doc, 2).unwrap();
let pooled_refs: Vec<&[f32]> = pooled.iter().map(Vec::as_slice).collect();
let pooled_score = super::simd::maxsim(&query_refs, &pooled_refs);
prop_assert!(
pooled_score >= original_score * 0.5,
"Score dropped too much: {} -> {}",
original_score,
pooled_score
);
}
#[test]
fn refine_alpha_one_preserves_order(n_cand in 2usize..5) {
let candidates: Vec<(u32, f32)> = (0..n_cand as u32)
.map(|i| (i, 10.0 - i as f32)) .collect();
let query = vec![vec![1.0f32; 4]];
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_cand as u32)
.map(|i| (i, vec![vec![0.5f32; 4]]))
.collect();
let refined = refine(&candidates, &query, &docs, 1.0);
for (i, (id, _)) in refined.iter().enumerate() {
prop_assert_eq!(*id, i as u32, "Order not preserved at index {}", i);
}
}
#[test]
fn pool_preserves_vector_mass(n_tokens in 4usize..16, dim in 4usize..16) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let orig_total_norm: f32 = tokens
.iter()
.map(|t| super::simd::norm(t))
.sum();
let pooled = pool_tokens(&tokens, 2).unwrap();
let pooled_total_norm: f32 = pooled
.iter()
.map(|t| super::simd::norm(t))
.sum();
prop_assert!(
pooled_total_norm > 0.0,
"Pooled vectors should have positive norm"
);
prop_assert!(
pooled_total_norm <= orig_total_norm * 1.1,
"Pooled norm {} too large vs original {}",
pooled_total_norm,
orig_total_norm
);
}
#[test]
fn duplicate_tokens_cluster(dim in 4usize..16) {
let base_token: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
let mut tokens = vec![base_token.clone(); 4];
for i in 0..4 {
tokens.push((0..dim).map(|j| ((i * 10 + j) as f32 * 0.3).cos()).collect());
}
let pooled = pool_tokens(&tokens, 2).unwrap();
let max_sim = pooled
.iter()
.map(|p| super::simd::cosine(p, &base_token))
.fold(f32::NEG_INFINITY, f32::max);
prop_assert!(
max_sim > 0.95,
"Duplicate tokens didn't cluster well: max_sim = {}",
max_sim
);
}
#[test]
fn maxsim_not_commutative(dim in 4usize..8) {
let a: Vec<Vec<f32>> = vec![
(0..dim).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect(),
];
let b: Vec<Vec<f32>> = vec![
(0..dim).map(|_| 0.5).collect(),
(0..dim).map(|i| if i == 1 { 1.0 } else { 0.0 }).collect(),
];
let a_refs: Vec<&[f32]> = a.iter().map(Vec::as_slice).collect();
let b_refs: Vec<&[f32]> = b.iter().map(Vec::as_slice).collect();
let score_ab = super::simd::maxsim(&a_refs, &b_refs);
let score_ba = super::simd::maxsim(&b_refs, &a_refs);
prop_assert!(score_ab.is_finite() && score_ba.is_finite());
}
#[test]
fn more_doc_tokens_higher_score(dim in 4usize..8) {
let query: Vec<Vec<f32>> = vec![
(0..dim).map(|i| (i as f32 * 0.1).sin()).collect(),
(0..dim).map(|i| (i as f32 * 0.2).cos()).collect(),
];
let doc_small: Vec<Vec<f32>> = vec![
(0..dim).map(|i| (i as f32 * 0.15).sin()).collect(),
];
let mut doc_large = doc_small.clone();
doc_large.push((0..dim).map(|i| (i as f32 * 0.25).cos()).collect());
let q_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let small_refs: Vec<&[f32]> = doc_small.iter().map(Vec::as_slice).collect();
let large_refs: Vec<&[f32]> = doc_large.iter().map(Vec::as_slice).collect();
let score_small = super::simd::maxsim(&q_refs, &small_refs);
let score_large = super::simd::maxsim(&q_refs, &large_refs);
prop_assert!(
score_large >= score_small - 1e-6,
"More tokens should help: {} vs {}",
score_large,
score_small
);
}
#[test]
fn pooling_idempotent_at_target(dim in 4usize..8) {
let tokens: Vec<Vec<f32>> = (0..8)
.map(|i| (0..dim).map(|j| ((i + j) as f32 * 0.1).sin()).collect())
.collect();
let pooled_once = pool_tokens(&tokens, 2).unwrap(); let pooled_twice = pool_tokens(&pooled_once, 1).unwrap();
prop_assert_eq!(
pooled_once.len(),
pooled_twice.len(),
"Second pool changed count"
);
for (a, b) in pooled_once.iter().zip(pooled_twice.iter()) {
let sim = super::simd::cosine(a, b);
prop_assert!(
sim > 0.999,
"Pool not idempotent: similarity = {}",
sim
);
}
}
#[test]
fn adaptive_uses_clustering_for_low_factors(n_tokens in 4usize..12, dim in 4usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let adaptive = pool_tokens_adaptive(&tokens, 2).unwrap();
let clustering = pool_tokens(&tokens, 2).unwrap();
prop_assert_eq!(adaptive.len(), clustering.len());
}
#[test]
fn adaptive_uses_sequential_for_high_factors(n_tokens in 8usize..20, dim in 4usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let adaptive = pool_tokens_adaptive(&tokens, 4).unwrap();
let sequential = pool_tokens_sequential(&tokens, 4).unwrap();
prop_assert_eq!(adaptive.len(), sequential.len());
for (a, s) in adaptive.iter().zip(sequential.iter()) {
prop_assert_eq!(a, s);
}
}
#[test]
fn all_pooling_methods_preserve_dim(n_tokens in 4usize..16, dim in 4usize..16) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let p1 = pool_tokens(&tokens, 2).unwrap();
let p2 = pool_tokens_sequential(&tokens, 2).unwrap();
let p3 = pool_tokens_adaptive(&tokens, 2).unwrap();
let p4 = pool_tokens_with_protected(&tokens, 2, 1).unwrap();
prop_assert!(p1.iter().all(|v| v.len() == dim));
prop_assert!(p2.iter().all(|v| v.len() == dim));
prop_assert!(p3.iter().all(|v| v.len() == dim));
prop_assert!(p4.iter().all(|v| v.len() == dim));
}
#[test]
fn pooling_never_increases_count(n_tokens in 1usize..20, factor in 1usize..5) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| vec![(i as f32) * 0.1; 8])
.collect();
let p1 = pool_tokens(&tokens, factor).unwrap();
let p2 = pool_tokens_sequential(&tokens, factor).unwrap();
let p3 = pool_tokens_adaptive(&tokens, factor).unwrap();
prop_assert!(p1.len() <= n_tokens);
prop_assert!(p2.len() <= n_tokens);
prop_assert!(p3.len() <= n_tokens);
}
#[test]
fn empty_tokens_all_methods(factor in 1usize..5) {
let empty: Vec<Vec<f32>> = vec![];
prop_assert!(pool_tokens(&empty, factor).unwrap().is_empty());
prop_assert!(pool_tokens_sequential(&empty, factor).unwrap().is_empty());
prop_assert!(pool_tokens_adaptive(&empty, factor).unwrap().is_empty());
prop_assert!(pool_tokens_with_protected(&empty, factor, 0).unwrap().is_empty());
}
#[test]
fn single_token_unchanged(dim in 2usize..16, factor in 2usize..5) {
let tokens = vec![vec![1.0f32; dim]];
let p1 = pool_tokens(&tokens, factor).unwrap();
let p2 = pool_tokens_sequential(&tokens, factor).unwrap();
let p3 = pool_tokens_adaptive(&tokens, factor).unwrap();
prop_assert_eq!(p1.len(), 1);
prop_assert_eq!(p2.len(), 1);
prop_assert_eq!(p3.len(), 1);
}
#[test]
fn greedy_uses_strict_greater_than(n_tokens in 3usize..8, dim in 4usize..8) {
let mut tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| {
let mut v = vec![0.0f32; dim];
v[0] = i as f32;
v
})
.collect();
tokens[1] = tokens[0].clone();
let pooled = pool_tokens(&tokens, 2).unwrap();
prop_assert!(pooled.len() <= n_tokens, "Should not exceed original count");
prop_assert!(!pooled.is_empty(), "Should have at least one cluster");
}
#[test]
#[cfg(feature = "hierarchical")]
fn hierarchical_filters_empty_clusters(n_tokens in 4usize..10, dim in 4usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| {
let mut v = vec![0.0f32; dim];
v[0] = i as f32;
v
})
.collect();
let pooled = pool_tokens(&tokens, 2).unwrap();
prop_assert!(!pooled.is_empty(), "Should have at least one cluster");
for tok in &pooled {
prop_assert!(!tok.is_empty(), "Pooled token should not be empty");
}
}
#[test]
#[cfg(feature = "hierarchical")]
fn hierarchical_uses_addition_for_cluster_count(n_tokens in 4usize..10, dim in 4usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| {
let mut v = vec![0.0f32; dim];
v[0] = i as f32;
v
})
.collect();
let pooled = pool_tokens(&tokens, 2).unwrap();
prop_assert!(pooled.len() <= n_tokens, "Should not exceed original token count");
prop_assert!(!pooled.is_empty(), "Should have at least one cluster");
}
#[test]
#[cfg(feature = "hierarchical")]
fn hierarchical_uses_subtraction_for_distance(n_tokens in 4usize..8, dim in 4usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| {
let mut v = vec![0.0f32; dim];
v[0] = (i % 2) as f32; v
})
.collect();
let pooled = pool_tokens(&tokens, 2).unwrap();
prop_assert!(pooled.len() <= n_tokens, "Should not exceed original token count");
prop_assert!(!pooled.is_empty(), "Should have at least one cluster");
}
#[test]
#[cfg(feature = "hierarchical")]
fn hierarchical_handles_nan_with_negative_one(n_tokens in 3usize..6, dim in 4usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| {
if i == 0 {
vec![0.0f32; dim] } else {
let mut v = vec![0.0f32; dim];
v[0] = i as f32;
v
}
})
.collect();
let pooled = pool_tokens(&tokens, 2).unwrap();
prop_assert!(!pooled.is_empty(), "Should handle NaN and produce clusters");
}
#[test]
fn maxsim_pooled_finite(n_query in 1usize..4, n_doc in 2usize..8, dim in 4usize..8) {
let query: Vec<Vec<f32>> = (0..n_query)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let doc: Vec<Vec<f32>> = (0..n_doc)
.map(|i| (0..dim).map(|j| ((i * dim + j + 100) as f32 * 0.1).cos()).collect())
.collect();
let pooled = pool_tokens_adaptive(&doc, 2).unwrap();
let q_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let p_refs: Vec<&[f32]> = pooled.iter().map(Vec::as_slice).collect();
let score = super::simd::maxsim(&q_refs, &p_refs);
prop_assert!(score.is_finite(), "`MaxSim` with pooled docs returned {}", score);
}
#[test]
fn token_index_maxsim_sorted(n_docs in 2usize..8, n_query in 1usize..4, dim in 2usize..8) {
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| {
let tokens: Vec<Vec<f32>> = (0..2)
.map(|t| (0..dim).map(|j| ((i as usize * 2 + t + j) as f32 * 0.1).sin()).collect())
.collect();
(i, tokens)
})
.collect();
let query: Vec<Vec<f32>> = (0..n_query)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).cos()).collect())
.collect();
let index = TokenIndex::new(docs);
let ranked = index.rank(&query);
for window in ranked.windows(2) {
prop_assert!(
window[0].1 >= window[1].1 - 1e-6,
"Not sorted: {} >= {}",
window[0].1,
window[1].1
);
}
}
#[test]
fn token_index_top_k_bounded(n_docs in 1usize..10, k in 1usize..5, dim in 2usize..8) {
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| (i, vec![vec![(i as f32 * 0.1).sin(); dim]]))
.collect();
let query = vec![vec![0.5f32; dim]];
let index = TokenIndex::new(docs);
let top = index.top_k(&query, k);
prop_assert!(top.len() <= k.min(n_docs));
}
#[test]
fn token_index_preserves_count(n_docs in 1usize..10, dim in 2usize..8) {
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| (i, vec![vec![(i as f32 * 0.1).sin(); dim]]))
.collect();
let index = TokenIndex::new(docs);
prop_assert_eq!(index.len(), n_docs);
}
#[test]
fn token_index_score_all_count(n_docs in 1usize..10, dim in 2usize..8) {
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| (i, vec![vec![(i as f32 * 0.1).sin(); dim]]))
.collect();
let query = vec![vec![0.5f32; dim]];
let index = TokenIndex::new(docs);
let scores = index.score_all(&query);
prop_assert_eq!(scores.len(), n_docs);
}
#[test]
fn token_index_scores_finite(n_docs in 1usize..5, n_query in 1usize..3, dim in 2usize..8) {
let docs: Vec<(u32, Vec<Vec<f32>>)> = (0..n_docs as u32)
.map(|i| {
let tokens: Vec<Vec<f32>> = (0..2)
.map(|t| (0..dim).map(|j| ((i as usize * 2 + t + j) as f32 * 0.1).sin()).collect())
.collect();
(i, tokens)
})
.collect();
let query: Vec<Vec<f32>> = (0..n_query)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).cos()).collect())
.collect();
let index = TokenIndex::new(docs);
let scores = index.score_all(&query);
for (id, score) in &scores {
prop_assert!(score.is_finite(), "Score for {} is not finite: {}", id, score);
}
}
#[test]
fn pool_tokens_zero_factor_returns_error(n_tokens in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let result = pool_tokens(&tokens, 0);
prop_assert!(result.is_err(), "Should return error for pool_factor = 0");
if let Err(e) = result {
prop_assert!(matches!(e, RerankError::InvalidPoolFactor { pool_factor: 0 }), "Should be InvalidPoolFactor error");
}
}
#[test]
fn pool_tokens_sequential_zero_window_returns_error(n_tokens in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let result = pool_tokens_sequential(&tokens, 0);
prop_assert!(result.is_err(), "Should return error for window_size = 0");
if let Err(e) = result {
prop_assert!(matches!(e, RerankError::InvalidWindowSize { window_size: 0 }), "Should be InvalidWindowSize error");
}
}
#[test]
fn pool_tokens_adaptive_zero_factor_returns_error(n_tokens in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let result = pool_tokens_adaptive(&tokens, 0);
prop_assert!(result.is_err(), "Should return error for pool_factor = 0");
if let Err(e) = result {
prop_assert!(matches!(e, RerankError::InvalidPoolFactor { pool_factor: 0 }), "Should be InvalidPoolFactor error");
}
}
#[test]
fn pool_tokens_with_protected_zero_factor_returns_error(n_tokens in 1usize..10, dim in 2usize..8, protected in 0usize..10) {
prop_assume!(protected < n_tokens);
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let result = pool_tokens_with_protected(&tokens, 0, protected);
prop_assert!(result.is_err(), "Should return error for pool_factor = 0");
if let Err(e) = result {
prop_assert!(matches!(e, RerankError::InvalidPoolFactor { pool_factor: 0 }), "Should be InvalidPoolFactor error");
}
}
#[test]
fn pool_tokens_valid_factor_succeeds(n_tokens in 1usize..20, pool_factor in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let result = pool_tokens(&tokens, pool_factor);
prop_assert!(result.is_ok(), "Should succeed for valid pool_factor");
if let Ok(pooled) = result {
prop_assert!(!pooled.is_empty() || tokens.is_empty(), "Should return non-empty unless input is empty");
prop_assert!(pooled.len() <= tokens.len(), "Pooled should not exceed input length");
}
}
#[test]
fn pool_tokens_sequential_valid_window_succeeds(n_tokens in 1usize..20, window_size in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let result = pool_tokens_sequential(&tokens, window_size);
prop_assert!(result.is_ok(), "Should succeed for valid window_size");
if let Ok(pooled) = result {
prop_assert!(!pooled.is_empty() || tokens.is_empty(), "Should return non-empty unless input is empty");
prop_assert!(pooled.len() <= tokens.len(), "Pooled should not exceed input length");
}
}
#[test]
fn pool_tokens_adaptive_valid_factor_succeeds(n_tokens in 1usize..20, pool_factor in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let result = pool_tokens_adaptive(&tokens, pool_factor);
prop_assert!(result.is_ok(), "Should succeed for valid pool_factor");
if let Ok(pooled) = result {
prop_assert!(!pooled.is_empty() || tokens.is_empty(), "Should return non-empty unless input is empty");
prop_assert!(pooled.len() <= tokens.len(), "Pooled should not exceed input length");
}
}
#[test]
fn pool_tokens_preserves_dimensions(n_tokens in 1usize..20, pool_factor in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
if let Ok(pooled) = pool_tokens(&tokens, pool_factor) {
for pooled_vec in &pooled {
prop_assert_eq!(pooled_vec.len(), dim, "Pooled vector should preserve dimension");
}
}
}
#[test]
fn pool_tokens_sequential_preserves_dimensions(n_tokens in 1usize..20, window_size in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
if let Ok(pooled) = pool_tokens_sequential(&tokens, window_size) {
for pooled_vec in &pooled {
prop_assert_eq!(pooled_vec.len(), dim, "Pooled vector should preserve dimension");
}
}
}
#[test]
fn pool_tokens_adaptive_preserves_dimensions(n_tokens in 1usize..20, pool_factor in 1usize..10, dim in 2usize..8) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
if let Ok(pooled) = pool_tokens_adaptive(&tokens, pool_factor) {
for pooled_vec in &pooled {
prop_assert_eq!(pooled_vec.len(), dim, "Pooled vector should preserve dimension");
}
}
}
#[test]
fn pool_tokens_empty_input(pool_factor in 1usize..10) {
let tokens: Vec<Vec<f32>> = vec![];
let result = pool_tokens(&tokens, pool_factor);
prop_assert!(result.is_ok(), "Should succeed for empty input");
if let Ok(pooled) = result {
prop_assert_eq!(pooled.len(), 0, "Should return empty for empty input");
}
}
#[test]
fn pool_tokens_sequential_empty_input(window_size in 1usize..10) {
let tokens: Vec<Vec<f32>> = vec![];
let result = pool_tokens_sequential(&tokens, window_size);
prop_assert!(result.is_ok(), "Should succeed for empty input");
if let Ok(pooled) = result {
prop_assert_eq!(pooled.len(), 0, "Should return empty for empty input");
}
}
#[test]
fn pool_tokens_adaptive_empty_input(pool_factor in 1usize..10) {
let tokens: Vec<Vec<f32>> = vec![];
let result = pool_tokens_adaptive(&tokens, pool_factor);
prop_assert!(result.is_ok(), "Should succeed for empty input");
if let Ok(pooled) = result {
prop_assert_eq!(pooled.len(), 0, "Should return empty for empty input");
}
}
}
}