Skip to main content

oxiphysics_core/tensor/
operations.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5#![allow(clippy::needless_range_loop)]
6use super::types::*;
7
8/// Einstein summation helper for common rank-2 tensor operations.
9///
10/// Supported notation strings:
11/// - `"ij,jk->ik"` : matrix–matrix product
12/// - `"ij,ij->"` : Frobenius inner product (returns scalar in a 1-element vec)
13/// - `"ij->ji"` : transpose
14/// - `"ii->"` : trace (returns scalar in a 1-element vec)
15#[allow(dead_code)]
16pub fn einsum_2d(notation: &str, a: &[Vec<f64>], b: Option<&[Vec<f64>]>) -> Vec<Vec<f64>> {
17    let notation = notation.trim();
18    match notation {
19        "ij,jk->ik" => {
20            let b = b.expect("einsum 'ij,jk->ik' requires two operands");
21            let m = a.len();
22            let k = a[0].len();
23            let n = b[0].len();
24            let mut c = vec![vec![0.0; n]; m];
25            for i in 0..m {
26                for kk in 0..k {
27                    for j in 0..n {
28                        c[i][j] += a[i][kk] * b[kk][j];
29                    }
30                }
31            }
32            c
33        }
34        "ij,ij->" => {
35            let b = b.expect("einsum 'ij,ij->' requires two operands");
36            let s: f64 = a
37                .iter()
38                .zip(b.iter())
39                .flat_map(|(ar, br)| ar.iter().zip(br.iter()).map(|(&x, &y)| x * y))
40                .sum();
41            vec![vec![s]]
42        }
43        "ij->ji" => {
44            let m = a.len();
45            let n = a[0].len();
46            let mut out = vec![vec![0.0; m]; n];
47            for i in 0..m {
48                for j in 0..n {
49                    out[j][i] = a[i][j];
50                }
51            }
52            out
53        }
54        "ii->" => {
55            let s: f64 = a.iter().enumerate().map(|(i, row)| row[i]).sum();
56            vec![vec![s]]
57        }
58        _ => panic!("einsum_2d: unsupported notation '{notation}'"),
59    }
60}
61/// Helper: multiply m×k matrix by k×n matrix.
62#[allow(dead_code)]
63pub(super) fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
64    let m = a.len();
65    let k = a[0].len();
66    let n = b[0].len();
67    let mut c = vec![vec![0.0; n]; m];
68    for i in 0..m {
69        for p in 0..k {
70            for j in 0..n {
71                c[i][j] += a[i][p] * b[p][j];
72            }
73        }
74    }
75    c
76}
77/// Helper: element-wise (Hadamard) product of two matrices of the same shape.
78#[allow(dead_code)]
79fn hadamard(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
80    a.iter()
81        .zip(b.iter())
82        .map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(&x, &y)| x * y).collect())
83        .collect()
84}
85/// Helper: transpose a matrix.
86#[allow(dead_code)]
87pub(super) fn transpose(a: &[Vec<f64>]) -> Vec<Vec<f64>> {
88    if a.is_empty() {
89        return vec![];
90    }
91    let m = a.len();
92    let n = a[0].len();
93    let mut t = vec![vec![0.0; m]; n];
94    for i in 0..m {
95        for j in 0..n {
96            t[j][i] = a[i][j];
97        }
98    }
99    t
100}
101/// Helper: solve a small linear system A x = b via Gaussian elimination (dense).
102#[allow(dead_code)]
103fn solve_ls(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
104    let r = a.len();
105    let m = b[0].len();
106    let mut aug: Vec<Vec<f64>> = a
107        .iter()
108        .enumerate()
109        .map(|(i, row)| {
110            let mut r = row.clone();
111            r.extend_from_slice(&b[i]);
112            r
113        })
114        .collect();
115    for col in 0..r {
116        let mut pivot = col;
117        for row in (col + 1)..r {
118            if aug[row][col].abs() > aug[pivot][col].abs() {
119                pivot = row;
120            }
121        }
122        aug.swap(col, pivot);
123        let d = aug[col][col];
124        if d.abs() < 1e-14 {
125            continue;
126        }
127        for j in col..r + m {
128            aug[col][j] /= d;
129        }
130        for row in 0..r {
131            if row == col {
132                continue;
133            }
134            let factor = aug[row][col];
135            for j in col..r + m {
136                aug[row][j] -= factor * aug[col][j];
137            }
138        }
139    }
140    let mut x = vec![vec![0.0; m]; r];
141    for i in 0..r {
142        for j in 0..m {
143            x[i][j] = aug[i][r + j];
144        }
145    }
146    x
147}
148/// Alternating Least Squares (ALS) for CP decomposition of a 3-way tensor.
149///
150/// Decomposes `tensor` (shape n0 × n1 × n2) into `rank` components.
151/// Returns the factor matrices A (n0×rank), B (n1×rank), C (n2×rank) and
152/// normalisation weights λ.
153///
154/// # Arguments
155/// * `tensor`   – reference to a `DenseTensor` with rank 3
156/// * `rank`     – number of CP components
157/// * `max_iter` – maximum ALS iterations
158/// * `tol`      – relative reconstruction error tolerance for convergence
159#[allow(dead_code)]
160pub fn cp_als(tensor: &DenseTensor, rank: usize, max_iter: usize, tol: f64) -> CpDecomposition {
161    assert_eq!(tensor.shape.len(), 3, "cp_als requires a rank-3 tensor");
162    let n0 = tensor.shape[0];
163    let n1 = tensor.shape[1];
164    let n2 = tensor.shape[2];
165    let init_val = 1.0 / rank as f64;
166    let mut a: Vec<Vec<f64>> = (0..n0)
167        .map(|i| {
168            (0..rank)
169                .map(|r| ((i + r) as f64 + 1.0) * init_val)
170                .collect()
171        })
172        .collect();
173    let mut b: Vec<Vec<f64>> = (0..n1)
174        .map(|j| (0..rank).map(|r| ((j + r + 1) as f64) * init_val).collect())
175        .collect();
176    let mut c: Vec<Vec<f64>> = (0..n2)
177        .map(|k| (0..rank).map(|r| ((k + r + 2) as f64) * init_val).collect())
178        .collect();
179    let x0 = tensor.mode_n_unfold(0);
180    let x1 = tensor.mode_n_unfold(1);
181    let x2 = tensor.mode_n_unfold(2);
182    let mut prev_err = f64::MAX;
183    for _iter in 0..max_iter {
184        let c_tc = gram(&c);
185        let b_tb = gram(&b);
186        let v_cb = hadamard(&c_tc, &b_tb);
187        let khatri_rao_cb = khatri_rao(&c, &b);
188        let rhs_a = matmul(&x0, &khatri_rao_cb);
189        let rhs_at = transpose(&rhs_a);
190        let sol_a = solve_ls(&v_cb, &rhs_at);
191        a = transpose(&sol_a);
192        let a_ta = gram(&a);
193        let v_ca = hadamard(&c_tc, &a_ta);
194        let khatri_rao_ca = khatri_rao(&c, &a);
195        let rhs_b = matmul(&x1, &khatri_rao_ca);
196        let rhs_bt = transpose(&rhs_b);
197        let sol_b = solve_ls(&v_ca, &rhs_bt);
198        b = transpose(&sol_b);
199        let b_tb2 = gram(&b);
200        let a_ta2 = gram(&a);
201        let v_ba = hadamard(&b_tb2, &a_ta2);
202        let khatri_rao_ba = khatri_rao(&b, &a);
203        let rhs_c = matmul(&x2, &khatri_rao_ba);
204        let rhs_ct = transpose(&rhs_c);
205        let sol_c = solve_ls(&v_ba, &rhs_ct);
206        c = transpose(&sol_c);
207        let mut lambdas = vec![1.0f64; rank];
208        for r in 0..rank {
209            let na: f64 = a.iter().map(|row| row[r] * row[r]).sum::<f64>().sqrt();
210            let nb: f64 = b.iter().map(|row| row[r] * row[r]).sum::<f64>().sqrt();
211            let nc: f64 = c.iter().map(|row| row[r] * row[r]).sum::<f64>().sqrt();
212            let lam = na * nb * nc;
213            lambdas[r] = lam;
214            if na > 1e-14 {
215                for row in &mut a {
216                    row[r] /= na;
217                }
218            }
219            if nb > 1e-14 {
220                for row in &mut b {
221                    row[r] /= nb;
222                }
223            }
224            if nc > 1e-14 {
225                for row in &mut c {
226                    row[r] /= nc;
227                }
228            }
229        }
230        let err = cp_reconstruction_error(tensor, &a, &b, &c, &lambdas);
231        if (prev_err - err).abs() / (prev_err.abs() + 1e-14) < tol {
232            return CpDecomposition { a, b, c, lambdas };
233        }
234        prev_err = err;
235    }
236    let lambdas = vec![1.0f64; rank];
237    CpDecomposition { a, b, c, lambdas }
238}
239/// Compute Gram matrix A^T A of a factor matrix A (n×r) → (r×r).
240#[allow(dead_code)]
241fn gram(a: &[Vec<f64>]) -> Vec<Vec<f64>> {
242    let r = a[0].len();
243    let mut g = vec![vec![0.0; r]; r];
244    for row in a {
245        for i in 0..r {
246            for j in 0..r {
247                g[i][j] += row[i] * row[j];
248            }
249        }
250    }
251    g
252}
253/// Khatri-Rao product (column-wise Kronecker product) of A (m×r) and B (n×r)
254/// → (m*n)×r.
255#[allow(dead_code)]
256fn khatri_rao(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
257    let m = a.len();
258    let n = b.len();
259    let r = a[0].len();
260    let mut out = vec![vec![0.0; r]; m * n];
261    for i in 0..m {
262        for j in 0..n {
263            for rr in 0..r {
264                out[i * n + j][rr] = a[i][rr] * b[j][rr];
265            }
266        }
267    }
268    out
269}
270/// Frobenius norm of the reconstruction error for a CP decomposition.
271#[allow(dead_code)]
272fn cp_reconstruction_error(
273    tensor: &DenseTensor,
274    a: &[Vec<f64>],
275    b: &[Vec<f64>],
276    c: &[Vec<f64>],
277    lambdas: &[f64],
278) -> f64 {
279    let n0 = tensor.shape[0];
280    let n1 = tensor.shape[1];
281    let n2 = tensor.shape[2];
282    let rank = lambdas.len();
283    let mut err = 0.0;
284    for i in 0..n0 {
285        for j in 0..n1 {
286            for k in 0..n2 {
287                let approx: f64 = (0..rank)
288                    .map(|r| lambdas[r] * a[i][r] * b[j][r] * c[k][r])
289                    .sum();
290                let diff = tensor.get(&[i, j, k]) - approx;
291                err += diff * diff;
292            }
293        }
294    }
295    err.sqrt()
296}
297/// Helper: extract leading `k` left singular vectors of a matrix via power
298/// iteration on A A^T with Gram-Schmidt orthogonalisation.  Returns an n×k matrix.
299#[allow(dead_code)]
300pub(super) fn truncated_svd_left(a: &[Vec<f64>], k: usize) -> Vec<Vec<f64>> {
301    let n = a.len();
302    let m = a[0].len();
303    let mut aat = vec![vec![0.0; n]; n];
304    for i in 0..n {
305        for j in 0..n {
306            for p in 0..m {
307                aat[i][j] += a[i][p] * a[j][p];
308            }
309        }
310    }
311    let k = k.min(n);
312    let mut result = vec![vec![0.0f64; k]; n];
313    let mut found: Vec<Vec<f64>> = Vec::new();
314    for r in 0..k {
315        let mut v: Vec<f64> = (0..n).map(|i| if i == r { 1.0 } else { 0.0 }).collect();
316        for fv in &found {
317            let dot: f64 = v.iter().zip(fv.iter()).map(|(a, b)| a * b).sum();
318            for i in 0..n {
319                v[i] -= dot * fv[i];
320            }
321        }
322        let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
323        if norm < 1e-14 {
324            v = (0..n)
325                .map(|i| if i == (r + 1) % n { 1.0 } else { 0.0 })
326                .collect();
327        } else {
328            for vi in &mut v {
329                *vi /= norm;
330            }
331        }
332        for _ in 0..300 {
333            let mut w = vec![0.0; n];
334            for i in 0..n {
335                for j in 0..n {
336                    w[i] += aat[i][j] * v[j];
337                }
338            }
339            for fv in &found {
340                let dot: f64 = w.iter().zip(fv.iter()).map(|(a, b)| a * b).sum();
341                for i in 0..n {
342                    w[i] -= dot * fv[i];
343                }
344            }
345            let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
346            if norm < 1e-14 {
347                break;
348            }
349            let v_new: Vec<f64> = w.iter().map(|x| x / norm).collect();
350            let diff: f64 = v_new
351                .iter()
352                .zip(v.iter())
353                .map(|(a, b)| (a - b).powi(2))
354                .sum::<f64>()
355                .sqrt();
356            v = v_new;
357            if diff < 1e-12 {
358                break;
359            }
360        }
361        for i in 0..n {
362            result[i][r] = v[i];
363        }
364        found.push(v);
365    }
366    result
367}
368/// Higher-order SVD (HOSVD) Tucker decomposition of a 3-way tensor.
369///
370/// Computes a Tucker-(r0, r1, r2) approximation using sequentially truncated
371/// SVD on each unfolding.
372///
373/// # Arguments
374/// * `tensor` – 3-way `DenseTensor`
375/// * `ranks`  – target rank for each mode (r0, r1, r2)
376#[allow(dead_code)]
377pub fn tucker_hosvd(tensor: &DenseTensor, ranks: [usize; 3]) -> TuckerDecomposition {
378    assert_eq!(
379        tensor.shape.len(),
380        3,
381        "tucker_hosvd requires a rank-3 tensor"
382    );
383    let [r0, r1, r2] = ranks;
384    let x0 = tensor.mode_n_unfold(0);
385    let x1 = tensor.mode_n_unfold(1);
386    let x2 = tensor.mode_n_unfold(2);
387    let u0 = truncated_svd_left(&x0, r0);
388    let u1 = truncated_svd_left(&x1, r1);
389    let u2 = truncated_svd_left(&x2, r2);
390    let n0 = tensor.shape[0];
391    let n1 = tensor.shape[1];
392    let n2 = tensor.shape[2];
393    let mut core = DenseTensor::zeros(&[r0, r1, r2]);
394    for i in 0..r0 {
395        for j in 0..r1 {
396            for k in 0..r2 {
397                let mut val = 0.0;
398                for ii in 0..n0 {
399                    for jj in 0..n1 {
400                        for kk in 0..n2 {
401                            val += u0[ii][i] * u1[jj][j] * u2[kk][k] * tensor.get(&[ii, jj, kk]);
402                        }
403                    }
404                }
405                core.set(&[i, j, k], val);
406            }
407        }
408    }
409    TuckerDecomposition { core, u0, u1, u2 }
410}
411/// Reconstruct a 3-way tensor from a Tucker decomposition.
412#[allow(dead_code)]
413pub fn tucker_reconstruct(td: &TuckerDecomposition) -> DenseTensor {
414    let r0 = td.core.shape[0];
415    let r1 = td.core.shape[1];
416    let r2 = td.core.shape[2];
417    let n0 = td.u0.len();
418    let n1 = td.u1.len();
419    let n2 = td.u2.len();
420    let mut out = DenseTensor::zeros(&[n0, n1, n2]);
421    for i in 0..n0 {
422        for j in 0..n1 {
423            for k in 0..n2 {
424                let mut val = 0.0;
425                for ii in 0..r0 {
426                    for jj in 0..r1 {
427                        for kk in 0..r2 {
428                            val += td.u0[i][ii]
429                                * td.u1[j][jj]
430                                * td.u2[k][kk]
431                                * td.core.get(&[ii, jj, kk]);
432                        }
433                    }
434                }
435                out.set(&[i, j, k], val);
436            }
437        }
438    }
439    out
440}
441/// Compute the bilinear form v^T A w for a Tensor2.
442///
443/// Returns the scalar `Σ_ij v_i A_ij w_j`.
444#[allow(dead_code)]
445pub fn bilinear_form(a: &Tensor2, v: &[f64; 3], w: &[f64; 3]) -> f64 {
446    let mut acc = 0.0f64;
447    for i in 0..3 {
448        for j in 0..3 {
449            acc += v[i] * a.data[i][j] * w[j];
450        }
451    }
452    acc
453}
454/// Compute the quadratic form v^T A v for a Tensor2.
455#[allow(dead_code)]
456pub fn quadratic_form(a: &Tensor2, v: &[f64; 3]) -> f64 {
457    bilinear_form(a, v, v)
458}
459/// Outer product of two 3-vectors: result_ij = a_i b_j (returns Tensor2).
460#[allow(dead_code)]
461pub fn vec_outer(a: &[f64; 3], b: &[f64; 3]) -> Tensor2 {
462    let mut d = [[0.0f64; 3]; 3];
463    for i in 0..3 {
464        for j in 0..3 {
465            d[i][j] = a[i] * b[j];
466        }
467    }
468    Tensor2 { data: d }
469}
470/// Tensor contraction: contract modes `mode_a` of `a` and `mode_b` of `b`
471/// over their shared dimension.
472///
473/// For rank-2 tensors this reduces to matrix–matrix multiplication when
474/// `mode_a == 1` (col) and `mode_b == 0` (row).
475#[allow(dead_code)]
476pub fn tensor_contraction(
477    a: &DenseTensor,
478    mode_a: usize,
479    b: &DenseTensor,
480    mode_b: usize,
481) -> DenseTensor {
482    let dim_a = a.shape[mode_a];
483    let dim_b = b.shape[mode_b];
484    assert_eq!(dim_a, dim_b, "contracted dimensions must match");
485    let mut out_shape: Vec<usize> = Vec::new();
486    for (k, &s) in a.shape.iter().enumerate() {
487        if k != mode_a {
488            out_shape.push(s);
489        }
490    }
491    for (k, &s) in b.shape.iter().enumerate() {
492        if k != mode_b {
493            out_shape.push(s);
494        }
495    }
496    if out_shape.is_empty() {
497        out_shape.push(1);
498    }
499    let n_out: usize = out_shape.iter().product();
500    let mut out_data = vec![0.0f64; n_out];
501    let strides_a = a.strides();
502    let strides_b = b.strides();
503    let out_rank = out_shape.len();
504    let mut out_strides = vec![1usize; out_rank];
505    for k in (0..out_rank.saturating_sub(1)).rev() {
506        out_strides[k] = out_strides[k + 1] * out_shape[k + 1];
507    }
508    let total_a: usize = a.data.len();
509    let total_b: usize = b.data.len();
510    for fa in 0..total_a {
511        let mut tmp = fa;
512        let mut midx_a = vec![0usize; a.shape.len()];
513        for k in 0..a.shape.len() {
514            midx_a[k] = tmp / strides_a[k];
515            tmp %= strides_a[k];
516        }
517        let contract_val_a = midx_a[mode_a];
518        for fb in 0..total_b {
519            let mut tmp2 = fb;
520            let mut midx_b = vec![0usize; b.shape.len()];
521            for k in 0..b.shape.len() {
522                midx_b[k] = tmp2 / strides_b[k];
523                tmp2 %= strides_b[k];
524            }
525            if midx_b[mode_b] != contract_val_a {
526                continue;
527            }
528            let mut oidx = vec![0usize; out_rank];
529            let mut oi = 0;
530            for k in 0..a.shape.len() {
531                if k != mode_a {
532                    oidx[oi] = midx_a[k];
533                    oi += 1;
534                }
535            }
536            for k in 0..b.shape.len() {
537                if k != mode_b {
538                    oidx[oi] = midx_b[k];
539                    oi += 1;
540                }
541            }
542            let fo: usize = oidx
543                .iter()
544                .zip(out_strides.iter())
545                .map(|(&i, &s)| i * s)
546                .sum();
547            out_data[fo] += a.data[fa] * b.data[fb];
548        }
549    }
550    DenseTensor {
551        shape: out_shape,
552        data: out_data,
553    }
554}
555/// Compute the general outer product of two DenseTensors: result has rank = rank(a) + rank(b).
556#[allow(dead_code)]
557pub fn tensor_outer(a: &DenseTensor, b: &DenseTensor) -> DenseTensor {
558    let mut shape = a.shape.clone();
559    shape.extend_from_slice(&b.shape);
560    let n_out: usize = shape.iter().product();
561    let mut data = vec![0.0f64; n_out];
562    let stride_b: usize = b.data.len();
563    for (ia, &va) in a.data.iter().enumerate() {
564        for (ib, &vb) in b.data.iter().enumerate() {
565            data[ia * stride_b + ib] = va * vb;
566        }
567    }
568    DenseTensor { shape, data }
569}
570/// Check whether a DenseTensor is symmetric under swap of two given modes.
571#[allow(dead_code)]
572pub fn is_symmetric_modes(t: &DenseTensor, mode1: usize, mode2: usize, tol: f64) -> bool {
573    assert_eq!(
574        t.shape[mode1], t.shape[mode2],
575        "modes must have equal dimension"
576    );
577    let n = t.shape[mode1];
578    let strides = t.strides();
579    let total = t.data.len();
580    for flat in 0..total {
581        let mut tmp = flat;
582        let mut midx = vec![0usize; t.shape.len()];
583        for k in 0..t.shape.len() {
584            midx[k] = tmp / strides[k];
585            tmp %= strides[k];
586        }
587        if midx[mode1] >= midx[mode2] {
588            continue;
589        }
590        let mut midx2 = midx.clone();
591        midx2[mode1] = midx[mode2];
592        midx2[mode2] = midx[mode1];
593        let _n = n;
594        let flat2: usize = midx2.iter().zip(strides.iter()).map(|(&i, &s)| i * s).sum();
595        if (t.data[flat] - t.data[flat2]).abs() > tol {
596            return false;
597        }
598    }
599    true
600}
601/// Build a fully symmetric DenseTensor by averaging over all permutations of indices
602/// (only for rank-2 and rank-3 tensors for now; rank-2 reduces to symmetrising a matrix).
603#[allow(dead_code)]
604pub fn symmetrize_tensor(t: &DenseTensor) -> DenseTensor {
605    let rank = t.shape.len();
606    match rank {
607        2 => {
608            let n = t.shape[0];
609            assert_eq!(t.shape[1], n, "rank-2 symmetrize requires square tensor");
610            let mut out = DenseTensor::zeros(&t.shape);
611            for i in 0..n {
612                for j in 0..n {
613                    let v = 0.5 * (t.get(&[i, j]) + t.get(&[j, i]));
614                    out.set(&[i, j], v);
615                }
616            }
617            out
618        }
619        3 => {
620            let n = t.shape[0];
621            assert!(
622                t.shape.iter().all(|&s| s == n),
623                "rank-3 symmetrize requires cubic tensor"
624            );
625            let mut out = DenseTensor::zeros(&t.shape);
626            let perms: [[usize; 3]; 6] = [
627                [0, 1, 2],
628                [0, 2, 1],
629                [1, 0, 2],
630                [1, 2, 0],
631                [2, 0, 1],
632                [2, 1, 0],
633            ];
634            for i in 0..n {
635                for j in 0..n {
636                    for k in 0..n {
637                        let idx_orig = [i, j, k];
638                        let s: f64 = perms
639                            .iter()
640                            .map(|p| {
641                                let perm_idx = [idx_orig[p[0]], idx_orig[p[1]], idx_orig[p[2]]];
642                                t.get(&perm_idx)
643                            })
644                            .sum::<f64>()
645                            / 6.0;
646                        out.set(&[i, j, k], s);
647                    }
648                }
649            }
650            out
651        }
652        _ => t.clone(),
653    }
654}
655/// The 4th-order symmetric identity tensor I^S_ijkl = (delta_ik delta_jl + delta_il delta_jk)/2.
656///
657/// This is the projector onto symmetric second-order tensors:
658/// I^S : A = sym(A) = (A + A^T)/2.
659#[allow(dead_code)]
660pub fn fourth_order_symmetric_identity() -> Tensor4 {
661    let mut t = Tensor4::zero();
662    for i in 0..3 {
663        for j in 0..3 {
664            for k in 0..3 {
665                for l in 0..3 {
666                    let v = 0.5
667                        * (if i == k && j == l { 1.0 } else { 0.0 }
668                            + if i == l && j == k { 1.0 } else { 0.0 });
669                    t.data[i][j][k][l] = v;
670                }
671            }
672        }
673    }
674    t
675}
676/// The 4th-order skew (antisymmetric) identity tensor:
677/// I^A_ijkl = (delta_ik delta_jl - delta_il delta_jk)/2.
678#[allow(dead_code)]
679pub fn fourth_order_skew_identity() -> Tensor4 {
680    let mut t = Tensor4::zero();
681    for i in 0..3 {
682        for j in 0..3 {
683            for k in 0..3 {
684                for l in 0..3 {
685                    let v = 0.5
686                        * (if i == k && j == l { 1.0 } else { 0.0 }
687                            - if i == l && j == k { 1.0 } else { 0.0 });
688                    t.data[i][j][k][l] = v;
689                }
690            }
691        }
692    }
693    t
694}
695/// Compute the double inner product C :: T for a 4th-order elasticity tensor C
696/// and a symmetric second-order tensor T.  Returns sigma_ij = C_ijkl T_kl.
697#[allow(dead_code)]
698pub fn elasticity_stress(c: &Tensor4, strain: &Tensor2) -> Tensor2 {
699    c.double_contract_2(strain)
700}
701/// Compliance tensor from stiffness: S = C^{-1} in Voigt form (6×6 matrix inversion).
702///
703/// Returns `None` if the Voigt matrix is singular (det ≈ 0).
704#[allow(dead_code)]
705pub fn compliance_from_stiffness(c: &Tensor4) -> Option<[[f64; 6]; 6]> {
706    let m = KelvinTensor::from_tensor4(c);
707    invert_6x6(&m)
708}
709/// Invert a 6×6 matrix using Gauss–Jordan elimination.
710#[allow(dead_code)]
711pub fn invert_6x6(m: &[[f64; 6]; 6]) -> Option<[[f64; 6]; 6]> {
712    let mut a = *m;
713    let mut inv = [[0.0f64; 6]; 6];
714    for i in 0..6 {
715        inv[i][i] = 1.0;
716    }
717    for col in 0..6 {
718        let mut pivot_row = col;
719        let mut max_val = a[col][col].abs();
720        for row in (col + 1)..6 {
721            if a[row][col].abs() > max_val {
722                max_val = a[row][col].abs();
723                pivot_row = row;
724            }
725        }
726        if max_val < 1e-14 {
727            return None;
728        }
729        a.swap(col, pivot_row);
730        inv.swap(col, pivot_row);
731        let piv = a[col][col];
732        for j in 0..6 {
733            a[col][j] /= piv;
734            inv[col][j] /= piv;
735        }
736        for row in 0..6 {
737            if row == col {
738                continue;
739            }
740            let factor = a[row][col];
741            for j in 0..6 {
742                let av = a[col][j];
743                a[row][j] -= factor * av;
744                let iv = inv[col][j];
745                inv[row][j] -= factor * iv;
746            }
747        }
748    }
749    Some(inv)
750}
751/// Build the Eshelby inclusion tensor for an isotropic matrix.
752///
753/// Valid for a spherical inclusion (aspect ratio = 1).
754/// Returns the Eshelby tensor S_ijkl.
755///
756/// Reference: J.D. Eshelby (1957), nu = Poisson's ratio.
757#[allow(dead_code)]
758pub fn eshelby_sphere(nu: f64) -> Tensor4 {
759    let mut s = Tensor4::zero();
760    let s_iiii = (7.0 - 5.0 * nu) / (15.0 * (1.0 - nu));
761    let s_iijj = (5.0 * nu - 1.0) / (15.0 * (1.0 - nu));
762    let s_ijij = (4.0 - 5.0 * nu) / (15.0 * (1.0 - nu));
763    for i in 0..3 {
764        s.data[i][i][i][i] = s_iiii;
765    }
766    for i in 0..3 {
767        for j in 0..3 {
768            if i != j {
769                s.data[i][i][j][j] = s_iijj;
770                s.data[i][j][i][j] = s_ijij;
771                s.data[i][j][j][i] = s_ijij;
772            }
773        }
774    }
775    s
776}
777#[cfg(test)]
778mod tests {
779    use super::*;
780    const EPS: f64 = 1e-12;
781    fn approx_eq(a: f64, b: f64) -> bool {
782        (a - b).abs() < EPS
783    }
784    fn tensor_approx_eq(a: &Tensor2, b: &Tensor2) -> bool {
785        for i in 0..3 {
786            for j in 0..3 {
787                if !approx_eq(a.data[i][j], b.data[i][j]) {
788                    return false;
789                }
790            }
791        }
792        true
793    }
794    #[test]
795    fn test_identity_double_contract_equals_trace() {
796        let a = Tensor2::new([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
797        let id = Tensor2::identity();
798        let dc = id.double_contract(&a);
799        assert!(
800            approx_eq(dc, a.trace()),
801            "I:A should equal tr(A), got {dc} vs {}",
802            a.trace()
803        );
804    }
805    #[test]
806    fn test_det_identity_is_one() {
807        let id = Tensor2::identity();
808        assert!(approx_eq(id.det(), 1.0), "det(I) should be 1");
809    }
810    #[test]
811    fn test_inverse_of_identity_is_identity() {
812        let id = Tensor2::identity();
813        let inv = id.inverse().expect("identity is invertible");
814        assert!(
815            tensor_approx_eq(&inv, &Tensor2::identity()),
816            "inv(I) should be I"
817        );
818    }
819    #[test]
820    fn test_outer_product_e1_e2() {
821        let t = Tensor2::outer_product([1.0, 0.0, 0.0], [0.0, 1.0, 0.0]);
822        let mut expected = [[0.0f64; 3]; 3];
823        expected[0][1] = 1.0;
824        assert!(tensor_approx_eq(&t, &Tensor2::new(expected)));
825    }
826    #[test]
827    fn test_von_mises_pure_shear() {
828        let tau = 1.0f64;
829        let t = Tensor2::new([[0.0, tau, 0.0], [tau, 0.0, 0.0], [0.0, 0.0, 0.0]]);
830        let vm = t.von_mises();
831        let expected = 3.0f64.sqrt() * tau;
832        assert!(
833            (vm - expected).abs() < 1e-10,
834            "von Mises for pure shear: got {vm}, expected {expected}"
835        );
836    }
837    #[test]
838    fn test_isotropic_c_double_contract_identity() {
839        let lambda = 1.0f64;
840        let mu = 1.0f64;
841        let c = Tensor4::isotropic(lambda, mu);
842        let id = Tensor2::identity();
843        let result = c.double_contract_2(&id);
844        let expected = Tensor2::identity().scale(3.0 * lambda + 2.0 * mu);
845        assert!(
846            tensor_approx_eq(&result, &expected),
847            "C(λ=1,μ=1):I should be (3λ+2μ)I = 5I"
848        );
849    }
850    #[test]
851    fn test_eigenvalues_diagonal_matrix() {
852        let t = Tensor2::new([[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 5.0]]);
853        let ev = t.eigenvalues_symmetric();
854        assert!(
855            (ev[0] - 2.0).abs() < 1e-10
856                && (ev[1] - 3.0).abs() < 1e-10
857                && (ev[2] - 5.0).abs() < 1e-10,
858            "eigenvalues should be [2,3,5], got {ev:?}"
859        );
860    }
861    #[test]
862    fn test_deviatoric_trace_is_zero() {
863        let t = Tensor2::new([[3.0, 1.0, 0.5], [1.0, 2.0, 0.3], [0.5, 0.3, 4.0]]);
864        let dev = t.deviatoric();
865        assert!(
866            dev.trace().abs() < EPS,
867            "trace of deviatoric should be 0, got {}",
868            dev.trace()
869        );
870    }
871    #[test]
872    fn test_voigt_round_trip() {
873        let t = Tensor2::new([[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [3.0, 5.0, 6.0]]);
874        let v = VoigtTensor::from_tensor2(&t);
875        let t2 = VoigtTensor::to_tensor2(&v);
876        assert!(tensor_approx_eq(&t, &t2), "Voigt round-trip failed");
877    }
878    #[test]
879    fn test_contract_vec() {
880        let id = Tensor2::identity();
881        let v = [1.0, 2.0, 3.0];
882        let result = id.contract_vec(v);
883        assert!(approx_eq(result[0], 1.0));
884        assert!(approx_eq(result[1], 2.0));
885        assert!(approx_eq(result[2], 3.0));
886    }
887    #[test]
888    fn test_dyadic_product() {
889        let a = [1.0, 0.0, 0.0];
890        let b = [0.0, 1.0, 0.0];
891        let t = Tensor2::dyadic(a, b);
892        assert!(approx_eq(t.data[0][1], 1.0));
893        assert!(approx_eq(t.data[0][0], 0.0));
894    }
895    #[test]
896    fn test_rotate_identity() {
897        let a = Tensor2::new([[1.0, 2.0, 0.0], [2.0, 3.0, 0.0], [0.0, 0.0, 4.0]]);
898        let r = Tensor2::identity();
899        let rotated = a.rotate(&r);
900        assert!(
901            tensor_approx_eq(&a, &rotated),
902            "Rotating by I should be identity"
903        );
904    }
905    #[test]
906    fn test_rotate_90_degrees_z() {
907        let r = Tensor2::new([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]);
908        let a = Tensor2::diagonal([1.0, 0.0, 0.0]);
909        let rotated = a.rotate(&r);
910        assert!(
911            approx_eq(rotated.data[1][1], 1.0),
912            "After 90° rotation, sigma_xx should become sigma_yy"
913        );
914        assert!(
915            rotated.data[0][0].abs() < EPS,
916            "sigma_xx should be 0 after rotation"
917        );
918    }
919    #[test]
920    fn test_invariant_i1() {
921        let t = Tensor2::diagonal([1.0, 2.0, 3.0]);
922        assert!(approx_eq(t.invariant_i1(), 6.0));
923    }
924    #[test]
925    fn test_invariant_i2_diagonal() {
926        let t = Tensor2::diagonal([1.0, 2.0, 3.0]);
927        assert!(
928            approx_eq(t.invariant_i2(), 11.0),
929            "I2 = {}",
930            t.invariant_i2()
931        );
932    }
933    #[test]
934    fn test_invariant_i3_diagonal() {
935        let t = Tensor2::diagonal([1.0, 2.0, 3.0]);
936        assert!(approx_eq(t.invariant_i3(), 6.0), "I3 = det = 1*2*3 = 6");
937    }
938    #[test]
939    fn test_principal_invariants() {
940        let t = Tensor2::identity();
941        let inv = t.principal_invariants();
942        assert!(approx_eq(inv[0], 3.0));
943        assert!(approx_eq(inv[1], 3.0));
944        assert!(approx_eq(inv[2], 1.0));
945    }
946    #[test]
947    fn test_decompose_dev_hydro() {
948        let t = Tensor2::diagonal([3.0, 6.0, 9.0]);
949        let (dev, h) = t.decompose_dev_hydro();
950        assert!(approx_eq(h, 6.0), "hydrostatic = {h}");
951        assert!(
952            dev.trace().abs() < EPS,
953            "deviatoric trace = {}",
954            dev.trace()
955        );
956    }
957    #[test]
958    fn test_from_voigt() {
959        let v = [1.0, 2.0, 3.0, 0.5, 0.3, 0.1];
960        let t = Tensor2::from_voigt(v);
961        assert!(approx_eq(t.data[0][0], 1.0));
962        assert!(approx_eq(t.data[1][1], 2.0));
963        assert!(approx_eq(t.data[0][1], 0.5));
964        assert!(approx_eq(t.data[1][0], 0.5));
965    }
966    #[test]
967    fn test_to_voigt() {
968        let t = Tensor2::new([[1.0, 0.5, 0.1], [0.5, 2.0, 0.3], [0.1, 0.3, 3.0]]);
969        let v = t.to_voigt();
970        assert!(approx_eq(v[0], 1.0));
971        assert!(approx_eq(v[3], 0.5));
972    }
973    #[test]
974    fn test_effective_strain() {
975        let t = Tensor2::zero();
976        assert!(
977            t.effective_strain() < EPS,
978            "zero tensor should give zero strain"
979        );
980    }
981    #[test]
982    fn test_diagonal_constructor() {
983        let t = Tensor2::diagonal([1.0, 2.0, 3.0]);
984        assert!(approx_eq(t.data[0][0], 1.0));
985        assert!(approx_eq(t.data[1][1], 2.0));
986        assert!(approx_eq(t.data[2][2], 3.0));
987        assert!(approx_eq(t.data[0][1], 0.0));
988    }
989    #[test]
990    fn test_is_symmetric() {
991        let t = Tensor2::new([[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [3.0, 5.0, 6.0]]);
992        assert!(t.is_symmetric(1e-14));
993        let t2 = Tensor2::new([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
994        assert!(!t2.is_symmetric(1e-14));
995    }
996    #[test]
997    fn test_matrix_exp_identity() {
998        let zero = Tensor2::zero();
999        let result = zero.matrix_exp_approx();
1000        assert!(
1001            tensor_approx_eq(&result, &Tensor2::identity()),
1002            "exp(0) = I"
1003        );
1004    }
1005    #[test]
1006    fn test_deviatoric_projector() {
1007        let p = Tensor4::deviatoric_projector();
1008        let id = Tensor2::identity();
1009        let result = p.double_contract_2(&id);
1010        assert!(
1011            result.norm() < 1e-10,
1012            "P_dev : I should be zero: norm = {}",
1013            result.norm()
1014        );
1015    }
1016    #[test]
1017    fn test_tensor4_to_voigt_matrix_isotropic() {
1018        let lambda = 1.0;
1019        let mu = 1.0;
1020        let c = Tensor4::isotropic(lambda, mu);
1021        let m = c.to_voigt_matrix();
1022        assert!(approx_eq(m[0][0], 3.0), "C_1111 = {}", m[0][0]);
1023        assert!(approx_eq(m[0][1], 1.0), "C_1122 = {}", m[0][1]);
1024        assert!(approx_eq(m[3][3], 1.0), "C_1212 = {}", m[3][3]);
1025    }
1026    #[test]
1027    fn test_tensor4_scale() {
1028        let c = Tensor4::isotropic(1.0, 1.0);
1029        let c2 = c.scale(2.0);
1030        for i in 0..3 {
1031            for j in 0..3 {
1032                for k in 0..3 {
1033                    for l in 0..3 {
1034                        assert!(approx_eq(c2.data[i][j][k][l], 2.0 * c.data[i][j][k][l]));
1035                    }
1036                }
1037            }
1038        }
1039    }
1040    #[test]
1041    fn test_tensor4_add() {
1042        let c1 = Tensor4::isotropic(1.0, 1.0);
1043        let c2 = Tensor4::isotropic(2.0, 3.0);
1044        let c3 = c1.add(&c2);
1045        let expected = Tensor4::isotropic(3.0, 4.0);
1046        for i in 0..3 {
1047            for j in 0..3 {
1048                for k in 0..3 {
1049                    for l in 0..3 {
1050                        assert!(approx_eq(c3.data[i][j][k][l], expected.data[i][j][k][l]));
1051                    }
1052                }
1053            }
1054        }
1055    }
1056    #[test]
1057    fn test_engineering_strain_round_trip() {
1058        let v = [0.01, 0.02, 0.03, 0.005, 0.003, 0.001];
1059        let eng = VoigtTensor::to_engineering_strain(&v);
1060        let back = VoigtTensor::from_engineering_strain(&eng);
1061        for i in 0..6 {
1062            assert!(approx_eq(v[i], back[i]), "round trip failed at {i}");
1063        }
1064    }
1065    #[test]
1066    fn test_von_mises_voigt() {
1067        let v = [0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
1068        let vm = VoigtTensor::von_mises_voigt(&v);
1069        let expected = 3.0_f64.sqrt();
1070        assert!(
1071            (vm - expected).abs() < 1e-10,
1072            "von Mises = {vm}, expected {expected}"
1073        );
1074    }
1075    #[test]
1076    fn test_hydrostatic_pressure() {
1077        let v = [-100.0, -100.0, -100.0, 0.0, 0.0, 0.0];
1078        let p = VoigtTensor::hydrostatic_pressure(&v);
1079        assert!(approx_eq(p, 100.0), "pressure = {p}, expected 100.0");
1080    }
1081    #[test]
1082    fn test_sym_and_skew_add_to_original() {
1083        let a = Tensor2::new([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
1084        let sym = a.sym();
1085        let skew = a.skew();
1086        let reconstructed = sym.add(&skew);
1087        assert!(
1088            tensor_approx_eq(&a, &reconstructed),
1089            "sym + skew should equal original"
1090        );
1091    }
1092    #[test]
1093    fn test_tensor4_has_minor_symmetry_isotropic() {
1094        let c = Tensor4::isotropic(1.0, 1.0);
1095        assert!(
1096            c.has_minor_symmetry(1e-12),
1097            "isotropic tensor should have minor symmetry"
1098        );
1099    }
1100    #[test]
1101    fn test_tensor4_has_major_symmetry_isotropic() {
1102        let c = Tensor4::isotropic(2.0, 3.0);
1103        assert!(
1104            c.has_major_symmetry(1e-12),
1105            "isotropic tensor should have major symmetry"
1106        );
1107    }
1108    #[test]
1109    fn test_tensor4_symmetrize_minor_idempotent() {
1110        let c = Tensor4::isotropic(1.0, 1.0);
1111        let cs = c.symmetrize_minor();
1112        let cs2 = cs.symmetrize_minor();
1113        for i in 0..3 {
1114            for j in 0..3 {
1115                for k in 0..3 {
1116                    for l in 0..3 {
1117                        assert!(approx_eq(cs.data[i][j][k][l], cs2.data[i][j][k][l]));
1118                    }
1119                }
1120            }
1121        }
1122    }
1123    #[test]
1124    fn test_tensor4_symmetrize_major_idempotent() {
1125        let c = Tensor4::isotropic(1.0, 1.0);
1126        let cs = c.symmetrize_major();
1127        let cs2 = cs.symmetrize_major();
1128        for i in 0..3 {
1129            for j in 0..3 {
1130                for k in 0..3 {
1131                    for l in 0..3 {
1132                        assert!(approx_eq(cs.data[i][j][k][l], cs2.data[i][j][k][l]));
1133                    }
1134                }
1135            }
1136        }
1137    }
1138    #[test]
1139    fn test_tensor4_double_contract_4_identity() {
1140        let c = Tensor4::isotropic(1.0, 1.0);
1141        let id4 = Tensor4::identity();
1142        let result = c.double_contract_4(&id4);
1143        for i in 0..3 {
1144            for j in 0..3 {
1145                for k in 0..3 {
1146                    for l in 0..3 {
1147                        assert!(
1148                            approx_eq(result.data[i][j][k][l], c.data[i][j][k][l]),
1149                            "C::I4 should equal C at [{i}][{j}][{k}][{l}]"
1150                        );
1151                    }
1152                }
1153            }
1154        }
1155    }
1156    #[test]
1157    fn test_tensor4_rotate_identity_unchanged() {
1158        let c = Tensor4::isotropic(1.0, 1.0);
1159        let r = Tensor2::identity();
1160        let c_rot = c.rotate(&r);
1161        for i in 0..3 {
1162            for j in 0..3 {
1163                for k in 0..3 {
1164                    for l in 0..3 {
1165                        assert!(
1166                            (c_rot.data[i][j][k][l] - c.data[i][j][k][l]).abs() < 1e-10,
1167                            "rotating by I should be unchanged"
1168                        );
1169                    }
1170                }
1171            }
1172        }
1173    }
1174    #[test]
1175    fn test_tensor4_norm_positive() {
1176        let c = Tensor4::isotropic(1.0, 1.0);
1177        assert!(c.norm() > 0.0, "Frobenius norm should be positive");
1178    }
1179    #[test]
1180    fn test_tensor4_sub() {
1181        let c1 = Tensor4::isotropic(2.0, 2.0);
1182        let c2 = Tensor4::isotropic(1.0, 1.0);
1183        let diff = c1.sub(&c2);
1184        let expected = Tensor4::isotropic(1.0, 1.0);
1185        for i in 0..3 {
1186            for j in 0..3 {
1187                for k in 0..3 {
1188                    for l in 0..3 {
1189                        assert!(approx_eq(diff.data[i][j][k][l], expected.data[i][j][k][l]));
1190                    }
1191                }
1192            }
1193        }
1194    }
1195    #[test]
1196    fn test_tensor4_deviatoric_plus_volumetric_is_identity_sym() {
1197        let pd = Tensor4::deviatoric_projector();
1198        let pv = Tensor4::volumetric_projector();
1199        let sum = pd.add(&pv);
1200        let id_sym = Tensor4::identity_sym();
1201        for i in 0..3 {
1202            for j in 0..3 {
1203                for k in 0..3 {
1204                    for l in 0..3 {
1205                        assert!(
1206                            (sum.data[i][j][k][l] - id_sym.data[i][j][k][l]).abs() < 1e-12,
1207                            "P_dev + P_vol != I_sym at [{i}][{j}][{k}][{l}]"
1208                        );
1209                    }
1210                }
1211            }
1212        }
1213    }
1214    #[test]
1215    fn test_kelvin_from_tensor4_isotropic_diagonal() {
1216        let c = Tensor4::isotropic(1.0, 1.0);
1217        let km = KelvinTensor::from_tensor4(&c);
1218        assert!(
1219            (km[0][0] - 3.0).abs() < 1e-10,
1220            "K[0][0] = {} expected 3",
1221            km[0][0]
1222        );
1223        assert!(
1224            (km[3][3] - 2.0).abs() < 1e-10,
1225            "K[3][3] = {} expected 2",
1226            km[3][3]
1227        );
1228    }
1229    #[test]
1230    fn test_kelvin_round_trip() {
1231        let c = Tensor4::isotropic(2.0, 3.0);
1232        let km = KelvinTensor::from_tensor4(&c);
1233        let c2 = KelvinTensor::to_tensor4(&km);
1234        for i in 0..3 {
1235            for j in 0..3 {
1236                for k in 0..3 {
1237                    for l in 0..3 {
1238                        assert!(
1239                            (c.data[i][j][k][l] - c2.data[i][j][k][l]).abs() < 1e-10,
1240                            "Kelvin round-trip failed at [{i}][{j}][{k}][{l}]"
1241                        );
1242                    }
1243                }
1244            }
1245        }
1246    }
1247    #[test]
1248    fn test_kelvin_stress_round_trip() {
1249        let sigma = Tensor2::new([[1.0, 0.5, 0.3], [0.5, 2.0, 0.2], [0.3, 0.2, 3.0]]);
1250        let kv = KelvinTensor::stress_to_kelvin(&sigma);
1251        let sigma2 = KelvinTensor::kelvin_to_stress(&kv);
1252        assert!(
1253            tensor_approx_eq(&sigma, &sigma2),
1254            "Kelvin stress round-trip failed"
1255        );
1256    }
1257    #[test]
1258    fn test_kelvin_matvec_isotropic() {
1259        let lambda = 1.0_f64;
1260        let mu = 1.0_f64;
1261        let c = Tensor4::isotropic(lambda, mu);
1262        let km = KelvinTensor::from_tensor4(&c);
1263        let id_kelvin = KelvinTensor::stress_to_kelvin(&Tensor2::identity());
1264        let result_kelvin = KelvinTensor::matvec(&km, &id_kelvin);
1265        let result = KelvinTensor::kelvin_to_stress(&result_kelvin);
1266        let expected_val = 3.0 * lambda + 2.0 * mu;
1267        assert!((result.data[0][0] - expected_val).abs() < 1e-10);
1268        assert!((result.data[1][1] - expected_val).abs() < 1e-10);
1269        assert!((result.data[0][1]).abs() < 1e-10);
1270    }
1271    #[test]
1272    fn test_rotation_z_is_orthogonal() {
1273        use std::f64::consts::PI;
1274        let r = TensorBasis::rotation_z(PI / 6.0);
1275        let rt = r.transpose();
1276        let rrt = r.dot(&rt);
1277        assert!(
1278            tensor_approx_eq(&rrt, &Tensor2::identity()),
1279            "R*R^T should be I"
1280        );
1281    }
1282    #[test]
1283    fn test_rotation_x_is_orthogonal() {
1284        let r = TensorBasis::rotation_x(0.7);
1285        let rrt = r.dot(&r.transpose());
1286        assert!(tensor_approx_eq(&rrt, &Tensor2::identity()));
1287    }
1288    #[test]
1289    fn test_rotate_voigt_stiffness_identity_rotation() {
1290        let c = Tensor4::isotropic(1.0, 1.0);
1291        let m = c.to_voigt_matrix();
1292        let r = Tensor2::identity();
1293        let m_rot = TensorBasis::rotate_voigt_stiffness(&m, &r);
1294        for p in 0..6 {
1295            for q in 0..6 {
1296                assert!(
1297                    (m[p][q] - m_rot[p][q]).abs() < 1e-10,
1298                    "rotation by I should not change stiffness matrix at [{p}][{q}]"
1299                );
1300            }
1301        }
1302    }
1303    #[test]
1304    fn test_tensor4_single_contract_right_identity() {
1305        let c = Tensor4::isotropic(1.0, 1.0);
1306        let id = Tensor2::identity();
1307        let result = c.single_contract_right(&id);
1308        for i in 0..3 {
1309            for j in 0..3 {
1310                for k in 0..3 {
1311                    for l in 0..3 {
1312                        assert!(
1313                            (result.data[i][j][k][l] - c.data[i][j][k][l]).abs() < 1e-10,
1314                            "C·I should equal C at [{i}][{j}][{k}][{l}]"
1315                        );
1316                    }
1317                }
1318            }
1319        }
1320    }
1321    #[test]
1322    fn test_tensor3_zero() {
1323        let t = Tensor3::zero();
1324        for i in 0..3 {
1325            for j in 0..3 {
1326                for k in 0..3 {
1327                    assert_eq!(t.data[i][j][k], 0.0, "Tensor3::zero should be all zeros");
1328                }
1329            }
1330        }
1331    }
1332    #[test]
1333    fn test_tensor3_set_get() {
1334        let mut t = Tensor3::zero();
1335        t.data[1][2][0] = 3.125;
1336        assert!((t.data[1][2][0] - 3.125).abs() < 1e-12);
1337    }
1338    #[test]
1339    fn test_tensor3_levi_civita() {
1340        let eps = Tensor3::levi_civita();
1341        assert!((eps.data[0][1][2] - 1.0).abs() < EPS, "ε_012 should be 1");
1342        assert!((eps.data[0][2][1] + 1.0).abs() < EPS, "ε_021 should be -1");
1343        assert!((eps.data[1][0][2] + 1.0).abs() < EPS, "ε_102 should be -1");
1344        assert!((eps.data[1][2][0] - 1.0).abs() < EPS, "ε_120 should be 1");
1345        assert!((eps.data[2][0][1] - 1.0).abs() < EPS, "ε_201 should be 1");
1346        assert!((eps.data[2][1][0] + 1.0).abs() < EPS, "ε_210 should be -1");
1347        assert!((eps.data[0][0][0]).abs() < EPS, "ε_000 should be 0");
1348    }
1349    #[test]
1350    fn test_tensor3_contract_vec_with_levi_civita() {
1351        let eps = Tensor3::levi_civita();
1352        let v = [1.0_f64, 0.0, 0.0];
1353        let result = eps.contract_last(&v);
1354        assert!(
1355            (result.data[1][2] - 1.0).abs() < EPS,
1356            "ε_12k*e_0_k = ε_120 = 1"
1357        );
1358        assert!(
1359            (result.data[2][1] + 1.0).abs() < EPS,
1360            "ε_21k*e_0_k = ε_210 = -1"
1361        );
1362    }
1363    #[test]
1364    fn test_tensor3_scale() {
1365        let eps = Tensor3::levi_civita();
1366        let scaled = eps.scale(2.0);
1367        assert!(
1368            (scaled.data[0][1][2] - 2.0).abs() < EPS,
1369            "scaled ε_012 should be 2"
1370        );
1371    }
1372    #[test]
1373    fn test_tensor3_add_zero() {
1374        let eps = Tensor3::levi_civita();
1375        let zero = Tensor3::zero();
1376        let sum = eps.add(&zero);
1377        for i in 0..3 {
1378            for j in 0..3 {
1379                for k in 0..3 {
1380                    assert!(
1381                        (sum.data[i][j][k] - eps.data[i][j][k]).abs() < EPS,
1382                        "add with zero should not change tensor at [{i}][{j}][{k}]"
1383                    );
1384                }
1385            }
1386        }
1387    }
1388    #[test]
1389    fn test_tensor3_totally_antisymmetric_levi_civita() {
1390        let eps = Tensor3::levi_civita();
1391        for i in 0..3 {
1392            for j in 0..3 {
1393                for k in 0..3 {
1394                    let e_ijk = eps.data[i][j][k];
1395                    let e_jki = eps.data[j][k][i];
1396                    let e_kij = eps.data[k][i][j];
1397                    assert!(
1398                        (e_ijk - e_jki).abs() < EPS,
1399                        "ε_ijk != ε_jki for ({i},{j},{k})"
1400                    );
1401                    assert!(
1402                        (e_ijk - e_kij).abs() < EPS,
1403                        "ε_ijk != ε_kij for ({i},{j},{k})"
1404                    );
1405                }
1406            }
1407        }
1408    }
1409    #[test]
1410    fn test_mandel_normal_components_unchanged() {
1411        let sigma = Tensor2::new([[1.0, 0.5, 0.3], [0.5, 2.0, 0.2], [0.3, 0.2, 3.0]]);
1412        let mv = MandelNotation::from_tensor2(&sigma);
1413        assert!((mv[0] - 1.0).abs() < EPS, "Mandel[0] = σ_11 = 1");
1414        assert!((mv[1] - 2.0).abs() < EPS, "Mandel[1] = σ_22 = 2");
1415        assert!((mv[2] - 3.0).abs() < EPS, "Mandel[2] = σ_33 = 3");
1416    }
1417    #[test]
1418    fn test_mandel_shear_scaled_by_sqrt2() {
1419        let sigma = Tensor2::new([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
1420        let mv = MandelNotation::from_tensor2(&sigma);
1421        let s2 = std::f64::consts::SQRT_2;
1422        assert!((mv[3] - s2).abs() < EPS, "Mandel[3] = √2 * σ_12");
1423    }
1424    #[test]
1425    fn test_mandel_round_trip() {
1426        let sigma = Tensor2::new([[1.0, 0.5, 0.3], [0.5, 2.0, 0.2], [0.3, 0.2, 3.0]]);
1427        let mv = MandelNotation::from_tensor2(&sigma);
1428        let sigma2 = MandelNotation::to_tensor2(&mv);
1429        assert!(
1430            tensor_approx_eq(&sigma, &sigma2),
1431            "Mandel round-trip failed"
1432        );
1433    }
1434    #[test]
1435    fn test_mandel_inner_product_preserved() {
1436        let a = Tensor2::new([[1.0, 0.5, 0.0], [0.5, 2.0, 0.0], [0.0, 0.0, 3.0]]);
1437        let b = Tensor2::new([[2.0, 0.1, 0.0], [0.1, 1.0, 0.0], [0.0, 0.0, 0.5]]);
1438        let ab_direct = a.double_contract(&b);
1439        let ma = MandelNotation::from_tensor2(&a);
1440        let mb = MandelNotation::from_tensor2(&b);
1441        let ab_mandel: f64 = ma.iter().zip(mb.iter()).map(|(x, y)| x * y).sum();
1442        assert!(
1443            (ab_direct - ab_mandel).abs() < 1e-10,
1444            "Mandel inner product should match direct double contraction: {} vs {}",
1445            ab_direct,
1446            ab_mandel
1447        );
1448    }
1449    #[test]
1450    fn test_mandel_norm_equals_frobenius() {
1451        let a = Tensor2::new([[1.0, 0.5, 0.3], [0.5, 2.0, 0.2], [0.3, 0.2, 3.0]]);
1452        let frob_sq: f64 = (0..3)
1453            .flat_map(|i| (0..3).map(move |j| a.data[i][j] * a.data[i][j]))
1454            .sum();
1455        let mv = MandelNotation::from_tensor2(&a);
1456        let mandel_sq: f64 = mv.iter().map(|x| x * x).sum();
1457        assert!(
1458            (frob_sq - mandel_sq).abs() < 1e-10,
1459            "Mandel norm² should equal Frobenius norm²: {} vs {}",
1460            frob_sq,
1461            mandel_sq
1462        );
1463    }
1464    #[test]
1465    fn test_spectral_decomp_identity() {
1466        let id = Tensor2::identity();
1467        let evals = id.eigenvalues_symmetric();
1468        for &e in &evals {
1469            assert!(
1470                (e - 1.0).abs() < 1e-8,
1471                "Identity eigenvalue should be 1, got {e}"
1472            );
1473        }
1474    }
1475    #[test]
1476    fn test_spectral_decomp_diagonal_sorted() {
1477        let d = Tensor2::diagonal([3.0, 1.0, 2.0]);
1478        let evals = d.eigenvalues_symmetric();
1479        assert!(
1480            evals[0] <= evals[1] && evals[1] <= evals[2],
1481            "Eigenvalues should be sorted ascending: {:?}",
1482            evals
1483        );
1484        assert!((evals[0] - 1.0).abs() < 1e-8);
1485        assert!((evals[1] - 2.0).abs() < 1e-8);
1486        assert!((evals[2] - 3.0).abs() < 1e-8);
1487    }
1488    #[test]
1489    fn test_spectral_decomp_stress_tensor() {
1490        let sigma = Tensor2::new([[2.0, 1.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 1.0]]);
1491        let evals = sigma.eigenvalues_symmetric();
1492        let mut sorted = evals;
1493        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1494        assert!(
1495            (sorted[0] - 1.0).abs() < 1e-6,
1496            "eval[0] = {} expected 1",
1497            sorted[0]
1498        );
1499        assert!(
1500            (sorted[1] - 1.0).abs() < 1e-6,
1501            "eval[1] = {} expected 1",
1502            sorted[1]
1503        );
1504        assert!(
1505            (sorted[2] - 3.0).abs() < 1e-6,
1506            "eval[2] = {} expected 3",
1507            sorted[2]
1508        );
1509    }
1510    #[test]
1511    fn test_voigt6x6_isotropic_symmetry() {
1512        let c = Tensor4::isotropic(1.0, 1.0);
1513        let m = c.to_voigt_matrix();
1514        for p in 0..6 {
1515            for q in 0..6 {
1516                assert!(
1517                    (m[p][q] - m[q][p]).abs() < 1e-10,
1518                    "Voigt 6x6 should be symmetric at [{p}][{q}]: {} vs {}",
1519                    m[p][q],
1520                    m[q][p]
1521                );
1522            }
1523        }
1524    }
1525    #[test]
1526    fn test_voigt6x6_positive_diagonal() {
1527        let c = Tensor4::isotropic(1.0, 1.0);
1528        let m = c.to_voigt_matrix();
1529        for p in 0..6 {
1530            assert!(
1531                m[p][p] > 0.0,
1532                "Voigt diagonal M[{p}][{p}] = {} should be > 0",
1533                m[p][p]
1534            );
1535        }
1536    }
1537    #[test]
1538    fn test_voigt6x6_isotropic_values() {
1539        let c = Tensor4::isotropic(1.0, 1.0);
1540        let m = c.to_voigt_matrix();
1541        assert!(
1542            (m[0][0] - 3.0).abs() < 1e-10,
1543            "C_1111 = λ+2μ = 3, got {}",
1544            m[0][0]
1545        );
1546        assert!(
1547            (m[0][1] - 1.0).abs() < 1e-10,
1548            "C_1122 = λ = 1, got {}",
1549            m[0][1]
1550        );
1551        assert!(
1552            (m[3][3] - 1.0).abs() < 1e-10,
1553            "C_1212 = μ = 1, got {}",
1554            m[3][3]
1555        );
1556    }
1557    #[test]
1558    fn test_double_contract_commutative_for_symmetric() {
1559        let a = Tensor2::new([[1.0, 0.5, 0.0], [0.5, 2.0, 0.0], [0.0, 0.0, 3.0]]);
1560        let b = Tensor2::new([[2.0, 0.1, 0.0], [0.1, 1.0, 0.0], [0.0, 0.0, 0.5]]);
1561        let ab = a.double_contract(&b);
1562        let ba = b.double_contract(&a);
1563        assert!(
1564            (ab - ba).abs() < 1e-10,
1565            "A:B should equal B:A for symmetric tensors"
1566        );
1567    }
1568    #[test]
1569    fn test_deviatoric_plus_hydrostatic_equals_original() {
1570        let sigma = Tensor2::new([[3.0, 1.0, 0.5], [1.0, 4.0, 0.2], [0.5, 0.2, 5.0]]);
1571        let (dev, p) = sigma.decompose_dev_hydro();
1572        let hydro = Tensor2::identity().scale(p);
1573        let reconstructed = dev.add(&hydro);
1574        assert!(
1575            tensor_approx_eq(&reconstructed, &sigma),
1576            "dev + hydro should equal original tensor"
1577        );
1578    }
1579    #[test]
1580    fn test_deviatoric_is_traceless() {
1581        let sigma = Tensor2::new([[3.0, 1.0, 0.5], [1.0, 4.0, 0.2], [0.5, 0.2, 5.0]]);
1582        let dev = sigma.deviatoric();
1583        let tr = dev.trace();
1584        assert!(
1585            tr.abs() < 1e-10,
1586            "Deviatoric tensor should be traceless, got {tr}"
1587        );
1588    }
1589    #[test]
1590    fn test_major_symmetry_isotropic() {
1591        let c = Tensor4::isotropic(2.0, 3.0);
1592        assert!(
1593            c.has_major_symmetry(1e-10),
1594            "Isotropic C should have major symmetry"
1595        );
1596    }
1597    #[test]
1598    fn test_minor_symmetry_isotropic() {
1599        let c = Tensor4::isotropic(2.0, 3.0);
1600        assert!(
1601            c.has_minor_symmetry(1e-10),
1602            "Isotropic C should have minor symmetry"
1603        );
1604    }
1605    #[test]
1606    fn test_dense_tensor_zeros() {
1607        let t = DenseTensor::zeros(&[2, 3, 4]);
1608        assert_eq!(t.shape, vec![2, 3, 4]);
1609        assert_eq!(t.data.len(), 24);
1610        assert!(t.data.iter().all(|&x| x == 0.0));
1611    }
1612    #[test]
1613    fn test_dense_tensor_from_data() {
1614        let data: Vec<f64> = (0..6).map(|x| x as f64).collect();
1615        let t = DenseTensor::from_data(&[2, 3], data.clone());
1616        assert_eq!(t.get(&[0, 0]), 0.0);
1617        assert_eq!(t.get(&[0, 1]), 1.0);
1618        assert_eq!(t.get(&[1, 2]), 5.0);
1619    }
1620    #[test]
1621    fn test_dense_tensor_set_get() {
1622        let mut t = DenseTensor::zeros(&[3, 3]);
1623        t.set(&[1, 2], 42.0);
1624        assert_eq!(t.get(&[1, 2]), 42.0);
1625        assert_eq!(t.get(&[0, 0]), 0.0);
1626    }
1627    #[test]
1628    fn test_dense_tensor_frobenius_norm() {
1629        let t = DenseTensor::from_data(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
1630        let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1631        assert!(
1632            (t.frobenius_norm() - expected).abs() < 1e-12,
1633            "Frobenius norm mismatch: {} vs {}",
1634            t.frobenius_norm(),
1635            expected
1636        );
1637    }
1638    #[test]
1639    fn test_mode_n_unfold_matrix() {
1640        let t = DenseTensor::from_data(&[2, 3], vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
1641        let m0 = t.mode_n_unfold(0);
1642        assert_eq!(m0.len(), 2);
1643        assert_eq!(m0[0].len(), 3);
1644        let m1 = t.mode_n_unfold(1);
1645        assert_eq!(m1.len(), 3);
1646        assert_eq!(m1[0].len(), 2);
1647    }
1648    #[test]
1649    fn test_mode_n_fold_roundtrip() {
1650        let data: Vec<f64> = (0..24).map(|x| x as f64).collect();
1651        let t = DenseTensor::from_data(&[2, 3, 4], data.clone());
1652        for mode in 0..3 {
1653            let unfolded = t.mode_n_unfold(mode);
1654            let refolded = DenseTensor::mode_n_fold(&unfolded, mode, &[2, 3, 4]);
1655            for (a, b) in t.data.iter().zip(refolded.data.iter()) {
1656                assert!(
1657                    (a - b).abs() < 1e-12,
1658                    "mode_{mode} fold round-trip failed: {a} vs {b}"
1659                );
1660            }
1661        }
1662    }
1663    #[test]
1664    fn test_einsum_matmul() {
1665        let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1666        let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1667        let c = einsum_2d("ij,jk->ik", &a, Some(&b));
1668        assert!((c[0][0] - 19.0).abs() < 1e-12);
1669        assert!((c[0][1] - 22.0).abs() < 1e-12);
1670        assert!((c[1][0] - 43.0).abs() < 1e-12);
1671        assert!((c[1][1] - 50.0).abs() < 1e-12);
1672    }
1673    #[test]
1674    fn test_einsum_frobenius_inner_product() {
1675        let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1676        let b = vec![vec![2.0, 3.0], vec![4.0, 5.0]];
1677        let r = einsum_2d("ij,ij->", &a, Some(&b));
1678        assert!((r[0][0] - 7.0).abs() < 1e-12);
1679    }
1680    #[test]
1681    fn test_einsum_transpose() {
1682        let a = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
1683        let t = einsum_2d("ij->ji", &a, None);
1684        assert_eq!(t.len(), 3);
1685        assert_eq!(t[0].len(), 2);
1686        assert!((t[0][0] - 1.0).abs() < 1e-12);
1687        assert!((t[2][1] - 6.0).abs() < 1e-12);
1688    }
1689    #[test]
1690    fn test_einsum_trace() {
1691        let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1692        let r = einsum_2d("ii->", &a, None);
1693        assert!(
1694            (r[0][0] - 5.0).abs() < 1e-12,
1695            "trace should be 5, got {}",
1696            r[0][0]
1697        );
1698    }
1699    #[test]
1700    fn test_cp_als_rank1_reconstruction() {
1701        let a = [1.0, 2.0];
1702        let b = [1.0, 2.0, 3.0];
1703        let c = [1.0, 2.0];
1704        let mut data = vec![];
1705        for &ai in &a {
1706            for &bj in &b {
1707                for &ck in &c {
1708                    data.push(ai * bj * ck);
1709                }
1710            }
1711        }
1712        let tensor = DenseTensor::from_data(&[2, 3, 2], data.clone());
1713        let fro_norm = tensor.frobenius_norm();
1714        let cp = cp_als(&tensor, 1, 500, 1e-12);
1715        let err = cp_reconstruction_error(&tensor, &cp.a, &cp.b, &cp.c, &cp.lambdas);
1716        let rel_err = err / (fro_norm + 1e-14);
1717        assert!(
1718            rel_err < 0.1,
1719            "CP rank-1 relative error should be < 10%, got {rel_err}"
1720        );
1721    }
1722    #[test]
1723    fn test_cp_als_returns_correct_shapes() {
1724        let tensor = DenseTensor::zeros(&[3, 4, 5]);
1725        let cp = cp_als(&tensor, 2, 10, 1e-6);
1726        assert_eq!(cp.a.len(), 3);
1727        assert_eq!(cp.b.len(), 4);
1728        assert_eq!(cp.c.len(), 5);
1729        assert_eq!(cp.lambdas.len(), 2);
1730    }
1731    #[test]
1732    fn test_cp_reconstruction_error_zeros() {
1733        let tensor = DenseTensor::zeros(&[2, 2, 2]);
1734        let cp = cp_als(&tensor, 1, 5, 1e-10);
1735        let err = cp_reconstruction_error(&tensor, &cp.a, &cp.b, &cp.c, &cp.lambdas);
1736        assert!(err < 1e-6, "Zero tensor CP error should be tiny, got {err}");
1737    }
1738    #[test]
1739    fn test_tucker_hosvd_shape() {
1740        let data: Vec<f64> = (0..24).map(|x| x as f64).collect();
1741        let tensor = DenseTensor::from_data(&[2, 3, 4], data);
1742        let td = tucker_hosvd(&tensor, [2, 2, 2]);
1743        assert_eq!(td.core.shape, vec![2, 2, 2]);
1744        assert_eq!(td.u0.len(), 2);
1745        assert_eq!(td.u0[0].len(), 2);
1746        assert_eq!(td.u1.len(), 3);
1747        assert_eq!(td.u2.len(), 4);
1748    }
1749    #[test]
1750    fn test_tucker_hosvd_full_rank_reconstruction() {
1751        let data: Vec<f64> = (0..8).map(|x| x as f64).collect();
1752        let tensor = DenseTensor::from_data(&[2, 2, 2], data);
1753        let fro = tensor.frobenius_norm();
1754        let td = tucker_hosvd(&tensor, [2, 2, 2]);
1755        let recon = tucker_reconstruct(&td);
1756        let err: f64 = tensor
1757            .data
1758            .iter()
1759            .zip(recon.data.iter())
1760            .map(|(a, b)| (a - b).powi(2))
1761            .sum::<f64>()
1762            .sqrt();
1763        let rel_err = err / (fro + 1e-14);
1764        assert!(
1765            rel_err < 0.5,
1766            "Tucker full-rank reconstruction relative error: {rel_err} (err={err}, fro={fro})"
1767        );
1768    }
1769    #[test]
1770    fn test_tucker_reconstruct_shape() {
1771        let data: Vec<f64> = (0..27).map(|x| x as f64).collect();
1772        let tensor = DenseTensor::from_data(&[3, 3, 3], data);
1773        let td = tucker_hosvd(&tensor, [2, 2, 2]);
1774        let recon = tucker_reconstruct(&td);
1775        assert_eq!(recon.shape, vec![3, 3, 3]);
1776        assert_eq!(recon.data.len(), 27);
1777    }
1778    #[test]
1779    fn test_tt_core_get_set() {
1780        let mut core = TtCore::zeros(2, 3, 4);
1781        core.set(1, 2, 3, 5.0);
1782        assert!((core.get(1, 2, 3) - 5.0).abs() < 1e-12);
1783        assert!((core.get(0, 0, 0)).abs() < 1e-12);
1784    }
1785    #[test]
1786    fn test_tt_core_frobenius_norm() {
1787        let mut core = TtCore::zeros(1, 2, 1);
1788        core.set(0, 0, 0, 3.0);
1789        core.set(0, 1, 0, 4.0);
1790        assert!(
1791            (core.frobenius_norm() - 5.0).abs() < 1e-12,
1792            "TT core norm: expected 5, got {}",
1793            core.frobenius_norm()
1794        );
1795    }
1796    #[test]
1797    fn test_tt_evaluate_simple() {
1798        let mut c0 = TtCore::zeros(1, 2, 1);
1799        c0.set(0, 0, 0, 1.0);
1800        c0.set(0, 1, 0, 2.0);
1801        let mut c1 = TtCore::zeros(1, 2, 1);
1802        c1.set(0, 0, 0, 1.0);
1803        c1.set(0, 1, 0, 2.0);
1804        let tt = TensorTrain {
1805            cores: vec![c0, c1],
1806            shape: vec![2, 2],
1807        };
1808        assert!((tt.evaluate(&[0, 0]) - 1.0).abs() < 1e-12);
1809        assert!((tt.evaluate(&[0, 1]) - 2.0).abs() < 1e-12);
1810        assert!((tt.evaluate(&[1, 0]) - 2.0).abs() < 1e-12);
1811        assert!((tt.evaluate(&[1, 1]) - 4.0).abs() < 1e-12);
1812    }
1813    #[test]
1814    fn test_tt_from_dense_shape() {
1815        let data: Vec<f64> = (0..8).map(|x| x as f64).collect();
1816        let tensor = DenseTensor::from_data(&[2, 2, 2], data);
1817        let tt = TensorTrain::from_dense(&tensor, 4, 1e-10);
1818        assert_eq!(tt.shape, vec![2, 2, 2]);
1819        assert_eq!(tt.cores.len(), 3);
1820        assert_eq!(tt.cores[0].r_left, 1);
1821        assert_eq!(tt.cores.last().unwrap().r_right, 1);
1822    }
1823    #[test]
1824    fn test_tt_frobenius_norm_identity_like() {
1825        let mut c0 = TtCore::zeros(1, 2, 1);
1826        c0.set(0, 0, 0, 1.0);
1827        c0.set(0, 1, 0, 1.0);
1828        let mut c1 = TtCore::zeros(1, 2, 1);
1829        c1.set(0, 0, 0, 1.0);
1830        c1.set(0, 1, 0, 1.0);
1831        let tt = TensorTrain {
1832            cores: vec![c0, c1],
1833            shape: vec![2, 2],
1834        };
1835        let norm = tt.frobenius_norm_approx();
1836        assert!(
1837            (norm - 2.0).abs() < 1e-12,
1838            "TT Frobenius norm expected 2, got {norm}"
1839        );
1840    }
1841    #[test]
1842    fn test_matmul_helper() {
1843        let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1844        let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1845        let c = matmul(&a, &b);
1846        assert!((c[0][0] - 19.0).abs() < 1e-12);
1847        assert!((c[1][1] - 50.0).abs() < 1e-12);
1848    }
1849    #[test]
1850    fn test_gram_helper() {
1851        let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1852        let g = gram(&a);
1853        assert!((g[0][0] - 1.0).abs() < 1e-12);
1854        assert!((g[0][1]).abs() < 1e-12);
1855        assert!((g[1][1] - 1.0).abs() < 1e-12);
1856    }
1857    #[test]
1858    fn test_khatri_rao_shape() {
1859        let a = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
1860        let b = vec![vec![7.0, 8.0], vec![9.0, 10.0]];
1861        let kr = khatri_rao(&a, &b);
1862        assert_eq!(kr.len(), 6);
1863        assert_eq!(kr[0].len(), 2);
1864    }
1865    #[test]
1866    fn test_transpose_helper() {
1867        let a = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
1868        let t = transpose(&a);
1869        assert_eq!(t.len(), 3);
1870        assert_eq!(t[0].len(), 2);
1871        assert!((t[2][1] - 6.0).abs() < 1e-12);
1872    }
1873    #[test]
1874    fn test_solve_ls_identity() {
1875        let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1876        let b = vec![vec![3.0, 5.0]];
1877        let b_t = vec![vec![3.0], vec![5.0]];
1878        let x = solve_ls(&a, &b_t);
1879        assert!((x[0][0] - 3.0).abs() < 1e-12);
1880        assert!((x[1][0] - 5.0).abs() < 1e-12);
1881        let _ = b;
1882    }
1883}