use crate::dense::{cosine, dot};
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
use crate::arch;
#[inline]
#[must_use]
#[allow(unsafe_code)]
pub fn maxsim(query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
let dim = query_tokens[0].len();
assert!(
query_tokens.iter().all(|t| t.len() == dim),
"dimension mismatch (query)"
);
assert!(
doc_tokens.iter().all(|t| t.len() == dim),
"dimension mismatch (doc)"
);
#[cfg(target_arch = "x86_64")]
{
if dim >= 64 && is_x86_feature_detected!("avx512f") {
return unsafe { arch::x86_64::maxsim_avx512(query_tokens, doc_tokens) };
}
if dim >= 16 && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { arch::x86_64::maxsim_avx2(query_tokens, doc_tokens) };
}
}
#[cfg(target_arch = "aarch64")]
{
if dim >= 16 {
return unsafe { arch::aarch64::maxsim_neon(query_tokens, doc_tokens) };
}
}
maxsim_portable(query_tokens, doc_tokens)
}
#[inline]
#[must_use]
fn maxsim_portable(query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
query_tokens
.iter()
.map(|q| {
doc_tokens
.iter()
.map(|d| dot(q, d)) .fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
#[inline]
#[must_use]
pub fn maxsim_cosine(query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
query_tokens
.iter()
.map(|q| {
doc_tokens
.iter()
.map(|d| cosine(q, d))
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maxsim_basic() {
let q1 = [1.0f32, 0.0];
let q2 = [0.0f32, 1.0];
let d1 = [0.9f32, 0.1];
let d2 = [0.1f32, 0.9];
let query: &[&[f32]] = &[&q1, &q2];
let doc: &[&[f32]] = &[&d1, &d2];
let score = maxsim(query, doc);
assert!((score - 1.8).abs() < 1e-6);
}
#[test]
fn test_maxsim_empty() {
let q1 = [1.0f32, 0.0];
let query: &[&[f32]] = &[&q1];
let empty: &[&[f32]] = &[];
assert_eq!(maxsim(query, empty), 0.0);
assert_eq!(maxsim(empty, query), 0.0);
}
#[test]
fn test_maxsim_not_commutative() {
let q1 = [1.0f32, 0.0];
let d1 = [0.5f32, 0.5];
let d2 = [0.5f32, 0.5];
let query: &[&[f32]] = &[&q1];
let doc: &[&[f32]] = &[&d1, &d2];
let score_qd = maxsim(query, doc);
let score_dq = maxsim(doc, query);
assert!((score_qd - 0.5).abs() < 1e-6);
assert!((score_dq - 1.0).abs() < 1e-6);
assert!((score_qd - score_dq).abs() > 0.4); }
#[test]
fn test_maxsim_cosine_normalized() {
let q1 = [1.0f32, 0.0]; let d1 = [1.0f32, 0.0];
let query: &[&[f32]] = &[&q1];
let doc: &[&[f32]] = &[&d1];
let dot_score = maxsim(query, doc);
let cos_score = maxsim_cosine(query, doc);
assert!((dot_score - cos_score).abs() < 1e-6);
}
#[test]
fn test_maxsim_cosine_unnormalized() {
let q1 = [2.0f32, 0.0]; let d1 = [3.0f32, 0.0];
let query: &[&[f32]] = &[&q1];
let doc: &[&[f32]] = &[&d1];
let cos_score = maxsim_cosine(query, doc);
assert!((cos_score - 1.0).abs() < 1e-6);
}
#[test]
fn test_maxsim_single_query_single_doc() {
let q = [1.0f32, 2.0, 3.0];
let d = [4.0f32, 5.0, 6.0];
let query: &[&[f32]] = &[&q];
let doc: &[&[f32]] = &[&d];
let score = maxsim(query, doc);
assert!((score - 32.0).abs() < 1e-6);
}
#[test]
fn test_maxsim_multiple_query_multiple_doc() {
let q1 = [1.0f32, 0.0, 0.0];
let q2 = [0.0f32, 1.0, 0.0];
let q3 = [0.0f32, 0.0, 1.0];
let d1 = [0.5f32, 0.3, 0.0];
let d2 = [0.0f32, 0.7, 0.9];
let query: &[&[f32]] = &[&q1, &q2, &q3];
let doc: &[&[f32]] = &[&d1, &d2];
let score = maxsim(query, doc);
assert!((score - 2.1).abs() < 1e-6);
}
#[test]
fn test_maxsim_identical_embeddings() {
let v = [1.0f32, 0.0, 0.0];
let query: &[&[f32]] = &[&v, &v, &v];
let doc: &[&[f32]] = &[&v, &v];
let score = maxsim(query, doc);
assert!((score - 3.0).abs() < 1e-6);
}
#[test]
fn test_maxsim_orthogonal_embeddings() {
let q1 = [1.0f32, 0.0, 0.0, 0.0];
let q2 = [0.0f32, 1.0, 0.0, 0.0];
let d1 = [0.0f32, 0.0, 1.0, 0.0];
let d2 = [0.0f32, 0.0, 0.0, 1.0];
let query: &[&[f32]] = &[&q1, &q2];
let doc: &[&[f32]] = &[&d1, &d2];
let score = maxsim(query, doc);
assert!(score.abs() < 1e-6);
}
#[test]
fn test_maxsim_cosine_orthogonal() {
let q1 = [1.0f32, 0.0];
let d1 = [0.0f32, 1.0];
let query: &[&[f32]] = &[&q1];
let doc: &[&[f32]] = &[&d1];
let score = maxsim_cosine(query, doc);
assert!(score.abs() < 1e-6);
}
#[test]
fn test_maxsim_cosine_identical() {
let v = [3.0f32, 4.0]; let query: &[&[f32]] = &[&v, &v];
let doc: &[&[f32]] = &[&v];
let score = maxsim_cosine(query, doc);
assert!((score - 2.0).abs() < 1e-6);
}
#[test]
fn test_maxsim_cosine_empty() {
let v = [1.0f32, 0.0];
let query: &[&[f32]] = &[&v];
let empty: &[&[f32]] = &[];
assert_eq!(maxsim_cosine(query, empty), 0.0);
assert_eq!(maxsim_cosine(empty, query), 0.0);
}
#[test]
fn test_maxsim_higher_dim() {
let q1 = [1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let d1 = [0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
let d2 = [0.5f32, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let query: &[&[f32]] = &[&q1];
let doc: &[&[f32]] = &[&d1, &d2];
let score = maxsim(query, doc);
assert!((score - 0.5).abs() < 1e-6);
}
}