use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use crate::error::{Result, TransformError};
#[derive(Debug, Clone, PartialEq)]
pub enum KernelType {
Linear,
Polynomial {
gamma: f64,
coef0: f64,
degree: u32,
},
RBF {
gamma: f64,
},
Laplacian {
gamma: f64,
},
Sigmoid {
gamma: f64,
coef0: f64,
},
}
impl KernelType {
pub fn rbf_auto<S>(x: &ArrayBase<S, Ix2>) -> Result<Self>
where
S: Data,
S::Elem: Float + NumCast,
{
let gamma = estimate_rbf_gamma(x)?;
Ok(KernelType::RBF { gamma })
}
pub fn polynomial_default() -> Self {
KernelType::Polynomial {
gamma: 1.0,
coef0: 1.0,
degree: 3,
}
}
pub fn rbf(gamma: f64) -> Self {
KernelType::RBF { gamma }
}
pub fn laplacian(gamma: f64) -> Self {
KernelType::Laplacian { gamma }
}
pub fn sigmoid_default() -> Self {
KernelType::Sigmoid {
gamma: 1.0,
coef0: 0.0,
}
}
}
pub fn kernel_eval<S1, S2>(
x: &ArrayBase<S1, Ix1>,
y: &ArrayBase<S2, Ix1>,
kernel: &KernelType,
) -> Result<f64>
where
S1: Data,
S2: Data,
S1::Elem: Float + NumCast,
S2::Elem: Float + NumCast,
{
if x.len() != y.len() {
return Err(TransformError::InvalidInput(format!(
"Vector dimensions must match: {} vs {}",
x.len(),
y.len()
)));
}
let n = x.len();
match kernel {
KernelType::Linear => {
let mut dot = 0.0;
for i in 0..n {
let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
dot += xi * yi;
}
Ok(dot)
}
KernelType::Polynomial {
gamma,
coef0,
degree,
} => {
let mut dot = 0.0;
for i in 0..n {
let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
dot += xi * yi;
}
Ok((gamma * dot + coef0).powi(*degree as i32))
}
KernelType::RBF { gamma } => {
let mut dist_sq = 0.0;
for i in 0..n {
let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
let diff = xi - yi;
dist_sq += diff * diff;
}
Ok((-gamma * dist_sq).exp())
}
KernelType::Laplacian { gamma } => {
let mut l1_dist = 0.0;
for i in 0..n {
let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
l1_dist += (xi - yi).abs();
}
Ok((-gamma * l1_dist).exp())
}
KernelType::Sigmoid { gamma, coef0 } => {
let mut dot = 0.0;
for i in 0..n {
let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
dot += xi * yi;
}
Ok((gamma * dot + coef0).tanh())
}
}
}
pub fn gram_matrix<S>(x: &ArrayBase<S, Ix2>, kernel: &KernelType) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
let n_samples = x.nrows();
let mut k = Array2::zeros((n_samples, n_samples));
for i in 0..n_samples {
for j in i..n_samples {
let val = kernel_eval(&x.row(i), &x.row(j), kernel)?;
k[[i, j]] = val;
k[[j, i]] = val;
}
}
Ok(k)
}
pub fn cross_gram_matrix<S1, S2>(
x: &ArrayBase<S1, Ix2>,
y: &ArrayBase<S2, Ix2>,
kernel: &KernelType,
) -> Result<Array2<f64>>
where
S1: Data,
S2: Data,
S1::Elem: Float + NumCast,
S2::Elem: Float + NumCast,
{
if x.ncols() != y.ncols() {
return Err(TransformError::InvalidInput(format!(
"Feature dimensions must match: {} vs {}",
x.ncols(),
y.ncols()
)));
}
let n_x = x.nrows();
let n_y = y.nrows();
let mut k = Array2::zeros((n_x, n_y));
for i in 0..n_x {
for j in 0..n_y {
k[[i, j]] = kernel_eval(&x.row(i), &y.row(j), kernel)?;
}
}
Ok(k)
}
pub fn center_kernel_matrix(k: &Array2<f64>) -> Result<Array2<f64>> {
let n = k.nrows();
if n != k.ncols() {
return Err(TransformError::InvalidInput(
"Kernel matrix must be square".to_string(),
));
}
if n == 0 {
return Err(TransformError::InvalidInput(
"Kernel matrix must be non-empty".to_string(),
));
}
let n_f64 = n as f64;
let row_means = k.mean_axis(Axis(0)).ok_or_else(|| {
TransformError::ComputationError("Failed to compute row means".to_string())
})?;
let col_means = k.mean_axis(Axis(1)).ok_or_else(|| {
TransformError::ComputationError("Failed to compute column means".to_string())
})?;
let grand_mean = row_means.sum() / n_f64;
let mut k_centered = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
k_centered[[i, j]] = k[[i, j]] - row_means[j] - col_means[i] + grand_mean;
}
}
Ok(k_centered)
}
pub fn center_kernel_matrix_test(
k_test: &Array2<f64>,
k_train: &Array2<f64>,
) -> Result<Array2<f64>> {
let n_train = k_train.nrows();
let n_test = k_test.nrows();
if k_train.nrows() != k_train.ncols() {
return Err(TransformError::InvalidInput(
"Training kernel matrix must be square".to_string(),
));
}
if k_test.ncols() != n_train {
return Err(TransformError::InvalidInput(format!(
"Test kernel matrix columns ({}) must match training samples ({})",
k_test.ncols(),
n_train
)));
}
let n_f64 = n_train as f64;
let train_col_means = k_train.mean_axis(Axis(0)).ok_or_else(|| {
TransformError::ComputationError("Failed to compute train column means".to_string())
})?;
let test_row_means = k_test.mean_axis(Axis(1)).ok_or_else(|| {
TransformError::ComputationError("Failed to compute test row means".to_string())
})?;
let train_grand_mean = train_col_means.sum() / n_f64;
let mut k_centered = Array2::zeros((n_test, n_train));
for i in 0..n_test {
for j in 0..n_train {
k_centered[[i, j]] =
k_test[[i, j]] - test_row_means[i] - train_col_means[j] + train_grand_mean;
}
}
Ok(k_centered)
}
pub fn estimate_rbf_gamma<S>(x: &ArrayBase<S, Ix2>) -> Result<f64>
where
S: Data,
S::Elem: Float + NumCast,
{
let n = x.nrows();
if n < 2 {
return Err(TransformError::InvalidInput(
"Need at least 2 samples to estimate gamma".to_string(),
));
}
let mut distances: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let mut dist_sq = 0.0;
for k in 0..x.ncols() {
let xi: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0);
let xj: f64 = NumCast::from(x[[j, k]]).unwrap_or(0.0);
let diff = xi - xj;
dist_sq += diff * diff;
}
distances.push(dist_sq);
}
}
distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_sq = if distances.len() % 2 == 0 {
let mid = distances.len() / 2;
(distances[mid - 1] + distances[mid]) / 2.0
} else {
distances[distances.len() / 2]
};
if median_sq < 1e-15 {
Ok(1.0)
} else {
Ok(1.0 / (2.0 * median_sq))
}
}
pub fn kernel_diagonal<S>(x: &ArrayBase<S, Ix2>, kernel: &KernelType) -> Result<Array1<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
let n = x.nrows();
let mut diag = Array1::zeros(n);
for i in 0..n {
diag[i] = kernel_eval(&x.row(i), &x.row(i), kernel)?;
}
Ok(diag)
}
pub fn is_positive_semidefinite(k: &Array2<f64>, tol: f64) -> Result<bool> {
if k.nrows() != k.ncols() {
return Err(TransformError::InvalidInput(
"Matrix must be square".to_string(),
));
}
let (eigenvalues, _) =
scirs2_linalg::eigh(&k.view(), None).map_err(TransformError::LinalgError)?;
let min_eigenvalue = eigenvalues.iter().copied().fold(f64::INFINITY, f64::min);
Ok(min_eigenvalue >= tol)
}
pub fn kernel_alignment(k1: &Array2<f64>, k2: &Array2<f64>) -> Result<f64> {
if k1.dim() != k2.dim() {
return Err(TransformError::InvalidInput(
"Kernel matrices must have the same dimensions".to_string(),
));
}
let frobenius_inner: f64 = k1.iter().zip(k2.iter()).map(|(&a, &b)| a * b).sum();
let norm1: f64 = k1.iter().map(|&a| a * a).sum::<f64>().sqrt();
let norm2: f64 = k2.iter().map(|&a| a * a).sum::<f64>().sqrt();
let denom = norm1 * norm2;
if denom < 1e-15 {
Ok(0.0)
} else {
Ok((frobenius_inner / denom).clamp(0.0, 1.0))
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
fn sample_data() -> Array2<f64> {
Array::from_shape_vec(
(5, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
],
)
.expect("Failed to create sample data")
}
#[test]
fn test_linear_kernel() {
let x = Array::from_vec(vec![1.0, 2.0, 3.0]);
let y = Array::from_vec(vec![4.0, 5.0, 6.0]);
let result =
kernel_eval(&x.view(), &y.view(), &KernelType::Linear).expect("kernel eval failed");
assert!((result - 32.0).abs() < 1e-10);
}
#[test]
fn test_polynomial_kernel() {
let x = Array::from_vec(vec![1.0, 2.0]);
let y = Array::from_vec(vec![3.0, 4.0]);
let kernel = KernelType::Polynomial {
gamma: 1.0,
coef0: 1.0,
degree: 2,
};
let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
assert!((result - 144.0).abs() < 1e-10);
}
#[test]
fn test_rbf_kernel() {
let x = Array::from_vec(vec![1.0, 0.0]);
let y = Array::from_vec(vec![0.0, 1.0]);
let kernel = KernelType::RBF { gamma: 0.5 };
let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
assert!((result - (-1.0_f64).exp()).abs() < 1e-10);
}
#[test]
fn test_rbf_kernel_self() {
let x = Array::from_vec(vec![1.0, 2.0, 3.0]);
let kernel = KernelType::RBF { gamma: 1.0 };
let result = kernel_eval(&x.view(), &x.view(), &kernel).expect("kernel eval failed");
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn test_laplacian_kernel() {
let x = Array::from_vec(vec![1.0, 2.0]);
let y = Array::from_vec(vec![3.0, 4.0]);
let kernel = KernelType::Laplacian { gamma: 0.5 };
let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
assert!((result - (-2.0_f64).exp()).abs() < 1e-10);
}
#[test]
fn test_sigmoid_kernel() {
let x = Array::from_vec(vec![1.0, 0.0]);
let y = Array::from_vec(vec![0.0, 1.0]);
let kernel = KernelType::Sigmoid {
gamma: 1.0,
coef0: 0.0,
};
let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
assert!((result - 0.0).abs() < 1e-10);
}
#[test]
fn test_gram_matrix_symmetry() {
let data = sample_data();
let kernel = KernelType::RBF { gamma: 0.1 };
let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
assert_eq!(k.shape(), &[5, 5]);
for i in 0..5 {
for j in 0..5 {
assert!(
(k[[i, j]] - k[[j, i]]).abs() < 1e-10,
"Gram matrix not symmetric at ({}, {})",
i,
j
);
}
}
}
#[test]
fn test_gram_matrix_diagonal() {
let data = sample_data();
let kernel = KernelType::RBF { gamma: 0.1 };
let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
for i in 0..5 {
assert!(
(k[[i, i]] - 1.0).abs() < 1e-10,
"RBF diagonal should be 1.0"
);
}
}
#[test]
fn test_cross_gram_matrix() {
let x = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("Failed");
let y = Array::from_shape_vec((2, 2), vec![1.5, 2.5, 3.5, 4.5]).expect("Failed");
let kernel = KernelType::Linear;
let k = cross_gram_matrix(&x.view(), &y.view(), &kernel).expect("cross gram matrix failed");
assert_eq!(k.shape(), &[3, 2]);
assert!((k[[0, 0]] - 6.5).abs() < 1e-10);
}
#[test]
fn test_center_kernel_matrix() {
let data = sample_data();
let kernel = KernelType::RBF { gamma: 0.01 };
let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
let k_centered = center_kernel_matrix(&k).expect("centering failed");
let col_means = k_centered
.mean_axis(Axis(0))
.expect("Failed to compute means");
for i in 0..col_means.len() {
assert!(
col_means[i].abs() < 1e-10,
"Centered kernel column mean should be ~0, got {}",
col_means[i]
);
}
let row_means = k_centered
.mean_axis(Axis(1))
.expect("Failed to compute means");
for i in 0..row_means.len() {
assert!(
row_means[i].abs() < 1e-10,
"Centered kernel row mean should be ~0, got {}",
row_means[i]
);
}
}
#[test]
fn test_center_kernel_matrix_test() {
let x_train = sample_data();
let x_test =
Array::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).expect("Failed");
let kernel = KernelType::RBF { gamma: 0.01 };
let k_train = gram_matrix(&x_train.view(), &kernel).expect("gram failed");
let k_test =
cross_gram_matrix(&x_test.view(), &x_train.view(), &kernel).expect("cross gram failed");
let k_test_centered =
center_kernel_matrix_test(&k_test, &k_train).expect("test centering failed");
assert_eq!(k_test_centered.shape(), &[2, 5]);
for val in k_test_centered.iter() {
assert!(val.is_finite());
}
}
#[test]
fn test_estimate_rbf_gamma() {
let data = sample_data();
let gamma = estimate_rbf_gamma(&data.view()).expect("gamma estimation failed");
assert!(gamma > 0.0);
assert!(gamma.is_finite());
}
#[test]
fn test_kernel_diagonal() {
let data = sample_data();
let kernel = KernelType::Linear;
let diag = kernel_diagonal(&data.view(), &kernel).expect("diagonal failed");
assert_eq!(diag.len(), 5);
assert!((diag[0] - 14.0).abs() < 1e-10);
}
#[test]
fn test_rbf_gram_psd() {
let data = sample_data();
let kernel = KernelType::RBF { gamma: 0.1 };
let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
let psd = is_positive_semidefinite(&k, -1e-10).expect("PSD check failed");
assert!(psd, "RBF Gram matrix should be PSD");
}
#[test]
fn test_kernel_alignment() {
let data = sample_data();
let k1 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.1 }).expect("gram failed");
let k2 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.1 }).expect("gram failed");
let alignment = kernel_alignment(&k1, &k2).expect("alignment failed");
assert!(
(alignment - 1.0).abs() < 1e-10,
"Self-alignment should be 1.0, got {}",
alignment
);
}
#[test]
fn test_kernel_alignment_different() {
let data = sample_data();
let k1 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.01 }).expect("gram failed");
let k2 = gram_matrix(&data.view(), &KernelType::Linear).expect("gram failed");
let alignment = kernel_alignment(&k1, &k2).expect("alignment failed");
assert!(alignment >= 0.0 && alignment <= 1.0);
}
#[test]
fn test_rbf_auto() {
let data = sample_data();
let kernel = KernelType::rbf_auto(&data.view()).expect("auto rbf failed");
match kernel {
KernelType::RBF { gamma } => {
assert!(gamma > 0.0);
assert!(gamma.is_finite());
}
_ => panic!("Expected RBF kernel type"),
}
}
#[test]
fn test_dimension_mismatch() {
let x = Array::from_vec(vec![1.0, 2.0]);
let y = Array::from_vec(vec![1.0, 2.0, 3.0]);
let result = kernel_eval(&x.view(), &y.view(), &KernelType::Linear);
assert!(result.is_err());
}
}