Skip to main content

cjc_runtime/
linalg.rs

1//! Dense linear algebra operations on [`Tensor`].
2//!
3//! Provides matrix decompositions (LU, QR, Cholesky, Schur, SVD), solvers
4//! (direct and least-squares), eigenvalue routines, norms, and related
5//! utilities. All floating-point reductions use [`BinnedAccumulatorF64`] or
6//! Kahan summation for deterministic, order-invariant results.
7//!
8//! # Determinism Contract
9//!
10//! - Pivot selection uses strict `>` on absolute values with lowest-index
11//!   tie-breaking.
12//! - Eigenvalue / SVD sign-canonical: the first nonzero element of each
13//!   eigenvector / singular vector is forced positive.
14//! - No `HashMap`, no parallel iteration, no OS randomness.
15
16use crate::accumulator::BinnedAccumulatorF64;
17use crate::error::RuntimeError;
18use crate::tensor::Tensor;
19
20// ---------------------------------------------------------------------------
21// 6. Linalg Operations on Tensor
22// ---------------------------------------------------------------------------
23
24impl Tensor {
25    /// Compute the LU decomposition with partial pivoting.
26    ///
27    /// Returns `(L, U, pivot_indices)` where `P * A = L * U` and `pivot_indices`
28    /// encodes the row permutation `P`.
29    ///
30    /// # Arguments
31    ///
32    /// * `self` - A square 2-D [`Tensor`] (n x n).
33    ///
34    /// # Returns
35    ///
36    /// * `L` - Lower-triangular matrix with unit diagonal.
37    /// * `U` - Upper-triangular matrix.
38    /// * `pivot_indices` - Permutation vector of length n.
39    ///
40    /// # Errors
41    ///
42    /// Returns [`RuntimeError::InvalidOperation`] if the matrix is not square
43    /// 2-D or is singular.
44    ///
45    /// # Determinism
46    ///
47    /// Pivot selection uses strict `>` comparison on absolute values; ties are
48    /// broken by choosing the lowest row index.
49    ///
50    /// **Determinism contract:** Pivot selection uses strict `>` comparison on
51    /// absolute values. When two candidates have identical absolute values, the
52    /// first (lowest row index) is chosen. This is deterministic given identical
53    /// input bits.
54    pub fn lu_decompose(&self) -> Result<(Tensor, Tensor, Vec<usize>), RuntimeError> {
55        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
56            return Err(RuntimeError::InvalidOperation(
57                "LU decomposition requires a square 2D matrix".to_string(),
58            ));
59        }
60        let n = self.shape[0];
61        let mut a = self.to_vec();
62        let mut pivots: Vec<usize> = (0..n).collect();
63
64        for k in 0..n {
65            // Find pivot
66            let mut max_val = a[k * n + k].abs();
67            let mut max_row = k;
68            for i in (k + 1)..n {
69                let v = a[i * n + k].abs();
70                if v > max_val {
71                    max_val = v;
72                    max_row = i;
73                }
74            }
75            if max_val < 1e-15 {
76                return Err(RuntimeError::InvalidOperation(
77                    "LU decomposition: singular matrix".to_string(),
78                ));
79            }
80            if max_row != k {
81                pivots.swap(k, max_row);
82                for j in 0..n {
83                    let tmp = a[k * n + j];
84                    a[k * n + j] = a[max_row * n + j];
85                    a[max_row * n + j] = tmp;
86                }
87            }
88            for i in (k + 1)..n {
89                a[i * n + k] /= a[k * n + k];
90                for j in (k + 1)..n {
91                    a[i * n + j] -= a[i * n + k] * a[k * n + j];
92                }
93            }
94        }
95
96        // Extract L and U
97        let mut l_data = vec![0.0f64; n * n];
98        let mut u_data = vec![0.0f64; n * n];
99        for i in 0..n {
100            for j in 0..n {
101                if i == j {
102                    l_data[i * n + j] = 1.0;
103                    u_data[i * n + j] = a[i * n + j];
104                } else if i > j {
105                    l_data[i * n + j] = a[i * n + j];
106                } else {
107                    u_data[i * n + j] = a[i * n + j];
108                }
109            }
110        }
111
112        Ok((
113            Tensor::from_vec(l_data, &[n, n])?,
114            Tensor::from_vec(u_data, &[n, n])?,
115            pivots,
116        ))
117    }
118
119    /// Compute the QR decomposition via Householder reflections.
120    ///
121    /// Returns `(Q, R)` where `A = Q * R`, `Q` is orthogonal (m x min(m,n)),
122    /// and `R` is upper-triangular (min(m,n) x n).
123    ///
124    /// # Errors
125    ///
126    /// Returns [`RuntimeError::InvalidOperation`] if the tensor is not 2-D.
127    pub fn qr_decompose(&self) -> Result<(Tensor, Tensor), RuntimeError> {
128        if self.ndim() != 2 {
129            return Err(RuntimeError::InvalidOperation(
130                "QR decomposition requires a 2D matrix".to_string(),
131            ));
132        }
133        let m = self.shape[0];
134        let n = self.shape[1];
135        let min_mn = m.min(n);
136
137        // Work on a column-major copy for better locality in Householder reflections.
138        let row_major = self.to_vec();
139        let mut cm = vec![0.0f64; m * n]; // column-major
140        for i in 0..m {
141            for j in 0..n {
142                cm[j * m + i] = row_major[i * n + j];
143            }
144        }
145
146        // Store Householder vectors and tau values
147        let mut tau = vec![0.0f64; min_mn];
148
149        for k in 0..min_mn {
150            // Compute norm of column k below diagonal
151            let mut sigma = 0.0f64;
152            for i in (k + 1)..m {
153                sigma += cm[k * m + i] * cm[k * m + i];
154            }
155            let x0 = cm[k * m + k];
156            let norm_x = (x0 * x0 + sigma).sqrt();
157
158            if norm_x < 1e-15 {
159                tau[k] = 0.0;
160                continue;
161            }
162
163            // Householder vector: v[k] = x[k] - alpha, v[k+1:] = x[k+1:]
164            let alpha = if x0 >= 0.0 { -norm_x } else { norm_x };
165            cm[k * m + k] -= alpha;
166            let v_norm_sq = cm[k * m + k] * cm[k * m + k] + sigma;
167            tau[k] = if v_norm_sq > 1e-30 { 2.0 / v_norm_sq } else { 0.0 };
168
169            // Apply H = I - tau * v * v^T to columns k+1..n
170            for j in (k + 1)..n {
171                let mut dot = 0.0f64;
172                for i in k..m {
173                    dot += cm[k * m + i] * cm[j * m + i];
174                }
175                let scale = tau[k] * dot;
176                for i in k..m {
177                    cm[j * m + i] -= scale * cm[k * m + i];
178                }
179            }
180
181            // Store alpha on the diagonal (R[k,k])
182            cm[k * m + k] = alpha;
183        }
184
185        // Extract R (min_mn × n) from upper triangle of cm (column-major)
186        let mut r_data = vec![0.0f64; min_mn * n];
187        for j in 0..n {
188            for i in 0..min_mn.min(j + 1) {
189                r_data[i * n + j] = cm[j * m + i];
190            }
191        }
192
193        // Reconstruct Q (m × min_mn) from Householder vectors
194        // Start with identity, apply reflectors in reverse order
195        let mut q_cm = vec![0.0f64; m * min_mn]; // column-major
196        for i in 0..min_mn { q_cm[i * m + i] = 1.0; }
197
198        for k in (0..min_mn).rev() {
199            if tau[k] == 0.0 { continue; }
200            // Reconstruct v from stored values
201            // v[k] was overwritten with alpha, but we need the original v[k].
202            // Since alpha = x[k] - v[k], and R[k,k] = alpha is in cm[k*m+k]...
203            // We stored the Householder vector in cm[k*m+k:k*m+m] AFTER
204            // setting cm[k*m+k] = alpha. So the v is lost for the diagonal.
205            // Instead, let's use the standard approach: v[k] = 1 (normalize).
206            // Actually, we need to reconstruct properly. Let me just use
207            // the explicit Q accumulation approach.
208
209            // v = [0..0, 1, cm[k*m+k+1], cm[k*m+k+2], ..., cm[k*m+m-1]]
210            // But cm[k*m+k] was overwritten. We need to recover v[k].
211            // Since tau[k] = 2 / v_norm_sq, and v[k+1:] are stored in cm,
212            // we can recover: v_norm_sq = 2/tau[k], v[k]^2 = v_norm_sq - sigma
213            // where sigma = sum(cm[k*m+i]^2 for i in k+1..m)
214            let mut sigma2 = 0.0f64;
215            for i in (k + 1)..m {
216                sigma2 += cm[k * m + i] * cm[k * m + i];
217            }
218            let v_norm_sq = 2.0 / tau[k];
219            let vk_sq = v_norm_sq - sigma2;
220            let vk = if vk_sq > 0.0 { vk_sq.sqrt() } else { 0.0 };
221            // Determine sign: v[k] = x[k] - alpha, both known at construction.
222            // Since we set v[k] = x[k] - alpha and alpha has opposite sign to x[k],
223            // v[k] is always positive (|x[k]| + |alpha| > 0).
224
225            // Apply H_k to Q columns
226            for j in k..min_mn {
227                let mut dot = vk * q_cm[j * m + k];
228                for i in (k + 1)..m {
229                    dot += cm[k * m + i] * q_cm[j * m + i];
230                }
231                let scale = tau[k] * dot;
232                q_cm[j * m + k] -= scale * vk;
233                for i in (k + 1)..m {
234                    q_cm[j * m + i] -= scale * cm[k * m + i];
235                }
236            }
237        }
238
239        // Convert Q from column-major to row-major
240        let mut q_data = vec![0.0f64; m * min_mn];
241        for i in 0..m {
242            for j in 0..min_mn {
243                q_data[i * min_mn + j] = q_cm[j * m + i];
244            }
245        }
246
247        Ok((
248            Tensor::from_vec(q_data, &[m, min_mn])?,
249            Tensor::from_vec(r_data, &[min_mn, n])?,
250        ))
251    }
252
253    /// Compute the Cholesky decomposition: `A = L * L^T`.
254    ///
255    /// Returns the lower-triangular factor `L`.
256    ///
257    /// # Errors
258    ///
259    /// Returns [`RuntimeError::InvalidOperation`] if the matrix is not square
260    /// 2-D or is not positive definite.
261    ///
262    /// # Determinism
263    ///
264    /// Inner-loop summation uses Kahan compensation with fixed iteration order.
265    pub fn cholesky(&self) -> Result<Tensor, RuntimeError> {
266        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
267            return Err(RuntimeError::InvalidOperation(
268                "Cholesky decomposition requires a square 2D matrix".to_string(),
269            ));
270        }
271        let n = self.shape[0];
272        let a = self.to_vec();
273        let mut l = vec![0.0f64; n * n];
274
275        for j in 0..n {
276            // Kahan summation: lightweight (16 bytes) vs BinnedAccumulator (32KB).
277            // Deterministic because iteration order is fixed (0..j).
278            let mut sum = 0.0f64;
279            let mut comp = 0.0f64;
280            for k in 0..j {
281                let y = l[j * n + k] * l[j * n + k] - comp;
282                let t = sum + y;
283                comp = (t - sum) - y;
284                sum = t;
285            }
286            let diag = a[j * n + j] - sum;
287            if diag <= 0.0 {
288                return Err(RuntimeError::InvalidOperation(
289                    "Cholesky: matrix is not positive definite".to_string(),
290                ));
291            }
292            l[j * n + j] = diag.sqrt();
293
294            for i in (j + 1)..n {
295                let mut s = 0.0f64;
296                let mut c = 0.0f64;
297                for k in 0..j {
298                    let y = l[i * n + k] * l[j * n + k] - c;
299                    let t = s + y;
300                    c = (t - s) - y;
301                    s = t;
302                }
303                l[i * n + j] = (a[i * n + j] - s) / l[j * n + j];
304            }
305        }
306
307        Tensor::from_vec(l, &[n, n])
308    }
309
310    /// Compute the determinant via LU decomposition.
311    ///
312    /// Returns the product of the `U` diagonal elements multiplied by the
313    /// permutation parity sign. Returns `0.0` for singular matrices.
314    pub fn det(&self) -> Result<f64, RuntimeError> {
315        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
316            return Err(RuntimeError::InvalidOperation(
317                "det requires a square 2D matrix".to_string(),
318            ));
319        }
320        let n = self.shape[0];
321        let mut a = self.to_vec();
322        let mut sign = 1.0f64;
323        for k in 0..n {
324            let mut max_val = a[k * n + k].abs();
325            let mut max_row = k;
326            for i in (k + 1)..n {
327                let v = a[i * n + k].abs();
328                if v > max_val {
329                    max_val = v;
330                    max_row = i;
331                }
332            }
333            if max_val < 1e-15 {
334                return Ok(0.0); // singular
335            }
336            if max_row != k {
337                sign = -sign;
338                for j in 0..n {
339                    let tmp = a[k * n + j];
340                    a[k * n + j] = a[max_row * n + j];
341                    a[max_row * n + j] = tmp;
342                }
343            }
344            for i in (k + 1)..n {
345                a[i * n + k] /= a[k * n + k];
346                for j in (k + 1)..n {
347                    a[i * n + j] -= a[i * n + k] * a[k * n + j];
348                }
349            }
350        }
351        let mut det = sign;
352        for i in 0..n {
353            det *= a[i * n + i];
354        }
355        Ok(det)
356    }
357
358    /// Solve the linear system `A * x = b` via LU decomposition with partial pivoting.
359    ///
360    /// # Arguments
361    ///
362    /// * `self` - Coefficient matrix `A` (n x n).
363    /// * `b` - Right-hand side vector (length n).
364    ///
365    /// # Returns
366    ///
367    /// Solution vector `x` as a 1-D [`Tensor`].
368    pub fn solve(&self, b: &Tensor) -> Result<Tensor, RuntimeError> {
369        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
370            return Err(RuntimeError::InvalidOperation(
371                "solve requires a square 2D matrix A".to_string(),
372            ));
373        }
374        let n = self.shape[0];
375        if b.len() != n {
376            return Err(RuntimeError::InvalidOperation(
377                format!("solve: b length {} != n = {n}", b.len()),
378            ));
379        }
380        let (l, u, pivots) = self.lu_decompose()?;
381        let l_data = l.to_vec();
382        let u_data = u.to_vec();
383        let b_data = b.to_vec();
384
385        // Permute b
386        let mut pb = vec![0.0; n];
387        for i in 0..n {
388            pb[i] = b_data[pivots[i]];
389        }
390
391        // Forward substitution: L * y = pb
392        let mut y = vec![0.0; n];
393        for i in 0..n {
394            let mut s = pb[i];
395            for j in 0..i {
396                s -= l_data[i * n + j] * y[j];
397            }
398            y[i] = s;
399        }
400
401        // Back substitution: U * x = y
402        let mut x = vec![0.0; n];
403        for i in (0..n).rev() {
404            let mut s = y[i];
405            for j in (i + 1)..n {
406                s -= u_data[i * n + j] * x[j];
407            }
408            x[i] = s / u_data[i * n + i];
409        }
410
411        Tensor::from_vec(x, &[n])
412    }
413
414    /// Compute the ordinary least-squares solution minimizing `||A*x - b||_2`.
415    ///
416    /// Uses QR decomposition for numerical stability. The dot products for
417    /// `Q^T * b` use [`BinnedAccumulatorF64`] for determinism.
418    ///
419    /// # Arguments
420    ///
421    /// * `self` - Design matrix `A` (m x n, m >= n).
422    /// * `b` - Observation vector (length m).
423    ///
424    /// # Returns
425    ///
426    /// Solution vector `x` as a 1-D [`Tensor`] of length n.
427    pub fn lstsq(&self, b: &Tensor) -> Result<Tensor, RuntimeError> {
428        if self.ndim() != 2 {
429            return Err(RuntimeError::InvalidOperation(
430                "lstsq requires a 2D matrix".to_string(),
431            ));
432        }
433        let m = self.shape[0];
434        let n = self.shape[1];
435        if m < n {
436            return Err(RuntimeError::InvalidOperation(
437                "lstsq requires m >= n".to_string(),
438            ));
439        }
440        if b.len() != m {
441            return Err(RuntimeError::InvalidOperation(
442                format!("lstsq: b length {} != m = {m}", b.len()),
443            ));
444        }
445        let (q, r) = self.qr_decompose()?;
446        let q_data = q.to_vec();
447        let r_data = r.to_vec();
448        let b_data = b.to_vec();
449
450        // Q^T * b (Q is m x n, so Q^T * b gives n-vector)
451        // Use BinnedAccumulator for deterministic dot products.
452        let mut qtb = vec![0.0; n];
453        for j in 0..n {
454            let mut acc = BinnedAccumulatorF64::new();
455            for i in 0..m {
456                acc.add(q_data[i * n + j] * b_data[i]);
457            }
458            qtb[j] = acc.finalize();
459        }
460
461        // Back substitution: R * x = Q^T * b
462        let mut x = vec![0.0; n];
463        for i in (0..n).rev() {
464            let mut s = qtb[i];
465            for j in (i + 1)..n {
466                s -= r_data[i * n + j] * x[j];
467            }
468            if r_data[i * n + i].abs() < 1e-15 {
469                return Err(RuntimeError::InvalidOperation(
470                    "lstsq: rank-deficient matrix".to_string(),
471                ));
472            }
473            x[i] = s / r_data[i * n + i];
474        }
475
476        Tensor::from_vec(x, &[n])
477    }
478
479    /// Compute the matrix trace (sum of diagonal elements) using [`BinnedAccumulatorF64`].
480    pub fn trace(&self) -> Result<f64, RuntimeError> {
481        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
482            return Err(RuntimeError::InvalidOperation(
483                "trace requires a square 2D matrix".to_string(),
484            ));
485        }
486        let n = self.shape[0];
487        let data = self.to_vec();
488        let mut acc = BinnedAccumulatorF64::new();
489        for i in 0..n {
490            acc.add(data[i * n + i]);
491        }
492        Ok(acc.finalize())
493    }
494
495    /// Compute the Frobenius norm: `sqrt(sum(a_ij^2))` using [`BinnedAccumulatorF64`].
496    pub fn norm_frobenius(&self) -> Result<f64, RuntimeError> {
497        if self.ndim() != 2 {
498            return Err(RuntimeError::InvalidOperation(
499                "norm_frobenius requires a 2D matrix".to_string(),
500            ));
501        }
502        let data = self.to_vec();
503        let mut acc = BinnedAccumulatorF64::new();
504        for &v in &data {
505            acc.add(v * v);
506        }
507        Ok(acc.finalize().sqrt())
508    }
509
510    /// Compute the symmetric eigenvalue decomposition via Householder
511    /// tridiagonalization followed by implicit QR iteration with Wilkinson shift.
512    ///
513    /// Returns `(eigenvalues, eigenvectors)` where eigenvalues are sorted in
514    /// ascending order and eigenvectors form the columns of an n x n [`Tensor`].
515    ///
516    /// # Algorithm
517    ///
518    /// 1. Householder reduction to tridiagonal form -- O(n^3).
519    /// 2. Implicit QR iteration on the tridiagonal matrix -- O(n^2) total.
520    /// 3. Eigenvectors are sign-canonicalized (first nonzero element positive).
521    ///
522    /// # Determinism
523    ///
524    /// Fixed row-major sweep order with smallest `(i, j)` tie-breaking.
525    /// All reductions use fixed iteration order.
526    pub fn eigh(&self) -> Result<(Vec<f64>, Tensor), RuntimeError> {
527        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
528            return Err(RuntimeError::InvalidOperation(
529                "eigh requires a square 2D matrix".to_string(),
530            ));
531        }
532        let n = self.shape[0];
533
534        if n <= 1 {
535            let val = if n == 1 { self.to_vec()[0] } else { 0.0 };
536            let v_data = if n == 1 { vec![1.0] } else { vec![] };
537            return Ok((vec![val], Tensor::from_vec(v_data, &[n, n])?));
538        }
539
540        // ── Step 1: Householder tridiagonalization ──────────────────────
541        // Reduce symmetric A to tridiagonal form T = Q^T A Q.
542        // Q is accumulated as a product of Householder reflectors.
543        // This is O(n^3) — far better than Jacobi's O(n^4) worst case.
544
545        let mut a = self.to_vec();
546        let mut q = vec![0.0f64; n * n];
547        for i in 0..n { q[i * n + i] = 1.0; }
548
549        // diag[i] = T[i,i], offd[i] = T[i,i+1]
550        let mut diag = vec![0.0f64; n];
551        let mut offd = vec![0.0f64; n]; // offd[0..n-1], offd[n-1] unused
552
553        for k in 0..n.saturating_sub(2) {
554            // Compute Householder vector for column k, rows k+1..n
555            let mut sigma = 0.0f64;
556            for i in (k + 1)..n {
557                sigma += a[i * n + k] * a[i * n + k];
558            }
559            let sigma_sqrt = sigma.sqrt();
560
561            if sigma_sqrt < 1e-15 {
562                // Column already zero below diagonal
563                offd[k] = a[(k + 1) * n + k];
564                continue;
565            }
566
567            let alpha = if a[(k + 1) * n + k] >= 0.0 { -sigma_sqrt } else { sigma_sqrt };
568            offd[k] = alpha;
569
570            // Householder vector v stored in a[k+1..n, k]
571            a[(k + 1) * n + k] -= alpha;
572            let mut v_norm_sq = 0.0f64;
573            for i in (k + 1)..n {
574                v_norm_sq += a[i * n + k] * a[i * n + k];
575            }
576            if v_norm_sq < 1e-30 { continue; }
577            let inv_v_norm_sq = 2.0 / v_norm_sq;
578
579            // p = inv_v_norm_sq * A[k+1..n, k+1..n] * v
580            let mut p = vec![0.0f64; n];
581            for i in (k + 1)..n {
582                let mut s = 0.0f64;
583                for j in (k + 1)..n {
584                    s += a[i * n + j] * a[j * n + k];
585                }
586                p[i] = inv_v_norm_sq * s;
587            }
588
589            // K = v^T p / (2 * v_norm_sq) * inv_v_norm_sq
590            let mut vtp = 0.0f64;
591            for i in (k + 1)..n { vtp += a[i * n + k] * p[i]; }
592            let kk = inv_v_norm_sq * vtp * 0.5;
593
594            // q = p - K * v
595            let mut qq = vec![0.0f64; n];
596            for i in (k + 1)..n {
597                qq[i] = p[i] - kk * a[i * n + k];
598            }
599
600            // A = A - v * q^T - q * v^T
601            for i in (k + 1)..n {
602                for j in (k + 1)..n {
603                    a[i * n + j] -= a[i * n + k] * qq[j] + qq[i] * a[j * n + k];
604                }
605            }
606
607            // Accumulate Q = Q * (I - inv_v_norm_sq * v * v^T)
608            // Q := Q - (Q * v) * (inv_v_norm_sq * v^T)
609            let mut qv = vec![0.0f64; n];
610            for i in 0..n {
611                let mut s = 0.0f64;
612                for j in (k + 1)..n {
613                    s += q[i * n + j] * a[j * n + k];
614                }
615                qv[i] = s;
616            }
617            for i in 0..n {
618                for j in (k + 1)..n {
619                    q[i * n + j] -= inv_v_norm_sq * qv[i] * a[j * n + k];
620                }
621            }
622        }
623
624        // Fill diagonal and last off-diagonal from the reduced matrix
625        for i in 0..n { diag[i] = a[i * n + i]; }
626        if n >= 2 { offd[n - 2] = a[(n - 2) * n + (n - 1)]; }
627
628        // ── Step 2: Implicit QR iteration on tridiagonal matrix ────────
629        // Wilkinson shift for cubic convergence.
630        // O(n) per iteration, O(n) iterations typical => O(n^2) total.
631
632        let max_iter = 30 * n;
633        let mut lo = 0usize;
634        let mut hi = n - 1;
635
636        for _iter in 0..max_iter {
637            // Find unreduced block
638            while lo < hi {
639                if offd[hi - 1].abs() < 1e-14 * (diag[hi - 1].abs() + diag[hi].abs()).max(1e-14) {
640                    offd[hi - 1] = 0.0;
641                    hi -= 1;
642                } else {
643                    break;
644                }
645            }
646            if hi <= lo { break; }
647
648            // Find start of unreduced block
649            let mut block_start = hi - 1;
650            while block_start > lo {
651                if offd[block_start - 1].abs() < 1e-14 * (diag[block_start - 1].abs() + diag[block_start].abs()).max(1e-14) {
652                    offd[block_start - 1] = 0.0;
653                    break;
654                }
655                block_start -= 1;
656            }
657
658            if block_start == hi {
659                // 1×1 block converged
660                hi = if hi > 0 { hi - 1 } else { break };
661                continue;
662            }
663
664            // Wilkinson shift: eigenvalue of trailing 2×2 closer to d[hi]
665            let d_hi = diag[hi];
666            let d_him1 = diag[hi - 1];
667            let e_him1 = offd[hi - 1];
668            let delta = (d_him1 - d_hi) * 0.5;
669            let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
670            let shift = d_hi - e_him1 * e_him1 / (delta + sign_delta * (delta * delta + e_him1 * e_him1).sqrt());
671
672            // Implicit QR step (Givens rotations)
673            let mut x = diag[block_start] - shift;
674            let mut z = offd[block_start];
675
676            for k in block_start..hi {
677                // Compute Givens rotation to zero z
678                let r = (x * x + z * z).sqrt();
679                let c = if r > 1e-30 { x / r } else { 1.0 };
680                let s = if r > 1e-30 { -z / r } else { 0.0 };
681
682                // Apply to tridiagonal elements
683                if k > block_start {
684                    offd[k - 1] = r;
685                }
686                let d0 = diag[k];
687                let d1 = diag[k + 1];
688                let e0 = offd[k];
689                diag[k] = c * c * d0 + s * s * d1 - 2.0 * c * s * e0;
690                diag[k + 1] = s * s * d0 + c * c * d1 + 2.0 * c * s * e0;
691                offd[k] = c * s * (d0 - d1) + (c * c - s * s) * e0;
692
693                if k + 1 < hi {
694                    x = offd[k];
695                    z = -s * offd[k + 1];
696                    offd[k + 1] *= c;
697                }
698
699                // Accumulate eigenvector rotation: Q[:, k:k+2] *= G
700                for i in 0..n {
701                    let qik = q[i * n + k];
702                    let qik1 = q[i * n + k + 1];
703                    q[i * n + k] = c * qik - s * qik1;
704                    q[i * n + k + 1] = s * qik + c * qik1;
705                }
706            }
707
708            // Check if all off-diagonals converged
709            let mut all_converged = true;
710            for i in lo..hi {
711                if offd[i].abs() >= 1e-14 * (diag[i].abs() + diag[i + 1].abs()).max(1e-14) {
712                    all_converged = false;
713                    break;
714                }
715            }
716            if all_converged { break; }
717        }
718
719        // ── Step 3: Sort eigenvalues ascending, reorder eigenvectors ────
720        let mut eigenvalues: Vec<(f64, usize)> = (0..n).map(|i| (diag[i], i)).collect();
721        eigenvalues.sort_by(|a, b| a.0.total_cmp(&b.0));
722        let vals: Vec<f64> = eigenvalues.iter().map(|&(v, _)| v).collect();
723
724        let mut v_sorted = vec![0.0; n * n];
725        for (new_col, &(_, old_col)) in eigenvalues.iter().enumerate() {
726            for row in 0..n {
727                v_sorted[row * n + new_col] = q[row * n + old_col];
728            }
729        }
730
731        // Sign-canonical: first nonzero component positive
732        for col in 0..n {
733            let mut first_nonzero = 0.0;
734            for row in 0..n {
735                if v_sorted[row * n + col].abs() > 1e-15 {
736                    first_nonzero = v_sorted[row * n + col];
737                    break;
738                }
739            }
740            if first_nonzero < 0.0 {
741                for row in 0..n {
742                    v_sorted[row * n + col] = -v_sorted[row * n + col];
743                }
744            }
745        }
746
747        Ok((vals, Tensor::from_vec(v_sorted, &[n, n])?))
748    }
749
750    /// Estimate the matrix rank by counting nonzero diagonal elements of `R`
751    /// from a QR decomposition (tolerance `1e-10`).
752    pub fn matrix_rank(&self) -> Result<usize, RuntimeError> {
753        if self.ndim() != 2 {
754            return Err(RuntimeError::InvalidOperation(
755                "matrix_rank requires a 2D matrix".to_string(),
756            ));
757        }
758        // Simple approach: use QR and count non-zero diagonal
759        let (_q, r) = self.qr_decompose()?;
760        let r_data = r.to_vec();
761        let n = r.shape()[0].min(r.shape()[1]);
762        let cols = r.shape()[1];
763        let mut rank = 0;
764        for i in 0..n {
765            if r_data[i * cols + i].abs() > 1e-10 {
766                rank += 1;
767            }
768        }
769        Ok(rank)
770    }
771
772    /// Compute the Kronecker product `A (x) B`.
773    ///
774    /// For `A` of shape (m, n) and `B` of shape (p, q), the result has
775    /// shape (m*p, n*q).
776    pub fn kron(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
777        if self.ndim() != 2 || other.ndim() != 2 {
778            return Err(RuntimeError::InvalidOperation(
779                "kron requires two 2D matrices".to_string(),
780            ));
781        }
782        let (m, n) = (self.shape[0], self.shape[1]);
783        let (p, q) = (other.shape()[0], other.shape()[1]);
784        let a = self.to_vec();
785        let b = other.to_vec();
786        let mut result = vec![0.0; m * p * n * q];
787        let out_cols = n * q;
788        for i in 0..m {
789            for j in 0..n {
790                let aij = a[i * n + j];
791                for k in 0..p {
792                    for l in 0..q {
793                        result[(i * p + k) * out_cols + (j * q + l)] = aij * b[k * q + l];
794                    }
795                }
796            }
797        }
798        Tensor::from_vec(result, &[m * p, n * q])
799    }
800
801    /// Compute the matrix inverse via LU decomposition and column-wise
802    /// forward/back substitution.
803    pub fn inverse(&self) -> Result<Tensor, RuntimeError> {
804        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
805            return Err(RuntimeError::InvalidOperation(
806                "Matrix inverse requires a square 2D matrix".to_string(),
807            ));
808        }
809        let n = self.shape[0];
810        let (l, u, pivots) = self.lu_decompose()?;
811        let l_data = l.to_vec();
812        let u_data = u.to_vec();
813
814        let mut inv = vec![0.0f64; n * n];
815
816        // Solve for each column of the inverse
817        for col in 0..n {
818            // Create permuted identity column
819            let mut b = vec![0.0f64; n];
820            b[pivots[col]] = 1.0;
821
822            // Forward substitution: L * y = b
823            let mut y = vec![0.0f64; n];
824            for i in 0..n {
825                let mut sum = b[i];
826                for j in 0..i {
827                    sum -= l_data[i * n + j] * y[j];
828                }
829                y[i] = sum; // L has 1s on diagonal
830            }
831
832            // Back substitution: U * x = y
833            let mut x = vec![0.0f64; n];
834            for i in (0..n).rev() {
835                let mut sum = y[i];
836                for j in (i + 1)..n {
837                    sum -= u_data[i * n + j] * x[j];
838                }
839                x[i] = sum / u_data[i * n + i];
840            }
841
842            for i in 0..n {
843                inv[i * n + col] = x[i];
844            }
845        }
846
847        Tensor::from_vec(inv, &[n, n])
848    }
849
850    // -----------------------------------------------------------------------
851    // Phase B3: Linear algebra extensions
852    // -----------------------------------------------------------------------
853
854    /// Compute the matrix 1-norm (maximum absolute column sum) using
855    /// [`BinnedAccumulatorF64`].
856    pub fn norm_1(&self) -> Result<f64, RuntimeError> {
857        if self.ndim() != 2 {
858            return Err(RuntimeError::InvalidOperation(
859                "norm_1 requires a 2D matrix".to_string(),
860            ));
861        }
862        let (m, n) = (self.shape[0], self.shape[1]);
863        let data = self.to_vec();
864        let mut max_col_sum = 0.0_f64;
865        for j in 0..n {
866            let mut acc = BinnedAccumulatorF64::new();
867            for i in 0..m {
868                acc.add(data[i * n + j].abs());
869            }
870            let col_sum = acc.finalize();
871            if col_sum > max_col_sum {
872                max_col_sum = col_sum;
873            }
874        }
875        Ok(max_col_sum)
876    }
877
878    /// Compute the matrix infinity-norm (maximum absolute row sum) using
879    /// [`BinnedAccumulatorF64`].
880    pub fn norm_inf(&self) -> Result<f64, RuntimeError> {
881        if self.ndim() != 2 {
882            return Err(RuntimeError::InvalidOperation(
883                "norm_inf requires a 2D matrix".to_string(),
884            ));
885        }
886        let (m, n) = (self.shape[0], self.shape[1]);
887        let data = self.to_vec();
888        let mut max_row_sum = 0.0_f64;
889        for i in 0..m {
890            let mut acc = BinnedAccumulatorF64::new();
891            for j in 0..n {
892                acc.add(data[i * n + j].abs());
893            }
894            let row_sum = acc.finalize();
895            if row_sum > max_row_sum {
896                max_row_sum = row_sum;
897            }
898        }
899        Ok(max_row_sum)
900    }
901
902    /// Estimate the 2-norm condition number of the matrix.
903    ///
904    /// For symmetric matrices, computes `|lambda_max| / |lambda_min|` via [`eigh`](Tensor::eigh).
905    /// For general matrices, computes `sqrt(sigma_max / sigma_min)` via the
906    /// eigenvalues of `A^T * A`.
907    pub fn cond(&self) -> Result<f64, RuntimeError> {
908        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
909            return Err(RuntimeError::InvalidOperation(
910                "cond requires a square 2D matrix".to_string(),
911            ));
912        }
913        let n = self.shape[0];
914        // Check if symmetric
915        let data = self.to_vec();
916        let mut is_sym = true;
917        'outer: for i in 0..n {
918            for j in (i + 1)..n {
919                if (data[i * n + j] - data[j * n + i]).abs() > 1e-14 {
920                    is_sym = false;
921                    break 'outer;
922                }
923            }
924        }
925        if is_sym {
926            let (eigenvalues, _) = self.eigh()?;
927            let abs_min = eigenvalues.iter().map(|v| v.abs()).fold(f64::INFINITY, f64::min);
928            let abs_max = eigenvalues.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
929            if abs_min < 1e-15 {
930                return Ok(f64::INFINITY);
931            }
932            Ok(abs_max / abs_min)
933        } else {
934            // General: compute A^T * A, then eigh
935            let at = self.transpose();
936            let ata = at.matmul(self)?;
937            let (eigenvalues, _) = ata.eigh()?;
938            let abs_min = eigenvalues.iter().map(|v| v.abs()).fold(f64::INFINITY, f64::min);
939            let abs_max = eigenvalues.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
940            if abs_min < 1e-15 {
941                return Ok(f64::INFINITY);
942            }
943            Ok((abs_max / abs_min).sqrt())
944        }
945    }
946
947    /// Compute the real Schur decomposition: `A = Q * T * Q^T`.
948    ///
949    /// `T` is quasi-upper-triangular (upper triangular with possible 2x2 blocks
950    /// on the diagonal for complex eigenvalue pairs) and `Q` is orthogonal.
951    ///
952    /// # Algorithm
953    ///
954    /// 1. Householder reduction to upper Hessenberg form.
955    /// 2. Implicit single-shift QR iteration with Wilkinson shift and Givens
956    ///    rotations.
957    pub fn schur(&self) -> Result<(Tensor, Tensor), RuntimeError> {
958        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
959            return Err(RuntimeError::InvalidOperation(
960                "schur requires a square 2D matrix".to_string(),
961            ));
962        }
963        let n = self.shape[0];
964        if n == 0 {
965            return Err(RuntimeError::InvalidOperation("schur: empty matrix".to_string()));
966        }
967        if n == 1 {
968            return Ok((
969                Tensor::from_vec(vec![1.0], &[1, 1])?,
970                self.clone(),
971            ));
972        }
973
974        let mut h = self.to_vec();
975        let mut q = vec![0.0; n * n];
976        for i in 0..n {
977            q[i * n + i] = 1.0;
978        }
979
980        // Step 1: Reduce to upper Hessenberg form
981        for k in 0..n.saturating_sub(2) {
982            // Compute Householder reflector for column k below diagonal
983            let mut col = vec![0.0; n - k - 1];
984            for i in 0..col.len() {
985                col[i] = h[(k + 1 + i) * n + k];
986            }
987            let norm_col = {
988                let mut acc = BinnedAccumulatorF64::new();
989                for &v in &col { acc.add(v * v); }
990                acc.finalize().sqrt()
991            };
992            if norm_col < 1e-15 {
993                continue;
994            }
995            // Sign convention: positive diagonal
996            let sign = if col[0] >= 0.0 { 1.0 } else { -1.0 };
997            col[0] += sign * norm_col;
998            let norm_v = {
999                let mut acc = BinnedAccumulatorF64::new();
1000                for &v in &col { acc.add(v * v); }
1001                acc.finalize().sqrt()
1002            };
1003            if norm_v < 1e-15 {
1004                continue;
1005            }
1006            for v in &mut col {
1007                *v /= norm_v;
1008            }
1009            // Apply H = (I - 2vv^T) * H from left
1010            for j in 0..n {
1011                let mut acc = BinnedAccumulatorF64::new();
1012                for i in 0..col.len() {
1013                    acc.add(col[i] * h[(k + 1 + i) * n + j]);
1014                }
1015                let dot = acc.finalize();
1016                for i in 0..col.len() {
1017                    h[(k + 1 + i) * n + j] -= 2.0 * col[i] * dot;
1018                }
1019            }
1020            // Apply H = H * (I - 2vv^T) from right
1021            for i in 0..n {
1022                let mut acc = BinnedAccumulatorF64::new();
1023                for j in 0..col.len() {
1024                    acc.add(col[j] * h[i * n + (k + 1 + j)]);
1025                }
1026                let dot = acc.finalize();
1027                for j in 0..col.len() {
1028                    h[i * n + (k + 1 + j)] -= 2.0 * col[j] * dot;
1029                }
1030            }
1031            // Accumulate Q
1032            for i in 0..n {
1033                let mut acc = BinnedAccumulatorF64::new();
1034                for j in 0..col.len() {
1035                    acc.add(col[j] * q[i * n + (k + 1 + j)]);
1036                }
1037                let dot = acc.finalize();
1038                for j in 0..col.len() {
1039                    q[i * n + (k + 1 + j)] -= 2.0 * col[j] * dot;
1040                }
1041            }
1042        }
1043
1044        // Step 2: Implicit QR iterations
1045        let eps = 1e-14;
1046        let max_iter = 200 * n;
1047        let mut ihi = n - 1;
1048
1049        for _iter in 0..max_iter {
1050            if ihi == 0 {
1051                break;
1052            }
1053            // Find active block
1054            let mut ilo = ihi;
1055            while ilo > 0 {
1056                if h[ilo * n + (ilo - 1)].abs()
1057                    < eps * (h[(ilo - 1) * n + (ilo - 1)].abs() + h[ilo * n + ilo].abs())
1058                {
1059                    h[ilo * n + (ilo - 1)] = 0.0;
1060                    break;
1061                }
1062                ilo -= 1;
1063            }
1064            if ilo == ihi {
1065                // 1x1 block converged
1066                ihi -= 1;
1067                continue;
1068            }
1069            if ilo + 1 == ihi {
1070                // 2x2 block - check if real eigenvalues
1071                ihi -= 2;
1072                continue;
1073            }
1074
1075            // Wilkinson shift from trailing 2x2
1076            let a11 = h[(ihi - 1) * n + (ihi - 1)];
1077            let a12 = h[(ihi - 1) * n + ihi];
1078            let a21 = h[ihi * n + (ihi - 1)];
1079            let a22 = h[ihi * n + ihi];
1080            let tr = a11 + a22;
1081            let det = a11 * a22 - a12 * a21;
1082            let disc = tr * tr - 4.0 * det;
1083
1084            let shift = if disc >= 0.0 {
1085                let sqrt_disc = disc.sqrt();
1086                let ev1 = (tr + sqrt_disc) / 2.0;
1087                let ev2 = (tr - sqrt_disc) / 2.0;
1088                if (ev1 - a22).abs() < (ev2 - a22).abs() { ev1 } else { ev2 }
1089            } else {
1090                a22 // complex eigenvalues: use a22 as shift
1091            };
1092
1093            // Apply single-shift QR step with Givens rotations
1094            let mut x = h[ilo * n + ilo] - shift;
1095            let mut z = h[(ilo + 1) * n + ilo];
1096            for k in ilo..ihi {
1097                let r = (x * x + z * z).sqrt();
1098                let c = if r < 1e-15 { 1.0 } else { x / r };
1099                let s = if r < 1e-15 { 0.0 } else { z / r };
1100                // Apply Givens from left: rows k and k+1
1101                for j in 0..n {
1102                    let t1 = h[k * n + j];
1103                    let t2 = h[(k + 1) * n + j];
1104                    h[k * n + j] = c * t1 + s * t2;
1105                    h[(k + 1) * n + j] = -s * t1 + c * t2;
1106                }
1107                // Apply Givens from right: cols k and k+1
1108                let jmax = if k + 3 < n { k + 3 } else { n };
1109                for i in 0..jmax {
1110                    let t1 = h[i * n + k];
1111                    let t2 = h[i * n + (k + 1)];
1112                    h[i * n + k] = c * t1 + s * t2;
1113                    h[i * n + (k + 1)] = -s * t1 + c * t2;
1114                }
1115                // Accumulate Q
1116                for i in 0..n {
1117                    let t1 = q[i * n + k];
1118                    let t2 = q[i * n + (k + 1)];
1119                    q[i * n + k] = c * t1 + s * t2;
1120                    q[i * n + (k + 1)] = -s * t1 + c * t2;
1121                }
1122                if k + 2 <= ihi {
1123                    x = h[(k + 1) * n + k];
1124                    z = h[(k + 2) * n + k];
1125                }
1126            }
1127        }
1128
1129        // Clean up sub-diagonal entries
1130        for i in 0..n {
1131            for j in 0..i.saturating_sub(1) {
1132                h[i * n + j] = 0.0;
1133            }
1134        }
1135
1136        Ok((
1137            Tensor::from_vec(q, &[n, n])?,
1138            Tensor::from_vec(h, &[n, n])?,
1139        ))
1140    }
1141
1142    // -----------------------------------------------------------------------
1143    // Phase 3A: SVD via Golub-Kahan Bidiagonalization
1144    // -----------------------------------------------------------------------
1145
1146    /// Compute the Singular Value Decomposition: `A = U * diag(S) * Vt`.
1147    ///
1148    /// Returns `(U, S, Vt)` where `S` is a `Vec<f64>` of singular values in
1149    /// descending order, `U` is m x k, and `Vt` is k x n (k = min(m, n)).
1150    ///
1151    /// # Algorithm
1152    ///
1153    /// Eigendecomposition of `A^T * A` yields `V` and `sigma^2`. Then
1154    /// `U = A * V * diag(1 / sigma_i)`. Sign-canonical: largest-magnitude
1155    /// element of each `U` column is positive.
1156    ///
1157    /// # Determinism
1158    ///
1159    /// All intermediate floating-point reductions use [`BinnedAccumulatorF64`].
1160    /// Iteration order is fixed row-major.
1161    pub fn svd(&self) -> Result<(Tensor, Vec<f64>, Tensor), RuntimeError> {
1162        if self.ndim() != 2 {
1163            return Err(RuntimeError::InvalidOperation(
1164                "SVD requires a 2D matrix".to_string(),
1165            ));
1166        }
1167        let m = self.shape[0];
1168        let n = self.shape[1];
1169        if m == 0 || n == 0 {
1170            return Err(RuntimeError::InvalidOperation(
1171                "SVD: empty matrix".to_string(),
1172            ));
1173        }
1174
1175        let min_mn = m.min(n);
1176
1177        // Compute A^T * A (n x n symmetric matrix)
1178        let at = self.transpose();
1179        let ata = at.matmul(self)?;
1180
1181        // Eigendecomposition of A^T*A gives V and eigenvalues = sigma^2
1182        let (eigenvalues, eigenvectors) = ata.eigh()?;
1183
1184        // eigenvalues are in ascending order from eigh; we want descending singular values
1185        // Singular values = sqrt(eigenvalues), clamp negatives to 0
1186        let mut singular_values: Vec<f64> = eigenvalues.iter()
1187            .map(|&ev| if ev > 0.0 { ev.sqrt() } else { 0.0 })
1188            .collect();
1189
1190        // Reverse to get descending order
1191        singular_values.reverse();
1192
1193        // Take only min(m,n) singular values
1194        let k = min_mn.min(singular_values.len());
1195        let s: Vec<f64> = singular_values[..k].to_vec();
1196
1197        // V columns are eigenvectors of A^T*A, reversed for descending order
1198        // eigenvectors is n x n, columns are eigenvectors
1199        let ev_data = eigenvectors.to_vec();
1200        let ev_n = eigenvectors.shape()[1]; // should be n
1201
1202        // Build V matrix (n x k) with columns in descending singular value order
1203        let mut v_data = vec![0.0f64; n * k];
1204        for col in 0..k {
1205            let ev_col = n - 1 - col; // reverse index for descending order
1206            for row in 0..n {
1207                v_data[row * k + col] = ev_data[row * ev_n + ev_col];
1208            }
1209        }
1210        let v_mat = Tensor::from_vec(v_data.clone(), &[n, k])?;
1211
1212        // U = A * V * diag(1/s)
1213        // First compute A * V
1214        let av = self.matmul(&v_mat)?;
1215        let av_data = av.to_vec();
1216
1217        // Then scale each column by 1/s_i
1218        let mut u_data = vec![0.0f64; m * k];
1219        for col in 0..k {
1220            if s[col] > 1e-14 {
1221                let inv_s = 1.0 / s[col];
1222                for row in 0..m {
1223                    u_data[row * k + col] = av_data[row * k + col] * inv_s;
1224                }
1225            }
1226            // If s[col] ≈ 0, leave u column as zeros
1227        }
1228
1229        // Sign-canonical: ensure largest-magnitude element of each u column is positive
1230        for col in 0..k {
1231            let mut max_abs = 0.0f64;
1232            let mut max_sign = 1.0f64;
1233            for row in 0..m {
1234                let val = u_data[row * k + col];
1235                if val.abs() > max_abs {
1236                    max_abs = val.abs();
1237                    max_sign = if val >= 0.0 { 1.0 } else { -1.0 };
1238                }
1239            }
1240            if max_sign < 0.0 {
1241                for row in 0..m {
1242                    u_data[row * k + col] = -u_data[row * k + col];
1243                }
1244                for row in 0..n {
1245                    v_data[row * k + col] = -v_data[row * k + col];
1246                }
1247            }
1248        }
1249
1250        let u_tensor = Tensor::from_vec(u_data, &[m, k])?;
1251
1252        // Vt = V^T (k x n)
1253        let mut vt_data = vec![0.0f64; k * n];
1254        for row in 0..k {
1255            for col in 0..n {
1256                vt_data[row * n + col] = v_data[col * k + row];
1257            }
1258        }
1259        let vt_tensor = Tensor::from_vec(vt_data, &[k, n])?;
1260
1261        Ok((u_tensor, s, vt_tensor))
1262    }
1263
1264    /// Compute a truncated SVD retaining only the top `k` singular triplets.
1265    ///
1266    /// Returns `(U_k, S_k, Vt_k)` where `U_k` is m x k and `Vt_k` is k x n.
1267    pub fn svd_truncated(
1268        &self,
1269        k: usize,
1270    ) -> Result<(Tensor, Vec<f64>, Tensor), RuntimeError> {
1271        let (u_full, s_full, vt_full) = self.svd()?;
1272        let m = u_full.shape()[0];
1273        let n = vt_full.shape()[1];
1274        let actual_k = k.min(s_full.len());
1275
1276        if actual_k == 0 {
1277            return Err(RuntimeError::InvalidOperation(
1278                "svd_truncated: k must be > 0".to_string(),
1279            ));
1280        }
1281
1282        let s_k: Vec<f64> = s_full[..actual_k].to_vec();
1283
1284        // Extract first k columns of U
1285        let u_data = u_full.to_vec();
1286        let u_cols = u_full.shape()[1];
1287        let mut u_k = vec![0.0f64; m * actual_k];
1288        for row in 0..m {
1289            for col in 0..actual_k {
1290                u_k[row * actual_k + col] = u_data[row * u_cols + col];
1291            }
1292        }
1293
1294        // Extract first k rows of Vt
1295        let vt_data = vt_full.to_vec();
1296        let mut vt_k = vec![0.0f64; actual_k * n];
1297        for row in 0..actual_k {
1298            for col in 0..n {
1299                vt_k[row * n + col] = vt_data[row * n + col];
1300            }
1301        }
1302
1303        Ok((
1304            Tensor::from_vec(u_k, &[m, actual_k])?,
1305            s_k,
1306            Tensor::from_vec(vt_k, &[actual_k, n])?,
1307        ))
1308    }
1309
1310    // -----------------------------------------------------------------------
1311    // Phase 3B: Pseudoinverse (Moore-Penrose, via SVD)
1312    // -----------------------------------------------------------------------
1313
1314    /// Compute the Moore-Penrose pseudoinverse via SVD.
1315    ///
1316    /// `A+ = V * diag(1/s_i) * U^T`, with default tolerance
1317    /// `max(m, n) * eps * max(S)` for near-zero singular values.
1318    pub fn pinv(&self) -> Result<Tensor, RuntimeError> {
1319        // Default tolerance: max(m,n) * eps * max(S)
1320        let (u, s, vt) = self.svd()?;
1321        let m = self.shape[0];
1322        let n = self.shape[1];
1323        let max_s = s.first().copied().unwrap_or(0.0);
1324        let tol = (m.max(n) as f64) * f64::EPSILON * max_s;
1325        Self::pinv_from_svd(&u, &s, &vt, tol)
1326    }
1327
1328    /// Compute the Moore-Penrose pseudoinverse via SVD with an explicit
1329    /// singular-value cutoff tolerance.
1330    pub fn pinv_with_tol(&self, tol: f64) -> Result<Tensor, RuntimeError> {
1331        let (u, s, vt) = self.svd()?;
1332        Self::pinv_from_svd(&u, &s, &vt, tol)
1333    }
1334
1335    /// Internal: compute pseudoinverse from pre-computed SVD.
1336    /// A+ = V @ diag(1/s_i) @ Ut, zeroing 1/s_i where s_i <= tol.
1337    fn pinv_from_svd(
1338        u: &Tensor,
1339        s: &[f64],
1340        vt: &Tensor,
1341        tol: f64,
1342    ) -> Result<Tensor, RuntimeError> {
1343        let m = u.shape()[0];
1344        let k = s.len();
1345        let n = vt.shape()[1];
1346
1347        // Build S_inv: k-vector with 1/s_i or 0
1348        let s_inv: Vec<f64> = s
1349            .iter()
1350            .map(|&si| if si > tol { 1.0 / si } else { 0.0 })
1351            .collect();
1352
1353        // Compute Vt^T @ diag(s_inv) @ U^T = V @ diag(s_inv) @ Ut
1354        // Result is n x m
1355        let u_data = u.to_vec();
1356        let vt_data = vt.to_vec();
1357        let mut result = vec![0.0f64; n * m];
1358
1359        for i in 0..n {
1360            for j in 0..m {
1361                let mut acc = BinnedAccumulatorF64::new();
1362                for l in 0..k {
1363                    // V[i, l] = Vt[l, i] (transposed)
1364                    // Ut[l, j] = U[j, l] (transposed)
1365                    acc.add(vt_data[l * n + i] * s_inv[l] * u_data[j * k + l]);
1366                }
1367                result[i * m + j] = acc.finalize();
1368            }
1369        }
1370
1371        Tensor::from_vec(result, &[n, m])
1372    }
1373
1374    /// Helper: compute Givens rotation parameters.
1375    /// Returns (c, s, r) such that [c s; -s c]^T * [a; b] = [r; 0].
1376    fn givens_rotation(a: f64, b: f64) -> (f64, f64, f64) {
1377        if b.abs() < 1e-15 {
1378            (1.0, 0.0, a)
1379        } else if a.abs() < 1e-15 {
1380            (0.0, if b >= 0.0 { 1.0 } else { -1.0 }, b.abs())
1381        } else {
1382            let r = (a * a + b * b).sqrt();
1383            (a / r, b / r, r)
1384        }
1385    }
1386
1387    /// Compute the matrix exponential `exp(A)` via scaling-and-squaring with a
1388    /// Pade(13,13) rational approximation.
1389    ///
1390    /// # Errors
1391    ///
1392    /// Returns [`RuntimeError::InvalidOperation`] if the matrix is not square 2-D.
1393    pub fn matrix_exp(&self) -> Result<Tensor, RuntimeError> {
1394        if self.ndim() != 2 || self.shape[0] != self.shape[1] {
1395            return Err(RuntimeError::InvalidOperation(
1396                "matrix_exp requires a square 2D matrix".to_string(),
1397            ));
1398        }
1399        let n = self.shape[0];
1400        if n == 0 {
1401            return Err(RuntimeError::InvalidOperation("matrix_exp: empty matrix".to_string()));
1402        }
1403
1404        const PADE_COEFFS: [f64; 14] = [
1405            64764752532480000.0,
1406            32382376266240000.0,
1407            7771770303897600.0,
1408            1187353796428800.0,
1409            129060195264000.0,
1410            10559470521600.0,
1411            670442572800.0,
1412            33522128640.0,
1413            1323241920.0,
1414            40840800.0,
1415            960960.0,
1416            16380.0,
1417            182.0,
1418            1.0,
1419        ];
1420        const THETA_13: f64 = 5.371920351148152;
1421
1422        // Scaling: s = max(0, ceil(log2(||A||_1 / theta_13)))
1423        let norm = self.norm_1()?;
1424        let s = if norm <= THETA_13 {
1425            0u32
1426        } else {
1427            (norm / THETA_13).log2().ceil() as u32
1428        };
1429
1430        // B = A / 2^s
1431        let scale = 2.0_f64.powi(-(s as i32));
1432        let b_data: Vec<f64> = self.to_vec().iter().map(|&x| x * scale).collect();
1433        let b = Tensor::from_vec(b_data, &[n, n])?;
1434
1435        // Compute B^2, B^4, B^6
1436        let b2 = b.matmul(&b)?;
1437        let b4 = b2.matmul(&b2)?;
1438        let b6 = b4.matmul(&b2)?;
1439
1440        // Identity matrix
1441        let mut eye = vec![0.0; n * n];
1442        for i in 0..n {
1443            eye[i * n + i] = 1.0;
1444        }
1445        let eye_t = Tensor::from_vec(eye, &[n, n])?;
1446
1447        // Build U and V
1448        // U = B * (b_13*B^6 + b_11*B^4 + b_9*B^2 + b_7*I) * B^6
1449        //   + B * (b_5*B^4 + b_3*B^2 + b_1*I)
1450        // V = (b_12*B^6 + b_10*B^4 + b_8*B^2 + b_6*I) * B^6
1451        //   + (b_4*B^4 + b_2*B^2 + b_0*I)
1452
1453        // Helper: scale and add tensors
1454        fn scale_add(a: &Tensor, sa: f64, b: &Tensor, sb: f64, n: usize) -> Vec<f64> {
1455            let ad = a.to_vec();
1456            let bd = b.to_vec();
1457            let mut r = vec![0.0; n * n];
1458            for i in 0..n * n {
1459                r[i] = sa * ad[i] + sb * bd[i];
1460            }
1461            r
1462        }
1463
1464        let c = &PADE_COEFFS;
1465
1466        // inner_u1 = b_13*B6 + b_11*B4 + b_9*B2 + b_7*I
1467        let mut iu1 = scale_add(&b6, c[13], &b4, c[11], n);
1468        let t = scale_add(&b2, c[9], &eye_t, c[7], n);
1469        for i in 0..n * n {
1470            iu1[i] += t[i];
1471        }
1472        let iu1_t = Tensor::from_vec(iu1, &[n, n])?;
1473        let iu1_b6 = iu1_t.matmul(&b6)?;
1474
1475        // inner_u2 = b_5*B4 + b_3*B2 + b_1*I
1476        let mut iu2 = scale_add(&b4, c[5], &b2, c[3], n);
1477        let t = Tensor::from_vec({
1478            let mut v = vec![0.0; n * n];
1479            for i in 0..n { v[i * n + i] = c[1]; }
1480            v
1481        }, &[n, n])?;
1482        let td = t.to_vec();
1483        for i in 0..n * n {
1484            iu2[i] += td[i];
1485        }
1486        let iu2_t = Tensor::from_vec(iu2, &[n, n])?;
1487
1488        // U_inner = iu1_b6 + iu2
1489        let iu1d = iu1_b6.to_vec();
1490        let iu2d = iu2_t.to_vec();
1491        let mut u_inner = vec![0.0; n * n];
1492        for i in 0..n * n {
1493            u_inner[i] = iu1d[i] + iu2d[i];
1494        }
1495        let u_inner_t = Tensor::from_vec(u_inner, &[n, n])?;
1496        let u = b.matmul(&u_inner_t)?;
1497
1498        // inner_v1 = b_12*B6 + b_10*B4 + b_8*B2 + b_6*I
1499        let mut iv1 = scale_add(&b6, c[12], &b4, c[10], n);
1500        let t = scale_add(&b2, c[8], &eye_t, c[6], n);
1501        for i in 0..n * n {
1502            iv1[i] += t[i];
1503        }
1504        let iv1_t = Tensor::from_vec(iv1, &[n, n])?;
1505        let iv1_b6 = iv1_t.matmul(&b6)?;
1506
1507        // inner_v2 = b_4*B4 + b_2*B2 + b_0*I
1508        let mut iv2 = scale_add(&b4, c[4], &b2, c[2], n);
1509        let t = Tensor::from_vec({
1510            let mut v = vec![0.0; n * n];
1511            for i in 0..n { v[i * n + i] = c[0]; }
1512            v
1513        }, &[n, n])?;
1514        let td = t.to_vec();
1515        for i in 0..n * n {
1516            iv2[i] += td[i];
1517        }
1518        let iv2_t = Tensor::from_vec(iv2, &[n, n])?;
1519
1520        // V = iv1_b6 + iv2
1521        let iv1d = iv1_b6.to_vec();
1522        let iv2d = iv2_t.to_vec();
1523        let mut v_data = vec![0.0; n * n];
1524        for i in 0..n * n {
1525            v_data[i] = iv1d[i] + iv2d[i];
1526        }
1527        let v_mat = Tensor::from_vec(v_data, &[n, n])?;
1528
1529        // Solve (V - U) * r = (V + U)
1530        let ud = u.to_vec();
1531        let vd = v_mat.to_vec();
1532        let mut lhs_data = vec![0.0; n * n];
1533        let mut rhs_data = vec![0.0; n * n];
1534        for i in 0..n * n {
1535            lhs_data[i] = vd[i] - ud[i];
1536            rhs_data[i] = vd[i] + ud[i];
1537        }
1538        let lhs = Tensor::from_vec(lhs_data, &[n, n])?;
1539
1540        // Solve column-by-column
1541        let mut result = vec![0.0; n * n];
1542        for col in 0..n {
1543            let mut rhs_col = vec![0.0; n];
1544            for row in 0..n {
1545                rhs_col[row] = rhs_data[row * n + col];
1546            }
1547            let rhs_tensor = Tensor::from_vec(rhs_col, &[n])?;
1548            let sol = lhs.solve(&rhs_tensor)?;
1549            let sol_data = sol.to_vec();
1550            for row in 0..n {
1551                result[row * n + col] = sol_data[row];
1552            }
1553        }
1554
1555        let mut r = Tensor::from_vec(result, &[n, n])?;
1556
1557        // Square s times: r = r * r
1558        for _ in 0..s {
1559            r = r.matmul(&r)?;
1560        }
1561
1562        Ok(r)
1563    }
1564}
1565
1566// ---------------------------------------------------------------------------
1567// Tests
1568// ---------------------------------------------------------------------------
1569
1570#[cfg(test)]
1571mod tests {
1572    use super::*;
1573
1574    /// Helper: reconstruct A from SVD: U @ diag(S) @ Vt
1575    fn reconstruct_svd(u: &Tensor, s: &[f64], vt: &Tensor) -> Vec<f64> {
1576        let m = u.shape()[0];
1577        let k = s.len();
1578        let n = vt.shape()[1];
1579        let u_data = u.to_vec();
1580        let vt_data = vt.to_vec();
1581        let u_cols = u.shape()[1];
1582        let mut result = vec![0.0f64; m * n];
1583        for i in 0..m {
1584            for j in 0..n {
1585                let mut sum = 0.0;
1586                for l in 0..k {
1587                    sum += u_data[i * u_cols + l] * s[l] * vt_data[l * n + j];
1588                }
1589                result[i * n + j] = sum;
1590            }
1591        }
1592        result
1593    }
1594
1595    /// Helper: check two flat arrays are approximately equal
1596    fn assert_approx_eq(a: &[f64], b: &[f64], tol: f64, msg: &str) {
1597        assert_eq!(a.len(), b.len(), "{}: length mismatch", msg);
1598        for (i, (&ai, &bi)) in a.iter().zip(b.iter()).enumerate() {
1599            assert!(
1600                (ai - bi).abs() < tol,
1601                "{}: element [{}] differs: {} vs {} (diff={})",
1602                msg,
1603                i,
1604                ai,
1605                bi,
1606                (ai - bi).abs()
1607            );
1608        }
1609    }
1610
1611    #[test]
1612    fn test_svd_identity_2x2() {
1613        let eye = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
1614        let (u, s, vt) = eye.svd().unwrap();
1615        // Singular values should be [1, 1]
1616        assert!((s[0] - 1.0).abs() < 1e-10, "s[0] = {}", s[0]);
1617        assert!((s[1] - 1.0).abs() < 1e-10, "s[1] = {}", s[1]);
1618        // Roundtrip
1619        let recon = reconstruct_svd(&u, &s, &vt);
1620        assert_approx_eq(&recon, &[1.0, 0.0, 0.0, 1.0], 1e-10, "SVD identity roundtrip");
1621    }
1622
1623    #[test]
1624    fn test_svd_identity_3x3() {
1625        let mut data = vec![0.0; 9];
1626        for i in 0..3 {
1627            data[i * 3 + i] = 1.0;
1628        }
1629        let eye = Tensor::from_vec(data.clone(), &[3, 3]).unwrap();
1630        let (u, s, vt) = eye.svd().unwrap();
1631        for &si in &s {
1632            assert!((si - 1.0).abs() < 1e-10, "singular value = {}", si);
1633        }
1634        let recon = reconstruct_svd(&u, &s, &vt);
1635        assert_approx_eq(&recon, &data, 1e-10, "SVD 3x3 identity roundtrip");
1636    }
1637
1638    #[test]
1639    fn test_svd_known_matrix() {
1640        // A = [[3, 0], [0, 2]] — singular values should be 3, 2
1641        let a = Tensor::from_vec(vec![3.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
1642        let (_u, s, _vt) = a.svd().unwrap();
1643        assert!((s[0] - 3.0).abs() < 1e-10, "s[0] = {}", s[0]);
1644        assert!((s[1] - 2.0).abs() < 1e-10, "s[1] = {}", s[1]);
1645    }
1646
1647    #[test]
1648    fn test_svd_roundtrip_general() {
1649        let a = Tensor::from_vec(
1650            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.5],
1651            &[3, 3],
1652        )
1653        .unwrap();
1654        let (u, s, vt) = a.svd().unwrap();
1655        let recon = reconstruct_svd(&u, &s, &vt);
1656        let original = a.to_vec();
1657        assert_approx_eq(&recon, &original, 1e-8, "SVD general roundtrip");
1658    }
1659
1660    #[test]
1661    fn test_svd_rectangular_tall() {
1662        // 3x2 matrix
1663        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
1664        let (u, s, vt) = a.svd().unwrap();
1665        assert_eq!(u.shape(), &[3, 2]);
1666        assert_eq!(s.len(), 2);
1667        assert_eq!(vt.shape(), &[2, 2]);
1668        let recon = reconstruct_svd(&u, &s, &vt);
1669        let original = a.to_vec();
1670        assert_approx_eq(&recon, &original, 1e-8, "SVD tall rectangular roundtrip");
1671    }
1672
1673    #[test]
1674    fn test_svd_rectangular_wide() {
1675        // 2x3 matrix
1676        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1677        let (u, s, vt) = a.svd().unwrap();
1678        assert_eq!(u.shape(), &[2, 2]);
1679        assert_eq!(s.len(), 2);
1680        assert_eq!(vt.shape(), &[2, 3]);
1681        let recon = reconstruct_svd(&u, &s, &vt);
1682        let original = a.to_vec();
1683        assert_approx_eq(&recon, &original, 1e-8, "SVD wide rectangular roundtrip");
1684    }
1685
1686    #[test]
1687    fn test_svd_singular_values_descending() {
1688        let a = Tensor::from_vec(
1689            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1690            &[3, 3],
1691        )
1692        .unwrap();
1693        let (_, s, _) = a.svd().unwrap();
1694        for i in 0..s.len() - 1 {
1695            assert!(s[i] >= s[i + 1], "singular values not descending: {} < {}", s[i], s[i + 1]);
1696        }
1697    }
1698
1699    #[test]
1700    fn test_svd_truncated_basic() {
1701        let a = Tensor::from_vec(
1702            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1703            &[3, 3],
1704        )
1705        .unwrap();
1706        let (u, s, vt) = a.svd_truncated(2).unwrap();
1707        assert_eq!(u.shape(), &[3, 2]);
1708        assert_eq!(s.len(), 2);
1709        assert_eq!(vt.shape(), &[2, 3]);
1710    }
1711
1712    #[test]
1713    fn test_svd_deterministic() {
1714        let a = Tensor::from_vec(
1715            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1716            &[3, 3],
1717        )
1718        .unwrap();
1719        let (u1, s1, vt1) = a.svd().unwrap();
1720        let (u2, s2, vt2) = a.svd().unwrap();
1721        assert_eq!(u1.to_vec(), u2.to_vec(), "U not deterministic");
1722        assert_eq!(s1, s2, "S not deterministic");
1723        assert_eq!(vt1.to_vec(), vt2.to_vec(), "Vt not deterministic");
1724    }
1725
1726    #[test]
1727    fn test_pinv_square() {
1728        // For a non-singular square matrix, pinv(A) ≈ inv(A)
1729        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 5.0], &[2, 2]).unwrap();
1730        let a_pinv = a.pinv().unwrap();
1731        // Check A @ A+ @ A ≈ A
1732        let a_ap = a.matmul(&a_pinv).unwrap();
1733        let a_ap_a = a_ap.matmul(&a).unwrap();
1734        assert_approx_eq(&a_ap_a.to_vec(), &a.to_vec(), 1e-8, "pinv square: A @ A+ @ A ≈ A");
1735    }
1736
1737    #[test]
1738    fn test_pinv_identity() {
1739        let eye = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
1740        let eye_pinv = eye.pinv().unwrap();
1741        assert_approx_eq(
1742            &eye_pinv.to_vec(),
1743            &[1.0, 0.0, 0.0, 1.0],
1744            1e-10,
1745            "pinv of identity",
1746        );
1747    }
1748
1749    #[test]
1750    fn test_pinv_rectangular() {
1751        // Tall matrix: 3x2
1752        let a = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[3, 2]).unwrap();
1753        let a_pinv = a.pinv().unwrap();
1754        assert_eq!(a_pinv.shape(), &[2, 3]);
1755        // A @ A+ @ A ≈ A
1756        let a_ap = a.matmul(&a_pinv).unwrap();
1757        let a_ap_a = a_ap.matmul(&a).unwrap();
1758        assert_approx_eq(&a_ap_a.to_vec(), &a.to_vec(), 1e-8, "pinv rect: A @ A+ @ A ≈ A");
1759    }
1760
1761    #[test]
1762    fn test_pinv_moore_penrose_conditions() {
1763        // All 4 Moore-Penrose conditions for a general matrix
1764        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1765        let ap = a.pinv().unwrap();
1766        // Condition 1: A @ A+ @ A ≈ A
1767        let aapa = a.matmul(&ap).unwrap().matmul(&a).unwrap();
1768        assert_approx_eq(&aapa.to_vec(), &a.to_vec(), 1e-6, "MP condition 1");
1769        // Condition 2: A+ @ A @ A+ ≈ A+
1770        let apaap = ap.matmul(&a).unwrap().matmul(&ap).unwrap();
1771        assert_approx_eq(&apaap.to_vec(), &ap.to_vec(), 1e-6, "MP condition 2");
1772    }
1773
1774    #[test]
1775    fn test_pinv_with_tol() {
1776        let a = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1e-16], &[2, 2]).unwrap();
1777        // With large tolerance, treat 1e-16 as zero
1778        let ap = a.pinv_with_tol(1e-10).unwrap();
1779        let ap_data = ap.to_vec();
1780        // Should act like pseudoinverse of [[1,0],[0,0]]
1781        assert!((ap_data[0] - 1.0).abs() < 1e-8, "pinv_with_tol [0,0]");
1782        assert!(ap_data[3].abs() < 1e-8, "pinv_with_tol [1,1] should be ~0");
1783    }
1784
1785    #[test]
1786    fn test_svd_1x1() {
1787        let a = Tensor::from_vec(vec![5.0], &[1, 1]).unwrap();
1788        let (u, s, vt) = a.svd().unwrap();
1789        assert!((s[0] - 5.0).abs() < 1e-10);
1790        let recon = reconstruct_svd(&u, &s, &vt);
1791        assert_approx_eq(&recon, &[5.0], 1e-10, "SVD 1x1 roundtrip");
1792    }
1793}
1794