use crate::analysis::finite::ensure_finite_2d;
use crate::errors::AnalysisError;
use ndarray::Array2;
fn linear_kernel(x: &Array2<f32>) -> Array2<f32> {
x.dot(&x.t())
}
fn centre_kernel(k: &Array2<f32>) -> Result<Array2<f32>, AnalysisError> {
let row_mean = k
.mean_axis(ndarray::Axis(0))
.ok_or_else(|| AnalysisError::EmptyInput("CKA kernel must be non-empty".into()))?;
let grand_mean = row_mean
.mean()
.ok_or_else(|| AnalysisError::EmptyInput("CKA kernel mean must be non-empty".into()))?;
let mut kc = k.clone();
for mut row in kc.rows_mut() {
row -= &row_mean;
}
for (j, mut col) in kc.columns_mut().into_iter().enumerate() {
col -= &(ndarray::Array1::from_elem(col.len(), row_mean[j]));
}
kc += grand_mean;
Ok(kc)
}
fn hsic(kc: &Array2<f32>, lc: &Array2<f32>) -> f32 {
let n = kc.shape()[0];
if n < 2 {
return 0.0;
}
let trace: f32 = kc.iter().zip(lc.iter()).map(|(a, b)| a * b).sum();
trace / ((n - 1) * (n - 1)) as f32
}
pub fn linear_cka(x: &Array2<f32>, y: &Array2<f32>) -> Result<f32, AnalysisError> {
let nx = x.shape()[0];
let ny = y.shape()[0];
if nx != ny {
return Err(AnalysisError::ShapeMismatch {
expected: vec![nx, x.shape()[1]],
actual: vec![ny, y.shape()[1]],
});
}
if nx < 2 {
return Err(AnalysisError::InsufficientData(format!(
"CKA requires ≥2 samples, got {nx}"
)));
}
ensure_finite_2d(x, "left representation for CKA")?;
ensure_finite_2d(y, "right representation for CKA")?;
let kx = linear_kernel(x);
let ky = linear_kernel(y);
let kxc = centre_kernel(&kx)?;
let kyc = centre_kernel(&ky)?;
let hsic_xy = hsic(&kxc, &kyc);
let hsic_xx = hsic(&kxc, &kxc);
let hsic_yy = hsic(&kyc, &kyc);
let denom = (hsic_xx * hsic_yy).sqrt();
if !denom.is_finite() || denom < 1e-10 {
return Ok(0.0);
}
let cka = hsic_xy / denom;
if !cka.is_finite() {
return Ok(0.0);
}
Ok(cka.clamp(0.0, 1.0))
}
pub fn cls_cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_cka_identical() {
let x = Array2::from_shape_fn((20, 8), |(i, j)| (i + j) as f32);
let cka = linear_cka(&x, &x).unwrap();
assert_relative_eq!(cka, 1.0, epsilon = 1e-4);
}
#[test]
fn test_cka_range() {
let x = Array2::from_shape_fn((20, 8), |(i, j)| (i + j) as f32);
let y = Array2::from_shape_fn((20, 8), |(i, j)| (i * j) as f32);
let cka = linear_cka(&x, &y).unwrap();
assert!((0.0..=1.0 + 1e-5).contains(&cka));
}
#[test]
fn test_cka_shape_mismatch() {
let x = Array2::from_elem((10, 4), 1.0_f32);
let y = Array2::from_elem((12, 4), 1.0_f32);
assert!(linear_cka(&x, &y).is_err());
}
#[test]
fn test_cosine_similarity_identical() {
let a = ndarray::array![1.0f32, 2.0, 3.0];
let b = ndarray::array![1.0f32, 2.0, 3.0];
assert_relative_eq!(cls_cosine_similarity(&a, &b), 1.0, epsilon = 1e-5);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = ndarray::array![1.0f32, 0.0];
let b = ndarray::array![0.0f32, 1.0];
assert_relative_eq!(cls_cosine_similarity(&a, &b), 0.0, epsilon = 1e-5);
}
#[test]
fn test_cka_rejects_non_finite_values() {
let x = Array2::from_elem((10, 4), 1.0_f32);
let mut y = Array2::from_elem((10, 4), 1.0_f32);
y[[3, 1]] = f32::INFINITY;
let error = linear_cka(&x, &y).unwrap_err();
assert!(matches!(error, AnalysisError::NonFiniteValues { .. }));
}
}