Skip to main content

cjc_runtime/
linalg.rs

1use crate::accumulator::BinnedAccumulatorF64;
2use crate::error::RuntimeError;
3use crate::tensor::Tensor;
4
5// ---------------------------------------------------------------------------
6// 6. Linalg Operations on Tensor
7// ---------------------------------------------------------------------------
8
9impl Tensor {
10    /// LU decomposition with partial pivoting. Returns (L, U, pivot_indices).
11    /// Input must be square 2D.
12    ///
13    /// **Determinism contract:** Pivot selection uses strict `>` comparison on
14    /// absolute values. When two candidates have identical absolute values, the
15    /// first (lowest row index) is chosen. This is deterministic given identical
16    /// input bits.
17    pub fn lu_decompose(&self) -> Result<(Tensor, Tensor, Vec<usize>), RuntimeError> {
18        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
19            return Err(RuntimeError::InvalidOperation(
20                "LU decomposition requires a square 2D matrix".to_string(),
21            ));
22        }
23        let n = self.shape[0];
24        let mut a = self.to_vec();
25        let mut pivots: Vec<usize> = (0..n).collect();
26
27        for k in 0..n {
28            // Find pivot
29            let mut max_val = a[k * n + k].abs();
30            let mut max_row = k;
31            for i in (k + 1)..n {
32                let v = a[i * n + k].abs();
33                if v > max_val {
34                    max_val = v;
35                    max_row = i;
36                }
37            }
38            if max_val < 1e-15 {
39                return Err(RuntimeError::InvalidOperation(
40                    "LU decomposition: singular matrix".to_string(),
41                ));
42            }
43            if max_row != k {
44                pivots.swap(k, max_row);
45                for j in 0..n {
46                    let tmp = a[k * n + j];
47                    a[k * n + j] = a[max_row * n + j];
48                    a[max_row * n + j] = tmp;
49                }
50            }
51            for i in (k + 1)..n {
52                a[i * n + k] /= a[k * n + k];
53                for j in (k + 1)..n {
54                    a[i * n + j] -= a[i * n + k] * a[k * n + j];
55                }
56            }
57        }
58
59        // Extract L and U
60        let mut l_data = vec![0.0f64; n * n];
61        let mut u_data = vec![0.0f64; n * n];
62        for i in 0..n {
63            for j in 0..n {
64                if i == j {
65                    l_data[i * n + j] = 1.0;
66                    u_data[i * n + j] = a[i * n + j];
67                } else if i > j {
68                    l_data[i * n + j] = a[i * n + j];
69                } else {
70                    u_data[i * n + j] = a[i * n + j];
71                }
72            }
73        }
74
75        Ok((
76            Tensor::from_vec(l_data, &[n, n])?,
77            Tensor::from_vec(u_data, &[n, n])?,
78            pivots,
79        ))
80    }
81
82    /// QR decomposition via Modified Gram-Schmidt. Returns (Q, R).
83    /// Input must be 2D with rows >= cols.
84    pub fn qr_decompose(&self) -> Result<(Tensor, Tensor), RuntimeError> {
85        if self.ndim() != 2 {
86            return Err(RuntimeError::InvalidOperation(
87                "QR decomposition requires a 2D matrix".to_string(),
88            ));
89        }
90        let m = self.shape[0];
91        let n = self.shape[1];
92        let min_mn = m.min(n);
93
94        // Work on a column-major copy for better locality in Householder reflections.
95        let row_major = self.to_vec();
96        let mut cm = vec![0.0f64; m * n]; // column-major
97        for i in 0..m {
98            for j in 0..n {
99                cm[j * m + i] = row_major[i * n + j];
100            }
101        }
102
103        // Store Householder vectors and tau values
104        let mut tau = vec![0.0f64; min_mn];
105
106        for k in 0..min_mn {
107            // Compute norm of column k below diagonal
108            let mut sigma = 0.0f64;
109            for i in (k + 1)..m {
110                sigma += cm[k * m + i] * cm[k * m + i];
111            }
112            let x0 = cm[k * m + k];
113            let norm_x = (x0 * x0 + sigma).sqrt();
114
115            if norm_x < 1e-15 {
116                tau[k] = 0.0;
117                continue;
118            }
119
120            // Householder vector: v[k] = x[k] - alpha, v[k+1:] = x[k+1:]
121            let alpha = if x0 >= 0.0 { -norm_x } else { norm_x };
122            cm[k * m + k] -= alpha;
123            let v_norm_sq = cm[k * m + k] * cm[k * m + k] + sigma;
124            tau[k] = if v_norm_sq > 1e-30 { 2.0 / v_norm_sq } else { 0.0 };
125
126            // Apply H = I - tau * v * v^T to columns k+1..n
127            for j in (k + 1)..n {
128                let mut dot = 0.0f64;
129                for i in k..m {
130                    dot += cm[k * m + i] * cm[j * m + i];
131                }
132                let scale = tau[k] * dot;
133                for i in k..m {
134                    cm[j * m + i] -= scale * cm[k * m + i];
135                }
136            }
137
138            // Store alpha on the diagonal (R[k,k])
139            cm[k * m + k] = alpha;
140        }
141
142        // Extract R (min_mn × n) from upper triangle of cm (column-major)
143        let mut r_data = vec![0.0f64; min_mn * n];
144        for j in 0..n {
145            for i in 0..min_mn.min(j + 1) {
146                r_data[i * n + j] = cm[j * m + i];
147            }
148        }
149
150        // Reconstruct Q (m × min_mn) from Householder vectors
151        // Start with identity, apply reflectors in reverse order
152        let mut q_cm = vec![0.0f64; m * min_mn]; // column-major
153        for i in 0..min_mn { q_cm[i * m + i] = 1.0; }
154
155        for k in (0..min_mn).rev() {
156            if tau[k] == 0.0 { continue; }
157            // Reconstruct v from stored values
158            // v[k] was overwritten with alpha, but we need the original v[k].
159            // Since alpha = x[k] - v[k], and R[k,k] = alpha is in cm[k*m+k]...
160            // We stored the Householder vector in cm[k*m+k:k*m+m] AFTER
161            // setting cm[k*m+k] = alpha. So the v is lost for the diagonal.
162            // Instead, let's use the standard approach: v[k] = 1 (normalize).
163            // Actually, we need to reconstruct properly. Let me just use
164            // the explicit Q accumulation approach.
165
166            // v = [0..0, 1, cm[k*m+k+1], cm[k*m+k+2], ..., cm[k*m+m-1]]
167            // But cm[k*m+k] was overwritten. We need to recover v[k].
168            // Since tau[k] = 2 / v_norm_sq, and v[k+1:] are stored in cm,
169            // we can recover: v_norm_sq = 2/tau[k], v[k]^2 = v_norm_sq - sigma
170            // where sigma = sum(cm[k*m+i]^2 for i in k+1..m)
171            let mut sigma2 = 0.0f64;
172            for i in (k + 1)..m {
173                sigma2 += cm[k * m + i] * cm[k * m + i];
174            }
175            let v_norm_sq = 2.0 / tau[k];
176            let vk_sq = v_norm_sq - sigma2;
177            let vk = if vk_sq > 0.0 { vk_sq.sqrt() } else { 0.0 };
178            // Determine sign: v[k] = x[k] - alpha, both known at construction.
179            // Since we set v[k] = x[k] - alpha and alpha has opposite sign to x[k],
180            // v[k] is always positive (|x[k]| + |alpha| > 0).
181
182            // Apply H_k to Q columns
183            for j in k..min_mn {
184                let mut dot = vk * q_cm[j * m + k];
185                for i in (k + 1)..m {
186                    dot += cm[k * m + i] * q_cm[j * m + i];
187                }
188                let scale = tau[k] * dot;
189                q_cm[j * m + k] -= scale * vk;
190                for i in (k + 1)..m {
191                    q_cm[j * m + i] -= scale * cm[k * m + i];
192                }
193            }
194        }
195
196        // Convert Q from column-major to row-major
197        let mut q_data = vec![0.0f64; m * min_mn];
198        for i in 0..m {
199            for j in 0..min_mn {
200                q_data[i * min_mn + j] = q_cm[j * m + i];
201            }
202        }
203
204        Ok((
205            Tensor::from_vec(q_data, &[m, min_mn])?,
206            Tensor::from_vec(r_data, &[min_mn, n])?,
207        ))
208    }
209
210    /// Cholesky decomposition: A = L * L^T.
211    /// Input must be symmetric positive definite 2D.
212    pub fn cholesky(&self) -> Result<Tensor, RuntimeError> {
213        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
214            return Err(RuntimeError::InvalidOperation(
215                "Cholesky decomposition requires a square 2D matrix".to_string(),
216            ));
217        }
218        let n = self.shape[0];
219        let a = self.to_vec();
220        let mut l = vec![0.0f64; n * n];
221
222        for j in 0..n {
223            // Kahan summation: lightweight (16 bytes) vs BinnedAccumulator (32KB).
224            // Deterministic because iteration order is fixed (0..j).
225            let mut sum = 0.0f64;
226            let mut comp = 0.0f64;
227            for k in 0..j {
228                let y = l[j * n + k] * l[j * n + k] - comp;
229                let t = sum + y;
230                comp = (t - sum) - y;
231                sum = t;
232            }
233            let diag = a[j * n + j] - sum;
234            if diag <= 0.0 {
235                return Err(RuntimeError::InvalidOperation(
236                    "Cholesky: matrix is not positive definite".to_string(),
237                ));
238            }
239            l[j * n + j] = diag.sqrt();
240
241            for i in (j + 1)..n {
242                let mut s = 0.0f64;
243                let mut c = 0.0f64;
244                for k in 0..j {
245                    let y = l[i * n + k] * l[j * n + k] - c;
246                    let t = s + y;
247                    c = (t - s) - y;
248                    s = t;
249                }
250                l[i * n + j] = (a[i * n + j] - s) / l[j * n + j];
251            }
252        }
253
254        Tensor::from_vec(l, &[n, n])
255    }
256
257    /// Determinant via LU decomposition: product of U diagonal * parity.
258    pub fn det(&self) -> Result<f64, RuntimeError> {
259        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
260            return Err(RuntimeError::InvalidOperation(
261                "det requires a square 2D matrix".to_string(),
262            ));
263        }
264        let n = self.shape[0];
265        let mut a = self.to_vec();
266        let mut sign = 1.0f64;
267        for k in 0..n {
268            let mut max_val = a[k * n + k].abs();
269            let mut max_row = k;
270            for i in (k + 1)..n {
271                let v = a[i * n + k].abs();
272                if v > max_val {
273                    max_val = v;
274                    max_row = i;
275                }
276            }
277            if max_val < 1e-15 {
278                return Ok(0.0); // singular
279            }
280            if max_row != k {
281                sign = -sign;
282                for j in 0..n {
283                    let tmp = a[k * n + j];
284                    a[k * n + j] = a[max_row * n + j];
285                    a[max_row * n + j] = tmp;
286                }
287            }
288            for i in (k + 1)..n {
289                a[i * n + k] /= a[k * n + k];
290                for j in (k + 1)..n {
291                    a[i * n + j] -= a[i * n + k] * a[k * n + j];
292                }
293            }
294        }
295        let mut det = sign;
296        for i in 0..n {
297            det *= a[i * n + i];
298        }
299        Ok(det)
300    }
301
302    /// Solve Ax = b via LU decomposition.
303    /// self = A (n x n), b = vector (n).
304    pub fn solve(&self, b: &Tensor) -> Result<Tensor, RuntimeError> {
305        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
306            return Err(RuntimeError::InvalidOperation(
307                "solve requires a square 2D matrix A".to_string(),
308            ));
309        }
310        let n = self.shape[0];
311        if b.len() != n {
312            return Err(RuntimeError::InvalidOperation(
313                format!("solve: b length {} != n = {n}", b.len()),
314            ));
315        }
316        let (l, u, pivots) = self.lu_decompose()?;
317        let l_data = l.to_vec();
318        let u_data = u.to_vec();
319        let b_data = b.to_vec();
320
321        // Permute b
322        let mut pb = vec![0.0; n];
323        for i in 0..n {
324            pb[i] = b_data[pivots[i]];
325        }
326
327        // Forward substitution: L * y = pb
328        let mut y = vec![0.0; n];
329        for i in 0..n {
330            let mut s = pb[i];
331            for j in 0..i {
332                s -= l_data[i * n + j] * y[j];
333            }
334            y[i] = s;
335        }
336
337        // Back substitution: U * x = y
338        let mut x = vec![0.0; n];
339        for i in (0..n).rev() {
340            let mut s = y[i];
341            for j in (i + 1)..n {
342                s -= u_data[i * n + j] * x[j];
343            }
344            x[i] = s / u_data[i * n + i];
345        }
346
347        Tensor::from_vec(x, &[n])
348    }
349
350    /// Least squares solution: min ||Ax - b||_2 via QR decomposition.
351    /// self = A (m x n, m >= n), b = vector (m).
352    pub fn lstsq(&self, b: &Tensor) -> Result<Tensor, RuntimeError> {
353        if self.ndim() != 2 {
354            return Err(RuntimeError::InvalidOperation(
355                "lstsq requires a 2D matrix".to_string(),
356            ));
357        }
358        let m = self.shape[0];
359        let n = self.shape[1];
360        if m < n {
361            return Err(RuntimeError::InvalidOperation(
362                "lstsq requires m >= n".to_string(),
363            ));
364        }
365        if b.len() != m {
366            return Err(RuntimeError::InvalidOperation(
367                format!("lstsq: b length {} != m = {m}", b.len()),
368            ));
369        }
370        let (q, r) = self.qr_decompose()?;
371        let q_data = q.to_vec();
372        let r_data = r.to_vec();
373        let b_data = b.to_vec();
374
375        // Q^T * b (Q is m x n, so Q^T * b gives n-vector)
376        // Use BinnedAccumulator for deterministic dot products.
377        let mut qtb = vec![0.0; n];
378        for j in 0..n {
379            let mut acc = BinnedAccumulatorF64::new();
380            for i in 0..m {
381                acc.add(q_data[i * n + j] * b_data[i]);
382            }
383            qtb[j] = acc.finalize();
384        }
385
386        // Back substitution: R * x = Q^T * b
387        let mut x = vec![0.0; n];
388        for i in (0..n).rev() {
389            let mut s = qtb[i];
390            for j in (i + 1)..n {
391                s -= r_data[i * n + j] * x[j];
392            }
393            if r_data[i * n + i].abs() < 1e-15 {
394                return Err(RuntimeError::InvalidOperation(
395                    "lstsq: rank-deficient matrix".to_string(),
396                ));
397            }
398            x[i] = s / r_data[i * n + i];
399        }
400
401        Tensor::from_vec(x, &[n])
402    }
403
404    /// Matrix trace: sum of diagonal elements.
405    pub fn trace(&self) -> Result<f64, RuntimeError> {
406        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
407            return Err(RuntimeError::InvalidOperation(
408                "trace requires a square 2D matrix".to_string(),
409            ));
410        }
411        let n = self.shape[0];
412        let data = self.to_vec();
413        let mut acc = BinnedAccumulatorF64::new();
414        for i in 0..n {
415            acc.add(data[i * n + i]);
416        }
417        Ok(acc.finalize())
418    }
419
420    /// Frobenius norm: sqrt(sum(aij^2)).
421    pub fn norm_frobenius(&self) -> Result<f64, RuntimeError> {
422        if self.ndim() != 2 {
423            return Err(RuntimeError::InvalidOperation(
424                "norm_frobenius requires a 2D matrix".to_string(),
425            ));
426        }
427        let data = self.to_vec();
428        let mut acc = BinnedAccumulatorF64::new();
429        for &v in &data {
430            acc.add(v * v);
431        }
432        Ok(acc.finalize().sqrt())
433    }
434
435    /// Eigenvalue decomposition for symmetric matrices (Jacobi method).
436    /// Returns (eigenvalues sorted ascending, eigenvectors n x n).
437    /// DETERMINISM: fixed row-major sweep order, smallest (i,j) tie-breaking.
438    pub fn eigh(&self) -> Result<(Vec<f64>, Tensor), RuntimeError> {
439        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
440            return Err(RuntimeError::InvalidOperation(
441                "eigh requires a square 2D matrix".to_string(),
442            ));
443        }
444        let n = self.shape[0];
445
446        if n <= 1 {
447            let val = if n == 1 { self.to_vec()[0] } else { 0.0 };
448            let v_data = if n == 1 { vec![1.0] } else { vec![] };
449            return Ok((vec![val], Tensor::from_vec(v_data, &[n, n])?));
450        }
451
452        // ── Step 1: Householder tridiagonalization ──────────────────────
453        // Reduce symmetric A to tridiagonal form T = Q^T A Q.
454        // Q is accumulated as a product of Householder reflectors.
455        // This is O(n^3) — far better than Jacobi's O(n^4) worst case.
456
457        let mut a = self.to_vec();
458        let mut q = vec![0.0f64; n * n];
459        for i in 0..n { q[i * n + i] = 1.0; }
460
461        // diag[i] = T[i,i], offd[i] = T[i,i+1]
462        let mut diag = vec![0.0f64; n];
463        let mut offd = vec![0.0f64; n]; // offd[0..n-1], offd[n-1] unused
464
465        for k in 0..n.saturating_sub(2) {
466            // Compute Householder vector for column k, rows k+1..n
467            let mut sigma = 0.0f64;
468            for i in (k + 1)..n {
469                sigma += a[i * n + k] * a[i * n + k];
470            }
471            let sigma_sqrt = sigma.sqrt();
472
473            if sigma_sqrt < 1e-15 {
474                // Column already zero below diagonal
475                offd[k] = a[(k + 1) * n + k];
476                continue;
477            }
478
479            let alpha = if a[(k + 1) * n + k] >= 0.0 { -sigma_sqrt } else { sigma_sqrt };
480            offd[k] = alpha;
481
482            // Householder vector v stored in a[k+1..n, k]
483            a[(k + 1) * n + k] -= alpha;
484            let mut v_norm_sq = 0.0f64;
485            for i in (k + 1)..n {
486                v_norm_sq += a[i * n + k] * a[i * n + k];
487            }
488            if v_norm_sq < 1e-30 { continue; }
489            let inv_v_norm_sq = 2.0 / v_norm_sq;
490
491            // p = inv_v_norm_sq * A[k+1..n, k+1..n] * v
492            let mut p = vec![0.0f64; n];
493            for i in (k + 1)..n {
494                let mut s = 0.0f64;
495                for j in (k + 1)..n {
496                    s += a[i * n + j] * a[j * n + k];
497                }
498                p[i] = inv_v_norm_sq * s;
499            }
500
501            // K = v^T p / (2 * v_norm_sq) * inv_v_norm_sq
502            let mut vtp = 0.0f64;
503            for i in (k + 1)..n { vtp += a[i * n + k] * p[i]; }
504            let kk = inv_v_norm_sq * vtp * 0.5;
505
506            // q = p - K * v
507            let mut qq = vec![0.0f64; n];
508            for i in (k + 1)..n {
509                qq[i] = p[i] - kk * a[i * n + k];
510            }
511
512            // A = A - v * q^T - q * v^T
513            for i in (k + 1)..n {
514                for j in (k + 1)..n {
515                    a[i * n + j] -= a[i * n + k] * qq[j] + qq[i] * a[j * n + k];
516                }
517            }
518
519            // Accumulate Q = Q * (I - inv_v_norm_sq * v * v^T)
520            // Q := Q - (Q * v) * (inv_v_norm_sq * v^T)
521            let mut qv = vec![0.0f64; n];
522            for i in 0..n {
523                let mut s = 0.0f64;
524                for j in (k + 1)..n {
525                    s += q[i * n + j] * a[j * n + k];
526                }
527                qv[i] = s;
528            }
529            for i in 0..n {
530                for j in (k + 1)..n {
531                    q[i * n + j] -= inv_v_norm_sq * qv[i] * a[j * n + k];
532                }
533            }
534        }
535
536        // Fill diagonal and last off-diagonal from the reduced matrix
537        for i in 0..n { diag[i] = a[i * n + i]; }
538        if n >= 2 { offd[n - 2] = a[(n - 2) * n + (n - 1)]; }
539
540        // ── Step 2: Implicit QR iteration on tridiagonal matrix ────────
541        // Wilkinson shift for cubic convergence.
542        // O(n) per iteration, O(n) iterations typical => O(n^2) total.
543
544        let max_iter = 30 * n;
545        let mut lo = 0usize;
546        let mut hi = n - 1;
547
548        for _iter in 0..max_iter {
549            // Find unreduced block
550            while lo < hi {
551                if offd[hi - 1].abs() < 1e-14 * (diag[hi - 1].abs() + diag[hi].abs()).max(1e-14) {
552                    offd[hi - 1] = 0.0;
553                    hi -= 1;
554                } else {
555                    break;
556                }
557            }
558            if hi <= lo { break; }
559
560            // Find start of unreduced block
561            let mut block_start = hi - 1;
562            while block_start > lo {
563                if offd[block_start - 1].abs() < 1e-14 * (diag[block_start - 1].abs() + diag[block_start].abs()).max(1e-14) {
564                    offd[block_start - 1] = 0.0;
565                    break;
566                }
567                block_start -= 1;
568            }
569
570            if block_start == hi {
571                // 1×1 block converged
572                hi = if hi > 0 { hi - 1 } else { break };
573                continue;
574            }
575
576            // Wilkinson shift: eigenvalue of trailing 2×2 closer to d[hi]
577            let d_hi = diag[hi];
578            let d_him1 = diag[hi - 1];
579            let e_him1 = offd[hi - 1];
580            let delta = (d_him1 - d_hi) * 0.5;
581            let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
582            let shift = d_hi - e_him1 * e_him1 / (delta + sign_delta * (delta * delta + e_him1 * e_him1).sqrt());
583
584            // Implicit QR step (Givens rotations)
585            let mut x = diag[block_start] - shift;
586            let mut z = offd[block_start];
587
588            for k in block_start..hi {
589                // Compute Givens rotation to zero z
590                let r = (x * x + z * z).sqrt();
591                let c = if r > 1e-30 { x / r } else { 1.0 };
592                let s = if r > 1e-30 { -z / r } else { 0.0 };
593
594                // Apply to tridiagonal elements
595                if k > block_start {
596                    offd[k - 1] = r;
597                }
598                let d0 = diag[k];
599                let d1 = diag[k + 1];
600                let e0 = offd[k];
601                diag[k] = c * c * d0 + s * s * d1 - 2.0 * c * s * e0;
602                diag[k + 1] = s * s * d0 + c * c * d1 + 2.0 * c * s * e0;
603                offd[k] = c * s * (d0 - d1) + (c * c - s * s) * e0;
604
605                if k + 1 < hi {
606                    x = offd[k];
607                    z = -s * offd[k + 1];
608                    offd[k + 1] *= c;
609                }
610
611                // Accumulate eigenvector rotation: Q[:, k:k+2] *= G
612                for i in 0..n {
613                    let qik = q[i * n + k];
614                    let qik1 = q[i * n + k + 1];
615                    q[i * n + k] = c * qik - s * qik1;
616                    q[i * n + k + 1] = s * qik + c * qik1;
617                }
618            }
619
620            // Check if all off-diagonals converged
621            let mut all_converged = true;
622            for i in lo..hi {
623                if offd[i].abs() >= 1e-14 * (diag[i].abs() + diag[i + 1].abs()).max(1e-14) {
624                    all_converged = false;
625                    break;
626                }
627            }
628            if all_converged { break; }
629        }
630
631        // ── Step 3: Sort eigenvalues ascending, reorder eigenvectors ────
632        let mut eigenvalues: Vec<(f64, usize)> = (0..n).map(|i| (diag[i], i)).collect();
633        eigenvalues.sort_by(|a, b| a.0.total_cmp(&b.0));
634        let vals: Vec<f64> = eigenvalues.iter().map(|&(v, _)| v).collect();
635
636        let mut v_sorted = vec![0.0; n * n];
637        for (new_col, &(_, old_col)) in eigenvalues.iter().enumerate() {
638            for row in 0..n {
639                v_sorted[row * n + new_col] = q[row * n + old_col];
640            }
641        }
642
643        // Sign-canonical: first nonzero component positive
644        for col in 0..n {
645            let mut first_nonzero = 0.0;
646            for row in 0..n {
647                if v_sorted[row * n + col].abs() > 1e-15 {
648                    first_nonzero = v_sorted[row * n + col];
649                    break;
650                }
651            }
652            if first_nonzero < 0.0 {
653                for row in 0..n {
654                    v_sorted[row * n + col] = -v_sorted[row * n + col];
655                }
656            }
657        }
658
659        Ok((vals, Tensor::from_vec(v_sorted, &[n, n])?))
660    }
661
662    /// Matrix rank via SVD: count singular values > tolerance.
663    pub fn matrix_rank(&self) -> Result<usize, RuntimeError> {
664        if self.ndim() != 2 {
665            return Err(RuntimeError::InvalidOperation(
666                "matrix_rank requires a 2D matrix".to_string(),
667            ));
668        }
669        // Simple approach: use QR and count non-zero diagonal
670        let (_q, r) = self.qr_decompose()?;
671        let r_data = r.to_vec();
672        let n = r.shape()[0].min(r.shape()[1]);
673        let cols = r.shape()[1];
674        let mut rank = 0;
675        for i in 0..n {
676            if r_data[i * cols + i].abs() > 1e-10 {
677                rank += 1;
678            }
679        }
680        Ok(rank)
681    }
682
683    /// Kronecker product: A ⊗ B.
684    pub fn kron(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
685        if self.ndim() != 2 || other.ndim() != 2 {
686            return Err(RuntimeError::InvalidOperation(
687                "kron requires two 2D matrices".to_string(),
688            ));
689        }
690        let (m, n) = (self.shape[0], self.shape[1]);
691        let (p, q) = (other.shape()[0], other.shape()[1]);
692        let a = self.to_vec();
693        let b = other.to_vec();
694        let mut result = vec![0.0; m * p * n * q];
695        let out_cols = n * q;
696        for i in 0..m {
697            for j in 0..n {
698                let aij = a[i * n + j];
699                for k in 0..p {
700                    for l in 0..q {
701                        result[(i * p + k) * out_cols + (j * q + l)] = aij * b[k * q + l];
702                    }
703                }
704            }
705        }
706        Tensor::from_vec(result, &[m * p, n * q])
707    }
708
709    /// Matrix inverse via LU decomposition + back-substitution.
710    pub fn inverse(&self) -> Result<Tensor, RuntimeError> {
711        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
712            return Err(RuntimeError::InvalidOperation(
713                "Matrix inverse requires a square 2D matrix".to_string(),
714            ));
715        }
716        let n = self.shape[0];
717        let (l, u, pivots) = self.lu_decompose()?;
718        let l_data = l.to_vec();
719        let u_data = u.to_vec();
720
721        let mut inv = vec![0.0f64; n * n];
722
723        // Solve for each column of the inverse
724        for col in 0..n {
725            // Create permuted identity column
726            let mut b = vec![0.0f64; n];
727            b[pivots[col]] = 1.0;
728
729            // Forward substitution: L * y = b
730            let mut y = vec![0.0f64; n];
731            for i in 0..n {
732                let mut sum = b[i];
733                for j in 0..i {
734                    sum -= l_data[i * n + j] * y[j];
735                }
736                y[i] = sum; // L has 1s on diagonal
737            }
738
739            // Back substitution: U * x = y
740            let mut x = vec![0.0f64; n];
741            for i in (0..n).rev() {
742                let mut sum = y[i];
743                for j in (i + 1)..n {
744                    sum -= u_data[i * n + j] * x[j];
745                }
746                x[i] = sum / u_data[i * n + i];
747            }
748
749            for i in 0..n {
750                inv[i * n + col] = x[i];
751            }
752        }
753
754        Tensor::from_vec(inv, &[n, n])
755    }
756
757    // -----------------------------------------------------------------------
758    // Phase B3: Linear algebra extensions
759    // -----------------------------------------------------------------------
760
761    /// 1-norm: maximum absolute column sum.
762    pub fn norm_1(&self) -> Result<f64, RuntimeError> {
763        if self.ndim() != 2 {
764            return Err(RuntimeError::InvalidOperation(
765                "norm_1 requires a 2D matrix".to_string(),
766            ));
767        }
768        let (m, n) = (self.shape[0], self.shape[1]);
769        let data = self.to_vec();
770        let mut max_col_sum = 0.0_f64;
771        for j in 0..n {
772            let mut acc = BinnedAccumulatorF64::new();
773            for i in 0..m {
774                acc.add(data[i * n + j].abs());
775            }
776            let col_sum = acc.finalize();
777            if col_sum > max_col_sum {
778                max_col_sum = col_sum;
779            }
780        }
781        Ok(max_col_sum)
782    }
783
784    /// Infinity norm: maximum absolute row sum.
785    pub fn norm_inf(&self) -> Result<f64, RuntimeError> {
786        if self.ndim() != 2 {
787            return Err(RuntimeError::InvalidOperation(
788                "norm_inf requires a 2D matrix".to_string(),
789            ));
790        }
791        let (m, n) = (self.shape[0], self.shape[1]);
792        let data = self.to_vec();
793        let mut max_row_sum = 0.0_f64;
794        for i in 0..m {
795            let mut acc = BinnedAccumulatorF64::new();
796            for j in 0..n {
797                acc.add(data[i * n + j].abs());
798            }
799            let row_sum = acc.finalize();
800            if row_sum > max_row_sum {
801                max_row_sum = row_sum;
802            }
803        }
804        Ok(max_row_sum)
805    }
806
807    /// Condition number via eigenvalue ratio.
808    /// For symmetric: |lambda_max| / |lambda_min|. For general: sqrt(sigma_max/sigma_min).
809    pub fn cond(&self) -> Result<f64, RuntimeError> {
810        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
811            return Err(RuntimeError::InvalidOperation(
812                "cond requires a square 2D matrix".to_string(),
813            ));
814        }
815        let n = self.shape[0];
816        // Check if symmetric
817        let data = self.to_vec();
818        let mut is_sym = true;
819        'outer: for i in 0..n {
820            for j in (i + 1)..n {
821                if (data[i * n + j] - data[j * n + i]).abs() > 1e-14 {
822                    is_sym = false;
823                    break 'outer;
824                }
825            }
826        }
827        if is_sym {
828            let (eigenvalues, _) = self.eigh()?;
829            let abs_min = eigenvalues.iter().map(|v| v.abs()).fold(f64::INFINITY, f64::min);
830            let abs_max = eigenvalues.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
831            if abs_min < 1e-15 {
832                return Ok(f64::INFINITY);
833            }
834            Ok(abs_max / abs_min)
835        } else {
836            // General: compute A^T * A, then eigh
837            let at = self.transpose();
838            let ata = at.matmul(self)?;
839            let (eigenvalues, _) = ata.eigh()?;
840            let abs_min = eigenvalues.iter().map(|v| v.abs()).fold(f64::INFINITY, f64::min);
841            let abs_max = eigenvalues.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
842            if abs_min < 1e-15 {
843                return Ok(f64::INFINITY);
844            }
845            Ok((abs_max / abs_min).sqrt())
846        }
847    }
848
849    /// Real Schur decomposition: A = Q * T * Q^T.
850    /// Uses Hessenberg reduction + implicit double-shift QR.
851    pub fn schur(&self) -> Result<(Tensor, Tensor), RuntimeError> {
852        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
853            return Err(RuntimeError::InvalidOperation(
854                "schur requires a square 2D matrix".to_string(),
855            ));
856        }
857        let n = self.shape[0];
858        if n == 0 {
859            return Err(RuntimeError::InvalidOperation("schur: empty matrix".to_string()));
860        }
861        if n == 1 {
862            return Ok((
863                Tensor::from_vec(vec![1.0], &[1, 1])?,
864                self.clone(),
865            ));
866        }
867
868        let mut h = self.to_vec();
869        let mut q = vec![0.0; n * n];
870        for i in 0..n {
871            q[i * n + i] = 1.0;
872        }
873
874        // Step 1: Reduce to upper Hessenberg form
875        for k in 0..n.saturating_sub(2) {
876            // Compute Householder reflector for column k below diagonal
877            let mut col = vec![0.0; n - k - 1];
878            for i in 0..col.len() {
879                col[i] = h[(k + 1 + i) * n + k];
880            }
881            let norm_col = {
882                let mut acc = BinnedAccumulatorF64::new();
883                for &v in &col { acc.add(v * v); }
884                acc.finalize().sqrt()
885            };
886            if norm_col < 1e-15 {
887                continue;
888            }
889            // Sign convention: positive diagonal
890            let sign = if col[0] >= 0.0 { 1.0 } else { -1.0 };
891            col[0] += sign * norm_col;
892            let norm_v = {
893                let mut acc = BinnedAccumulatorF64::new();
894                for &v in &col { acc.add(v * v); }
895                acc.finalize().sqrt()
896            };
897            if norm_v < 1e-15 {
898                continue;
899            }
900            for v in &mut col {
901                *v /= norm_v;
902            }
903            // Apply H = (I - 2vv^T) * H from left
904            for j in 0..n {
905                let mut acc = BinnedAccumulatorF64::new();
906                for i in 0..col.len() {
907                    acc.add(col[i] * h[(k + 1 + i) * n + j]);
908                }
909                let dot = acc.finalize();
910                for i in 0..col.len() {
911                    h[(k + 1 + i) * n + j] -= 2.0 * col[i] * dot;
912                }
913            }
914            // Apply H = H * (I - 2vv^T) from right
915            for i in 0..n {
916                let mut acc = BinnedAccumulatorF64::new();
917                for j in 0..col.len() {
918                    acc.add(col[j] * h[i * n + (k + 1 + j)]);
919                }
920                let dot = acc.finalize();
921                for j in 0..col.len() {
922                    h[i * n + (k + 1 + j)] -= 2.0 * col[j] * dot;
923                }
924            }
925            // Accumulate Q
926            for i in 0..n {
927                let mut acc = BinnedAccumulatorF64::new();
928                for j in 0..col.len() {
929                    acc.add(col[j] * q[i * n + (k + 1 + j)]);
930                }
931                let dot = acc.finalize();
932                for j in 0..col.len() {
933                    q[i * n + (k + 1 + j)] -= 2.0 * col[j] * dot;
934                }
935            }
936        }
937
938        // Step 2: Implicit QR iterations
939        let eps = 1e-14;
940        let max_iter = 200 * n;
941        let mut ihi = n - 1;
942
943        for _iter in 0..max_iter {
944            if ihi == 0 {
945                break;
946            }
947            // Find active block
948            let mut ilo = ihi;
949            while ilo > 0 {
950                if h[ilo * n + (ilo - 1)].abs()
951                    < eps * (h[(ilo - 1) * n + (ilo - 1)].abs() + h[ilo * n + ilo].abs())
952                {
953                    h[ilo * n + (ilo - 1)] = 0.0;
954                    break;
955                }
956                ilo -= 1;
957            }
958            if ilo == ihi {
959                // 1x1 block converged
960                ihi -= 1;
961                continue;
962            }
963            if ilo + 1 == ihi {
964                // 2x2 block - check if real eigenvalues
965                ihi -= 2;
966                continue;
967            }
968
969            // Wilkinson shift from trailing 2x2
970            let a11 = h[(ihi - 1) * n + (ihi - 1)];
971            let a12 = h[(ihi - 1) * n + ihi];
972            let a21 = h[ihi * n + (ihi - 1)];
973            let a22 = h[ihi * n + ihi];
974            let tr = a11 + a22;
975            let det = a11 * a22 - a12 * a21;
976            let disc = tr * tr - 4.0 * det;
977
978            let shift = if disc >= 0.0 {
979                let sqrt_disc = disc.sqrt();
980                let ev1 = (tr + sqrt_disc) / 2.0;
981                let ev2 = (tr - sqrt_disc) / 2.0;
982                if (ev1 - a22).abs() < (ev2 - a22).abs() { ev1 } else { ev2 }
983            } else {
984                a22 // complex eigenvalues: use a22 as shift
985            };
986
987            // Apply single-shift QR step with Givens rotations
988            let mut x = h[ilo * n + ilo] - shift;
989            let mut z = h[(ilo + 1) * n + ilo];
990            for k in ilo..ihi {
991                let r = (x * x + z * z).sqrt();
992                let c = if r < 1e-15 { 1.0 } else { x / r };
993                let s = if r < 1e-15 { 0.0 } else { z / r };
994                // Apply Givens from left: rows k and k+1
995                for j in 0..n {
996                    let t1 = h[k * n + j];
997                    let t2 = h[(k + 1) * n + j];
998                    h[k * n + j] = c * t1 + s * t2;
999                    h[(k + 1) * n + j] = -s * t1 + c * t2;
1000                }
1001                // Apply Givens from right: cols k and k+1
1002                let jmax = if k + 3 < n { k + 3 } else { n };
1003                for i in 0..jmax {
1004                    let t1 = h[i * n + k];
1005                    let t2 = h[i * n + (k + 1)];
1006                    h[i * n + k] = c * t1 + s * t2;
1007                    h[i * n + (k + 1)] = -s * t1 + c * t2;
1008                }
1009                // Accumulate Q
1010                for i in 0..n {
1011                    let t1 = q[i * n + k];
1012                    let t2 = q[i * n + (k + 1)];
1013                    q[i * n + k] = c * t1 + s * t2;
1014                    q[i * n + (k + 1)] = -s * t1 + c * t2;
1015                }
1016                if k + 2 <= ihi {
1017                    x = h[(k + 1) * n + k];
1018                    z = h[(k + 2) * n + k];
1019                }
1020            }
1021        }
1022
1023        // Clean up sub-diagonal entries
1024        for i in 0..n {
1025            for j in 0..i.saturating_sub(1) {
1026                h[i * n + j] = 0.0;
1027            }
1028        }
1029
1030        Ok((
1031            Tensor::from_vec(q, &[n, n])?,
1032            Tensor::from_vec(h, &[n, n])?,
1033        ))
1034    }
1035
1036    // -----------------------------------------------------------------------
1037    // Phase 3A: SVD via Golub-Kahan Bidiagonalization
1038    // -----------------------------------------------------------------------
1039
1040    /// Compute the Singular Value Decomposition: A = U @ diag(S) @ Vt.
1041    /// Returns (U, S, Vt) as (Tensor, Vec<f64>, Tensor).
1042    ///
1043    /// Implementation: via eigendecomposition of A^T*A (for V and S^2),
1044    /// then U = A*V*diag(1/s_i).
1045    ///
1046    /// **Determinism contract:** All intermediate float reductions use
1047    /// `BinnedAccumulatorF64`. Iteration order is fixed row-major.
1048    pub fn svd(&self) -> Result<(Tensor, Vec<f64>, Tensor), RuntimeError> {
1049        if self.ndim() != 2 {
1050            return Err(RuntimeError::InvalidOperation(
1051                "SVD requires a 2D matrix".to_string(),
1052            ));
1053        }
1054        let m = self.shape[0];
1055        let n = self.shape[1];
1056        if m == 0 || n == 0 {
1057            return Err(RuntimeError::InvalidOperation(
1058                "SVD: empty matrix".to_string(),
1059            ));
1060        }
1061
1062        let min_mn = m.min(n);
1063
1064        // Compute A^T * A (n x n symmetric matrix)
1065        let at = self.transpose();
1066        let ata = at.matmul(self)?;
1067
1068        // Eigendecomposition of A^T*A gives V and eigenvalues = sigma^2
1069        let (eigenvalues, eigenvectors) = ata.eigh()?;
1070
1071        // eigenvalues are in ascending order from eigh; we want descending singular values
1072        // Singular values = sqrt(eigenvalues), clamp negatives to 0
1073        let mut singular_values: Vec<f64> = eigenvalues.iter()
1074            .map(|&ev| if ev > 0.0 { ev.sqrt() } else { 0.0 })
1075            .collect();
1076
1077        // Reverse to get descending order
1078        singular_values.reverse();
1079
1080        // Take only min(m,n) singular values
1081        let k = min_mn.min(singular_values.len());
1082        let s: Vec<f64> = singular_values[..k].to_vec();
1083
1084        // V columns are eigenvectors of A^T*A, reversed for descending order
1085        // eigenvectors is n x n, columns are eigenvectors
1086        let ev_data = eigenvectors.to_vec();
1087        let ev_n = eigenvectors.shape()[1]; // should be n
1088
1089        // Build V matrix (n x k) with columns in descending singular value order
1090        let mut v_data = vec![0.0f64; n * k];
1091        for col in 0..k {
1092            let ev_col = n - 1 - col; // reverse index for descending order
1093            for row in 0..n {
1094                v_data[row * k + col] = ev_data[row * ev_n + ev_col];
1095            }
1096        }
1097        let v_mat = Tensor::from_vec(v_data.clone(), &[n, k])?;
1098
1099        // U = A * V * diag(1/s)
1100        // First compute A * V
1101        let av = self.matmul(&v_mat)?;
1102        let av_data = av.to_vec();
1103
1104        // Then scale each column by 1/s_i
1105        let mut u_data = vec![0.0f64; m * k];
1106        for col in 0..k {
1107            if s[col] > 1e-14 {
1108                let inv_s = 1.0 / s[col];
1109                for row in 0..m {
1110                    u_data[row * k + col] = av_data[row * k + col] * inv_s;
1111                }
1112            }
1113            // If s[col] ≈ 0, leave u column as zeros
1114        }
1115
1116        // Sign-canonical: ensure largest-magnitude element of each u column is positive
1117        for col in 0..k {
1118            let mut max_abs = 0.0f64;
1119            let mut max_sign = 1.0f64;
1120            for row in 0..m {
1121                let val = u_data[row * k + col];
1122                if val.abs() > max_abs {
1123                    max_abs = val.abs();
1124                    max_sign = if val >= 0.0 { 1.0 } else { -1.0 };
1125                }
1126            }
1127            if max_sign < 0.0 {
1128                for row in 0..m {
1129                    u_data[row * k + col] = -u_data[row * k + col];
1130                }
1131                for row in 0..n {
1132                    v_data[row * k + col] = -v_data[row * k + col];
1133                }
1134            }
1135        }
1136
1137        let u_tensor = Tensor::from_vec(u_data, &[m, k])?;
1138
1139        // Vt = V^T (k x n)
1140        let mut vt_data = vec![0.0f64; k * n];
1141        for row in 0..k {
1142            for col in 0..n {
1143                vt_data[row * n + col] = v_data[col * k + row];
1144            }
1145        }
1146        let vt_tensor = Tensor::from_vec(vt_data, &[k, n])?;
1147
1148        Ok((u_tensor, s, vt_tensor))
1149    }
1150
1151    /// Truncated SVD — only the top `k` singular values/vectors.
1152    /// Returns (U_k, S_k, Vt_k) where U_k is m x k, Vt_k is k x n.
1153    pub fn svd_truncated(
1154        &self,
1155        k: usize,
1156    ) -> Result<(Tensor, Vec<f64>, Tensor), RuntimeError> {
1157        let (u_full, s_full, vt_full) = self.svd()?;
1158        let m = u_full.shape()[0];
1159        let n = vt_full.shape()[1];
1160        let actual_k = k.min(s_full.len());
1161
1162        if actual_k == 0 {
1163            return Err(RuntimeError::InvalidOperation(
1164                "svd_truncated: k must be > 0".to_string(),
1165            ));
1166        }
1167
1168        let s_k: Vec<f64> = s_full[..actual_k].to_vec();
1169
1170        // Extract first k columns of U
1171        let u_data = u_full.to_vec();
1172        let u_cols = u_full.shape()[1];
1173        let mut u_k = vec![0.0f64; m * actual_k];
1174        for row in 0..m {
1175            for col in 0..actual_k {
1176                u_k[row * actual_k + col] = u_data[row * u_cols + col];
1177            }
1178        }
1179
1180        // Extract first k rows of Vt
1181        let vt_data = vt_full.to_vec();
1182        let mut vt_k = vec![0.0f64; actual_k * n];
1183        for row in 0..actual_k {
1184            for col in 0..n {
1185                vt_k[row * n + col] = vt_data[row * n + col];
1186            }
1187        }
1188
1189        Ok((
1190            Tensor::from_vec(u_k, &[m, actual_k])?,
1191            s_k,
1192            Tensor::from_vec(vt_k, &[actual_k, n])?,
1193        ))
1194    }
1195
1196    // -----------------------------------------------------------------------
1197    // Phase 3B: Pseudoinverse (Moore-Penrose, via SVD)
1198    // -----------------------------------------------------------------------
1199
1200    /// Compute the Moore-Penrose pseudoinverse via SVD.
1201    /// A+ = V @ diag(1/s_i) @ Ut (with default tolerance for near-zero singular values).
1202    pub fn pinv(&self) -> Result<Tensor, RuntimeError> {
1203        // Default tolerance: max(m,n) * eps * max(S)
1204        let (u, s, vt) = self.svd()?;
1205        let m = self.shape[0];
1206        let n = self.shape[1];
1207        let max_s = s.first().copied().unwrap_or(0.0);
1208        let tol = (m.max(n) as f64) * f64::EPSILON * max_s;
1209        Self::pinv_from_svd(&u, &s, &vt, tol)
1210    }
1211
1212    /// Compute the Moore-Penrose pseudoinverse via SVD with explicit tolerance.
1213    pub fn pinv_with_tol(&self, tol: f64) -> Result<Tensor, RuntimeError> {
1214        let (u, s, vt) = self.svd()?;
1215        Self::pinv_from_svd(&u, &s, &vt, tol)
1216    }
1217
1218    /// Internal: compute pseudoinverse from pre-computed SVD.
1219    /// A+ = V @ diag(1/s_i) @ Ut, zeroing 1/s_i where s_i <= tol.
1220    fn pinv_from_svd(
1221        u: &Tensor,
1222        s: &[f64],
1223        vt: &Tensor,
1224        tol: f64,
1225    ) -> Result<Tensor, RuntimeError> {
1226        let m = u.shape()[0];
1227        let k = s.len();
1228        let n = vt.shape()[1];
1229
1230        // Build S_inv: k-vector with 1/s_i or 0
1231        let s_inv: Vec<f64> = s
1232            .iter()
1233            .map(|&si| if si > tol { 1.0 / si } else { 0.0 })
1234            .collect();
1235
1236        // Compute Vt^T @ diag(s_inv) @ U^T = V @ diag(s_inv) @ Ut
1237        // Result is n x m
1238        let u_data = u.to_vec();
1239        let vt_data = vt.to_vec();
1240        let mut result = vec![0.0f64; n * m];
1241
1242        for i in 0..n {
1243            for j in 0..m {
1244                let mut acc = BinnedAccumulatorF64::new();
1245                for l in 0..k {
1246                    // V[i, l] = Vt[l, i] (transposed)
1247                    // Ut[l, j] = U[j, l] (transposed)
1248                    acc.add(vt_data[l * n + i] * s_inv[l] * u_data[j * k + l]);
1249                }
1250                result[i * m + j] = acc.finalize();
1251            }
1252        }
1253
1254        Tensor::from_vec(result, &[n, m])
1255    }
1256
1257    /// Helper: compute Givens rotation parameters.
1258    /// Returns (c, s, r) such that [c s; -s c]^T * [a; b] = [r; 0].
1259    fn givens_rotation(a: f64, b: f64) -> (f64, f64, f64) {
1260        if b.abs() < 1e-15 {
1261            (1.0, 0.0, a)
1262        } else if a.abs() < 1e-15 {
1263            (0.0, if b >= 0.0 { 1.0 } else { -1.0 }, b.abs())
1264        } else {
1265            let r = (a * a + b * b).sqrt();
1266            (a / r, b / r, r)
1267        }
1268    }
1269
1270    /// Matrix exponential via scaling and squaring with Pade(13,13) approximation.
1271    pub fn matrix_exp(&self) -> Result<Tensor, RuntimeError> {
1272        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
1273            return Err(RuntimeError::InvalidOperation(
1274                "matrix_exp requires a square 2D matrix".to_string(),
1275            ));
1276        }
1277        let n = self.shape[0];
1278        if n == 0 {
1279            return Err(RuntimeError::InvalidOperation("matrix_exp: empty matrix".to_string()));
1280        }
1281
1282        const PADE_COEFFS: [f64; 14] = [
1283            64764752532480000.0,
1284            32382376266240000.0,
1285            7771770303897600.0,
1286            1187353796428800.0,
1287            129060195264000.0,
1288            10559470521600.0,
1289            670442572800.0,
1290            33522128640.0,
1291            1323241920.0,
1292            40840800.0,
1293            960960.0,
1294            16380.0,
1295            182.0,
1296            1.0,
1297        ];
1298        const THETA_13: f64 = 5.371920351148152;
1299
1300        // Scaling: s = max(0, ceil(log2(||A||_1 / theta_13)))
1301        let norm = self.norm_1()?;
1302        let s = if norm <= THETA_13 {
1303            0u32
1304        } else {
1305            (norm / THETA_13).log2().ceil() as u32
1306        };
1307
1308        // B = A / 2^s
1309        let scale = 2.0_f64.powi(-(s as i32));
1310        let b_data: Vec<f64> = self.to_vec().iter().map(|&x| x * scale).collect();
1311        let b = Tensor::from_vec(b_data, &[n, n])?;
1312
1313        // Compute B^2, B^4, B^6
1314        let b2 = b.matmul(&b)?;
1315        let b4 = b2.matmul(&b2)?;
1316        let b6 = b4.matmul(&b2)?;
1317
1318        // Identity matrix
1319        let mut eye = vec![0.0; n * n];
1320        for i in 0..n {
1321            eye[i * n + i] = 1.0;
1322        }
1323        let eye_t = Tensor::from_vec(eye, &[n, n])?;
1324
1325        // Build U and V
1326        // U = B * (b_13*B^6 + b_11*B^4 + b_9*B^2 + b_7*I) * B^6
1327        //   + B * (b_5*B^4 + b_3*B^2 + b_1*I)
1328        // V = (b_12*B^6 + b_10*B^4 + b_8*B^2 + b_6*I) * B^6
1329        //   + (b_4*B^4 + b_2*B^2 + b_0*I)
1330
1331        // Helper: scale and add tensors
1332        fn scale_add(a: &Tensor, sa: f64, b: &Tensor, sb: f64, n: usize) -> Vec<f64> {
1333            let ad = a.to_vec();
1334            let bd = b.to_vec();
1335            let mut r = vec![0.0; n * n];
1336            for i in 0..n * n {
1337                r[i] = sa * ad[i] + sb * bd[i];
1338            }
1339            r
1340        }
1341
1342        let c = &PADE_COEFFS;
1343
1344        // inner_u1 = b_13*B6 + b_11*B4 + b_9*B2 + b_7*I
1345        let mut iu1 = scale_add(&b6, c[13], &b4, c[11], n);
1346        let t = scale_add(&b2, c[9], &eye_t, c[7], n);
1347        for i in 0..n * n {
1348            iu1[i] += t[i];
1349        }
1350        let iu1_t = Tensor::from_vec(iu1, &[n, n])?;
1351        let iu1_b6 = iu1_t.matmul(&b6)?;
1352
1353        // inner_u2 = b_5*B4 + b_3*B2 + b_1*I
1354        let mut iu2 = scale_add(&b4, c[5], &b2, c[3], n);
1355        let t = Tensor::from_vec({
1356            let mut v = vec![0.0; n * n];
1357            for i in 0..n { v[i * n + i] = c[1]; }
1358            v
1359        }, &[n, n])?;
1360        let td = t.to_vec();
1361        for i in 0..n * n {
1362            iu2[i] += td[i];
1363        }
1364        let iu2_t = Tensor::from_vec(iu2, &[n, n])?;
1365
1366        // U_inner = iu1_b6 + iu2
1367        let iu1d = iu1_b6.to_vec();
1368        let iu2d = iu2_t.to_vec();
1369        let mut u_inner = vec![0.0; n * n];
1370        for i in 0..n * n {
1371            u_inner[i] = iu1d[i] + iu2d[i];
1372        }
1373        let u_inner_t = Tensor::from_vec(u_inner, &[n, n])?;
1374        let u = b.matmul(&u_inner_t)?;
1375
1376        // inner_v1 = b_12*B6 + b_10*B4 + b_8*B2 + b_6*I
1377        let mut iv1 = scale_add(&b6, c[12], &b4, c[10], n);
1378        let t = scale_add(&b2, c[8], &eye_t, c[6], n);
1379        for i in 0..n * n {
1380            iv1[i] += t[i];
1381        }
1382        let iv1_t = Tensor::from_vec(iv1, &[n, n])?;
1383        let iv1_b6 = iv1_t.matmul(&b6)?;
1384
1385        // inner_v2 = b_4*B4 + b_2*B2 + b_0*I
1386        let mut iv2 = scale_add(&b4, c[4], &b2, c[2], n);
1387        let t = Tensor::from_vec({
1388            let mut v = vec![0.0; n * n];
1389            for i in 0..n { v[i * n + i] = c[0]; }
1390            v
1391        }, &[n, n])?;
1392        let td = t.to_vec();
1393        for i in 0..n * n {
1394            iv2[i] += td[i];
1395        }
1396        let iv2_t = Tensor::from_vec(iv2, &[n, n])?;
1397
1398        // V = iv1_b6 + iv2
1399        let iv1d = iv1_b6.to_vec();
1400        let iv2d = iv2_t.to_vec();
1401        let mut v_data = vec![0.0; n * n];
1402        for i in 0..n * n {
1403            v_data[i] = iv1d[i] + iv2d[i];
1404        }
1405        let v_mat = Tensor::from_vec(v_data, &[n, n])?;
1406
1407        // Solve (V - U) * r = (V + U)
1408        let ud = u.to_vec();
1409        let vd = v_mat.to_vec();
1410        let mut lhs_data = vec![0.0; n * n];
1411        let mut rhs_data = vec![0.0; n * n];
1412        for i in 0..n * n {
1413            lhs_data[i] = vd[i] - ud[i];
1414            rhs_data[i] = vd[i] + ud[i];
1415        }
1416        let lhs = Tensor::from_vec(lhs_data, &[n, n])?;
1417
1418        // Solve column-by-column
1419        let mut result = vec![0.0; n * n];
1420        for col in 0..n {
1421            let mut rhs_col = vec![0.0; n];
1422            for row in 0..n {
1423                rhs_col[row] = rhs_data[row * n + col];
1424            }
1425            let rhs_tensor = Tensor::from_vec(rhs_col, &[n])?;
1426            let sol = lhs.solve(&rhs_tensor)?;
1427            let sol_data = sol.to_vec();
1428            for row in 0..n {
1429                result[row * n + col] = sol_data[row];
1430            }
1431        }
1432
1433        let mut r = Tensor::from_vec(result, &[n, n])?;
1434
1435        // Square s times: r = r * r
1436        for _ in 0..s {
1437            r = r.matmul(&r)?;
1438        }
1439
1440        Ok(r)
1441    }
1442}
1443
1444// ---------------------------------------------------------------------------
1445// Tests
1446// ---------------------------------------------------------------------------
1447
1448#[cfg(test)]
1449mod tests {
1450    use super::*;
1451
1452    /// Helper: reconstruct A from SVD: U @ diag(S) @ Vt
1453    fn reconstruct_svd(u: &Tensor, s: &[f64], vt: &Tensor) -> Vec<f64> {
1454        let m = u.shape()[0];
1455        let k = s.len();
1456        let n = vt.shape()[1];
1457        let u_data = u.to_vec();
1458        let vt_data = vt.to_vec();
1459        let u_cols = u.shape()[1];
1460        let mut result = vec![0.0f64; m * n];
1461        for i in 0..m {
1462            for j in 0..n {
1463                let mut sum = 0.0;
1464                for l in 0..k {
1465                    sum += u_data[i * u_cols + l] * s[l] * vt_data[l * n + j];
1466                }
1467                result[i * n + j] = sum;
1468            }
1469        }
1470        result
1471    }
1472
1473    /// Helper: check two flat arrays are approximately equal
1474    fn assert_approx_eq(a: &[f64], b: &[f64], tol: f64, msg: &str) {
1475        assert_eq!(a.len(), b.len(), "{}: length mismatch", msg);
1476        for (i, (&ai, &bi)) in a.iter().zip(b.iter()).enumerate() {
1477            assert!(
1478                (ai - bi).abs() < tol,
1479                "{}: element [{}] differs: {} vs {} (diff={})",
1480                msg,
1481                i,
1482                ai,
1483                bi,
1484                (ai - bi).abs()
1485            );
1486        }
1487    }
1488
1489    #[test]
1490    fn test_svd_identity_2x2() {
1491        let eye = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
1492        let (u, s, vt) = eye.svd().unwrap();
1493        // Singular values should be [1, 1]
1494        assert!((s[0] - 1.0).abs() < 1e-10, "s[0] = {}", s[0]);
1495        assert!((s[1] - 1.0).abs() < 1e-10, "s[1] = {}", s[1]);
1496        // Roundtrip
1497        let recon = reconstruct_svd(&u, &s, &vt);
1498        assert_approx_eq(&recon, &[1.0, 0.0, 0.0, 1.0], 1e-10, "SVD identity roundtrip");
1499    }
1500
1501    #[test]
1502    fn test_svd_identity_3x3() {
1503        let mut data = vec![0.0; 9];
1504        for i in 0..3 {
1505            data[i * 3 + i] = 1.0;
1506        }
1507        let eye = Tensor::from_vec(data.clone(), &[3, 3]).unwrap();
1508        let (u, s, vt) = eye.svd().unwrap();
1509        for &si in &s {
1510            assert!((si - 1.0).abs() < 1e-10, "singular value = {}", si);
1511        }
1512        let recon = reconstruct_svd(&u, &s, &vt);
1513        assert_approx_eq(&recon, &data, 1e-10, "SVD 3x3 identity roundtrip");
1514    }
1515
1516    #[test]
1517    fn test_svd_known_matrix() {
1518        // A = [[3, 0], [0, 2]] — singular values should be 3, 2
1519        let a = Tensor::from_vec(vec![3.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
1520        let (_u, s, _vt) = a.svd().unwrap();
1521        assert!((s[0] - 3.0).abs() < 1e-10, "s[0] = {}", s[0]);
1522        assert!((s[1] - 2.0).abs() < 1e-10, "s[1] = {}", s[1]);
1523    }
1524
1525    #[test]
1526    fn test_svd_roundtrip_general() {
1527        let a = Tensor::from_vec(
1528            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.5],
1529            &[3, 3],
1530        )
1531        .unwrap();
1532        let (u, s, vt) = a.svd().unwrap();
1533        let recon = reconstruct_svd(&u, &s, &vt);
1534        let original = a.to_vec();
1535        assert_approx_eq(&recon, &original, 1e-8, "SVD general roundtrip");
1536    }
1537
1538    #[test]
1539    fn test_svd_rectangular_tall() {
1540        // 3x2 matrix
1541        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
1542        let (u, s, vt) = a.svd().unwrap();
1543        assert_eq!(u.shape(), &[3, 2]);
1544        assert_eq!(s.len(), 2);
1545        assert_eq!(vt.shape(), &[2, 2]);
1546        let recon = reconstruct_svd(&u, &s, &vt);
1547        let original = a.to_vec();
1548        assert_approx_eq(&recon, &original, 1e-8, "SVD tall rectangular roundtrip");
1549    }
1550
1551    #[test]
1552    fn test_svd_rectangular_wide() {
1553        // 2x3 matrix
1554        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1555        let (u, s, vt) = a.svd().unwrap();
1556        assert_eq!(u.shape(), &[2, 2]);
1557        assert_eq!(s.len(), 2);
1558        assert_eq!(vt.shape(), &[2, 3]);
1559        let recon = reconstruct_svd(&u, &s, &vt);
1560        let original = a.to_vec();
1561        assert_approx_eq(&recon, &original, 1e-8, "SVD wide rectangular roundtrip");
1562    }
1563
1564    #[test]
1565    fn test_svd_singular_values_descending() {
1566        let a = Tensor::from_vec(
1567            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1568            &[3, 3],
1569        )
1570        .unwrap();
1571        let (_, s, _) = a.svd().unwrap();
1572        for i in 0..s.len() - 1 {
1573            assert!(s[i] >= s[i + 1], "singular values not descending: {} < {}", s[i], s[i + 1]);
1574        }
1575    }
1576
1577    #[test]
1578    fn test_svd_truncated_basic() {
1579        let a = Tensor::from_vec(
1580            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1581            &[3, 3],
1582        )
1583        .unwrap();
1584        let (u, s, vt) = a.svd_truncated(2).unwrap();
1585        assert_eq!(u.shape(), &[3, 2]);
1586        assert_eq!(s.len(), 2);
1587        assert_eq!(vt.shape(), &[2, 3]);
1588    }
1589
1590    #[test]
1591    fn test_svd_deterministic() {
1592        let a = Tensor::from_vec(
1593            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1594            &[3, 3],
1595        )
1596        .unwrap();
1597        let (u1, s1, vt1) = a.svd().unwrap();
1598        let (u2, s2, vt2) = a.svd().unwrap();
1599        assert_eq!(u1.to_vec(), u2.to_vec(), "U not deterministic");
1600        assert_eq!(s1, s2, "S not deterministic");
1601        assert_eq!(vt1.to_vec(), vt2.to_vec(), "Vt not deterministic");
1602    }
1603
1604    #[test]
1605    fn test_pinv_square() {
1606        // For a non-singular square matrix, pinv(A) ≈ inv(A)
1607        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 5.0], &[2, 2]).unwrap();
1608        let a_pinv = a.pinv().unwrap();
1609        // Check A @ A+ @ A ≈ A
1610        let a_ap = a.matmul(&a_pinv).unwrap();
1611        let a_ap_a = a_ap.matmul(&a).unwrap();
1612        assert_approx_eq(&a_ap_a.to_vec(), &a.to_vec(), 1e-8, "pinv square: A @ A+ @ A ≈ A");
1613    }
1614
1615    #[test]
1616    fn test_pinv_identity() {
1617        let eye = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
1618        let eye_pinv = eye.pinv().unwrap();
1619        assert_approx_eq(
1620            &eye_pinv.to_vec(),
1621            &[1.0, 0.0, 0.0, 1.0],
1622            1e-10,
1623            "pinv of identity",
1624        );
1625    }
1626
1627    #[test]
1628    fn test_pinv_rectangular() {
1629        // Tall matrix: 3x2
1630        let a = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[3, 2]).unwrap();
1631        let a_pinv = a.pinv().unwrap();
1632        assert_eq!(a_pinv.shape(), &[2, 3]);
1633        // A @ A+ @ A ≈ A
1634        let a_ap = a.matmul(&a_pinv).unwrap();
1635        let a_ap_a = a_ap.matmul(&a).unwrap();
1636        assert_approx_eq(&a_ap_a.to_vec(), &a.to_vec(), 1e-8, "pinv rect: A @ A+ @ A ≈ A");
1637    }
1638
1639    #[test]
1640    fn test_pinv_moore_penrose_conditions() {
1641        // All 4 Moore-Penrose conditions for a general matrix
1642        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1643        let ap = a.pinv().unwrap();
1644        // Condition 1: A @ A+ @ A ≈ A
1645        let aapa = a.matmul(&ap).unwrap().matmul(&a).unwrap();
1646        assert_approx_eq(&aapa.to_vec(), &a.to_vec(), 1e-6, "MP condition 1");
1647        // Condition 2: A+ @ A @ A+ ≈ A+
1648        let apaap = ap.matmul(&a).unwrap().matmul(&ap).unwrap();
1649        assert_approx_eq(&apaap.to_vec(), &ap.to_vec(), 1e-6, "MP condition 2");
1650    }
1651
1652    #[test]
1653    fn test_pinv_with_tol() {
1654        let a = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1e-16], &[2, 2]).unwrap();
1655        // With large tolerance, treat 1e-16 as zero
1656        let ap = a.pinv_with_tol(1e-10).unwrap();
1657        let ap_data = ap.to_vec();
1658        // Should act like pseudoinverse of [[1,0],[0,0]]
1659        assert!((ap_data[0] - 1.0).abs() < 1e-8, "pinv_with_tol [0,0]");
1660        assert!(ap_data[3].abs() < 1e-8, "pinv_with_tol [1,1] should be ~0");
1661    }
1662
1663    #[test]
1664    fn test_svd_1x1() {
1665        let a = Tensor::from_vec(vec![5.0], &[1, 1]).unwrap();
1666        let (u, s, vt) = a.svd().unwrap();
1667        assert!((s[0] - 5.0).abs() < 1e-10);
1668        let recon = reconstruct_svd(&u, &s, &vt);
1669        assert_approx_eq(&recon, &[5.0], 1e-10, "SVD 1x1 roundtrip");
1670    }
1671}
1672