use ndarray::Array2;
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0_f32;
let mut norm_a = 0.0_f32;
let mut norm_b = 0.0_f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom == 0.0 { 0.0 } else { dot / denom }
}
pub fn build_kernel_matrix(
embeddings: &[&[f32]],
scores: &[f32],
lambda: f32,
regularization: f32,
) -> Array2<f32> {
let n = embeddings.len();
debug_assert_eq!(n, scores.len());
let mut kernel = Array2::<f32>::zeros((n, n));
let weights: Vec<f32> = scores
.iter()
.map(|&s| {
let s_clamped = s.clamp(0.0, 1.0);
s_clamped.powf(lambda)
})
.collect();
for i in 0..n {
for j in i..n {
let sim = cosine_similarity(embeddings[i], embeddings[j]);
let sim_clamped = sim.max(0.0);
let val = weights[i] * weights[j] * sim_clamped;
kernel[[i, j]] = val;
kernel[[j, i]] = val;
}
kernel[[i, i]] += regularization;
}
kernel
}
pub fn submatrix(kernel: &Array2<f32>, indices: &[usize]) -> Array2<f32> {
let k = indices.len();
let mut sub = Array2::<f32>::zeros((k, k));
for (i, &row) in indices.iter().enumerate() {
for (j, &col) in indices.iter().enumerate() {
sub[[i, j]] = kernel[[row, col]];
}
}
sub
}
pub fn cross_column(kernel: &Array2<f32>, selected: &[usize], candidate: usize) -> Vec<f32> {
selected.iter().map(|&s| kernel[[s, candidate]]).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![0.0, 0.0, 0.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_kernel_matrix_shape() {
let e1 = vec![1.0, 0.0];
let e2 = vec![0.0, 1.0];
let e3 = vec![1.0, 1.0];
let embeddings: Vec<&[f32]> = vec![&e1, &e2, &e3];
let scores = vec![0.9, 0.8, 0.7];
let kernel = build_kernel_matrix(&embeddings, &scores, 0.5, 1e-6);
assert_eq!(kernel.nrows(), 3);
assert_eq!(kernel.ncols(), 3);
}
#[test]
fn test_kernel_matrix_symmetric() {
let e1 = vec![1.0, 2.0, 3.0];
let e2 = vec![4.0, 5.0, 6.0];
let e3 = vec![7.0, 8.0, 9.0];
let embeddings: Vec<&[f32]> = vec![&e1, &e2, &e3];
let scores = vec![0.9, 0.8, 0.7];
let kernel = build_kernel_matrix(&embeddings, &scores, 0.5, 1e-6);
for i in 0..3 {
for j in 0..3 {
assert!(
(kernel[[i, j]] - kernel[[j, i]]).abs() < 1e-6,
"not symmetric at [{},{}]",
i,
j
);
}
}
}
#[test]
fn test_kernel_lambda_zero_ignores_relevance() {
let e1 = vec![1.0, 0.0];
let e2 = vec![0.0, 1.0];
let embeddings: Vec<&[f32]> = vec![&e1, &e2];
let scores = vec![0.1, 0.9];
let kernel = build_kernel_matrix(&embeddings, &scores, 0.0, 0.0);
assert!(kernel[[0, 1]].abs() < 1e-6);
assert!((kernel[[0, 0]] - 1.0).abs() < 1e-6);
assert!((kernel[[1, 1]] - 1.0).abs() < 1e-6);
}
#[test]
fn test_submatrix_extraction() {
let e1 = vec![1.0, 0.0];
let e2 = vec![0.0, 1.0];
let e3 = vec![1.0, 1.0];
let embeddings: Vec<&[f32]> = vec![&e1, &e2, &e3];
let scores = vec![0.9, 0.8, 0.7];
let kernel = build_kernel_matrix(&embeddings, &scores, 0.5, 1e-6);
let sub = submatrix(&kernel, &[0, 2]);
assert_eq!(sub.nrows(), 2);
assert_eq!(sub.ncols(), 2);
assert!((sub[[0, 0]] - kernel[[0, 0]]).abs() < 1e-6);
assert!((sub[[0, 1]] - kernel[[0, 2]]).abs() < 1e-6);
assert!((sub[[1, 0]] - kernel[[2, 0]]).abs() < 1e-6);
assert!((sub[[1, 1]] - kernel[[2, 2]]).abs() < 1e-6);
}
#[test]
fn test_cross_column() {
let e1 = vec![1.0, 0.0];
let e2 = vec![0.0, 1.0];
let e3 = vec![1.0, 1.0];
let embeddings: Vec<&[f32]> = vec![&e1, &e2, &e3];
let scores = vec![0.9, 0.8, 0.7];
let kernel = build_kernel_matrix(&embeddings, &scores, 0.5, 1e-6);
let cross = cross_column(&kernel, &[0, 1], 2);
assert_eq!(cross.len(), 2);
assert!((cross[0] - kernel[[0, 2]]).abs() < 1e-6);
assert!((cross[1] - kernel[[1, 2]]).abs() < 1e-6);
}
}