use crate::decomposition::svd;
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayD, ArrayView, Dimension};
use scirs2_core::numeric::{Float, NumAssign, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct CanonicalPolyadic<A>
where
A: Clone + Float + Debug,
{
pub factors: Vec<Array2<A>>,
pub weights: Option<Array1<A>>,
pub shape: Vec<usize>,
}
impl<A> CanonicalPolyadic<A>
where
A: Clone + Float + NumAssign + Zero + Debug + Sum + Send + Sync + 'static,
{
pub fn new(
factors: Vec<Array2<A>>,
weights: Option<Array1<A>>,
shape: Option<Vec<usize>>,
) -> LinalgResult<Self> {
if factors.is_empty() {
return Err(LinalgError::ValueError(
"Factor matrices list cannot be empty".to_string(),
));
}
let rank = factors[0].shape()[1];
for (i, factor) in factors.iter().enumerate().skip(1) {
if factor.shape()[1] != rank {
return Err(LinalgError::ShapeError(format!(
"All factor matrices must have the same number of columns (rank). Factor 0 has {} columns, but factor {} has {} columns.",
rank,
i,
factor.shape()[1]
)));
}
}
if let Some(ref w) = weights {
if w.len() != rank {
return Err(LinalgError::ShapeError(format!(
"Weights length ({}) must match decomposition rank ({})",
w.len(),
rank
)));
}
}
let shape = match shape {
Some(s) => {
if s.len() != factors.len() {
return Err(LinalgError::ShapeError(format!(
"Shape length ({}) must match number of factor matrices ({})",
s.len(),
factors.len()
)));
}
for (i, &dim) in s.iter().enumerate() {
if dim != factors[i].shape()[0] {
return Err(LinalgError::ShapeError(format!(
"Shape dimension {} ({}) must match factor matrix {} rows ({})",
i,
dim,
i,
factors[i].shape()[0]
)));
}
}
s
}
None => factors.iter().map(|f| f.shape()[0]).collect(),
};
Ok(CanonicalPolyadic {
factors,
weights,
shape,
})
}
pub fn to_full(&self) -> LinalgResult<ArrayD<A>> {
let rank = self.factors[0].shape()[1];
let _n_modes = self.factors.len();
let mut result = ArrayD::zeros(self.shape.clone());
for r in 0..rank {
let weight = match &self.weights {
Some(w) => w[r],
None => A::one(),
};
let mut indices_vec = Vec::new();
fn generate_indices(
shape: &[usize],
current: Vec<usize>,
depth: usize,
all_indices: &mut Vec<Vec<usize>>,
) {
if depth == shape.len() {
all_indices.push(current);
return;
}
let mut current = current;
for i in 0..shape[depth] {
current.push(i);
generate_indices(shape, current.clone(), depth + 1, all_indices);
current.pop();
}
}
generate_indices(&self.shape, Vec::new(), 0, &mut indices_vec);
for idx in indices_vec {
let mut value = weight;
for (mode, &i) in idx.iter().enumerate() {
value *= self.factors[mode][[i, r]];
}
let result_idx = scirs2_core::ndarray::IxDyn(idx.as_slice());
result[&result_idx] += value;
}
}
Ok(result)
}
pub fn reconstruction_error(&self, tensor: &ArrayView<A, impl Dimension>) -> LinalgResult<A> {
let reconstructed = self.to_full()?;
let tensor_dyn = tensor.to_owned().into_dyn();
if tensor_dyn.shape() != reconstructed.shape() {
return Err(LinalgError::ShapeError(format!(
"Original tensor shape {:?} does not match reconstructed tensor shape {:?}",
tensor_dyn.shape(),
reconstructed.shape()
)));
}
let mut diff_squared_sum = A::zero();
let mut orig_squared_sum = A::zero();
for (idx, &orig_val) in tensor_dyn.indexed_iter() {
let rec_val = reconstructed[idx.clone()];
diff_squared_sum += (orig_val - rec_val).powi(2);
orig_squared_sum += orig_val.powi(2);
}
if orig_squared_sum == A::zero() {
return Ok(if diff_squared_sum == A::zero() {
A::zero()
} else {
A::infinity()
});
}
Ok((diff_squared_sum / orig_squared_sum).sqrt())
}
pub fn compress(&self, newrank: usize) -> LinalgResult<Self> {
let current_rank = self.factors[0].shape()[1];
if newrank > current_rank {
return Err(LinalgError::ValueError(format!(
"New _rank ({}) must be less than or equal to current _rank ({})",
newrank, current_rank
)));
}
if newrank == 0 {
return Err(LinalgError::ValueError(
"New _rank must be at least 1".to_string(),
));
}
if newrank == current_rank {
return Ok(self.clone());
}
let compressed_factors: Vec<Array2<A>> = self
.factors
.iter()
.map(|factor| {
factor
.slice(scirs2_core::ndarray::s![.., ..newrank])
.to_owned()
})
.collect();
let compressed_weights = self
.weights
.as_ref()
.map(|w| w.slice(scirs2_core::ndarray::s![..newrank]).to_owned());
CanonicalPolyadic::new(
compressed_factors,
compressed_weights,
Some(self.shape.clone()),
)
}
}
#[allow(dead_code)]
pub fn cp_als<A, D>(
tensor: &ArrayView<A, D>,
rank: usize,
max_iterations: usize,
tolerance: A,
normalize: bool,
) -> LinalgResult<CanonicalPolyadic<A>>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ Send
+ Sync
+ 'static
+ scirs2_core::ndarray::ScalarOperand,
D: Dimension,
{
if rank == 0 {
return Err(LinalgError::ValueError(
"Rank must be at least 1".to_string(),
));
}
let n_modes = tensor.ndim();
let shape = tensor.shape().to_vec();
let mut factors: Vec<Array2<A>> = Vec::with_capacity(n_modes);
for &dim in shape.iter() {
let mut factor = Array2::zeros((dim, rank));
for i in 0..dim {
for j in 0..rank {
factor[[i, j]] = A::from(((i + 1) * (j + 1)) % 10).expect("Operation failed")
/ A::from(10).expect("Operation failed");
}
}
factors.push(factor);
}
let tensor_dyn = tensor.to_owned().into_dyn();
let unfolded_tensors: Vec<Array2<A>> = (0..n_modes)
.map(|mode| unfold_tensor(&tensor_dyn, mode).expect("Tensor unfolding failed"))
.collect();
let mut prev_error = A::infinity();
for iteration in 0..max_iterations {
for mode in 0..n_modes {
let kr_product = khatri_rao_product(&factors, mode)?;
let mttkrp = unfolded_tensors[mode].dot(&kr_product);
let grammatrix = compute_grammatrix(&factors, mode)?;
let gram_inv = pseudo_inverse(&grammatrix)?;
factors[mode] = mttkrp.dot(&gram_inv);
}
let mut weights = None;
if normalize {
weights = Some(normalize_factors(&mut factors));
}
let cp = CanonicalPolyadic::new(factors.clone(), weights.clone(), Some(shape.clone()))?;
let error = cp.reconstruction_error(tensor)?;
let rel_improvement = (prev_error - error) / prev_error;
if !rel_improvement.is_nan() && rel_improvement < tolerance && iteration > 0 {
return Ok(cp);
}
prev_error = error;
}
let weights = if normalize {
Some(normalize_factors(&mut factors))
} else {
None
};
CanonicalPolyadic::new(factors, weights, Some(shape))
}
#[allow(dead_code)]
fn unfold_tensor<A>(tensor: &ArrayD<A>, mode: usize) -> LinalgResult<Array2<A>>
where
A: Clone + Float + NumAssign + Zero + Debug + Send + Sync + 'static,
{
let shape = tensor.shape();
if mode >= shape.len() {
return Err(LinalgError::ShapeError(format!(
"Mode {} is out of bounds for _tensor with {} dimensions",
mode,
shape.len()
)));
}
let mode_dim = shape[mode];
let other_dims_prod: usize = shape
.iter()
.enumerate()
.filter(|&(i, _)| i != mode)
.map(|(_, &dim)| dim)
.product();
let mut result = Array2::zeros((mode_dim, other_dims_prod));
fn calc_col_idx(idx: &[usize], shape: &[usize], mode: usize) -> usize {
let mut col_idx = 0;
let mut stride = 1;
for dim in (0..shape.len()).rev() {
if dim != mode {
col_idx += idx[dim] * stride;
stride *= shape[dim];
}
}
col_idx
}
for idx in scirs2_core::ndarray::indices(shape) {
let mode_idx = idx[mode];
let idx_vec: Vec<usize> = idx.as_array_view().to_vec();
let col_idx = calc_col_idx(&idx_vec, shape, mode);
result[[mode_idx, col_idx]] = tensor[idx.clone()];
}
Ok(result)
}
#[allow(dead_code)]
fn khatri_rao_product<A>(factors: &[Array2<A>], skipmode: usize) -> LinalgResult<Array2<A>>
where
A: Clone + Float + NumAssign + Zero + Debug + Send + Sync + 'static,
{
let n_modes = factors.len();
if n_modes <= 1 {
return Err(LinalgError::ValueError(
"Need at least two factor matrices to compute Khatri-Rao product".to_string(),
));
}
let rank = factors[0].shape()[1];
if n_modes == 2 && skipmode < n_modes {
let other_mode = if skipmode == 0 { 1 } else { 0 };
return Ok(factors[other_mode].clone());
}
let _n_rows: usize = factors
.iter()
.enumerate()
.filter(|&(i, _)| i != skipmode)
.map(|(_, f)| f.shape()[0])
.product();
let mut result = None;
let mut result_rows = 1;
for (_mode, factor) in factors.iter().enumerate() {
if _mode == skipmode {
continue;
}
match result {
None => {
result = Some(factor.clone());
result_rows = factor.shape()[0];
}
Some(prev_result) => {
let factor_rows = factor.shape()[0];
let mut new_result = Array2::zeros((result_rows * factor_rows, rank));
for r in 0..rank {
let mut col_idx = 0;
for i in 0..result_rows {
for j in 0..factor_rows {
new_result[[col_idx, r]] = prev_result[[i, r]] * factor[[j, r]];
col_idx += 1;
}
}
}
result = Some(new_result);
result_rows = result.as_ref().expect("Operation failed").shape()[0];
}
}
}
result.ok_or_else(|| LinalgError::ValueError("All _factors were skipped".to_string()))
}
#[allow(dead_code)]
fn compute_grammatrix<A>(factors: &[Array2<A>], skipmode: usize) -> LinalgResult<Array2<A>>
where
A: Clone + Float + NumAssign + Zero + Debug + Send + Sync + 'static,
{
let _n_modes = factors.len();
let rank = factors[0].shape()[1];
let mut gram = Array2::ones((rank, rank));
for (_mode, factor) in factors.iter().enumerate() {
if _mode == skipmode {
continue;
}
let factor_t = factor.t();
let factor_gram = factor_t.dot(factor);
for i in 0..rank {
for j in 0..rank {
gram[[i, j]] *= factor_gram[[i, j]];
}
}
}
Ok(gram)
}
#[allow(dead_code)]
fn pseudo_inverse<A>(matrix: &Array2<A>) -> LinalgResult<Array2<A>>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ Send
+ Sync
+ 'static
+ scirs2_core::ndarray::ScalarOperand,
{
let (u, s, vt) = svd(&matrix.view(), false, None)?;
let mut s_inv = Array2::zeros((s.len(), s.len()));
for i in 0..s.len() {
if s[i] > A::epsilon() * A::from(10.0).expect("Operation failed") {
s_inv[[i, i]] = A::one() / s[i];
}
}
let v = vt.t();
let u_t = u.t();
let vs_inv = v.dot(&s_inv);
let result = vs_inv.dot(&u_t);
Ok(result)
}
#[allow(dead_code)]
fn normalize_factors<A>(factors: &mut [Array2<A>]) -> Array1<A>
where
A: Clone + Float + NumAssign + Zero + Debug + Send + Sync + 'static,
{
let _n_modes = factors.len();
let rank = factors[0].shape()[1];
let mut weights = Array1::ones(rank);
for r in 0..rank {
for factor in factors.iter_mut() {
let mut norm_sq = A::zero();
for i in 0..factor.shape()[0] {
norm_sq += factor[[i, r]].powi(2);
}
let norm = norm_sq.sqrt();
if norm > A::epsilon() {
for i in 0..factor.shape()[0] {
factor[[i, r]] /= norm;
}
weights[r] *= norm;
}
}
}
weights
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_cp_decomposition_basic() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let cp = cp_als(&tensor.view(), 2, 50, 1e-4, true).expect("Operation failed");
assert_eq!(cp.factors.len(), 3);
assert_eq!(cp.factors[0].shape(), &[2, 2]);
assert_eq!(cp.factors[1].shape(), &[3, 2]);
assert_eq!(cp.factors[2].shape(), &[2, 2]);
assert!(cp.weights.is_some());
assert_eq!(cp.weights.as_ref().expect("Operation failed").len(), 2);
let _reconstructed = cp.to_full().expect("Operation failed");
let error = cp
.reconstruction_error(&tensor.view())
.expect("Operation failed");
assert!(error < 0.1); }
#[test]
fn test_cp_decomposition_truncated() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let cp = cp_als(&tensor.view(), 2, 50, 1e-4, true).expect("Operation failed");
assert_eq!(cp.factors.len(), 3);
assert_eq!(cp.factors[0].shape(), &[2, 2]);
assert_eq!(cp.factors[1].shape(), &[3, 2]);
assert_eq!(cp.factors[2].shape(), &[2, 2]);
let _reconstructed = cp.to_full().expect("Operation failed");
let error = cp
.reconstruction_error(&tensor.view())
.expect("Operation failed");
assert!(error > 0.0);
assert!(error < 0.5); }
#[test]
#[ignore = "SVD fails for small matrices due to unimplemented eigendecomposition"]
fn test_compress() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let cp = cp_als(&tensor.view(), 4, 50, 1e-4, true).expect("Operation failed");
let compressed = cp.compress(2).expect("Operation failed");
assert_eq!(compressed.factors.len(), 3);
assert_eq!(compressed.factors[0].shape(), &[2, 2]);
assert_eq!(compressed.factors[1].shape(), &[3, 2]);
assert_eq!(compressed.factors[2].shape(), &[2, 2]);
assert!(compressed.weights.is_some());
assert_eq!(
compressed.weights.as_ref().expect("Operation failed").len(),
2
);
let error_orig = cp
.reconstruction_error(&tensor.view())
.expect("Operation failed");
let error_comp = compressed
.reconstruction_error(&tensor.view())
.expect("Operation failed");
assert!(error_comp >= error_orig * 0.99); }
#[test]
fn test_reconstruction() {
let a = array![1.0, 2.0];
let b = array![3.0, 4.0, 5.0];
let c = array![6.0, 7.0];
let mut tensor = ArrayD::<f64>::zeros(scirs2_core::ndarray::IxDyn(&[2, 3, 2]));
for i in 0..2 {
for j in 0..3 {
for k in 0..2 {
tensor[[i, j, k]] = a[i] * b[j] * c[k];
}
}
}
let factors = vec![
Array2::from_shape_fn((2, 1), |(i, _)| a[i]),
Array2::from_shape_fn((3, 1), |(j, _)| b[j]),
Array2::from_shape_fn((2, 1), |(k, _)| c[k]),
];
let cp =
CanonicalPolyadic::new(factors, None, Some(vec![2, 3, 2])).expect("Operation failed");
let reconstructed = cp.to_full().expect("Operation failed");
for i in 0..2 {
for j in 0..3 {
for k in 0..2 {
assert_abs_diff_eq!(
reconstructed[[i, j, k]],
tensor[[i, j, k]],
epsilon = 1e-10
);
}
}
}
}
#[test]
fn test_khatri_rao_product() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
let factors = vec![a.clone(), b.clone()];
let kr = khatri_rao_product(&factors, 0).expect("Operation failed");
assert_eq!(kr.shape(), &[3, 2]);
for i in 0..3 {
for j in 0..2 {
assert_abs_diff_eq!(kr[[i, j]], b[[i, j]], epsilon = 1e-10);
}
}
let kr = khatri_rao_product(&factors, 1).expect("Operation failed");
assert_eq!(kr.shape(), &[2, 2]);
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(kr[[i, j]], a[[i, j]], epsilon = 1e-10);
}
}
}
}