Skip to main content

oxicuda_solver/dense/
tensor_decomp.rs

1//! Tensor decomposition algorithms: CP, Tucker, and Tensor-Train (TT).
2//!
3//! This module provides CPU-side implementations of the three most widely used
4//! tensor decomposition methods, which can be used to generate PTX kernels for
5//! GPU execution.
6//!
7//! ## Decompositions
8//!
9//! - **CP (CANDECOMP/PARAFAC)** — decomposes a tensor into a sum of rank-one
10//!   tensors via Alternating Least Squares (ALS).
11//! - **Tucker** — decomposes a tensor into a core tensor multiplied by factor
12//!   matrices along each mode, via HOSVD or HOOI.
13//! - **Tensor-Train (TT)** — decomposes a tensor into a chain of 3D cores
14//!   via sequential SVDs.
15//!
16//! ## References
17//!
18//! - Kolda & Bader, "Tensor Decompositions and Applications", SIAM Review 2009
19//! - Oseledets, "Tensor-Train Decomposition", SIAM J. Sci. Comput. 2011
20
21use crate::error::{SolverError, SolverResult};
22
23// ---------------------------------------------------------------------------
24// Matrix (minimal linear algebra support)
25// ---------------------------------------------------------------------------
26
27/// A dense row-major matrix.
28#[derive(Debug, Clone)]
29pub struct Matrix {
30    /// Number of rows.
31    pub rows: usize,
32    /// Number of columns.
33    pub cols: usize,
34    /// Row-major data.
35    pub data: Vec<f64>,
36}
37
38impl Matrix {
39    /// Creates a new matrix from shape and data.
40    pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> SolverResult<Self> {
41        if data.len() != rows * cols {
42            return Err(SolverError::DimensionMismatch(format!(
43                "matrix {}x{} requires {} elements, got {}",
44                rows,
45                cols,
46                rows * cols,
47                data.len()
48            )));
49        }
50        Ok(Self { rows, cols, data })
51    }
52
53    /// Creates a zero matrix.
54    pub fn zeros(rows: usize, cols: usize) -> Self {
55        Self {
56            rows,
57            cols,
58            data: vec![0.0; rows * cols],
59        }
60    }
61
62    /// Creates an identity matrix.
63    pub fn eye(n: usize) -> Self {
64        let mut data = vec![0.0; n * n];
65        for i in 0..n {
66            data[i * n + i] = 1.0;
67        }
68        Self {
69            rows: n,
70            cols: n,
71            data,
72        }
73    }
74
75    /// Element access.
76    #[inline]
77    pub fn get(&self, r: usize, c: usize) -> f64 {
78        self.data[r * self.cols + c]
79    }
80
81    /// Mutable element access.
82    #[inline]
83    pub fn set(&mut self, r: usize, c: usize, v: f64) {
84        self.data[r * self.cols + c] = v;
85    }
86
87    /// Returns the transpose.
88    pub fn transpose(&self) -> Self {
89        let mut out = Self::zeros(self.cols, self.rows);
90        for r in 0..self.rows {
91            for c in 0..self.cols {
92                out.set(c, r, self.get(r, c));
93            }
94        }
95        out
96    }
97
98    /// Matrix multiplication: self * other.
99    pub fn matmul(&self, other: &Matrix) -> SolverResult<Matrix> {
100        if self.cols != other.rows {
101            return Err(SolverError::DimensionMismatch(format!(
102                "matmul: {}x{} * {}x{}",
103                self.rows, self.cols, other.rows, other.cols
104            )));
105        }
106        let mut out = Matrix::zeros(self.rows, other.cols);
107        for i in 0..self.rows {
108            for k in 0..self.cols {
109                let a_ik = self.get(i, k);
110                for j in 0..other.cols {
111                    let cur = out.get(i, j);
112                    out.set(i, j, cur + a_ik * other.get(k, j));
113                }
114            }
115        }
116        Ok(out)
117    }
118
119    /// Frobenius norm.
120    pub fn frobenius_norm(&self) -> f64 {
121        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
122    }
123
124    /// Column norms (L2 norm of each column).
125    pub fn column_norms(&self) -> Vec<f64> {
126        let mut norms = vec![0.0; self.cols];
127        for r in 0..self.rows {
128            for (c, norm) in norms.iter_mut().enumerate() {
129                let v = self.get(r, c);
130                *norm += v * v;
131            }
132        }
133        norms.iter().map(|s| s.sqrt()).collect()
134    }
135
136    /// Normalize columns in-place, returning the norms.
137    pub fn normalize_columns(&mut self) -> Vec<f64> {
138        let norms = self.column_norms();
139        for (c, &norm) in norms.iter().enumerate() {
140            if norm > 1e-15 {
141                for r in 0..self.rows {
142                    let v = self.get(r, c) / norm;
143                    self.set(r, c, v);
144                }
145            }
146        }
147        norms
148    }
149
150    /// Extract a column as a vector.
151    pub fn column(&self, c: usize) -> Vec<f64> {
152        (0..self.rows).map(|r| self.get(r, c)).collect()
153    }
154
155    /// Truncated SVD via deflated power iteration.
156    ///
157    /// Computes the top `rank` singular triplets one at a time, deflating the
158    /// matrix after each. Returns (U, sigma, V) where U is m×k, sigma has k
159    /// entries, V is n×k.
160    pub fn svd_truncated(&self, rank: usize) -> SolverResult<(Matrix, Vec<f64>, Matrix)> {
161        let m = self.rows;
162        let n = self.cols;
163        let k = rank.min(m).min(n);
164        if k == 0 {
165            return Ok((Matrix::zeros(m, 0), Vec::new(), Matrix::zeros(n, 0)));
166        }
167
168        let mut u_mat = Matrix::zeros(m, k);
169        let mut v_mat = Matrix::zeros(n, k);
170        let mut sigma = vec![0.0; k];
171
172        // Work on a deflated copy of the matrix
173        let mut deflated = self.clone();
174
175        for (s, sigma_s) in sigma.iter_mut().enumerate().take(k) {
176            // Initialize v with a deterministic vector
177            let mut v: Vec<f64> = (0..n)
178                .map(|i| ((i + 1) as f64 * (s + 1) as f64 * 0.7 + 0.3).sin())
179                .collect();
180            // Normalize v
181            let mut vnorm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
182            if vnorm > 1e-15 {
183                for x in &mut v {
184                    *x /= vnorm;
185                }
186            }
187
188            let max_iters = 200;
189            for _ in 0..max_iters {
190                // u = A * v
191                let mut u: Vec<f64> = (0..m)
192                    .map(|i| {
193                        v.iter()
194                            .enumerate()
195                            .map(|(j, &vj)| deflated.get(i, j) * vj)
196                            .sum()
197                    })
198                    .collect();
199
200                // sigma_new = ||u||
201                let sigma_new: f64 = u.iter().map(|x| x * x).sum::<f64>().sqrt();
202                if sigma_new < 1e-15 {
203                    break;
204                }
205                // Normalize u
206                for x in &mut u {
207                    *x /= sigma_new;
208                }
209
210                // v_new = A^T * u
211                let mut v_new: Vec<f64> = (0..n)
212                    .map(|j| {
213                        u.iter()
214                            .enumerate()
215                            .map(|(i, &ui)| deflated.get(i, j) * ui)
216                            .sum()
217                    })
218                    .collect();
219
220                // Normalize v_new
221                vnorm = v_new.iter().map(|x| x * x).sum::<f64>().sqrt();
222                if vnorm < 1e-15 {
223                    break;
224                }
225                for x in &mut v_new {
226                    *x /= vnorm;
227                }
228
229                // Check convergence: |sigma_new - prev| / sigma_new
230                let diff: f64 = v
231                    .iter()
232                    .zip(v_new.iter())
233                    .map(|(a, b)| (a - b) * (a - b))
234                    .sum::<f64>()
235                    .sqrt();
236                v = v_new;
237
238                if diff < 1e-12 {
239                    break;
240                }
241            }
242
243            // Final pass: compute sigma from u = A*v
244            let mut u: Vec<f64> = (0..m)
245                .map(|i| {
246                    v.iter()
247                        .enumerate()
248                        .map(|(j, &vj)| deflated.get(i, j) * vj)
249                        .sum()
250                })
251                .collect();
252            let sv = u.iter().map(|x| x * x).sum::<f64>().sqrt();
253            if sv > 1e-15 {
254                for x in &mut u {
255                    *x /= sv;
256                }
257            }
258
259            *sigma_s = sv;
260            for (i, &ui) in u.iter().enumerate() {
261                u_mat.set(i, s, ui);
262            }
263            for (j, &vj) in v.iter().enumerate() {
264                v_mat.set(j, s, vj);
265            }
266
267            // Deflate: A <- A - sigma * u * v^T
268            for (i, &ui) in u.iter().enumerate() {
269                for (j, &vj) in v.iter().enumerate() {
270                    let old = deflated.get(i, j);
271                    deflated.set(i, j, old - sv * ui * vj);
272                }
273            }
274        }
275
276        Ok((u_mat, sigma, v_mat))
277    }
278}
279
280/// Modified Gram-Schmidt QR decomposition.
281/// Returns (Q, R) where Q is m×k orthonormal, R is k×k upper triangular.
282fn qr_gram_schmidt(a: &Matrix) -> (Matrix, Matrix) {
283    let m = a.rows;
284    let k = a.cols;
285    let mut q = a.clone();
286    let mut r = Matrix::zeros(k, k);
287
288    for j in 0..k {
289        // Orthogonalize against previous columns
290        for i in 0..j {
291            let mut dot = 0.0;
292            for row in 0..m {
293                dot += q.get(row, i) * q.get(row, j);
294            }
295            r.set(i, j, dot);
296            for row in 0..m {
297                let v = q.get(row, j) - dot * q.get(row, i);
298                q.set(row, j, v);
299            }
300        }
301        // Normalize
302        let mut norm = 0.0;
303        for row in 0..m {
304            let v = q.get(row, j);
305            norm += v * v;
306        }
307        norm = norm.sqrt();
308        r.set(j, j, norm);
309        if norm > 1e-15 {
310            for row in 0..m {
311                let v = q.get(row, j) / norm;
312                q.set(row, j, v);
313            }
314        }
315    }
316
317    (q, r)
318}
319
320// ---------------------------------------------------------------------------
321// Tensor
322// ---------------------------------------------------------------------------
323
324/// A dense multi-dimensional tensor.
325#[derive(Debug, Clone)]
326pub struct Tensor {
327    /// Shape of the tensor (dimensions along each mode).
328    shape: Vec<usize>,
329    /// Flattened data in row-major (C-order) layout.
330    data: Vec<f64>,
331}
332
333impl Tensor {
334    /// Creates a new tensor from shape and data.
335    pub fn new(shape: Vec<usize>, data: Vec<f64>) -> SolverResult<Self> {
336        let numel: usize = shape.iter().product();
337        if data.len() != numel {
338            return Err(SolverError::DimensionMismatch(format!(
339                "tensor with shape {:?} requires {} elements, got {}",
340                shape,
341                numel,
342                data.len()
343            )));
344        }
345        if shape.is_empty() {
346            return Err(SolverError::DimensionMismatch(
347                "tensor must have at least one dimension".to_string(),
348            ));
349        }
350        Ok(Self { shape, data })
351    }
352
353    /// Creates a zero tensor with the given shape.
354    pub fn zeros(shape: Vec<usize>) -> Self {
355        let numel: usize = shape.iter().product();
356        Self {
357            shape,
358            data: vec![0.0; numel],
359        }
360    }
361
362    /// Returns the number of dimensions (modes).
363    pub fn ndim(&self) -> usize {
364        self.shape.len()
365    }
366
367    /// Returns the shape.
368    pub fn shape(&self) -> &[usize] {
369        &self.shape
370    }
371
372    /// Returns the total number of elements.
373    pub fn numel(&self) -> usize {
374        self.data.len()
375    }
376
377    /// Returns a reference to the data.
378    pub fn data(&self) -> &[f64] {
379        &self.data
380    }
381
382    /// Computes the linear index from multi-dimensional indices (row-major).
383    fn linear_index(&self, indices: &[usize]) -> SolverResult<usize> {
384        if indices.len() != self.shape.len() {
385            return Err(SolverError::DimensionMismatch(format!(
386                "expected {} indices, got {}",
387                self.shape.len(),
388                indices.len()
389            )));
390        }
391        let mut idx = 0;
392        let mut stride = 1;
393        for d in (0..self.shape.len()).rev() {
394            if indices[d] >= self.shape[d] {
395                return Err(SolverError::DimensionMismatch(format!(
396                    "index {} out of range for dimension {} with size {}",
397                    indices[d], d, self.shape[d]
398                )));
399            }
400            idx += indices[d] * stride;
401            stride *= self.shape[d];
402        }
403        Ok(idx)
404    }
405
406    /// Gets an element by multi-dimensional index.
407    pub fn get(&self, indices: &[usize]) -> SolverResult<f64> {
408        let idx = self.linear_index(indices)?;
409        Ok(self.data[idx])
410    }
411
412    /// Sets an element by multi-dimensional index.
413    pub fn set(&mut self, indices: &[usize], value: f64) -> SolverResult<()> {
414        let idx = self.linear_index(indices)?;
415        self.data[idx] = value;
416        Ok(())
417    }
418
419    /// Mode-n unfolding (matricization).
420    ///
421    /// Rearranges the tensor into a matrix where mode `n` becomes the row
422    /// dimension and all other modes are collapsed into the column dimension.
423    pub fn unfold(&self, mode: usize) -> SolverResult<Matrix> {
424        if mode >= self.ndim() {
425            return Err(SolverError::DimensionMismatch(format!(
426                "mode {} out of range for {}-dimensional tensor",
427                mode,
428                self.ndim()
429            )));
430        }
431
432        let rows = self.shape[mode];
433        let cols = self.numel() / rows;
434        let mut mat = Matrix::zeros(rows, cols);
435
436        let ndim = self.ndim();
437        let mut indices = vec![0usize; ndim];
438        for flat in 0..self.numel() {
439            // Compute multi-index from flat index
440            let mut rem = flat;
441            for d in (0..ndim).rev() {
442                indices[d] = rem % self.shape[d];
443                rem /= self.shape[d];
444            }
445
446            // Row = index along unfolded mode
447            let row = indices[mode];
448
449            // Column = multi-index of remaining modes, in order
450            let mut col = 0;
451            let mut col_stride = 1;
452            for d in (0..ndim).rev() {
453                if d != mode {
454                    col += indices[d] * col_stride;
455                    col_stride *= self.shape[d];
456                }
457            }
458
459            mat.set(row, col, self.data[flat]);
460        }
461
462        Ok(mat)
463    }
464
465    /// Folds a matrix back into a tensor (inverse of unfold).
466    pub fn fold(matrix: &Matrix, mode: usize, shape: &[usize]) -> SolverResult<Tensor> {
467        let ndim = shape.len();
468        if mode >= ndim {
469            return Err(SolverError::DimensionMismatch(format!(
470                "mode {} out of range for {}-dimensional tensor",
471                mode, ndim
472            )));
473        }
474        if matrix.rows != shape[mode] {
475            return Err(SolverError::DimensionMismatch(format!(
476                "matrix rows {} != shape[{}] = {}",
477                matrix.rows, mode, shape[mode]
478            )));
479        }
480
481        let numel: usize = shape.iter().product();
482        let mut data = vec![0.0; numel];
483
484        let mut indices = vec![0usize; ndim];
485        for (flat, datum) in data.iter_mut().enumerate() {
486            let mut rem = flat;
487            for d in (0..ndim).rev() {
488                indices[d] = rem % shape[d];
489                rem /= shape[d];
490            }
491
492            let row = indices[mode];
493            let mut col = 0;
494            let mut col_stride = 1;
495            for d in (0..ndim).rev() {
496                if d != mode {
497                    col += indices[d] * col_stride;
498                    col_stride *= shape[d];
499                }
500            }
501
502            *datum = matrix.get(row, col);
503        }
504
505        Ok(Tensor {
506            shape: shape.to_vec(),
507            data,
508        })
509    }
510
511    /// Frobenius norm of the tensor.
512    pub fn frobenius_norm(&self) -> f64 {
513        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
514    }
515}
516
517// ---------------------------------------------------------------------------
518// Utility functions
519// ---------------------------------------------------------------------------
520
521/// Khatri-Rao product (column-wise Kronecker product).
522///
523/// Given A (I×R) and B (J×R), produces C ((I*J)×R) where
524/// `c_{(i-1)*J+j, r} = a_{i,r} * b_{j,r}`.
525pub fn khatri_rao_product(a: &Matrix, b: &Matrix) -> SolverResult<Matrix> {
526    if a.cols != b.cols {
527        return Err(SolverError::DimensionMismatch(format!(
528            "khatri-rao requires same number of columns: {} vs {}",
529            a.cols, b.cols
530        )));
531    }
532    let r = a.cols;
533    let rows = a.rows * b.rows;
534    let mut out = Matrix::zeros(rows, r);
535    for col in 0..r {
536        for i in 0..a.rows {
537            for j in 0..b.rows {
538                out.set(i * b.rows + j, col, a.get(i, col) * b.get(j, col));
539            }
540        }
541    }
542    Ok(out)
543}
544
545/// Element-wise (Hadamard) product of two matrices.
546pub fn hadamard_product(a: &Matrix, b: &Matrix) -> SolverResult<Matrix> {
547    if a.rows != b.rows || a.cols != b.cols {
548        return Err(SolverError::DimensionMismatch(format!(
549            "hadamard requires same dimensions: {}x{} vs {}x{}",
550            a.rows, a.cols, b.rows, b.cols
551        )));
552    }
553    let data: Vec<f64> = a
554        .data
555        .iter()
556        .zip(b.data.iter())
557        .map(|(x, y)| x * y)
558        .collect();
559    Matrix::new(a.rows, a.cols, data)
560}
561
562/// Mode-n product: multiplies a tensor by a matrix along mode n.
563///
564/// If tensor has shape `[I_0, ..., I_n, ..., I_{N-1}]` and matrix is `J × I_n`,
565/// the result has shape `[I_0, ..., J, ..., I_{N-1}]`.
566pub fn mode_n_product(tensor: &Tensor, matrix: &Matrix, mode: usize) -> SolverResult<Tensor> {
567    if mode >= tensor.ndim() {
568        return Err(SolverError::DimensionMismatch(format!(
569            "mode {} out of range for {}-dimensional tensor",
570            mode,
571            tensor.ndim()
572        )));
573    }
574    if matrix.cols != tensor.shape()[mode] {
575        return Err(SolverError::DimensionMismatch(format!(
576            "matrix cols {} != tensor dimension {} size {}",
577            matrix.cols,
578            mode,
579            tensor.shape()[mode]
580        )));
581    }
582
583    let unfolded = tensor.unfold(mode)?;
584    let result_mat = matrix.matmul(&unfolded)?;
585
586    let mut new_shape = tensor.shape().to_vec();
587    new_shape[mode] = matrix.rows;
588
589    Tensor::fold(&result_mat, mode, &new_shape)
590}
591
592// ---------------------------------------------------------------------------
593// CP Decomposition
594// ---------------------------------------------------------------------------
595
596/// CP (CANDECOMP/PARAFAC) decomposition result.
597///
598/// Represents a tensor as a weighted sum of rank-one tensors:
599/// `X ≈ Σ_r λ_r · a_r^(1) ◦ a_r^(2) ◦ ... ◦ a_r^(N)`
600#[derive(Debug, Clone)]
601pub struct CpDecomposition {
602    /// Weights (lambda) for each rank-one component.
603    pub weights: Vec<f64>,
604    /// Factor matrices, one per mode. Factor\[n\] is `I_n × R`.
605    pub factors: Vec<Matrix>,
606}
607
608impl CpDecomposition {
609    /// Returns the rank of the decomposition.
610    pub fn rank(&self) -> usize {
611        self.weights.len()
612    }
613
614    /// Reconstructs the full tensor from the CP factors.
615    pub fn reconstruct(&self) -> SolverResult<Tensor> {
616        if self.factors.is_empty() {
617            return Err(SolverError::InternalError(
618                "CP decomposition has no factors".to_string(),
619            ));
620        }
621
622        let shape: Vec<usize> = self.factors.iter().map(|f| f.rows).collect();
623        let numel: usize = shape.iter().product();
624        let ndim = shape.len();
625        let rank = self.rank();
626        let mut data = vec![0.0; numel];
627
628        let mut indices = vec![0usize; ndim];
629        for (flat, datum) in data.iter_mut().enumerate() {
630            let mut rem = flat;
631            for d in (0..ndim).rev() {
632                indices[d] = rem % shape[d];
633                rem /= shape[d];
634            }
635
636            let mut val = 0.0;
637            for r in 0..rank {
638                let mut term = self.weights[r];
639                for (d, idx) in indices.iter().enumerate() {
640                    term *= self.factors[d].get(*idx, r);
641                }
642                val += term;
643            }
644            *datum = val;
645        }
646
647        Tensor::new(shape, data)
648    }
649
650    /// Computes the relative fit error: `||X - X_hat|| / ||X||`.
651    pub fn fit_error(&self, original: &Tensor) -> SolverResult<f64> {
652        let reconstructed = self.reconstruct()?;
653        let orig_norm = original.frobenius_norm();
654        if orig_norm < 1e-15 {
655            return Ok(0.0);
656        }
657        let diff_norm: f64 = original
658            .data()
659            .iter()
660            .zip(reconstructed.data().iter())
661            .map(|(a, b)| (a - b) * (a - b))
662            .sum::<f64>()
663            .sqrt();
664        Ok(diff_norm / orig_norm)
665    }
666}
667
668/// Configuration for CP-ALS algorithm.
669#[derive(Debug, Clone)]
670pub struct CpAlsConfig {
671    /// Target rank.
672    pub rank: usize,
673    /// Maximum number of ALS iterations.
674    pub max_iterations: usize,
675    /// Convergence tolerance.
676    pub tolerance: f64,
677    /// Whether to normalize factor columns after each iteration.
678    pub normalize_factors: bool,
679}
680
681impl Default for CpAlsConfig {
682    fn default() -> Self {
683        Self {
684            rank: 3,
685            max_iterations: 100,
686            tolerance: 1e-8,
687            normalize_factors: true,
688        }
689    }
690}
691
692/// CP decomposition via Alternating Least Squares (ALS).
693///
694/// For each mode n, we fix all other factor matrices and solve the least
695/// squares problem for factor matrix A^(n).
696pub fn cp_als(tensor: &Tensor, config: &CpAlsConfig) -> SolverResult<CpDecomposition> {
697    let ndim = tensor.ndim();
698    let rank = config.rank;
699
700    if rank == 0 {
701        return Err(SolverError::InternalError(
702            "CP rank must be positive".to_string(),
703        ));
704    }
705
706    // Initialize factor matrices deterministically
707    let mut factors: Vec<Matrix> = Vec::with_capacity(ndim);
708    for n in 0..ndim {
709        let rows = tensor.shape()[n];
710        let mut data = vec![0.0; rows * rank];
711        for i in 0..rows {
712            for r in 0..rank {
713                data[i * rank + r] = ((i + 1) as f64 * (r + 1) as f64 * 0.37).sin().abs() + 0.01;
714            }
715        }
716        factors.push(Matrix::new(rows, rank, data)?);
717    }
718
719    let mut weights = vec![1.0; rank];
720    let mut prev_fit = f64::MAX;
721
722    for _iter in 0..config.max_iterations {
723        for n in 0..ndim {
724            // Build the list of modes excluding n, in the order that
725            // matches the column layout of the mode-n unfolding.
726            // The unfolding column index cycles through modes != n
727            // in the order they appear (skipping n), from last to first.
728            let other_modes: Vec<usize> = (0..ndim).filter(|&m| m != n).collect();
729
730            // Khatri-Rao product of factors for other modes (in column order)
731            let mut kr = factors[other_modes[0]].clone();
732            for &m in &other_modes[1..] {
733                kr = khatri_rao_product(&kr, &factors[m])?;
734            }
735
736            // Gram matrix = Hadamard of (F_m^T * F_m) for all m != n
737            let mut gram = {
738                let ft = factors[other_modes[0]].transpose();
739                ft.matmul(&factors[other_modes[0]])?
740            };
741            for &m in &other_modes[1..] {
742                let ft = factors[m].transpose();
743                let g = ft.matmul(&factors[m])?;
744                gram = hadamard_product(&gram, &g)?;
745            }
746
747            // Unfold tensor along mode n
748            let x_n = tensor.unfold(n)?;
749
750            // V = X_(n) * KR
751            let v = x_n.matmul(&kr)?;
752
753            // Solve A^(n) = V * gram^{-1} via normal equations
754            let gram_inv = invert_small_matrix(&gram)?;
755            factors[n] = v.matmul(&gram_inv)?;
756        }
757
758        // Normalize factors
759        if config.normalize_factors {
760            for w in weights.iter_mut() {
761                *w = 1.0;
762            }
763            for factor in factors.iter_mut() {
764                let norms = factor.normalize_columns();
765                for (w, &norm) in weights.iter_mut().zip(norms.iter()) {
766                    *w *= norm;
767                }
768            }
769        }
770
771        // Check convergence
772        let decomp = CpDecomposition {
773            weights: weights.clone(),
774            factors: factors.clone(),
775        };
776        let fit = decomp.fit_error(tensor).unwrap_or(f64::MAX);
777        if (prev_fit - fit).abs() < config.tolerance {
778            return Ok(decomp);
779        }
780        prev_fit = fit;
781    }
782
783    Ok(CpDecomposition { weights, factors })
784}
785
786/// Invert a small R×R matrix via Gauss-Jordan elimination.
787fn invert_small_matrix(m: &Matrix) -> SolverResult<Matrix> {
788    if m.rows != m.cols {
789        return Err(SolverError::DimensionMismatch(
790            "matrix must be square to invert".to_string(),
791        ));
792    }
793    let n = m.rows;
794    // Augmented matrix [A | I]
795    let mut aug = Matrix::zeros(n, 2 * n);
796    for r in 0..n {
797        for c in 0..n {
798            aug.set(r, c, m.get(r, c));
799        }
800        aug.set(r, n + r, 1.0);
801    }
802
803    for col in 0..n {
804        // Find pivot
805        let mut max_val = aug.get(col, col).abs();
806        let mut max_row = col;
807        for r in (col + 1)..n {
808            let v = aug.get(r, col).abs();
809            if v > max_val {
810                max_val = v;
811                max_row = r;
812            }
813        }
814        if max_val < 1e-14 {
815            return Err(SolverError::SingularMatrix);
816        }
817
818        // Swap rows
819        if max_row != col {
820            for c in 0..(2 * n) {
821                let tmp = aug.get(col, c);
822                aug.set(col, c, aug.get(max_row, c));
823                aug.set(max_row, c, tmp);
824            }
825        }
826
827        // Scale pivot row
828        let pivot = aug.get(col, col);
829        for c in 0..(2 * n) {
830            aug.set(col, c, aug.get(col, c) / pivot);
831        }
832
833        // Eliminate other rows
834        for r in 0..n {
835            if r == col {
836                continue;
837            }
838            let factor = aug.get(r, col);
839            for c in 0..(2 * n) {
840                let v = aug.get(r, c) - factor * aug.get(col, c);
841                aug.set(r, c, v);
842            }
843        }
844    }
845
846    // Extract inverse
847    let mut inv = Matrix::zeros(n, n);
848    for r in 0..n {
849        for c in 0..n {
850            inv.set(r, c, aug.get(r, n + c));
851        }
852    }
853    Ok(inv)
854}
855
856// ---------------------------------------------------------------------------
857// Tucker Decomposition
858// ---------------------------------------------------------------------------
859
860/// Tucker decomposition result.
861///
862/// Represents a tensor as a core tensor multiplied by factor matrices:
863/// `X ≈ G ×_1 U^(1) ×_2 U^(2) ... ×_N U^(N)`
864#[derive(Debug, Clone)]
865pub struct TuckerDecomposition {
866    /// Core tensor with shape `[R_1, R_2, ..., R_N]`.
867    pub core: Tensor,
868    /// Factor matrices, one per mode. Factor\[n\] is `I_n × R_n`.
869    pub factors: Vec<Matrix>,
870}
871
872impl TuckerDecomposition {
873    /// Reconstructs the full tensor from the Tucker factors.
874    pub fn reconstruct(&self) -> SolverResult<Tensor> {
875        let mut result = self.core.clone();
876        for (n, factor) in self.factors.iter().enumerate() {
877            result = mode_n_product(&result, factor, n)?;
878        }
879        Ok(result)
880    }
881
882    /// Computes the relative fit error: `||X - X_hat|| / ||X||`.
883    pub fn fit_error(&self, original: &Tensor) -> SolverResult<f64> {
884        let reconstructed = self.reconstruct()?;
885        let orig_norm = original.frobenius_norm();
886        if orig_norm < 1e-15 {
887            return Ok(0.0);
888        }
889        let diff_norm: f64 = original
890            .data()
891            .iter()
892            .zip(reconstructed.data().iter())
893            .map(|(a, b)| (a - b) * (a - b))
894            .sum::<f64>()
895            .sqrt();
896        Ok(diff_norm / orig_norm)
897    }
898
899    /// Compression ratio: original elements / decomposition elements.
900    pub fn compression_ratio(&self, original_shape: &[usize]) -> f64 {
901        let original_size: usize = original_shape.iter().product();
902        let core_size = self.core.numel();
903        let factor_size: usize = self.factors.iter().map(|f| f.rows * f.cols).sum();
904        let decomp_size = core_size + factor_size;
905        if decomp_size == 0 {
906            return 0.0;
907        }
908        original_size as f64 / decomp_size as f64
909    }
910}
911
912/// Configuration for Tucker decomposition.
913#[derive(Debug, Clone)]
914pub struct TuckerConfig {
915    /// Target ranks, one per mode.
916    pub ranks: Vec<usize>,
917    /// Maximum number of HOOI iterations (unused for HOSVD).
918    pub max_iterations: usize,
919    /// Convergence tolerance for HOOI.
920    pub tolerance: f64,
921}
922
923impl Default for TuckerConfig {
924    fn default() -> Self {
925        Self {
926            ranks: vec![2, 2, 2],
927            max_iterations: 50,
928            tolerance: 1e-8,
929        }
930    }
931}
932
933/// Higher-Order SVD (HOSVD) for Tucker decomposition.
934///
935/// Non-iterative method: for each mode, compute the truncated SVD of the
936/// mode-n unfolding and use the left singular vectors as the factor matrix.
937pub fn tucker_hosvd(tensor: &Tensor, config: &TuckerConfig) -> SolverResult<TuckerDecomposition> {
938    let ndim = tensor.ndim();
939    if config.ranks.len() != ndim {
940        return Err(SolverError::DimensionMismatch(format!(
941            "Tucker config requires {} ranks (one per mode), got {}",
942            ndim,
943            config.ranks.len()
944        )));
945    }
946
947    // Compute factor matrices via mode-n SVD
948    let mut factors: Vec<Matrix> = Vec::with_capacity(ndim);
949    for n in 0..ndim {
950        let unfolded = tensor.unfold(n)?;
951        let rank_n = config.ranks[n].min(tensor.shape()[n]);
952        let (u, _sigma, _v) = unfolded.svd_truncated(rank_n)?;
953        factors.push(u);
954    }
955
956    // Core tensor = X ×_1 U^(1)^T ×_2 U^(2)^T ... ×_N U^(N)^T
957    let mut core = tensor.clone();
958    for (n, factor) in factors.iter().enumerate() {
959        let ft = factor.transpose();
960        core = mode_n_product(&core, &ft, n)?;
961    }
962
963    Ok(TuckerDecomposition { core, factors })
964}
965
966/// Higher-Order Orthogonal Iteration (HOOI) for Tucker decomposition.
967///
968/// Iterative refinement of HOSVD. Alternately optimizes each factor matrix
969/// while holding others fixed.
970pub fn tucker_hooi(tensor: &Tensor, config: &TuckerConfig) -> SolverResult<TuckerDecomposition> {
971    let ndim = tensor.ndim();
972    if config.ranks.len() != ndim {
973        return Err(SolverError::DimensionMismatch(format!(
974            "Tucker config requires {} ranks (one per mode), got {}",
975            ndim,
976            config.ranks.len()
977        )));
978    }
979
980    // Initialize with HOSVD
981    let mut decomp = tucker_hosvd(tensor, config)?;
982    let mut prev_core_norm = decomp.core.frobenius_norm();
983
984    for _iter in 0..config.max_iterations {
985        for n in 0..ndim {
986            // Compute Y = X ×_{m≠n} U^(m)^T
987            let mut y = tensor.clone();
988            for (m, factor) in decomp.factors.iter().enumerate() {
989                if m != n {
990                    let ft = factor.transpose();
991                    y = mode_n_product(&y, &ft, m)?;
992                }
993            }
994
995            // SVD of mode-n unfolding of Y
996            let y_n = y.unfold(n)?;
997            let rank_n = config.ranks[n].min(tensor.shape()[n]);
998            let (u, _sigma, _v) = y_n.svd_truncated(rank_n)?;
999            decomp.factors[n] = u;
1000        }
1001
1002        // Recompute core
1003        let mut core = tensor.clone();
1004        for (n, factor) in decomp.factors.iter().enumerate() {
1005            let ft = factor.transpose();
1006            core = mode_n_product(&core, &ft, n)?;
1007        }
1008        decomp.core = core;
1009
1010        // Check convergence
1011        let core_norm = decomp.core.frobenius_norm();
1012        if (core_norm - prev_core_norm).abs() / (prev_core_norm + 1e-15) < config.tolerance {
1013            break;
1014        }
1015        prev_core_norm = core_norm;
1016    }
1017
1018    Ok(decomp)
1019}
1020
1021// ---------------------------------------------------------------------------
1022// Tensor-Train (TT) Decomposition
1023// ---------------------------------------------------------------------------
1024
1025/// Tensor-Train (TT) decomposition result.
1026///
1027/// Represents a tensor as a chain of 3D cores:
1028/// `X(i_1, ..., i_N) = G_1(i_1) · G_2(i_2) · ... · G_N(i_N)`
1029/// where each `G_k(i_k)` is an `r_{k-1} × r_k` matrix slice, and
1030/// core k has shape `r_{k-1} × n_k × r_k`.
1031#[derive(Debug, Clone)]
1032pub struct TtDecomposition {
1033    /// TT-cores, each a 3D tensor of shape `[r_{k-1}, n_k, r_k]`.
1034    pub cores: Vec<Tensor>,
1035}
1036
1037impl TtDecomposition {
1038    /// Returns the TT-ranks `[r_0, r_1, ..., r_N]` where r_0 = r_N = 1.
1039    pub fn ranks(&self) -> Vec<usize> {
1040        let mut ranks = Vec::with_capacity(self.cores.len() + 1);
1041        if self.cores.is_empty() {
1042            return ranks;
1043        }
1044        ranks.push(self.cores[0].shape()[0]);
1045        for core in &self.cores {
1046            ranks.push(core.shape()[2]);
1047        }
1048        ranks
1049    }
1050
1051    /// Reconstructs the full tensor from TT-cores.
1052    pub fn reconstruct(&self) -> SolverResult<Tensor> {
1053        if self.cores.is_empty() {
1054            return Err(SolverError::InternalError(
1055                "TT decomposition has no cores".to_string(),
1056            ));
1057        }
1058
1059        let shape: Vec<usize> = self.cores.iter().map(|c| c.shape()[1]).collect();
1060        let ndim = shape.len();
1061        let numel: usize = shape.iter().product();
1062        let mut data = vec![0.0; numel];
1063
1064        let mut indices = vec![0usize; ndim];
1065        for (flat, datum) in data.iter_mut().enumerate() {
1066            let mut rem = flat;
1067            for d in (0..ndim).rev() {
1068                indices[d] = rem % shape[d];
1069                rem /= shape[d];
1070            }
1071
1072            // Compute the matrix product G_1(i_1) * G_2(i_2) * ... * G_N(i_N)
1073            // Start with G_1(i_1): shape [r_0, r_1] = [1, r_1]
1074            let core0 = &self.cores[0];
1075            let r1 = core0.shape()[2];
1076            let mut current: Vec<f64> = (0..r1)
1077                .map(|j| core0.get(&[0, indices[0], j]))
1078                .collect::<SolverResult<_>>()?;
1079
1080            for (k, &idx_k) in indices.iter().enumerate().skip(1) {
1081                let core_k = &self.cores[k];
1082                let r_next = core_k.shape()[2];
1083                let mut next = vec![0.0; r_next];
1084                for (j, nj) in next.iter_mut().enumerate() {
1085                    let mut sum = 0.0;
1086                    for (i, &ci) in current.iter().enumerate() {
1087                        sum += ci * core_k.get(&[i, idx_k, j])?;
1088                    }
1089                    *nj = sum;
1090                }
1091                current = next;
1092            }
1093
1094            *datum = current[0]; // r_N = 1, so scalar result
1095        }
1096
1097        Tensor::new(shape, data)
1098    }
1099
1100    /// Computes the relative fit error: `||X - X_hat|| / ||X||`.
1101    pub fn fit_error(&self, original: &Tensor) -> SolverResult<f64> {
1102        let reconstructed = self.reconstruct()?;
1103        let orig_norm = original.frobenius_norm();
1104        if orig_norm < 1e-15 {
1105            return Ok(0.0);
1106        }
1107        let diff_norm: f64 = original
1108            .data()
1109            .iter()
1110            .zip(reconstructed.data().iter())
1111            .map(|(a, b)| (a - b) * (a - b))
1112            .sum::<f64>()
1113            .sqrt();
1114        Ok(diff_norm / orig_norm)
1115    }
1116
1117    /// Compression ratio: original elements / decomposition elements.
1118    pub fn compression_ratio(&self, original_shape: &[usize]) -> f64 {
1119        let original_size: usize = original_shape.iter().product();
1120        let decomp_size: usize = self.cores.iter().map(|c| c.numel()).sum();
1121        if decomp_size == 0 {
1122            return 0.0;
1123        }
1124        original_size as f64 / decomp_size as f64
1125    }
1126
1127    /// TT-rounding: truncates TT-ranks to at most `max_rank`.
1128    ///
1129    /// Performs left-to-right QR orthogonalization, then right-to-left
1130    /// truncated SVD to reduce ranks.
1131    pub fn tt_round(&self, max_rank: usize) -> SolverResult<TtDecomposition> {
1132        if self.cores.is_empty() {
1133            return Err(SolverError::InternalError(
1134                "TT decomposition has no cores".to_string(),
1135            ));
1136        }
1137
1138        let ndim = self.cores.len();
1139        let mut cores = self.cores.clone();
1140
1141        // Left-to-right QR sweep (orthogonalize)
1142        for k in 0..(ndim - 1) {
1143            let r_prev = cores[k].shape()[0];
1144            let n_k = cores[k].shape()[1];
1145            let r_next = cores[k].shape()[2];
1146
1147            // Reshape core_k to (r_prev * n_k) × r_next matrix
1148            let mat = Matrix::new(r_prev * n_k, r_next, cores[k].data().to_vec())?;
1149            let (q, r_mat) = qr_gram_schmidt(&mat);
1150
1151            let new_r = q.cols;
1152            cores[k] = Tensor::new(vec![r_prev, n_k, new_r], q.data.clone())?;
1153
1154            // Absorb R into next core
1155            let r_next2 = cores[k + 1].shape()[2];
1156            let n_next = cores[k + 1].shape()[1];
1157            let next_mat = Matrix::new(r_next, n_next * r_next2, cores[k + 1].data().to_vec())?;
1158            let absorbed = r_mat.matmul(&next_mat)?;
1159            cores[k + 1] = Tensor::new(vec![new_r, n_next, r_next2], absorbed.data.clone())?;
1160        }
1161
1162        // Right-to-left SVD sweep (truncate)
1163        for k in (1..ndim).rev() {
1164            let r_prev = cores[k].shape()[0];
1165            let n_k = cores[k].shape()[1];
1166            let r_next = cores[k].shape()[2];
1167
1168            // Reshape core_k to r_prev × (n_k * r_next) matrix
1169            let mat = Matrix::new(r_prev, n_k * r_next, cores[k].data().to_vec())?;
1170            let trunc_rank = max_rank.min(r_prev).min(n_k * r_next);
1171            let (u, sigma, v) = mat.svd_truncated(trunc_rank)?;
1172
1173            // New core_k = diag(sigma) * V^T reshaped to [trunc_rank, n_k, r_next]
1174            let mut sv = Matrix::zeros(trunc_rank, n_k * r_next);
1175            for (i, &si) in sigma.iter().enumerate().take(trunc_rank) {
1176                for j in 0..(n_k * r_next) {
1177                    sv.set(i, j, si * v.get(j, i));
1178                }
1179            }
1180            cores[k] = Tensor::new(vec![trunc_rank, n_k, r_next], sv.data.clone())?;
1181
1182            // Absorb U into previous core
1183            let prev_r_prev = cores[k - 1].shape()[0];
1184            let prev_n = cores[k - 1].shape()[1];
1185            let prev_mat = Matrix::new(prev_r_prev * prev_n, r_prev, cores[k - 1].data().to_vec())?;
1186            let absorbed = prev_mat.matmul(&u)?;
1187            cores[k - 1] =
1188                Tensor::new(vec![prev_r_prev, prev_n, trunc_rank], absorbed.data.clone())?;
1189        }
1190
1191        Ok(TtDecomposition { cores })
1192    }
1193}
1194
1195/// Configuration for TT-SVD algorithm.
1196#[derive(Debug, Clone)]
1197pub struct TtConfig {
1198    /// Maximum TT-rank.
1199    pub max_rank: usize,
1200    /// Truncation tolerance.
1201    pub tolerance: f64,
1202}
1203
1204impl Default for TtConfig {
1205    fn default() -> Self {
1206        Self {
1207            max_rank: 10,
1208            tolerance: 1e-8,
1209        }
1210    }
1211}
1212
1213/// TT-SVD decomposition.
1214///
1215/// Decomposes a tensor into Tensor-Train format via sequential SVDs
1216/// from left to right.
1217pub fn tt_svd(tensor: &Tensor, config: &TtConfig) -> SolverResult<TtDecomposition> {
1218    let ndim = tensor.ndim();
1219    if ndim < 2 {
1220        // For 1D tensors, wrap as a single core [1, n, 1]
1221        let n = tensor.shape()[0];
1222        let core = Tensor::new(vec![1, n, 1], tensor.data().to_vec())?;
1223        return Ok(TtDecomposition { cores: vec![core] });
1224    }
1225
1226    if config.max_rank == 0 {
1227        return Err(SolverError::InternalError(
1228            "TT max_rank must be positive".to_string(),
1229        ));
1230    }
1231
1232    let shape = tensor.shape().to_vec();
1233    let mut cores: Vec<Tensor> = Vec::with_capacity(ndim);
1234    let mut remaining_data = tensor.data().to_vec();
1235    let mut r_prev = 1usize;
1236
1237    for k in 0..(ndim - 1) {
1238        let n_k = shape[k];
1239        let remaining_size: usize = shape[(k + 1)..].iter().product();
1240
1241        // Reshape to (r_prev * n_k) × remaining
1242        let rows = r_prev * n_k;
1243        let cols = remaining_size;
1244
1245        // Handle edge case where data might have been truncated
1246        let actual_len = remaining_data.len();
1247        if actual_len != rows * cols {
1248            return Err(SolverError::InternalError(format!(
1249                "TT-SVD reshape error at mode {}: expected {} elements, have {}",
1250                k,
1251                rows * cols,
1252                actual_len
1253            )));
1254        }
1255
1256        let mat = Matrix::new(rows, cols, remaining_data)?;
1257
1258        // Truncated SVD
1259        let trunc_rank = config.max_rank.min(rows).min(cols);
1260        let (u, sigma, v) = mat.svd_truncated(trunc_rank)?;
1261
1262        // Determine effective rank (truncate small singular values)
1263        let total_sv_norm: f64 = sigma.iter().map(|s| s * s).sum::<f64>().sqrt();
1264        let mut effective_rank = trunc_rank;
1265        if total_sv_norm > 1e-15 {
1266            let mut accumulated = 0.0;
1267            for (i, &s) in sigma.iter().enumerate().rev() {
1268                accumulated += s * s;
1269                if accumulated.sqrt() / total_sv_norm > config.tolerance {
1270                    effective_rank = i + 1;
1271                    break;
1272                }
1273            }
1274        }
1275        effective_rank = effective_rank.min(trunc_rank);
1276        if effective_rank == 0 {
1277            effective_rank = 1;
1278        }
1279
1280        // Core_k = U reshaped to [r_prev, n_k, effective_rank]
1281        let mut core_data = vec![0.0; r_prev * n_k * effective_rank];
1282        for i in 0..rows {
1283            for j in 0..effective_rank {
1284                core_data[i * effective_rank + j] = u.get(i, j);
1285            }
1286        }
1287        cores.push(Tensor::new(vec![r_prev, n_k, effective_rank], core_data)?);
1288
1289        // Remaining = diag(sigma) * V^T
1290        let new_cols = cols;
1291        let mut new_remaining = vec![0.0; effective_rank * new_cols];
1292        for i in 0..effective_rank {
1293            for j in 0..new_cols {
1294                new_remaining[i * new_cols + j] = sigma[i] * v.get(j, i);
1295            }
1296        }
1297        remaining_data = new_remaining;
1298        r_prev = effective_rank;
1299    }
1300
1301    // Last core: [r_prev, n_{N-1}, 1]
1302    let n_last = shape[ndim - 1];
1303    if remaining_data.len() != r_prev * n_last {
1304        return Err(SolverError::InternalError(format!(
1305            "TT-SVD final reshape error: expected {} elements, have {}",
1306            r_prev * n_last,
1307            remaining_data.len()
1308        )));
1309    }
1310    let mut last_core_data = vec![0.0; r_prev * n_last];
1311    for i in 0..r_prev {
1312        for j in 0..n_last {
1313            last_core_data[i * n_last + j] = remaining_data[i * n_last + j];
1314        }
1315    }
1316    cores.push(Tensor::new(vec![r_prev, n_last, 1], last_core_data)?);
1317
1318    Ok(TtDecomposition { cores })
1319}
1320
1321// ---------------------------------------------------------------------------
1322// Tests
1323// ---------------------------------------------------------------------------
1324
1325#[cfg(test)]
1326mod tests {
1327    use super::*;
1328
1329    /// Helper: creates a small 3×4×2 test tensor.
1330    fn make_test_tensor_3d() -> Tensor {
1331        let shape = vec![3, 4, 2];
1332        let data: Vec<f64> = (0..24).map(|i| (i as f64) * 0.5 + 1.0).collect();
1333        Tensor::new(shape, data).expect("failed to create test tensor")
1334    }
1335
1336    /// Helper: creates a rank-1 tensor: outer product of [1,2,3] x [1,2] x [1,2,3,4].
1337    fn make_rank1_tensor() -> Tensor {
1338        let a = [1.0, 2.0, 3.0];
1339        let b = [1.0, 2.0];
1340        let c = [1.0, 2.0, 3.0, 4.0];
1341        let shape = vec![3, 2, 4];
1342        let mut data = vec![0.0; 24];
1343        for i in 0..3 {
1344            for j in 0..2 {
1345                for k in 0..4 {
1346                    data[i * 8 + j * 4 + k] = a[i] * b[j] * c[k];
1347                }
1348            }
1349        }
1350        Tensor::new(shape, data).expect("failed to create rank-1 tensor")
1351    }
1352
1353    #[test]
1354    fn test_tensor_creation_and_indexing() {
1355        let t = make_test_tensor_3d();
1356        assert_eq!(t.ndim(), 3);
1357        assert_eq!(t.shape(), &[3, 4, 2]);
1358        assert_eq!(t.numel(), 24);
1359
1360        // First element
1361        let v = t.get(&[0, 0, 0]).expect("get failed");
1362        assert!((v - 1.0).abs() < 1e-12);
1363
1364        // Last element
1365        let v = t.get(&[2, 3, 1]).expect("get failed");
1366        assert!((v - 12.5).abs() < 1e-12);
1367    }
1368
1369    #[test]
1370    fn test_tensor_set() {
1371        let mut t = make_test_tensor_3d();
1372        t.set(&[1, 2, 0], 99.0).expect("set failed");
1373        let v = t.get(&[1, 2, 0]).expect("get failed");
1374        assert!((v - 99.0).abs() < 1e-12);
1375    }
1376
1377    #[test]
1378    fn test_tensor_index_out_of_range() {
1379        let t = make_test_tensor_3d();
1380        assert!(t.get(&[3, 0, 0]).is_err());
1381        assert!(t.get(&[0, 4, 0]).is_err());
1382        assert!(t.get(&[0, 0]).is_err()); // wrong ndim
1383    }
1384
1385    #[test]
1386    fn test_mode_n_unfolding_and_folding_roundtrip() {
1387        let t = make_test_tensor_3d();
1388        for mode in 0..3 {
1389            let mat = t.unfold(mode).expect("unfold failed");
1390
1391            // Check dimensions
1392            assert_eq!(mat.rows, t.shape()[mode]);
1393            assert_eq!(mat.rows * mat.cols, t.numel());
1394
1395            // Fold back
1396            let recovered = Tensor::fold(&mat, mode, t.shape()).expect("fold failed");
1397            for i in 0..t.numel() {
1398                assert!(
1399                    (t.data()[i] - recovered.data()[i]).abs() < 1e-12,
1400                    "mismatch at element {} for mode {} unfold/fold",
1401                    i,
1402                    mode
1403                );
1404            }
1405        }
1406    }
1407
1408    #[test]
1409    fn test_matrix_svd_truncated() {
1410        // Create a rank-2 matrix
1411        let mut data = vec![0.0; 12];
1412        for i in 0..4 {
1413            for j in 0..3 {
1414                data[i * 3 + j] = (i + 1) as f64 * (j + 1) as f64;
1415            }
1416        }
1417        let m = Matrix::new(4, 3, data).expect("matrix creation failed");
1418        let (u, sigma, v) = m.svd_truncated(2).expect("svd failed");
1419
1420        assert_eq!(u.rows, 4);
1421        assert_eq!(u.cols, 2);
1422        assert_eq!(sigma.len(), 2);
1423        assert_eq!(v.rows, 3);
1424        assert_eq!(v.cols, 2);
1425
1426        // Singular values should be non-negative and decreasing
1427        assert!(sigma[0] >= sigma[1]);
1428        assert!(sigma[0] > 0.0);
1429
1430        // Reconstruct: U * diag(sigma) * V^T should approximate original
1431        let mut reconstructed = Matrix::zeros(4, 3);
1432        for i in 0..4 {
1433            for j in 0..3 {
1434                let mut val = 0.0;
1435                for (r, sigma_r) in sigma.iter().enumerate().take(2) {
1436                    val += u.get(i, r) * sigma_r * v.get(j, r);
1437                }
1438                reconstructed.set(i, j, val);
1439            }
1440        }
1441        let error = (m.data)
1442            .iter()
1443            .zip(reconstructed.data.iter())
1444            .map(|(a, b)| (a - b) * (a - b))
1445            .sum::<f64>()
1446            .sqrt();
1447        let norm = m.frobenius_norm();
1448        // This is rank-1 matrix, so rank-2 truncation should be exact
1449        assert!(
1450            error / norm < 0.05,
1451            "SVD reconstruction error too large: {}",
1452            error / norm
1453        );
1454    }
1455
1456    #[test]
1457    fn test_khatri_rao_product() {
1458        let a = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]).expect("a");
1459        let b = Matrix::new(3, 2, vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]).expect("b");
1460        let kr = khatri_rao_product(&a, &b).expect("kr");
1461
1462        assert_eq!(kr.rows, 6);
1463        assert_eq!(kr.cols, 2);
1464
1465        // Column 0: [1*5, 1*7, 1*9, 3*5, 3*7, 3*9]
1466        assert!((kr.get(0, 0) - 5.0).abs() < 1e-12);
1467        assert!((kr.get(1, 0) - 7.0).abs() < 1e-12);
1468        assert!((kr.get(2, 0) - 9.0).abs() < 1e-12);
1469        assert!((kr.get(3, 0) - 15.0).abs() < 1e-12);
1470    }
1471
1472    #[test]
1473    fn test_hadamard_product() {
1474        let a = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("a");
1475        let b = Matrix::new(2, 3, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("b");
1476        let h = hadamard_product(&a, &b).expect("hadamard");
1477
1478        assert_eq!(h.rows, 2);
1479        assert_eq!(h.cols, 3);
1480        assert!((h.get(0, 0) - 7.0).abs() < 1e-12);
1481        assert!((h.get(1, 2) - 72.0).abs() < 1e-12);
1482    }
1483
1484    #[test]
1485    fn test_hadamard_product_dimension_mismatch() {
1486        let a = Matrix::new(2, 3, vec![0.0; 6]).expect("a");
1487        let b = Matrix::new(3, 2, vec![0.0; 6]).expect("b");
1488        assert!(hadamard_product(&a, &b).is_err());
1489    }
1490
1491    #[test]
1492    fn test_mode_n_product() {
1493        let t = make_test_tensor_3d(); // 3×4×2
1494        let m = Matrix::new(5, 4, vec![0.1; 20]).expect("matrix");
1495        let result = mode_n_product(&t, &m, 1).expect("mode_n_product");
1496        assert_eq!(result.shape(), &[3, 5, 2]);
1497    }
1498
1499    #[test]
1500    fn test_cp_als_rank1_tensor() {
1501        let t = make_rank1_tensor();
1502        let config = CpAlsConfig {
1503            rank: 1,
1504            max_iterations: 200,
1505            tolerance: 1e-10,
1506            normalize_factors: true,
1507        };
1508        let decomp = cp_als(&t, &config).expect("cp_als failed");
1509
1510        assert_eq!(decomp.rank(), 1);
1511        assert_eq!(decomp.factors.len(), 3);
1512        assert_eq!(decomp.factors[0].rows, 3);
1513        assert_eq!(decomp.factors[1].rows, 2);
1514        assert_eq!(decomp.factors[2].rows, 4);
1515
1516        // Rank-1 tensor should be recovered nearly exactly
1517        let error = decomp.fit_error(&t).expect("fit_error failed");
1518        assert!(error < 0.01, "CP-ALS rank-1 error too large: {}", error);
1519    }
1520
1521    #[test]
1522    fn test_cp_als_rank3_tensor() {
1523        // Create a rank-3 tensor by summing 3 rank-1 components
1524        let shape = vec![4, 3, 5];
1525        let numel = 60;
1526        let mut data = vec![0.0; numel];
1527        for r in 0..3 {
1528            for i in 0..4 {
1529                for j in 0..3 {
1530                    for k in 0..5 {
1531                        let a_val = ((i + 1) as f64 * (r + 1) as f64 * 0.3).sin();
1532                        let b_val = ((j + 1) as f64 * (r + 1) as f64 * 0.5).cos();
1533                        let c_val = ((k + 1) as f64 * (r + 1) as f64 * 0.7).sin();
1534                        data[i * 15 + j * 5 + k] += a_val * b_val * c_val;
1535                    }
1536                }
1537            }
1538        }
1539        let t = Tensor::new(shape, data).expect("tensor");
1540
1541        let config = CpAlsConfig {
1542            rank: 3,
1543            max_iterations: 300,
1544            tolerance: 1e-10,
1545            normalize_factors: true,
1546        };
1547        let decomp = cp_als(&t, &config).expect("cp_als failed");
1548        assert_eq!(decomp.rank(), 3);
1549
1550        let error = decomp.fit_error(&t).expect("fit_error");
1551        assert!(error < 0.5, "CP-ALS rank-3 error too large: {}", error);
1552    }
1553
1554    #[test]
1555    fn test_cp_reconstruction() {
1556        let t = make_rank1_tensor();
1557        let config = CpAlsConfig {
1558            rank: 1,
1559            max_iterations: 200,
1560            tolerance: 1e-10,
1561            normalize_factors: true,
1562        };
1563        let decomp = cp_als(&t, &config).expect("cp_als");
1564        let recon = decomp.reconstruct().expect("reconstruct");
1565
1566        assert_eq!(recon.shape(), t.shape());
1567        assert_eq!(recon.numel(), t.numel());
1568    }
1569
1570    #[test]
1571    fn test_tucker_hosvd() {
1572        let t = make_test_tensor_3d();
1573        let config = TuckerConfig {
1574            ranks: vec![2, 3, 2],
1575            max_iterations: 50,
1576            tolerance: 1e-8,
1577        };
1578        let decomp = tucker_hosvd(&t, &config).expect("tucker_hosvd");
1579
1580        assert_eq!(decomp.core.shape(), &[2, 3, 2]);
1581        assert_eq!(decomp.factors.len(), 3);
1582        assert_eq!(decomp.factors[0].rows, 3); // original dim
1583        assert_eq!(decomp.factors[0].cols, 2); // rank
1584        assert_eq!(decomp.factors[1].rows, 4);
1585        assert_eq!(decomp.factors[1].cols, 3);
1586        assert_eq!(decomp.factors[2].rows, 2);
1587        assert_eq!(decomp.factors[2].cols, 2);
1588
1589        // Should reconstruct reasonably well
1590        let error = decomp.fit_error(&t).expect("fit_error");
1591        assert!(error < 0.5, "Tucker HOSVD error too large: {}", error);
1592    }
1593
1594    #[test]
1595    fn test_tucker_hooi_convergence() {
1596        // Use a larger tensor where compression actually happens
1597        let shape = vec![6, 6, 6];
1598        let data: Vec<f64> = (0..216)
1599            .map(|i| ((i as f64) * 0.13 + 0.7).sin() * ((i as f64) * 0.07).cos())
1600            .collect();
1601        let t = Tensor::new(shape, data).expect("tensor");
1602
1603        let config = TuckerConfig {
1604            ranks: vec![2, 2, 2],
1605            max_iterations: 30,
1606            tolerance: 1e-8,
1607        };
1608
1609        let hosvd_decomp = tucker_hosvd(&t, &config).expect("hosvd");
1610        let hooi_decomp = tucker_hooi(&t, &config).expect("hooi");
1611
1612        let hosvd_error = hosvd_decomp.fit_error(&t).expect("fit");
1613        let hooi_error = hooi_decomp.fit_error(&t).expect("fit");
1614
1615        // HOOI should be at least as good as HOSVD (or close)
1616        assert!(
1617            hooi_error <= hosvd_error + 0.05,
1618            "HOOI error {} should not be much worse than HOSVD error {}",
1619            hooi_error,
1620            hosvd_error
1621        );
1622    }
1623
1624    #[test]
1625    fn test_tucker_compression_ratio() {
1626        let t = make_test_tensor_3d();
1627        let config = TuckerConfig {
1628            ranks: vec![2, 2, 2],
1629            max_iterations: 50,
1630            tolerance: 1e-8,
1631        };
1632        let decomp = tucker_hosvd(&t, &config).expect("tucker");
1633        let ratio = decomp.compression_ratio(t.shape());
1634
1635        // Core: 2*2*2 = 8, Factors: 3*2 + 4*2 + 2*2 = 18, Total = 26
1636        // Original: 24. Ratio = 24/26 ~ 0.92
1637        assert!(ratio > 0.0);
1638    }
1639
1640    #[test]
1641    fn test_tt_svd_3d_tensor() {
1642        let t = make_test_tensor_3d();
1643        let config = TtConfig {
1644            max_rank: 10,
1645            tolerance: 1e-10,
1646        };
1647        let decomp = tt_svd(&t, &config).expect("tt_svd");
1648
1649        assert_eq!(decomp.cores.len(), 3);
1650
1651        // First core: [1, 3, r1]
1652        assert_eq!(decomp.cores[0].shape()[0], 1);
1653        assert_eq!(decomp.cores[0].shape()[1], 3);
1654
1655        // Last core: [r2, 2, 1]
1656        assert_eq!(decomp.cores[2].shape()[1], 2);
1657        assert_eq!(decomp.cores[2].shape()[2], 1);
1658
1659        // Ranks start with 1 and end with 1
1660        let ranks = decomp.ranks();
1661        assert_eq!(ranks[0], 1);
1662        assert_eq!(*ranks.last().expect("no ranks"), 1);
1663
1664        // Reconstruction
1665        let recon = decomp.reconstruct().expect("reconstruct");
1666        assert_eq!(recon.shape(), t.shape());
1667    }
1668
1669    #[test]
1670    fn test_tt_svd_4d_tensor() {
1671        let shape = vec![2, 3, 2, 4];
1672        let data: Vec<f64> = (0..48).map(|i| (i as f64 + 1.0) * 0.1).collect();
1673        let t = Tensor::new(shape, data).expect("tensor");
1674
1675        let config = TtConfig {
1676            max_rank: 10,
1677            tolerance: 1e-10,
1678        };
1679        let decomp = tt_svd(&t, &config).expect("tt_svd");
1680        assert_eq!(decomp.cores.len(), 4);
1681
1682        let ranks = decomp.ranks();
1683        assert_eq!(ranks[0], 1);
1684        assert_eq!(ranks[4], 1);
1685    }
1686
1687    #[test]
1688    fn test_tt_reconstruction_error() {
1689        let t = make_rank1_tensor();
1690        let config = TtConfig {
1691            max_rank: 5,
1692            tolerance: 1e-10,
1693        };
1694        let decomp = tt_svd(&t, &config).expect("tt_svd");
1695        let error = decomp.fit_error(&t).expect("fit_error");
1696
1697        // Rank-1 tensor should have small reconstruction error
1698        assert!(error < 0.1, "TT reconstruction error too large: {}", error);
1699    }
1700
1701    #[test]
1702    fn test_tt_rank_truncation() {
1703        let t = make_test_tensor_3d();
1704        let config = TtConfig {
1705            max_rank: 10,
1706            tolerance: 1e-10,
1707        };
1708        let decomp = tt_svd(&t, &config).expect("tt_svd");
1709
1710        // Truncate to max_rank=1
1711        let truncated = decomp.tt_round(1).expect("tt_round");
1712        let trunc_ranks = truncated.ranks();
1713        for &r in &trunc_ranks {
1714            assert!(r <= 1, "rank {} exceeds max_rank 1", r);
1715        }
1716    }
1717
1718    #[test]
1719    fn test_tt_compression_ratio() {
1720        let t = make_test_tensor_3d();
1721        let config = TtConfig {
1722            max_rank: 2,
1723            tolerance: 1e-10,
1724        };
1725        let decomp = tt_svd(&t, &config).expect("tt_svd");
1726        let ratio = decomp.compression_ratio(t.shape());
1727        assert!(ratio > 0.0);
1728    }
1729
1730    #[test]
1731    fn test_1d_tensor_vector() {
1732        let t = Tensor::new(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("1d tensor");
1733        assert_eq!(t.ndim(), 1);
1734        assert_eq!(t.numel(), 5);
1735
1736        let v = t.get(&[3]).expect("get");
1737        assert!((v - 4.0).abs() < 1e-12);
1738
1739        // TT of 1D tensor
1740        let config = TtConfig {
1741            max_rank: 5,
1742            tolerance: 1e-10,
1743        };
1744        let decomp = tt_svd(&t, &config).expect("tt_svd 1d");
1745        assert_eq!(decomp.cores.len(), 1);
1746        assert_eq!(decomp.cores[0].shape(), &[1, 5, 1]);
1747    }
1748
1749    #[test]
1750    fn test_config_validation() {
1751        let t = make_test_tensor_3d();
1752
1753        // CP with rank 0
1754        let config = CpAlsConfig {
1755            rank: 0,
1756            ..Default::default()
1757        };
1758        assert!(cp_als(&t, &config).is_err());
1759
1760        // Tucker with wrong number of ranks
1761        let config = TuckerConfig {
1762            ranks: vec![2, 2], // 3D tensor needs 3 ranks
1763            ..Default::default()
1764        };
1765        assert!(tucker_hosvd(&t, &config).is_err());
1766
1767        // TT with max_rank 0
1768        let config = TtConfig {
1769            max_rank: 0,
1770            tolerance: 1e-8,
1771        };
1772        assert!(tt_svd(&t, &config).is_err());
1773    }
1774}