use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array, Array2, Array3, ArrayD, ArrayView, Dimension};
use scirs2_core::numeric::{Float, NumAssign, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct TensorTrain<A>
where
A: Clone + Float + Debug,
{
pub cores: Vec<Array3<A>>,
pub ranks: Vec<usize>,
pub shape: Vec<usize>,
}
impl<A> TensorTrain<A>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
pub fn new(cores: Vec<Array3<A>>, ranks: Vec<usize>, shape: Vec<usize>) -> LinalgResult<Self> {
if cores.is_empty() {
return Err(LinalgError::ValueError(
"Cores list cannot be empty".to_string(),
));
}
if ranks.len() != shape.len() + 1 {
return Err(LinalgError::ShapeError(format!(
"Ranks vector length ({}) must be 1 more than shape length ({})",
ranks.len(),
shape.len()
)));
}
if ranks[0] != 1 || ranks[ranks.len() - 1] != 1 {
return Err(LinalgError::ValueError(
"First and last TT-ranks must be 1".to_string(),
));
}
if cores.len() != shape.len() {
return Err(LinalgError::ShapeError(format!(
"Number of _cores ({}) must match the shape length ({})",
cores.len(),
shape.len()
)));
}
for (i, core) in cores.iter().enumerate() {
if core.ndim() != 3 {
return Err(LinalgError::ShapeError(format!(
"Core {} must be 3-dimensional, got {} dimensions",
i,
core.ndim()
)));
}
if core.shape()[0] != ranks[i] {
return Err(LinalgError::ShapeError(format!(
"Core {} has first dimension {}, expected rank {}",
i,
core.shape()[0],
ranks[i]
)));
}
if core.shape()[1] != shape[i] {
return Err(LinalgError::ShapeError(format!(
"Core {} has second dimension {}, expected shape dimension {}",
i,
core.shape()[1],
shape[i]
)));
}
if core.shape()[2] != ranks[i + 1] {
return Err(LinalgError::ShapeError(format!(
"Core {} has third dimension {}, expected rank {}",
i,
core.shape()[2],
ranks[i + 1]
)));
}
}
Ok(TensorTrain {
cores,
ranks,
shape,
})
}
pub fn to_full(&self) -> LinalgResult<ArrayD<A>> {
let n_dims = self.shape.len();
let mut result =
ArrayD::zeros(scirs2_core::ndarray::IxDyn(&[self.shape[0], self.ranks[1]]));
for i in 0..self.shape[0] {
for j in 0..self.ranks[1] {
result[[i, j]] = self.cores[0][[0, i, j]];
}
}
for i in 1..n_dims {
let core = &self.cores[i];
let currentshape = result.shape().to_vec();
let current_rank = currentshape[currentshape.len() - 1];
let result_flat = result
.into_shape_with_order((
currentshape[..currentshape.len() - 1].iter().product(),
current_rank,
))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
let mut newshape = currentshape[..currentshape.len() - 1].to_vec();
newshape.push(self.shape[i]);
newshape.push(self.ranks[i + 1]);
let mut result_contracted = Array::zeros(scirs2_core::ndarray::IxDyn(&newshape));
match newshape.len() - 2 {
0 => {
for k in 0..self.shape[i] {
for l in 0..self.ranks[i + 1] {
let mut sum = A::zero();
for r in 0..current_rank {
sum += result_flat[[0, r]] * core[[r, k, l]];
}
result_contracted[[k, l]] = sum;
}
}
}
1 => {
for idx1 in 0..newshape[0] {
for k in 0..self.shape[i] {
for l in 0..self.ranks[i + 1] {
let mut sum = A::zero();
for r in 0..current_rank {
sum += result_flat[[idx1, r]] * core[[r, k, l]];
}
result_contracted[[idx1, k, l]] = sum;
}
}
}
}
2 => {
for idx1 in 0..newshape[0] {
for idx2 in 0..newshape[1] {
let flat_idx = idx1 * newshape[1] + idx2;
for k in 0..self.shape[i] {
for l in 0..self.ranks[i + 1] {
let mut sum = A::zero();
for r in 0..current_rank {
sum += result_flat[[flat_idx, r]] * core[[r, k, l]];
}
result_contracted[[idx1, idx2, k, l]] = sum;
}
}
}
}
}
3 => {
for idx1 in 0..newshape[0] {
for idx2 in 0..newshape[1] {
for idx3 in 0..newshape[2] {
let stride2 = newshape[2];
let flat_idx = idx1 * newshape[1] * stride2 + idx2 * stride2 + idx3;
for k in 0..self.shape[i] {
for l in 0..self.ranks[i + 1] {
let mut sum = A::zero();
for r in 0..current_rank {
sum += result_flat[[flat_idx, r]] * core[[r, k, l]];
}
result_contracted[[idx1, idx2, idx3, k, l]] = sum;
}
}
}
}
}
}
_ => {
fn visit_indices(
shape: &[usize],
current_indices: &mut Vec<usize>,
depth: usize,
callback: &mut dyn FnMut(&[usize]),
) {
if depth == shape.len() {
callback(current_indices);
return;
}
for i in 0..shape[depth] {
current_indices.push(i);
visit_indices(shape, current_indices, depth + 1, callback);
current_indices.pop();
}
}
fn flat_index(indices: &[usize], shape: &[usize]) -> usize {
let mut idx = 0;
let mut stride = 1;
for i in (0..indices.len()).rev() {
idx += indices[i] * stride;
if i > 0 {
stride *= shape[i];
}
}
idx
}
let free_dims = newshape[..newshape.len() - 2].to_vec();
let mut indices = Vec::new();
let mut callback = |idx: &[usize]| {
let flat_idx = flat_index(idx, &free_dims);
for k in 0..self.shape[i] {
for l in 0..self.ranks[i + 1] {
let mut sum = A::zero();
for r in 0..current_rank {
sum += result_flat[[flat_idx, r]] * core[[r, k, l]];
}
let mut idx_full = idx.to_vec();
idx_full.push(k);
idx_full.push(l);
result_contracted[scirs2_core::ndarray::IxDyn(&idx_full)] = sum;
}
}
};
visit_indices(&free_dims, &mut indices, 0, &mut callback);
}
}
result = result_contracted;
}
let mut final_result = ArrayD::zeros(scirs2_core::ndarray::IxDyn(self.shape.as_slice()));
if self.ranks[n_dims] == 1 {
fn set_values<A>(
result: &ArrayD<A>,
final_result: &mut ArrayD<A>,
current_idx: &mut Vec<usize>,
shape: &[usize],
depth: usize,
) where
A: Clone,
{
if depth == shape.len() {
let mut source_idx = current_idx.clone();
source_idx.push(0); final_result[current_idx.as_slice()] = result[source_idx.as_slice()].clone();
return;
}
for i in 0..shape[depth] {
current_idx.push(i);
set_values(result, final_result, current_idx, shape, depth + 1);
current_idx.pop();
}
}
let mut current_idx = Vec::new();
set_values(
&result,
&mut final_result,
&mut current_idx,
self.shape.as_slice(),
0,
);
} else {
return Err(LinalgError::ComputationError(
"Last rank dimension must be 1 for a valid tensor train".to_string(),
));
}
Ok(final_result)
}
pub fn get(&self, indices: &[usize]) -> LinalgResult<A> {
if indices.len() != self.shape.len() {
return Err(LinalgError::ShapeError(format!(
"Index length ({}) must match tensor dimensionality ({})",
indices.len(),
self.shape.len()
)));
}
for (i, (&idx, &dim)) in indices.iter().zip(self.shape.iter()).enumerate() {
if idx >= dim {
return Err(LinalgError::ShapeError(format!(
"Index {} for dimension {} out of bounds ({})",
idx, i, dim
)));
}
}
let mut result = self.cores[0]
.slice(scirs2_core::ndarray::s![0, indices[0], ..])
.to_owned();
for i in 1..self.shape.len() {
let core_slice = self.cores[i].slice(scirs2_core::ndarray::s![.., indices[i], ..]);
let mut new_result = Array::zeros((1, core_slice.shape()[1]));
for j in 0..core_slice.shape()[1] {
let mut sum = A::zero();
for k in 0..result.len() {
sum += result[k] * core_slice[[k, j]];
}
new_result[[0, j]] = sum;
}
let shape1 = new_result.shape()[1];
result = new_result
.into_shape_with_order(shape1)
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
}
if result.len() != 1 {
return Err(LinalgError::ComputationError(
"Final result should be a scalar".to_string(),
));
}
Ok(result[0])
}
pub fn round(&self, epsilon: A) -> LinalgResult<Self> {
if epsilon <= A::zero() {
return Err(LinalgError::ValueError(
"Epsilon must be positive".to_string(),
));
}
let mut cores = self.cores.clone();
let mut ranks = self.ranks.clone();
for i in 0..cores.len() - 1 {
let core = &cores[i];
let (r1, n, r2) = (core.shape()[0], core.shape()[1], core.shape()[2]);
let core_mat = core
.clone()
.into_shape_with_order((r1 * n, r2))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
let (q, r) = qr_decomposition(&core_mat)?;
let qshape1 = q.shape()[1];
cores[i] = q
.into_shape_with_order((r1, n, qshape1))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
let next_core = &cores[i + 1];
let (next_r1, next_n, next_r2) = (
next_core.shape()[0],
next_core.shape()[1],
next_core.shape()[2],
);
let next_core_mat = next_core
.clone()
.into_shape_with_order((next_r1, next_n * next_r2))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
let updated_next_core = r.dot(&next_core_mat);
cores[i + 1] = updated_next_core
.into_shape_with_order((r.shape()[0], next_n, next_r2))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
}
for i in (1..cores.len()).rev() {
let core = &cores[i];
let (r1, n, r2) = (core.shape()[0], core.shape()[1], core.shape()[2]);
let core_mat = core
.clone()
.into_shape_with_order((r1, n * r2))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
let (u, s, vt) = svd_with_truncation(&core_mat, epsilon)?;
ranks[i] = u.shape()[1];
cores[i] = vt
.into_shape_with_order((u.shape()[1], n, r2))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
let prev_core = &cores[i - 1];
let (prev_r1, prev_n, prev_r2) = (
prev_core.shape()[0],
prev_core.shape()[1],
prev_core.shape()[2],
);
let u_s = Array2::from_diag(&s).dot(&u.t());
let prev_core_mat = prev_core
.clone()
.into_shape_with_order((prev_r1 * prev_n, prev_r2))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
let updated_prev_core = prev_core_mat.dot(&u_s);
cores[i - 1] = updated_prev_core
.into_shape_with_order((prev_r1, prev_n, u.shape()[1]))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
}
TensorTrain::new(cores, ranks, self.shape.clone())
}
}
#[allow(dead_code)]
pub fn tensor_train_decomposition<A, D>(
tensor: &ArrayView<A, D>,
max_rank: Option<usize>,
epsilon: Option<A>,
) -> LinalgResult<TensorTrain<A>>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
D: Dimension,
{
let tensor_dyn = tensor.to_owned().into_dyn();
let shape = tensor.shape().to_vec();
let ndim = shape.len();
let mut ranks = vec![1];
ranks.resize(ndim + 1, 1);
let mut cores = Vec::with_capacity(ndim);
let mut curr_tensor = tensor_dyn;
for k in 0..ndim - 1 {
let rows = ranks[k] * shape[k];
let cols: usize = shape.iter().skip(k + 1).product();
let tensor_mat = curr_tensor
.clone()
.into_shape_with_order((rows, cols))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
let (u, s, vt) = match (max_rank, epsilon) {
(Some(max_r), Some(eps)) => {
let (u, s, vt) = svd_with_truncation_and_max_rank(&tensor_mat, eps, max_r)?;
ranks[k + 1] = u.shape()[1];
(u, s, vt)
}
(Some(max_r), None) => {
let (u, s, vt) = svd_with_max_rank(&tensor_mat, max_r)?;
ranks[k + 1] = u.shape()[1];
(u, s, vt)
}
(None, Some(eps)) => {
let (u, s, vt) = svd_with_truncation(&tensor_mat, eps)?;
ranks[k + 1] = u.shape()[1];
(u, s, vt)
}
(None, None) => {
let (u, s, vt) = svd(&tensor_mat)?;
ranks[k + 1] = s.len();
(u, s, vt)
}
};
let core = u
.into_shape_with_order((ranks[k], shape[k], ranks[k + 1]))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
cores.push(core);
let s_vt = Array2::from_diag(&s).dot(&vt);
curr_tensor = s_vt
.into_shape_with_order(
std::iter::once(ranks[k + 1])
.chain(shape.iter().skip(k + 1).copied())
.collect::<Vec<_>>(),
)
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
}
let last_core = curr_tensor
.into_shape_with_order((ranks[ndim - 1], shape[ndim - 1], ranks[ndim]))
.map_err(|e| LinalgError::ComputationError(format!("Reshape error: {}", e)))?;
cores.push(last_core);
TensorTrain::new(cores, ranks, shape)
}
#[allow(dead_code)]
fn svd<A>(matrix: &Array2<A>) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
use crate::decomposition::svd as svd_decomp;
let matrix_view = matrix.view();
svd_decomp(&matrix_view, false, None)
}
#[allow(dead_code)]
fn svd_with_truncation<A>(
matrix: &Array2<A>,
epsilon: A,
) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
let (u, s, vt) = svd(matrix)?;
let s_norm = if !s.is_empty() && s[0] > A::zero() {
s.mapv(|x| x / s[0])
} else {
s.clone()
};
let mut rank = 0;
for (i, &val) in s_norm.iter().enumerate() {
if val < epsilon {
rank = i;
break;
}
rank = i + 1;
}
rank = rank.max(1);
let u_trunc = u.slice(scirs2_core::ndarray::s![.., ..rank]).to_owned();
let s_trunc = s.slice(scirs2_core::ndarray::s![..rank]).to_owned();
let vt_trunc = vt.slice(scirs2_core::ndarray::s![..rank, ..]).to_owned();
Ok((u_trunc, s_trunc, vt_trunc))
}
#[allow(dead_code)]
fn svd_with_max_rank<A>(
matrix: &Array2<A>,
max_rank: usize,
) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
let (u, s, vt) = svd(matrix)?;
let rank = max_rank.min(s.len());
let u_trunc = u.slice(scirs2_core::ndarray::s![.., ..rank]).to_owned();
let s_trunc = s.slice(scirs2_core::ndarray::s![..rank]).to_owned();
let vt_trunc = vt.slice(scirs2_core::ndarray::s![..rank, ..]).to_owned();
Ok((u_trunc, s_trunc, vt_trunc))
}
#[allow(dead_code)]
fn svd_with_truncation_and_max_rank<A>(
matrix: &Array2<A>,
epsilon: A,
max_rank: usize,
) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
let (u, s, vt) = svd(matrix)?;
let s_norm = if !s.is_empty() && s[0] > A::zero() {
s.mapv(|x| x / s[0])
} else {
s.clone()
};
let mut rank = 0;
for (i, &val) in s_norm.iter().enumerate() {
if val < epsilon {
rank = i;
break;
}
rank = i + 1;
}
rank = rank.max(1);
rank = rank.min(max_rank);
let u_trunc = u.slice(scirs2_core::ndarray::s![.., ..rank]).to_owned();
let s_trunc = s.slice(scirs2_core::ndarray::s![..rank]).to_owned();
let vt_trunc = vt.slice(scirs2_core::ndarray::s![..rank, ..]).to_owned();
Ok((u_trunc, s_trunc, vt_trunc))
}
#[allow(dead_code)]
fn qr_decomposition<A>(matrix: &Array2<A>) -> LinalgResult<(Array2<A>, Array2<A>)>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Debug
+ Sum
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
use crate::decomposition::qr;
let matrix_view = matrix.view();
let (q, r) = qr(&matrix_view, None)?;
Ok((q, r))
}
pub type Array1<A> = Array<A, scirs2_core::ndarray::Ix1>;
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_tensor_train_decomposition_3d() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let tt = tensor_train_decomposition(&tensor.view(), None, None).expect("Operation failed");
assert_eq!(tt.shape, vec![2, 3, 2]);
assert_eq!(tt.ranks.len(), 4); assert_eq!(tt.ranks[0], 1); assert_eq!(tt.ranks[3], 1);
for i in 0..tt.cores.len() {
let core = &tt.cores[i];
assert_eq!(core.shape()[0], tt.ranks[i]);
assert_eq!(core.shape()[1], tt.shape[i]);
assert_eq!(core.shape()[2], tt.ranks[i + 1]);
}
let reconstructed = tt.to_full().expect("Operation failed");
for i in 0..2 {
for j in 0..3 {
for k in 0..2 {
assert_abs_diff_eq!(
reconstructed[[i, j, k]],
tensor[[i, j, k]],
epsilon = 1e-10
);
}
}
}
}
#[test]
#[ignore = "SVD fails for small matrices due to unimplemented eigendecomposition"]
fn test_tensor_train_decomposition_with_truncation() {
let mut tensor = ArrayD::<f64>::zeros(scirs2_core::ndarray::IxDyn(&[4, 3, 2, 2]));
for i in 0..4 {
for j in 0..3 {
for k in 0..2 {
for l in 0..2 {
tensor[[i, j, k, l]] =
(i + 1) as f64 * (j + 1) as f64 * (k + 1) as f64 * (l + 1) as f64;
}
}
}
}
let tt =
tensor_train_decomposition(&tensor.view(), Some(2), None).expect("Operation failed");
for &r in &tt.ranks {
assert!(r <= 2);
}
let reconstructed = tt.to_full().expect("Operation failed");
for i in 0..4 {
for j in 0..3 {
for k in 0..2 {
for l in 0..2 {
assert_abs_diff_eq!(
reconstructed[[i, j, k, l]],
tensor[[i, j, k, l]],
epsilon = 1e-6
);
}
}
}
}
}
#[test]
fn test_get_tensor_element() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let tt = tensor_train_decomposition(&tensor.view(), None, None).expect("Operation failed");
for i in 0..2 {
for j in 0..3 {
for k in 0..2 {
let value = tt.get(&[i, j, k]).expect("Operation failed");
assert_abs_diff_eq!(value, tensor[[i, j, k]], epsilon = 1e-10);
}
}
}
}
#[test]
#[ignore = "SVD fails for small matrices due to unimplemented eigendecomposition"]
fn test_round_tensor_train() {
let mut tensor = ArrayD::<f64>::zeros(scirs2_core::ndarray::IxDyn(&[3, 4, 3, 2]));
for i in 0..3 {
for j in 0..4 {
for k in 0..3 {
for l in 0..2 {
tensor[[i, j, k, l]] =
(i + 1) as f64 * (j + 1) as f64 * (k + 1) as f64 * (l + 1) as f64;
}
}
}
}
let tt = tensor_train_decomposition(&tensor.view(), None, None).expect("Operation failed");
for epsilon in [1e-8, 1e-4, 1e-2].iter() {
let rounded_tt = tt.round(*epsilon).expect("Operation failed");
let reconstructed = rounded_tt.to_full().expect("Operation failed");
let mut max_error = 0.0;
let norm = tensor.mapv(|x| x * x).sum().sqrt();
for i in 0..3 {
for j in 0..4 {
for k in 0..3 {
for l in 0..2 {
let error = (reconstructed[[i, j, k, l]] - tensor[[i, j, k, l]]).abs();
max_error = max_error.max(error);
}
}
}
}
let relative_error = max_error / norm;
assert!(relative_error <= *epsilon || relative_error <= 1e-10);
}
}
}