use burn::tensor::{ElementConversion, Tensor, backend::Backend};
#[derive(Debug, Clone)]
pub struct WhiteningConfig {
pub eps: f64,
pub method: WhiteningMethod,
}
#[derive(Debug, Clone, Copy)]
pub enum WhiteningMethod {
Zca,
Pca,
None,
}
impl Default for WhiteningConfig {
fn default() -> Self {
Self {
eps: 1e-5,
method: WhiteningMethod::Zca,
}
}
}
#[derive(Debug, Clone)]
pub struct Whitening {
cfg: WhiteningConfig,
}
impl Whitening {
pub fn new(cfg: WhiteningConfig) -> Self {
Self { cfg }
}
pub fn forward<B: Backend>(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
match self.cfg.method {
WhiteningMethod::None => x,
WhiteningMethod::Zca => self.zca_whiten(x),
WhiteningMethod::Pca => self.pca_whiten(x),
}
}
fn zca_whiten<B: Backend>(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let [n, d] = x.dims();
let device = x.device();
let mean = x.clone().mean_dim(0);
let mean_broadcast = mean.reshape([1, d]);
let x_centered = x.clone() - mean_broadcast;
let cov = x_centered
.clone()
.transpose()
.matmul(x_centered.clone())
.div_scalar(n as f64);
let eye = Tensor::<B, 2>::eye(d, &device).mul_scalar(self.cfg.eps);
let cov_reg = cov + eye;
let whitening_matrix = self.inverse_sqrt_symmetric(cov_reg);
x_centered.matmul(whitening_matrix)
}
fn pca_whiten<B: Backend>(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let [n, d] = x.dims();
let device = x.device();
let mean = x.clone().mean_dim(0);
let mean_broadcast = mean.reshape([1, d]);
let x_centered = x.clone() - mean_broadcast;
let cov = x_centered
.clone()
.transpose()
.matmul(x_centered.clone())
.div_scalar(n as f64);
let eye = Tensor::<B, 2>::eye(d, &device).mul_scalar(self.cfg.eps);
let cov_reg = cov + eye;
let whitening_matrix = self.inverse_sqrt_symmetric(cov_reg);
x_centered.matmul(whitening_matrix)
}
fn inverse_sqrt_symmetric<B: Backend>(&self, a: Tensor<B, 2>) -> Tensor<B, 2> {
let [d, _] = a.dims();
let device = a.device();
let a_norm = a.clone().powf_scalar(2.0).sum().sqrt();
let norm_val = a_norm.into_scalar().elem::<f64>();
let mut y = Tensor::<B, 2>::eye(d, &device).div_scalar(norm_val);
for _ in 0..5 {
let y2 = y.clone().matmul(y.clone());
let ay2 = a.clone().matmul(y2);
let three_i = Tensor::<B, 2>::eye(d, &device).mul_scalar(3.0);
y = y.matmul(three_i - ay2).div_scalar(2.0);
}
y
}
}