use crate::error::{FFTError, FFTResult};
pub fn matmul(a: &[f64], m: usize, k: usize, b: &[f64], n: usize) -> FFTResult<Vec<f64>> {
if a.len() != m * k {
return Err(FFTError::ValueError(
format!("matmul: a has {} elements, expected {}", a.len(), m * k),
));
}
if b.len() != k * n {
return Err(FFTError::ValueError(
format!("matmul: b has {} elements, expected {}", b.len(), k * n),
));
}
let mut c = vec![0.0_f64; m * n];
for i in 0..m {
for l in 0..k {
let a_il = a[i * k + l];
for j in 0..n {
c[i * n + j] += a_il * b[l * n + j];
}
}
}
Ok(c)
}
pub fn frobenius_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub fn truncated_svd(
a: &[f64],
m: usize,
n: usize,
max_rank: usize,
) -> FFTResult<(Vec<f64>, Vec<f64>, Vec<f64>)> {
if a.len() != m * n {
return Err(FFTError::ValueError(
format!("truncated_svd: a has {} elements, expected {}", a.len(), m * n),
));
}
let r_full = m.min(n);
let r = max_rank.min(r_full);
let (u_full, s_full, vt_full) = bidiag_svd(a, m, n)?;
let mut u = vec![0.0_f64; m * r];
for i in 0..m {
for k in 0..r {
u[i * r + k] = u_full[i * r_full + k];
}
}
let s: Vec<f64> = s_full[..r].to_vec();
let mut vt = vec![0.0_f64; r * n];
for k in 0..r {
for j in 0..n {
vt[k * n + j] = vt_full[k * n + j];
}
}
Ok((u, s, vt))
}
pub fn bidiag_svd(a: &[f64], m: usize, n: usize) -> FFTResult<(Vec<f64>, Vec<f64>, Vec<f64>)> {
if m == 0 || n == 0 {
return Err(FFTError::ValueError("bidiag_svd: zero-dimension matrix".into()));
}
let r_full = m.min(n);
let (mut ub, mut b_diag, mut b_super, mut vbt) = householder_bidiag(a, m, n)?;
let max_iter = 1000 * r_full;
for _ in 0..max_iter {
let mut converged = true;
for k in 0..b_super.len() {
if b_super[k].abs() > 1e-14 * (b_diag[k].abs() + b_diag[k + 1].abs()) {
converged = false;
break;
}
}
if converged { break; }
golub_kahan_step(&mut b_diag, &mut b_super, &mut ub, &mut vbt, m, n)?;
}
for k in 0..r_full {
if b_diag[k] < 0.0 {
b_diag[k] = -b_diag[k];
for i in 0..m {
ub[i * r_full + k] = -ub[i * r_full + k];
}
}
}
let mut order: Vec<usize> = (0..r_full).collect();
order.sort_by(|&a, &b| b_diag[b].partial_cmp(&b_diag[a]).unwrap_or(std::cmp::Ordering::Equal));
let s: Vec<f64> = order.iter().map(|&k| b_diag[k]).collect();
let mut u: Vec<f64> = Vec::with_capacity(m * r_full);
for i in 0..m {
for &k in order.iter() {
u.push(ub[i * r_full + k]);
}
}
let mut vt: Vec<f64> = Vec::with_capacity(r_full * n);
for &k in order.iter() {
for j in 0..n {
vt.push(vbt[k * n + j]);
}
}
Ok((u, s, vt))
}
fn householder_bidiag(
a: &[f64],
m: usize,
n: usize,
) -> FFTResult<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>)> {
let r = m.min(n);
let mut mat = a.to_vec();
let mut u_mat = eye(m);
let mut v_mat = eye(n);
let mut diag = vec![0.0_f64; r];
let mut super_d = vec![0.0_f64; r.saturating_sub(1)];
for k in 0..r {
{
let col_len = m - k;
let mut v: Vec<f64> = (0..col_len).map(|i| mat[(k + i) * n + k]).collect();
let sigma = householder_vec(&mut v);
diag[k] = if mat[k * n + k] >= 0.0 { -sigma } else { sigma };
if sigma.abs() > 1e-15 {
apply_householder_left(&mut mat, m, n, k, &v, sigma);
apply_householder_right_rect(&mut u_mat, m, m, k, &v, sigma);
}
}
if k < n - 1 && k < r - 1 {
let row_len = n - k - 1;
let mut v: Vec<f64> = (0..row_len).map(|j| mat[k * n + (k + 1 + j)]).collect();
let sigma = householder_vec(&mut v);
super_d[k] = if mat[k * n + k + 1] >= 0.0 { -sigma } else { sigma };
if sigma.abs() > 1e-15 {
apply_householder_right(&mut mat, m, n, k, &v, sigma);
apply_householder_right_rect(&mut v_mat, n, n, k + 1, &v, sigma);
}
}
}
let mut u_r = vec![0.0_f64; m * r];
for i in 0..m {
for j in 0..r {
u_r[i * r + j] = u_mat[i * m + j];
}
}
let mut vt_r = vec![0.0_f64; r * n];
for k in 0..r {
for j in 0..n {
vt_r[k * n + j] = v_mat[j * n + k];
}
}
Ok((u_r, diag, super_d, vt_r))
}
fn householder_vec(v: &mut Vec<f64>) -> f64 {
let sigma: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if sigma < 1e-15 { return 0.0; }
let sign = if v[0] >= 0.0 { 1.0 } else { -1.0 };
v[0] += sign * sigma;
let norm2: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm2 > 1e-15 {
for x in v.iter_mut() { *x /= norm2; }
}
sigma
}
fn apply_householder_left(mat: &mut Vec<f64>, m: usize, n: usize, k: usize, v: &[f64], _sigma: f64) {
for j in k..n {
let dot: f64 = v.iter().enumerate().map(|(i, vi)| vi * mat[(k + i) * n + j]).sum();
for (i, vi) in v.iter().enumerate() {
mat[(k + i) * n + j] -= 2.0 * dot * vi;
}
}
}
fn apply_householder_right(mat: &mut Vec<f64>, m: usize, n: usize, k: usize, v: &[f64], _sigma: f64) {
for i in k..m {
let dot: f64 = v
.iter()
.enumerate()
.map(|(j, vj)| vj * mat[i * n + (k + 1 + j)])
.sum();
for (j, vj) in v.iter().enumerate() {
mat[i * n + (k + 1 + j)] -= 2.0 * dot * vj;
}
}
}
fn apply_householder_right_rect(
mat: &mut Vec<f64>,
rows: usize,
cols: usize,
start_col: usize,
v: &[f64],
_sigma: f64,
) {
for i in 0..rows {
let dot: f64 = v
.iter()
.enumerate()
.map(|(j, vj)| vj * mat[i * cols + (start_col + j)])
.sum();
for (j, vj) in v.iter().enumerate() {
mat[i * cols + (start_col + j)] -= 2.0 * dot * vj;
}
}
}
fn golub_kahan_step(
diag: &mut Vec<f64>,
super_d: &mut Vec<f64>,
u: &mut Vec<f64>,
vt: &mut Vec<f64>,
m: usize,
n: usize,
) -> FFTResult<()> {
let r = diag.len();
if r == 0 { return Ok(()); }
let t = if r >= 2 {
let d = diag[r - 1];
let e = super_d.last().copied().unwrap_or(0.0);
let d2 = diag[r - 2];
let e2 = if r >= 3 { super_d.get(r - 3).copied().unwrap_or(0.0) } else { 0.0 };
let tr = d2 * d2 + e2 * e2 + d * d + e * e;
let det = (d2 * d2 + e2 * e2) * (d * d + e * e) - e2 * e2 * d * d;
let mu1 = tr * 0.5 + (tr * tr * 0.25 - det).max(0.0).sqrt();
let mu2 = tr * 0.5 - (tr * tr * 0.25 - det).max(0.0).sqrt();
let shift1 = (mu1 - d * d - e * e).abs();
let shift2 = (mu2 - d * d - e * e).abs();
if shift1 < shift2 { mu1 } else { mu2 }
} else {
0.0
};
let mut f = diag[0] * diag[0] - t;
let mut g = diag[0] * super_d.first().copied().unwrap_or(0.0);
for k in 0..r.saturating_sub(1) {
let (c, s) = givens(f, g);
if k > 0 {
let sd = super_d[k - 1];
super_d[k - 1] = c * sd + s * diag[k];
diag[k] = -s * sd + c * diag[k];
} else {
diag[k] = c * diag[k] + s * (super_d.get(k).copied().unwrap_or(0.0));
}
if k < super_d.len() {
let d_k = diag[k];
let e_k = super_d[k];
let d_k1 = diag[k + 1];
let new_dk = c * d_k + s * e_k;
let new_ek = -s * d_k + c * e_k; let new_dk1 = d_k1;
f = new_dk;
g = s * new_dk1;
super_d[k] = s * new_ek + c * (new_dk1 * 0.0);
let r_full = n.min(m);
if k + 1 < r_full {
for i in 0..n {
let vk = vt[k * n + i];
let vk1 = vt[(k + 1) * n + i];
vt[k * n + i] = c * vk + s * vk1;
vt[(k + 1) * n + i] = -s * vk + c * vk1;
}
}
let (c2, s2) = givens(f, g);
let _ = (c2, s2);
let new_ek2 = c2 * new_ek + s2 * d_k1;
let new_dk1_2 = -s2 * new_ek + c2 * d_k1;
super_d[k] = new_ek2;
diag[k + 1] = new_dk1_2;
diag[k] = new_dk;
if k + 1 < r - 1 {
let e_next = super_d[k + 1];
f = new_dk1_2;
g = s2 * e_next;
super_d[k + 1] = c2 * e_next;
}
for i in 0..m {
let uk = u[i * r + k];
let uk1 = u[i * r + k + 1];
u[i * r + k] = c2 * uk + s2 * uk1;
u[i * r + k + 1] = -s2 * uk + c2 * uk1;
}
}
}
Ok(())
}
fn givens(f: f64, g: f64) -> (f64, f64) {
if g.abs() < 1e-15 {
return (1.0, 0.0);
}
if f.abs() < 1e-15 {
return (0.0, 1.0);
}
let r = (f * f + g * g).sqrt();
(f / r, g / r)
}
fn eye(n: usize) -> Vec<f64> {
let mut m = vec![0.0_f64; n * n];
for i in 0..n {
m[i * n + i] = 1.0;
}
m
}
pub fn n_mode_unfolding(tensor: &[f64], shape: &[usize], mode: usize) -> FFTResult<(Vec<f64>, usize, usize)> {
let d = shape.len();
if mode >= d {
return Err(FFTError::ValueError(
format!("n_mode_unfolding: mode {mode} ≥ d={d}"),
));
}
let n_total: usize = shape.iter().product();
if tensor.len() != n_total {
return Err(FFTError::ValueError(
format!("n_mode_unfolding: tensor length {} ≠ {}", tensor.len(), n_total),
));
}
let n_rows = shape[mode];
let n_cols = n_total / n_rows;
let mut mat = vec![0.0_f64; n_rows * n_cols];
let mut strides = vec![1usize; d];
for k in (0..d - 1).rev() {
strides[k] = strides[k + 1] * shape[k + 1];
}
for flat_idx in 0..n_total {
let mut multi_idx = vec![0usize; d];
let mut rem = flat_idx;
for k in 0..d {
multi_idx[k] = rem / strides[k];
rem %= strides[k];
}
let row = multi_idx[mode];
let mut col = 0usize;
let mut col_stride = 1usize;
let col_order: Vec<usize> = (mode + 1..d).chain(0..mode).collect();
let col_strides: Vec<usize> = {
let mut cs = vec![1usize; d - 1];
let mut acc = 1usize;
for (i, &mo) in col_order.iter().rev().enumerate() {
let idx = col_order.len() - 1 - i;
cs[idx] = acc;
acc *= shape[mo];
}
cs
};
let _ = col_stride; for (i, &mo) in col_order.iter().enumerate() {
col += multi_idx[mo] * col_strides[i];
}
mat[row * n_cols + col] = tensor[flat_idx];
}
Ok((mat, n_rows, n_cols))
}
pub fn n_mode_folding(mat: &[f64], shape: &[usize], mode: usize) -> FFTResult<Vec<f64>> {
let d = shape.len();
if mode >= d {
return Err(FFTError::ValueError(
format!("n_mode_folding: mode {mode} ≥ d={d}"),
));
}
let n_total: usize = shape.iter().product();
let n_rows = shape[mode];
let n_cols = n_total / n_rows;
if mat.len() != n_rows * n_cols {
return Err(FFTError::ValueError(
format!("n_mode_folding: mat length {} ≠ {}", mat.len(), n_rows * n_cols),
));
}
let mut tensor = vec![0.0_f64; n_total];
let mut strides = vec![1usize; d];
for k in (0..d - 1).rev() {
strides[k] = strides[k + 1] * shape[k + 1];
}
for flat_idx in 0..n_total {
let mut multi_idx = vec![0usize; d];
let mut rem = flat_idx;
for k in 0..d {
multi_idx[k] = rem / strides[k];
rem %= strides[k];
}
let row = multi_idx[mode];
let col_order: Vec<usize> = (mode + 1..d).chain(0..mode).collect();
let col_strides: Vec<usize> = {
let mut cs = vec![1usize; d - 1];
let mut acc = 1usize;
for (i, &mo) in col_order.iter().rev().enumerate() {
let idx = col_order.len() - 1 - i;
cs[idx] = acc;
acc *= shape[mo];
}
cs
};
let mut col = 0usize;
for (i, &mo) in col_order.iter().enumerate() {
col += multi_idx[mo] * col_strides[i];
}
tensor[flat_idx] = mat[row * n_cols + col];
}
Ok(tensor)
}
pub fn n_mode_product(
tensor: &[f64],
shape: &[usize],
mode: usize,
u: &[f64],
r: usize,
) -> FFTResult<Vec<f64>> {
let (mat, n_rows, n_cols) = n_mode_unfolding(tensor, shape, mode)?;
if u.len() != r * n_rows {
return Err(FFTError::ValueError(
format!("n_mode_product: U has {} elements, expected {}×{}={}", u.len(), r, n_rows, r * n_rows),
));
}
let result_mat = matmul(u, r, n_rows, &mat, n_cols)?;
let mut new_shape = shape.to_vec();
new_shape[mode] = r;
n_mode_folding(&result_mat, &new_shape, mode)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matmul_identity() {
let a = vec![1.0, 0.0, 0.0, 1.0]; let b = vec![3.0, 4.0, 5.0, 6.0]; let c = matmul(&a, 2, 2, &b, 2).expect("failed to create c");
assert!((c[0] - 3.0).abs() < 1e-12);
assert!((c[3] - 6.0).abs() < 1e-12);
}
#[test]
fn test_n_mode_unfolding_roundtrip() {
let shape = vec![2, 3, 4];
let n = shape.iter().product::<usize>();
let tensor: Vec<f64> = (0..n).map(|i| i as f64).collect();
for mode in 0..3 {
let (mat, rows, cols) = n_mode_unfolding(&tensor, &shape, mode).expect("unexpected None or Err");
assert_eq!(rows, shape[mode]);
assert_eq!(cols, n / shape[mode]);
assert_eq!(mat.len(), n);
let recovered = n_mode_folding(&mat, &shape, mode).expect("failed to create recovered");
for (a, b) in tensor.iter().zip(recovered.iter()) {
assert!((a - b).abs() < 1e-12, "roundtrip failed at a={a} b={b}");
}
}
}
#[test]
fn test_frobenius_norm() {
let v = vec![3.0_f64, 4.0];
assert!((frobenius_norm(&v) - 5.0).abs() < 1e-12);
}
}