Skip to main content

oxicuda_solver/dense/
eig.rs

1//! Symmetric eigenvalue decomposition.
2//!
3//! Computes `A = Q * Λ * Q^T` for a real symmetric matrix A, where:
4//! - Q is an orthogonal matrix whose columns are eigenvectors
5//! - Λ is a diagonal matrix of eigenvalues in ascending order
6//!
7//! The algorithm proceeds in two stages:
8//! 1. **Tridiagonalization**: Reduce A to tridiagonal form T via blocked Householder
9//!    reflections: `A = Q_1 * T * Q_1^T`.
10//! 2. **Tridiagonal QR iteration**: Apply implicit-shift QR iteration to T to
11//!    compute eigenvalues (and optionally eigenvectors).
12//! 3. **Back-transformation**: If eigenvectors are requested, accumulate the
13//!    Householder reflections and QR rotations: `Q = Q_1 * Q_2`.
14
15#![allow(dead_code)]
16
17use std::sync::Arc;
18
19use oxicuda_blas::GpuFloat;
20use oxicuda_driver::Module;
21use oxicuda_launch::{Kernel, LaunchParams};
22use oxicuda_memory::DeviceBuffer;
23use oxicuda_ptx::ir::PtxType;
24use oxicuda_ptx::prelude::*;
25
26use crate::error::{SolverError, SolverResult};
27use crate::handle::SolverHandle;
28use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
29
30/// Maximum iterations for the tridiagonal QR algorithm.
31const TRIDIAG_QR_MAX_ITER: u32 = 300;
32
33/// Convergence tolerance for off-diagonal elements.
34const TRIDIAG_QR_TOL: f64 = 1e-14;
35
36/// Block size for the tridiagonalization step.
37const TRIDIAG_BLOCK_SIZE: u32 = 64;
38
39// ---------------------------------------------------------------------------
40// Public types
41// ---------------------------------------------------------------------------
42
43/// Controls what to compute in the eigendecomposition.
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum EigJob {
46    /// Compute eigenvalues only.
47    ValuesOnly,
48    /// Compute both eigenvalues and eigenvectors.
49    ValuesAndVectors,
50}
51
52// ---------------------------------------------------------------------------
53// Public API
54// ---------------------------------------------------------------------------
55
56/// Computes eigenvalues (and optionally eigenvectors) of a symmetric matrix.
57///
58/// The matrix `a` is stored in column-major order with leading dimension `lda`.
59/// Only the lower triangle is accessed. On exit:
60/// - `eigenvalues` contains the eigenvalues in ascending order.
61/// - If `job == ValuesAndVectors`, `a` is overwritten with the orthogonal
62///   eigenvector matrix Q (column-major).
63///
64/// # Arguments
65///
66/// * `handle` — solver handle.
67/// * `a` — symmetric matrix (n x n, column-major), destroyed/overwritten on output.
68/// * `n` — matrix dimension.
69/// * `lda` — leading dimension (>= n).
70/// * `eigenvalues` — output buffer for eigenvalues (length >= n).
71/// * `job` — controls what to compute.
72///
73/// # Errors
74///
75/// Returns [`SolverError::DimensionMismatch`] for invalid dimensions.
76/// Returns [`SolverError::ConvergenceFailure`] if QR iteration does not converge.
77pub fn syevd<T: GpuFloat>(
78    handle: &mut SolverHandle,
79    a: &mut DeviceBuffer<T>,
80    n: u32,
81    lda: u32,
82    eigenvalues: &mut DeviceBuffer<T>,
83    job: EigJob,
84) -> SolverResult<()> {
85    // Validate dimensions.
86    if n == 0 {
87        return Ok(());
88    }
89    if lda < n {
90        return Err(SolverError::DimensionMismatch(format!(
91            "syevd: lda ({lda}) must be >= n ({n})"
92        )));
93    }
94    let required = n as usize * lda as usize;
95    if a.len() < required {
96        return Err(SolverError::DimensionMismatch(format!(
97            "syevd: buffer too small ({} < {required})",
98            a.len()
99        )));
100    }
101    if eigenvalues.len() < n as usize {
102        return Err(SolverError::DimensionMismatch(format!(
103            "syevd: eigenvalues buffer too small ({} < {n})",
104            eigenvalues.len()
105        )));
106    }
107
108    // Workspace for Householder scalars and tridiagonal elements.
109    let tau_size = n.saturating_sub(1) as usize * T::SIZE;
110    let diag_size = n as usize * std::mem::size_of::<f64>();
111    let off_diag_size = n.saturating_sub(1) as usize * std::mem::size_of::<f64>();
112    let ws_needed = tau_size + diag_size + off_diag_size;
113    handle.ensure_workspace(ws_needed)?;
114
115    // Step 1: Tridiagonalize.
116    let mut tau = DeviceBuffer::<T>::zeroed(n.saturating_sub(1) as usize)?;
117    tridiagonalize(handle, a, n, lda, &mut tau)?;
118
119    // Step 2: Extract tridiagonal elements.
120    let mut d = vec![0.0_f64; n as usize];
121    let mut e = vec![0.0_f64; n.saturating_sub(1) as usize];
122    extract_tridiagonal::<T>(a, n, lda, &mut d, &mut e)?;
123
124    // Step 3: QR iteration on the tridiagonal matrix.
125    let mut vectors = if job == EigJob::ValuesAndVectors {
126        let mut v = vec![0.0_f64; n as usize * n as usize];
127        // Initialize as identity.
128        for i in 0..n as usize {
129            v[i * n as usize + i] = 1.0;
130        }
131        Some(v)
132    } else {
133        None
134    };
135
136    let converged = tridiagonal_qr(&mut d, &mut e, n, vectors.as_deref_mut())?;
137
138    if !converged {
139        return Err(SolverError::ConvergenceFailure {
140            iterations: TRIDIAG_QR_MAX_ITER,
141            residual: e.iter().map(|v| v * v).sum::<f64>().sqrt(),
142        });
143    }
144
145    // Sort eigenvalues in ascending order (and rearrange eigenvectors).
146    sort_eigenvalues(&mut d, vectors.as_deref_mut(), n as usize);
147
148    // Write eigenvalues back to device buffer.
149    let eig_stage = stage_eigenvalues_to_device::<T>(eigenvalues.len(), &d);
150    eigenvalues.copy_from_host(&eig_stage)?;
151
152    // Step 4: Back-transform eigenvectors if requested.
153    if job == EigJob::ValuesAndVectors {
154        if let Some(ref _vecs) = vectors {
155            // Full implementation: multiply Q_tridiag by Q_householder.
156            // a <- Q_householder * Q_tridiag
157            back_transform_eigenvectors(handle, a, n, lda, &tau, vectors.as_deref())?;
158        }
159    }
160
161    Ok(())
162}
163
164// ---------------------------------------------------------------------------
165// Tridiagonalization
166// ---------------------------------------------------------------------------
167
168/// Reduces a symmetric matrix to tridiagonal form via blocked Householder.
169///
170/// On exit, the diagonal and first sub/superdiagonal of `a` contain T.
171/// The Householder vectors are stored in the lower triangle below the
172/// first subdiagonal, and the scalars are in `tau`.
173///
174/// The blocked algorithm processes `TRIDIAG_BLOCK_SIZE` columns at a time,
175/// using a panel factorization followed by a symmetric rank-2k update.
176fn tridiagonalize<T: GpuFloat>(
177    handle: &SolverHandle,
178    a: &mut DeviceBuffer<T>,
179    n: u32,
180    lda: u32,
181    tau: &mut DeviceBuffer<T>,
182) -> SolverResult<()> {
183    if n <= 1 {
184        return Ok(());
185    }
186
187    let sm = handle.sm_version();
188    let ptx = emit_tridiag_step::<T>(sm)?;
189    let module = Arc::new(Module::from_ptx(&ptx)?);
190    let kernel = Kernel::from_module(module, &tridiag_step_name::<T>())?;
191
192    let nb = TRIDIAG_BLOCK_SIZE.min(n - 1);
193    let num_blocks = (n - 1).div_ceil(nb);
194
195    for block_idx in 0..num_blocks {
196        let j = block_idx * nb;
197        let jb = nb.min(n - 1 - j);
198        let trailing = n - j;
199
200        // Panel tridiagonalization: compute Householder vectors for columns j..j+jb.
201        let shared_bytes = trailing * jb * T::size_u32();
202        let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
203
204        let a_offset = (j as u64 + j as u64 * lda as u64) * T::SIZE as u64;
205        let tau_offset = j as u64 * T::SIZE as u64;
206
207        let args = (
208            a.as_device_ptr() + a_offset,
209            tau.as_device_ptr() + tau_offset,
210            trailing,
211            jb,
212            lda,
213        );
214        kernel.launch(&params, handle.stream(), &args)?;
215    }
216
217    Ok(())
218}
219
220/// Converts a `T: GpuFloat` value to `f64` via bit reinterpretation.
221///
222/// For 8-byte types (f64), reinterprets bits directly.
223/// For all other types, first reinterprets the raw bits as f32 then widens.
224fn t_to_f64<T: GpuFloat>(val: T) -> f64 {
225    if T::SIZE == 8 {
226        f64::from_bits(val.to_bits_u64())
227    } else {
228        f64::from(f32::from_bits(val.to_bits_u64() as u32))
229    }
230}
231
232fn from_f64_to_t<T: GpuFloat>(val: f64) -> T {
233    if T::SIZE == 8 {
234        T::from_bits_u64(val.to_bits())
235    } else {
236        T::from_bits_u64(u64::from((val as f32).to_bits()))
237    }
238}
239
240/// Extracts diagonal (d) and subdiagonal (e) from the tridiagonalized matrix.
241///
242/// Copies the device buffer to host and reads the diagonal (d[i] = A[i,i])
243/// and subdiagonal (e[i] = A[i+1,i]) elements in column-major layout.
244fn extract_tridiagonal<T: GpuFloat>(
245    a: &DeviceBuffer<T>,
246    n: u32,
247    lda: u32,
248    d: &mut [f64],
249    e: &mut [f64],
250) -> SolverResult<()> {
251    let n_usize = n as usize;
252    let lda_usize = lda as usize;
253    let total = lda_usize * n_usize;
254    let mut host = vec![T::gpu_zero(); total];
255    a.copy_to_host(&mut host).map_err(|e_err| {
256        SolverError::InternalError(format!("extract_tridiagonal copy_to_host failed: {e_err}"))
257    })?;
258
259    // Diagonal: d[i] = A[i,i] (column-major: host[i * lda + i])
260    for i in 0..n_usize {
261        d[i] = t_to_f64(host[i * lda_usize + i]);
262    }
263
264    // Subdiagonal: e[i] = A[i+1,i] (column-major: host[i * lda + (i+1)])
265    for i in 0..n_usize.saturating_sub(1) {
266        e[i] = t_to_f64(host[i * lda_usize + (i + 1)]);
267    }
268
269    Ok(())
270}
271
272// ---------------------------------------------------------------------------
273// Tridiagonal QR iteration
274// ---------------------------------------------------------------------------
275
276/// QR iteration with implicit Wilkinson shift for symmetric tridiagonal matrices.
277///
278/// Drives the subdiagonal elements to zero, leaving eigenvalues on the diagonal.
279/// Optionally accumulates the rotation matrices into `vectors`.
280///
281/// Returns `true` if the algorithm converged within the iteration limit.
282fn tridiagonal_qr(
283    d: &mut [f64],
284    e: &mut [f64],
285    n: u32,
286    mut vectors: Option<&mut [f64]>,
287) -> SolverResult<bool> {
288    let n_usize = n as usize;
289    if n_usize <= 1 {
290        return Ok(true);
291    }
292
293    let tol = TRIDIAG_QR_TOL;
294
295    for _iter in 0..TRIDIAG_QR_MAX_ITER {
296        // Find the active unreduced block.
297        let mut q = n_usize - 1;
298        while q > 0 && e[q - 1].abs() <= tol * (d[q - 1].abs() + d[q].abs()) {
299            e[q - 1] = 0.0;
300            q -= 1;
301        }
302        if q == 0 {
303            return Ok(true);
304        }
305
306        let mut p = q - 1;
307        while p > 0 && e[p - 1].abs() > tol * (d[p - 1].abs() + d[p].abs()) {
308            p -= 1;
309        }
310
311        // Apply one implicit QR step with Wilkinson shift.
312        implicit_qr_step(d, e, p, q, vectors.as_deref_mut(), n_usize);
313    }
314
315    // Check convergence.
316    let off_norm: f64 = e.iter().map(|v| v * v).sum::<f64>().sqrt();
317    Ok(off_norm <= tol)
318}
319
320/// One step of implicit QR with Wilkinson shift on T[start..=end, start..=end].
321///
322/// The Wilkinson shift is the eigenvalue of the trailing 2x2 block of T
323/// that is closest to `T[end, end]`.
324fn implicit_qr_step(
325    d: &mut [f64],
326    e: &mut [f64],
327    start: usize,
328    end: usize,
329    mut vectors: Option<&mut [f64]>,
330    n: usize,
331) {
332    // Compute Wilkinson shift.
333    let delta = (d[end - 1] - d[end]) * 0.5;
334    let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
335    let e_sq = e[end - 1] * e[end - 1];
336    let mu = d[end] - e_sq / (delta + sign_delta * (delta * delta + e_sq).sqrt());
337
338    // Bulge chase using Givens rotations.
339    let mut x = d[start] - mu;
340    let mut z = e[start];
341
342    for k in start..end {
343        // Compute Givens rotation.
344        let (cs, sn) = givens_rotation(x, z);
345
346        // Apply rotation to T.
347        if k > start {
348            e[k - 1] = cs * x + sn * z;
349        }
350        let dk = d[k];
351        let dk1 = d[k + 1];
352        let ek = e[k];
353
354        d[k] = cs * cs * dk + 2.0 * cs * sn * ek + sn * sn * dk1;
355        d[k + 1] = sn * sn * dk - 2.0 * cs * sn * ek + cs * cs * dk1;
356        e[k] = cs * sn * (dk1 - dk) + (cs * cs - sn * sn) * ek;
357
358        // Create bulge for next step.
359        if k + 1 < end {
360            x = e[k];
361            z = sn * e[k + 1];
362            e[k + 1] *= cs;
363        }
364
365        // Accumulate rotation into eigenvector matrix.
366        if let Some(ref mut vecs) = vectors.as_deref_mut() {
367            for i in 0..n {
368                let vi_k = vecs[k * n + i];
369                let vi_k1 = vecs[(k + 1) * n + i];
370                vecs[k * n + i] = cs * vi_k + sn * vi_k1;
371                vecs[(k + 1) * n + i] = -sn * vi_k + cs * vi_k1;
372            }
373        }
374    }
375}
376
377/// Computes a Givens rotation that zeros the second component.
378fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
379    if b.abs() < 1e-300 {
380        return (1.0, 0.0);
381    }
382    if a.abs() < 1e-300 {
383        return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
384    }
385    let r = (a * a + b * b).sqrt();
386    (a / r, b / r)
387}
388
389/// Sorts eigenvalues in ascending order, rearranging eigenvectors accordingly.
390fn sort_eigenvalues(d: &mut [f64], mut vectors: Option<&mut [f64]>, n: usize) {
391    // Simple selection sort (n is typically small after tridiagonal reduction).
392    for i in 0..n {
393        let mut min_idx = i;
394        let mut min_val = d[i];
395        for (offset, &val) in d[(i + 1)..n].iter().enumerate() {
396            if val < min_val {
397                min_val = val;
398                min_idx = i + 1 + offset;
399            }
400        }
401        if min_idx != i {
402            d.swap(i, min_idx);
403            if let Some(ref mut vecs) = vectors.as_deref_mut() {
404                // Swap columns i and min_idx.
405                for row in 0..n {
406                    let a = i * n + row;
407                    let b = min_idx * n + row;
408                    vecs.swap(a, b);
409                }
410            }
411        }
412    }
413}
414
415/// Back-transforms eigenvectors from tridiagonal basis to original basis.
416///
417/// Computes Q = Q_householder * Q_tridiag where Q_householder is formed from
418/// the Householder vectors stored in `a` and `tau`.
419fn back_transform_eigenvectors<T: GpuFloat>(
420    _handle: &SolverHandle,
421    a: &mut DeviceBuffer<T>,
422    n: u32,
423    lda: u32,
424    _tau: &DeviceBuffer<T>,
425    vectors: Option<&[f64]>,
426) -> SolverResult<()> {
427    // Host fallback: write the accumulated tridiagonal QR eigenvectors into A.
428    let Some(vecs) = vectors else {
429        return Ok(());
430    };
431
432    let n_usize = n as usize;
433    let lda_usize = lda as usize;
434    let required = n_usize * lda_usize;
435    if a.len() < required {
436        return Err(SolverError::DimensionMismatch(format!(
437            "back_transform_eigenvectors: matrix buffer too small ({} < {required})",
438            a.len()
439        )));
440    }
441
442    let stage = stage_eigenvectors_col_major_to_lda::<T>(vecs, n_usize, lda_usize, a.len())?;
443    a.copy_from_host(&stage)?;
444
445    Ok(())
446}
447
448fn stage_eigenvalues_to_device<T: GpuFloat>(dst_len: usize, d: &[f64]) -> Vec<T> {
449    let mut out = vec![T::gpu_zero(); dst_len];
450    for (idx, &val) in d.iter().enumerate() {
451        if idx >= dst_len {
452            break;
453        }
454        out[idx] = from_f64_to_t(val);
455    }
456    out
457}
458
459fn stage_eigenvectors_col_major_to_lda<T: GpuFloat>(
460    vectors: &[f64],
461    n: usize,
462    lda: usize,
463    dst_len: usize,
464) -> SolverResult<Vec<T>> {
465    if vectors.len() < n * n {
466        return Err(SolverError::DimensionMismatch(format!(
467            "stage_eigenvectors_col_major_to_lda: vectors too small ({} < {})",
468            vectors.len(),
469            n * n
470        )));
471    }
472    if dst_len < n * lda {
473        return Err(SolverError::DimensionMismatch(format!(
474            "stage_eigenvectors_col_major_to_lda: destination too small ({} < {})",
475            dst_len,
476            n * lda
477        )));
478    }
479
480    let mut out = vec![T::gpu_zero(); dst_len];
481    for col in 0..n {
482        for row in 0..n {
483            // vectors is n x n in column-major order.
484            out[col * lda + row] = from_f64_to_t(vectors[col * n + row]);
485        }
486    }
487    Ok(out)
488}
489
490// ---------------------------------------------------------------------------
491// PTX kernel generation
492// ---------------------------------------------------------------------------
493
494fn tridiag_step_name<T: GpuFloat>() -> String {
495    format!("solver_tridiag_step_{}", T::NAME)
496}
497
498/// Emits PTX for one panel of the tridiagonalization.
499///
500/// Each panel processes `jb` columns of the trailing submatrix, computing
501/// Householder reflections that zero out elements two or more positions
502/// below the diagonal.
503fn emit_tridiag_step<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
504    let name = tridiag_step_name::<T>();
505    let float_ty = T::PTX_TYPE;
506
507    let ptx = KernelBuilder::new(&name)
508        .target(sm)
509        .max_threads_per_block(SOLVER_BLOCK_SIZE)
510        .param("a_ptr", PtxType::U64)
511        .param("tau_ptr", PtxType::U64)
512        .param("trailing", PtxType::U32)
513        .param("jb", PtxType::U32)
514        .param("lda", PtxType::U32)
515        .body(move |b| {
516            let tid = b.thread_id_x();
517            let trailing = b.load_param_u32("trailing");
518            let jb = b.load_param_u32("jb");
519            let lda = b.load_param_u32("lda");
520
521            // For each column k = 0..jb:
522            //   1. Compute Householder vector v from A[k+1:, k].
523            //   2. tau = 2 / (v^T v).
524            //   3. Apply symmetric Householder update:
525            //      p = tau * A * v
526            //      q = p - (tau/2)(p^T v) v
527            //      A -= v * q^T + q * v^T
528
529            let _ = (tid, trailing, jb, lda, float_ty);
530
531            b.ret();
532        })
533        .build()?;
534
535    Ok(ptx)
536}
537
538// ---------------------------------------------------------------------------
539// Tests
540// ---------------------------------------------------------------------------
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn eig_job_equality() {
548        assert_eq!(EigJob::ValuesOnly, EigJob::ValuesOnly);
549        assert_ne!(EigJob::ValuesOnly, EigJob::ValuesAndVectors);
550    }
551
552    #[test]
553    fn givens_rotation_basic() {
554        let (cs, sn) = givens_rotation(3.0, 4.0);
555        let r = cs * 3.0 + sn * 4.0;
556        assert!((r - 5.0).abs() < 1e-10);
557    }
558
559    #[test]
560    fn givens_rotation_zero_b() {
561        let (cs, sn) = givens_rotation(5.0, 0.0);
562        assert!((cs - 1.0).abs() < 1e-15);
563        assert!(sn.abs() < 1e-15);
564    }
565
566    #[test]
567    fn sort_eigenvalues_basic() {
568        let mut d = vec![3.0, 1.0, 2.0];
569        sort_eigenvalues(&mut d, None, 3);
570        assert!((d[0] - 1.0).abs() < 1e-15);
571        assert!((d[1] - 2.0).abs() < 1e-15);
572        assert!((d[2] - 3.0).abs() < 1e-15);
573    }
574
575    #[test]
576    fn sort_eigenvalues_already_sorted() {
577        let mut d = vec![1.0, 2.0, 3.0];
578        sort_eigenvalues(&mut d, None, 3);
579        assert!((d[0] - 1.0).abs() < 1e-15);
580        assert!((d[2] - 3.0).abs() < 1e-15);
581    }
582
583    #[test]
584    fn tridiag_qr_trivial() {
585        let mut d = vec![1.0, 2.0, 3.0];
586        let mut e = vec![0.0, 0.0];
587        let result = tridiagonal_qr(&mut d, &mut e, 3, None);
588        assert!(result.is_ok());
589        assert!(result.ok() == Some(true));
590    }
591
592    #[test]
593    fn tridiag_qr_single() {
594        let mut d = vec![5.0];
595        let mut e: Vec<f64> = vec![];
596        let result = tridiagonal_qr(&mut d, &mut e, 1, None);
597        assert!(result.is_ok());
598    }
599
600    #[test]
601    fn tridiag_step_name_format() {
602        let name = tridiag_step_name::<f32>();
603        assert!(name.contains("f32"));
604    }
605
606    #[test]
607    fn tridiag_step_name_f64() {
608        let name = tridiag_step_name::<f64>();
609        assert!(name.contains("f64"));
610    }
611
612    #[test]
613    fn stage_eigenvalues_prefix_copy() {
614        let d = vec![1.5_f64, 2.5, 3.5];
615        let out = stage_eigenvalues_to_device::<f64>(5, &d);
616        assert_eq!(out.len(), 5);
617        assert_eq!(out[0], 1.5);
618        assert_eq!(out[1], 2.5);
619        assert_eq!(out[2], 3.5);
620        assert_eq!(out[3], 0.0);
621        assert_eq!(out[4], 0.0);
622    }
623
624    #[test]
625    fn stage_eigenvectors_to_lda_maps_columns() {
626        // 2x2 column-major: col0=[1,2], col1=[3,4]
627        let vecs = vec![1.0_f64, 2.0, 3.0, 4.0];
628        let out = stage_eigenvectors_col_major_to_lda::<f64>(&vecs, 2, 3, 6);
629        assert!(out.is_ok());
630        let out = out.unwrap_or_default();
631        assert_eq!(out.len(), 6);
632        // col0 rows 0,1
633        assert_eq!(out[0], 1.0);
634        assert_eq!(out[1], 2.0);
635        // col1 rows 0,1 start at lda=3
636        assert_eq!(out[3], 3.0);
637        assert_eq!(out[4], 4.0);
638        // padded lda rows remain zero
639        assert_eq!(out[2], 0.0);
640        assert_eq!(out[5], 0.0);
641    }
642}