use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayD, ArrayView, Dimension};
use scirs2_core::numeric::{Float, NumAssign, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct Tucker<A>
where
A: Clone + Float + Debug,
{
pub core: ArrayD<A>,
pub factors: Vec<Array2<A>>,
pub shape: Vec<usize>,
}
impl<A> Tucker<A>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ Send
+ Sync
+ 'static
+ scirs2_core::ndarray::ScalarOperand,
{
pub fn new(
core: ArrayD<A>,
factors: Vec<Array2<A>>,
shape: Option<Vec<usize>>,
) -> LinalgResult<Self> {
if factors.is_empty() {
return Err(LinalgError::ValueError(
"Factor matrices list cannot be empty".to_string(),
));
}
if core.ndim() != factors.len() {
return Err(LinalgError::ShapeError(format!(
"Number of factor matrices ({}) must match core tensor dimensionality ({})",
factors.len(),
core.ndim()
)));
}
for (i, factor) in factors.iter().enumerate() {
if factor.ndim() != 2 {
return Err(LinalgError::ShapeError(format!(
"Factor matrix {} must be 2-dimensional, got {} dimensions",
i,
factor.ndim()
)));
}
if factor.shape()[1] != core.shape()[i] {
return Err(LinalgError::ShapeError(format!(
"Factor matrix {} columns ({}) must match core tensor dimension {} ({})",
i,
factor.shape()[1],
i,
core.shape()[i]
)));
}
}
let shape = match shape {
Some(s) => {
if s.len() != factors.len() {
return Err(LinalgError::ShapeError(format!(
"Shape length ({}) must match number of factor matrices ({})",
s.len(),
factors.len()
)));
}
for (i, &dim) in s.iter().enumerate() {
if dim != factors[i].shape()[0] {
return Err(LinalgError::ShapeError(format!(
"Shape dimension {} ({}) must match factor matrix {} rows ({})",
i,
dim,
i,
factors[i].shape()[0]
)));
}
}
s
}
None => factors.iter().map(|f| f.shape()[0]).collect(),
};
Ok(Tucker {
core,
factors,
shape,
})
}
pub fn to_full(&self) -> LinalgResult<ArrayD<A>> {
use crate::tensor_contraction::mode_n_product;
let mut result = self.core.clone();
for (mode, factor) in self.factors.iter().enumerate() {
result = mode_n_product(&result.view(), &factor.view(), mode)?;
}
Ok(result)
}
pub fn reconstruction_error(&self, tensor: &ArrayView<A, impl Dimension>) -> LinalgResult<A> {
let reconstructed = self.to_full()?;
let tensor_dyn = tensor.to_owned().into_dyn();
if tensor_dyn.shape() != reconstructed.shape() {
return Err(LinalgError::ShapeError(format!(
"Original tensor shape {:?} does not match reconstructed tensor shape {:?}",
tensor_dyn.shape(),
reconstructed.shape()
)));
}
let mut diff_squared_sum = A::zero();
let mut orig_squared_sum = A::zero();
for (idx, &orig_val) in tensor_dyn.indexed_iter() {
let rec_val = reconstructed[idx.clone()];
diff_squared_sum += (orig_val - rec_val).powi(2);
orig_squared_sum += orig_val.powi(2);
}
if orig_squared_sum == A::zero() {
return Ok(if diff_squared_sum == A::zero() {
A::zero()
} else {
A::infinity()
});
}
Ok((diff_squared_sum / orig_squared_sum).sqrt())
}
pub fn compress(&self, ranks: Option<Vec<usize>>, epsilon: Option<A>) -> LinalgResult<Self> {
use super::svd_truncated;
if let Some(ref r) = ranks {
if r.len() != self.factors.len() {
return Err(LinalgError::ShapeError(format!(
"Ranks length ({}) must match number of factor matrices ({})",
r.len(),
self.factors.len()
)));
}
}
if let Some(eps) = epsilon {
if eps <= A::zero() {
return Err(LinalgError::ValueError(
"Epsilon must be positive".to_string(),
));
}
}
if ranks.is_none() && epsilon.is_none() {
return Ok(self.clone());
}
let target_ranks: Vec<usize> = match (&ranks, epsilon) {
(Some(r), None) => r.clone(),
(None, Some(eps)) => {
use scirs2_core::parallel_ops::*;
self.factors
.par_iter()
.map(|factor| {
let (_, s, _) = svd_truncated(factor, factor.shape()[1])
.expect("SVD of factor matrix failed");
let s_norm = if !s.is_empty() && s[[0, 0]] > A::zero() {
s.mapv(|v| v / s[[0, 0]])
} else {
s.clone()
};
let mut count = 0;
for i in 0..s_norm.shape()[0] {
if s_norm[[i, i]] >= eps {
count += 1;
} else {
break;
}
}
count.max(1)
})
.collect()
}
(Some(r), Some(eps)) => {
use scirs2_core::parallel_ops::*;
self.factors
.par_iter()
.zip(r.par_iter())
.map(|(factor, &max_rank)| {
let (_, s, _) = svd_truncated(factor, factor.shape()[1])
.expect("SVD of factor matrix failed");
let s_norm = if !s.is_empty() && s[[0, 0]] > A::zero() {
s.mapv(|v| v / s[[0, 0]])
} else {
s.clone()
};
let mut count = 0;
for i in 0..s_norm.shape()[0] {
if s_norm[[i, i]] >= eps {
count += 1;
} else {
break;
}
}
count.max(1).min(max_rank)
})
.collect()
}
(None, None) => unreachable!("This case is handled above"),
};
use scirs2_core::parallel_ops::*;
let compressed_factors: Vec<Array2<A>> = self
.factors
.par_iter()
.zip(target_ranks.par_iter())
.map(|(factor, &rank)| {
let rank = rank.min(factor.shape()[1]);
let (u, _, _) = svd_truncated(factor, rank).expect("SVD of factor matrix failed");
u
})
.collect();
let mut compressed_core = self.core.clone();
#[allow(clippy::needless_range_loop)]
for mode in 0..compressed_factors.len() {
let orig_factor_t = self.factors[mode].t().to_owned();
let comp_factor_t = compressed_factors[mode].t().to_owned();
let transform = comp_factor_t.dot(&orig_factor_t.t());
compressed_core = crate::tensor_contraction::mode_n_product(
&compressed_core.view(),
&transform.view(),
mode,
)?;
}
Tucker::new(
compressed_core,
compressed_factors,
Some(self.shape.clone()),
)
}
}
#[allow(dead_code)]
pub fn tucker_decomposition<A, D>(
tensor: &ArrayView<A, D>,
ranks: &[usize],
) -> LinalgResult<Tucker<A>>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ Send
+ Sync
+ 'static
+ scirs2_core::ndarray::ScalarOperand,
D: Dimension,
{
use super::hosvd;
if ranks.len() != tensor.ndim() {
return Err(LinalgError::ShapeError(format!(
"Ranks length ({}) must match tensor dimensionality ({})",
ranks.len(),
tensor.ndim()
)));
}
for (i, &rank) in ranks.iter().enumerate() {
if rank > tensor.shape()[i] {
return Err(LinalgError::ShapeError(format!(
"Rank for mode {} ({}) cannot exceed the mode dimension ({})",
i,
rank,
tensor.shape()[i]
)));
}
if rank == 0 {
return Err(LinalgError::ValueError(format!(
"Rank for mode {} must be positive",
i
)));
}
}
let (core, factors) = hosvd(tensor, ranks)?;
Tucker::new(core, factors, Some(tensor.shape().to_vec()))
}
#[allow(dead_code)]
pub fn tucker_als<A, D>(
tensor: &ArrayView<A, D>,
ranks: &[usize],
max_iterations: usize,
tolerance: A,
) -> LinalgResult<Tucker<A>>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ Send
+ Sync
+ 'static
+ scirs2_core::ndarray::ScalarOperand,
D: Dimension,
{
use super::mode_n_product;
use crate::decomposition::svd;
let mut tucker = tucker_decomposition(tensor, ranks)?;
let tensor_dyn = tensor.to_owned().into_dyn();
let mut tensor_norm_sq = A::zero();
for &val in tensor_dyn.iter() {
tensor_norm_sq += val.powi(2);
}
let mut prev_error = tucker.reconstruction_error(tensor)?;
for iteration in 0..max_iterations {
#[allow(clippy::needless_range_loop)]
for mode in 0..ranks.len() {
let tensor_unfolded = unfold_tensor(&tensor_dyn, mode)?;
let khatri_rao_product =
compute_khatri_rao_product(&tucker.factors, mode, &tucker.core)?;
let tensor_result = tensor_unfolded.dot(&khatri_rao_product.t());
let (u, _, _) = svd(&tensor_result.view(), false, None)?;
let new_factor = u
.slice(scirs2_core::ndarray::s![.., ..ranks[mode]])
.to_owned();
tucker.factors[mode] = new_factor;
let mut temp_core = tensor_dyn.clone();
for m in 0..tucker.factors.len() {
let factor_t = tucker.factors[m].t().to_owned();
temp_core = mode_n_product(&temp_core.view(), &factor_t.view(), m)?;
}
tucker.core = temp_core;
}
let error = tucker.reconstruction_error(tensor)?;
let rel_improvement = (prev_error - error) / prev_error;
if rel_improvement < tolerance && iteration > 0 {
break;
}
prev_error = error;
}
Ok(tucker)
}
#[allow(dead_code)]
fn unfold_tensor<A>(tensor: &ArrayD<A>, mode: usize) -> LinalgResult<Array2<A>>
where
A: Clone + Float + NumAssign + Zero + Debug + Send + Sync + 'static,
{
let shape = tensor.shape();
if mode >= shape.len() {
return Err(LinalgError::ShapeError(format!(
"Mode {} is out of bounds for _tensor with {} dimensions",
mode,
shape.len()
)));
}
let mode_dim = shape[mode];
let other_dims_prod: usize = shape
.iter()
.enumerate()
.filter(|&(i, _)| i != mode)
.map(|(_, &dim)| dim)
.product();
let mut result = Array2::zeros((mode_dim, other_dims_prod));
fn calc_col_idx(idx: &[usize], shape: &[usize], mode: usize) -> usize {
let mut col_idx = 0;
let mut stride = 1;
for dim in (0..shape.len()).rev() {
if dim != mode {
col_idx += idx[dim] * stride;
stride *= shape[dim];
}
}
col_idx
}
for idx in scirs2_core::ndarray::indices(shape) {
let mode_idx = idx[mode];
let idx_vec: Vec<usize> = idx.as_array_view().to_vec();
let col_idx = calc_col_idx(&idx_vec, shape, mode);
result[[mode_idx, col_idx]] = tensor[idx.clone()];
}
Ok(result)
}
#[allow(dead_code)]
fn compute_khatri_rao_product<A>(
factors: &[Array2<A>],
skip_mode: usize,
core: &ArrayD<A>,
) -> LinalgResult<Array2<A>>
where
A: Clone + Float + NumAssign + Zero + Debug + Send + Sync + 'static,
{
use crate::tensor_contraction::mode_n_product;
let _n_modes = factors.len();
let _core_unfolded = unfold_tensor(core, skip_mode)?;
let mut projected_tensor = core.clone();
for (mode, factor) in factors.iter().enumerate() {
if mode == skip_mode {
continue;
}
projected_tensor = mode_n_product(&projected_tensor.view(), &factor.view(), mode)?;
}
let result = unfold_tensor(&projected_tensor, skip_mode)?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
#[ignore = "SVD fails for small matrices due to unimplemented eigendecomposition"]
fn test_tucker_decomposition_basic() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let tucker = tucker_decomposition(&tensor.view(), &[2, 3, 2]).expect("Operation failed");
assert_eq!(tucker.core.shape(), &[2, 3, 2]);
assert_eq!(tucker.factors.len(), 3);
assert_eq!(tucker.factors[0].shape(), &[2, 2]);
assert_eq!(tucker.factors[1].shape(), &[3, 3]);
assert_eq!(tucker.factors[2].shape(), &[2, 2]);
let reconstructed = tucker.to_full().expect("Operation failed");
for i in 0..2 {
for j in 0..3 {
for k in 0..2 {
assert_abs_diff_eq!(
reconstructed[[i, j, k]],
tensor[[i, j, k]],
epsilon = 1e-10
);
}
}
}
}
#[test]
#[ignore = "SVD fails for small matrices due to unimplemented eigendecomposition"]
fn test_tucker_decomposition_truncated() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let tucker = tucker_decomposition(&tensor.view(), &[2, 2, 2]).expect("Operation failed");
assert_eq!(tucker.core.shape(), &[2, 2, 2]);
assert_eq!(tucker.factors.len(), 3);
assert_eq!(tucker.factors[0].shape(), &[2, 2]);
assert_eq!(tucker.factors[1].shape(), &[3, 2]);
assert_eq!(tucker.factors[2].shape(), &[2, 2]);
let _reconstructed = tucker.to_full().expect("Operation failed");
let error = tucker
.reconstruction_error(&tensor.view())
.expect("Operation failed");
assert!(error > 0.0);
assert!(error < 0.1); }
#[test]
#[ignore = "SVD fails for small matrices due to unimplemented eigendecomposition"]
fn test_tucker_als() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let tucker = tucker_als(&tensor.view(), &[2, 2, 2], 10, 1e-4).expect("Operation failed");
assert_eq!(tucker.core.shape(), &[2, 2, 2]);
assert_eq!(tucker.factors.len(), 3);
assert_eq!(tucker.factors[0].shape(), &[2, 2]);
assert_eq!(tucker.factors[1].shape(), &[3, 2]);
assert_eq!(tucker.factors[2].shape(), &[2, 2]);
let _reconstructed = tucker.to_full().expect("Operation failed");
let hosvd_tucker =
tucker_decomposition(&tensor.view(), &[2, 2, 2]).expect("Operation failed");
let als_error = tucker
.reconstruction_error(&tensor.view())
.expect("Operation failed");
let hosvd_error = hosvd_tucker
.reconstruction_error(&tensor.view())
.expect("Operation failed");
assert!(als_error <= hosvd_error * 1.001); }
#[test]
#[ignore = "SVD fails for small matrices due to unimplemented eigendecomposition"]
fn test_compress() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let tucker = tucker_decomposition(&tensor.view(), &[2, 3, 2]).expect("Operation failed");
let compressed = tucker
.compress(Some(vec![2, 2, 2]), None)
.expect("Operation failed");
assert_eq!(compressed.core.shape(), &[2, 2, 2]);
assert_eq!(compressed.factors[0].shape(), &[2, 2]);
assert_eq!(compressed.factors[1].shape(), &[3, 2]);
assert_eq!(compressed.factors[2].shape(), &[2, 2]);
let compressed_eps = tucker.compress(None, Some(0.1)).expect("Operation failed");
for factor in &compressed_eps.factors {
assert!(factor.shape()[1] >= 1);
}
}
}