#[cfg(not(feature = "innr"))]
const MIN_DIM_SIMD: usize = 16;
#[cfg(not(feature = "innr"))]
const NORM_EPSILON: f32 = 1e-9;
#[cfg(feature = "innr")]
pub use innr::{cosine, dot, maxsim, maxsim_cosine, norm};
#[cfg(not(feature = "innr"))]
mod fallback {
use super::{MIN_DIM_SIMD, NORM_EPSILON};
#[inline]
#[must_use]
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(
a.len(),
b.len(),
"dot: dimension mismatch ({} vs {})",
a.len(),
b.len()
);
let n = a.len().min(b.len());
#[cfg(target_arch = "x86_64")]
{
if n >= MIN_DIM_SIMD && is_x86_feature_detected!("avx512f") {
return unsafe { super::dot_avx512(a, b) };
}
if n >= MIN_DIM_SIMD
&& is_x86_feature_detected!("avx2")
&& is_x86_feature_detected!("fma")
{
return unsafe { super::dot_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
if n >= MIN_DIM_SIMD {
return unsafe { super::dot_neon(a, b) };
}
}
#[allow(unreachable_code)]
dot_portable(a, b)
}
#[inline]
#[must_use]
pub fn dot_portable(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
#[must_use]
pub fn norm(v: &[f32]) -> f32 {
dot(v, v).sqrt()
}
#[inline]
#[must_use]
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
let d = dot(a, b);
let na = norm(a);
let nb = norm(b);
if na > NORM_EPSILON && nb > NORM_EPSILON {
d / (na * nb)
} else {
0.0
}
}
#[inline]
#[must_use]
pub fn maxsim(query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
query_tokens
.iter()
.map(|q| {
doc_tokens
.iter()
.map(|d| dot(q, d))
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
#[inline]
#[must_use]
pub fn maxsim_cosine(query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
query_tokens
.iter()
.map(|q| {
doc_tokens
.iter()
.map(|d| cosine(q, d))
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
}
#[cfg(not(feature = "innr"))]
pub use fallback::{cosine, dot, maxsim, maxsim_cosine, norm};
#[must_use]
pub fn maxsim_alignments(
query_tokens: &[&[f32]],
doc_tokens: &[&[f32]],
) -> Vec<(usize, usize, f32)> {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return Vec::new();
}
query_tokens
.iter()
.enumerate()
.map(|(q_idx, q)| {
let (best_doc_idx, best_score) = doc_tokens
.iter()
.enumerate()
.map(|(d_idx, d)| (d_idx, dot(q, d)))
.fold(
(0, f32::NEG_INFINITY),
|(best_idx, best_score), (idx, score)| {
if score > best_score {
(idx, score)
} else {
(best_idx, best_score)
}
},
);
(q_idx, best_doc_idx, best_score)
})
.collect()
}
#[must_use]
pub fn maxsim_alignments_cosine(
query_tokens: &[&[f32]],
doc_tokens: &[&[f32]],
) -> Vec<(usize, usize, f32)> {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return Vec::new();
}
query_tokens
.iter()
.enumerate()
.map(|(q_idx, q)| {
let (best_doc_idx, best_score) = doc_tokens
.iter()
.enumerate()
.map(|(d_idx, d)| (d_idx, cosine(q, d)))
.fold(
(0, f32::NEG_INFINITY),
|(best_idx, best_score), (idx, score)| {
if score > best_score {
(idx, score)
} else {
(best_idx, best_score)
}
},
);
(q_idx, best_doc_idx, best_score)
})
.collect()
}
#[must_use]
pub fn highlight_matches(
query_tokens: &[&[f32]],
doc_tokens: &[&[f32]],
threshold: f32,
) -> Vec<usize> {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return Vec::new();
}
let mut matched_indices = std::collections::HashSet::new();
for q in query_tokens {
for (d_idx, d) in doc_tokens.iter().enumerate() {
let similarity = dot(q, d);
if similarity >= threshold {
matched_indices.insert(d_idx);
}
}
}
let mut result: Vec<usize> = matched_indices.into_iter().collect();
result.sort_unstable();
result
}
#[inline]
#[must_use]
pub fn maxsim_weighted(query_tokens: &[&[f32]], doc_tokens: &[&[f32]], weights: &[f32]) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
query_tokens
.iter()
.enumerate()
.map(|(i, q)| {
let w = weights.get(i).copied().unwrap_or(1.0);
let max_sim = doc_tokens
.iter()
.map(|d| dot(q, d))
.fold(f32::NEG_INFINITY, f32::max);
w * max_sim
})
.sum()
}
#[inline]
#[must_use]
pub fn maxsim_cosine_weighted(
query_tokens: &[&[f32]],
doc_tokens: &[&[f32]],
weights: &[f32],
) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
query_tokens
.iter()
.enumerate()
.map(|(i, q)| {
let w = weights.get(i).copied().unwrap_or(1.0);
let max_sim = doc_tokens
.iter()
.map(|d| cosine(q, d))
.fold(f32::NEG_INFINITY, f32::max);
w * max_sim
})
.sum()
}
#[inline]
#[must_use]
pub fn maxsim_weighted_vecs(
query_tokens: &[Vec<f32>],
doc_tokens: &[Vec<f32>],
weights: &[f32],
) -> f32 {
let q = as_slices(query_tokens);
let d = as_slices(doc_tokens);
maxsim_weighted(&q, &d, weights)
}
#[inline]
#[must_use]
pub fn maxsim_vecs(query_tokens: &[Vec<f32>], doc_tokens: &[Vec<f32>]) -> f32 {
let q = as_slices(query_tokens);
let d = as_slices(doc_tokens);
maxsim(&q, &d)
}
#[inline]
#[must_use]
pub fn maxsim_cosine_vecs(query_tokens: &[Vec<f32>], doc_tokens: &[Vec<f32>]) -> f32 {
let q = as_slices(query_tokens);
let d = as_slices(doc_tokens);
maxsim_cosine(&q, &d)
}
#[inline]
#[must_use]
pub fn maxsim_alignments_vecs(
query_tokens: &[Vec<f32>],
doc_tokens: &[Vec<f32>],
) -> Vec<(usize, usize, f32)> {
let q = as_slices(query_tokens);
let d = as_slices(doc_tokens);
maxsim_alignments(&q, &d)
}
#[inline]
#[must_use]
pub fn maxsim_alignments_cosine_vecs(
query_tokens: &[Vec<f32>],
doc_tokens: &[Vec<f32>],
) -> Vec<(usize, usize, f32)> {
let q = as_slices(query_tokens);
let d = as_slices(doc_tokens);
maxsim_alignments_cosine(&q, &d)
}
#[inline]
#[must_use]
pub fn highlight_matches_vecs(
query_tokens: &[Vec<f32>],
doc_tokens: &[Vec<f32>],
threshold: f32,
) -> Vec<usize> {
let q = as_slices(query_tokens);
let d = as_slices(doc_tokens);
highlight_matches(&q, &d, threshold)
}
#[must_use]
pub fn maxsim_batch(query: &[Vec<f32>], docs: &[Vec<Vec<f32>>]) -> Vec<f32> {
let q = as_slices(query);
docs.iter()
.map(|doc| {
let d = as_slices(doc);
maxsim(&q, &d)
})
.collect()
}
#[must_use]
pub fn maxsim_cosine_batch(query: &[Vec<f32>], docs: &[Vec<Vec<f32>>]) -> Vec<f32> {
let q = as_slices(query);
docs.iter()
.map(|doc| {
let d = as_slices(doc);
maxsim_cosine(&q, &d)
})
.collect()
}
#[must_use]
pub fn maxsim_alignments_batch(
query: &[Vec<f32>],
docs: &[Vec<Vec<f32>>],
) -> Vec<Vec<(usize, usize, f32)>> {
let q = as_slices(query);
docs.iter()
.map(|doc| {
let d = as_slices(doc);
maxsim_alignments(&q, &d)
})
.collect()
}
#[must_use]
pub fn maxsim_alignments_cosine_batch(
query: &[Vec<f32>],
docs: &[Vec<Vec<f32>>],
) -> Vec<Vec<(usize, usize, f32)>> {
let q = as_slices(query);
docs.iter()
.map(|doc| {
let d = as_slices(doc);
maxsim_alignments_cosine(&q, &d)
})
.collect()
}
#[must_use]
pub fn highlight_matches_batch(
query: &[Vec<f32>],
docs: &[Vec<Vec<f32>>],
threshold: f32,
) -> Vec<Vec<usize>> {
let q = as_slices(query);
docs.iter()
.map(|doc| {
let d = as_slices(doc);
highlight_matches(&q, &d, threshold)
})
.collect()
}
#[must_use]
pub fn top_k_alignments(alignments: &[(usize, usize, f32)], k: usize) -> Vec<(usize, usize, f32)> {
if k == 0 || alignments.is_empty() {
return Vec::new();
}
let mut sorted: Vec<_> = alignments.to_vec();
sorted.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
sorted.into_iter().take(k).collect()
}
#[must_use]
pub fn filter_alignments(
alignments: &[(usize, usize, f32)],
min_score: f32,
) -> Vec<(usize, usize, f32)> {
alignments
.iter()
.filter(|(_, _, score)| *score >= min_score)
.copied()
.collect()
}
#[must_use]
pub fn alignments_for_query_tokens(
alignments: &[(usize, usize, f32)],
query_indices: &std::collections::HashSet<usize>,
) -> Vec<(usize, usize, f32)> {
alignments
.iter()
.filter(|(q_idx, _, _)| query_indices.contains(q_idx))
.copied()
.collect()
}
#[must_use]
pub fn alignments_for_doc_tokens(
alignments: &[(usize, usize, f32)],
doc_indices: &std::collections::HashSet<usize>,
) -> Vec<(usize, usize, f32)> {
alignments
.iter()
.filter(|(_, d_idx, _)| doc_indices.contains(d_idx))
.copied()
.collect()
}
#[must_use]
pub fn alignment_stats(alignments: &[(usize, usize, f32)]) -> (f32, f32, f32, f32, usize) {
if alignments.is_empty() {
return (0.0, 0.0, 0.0, 0.0, 0);
}
let scores: Vec<f32> = alignments.iter().map(|(_, _, s)| *s).collect();
let min = scores.iter().copied().fold(f32::INFINITY, f32::min);
let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum: f32 = scores.iter().sum();
let count = scores.len();
let mean = sum / count as f32;
(min, max, mean, sum, count)
}
#[must_use]
pub fn idf_weights(token_doc_freqs: &[usize], total_docs: usize) -> Vec<f32> {
if token_doc_freqs.is_empty() || total_docs == 0 {
return vec![];
}
let idf_scores: Vec<f32> = token_doc_freqs
.iter()
.map(|&df| {
if df == 0 {
(total_docs as f32 + 1.0).ln()
} else if df > total_docs {
0.0
} else {
((total_docs as f32 + 1.0) / (df as f32 + 1.0)).ln()
}
})
.collect();
let min_idf = idf_scores.iter().copied().fold(f32::INFINITY, f32::min);
let max_idf = idf_scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
if (max_idf - min_idf).abs() < 1e-9 {
vec![1.0; token_doc_freqs.len()]
} else {
let mut weights = Vec::with_capacity(idf_scores.len());
for &idf in &idf_scores {
weights.push((idf - min_idf) / (max_idf - min_idf));
}
weights
}
}
#[must_use]
pub fn bm25_weights(
token_doc_freqs: &[usize],
token_query_freqs: &[usize],
total_docs: usize,
k1: f32,
) -> Vec<f32> {
if token_doc_freqs.len() != token_query_freqs.len() {
return vec![];
}
if token_doc_freqs.is_empty() || total_docs == 0 {
return vec![];
}
let k1 = k1.max(0.0);
let mut idf_scores = Vec::with_capacity(token_doc_freqs.len());
for &df in token_doc_freqs {
let idf = if df == 0 {
(total_docs as f32 + 1.0).ln()
} else if df > total_docs {
0.0
} else {
((total_docs as f32 + 1.0) / (df as f32 + 1.0)).ln()
};
idf_scores.push(idf);
}
let mut bm25_scores = Vec::with_capacity(idf_scores.len());
for (&idf, &tf) in idf_scores.iter().zip(token_query_freqs.iter()) {
let tf_f32 = tf as f32;
bm25_scores.push(idf * (tf_f32 * (k1 + 1.0)) / (tf_f32 + k1));
}
let min_score = bm25_scores.iter().copied().fold(f32::INFINITY, f32::min);
let max_score = bm25_scores
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
if (max_score - min_score).abs() < 1e-9 {
vec![1.0; token_doc_freqs.len()]
} else {
let mut weights = Vec::with_capacity(bm25_scores.len());
for &score in &bm25_scores {
weights.push((score - min_score) / (max_score - min_score));
}
weights
}
}
#[must_use]
pub fn patches_to_regions(
patch_indices: &[usize],
image_width: usize,
image_height: usize,
patches_per_side: usize,
) -> Vec<(usize, usize, usize, usize)> {
if patch_indices.is_empty() || patches_per_side == 0 || image_width == 0 || image_height == 0 {
return Vec::new();
}
let patch_width = image_width / patches_per_side;
let patch_height = image_height / patches_per_side;
let total_patches = patches_per_side * patches_per_side;
let mut regions = Vec::with_capacity(patch_indices.len());
for &patch_idx in patch_indices {
if patch_idx >= total_patches {
continue; }
let row = patch_idx / patches_per_side;
let col = patch_idx % patches_per_side;
let x = col * patch_width;
let y = row * patch_height;
regions.push((x, y, patch_width, patch_height));
}
regions
}
#[must_use]
pub fn extract_snippet_indices(
alignments: &[(usize, usize, f32)],
context_window: usize,
max_tokens: usize,
) -> Vec<usize> {
if alignments.is_empty() {
return Vec::new();
}
let estimated_capacity = alignments.len() * (1 + 2 * context_window.min(10));
let capacity = estimated_capacity.min(max_tokens.max(1));
let mut indices = std::collections::HashSet::with_capacity(capacity);
for (_, doc_idx, _) in alignments {
indices.insert(*doc_idx);
for i in 0..context_window {
let offset = i + 1;
if *doc_idx >= offset {
indices.insert(doc_idx - offset);
}
}
for i in 1..=context_window {
indices.insert(doc_idx.saturating_add(i));
}
}
let mut result: Vec<usize> = indices.into_iter().collect();
result.sort_unstable();
if max_tokens == 0 {
return Vec::new();
}
if result.len() > max_tokens {
result.truncate(max_tokens);
}
result
}
#[inline]
#[must_use]
pub fn normalize_maxsim(score: f32, query_maxlen: u32) -> f32 {
score / query_maxlen as f32
}
#[must_use]
pub fn softmax_scores(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return Vec::new();
}
let max_score = scores
.iter()
.copied()
.filter(|s| s.is_finite())
.fold(f32::NEG_INFINITY, f32::max);
if !max_score.is_finite() {
return vec![0.0; scores.len()];
}
let exp_scores: Vec<f32> = scores
.iter()
.map(|s| {
if s.is_finite() {
(s - max_score).exp()
} else {
0.0 }
})
.collect();
let sum: f32 = exp_scores.iter().sum();
if sum == 0.0 {
return vec![0.0; scores.len()];
}
exp_scores.iter().map(|s| s / sum).collect()
}
#[inline]
#[must_use]
pub fn normalize_maxsim_batch(scores: &[f32], query_maxlen: u32) -> Vec<f32> {
let divisor = query_maxlen as f32;
scores.iter().map(|s| s / divisor).collect()
}
#[must_use]
pub fn top_k_indices(scores: &[f32], k: usize) -> Vec<usize> {
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| {
match (a.1.is_nan(), b.1.is_nan()) {
(true, true) => std::cmp::Ordering::Equal,
(true, false) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Less, (false, false) => b.1.total_cmp(&a.1), }
});
indexed.into_iter().take(k).map(|(i, _)| i).collect()
}
#[inline]
#[must_use]
pub fn dot_truncating(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
dot(&a[..n], &b[..n])
}
#[inline]
#[must_use]
pub fn cosine_truncating(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
cosine(&a[..n], &b[..n])
}
#[cfg(not(feature = "innr"))]
#[allow(dead_code)] #[inline]
#[must_use]
pub(crate) fn dot_portable(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(all(not(feature = "innr"), target_arch = "x86_64"))]
#[target_feature(enable = "avx512f")]
unsafe fn dot_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m256, __m512, _mm256_add_ps, _mm256_castps256_ps128, _mm256_extractf128_ps,
_mm512_castps512_ps256, _mm512_extractf32x8_ps, _mm512_fmadd_ps, _mm512_loadu_ps,
_mm512_setzero_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_movehl_ps, _mm_shuffle_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let chunks = n / 16;
let remainder = n % 16;
let mut sum: __m512 = _mm512_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 16;
let va = _mm512_loadu_ps(a_ptr.add(offset));
let vb = _mm512_loadu_ps(b_ptr.add(offset));
sum = _mm512_fmadd_ps(va, vb, sum);
}
let sum256_lo: __m256 = _mm512_castps512_ps256(sum);
let sum256_hi: __m256 = _mm512_extractf32x8_ps::<1>(sum);
let sum256: __m256 = _mm256_add_ps(sum256_lo, sum256_hi);
let hi = _mm256_extractf128_ps(sum256, 1);
let lo = _mm256_castps256_ps128(sum256);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
let tail_start = chunks * 16;
for i in 0..remainder {
result += *a.get_unchecked(tail_start + i) * *b.get_unchecked(tail_start + i);
}
result
}
#[cfg(all(not(feature = "innr"), target_arch = "x86_64"))]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn dot_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m256, _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps,
_mm256_setzero_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_movehl_ps, _mm_shuffle_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let chunks = n / 8;
let remainder = n % 8;
let mut sum: __m256 = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let hi = _mm256_extractf128_ps(sum, 1);
let lo = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
let tail_start = chunks * 8;
for i in 0..remainder {
result += *a.get_unchecked(tail_start + i) * *b.get_unchecked(tail_start + i);
}
result
}
#[cfg(all(not(feature = "innr"), target_arch = "aarch64"))]
#[target_feature(enable = "neon")]
unsafe fn dot_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::{float32x4_t, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let chunks = n / 4;
let remainder = n % 4;
let mut sum: float32x4_t = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
sum = vfmaq_f32(sum, va, vb);
}
let mut result = vaddvq_f32(sum);
let tail_start = chunks * 4;
for i in 0..remainder {
result += *a.get_unchecked(tail_start + i) * *b.get_unchecked(tail_start + i);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
fn dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[test]
fn test_maxsim_alignments_basic() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 1.0];
let d1 = [0.9, 0.1];
let d2 = [0.1, 0.9];
let d3 = [0.5, 0.5];
let query: &[&[f32]] = &[&q1, &q2];
let doc: &[&[f32]] = &[&d1, &d2, &d3];
let alignments = maxsim_alignments(query, doc);
assert_eq!(alignments.len(), 2);
assert_eq!(alignments[0].0, 0); assert_eq!(alignments[0].1, 0); assert!(alignments[0].2 > 0.8);
assert_eq!(alignments[1].0, 1); assert_eq!(alignments[1].1, 1); assert!(alignments[1].2 > 0.8); }
#[test]
fn test_maxsim_alignments_empty() {
let query: &[&[f32]] = &[];
let doc: &[&[f32]] = &[&[1.0, 0.0]];
assert!(maxsim_alignments(query, doc).is_empty());
let query: &[&[f32]] = &[&[1.0, 0.0]];
let doc: &[&[f32]] = &[];
assert!(maxsim_alignments(query, doc).is_empty());
}
#[test]
fn test_maxsim_alignments_consistency_with_maxsim() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 1.0];
let d1 = [0.9, 0.1];
let d2 = [0.1, 0.9];
let query: &[&[f32]] = &[&q1, &q2];
let doc: &[&[f32]] = &[&d1, &d2];
let alignments = maxsim_alignments(query, doc);
let maxsim_score = maxsim(query, doc);
let alignment_sum: f32 = alignments.iter().map(|(_, _, score)| score).sum();
assert!(
(alignment_sum - maxsim_score).abs() < 1e-5,
"Alignment sum {} should equal MaxSim {}",
alignment_sum,
maxsim_score
);
}
#[test]
fn test_highlight_matches_basic() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 1.0];
let d1 = [0.9, 0.1];
let d2 = [0.1, 0.9];
let d3 = [0.5, 0.5];
let query: &[&[f32]] = &[&q1, &q2];
let doc: &[&[f32]] = &[&d1, &d2, &d3];
let highlighted = highlight_matches(query, doc, 0.7);
assert_eq!(highlighted.len(), 2);
assert!(highlighted.contains(&0));
assert!(highlighted.contains(&1));
assert!(!highlighted.contains(&2)); }
#[test]
fn test_highlight_matches_threshold() {
let q1 = [1.0, 0.0];
let d1 = [0.9, 0.1];
let d2 = [0.5, 0.5];
let d3 = [0.1, 0.9];
let query: &[&[f32]] = &[&q1];
let doc: &[&[f32]] = &[&d1, &d2, &d3];
let high = highlight_matches(query, doc, 0.8);
assert_eq!(high, vec![0]);
let low = highlight_matches(query, doc, 0.0);
assert_eq!(low.len(), 3);
}
#[test]
fn test_highlight_matches_empty() {
let query: &[&[f32]] = &[];
let doc: &[&[f32]] = &[&[1.0, 0.0]];
assert!(highlight_matches(query, doc, 0.5).is_empty());
let query: &[&[f32]] = &[&[1.0, 0.0]];
let doc: &[&[f32]] = &[];
assert!(highlight_matches(query, doc, 0.5).is_empty());
}
#[test]
fn test_maxsim_alignments_vecs() {
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let doc = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
let alignments = maxsim_alignments_vecs(&query, &doc);
assert_eq!(alignments.len(), 2);
assert_eq!(alignments[0].0, 0);
assert_eq!(alignments[0].1, 0);
assert_eq!(alignments[1].0, 1);
assert_eq!(alignments[1].1, 1);
}
#[test]
fn test_highlight_matches_vecs() {
let query = vec![vec![1.0, 0.0]];
let doc = vec![vec![0.9, 0.1], vec![0.5, 0.5]];
let highlighted = highlight_matches_vecs(&query, &doc, 0.7);
assert_eq!(highlighted, vec![0]);
}
#[test]
fn test_dot_basic() {
assert!((dot(&[1.0, 2.0], &[3.0, 4.0]) - 11.0).abs() < 1e-5);
}
#[test]
fn test_dot_empty() {
assert_eq!(dot(&[], &[]), 0.0);
}
#[test]
fn test_dot_truncating_empty() {
assert_eq!(dot_truncating(&[], &[]), 0.0);
assert_eq!(dot_truncating(&[1.0], &[]), 0.0);
assert_eq!(dot_truncating(&[], &[1.0]), 0.0);
}
#[test]
fn test_dot_truncating_mismatched() {
assert!((dot_truncating(&[1.0, 2.0, 3.0], &[4.0, 5.0]) - 14.0).abs() < 1e-5);
}
#[test]
fn test_dot_simd_vs_portable() {
for len in [
0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 48, 64, 100, 128, 256, 512, 1024,
] {
let a: Vec<f32> = (0..len).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..len).map(|i| (i as f32) * 0.2 + 1.0).collect();
let reference = dot_reference(&a, &b);
let simd = dot(&a, &b);
let tolerance = (reference.abs() * 1e-5).max(1e-5);
assert!(
(reference - simd).abs() < tolerance,
"Mismatch at len={}: reference={}, simd={}, diff={}",
len,
reference,
simd,
(reference - simd).abs()
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
#[ignore] fn test_avx512_dispatch() {
let a: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..128).map(|i| (i as f32) * 0.2 + 1.0).collect();
let result = dot(&a, &b);
assert!(result.is_finite());
assert!((result - dot_reference(&a, &b)).abs() < 1e-3);
}
#[test]
fn test_cosine_basic() {
assert!((cosine(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-5);
assert!(cosine(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-5);
}
#[test]
fn test_cosine_zero_norm() {
assert_eq!(cosine(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
assert_eq!(cosine(&[1.0, 0.0], &[0.0, 0.0]), 0.0);
}
#[test]
fn test_maxsim_basic() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 1.0];
let d1 = [0.5, 0.5];
let d2 = [1.0, 0.0];
let d3 = [0.0, 1.0];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d1, &d2, &d3];
assert!((maxsim(&query, &doc) - 2.0).abs() < 1e-5);
}
#[test]
fn test_maxsim_empty_query() {
let doc: Vec<&[f32]> = vec![&[1.0, 0.0]];
assert_eq!(maxsim(&[], &doc), 0.0);
}
#[test]
fn test_maxsim_empty_doc() {
let q1 = [1.0, 0.0];
let query: Vec<&[f32]> = vec![&q1];
assert_eq!(maxsim(&query, &[]), 0.0);
}
#[test]
fn test_cosine_zero_vector() {
let a = [1.0, 0.0];
let zero = [0.0, 0.0];
assert_eq!(cosine(&a, &zero), 0.0);
assert_eq!(cosine(&zero, &a), 0.0);
assert_eq!(cosine(&zero, &zero), 0.0);
}
#[test]
fn test_maxsim_zero_tokens() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 0.0]; let d1 = [0.9, 0.1];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d1];
let score = maxsim(&query, &doc);
assert!(score >= 0.0);
assert!((score - 0.9).abs() < 0.1);
}
#[test]
fn test_maxsim_nan_handling() {
let q1 = [1.0, 0.0];
let q2 = [f32::NAN, 0.0];
let d1 = [0.9, 0.1];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d1];
let score = maxsim(&query, &doc);
assert!(score.is_finite() || score.is_nan() || score.is_infinite());
}
#[test]
fn test_maxsim_inf_handling() {
let q1 = [1.0, 0.0];
let q2 = [f32::INFINITY, 0.0];
let d1 = [0.9, 0.1];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d1];
let score = maxsim(&query, &doc);
assert!(!score.is_nan());
}
#[test]
fn test_dot_empty_vectors() {
assert_eq!(dot(&[], &[]), 0.0);
assert_eq!(dot_truncating(&[1.0], &[]), 0.0);
assert_eq!(dot_truncating(&[], &[1.0]), 0.0);
}
#[test]
fn test_norm_zero_vector() {
let zero = [0.0, 0.0];
assert_eq!(norm(&zero), 0.0);
}
#[test]
fn test_maxsim_single_token() {
let q1 = [1.0, 0.0];
let d1 = [0.9, 0.1];
let query: Vec<&[f32]> = vec![&q1];
let doc: Vec<&[f32]> = vec![&d1];
let score = maxsim(&query, &doc);
assert!((score - 0.9).abs() < 0.1);
}
#[test]
fn test_maxsim_identical_tokens() {
let q1 = [1.0, 0.0];
let d1 = [1.0, 0.0];
let query: Vec<&[f32]> = vec![&q1];
let doc: Vec<&[f32]> = vec![&d1];
let score = maxsim(&query, &doc);
assert!(score > 0.9);
}
#[test]
fn test_maxsim_vecs() {
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let doc = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
assert!((maxsim_vecs(&query, &doc) - 2.0).abs() < 1e-5);
}
#[test]
fn test_maxsim_cosine_vecs() {
let query = vec![vec![2.0, 0.0]]; let doc = vec![vec![1.0, 0.0]];
assert!((maxsim_cosine_vecs(&query, &doc) - 1.0).abs() < 1e-5);
}
#[test]
fn idf_weights_empty_input() {
assert_eq!(idf_weights(&[], 1000), Vec::<f32>::new());
assert_eq!(idf_weights(&[100, 200], 0), Vec::<f32>::new());
}
#[test]
fn idf_weights_zero_doc_freq() {
let weights = idf_weights(&[0, 100], 1000);
assert_eq!(weights.len(), 2);
assert!(
weights[0] > weights[1],
"Zero doc freq should get higher weight"
);
}
#[test]
fn idf_weights_all_same_freq() {
let weights = idf_weights(&[100, 100, 100], 1000);
assert_eq!(weights.len(), 3);
assert!((weights[0] - weights[1]).abs() < 1e-5);
assert!((weights[1] - weights[2]).abs() < 1e-5);
assert!(
(weights[0] - 1.0).abs() < 1e-5,
"Uniform weights should be 1.0"
);
}
#[test]
fn idf_weights_doc_freq_exceeds_total() {
let weights = idf_weights(&[100, 2000], 1000);
assert_eq!(weights.len(), 2);
assert!(weights[0] >= weights[1]);
}
#[test]
fn idf_weights_normalized_range() {
let weights = idf_weights(&[10, 100, 1000], 10000);
assert_eq!(weights.len(), 3);
for w in &weights {
assert!(*w >= 0.0 && *w <= 1.0, "Weight {} should be in [0, 1]", w);
}
let min_w = weights.iter().copied().fold(f32::INFINITY, f32::min);
let max_w = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
assert!((min_w - 0.0).abs() < 1e-5, "Min weight should be 0.0");
assert!((max_w - 1.0).abs() < 1e-5, "Max weight should be 1.0");
}
#[test]
fn bm25_weights_length_mismatch() {
assert_eq!(bm25_weights(&[100], &[1, 2], 1000, 1.5), Vec::<f32>::new());
assert_eq!(
bm25_weights(&[100, 200], &[1], 1000, 1.5),
Vec::<f32>::new()
);
}
#[test]
fn bm25_weights_empty_input() {
assert_eq!(bm25_weights(&[], &[], 1000, 1.5), Vec::<f32>::new());
assert_eq!(bm25_weights(&[100], &[1], 0, 1.5), Vec::<f32>::new());
}
#[test]
fn bm25_weights_negative_k1() {
let weights = bm25_weights(&[100, 200], &[1, 1], 1000, -1.0);
assert_eq!(weights.len(), 2);
for w in &weights {
assert!(*w >= 0.0 && *w <= 1.0);
}
}
#[test]
fn bm25_weights_zero_k1() {
let weights = bm25_weights(&[100, 200], &[1, 1], 1000, 0.0);
assert_eq!(weights.len(), 2);
for w in &weights {
assert!(*w >= 0.0 && *w <= 1.0);
}
}
#[test]
fn bm25_weights_higher_tf_higher_weight() {
let weights = bm25_weights(&[100, 100], &[1, 3], 1000, 1.5);
assert_eq!(weights.len(), 2);
assert!(
weights[1] > weights[0],
"Higher tf should give higher weight"
);
}
#[test]
fn bm25_weights_normalized_range() {
let weights = bm25_weights(&[10, 100, 1000], &[1, 2, 1], 10000, 1.5);
assert_eq!(weights.len(), 3);
for w in &weights {
assert!(*w >= 0.0 && *w <= 1.0, "Weight {} should be in [0, 1]", w);
}
}
#[test]
fn patches_to_regions_empty() {
assert_eq!(patches_to_regions(&[], 1024, 768, 32), vec![]);
}
#[test]
fn patches_to_regions_zero_patches_per_side() {
assert_eq!(patches_to_regions(&[0, 1], 1024, 768, 0), vec![]);
}
#[test]
fn patches_to_regions_zero_dimensions() {
assert_eq!(patches_to_regions(&[0, 1], 0, 768, 32), vec![]);
assert_eq!(patches_to_regions(&[0, 1], 1024, 0, 32), vec![]);
}
#[test]
fn patches_to_regions_invalid_index() {
let regions = patches_to_regions(&[0, 1024, 1], 1024, 768, 32);
assert_eq!(regions.len(), 2, "Should filter out invalid patch index");
}
#[test]
fn patches_to_regions_valid() {
let regions = patches_to_regions(&[0, 1, 32, 33], 1024, 768, 32);
assert_eq!(regions.len(), 4);
assert_eq!(regions[0], (0, 0, 32, 24));
assert_eq!(regions[1], (32, 0, 32, 24));
assert_eq!(regions[2], (0, 24, 32, 24));
assert_eq!(regions[3], (32, 24, 32, 24));
}
#[test]
fn patches_to_regions_bounds_check() {
let regions = patches_to_regions(&[0, 100, 500, 1023], 1024, 768, 32);
for (x, y, w, h) in ®ions {
assert!(*x + *w <= 1024, "Region x+width should be <= image width");
assert!(*y + *h <= 768, "Region y+height should be <= image height");
}
}
#[test]
fn extract_snippet_indices_empty() {
assert_eq!(extract_snippet_indices(&[], 2, 10), Vec::<usize>::new());
}
#[test]
fn extract_snippet_indices_max_tokens_zero() {
let alignments = vec![(0, 5, 0.9), (1, 10, 0.8)];
assert_eq!(
extract_snippet_indices(&alignments, 1, 0),
Vec::<usize>::new()
);
}
#[test]
fn extract_snippet_indices_no_context() {
let alignments = vec![(0, 5, 0.9), (1, 10, 0.8)];
let snippet = extract_snippet_indices(&alignments, 0, 100);
assert_eq!(snippet, vec![5, 10], "Should include only matched tokens");
}
#[test]
fn extract_snippet_indices_with_context() {
let alignments = vec![(0, 5, 0.9)];
let snippet = extract_snippet_indices(&alignments, 2, 100);
assert!(snippet.contains(&5), "Should include matched token");
assert!(
snippet.contains(&3) || snippet.contains(&4),
"Should include context before"
);
assert!(
snippet.contains(&6) || snippet.contains(&7),
"Should include context after"
);
}
#[test]
fn extract_snippet_indices_max_tokens_limit() {
let alignments: Vec<(usize, usize, f32)> = (0..100).map(|i| (0, i, 0.9)).collect();
let snippet = extract_snippet_indices(&alignments, 2, 10);
assert_eq!(snippet.len(), 10, "Should respect max_tokens limit");
}
#[test]
fn extract_snippet_indices_sorted() {
let alignments = vec![(0, 10, 0.9), (1, 5, 0.8), (2, 20, 0.7)];
let snippet = extract_snippet_indices(&alignments, 1, 100);
for i in 1..snippet.len() {
assert!(snippet[i] >= snippet[i - 1], "Should be sorted");
}
}
#[test]
fn extract_snippet_indices_context_bounds() {
let alignments = vec![(0, 0, 0.9)]; let snippet = extract_snippet_indices(&alignments, 5, 100);
assert!(
!snippet.is_empty() || snippet.iter().all(|&idx| idx < 1000),
"Indices should be reasonable"
);
}
#[test]
fn extract_snippet_indices_deduplication() {
let alignments = vec![(0, 5, 0.9), (1, 5, 0.8), (2, 5, 0.7)];
let snippet = extract_snippet_indices(&alignments, 2, 100);
let count_5 = snippet.iter().filter(|&&x| x == 5).count();
assert_eq!(count_5, 1, "Should deduplicate same token");
}
#[test]
fn test_maxsim_cosine_basic() {
let q1 = [1.0, 0.0];
let d1 = [1.0, 0.0];
let d2 = [0.0, 1.0];
let query: Vec<&[f32]> = vec![&q1];
let doc: Vec<&[f32]> = vec![&d1, &d2];
assert!((maxsim_cosine(&query, &doc) - 1.0).abs() < 1e-5);
}
#[test]
fn test_maxsim_cosine_empty_doc() {
let q1 = [1.0, 0.0];
let query: Vec<&[f32]> = vec![&q1];
assert_eq!(maxsim_cosine(&query, &[]), 0.0);
}
#[test]
fn cosine_zero_norm_returns_zero_not_nan() {
let zero = [0.0, 0.0];
let nonzero = [1.0, 2.0];
let c1 = cosine(&zero, &nonzero);
let c2 = cosine(&nonzero, &zero);
let c3 = cosine(&zero, &zero);
assert_eq!(c1, 0.0, "cosine(zero, x) should be 0, got {}", c1);
assert_eq!(c2, 0.0, "cosine(x, zero) should be 0, got {}", c2);
assert_eq!(c3, 0.0, "cosine(zero, zero) should be 0, got {}", c3);
assert!(!c1.is_nan(), "should not return NaN");
}
#[test]
fn cosine_near_zero_norm_stable() {
let tiny = [1e-20, 0.0];
let normal = [1.0, 0.0];
let c = cosine(&tiny, &normal);
assert!(c.is_finite(), "cosine with tiny norm should be finite");
assert_eq!(c, 0.0, "tiny norm should return 0.0");
let small = [1e-8, 0.0];
let c2 = cosine(&small, &normal);
assert!(c2.is_finite(), "cosine with small norm should be finite");
}
#[test]
fn dot_exact_orthogonal() {
let a = [1.0, 0.0, 0.0];
let b = [0.0, 1.0, 0.0];
assert_eq!(dot(&a, &b), 0.0);
}
#[test]
fn dot_exact_parallel() {
let a = [1.0, 0.0];
let b = [1.0, 0.0];
assert!((dot(&a, &b) - 1.0).abs() < 1e-9);
}
#[test]
fn maxsim_single_query_single_doc() {
let q = [1.0, 2.0, 3.0];
let d = [4.0, 5.0, 6.0];
let query: Vec<&[f32]> = vec![&q];
let doc: Vec<&[f32]> = vec![&d];
let expected_dot = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; let actual = maxsim(&query, &doc);
assert!(
(actual - expected_dot).abs() < 1e-5,
"expected {}, got {}",
expected_dot,
actual
);
}
#[test]
fn maxsim_sum_of_maxes() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 1.0];
let d1 = [0.5, 0.0]; let d2 = [0.0, 0.8];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d1, &d2];
let expected = 0.5 + 0.8;
let actual = maxsim(&query, &doc);
assert!(
(actual - expected).abs() < 1e-5,
"expected {}, got {}",
expected,
actual
);
}
#[test]
fn norm_exact_values() {
assert!((norm(&[3.0, 4.0]) - 5.0).abs() < 1e-9, "3-4-5 triangle");
assert!((norm(&[1.0, 0.0]) - 1.0).abs() < 1e-9, "unit x");
assert!((norm(&[0.0, 0.0]) - 0.0).abs() < 1e-9, "zero vector");
}
#[test]
fn maxsim_weighted_basic() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 1.0];
let d1 = [1.0, 0.0]; let d2 = [0.0, 1.0];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d1, &d2];
let unweighted = maxsim(&query, &doc);
assert!((unweighted - 2.0).abs() < 1e-5);
let equal_weights = [1.0, 1.0];
let weighted_equal = maxsim_weighted(&query, &doc, &equal_weights);
assert!((weighted_equal - 2.0).abs() < 1e-5);
let weights = [2.0, 1.0];
let weighted = maxsim_weighted(&query, &doc, &weights);
assert!((weighted - 3.0).abs() < 1e-5);
}
#[test]
fn maxsim_weighted_single_token() {
let q = [1.0, 0.0, 0.0];
let d = [1.0, 0.0, 0.0];
let query: Vec<&[f32]> = vec![&q];
let doc: Vec<&[f32]> = vec![&d];
let weights = [2.0];
let score = maxsim_weighted(&query, &doc, &weights);
assert!((score - 2.0).abs() < 1e-5);
}
#[test]
fn maxsim_weighted_empty() {
let q1 = [1.0, 0.0];
let query: Vec<&[f32]> = vec![&q1];
let empty: Vec<&[f32]> = vec![];
let weights = [1.0];
assert_eq!(maxsim_weighted(&query, &empty, &weights), 0.0);
assert_eq!(maxsim_weighted(&[], &query, &weights), 0.0);
}
#[test]
fn maxsim_weighted_missing_weights_default_to_one() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 1.0];
let d = [1.0, 1.0];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d];
let weights = [2.0]; let score = maxsim_weighted(&query, &doc, &weights);
assert!((score - 3.0).abs() < 1e-5);
}
#[test]
fn maxsim_weighted_zero_weight_ignores_token() {
let q1 = [1.0, 0.0];
let q2 = [0.0, 1.0];
let d = [1.0, 0.0];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d];
let weights = [1.0, 0.0];
let score = maxsim_weighted(&query, &doc, &weights);
assert!((score - 1.0).abs() < 1e-5);
}
#[test]
fn maxsim_cosine_weighted_basic() {
let q1 = [2.0, 0.0]; let q2 = [0.0, 3.0]; let d = [1.0, 0.0];
let query: Vec<&[f32]> = vec![&q1, &q2];
let doc: Vec<&[f32]> = vec![&d];
let weights = [2.0, 1.0];
let score = maxsim_cosine_weighted(&query, &doc, &weights);
assert!((score - 2.0).abs() < 1e-5);
}
#[test]
fn maxsim_weighted_vecs_convenience() {
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let doc = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let weights = [2.0, 0.5];
let score = maxsim_weighted_vecs(&query, &doc, &weights);
assert!((score - 2.5).abs() < 1e-5);
}
#[test]
fn normalize_maxsim_basic() {
assert!((normalize_maxsim(16.0, 32) - 0.5).abs() < 1e-9);
assert!((normalize_maxsim(32.0, 32) - 1.0).abs() < 1e-9);
assert!((normalize_maxsim(0.0, 32) - 0.0).abs() < 1e-9);
}
#[test]
fn normalize_maxsim_batch_basic() {
let scores = vec![16.0, 32.0, 8.0];
let normalized = normalize_maxsim_batch(&scores, 32);
assert!((normalized[0] - 0.5).abs() < 1e-9);
assert!((normalized[1] - 1.0).abs() < 1e-9);
assert!((normalized[2] - 0.25).abs() < 1e-9);
}
#[test]
fn softmax_scores_sums_to_one() {
let scores = vec![2.0, 1.0, 0.5, 0.1];
let probs = softmax_scores(&scores);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"softmax should sum to 1, got {}",
sum
);
}
#[test]
fn softmax_scores_preserves_order() {
let scores = vec![3.0, 1.0, 2.0];
let probs = softmax_scores(&scores);
assert!(probs[0] > probs[2]); assert!(probs[2] > probs[1]); }
#[test]
fn softmax_scores_empty() {
let probs = softmax_scores(&[]);
assert!(probs.is_empty());
}
#[test]
fn softmax_scores_single() {
let probs = softmax_scores(&[5.0]);
assert_eq!(probs.len(), 1);
assert!((probs[0] - 1.0).abs() < 1e-9);
}
#[test]
fn softmax_scores_handles_nan() {
let scores = vec![f32::NAN, 1.0, 2.0];
let probs = softmax_scores(&scores);
assert!(probs[0].is_finite());
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn softmax_scores_large_values_stable() {
let scores = vec![1000.0, 1001.0, 999.0];
let probs = softmax_scores(&scores);
assert!(probs.iter().all(|p| p.is_finite()));
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn top_k_indices_basic() {
let scores = vec![0.5, 0.9, 0.1, 0.7, 0.3];
let top2 = top_k_indices(&scores, 2);
assert_eq!(top2, vec![1, 3]); }
#[test]
fn top_k_indices_with_ties() {
let scores = vec![0.5, 0.5, 0.5];
let top2 = top_k_indices(&scores, 2);
assert_eq!(top2.len(), 2);
}
#[test]
fn top_k_indices_k_larger_than_len() {
let scores = vec![0.5, 0.9];
let top5 = top_k_indices(&scores, 5);
assert_eq!(top5.len(), 2); }
#[test]
fn top_k_indices_with_nan() {
let scores = vec![0.5, f32::NAN, 0.9, 0.7];
let top2 = top_k_indices(&scores, 2);
assert_eq!(top2, vec![2, 3]);
}
#[test]
fn top_k_indices_empty() {
let top = top_k_indices(&[], 5);
assert!(top.is_empty());
}
#[test]
fn cosine_zero_norm_returns_zero() {
let zero = [0.0, 0.0, 0.0];
let nonzero = [1.0, 2.0, 3.0];
assert_eq!(cosine(&zero, &nonzero), 0.0);
assert_eq!(cosine(&nonzero, &zero), 0.0);
assert_eq!(cosine(&zero, &zero), 0.0);
}
#[test]
fn cosine_near_zero_norm_returns_zero() {
let tiny = [1e-10, 1e-10, 1e-10];
let nonzero = [1.0, 2.0, 3.0];
assert_eq!(cosine(&tiny, &nonzero), 0.0);
}
#[test]
fn cosine_at_threshold_boundary() {
let at_threshold = [1e-9, 0.0, 0.0]; let nonzero = [1.0, 0.0, 0.0];
assert_eq!(cosine(&at_threshold, &nonzero), 0.0);
let above_threshold = [1e-8, 0.0, 0.0]; let result = cosine(&above_threshold, &nonzero);
assert!(
result.is_finite(),
"Above threshold should be finite: {}",
result
);
}
#[test]
fn maxsim_cosine_vecs_not_one() {
let query = vec![vec![1.0, 0.0, 0.0]];
let doc = vec![vec![0.0, 1.0, 0.0]]; let score = maxsim_cosine_vecs(&query, &doc);
assert!(
score.abs() < 0.01,
"Orthogonal vectors should have ~0 cosine, got {}",
score
);
}
#[test]
fn maxsim_cosine_vecs_matches_slice_version() {
let query = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let doc = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let q_refs: Vec<&[f32]> = query.iter().map(|v| v.as_slice()).collect();
let d_refs: Vec<&[f32]> = doc.iter().map(|v| v.as_slice()).collect();
let vec_score = maxsim_cosine_vecs(&query, &doc);
let slice_score = maxsim_cosine(&q_refs, &d_refs);
assert!(
(vec_score - slice_score).abs() < 1e-5,
"Vec and slice versions should match: {} vs {}",
vec_score,
slice_score
);
}
#[test]
fn maxsim_cosine_weighted_empty_returns_zero() {
let empty_q: Vec<&[f32]> = vec![];
let empty_d: Vec<&[f32]> = vec![];
let some: Vec<&[f32]> = vec![&[1.0, 0.0]];
let weights = [1.0];
assert_eq!(maxsim_cosine_weighted(&empty_q, &some, &weights), 0.0);
assert_eq!(maxsim_cosine_weighted(&some, &empty_d, &weights), 0.0);
}
#[test]
fn dot_short_vector_uses_portable() {
let short_a: Vec<f32> = (0..8).map(|i| i as f32).collect();
let short_b: Vec<f32> = (0..8).map(|i| (i + 1) as f32).collect();
let result = dot(&short_a, &short_b);
assert!(
(result - 168.0).abs() < 1e-3,
"Short vector dot: {} != 168",
result
);
}
#[test]
fn dot_exactly_min_dim() {
let a: Vec<f32> = (0..16).map(|i| i as f32).collect();
let b: Vec<f32> = vec![1.0; 16];
let result = dot(&a, &b);
assert!(
(result - 120.0).abs() < 1e-3,
"MIN_DIM dot: {} != 120",
result
);
}
#[test]
fn dot_truncating_mismatched_lengths() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0];
let b = [1.0, 1.0, 1.0];
let result = dot_truncating(&a, &b);
assert!(
(result - 6.0).abs() < 1e-5,
"Mismatched len dot_truncating: {} != 6",
result
);
}
#[test]
#[should_panic]
#[cfg(debug_assertions)]
fn dot_panics_on_mismatch_in_debug() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0];
let b = [1.0, 1.0, 1.0];
let _ = dot(&a, &b);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn arb_vec(len: usize) -> impl Strategy<Value = Vec<f32>> {
proptest::collection::vec(-10.0f32..10.0, len)
}
proptest! {
#[test]
fn dot_simd_matches_portable(a in arb_vec(128), b in arb_vec(128)) {
let simd_result = dot(&a, &b);
let reference_result = dot_reference(&a, &b);
prop_assert!(
(simd_result - reference_result).abs() < 1e-3,
"SIMD {} != reference {}",
simd_result,
reference_result
);
}
#[test]
fn dot_commutative(a in arb_vec(64), b in arb_vec(64)) {
let ab = dot(&a, &b);
let ba = dot(&b, &a);
prop_assert!((ab - ba).abs() < 1e-5);
}
#[test]
fn cosine_bounded(
a in arb_vec(32).prop_filter("non-zero", |v| v.iter().any(|x| x.abs() > 1e-6)),
b in arb_vec(32).prop_filter("non-zero", |v| v.iter().any(|x| x.abs() > 1e-6))
) {
let c = cosine(&a, &b);
prop_assert!((-1.0 - 1e-5..=1.0 + 1e-5).contains(&c), "cosine {} out of bounds", c);
}
#[test]
fn cosine_commutative(a in arb_vec(32), b in arb_vec(32)) {
let ab = cosine(&a, &b);
let ba = cosine(&b, &a);
prop_assert!((ab - ba).abs() < 1e-5);
}
#[test]
fn maxsim_nonnegative_inputs(
q_data in proptest::collection::vec(arb_vec(16), 1..5),
d_data in proptest::collection::vec(arb_vec(16), 1..5)
) {
let q_pos: Vec<Vec<f32>> = q_data.iter()
.map(|v| v.iter().map(|x| x.abs()).collect())
.collect();
let d_pos: Vec<Vec<f32>> = d_data.iter()
.map(|v| v.iter().map(|x| x.abs()).collect())
.collect();
let q_refs: Vec<&[f32]> = q_pos.iter().map(|v| v.as_slice()).collect();
let d_refs: Vec<&[f32]> = d_pos.iter().map(|v| v.as_slice()).collect();
let score = maxsim(&q_refs, &d_refs);
prop_assert!(score >= 0.0, "maxsim {} should be non-negative", score);
}
#[test]
fn maxsim_empty_returns_zero(q_data in proptest::collection::vec(arb_vec(8), 0..3)) {
let q_refs: Vec<&[f32]> = q_data.iter().map(|v| v.as_slice()).collect();
let empty: Vec<&[f32]> = vec![];
prop_assert_eq!(maxsim(&empty, &q_refs), 0.0);
prop_assert_eq!(maxsim(&q_refs, &empty), 0.0);
}
#[test]
fn dot_self_is_squared_norm(v in arb_vec(32)) {
let dot_self = dot(&v, &v);
let n = norm(&v);
let squared_norm = n * n;
let tolerance = (squared_norm.abs() * 1e-4).max(1e-4);
prop_assert!(
(dot_self - squared_norm).abs() < tolerance,
"dot(v,v) = {} but norm²= {}",
dot_self,
squared_norm
);
}
#[test]
fn cosine_self_is_one(v in arb_vec(16).prop_filter("non-zero", |v| norm(v) > 1e-6)) {
let c = cosine(&v, &v);
prop_assert!(
(c - 1.0).abs() < 1e-5,
"cosine(v, v) = {} should be 1",
c
);
}
#[test]
fn norm_nonnegative(v in arb_vec(64)) {
let n = norm(&v);
prop_assert!(n >= 0.0, "norm {} should be non-negative", n);
}
#[test]
fn norm_scaling(v in arb_vec(16), alpha in -10.0f32..10.0) {
let scaled: Vec<f32> = v.iter().map(|x| x * alpha).collect();
let n_v = norm(&v);
let n_scaled = norm(&scaled);
let expected = alpha.abs() * n_v;
prop_assert!(
(n_scaled - expected).abs() < 1e-4,
"||αv|| = {} but |α|||v|| = {}",
n_scaled,
expected
);
}
#[test]
fn dot_bilinear(a in arb_vec(16), b in arb_vec(16), alpha in -5.0f32..5.0) {
let scaled_a: Vec<f32> = a.iter().map(|x| x * alpha).collect();
let dot_scaled = dot(&scaled_a, &b);
let expected = alpha * dot(&a, &b);
prop_assert!(
(dot_scaled - expected).abs() < 1e-3,
"dot(αa, b) = {} but α·dot(a, b) = {}",
dot_scaled,
expected
);
}
#[test]
fn cauchy_schwarz(a in arb_vec(32), b in arb_vec(32)) {
let d = dot(&a, &b).abs();
let bound = norm(&a) * norm(&b);
prop_assert!(
d <= bound + 1e-4,
"|dot(a,b)| = {} should be <= ||a||·||b|| = {}",
d,
bound
);
}
#[test]
fn maxsim_scales_with_query_count(n_query in 1usize..5, dim in 4usize..8) {
let token: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) * 0.1).collect();
let query: Vec<Vec<f32>> = vec![token.clone(); n_query];
let doc = [token.clone()];
let q_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = doc.iter().map(Vec::as_slice).collect();
let score = maxsim(&q_refs, &d_refs);
let expected = n_query as f32 * dot(&token, &token);
prop_assert!(
(score - expected).abs() < 1e-4,
"MaxSim should scale linearly: {} vs expected {}",
score,
expected
);
}
#[test]
fn maxsim_cosine_bounded_by_query_count(n_query in 1usize..4, n_doc in 1usize..4) {
let query: Vec<Vec<f32>> = (0..n_query)
.map(|i| {
let mut v = vec![0.0f32; 8];
v[i % 8] = 1.0;
v
})
.collect();
let doc: Vec<Vec<f32>> = (0..n_doc)
.map(|i| {
let mut v = vec![0.0f32; 8];
v[(i + 1) % 8] = 1.0;
v
})
.collect();
let q_refs: Vec<&[f32]> = query.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = doc.iter().map(Vec::as_slice).collect();
let score = maxsim_cosine(&q_refs, &d_refs);
let upper_bound = n_query as f32;
prop_assert!(
score <= upper_bound + 1e-5,
"MaxSim cosine {} should be <= {}",
score,
upper_bound
);
}
#[test]
fn maxsim_weighted_unit_weights_equals_unweighted(
q_data in proptest::collection::vec(arb_vec(8), 1..5),
d_data in proptest::collection::vec(arb_vec(8), 1..5)
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let weights: Vec<f32> = vec![1.0; q_data.len()];
let unweighted = maxsim(&q_refs, &d_refs);
let weighted = maxsim_weighted(&q_refs, &d_refs, &weights);
prop_assert!(
(unweighted - weighted).abs() < 1e-5,
"unweighted {} != weighted {}",
unweighted,
weighted
);
}
#[test]
fn maxsim_alignments_sum_equals_maxsim(
q_data in proptest::collection::vec(arb_vec(8), 1..5),
d_data in proptest::collection::vec(arb_vec(8), 1..5)
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let alignments = maxsim_alignments(&q_refs, &d_refs);
let maxsim_score = maxsim(&q_refs, &d_refs);
let alignment_sum: f32 = alignments.iter().map(|(_, _, s)| s).sum();
prop_assert!(
(alignment_sum - maxsim_score).abs() < 1e-4,
"Alignment sum {} != MaxSim {}",
alignment_sum,
maxsim_score
);
}
#[test]
fn maxsim_alignments_count_equals_query_tokens(
q_data in proptest::collection::vec(arb_vec(8), 1..5),
d_data in proptest::collection::vec(arb_vec(8), 1..5)
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let alignments = maxsim_alignments(&q_refs, &d_refs);
prop_assert_eq!(
alignments.len(),
q_data.len(),
"Should have one alignment per query token"
);
}
#[test]
fn maxsim_alignments_query_indices_sequential(
q_data in proptest::collection::vec(arb_vec(8), 1..5),
d_data in proptest::collection::vec(arb_vec(8), 1..5)
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let alignments = maxsim_alignments(&q_refs, &d_refs);
for (i, (q_idx, _, _)) in alignments.iter().enumerate() {
prop_assert_eq!(
*q_idx, i,
"Query index {} should match position {}",
q_idx, i
);
}
}
#[test]
fn maxsim_alignments_doc_indices_valid(
q_data in proptest::collection::vec(arb_vec(8), 1..5),
d_data in proptest::collection::vec(arb_vec(8), 1..5)
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let alignments = maxsim_alignments(&q_refs, &d_refs);
for (_, d_idx, _) in &alignments {
prop_assert!(
*d_idx < d_data.len(),
"Doc index {} out of bounds (len={})",
d_idx,
d_data.len()
);
}
}
#[test]
fn highlight_matches_valid_indices(
q_data in proptest::collection::vec(arb_vec(8), 1..5),
d_data in proptest::collection::vec(arb_vec(8), 1..5),
threshold in -1.0f32..2.0
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let highlighted = highlight_matches(&q_refs, &d_refs, threshold);
for &idx in &highlighted {
prop_assert!(
idx < d_data.len(),
"Highlighted index {} out of bounds (len={})",
idx,
d_data.len()
);
}
for i in 1..highlighted.len() {
prop_assert!(
highlighted[i - 1] < highlighted[i],
"Highlighted indices should be sorted"
);
}
}
#[test]
fn highlight_matches_high_threshold(
q_data in proptest::collection::vec(arb_vec(8), 1..5),
d_data in proptest::collection::vec(arb_vec(8), 1..5)
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let high_threshold = highlight_matches(&q_refs, &d_refs, 10.0);
let low_threshold = highlight_matches(&q_refs, &d_refs, -10.0);
prop_assert!(
high_threshold.len() <= low_threshold.len(),
"High threshold should return fewer or equal matches"
);
}
#[test]
fn maxsim_weighted_uniform_scaling(
q_data in proptest::collection::vec(arb_vec(8), 1..4),
d_data in proptest::collection::vec(arb_vec(8), 1..4),
scale in 0.1f32..5.0
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let weights: Vec<f32> = vec![scale; q_data.len()];
let unweighted = maxsim(&q_refs, &d_refs);
let weighted = maxsim_weighted(&q_refs, &d_refs, &weights);
let expected = scale * unweighted;
let tolerance = (expected.abs() * 1e-4).max(1e-4);
prop_assert!(
(weighted - expected).abs() < tolerance,
"weighted {} != scale * unweighted {}",
weighted,
expected
);
}
#[test]
fn maxsim_weighted_zero_weights_returns_zero(
q_data in proptest::collection::vec(arb_vec(8), 1..4),
d_data in proptest::collection::vec(arb_vec(8), 1..4)
) {
let q_refs: Vec<&[f32]> = q_data.iter().map(Vec::as_slice).collect();
let d_refs: Vec<&[f32]> = d_data.iter().map(Vec::as_slice).collect();
let weights: Vec<f32> = vec![0.0; q_data.len()];
let weighted = maxsim_weighted(&q_refs, &d_refs, &weights);
prop_assert!(
weighted.abs() < 1e-9,
"zero weights should give 0, got {}",
weighted
);
}
#[test]
fn maxsim_weighted_empty_returns_zero(
weights in proptest::collection::vec(0.1f32..2.0, 0..5)
) {
let empty_q: Vec<&[f32]> = vec![];
let empty_d: Vec<&[f32]> = vec![];
let some_q: Vec<&[f32]> = vec![&[1.0, 0.0]];
prop_assert_eq!(maxsim_weighted(&empty_q, &some_q, &weights), 0.0);
prop_assert_eq!(maxsim_weighted(&some_q, &empty_d, &weights), 0.0);
}
#[test]
fn softmax_sums_to_one(scores in proptest::collection::vec(-100.0f32..100.0, 1..20)) {
let probs = softmax_scores(&scores);
let sum: f32 = probs.iter().sum();
prop_assert!(
(sum - 1.0).abs() < 1e-5,
"softmax sum {} should be 1",
sum
);
}
#[test]
fn softmax_preserves_order(scores in proptest::collection::vec(-10.0f32..10.0, 2..10)) {
let probs = softmax_scores(&scores);
for i in 0..scores.len() {
for j in (i + 1)..scores.len() {
if scores[i] > scores[j] {
prop_assert!(
probs[i] >= probs[j],
"softmax should preserve order: score[{}]={} > score[{}]={}, but prob {} <= {}",
i, scores[i], j, scores[j], probs[i], probs[j]
);
}
}
}
}
#[test]
fn softmax_outputs_bounded(scores in proptest::collection::vec(-50.0f32..50.0, 1..15)) {
let probs = softmax_scores(&scores);
for (i, p) in probs.iter().enumerate() {
prop_assert!(
*p >= 0.0 && *p <= 1.0,
"softmax[{}] = {} should be in [0, 1]",
i, p
);
}
}
#[test]
fn normalize_is_linear(a in 0.0f32..100.0, b in 0.0f32..100.0, maxlen in 1u32..100) {
let sum_norm = normalize_maxsim(a + b, maxlen);
let norm_sum = normalize_maxsim(a, maxlen) + normalize_maxsim(b, maxlen);
prop_assert!(
(sum_norm - norm_sum).abs() < 1e-5,
"normalize should be linear: {} vs {}",
sum_norm,
norm_sum
);
}
#[test]
fn top_k_returns_correct_count(
scores in proptest::collection::vec(-100.0f32..100.0, 1..20),
k in 1usize..25
) {
let result = top_k_indices(&scores, k);
let expected_len = k.min(scores.len());
prop_assert_eq!(
result.len(),
expected_len,
"top_k should return min(k, len)"
);
}
#[test]
fn top_k_indices_valid(
scores in proptest::collection::vec(-100.0f32..100.0, 1..20),
k in 1usize..25
) {
let result = top_k_indices(&scores, k);
for idx in &result {
prop_assert!(
*idx < scores.len(),
"index {} should be < len {}",
idx,
scores.len()
);
}
}
#[test]
fn top_k_indices_unique(
scores in proptest::collection::vec(-100.0f32..100.0, 1..20),
k in 1usize..25
) {
let result = top_k_indices(&scores, k);
let mut seen = std::collections::HashSet::new();
for idx in &result {
prop_assert!(
seen.insert(*idx),
"index {} appears twice",
idx
);
}
}
#[test]
fn cosine_divides_by_norms(dim in 2usize..16) {
let a: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
let b: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).cos()).collect();
let cos = cosine(&a, &b);
prop_assert!(
(-1.1..=1.1).contains(&cos) || cos.is_nan(),
"Cosine should be in [-1, 1]: {}",
cos
);
}
#[test]
fn cosine_checks_norm_gt_epsilon(dim in 2usize..8) {
let tiny: Vec<f32> = (0..dim).map(|i| if i == 0 { 1e-10 } else { 0.0 }).collect();
let normal: Vec<f32> = (0..dim).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
let cos = cosine(&tiny, &normal);
prop_assert_eq!(cos, 0.0, "Tiny norm should return 0.0: {}", cos);
}
#[test]
fn dot_multiplies_elements(dim in 2usize..16) {
let a: Vec<f32> = (0..dim).map(|_| 1.0).collect();
let b: Vec<f32> = (0..dim).map(|_| 2.0).collect();
let dot_product = dot(&a, &b);
prop_assert!((dot_product - (dim as f32 * 2.0)).abs() < 0.01, "Dot should multiply: {} ≈ {}", dot_product, dim * 2);
}
#[test]
fn norm_uses_sqrt(dim in 2usize..16) {
let v: Vec<f32> = (0..dim).map(|_| 3.0).collect();
let n = norm(&v);
let expected = 3.0 * (dim as f32).sqrt();
prop_assert!((n - expected).abs() < 0.01, "Norm should use sqrt: {} ≈ {}", n, expected);
}
#[test]
fn maxsim_sums_max_scores(n_query in 1usize..5, n_doc in 1usize..5, dim in 2usize..8) {
let query: Vec<Vec<f32>> = (0..n_query)
.map(|_| (0..dim).map(|j| if j == 0 { 1.0 } else { 0.0 }).collect())
.collect();
let doc: Vec<Vec<f32>> = (0..n_doc)
.map(|_| (0..dim).map(|j| if j == 0 { 0.9 } else { 0.0 }).collect())
.collect();
let score = maxsim_vecs(&query, &doc);
prop_assert!(score > 0.0 && score.is_finite(), "MaxSim should be positive and finite: {}", score);
}
#[test]
fn maxsim_alignments_compares_gt(n_query in 1usize..4, n_doc in 1usize..4, dim in 2usize..8) {
let query: Vec<Vec<f32>> = (0..n_query)
.map(|i| (0..dim).map(|j| if j == i % dim { 1.0 } else { 0.0 }).collect())
.collect();
let doc: Vec<Vec<f32>> = (0..n_doc)
.map(|i| (0..dim).map(|j| if j == i % dim { 0.9 } else { 0.0 }).collect())
.collect();
let alignments = maxsim_alignments_vecs(&query, &doc);
prop_assert_eq!(alignments.len(), n_query, "Should have one alignment per query token");
for (_, _, score) in &alignments {
prop_assert!(score.is_finite(), "Alignment score should be finite: {}", score);
}
}
#[test]
fn bm25_weights_rejects_df_gt_total(total_docs in 10usize..100, df in 1usize..200) {
let token_doc_freqs = vec![df];
let token_query_freqs = vec![1];
let k1 = 1.5;
let weights = bm25_weights(&token_doc_freqs, &token_query_freqs, total_docs, k1);
prop_assert_eq!(weights.len(), 1);
if df > total_docs {
let token_doc_freqs_multi = vec![df, total_docs / 2]; let token_query_freqs_multi = vec![1, 1];
let weights_multi = bm25_weights(&token_doc_freqs_multi, &token_query_freqs_multi, total_docs, k1);
prop_assert_eq!(
weights_multi[0], 0.0,
"df {} > total_docs {} should return 0.0 for first weight, got {}",
df, total_docs, weights_multi[0]
);
prop_assert!(
weights_multi[1] > 0.0,
"Valid df should return positive weight, got {}",
weights_multi[1]
);
} else if df == total_docs {
prop_assert!(
weights[0] != 0.0, "df {} == total_docs {} should return non-zero weight, got {}",
df, total_docs, weights[0]
);
}
}
}
}
#[inline]
#[must_use]
pub fn as_slices(tokens: &[Vec<f32>]) -> Vec<&[f32]> {
tokens.iter().map(Vec::as_slice).collect()
}