use crate::error::{LinalgError, LinalgResult};
use crate::tensor_decomp::tensor_utils::{
gram, hadamard, mat_mul, mat_transpose, mode_n_product, solve_spd, truncated_svd, Tensor3D,
};
#[derive(Debug, Clone)]
pub struct CPDecomp {
pub a: Vec<Vec<f64>>,
pub b: Vec<Vec<f64>>,
pub c: Vec<Vec<f64>>,
pub rank: usize,
pub lambda: Vec<f64>,
}
impl CPDecomp {
pub fn reconstruct(&self) -> LinalgResult<Tensor3D> {
let i_dim = self.a.len();
let j_dim = self.b.len();
let k_dim = self.c.len();
let mut result = Tensor3D::zeros([i_dim, j_dim, k_dim]);
for r in 0..self.rank {
let lam = self.lambda[r];
for i in 0..i_dim {
for j in 0..j_dim {
for k in 0..k_dim {
let v = result.get(i, j, k)
+ lam * self.a[i][r] * self.b[j][r] * self.c[k][r];
result.set(i, j, k, v);
}
}
}
}
Ok(result)
}
pub fn relative_error(&self, x: &Tensor3D) -> LinalgResult<f64> {
let xhat = self.reconstruct()?;
let diff_sq: f64 = x
.data
.iter()
.zip(xhat.data.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
let orig_sq: f64 = x.data.iter().map(|v| v * v).sum();
if orig_sq == 0.0 {
if diff_sq == 0.0 {
Ok(0.0)
} else {
Ok(f64::INFINITY)
}
} else {
Ok((diff_sq / orig_sq).sqrt())
}
}
}
pub fn fit_als(
x: &Tensor3D,
rank: usize,
max_iter: usize,
tol: f64,
) -> LinalgResult<CPDecomp> {
fit_als_impl(x, rank, max_iter, tol, 0.0)
}
pub fn fit_sparse_als(
x: &Tensor3D,
rank: usize,
max_iter: usize,
lambda_reg: f64,
) -> LinalgResult<CPDecomp> {
fit_als_impl(x, rank, max_iter, 1e-10, lambda_reg)
}
fn fit_als_impl(
x: &Tensor3D,
rank: usize,
max_iter: usize,
tol: f64,
lambda_reg: f64,
) -> LinalgResult<CPDecomp> {
if rank == 0 {
return Err(LinalgError::DomainError("CP rank must be ≥ 1".to_string()));
}
if max_iter == 0 {
return Err(LinalgError::DomainError(
"max_iter must be ≥ 1".to_string(),
));
}
let [i_dim, j_dim, k_dim] = x.shape;
let x_norm = x.frobenius_norm();
let x0 = x.mode_unfold(0)?;
let (u0, _, _) = truncated_svd(&x0, rank)?;
let x1 = x.mode_unfold(1)?;
let (u1, _, _) = truncated_svd(&x1, rank)?;
let x2 = x.mode_unfold(2)?;
let (u2, _, _) = truncated_svd(&x2, rank)?;
let mut a = pad_or_trim_columns(u0, i_dim, rank);
let mut b = pad_or_trim_columns(u1, j_dim, rank);
let mut c = pad_or_trim_columns(u2, k_dim, rank);
let mut lambda = vec![1.0_f64; rank];
let x_unfold_0 = x.mode_unfold(0)?; let x_unfold_1 = x.mode_unfold(1)?; let x_unfold_2 = x.mode_unfold(2)?;
let mut prev_err = f64::INFINITY;
for _iter in 0..max_iter {
let cb = Tensor3D::khatri_rao(&c, &b)?; let gram_cb = gram_hadamard(&c, &b)?; let gram_cb_reg = add_ridge(&gram_cb, lambda_reg);
let rhs_a = mat_mul(&x_unfold_0, &cb)?; a = solve_spd_rows(&gram_cb_reg, &rhs_a)?; normalise_columns(&mut a, &mut lambda);
let ca = Tensor3D::khatri_rao(&c, &a)?; let gram_ca = gram_hadamard(&c, &a)?;
let gram_ca_reg = add_ridge(&gram_ca, lambda_reg);
let rhs_b = mat_mul(&x_unfold_1, &ca)?;
b = solve_spd_rows(&gram_ca_reg, &rhs_b)?;
normalise_columns(&mut b, &mut lambda);
let ba = Tensor3D::khatri_rao(&b, &a)?; let gram_ba = gram_hadamard(&b, &a)?;
let gram_ba_reg = add_ridge(&gram_ba, lambda_reg);
let rhs_c = mat_mul(&x_unfold_2, &ba)?;
c = solve_spd_rows(&gram_ba_reg, &rhs_c)?;
normalise_columns(&mut c, &mut lambda);
if tol > 0.0 && x_norm > 0.0 {
let err = reconstruction_error_fast(x, &a, &b, &c, &lambda)?;
let rel = err / x_norm;
if (prev_err - rel).abs() < tol {
break;
}
prev_err = rel;
}
}
Ok(CPDecomp {
a,
b,
c,
rank,
lambda,
})
}
fn pad_or_trim_columns(mut mat: Vec<Vec<f64>>, rows: usize, cols: usize) -> Vec<Vec<f64>> {
mat.resize_with(rows, || vec![0.0_f64; cols]);
for row in mat.iter_mut() {
row.resize(cols, 0.0);
}
mat
}
fn gram_hadamard(a: &[Vec<f64>], b: &[Vec<f64>]) -> LinalgResult<Vec<Vec<f64>>> {
let ga = gram(a)?;
let gb = gram(b)?;
hadamard(&ga, &gb)
}
fn add_ridge(mat: &[Vec<f64>], ridge: f64) -> Vec<Vec<f64>> {
let mut m = mat.to_vec();
for (i, row) in m.iter_mut().enumerate() {
row[i] += ridge;
}
m
}
fn solve_spd_rows(
gram_mat: &[Vec<f64>],
rhs: &[Vec<f64>],
) -> LinalgResult<Vec<Vec<f64>>> {
let rhs_t = mat_transpose(rhs); let sol_t = solve_spd(gram_mat, &rhs_t)?; Ok(mat_transpose(&sol_t)) }
fn normalise_columns(mat: &mut Vec<Vec<f64>>, lambda: &mut Vec<f64>) {
let rows = mat.len();
if rows == 0 {
return;
}
let cols = mat[0].len();
for c in 0..cols {
let mut norm_sq = 0.0_f64;
for r in 0..rows {
norm_sq += mat[r][c] * mat[r][c];
}
let norm = norm_sq.sqrt();
if norm > 1e-15 {
lambda[c] *= norm;
for r in 0..rows {
mat[r][c] /= norm;
}
}
}
}
fn reconstruction_error_fast(
x: &Tensor3D,
a: &[Vec<f64>],
b: &[Vec<f64>],
c: &[Vec<f64>],
lambda: &[f64],
) -> LinalgResult<f64> {
let rank = lambda.len();
let [i_dim, j_dim, k_dim] = x.shape;
let ga = gram(a)?;
let gb = gram(b)?;
let gc = gram(c)?;
let gabgc = hadamard(&hadamard(&ga, &gb)?, &gc)?;
let mut xhat_norm_sq = 0.0_f64;
for r in 0..rank {
for s in 0..rank {
xhat_norm_sq += lambda[r] * lambda[s] * gabgc[r][s];
}
}
let mut inner = 0.0_f64;
for i in 0..i_dim {
for j in 0..j_dim {
for k in 0..k_dim {
let xval = x.get(i, j, k);
let mut approx = 0.0_f64;
for r in 0..rank {
approx += lambda[r] * a[i][r] * b[j][r] * c[k][r];
}
inner += xval * approx;
}
}
}
let x_norm_sq: f64 = x.data.iter().map(|v| v * v).sum();
let err_sq = (x_norm_sq - 2.0 * inner + xhat_norm_sq).max(0.0);
Ok(err_sq.sqrt())
}
pub fn cp_als(
x: &Tensor3D,
rank: usize,
max_iter: usize,
tol: f64,
) -> LinalgResult<CPDecomp> {
fit_als(x, rank, max_iter, tol)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_rank2_tensor() -> Tensor3D {
let i = 4usize;
let j = 5usize;
let k = 6usize;
let a1 = vec![1.0, 0.0, 0.0, 0.0];
let b1 = vec![0.0, 1.0, 0.0, 0.0, 0.0];
let c1 = vec![0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
let a2 = vec![0.0, 1.0, 0.0, 0.0];
let b2 = vec![0.0, 0.0, 1.0, 0.0, 0.0];
let c2 = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let mut data = vec![0.0_f64; i * j * k];
for ii in 0..i {
for jj in 0..j {
for kk in 0..k {
data[ii * j * k + jj * k + kk] =
a1[ii] * b1[jj] * c1[kk] + 2.0 * a2[ii] * b2[jj] * c2[kk];
}
}
}
Tensor3D::new(data, [i, j, k]).expect("ok")
}
#[test]
fn test_cp_reconstruct_rank2() {
let x = make_rank2_tensor();
let decomp = fit_als(&x, 2, 200, 1e-8).expect("als ok");
let err = decomp.relative_error(&x).expect("err ok");
assert!(
err < 0.1,
"CP reconstruction error {err:.4} >= 0.1"
);
}
#[test]
fn test_cp_sparse_als() {
let x = make_rank2_tensor();
let decomp = fit_sparse_als(&x, 2, 200, 1e-4).expect("sparse als ok");
let err = decomp.relative_error(&x).expect("err ok");
assert!(
err < 0.15,
"CP sparse ALS error {err:.4} >= 0.15"
);
}
#[test]
fn test_cp_rank_1() {
let a = [1.0_f64, 2.0, 3.0];
let b = [4.0_f64, 5.0];
let c = [6.0_f64, 7.0, 8.0];
let data: Vec<f64> = (0..3)
.flat_map(|i| {
(0..2).flat_map(move |j| (0..3).map(move |k| a[i] * b[j] * c[k]))
})
.collect();
let x = Tensor3D::new(data, [3, 2, 3]).expect("ok");
let decomp = fit_als(&x, 1, 100, 1e-10).expect("als ok");
let err = decomp.relative_error(&x).expect("err ok");
assert!(err < 1e-6, "Rank-1 error {err:.2e} >= 1e-6");
}
#[test]
fn test_cp_error_high_rank() {
let data: Vec<f64> = (0..60).map(|x| x as f64).collect();
let x = Tensor3D::new(data, [3, 4, 5]).expect("ok");
let decomp = fit_als(&x, 12, 300, 1e-8).expect("als ok");
let err = decomp.relative_error(&x).expect("err ok");
assert!(err < 0.05, "over-param error {err:.4} >= 0.05");
}
#[test]
fn test_cp_lambda_positive() {
let data: Vec<f64> = (0..24).map(|x| (x as f64) + 1.0).collect();
let x = Tensor3D::new(data, [2, 3, 4]).expect("ok");
let decomp = fit_als(&x, 3, 100, 1e-6).expect("als ok");
for (r, &lam) in decomp.lambda.iter().enumerate() {
assert!(lam >= 0.0, "lambda[{r}] = {lam} < 0");
}
}
}