use crate::error::{LinalgError, LinalgResult};
use crate::tensor_decomp::hosvd::hooi;
use crate::tensor_decomp::parafac::fit_als;
use crate::tensor_decomp::tensor_utils::{mode_n_product, Tensor3D};
#[derive(Debug, Clone)]
pub struct TuckerDecomp {
pub g: Tensor3D,
pub factors: [Vec<Vec<f64>>; 3],
pub ranks: [usize; 3],
}
impl TuckerDecomp {
pub fn reconstruct(&self) -> LinalgResult<Tensor3D> {
let t1 = mode_n_product(&self.g, &self.factors[0], 0)?;
let t2 = mode_n_product(&t1, &self.factors[1], 1)?;
mode_n_product(&t2, &self.factors[2], 2)
}
pub fn relative_error(&self, x: &Tensor3D) -> LinalgResult<f64> {
let xhat = self.reconstruct()?;
let diff_sq: f64 = x
.data
.iter()
.zip(xhat.data.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
let orig_sq: f64 = x.data.iter().map(|v| v * v).sum();
if orig_sq == 0.0 {
if diff_sq == 0.0 {
Ok(0.0)
} else {
Ok(f64::INFINITY)
}
} else {
Ok((diff_sq / orig_sq).sqrt())
}
}
pub fn compress_ratio(&self, original_shape: [usize; 3]) -> f64 {
let original: usize = original_shape.iter().product();
let core: usize = self.g.shape.iter().product();
let factors: usize = self
.factors
.iter()
.map(|f| f.len() * if f.is_empty() { 0 } else { f[0].len() })
.sum();
let compressed = core + factors;
if compressed == 0 {
return f64::INFINITY;
}
original as f64 / compressed as f64
}
}
pub fn tucker_als(
x: &Tensor3D,
ranks: [usize; 3],
max_iter: usize,
tol: f64,
) -> LinalgResult<TuckerDecomp> {
let hosvd_result = hooi(x, ranks, max_iter, tol)?;
Ok(TuckerDecomp {
g: hosvd_result.g,
factors: hosvd_result.u,
ranks,
})
}
pub fn core_consistency_diagnostic(
x: &Tensor3D,
rank: usize,
max_iter: usize,
) -> LinalgResult<f64> {
let cp = fit_als(x, rank, max_iter, 1e-8)?;
use crate::tensor_decomp::tensor_utils::mat_transpose;
let at = mat_transpose(&cp.a);
let bt = mat_transpose(&cp.b);
let ct = mat_transpose(&cp.c);
let at_q = gram_schmidt_rows(&at)?;
let bt_q = gram_schmidt_rows(&bt)?;
let ct_q = gram_schmidt_rows(&ct)?;
let g1 = mode_n_product(x, &at_q, 0)?;
let g2 = mode_n_product(&g1, &bt_q, 1)?;
let g = mode_n_product(&g2, &ct_q, 2)?;
let r = rank.min(g.shape[0]).min(g.shape[1]).min(g.shape[2]);
let mut t = Tensor3D::zeros(g.shape);
for ri in 0..r {
t.set(ri, ri, ri, 1.0);
}
let diff_sq: f64 = g
.data
.iter()
.zip(t.data.iter())
.map(|(gi, ti)| (gi - ti).powi(2))
.sum();
let t_norm_sq: f64 = t.data.iter().map(|v| v * v).sum();
if t_norm_sq == 0.0 {
return Ok(0.0);
}
let corcondia = 100.0 * (1.0 - diff_sq / t_norm_sq);
Ok(corcondia)
}
fn gram_schmidt_rows(mat: &[Vec<f64>]) -> LinalgResult<Vec<Vec<f64>>> {
if mat.is_empty() {
return Ok(Vec::new());
}
let n = mat[0].len();
let m = mat.len();
let mut q: Vec<Vec<f64>> = Vec::with_capacity(m);
for row in mat {
let mut v = row.clone();
for qi in &q {
let dot: f64 = v.iter().zip(qi.iter()).map(|(a, b)| a * b).sum();
for (vi, qi_val) in v.iter_mut().zip(qi.iter()) {
*vi -= dot * qi_val;
}
}
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm < 1e-12 {
q.push(vec![0.0_f64; n]);
} else {
q.push(v.iter().map(|x| x / norm).collect());
}
}
Ok(q)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tensor() -> Tensor3D {
let data: Vec<f64> = (0..60_usize).map(|x| x as f64 + 1.0).collect();
Tensor3D::new(data, [3, 4, 5]).expect("ok")
}
#[test]
fn test_tucker_als_shapes() {
let t = make_tensor();
let d = tucker_als(&t, [2, 3, 4], 20, 1e-8).expect("ok");
assert_eq!(d.g.shape, [2, 3, 4]);
assert_eq!(d.factors[0].len(), 3);
assert_eq!(d.factors[0][0].len(), 2);
assert_eq!(d.factors[1].len(), 4);
assert_eq!(d.factors[1][0].len(), 3);
assert_eq!(d.factors[2].len(), 5);
assert_eq!(d.factors[2][0].len(), 4);
}
#[test]
fn test_tucker_als_full_rank_lossless() {
let t = Tensor3D::new(
(0..27_usize).map(|x| x as f64 + 1.0).collect(),
[3, 3, 3],
)
.expect("ok");
let d = tucker_als(&t, [3, 3, 3], 10, 1e-12).expect("ok");
let err = d.relative_error(&t).expect("err");
assert!(err < 1e-7, "full-rank Tucker error {err:.2e}");
}
#[test]
fn test_tucker_als_reconstruction_error() {
let t = make_tensor();
let d = tucker_als(&t, [2, 2, 2], 30, 1e-10).expect("ok");
let err = d.relative_error(&t).expect("err");
assert!(err < 1.0, "Tucker reconstruction error {err:.4}");
}
#[test]
fn test_compress_ratio() {
let t = make_tensor(); let d = tucker_als(&t, [2, 3, 4], 20, 1e-8).expect("ok");
let ratio = d.compress_ratio([3, 4, 5]);
assert!(ratio > 0.0, "compress_ratio should be positive");
}
#[test]
fn test_tucker_factor_orthogonality() {
let t = make_tensor();
let d = tucker_als(&t, [2, 3, 4], 20, 1e-8).expect("ok");
for n in 0..3 {
let u = &d.factors[n];
let m = u.len();
let r = u[0].len();
for i in 0..r {
for j in 0..r {
let dot: f64 = (0..m).map(|k| u[k][i] * u[k][j]).sum();
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-7,
"mode {n}: U^TU[{i},{j}] = {dot:.3e}"
);
}
}
}
}
#[test]
fn test_core_consistency_rank1() {
let a = [1.0_f64, 2.0, 3.0];
let b = [1.0_f64, 2.0];
let c = [1.0_f64, 2.0, 3.0, 4.0];
let data: Vec<f64> = (0..3)
.flat_map(|i| {
(0..2_usize).flat_map(move |j| (0..4_usize).map(move |k| a[i] * b[j] * c[k]))
})
.collect();
let x = Tensor3D::new(data, [3, 2, 4]).expect("ok");
let cc = core_consistency_diagnostic(&x, 1, 200).expect("cc ok");
assert!(
cc > 80.0,
"rank-1 tensor CORCONDIA for rank=1 should be near 100, got {cc:.1}"
);
}
}