#[derive(Debug, Clone)]
pub struct Normalized {
data: Vec<f32>,
}
impl Normalized {
#[inline]
#[must_use]
pub fn as_slice(&self) -> &[f32] {
&self.data
}
#[inline]
#[must_use]
pub fn dim(&self) -> usize {
self.data.len()
}
#[inline]
#[must_use]
pub fn dot(&self, other: &Normalized) -> f32 {
super::simd::dot(&self.data, &other.data)
}
#[inline]
#[must_use]
pub fn cosine(&self, other: &Normalized) -> f32 {
self.dot(other)
}
}
#[must_use]
pub fn normalize(v: &[f32]) -> Option<Normalized> {
let norm = super::simd::norm(v);
if norm < 1e-9 {
return None;
}
let data: Vec<f32> = v.iter().map(|&x| x / norm).collect();
Some(Normalized { data })
}
#[must_use]
pub fn normalize_or_zero(v: &[f32]) -> Normalized {
normalize(v).unwrap_or_else(|| Normalized {
data: vec![0.0; v.len()],
})
}
#[derive(Debug, Clone)]
pub struct MaskedTokens {
tokens: Vec<Vec<f32>>,
mask: Vec<bool>,
}
impl MaskedTokens {
#[must_use]
pub fn new(tokens: Vec<Vec<f32>>, mask: Vec<bool>) -> Self {
assert_eq!(
tokens.len(),
mask.len(),
"Token count {} must match mask length {}",
tokens.len(),
mask.len()
);
Self { tokens, mask }
}
#[must_use]
pub fn from_tokens(tokens: Vec<Vec<f32>>) -> Self {
let mask = vec![true; tokens.len()];
Self { tokens, mask }
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.tokens.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
#[inline]
#[must_use]
pub fn valid_count(&self) -> usize {
self.mask.iter().filter(|&&m| m).count()
}
pub fn valid_tokens(&self) -> impl Iterator<Item = &[f32]> {
self.tokens
.iter()
.zip(self.mask.iter())
.filter_map(|(t, &m)| if m { Some(t.as_slice()) } else { None })
}
#[must_use]
pub fn all_tokens(&self) -> &[Vec<f32>] {
&self.tokens
}
#[must_use]
pub fn mask(&self) -> &[bool] {
&self.mask
}
}
#[must_use]
pub fn maxsim_masked(query: &MaskedTokens, doc: &MaskedTokens) -> f32 {
if query.valid_count() == 0 || doc.valid_count() == 0 {
return 0.0;
}
let valid_doc_tokens: Vec<&[f32]> = doc.valid_tokens().collect();
query
.valid_tokens()
.map(|q| {
valid_doc_tokens
.iter()
.map(|d| super::simd::dot(q, d))
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
let v = normalize(&[3.0, 4.0]).unwrap();
let norm: f32 = v.as_slice().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn test_normalize_zero() {
assert!(normalize(&[0.0, 0.0, 0.0]).is_none());
}
#[test]
fn test_normalized_dot_is_cosine() {
let a = normalize(&[1.0, 0.0]).unwrap();
let b = normalize(&[1.0, 1.0]).unwrap();
let dot_result = a.dot(&b);
let cosine_result = a.cosine(&b);
assert!((dot_result - cosine_result).abs() < 1e-6);
}
#[test]
fn test_masked_tokens() {
let tokens = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![0.0, 0.0], ];
let mask = vec![true, true, false];
let masked = MaskedTokens::new(tokens, mask);
assert_eq!(masked.len(), 3);
assert_eq!(masked.valid_count(), 2);
}
#[test]
fn test_maxsim_masked() {
let query = MaskedTokens::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![true, true]);
let doc = MaskedTokens::new(
vec![vec![1.0, 0.0], vec![0.5, 0.5], vec![0.0, 0.0]], vec![true, true, false],
);
let score = maxsim_masked(&query, &doc);
assert!((score - 1.5).abs() < 1e-5);
}
#[test]
fn test_maxsim_masked_empty() {
let query = MaskedTokens::new(vec![vec![1.0, 0.0]], vec![false]); let doc = MaskedTokens::from_tokens(vec![vec![1.0, 0.0]]);
let score = maxsim_masked(&query, &doc);
assert!((score - 0.0).abs() < 1e-9);
}
#[test]
#[should_panic(expected = "Token count")]
fn test_masked_mismatched_lengths() {
let _ = MaskedTokens::new(vec![vec![1.0]], vec![true, false]);
}
#[test]
fn normalize_uses_strict_less_than() {
let tiny: Vec<f32> = vec![1e-10; 10]; let result = normalize(&tiny);
assert!(
result.is_none(),
"Tiny vector with norm < 1e-9 should return None"
);
let small: Vec<f32> = vec![1e-8; 10]; let result2 = normalize(&small);
assert!(
result2.is_some(),
"Vector with norm > 1e-9 should normalize"
);
if let Some(normed) = result2 {
let norm: f32 = normed.as_slice().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4, "Should be normalized");
}
}
}
#[cfg(test)]
mod proptests {
use super::*;
use crate::rerank::simd;
use proptest::prelude::*;
proptest! {
#[test]
fn normalized_has_unit_norm(v in proptest::collection::vec(-10.0f32..10.0, 2..16)) {
if let Some(n) = normalize(&v) {
let norm: f32 = n.as_slice().iter().map(|x| x * x).sum::<f32>().sqrt();
prop_assert!((norm - 1.0).abs() < 1e-4, "Norm was {}", norm);
}
}
#[test]
fn normalized_dot_symmetric(
dim in 4usize..8,
a_vals in proptest::collection::vec(-10.0f32..10.0, 8),
b_vals in proptest::collection::vec(-10.0f32..10.0, 8)
) {
let a: Vec<f32> = a_vals.into_iter().take(dim).collect();
let b: Vec<f32> = b_vals.into_iter().take(dim).collect();
if let (Some(na), Some(nb)) = (normalize(&a), normalize(&b)) {
let ab = na.dot(&nb);
let ba = nb.dot(&na);
prop_assert!((ab - ba).abs() < 1e-5);
}
}
#[test]
fn normalized_cosine_bounded(
dim in 4usize..8,
a_vals in proptest::collection::vec(-10.0f32..10.0, 8),
b_vals in proptest::collection::vec(-10.0f32..10.0, 8)
) {
let a: Vec<f32> = a_vals.into_iter().take(dim).collect();
let b: Vec<f32> = b_vals.into_iter().take(dim).collect();
if let (Some(na), Some(nb)) = (normalize(&a), normalize(&b)) {
let cos = na.cosine(&nb);
prop_assert!((-1.01..=1.01).contains(&cos), "Cosine was {}", cos);
}
}
#[test]
fn masked_valid_count_matches(
n_tokens in 1usize..10,
dim in 2usize..8
) {
let tokens: Vec<Vec<f32>> = (0..n_tokens)
.map(|i| (0..dim).map(|j| (i + j) as f32 * 0.1).collect())
.collect();
let mask: Vec<bool> = (0..n_tokens).map(|i| i % 2 == 0).collect();
let expected_valid = mask.iter().filter(|&&m| m).count();
let masked = MaskedTokens::new(tokens, mask);
prop_assert_eq!(masked.valid_count(), expected_valid);
}
#[test]
fn maxsim_masked_all_valid_equals_regular(
n_q in 1usize..4,
n_d in 1usize..4,
dim in 2usize..8
) {
let query_tokens: Vec<Vec<f32>> = (0..n_q)
.map(|i| (0..dim).map(|j| ((i * dim + j) as f32 * 0.1).sin()).collect())
.collect();
let doc_tokens: Vec<Vec<f32>> = (0..n_d)
.map(|i| (0..dim).map(|j| ((i * dim + j + 1) as f32 * 0.1).cos()).collect())
.collect();
let masked_q = MaskedTokens::from_tokens(query_tokens.clone());
let masked_d = MaskedTokens::from_tokens(doc_tokens.clone());
let masked_score = maxsim_masked(&masked_q, &masked_d);
let regular_score = simd::maxsim_vecs(&query_tokens, &doc_tokens);
prop_assert!((masked_score - regular_score).abs() < 1e-5);
}
}
}