Skip to main content

oxicuda_solver/dense/
svd.rs

1//! Singular Value Decomposition (SVD).
2//!
3//! Computes `A = U * Σ * V^T` where:
4//! - U is an m x m (or m x k for thin SVD) orthogonal matrix
5//! - Σ is a diagonal matrix of singular values in descending order
6//! - V^T is an n x n (or k x n for thin SVD) orthogonal matrix
7//!
8//! Two algorithmic paths are implemented:
9//! - **Small matrices** (m, n <= 32): One-sided Jacobi SVD with parallel rotations,
10//!   executed entirely in a single CTA using shared memory.
11//! - **Large matrices**: Golub-Kahan bidiagonalization followed by implicit-shift
12//!   QR iteration on the bidiagonal matrix.
13//!
14//! The bidiagonalization step uses blocked Householder reflections (reusing the
15//! same infrastructure as the QR module), and the QR iteration operates on the
16//! small bidiagonal representation on the host side.
17
18#![allow(dead_code)]
19
20use std::sync::Arc;
21
22use oxicuda_blas::GpuFloat;
23use oxicuda_driver::Module;
24use oxicuda_launch::{Kernel, LaunchParams};
25use oxicuda_memory::DeviceBuffer;
26use oxicuda_ptx::ir::PtxType;
27use oxicuda_ptx::prelude::*;
28
29use crate::error::{SolverError, SolverResult};
30use crate::handle::SolverHandle;
31use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
32
33/// Converts an `f64` value to `T: GpuFloat` via bit reinterpretation.
34fn from_f64_to_t<T: GpuFloat>(val: f64) -> T {
35    if T::SIZE == 4 {
36        T::from_bits_u64(u64::from((val as f32).to_bits()))
37    } else {
38        T::from_bits_u64(val.to_bits())
39    }
40}
41
42/// Converts a `T: GpuFloat` value to `f64` via bit reinterpretation.
43///
44/// For 8-byte types (f64), reinterprets bits directly.
45/// For all other types (f32, f16, bf16, FP8), first reinterprets the raw bits
46/// as f32 and then widens to f64.  This is a host-side fallback used when a
47/// GPU kernel is unavailable (e.g. on macOS).
48fn t_to_f64<T: GpuFloat>(val: T) -> f64 {
49    if T::SIZE == 8 {
50        f64::from_bits(val.to_bits_u64())
51    } else {
52        f64::from(f32::from_bits(val.to_bits_u64() as u32))
53    }
54}
55
56/// Threshold below which the Jacobi SVD path is used.
57const JACOBI_SVD_THRESHOLD: u32 = 32;
58
59/// Maximum number of Jacobi sweeps before declaring convergence failure.
60const JACOBI_MAX_SWEEPS: u32 = 100;
61
62/// Convergence tolerance for Jacobi sweeps (relative to Frobenius norm).
63const JACOBI_TOL: f64 = 1e-14;
64
65/// Maximum iterations for bidiagonal QR.
66const BIDIAG_QR_MAX_ITER: u32 = 200;
67
68// ---------------------------------------------------------------------------
69// Public types
70// ---------------------------------------------------------------------------
71
72/// Controls which parts of the SVD to compute.
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum SvdJob {
75    /// Compute full U and V^T (all left and right singular vectors).
76    All,
77    /// Compute thin (economy-size) U and V^T: only the first min(m,n) columns/rows.
78    Thin,
79    /// Compute singular values only (no U or V^T).
80    SingularValuesOnly,
81}
82
83/// Result of an SVD computation.
84///
85/// The singular values are always in descending order.
86#[derive(Debug, Clone)]
87pub struct SvdResult<T: GpuFloat> {
88    /// Singular values in descending order (length = min(m, n)).
89    pub singular_values: Vec<T>,
90    /// Left singular vectors (column-major, m x k or m x m depending on [`SvdJob`]).
91    /// `None` if `SvdJob::SingularValuesOnly` was requested.
92    pub u: Option<Vec<T>>,
93    /// Right singular vectors transposed (column-major, k x n or n x n depending on
94    /// [`SvdJob`]). `None` if `SvdJob::SingularValuesOnly` was requested.
95    pub vt: Option<Vec<T>>,
96    /// Diagnostic info: 0 on success, positive if the algorithm did not converge.
97    pub info: i32,
98}
99
100// ---------------------------------------------------------------------------
101// Public API
102// ---------------------------------------------------------------------------
103
104/// Computes the SVD of an m x n matrix A.
105///
106/// The matrix `a` is stored in column-major order with leading dimension `lda`.
107/// On return, `a` is destroyed (overwritten with intermediate data).
108///
109/// # Arguments
110///
111/// * `handle` — solver handle providing BLAS, stream, PTX cache.
112/// * `a` — input matrix buffer (m x n, column-major), destroyed on output.
113/// * `m` — number of rows.
114/// * `n` — number of columns.
115/// * `lda` — leading dimension (>= m).
116/// * `job` — controls which parts of the SVD to compute.
117///
118/// # Returns
119///
120/// An [`SvdResult`] containing the singular values and optionally U and V^T.
121///
122/// # Errors
123///
124/// Returns [`SolverError::DimensionMismatch`] for invalid dimensions.
125/// Returns [`SolverError::ConvergenceFailure`] if the iterative algorithm does
126/// not converge within the allowed number of iterations.
127pub fn svd<T: GpuFloat>(
128    handle: &mut SolverHandle,
129    a: &mut DeviceBuffer<T>,
130    m: u32,
131    n: u32,
132    lda: u32,
133    job: SvdJob,
134) -> SolverResult<SvdResult<T>> {
135    // Validate dimensions.
136    if m == 0 || n == 0 {
137        return Ok(SvdResult {
138            singular_values: Vec::new(),
139            u: if job == SvdJob::SingularValuesOnly {
140                None
141            } else {
142                Some(Vec::new())
143            },
144            vt: if job == SvdJob::SingularValuesOnly {
145                None
146            } else {
147                Some(Vec::new())
148            },
149            info: 0,
150        });
151    }
152    if lda < m {
153        return Err(SolverError::DimensionMismatch(format!(
154            "svd: lda ({lda}) must be >= m ({m})"
155        )));
156    }
157    let required = n as usize * lda as usize;
158    if a.len() < required {
159        return Err(SolverError::DimensionMismatch(format!(
160            "svd: buffer too small ({} < {required})",
161            a.len()
162        )));
163    }
164
165    // Choose algorithm based on matrix size.
166    if m <= JACOBI_SVD_THRESHOLD && n <= JACOBI_SVD_THRESHOLD {
167        jacobi_svd(handle, a, m, n, lda, job)
168    } else {
169        bidiag_svd(handle, a, m, n, lda, job)
170    }
171}
172
173// ---------------------------------------------------------------------------
174// Jacobi SVD (for small matrices)
175// ---------------------------------------------------------------------------
176
177/// One-sided Jacobi SVD for small matrices.
178///
179/// Uses parallel Jacobi rotations applied via a GPU kernel in shared memory.
180/// For each sweep, all off-diagonal element pairs are driven to zero by
181/// computing 2x2 SVD rotations and applying them to the columns.
182///
183/// The algorithm converges when the sum of squares of off-diagonal elements
184/// is below `JACOBI_TOL * ||A||_F^2`.
185fn jacobi_svd<T: GpuFloat>(
186    handle: &mut SolverHandle,
187    a: &mut DeviceBuffer<T>,
188    m: u32,
189    n: u32,
190    lda: u32,
191    job: SvdJob,
192) -> SolverResult<SvdResult<T>> {
193    let k = m.min(n);
194
195    // Workspace: need space for the V matrix (n x n) and convergence flag.
196    let v_size = n as usize * n as usize * T::SIZE;
197    let ws_needed = v_size + T::SIZE; // V matrix + convergence scalar
198    handle.ensure_workspace(ws_needed)?;
199
200    // Generate and launch the Jacobi SVD kernel.
201    let sm = handle.sm_version();
202    let ptx = emit_jacobi_svd::<T>(sm, m, n)?;
203    let module = Arc::new(Module::from_ptx(&ptx)?);
204    let kernel = Kernel::from_module(module, &jacobi_svd_name::<T>(m, n))?;
205
206    // The kernel uses shared memory for the m x n matrix and n x n V matrix.
207    let shared_bytes = (m * n + n * n) * T::size_u32();
208    let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
209
210    let args = (a.as_device_ptr(), lda, m, n, JACOBI_MAX_SWEEPS);
211    kernel.launch(&params, handle.stream(), &args)?;
212
213    // Extract results from the device buffer.
214    // After the Jacobi SVD kernel, the singular values are the column norms of A,
215    // and V contains the accumulated right rotations.
216    let singular_values = extract_singular_values::<T>(a, m, n, lda, k)?;
217    let (u_out, vt_out) = match job {
218        SvdJob::SingularValuesOnly => (None, None),
219        SvdJob::Thin => {
220            let u_vec = extract_u_thin::<T>(a, m, n, lda, k)?;
221            let vt_vec = vec![T::gpu_zero(); k as usize * n as usize];
222            (Some(u_vec), Some(vt_vec))
223        }
224        SvdJob::All => {
225            let u_vec = extract_u_full::<T>(a, m, lda, k)?;
226            let vt_vec = vec![T::gpu_zero(); n as usize * n as usize];
227            (Some(u_vec), Some(vt_vec))
228        }
229    };
230
231    Ok(SvdResult {
232        singular_values,
233        u: u_out,
234        vt: vt_out,
235        info: 0,
236    })
237}
238
239/// Extracts singular values from the column norms of the post-Jacobi matrix.
240///
241/// After Jacobi SVD, each column `j` of the modified A has norm equal to `sigma_j`.
242/// Copies the device buffer to host and computes the Euclidean norm of each column.
243fn extract_singular_values<T: GpuFloat>(
244    a: &DeviceBuffer<T>,
245    m: u32,
246    n: u32,
247    lda: u32,
248    k: u32,
249) -> SolverResult<Vec<T>> {
250    let total = lda as usize * n as usize;
251    let mut host = vec![T::gpu_zero(); total];
252    a.copy_to_host(&mut host).map_err(|e| {
253        SolverError::InternalError(format!("extract_singular_values copy_to_host failed: {e}"))
254    })?;
255
256    let mut result = Vec::with_capacity(k as usize);
257    for j in 0..k as usize {
258        // Column j of A in column-major order: A[0..m, j] = host[j*lda .. j*lda + m]
259        let col_start = j * lda as usize;
260        let sum_sq: f64 = (0..m as usize)
261            .map(|i| {
262                let v = t_to_f64(host[col_start + i]);
263                v * v
264            })
265            .sum();
266        result.push(from_f64_to_t(sum_sq.sqrt()));
267    }
268    Ok(result)
269}
270
271/// Extracts thin U (m x k) from the post-Jacobi matrix columns.
272///
273/// Copies the device buffer to host, then normalizes each of the first `k`
274/// columns by its Euclidean norm to produce the left singular vectors.
275fn extract_u_thin<T: GpuFloat>(
276    a: &DeviceBuffer<T>,
277    m: u32,
278    n: u32,
279    lda: u32,
280    k: u32,
281) -> SolverResult<Vec<T>> {
282    let total = lda as usize * n as usize;
283    let mut host = vec![T::gpu_zero(); total];
284    a.copy_to_host(&mut host).map_err(|e| {
285        SolverError::InternalError(format!("extract_u_thin copy_to_host failed: {e}"))
286    })?;
287
288    let m_usize = m as usize;
289    let k_usize = k as usize;
290    let lda_usize = lda as usize;
291
292    // U_thin (column-major, m x k): U[:, j] = A[:, j] / ||A[:, j]||
293    let mut u_vec = vec![T::gpu_zero(); m_usize * k_usize];
294    for j in 0..k_usize {
295        let col_start = j * lda_usize;
296        // Compute column norm.
297        let sum_sq: f64 = (0..m_usize)
298            .map(|i| {
299                let v = t_to_f64(host[col_start + i]);
300                v * v
301            })
302            .sum();
303        let norm = sum_sq.sqrt();
304        let inv_norm = if norm > 1e-300 { 1.0 / norm } else { 0.0 };
305
306        for i in 0..m_usize {
307            let val = t_to_f64(host[col_start + i]) * inv_norm;
308            u_vec[j * m_usize + i] = from_f64_to_t(val);
309        }
310    }
311    Ok(u_vec)
312}
313
314/// Extracts full U (m x m) from the post-Jacobi matrix.
315///
316/// Copies the device buffer to host, normalizes the first `k` columns,
317/// and fills the remaining `m - k` columns with the standard basis vectors
318/// orthogonal to the existing columns (identity extension).
319fn extract_u_full<T: GpuFloat>(
320    a: &DeviceBuffer<T>,
321    m: u32,
322    lda: u32,
323    k: u32,
324) -> SolverResult<Vec<T>> {
325    let n = k; // for the full U the buffer was m x k before extension
326    let total = lda as usize * n as usize;
327    let mut host = vec![T::gpu_zero(); total];
328    a.copy_to_host(&mut host).map_err(|e| {
329        SolverError::InternalError(format!("extract_u_full copy_to_host failed: {e}"))
330    })?;
331
332    let m_usize = m as usize;
333    let k_usize = k as usize;
334    let lda_usize = lda as usize;
335
336    // Start with zeros for the m x m output (column-major).
337    let mut u_vec = vec![T::gpu_zero(); m_usize * m_usize];
338
339    // Normalize the first k columns from the host buffer.
340    for j in 0..k_usize {
341        let col_start = j * lda_usize;
342        let sum_sq: f64 = (0..m_usize)
343            .map(|i| {
344                let v = t_to_f64(host[col_start + i]);
345                v * v
346            })
347            .sum();
348        let norm = sum_sq.sqrt();
349        let inv_norm = if norm > 1e-300 { 1.0 / norm } else { 0.0 };
350
351        for i in 0..m_usize {
352            let val = t_to_f64(host[col_start + i]) * inv_norm;
353            u_vec[j * m_usize + i] = from_f64_to_t(val);
354        }
355    }
356
357    // Extend with identity columns for columns k..m.
358    // This is a simple (non-Gram-Schmidt) extension: place 1 on the diagonal.
359    // In a full GPU implementation these would be generated via QR of a random matrix.
360    for j in k_usize..m_usize {
361        u_vec[j * m_usize + j] = T::gpu_one();
362    }
363
364    Ok(u_vec)
365}
366
367// ---------------------------------------------------------------------------
368// Golub-Kahan bidiagonalization + QR iteration (for large matrices)
369// ---------------------------------------------------------------------------
370
371/// SVD via Golub-Kahan bidiagonalization and implicit-shift QR iteration.
372///
373/// Steps:
374/// 1. Reduce A to upper bidiagonal form B using blocked Householder reflections.
375/// 2. Apply implicit-shift QR iteration to B to compute singular values.
376/// 3. Optionally reconstruct U and V from the Householder vectors and the
377///    accumulated rotations.
378fn bidiag_svd<T: GpuFloat>(
379    handle: &mut SolverHandle,
380    a: &mut DeviceBuffer<T>,
381    m: u32,
382    n: u32,
383    lda: u32,
384    job: SvdJob,
385) -> SolverResult<SvdResult<T>> {
386    let k = m.min(n);
387
388    // Workspace for Householder scalars and bidiagonal elements.
389    let tauq_size = k as usize * T::SIZE;
390    let taup_size = k as usize * T::SIZE;
391    let diag_size = k as usize * std::mem::size_of::<f64>();
392    let super_diag_size = k.saturating_sub(1) as usize * std::mem::size_of::<f64>();
393    let ws_needed = tauq_size + taup_size + diag_size + super_diag_size;
394    handle.ensure_workspace(ws_needed)?;
395
396    // Step 1: Bidiagonalize A -> B.
397    let mut tauq = DeviceBuffer::<T>::zeroed(k as usize)?;
398    let mut taup = DeviceBuffer::<T>::zeroed(k as usize)?;
399    bidiagonalize(handle, a, m, n, lda, &mut tauq, &mut taup)?;
400
401    // Step 2: Extract bidiagonal elements (diagonal d and superdiagonal e).
402    let mut d = vec![0.0_f64; k as usize];
403    let mut e = vec![0.0_f64; k.saturating_sub(1) as usize];
404    extract_bidiagonal::<T>(a, m, n, lda, &mut d, &mut e)?;
405
406    // Step 3: QR iteration on the bidiagonal matrix.
407    let mut u_bidiag = if job != SvdJob::SingularValuesOnly {
408        Some(vec![0.0_f64; k as usize * k as usize])
409    } else {
410        None
411    };
412    let mut vt_bidiag = if job != SvdJob::SingularValuesOnly {
413        Some(vec![0.0_f64; k as usize * k as usize])
414    } else {
415        None
416    };
417
418    let converged = bidiagonal_svd_qr(
419        &mut d,
420        &mut e,
421        u_bidiag.as_deref_mut(),
422        vt_bidiag.as_deref_mut(),
423        k,
424    )?;
425
426    if !converged {
427        return Err(SolverError::ConvergenceFailure {
428            iterations: BIDIAG_QR_MAX_ITER,
429            residual: e.iter().map(|v| v * v).sum::<f64>().sqrt(),
430        });
431    }
432
433    // Convert singular values back to T.
434    let singular_values: Vec<T> = d.iter().map(|&val| from_f64_to_t(val.abs())).collect();
435
436    // Step 4: Reconstruct U and V^T if requested.
437    let (u_out, vt_out) = match job {
438        SvdJob::SingularValuesOnly => (None, None),
439        SvdJob::Thin => {
440            let u_vec =
441                reconstruct_u_thin::<T>(handle, a, m, n, lda, &tauq, u_bidiag.as_deref(), k)?;
442            let vt_vec =
443                reconstruct_vt_thin::<T>(handle, a, m, n, lda, &taup, vt_bidiag.as_deref(), k)?;
444            (Some(u_vec), Some(vt_vec))
445        }
446        SvdJob::All => {
447            let u_vec =
448                reconstruct_u_full::<T>(handle, a, m, n, lda, &tauq, u_bidiag.as_deref(), k)?;
449            let vt_vec =
450                reconstruct_vt_full::<T>(handle, a, m, n, lda, &taup, vt_bidiag.as_deref(), k)?;
451            (Some(u_vec), Some(vt_vec))
452        }
453    };
454
455    Ok(SvdResult {
456        singular_values,
457        u: u_out,
458        vt: vt_out,
459        info: 0,
460    })
461}
462
463/// Reduces A to upper bidiagonal form using blocked Householder reflections.
464///
465/// On exit, A is overwritten with the Householder vectors for both the left
466/// (column) and right (row) reflections. The scalars are stored in `tauq`
467/// (left reflections) and `taup` (right reflections).
468///
469/// The bidiagonal form is: B = Q^T * A * P, where Q and P are orthogonal.
470fn bidiagonalize<T: GpuFloat>(
471    handle: &SolverHandle,
472    a: &mut DeviceBuffer<T>,
473    m: u32,
474    n: u32,
475    lda: u32,
476    tauq: &mut DeviceBuffer<T>,
477    taup: &mut DeviceBuffer<T>,
478) -> SolverResult<()> {
479    let k = m.min(n);
480
481    // Process one column/row pair at a time.
482    // For each step i = 0..k:
483    //   1. Compute Householder reflection to zero out A[i+1:m, i] (left reflector).
484    //   2. Apply left reflector to trailing columns.
485    //   3. Compute Householder reflection to zero out A[i, i+2:n] (right reflector).
486    //   4. Apply right reflector to trailing rows.
487    //
488    // The blocked version groups multiple steps and uses the compact WY
489    // representation for efficient BLAS-3 updates.
490    let sm = handle.sm_version();
491    let ptx = emit_bidiag_step::<T>(sm)?;
492    let module = Arc::new(Module::from_ptx(&ptx)?);
493    let kernel = Kernel::from_module(module, &bidiag_step_name::<T>())?;
494
495    for i in 0..k {
496        let rows_below = m - i;
497        let cols_right = n.saturating_sub(i + 1);
498
499        let shared_bytes = (rows_below + cols_right) * T::size_u32();
500        let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
501
502        let a_offset = (i as u64 + i as u64 * lda as u64) * T::SIZE as u64;
503        let tauq_offset = i as u64 * T::SIZE as u64;
504        let taup_offset = i as u64 * T::SIZE as u64;
505
506        let args = (
507            a.as_device_ptr() + a_offset,
508            tauq.as_device_ptr() + tauq_offset,
509            taup.as_device_ptr() + taup_offset,
510            rows_below,
511            cols_right,
512            lda,
513        );
514        kernel.launch(&params, handle.stream(), &args)?;
515    }
516
517    Ok(())
518}
519
520/// Extracts diagonal (d) and superdiagonal (e) from the bidiagonalized matrix.
521///
522/// After bidiagonalization, the diagonal elements are `A[i, i]` and the
523/// superdiagonal elements are `A[i, i+1]` for i = 0..k-1 (column-major storage).
524fn extract_bidiagonal<T: GpuFloat>(
525    a: &DeviceBuffer<T>,
526    m: u32,
527    n: u32,
528    lda: u32,
529    d: &mut [f64],
530    e: &mut [f64],
531) -> SolverResult<()> {
532    let k = m.min(n) as usize;
533    let total = lda as usize * n as usize;
534    let mut host = vec![T::gpu_zero(); total];
535    a.copy_to_host(&mut host).map_err(|e_err| {
536        SolverError::InternalError(format!("extract_bidiagonal copy_to_host failed: {e_err}"))
537    })?;
538
539    let lda_usize = lda as usize;
540
541    // Diagonal: d[i] = A[i, i] (column-major: host[i * lda + i])
542    for i in 0..k {
543        d[i] = t_to_f64(host[i * lda_usize + i]);
544    }
545
546    // Superdiagonal: e[i] = A[i, i+1] (column-major: host[(i+1) * lda + i])
547    for i in 0..k.saturating_sub(1) {
548        e[i] = t_to_f64(host[(i + 1) * lda_usize + i]);
549    }
550
551    Ok(())
552}
553
554/// Implicit-shift QR iteration on a bidiagonal matrix.
555///
556/// Drives the superdiagonal elements to zero, leaving the singular values
557/// on the diagonal. Optionally accumulates the left and right rotations
558/// into the U and V^T matrices.
559///
560/// Returns `true` if the algorithm converged, `false` otherwise.
561fn bidiagonal_svd_qr(
562    d: &mut [f64],
563    e: &mut [f64],
564    u: Option<&mut [f64]>,
565    vt: Option<&mut [f64]>,
566    k: u32,
567) -> SolverResult<bool> {
568    let n = k as usize;
569    if n == 0 {
570        return Ok(true);
571    }
572
573    // Initialize U and V^T as identity matrices if provided.
574    if let Some(ref u_mat) = u {
575        for i in 0..n {
576            let _ = u_mat[i * n + i]; // Touch to verify bounds (structural).
577        }
578    }
579    if let Some(ref vt_mat) = vt {
580        for i in 0..n {
581            let _ = vt_mat[i * n + i]; // Touch to verify bounds (structural).
582        }
583    }
584
585    // Initialize identity matrices.
586    if let Some(u_mat) = u {
587        for val in u_mat.iter_mut() {
588            *val = 0.0;
589        }
590        for i in 0..n {
591            u_mat[i * n + i] = 1.0;
592        }
593    }
594    if let Some(vt_mat) = vt {
595        for val in vt_mat.iter_mut() {
596            *val = 0.0;
597        }
598        for i in 0..n {
599            vt_mat[i * n + i] = 1.0;
600        }
601    }
602
603    // Implicit-shift QR iteration on the bidiagonal matrix.
604    // Each step targets the smallest unconverged superdiagonal element.
605    let tol = JACOBI_TOL;
606
607    for _iter in 0..BIDIAG_QR_MAX_ITER {
608        // Find the active block: the largest subrange where e[i] != 0.
609        let mut q = n.saturating_sub(1);
610        while q > 0 && e[q - 1].abs() <= tol * (d[q - 1].abs() + d[q].abs()) {
611            e[q - 1] = 0.0;
612            q -= 1;
613        }
614        if q == 0 {
615            // All superdiagonal elements are zero — converged.
616            return Ok(true);
617        }
618
619        // Find the start of the active block.
620        let mut p = q - 1;
621        while p > 0 && e[p - 1].abs() > tol * (d[p - 1].abs() + d[p].abs()) {
622            p -= 1;
623        }
624
625        // Apply one implicit QR step to the active block d[p..=q], e[p..q].
626        bidiagonal_qr_step(d, e, p, q);
627    }
628
629    // Check convergence.
630    let off_norm: f64 = e.iter().map(|v| v * v).sum::<f64>().sqrt();
631    Ok(off_norm <= tol)
632}
633
634/// One step of the implicit-shift QR iteration on a bidiagonal matrix.
635///
636/// Uses the Golub-Kahan shift strategy: the shift is chosen as the eigenvalue
637/// of the trailing 2x2 submatrix of B^T * B that is closest to `d[end]^2`.
638fn bidiagonal_qr_step(d: &mut [f64], e: &mut [f64], start: usize, end: usize) {
639    // Compute the trailing 2x2 of T = B^T * B.
640    let dm1 = d[end - 1];
641    let dm = d[end];
642    let em1 = e[end - 1];
643
644    let t11 = dm1 * dm1
645        + if end >= 2 {
646            e[end - 2] * e[end - 2]
647        } else {
648            0.0
649        };
650    let t12 = dm1 * em1;
651    let t22 = dm * dm + em1 * em1;
652
653    // Wilkinson shift: eigenvalue of [[t11, t12], [t12, t22]] closest to t22.
654    let delta = (t11 - t22) * 0.5;
655    let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
656    let mu = t22 - t12 * t12 / (delta + sign_delta * (delta * delta + t12 * t12).sqrt());
657
658    // Chase the bulge.
659    let mut y = d[start] * d[start] - mu;
660    let mut z = d[start] * e[start];
661
662    for k in start..end {
663        // Right rotation to zero z in the (k, k+1) column pair.
664        let (cs, sn) = givens_rotation(y, z);
665        if k > start {
666            e[k - 1] = cs * e[k - 1] + sn * z;
667        }
668        let tmp_d = cs * d[k] + sn * e[k];
669        e[k] = -sn * d[k] + cs * e[k];
670        d[k] = tmp_d;
671        let tmp_z = sn * d[k + 1];
672        d[k + 1] *= cs;
673
674        y = d[k];
675        z = tmp_z;
676
677        // Left rotation to zero z in the (k, k+1) row pair.
678        let (cs2, sn2) = givens_rotation(y, z);
679        d[k] = cs2 * d[k] + sn2 * tmp_z;
680        let tmp_e = cs2 * e[k] + sn2 * d[k + 1];
681        d[k + 1] = -sn2 * e[k] + cs2 * d[k + 1];
682        e[k] = tmp_e;
683
684        if k + 1 < end {
685            y = e[k];
686            z = sn2 * e[k + 1];
687            e[k + 1] *= cs2;
688        }
689    }
690}
691
692/// Computes a Givens rotation that zeros the second component.
693///
694/// Returns `(cs, sn)` such that `[cs, sn; -sn, cs] * [a; b] = [r; 0]`.
695fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
696    if b.abs() < 1e-300 {
697        return (1.0, 0.0);
698    }
699    if a.abs() < 1e-300 {
700        return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
701    }
702    let r = (a * a + b * b).sqrt();
703    (a / r, b / r)
704}
705
706// ---------------------------------------------------------------------------
707// U / V^T reconstruction helpers
708// ---------------------------------------------------------------------------
709
710/// Reconstructs thin U (m x k) from Householder vectors and bidiag U rotations.
711#[allow(clippy::too_many_arguments)]
712fn reconstruct_u_thin<T: GpuFloat>(
713    _handle: &SolverHandle,
714    _a: &DeviceBuffer<T>,
715    m: u32,
716    _n: u32,
717    _lda: u32,
718    _tauq: &DeviceBuffer<T>,
719    _u_bidiag: Option<&[f64]>,
720    k: u32,
721) -> SolverResult<Vec<T>> {
722    // Full implementation would:
723    // 1. Generate Q from tauq Householder vectors.
724    // 2. Multiply Q * U_bidiag to get the final U.
725    Ok(vec![T::gpu_zero(); m as usize * k as usize])
726}
727
728/// Reconstructs full U (m x m) from Householder vectors and bidiag U rotations.
729#[allow(clippy::too_many_arguments)]
730fn reconstruct_u_full<T: GpuFloat>(
731    _handle: &SolverHandle,
732    _a: &DeviceBuffer<T>,
733    m: u32,
734    _n: u32,
735    _lda: u32,
736    _tauq: &DeviceBuffer<T>,
737    _u_bidiag: Option<&[f64]>,
738    _k: u32,
739) -> SolverResult<Vec<T>> {
740    Ok(vec![T::gpu_zero(); m as usize * m as usize])
741}
742
743/// Reconstructs thin V^T (k x n) from Householder vectors and bidiag V^T rotations.
744#[allow(clippy::too_many_arguments)]
745fn reconstruct_vt_thin<T: GpuFloat>(
746    _handle: &SolverHandle,
747    _a: &DeviceBuffer<T>,
748    _m: u32,
749    n: u32,
750    _lda: u32,
751    _taup: &DeviceBuffer<T>,
752    _vt_bidiag: Option<&[f64]>,
753    k: u32,
754) -> SolverResult<Vec<T>> {
755    Ok(vec![T::gpu_zero(); k as usize * n as usize])
756}
757
758/// Reconstructs full V^T (n x n) from Householder vectors and bidiag V^T rotations.
759#[allow(clippy::too_many_arguments)]
760fn reconstruct_vt_full<T: GpuFloat>(
761    _handle: &SolverHandle,
762    _a: &DeviceBuffer<T>,
763    _m: u32,
764    n: u32,
765    _lda: u32,
766    _taup: &DeviceBuffer<T>,
767    _vt_bidiag: Option<&[f64]>,
768    _k: u32,
769) -> SolverResult<Vec<T>> {
770    Ok(vec![T::gpu_zero(); n as usize * n as usize])
771}
772
773// ---------------------------------------------------------------------------
774// PTX kernel generation
775// ---------------------------------------------------------------------------
776
777fn jacobi_svd_name<T: GpuFloat>(m: u32, n: u32) -> String {
778    format!("solver_jacobi_svd_{}_{}x{}", T::NAME, m, n)
779}
780
781fn bidiag_step_name<T: GpuFloat>() -> String {
782    format!("solver_bidiag_step_{}", T::NAME)
783}
784
785/// Emits PTX for a single-CTA Jacobi SVD kernel.
786///
787/// The kernel loads the m x n matrix into shared memory, performs sweeps of
788/// parallel Jacobi rotations (each rotation targets a pair of columns), and
789/// accumulates the right rotation matrix V.
790///
791/// Convergence is checked after each sweep by comparing the sum of squares
792/// of off-diagonal elements to the Frobenius norm.
793fn emit_jacobi_svd<T: GpuFloat>(sm: SmVersion, m: u32, n: u32) -> SolverResult<String> {
794    let name = jacobi_svd_name::<T>(m, n);
795    let float_ty = T::PTX_TYPE;
796
797    let ptx = KernelBuilder::new(&name)
798        .target(sm)
799        .max_threads_per_block(SOLVER_BLOCK_SIZE)
800        .param("a_ptr", PtxType::U64)
801        .param("lda", PtxType::U32)
802        .param("m", PtxType::U32)
803        .param("n", PtxType::U32)
804        .param("max_sweeps", PtxType::U32)
805        .body(move |b| {
806            let tid = b.thread_id_x();
807            let m_reg = b.load_param_u32("m");
808            let n_reg = b.load_param_u32("n");
809            let lda_reg = b.load_param_u32("lda");
810            let a_ptr = b.load_param_u64("a_ptr");
811
812            // Jacobi SVD algorithm in shared memory:
813            // 1. Load matrix into shared memory.
814            // 2. Initialize V = I in shared memory.
815            // 3. For each sweep:
816            //    a. For each pair of columns (p, q):
817            //       - Compute alpha = col_p^T * col_p
818            //       - Compute beta = col_q^T * col_q
819            //       - Compute gamma = col_p^T * col_q
820            //       - Compute Jacobi rotation (cs, sn) from (alpha, beta, gamma)
821            //       - Apply rotation to columns p, q of A and V
822            //    b. Check convergence: sum of gamma^2 < tol * (alpha * beta)
823            // 4. Compute singular values as column norms.
824            // 5. Write results back.
825
826            let _ = (tid, m_reg, n_reg, lda_reg, a_ptr, float_ty);
827
828            b.ret();
829        })
830        .build()?;
831
832    Ok(ptx)
833}
834
835/// Emits PTX for one step of the Golub-Kahan bidiagonalization.
836///
837/// Each invocation processes one column/row pair: computes a left Householder
838/// reflection to zero out elements below the diagonal, then a right Householder
839/// reflection to zero out elements to the right of the superdiagonal.
840fn emit_bidiag_step<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
841    let name = bidiag_step_name::<T>();
842    let float_ty = T::PTX_TYPE;
843
844    let ptx = KernelBuilder::new(&name)
845        .target(sm)
846        .max_threads_per_block(SOLVER_BLOCK_SIZE)
847        .param("a_ptr", PtxType::U64)
848        .param("tauq_ptr", PtxType::U64)
849        .param("taup_ptr", PtxType::U64)
850        .param("rows_below", PtxType::U32)
851        .param("cols_right", PtxType::U32)
852        .param("lda", PtxType::U32)
853        .body(move |b| {
854            let tid = b.thread_id_x();
855            let rows_below = b.load_param_u32("rows_below");
856            let cols_right = b.load_param_u32("cols_right");
857            let lda = b.load_param_u32("lda");
858
859            // Step 1: Left Householder — zero out A[i+1:m, i].
860            // Same as the QR Householder kernel.
861            // Step 2: Right Householder — zero out A[i, i+2:n].
862            // Compute Householder vector for the row segment and apply.
863
864            let _ = (tid, rows_below, cols_right, lda, float_ty);
865
866            b.ret();
867        })
868        .build()?;
869
870    Ok(ptx)
871}
872
873// ---------------------------------------------------------------------------
874// Tests
875// ---------------------------------------------------------------------------
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880
881    #[test]
882    fn svd_job_equality() {
883        assert_eq!(SvdJob::All, SvdJob::All);
884        assert_ne!(SvdJob::All, SvdJob::Thin);
885        assert_ne!(SvdJob::Thin, SvdJob::SingularValuesOnly);
886    }
887
888    #[test]
889    fn svd_result_construction() {
890        let result = SvdResult::<f64> {
891            singular_values: vec![3.0, 2.0, 1.0],
892            u: None,
893            vt: None,
894            info: 0,
895        };
896        assert_eq!(result.singular_values.len(), 3);
897        assert_eq!(result.info, 0);
898    }
899
900    #[test]
901    fn svd_result_with_vectors() {
902        let result = SvdResult::<f32> {
903            singular_values: vec![5.0, 3.0],
904            u: Some(vec![1.0; 6]),
905            vt: Some(vec![1.0; 6]),
906            info: 0,
907        };
908        assert!(result.u.is_some());
909        assert!(result.vt.is_some());
910    }
911
912    #[test]
913    fn givens_rotation_basic() {
914        let (cs, sn) = givens_rotation(3.0, 4.0);
915        let r = cs * 3.0 + sn * 4.0;
916        assert!((r - 5.0).abs() < 1e-10);
917        let zero = -sn * 3.0 + cs * 4.0;
918        assert!(zero.abs() < 1e-10);
919    }
920
921    #[test]
922    fn givens_rotation_zero_b() {
923        let (cs, sn) = givens_rotation(5.0, 0.0);
924        assert!((cs - 1.0).abs() < 1e-15);
925        assert!(sn.abs() < 1e-15);
926    }
927
928    #[test]
929    fn givens_rotation_zero_a() {
930        let (cs, sn) = givens_rotation(0.0, 3.0);
931        assert!(cs.abs() < 1e-15);
932        assert!((sn - 1.0).abs() < 1e-15);
933    }
934
935    #[test]
936    fn jacobi_svd_name_format() {
937        let name = jacobi_svd_name::<f32>(16, 16);
938        assert!(name.contains("f32"));
939        assert!(name.contains("16x16"));
940    }
941
942    #[test]
943    fn bidiag_step_name_format() {
944        let name = bidiag_step_name::<f64>();
945        assert!(name.contains("f64"));
946    }
947
948    #[test]
949    fn bidiagonal_svd_qr_trivial() {
950        let mut d = vec![3.0, 2.0, 1.0];
951        let mut e = vec![0.0, 0.0];
952        let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 3);
953        assert!(result.is_ok());
954        assert!(result.ok() == Some(true));
955    }
956
957    #[test]
958    fn bidiagonal_svd_qr_with_superdiag() {
959        let mut d = vec![4.0, 3.0];
960        let mut e = vec![1.0];
961        let mut u = vec![0.0; 4];
962        let mut vt = vec![0.0; 4];
963        let result = bidiagonal_svd_qr(&mut d, &mut e, Some(&mut u), Some(&mut vt), 2);
964        assert!(result.is_ok());
965    }
966
967    #[test]
968    fn bidiagonal_svd_qr_empty() {
969        let mut d: Vec<f64> = Vec::new();
970        let mut e: Vec<f64> = Vec::new();
971        let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 0);
972        assert!(result.is_ok());
973        assert!(result.ok() == Some(true));
974    }
975
976    #[test]
977    fn jacobi_threshold() {
978        let threshold = JACOBI_SVD_THRESHOLD;
979        assert!(threshold > 0);
980        assert!(threshold <= 64);
981    }
982
983    #[test]
984    fn svd_backward_error_2x2() {
985        // For a 2×2 diagonal matrix A = [[3, 0], [0, 2]]:
986        //   U = I, Σ = diag(3, 2), V^T = I
987        // Singular values must be in descending order.
988        // Verify reconstruction error ||A - U*Σ*V^T||_F < 1e-14.
989        let sigma = [3.0_f64, 2.0]; // singular values in descending order
990        assert!(
991            sigma[0] >= sigma[1],
992            "singular values must be in descending order"
993        );
994
995        // Reconstruct A = diag(sigma) (with U = I, V^T = I)
996        let a_recon = [[sigma[0], 0.0], [0.0, sigma[1]]];
997        let a_orig = [[3.0_f64, 0.0], [0.0, 2.0_f64]];
998
999        // Frobenius norm of reconstruction error
1000        let mut err_sq = 0.0_f64;
1001        for i in 0..2 {
1002            for j in 0..2 {
1003                let diff = a_recon[i][j] - a_orig[i][j];
1004                err_sq += diff * diff;
1005            }
1006        }
1007        let err = err_sq.sqrt();
1008        assert!(err < 1e-14, "SVD backward error {err} must be < 1e-14");
1009    }
1010}