Skip to main content

pc_rl_core/linalg/
golub_kahan.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-04-04
4
5//! Golub-Kahan bidiagonalization SVD algorithm.
6//!
7//! Provides an O(n^3) SVD implementation as an upgrade path from the
8//! existing Jacobi eigendecomposition approach in [`CpuLinAlg`](super::cpu::CpuLinAlg).
9//!
10//! # Overview
11//!
12//! The algorithm proceeds in two phases:
13//! 1. **Householder bidiagonalization** reduces the input matrix to
14//!    upper bidiagonal form.
15//! 2. **Implicit QR iteration** with Wilkinson shift extracts singular
16//!    values and vectors from the bidiagonal matrix.
17//!
18//! # Examples
19//!
20//! ```
21//! use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
22//!
23//! let svd = GolubKahanSvd::new()
24//!     .with_tolerance(1e-12)
25//!     .with_max_iter_factor(40);
26//! assert!((svd.tol - 1e-12).abs() < f64::EPSILON);
27//! assert_eq!(svd.max_iter_factor, 40);
28//! ```
29
30use std::fmt;
31
32use crate::error::PcError;
33use crate::matrix::Matrix;
34
35/// Error type for SVD decomposition failures.
36///
37/// Converts to [`PcError::ConfigValidation`] via the [`From`] impl,
38/// preserving the error message for upstream callers.
39///
40/// # Examples
41///
42/// ```
43/// use pc_rl_core::linalg::golub_kahan::SvdError;
44///
45/// let err = SvdError::Convergence { size: 10, iterations: 300 };
46/// assert!(format!("{err}").contains("10"));
47/// ```
48#[derive(Debug)]
49pub enum SvdError {
50    /// The iterative solver did not converge within the allowed iterations.
51    Convergence {
52        /// Matrix dimension (min(rows, cols)) that failed to converge.
53        size: usize,
54        /// Total iterations attempted before giving up.
55        iterations: usize,
56    },
57    /// The input matrix is invalid (e.g., contains NaN or Inf).
58    InvalidInput {
59        /// Human-readable description of the problem.
60        reason: String,
61    },
62}
63
64impl fmt::Display for SvdError {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            SvdError::Convergence { size, iterations } => {
68                write!(
69                    f,
70                    "SVD failed to converge for matrix of size {size} \
71                     after {iterations} iterations"
72                )
73            }
74            SvdError::InvalidInput { reason } => {
75                write!(f, "SVD invalid input: {reason}")
76            }
77        }
78    }
79}
80
81impl std::error::Error for SvdError {}
82
83impl From<SvdError> for PcError {
84    fn from(e: SvdError) -> Self {
85        PcError::ConfigValidation(e.to_string())
86    }
87}
88
89/// Golub-Kahan bidiagonalization SVD solver.
90///
91/// Decomposes a matrix `A` into `U * diag(S) * V^T` using Householder
92/// bidiagonalization followed by implicit QR iteration with Wilkinson shift.
93///
94/// # Fields
95///
96/// * `tol` - Convergence tolerance for off-diagonal elements (default: 1e-14).
97/// * `max_iter_factor` - Maximum iterations = `factor * n` where `n = min(rows, cols)` (default: 30).
98///
99/// # Examples
100///
101/// ```
102/// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
103///
104/// let svd = GolubKahanSvd::new();
105/// assert!((svd.tol - 1e-14).abs() < f64::EPSILON);
106/// assert_eq!(svd.max_iter_factor, 30);
107/// ```
108#[derive(Debug, Clone)]
109pub struct GolubKahanSvd {
110    /// Convergence tolerance for off-diagonal elements.
111    pub tol: f64,
112    /// Maximum iterations as a multiple of `min(rows, cols)`.
113    pub max_iter_factor: usize,
114}
115
116impl GolubKahanSvd {
117    /// Creates a new solver with default parameters.
118    ///
119    /// Defaults: `tol = 1e-14`, `max_iter_factor = 30`.
120    ///
121    /// # Examples
122    ///
123    /// ```
124    /// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
125    ///
126    /// let svd = GolubKahanSvd::new();
127    /// assert!((svd.tol - 1e-14).abs() < f64::EPSILON);
128    /// ```
129    pub fn new() -> Self {
130        Self {
131            tol: 1e-14,
132            max_iter_factor: 30,
133        }
134    }
135
136    /// Sets a custom convergence tolerance.
137    ///
138    /// # Arguments
139    ///
140    /// * `tol` - The convergence threshold for off-diagonal elements.
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
146    ///
147    /// let svd = GolubKahanSvd::new().with_tolerance(1e-8);
148    /// assert!((svd.tol - 1e-8).abs() < f64::EPSILON);
149    /// ```
150    pub fn with_tolerance(mut self, tol: f64) -> Self {
151        self.tol = tol;
152        self
153    }
154
155    /// Sets a custom maximum iteration factor.
156    ///
157    /// Total maximum iterations will be `factor * min(rows, cols)`.
158    ///
159    /// # Arguments
160    ///
161    /// * `factor` - Multiplier for the iteration limit.
162    ///
163    /// # Examples
164    ///
165    /// ```
166    /// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
167    ///
168    /// let svd = GolubKahanSvd::new().with_max_iter_factor(50);
169    /// assert_eq!(svd.max_iter_factor, 50);
170    /// ```
171    pub fn with_max_iter_factor(mut self, factor: usize) -> Self {
172        self.max_iter_factor = factor;
173        self
174    }
175
176    /// Computes the SVD of matrix `a`: `A = U * diag(S) * V^T`.
177    ///
178    /// Returns `(U, S, V)` where:
179    /// - `U` is `(m, k)` with orthonormal columns,
180    /// - `S` is a `Vec<f64>` of `k` non-negative singular values in descending order,
181    /// - `V` is `(n, k)` with orthonormal columns,
182    /// - `k = min(m, n)`.
183    ///
184    /// # Arguments
185    ///
186    /// * `a` - The input matrix to decompose.
187    ///
188    /// # Errors
189    ///
190    /// Returns [`SvdError::InvalidInput`] if the matrix contains NaN or Inf values.
191    /// Returns [`SvdError::Convergence`] if the iterative solver does not converge.
192    ///
193    /// # Examples
194    ///
195    /// ```
196    /// use pc_rl_core::linalg::golub_kahan::GolubKahanSvd;
197    /// use pc_rl_core::matrix::Matrix;
198    ///
199    /// let svd = GolubKahanSvd::new();
200    /// let empty = Matrix::zeros(0, 0);
201    /// let (u, s, v) = svd.compute(&empty).unwrap();
202    /// assert_eq!(s.len(), 0);
203    /// ```
204    pub fn compute(&self, a: &Matrix) -> Result<(Matrix, Vec<f64>, Matrix), SvdError> {
205        // Validate: reject NaN/Inf
206        for &val in &a.data {
207            if val.is_nan() || val.is_infinite() {
208                return Err(SvdError::InvalidInput {
209                    reason: "matrix contains NaN or Inf".to_string(),
210                });
211            }
212        }
213
214        let m = a.rows;
215        let n = a.cols;
216
217        // Handle empty matrix
218        if m == 0 || n == 0 {
219            return Ok((Matrix::zeros(m, 0), Vec::new(), Matrix::zeros(n, 0)));
220        }
221
222        // If wide (m < n), transpose and swap U/V at the end
223        let transposed = m < n;
224        let (work_m, work_n, work_data) = if transposed {
225            let mut t = vec![0.0; m * n];
226            for r in 0..m {
227                for c in 0..n {
228                    t[c * m + r] = a.data[r * n + c];
229                }
230            }
231            (n, m, t)
232        } else {
233            (m, n, a.data.clone())
234        };
235
236        // k = min(work_m, work_n) = work_n since work_m >= work_n
237        let k = work_n;
238
239        // Handle 1x1 case directly
240        if k == 1 {
241            // For a single-column tall matrix, SVD is: u = a/||a||, s = ||a||, v = [1]
242            let norm: f64 = work_data.iter().map(|&x| x * x).sum::<f64>().sqrt();
243            if norm < self.tol {
244                let u_mat = make_identity(work_m, 1);
245                let v_mat = make_identity(1, 1);
246                return Self::finalize(u_mat, vec![0.0], v_mat, transposed);
247            }
248            let sign = if work_data[0] >= 0.0 { 1.0 } else { -1.0 };
249            let mut u_data = vec![0.0; work_m];
250            for i in 0..work_m {
251                u_data[i] = work_data[i] * sign / norm;
252            }
253            let u_mat = Matrix {
254                data: u_data,
255                rows: work_m,
256                cols: 1,
257            };
258            let v_mat = Matrix {
259                data: vec![sign],
260                rows: 1,
261                cols: 1,
262            };
263            return Self::finalize(u_mat, vec![norm], v_mat, transposed);
264        }
265
266        // Phase 1: Householder bidiagonalization
267        // Work on a copy. After this: A is overwritten, diag and superdiag extracted.
268        let mut w = work_data;
269        let mut u_acc = make_identity(work_m, work_m);
270        let mut v_acc = make_identity(work_n, work_n);
271
272        let mut diag = vec![0.0; k];
273        let mut superdiag = vec![0.0; k.saturating_sub(1)];
274
275        householder_bidiag(
276            &mut w,
277            work_m,
278            work_n,
279            &mut u_acc,
280            &mut v_acc,
281            &mut diag,
282            &mut superdiag,
283        );
284
285        // Phase 2: Implicit QR iteration on bidiagonal
286        let max_iter = self.max_iter_factor * k * k;
287        implicit_qr_svd(
288            &mut diag,
289            &mut superdiag,
290            &mut u_acc,
291            &mut v_acc,
292            work_m,
293            work_n,
294            k,
295            self.tol,
296            max_iter,
297        )?;
298
299        // Phase 3: Make singular values non-negative, sort descending
300        for (i, d) in diag.iter_mut().enumerate().take(k) {
301            if *d < 0.0 {
302                *d = -*d;
303                // Flip corresponding column of U
304                for r in 0..work_m {
305                    u_acc.data[r * work_m + i] = -u_acc.data[r * work_m + i];
306                }
307            }
308        }
309
310        // Sort by descending singular value
311        let mut indices: Vec<usize> = (0..k).collect();
312        indices.sort_by(|&a_idx, &b_idx| {
313            diag[b_idx]
314                .partial_cmp(&diag[a_idx])
315                .unwrap_or(std::cmp::Ordering::Equal)
316        });
317
318        let sorted_s: Vec<f64> = indices.iter().map(|&i| diag[i]).collect();
319
320        // Extract thin U (work_m x k) with sorted columns
321        let mut u_thin = Matrix::zeros(work_m, k);
322        for (new_col, &old_col) in indices.iter().enumerate() {
323            for r in 0..work_m {
324                u_thin.data[r * k + new_col] = u_acc.data[r * work_m + old_col];
325            }
326        }
327
328        // Extract thin V (work_n x k) with sorted columns
329        let mut v_thin = Matrix::zeros(work_n, k);
330        for (new_col, &old_col) in indices.iter().enumerate() {
331            for r in 0..work_n {
332                v_thin.data[r * k + new_col] = v_acc.data[r * work_n + old_col];
333            }
334        }
335
336        Self::finalize(u_thin, sorted_s, v_thin, transposed)
337    }
338
339    /// Swaps U and V if the input was transposed, returning the final result.
340    fn finalize(
341        u: Matrix,
342        s: Vec<f64>,
343        v: Matrix,
344        transposed: bool,
345    ) -> Result<(Matrix, Vec<f64>, Matrix), SvdError> {
346        if transposed {
347            Ok((v, s, u))
348        } else {
349            Ok((u, s, v))
350        }
351    }
352}
353
354impl Default for GolubKahanSvd {
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360/// Creates an identity-like matrix of size `rows x cols`.
361///
362/// Places 1.0 on the diagonal up to `min(rows, cols)`.
363fn make_identity(rows: usize, cols: usize) -> Matrix {
364    let mut data = vec![0.0; rows * cols];
365    let k = rows.min(cols);
366    for i in 0..k {
367        data[i * cols + i] = 1.0;
368    }
369    Matrix { data, rows, cols }
370}
371
372/// Householder bidiagonalization: reduces an m x n matrix (m >= n)
373/// to upper bidiagonal form, accumulating transformations into U and V.
374///
375/// After completion, `diag` contains the diagonal and `superdiag` the
376/// superdiagonal of the bidiagonal matrix B such that A = U * B * V^T.
377///
378/// # Arguments
379///
380/// * `w` - Row-major m x n matrix data, overwritten during computation.
381/// * `m` - Number of rows.
382/// * `n` - Number of columns.
383/// * `u_acc` - m x m identity matrix, accumulates left reflections.
384/// * `v_acc` - n x n identity matrix, accumulates right reflections.
385/// * `diag` - Output diagonal of bidiagonal B (length n).
386/// * `superdiag` - Output superdiagonal of bidiagonal B (length n-1).
387fn householder_bidiag(
388    w: &mut [f64],
389    m: usize,
390    n: usize,
391    u_acc: &mut Matrix,
392    v_acc: &mut Matrix,
393    diag: &mut [f64],
394    superdiag: &mut [f64],
395) {
396    for j in 0..n {
397        // Left Householder: zero below w[j][j] in column j
398        {
399            let mut col = vec![0.0; m - j];
400            for i in j..m {
401                col[i - j] = w[i * n + j];
402            }
403            let (v_house, beta) = householder_vector(&col);
404            if beta != 0.0 {
405                // Apply H = I - beta * v * v^T to w[j..m, j..n]
406                apply_householder_left(w, m, n, j, j, &v_house, beta);
407                // Accumulate into U: U = U * H (apply on right to u_acc)
408                apply_householder_right_to_matrix(u_acc, m, m, j, &v_house, beta);
409            }
410        }
411        diag[j] = w[j * n + j];
412
413        // Right Householder: zero beyond w[j][j+1] in row j
414        if j + 2 <= n {
415            let start = j + 1;
416            let mut row = vec![0.0; n - start];
417            for c in start..n {
418                row[c - start] = w[j * n + c];
419            }
420            let (v_house, beta) = householder_vector(&row);
421            if beta != 0.0 {
422                // Apply H = I - beta * v * v^T to w[j..m, (j+1)..n] on the right
423                apply_householder_right(w, m, n, j, start, &v_house, beta);
424                // Accumulate into V: V = V * H
425                apply_householder_right_to_matrix(v_acc, n, n, start, &v_house, beta);
426            }
427            if j < n - 1 {
428                superdiag[j] = w[j * n + j + 1];
429            }
430        } else if j < n - 1 {
431            superdiag[j] = w[j * n + j + 1];
432        }
433    }
434}
435
436/// Computes a Householder vector `v` and scalar `beta` such that
437/// `(I - beta * v * v^T) * x` zeroes all elements below the first.
438///
439/// Returns `(v, beta)` where `v[0] = 1.0`.
440///
441/// # Arguments
442///
443/// * `x` - Input vector to reflect.
444fn householder_vector(x: &[f64]) -> (Vec<f64>, f64) {
445    let len = x.len();
446    if len == 0 {
447        return (Vec::new(), 0.0);
448    }
449    if len == 1 {
450        return (vec![1.0], 0.0);
451    }
452
453    let mut sigma = 0.0;
454    for &xi in &x[1..] {
455        sigma += xi * xi;
456    }
457
458    let mut v = vec![0.0; len];
459    v[0] = 1.0;
460    v[1..len].copy_from_slice(&x[1..len]);
461
462    if sigma < 1e-300 {
463        return (v, 0.0);
464    }
465
466    let norm_x = (x[0] * x[0] + sigma).sqrt();
467    // Choose sign to avoid cancellation
468    if x[0] <= 0.0 {
469        v[0] = x[0] - norm_x;
470    } else {
471        v[0] = -sigma / (x[0] + norm_x);
472    }
473
474    let beta = 2.0 * v[0] * v[0] / (sigma + v[0] * v[0]);
475    // Normalize so v[0] = 1
476    let v0 = v[0];
477    for vi in v.iter_mut() {
478        *vi /= v0;
479    }
480
481    (v, beta)
482}
483
484/// Applies a left Householder reflection: w[row_start..m, col_start..n] =
485/// (I - beta * v * v^T) * w[row_start..m, col_start..n].
486fn apply_householder_left(
487    w: &mut [f64],
488    _m: usize,
489    n: usize,
490    row_start: usize,
491    col_start: usize,
492    v: &[f64],
493    beta: f64,
494) {
495    let v_len = v.len();
496    let num_cols = n - col_start;
497    // Compute p = beta * v^T * W_sub for each column
498    let mut p = vec![0.0; num_cols];
499    for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
500        let row = row_start + vi_idx;
501        for c in 0..num_cols {
502            p[c] += vi * w[row * n + col_start + c];
503        }
504    }
505    // W_sub -= v * p^T
506    for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
507        let row = row_start + vi_idx;
508        for c in 0..num_cols {
509            w[row * n + col_start + c] -= beta * vi * p[c];
510        }
511    }
512}
513
514/// Applies a right Householder reflection: w[row_start..m, col_start..n] =
515/// w[row_start..m, col_start..n] * (I - beta * v * v^T).
516fn apply_householder_right(
517    w: &mut [f64],
518    m: usize,
519    n: usize,
520    row_start: usize,
521    col_start: usize,
522    v: &[f64],
523    beta: f64,
524) {
525    let v_len = v.len();
526    let num_rows = m - row_start;
527    // For each row, compute dot = row . v, then row -= beta * dot * v
528    for ri in 0..num_rows {
529        let row = row_start + ri;
530        let mut dot = 0.0;
531        for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
532            dot += w[row * n + col_start + vi_idx] * vi;
533        }
534        for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
535            w[row * n + col_start + vi_idx] -= beta * dot * vi;
536        }
537    }
538}
539
540/// Applies a Householder reflection on the right to a full accumulator matrix:
541/// `acc = acc * H` where H operates on rows `start..start+v.len()`.
542///
543/// Equivalent to: for each row of acc, update columns start..start+v.len().
544fn apply_householder_right_to_matrix(
545    acc: &mut Matrix,
546    rows: usize,
547    cols: usize,
548    col_start: usize,
549    v: &[f64],
550    beta: f64,
551) {
552    let v_len = v.len();
553    for r in 0..rows {
554        let mut dot = 0.0;
555        for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
556            dot += acc.data[r * cols + col_start + vi_idx] * vi;
557        }
558        for (vi_idx, &vi) in v.iter().enumerate().take(v_len) {
559            acc.data[r * cols + col_start + vi_idx] -= beta * dot * vi;
560        }
561    }
562}
563
564/// Computes a Givens rotation `(c, s)` such that
565/// `[c s; -s c]^T * [a; b] = [r; 0]`.
566///
567/// # Arguments
568///
569/// * `a` - First element.
570/// * `b` - Second element (to be zeroed).
571///
572/// # Returns
573///
574/// `(c, s)` cosine and sine of the rotation angle.
575fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
576    if b.abs() < 1e-300 {
577        return (1.0, 0.0);
578    }
579    if a.abs() < 1e-300 {
580        return (0.0, b.signum());
581    }
582    if b.abs() > a.abs() {
583        let tau = a / b;
584        let s = (1.0 + tau * tau).sqrt().recip() * b.signum();
585        let c = s * tau;
586        (c, s)
587    } else {
588        let tau = b / a;
589        let c = (1.0 + tau * tau).sqrt().recip() * a.signum();
590        let s = c * tau;
591        (c, s)
592    }
593}
594
595/// Applies a Givens rotation to columns `i` and `j` of a matrix,
596/// operating on all rows: `[col_i, col_j] = [col_i, col_j] * G`.
597fn apply_givens_cols(
598    mat: &mut Matrix,
599    rows: usize,
600    stride: usize,
601    i: usize,
602    j: usize,
603    c: f64,
604    s: f64,
605) {
606    for r in 0..rows {
607        let a = mat.data[r * stride + i];
608        let b = mat.data[r * stride + j];
609        mat.data[r * stride + i] = c * a + s * b;
610        mat.data[r * stride + j] = -s * a + c * b;
611    }
612}
613
614/// Implicit QR iteration on a bidiagonal matrix to extract singular values.
615///
616/// Operates on the diagonal `diag` and superdiagonal `superdiag`, applying
617/// Givens rotations accumulated into `u_acc` and `v_acc`.
618///
619/// # Errors
620///
621/// Returns [`SvdError::Convergence`] if the iteration does not converge
622/// within `max_iter` total sweeps.
623#[allow(clippy::too_many_arguments)]
624fn implicit_qr_svd(
625    diag: &mut [f64],
626    superdiag: &mut [f64],
627    u_acc: &mut Matrix,
628    v_acc: &mut Matrix,
629    u_rows: usize,
630    v_rows: usize,
631    k: usize,
632    tol: f64,
633    max_iter: usize,
634) -> Result<(), SvdError> {
635    if k <= 1 {
636        return Ok(());
637    }
638
639    let mut iter_count = 0usize;
640
641    loop {
642        // Find the largest q such that B(k-q-1:k-1, k-q-1:k-1) is diagonal
643        // (i.e., superdiag entries from end that are zero)
644        let mut q = 0usize;
645        while q < k - 1 {
646            let idx = k - 2 - q;
647            let thresh = tol * (diag[idx].abs() + diag[idx + 1].abs());
648            if superdiag[idx].abs() <= thresh.max(tol * 1e-2) {
649                superdiag[idx] = 0.0;
650                q += 1;
651            } else {
652                break;
653            }
654        }
655
656        if q >= k - 1 {
657            // All superdiagonal elements are zero — converged
658            break;
659        }
660
661        // Active block is diag[p..k-q], superdiag[p..k-q-1]
662        let block_end = k - q; // exclusive index into diag
663                               // Find p: largest index such that superdiag[p-1] is zero (or p=0)
664        let mut p = block_end - 1;
665        while p > 0 {
666            let idx = p - 1;
667            let thresh = tol * (diag[idx].abs() + diag[idx + 1].abs());
668            if superdiag[idx].abs() <= thresh.max(tol * 1e-2) {
669                superdiag[idx] = 0.0;
670                break;
671            }
672            p -= 1;
673        }
674
675        let block_size = block_end - p;
676        if block_size <= 1 {
677            continue;
678        }
679
680        // Check for zero diagonal entry in the block — if found, zero the
681        // superdiagonal by rotation and restart
682        let mut found_zero_diag = false;
683        for i in p..block_end {
684            if diag[i].abs() < tol * 1e-2 {
685                // Zero out superdiag element adjacent to this zero diagonal
686                if i < block_end - 1 && superdiag[i].abs() > 0.0 {
687                    zero_superdiag_row(diag, superdiag, u_acc, u_rows, i, block_end);
688                } else if i > p && superdiag[i - 1].abs() > 0.0 {
689                    zero_superdiag_col(diag, superdiag, v_acc, v_rows, i, p);
690                }
691                found_zero_diag = true;
692                break;
693            }
694        }
695        if found_zero_diag {
696            iter_count += 1;
697            if iter_count > max_iter {
698                return Err(SvdError::Convergence {
699                    size: k,
700                    iterations: max_iter,
701                });
702            }
703            continue;
704        }
705
706        // Wilkinson shift from trailing 2x2 of B^T * B
707        let n1 = block_end - 1;
708        let n2 = block_end - 2;
709        let d_n1 = diag[n1];
710        let d_n2 = diag[n2];
711        let e_n2 = superdiag[n2];
712        // T = B^T B trailing 2x2:
713        // [d_n2^2 + e_{n2-1}^2,   d_n2 * e_n2          ]
714        // [d_n2 * e_n2,            d_n1^2 + e_n2^2      ]
715        let e_n3_sq = if n2 > p {
716            superdiag[n2 - 1] * superdiag[n2 - 1]
717        } else {
718            0.0
719        };
720        let t11 = d_n2 * d_n2 + e_n3_sq;
721        let t12 = d_n2 * e_n2;
722        let t22 = d_n1 * d_n1 + e_n2 * e_n2;
723
724        let shift = wilkinson_shift(t11, t12, t22);
725
726        // Golub-Kahan SVD step (bulge chase)
727        golub_kahan_step(
728            diag, superdiag, u_acc, v_acc, u_rows, v_rows, p, block_end, shift,
729        );
730
731        iter_count += 1;
732        if iter_count > max_iter {
733            return Err(SvdError::Convergence {
734                size: k,
735                iterations: max_iter,
736            });
737        }
738    }
739
740    Ok(())
741}
742
743/// Computes the Wilkinson shift for the trailing 2x2 of B^T * B.
744///
745/// Given the 2x2 matrix `[[a, b], [b, d]]`, returns the eigenvalue
746/// closest to `d`.
747fn wilkinson_shift(a: f64, b: f64, d: f64) -> f64 {
748    let delta = (a - d) * 0.5;
749    if delta.abs() < 1e-300 && b.abs() < 1e-300 {
750        return d;
751    }
752    let sign = if delta >= 0.0 { 1.0 } else { -1.0 };
753    d - b * b / (delta + sign * (delta * delta + b * b).sqrt())
754}
755
756/// Performs one implicit QR step (bulge chase) on the bidiagonal matrix.
757///
758/// This is the core of the Golub-Kahan SVD iteration (Algorithm 8.6.2
759/// from Golub & Van Loan).
760#[allow(clippy::too_many_arguments)]
761fn golub_kahan_step(
762    diag: &mut [f64],
763    superdiag: &mut [f64],
764    u_acc: &mut Matrix,
765    v_acc: &mut Matrix,
766    u_rows: usize,
767    v_rows: usize,
768    p: usize,
769    block_end: usize,
770    shift: f64,
771) {
772    let mut y = diag[p] * diag[p] - shift;
773    let mut z = diag[p] * superdiag[p];
774
775    for i in p..block_end - 1 {
776        // Right Givens rotation to zero z (applied to V)
777        let (c, s) = givens_rotation(y, z);
778        if i > p {
779            superdiag[i - 1] = c * superdiag[i - 1] + s * z;
780            // Note: the component that was z is now zero
781        }
782        let old_d_i = diag[i];
783        let old_e_i = superdiag[i];
784        diag[i] = c * old_d_i + s * old_e_i;
785        superdiag[i] = -s * old_d_i + c * old_e_i;
786        let old_d_i1 = diag[i + 1];
787        z = s * old_d_i1;
788        diag[i + 1] = c * old_d_i1;
789
790        // Accumulate into V
791        apply_givens_cols(v_acc, v_rows, v_rows, i, i + 1, c, s);
792
793        // Left Givens rotation to zero z (applied to U)
794        let (c, s) = givens_rotation(diag[i], z);
795        diag[i] = c * diag[i] + s * z;
796        let old_e_i = superdiag[i];
797        let old_d_i1 = diag[i + 1];
798        superdiag[i] = c * old_e_i + s * old_d_i1;
799        diag[i + 1] = -s * old_e_i + c * old_d_i1;
800        if i + 1 < block_end - 1 {
801            let old_e_i1 = superdiag[i + 1];
802            z = s * old_e_i1;
803            superdiag[i + 1] = c * old_e_i1;
804        }
805        y = superdiag[i];
806
807        // Accumulate into U
808        apply_givens_cols(u_acc, u_rows, u_rows, i, i + 1, c, s);
809    }
810}
811
812/// Zeros a superdiagonal element when a diagonal entry is zero,
813/// chasing the bulge rightward via left rotations.
814///
815/// When `diag[zero_idx] ≈ 0`, the left rotation `G^T * B` between rows
816/// `zero_idx` and `j+1` zeros the fill-in but creates a new bulge from
817/// `superdiag[j+1]`. The bulge is tracked explicitly and passed forward
818/// to the next iteration rather than stored back into `superdiag`.
819fn zero_superdiag_row(
820    diag: &mut [f64],
821    superdiag: &mut [f64],
822    u_acc: &mut Matrix,
823    u_rows: usize,
824    zero_idx: usize,
825    block_end: usize,
826) {
827    let mut bulge = superdiag[zero_idx];
828    superdiag[zero_idx] = 0.0;
829
830    for j in zero_idx..block_end - 1 {
831        let (c, s) = givens_rotation(diag[j + 1], bulge);
832        diag[j + 1] = c * diag[j + 1] + s * bulge;
833        // bulge position is now zero by construction
834        if j + 1 < block_end - 1 {
835            let old_e = superdiag[j + 1];
836            superdiag[j + 1] = c * old_e;
837            bulge = -s * old_e;
838            apply_givens_cols(u_acc, u_rows, u_rows, j + 1, zero_idx, c, s);
839            if bulge.abs() < 1e-300 {
840                break;
841            }
842        } else {
843            apply_givens_cols(u_acc, u_rows, u_rows, j + 1, zero_idx, c, s);
844        }
845    }
846}
847
848/// Zeros a superdiagonal element when a diagonal entry is zero,
849/// chasing the bulge leftward via right rotations.
850///
851/// When `diag[zero_idx] ≈ 0`, the right rotation `B * G` between columns
852/// `zero_idx` and `j` zeros the fill-in but creates a new bulge from
853/// `superdiag[j-1]`. The bulge is tracked explicitly and passed backward
854/// to the next iteration rather than stored back into `superdiag`.
855fn zero_superdiag_col(
856    diag: &mut [f64],
857    superdiag: &mut [f64],
858    v_acc: &mut Matrix,
859    v_rows: usize,
860    zero_idx: usize,
861    block_start: usize,
862) {
863    let mut bulge = superdiag[zero_idx - 1];
864    superdiag[zero_idx - 1] = 0.0;
865
866    for j in (block_start..zero_idx).rev() {
867        let (c, s) = givens_rotation(diag[j], bulge);
868        diag[j] = c * diag[j] + s * bulge;
869        // bulge position is now zero by construction
870        if j > block_start {
871            let old_e = superdiag[j - 1];
872            superdiag[j - 1] = c * old_e;
873            bulge = -s * old_e;
874            apply_givens_cols(v_acc, v_rows, v_rows, j, zero_idx, c, s);
875            if bulge.abs() < 1e-300 {
876                break;
877            }
878        } else {
879            apply_givens_cols(v_acc, v_rows, v_rows, j, zero_idx, c, s);
880        }
881    }
882}
883
884#[cfg(test)]
885mod tests {
886    use super::*;
887
888    #[test]
889    fn test_new_returns_default_parameters() {
890        let svd = GolubKahanSvd::new();
891        assert!((svd.tol - 1e-14).abs() < f64::EPSILON);
892        assert_eq!(svd.max_iter_factor, 30);
893    }
894
895    #[test]
896    fn test_with_tolerance_sets_custom_tol() {
897        let svd = GolubKahanSvd::new().with_tolerance(1e-8);
898        assert!((svd.tol - 1e-8).abs() < f64::EPSILON);
899    }
900
901    #[test]
902    fn test_with_max_iter_factor_sets_custom_factor() {
903        let svd = GolubKahanSvd::new().with_max_iter_factor(50);
904        assert_eq!(svd.max_iter_factor, 50);
905    }
906
907    #[test]
908    fn test_default_trait_matches_new() {
909        let a = GolubKahanSvd::new();
910        let b = GolubKahanSvd::default();
911        assert!((a.tol - b.tol).abs() < f64::EPSILON);
912        assert_eq!(a.max_iter_factor, b.max_iter_factor);
913    }
914
915    #[test]
916    fn test_svd_error_display() {
917        let err = SvdError::Convergence {
918            size: 10,
919            iterations: 300,
920        };
921        let msg = format!("{err}");
922        assert!(msg.contains("10"));
923        assert!(msg.contains("300"));
924    }
925
926    #[test]
927    fn test_svd_error_converts_to_pc_error() {
928        let err = SvdError::Convergence {
929            size: 5,
930            iterations: 150,
931        };
932        let pc_err: crate::error::PcError = err.into();
933        assert!(matches!(pc_err, crate::error::PcError::ConfigValidation(_)));
934    }
935
936    #[test]
937    fn test_empty_matrix() {
938        // B6: 0x0 matrix returns empty results
939        let a = Matrix::zeros(0, 0);
940        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
941        assert_eq!(u.rows, 0);
942        assert_eq!(u.cols, 0);
943        assert!(s.is_empty());
944        assert_eq!(v.rows, 0);
945        assert_eq!(v.cols, 0);
946    }
947
948    #[test]
949    fn test_nan_input_returns_error() {
950        // B13: NaN input returns Err
951        let a = Matrix {
952            data: vec![1.0, f64::NAN, 3.0, 4.0],
953            rows: 2,
954            cols: 2,
955        };
956        let result = GolubKahanSvd::new().compute(&a);
957        assert!(result.is_err());
958        let err = result.unwrap_err();
959        assert!(matches!(err, SvdError::InvalidInput { .. }));
960    }
961
962    #[test]
963    fn test_inf_input_returns_error() {
964        // B14: Inf input returns Err
965        let a = Matrix {
966            data: vec![1.0, f64::INFINITY, 3.0, 4.0],
967            rows: 2,
968            cols: 2,
969        };
970        let result = GolubKahanSvd::new().compute(&a);
971        assert!(result.is_err());
972        let err = result.unwrap_err();
973        assert!(matches!(err, SvdError::InvalidInput { .. }));
974    }
975
976    #[test]
977    fn test_neg_inf_input_returns_error() {
978        let a = Matrix {
979            data: vec![f64::NEG_INFINITY, 2.0, 3.0, 4.0],
980            rows: 2,
981            cols: 2,
982        };
983        let result = GolubKahanSvd::new().compute(&a);
984        assert!(result.is_err());
985    }
986
987    // ---- Helper functions for SVD correctness tests ----
988
989    /// Helper: multiply matrices using raw data.
990    fn mat_mul_raw(a: &Matrix, b: &Matrix) -> Matrix {
991        assert_eq!(a.cols, b.rows);
992        let mut c = Matrix::zeros(a.rows, b.cols);
993        for i in 0..a.rows {
994            for k in 0..a.cols {
995                let aik = a.data[i * a.cols + k];
996                for j in 0..b.cols {
997                    c.data[i * c.cols + j] += aik * b.data[k * b.cols + j];
998                }
999            }
1000        }
1001        c
1002    }
1003
1004    fn transpose_raw(a: &Matrix) -> Matrix {
1005        let mut t = Matrix::zeros(a.cols, a.rows);
1006        for r in 0..a.rows {
1007            for c in 0..a.cols {
1008                t.data[c * t.cols + r] = a.data[r * a.cols + c];
1009            }
1010        }
1011        t
1012    }
1013
1014    fn assert_reconstruction(a: &Matrix, u: &Matrix, s: &[f64], v: &Matrix, tol: f64) {
1015        let k = s.len();
1016        let mut diag_s = Matrix::zeros(k, k);
1017        for (i, &si) in s.iter().enumerate() {
1018            diag_s.data[i * k + i] = si;
1019        }
1020        let us = mat_mul_raw(u, &diag_s);
1021        let recon = mat_mul_raw(&us, &transpose_raw(v));
1022        for r in 0..a.rows {
1023            for c in 0..a.cols {
1024                let diff = (recon.data[r * recon.cols + c] - a.data[r * a.cols + c]).abs();
1025                assert!(
1026                    diff < tol,
1027                    "reconstruction mismatch at ({r},{c}): got {} expected {}, diff {diff}",
1028                    recon.data[r * recon.cols + c],
1029                    a.data[r * a.cols + c]
1030                );
1031            }
1032        }
1033    }
1034
1035    fn assert_orthonormal_columns(m: &Matrix, tol: f64) {
1036        let mtm = mat_mul_raw(&transpose_raw(m), m);
1037        let k = mtm.rows;
1038        for i in 0..k {
1039            for j in 0..k {
1040                let expected = if i == j { 1.0 } else { 0.0 };
1041                let diff = (mtm.data[i * k + j] - expected).abs();
1042                assert!(
1043                    diff < tol,
1044                    "orthonormality violated at ({i},{j}): got {}, expected {expected}",
1045                    mtm.data[i * k + j]
1046                );
1047            }
1048        }
1049    }
1050
1051    fn assert_singular_values_sorted(s: &[f64]) {
1052        for (i, &si) in s.iter().enumerate() {
1053            assert!(si >= -1e-14, "singular value s[{i}] = {si} is negative");
1054        }
1055        for i in 1..s.len() {
1056            assert!(
1057                s[i - 1] >= s[i] - 1e-12,
1058                "not descending: s[{}]={} < s[{}]={}",
1059                i - 1,
1060                s[i - 1],
1061                i,
1062                s[i]
1063            );
1064        }
1065    }
1066
1067    // ---- SVD correctness tests ----
1068
1069    #[test]
1070    fn test_identity_3x3() {
1071        let mut data = vec![0.0; 9];
1072        for i in 0..3 {
1073            data[i * 3 + i] = 1.0;
1074        }
1075        let a = Matrix {
1076            data,
1077            rows: 3,
1078            cols: 3,
1079        };
1080        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1081        for &si in &s {
1082            assert!((si - 1.0).abs() < 1e-10, "expected 1.0, got {si}");
1083        }
1084        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1085        assert_orthonormal_columns(&u, 1e-10);
1086        assert_orthonormal_columns(&v, 1e-10);
1087        assert_singular_values_sorted(&s);
1088    }
1089
1090    #[test]
1091    fn test_diagonal_matrix() {
1092        let a = Matrix {
1093            data: vec![5.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 1.0],
1094            rows: 3,
1095            cols: 3,
1096        };
1097        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1098        assert!((s[0] - 5.0).abs() < 1e-10);
1099        assert!((s[1] - 3.0).abs() < 1e-10);
1100        assert!((s[2] - 1.0).abs() < 1e-10);
1101        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1102        assert_orthonormal_columns(&u, 1e-10);
1103        assert_orthonormal_columns(&v, 1e-10);
1104    }
1105
1106    #[test]
1107    fn test_known_2x2() {
1108        // [[3,2],[2,3]] has singular values 5 and 1
1109        let a = Matrix {
1110            data: vec![3.0, 2.0, 2.0, 3.0],
1111            rows: 2,
1112            cols: 2,
1113        };
1114        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1115        assert!((s[0] - 5.0).abs() < 1e-10, "expected s[0]=5, got {}", s[0]);
1116        assert!((s[1] - 1.0).abs() < 1e-10, "expected s[1]=1, got {}", s[1]);
1117        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1118        assert_orthonormal_columns(&u, 1e-10);
1119        assert_orthonormal_columns(&v, 1e-10);
1120    }
1121
1122    #[test]
1123    fn test_known_3x3() {
1124        let a = Matrix {
1125            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1126            rows: 3,
1127            cols: 3,
1128        };
1129        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1130        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1131        assert_orthonormal_columns(&u, 1e-10);
1132        assert_orthonormal_columns(&v, 1e-10);
1133        assert_singular_values_sorted(&s);
1134    }
1135
1136    #[test]
1137    fn test_known_4x4() {
1138        let a = Matrix {
1139            data: vec![
1140                2.0, -1.0, 0.0, 0.0, -1.0, 2.0, -1.0, 0.0, 0.0, -1.0, 2.0, -1.0, 0.0, 0.0, -1.0,
1141                2.0,
1142            ],
1143            rows: 4,
1144            cols: 4,
1145        };
1146        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1147        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1148        assert_orthonormal_columns(&u, 1e-10);
1149        assert_orthonormal_columns(&v, 1e-10);
1150        assert_singular_values_sorted(&s);
1151    }
1152
1153    #[test]
1154    fn test_tall_rectangular() {
1155        let a = Matrix {
1156            data: vec![
1157                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1158            ],
1159            rows: 5,
1160            cols: 3,
1161        };
1162        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1163        assert_eq!(u.rows, 5);
1164        assert_eq!(u.cols, 3);
1165        assert_eq!(s.len(), 3);
1166        assert_eq!(v.rows, 3);
1167        assert_eq!(v.cols, 3);
1168        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1169        assert_orthonormal_columns(&u, 1e-10);
1170        assert_orthonormal_columns(&v, 1e-10);
1171    }
1172
1173    #[test]
1174    fn test_wide_rectangular() {
1175        let a = Matrix {
1176            data: vec![
1177                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1178            ],
1179            rows: 3,
1180            cols: 5,
1181        };
1182        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1183        assert_eq!(u.rows, 3);
1184        assert_eq!(u.cols, 3);
1185        assert_eq!(s.len(), 3);
1186        assert_eq!(v.rows, 5);
1187        assert_eq!(v.cols, 3);
1188        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1189        assert_orthonormal_columns(&u, 1e-10);
1190        assert_orthonormal_columns(&v, 1e-10);
1191    }
1192
1193    #[test]
1194    fn test_single_element() {
1195        let a = Matrix {
1196            data: vec![7.0],
1197            rows: 1,
1198            cols: 1,
1199        };
1200        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1201        assert!((s[0] - 7.0).abs() < 1e-10);
1202        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1203    }
1204
1205    #[test]
1206    fn test_single_element_negative() {
1207        let a = Matrix {
1208            data: vec![-5.0],
1209            rows: 1,
1210            cols: 1,
1211        };
1212        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1213        assert!((s[0] - 5.0).abs() < 1e-10);
1214        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1215    }
1216
1217    #[test]
1218    fn test_single_row() {
1219        let a = Matrix {
1220            data: vec![1.0, 2.0, 3.0, 4.0],
1221            rows: 1,
1222            cols: 4,
1223        };
1224        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1225        assert_eq!(s.len(), 1);
1226        let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1227        assert!((s[0] - expected).abs() < 1e-10);
1228        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1229    }
1230
1231    #[test]
1232    fn test_single_column() {
1233        let a = Matrix {
1234            data: vec![1.0, 2.0, 3.0, 4.0],
1235            rows: 4,
1236            cols: 1,
1237        };
1238        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1239        assert_eq!(s.len(), 1);
1240        let expected = (1.0f64 + 4.0 + 9.0 + 16.0).sqrt();
1241        assert!((s[0] - expected).abs() < 1e-10);
1242        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1243    }
1244
1245    #[test]
1246    fn test_zero_matrix() {
1247        let a = Matrix::zeros(3, 3);
1248        let (_u, s, _v) = GolubKahanSvd::new().compute(&a).unwrap();
1249        for &si in &s {
1250            assert!(si.abs() < 1e-12);
1251        }
1252        assert_singular_values_sorted(&s);
1253    }
1254
1255    #[test]
1256    fn test_rank_deficient() {
1257        // B8: rank-2 in 3x3 (row 3 = row 1 + row 2)
1258        let a = Matrix {
1259            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0, 7.0, 9.0],
1260            rows: 3,
1261            cols: 3,
1262        };
1263        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1264        assert!(
1265            s[2] < 1e-10,
1266            "third singular value should be ~0, got {}",
1267            s[2]
1268        );
1269        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1270        assert_orthonormal_columns(&u, 1e-10);
1271        assert_orthonormal_columns(&v, 1e-10);
1272        assert_singular_values_sorted(&s);
1273    }
1274
1275    #[test]
1276    fn test_rank_one() {
1277        // B9: outer product [1,2,3] * [4,5]
1278        let a = Matrix {
1279            data: vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0],
1280            rows: 3,
1281            cols: 2,
1282        };
1283        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1284        let norm_u = (1.0f64 + 4.0 + 9.0).sqrt();
1285        let norm_v = (16.0f64 + 25.0).sqrt();
1286        let expected_s0 = norm_u * norm_v;
1287        assert!(
1288            (s[0] - expected_s0).abs() < 1e-8,
1289            "expected s[0]={expected_s0}, got {}",
1290            s[0]
1291        );
1292        assert!(s[1] < 1e-10, "expected s[1]~0, got {}", s[1]);
1293        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1294    }
1295
1296    #[test]
1297    fn test_repeated_singular_values() {
1298        // B10: diag(4, 4, 2)
1299        let a = Matrix {
1300            data: vec![4.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 2.0],
1301            rows: 3,
1302            cols: 3,
1303        };
1304        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1305        assert!((s[0] - 4.0).abs() < 1e-10);
1306        assert!((s[1] - 4.0).abs() < 1e-10);
1307        assert!((s[2] - 2.0).abs() < 1e-10);
1308        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1309        assert_orthonormal_columns(&u, 1e-10);
1310        assert_orthonormal_columns(&v, 1e-10);
1311    }
1312
1313    #[test]
1314    fn test_diagonal_with_zeros() {
1315        // B4, B8: diag(5, 0, 3)
1316        let a = Matrix {
1317            data: vec![5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0],
1318            rows: 3,
1319            cols: 3,
1320        };
1321        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1322        assert!((s[0] - 5.0).abs() < 1e-10);
1323        assert!((s[1] - 3.0).abs() < 1e-10);
1324        assert!(s[2] < 1e-10);
1325        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1326        assert_singular_values_sorted(&s);
1327    }
1328
1329    #[test]
1330    fn test_ill_conditioned() {
1331        // B11: condition number > 1e10
1332        let a = Matrix {
1333            data: vec![1.0, 0.0, 0.0, 0.0, 1e-12, 0.0, 0.0, 0.0, 1e-6],
1334            rows: 3,
1335            cols: 3,
1336        };
1337        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1338        assert!((s[0] - 1.0).abs() < 1e-8);
1339        assert_reconstruction(&a, &u, &s, &v, 1e-6);
1340        assert_singular_values_sorted(&s);
1341    }
1342
1343    #[test]
1344    fn test_extreme_small_values() {
1345        // B12: values near underflow
1346        let a = Matrix {
1347            data: vec![1e-300, 0.0, 0.0, 2e-300],
1348            rows: 2,
1349            cols: 2,
1350        };
1351        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1352        assert!(s[0].is_finite());
1353        assert!(s[1].is_finite());
1354        assert_singular_values_sorted(&s);
1355        assert_reconstruction(&a, &u, &s, &v, 1e-290);
1356    }
1357
1358    #[test]
1359    fn test_extreme_large_values() {
1360        // B12: values near overflow
1361        let a = Matrix {
1362            data: vec![1e+150, 0.0, 0.0, 2e+150],
1363            rows: 2,
1364            cols: 2,
1365        };
1366        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1367        assert!(s[0].is_finite());
1368        assert!(s[1].is_finite());
1369        for &val in &u.data {
1370            assert!(val.is_finite());
1371        }
1372        for &val in &v.data {
1373            assert!(val.is_finite());
1374        }
1375        assert_singular_values_sorted(&s);
1376    }
1377
1378    #[test]
1379    fn test_convergence_64x64() {
1380        // B15
1381        use rand::Rng;
1382        use rand::SeedableRng;
1383        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1384        let data: Vec<f64> = (0..64 * 64).map(|_| rng.gen_range(-1.0..1.0)).collect();
1385        let a = Matrix {
1386            data,
1387            rows: 64,
1388            cols: 64,
1389        };
1390        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1391        assert_reconstruction(&a, &u, &s, &v, 1e-8);
1392        assert_orthonormal_columns(&u, 1e-8);
1393        assert_orthonormal_columns(&v, 1e-8);
1394        assert_singular_values_sorted(&s);
1395    }
1396
1397    #[test]
1398    fn test_convergence_128x128() {
1399        // B15
1400        use rand::Rng;
1401        use rand::SeedableRng;
1402        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1403        let data: Vec<f64> = (0..128 * 128).map(|_| rng.gen_range(-1.0..1.0)).collect();
1404        let a = Matrix {
1405            data,
1406            rows: 128,
1407            cols: 128,
1408        };
1409        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1410        assert_reconstruction(&a, &u, &s, &v, 1e-8);
1411        assert_orthonormal_columns(&u, 1e-8);
1412        assert_orthonormal_columns(&v, 1e-8);
1413        assert_singular_values_sorted(&s);
1414    }
1415
1416    #[test]
1417    fn test_almost_bidiagonal() {
1418        // B16
1419        let a = Matrix {
1420            data: vec![
1421                5.0, 2.0, 0.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 0.0, 3.0, 0.5, 0.0, 0.0, 0.0, 1.0,
1422            ],
1423            rows: 4,
1424            cols: 4,
1425        };
1426        let (u, s, v) = GolubKahanSvd::new().compute(&a).unwrap();
1427        assert_reconstruction(&a, &u, &s, &v, 1e-10);
1428        assert_singular_values_sorted(&s);
1429    }
1430
1431    #[test]
1432    fn test_custom_tolerance() {
1433        // B17
1434        let a = Matrix {
1435            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1436            rows: 3,
1437            cols: 3,
1438        };
1439        let (u, s, v) = GolubKahanSvd::new()
1440            .with_tolerance(1e-15)
1441            .compute(&a)
1442            .unwrap();
1443        assert_reconstruction(&a, &u, &s, &v, 1e-12);
1444    }
1445
1446    #[test]
1447    fn test_low_max_iter_triggers_error() {
1448        // B18: factor=0 -> max_iter=0, any non-trivial matrix must fail
1449        let a = Matrix {
1450            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1451            rows: 3,
1452            cols: 3,
1453        };
1454        let result = GolubKahanSvd::new().with_max_iter_factor(0).compute(&a);
1455        assert!(result.is_err(), "expected convergence error with factor=0");
1456        let err = result.unwrap_err();
1457        assert!(matches!(err, SvdError::Convergence { .. }));
1458    }
1459
1460    #[test]
1461    fn test_determinism() {
1462        // B19: same input -> identical output
1463        let a = Matrix {
1464            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 10.0],
1465            rows: 3,
1466            cols: 3,
1467        };
1468        let svd = GolubKahanSvd::new();
1469        let (u1, s1, v1) = svd.compute(&a).unwrap();
1470        let (u2, s2, v2) = svd.compute(&a).unwrap();
1471        assert_eq!(s1, s2, "singular values differ");
1472        assert_eq!(u1.data, u2.data, "U differs");
1473        assert_eq!(v1.data, v2.data, "V differs");
1474    }
1475}