use super::DenseTensor;
#[derive(Debug, Clone)]
pub struct TensorTrainConfig {
pub max_rank: usize,
pub tolerance: f64,
}
impl Default for TensorTrainConfig {
fn default() -> Self {
Self {
max_rank: 0,
tolerance: 1e-12,
}
}
}
#[derive(Debug, Clone)]
pub struct TTCore {
pub data: Vec<f64>,
pub rank_left: usize,
pub mode_size: usize,
pub rank_right: usize,
}
impl TTCore {
pub fn new(data: Vec<f64>, rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
assert_eq!(data.len(), rank_left * mode_size * rank_right);
Self {
data,
rank_left,
mode_size,
rank_right,
}
}
pub fn zeros(rank_left: usize, mode_size: usize, rank_right: usize) -> Self {
Self {
data: vec![0.0; rank_left * mode_size * rank_right],
rank_left,
mode_size,
rank_right,
}
}
pub fn get_matrix(&self, i: usize) -> Vec<f64> {
let start = i * self.rank_left * self.rank_right;
let end = start + self.rank_left * self.rank_right;
let mut result = vec![0.0; self.rank_left * self.rank_right];
for rl in 0..self.rank_left {
for rr in 0..self.rank_right {
let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
result[rl * self.rank_right + rr] = self.data[idx];
}
}
result
}
pub fn set(&mut self, rl: usize, i: usize, rr: usize, value: f64) {
let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
self.data[idx] = value;
}
pub fn get(&self, rl: usize, i: usize, rr: usize) -> f64 {
let idx = rl * self.mode_size * self.rank_right + i * self.rank_right + rr;
self.data[idx]
}
}
#[derive(Debug, Clone)]
pub struct TensorTrain {
pub cores: Vec<TTCore>,
pub shape: Vec<usize>,
pub ranks: Vec<usize>,
}
impl TensorTrain {
pub fn from_cores(cores: Vec<TTCore>) -> Self {
let shape: Vec<usize> = cores.iter().map(|c| c.mode_size).collect();
let mut ranks = vec![1];
for core in &cores {
ranks.push(core.rank_right);
}
Self {
cores,
shape,
ranks,
}
}
pub fn from_vectors(vectors: Vec<Vec<f64>>) -> Self {
let cores: Vec<TTCore> = vectors
.into_iter()
.map(|v| {
let n = v.len();
TTCore::new(v, 1, n, 1)
})
.collect();
Self::from_cores(cores)
}
pub fn order(&self) -> usize {
self.shape.len()
}
pub fn max_rank(&self) -> usize {
self.ranks.iter().cloned().max().unwrap_or(1)
}
pub fn storage(&self) -> usize {
self.cores.iter().map(|c| c.data.len()).sum()
}
pub fn eval(&self, indices: &[usize]) -> f64 {
assert_eq!(indices.len(), self.order());
let mut result = vec![1.0];
let mut current_size = 1;
for (k, &idx) in indices.iter().enumerate() {
let core = &self.cores[k];
let new_size = core.rank_right;
let mut new_result = vec![0.0; new_size];
for rr in 0..new_size {
for rl in 0..current_size {
new_result[rr] += result[rl] * core.get(rl, idx, rr);
}
}
result = new_result;
current_size = new_size;
}
result[0]
}
pub fn to_dense(&self) -> DenseTensor {
let total_size: usize = self.shape.iter().product();
let mut data = vec![0.0; total_size];
let mut indices = vec![0usize; self.order()];
for flat_idx in 0..total_size {
data[flat_idx] = self.eval(&indices);
for k in (0..self.order()).rev() {
indices[k] += 1;
if indices[k] < self.shape[k] {
break;
}
indices[k] = 0;
}
}
DenseTensor::new(data, self.shape.clone())
}
pub fn dot(&self, other: &TensorTrain) -> f64 {
assert_eq!(self.shape, other.shape);
let mut z = vec![1.0]; let mut z_rows = 1;
let mut z_cols = 1;
for k in 0..self.order() {
let c1 = &self.cores[k];
let c2 = &other.cores[k];
let n = c1.mode_size;
let new_rows = c1.rank_right;
let new_cols = c2.rank_right;
let mut new_z = vec![0.0; new_rows * new_cols];
for i in 0..n {
for r1l in 0..z_rows {
for r2l in 0..z_cols {
let z_val = z[r1l * z_cols + r2l];
for r1r in 0..c1.rank_right {
for r2r in 0..c2.rank_right {
new_z[r1r * new_cols + r2r] +=
z_val * c1.get(r1l, i, r1r) * c2.get(r2l, i, r2r);
}
}
}
}
}
z = new_z;
z_rows = new_rows;
z_cols = new_cols;
}
z[0]
}
pub fn frobenius_norm(&self) -> f64 {
self.dot(self).sqrt()
}
pub fn add(&self, other: &TensorTrain) -> TensorTrain {
assert_eq!(self.shape, other.shape);
let mut new_cores = Vec::new();
for k in 0..self.order() {
let c1 = &self.cores[k];
let c2 = &other.cores[k];
let new_rl = if k == 0 {
1
} else {
c1.rank_left + c2.rank_left
};
let new_rr = if k == self.order() - 1 {
1
} else {
c1.rank_right + c2.rank_right
};
let n = c1.mode_size;
let mut new_data = vec![0.0; new_rl * n * new_rr];
let mut new_core = TTCore::new(new_data.clone(), new_rl, n, new_rr);
for i in 0..n {
if k == 0 {
for rr1 in 0..c1.rank_right {
new_core.set(0, i, rr1, c1.get(0, i, rr1));
}
for rr2 in 0..c2.rank_right {
new_core.set(0, i, c1.rank_right + rr2, c2.get(0, i, rr2));
}
} else if k == self.order() - 1 {
for rl1 in 0..c1.rank_left {
new_core.set(rl1, i, 0, c1.get(rl1, i, 0));
}
for rl2 in 0..c2.rank_left {
new_core.set(c1.rank_left + rl2, i, 0, c2.get(rl2, i, 0));
}
} else {
for rl1 in 0..c1.rank_left {
for rr1 in 0..c1.rank_right {
new_core.set(rl1, i, rr1, c1.get(rl1, i, rr1));
}
}
for rl2 in 0..c2.rank_left {
for rr2 in 0..c2.rank_right {
new_core.set(
c1.rank_left + rl2,
i,
c1.rank_right + rr2,
c2.get(rl2, i, rr2),
);
}
}
}
}
new_cores.push(new_core);
}
TensorTrain::from_cores(new_cores)
}
pub fn scale(&self, alpha: f64) -> TensorTrain {
let mut new_cores = self.cores.clone();
for val in new_cores[0].data.iter_mut() {
*val *= alpha;
}
TensorTrain::from_cores(new_cores)
}
pub fn from_dense(tensor: &DenseTensor, config: &TensorTrainConfig) -> Self {
let d = tensor.order();
if d == 0 {
return TensorTrain::from_cores(vec![]);
}
let mut cores = Vec::new();
let mut c = tensor.data.clone();
let mut remaining_shape = tensor.shape.clone();
let mut left_rank = 1usize;
for k in 0..d - 1 {
let n_k = remaining_shape[0];
let rest_size: usize = remaining_shape[1..].iter().product();
let rows = left_rank * n_k;
let cols = rest_size;
let (u, s, vt, new_rank) = simple_svd(&c, rows, cols, config);
let core = TTCore::new(u, left_rank, n_k, new_rank);
cores.push(core);
c = Vec::with_capacity(new_rank * cols);
for i in 0..new_rank {
for j in 0..cols {
c.push(s[i] * vt[i * cols + j]);
}
}
left_rank = new_rank;
remaining_shape.remove(0);
}
let n_d = remaining_shape[0];
let last_core = TTCore::new(c, left_rank, n_d, 1);
cores.push(last_core);
TensorTrain::from_cores(cores)
}
}
fn simple_svd(
a: &[f64],
rows: usize,
cols: usize,
config: &TensorTrainConfig,
) -> (Vec<f64>, Vec<f64>, Vec<f64>, usize) {
let max_rank = if config.max_rank > 0 {
config.max_rank.min(rows).min(cols)
} else {
rows.min(cols)
};
let mut u = Vec::new();
let mut s = Vec::new();
let mut vt = Vec::new();
let mut a_residual = a.to_vec();
for _ in 0..max_rank {
let (sigma, u_vec, v_vec) = power_iteration(&a_residual, rows, cols, 20);
if sigma < config.tolerance {
break;
}
s.push(sigma);
u.extend(u_vec.iter());
vt.extend(v_vec.iter());
for i in 0..rows {
for j in 0..cols {
a_residual[i * cols + j] -= sigma * u_vec[i] * v_vec[j];
}
}
}
let rank = s.len();
(u, s, vt, rank.max(1))
}
fn power_iteration(
a: &[f64],
rows: usize,
cols: usize,
max_iter: usize,
) -> (f64, Vec<f64>, Vec<f64>) {
let mut v: Vec<f64> = (0..cols)
.map(|i| ((i * 2654435769) as f64 / 4294967296.0) * 2.0 - 1.0)
.collect();
normalize(&mut v);
let mut u = vec![0.0; rows];
for _ in 0..max_iter {
for i in 0..rows {
u[i] = 0.0;
for j in 0..cols {
u[i] += a[i * cols + j] * v[j];
}
}
normalize(&mut u);
for j in 0..cols {
v[j] = 0.0;
for i in 0..rows {
v[j] += a[i * cols + j] * u[i];
}
}
normalize(&mut v);
}
let mut av = vec![0.0; rows];
for i in 0..rows {
for j in 0..cols {
av[i] += a[i * cols + j] * v[j];
}
}
let sigma: f64 = u.iter().zip(av.iter()).map(|(ui, avi)| ui * avi).sum();
(sigma.abs(), u, v)
}
fn normalize(v: &mut [f64]) {
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-15 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tt_eval() {
let v1 = vec![1.0, 2.0];
let v2 = vec![3.0, 4.0];
let tt = TensorTrain::from_vectors(vec![v1, v2]);
assert!((tt.eval(&[0, 0]) - 3.0).abs() < 1e-10);
assert!((tt.eval(&[0, 1]) - 4.0).abs() < 1e-10);
assert!((tt.eval(&[1, 0]) - 6.0).abs() < 1e-10);
assert!((tt.eval(&[1, 1]) - 8.0).abs() < 1e-10);
}
#[test]
fn test_tt_dot() {
let v1 = vec![1.0, 2.0];
let v2 = vec![3.0, 4.0];
let tt = TensorTrain::from_vectors(vec![v1, v2]);
let norm_sq = tt.dot(&tt);
assert!((norm_sq - 125.0).abs() < 1e-10);
}
#[test]
fn test_tt_from_dense() {
let tensor = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let tt = TensorTrain::from_dense(&tensor, &TensorTrainConfig::default());
let reconstructed = tt.to_dense();
let error: f64 = tensor
.data
.iter()
.zip(reconstructed.data.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
assert!(error < 1e-6);
}
#[test]
fn test_tt_add() {
let v1 = vec![1.0, 2.0];
let v2 = vec![3.0, 4.0];
let tt1 = TensorTrain::from_vectors(vec![v1.clone(), v2.clone()]);
let tt2 = TensorTrain::from_vectors(vec![v1, v2]);
let sum = tt1.add(&tt2);
assert!((sum.eval(&[0, 0]) - 6.0).abs() < 1e-10);
assert!((sum.eval(&[1, 1]) - 16.0).abs() < 1e-10);
}
}