use crate::error::InterpolateError;
use scirs2_core::ndarray::{Array2, Array3, ArrayD, IxDyn};
#[derive(Debug, Clone)]
pub struct TensorTrain {
pub cores: Vec<Array3<f64>>,
pub shape: Vec<usize>,
pub ranks: Vec<usize>,
}
impl TensorTrain {
pub fn new(cores: Vec<Array3<f64>>) -> Result<Self, InterpolateError> {
if cores.is_empty() {
return Err(InterpolateError::InvalidInput {
message: "TensorTrain requires at least one core".into(),
});
}
let d = cores.len();
let mut shape = Vec::with_capacity(d);
let mut ranks = Vec::with_capacity(d + 1);
ranks.push(cores[0].shape()[0]);
for (k, core) in cores.iter().enumerate() {
let s = core.shape();
if s.len() != 3 {
return Err(InterpolateError::InvalidInput {
message: format!("Core {k} must be a 3-D array, got {}D", s.len()),
});
}
let prev_rank = *ranks.last().ok_or_else(|| InterpolateError::InvalidInput {
message: "Internal rank mismatch".into(),
})?;
if k > 0 && s[0] != prev_rank {
return Err(InterpolateError::InvalidInput {
message: format!(
"Left rank of core {k} ({}) does not match right rank of core {} ({})",
s[0],
k - 1,
prev_rank,
),
});
}
shape.push(s[1]);
ranks.push(s[2]);
}
if ranks[0] != 1
|| *ranks.last().ok_or_else(|| InterpolateError::InvalidInput {
message: "Empty ranks".into(),
})? != 1
{
return Err(InterpolateError::InvalidInput {
message: format!(
"Boundary ranks must be 1, got r_0={} and r_d={}",
ranks[0], ranks[d]
),
});
}
Ok(Self {
cores,
shape,
ranks,
})
}
pub fn eval(&self, idx: &[usize]) -> Result<f64, InterpolateError> {
let d = self.cores.len();
if idx.len() != d {
return Err(InterpolateError::DimensionMismatch(format!(
"idx has length {}, expected {d}",
idx.len()
)));
}
for (k, (&ik, &nk)) in idx.iter().zip(self.shape.iter()).enumerate() {
if ik >= nk {
return Err(InterpolateError::OutOfBounds(format!(
"Index {ik} out of range [0, {nk}) in dimension {k}"
)));
}
}
let mut v = vec![1.0f64];
for (k, &ik) in idx.iter().enumerate() {
let core = &self.cores[k];
let (r_left, _n, r_right) = (core.shape()[0], core.shape()[1], core.shape()[2]);
let mut new_v = vec![0.0f64; r_right];
for j in 0..r_right {
let mut s = 0.0f64;
for i in 0..r_left {
s += v[i] * core[[i, ik, j]];
}
new_v[j] = s;
}
v = new_v;
}
Ok(v[0])
}
pub fn to_dense(&self) -> Result<ArrayD<f64>, InterpolateError> {
let d = self.cores.len();
let total: usize = self.shape.iter().product();
let mut data = vec![0.0f64; total];
let mut idx = vec![0usize; d];
loop {
let flat = row_major_index(&idx, &self.shape);
data[flat] = self.eval(&idx)?;
let mut carry = true;
for k in (0..d).rev() {
if carry {
idx[k] += 1;
if idx[k] >= self.shape[k] {
idx[k] = 0;
} else {
carry = false;
}
}
}
if carry {
break; }
}
ArrayD::from_shape_vec(IxDyn(&self.shape), data)
.map_err(|e| InterpolateError::ComputationError(format!("to_dense shape error: {e}")))
}
pub fn norm(&self) -> f64 {
let d = self.cores.len();
if d == 0 {
return 0.0;
}
let mut gram = Array2::<f64>::eye(1);
for core in &self.cores {
let (r_left, n, r_right) = (core.shape()[0], core.shape()[1], core.shape()[2]);
let mut new_gram = Array2::<f64>::zeros((r_right, r_right));
for ik in 0..n {
for beta1 in 0..r_right {
for beta2 in 0..r_right {
let mut s = 0.0f64;
for alpha1 in 0..r_left {
for alpha2 in 0..r_left {
s += core[[alpha1, ik, beta1]]
* gram[[alpha1, alpha2]]
* core[[alpha2, ik, beta2]];
}
}
new_gram[[beta1, beta2]] += s;
}
}
}
gram = new_gram;
}
gram[[0, 0]].max(0.0).sqrt()
}
pub fn n_params(&self) -> usize {
self.cores.iter().map(|c| c.len()).sum()
}
pub fn from_dense(
tensor: &ArrayD<f64>,
max_rank: usize,
tol: f64,
) -> Result<Self, InterpolateError> {
tt_svd(tensor, max_rank, tol)
}
}
fn row_major_index(idx: &[usize], shape: &[usize]) -> usize {
let mut flat = 0usize;
let mut stride = 1usize;
for k in (0..idx.len()).rev() {
flat += idx[k] * stride;
stride *= shape[k];
}
flat
}
pub fn truncated_svd(
a: &Array2<f64>,
max_rank: usize,
tol: f64,
) -> Result<(Array2<f64>, Vec<f64>, Array2<f64>), InterpolateError> {
let m = a.nrows();
let n = a.ncols();
if m == 0 || n == 0 {
return Err(InterpolateError::InvalidInput {
message: "truncated_svd: matrix must have positive dimensions".into(),
});
}
let r_max = max_rank.min(m).min(n);
let mat_data: Vec<f64> = a.iter().copied().collect();
let (u_data, s_vals, vt_data) = svd_deflation(&mat_data, m, n, r_max, tol);
let r = s_vals.len();
let u_arr = Array2::from_shape_vec((m, r), u_data)
.map_err(|e| InterpolateError::ComputationError(format!("SVD U shape error: {e}")))?;
let vt_arr = Array2::from_shape_vec((r, n), vt_data)
.map_err(|e| InterpolateError::ComputationError(format!("SVD Vt shape error: {e}")))?;
Ok((u_arr, s_vals, vt_arr))
}
fn svd_deflation(
a: &[f64],
m: usize,
n: usize,
max_rank: usize,
tol: f64,
) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let frob_sq: f64 = a.iter().map(|x| x * x).sum();
let sigma_ref = frob_sq.sqrt();
if sigma_ref < 1e-300 {
let mut u0 = vec![0.0f64; m];
let mut v0 = vec![0.0f64; n];
if m > 0 {
u0[0] = 1.0;
}
if n > 0 {
v0[0] = 1.0;
}
return (u0, vec![0.0], v0);
}
let threshold = tol * sigma_ref;
let mut residual = a.to_vec();
let mut u_cols: Vec<Vec<f64>> = Vec::new();
let mut s_vals: Vec<f64> = Vec::new();
let mut v_rows: Vec<Vec<f64>> = Vec::new();
for _rank in 0..max_rank {
let mut vk: Vec<f64> = vec![0.0f64; n];
let mut best_norm = 0.0f64;
for i in 0..m {
let row_norm: f64 = (0..n)
.map(|j| residual[i * n + j] * residual[i * n + j])
.sum::<f64>()
.sqrt();
if row_norm > best_norm {
best_norm = row_norm;
for j in 0..n {
vk[j] = residual[i * n + j];
}
}
}
let vnorm: f64 = vk.iter().map(|x| x * x).sum::<f64>().sqrt();
if vnorm < 1e-300 {
break;
}
for x in vk.iter_mut() {
*x /= vnorm;
}
let mut uk = vec![0.0f64; m];
for _iter in 0..20 {
for i in 0..m {
let mut s = 0.0f64;
for j in 0..n {
s += residual[i * n + j] * vk[j];
}
uk[i] = s;
}
let unorm: f64 = uk.iter().map(|x| x * x).sum::<f64>().sqrt();
if unorm < 1e-300 {
break;
}
for x in uk.iter_mut() {
*x /= unorm;
}
let mut new_vk = vec![0.0f64; n];
for j in 0..n {
let mut s = 0.0f64;
for i in 0..m {
s += residual[i * n + j] * uk[i];
}
new_vk[j] = s;
}
let new_vnorm: f64 = new_vk.iter().map(|x| x * x).sum::<f64>().sqrt();
if new_vnorm < 1e-300 {
break;
}
let diff: f64 = new_vk
.iter()
.zip(vk.iter())
.map(|(a, b)| (a / new_vnorm - b).powi(2))
.sum::<f64>()
.sqrt();
for j in 0..n {
vk[j] = new_vk[j] / new_vnorm;
}
if diff < 1e-12 {
break;
}
}
let mut uk_final = vec![0.0f64; m];
for i in 0..m {
let mut s = 0.0f64;
for j in 0..n {
s += residual[i * n + j] * vk[j];
}
uk_final[i] = s;
}
let sigma: f64 = uk_final.iter().map(|x| x * x).sum::<f64>().sqrt();
if sigma < threshold {
break;
}
for x in uk_final.iter_mut() {
*x /= sigma;
}
for i in 0..m {
for j in 0..n {
residual[i * n + j] -= sigma * uk_final[i] * vk[j];
}
}
u_cols.push(uk_final);
s_vals.push(sigma);
v_rows.push(vk);
}
if s_vals.is_empty() {
let mut u0 = vec![0.0f64; m];
let mut v0 = vec![0.0f64; n];
if m > 0 {
u0[0] = 1.0;
}
if n > 0 {
v0[0] = 1.0;
}
return (u0, vec![threshold.max(1e-300)], v0);
}
let r = s_vals.len();
let mut u_data = vec![0.0f64; m * r];
for i in 0..m {
for k in 0..r {
u_data[i * r + k] = u_cols[k][i];
}
}
let vt_data: Vec<f64> = v_rows.into_iter().flatten().collect();
(u_data, s_vals, vt_data)
}
pub fn tt_svd(
tensor: &ArrayD<f64>,
max_rank: usize,
tol: f64,
) -> Result<TensorTrain, InterpolateError> {
let shape: Vec<usize> = tensor.shape().to_vec();
let d = shape.len();
if d == 0 {
return Err(InterpolateError::InvalidInput {
message: "tt_svd: tensor must have at least one dimension".into(),
});
}
if max_rank == 0 {
return Err(InterpolateError::InvalidInput {
message: "tt_svd: max_rank must be >= 1".into(),
});
}
let mut cores = Vec::with_capacity(d);
let mut r_left = 1usize;
let mut remainder: Vec<f64> = tensor.iter().copied().collect();
for k in 0..d {
let n_k = shape[k];
let n_right: usize = shape[k + 1..].iter().product::<usize>().max(1);
let rows = r_left * n_k;
let cols = n_right;
let mat = Array2::from_shape_vec((rows, cols), remainder.clone()).map_err(|e| {
InterpolateError::ComputationError(format!("tt_svd reshape error at k={k}: {e}"))
})?;
if k < d - 1 {
let (u, s, vt) = truncated_svd(&mat, max_rank, tol)?;
let r_right = s.len();
let u_flat: Vec<f64> = u.iter().copied().collect();
let core = Array3::from_shape_vec((r_left, n_k, r_right), u_flat).map_err(|e| {
InterpolateError::ComputationError(format!("tt_svd core shape error k={k}: {e}"))
})?;
cores.push(core);
let mut new_rem = vec![0.0f64; r_right * cols];
for i in 0..r_right {
let si = s[i];
for j in 0..cols {
new_rem[i * cols + j] = si * vt[[i, j]];
}
}
remainder = new_rem;
r_left = r_right;
} else {
let mat_flat: Vec<f64> = mat.iter().copied().collect();
let core = Array3::from_shape_vec((r_left, n_k, 1), mat_flat).map_err(|e| {
InterpolateError::ComputationError(format!(
"tt_svd last core shape error k={k}: {e}"
))
})?;
cores.push(core);
}
}
TensorTrain::new(cores)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::IxDyn;
fn make_rank1_tt(shape: &[usize]) -> TensorTrain {
let cores: Vec<Array3<f64>> = shape.iter().map(|&n| Array3::ones((1, n, 1))).collect();
TensorTrain::new(cores).expect("rank-1 TT valid")
}
#[test]
fn test_tt_eval_correct() {
let core0 = Array3::from_shape_fn((1, 3, 1), |(_, i, _)| (i + 1) as f64);
let core1 = Array3::from_shape_fn((1, 4, 1), |(_, j, _)| (j + 1) as f64);
let tt = TensorTrain::new(vec![core0, core1]).expect("valid TT");
for i in 0..3 {
for j in 0..4 {
let val = tt.eval(&[i, j]).expect("eval ok");
let expected = ((i + 1) * (j + 1)) as f64;
assert!(
(val - expected).abs() < 1e-12,
"TT[{i},{j}] expected {expected} got {val}"
);
}
}
}
#[test]
fn test_tt_norm() {
let tt = make_rank1_tt(&[2, 2]);
let norm = tt.norm();
assert!((norm - 2.0).abs() < 1e-10, "norm={norm}");
}
#[test]
fn test_tt_n_params() {
let tt = make_rank1_tt(&[3, 4]);
assert_eq!(tt.n_params(), 7);
}
#[test]
fn test_tt_svd_2d() {
let a = [1.0, 2.0, 3.0f64];
let b = [1.0, -1.0, 2.0, -2.0f64];
let data: Vec<f64> = a
.iter()
.flat_map(|&ai| b.iter().map(move |&bj| ai * bj))
.collect();
let tensor = ArrayD::from_shape_vec(IxDyn(&[3, 4]), data).expect("valid");
let tt = tt_svd(&tensor, 4, 1e-10).expect("TT-SVD ok");
assert_eq!(tt.shape, vec![3, 4]);
for i in 0..3 {
for j in 0..4 {
let val = tt.eval(&[i, j]).expect("eval ok");
let expected = a[i] * b[j];
assert!(
(val - expected).abs() < 1e-7,
"TT-SVD[{i},{j}] expected {expected:.6} got {val:.6}"
);
}
}
}
#[test]
fn test_tt_from_dense_rank_compression() {
let tensor = ArrayD::ones(IxDyn(&[2, 2, 2]));
let tt = TensorTrain::from_dense(&tensor, 4, 1e-8).expect("from_dense ok");
assert!(tt.n_params() <= 8, "n_params={}", tt.n_params());
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
let val = tt.eval(&[i, j, k]).expect("eval ok");
assert!((val - 1.0).abs() < 1e-6, "val={val}");
}
}
}
}
#[test]
fn test_tt_to_dense() {
let core0 = Array3::from_shape_fn((1, 2, 1), |(_, i, _)| (i + 1) as f64);
let core1 = Array3::from_shape_fn((1, 2, 1), |(_, j, _)| (j + 1) as f64);
let tt = TensorTrain::new(vec![core0, core1]).expect("valid");
let dense = tt.to_dense().expect("to_dense ok");
assert_eq!(dense.shape(), &[2, 2]);
assert!((dense[[0, 0]] - 1.0).abs() < 1e-12);
assert!((dense[[0, 1]] - 2.0).abs() < 1e-12);
assert!((dense[[1, 0]] - 2.0).abs() < 1e-12);
assert!((dense[[1, 1]] - 4.0).abs() < 1e-12);
}
}