use crate::error::{LinalgError, LinalgResult};
use crate::tensor_decomp::tensor_utils::{mat_transpose, mode_n_product, truncated_svd, Tensor3D};
#[derive(Debug, Clone)]
pub struct HOSVDDecomp {
pub g: Tensor3D,
pub u: [Vec<Vec<f64>>; 3],
pub ranks: [usize; 3],
}
impl HOSVDDecomp {
pub fn reconstruct(&self) -> LinalgResult<Tensor3D> {
let t1 = mode_n_product(&self.g, &self.u[0], 0)?;
let t2 = mode_n_product(&t1, &self.u[1], 1)?;
mode_n_product(&t2, &self.u[2], 2)
}
pub fn relative_error(&self, x: &Tensor3D) -> LinalgResult<f64> {
let xhat = self.reconstruct()?;
let diff_sq: f64 = x
.data
.iter()
.zip(xhat.data.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
let orig_sq: f64 = x.data.iter().map(|v| v * v).sum();
if orig_sq == 0.0 {
if diff_sq == 0.0 {
Ok(0.0)
} else {
Ok(f64::INFINITY)
}
} else {
Ok((diff_sq / orig_sq).sqrt())
}
}
}
pub fn hosvd(x: &Tensor3D) -> LinalgResult<HOSVDDecomp> {
let ranks = x.shape;
hosvd_truncated(x, ranks)
}
pub fn hosvd_truncated(x: &Tensor3D, ranks: [usize; 3]) -> LinalgResult<HOSVDDecomp> {
for n in 0..3 {
if ranks[n] == 0 {
return Err(LinalgError::DomainError(format!(
"hosvd_truncated: ranks[{n}] must be ≥ 1"
)));
}
if ranks[n] > x.shape[n] {
return Err(LinalgError::DomainError(format!(
"hosvd_truncated: ranks[{n}]={} > shape[{n}]={}",
ranks[n], x.shape[n]
)));
}
}
let mut us: Vec<Vec<Vec<f64>>> = Vec::with_capacity(3);
for n in 0..3 {
let unfolding = x.mode_unfold(n)?;
let (u, _, _) = truncated_svd(&unfolding, ranks[n])?;
us.push(u);
}
let u0t = mat_transpose(&us[0]);
let u1t = mat_transpose(&us[1]);
let u2t = mat_transpose(&us[2]);
let g1 = mode_n_product(x, &u0t, 0)?;
let g2 = mode_n_product(&g1, &u1t, 1)?;
let g = mode_n_product(&g2, &u2t, 2)?;
let u_arr = [us.remove(0), us.remove(0), us.remove(0)];
Ok(HOSVDDecomp {
g,
u: u_arr,
ranks,
})
}
pub fn hooi(
x: &Tensor3D,
ranks: [usize; 3],
max_iter: usize,
tol: f64,
) -> LinalgResult<HOSVDDecomp> {
let init = hosvd_truncated(x, ranks)?;
let mut us = init.u;
for _iter in 0..max_iter {
let mut delta = 0.0_f64;
for n in 0..3_usize {
let modes_except_n: Vec<usize> = (0..3).filter(|&m| m != n).collect();
let mut y = x.clone();
for &m in &modes_except_n {
let umt = mat_transpose(&us[m]);
y = mode_n_product(&y, &umt, m)?;
}
let y_n = y.mode_unfold(n)?;
let (u_new, _, _) = truncated_svd(&y_n, ranks[n])?;
delta += subspace_change(&us[n], &u_new);
us[n] = u_new;
}
if delta < tol {
break;
}
}
let u0t = mat_transpose(&us[0]);
let u1t = mat_transpose(&us[1]);
let u2t = mat_transpose(&us[2]);
let g1 = mode_n_product(x, &u0t, 0)?;
let g2 = mode_n_product(&g1, &u1t, 1)?;
let g = mode_n_product(&g2, &u2t, 2)?;
Ok(HOSVDDecomp {
g,
u: [us[0].clone(), us[1].clone(), us[2].clone()],
ranks,
})
}
fn subspace_change(u_old: &[Vec<f64>], u_new: &[Vec<f64>]) -> f64 {
if u_old.len() != u_new.len() || u_old.is_empty() {
return f64::INFINITY;
}
let m = u_old.len();
let k = u_old[0].len().min(u_new[0].len());
let mut sum = 0.0_f64;
for i in 0..m {
for j in 0..k {
let diff = u_new[i][j] - u_old[i][j];
sum += diff * diff;
}
}
sum.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tensor() -> Tensor3D {
let data: Vec<f64> = (0..3_usize)
.flat_map(|i| {
(0..4_usize).flat_map(move |j| {
(0..5_usize)
.map(move |k| ((i + 1) * (j + 1)) as f64 + k as f64)
})
})
.collect();
Tensor3D::new(data, [3, 4, 5]).expect("ok")
}
#[test]
fn test_hosvd_shapes() {
let t = make_tensor();
let d = hosvd_truncated(&t, [2, 3, 4]).expect("ok");
assert_eq!(d.g.shape, [2, 3, 4]);
assert_eq!(d.u[0].len(), 3);
assert_eq!(d.u[0][0].len(), 2);
assert_eq!(d.u[1].len(), 4);
assert_eq!(d.u[1][0].len(), 3);
assert_eq!(d.u[2].len(), 5);
assert_eq!(d.u[2][0].len(), 4);
}
#[test]
fn test_hosvd_full_rank_lossless() {
let t = make_tensor();
let d = hosvd(&t).expect("full rank ok");
let err = d.relative_error(&t).expect("err ok");
assert!(err < 1e-8, "HOSVD full-rank error {err:.2e}");
}
#[test]
fn test_hosvd_truncated_reduces_error_with_rank() {
let t = make_tensor();
let d1 = hosvd_truncated(&t, [1, 1, 1]).expect("rank-1");
let d2 = hosvd_truncated(&t, [2, 2, 2]).expect("rank-2");
let e1 = d1.relative_error(&t).expect("e1");
let e2 = d2.relative_error(&t).expect("e2");
assert!(e2 <= e1 + 1e-10, "rank-2 error {e2} > rank-1 error {e1}");
}
#[test]
fn test_hooi_shapes() {
let t = make_tensor();
let d = hooi(&t, [2, 3, 4], 20, 1e-8).expect("hooi ok");
assert_eq!(d.g.shape, [2, 3, 4]);
assert_eq!(d.u[0].len(), 3);
assert_eq!(d.u[1].len(), 4);
assert_eq!(d.u[2].len(), 5);
}
#[test]
fn test_hooi_better_or_equal_to_hosvd() {
let t = make_tensor();
let d_hosvd = hosvd_truncated(&t, [2, 2, 2]).expect("hosvd");
let d_hooi = hooi(&t, [2, 2, 2], 30, 1e-10).expect("hooi");
let e_hosvd = d_hosvd.relative_error(&t).expect("e_hosvd");
let e_hooi = d_hooi.relative_error(&t).expect("e_hooi");
assert!(
e_hooi <= e_hosvd + 1e-6,
"HOOI error {e_hooi} > HOSVD error {e_hosvd}"
);
}
#[test]
fn test_hooi_full_rank_lossless() {
let t = Tensor3D::new(
(0..27_usize).map(|x| x as f64 + 1.0).collect(),
[3, 3, 3],
)
.expect("ok");
let d = hooi(&t, [3, 3, 3], 10, 1e-12).expect("ok");
let err = d.relative_error(&t).expect("err");
assert!(err < 1e-7, "full-rank HOOI error {err:.2e}");
}
#[test]
fn test_factor_orthogonality() {
let t = make_tensor();
let d = hosvd_truncated(&t, [2, 3, 4]).expect("ok");
for n in 0..3 {
let u = &d.u[n];
let r = u[0].len();
let m = u.len();
for i in 0..r {
for j in 0..r {
let dot: f64 = (0..m).map(|k| u[k][i] * u[k][j]).sum();
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-8,
"mode {n}: U^TU[{i},{j}] = {dot:.3e}, expected {expected}"
);
}
}
}
}
}