use super::{FixedPoint, FixedVector, FixedMatrix};
use super::tensor::Tensor;
use super::decompose::svd_decompose;
use crate::fixed_point::core_types::errors::OverflowDetected;
pub struct TruncatedSVD {
pub u: FixedMatrix,
pub sigma: FixedVector,
pub vt: FixedMatrix,
}
impl TruncatedSVD {
pub fn reconstruct(&self) -> FixedMatrix {
let m = self.u.rows();
let n = self.vt.cols();
let k = self.sigma.len();
let mut result = FixedMatrix::new(m, n);
for r in 0..k {
let sv = self.sigma[r];
for i in 0..m {
let u_ir = self.u.get(i, r) * sv;
for j in 0..n {
let val = result.get(i, j) + u_ir * self.vt.get(r, j);
result.set(i, j, val);
}
}
}
result
}
pub fn compression_ratio(&self, m: usize, n: usize) -> f64 {
let k = self.sigma.len();
(m * n) as f64 / (m * k + k + k * n) as f64
}
}
pub fn truncated_svd(a: &FixedMatrix, k: usize) -> Result<TruncatedSVD, OverflowDetected> {
let svd = svd_decompose(a)?;
let full_k = svd.sigma.len();
let k = k.min(full_k);
let m = svd.u.rows();
let mut u_k = FixedMatrix::new(m, k);
for i in 0..m {
for j in 0..k {
u_k.set(i, j, svd.u.get(i, j));
}
}
let mut sigma_k = FixedVector::new(k);
for i in 0..k {
sigma_k[i] = svd.sigma[i];
}
let n = svd.vt.cols();
let mut vt_k = FixedMatrix::new(k, n);
for i in 0..k {
for j in 0..n {
vt_k.set(i, j, svd.vt.get(i, j));
}
}
Ok(TruncatedSVD { u: u_k, sigma: sigma_k, vt: vt_k })
}
pub fn truncated_svd_auto(a: &FixedMatrix, threshold: Option<FixedPoint>) -> Result<TruncatedSVD, OverflowDetected> {
let svd = svd_decompose(a)?;
let thresh = threshold.unwrap_or_else(|| {
if svd.sigma.len() == 0 { return FixedPoint::one(); }
let sigma_max = svd.sigma[0];
let dim_factor = FixedPoint::from_int(a.rows().max(a.cols()) as i32);
let eps = super::linalg::convergence_threshold(sigma_max);
dim_factor * eps
});
let mut k = 0;
for i in 0..svd.sigma.len() {
if svd.sigma[i] > thresh { k += 1; } else { break; }
}
if k == 0 { k = 1; }
truncated_svd(a, k)
}
pub struct TuckerDecomposition {
pub core: Tensor,
pub factors: Vec<FixedMatrix>,
}
impl TuckerDecomposition {
pub fn reconstruct(&self) -> Tensor {
let mut result = self.core.clone();
for (n, u) in self.factors.iter().enumerate() {
result = mode_n_product(&result, u, n);
}
result
}
pub fn compression_ratio(&self, original_shape: &[usize]) -> f64 {
let orig: usize = original_shape.iter().product();
let core_size: usize = self.core.shape().iter().product();
let factor_size: usize = self.factors.iter().enumerate()
.map(|(n, f)| original_shape[n] * f.cols())
.sum();
orig as f64 / (core_size + factor_size) as f64
}
}
pub fn tucker_decompose(t: &Tensor, ranks: &[usize]) -> Result<TuckerDecomposition, OverflowDetected> {
let ndim = t.rank();
assert_eq!(ranks.len(), ndim, "ranks must have one entry per tensor mode");
let mut factors: Vec<FixedMatrix> = Vec::with_capacity(ndim);
for n in 0..ndim {
let unfolded = mode_unfold(t, n);
let k = ranks[n].min(unfolded.rows()).min(unfolded.cols());
let tsvd = truncated_svd(&unfolded, k)?;
factors.push(tsvd.u); }
let mut core = t.clone();
for (n, u) in factors.iter().enumerate() {
let ut = u.transpose();
core = mode_n_product(&core, &ut, n);
}
Ok(TuckerDecomposition { core, factors })
}
pub struct CPDecomposition {
pub weights: FixedVector,
pub factors: Vec<FixedMatrix>,
}
impl CPDecomposition {
pub fn reconstruct(&self, shape: &[usize]) -> Tensor {
let rank = self.weights.len();
let total: usize = shape.iter().product();
let mut data = vec![FixedPoint::ZERO; total];
for r in 0..rank {
let w = self.weights[r];
add_rank1_to_flat(&mut data, shape, &self.factors, r, w);
}
Tensor::from_data(shape, &data)
}
}
pub fn cp_decompose(
t: &Tensor,
rank: usize,
max_iter: usize,
_tol: FixedPoint,
) -> Result<CPDecomposition, OverflowDetected> {
let ndim = t.rank();
let shape = t.shape().to_vec();
let mut factors: Vec<FixedMatrix> = Vec::with_capacity(ndim);
for n in 0..ndim {
let unfolded = mode_unfold(t, n);
let k = rank.min(unfolded.rows()).min(unfolded.cols());
let svd = svd_decompose(&unfolded)?;
let mut f = FixedMatrix::new(shape[n], rank);
for i in 0..shape[n] {
for r in 0..rank {
if r < k {
f.set(i, r, svd.u.get(i, r));
}
}
}
factors.push(f);
}
for _iter in 0..max_iter {
for n in 0..ndim {
let v = khatri_rao_except(&factors, n, &shape);
let unfolded = mode_unfold(t, n);
let vt = v.transpose();
let vtv = &vt * &v; let rhs = &unfolded * &v; match super::derived::inverse(&vtv) {
Ok(vtv_inv) => {
factors[n] = &rhs * &vtv_inv;
}
Err(_) => {
continue;
}
}
}
}
let mut weights = FixedVector::new(rank);
for r in 0..rank {
let mut norm_product = FixedPoint::one();
for n in 0..ndim {
let mut col_norm_sq = FixedPoint::ZERO;
for i in 0..shape[n] {
let v = factors[n].get(i, r);
col_norm_sq = col_norm_sq + v * v;
}
let col_norm = col_norm_sq.sqrt();
if !col_norm.is_zero() {
for i in 0..shape[n] {
let v = factors[n].get(i, r);
factors[n].set(i, r, v / col_norm);
}
norm_product = norm_product * col_norm;
}
}
weights[r] = norm_product;
}
Ok(CPDecomposition { weights, factors })
}
fn mode_unfold(t: &Tensor, mode: usize) -> FixedMatrix {
let shape = t.shape();
let ndim = shape.len();
let rows = shape[mode];
let cols: usize = shape.iter().enumerate()
.filter(|&(i, _)| i != mode)
.map(|(_, &d)| d)
.product();
let mut result = FixedMatrix::new(rows, cols);
let mut perm: Vec<usize> = vec![mode];
for i in 0..ndim {
if i != mode { perm.push(i); }
}
let mut indices = vec![0usize; ndim];
let total: usize = shape.iter().product();
for flat in 0..total {
let mut rem = flat;
for d in (0..ndim).rev() {
indices[d] = rem % shape[d];
rem /= shape[d];
}
let row = indices[mode];
let mut col = 0;
let mut stride = 1;
for &p in perm[1..].iter().rev() {
col += indices[p] * stride;
stride *= shape[p];
}
result.set(row, col, t.get(&indices));
}
result
}
fn mode_n_product(t: &Tensor, m: &FixedMatrix, mode: usize) -> Tensor {
let shape = t.shape();
let ndim = shape.len();
let d_n = shape[mode];
let r = m.rows();
assert_eq!(m.cols(), d_n, "Matrix cols must match tensor mode dimension");
let mut new_shape = shape.to_vec();
new_shape[mode] = r;
let total: usize = new_shape.iter().product();
let mut result = Tensor::new(&new_shape);
let mut out_indices = vec![0usize; ndim];
for flat in 0..total {
let mut rem = flat;
for d in (0..ndim).rev() {
out_indices[d] = rem % new_shape[d];
rem /= new_shape[d];
}
let i = out_indices[mode];
let mut sum = FixedPoint::ZERO;
let mut src_indices = out_indices.clone();
for k in 0..d_n {
src_indices[mode] = k;
sum = sum + m.get(i, k) * t.get(&src_indices);
}
result.set(&out_indices, sum);
}
result
}
fn khatri_rao_except(factors: &[FixedMatrix], skip: usize, shape: &[usize]) -> FixedMatrix {
let rank = factors[0].cols();
let ndim = factors.len();
let rows: usize = shape.iter().enumerate()
.filter(|&(i, _)| i != skip)
.map(|(_, &d)| d)
.product();
let mut result = FixedMatrix::new(rows, rank);
for r in 0..rank {
let active_modes: Vec<usize> = (0..ndim).filter(|&i| i != skip).collect();
for row in 0..rows {
let mut rem = row;
let mut val = FixedPoint::one();
for &m in active_modes.iter().rev() {
let idx = rem % shape[m];
rem /= shape[m];
val = val * factors[m].get(idx, r);
}
result.set(row, r, val);
}
}
result
}
fn add_rank1_to_flat(
data: &mut [FixedPoint],
shape: &[usize],
factors: &[FixedMatrix],
r: usize,
weight: FixedPoint,
) {
let ndim = shape.len();
let total = data.len();
let mut indices = vec![0usize; ndim];
for flat in 0..total {
let mut rem = flat;
for d in (0..ndim).rev() {
indices[d] = rem % shape[d];
rem /= shape[d];
}
let mut val = weight;
for n in 0..ndim {
val = val * factors[n].get(indices[n], r);
}
data[flat] = data[flat] + val;
}
}