use scirs2_core::ndarray::{Array1, Array2};
use crate::kernel_pca::error::{KernelPcaError, KernelPcaResult};
#[derive(Clone, Debug, PartialEq)]
pub struct KernelCenteringStats {
pub row_means: Array1<f64>,
pub grand_mean: f64,
}
impl KernelCenteringStats {
pub fn n(&self) -> usize {
self.row_means.len()
}
}
pub fn double_center(k: &Array2<f64>) -> KernelPcaResult<(Array2<f64>, KernelCenteringStats)> {
let (rows, cols) = (k.nrows(), k.ncols());
if rows == 0 || cols == 0 {
return Err(KernelPcaError::InvalidInput(
"double_center: Gram matrix must be non-empty".to_string(),
));
}
if rows != cols {
return Err(KernelPcaError::InvalidInput(format!(
"double_center: Gram matrix must be square, got {}x{}",
rows, cols
)));
}
let n = rows;
let n_f = n as f64;
let mut row_means = Array1::<f64>::zeros(n);
for j in 0..n {
let mut s = 0.0;
for i in 0..n {
s += k[(i, j)];
}
row_means[j] = s / n_f;
}
let grand_mean = row_means.iter().copied().sum::<f64>() / n_f;
let mut centered = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
centered[(i, j)] = k[(i, j)] - row_means[i] - row_means[j] + grand_mean;
}
}
for i in 0..n {
for j in (i + 1)..n {
let avg = 0.5 * (centered[(i, j)] + centered[(j, i)]);
centered[(i, j)] = avg;
centered[(j, i)] = avg;
}
}
Ok((
centered,
KernelCenteringStats {
row_means,
grand_mean,
},
))
}
pub fn center_test_kernel(
k_test: &[f64],
stats: &KernelCenteringStats,
) -> KernelPcaResult<Array1<f64>> {
let n = stats.n();
if k_test.len() != n {
return Err(KernelPcaError::DimensionMismatch {
expected: n,
got: k_test.len(),
context: "center_test_kernel: test kernel row length".to_string(),
});
}
let n_f = n as f64;
let row_mean_test = k_test.iter().copied().sum::<f64>() / n_f;
let mut out = Array1::<f64>::zeros(n);
for j in 0..n {
out[j] = k_test[j] - stats.row_means[j] - row_mean_test + stats.grand_mean;
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn constant_matrix(n: usize, value: f64) -> Array2<f64> {
Array2::<f64>::from_shape_fn((n, n), |_| value)
}
#[test]
fn double_center_rejects_empty_matrix() {
let k = Array2::<f64>::zeros((0, 0));
assert!(double_center(&k).is_err());
}
#[test]
fn double_center_rejects_non_square() {
let k = Array2::<f64>::zeros((3, 4));
assert!(double_center(&k).is_err());
}
#[test]
fn double_center_of_constant_matrix_is_zero() {
let k = constant_matrix(5, 3.7);
let (centered, stats) = double_center(&k).expect("double_center");
for v in centered.iter() {
assert!(v.abs() < 1e-12, "expected zero, got {}", v);
}
assert_eq!(stats.row_means.len(), 5);
for &rm in stats.row_means.iter() {
assert!((rm - 3.7).abs() < 1e-12);
}
assert!((stats.grand_mean - 3.7).abs() < 1e-12);
}
#[test]
fn double_center_row_column_sums_are_zero() {
let k = Array2::<f64>::from_shape_fn((4, 4), |(i, j)| ((i + 1) as f64) * ((j + 1) as f64));
let (centered, _) = double_center(&k).expect("double_center");
for i in 0..4 {
let row_sum: f64 = (0..4).map(|j| centered[(i, j)]).sum();
assert!(row_sum.abs() < 1e-10, "row {} sum = {}", i, row_sum);
}
for j in 0..4 {
let col_sum: f64 = (0..4).map(|i| centered[(i, j)]).sum();
assert!(col_sum.abs() < 1e-10, "col {} sum = {}", j, col_sum);
}
}
#[test]
fn double_center_is_symmetric() {
let k = Array2::<f64>::from_shape_fn((6, 6), |(i, j)| {
let a = (i as f64).sin();
let b = (j as f64).sin();
1.0 + a + b + 0.5 * (a * b)
});
let (centered, _) = double_center(&k).expect("double_center");
for i in 0..6 {
for j in 0..6 {
assert!(
(centered[(i, j)] - centered[(j, i)]).abs() < 1e-14,
"asymmetry at ({},{})",
i,
j
);
}
}
}
#[test]
fn center_test_kernel_matches_pulling_row_from_double_center() {
let k = Array2::<f64>::from_shape_fn((4, 4), |(i, j)| {
(-((i as f64 - j as f64).powi(2)) / 4.0).exp()
});
let (centered, stats) = double_center(&k).expect("double_center");
for i in 0..4 {
let test_row: Vec<f64> = (0..4).map(|j| k[(i, j)]).collect();
let out = center_test_kernel(&test_row, &stats).expect("center_test_kernel");
for j in 0..4 {
assert!(
(out[j] - centered[(i, j)]).abs() < 1e-12,
"row {} col {} mismatch: test={}, expected={}",
i,
j,
out[j],
centered[(i, j)]
);
}
}
}
#[test]
fn center_test_kernel_rejects_wrong_length() {
let stats = KernelCenteringStats {
row_means: Array1::<f64>::zeros(3),
grand_mean: 0.0,
};
let err = center_test_kernel(&[1.0, 2.0], &stats).expect_err("must reject");
match err {
KernelPcaError::DimensionMismatch { expected, got, .. } => {
assert_eq!(expected, 3);
assert_eq!(got, 2);
}
other => panic!("wrong variant: {:?}", other),
}
}
}