use crate::error::{LinalgError, LinalgResult};
use crate::tensor::core::{Tensor, TensorScalar};
use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub struct CPConfig {
pub max_iter: usize,
pub tol: f64,
pub lr: f64,
pub random_init: bool,
pub seed: u64,
}
impl Default for CPConfig {
fn default() -> Self {
Self {
max_iter: 500,
tol: 1e-8,
lr: 1e-3,
random_init: false,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct CPResult<F> {
pub lambdas: Vec<F>,
pub factors: Vec<Array2<F>>,
pub loss: Vec<F>,
}
impl<F: TensorScalar> CPResult<F> {
pub fn reconstruct(&self, shape: &[usize]) -> LinalgResult<Tensor<F>> {
cp_reconstruct(self, shape)
}
pub fn relative_error(&self, original: &Tensor<F>) -> LinalgResult<F> {
let recon = self.reconstruct(&original.shape)?;
let orig_norm = original.frobenius_norm();
if orig_norm == F::zero() {
return Ok(F::zero());
}
let diff_sq: F = original
.data
.iter()
.zip(recon.data.iter())
.map(|(&a, &b)| {
let d = a - b;
d * d
})
.fold(F::zero(), |acc, x| acc + x);
Ok((diff_sq / (orig_norm * orig_norm)).sqrt())
}
}
pub fn cp_als<F: TensorScalar>(
tensor: &Tensor<F>,
rank: usize,
config: &CPConfig,
) -> LinalgResult<CPResult<F>> {
if rank == 0 {
return Err(LinalgError::ValueError("rank must be > 0".to_string()));
}
let ndim = tensor.ndim();
let shape = &tensor.shape;
let mut factors: Vec<Array2<F>> = init_factors(shape, rank, config.random_init, config.seed);
let tol = F::from(config.tol).unwrap_or(F::zero());
let mut loss_history: Vec<F> = Vec::with_capacity(config.max_iter + 1);
let initial_loss = reconstruction_loss(tensor, &factors, &vec![F::one(); rank]);
loss_history.push(initial_loss);
let mut lambdas: Vec<F> = vec![F::one(); rank];
for _iter in 0..config.max_iter {
for n in 0..ndim {
let kr = khatri_rao_except(&factors, n);
let unfolded = tensor.unfold(n)?;
let gram_hadamard = hadamard_gram_except(&factors, n, rank);
let gram_pinv = pinv_symmetric(&gram_hadamard)?;
let new_factor = unfolded.dot(&kr).dot(&gram_pinv);
let (normed, norms) = normalise_columns(new_factor);
lambdas = norms;
factors[n] = normed;
}
let loss = reconstruction_loss(tensor, &factors, &lambdas);
let prev_loss = loss_history.last().copied().unwrap_or(loss);
loss_history.push(loss);
let delta = if prev_loss > F::zero() {
(prev_loss - loss).abs() / prev_loss
} else {
F::zero()
};
if delta < tol {
break;
}
}
let root_n = F::from(ndim as f64).unwrap_or(F::one());
let scale: Vec<F> = lambdas
.iter()
.map(|&l| {
if l < F::zero() {
-(-l).powf(F::one() / root_n)
} else {
l.powf(F::one() / root_n)
}
})
.collect();
for n in 0..ndim {
let mut factor = factors[n].clone();
for r in 0..rank {
let s = scale[r];
for i in 0..factor.nrows() {
factor[[i, r]] = factor[[i, r]] * s;
}
}
factors[n] = factor;
}
let final_lambdas = vec![F::one(); rank];
Ok(CPResult {
lambdas: final_lambdas,
factors,
loss: loss_history,
})
}
pub fn cp_grad<F: TensorScalar>(
tensor: &Tensor<F>,
rank: usize,
config: &CPConfig,
) -> LinalgResult<CPResult<F>> {
if rank == 0 {
return Err(LinalgError::ValueError("rank must be > 0".to_string()));
}
let ndim = tensor.ndim();
let shape = &tensor.shape;
let lr = F::from(config.lr).unwrap_or(F::from(1e-3_f64).unwrap_or(F::one()));
let tol = F::from(config.tol).unwrap_or(F::zero());
let mut factors: Vec<Array2<F>> = init_factors(shape, rank, config.random_init, config.seed);
let lambdas: Vec<F> = vec![F::one(); rank];
let mut loss_history: Vec<F> = Vec::with_capacity(config.max_iter + 1);
loss_history.push(reconstruction_loss(tensor, &factors, &lambdas));
for _iter in 0..config.max_iter {
let recon = reconstruct_from_factors(tensor.shape.clone(), &factors, &lambdas);
let residual_data: Vec<F> = tensor
.data
.iter()
.zip(recon.data.iter())
.map(|(&t, &r)| t - r)
.collect();
let residual = Tensor::new(residual_data, tensor.shape.clone())?;
for n in 0..ndim {
let res_unfold = residual.unfold(n)?;
let kr = khatri_rao_except(&factors, n);
let grad: Array2<F> = {
let raw = res_unfold.dot(&kr);
let mut g = Array2::<F>::zeros(raw.dim());
for i in 0..raw.nrows() {
for j in 0..raw.ncols() {
g[[i, j]] = -raw[[i, j]];
}
}
g
};
let mut new_factor = factors[n].clone();
for i in 0..new_factor.nrows() {
for j in 0..new_factor.ncols() {
new_factor[[i, j]] = new_factor[[i, j]] - lr * grad[[i, j]];
}
}
factors[n] = new_factor;
}
let loss = reconstruction_loss(tensor, &factors, &lambdas);
let prev_loss = loss_history.last().copied().unwrap_or(loss);
loss_history.push(loss);
let delta = if prev_loss > F::zero() {
(prev_loss - loss).abs() / prev_loss
} else {
F::zero()
};
if delta < tol {
break;
}
}
Ok(CPResult {
lambdas,
factors,
loss: loss_history,
})
}
pub fn cp_reconstruct<F: TensorScalar>(
result: &CPResult<F>,
shape: &[usize],
) -> LinalgResult<Tensor<F>> {
if result.factors.len() != shape.len() {
return Err(LinalgError::DimensionError(format!(
"factors.len() {} != shape.len() {}",
result.factors.len(),
shape.len()
)));
}
Ok(reconstruct_from_factors(
shape.to_vec(),
&result.factors,
&result.lambdas,
))
}
fn init_factors<F: TensorScalar>(
shape: &[usize],
rank: usize,
random_init: bool,
seed: u64,
) -> Vec<Array2<F>> {
let mut lcg_state = seed.wrapping_add(1);
let mut lcg = || -> f64 {
lcg_state = lcg_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((lcg_state >> 33) as f64) / (u32::MAX as f64)
};
shape
.iter()
.map(|&dim| {
let mut mat = Array2::<F>::zeros((dim, rank));
for i in 0..dim {
for r in 0..rank {
let val: f64 = if random_init {
lcg() * 2.0 - 1.0
} else {
let t = (i as f64 + 0.5) * std::f64::consts::PI * (r as f64 + 1.0)
/ (dim as f64 + 1.0);
t.sin() / (dim as f64).sqrt()
};
mat[[i, r]] = F::from(val).unwrap_or(F::zero());
}
}
mat
})
.collect()
}
fn khatri_rao_except<F: TensorScalar>(factors: &[Array2<F>], skip: usize) -> Array2<F> {
let rank = factors[0].ncols();
let ndim = factors.len();
let modes: Vec<usize> = (0..ndim).filter(|&k| k != skip).collect();
if modes.is_empty() {
return Array2::<F>::eye(rank);
}
let mut result = factors[modes[0]].clone();
for &m in &modes[1..] {
let a = &result;
let b = &factors[m];
let rows_a = a.nrows();
let rows_b = b.nrows();
let mut kr = Array2::<F>::zeros((rows_a * rows_b, rank));
for r in 0..rank {
for i in 0..rows_a {
for j in 0..rows_b {
kr[[i * rows_b + j, r]] = a[[i, r]] * b[[j, r]];
}
}
}
result = kr;
}
result
}
fn hadamard_gram_except<F: TensorScalar>(
factors: &[Array2<F>],
skip: usize,
rank: usize,
) -> Array2<F> {
let mut result = Array2::<F>::from_elem((rank, rank), F::one());
for (k, factor) in factors.iter().enumerate() {
if k == skip {
continue;
}
let gram = factor.t().dot(factor);
for i in 0..rank {
for j in 0..rank {
result[[i, j]] = result[[i, j]] * gram[[i, j]];
}
}
}
result
}
fn pinv_symmetric<F: TensorScalar>(mat: &Array2<F>) -> LinalgResult<Array2<F>> {
let n = mat.nrows();
let eps = F::from(1e-12_f64).unwrap_or(F::zero());
let mut reg = mat.clone();
for i in 0..n {
reg[[i, i]] = reg[[i, i]] + eps;
}
let mut augmented: Vec<Vec<F>> = (0..n)
.map(|i| {
let mut row: Vec<F> = (0..n).map(|j| reg[[i, j]]).collect();
for j in 0..n {
row.push(if i == j { F::one() } else { F::zero() });
}
row
})
.collect();
for col in 0..n {
let pivot_row = (col..n)
.max_by(|&a, &b| {
augmented[a][col]
.abs()
.partial_cmp(&augmented[b][col].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.ok_or_else(|| LinalgError::SingularMatrixError("gram matrix singular".to_string()))?;
augmented.swap(col, pivot_row);
let pivot = augmented[col][col];
if pivot.abs() < F::from(1e-30_f64).unwrap_or(F::zero()) {
return Err(LinalgError::SingularMatrixError(
"gram matrix numerically singular".to_string(),
));
}
let inv_pivot = F::one() / pivot;
for j in 0..(2 * n) {
augmented[col][j] = augmented[col][j] * inv_pivot;
}
for row in 0..n {
if row == col {
continue;
}
let factor = augmented[row][col];
for j in 0..(2 * n) {
let sub = factor * augmented[col][j];
augmented[row][j] = augmented[row][j] - sub;
}
}
}
let mut inv = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
inv[[i, j]] = augmented[i][n + j];
}
}
Ok(inv)
}
fn normalise_columns<F: TensorScalar>(mat: Array2<F>) -> (Array2<F>, Vec<F>) {
let (m, r) = (mat.nrows(), mat.ncols());
let mut normed = mat.clone();
let mut norms = vec![F::one(); r];
for j in 0..r {
let sq: F = (0..m).map(|i| mat[[i, j]] * mat[[i, j]]).fold(F::zero(), |a, b| a + b);
let n = sq.sqrt();
norms[j] = n;
if n > F::from(1e-30_f64).unwrap_or(F::zero()) {
let inv_n = F::one() / n;
for i in 0..m {
normed[[i, j]] = mat[[i, j]] * inv_n;
}
}
}
(normed, norms)
}
fn reconstruction_loss<F: TensorScalar>(
tensor: &Tensor<F>,
factors: &[Array2<F>],
lambdas: &[F],
) -> F {
let recon = reconstruct_from_factors(tensor.shape.clone(), factors, lambdas);
tensor
.data
.iter()
.zip(recon.data.iter())
.map(|(&t, &r)| {
let d = t - r;
d * d
})
.fold(F::zero(), |a, b| a + b)
}
pub(crate) fn reconstruct_from_factors<F: TensorScalar>(
shape: Vec<usize>,
factors: &[Array2<F>],
lambdas: &[F],
) -> Tensor<F> {
let ndim = shape.len();
let rank = lambdas.len();
let total: usize = shape.iter().product();
let strides = crate::tensor::core::compute_row_major_strides(&shape);
let mut data = vec![F::zero(); total];
for flat in 0..total {
let mut multi = vec![0usize; ndim];
let mut rem = flat;
for d in (0..ndim).rev() {
multi[d] = rem % shape[d];
rem /= shape[d];
}
let mut val = F::zero();
for r in 0..rank {
let mut contrib = lambdas[r];
for n in 0..ndim {
contrib = contrib * factors[n][[multi[n], r]];
}
val = val + contrib;
}
let idx: usize = multi.iter().zip(strides.iter()).map(|(i, s)| i * s).sum();
data[idx] = val;
}
Tensor {
data,
shape,
strides,
}
}
#[allow(unused_imports)]
use scirs2_core::ndarray::ArrayBase;
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn make_tensor() -> Tensor<f64> {
let data: Vec<f64> = (0..24).map(|x| x as f64 + 1.0).collect();
Tensor::new(data, vec![2, 3, 4]).expect("valid")
}
#[test]
fn test_cp_als_shape() {
let t = make_tensor();
let cfg = CPConfig { max_iter: 50, ..Default::default() };
let r = cp_als(&t, 3, &cfg).expect("cp_als");
assert_eq!(r.factors.len(), 3);
assert_eq!(r.factors[0].shape(), &[2, 3]);
assert_eq!(r.factors[1].shape(), &[3, 3]);
assert_eq!(r.factors[2].shape(), &[4, 3]);
assert_eq!(r.lambdas.len(), 3);
}
#[test]
fn test_cp_als_loss_decreasing() {
let t = make_tensor();
let cfg = CPConfig { max_iter: 100, ..Default::default() };
let r = cp_als(&t, 4, &cfg).expect("cp_als");
for window in r.loss.windows(2) {
assert!(
window[1] <= window[0] + 1e-6,
"loss increased: {} -> {}",
window[0],
window[1]
);
}
}
#[test]
fn test_cp_reconstruct_shape() {
let t = make_tensor();
let cfg = CPConfig::default();
let r = cp_als(&t, 3, &cfg).expect("ok");
let recon = cp_reconstruct(&r, &[2, 3, 4]).expect("ok");
assert_eq!(recon.shape, vec![2, 3, 4]);
}
#[test]
fn test_cp_grad_shape() {
let t = make_tensor();
let cfg = CPConfig { max_iter: 50, lr: 1e-3, ..Default::default() };
let r = cp_grad(&t, 2, &cfg).expect("cp_grad");
assert_eq!(r.factors.len(), 3);
}
#[test]
fn test_cp_als_low_rank_approximation() {
let a = vec![1.0_f64, 2.0];
let b = vec![1.0_f64, 3.0, 5.0];
let c = vec![1.0_f64, 0.5, 2.0, 0.25];
let mut data = vec![0.0_f64; 2 * 3 * 4];
for i in 0..2 {
for j in 0..3 {
for k in 0..4 {
data[i * 12 + j * 4 + k] = a[i] * b[j] * c[k];
}
}
}
let t = Tensor::new(data, vec![2, 3, 4]).expect("ok");
let cfg = CPConfig { max_iter: 300, tol: 1e-10, ..Default::default() };
let r = cp_als(&t, 1, &cfg).expect("rank-1 ALS");
let err = r.relative_error(&t).expect("err");
assert!(err < 1e-5, "rank-1 CP should reconstruct exactly, err={err}");
}
#[test]
fn test_cp_invalid_rank() {
let t = make_tensor();
let cfg = CPConfig::default();
assert!(cp_als(&t, 0, &cfg).is_err());
assert!(cp_grad(&t, 0, &cfg).is_err());
}
}