1use crate::error::{SolverError, SolverResult};
22
23#[derive(Debug, Clone)]
29pub struct Matrix {
30 pub rows: usize,
32 pub cols: usize,
34 pub data: Vec<f64>,
36}
37
38impl Matrix {
39 pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> SolverResult<Self> {
41 if data.len() != rows * cols {
42 return Err(SolverError::DimensionMismatch(format!(
43 "matrix {}x{} requires {} elements, got {}",
44 rows,
45 cols,
46 rows * cols,
47 data.len()
48 )));
49 }
50 Ok(Self { rows, cols, data })
51 }
52
53 pub fn zeros(rows: usize, cols: usize) -> Self {
55 Self {
56 rows,
57 cols,
58 data: vec![0.0; rows * cols],
59 }
60 }
61
62 pub fn eye(n: usize) -> Self {
64 let mut data = vec![0.0; n * n];
65 for i in 0..n {
66 data[i * n + i] = 1.0;
67 }
68 Self {
69 rows: n,
70 cols: n,
71 data,
72 }
73 }
74
75 #[inline]
77 pub fn get(&self, r: usize, c: usize) -> f64 {
78 self.data[r * self.cols + c]
79 }
80
81 #[inline]
83 pub fn set(&mut self, r: usize, c: usize, v: f64) {
84 self.data[r * self.cols + c] = v;
85 }
86
87 pub fn transpose(&self) -> Self {
89 let mut out = Self::zeros(self.cols, self.rows);
90 for r in 0..self.rows {
91 for c in 0..self.cols {
92 out.set(c, r, self.get(r, c));
93 }
94 }
95 out
96 }
97
98 pub fn matmul(&self, other: &Matrix) -> SolverResult<Matrix> {
100 if self.cols != other.rows {
101 return Err(SolverError::DimensionMismatch(format!(
102 "matmul: {}x{} * {}x{}",
103 self.rows, self.cols, other.rows, other.cols
104 )));
105 }
106 let mut out = Matrix::zeros(self.rows, other.cols);
107 for i in 0..self.rows {
108 for k in 0..self.cols {
109 let a_ik = self.get(i, k);
110 for j in 0..other.cols {
111 let cur = out.get(i, j);
112 out.set(i, j, cur + a_ik * other.get(k, j));
113 }
114 }
115 }
116 Ok(out)
117 }
118
119 pub fn frobenius_norm(&self) -> f64 {
121 self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
122 }
123
124 pub fn column_norms(&self) -> Vec<f64> {
126 let mut norms = vec![0.0; self.cols];
127 for r in 0..self.rows {
128 for (c, norm) in norms.iter_mut().enumerate() {
129 let v = self.get(r, c);
130 *norm += v * v;
131 }
132 }
133 norms.iter().map(|s| s.sqrt()).collect()
134 }
135
136 pub fn normalize_columns(&mut self) -> Vec<f64> {
138 let norms = self.column_norms();
139 for (c, &norm) in norms.iter().enumerate() {
140 if norm > 1e-15 {
141 for r in 0..self.rows {
142 let v = self.get(r, c) / norm;
143 self.set(r, c, v);
144 }
145 }
146 }
147 norms
148 }
149
150 pub fn column(&self, c: usize) -> Vec<f64> {
152 (0..self.rows).map(|r| self.get(r, c)).collect()
153 }
154
155 pub fn svd_truncated(&self, rank: usize) -> SolverResult<(Matrix, Vec<f64>, Matrix)> {
161 let m = self.rows;
162 let n = self.cols;
163 let k = rank.min(m).min(n);
164 if k == 0 {
165 return Ok((Matrix::zeros(m, 0), Vec::new(), Matrix::zeros(n, 0)));
166 }
167
168 let mut u_mat = Matrix::zeros(m, k);
169 let mut v_mat = Matrix::zeros(n, k);
170 let mut sigma = vec![0.0; k];
171
172 let mut deflated = self.clone();
174
175 for (s, sigma_s) in sigma.iter_mut().enumerate().take(k) {
176 let mut v: Vec<f64> = (0..n)
178 .map(|i| ((i + 1) as f64 * (s + 1) as f64 * 0.7 + 0.3).sin())
179 .collect();
180 let mut vnorm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
182 if vnorm > 1e-15 {
183 for x in &mut v {
184 *x /= vnorm;
185 }
186 }
187
188 let max_iters = 200;
189 for _ in 0..max_iters {
190 let mut u: Vec<f64> = (0..m)
192 .map(|i| {
193 v.iter()
194 .enumerate()
195 .map(|(j, &vj)| deflated.get(i, j) * vj)
196 .sum()
197 })
198 .collect();
199
200 let sigma_new: f64 = u.iter().map(|x| x * x).sum::<f64>().sqrt();
202 if sigma_new < 1e-15 {
203 break;
204 }
205 for x in &mut u {
207 *x /= sigma_new;
208 }
209
210 let mut v_new: Vec<f64> = (0..n)
212 .map(|j| {
213 u.iter()
214 .enumerate()
215 .map(|(i, &ui)| deflated.get(i, j) * ui)
216 .sum()
217 })
218 .collect();
219
220 vnorm = v_new.iter().map(|x| x * x).sum::<f64>().sqrt();
222 if vnorm < 1e-15 {
223 break;
224 }
225 for x in &mut v_new {
226 *x /= vnorm;
227 }
228
229 let diff: f64 = v
231 .iter()
232 .zip(v_new.iter())
233 .map(|(a, b)| (a - b) * (a - b))
234 .sum::<f64>()
235 .sqrt();
236 v = v_new;
237
238 if diff < 1e-12 {
239 break;
240 }
241 }
242
243 let mut u: Vec<f64> = (0..m)
245 .map(|i| {
246 v.iter()
247 .enumerate()
248 .map(|(j, &vj)| deflated.get(i, j) * vj)
249 .sum()
250 })
251 .collect();
252 let sv = u.iter().map(|x| x * x).sum::<f64>().sqrt();
253 if sv > 1e-15 {
254 for x in &mut u {
255 *x /= sv;
256 }
257 }
258
259 *sigma_s = sv;
260 for (i, &ui) in u.iter().enumerate() {
261 u_mat.set(i, s, ui);
262 }
263 for (j, &vj) in v.iter().enumerate() {
264 v_mat.set(j, s, vj);
265 }
266
267 for (i, &ui) in u.iter().enumerate() {
269 for (j, &vj) in v.iter().enumerate() {
270 let old = deflated.get(i, j);
271 deflated.set(i, j, old - sv * ui * vj);
272 }
273 }
274 }
275
276 Ok((u_mat, sigma, v_mat))
277 }
278}
279
280fn qr_gram_schmidt(a: &Matrix) -> (Matrix, Matrix) {
283 let m = a.rows;
284 let k = a.cols;
285 let mut q = a.clone();
286 let mut r = Matrix::zeros(k, k);
287
288 for j in 0..k {
289 for i in 0..j {
291 let mut dot = 0.0;
292 for row in 0..m {
293 dot += q.get(row, i) * q.get(row, j);
294 }
295 r.set(i, j, dot);
296 for row in 0..m {
297 let v = q.get(row, j) - dot * q.get(row, i);
298 q.set(row, j, v);
299 }
300 }
301 let mut norm = 0.0;
303 for row in 0..m {
304 let v = q.get(row, j);
305 norm += v * v;
306 }
307 norm = norm.sqrt();
308 r.set(j, j, norm);
309 if norm > 1e-15 {
310 for row in 0..m {
311 let v = q.get(row, j) / norm;
312 q.set(row, j, v);
313 }
314 }
315 }
316
317 (q, r)
318}
319
320#[derive(Debug, Clone)]
326pub struct Tensor {
327 shape: Vec<usize>,
329 data: Vec<f64>,
331}
332
333impl Tensor {
334 pub fn new(shape: Vec<usize>, data: Vec<f64>) -> SolverResult<Self> {
336 let numel: usize = shape.iter().product();
337 if data.len() != numel {
338 return Err(SolverError::DimensionMismatch(format!(
339 "tensor with shape {:?} requires {} elements, got {}",
340 shape,
341 numel,
342 data.len()
343 )));
344 }
345 if shape.is_empty() {
346 return Err(SolverError::DimensionMismatch(
347 "tensor must have at least one dimension".to_string(),
348 ));
349 }
350 Ok(Self { shape, data })
351 }
352
353 pub fn zeros(shape: Vec<usize>) -> Self {
355 let numel: usize = shape.iter().product();
356 Self {
357 shape,
358 data: vec![0.0; numel],
359 }
360 }
361
362 pub fn ndim(&self) -> usize {
364 self.shape.len()
365 }
366
367 pub fn shape(&self) -> &[usize] {
369 &self.shape
370 }
371
372 pub fn numel(&self) -> usize {
374 self.data.len()
375 }
376
377 pub fn data(&self) -> &[f64] {
379 &self.data
380 }
381
382 fn linear_index(&self, indices: &[usize]) -> SolverResult<usize> {
384 if indices.len() != self.shape.len() {
385 return Err(SolverError::DimensionMismatch(format!(
386 "expected {} indices, got {}",
387 self.shape.len(),
388 indices.len()
389 )));
390 }
391 let mut idx = 0;
392 let mut stride = 1;
393 for d in (0..self.shape.len()).rev() {
394 if indices[d] >= self.shape[d] {
395 return Err(SolverError::DimensionMismatch(format!(
396 "index {} out of range for dimension {} with size {}",
397 indices[d], d, self.shape[d]
398 )));
399 }
400 idx += indices[d] * stride;
401 stride *= self.shape[d];
402 }
403 Ok(idx)
404 }
405
406 pub fn get(&self, indices: &[usize]) -> SolverResult<f64> {
408 let idx = self.linear_index(indices)?;
409 Ok(self.data[idx])
410 }
411
412 pub fn set(&mut self, indices: &[usize], value: f64) -> SolverResult<()> {
414 let idx = self.linear_index(indices)?;
415 self.data[idx] = value;
416 Ok(())
417 }
418
419 pub fn unfold(&self, mode: usize) -> SolverResult<Matrix> {
424 if mode >= self.ndim() {
425 return Err(SolverError::DimensionMismatch(format!(
426 "mode {} out of range for {}-dimensional tensor",
427 mode,
428 self.ndim()
429 )));
430 }
431
432 let rows = self.shape[mode];
433 let cols = self.numel() / rows;
434 let mut mat = Matrix::zeros(rows, cols);
435
436 let ndim = self.ndim();
437 let mut indices = vec![0usize; ndim];
438 for flat in 0..self.numel() {
439 let mut rem = flat;
441 for d in (0..ndim).rev() {
442 indices[d] = rem % self.shape[d];
443 rem /= self.shape[d];
444 }
445
446 let row = indices[mode];
448
449 let mut col = 0;
451 let mut col_stride = 1;
452 for d in (0..ndim).rev() {
453 if d != mode {
454 col += indices[d] * col_stride;
455 col_stride *= self.shape[d];
456 }
457 }
458
459 mat.set(row, col, self.data[flat]);
460 }
461
462 Ok(mat)
463 }
464
465 pub fn fold(matrix: &Matrix, mode: usize, shape: &[usize]) -> SolverResult<Tensor> {
467 let ndim = shape.len();
468 if mode >= ndim {
469 return Err(SolverError::DimensionMismatch(format!(
470 "mode {} out of range for {}-dimensional tensor",
471 mode, ndim
472 )));
473 }
474 if matrix.rows != shape[mode] {
475 return Err(SolverError::DimensionMismatch(format!(
476 "matrix rows {} != shape[{}] = {}",
477 matrix.rows, mode, shape[mode]
478 )));
479 }
480
481 let numel: usize = shape.iter().product();
482 let mut data = vec![0.0; numel];
483
484 let mut indices = vec![0usize; ndim];
485 for (flat, datum) in data.iter_mut().enumerate() {
486 let mut rem = flat;
487 for d in (0..ndim).rev() {
488 indices[d] = rem % shape[d];
489 rem /= shape[d];
490 }
491
492 let row = indices[mode];
493 let mut col = 0;
494 let mut col_stride = 1;
495 for d in (0..ndim).rev() {
496 if d != mode {
497 col += indices[d] * col_stride;
498 col_stride *= shape[d];
499 }
500 }
501
502 *datum = matrix.get(row, col);
503 }
504
505 Ok(Tensor {
506 shape: shape.to_vec(),
507 data,
508 })
509 }
510
511 pub fn frobenius_norm(&self) -> f64 {
513 self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
514 }
515}
516
517pub fn khatri_rao_product(a: &Matrix, b: &Matrix) -> SolverResult<Matrix> {
526 if a.cols != b.cols {
527 return Err(SolverError::DimensionMismatch(format!(
528 "khatri-rao requires same number of columns: {} vs {}",
529 a.cols, b.cols
530 )));
531 }
532 let r = a.cols;
533 let rows = a.rows * b.rows;
534 let mut out = Matrix::zeros(rows, r);
535 for col in 0..r {
536 for i in 0..a.rows {
537 for j in 0..b.rows {
538 out.set(i * b.rows + j, col, a.get(i, col) * b.get(j, col));
539 }
540 }
541 }
542 Ok(out)
543}
544
545pub fn hadamard_product(a: &Matrix, b: &Matrix) -> SolverResult<Matrix> {
547 if a.rows != b.rows || a.cols != b.cols {
548 return Err(SolverError::DimensionMismatch(format!(
549 "hadamard requires same dimensions: {}x{} vs {}x{}",
550 a.rows, a.cols, b.rows, b.cols
551 )));
552 }
553 let data: Vec<f64> = a
554 .data
555 .iter()
556 .zip(b.data.iter())
557 .map(|(x, y)| x * y)
558 .collect();
559 Matrix::new(a.rows, a.cols, data)
560}
561
562pub fn mode_n_product(tensor: &Tensor, matrix: &Matrix, mode: usize) -> SolverResult<Tensor> {
567 if mode >= tensor.ndim() {
568 return Err(SolverError::DimensionMismatch(format!(
569 "mode {} out of range for {}-dimensional tensor",
570 mode,
571 tensor.ndim()
572 )));
573 }
574 if matrix.cols != tensor.shape()[mode] {
575 return Err(SolverError::DimensionMismatch(format!(
576 "matrix cols {} != tensor dimension {} size {}",
577 matrix.cols,
578 mode,
579 tensor.shape()[mode]
580 )));
581 }
582
583 let unfolded = tensor.unfold(mode)?;
584 let result_mat = matrix.matmul(&unfolded)?;
585
586 let mut new_shape = tensor.shape().to_vec();
587 new_shape[mode] = matrix.rows;
588
589 Tensor::fold(&result_mat, mode, &new_shape)
590}
591
592#[derive(Debug, Clone)]
601pub struct CpDecomposition {
602 pub weights: Vec<f64>,
604 pub factors: Vec<Matrix>,
606}
607
608impl CpDecomposition {
609 pub fn rank(&self) -> usize {
611 self.weights.len()
612 }
613
614 pub fn reconstruct(&self) -> SolverResult<Tensor> {
616 if self.factors.is_empty() {
617 return Err(SolverError::InternalError(
618 "CP decomposition has no factors".to_string(),
619 ));
620 }
621
622 let shape: Vec<usize> = self.factors.iter().map(|f| f.rows).collect();
623 let numel: usize = shape.iter().product();
624 let ndim = shape.len();
625 let rank = self.rank();
626 let mut data = vec![0.0; numel];
627
628 let mut indices = vec![0usize; ndim];
629 for (flat, datum) in data.iter_mut().enumerate() {
630 let mut rem = flat;
631 for d in (0..ndim).rev() {
632 indices[d] = rem % shape[d];
633 rem /= shape[d];
634 }
635
636 let mut val = 0.0;
637 for r in 0..rank {
638 let mut term = self.weights[r];
639 for (d, idx) in indices.iter().enumerate() {
640 term *= self.factors[d].get(*idx, r);
641 }
642 val += term;
643 }
644 *datum = val;
645 }
646
647 Tensor::new(shape, data)
648 }
649
650 pub fn fit_error(&self, original: &Tensor) -> SolverResult<f64> {
652 let reconstructed = self.reconstruct()?;
653 let orig_norm = original.frobenius_norm();
654 if orig_norm < 1e-15 {
655 return Ok(0.0);
656 }
657 let diff_norm: f64 = original
658 .data()
659 .iter()
660 .zip(reconstructed.data().iter())
661 .map(|(a, b)| (a - b) * (a - b))
662 .sum::<f64>()
663 .sqrt();
664 Ok(diff_norm / orig_norm)
665 }
666}
667
668#[derive(Debug, Clone)]
670pub struct CpAlsConfig {
671 pub rank: usize,
673 pub max_iterations: usize,
675 pub tolerance: f64,
677 pub normalize_factors: bool,
679}
680
681impl Default for CpAlsConfig {
682 fn default() -> Self {
683 Self {
684 rank: 3,
685 max_iterations: 100,
686 tolerance: 1e-8,
687 normalize_factors: true,
688 }
689 }
690}
691
692pub fn cp_als(tensor: &Tensor, config: &CpAlsConfig) -> SolverResult<CpDecomposition> {
697 let ndim = tensor.ndim();
698 let rank = config.rank;
699
700 if rank == 0 {
701 return Err(SolverError::InternalError(
702 "CP rank must be positive".to_string(),
703 ));
704 }
705
706 let mut factors: Vec<Matrix> = Vec::with_capacity(ndim);
708 for n in 0..ndim {
709 let rows = tensor.shape()[n];
710 let mut data = vec![0.0; rows * rank];
711 for i in 0..rows {
712 for r in 0..rank {
713 data[i * rank + r] = ((i + 1) as f64 * (r + 1) as f64 * 0.37).sin().abs() + 0.01;
714 }
715 }
716 factors.push(Matrix::new(rows, rank, data)?);
717 }
718
719 let mut weights = vec![1.0; rank];
720 let mut prev_fit = f64::MAX;
721
722 for _iter in 0..config.max_iterations {
723 for n in 0..ndim {
724 let other_modes: Vec<usize> = (0..ndim).filter(|&m| m != n).collect();
729
730 let mut kr = factors[other_modes[0]].clone();
732 for &m in &other_modes[1..] {
733 kr = khatri_rao_product(&kr, &factors[m])?;
734 }
735
736 let mut gram = {
738 let ft = factors[other_modes[0]].transpose();
739 ft.matmul(&factors[other_modes[0]])?
740 };
741 for &m in &other_modes[1..] {
742 let ft = factors[m].transpose();
743 let g = ft.matmul(&factors[m])?;
744 gram = hadamard_product(&gram, &g)?;
745 }
746
747 let x_n = tensor.unfold(n)?;
749
750 let v = x_n.matmul(&kr)?;
752
753 let gram_inv = invert_small_matrix(&gram)?;
755 factors[n] = v.matmul(&gram_inv)?;
756 }
757
758 if config.normalize_factors {
760 for w in weights.iter_mut() {
761 *w = 1.0;
762 }
763 for factor in factors.iter_mut() {
764 let norms = factor.normalize_columns();
765 for (w, &norm) in weights.iter_mut().zip(norms.iter()) {
766 *w *= norm;
767 }
768 }
769 }
770
771 let decomp = CpDecomposition {
773 weights: weights.clone(),
774 factors: factors.clone(),
775 };
776 let fit = decomp.fit_error(tensor).unwrap_or(f64::MAX);
777 if (prev_fit - fit).abs() < config.tolerance {
778 return Ok(decomp);
779 }
780 prev_fit = fit;
781 }
782
783 Ok(CpDecomposition { weights, factors })
784}
785
786fn invert_small_matrix(m: &Matrix) -> SolverResult<Matrix> {
788 if m.rows != m.cols {
789 return Err(SolverError::DimensionMismatch(
790 "matrix must be square to invert".to_string(),
791 ));
792 }
793 let n = m.rows;
794 let mut aug = Matrix::zeros(n, 2 * n);
796 for r in 0..n {
797 for c in 0..n {
798 aug.set(r, c, m.get(r, c));
799 }
800 aug.set(r, n + r, 1.0);
801 }
802
803 for col in 0..n {
804 let mut max_val = aug.get(col, col).abs();
806 let mut max_row = col;
807 for r in (col + 1)..n {
808 let v = aug.get(r, col).abs();
809 if v > max_val {
810 max_val = v;
811 max_row = r;
812 }
813 }
814 if max_val < 1e-14 {
815 return Err(SolverError::SingularMatrix);
816 }
817
818 if max_row != col {
820 for c in 0..(2 * n) {
821 let tmp = aug.get(col, c);
822 aug.set(col, c, aug.get(max_row, c));
823 aug.set(max_row, c, tmp);
824 }
825 }
826
827 let pivot = aug.get(col, col);
829 for c in 0..(2 * n) {
830 aug.set(col, c, aug.get(col, c) / pivot);
831 }
832
833 for r in 0..n {
835 if r == col {
836 continue;
837 }
838 let factor = aug.get(r, col);
839 for c in 0..(2 * n) {
840 let v = aug.get(r, c) - factor * aug.get(col, c);
841 aug.set(r, c, v);
842 }
843 }
844 }
845
846 let mut inv = Matrix::zeros(n, n);
848 for r in 0..n {
849 for c in 0..n {
850 inv.set(r, c, aug.get(r, n + c));
851 }
852 }
853 Ok(inv)
854}
855
856#[derive(Debug, Clone)]
865pub struct TuckerDecomposition {
866 pub core: Tensor,
868 pub factors: Vec<Matrix>,
870}
871
872impl TuckerDecomposition {
873 pub fn reconstruct(&self) -> SolverResult<Tensor> {
875 let mut result = self.core.clone();
876 for (n, factor) in self.factors.iter().enumerate() {
877 result = mode_n_product(&result, factor, n)?;
878 }
879 Ok(result)
880 }
881
882 pub fn fit_error(&self, original: &Tensor) -> SolverResult<f64> {
884 let reconstructed = self.reconstruct()?;
885 let orig_norm = original.frobenius_norm();
886 if orig_norm < 1e-15 {
887 return Ok(0.0);
888 }
889 let diff_norm: f64 = original
890 .data()
891 .iter()
892 .zip(reconstructed.data().iter())
893 .map(|(a, b)| (a - b) * (a - b))
894 .sum::<f64>()
895 .sqrt();
896 Ok(diff_norm / orig_norm)
897 }
898
899 pub fn compression_ratio(&self, original_shape: &[usize]) -> f64 {
901 let original_size: usize = original_shape.iter().product();
902 let core_size = self.core.numel();
903 let factor_size: usize = self.factors.iter().map(|f| f.rows * f.cols).sum();
904 let decomp_size = core_size + factor_size;
905 if decomp_size == 0 {
906 return 0.0;
907 }
908 original_size as f64 / decomp_size as f64
909 }
910}
911
912#[derive(Debug, Clone)]
914pub struct TuckerConfig {
915 pub ranks: Vec<usize>,
917 pub max_iterations: usize,
919 pub tolerance: f64,
921}
922
923impl Default for TuckerConfig {
924 fn default() -> Self {
925 Self {
926 ranks: vec![2, 2, 2],
927 max_iterations: 50,
928 tolerance: 1e-8,
929 }
930 }
931}
932
933pub fn tucker_hosvd(tensor: &Tensor, config: &TuckerConfig) -> SolverResult<TuckerDecomposition> {
938 let ndim = tensor.ndim();
939 if config.ranks.len() != ndim {
940 return Err(SolverError::DimensionMismatch(format!(
941 "Tucker config requires {} ranks (one per mode), got {}",
942 ndim,
943 config.ranks.len()
944 )));
945 }
946
947 let mut factors: Vec<Matrix> = Vec::with_capacity(ndim);
949 for n in 0..ndim {
950 let unfolded = tensor.unfold(n)?;
951 let rank_n = config.ranks[n].min(tensor.shape()[n]);
952 let (u, _sigma, _v) = unfolded.svd_truncated(rank_n)?;
953 factors.push(u);
954 }
955
956 let mut core = tensor.clone();
958 for (n, factor) in factors.iter().enumerate() {
959 let ft = factor.transpose();
960 core = mode_n_product(&core, &ft, n)?;
961 }
962
963 Ok(TuckerDecomposition { core, factors })
964}
965
966pub fn tucker_hooi(tensor: &Tensor, config: &TuckerConfig) -> SolverResult<TuckerDecomposition> {
971 let ndim = tensor.ndim();
972 if config.ranks.len() != ndim {
973 return Err(SolverError::DimensionMismatch(format!(
974 "Tucker config requires {} ranks (one per mode), got {}",
975 ndim,
976 config.ranks.len()
977 )));
978 }
979
980 let mut decomp = tucker_hosvd(tensor, config)?;
982 let mut prev_core_norm = decomp.core.frobenius_norm();
983
984 for _iter in 0..config.max_iterations {
985 for n in 0..ndim {
986 let mut y = tensor.clone();
988 for (m, factor) in decomp.factors.iter().enumerate() {
989 if m != n {
990 let ft = factor.transpose();
991 y = mode_n_product(&y, &ft, m)?;
992 }
993 }
994
995 let y_n = y.unfold(n)?;
997 let rank_n = config.ranks[n].min(tensor.shape()[n]);
998 let (u, _sigma, _v) = y_n.svd_truncated(rank_n)?;
999 decomp.factors[n] = u;
1000 }
1001
1002 let mut core = tensor.clone();
1004 for (n, factor) in decomp.factors.iter().enumerate() {
1005 let ft = factor.transpose();
1006 core = mode_n_product(&core, &ft, n)?;
1007 }
1008 decomp.core = core;
1009
1010 let core_norm = decomp.core.frobenius_norm();
1012 if (core_norm - prev_core_norm).abs() / (prev_core_norm + 1e-15) < config.tolerance {
1013 break;
1014 }
1015 prev_core_norm = core_norm;
1016 }
1017
1018 Ok(decomp)
1019}
1020
1021#[derive(Debug, Clone)]
1032pub struct TtDecomposition {
1033 pub cores: Vec<Tensor>,
1035}
1036
1037impl TtDecomposition {
1038 pub fn ranks(&self) -> Vec<usize> {
1040 let mut ranks = Vec::with_capacity(self.cores.len() + 1);
1041 if self.cores.is_empty() {
1042 return ranks;
1043 }
1044 ranks.push(self.cores[0].shape()[0]);
1045 for core in &self.cores {
1046 ranks.push(core.shape()[2]);
1047 }
1048 ranks
1049 }
1050
1051 pub fn reconstruct(&self) -> SolverResult<Tensor> {
1053 if self.cores.is_empty() {
1054 return Err(SolverError::InternalError(
1055 "TT decomposition has no cores".to_string(),
1056 ));
1057 }
1058
1059 let shape: Vec<usize> = self.cores.iter().map(|c| c.shape()[1]).collect();
1060 let ndim = shape.len();
1061 let numel: usize = shape.iter().product();
1062 let mut data = vec![0.0; numel];
1063
1064 let mut indices = vec![0usize; ndim];
1065 for (flat, datum) in data.iter_mut().enumerate() {
1066 let mut rem = flat;
1067 for d in (0..ndim).rev() {
1068 indices[d] = rem % shape[d];
1069 rem /= shape[d];
1070 }
1071
1072 let core0 = &self.cores[0];
1075 let r1 = core0.shape()[2];
1076 let mut current: Vec<f64> = (0..r1)
1077 .map(|j| core0.get(&[0, indices[0], j]))
1078 .collect::<SolverResult<_>>()?;
1079
1080 for (k, &idx_k) in indices.iter().enumerate().skip(1) {
1081 let core_k = &self.cores[k];
1082 let r_next = core_k.shape()[2];
1083 let mut next = vec![0.0; r_next];
1084 for (j, nj) in next.iter_mut().enumerate() {
1085 let mut sum = 0.0;
1086 for (i, &ci) in current.iter().enumerate() {
1087 sum += ci * core_k.get(&[i, idx_k, j])?;
1088 }
1089 *nj = sum;
1090 }
1091 current = next;
1092 }
1093
1094 *datum = current[0]; }
1096
1097 Tensor::new(shape, data)
1098 }
1099
1100 pub fn fit_error(&self, original: &Tensor) -> SolverResult<f64> {
1102 let reconstructed = self.reconstruct()?;
1103 let orig_norm = original.frobenius_norm();
1104 if orig_norm < 1e-15 {
1105 return Ok(0.0);
1106 }
1107 let diff_norm: f64 = original
1108 .data()
1109 .iter()
1110 .zip(reconstructed.data().iter())
1111 .map(|(a, b)| (a - b) * (a - b))
1112 .sum::<f64>()
1113 .sqrt();
1114 Ok(diff_norm / orig_norm)
1115 }
1116
1117 pub fn compression_ratio(&self, original_shape: &[usize]) -> f64 {
1119 let original_size: usize = original_shape.iter().product();
1120 let decomp_size: usize = self.cores.iter().map(|c| c.numel()).sum();
1121 if decomp_size == 0 {
1122 return 0.0;
1123 }
1124 original_size as f64 / decomp_size as f64
1125 }
1126
1127 pub fn tt_round(&self, max_rank: usize) -> SolverResult<TtDecomposition> {
1132 if self.cores.is_empty() {
1133 return Err(SolverError::InternalError(
1134 "TT decomposition has no cores".to_string(),
1135 ));
1136 }
1137
1138 let ndim = self.cores.len();
1139 let mut cores = self.cores.clone();
1140
1141 for k in 0..(ndim - 1) {
1143 let r_prev = cores[k].shape()[0];
1144 let n_k = cores[k].shape()[1];
1145 let r_next = cores[k].shape()[2];
1146
1147 let mat = Matrix::new(r_prev * n_k, r_next, cores[k].data().to_vec())?;
1149 let (q, r_mat) = qr_gram_schmidt(&mat);
1150
1151 let new_r = q.cols;
1152 cores[k] = Tensor::new(vec![r_prev, n_k, new_r], q.data.clone())?;
1153
1154 let r_next2 = cores[k + 1].shape()[2];
1156 let n_next = cores[k + 1].shape()[1];
1157 let next_mat = Matrix::new(r_next, n_next * r_next2, cores[k + 1].data().to_vec())?;
1158 let absorbed = r_mat.matmul(&next_mat)?;
1159 cores[k + 1] = Tensor::new(vec![new_r, n_next, r_next2], absorbed.data.clone())?;
1160 }
1161
1162 for k in (1..ndim).rev() {
1164 let r_prev = cores[k].shape()[0];
1165 let n_k = cores[k].shape()[1];
1166 let r_next = cores[k].shape()[2];
1167
1168 let mat = Matrix::new(r_prev, n_k * r_next, cores[k].data().to_vec())?;
1170 let trunc_rank = max_rank.min(r_prev).min(n_k * r_next);
1171 let (u, sigma, v) = mat.svd_truncated(trunc_rank)?;
1172
1173 let mut sv = Matrix::zeros(trunc_rank, n_k * r_next);
1175 for (i, &si) in sigma.iter().enumerate().take(trunc_rank) {
1176 for j in 0..(n_k * r_next) {
1177 sv.set(i, j, si * v.get(j, i));
1178 }
1179 }
1180 cores[k] = Tensor::new(vec![trunc_rank, n_k, r_next], sv.data.clone())?;
1181
1182 let prev_r_prev = cores[k - 1].shape()[0];
1184 let prev_n = cores[k - 1].shape()[1];
1185 let prev_mat = Matrix::new(prev_r_prev * prev_n, r_prev, cores[k - 1].data().to_vec())?;
1186 let absorbed = prev_mat.matmul(&u)?;
1187 cores[k - 1] =
1188 Tensor::new(vec![prev_r_prev, prev_n, trunc_rank], absorbed.data.clone())?;
1189 }
1190
1191 Ok(TtDecomposition { cores })
1192 }
1193}
1194
1195#[derive(Debug, Clone)]
1197pub struct TtConfig {
1198 pub max_rank: usize,
1200 pub tolerance: f64,
1202}
1203
1204impl Default for TtConfig {
1205 fn default() -> Self {
1206 Self {
1207 max_rank: 10,
1208 tolerance: 1e-8,
1209 }
1210 }
1211}
1212
1213pub fn tt_svd(tensor: &Tensor, config: &TtConfig) -> SolverResult<TtDecomposition> {
1218 let ndim = tensor.ndim();
1219 if ndim < 2 {
1220 let n = tensor.shape()[0];
1222 let core = Tensor::new(vec![1, n, 1], tensor.data().to_vec())?;
1223 return Ok(TtDecomposition { cores: vec![core] });
1224 }
1225
1226 if config.max_rank == 0 {
1227 return Err(SolverError::InternalError(
1228 "TT max_rank must be positive".to_string(),
1229 ));
1230 }
1231
1232 let shape = tensor.shape().to_vec();
1233 let mut cores: Vec<Tensor> = Vec::with_capacity(ndim);
1234 let mut remaining_data = tensor.data().to_vec();
1235 let mut r_prev = 1usize;
1236
1237 for k in 0..(ndim - 1) {
1238 let n_k = shape[k];
1239 let remaining_size: usize = shape[(k + 1)..].iter().product();
1240
1241 let rows = r_prev * n_k;
1243 let cols = remaining_size;
1244
1245 let actual_len = remaining_data.len();
1247 if actual_len != rows * cols {
1248 return Err(SolverError::InternalError(format!(
1249 "TT-SVD reshape error at mode {}: expected {} elements, have {}",
1250 k,
1251 rows * cols,
1252 actual_len
1253 )));
1254 }
1255
1256 let mat = Matrix::new(rows, cols, remaining_data)?;
1257
1258 let trunc_rank = config.max_rank.min(rows).min(cols);
1260 let (u, sigma, v) = mat.svd_truncated(trunc_rank)?;
1261
1262 let total_sv_norm: f64 = sigma.iter().map(|s| s * s).sum::<f64>().sqrt();
1264 let mut effective_rank = trunc_rank;
1265 if total_sv_norm > 1e-15 {
1266 let mut accumulated = 0.0;
1267 for (i, &s) in sigma.iter().enumerate().rev() {
1268 accumulated += s * s;
1269 if accumulated.sqrt() / total_sv_norm > config.tolerance {
1270 effective_rank = i + 1;
1271 break;
1272 }
1273 }
1274 }
1275 effective_rank = effective_rank.min(trunc_rank);
1276 if effective_rank == 0 {
1277 effective_rank = 1;
1278 }
1279
1280 let mut core_data = vec![0.0; r_prev * n_k * effective_rank];
1282 for i in 0..rows {
1283 for j in 0..effective_rank {
1284 core_data[i * effective_rank + j] = u.get(i, j);
1285 }
1286 }
1287 cores.push(Tensor::new(vec![r_prev, n_k, effective_rank], core_data)?);
1288
1289 let new_cols = cols;
1291 let mut new_remaining = vec![0.0; effective_rank * new_cols];
1292 for i in 0..effective_rank {
1293 for j in 0..new_cols {
1294 new_remaining[i * new_cols + j] = sigma[i] * v.get(j, i);
1295 }
1296 }
1297 remaining_data = new_remaining;
1298 r_prev = effective_rank;
1299 }
1300
1301 let n_last = shape[ndim - 1];
1303 if remaining_data.len() != r_prev * n_last {
1304 return Err(SolverError::InternalError(format!(
1305 "TT-SVD final reshape error: expected {} elements, have {}",
1306 r_prev * n_last,
1307 remaining_data.len()
1308 )));
1309 }
1310 let mut last_core_data = vec![0.0; r_prev * n_last];
1311 for i in 0..r_prev {
1312 for j in 0..n_last {
1313 last_core_data[i * n_last + j] = remaining_data[i * n_last + j];
1314 }
1315 }
1316 cores.push(Tensor::new(vec![r_prev, n_last, 1], last_core_data)?);
1317
1318 Ok(TtDecomposition { cores })
1319}
1320
1321#[cfg(test)]
1326mod tests {
1327 use super::*;
1328
1329 fn make_test_tensor_3d() -> Tensor {
1331 let shape = vec![3, 4, 2];
1332 let data: Vec<f64> = (0..24).map(|i| (i as f64) * 0.5 + 1.0).collect();
1333 Tensor::new(shape, data).expect("failed to create test tensor")
1334 }
1335
1336 fn make_rank1_tensor() -> Tensor {
1338 let a = [1.0, 2.0, 3.0];
1339 let b = [1.0, 2.0];
1340 let c = [1.0, 2.0, 3.0, 4.0];
1341 let shape = vec![3, 2, 4];
1342 let mut data = vec![0.0; 24];
1343 for i in 0..3 {
1344 for j in 0..2 {
1345 for k in 0..4 {
1346 data[i * 8 + j * 4 + k] = a[i] * b[j] * c[k];
1347 }
1348 }
1349 }
1350 Tensor::new(shape, data).expect("failed to create rank-1 tensor")
1351 }
1352
1353 #[test]
1354 fn test_tensor_creation_and_indexing() {
1355 let t = make_test_tensor_3d();
1356 assert_eq!(t.ndim(), 3);
1357 assert_eq!(t.shape(), &[3, 4, 2]);
1358 assert_eq!(t.numel(), 24);
1359
1360 let v = t.get(&[0, 0, 0]).expect("get failed");
1362 assert!((v - 1.0).abs() < 1e-12);
1363
1364 let v = t.get(&[2, 3, 1]).expect("get failed");
1366 assert!((v - 12.5).abs() < 1e-12);
1367 }
1368
1369 #[test]
1370 fn test_tensor_set() {
1371 let mut t = make_test_tensor_3d();
1372 t.set(&[1, 2, 0], 99.0).expect("set failed");
1373 let v = t.get(&[1, 2, 0]).expect("get failed");
1374 assert!((v - 99.0).abs() < 1e-12);
1375 }
1376
1377 #[test]
1378 fn test_tensor_index_out_of_range() {
1379 let t = make_test_tensor_3d();
1380 assert!(t.get(&[3, 0, 0]).is_err());
1381 assert!(t.get(&[0, 4, 0]).is_err());
1382 assert!(t.get(&[0, 0]).is_err()); }
1384
1385 #[test]
1386 fn test_mode_n_unfolding_and_folding_roundtrip() {
1387 let t = make_test_tensor_3d();
1388 for mode in 0..3 {
1389 let mat = t.unfold(mode).expect("unfold failed");
1390
1391 assert_eq!(mat.rows, t.shape()[mode]);
1393 assert_eq!(mat.rows * mat.cols, t.numel());
1394
1395 let recovered = Tensor::fold(&mat, mode, t.shape()).expect("fold failed");
1397 for i in 0..t.numel() {
1398 assert!(
1399 (t.data()[i] - recovered.data()[i]).abs() < 1e-12,
1400 "mismatch at element {} for mode {} unfold/fold",
1401 i,
1402 mode
1403 );
1404 }
1405 }
1406 }
1407
1408 #[test]
1409 fn test_matrix_svd_truncated() {
1410 let mut data = vec![0.0; 12];
1412 for i in 0..4 {
1413 for j in 0..3 {
1414 data[i * 3 + j] = (i + 1) as f64 * (j + 1) as f64;
1415 }
1416 }
1417 let m = Matrix::new(4, 3, data).expect("matrix creation failed");
1418 let (u, sigma, v) = m.svd_truncated(2).expect("svd failed");
1419
1420 assert_eq!(u.rows, 4);
1421 assert_eq!(u.cols, 2);
1422 assert_eq!(sigma.len(), 2);
1423 assert_eq!(v.rows, 3);
1424 assert_eq!(v.cols, 2);
1425
1426 assert!(sigma[0] >= sigma[1]);
1428 assert!(sigma[0] > 0.0);
1429
1430 let mut reconstructed = Matrix::zeros(4, 3);
1432 for i in 0..4 {
1433 for j in 0..3 {
1434 let mut val = 0.0;
1435 for (r, sigma_r) in sigma.iter().enumerate().take(2) {
1436 val += u.get(i, r) * sigma_r * v.get(j, r);
1437 }
1438 reconstructed.set(i, j, val);
1439 }
1440 }
1441 let error = (m.data)
1442 .iter()
1443 .zip(reconstructed.data.iter())
1444 .map(|(a, b)| (a - b) * (a - b))
1445 .sum::<f64>()
1446 .sqrt();
1447 let norm = m.frobenius_norm();
1448 assert!(
1450 error / norm < 0.05,
1451 "SVD reconstruction error too large: {}",
1452 error / norm
1453 );
1454 }
1455
1456 #[test]
1457 fn test_khatri_rao_product() {
1458 let a = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]).expect("a");
1459 let b = Matrix::new(3, 2, vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]).expect("b");
1460 let kr = khatri_rao_product(&a, &b).expect("kr");
1461
1462 assert_eq!(kr.rows, 6);
1463 assert_eq!(kr.cols, 2);
1464
1465 assert!((kr.get(0, 0) - 5.0).abs() < 1e-12);
1467 assert!((kr.get(1, 0) - 7.0).abs() < 1e-12);
1468 assert!((kr.get(2, 0) - 9.0).abs() < 1e-12);
1469 assert!((kr.get(3, 0) - 15.0).abs() < 1e-12);
1470 }
1471
1472 #[test]
1473 fn test_hadamard_product() {
1474 let a = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("a");
1475 let b = Matrix::new(2, 3, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).expect("b");
1476 let h = hadamard_product(&a, &b).expect("hadamard");
1477
1478 assert_eq!(h.rows, 2);
1479 assert_eq!(h.cols, 3);
1480 assert!((h.get(0, 0) - 7.0).abs() < 1e-12);
1481 assert!((h.get(1, 2) - 72.0).abs() < 1e-12);
1482 }
1483
1484 #[test]
1485 fn test_hadamard_product_dimension_mismatch() {
1486 let a = Matrix::new(2, 3, vec![0.0; 6]).expect("a");
1487 let b = Matrix::new(3, 2, vec![0.0; 6]).expect("b");
1488 assert!(hadamard_product(&a, &b).is_err());
1489 }
1490
1491 #[test]
1492 fn test_mode_n_product() {
1493 let t = make_test_tensor_3d(); let m = Matrix::new(5, 4, vec![0.1; 20]).expect("matrix");
1495 let result = mode_n_product(&t, &m, 1).expect("mode_n_product");
1496 assert_eq!(result.shape(), &[3, 5, 2]);
1497 }
1498
1499 #[test]
1500 fn test_cp_als_rank1_tensor() {
1501 let t = make_rank1_tensor();
1502 let config = CpAlsConfig {
1503 rank: 1,
1504 max_iterations: 200,
1505 tolerance: 1e-10,
1506 normalize_factors: true,
1507 };
1508 let decomp = cp_als(&t, &config).expect("cp_als failed");
1509
1510 assert_eq!(decomp.rank(), 1);
1511 assert_eq!(decomp.factors.len(), 3);
1512 assert_eq!(decomp.factors[0].rows, 3);
1513 assert_eq!(decomp.factors[1].rows, 2);
1514 assert_eq!(decomp.factors[2].rows, 4);
1515
1516 let error = decomp.fit_error(&t).expect("fit_error failed");
1518 assert!(error < 0.01, "CP-ALS rank-1 error too large: {}", error);
1519 }
1520
1521 #[test]
1522 fn test_cp_als_rank3_tensor() {
1523 let shape = vec![4, 3, 5];
1525 let numel = 60;
1526 let mut data = vec![0.0; numel];
1527 for r in 0..3 {
1528 for i in 0..4 {
1529 for j in 0..3 {
1530 for k in 0..5 {
1531 let a_val = ((i + 1) as f64 * (r + 1) as f64 * 0.3).sin();
1532 let b_val = ((j + 1) as f64 * (r + 1) as f64 * 0.5).cos();
1533 let c_val = ((k + 1) as f64 * (r + 1) as f64 * 0.7).sin();
1534 data[i * 15 + j * 5 + k] += a_val * b_val * c_val;
1535 }
1536 }
1537 }
1538 }
1539 let t = Tensor::new(shape, data).expect("tensor");
1540
1541 let config = CpAlsConfig {
1542 rank: 3,
1543 max_iterations: 300,
1544 tolerance: 1e-10,
1545 normalize_factors: true,
1546 };
1547 let decomp = cp_als(&t, &config).expect("cp_als failed");
1548 assert_eq!(decomp.rank(), 3);
1549
1550 let error = decomp.fit_error(&t).expect("fit_error");
1551 assert!(error < 0.5, "CP-ALS rank-3 error too large: {}", error);
1552 }
1553
1554 #[test]
1555 fn test_cp_reconstruction() {
1556 let t = make_rank1_tensor();
1557 let config = CpAlsConfig {
1558 rank: 1,
1559 max_iterations: 200,
1560 tolerance: 1e-10,
1561 normalize_factors: true,
1562 };
1563 let decomp = cp_als(&t, &config).expect("cp_als");
1564 let recon = decomp.reconstruct().expect("reconstruct");
1565
1566 assert_eq!(recon.shape(), t.shape());
1567 assert_eq!(recon.numel(), t.numel());
1568 }
1569
1570 #[test]
1571 fn test_tucker_hosvd() {
1572 let t = make_test_tensor_3d();
1573 let config = TuckerConfig {
1574 ranks: vec![2, 3, 2],
1575 max_iterations: 50,
1576 tolerance: 1e-8,
1577 };
1578 let decomp = tucker_hosvd(&t, &config).expect("tucker_hosvd");
1579
1580 assert_eq!(decomp.core.shape(), &[2, 3, 2]);
1581 assert_eq!(decomp.factors.len(), 3);
1582 assert_eq!(decomp.factors[0].rows, 3); assert_eq!(decomp.factors[0].cols, 2); assert_eq!(decomp.factors[1].rows, 4);
1585 assert_eq!(decomp.factors[1].cols, 3);
1586 assert_eq!(decomp.factors[2].rows, 2);
1587 assert_eq!(decomp.factors[2].cols, 2);
1588
1589 let error = decomp.fit_error(&t).expect("fit_error");
1591 assert!(error < 0.5, "Tucker HOSVD error too large: {}", error);
1592 }
1593
1594 #[test]
1595 fn test_tucker_hooi_convergence() {
1596 let shape = vec![6, 6, 6];
1598 let data: Vec<f64> = (0..216)
1599 .map(|i| ((i as f64) * 0.13 + 0.7).sin() * ((i as f64) * 0.07).cos())
1600 .collect();
1601 let t = Tensor::new(shape, data).expect("tensor");
1602
1603 let config = TuckerConfig {
1604 ranks: vec![2, 2, 2],
1605 max_iterations: 30,
1606 tolerance: 1e-8,
1607 };
1608
1609 let hosvd_decomp = tucker_hosvd(&t, &config).expect("hosvd");
1610 let hooi_decomp = tucker_hooi(&t, &config).expect("hooi");
1611
1612 let hosvd_error = hosvd_decomp.fit_error(&t).expect("fit");
1613 let hooi_error = hooi_decomp.fit_error(&t).expect("fit");
1614
1615 assert!(
1617 hooi_error <= hosvd_error + 0.05,
1618 "HOOI error {} should not be much worse than HOSVD error {}",
1619 hooi_error,
1620 hosvd_error
1621 );
1622 }
1623
1624 #[test]
1625 fn test_tucker_compression_ratio() {
1626 let t = make_test_tensor_3d();
1627 let config = TuckerConfig {
1628 ranks: vec![2, 2, 2],
1629 max_iterations: 50,
1630 tolerance: 1e-8,
1631 };
1632 let decomp = tucker_hosvd(&t, &config).expect("tucker");
1633 let ratio = decomp.compression_ratio(t.shape());
1634
1635 assert!(ratio > 0.0);
1638 }
1639
1640 #[test]
1641 fn test_tt_svd_3d_tensor() {
1642 let t = make_test_tensor_3d();
1643 let config = TtConfig {
1644 max_rank: 10,
1645 tolerance: 1e-10,
1646 };
1647 let decomp = tt_svd(&t, &config).expect("tt_svd");
1648
1649 assert_eq!(decomp.cores.len(), 3);
1650
1651 assert_eq!(decomp.cores[0].shape()[0], 1);
1653 assert_eq!(decomp.cores[0].shape()[1], 3);
1654
1655 assert_eq!(decomp.cores[2].shape()[1], 2);
1657 assert_eq!(decomp.cores[2].shape()[2], 1);
1658
1659 let ranks = decomp.ranks();
1661 assert_eq!(ranks[0], 1);
1662 assert_eq!(*ranks.last().expect("no ranks"), 1);
1663
1664 let recon = decomp.reconstruct().expect("reconstruct");
1666 assert_eq!(recon.shape(), t.shape());
1667 }
1668
1669 #[test]
1670 fn test_tt_svd_4d_tensor() {
1671 let shape = vec![2, 3, 2, 4];
1672 let data: Vec<f64> = (0..48).map(|i| (i as f64 + 1.0) * 0.1).collect();
1673 let t = Tensor::new(shape, data).expect("tensor");
1674
1675 let config = TtConfig {
1676 max_rank: 10,
1677 tolerance: 1e-10,
1678 };
1679 let decomp = tt_svd(&t, &config).expect("tt_svd");
1680 assert_eq!(decomp.cores.len(), 4);
1681
1682 let ranks = decomp.ranks();
1683 assert_eq!(ranks[0], 1);
1684 assert_eq!(ranks[4], 1);
1685 }
1686
1687 #[test]
1688 fn test_tt_reconstruction_error() {
1689 let t = make_rank1_tensor();
1690 let config = TtConfig {
1691 max_rank: 5,
1692 tolerance: 1e-10,
1693 };
1694 let decomp = tt_svd(&t, &config).expect("tt_svd");
1695 let error = decomp.fit_error(&t).expect("fit_error");
1696
1697 assert!(error < 0.1, "TT reconstruction error too large: {}", error);
1699 }
1700
1701 #[test]
1702 fn test_tt_rank_truncation() {
1703 let t = make_test_tensor_3d();
1704 let config = TtConfig {
1705 max_rank: 10,
1706 tolerance: 1e-10,
1707 };
1708 let decomp = tt_svd(&t, &config).expect("tt_svd");
1709
1710 let truncated = decomp.tt_round(1).expect("tt_round");
1712 let trunc_ranks = truncated.ranks();
1713 for &r in &trunc_ranks {
1714 assert!(r <= 1, "rank {} exceeds max_rank 1", r);
1715 }
1716 }
1717
1718 #[test]
1719 fn test_tt_compression_ratio() {
1720 let t = make_test_tensor_3d();
1721 let config = TtConfig {
1722 max_rank: 2,
1723 tolerance: 1e-10,
1724 };
1725 let decomp = tt_svd(&t, &config).expect("tt_svd");
1726 let ratio = decomp.compression_ratio(t.shape());
1727 assert!(ratio > 0.0);
1728 }
1729
1730 #[test]
1731 fn test_1d_tensor_vector() {
1732 let t = Tensor::new(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("1d tensor");
1733 assert_eq!(t.ndim(), 1);
1734 assert_eq!(t.numel(), 5);
1735
1736 let v = t.get(&[3]).expect("get");
1737 assert!((v - 4.0).abs() < 1e-12);
1738
1739 let config = TtConfig {
1741 max_rank: 5,
1742 tolerance: 1e-10,
1743 };
1744 let decomp = tt_svd(&t, &config).expect("tt_svd 1d");
1745 assert_eq!(decomp.cores.len(), 1);
1746 assert_eq!(decomp.cores[0].shape(), &[1, 5, 1]);
1747 }
1748
1749 #[test]
1750 fn test_config_validation() {
1751 let t = make_test_tensor_3d();
1752
1753 let config = CpAlsConfig {
1755 rank: 0,
1756 ..Default::default()
1757 };
1758 assert!(cp_als(&t, &config).is_err());
1759
1760 let config = TuckerConfig {
1762 ranks: vec![2, 2], ..Default::default()
1764 };
1765 assert!(tucker_hosvd(&t, &config).is_err());
1766
1767 let config = TtConfig {
1769 max_rank: 0,
1770 tolerance: 1e-8,
1771 };
1772 assert!(tt_svd(&t, &config).is_err());
1773 }
1774}