use tensorlogic_sklears_kernels::{
deep_kernel::{Activation, DeepKernelBuilder},
Kernel, RbfKernel, RbfKernelConfig,
};
#[allow(clippy::needless_range_loop)]
fn is_symmetric(g: &[Vec<f64>], tol: f64) -> bool {
let n = g.len();
for i in 0..n {
if g[i].len() != n {
return false;
}
for j in 0..n {
if (g[i][j] - g[j][i]).abs() > tol {
return false;
}
}
}
true
}
#[allow(clippy::needless_range_loop)]
fn cholesky(g: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, String> {
let n = g.len();
let mut l = vec![vec![0.0f64; n]; n];
for i in 0..n {
for j in 0..=i {
let mut sum = g[i][j];
for k in 0..j {
sum -= l[i][k] * l[j][k];
}
if i == j {
if sum <= 0.0 {
return Err(format!(
"not PSD: leading minor ({}) ≤ 0 at index {}",
sum, i
));
}
l[i][j] = sum.sqrt();
} else {
l[i][j] = sum / l[j][j];
}
}
}
Ok(l)
}
#[test]
fn deep_kernel_gram_is_symmetric_and_psd() {
let xs: Vec<Vec<f64>> = vec![
vec![-1.0, -1.0],
vec![-0.8, -0.9],
vec![-1.1, -0.7],
vec![0.9, 1.0],
vec![1.1, 0.8],
vec![1.0, 0.9],
];
let rbf = RbfKernel::new(RbfKernelConfig::new(0.75)).expect("valid gamma");
let dkl = DeepKernelBuilder::new()
.input_dim(2)
.hidden_layer(4, Activation::Tanh)
.output_dim(3, Activation::Identity)
.seed(0x1337)
.build(rbf)
.expect("valid topology");
let gram = dkl
.compute_symmetric_gram(&xs)
.expect("symmetric gram succeeds");
assert!(
is_symmetric(&gram, 1e-12),
"DKL Gram must be symmetric: {:?}",
gram
);
for (i, row) in gram.iter().enumerate() {
let diag = row[i];
assert!(
(diag - 1.0).abs() < 1e-12,
"diagonal {} should be 1, got {}",
i,
diag
);
}
let ridge = 1e-9;
let mut ridged = gram.clone();
for (i, row) in ridged.iter_mut().enumerate() {
row[i] += ridge;
}
let _l = cholesky(&ridged).expect("Gram + εI must be PSD");
}
#[test]
fn deep_kernel_matches_direct_composition() {
let xs: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![-1.0, -1.0]];
let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid gamma");
let dkl = DeepKernelBuilder::new()
.input_dim(2)
.hidden_layer(5, Activation::ReLU)
.output_dim(2, Activation::Identity)
.seed(99)
.build(rbf)
.expect("valid");
let gram = dkl.compute_symmetric_gram(&xs).expect("gram");
for (i, xi) in xs.iter().enumerate() {
for (j, xj) in xs.iter().enumerate() {
let direct = dkl.compute(xi, xj).expect("direct");
assert!(
(gram[i][j] - direct).abs() < 1e-12,
"pair ({},{}) mismatch: gram={}, direct={}",
i,
j,
gram[i][j],
direct
);
}
}
}