use crate::error::{LinalgError, LinalgResult};
use crate::tensor_decomp::tensor_utils::{mat_mul, truncated_svd};
#[derive(Debug, Clone)]
pub struct TTCore {
pub data: Vec<f64>,
pub shape: (usize, usize, usize),
}
impl TTCore {
pub fn new(data: Vec<f64>, shape: (usize, usize, usize)) -> LinalgResult<Self> {
let expected = shape.0 * shape.1 * shape.2;
if data.len() != expected {
return Err(LinalgError::ShapeError(format!(
"TTCore: data length {} != shape {:?} (expected {})",
data.len(),
shape,
expected
)));
}
Ok(Self { data, shape })
}
pub fn zeros(shape: (usize, usize, usize)) -> Self {
Self {
data: vec![0.0_f64; shape.0 * shape.1 * shape.2],
shape,
}
}
#[inline]
pub fn get(&self, r_left: usize, n_idx: usize, r_right: usize) -> f64 {
self.data[r_left * self.shape.1 * self.shape.2 + n_idx * self.shape.2 + r_right]
}
#[inline]
pub fn set(&mut self, r_left: usize, n_idx: usize, r_right: usize, v: f64) {
let idx = r_left * self.shape.1 * self.shape.2 + n_idx * self.shape.2 + r_right;
self.data[idx] = v;
}
pub fn slice_n(&self, n_idx: usize) -> Vec<Vec<f64>> {
let (rl, _n, rr) = self.shape;
(0..rl)
.map(|r_l| (0..rr).map(|r_r| self.get(r_l, n_idx, r_r)).collect())
.collect()
}
}
#[derive(Debug, Clone)]
pub struct TensorTrain {
pub cores: Vec<TTCore>,
pub shape: Vec<usize>,
}
impl TensorTrain {
pub fn get(&self, indices: &[usize]) -> LinalgResult<f64> {
let d = self.cores.len();
if indices.len() != d {
return Err(LinalgError::ShapeError(format!(
"TT get: indices length {} != d={}",
indices.len(),
d
)));
}
for (k, (&idx, &n_k)) in indices.iter().zip(self.shape.iter()).enumerate() {
if idx >= n_k {
return Err(LinalgError::IndexError(format!(
"TT get: index {idx} out of bounds for mode {k} (size {n_k})"
)));
}
}
let mut current = self.cores[0].slice_n(indices[0]); for k in 1..d {
let g_slice = self.cores[k].slice_n(indices[k]); current = mat_mul(¤t, &g_slice)?;
}
if current.len() == 1 && current[0].len() == 1 {
Ok(current[0][0])
} else {
Err(LinalgError::ComputationError(
"TT get: final matrix is not 1×1".to_string(),
))
}
}
pub fn ranks(&self) -> Vec<usize> {
let mut r = vec![1_usize];
for core in &self.cores {
r.push(core.shape.2);
}
r
}
pub fn frobenius_norm(&self) -> LinalgResult<f64> {
let ip = inner_product(self, self)?;
Ok(ip.max(0.0).sqrt())
}
pub fn compress(&self, max_rank: usize, eps: f64) -> LinalgResult<TensorTrain> {
tt_round(self, max_rank, eps)
}
pub fn add(&self, other: &TensorTrain) -> LinalgResult<TensorTrain> {
tt_add(self, other)
}
pub fn hadamard(&self, other: &TensorTrain) -> LinalgResult<TensorTrain> {
tt_hadamard(self, other)
}
pub fn inner_product(&self, other: &TensorTrain) -> LinalgResult<f64> {
inner_product(self, other)
}
}
pub fn tt_svd(
tensor_data: &[f64],
shape: &[usize],
max_rank: usize,
eps: f64,
) -> LinalgResult<TensorTrain> {
if shape.is_empty() {
return Err(LinalgError::ShapeError(
"tt_svd: shape must be non-empty".to_string(),
));
}
let d = shape.len();
let total: usize = shape.iter().product();
if tensor_data.len() != total {
return Err(LinalgError::ShapeError(format!(
"tt_svd: data length {} != product(shape)={}",
tensor_data.len(),
total
)));
}
if max_rank == 0 {
return Err(LinalgError::DomainError(
"tt_svd: max_rank must be ≥ 1".to_string(),
));
}
let step_eps = if d > 1 { eps / ((d - 1) as f64).sqrt() } else { eps };
let mut cores: Vec<TTCore> = Vec::with_capacity(d);
let mut r_left = 1_usize;
let mut c: Vec<Vec<f64>> = {
let n_rest: usize = shape.iter().skip(1).product();
let rows = shape[0]; let mut mat = vec![vec![0.0_f64; n_rest]; rows];
for row in 0..rows {
for col in 0..n_rest {
mat[row][col] = tensor_data[row * n_rest + col];
}
}
mat
};
for k in 0..(d - 1) {
let n_k = shape[k];
let rows = r_left * n_k;
let cols = c[0].len();
debug_assert_eq!(c.len(), rows, "rows mismatch at k={k}");
let rank_cap = max_rank.min(rows).min(cols);
let (u_full, s_full, vt_full) = truncated_svd(&c, rank_cap)?;
let r_k = determine_rank(&s_full, step_eps, max_rank);
let mut core = TTCore::zeros((r_left, n_k, r_k));
for rl in 0..r_left {
for ni in 0..n_k {
let row_idx = rl * n_k + ni;
for rr in 0..r_k {
if row_idx < u_full.len() && rr < u_full[row_idx].len() {
core.set(rl, ni, rr, u_full[row_idx][rr]);
}
}
}
}
cores.push(core);
let n_next = shape[k + 1];
let n_remaining: usize = if k + 2 < d {
shape[k + 2..].iter().product()
} else {
1
};
let new_rows = r_k * n_next;
let new_cols = n_remaining;
let vt_trunc: Vec<Vec<f64>> = vt_full
.iter()
.take(r_k)
.enumerate()
.map(|(ri, row)| row.iter().map(|v| v * s_full[ri]).collect())
.collect();
c = vec![vec![0.0_f64; new_cols]; new_rows];
for rr in 0..r_k {
for ni in 0..n_next {
let out_row = rr * n_next + ni;
for nc in 0..n_remaining {
let in_col = ni * n_remaining + nc;
if in_col < vt_trunc[0].len() {
c[out_row][nc] = vt_trunc[rr][in_col];
}
}
}
}
r_left = r_k;
}
let n_d = shape[d - 1];
let mut last_core = TTCore::zeros((r_left, n_d, 1));
for rl in 0..r_left {
for ni in 0..n_d {
let v = if rl < c.len() && ni < c[rl].len() {
c[rl][ni]
} else {
0.0
};
last_core.set(rl, ni, 0, v);
}
}
cores.push(last_core);
Ok(TensorTrain {
shape: shape.to_vec(),
cores,
})
}
pub fn tt_round(tt: &TensorTrain, max_rank: usize, eps: f64) -> LinalgResult<TensorTrain> {
let d = tt.cores.len();
if d == 0 {
return Err(LinalgError::ShapeError(
"tt_round: empty TT tensor".to_string(),
));
}
let total: usize = tt.shape.iter().product();
if total > 1_000_000 {
return Ok(tt.clone());
}
let mut data = vec![0.0_f64; total];
fill_dense(tt, &mut data)?;
tt_svd(&data, &tt.shape, max_rank, eps)
}
pub fn tt_add(x: &TensorTrain, y: &TensorTrain) -> LinalgResult<TensorTrain> {
let d = x.cores.len();
if d != y.cores.len() {
return Err(LinalgError::ShapeError(format!(
"tt_add: X has {} cores but Y has {}",
d,
y.cores.len()
)));
}
for k in 0..d {
if x.shape[k] != y.shape[k] {
return Err(LinalgError::ShapeError(format!(
"tt_add: physical dim mismatch at mode {k}: {} vs {}",
x.shape[k], y.shape[k]
)));
}
}
let mut new_cores: Vec<TTCore> = Vec::with_capacity(d);
for k in 0..d {
let cx = &x.cores[k];
let cy = &y.cores[k];
let (rlx, n_k, rrx) = cx.shape;
let (rly, _, rry) = cy.shape;
let new_rl = if k == 0 { 1 } else { rlx + rly };
let new_rr = if k == d - 1 { 1 } else { rrx + rry };
let mut core = TTCore::zeros((new_rl, n_k, new_rr));
if k == 0 {
for ni in 0..n_k {
for rr in 0..rrx {
core.set(0, ni, rr, cx.get(0, ni, rr));
}
for rr in 0..rry {
core.set(0, ni, rrx + rr, cy.get(0, ni, rr));
}
}
} else if k == d - 1 {
for ni in 0..n_k {
for rl in 0..rlx {
core.set(rl, ni, 0, cx.get(rl, ni, 0));
}
for rl in 0..rly {
core.set(rlx + rl, ni, 0, cy.get(rl, ni, 0));
}
}
} else {
for ni in 0..n_k {
for rl in 0..rlx {
for rr in 0..rrx {
core.set(rl, ni, rr, cx.get(rl, ni, rr));
}
}
for rl in 0..rly {
for rr in 0..rry {
core.set(rlx + rl, ni, rrx + rr, cy.get(rl, ni, rr));
}
}
}
}
new_cores.push(core);
}
Ok(TensorTrain {
cores: new_cores,
shape: x.shape.clone(),
})
}
pub fn tt_hadamard(x: &TensorTrain, y: &TensorTrain) -> LinalgResult<TensorTrain> {
let d = x.cores.len();
if d != y.cores.len() {
return Err(LinalgError::ShapeError(format!(
"tt_hadamard: X has {} cores but Y has {}",
d,
y.cores.len()
)));
}
for k in 0..d {
if x.shape[k] != y.shape[k] {
return Err(LinalgError::ShapeError(format!(
"tt_hadamard: shape mismatch at mode {k}: {} vs {}",
x.shape[k], y.shape[k]
)));
}
}
let mut new_cores: Vec<TTCore> = Vec::with_capacity(d);
for k in 0..d {
let cx = &x.cores[k];
let cy = &y.cores[k];
let (rlx, n_k, rrx) = cx.shape;
let (rly, _, rry) = cy.shape;
let new_rl = rlx * rly;
let new_rr = rrx * rry;
let mut core = TTCore::zeros((new_rl, n_k, new_rr));
for ni in 0..n_k {
for rl_x in 0..rlx {
for rl_y in 0..rly {
let rl_new = rl_x * rly + rl_y;
for rr_x in 0..rrx {
for rr_y in 0..rry {
let rr_new = rr_x * rry + rr_y;
core.set(
rl_new,
ni,
rr_new,
cx.get(rl_x, ni, rr_x) * cy.get(rl_y, ni, rr_y),
);
}
}
}
}
}
new_cores.push(core);
}
Ok(TensorTrain {
cores: new_cores,
shape: x.shape.clone(),
})
}
pub fn inner_product(x: &TensorTrain, y: &TensorTrain) -> LinalgResult<f64> {
let d = x.cores.len();
if d != y.cores.len() {
return Err(LinalgError::ShapeError(format!(
"inner_product: X has {} cores but Y has {}",
d,
y.cores.len()
)));
}
let mut transfer = vec![vec![1.0_f64]];
for k in 0..d {
let cx = &x.cores[k];
let cy = &y.cores[k];
let (rlx, n_k, rrx) = cx.shape;
let (rly, _, rry) = cy.shape;
let mut new_transfer = vec![vec![0.0_f64; rry]; rrx];
for rr_x in 0..rrx {
for rr_y in 0..rry {
let mut val = 0.0_f64;
for rl_x in 0..rlx {
for rl_y in 0..rly {
let t_prev = if rl_x < transfer.len() && rl_y < transfer[rl_x].len() {
transfer[rl_x][rl_y]
} else {
0.0
};
if t_prev == 0.0 {
continue;
}
for ni in 0..n_k {
val += t_prev * cx.get(rl_x, ni, rr_x) * cy.get(rl_y, ni, rr_y);
}
}
}
new_transfer[rr_x][rr_y] = val;
}
}
transfer = new_transfer;
}
if transfer.len() == 1 && transfer[0].len() == 1 {
Ok(transfer[0][0])
} else {
Err(LinalgError::ComputationError(
"inner_product: final transfer matrix is not 1×1".to_string(),
))
}
}
fn determine_rank(s: &[f64], eps: f64, max_rank: usize) -> usize {
if s.is_empty() {
return 1;
}
if eps == 0.0 {
return s.len().min(max_rank).max(1);
}
let total_sq: f64 = s.iter().map(|v| v * v).sum();
if total_sq == 0.0 {
return 1;
}
let threshold = eps * eps * total_sq;
let mut tail_sq = 0.0_f64;
let mut rank = s.len();
for i in (0..s.len()).rev() {
tail_sq += s[i] * s[i];
if tail_sq > threshold {
rank = i + 1;
break;
}
rank = i;
}
rank.max(1).min(max_rank)
}
fn transpose_mat(mat: &[Vec<f64>]) -> Vec<Vec<f64>> {
if mat.is_empty() {
return Vec::new();
}
let m = mat.len();
let n = mat[0].len();
let mut t = vec![vec![0.0_f64; m]; n];
for i in 0..m {
for j in 0..n {
t[j][i] = mat[i][j];
}
}
t
}
fn u_full_trunc_rows(u_transposed: &[Vec<f64>], r_k: usize) -> Vec<Vec<f64>> {
u_transposed
.iter()
.map(|row| row.iter().take(r_k).copied().collect())
.collect()
}
fn fill_dense(tt: &TensorTrain, out: &mut [f64]) -> LinalgResult<()> {
let d = tt.shape.len();
let total: usize = tt.shape.iter().product();
if out.len() != total {
return Err(LinalgError::ShapeError(format!(
"fill_dense: output length {} != total {}",
out.len(),
total
)));
}
fn fill_recursive(
tt: &TensorTrain,
mode: usize,
current_transfer: &[f64], r_prev: usize,
base_flat: usize,
stride: usize,
out: &mut [f64],
) -> LinalgResult<()> {
let d = tt.shape.len();
let n_k = tt.shape[mode];
let core = &tt.cores[mode];
let r_next = core.shape.2;
if mode == d - 1 {
for ni in 0..n_k {
let flat_idx = base_flat + ni;
let mut val = 0.0_f64;
for rl in 0..r_prev {
val += current_transfer[rl] * core.get(rl, ni, 0);
}
out[flat_idx] = val;
}
} else {
let stride_next = stride / n_k;
for ni in 0..n_k {
let mut new_transfer = vec![0.0_f64; r_next];
for rl in 0..r_prev {
for rr in 0..r_next {
new_transfer[rr] += current_transfer[rl] * core.get(rl, ni, rr);
}
}
fill_recursive(
tt,
mode + 1,
&new_transfer,
r_next,
base_flat + ni * stride_next,
stride_next,
out,
)?;
}
}
Ok(())
}
let stride = total / tt.shape[0];
let init_transfer = vec![1.0_f64]; fill_recursive(tt, 0, &init_transfer, 1, 0, stride, out)
}
#[cfg(test)]
mod tests {
use super::*;
fn small_tensor_data() -> (Vec<f64>, Vec<usize>) {
let shape = vec![2_usize, 3, 4];
let data: Vec<f64> = (0..24).map(|x| x as f64 + 1.0).collect();
(data, shape)
}
#[test]
fn test_tt_svd_reconstruct() {
let (data, shape) = small_tensor_data();
let tt = tt_svd(&data, &shape, 100, 1e-10).expect("tt_svd ok");
let total: usize = shape.iter().product();
let mut reconstructed = vec![0.0_f64; total];
fill_dense(&tt, &mut reconstructed).expect("fill ok");
for (i, (&orig, &rec)) in data.iter().zip(reconstructed.iter()).enumerate() {
assert!(
(orig - rec).abs() < 1e-6,
"element {i}: orig={orig}, rec={rec}"
);
}
}
#[test]
fn test_tt_get() {
let (data, shape) = small_tensor_data();
let tt = tt_svd(&data, &shape, 100, 1e-10).expect("tt_svd ok");
let val = tt.get(&[1, 2, 3]).expect("get ok");
assert!((val - 24.0).abs() < 1e-6, "X[1,2,3] = {val}");
let val2 = tt.get(&[0, 0, 0]).expect("get ok");
assert!((val2 - 1.0).abs() < 1e-6, "X[0,0,0] = {val2}");
}
#[test]
fn test_tt_ranks() {
let (data, shape) = small_tensor_data();
let tt = tt_svd(&data, &shape, 100, 1e-10).expect("ok");
let ranks = tt.ranks();
assert_eq!(ranks[0], 1, "r_0 must be 1");
assert_eq!(*ranks.last().expect("last"), 1, "r_d must be 1");
assert_eq!(ranks.len(), shape.len() + 1);
}
#[test]
fn test_tt_add() {
let (data, shape) = small_tensor_data();
let tt = tt_svd(&data, &shape, 100, 1e-10).expect("ok");
let tt2 = tt_add(&tt, &tt).expect("add ok");
let val = tt2.get(&[0, 0, 0]).expect("get ok");
assert!((val - 2.0).abs() < 1e-5, "X+X[0,0,0] = {val}");
}
#[test]
fn test_tt_hadamard() {
let (data, shape) = small_tensor_data();
let tt = tt_svd(&data, &shape, 100, 1e-10).expect("ok");
let tt_had = tt_hadamard(&tt, &tt).expect("hadamard ok");
let val = tt_had.get(&[0, 0, 0]).expect("get ok");
assert!((val - 1.0).abs() < 1e-4, "X⊙X[0,0,0] = {val}");
let val2 = tt_had.get(&[1, 2, 3]).expect("get ok");
assert!((val2 - 576.0).abs() < 1e-2, "X⊙X[1,2,3] = {val2}");
}
#[test]
fn test_tt_inner_product() {
let (data, shape) = small_tensor_data();
let tt = tt_svd(&data, &shape, 100, 1e-10).expect("ok");
let ip = inner_product(&tt, &tt).expect("inner ok");
let expected: f64 = data.iter().map(|v| v * v).sum();
assert!((ip - expected).abs() < 1e-4, "inner product {ip} != {expected}");
}
#[test]
fn test_tt_compress() {
let data: Vec<f64> = (0..24).map(|x| x as f64).collect();
let shape = vec![2_usize, 3, 4];
let tt = tt_svd(&data, &shape, 100, 0.0).expect("ok");
let tt_small = tt.compress(2, 1e-3).expect("compress ok");
let val_orig = tt.get(&[0, 1, 2]).expect("ok");
let val_comp = tt_small.get(&[0, 1, 2]).expect("ok");
assert!(
(val_orig - val_comp).abs() < 2.0,
"compressed val {val_comp} too far from {val_orig}"
);
}
#[test]
fn test_tt_svd_error_bad_shape() {
let data = vec![1.0_f64, 2.0];
let result = tt_svd(&data, &[3, 3], 10, 0.0);
assert!(result.is_err());
}
#[test]
fn test_tt_core_new_validates() {
let bad = TTCore::new(vec![1.0; 5], (2, 2, 2));
assert!(bad.is_err(), "Should fail: 5 != 2*2*2=8");
let good = TTCore::new(vec![0.0; 8], (2, 2, 2));
assert!(good.is_ok());
}
}