Skip to main content

pc_rl_core/linalg/
golub_kahan.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-04-04
4
5//! Golub-Kahan bidiagonalization SVD algorithm.
6//!
7//! Provides an O(n^3) SVD implementation as an upgrade path from the
8//! existing Jacobi eigendecomposition approach in [`CpuLinAlg`](super::cpu::CpuLinAlg).
9//!
10//! # Overview
11//!
12//! The algorithm proceeds in two phases:
13//! 1. **Householder bidiagonalization** reduces the input matrix to
14//!    upper bidiagonal form.
15//! 2. **Implicit QR iteration** with Wilkinson shift extracts singular
16//!    values and vectors from the bidiagonal matrix.
17//!
18//! # Examples
19//!
20//! ```
21//! use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
22//!
23//! let svd = GolubKahanSvd::new()
24//!     .with_tolerance(1e-12)
25//!     .with_max_iter_factor(40);
26//! assert!((svd.tol - 1e-12).abs() < f64::EPSILON);
27//! assert_eq!(svd.max_iter_factor, 40);
28//! ```
29
30use std::fmt;
31
32use crate::error::PcError;
33use crate::matrix::Matrix;
34
35/// Error type for SVD decomposition failures.
36///
37/// Converts to [`PcError::ConfigValidation`] via the [`From`] impl,
38/// preserving the error message for upstream callers.
39///
40/// # Examples
41///
42/// ```
43/// use pc_rl_core::linalg::golub_kahan::SvdError;
44///
45/// let err = SvdError::Convergence { size: 10, iterations: 300 };
46/// assert!(format!("{err}").contains("10"));
47/// ```
48#[derive(Debug)]
49pub enum SvdError {
50    /// The iterative solver did not converge within the allowed iterations.
51    Convergence {
52        /// Matrix dimension (min(rows, cols)) that failed to converge.
53        size: usize,
54        /// Total iterations attempted before giving up.
55        iterations: usize,
56    },
57    /// The input matrix is invalid (e.g., contains NaN or Inf).
58    InvalidInput {
59        /// Human-readable description of the problem.
60        reason: String,
61    },
62}
63
64impl fmt::Display for SvdError {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            SvdError::Convergence { size, iterations } => {
68                write!(
69                    f,
70                    "SVD failed to converge for matrix of size {size} \
71                     after {iterations} iterations"
72                )
73            }
74            SvdError::InvalidInput { reason } => {
75                write!(f, "SVD invalid input: {reason}")
76            }
77        }
78    }
79}
80
81impl std::error::Error for SvdError {}
82
83impl From<SvdError> for PcError {
84    fn from(e: SvdError) -> Self {
85        PcError::ConfigValidation(e.to_string())
86    }
87}
88
89/// Golub-Kahan bidiagonalization SVD solver.
90///
91/// Decomposes a matrix `A` into `U * diag(S) * V^T` using Householder
92/// bidiagonalization followed by implicit QR iteration with Wilkinson shift.
93///
94/// # Fields
95///
96/// * `tol` - Convergence tolerance for off-diagonal elements (default: 1e-14).
97/// * `max_iter_factor` - Maximum iterations = `factor * n²` where `n = min(rows, cols)` (default: 30).
98///
99/// # Examples
100///
101/// ```
102/// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
103///
104/// let svd = GolubKahanSvd::new();
105/// assert!((svd.tol - 1e-14).abs() < f64::EPSILON);
106/// assert_eq!(svd.max_iter_factor, 30);
107/// ```
108#[derive(Debug, Clone)]
109pub struct GolubKahanSvd {
110    /// Convergence tolerance for off-diagonal elements.
111    pub tol: f64,
112    /// Maximum iterations as a multiple of `min(rows, cols)²`.
113    pub max_iter_factor: usize,
114}
115
116impl GolubKahanSvd {
117    /// Creates a new solver with default parameters.
118    ///
119    /// Defaults: `tol = 1e-14`, `max_iter_factor = 30`.
120    ///
121    /// # Examples
122    ///
123    /// ```
124    /// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
125    ///
126    /// let svd = GolubKahanSvd::new();
127    /// assert!((svd.tol - 1e-14).abs() < f64::EPSILON);
128    /// ```
129    pub fn new() -> Self {
130        Self {
131            tol: 1e-14,
132            max_iter_factor: 30,
133        }
134    }
135
136    /// Sets a custom convergence tolerance.
137    ///
138    /// # Arguments
139    ///
140    /// * `tol` - The convergence threshold for off-diagonal elements.
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
146    ///
147    /// let svd = GolubKahanSvd::new().with_tolerance(1e-8);
148    /// assert!((svd.tol - 1e-8).abs() < f64::EPSILON);
149    /// ```
150    pub fn with_tolerance(mut self, tol: f64) -> Self {
151        self.tol = tol;
152        self
153    }
154
155    /// Sets a custom maximum iteration factor.
156    ///
157    /// Total maximum iterations will be `factor * min(rows, cols)²`.
158    ///
159    /// # Arguments
160    ///
161    /// * `factor` - Multiplier for the iteration limit.
162    ///
163    /// # Examples
164    ///
165    /// ```
166    /// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
167    ///
168    /// let svd = GolubKahanSvd::new().with_max_iter_factor(50);
169    /// assert_eq!(svd.max_iter_factor, 50);
170    /// ```
171    pub fn with_max_iter_factor(mut self, factor: usize) -> Self {
172        self.max_iter_factor = factor;
173        self
174    }
175
176    /// Computes the SVD of matrix `a`: `A = U * diag(S) * V^T`.
177    ///
178    /// Returns `(U, S, V)` where:
179    /// - `U` is `(m, k)` with orthonormal columns,
180    /// - `S` is a `Vec<f64>` of `k` non-negative singular values in descending order,
181    /// - `V` is `(n, k)` with orthonormal columns,
182    /// - `k = min(m, n)`.
183    ///
184    /// # Arguments
185    ///
186    /// * `a` - The input matrix to decompose.
187    ///
188    /// # Errors
189    ///
190    /// Returns [`SvdError::InvalidInput`] if the matrix contains NaN or Inf values.
191    /// Returns [`SvdError::Convergence`] if the iterative solver does not converge.
192    ///
193    /// # Examples
194    ///
195    /// ```
196    /// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
197    /// use pc_rl_core::matrix::Matrix;
198    ///
199    /// let svd = GolubKahanSvd::new();
200    /// let empty = Matrix::zeros(0, 0);
201    /// let (u, s, v) = svd.compute(&empty).unwrap();
202    /// assert_eq!(s.len(), 0);
203    /// ```
204    pub fn compute(&self, a: &Matrix) -> Result<(Matrix, Vec<f64>, Matrix), SvdError> {
205        // Validate: reject NaN/Inf
206        for &val in &a.data {
207            if val.is_nan() || val.is_infinite() {
208                return Err(SvdError::InvalidInput {
209                    reason: "matrix contains NaN or Inf".to_string(),
210                });
211            }
212        }
213
214        let m = a.rows;
215        let n = a.cols;
216
217        // Handle empty matrix
218        if m == 0 || n == 0 {
219            return Ok((Matrix::zeros(m, 0), Vec::new(), Matrix::zeros(n, 0)));
220        }
221
222        // If wide (m < n), transpose and swap U/V at the end
223        let transposed = m < n;
224        let (work_m, work_n, work_data) = if transposed {
225            let mut t = vec![0.0; m * n];
226            for r in 0..m {
227                for c in 0..n {
228                    t[c * m + r] = a.data[r * n + c];
229                }
230            }
231            (n, m, t)
232        } else {
233            (m, n, a.data.clone())
234        };
235
236        // k = min(work_m, work_n) = work_n since work_m >= work_n
237        let k = work_n;
238
239        // Handle 1x1 case directly
240        if k == 1 {
241            // For a single-column tall matrix, SVD is: u = a/||a||, s = ||a||, v = [1]
242            let norm: f64 = work_data.iter().map(|&x| x * x).sum::<f64>().sqrt();
243            if norm < self.tol {
244                let u_mat = make_identity(work_m, 1);
245                let v_mat = make_identity(1, 1);
246                return Self::finalize(u_mat, vec![0.0], v_mat, transposed);
247            }
248            let sign = if work_data[0] >= 0.0 { 1.0 } else { -1.0 };
249            let mut u_data = vec![0.0; work_m];
250            for i in 0..work_m {
251                u_data[i] = work_data[i] * sign / norm;
252            }
253            let u_mat = Matrix {
254                data: u_data,
255                rows: work_m,
256                cols: 1,
257            };
258            let v_mat = Matrix {
259                data: vec![sign],
260                rows: 1,
261                cols: 1,
262            };
263            return Self::finalize(u_mat, vec![norm], v_mat, transposed);
264        }
265
266        // Phase 1: Householder bidiagonalization
267        // Work on a copy. After this: A is overwritten, diag and superdiag extracted.
268        let mut w = work_data;
269        let mut u_acc = make_identity(work_m, work_m);
270        let mut v_acc = make_identity(work_n, work_n);
271
272        let mut diag = vec![0.0; k];
273        let mut superdiag = vec![0.0; k.saturating_sub(1)];
274
275        householder_bidiag(
276            &mut w,
277            work_m,
278            work_n,
279            &mut u_acc,
280            &mut v_acc,
281            &mut diag,
282            &mut superdiag,
283        );
284
285        // Phase 2: Implicit QR iteration on bidiagonal
286        let max_iter = self.max_iter_factor * k * k;
287        implicit_qr_svd(
288            &mut diag,
289            &mut superdiag,
290            &mut u_acc,
291            &mut v_acc,
292            work_m,
293            work_n,
294            k,
295            self.tol,
296            max_iter,
297        )?;
298
299        // Phase 3: Make singular values non-negative, sort descending
300        for (i, d) in diag.iter_mut().enumerate().take(k) {
301            if *d < 0.0 {
302                *d = -*d;
303                // Flip corresponding column of U
304                for r in 0..work_m {
305                    u_acc.data[r * work_m + i] = -u_acc.data[r * work_m + i];
306                }
307            }
308        }
309
310        // Sort by descending singular value
311        let mut indices: Vec<usize> = (0..k).collect();
312        indices.sort_by(|&a_idx, &b_idx| diag[b_idx].total_cmp(&diag[a_idx]));
313
314        let sorted_s: Vec<f64> = indices.iter().map(|&i| diag[i]).collect();
315
316        // Extract thin U (work_m x k) with sorted columns
317        let mut u_thin = Matrix::zeros(work_m, k);
318        for (new_col, &old_col) in indices.iter().enumerate() {
319            for r in 0..work_m {
320                u_thin.data[r * k + new_col] = u_acc.data[r * work_m + old_col];
321            }
322        }
323
324        // Extract thin V (work_n x k) with sorted columns
325        let mut v_thin = Matrix::zeros(work_n, k);
326        for (new_col, &old_col) in indices.iter().enumerate() {
327            for r in 0..work_n {
328                v_thin.data[r * k + new_col] = v_acc.data[r * work_n + old_col];
329            }
330        }
331
332        Self::finalize(u_thin, sorted_s, v_thin, transposed)
333    }
334
335    /// Swaps U and V if the input was transposed, returning the final result.
336    fn finalize(
337        u: Matrix,
338        s: Vec<f64>,
339        v: Matrix,
340        transposed: bool,
341    ) -> Result<(Matrix, Vec<f64>, Matrix), SvdError> {
342        if transposed {
343            Ok((v, s, u))
344        } else {
345            Ok((u, s, v))
346        }
347    }
348}
349
350impl Default for GolubKahanSvd {
351    fn default() -> Self {
352        Self::new()
353    }
354}
355
356/// Creates an identity-like matrix of size `rows x cols`.
357///
358/// Places 1.0 on the diagonal up to `min(rows, cols)`.
359fn make_identity(rows: usize, cols: usize) -> Matrix {
360    let mut data = vec![0.0; rows * cols];
361    let k = rows.min(cols);
362    for i in 0..k {
363        data[i * cols + i] = 1.0;
364    }
365    Matrix { data, rows, cols }
366}
367
368/// Householder bidiagonalization: reduces an m x n matrix (m >= n)
369/// to upper bidiagonal form, accumulating transformations into U and V.
370///
371/// After completion, `diag` contains the diagonal and `superdiag` the
372/// superdiagonal of the bidiagonal matrix B such that A = U * B * V^T.
373///
374/// # Arguments
375///
376/// * `w` - Row-major m x n matrix data, overwritten during computation.
377/// * `m` - Number of rows.
378/// * `n` - Number of columns.
379/// * `u_acc` - m x m identity matrix, accumulates left reflections.
380/// * `v_acc` - n x n identity matrix, accumulates right reflections.
381/// * `diag` - Output diagonal of bidiagonal B (length n).
382/// * `superdiag` - Output superdiagonal of bidiagonal B (length n-1).
383fn householder_bidiag(
384    w: &mut [f64],
385    m: usize,
386    n: usize,
387    u_acc: &mut Matrix,
388    v_acc: &mut Matrix,
389    diag: &mut [f64],
390    superdiag: &mut [f64],
391) {
392    for j in 0..n {
393        // Left Householder: zero below w[j][j] in column j
394        {
395            let mut col = vec![0.0; m - j];
396            for i in j..m {
397                col[i - j] = w[i * n + j];
398            }
399            let (v_house, beta) = householder_vector(&col);
400            if beta != 0.0 {
401                // Apply H = I - beta * v * v^T to w[j..m, j..n]
402                apply_householder_left(w, m, n, j, j, &v_house, beta);
403                // Accumulate into U: U = U * H (apply on right to u_acc)
404                apply_householder_right_to_matrix(u_acc, m, m, j, &v_house, beta);
405            }
406        }
407        diag[j] = w[j * n + j];
408
409        // Right Householder: zero beyond w[j][j+1] in row j
410        if j + 2 <= n {
411            let start = j + 1;
412            let mut row = vec![0.0; n - start];
413            for c in start..n {
414                row[c - start] = w[j * n + c];
415            }
416            let (v_house, beta) = householder_vector(&row);
417            if beta != 0.0 {
418                // Apply H = I - beta * v * v^T to w[j..m, (j+1)..n] on the right
419                apply_householder_right(w, m, n, j, start, &v_house, beta);
420                // Accumulate into V: V = V * H
421                apply_householder_right_to_matrix(v_acc, n, n, start, &v_house, beta);
422            }
423            if j < n - 1 {
424                superdiag[j] = w[j * n + j + 1];
425            }
426        } else if j < n - 1 {
427            superdiag[j] = w[j * n + j + 1];
428        }
429    }
430}
431
432/// Computes a Householder vector `v` and scalar `beta` such that
433/// `(I - beta * v * v^T) * x` zeroes all elements below the first.
434///
435/// Returns `(v, beta)` where `v[0] = 1.0`.
436///
437/// # Arguments
438///
439/// * `x` - Input vector to reflect.
440fn householder_vector(x: &[f64]) -> (Vec<f64>, f64) {
441    let len = x.len();
442    if len == 0 {
443        return (Vec::new(), 0.0);
444    }
445    if len == 1 {
446        return (vec![1.0], 0.0);
447    }
448
449    let mut sigma = 0.0;
450    for &xi in &x[1..] {
451        sigma += xi * xi;
452    }
453
454    let mut v = vec![0.0; len];
455    v[0] = 1.0;
456    v[1..len].copy_from_slice(&x[1..len]);
457
458    if sigma < 1e-300 {
459        return (v, 0.0);
460    }
461
462    let norm_x = (x[0] * x[0] + sigma).sqrt();
463    // Choose sign to avoid cancellation
464    if x[0] <= 0.0 {
465        v[0] = x[0] - norm_x;
466    } else {
467        v[0] = -sigma / (x[0] + norm_x);
468    }
469
470    let beta = 2.0 * v[0] * v[0] / (sigma + v[0] * v[0]);
471    // Normalize so v[0] = 1
472    let v0 = v[0];
473    for vi in v.iter_mut() {
474        *vi /= v0;
475    }
476
477    (v, beta)
478}
479
480/// Applies a left Householder reflection: w[row_start..m, col_start..n] =
481/// (I - beta * v * v^T) * w[row_start..m, col_start..n].
482fn apply_householder_left(
483    w: &mut [f64],
484    _m: usize,
485    n: usize,
486    row_start: usize,
487    col_start: usize,
488    v: &[f64],
489    beta: f64,
490) {
491    let v_len = v.len();
492    let num_cols = n - col_start;
493    // Compute p = beta * v^T * W_sub for each column
494    let mut p = vec![0.0; num_cols];
495    for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
496        let row = row_start + vi_idx;
497        for c in 0..num_cols {
498            p[c] += vi * w[row * n + col_start + c];
499        }
500    }
501    // W_sub -= v * p^T
502    for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
503        let row = row_start + vi_idx;
504        for c in 0..num_cols {
505            w[row * n + col_start + c] -= beta * vi * p[c];
506        }
507    }
508}
509
510/// Applies a right Householder reflection: w[row_start..m, col_start..n] =
511/// w[row_start..m, col_start..n] * (I - beta * v * v^T).
512fn apply_householder_right(
513    w: &mut [f64],
514    m: usize,
515    n: usize,
516    row_start: usize,
517    col_start: usize,
518    v: &[f64],
519    beta: f64,
520) {
521    let v_len = v.len();
522    let num_rows = m - row_start;
523    // For each row, compute dot = row . v, then row -= beta * dot * v
524    for ri in 0..num_rows {
525        let row = row_start + ri;
526        let mut dot = 0.0;
527        for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
528            dot += w[row * n + col_start + vi_idx] * vi;
529        }
530        for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
531            w[row * n + col_start + vi_idx] -= beta * dot * vi;
532        }
533    }
534}
535
536/// Applies a Householder reflection on the right to a full accumulator matrix:
537/// `acc = acc * H` where H operates on rows `start..start+v.len()`.
538///
539/// Equivalent to: for each row of acc, update columns start..start+v.len().
540fn apply_householder_right_to_matrix(
541    acc: &mut Matrix,
542    rows: usize,
543    cols: usize,
544    col_start: usize,
545    v: &[f64],
546    beta: f64,
547) {
548    let v_len = v.len();
549    for r in 0..rows {
550        let mut dot = 0.0;
551        for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
552            dot += acc.data[r * cols + col_start + vi_idx] * vi;
553        }
554        for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
555            acc.data[r * cols + col_start + vi_idx] -= beta * dot * vi;
556        }
557    }
558}
559
560/// Computes a Givens rotation `(c, s)` such that
561/// `[c s; -s c]^T * [a; b] = [r; 0]`.
562///
563/// # Arguments
564///
565/// * `a` - First element.
566/// * `b` - Second element (to be zeroed).
567///
568/// # Returns
569///
570/// `(c, s)` cosine and sine of the rotation angle.
571fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
572    if b.abs() < 1e-300 {
573        return (1.0, 0.0);
574    }
575    if a.abs() < 1e-300 {
576        return (0.0, b.signum());
577    }
578    if b.abs() > a.abs() {
579        let tau = a / b;
580        let s = (1.0 + tau * tau).sqrt().recip() * b.signum();
581        let c = s * tau;
582        (c, s)
583    } else {
584        let tau = b / a;
585        let c = (1.0 + tau * tau).sqrt().recip() * a.signum();
586        let s = c * tau;
587        (c, s)
588    }
589}
590
591/// Applies a Givens rotation to columns `i` and `j` of a matrix,
592/// operating on all rows: `[col_i, col_j] = [col_i, col_j] * G`.
593fn apply_givens_cols(
594    mat: &mut Matrix,
595    rows: usize,
596    stride: usize,
597    i: usize,
598    j: usize,
599    c: f64,
600    s: f64,
601) {
602    for r in 0..rows {
603        let a = mat.data[r * stride + i];
604        let b = mat.data[r * stride + j];
605        mat.data[r * stride + i] = c * a + s * b;
606        mat.data[r * stride + j] = -s * a + c * b;
607    }
608}
609
610/// Implicit QR iteration on a bidiagonal matrix to extract singular values.
611///
612/// Operates on the diagonal `diag` and superdiagonal `superdiag`, applying
613/// Givens rotations accumulated into `u_acc` and `v_acc`.
614///
615/// # Errors
616///
617/// Returns [`SvdError::Convergence`] if the iteration does not converge
618/// within `max_iter` total sweeps.
619#[allow(clippy::too_many_arguments)]
620fn implicit_qr_svd(
621    diag: &mut [f64],
622    superdiag: &mut [f64],
623    u_acc: &mut Matrix,
624    v_acc: &mut Matrix,
625    u_rows: usize,
626    v_rows: usize,
627    k: usize,
628    tol: f64,
629    max_iter: usize,
630) -> Result<(), SvdError> {
631    if k <= 1 {
632        return Ok(());
633    }
634
635    let mut iter_count = 0usize;
636
637    loop {
638        // Find the largest q such that B(k-q-1:k-1, k-q-1:k-1) is diagonal
639        // (i.e., superdiag entries from end that are zero)
640        let mut q = 0usize;
641        while q < k - 1 {
642            let idx = k - 2 - q;
643            let thresh = tol * (diag[idx].abs() + diag[idx + 1].abs());
644            if superdiag[idx].abs() <= thresh.max(tol * 1e-2) {
645                superdiag[idx] = 0.0;
646                q += 1;
647            } else {
648                break;
649            }
650        }
651
652        if q >= k - 1 {
653            // All superdiagonal elements are zero — converged
654            break;
655        }
656
657        // Active block is diag[p..k-q], superdiag[p..k-q-1]
658        let block_end = k - q; // exclusive index into diag
659                               // Find p: largest index such that superdiag[p-1] is zero (or p=0)
660        let mut p = block_end - 1;
661        while p > 0 {
662            let idx = p - 1;
663            let thresh = tol * (diag[idx].abs() + diag[idx + 1].abs());
664            if superdiag[idx].abs() <= thresh.max(tol * 1e-2) {
665                superdiag[idx] = 0.0;
666                break;
667            }
668            p -= 1;
669        }
670
671        let block_size = block_end - p;
672        if block_size <= 1 {
673            continue;
674        }
675
676        // Check for zero diagonal entry in the block — if found, zero the
677        // superdiagonal by rotation and restart
678        let mut found_zero_diag = false;
679        for i in p..block_end {
680            if diag[i].abs() < tol * 1e-2 {
681                // Zero out superdiag element adjacent to this zero diagonal
682                if i < block_end - 1 && superdiag[i].abs() > 0.0 {
683                    zero_superdiag_row(diag, superdiag, u_acc, u_rows, i, block_end);
684                } else if i > p && superdiag[i - 1].abs() > 0.0 {
685                    zero_superdiag_col(diag, superdiag, v_acc, v_rows, i, p);
686                }
687                found_zero_diag = true;
688                break;
689            }
690        }
691        if found_zero_diag {
692            iter_count += 1;
693            if iter_count > max_iter {
694                return Err(SvdError::Convergence {
695                    size: k,
696                    iterations: max_iter,
697                });
698            }
699            continue;
700        }
701
702        // Wilkinson shift from trailing 2x2 of B^T * B
703        let n1 = block_end - 1;
704        let n2 = block_end - 2;
705        let d_n1 = diag[n1];
706        let d_n2 = diag[n2];
707        let e_n2 = superdiag[n2];
708        // T = B^T B trailing 2x2:
709        // [d_n2^2 + e_{n2-1}^2,   d_n2 * e_n2          ]
710        // [d_n2 * e_n2,            d_n1^2 + e_n2^2      ]
711        let e_n3_sq = if n2 > p {
712            superdiag[n2 - 1] * superdiag[n2 - 1]
713        } else {
714            0.0
715        };
716        let t11 = d_n2 * d_n2 + e_n3_sq;
717        let t12 = d_n2 * e_n2;
718        let t22 = d_n1 * d_n1 + e_n2 * e_n2;
719
720        let shift = wilkinson_shift(t11, t12, t22);
721
722        // Golub-Kahan SVD step (bulge chase)
723        golub_kahan_step(
724            diag, superdiag, u_acc, v_acc, u_rows, v_rows, p, block_end, shift,
725        );
726
727        iter_count += 1;
728        if iter_count > max_iter {
729            return Err(SvdError::Convergence {
730                size: k,
731                iterations: max_iter,
732            });
733        }
734    }
735
736    Ok(())
737}
738
739/// Computes the Wilkinson shift for the trailing 2x2 of B^T * B.
740///
741/// Given the 2x2 matrix `[[a, b], [b, d]]`, returns the eigenvalue
742/// closest to `d`.
743fn wilkinson_shift(a: f64, b: f64, d: f64) -> f64 {
744    let delta = (a - d) * 0.5;
745    if delta.abs() < 1e-300 && b.abs() < 1e-300 {
746        return d;
747    }
748    let sign = if delta >= 0.0 { 1.0 } else { -1.0 };
749    d - b * b / (delta + sign * (delta * delta + b * b).sqrt())
750}
751
752/// Performs one implicit QR step (bulge chase) on the bidiagonal matrix.
753///
754/// This is the core of the Golub-Kahan SVD iteration (Algorithm 8.6.2
755/// from Golub & Van Loan).
756#[allow(clippy::too_many_arguments)]
757fn golub_kahan_step(
758    diag: &mut [f64],
759    superdiag: &mut [f64],
760    u_acc: &mut Matrix,
761    v_acc: &mut Matrix,
762    u_rows: usize,
763    v_rows: usize,
764    p: usize,
765    block_end: usize,
766    shift: f64,
767) {
768    let mut y = diag[p] * diag[p] - shift;
769    let mut z = diag[p] * superdiag[p];
770
771    for i in p..block_end - 1 {
772        // Right Givens rotation to zero z (applied to V)
773        let (c, s) = givens_rotation(y, z);
774        if i > p {
775            superdiag[i - 1] = c * superdiag[i - 1] + s * z;
776            // Note: the component that was z is now zero
777        }
778        let old_d_i = diag[i];
779        let old_e_i = superdiag[i];
780        diag[i] = c * old_d_i + s * old_e_i;
781        superdiag[i] = -s * old_d_i + c * old_e_i;
782        let old_d_i1 = diag[i + 1];
783        z = s * old_d_i1;
784        diag[i + 1] = c * old_d_i1;
785
786        // Accumulate into V
787        apply_givens_cols(v_acc, v_rows, v_rows, i, i + 1, c, s);
788
789        // Left Givens rotation to zero z (applied to U)
790        let (c, s) = givens_rotation(diag[i], z);
791        diag[i] = c * diag[i] + s * z;
792        let old_e_i = superdiag[i];
793        let old_d_i1 = diag[i + 1];
794        superdiag[i] = c * old_e_i + s * old_d_i1;
795        diag[i + 1] = -s * old_e_i + c * old_d_i1;
796        if i + 1 < block_end - 1 {
797            let old_e_i1 = superdiag[i + 1];
798            z = s * old_e_i1;
799            superdiag[i + 1] = c * old_e_i1;
800        }
801        y = superdiag[i];
802
803        // Accumulate into U
804        apply_givens_cols(u_acc, u_rows, u_rows, i, i + 1, c, s);
805    }
806}
807
808/// Zeros a superdiagonal element when a diagonal entry is zero,
809/// chasing the bulge rightward via left rotations.
810///
811/// When `diag[zero_idx] ≈ 0`, the left rotation `G^T * B` between rows
812/// `zero_idx` and `j+1` zeros the fill-in but creates a new bulge from
813/// `superdiag[j+1]`. The bulge is tracked explicitly and passed forward
814/// to the next iteration rather than stored back into `superdiag`.
815fn zero_superdiag_row(
816    diag: &mut [f64],
817    superdiag: &mut [f64],
818    u_acc: &mut Matrix,
819    u_rows: usize,
820    zero_idx: usize,
821    block_end: usize,
822) {
823    let mut bulge = superdiag[zero_idx];
824    superdiag[zero_idx] = 0.0;
825
826    for j in zero_idx..block_end - 1 {
827        let (c, s) = givens_rotation(diag[j + 1], bulge);
828        diag[j + 1] = c * diag[j + 1] + s * bulge;
829        // bulge position is now zero by construction
830        if j + 1 < block_end - 1 {
831            let old_e = superdiag[j + 1];
832            superdiag[j + 1] = c * old_e;
833            bulge = -s * old_e;
834            apply_givens_cols(u_acc, u_rows, u_rows, j + 1, zero_idx, c, s);
835            if bulge.abs() < 1e-300 {
836                break;
837            }
838        } else {
839            apply_givens_cols(u_acc, u_rows, u_rows, j + 1, zero_idx, c, s);
840        }
841    }
842}
843
844/// Zeros a superdiagonal element when a diagonal entry is zero,
845/// chasing the bulge leftward via right rotations.
846///
847/// When `diag[zero_idx] ≈ 0`, the right rotation `B * G` between columns
848/// `zero_idx` and `j` zeros the fill-in but creates a new bulge from
849/// `superdiag[j-1]`. The bulge is tracked explicitly and passed backward
850/// to the next iteration rather than stored back into `superdiag`.
851fn zero_superdiag_col(
852    diag: &mut [f64],
853    superdiag: &mut [f64],
854    v_acc: &mut Matrix,
855    v_rows: usize,
856    zero_idx: usize,
857    block_start: usize,
858) {
859    let mut bulge = superdiag[zero_idx - 1];
860    superdiag[zero_idx - 1] = 0.0;
861
862    for j in (block_start..zero_idx).rev() {
863        let (c, s) = givens_rotation(diag[j], bulge);
864        diag[j] = c * diag[j] + s * bulge;
865        // bulge position is now zero by construction
866        if j > block_start {
867            let old_e = superdiag[j - 1];
868            superdiag[j - 1] = c * old_e;
869            bulge = -s * old_e;
870            apply_givens_cols(v_acc, v_rows, v_rows, j, zero_idx, c, s);
871            if bulge.abs() < 1e-300 {
872                break;
873            }
874        } else {
875            apply_givens_cols(v_acc, v_rows, v_rows, j, zero_idx, c, s);
876        }
877    }
878}
879
880#[cfg(test)]
881mod tests {
882    use super::*;
883
884    #[test]
885    fn test_new_returns_default_parameters() {
886        let svd = GolubKahanSvd::new();
887        assert!((svd.tol - 1e-14).abs() < f64::EPSILON);
888        assert_eq!(svd.max_iter_factor, 30);
889    }
890
891    #[test]
892    fn test_with_tolerance_sets_custom_tol() {
893        let svd = GolubKahanSvd::new().with_tolerance(1e-8);
894        assert!((svd.tol - 1e-8).abs() < f64::EPSILON);
895    }
896
897    #[test]
898    fn test_with_max_iter_factor_sets_custom_factor() {
899        let svd = GolubKahanSvd::new().with_max_iter_factor(50);
900        assert_eq!(svd.max_iter_factor, 50);
901    }
902
903    #[test]
904    fn test_default_trait_matches_new() {
905        let a = GolubKahanSvd::new();
906        let b = GolubKahanSvd::default();
907        assert!((a.tol - b.tol).abs() < f64::EPSILON);
908        assert_eq!(a.max_iter_factor, b.max_iter_factor);
909    }
910
911    #[test]
912    fn test_svd_error_display() {
913        let err = SvdError::Convergence {
914            size: 10,
915            iterations: 300,
916        };
917        let msg = format!("{err}");
918        assert!(msg.contains("10"));
919        assert!(msg.contains("300"));
920    }
921
922    #[test]
923    fn test_svd_error_converts_to_pc_error() {
924        let err = SvdError::Convergence {
925            size: 5,
926            iterations: 150,
927        };
928        let pc_err: crate::error::PcError = err.into();
929        assert!(matches!(pc_err, crate::error::PcError::ConfigValidation(_)));
930    }
931
932    #[test]
933    fn test_empty_matrix() {
934        // B6: 0x0 matrix returns empty results
935        let a = Matrix::zeros(0, 0);
936        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
937        assert_eq!(u.rows, 0);
938        assert_eq!(u.cols, 0);
939        assert!(s.is_empty());
940        assert_eq!(v.rows, 0);
941        assert_eq!(v.cols, 0);
942    }
943
944    #[test]
945    fn test_nan_input_returns_error() {
946        // B13: NaN input returns Err
947        let a = Matrix {
948            data: vec![1.0, f64::NAN, 3.0, 4.0],
949            rows: 2,
950            cols: 2,
951        };
952        let result = GolubKahanSvd::new().compute(&a);
953        assert!(result.is_err());
954        let err = result.unwrap_err();
955        assert!(matches!(err, SvdError::InvalidInput { .. }));
956    }
957
958    #[test]
959    fn test_inf_input_returns_error() {
960        // B14: Inf input returns Err
961        let a = Matrix {
962            data: vec![1.0, f64::INFINITY, 3.0, 4.0],
963            rows: 2,
964            cols: 2,
965        };
966        let result = GolubKahanSvd::new().compute(&a);
967        assert!(result.is_err());
968        let err = result.unwrap_err();
969        assert!(matches!(err, SvdError::InvalidInput { .. }));
970    }
971
972    #[test]
973    fn test_neg_inf_input_returns_error() {
974        let a = Matrix {
975            data: vec![f64::NEG_INFINITY, 2.0, 3.0, 4.0],
976            rows: 2,
977            cols: 2,
978        };
979        let result = GolubKahanSvd::new().compute(&a);
980        assert!(result.is_err());
981    }
982
983    // ---- Helper functions for SVD correctness tests ----
984
985    /// Helper: multiply matrices using raw data.
986    fn mat_mul_raw(a: &Matrix, b: &Matrix) -> Matrix {
987        assert_eq!(a.cols, b.rows);
988        let mut c = Matrix::zeros(a.rows, b.cols);
989        for i in 0..a.rows {
990            for k in 0..a.cols {
991                let aik = a.data[i * a.cols + k];
992                for j in 0..b.cols {
993                    c.data[i * c.cols + j] += aik * b.data[k * b.cols + j];
994                }
995            }
996        }
997        c
998    }
999
1000    fn transpose_raw(a: &Matrix) -> Matrix {
1001        let mut t = Matrix::zeros(a.cols, a.rows);
1002        for r in 0..a.rows {
1003            for c in 0..a.cols {
1004                t.data[c * t.cols + r] = a.data[r * a.cols + c];
1005            }
1006        }
1007        t
1008    }
1009
1010    fn assert_reconstruction(a: &Matrix, u: &Matrix, s: &[f64], v: &Matrix, tol: f64) {
1011        let k = s.len();
1012        let mut diag_s = Matrix::zeros(k, k);
1013        for (i, &si) in s.iter().enumerate() {
1014            diag_s.data[i * k + i] = si;
1015        }
1016        let us = mat_mul_raw(u, &diag_s);
1017        let recon = mat_mul_raw(&us, &transpose_raw(v));
1018        for r in 0..a.rows {
1019            for c in 0..a.cols {
1020                let diff = (recon.data[r * recon.cols + c] - a.data[r * a.cols + c]).abs();
1021                assert!(
1022                    diff < tol,
1023                    "reconstruction mismatch at ({r},{c}): got {} expected {}, diff {diff}",
1024                    recon.data[r * recon.cols + c],
1025                    a.data[r * a.cols + c]
1026                );
1027            }
1028        }
1029    }
1030
1031    fn assert_orthonormal_columns(m: &Matrix, tol: f64) {
1032        let mtm = mat_mul_raw(&transpose_raw(m), m);
1033        let k = mtm.rows;
1034        for i in 0..k {
1035            for j in 0..k {
1036                let expected = if i == j { 1.0 } else { 0.0 };
1037                let diff = (mtm.data[i * k + j] - expected).abs();
1038                assert!(
1039                    diff < tol,
1040                    "orthonormality violated at ({i},{j}): got {}, expected {expected}",
1041                    mtm.data[i * k + j]
1042                );
1043            }
1044        }
1045    }
1046
1047    fn assert_singular_values_sorted(s: &[f64]) {
1048        for (i, &si) in s.iter().enumerate() {
1049            assert!(si >= -1e-14, "singular value s[{i}] = {si} is negative");
1050        }
1051        for i in 1..s.len() {
1052            assert!(
1053                s[i - 1] >= s[i] - 1e-12,
1054                "not descending: s[{}]={} < s[{}]={}",
1055                i - 1,
1056                s[i - 1],
1057                i,
1058                s[i]
1059            );
1060        }
1061    }
1062
1063    // ---- SVD correctness tests ----
1064
1065    #[test]
1066    fn test_identity_3x3() {
1067        let mut data = vec![0.0; 9];
1068        for i in 0..3 {
1069            data[i * 3 + i] = 1.0;
1070        }
1071        let a = Matrix {
1072            data,
1073            rows: 3,
1074            cols: 3,
1075        };
1076        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1077        for &si in &s {
1078            assert!((si - 1.0).abs() < 1e-10, "expected 1.0, got {si}");
1079        }
1080        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1081        assert_orthonormal_columns(&u, 1e-10);
1082        assert_orthonormal_columns(&v, 1e-10);
1083        assert_singular_values_sorted(&s);
1084    }
1085
1086    #[test]
1087    fn test_diagonal_matrix() {
1088        let a = Matrix {
1089            data: vec![5.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 1.0],
1090            rows: 3,
1091            cols: 3,
1092        };
1093        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1094        assert!((s[0] - 5.0).abs() < 1e-10);
1095        assert!((s[1] - 3.0).abs() < 1e-10);
1096        assert!((s[2] - 1.0).abs() < 1e-10);
1097        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1098        assert_orthonormal_columns(&u, 1e-10);
1099        assert_orthonormal_columns(&v, 1e-10);
1100    }
1101
1102    #[test]
1103    fn test_known_2x2() {
1104        // [[3,2],[2,3]] has singular values 5 and 1
1105        let a = Matrix {
1106            data: vec![3.0, 2.0, 2.0, 3.0],
1107            rows: 2,
1108            cols: 2,
1109        };
1110        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1111        assert!((s[0] - 5.0).abs() < 1e-10, "expected s[0]=5, got {}", s[0]);
1112        assert!((s[1] - 1.0).abs() < 1e-10, "expected s[1]=1, got {}", s[1]);
1113        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1114        assert_orthonormal_columns(&u, 1e-10);
1115        assert_orthonormal_columns(&v, 1e-10);
1116    }
1117
1118    #[test]
1119    fn test_known_3x3() {
1120        let a = Matrix {
1121            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1122            rows: 3,
1123            cols: 3,
1124        };
1125        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1126        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1127        assert_orthonormal_columns(&u, 1e-10);
1128        assert_orthonormal_columns(&v, 1e-10);
1129        assert_singular_values_sorted(&s);
1130    }
1131
1132    #[test]
1133    fn test_known_4x4() {
1134        let a = Matrix {
1135            data: vec![
1136                2.0, -1.0, 0.0, 0.0, -1.0, 2.0, -1.0, 0.0, 0.0, -1.0, 2.0, -1.0, 0.0, 0.0, -1.0,
1137                2.0,
1138            ],
1139            rows: 4,
1140            cols: 4,
1141        };
1142        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1143        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1144        assert_orthonormal_columns(&u, 1e-10);
1145        assert_orthonormal_columns(&v, 1e-10);
1146        assert_singular_values_sorted(&s);
1147    }
1148
1149    #[test]
1150    fn test_tall_rectangular() {
1151        let a = Matrix {
1152            data: vec![
1153                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1154            ],
1155            rows: 5,
1156            cols: 3,
1157        };
1158        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1159        assert_eq!(u.rows, 5);
1160        assert_eq!(u.cols, 3);
1161        assert_eq!(s.len(), 3);
1162        assert_eq!(v.rows, 3);
1163        assert_eq!(v.cols, 3);
1164        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1165        assert_orthonormal_columns(&u, 1e-10);
1166        assert_orthonormal_columns(&v, 1e-10);
1167    }
1168
1169    #[test]
1170    fn test_wide_rectangular() {
1171        let a = Matrix {
1172            data: vec![
1173                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1174            ],
1175            rows: 3,
1176            cols: 5,
1177        };
1178        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1179        assert_eq!(u.rows, 3);
1180        assert_eq!(u.cols, 3);
1181        assert_eq!(s.len(), 3);
1182        assert_eq!(v.rows, 5);
1183        assert_eq!(v.cols, 3);
1184        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1185        assert_orthonormal_columns(&u, 1e-10);
1186        assert_orthonormal_columns(&v, 1e-10);
1187    }
1188
1189    #[test]
1190    fn test_single_element() {
1191        let a = Matrix {
1192            data: vec![7.0],
1193            rows: 1,
1194            cols: 1,
1195        };
1196        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1197        assert!((s[0] - 7.0).abs() < 1e-10);
1198        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1199    }
1200
1201    #[test]
1202    fn test_single_element_negative() {
1203        let a = Matrix {
1204            data: vec![-5.0],
1205            rows: 1,
1206            cols: 1,
1207        };
1208        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1209        assert!((s[0] - 5.0).abs() < 1e-10);
1210        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1211    }
1212
1213    #[test]
1214    fn test_single_row() {
1215        let a = Matrix {
1216            data: vec![1.0, 2.0, 3.0, 4.0],
1217            rows: 1,
1218            cols: 4,
1219        };
1220        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1221        assert_eq!(s.len(), 1);
1222        let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1223        assert!((s[0] - expected).abs() < 1e-10);
1224        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1225    }
1226
1227    #[test]
1228    fn test_single_column() {
1229        let a = Matrix {
1230            data: vec![1.0, 2.0, 3.0, 4.0],
1231            rows: 4,
1232            cols: 1,
1233        };
1234        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1235        assert_eq!(s.len(), 1);
1236        let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1237        assert!((s[0] - expected).abs() < 1e-10);
1238        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1239    }
1240
1241    #[test]
1242    fn test_zero_matrix() {
1243        let a = Matrix::zeros(3, 3);
1244        let (_u, s, _v) = GolubKahanSvd::new().compute(&a).unwrap();
1245        for &si in &s {
1246            assert!(si.abs() < 1e-12);
1247        }
1248        assert_singular_values_sorted(&s);
1249    }
1250
1251    #[test]
1252    fn test_rank_deficient() {
1253        // B8: rank-2 in 3x3 (row 3 = row 1 + row 2)
1254        let a = Matrix {
1255            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0, 7.0, 9.0],
1256            rows: 3,
1257            cols: 3,
1258        };
1259        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1260        assert!(
1261            s[2] < 1e-10,
1262            "third singular value should be ~0, got {}",
1263            s[2]
1264        );
1265        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1266        assert_orthonormal_columns(&u, 1e-10);
1267        assert_orthonormal_columns(&v, 1e-10);
1268        assert_singular_values_sorted(&s);
1269    }
1270
1271    #[test]
1272    fn test_rank_one() {
1273        // B9: outer product [1,2,3] * [4,5]
1274        let a = Matrix {
1275            data: vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0],
1276            rows: 3,
1277            cols: 2,
1278        };
1279        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1280        let norm_u = (1.0f64 + 4.0 + 9.0).sqrt();
1281        let norm_v = (16.0f64 + 25.0).sqrt();
1282        let expected_s0 = norm_u * norm_v;
1283        assert!(
1284            (s[0] - expected_s0).abs() < 1e-8,
1285            "expected s[0]={expected_s0}, got {}",
1286            s[0]
1287        );
1288        assert!(s[1] < 1e-10, "expected s[1]~0, got {}", s[1]);
1289        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1290    }
1291
1292    #[test]
1293    fn test_repeated_singular_values() {
1294        // B10: diag(4, 4, 2)
1295        let a = Matrix {
1296            data: vec![4.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 2.0],
1297            rows: 3,
1298            cols: 3,
1299        };
1300        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1301        assert!((s[0] - 4.0).abs() < 1e-10);
1302        assert!((s[1] - 4.0).abs() < 1e-10);
1303        assert!((s[2] - 2.0).abs() < 1e-10);
1304        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1305        assert_orthonormal_columns(&u, 1e-10);
1306        assert_orthonormal_columns(&v, 1e-10);
1307    }
1308
1309    #[test]
1310    fn test_diagonal_with_zeros() {
1311        // B4, B8: diag(5, 0, 3)
1312        let a = Matrix {
1313            data: vec![5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0],
1314            rows: 3,
1315            cols: 3,
1316        };
1317        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1318        assert!((s[0] - 5.0).abs() < 1e-10);
1319        assert!((s[1] - 3.0).abs() < 1e-10);
1320        assert!(s[2] < 1e-10);
1321        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1322        assert_singular_values_sorted(&s);
1323    }
1324
1325    #[test]
1326    fn test_ill_conditioned() {
1327        // B11: condition number > 1e10
1328        let a = Matrix {
1329            data: vec![1.0, 0.0, 0.0, 0.0, 1e-12, 0.0, 0.0, 0.0, 1e-6],
1330            rows: 3,
1331            cols: 3,
1332        };
1333        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1334        assert!((s[0] - 1.0).abs() < 1e-8);
1335        assert_reconstruction(&a, &u, &s, &v, 1e-6);
1336        assert_singular_values_sorted(&s);
1337    }
1338
1339    #[test]
1340    fn test_extreme_small_values() {
1341        // B12: values near underflow
1342        let a = Matrix {
1343            data: vec![1e-300, 0.0, 0.0, 2e-300],
1344            rows: 2,
1345            cols: 2,
1346        };
1347        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1348        assert!(s[0].is_finite());
1349        assert!(s[1].is_finite());
1350        assert_singular_values_sorted(&s);
1351        assert_reconstruction(&a, &u, &s, &v, 1e-290);
1352    }
1353
1354    #[test]
1355    fn test_extreme_large_values() {
1356        // B12: values near overflow
1357        let a = Matrix {
1358            data: vec![1e+150, 0.0, 0.0, 2e+150],
1359            rows: 2,
1360            cols: 2,
1361        };
1362        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1363        assert!(s[0].is_finite());
1364        assert!(s[1].is_finite());
1365        for &val in &u.data {
1366            assert!(val.is_finite());
1367        }
1368        for &val in &v.data {
1369            assert!(val.is_finite());
1370        }
1371        assert_singular_values_sorted(&s);
1372    }
1373
1374    #[test]
1375    fn test_convergence_64x64() {
1376        // B15
1377        use rand::Rng;
1378        use rand::SeedableRng;
1379        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1380        let data: Vec<f64> = (0..64 * 64).map(|_| rng.gen_range(-1.0..1.0)).collect();
1381        let a = Matrix {
1382            data,
1383            rows: 64,
1384            cols: 64,
1385        };
1386        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1387        assert_reconstruction(&a, &u, &s, &v, 1e-8);
1388        assert_orthonormal_columns(&u, 1e-8);
1389        assert_orthonormal_columns(&v, 1e-8);
1390        assert_singular_values_sorted(&s);
1391    }
1392
1393    #[test]
1394    fn test_convergence_128x128() {
1395        // B15
1396        use rand::Rng;
1397        use rand::SeedableRng;
1398        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1399        let data: Vec<f64> = (0..128 * 128).map(|_| rng.gen_range(-1.0..1.0)).collect();
1400        let a = Matrix {
1401            data,
1402            rows: 128,
1403            cols: 128,
1404        };
1405        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1406        assert_reconstruction(&a, &u, &s, &v, 1e-8);
1407        assert_orthonormal_columns(&u, 1e-8);
1408        assert_orthonormal_columns(&v, 1e-8);
1409        assert_singular_values_sorted(&s);
1410    }
1411
1412    #[test]
1413    fn test_almost_bidiagonal() {
1414        // B16
1415        let a = Matrix {
1416            data: vec![
1417                5.0, 2.0, 0.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 0.0, 3.0, 0.5, 0.0, 0.0, 0.0, 1.0,
1418            ],
1419            rows: 4,
1420            cols: 4,
1421        };
1422        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1423        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1424        assert_singular_values_sorted(&s);
1425    }
1426
1427    #[test]
1428    fn test_custom_tolerance() {
1429        // B17
1430        let a = Matrix {
1431            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1432            rows: 3,
1433            cols: 3,
1434        };
1435        let (u, s, v) = GolubKahanSvd::new()
1436            .with_tolerance(1e-15)
1437            .compute(&a)
1438            .unwrap();
1439        assert_reconstruction(&a, &u, &s, &v, 1e-12);
1440    }
1441
1442    #[test]
1443    fn test_low_max_iter_triggers_error() {
1444        // B18: factor=0 -> max_iter=0, any non-trivial matrix must fail
1445        let a = Matrix {
1446            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1447            rows: 3,
1448            cols: 3,
1449        };
1450        let result = GolubKahanSvd::new().with_max_iter_factor(0).compute(&a);
1451        assert!(result.is_err(), "expected convergence error with factor=0");
1452        let err = result.unwrap_err();
1453        assert!(matches!(err, SvdError::Convergence { .. }));
1454    }
1455
1456    #[test]
1457    fn test_determinism() {
1458        // B19: same input -> identical output
1459        let a = Matrix {
1460            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1461            rows: 3,
1462            cols: 3,
1463        };
1464        let svd = GolubKahanSvd::new();
1465        let (u1, s1, v1) = svd.compute(&a).unwrap();
1466        let (u2, s2, v2) = svd.compute(&a).unwrap();
1467        assert_eq!(s1, s2, "singular values differ");
1468        assert_eq!(u1.data, u2.data, "U differs");
1469        assert_eq!(v1.data, v2.data, "V differs");
1470    }
1471}