Skip to main content

ferrotorch_gpu/
cusolver.rs

1//! cuSOLVER-backed GPU linear algebra: SVD, Cholesky, QR, Solve.
2//!
3//! Each operation follows the cuSOLVER pattern:
4//! 1. Query workspace size via `*_bufferSize`.
5//! 2. Allocate workspace + output buffers on the device.
6//! 3. Call the cuSOLVER routine.
7//! 4. Check `devInfo` — non-zero means the operation failed (singular matrix, etc.).
8//!
9//! All functions operate on column-major data because cuSOLVER (LAPACK-style)
10//! uses column-major layout. The caller is responsible for transposing
11//! row-major tensors before calling and transposing outputs back.
12
13#[cfg(feature = "cuda")]
14use cudarc::cusolver as cusolver_safe;
15#[cfg(feature = "cuda")]
16use cudarc::driver::{DevicePtr, DevicePtrMut};
17
18#[cfg(feature = "cuda")]
19use crate::buffer::CudaBuffer;
20#[cfg(feature = "cuda")]
21use crate::device::GpuDevice;
22#[cfg(feature = "cuda")]
23use crate::error::{GpuError, GpuResult};
24
25#[cfg(not(feature = "cuda"))]
26use crate::device::GpuDevice;
27#[cfg(not(feature = "cuda"))]
28use crate::error::{GpuError, GpuResult};
29
30// ---------------------------------------------------------------------------
31// Helper: transpose row-major <-> column-major in-place on CPU
32// ---------------------------------------------------------------------------
33
34/// Transpose an m-by-n row-major flat array to column-major (or vice versa).
35fn transpose_f32(data: &[f32], m: usize, n: usize) -> Vec<f32> {
36    let mut out = vec![0.0f32; m * n];
37    for i in 0..m {
38        for j in 0..n {
39            out[j * m + i] = data[i * n + j];
40        }
41    }
42    out
43}
44
45/// Transpose an m-by-n row-major flat array to column-major (or vice versa) — f64 variant.
46fn transpose_f64(data: &[f64], m: usize, n: usize) -> Vec<f64> {
47    let mut out = vec![0.0f64; m * n];
48    for i in 0..m {
49        for j in 0..n {
50            out[j * m + i] = data[i * n + j];
51        }
52    }
53    out
54}
55
56// ---------------------------------------------------------------------------
57// Helper: check devInfo (download single i32 from GPU, verify == 0)
58// ---------------------------------------------------------------------------
59
60/// Download a single i32 `devInfo` value from the GPU and check it.
61///
62/// Returns `Ok(info_val)` always. The caller decides how to interpret non-zero.
63#[cfg(feature = "cuda")]
64fn read_dev_info(info_buf: &CudaBuffer<i32>, device: &GpuDevice) -> GpuResult<i32> {
65    let host = crate::transfer::gpu_to_cpu(info_buf, device)?;
66    Ok(host[0])
67}
68
69/// Allocate a zero-initialized `CudaBuffer<i32>` (for devInfo / ipiv).
70#[cfg(feature = "cuda")]
71fn alloc_zeros_i32(len: usize, device: &GpuDevice) -> GpuResult<CudaBuffer<i32>> {
72    crate::transfer::alloc_zeros::<i32>(len, device)
73}
74
75/// Allocate a `CudaBuffer<f32>` from host data.
76#[cfg(feature = "cuda")]
77fn upload_f32(data: &[f32], device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
78    crate::transfer::cpu_to_gpu(data, device)
79}
80
81/// Download a `CudaBuffer<f32>` to host.
82#[cfg(feature = "cuda")]
83fn download_f32(buf: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<Vec<f32>> {
84    crate::transfer::gpu_to_cpu(buf, device)
85}
86
87/// Allocate a `CudaBuffer<f64>` from host data.
88#[cfg(feature = "cuda")]
89fn upload_f64(data: &[f64], device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
90    crate::transfer::cpu_to_gpu(data, device)
91}
92
93/// Download a `CudaBuffer<f64>` to host.
94#[cfg(feature = "cuda")]
95fn download_f64(buf: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<Vec<f64>> {
96    crate::transfer::gpu_to_cpu(buf, device)
97}
98
99// ---------------------------------------------------------------------------
100// SVD: A = U * diag(S) * Vh   (thin/reduced)
101// ---------------------------------------------------------------------------
102
103/// Compute the thin SVD of an m-by-n matrix (row-major f32).
104///
105/// Returns `(U, S, Vh)` as flat row-major `Vec<f32>` with shapes:
106/// - U:  [m, k]  where k = min(m, n)
107/// - S:  [k]
108/// - Vh: [k, n]
109///
110/// cuSOLVER's `Sgesvd` operates on column-major data and produces
111/// column-major U and VT. We transpose on input and output.
112#[cfg(feature = "cuda")]
113pub fn gpu_svd_f32(
114    data: &[f32],
115    m: usize,
116    n: usize,
117    device: &GpuDevice,
118) -> GpuResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
119    use cudarc::cusolver::sys as csys;
120
121    if m == 0 || n == 0 {
122        return Ok((vec![], vec![], vec![]));
123    }
124
125    let k = m.min(n);
126    let stream = device.stream();
127
128    // Create cuSOLVER handle.
129    let dn = cusolver_safe::DnHandle::new(stream.clone())?;
130
131    // Transpose input from row-major to column-major.
132    let col_major = transpose_f32(data, m, n);
133    let mut d_a = upload_f32(&col_major, device)?;
134
135    // Allocate output buffers on device.
136    let mut d_s = crate::transfer::alloc_zeros_f32(k, device)?;
137    let mut d_u = crate::transfer::alloc_zeros_f32(m * k, device)?;
138    let mut d_vt = crate::transfer::alloc_zeros_f32(k * n, device)?;
139    let mut d_info = alloc_zeros_i32(1, device)?;
140
141    // Query workspace size.
142    let mut lwork: i32 = 0;
143    // SAFETY: dn.cu() is a valid cusolverDnHandle_t, m/n are valid dimensions.
144    unsafe {
145        csys::cusolverDnSgesvd_bufferSize(
146            dn.cu(),
147            m as i32,
148            n as i32,
149            &mut lwork,
150        )
151        .result()?;
152    }
153
154    let mut d_work = crate::transfer::alloc_zeros_f32(lwork.max(1) as usize, device)?;
155
156    // cuSOLVER Sgesvd: jobu='S' (thin U), jobvt='S' (thin VT).
157    // SAFETY: All device pointers are valid allocations of the required sizes.
158    // The handle and stream are valid. We synchronize and check devInfo after.
159    unsafe {
160        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
161        let (s_ptr, _s_sync) = d_s.inner_mut().device_ptr_mut(&stream);
162        let (u_ptr, _u_sync) = d_u.inner_mut().device_ptr_mut(&stream);
163        let (vt_ptr, _vt_sync) = d_vt.inner_mut().device_ptr_mut(&stream);
164        let (work_ptr, _work_sync) = d_work.inner_mut().device_ptr_mut(&stream);
165        let (info_ptr, _info_sync) = d_info.inner_mut().device_ptr_mut(&stream);
166
167        csys::cusolverDnSgesvd(
168            dn.cu(),
169            b'S' as i8,   // jobu: thin U
170            b'S' as i8,   // jobvt: thin VT
171            m as i32,
172            n as i32,
173            a_ptr as *mut f32,
174            m as i32,      // lda = m (column-major)
175            s_ptr as *mut f32,
176            u_ptr as *mut f32,
177            m as i32,      // ldu = m
178            vt_ptr as *mut f32,
179            k as i32,      // ldvt = k (for thin SVD)
180            work_ptr as *mut f32,
181            lwork,
182            std::ptr::null_mut(), // rwork (unused for real)
183            info_ptr as *mut i32,
184        )
185        .result()?;
186    }
187
188    stream.synchronize()?;
189
190    // Check devInfo.
191    let info_val = read_dev_info(&d_info, device)?;
192    if info_val != 0 {
193        return Err(GpuError::ShapeMismatch {
194            op: "gpu_svd_f32",
195            expected: vec![0],
196            got: vec![info_val as usize],
197        });
198    }
199
200    // Download results and transpose from column-major back to row-major.
201    let s_host = download_f32(&d_s, device)?;
202
203    // U is m-by-k column-major -> transpose to k columns, m rows row-major.
204    let u_col = download_f32(&d_u, device)?;
205    // Column-major m-by-k means the data is laid out as k columns of m elements.
206    // To convert to row-major m-by-k: out[i*k + j] = col[j*m + i].
207    let mut u_host = vec![0.0f32; m * k];
208    for i in 0..m {
209        for j in 0..k {
210            u_host[i * k + j] = u_col[j * m + i];
211        }
212    }
213
214    // VT is k-by-n column-major -> convert to row-major k-by-n.
215    let vt_col = download_f32(&d_vt, device)?;
216    let mut vt_host = vec![0.0f32; k * n];
217    for i in 0..k {
218        for j in 0..n {
219            vt_host[i * n + j] = vt_col[j * k + i];
220        }
221    }
222
223    Ok((u_host, s_host, vt_host))
224}
225
226/// Compute the thin SVD of an m-by-n matrix (row-major f64).
227///
228/// Returns `(U, S, Vh)` as flat row-major `Vec<f64>` with shapes:
229/// - U:  [m, k]  where k = min(m, n)
230/// - S:  [k]
231/// - Vh: [k, n]
232///
233/// cuSOLVER's `Dgesvd` operates on column-major data and produces
234/// column-major U and VT. We transpose on input and output.
235#[cfg(feature = "cuda")]
236pub fn gpu_svd_f64(
237    data: &[f64],
238    m: usize,
239    n: usize,
240    device: &GpuDevice,
241) -> GpuResult<(Vec<f64>, Vec<f64>, Vec<f64>)> {
242    use cudarc::cusolver::sys as csys;
243
244    if m == 0 || n == 0 {
245        return Ok((vec![], vec![], vec![]));
246    }
247
248    let k = m.min(n);
249    let stream = device.stream();
250
251    // Create cuSOLVER handle.
252    let dn = cusolver_safe::DnHandle::new(stream.clone())?;
253
254    // Transpose input from row-major to column-major.
255    let col_major = transpose_f64(data, m, n);
256    let mut d_a = upload_f64(&col_major, device)?;
257
258    // Allocate output buffers on device.
259    let mut d_s = crate::transfer::alloc_zeros_f64(k, device)?;
260    let mut d_u = crate::transfer::alloc_zeros_f64(m * k, device)?;
261    let mut d_vt = crate::transfer::alloc_zeros_f64(k * n, device)?;
262    let mut d_info = alloc_zeros_i32(1, device)?;
263
264    // Query workspace size.
265    let mut lwork: i32 = 0;
266    // SAFETY: dn.cu() is a valid cusolverDnHandle_t, m/n are valid dimensions.
267    unsafe {
268        csys::cusolverDnDgesvd_bufferSize(
269            dn.cu(),
270            m as i32,
271            n as i32,
272            &mut lwork,
273        )
274        .result()?;
275    }
276
277    let mut d_work = crate::transfer::alloc_zeros_f64(lwork.max(1) as usize, device)?;
278
279    // cuSOLVER Dgesvd: jobu='S' (thin U), jobvt='S' (thin VT).
280    // SAFETY: All device pointers are valid allocations of the required sizes.
281    // The handle and stream are valid. We synchronize and check devInfo after.
282    unsafe {
283        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
284        let (s_ptr, _s_sync) = d_s.inner_mut().device_ptr_mut(&stream);
285        let (u_ptr, _u_sync) = d_u.inner_mut().device_ptr_mut(&stream);
286        let (vt_ptr, _vt_sync) = d_vt.inner_mut().device_ptr_mut(&stream);
287        let (work_ptr, _work_sync) = d_work.inner_mut().device_ptr_mut(&stream);
288        let (info_ptr, _info_sync) = d_info.inner_mut().device_ptr_mut(&stream);
289
290        csys::cusolverDnDgesvd(
291            dn.cu(),
292            b'S' as i8,   // jobu: thin U
293            b'S' as i8,   // jobvt: thin VT
294            m as i32,
295            n as i32,
296            a_ptr as *mut f64,
297            m as i32,      // lda = m (column-major)
298            s_ptr as *mut f64,
299            u_ptr as *mut f64,
300            m as i32,      // ldu = m
301            vt_ptr as *mut f64,
302            k as i32,      // ldvt = k (for thin SVD)
303            work_ptr as *mut f64,
304            lwork,
305            std::ptr::null_mut(), // rwork (unused for real)
306            info_ptr as *mut i32,
307        )
308        .result()?;
309    }
310
311    stream.synchronize()?;
312
313    // Check devInfo.
314    let info_val = read_dev_info(&d_info, device)?;
315    if info_val != 0 {
316        return Err(GpuError::ShapeMismatch {
317            op: "gpu_svd_f64",
318            expected: vec![0],
319            got: vec![info_val as usize],
320        });
321    }
322
323    // Download results and transpose from column-major back to row-major.
324    let s_host = download_f64(&d_s, device)?;
325
326    // U is m-by-k column-major -> transpose to k columns, m rows row-major.
327    let u_col = download_f64(&d_u, device)?;
328    // Column-major m-by-k means the data is laid out as k columns of m elements.
329    // To convert to row-major m-by-k: out[i*k + j] = col[j*m + i].
330    let mut u_host = vec![0.0f64; m * k];
331    for i in 0..m {
332        for j in 0..k {
333            u_host[i * k + j] = u_col[j * m + i];
334        }
335    }
336
337    // VT is k-by-n column-major -> convert to row-major k-by-n.
338    let vt_col = download_f64(&d_vt, device)?;
339    let mut vt_host = vec![0.0f64; k * n];
340    for i in 0..k {
341        for j in 0..n {
342            vt_host[i * n + j] = vt_col[j * k + i];
343        }
344    }
345
346    Ok((u_host, s_host, vt_host))
347}
348
349// ---------------------------------------------------------------------------
350// Cholesky: A = L * L^T   (lower-triangular)
351// ---------------------------------------------------------------------------
352
353/// Compute the Cholesky decomposition of an n-by-n SPD matrix (row-major f32).
354///
355/// Returns the lower-triangular factor L as a flat row-major `Vec<f32>` [n, n].
356///
357/// Upper-triangular entries are explicitly zeroed.
358#[cfg(feature = "cuda")]
359pub fn gpu_cholesky_f32(
360    data: &[f32],
361    n: usize,
362    device: &GpuDevice,
363) -> GpuResult<Vec<f32>> {
364    use cudarc::cusolver::sys as csys;
365
366    if n == 0 {
367        return Ok(vec![]);
368    }
369
370    let stream = device.stream();
371    let dn = cusolver_safe::DnHandle::new(stream.clone())?;
372
373    // Transpose to column-major.
374    let col_major = transpose_f32(data, n, n);
375    let mut d_a = upload_f32(&col_major, device)?;
376    let mut d_info = alloc_zeros_i32(1, device)?;
377
378    // Query workspace size.
379    let mut lwork: i32 = 0;
380    // SAFETY: dn.cu() is valid, d_a points to n*n f32 elements.
381    unsafe {
382        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
383        csys::cusolverDnSpotrf_bufferSize(
384            dn.cu(),
385            csys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
386            n as i32,
387            a_ptr as *mut f32,
388            n as i32,
389            &mut lwork,
390        )
391        .result()?;
392    }
393
394    let mut d_work = crate::transfer::alloc_zeros_f32(lwork.max(1) as usize, device)?;
395
396    // SAFETY: All device pointers are valid. We use LOWER fill mode.
397    // devInfo is checked after synchronization.
398    unsafe {
399        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
400        let (work_ptr, _work_sync) = d_work.inner_mut().device_ptr_mut(&stream);
401        let (info_ptr, _info_sync) = d_info.inner_mut().device_ptr_mut(&stream);
402
403        csys::cusolverDnSpotrf(
404            dn.cu(),
405            csys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
406            n as i32,
407            a_ptr as *mut f32,
408            n as i32,
409            work_ptr as *mut f32,
410            lwork,
411            info_ptr as *mut i32,
412        )
413        .result()?;
414    }
415
416    stream.synchronize()?;
417
418    let info_val = read_dev_info(&d_info, device)?;
419    if info_val != 0 {
420        return Err(GpuError::ShapeMismatch {
421            op: "gpu_cholesky_f32: matrix is not positive-definite",
422            expected: vec![0],
423            got: vec![info_val as usize],
424        });
425    }
426
427    // Download column-major result and convert to row-major.
428    let l_col = download_f32(&d_a, device)?;
429    let mut l_host = vec![0.0f32; n * n];
430    for i in 0..n {
431        for j in 0..n {
432            l_host[i * n + j] = l_col[j * n + i];
433        }
434    }
435
436    // cuSOLVER only writes the lower triangle; zero the upper triangle explicitly.
437    for i in 0..n {
438        for j in (i + 1)..n {
439            l_host[i * n + j] = 0.0;
440        }
441    }
442
443    Ok(l_host)
444}
445
446/// Compute the Cholesky decomposition of an n-by-n SPD matrix (row-major f64).
447///
448/// Returns the lower-triangular factor L as a flat row-major `Vec<f64>` [n, n].
449///
450/// Upper-triangular entries are explicitly zeroed.
451#[cfg(feature = "cuda")]
452pub fn gpu_cholesky_f64(
453    data: &[f64],
454    n: usize,
455    device: &GpuDevice,
456) -> GpuResult<Vec<f64>> {
457    use cudarc::cusolver::sys as csys;
458
459    if n == 0 {
460        return Ok(vec![]);
461    }
462
463    let stream = device.stream();
464    let dn = cusolver_safe::DnHandle::new(stream.clone())?;
465
466    // Transpose to column-major.
467    let col_major = transpose_f64(data, n, n);
468    let mut d_a = upload_f64(&col_major, device)?;
469    let mut d_info = alloc_zeros_i32(1, device)?;
470
471    // Query workspace size.
472    let mut lwork: i32 = 0;
473    // SAFETY: dn.cu() is valid, d_a points to n*n f64 elements.
474    unsafe {
475        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
476        csys::cusolverDnDpotrf_bufferSize(
477            dn.cu(),
478            csys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
479            n as i32,
480            a_ptr as *mut f64,
481            n as i32,
482            &mut lwork,
483        )
484        .result()?;
485    }
486
487    let mut d_work = crate::transfer::alloc_zeros_f64(lwork.max(1) as usize, device)?;
488
489    // SAFETY: All device pointers are valid. We use LOWER fill mode.
490    // devInfo is checked after synchronization.
491    unsafe {
492        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
493        let (work_ptr, _work_sync) = d_work.inner_mut().device_ptr_mut(&stream);
494        let (info_ptr, _info_sync) = d_info.inner_mut().device_ptr_mut(&stream);
495
496        csys::cusolverDnDpotrf(
497            dn.cu(),
498            csys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
499            n as i32,
500            a_ptr as *mut f64,
501            n as i32,
502            work_ptr as *mut f64,
503            lwork,
504            info_ptr as *mut i32,
505        )
506        .result()?;
507    }
508
509    stream.synchronize()?;
510
511    let info_val = read_dev_info(&d_info, device)?;
512    if info_val != 0 {
513        return Err(GpuError::ShapeMismatch {
514            op: "gpu_cholesky_f64: matrix is not positive-definite",
515            expected: vec![0],
516            got: vec![info_val as usize],
517        });
518    }
519
520    // Download column-major result and convert to row-major.
521    let l_col = download_f64(&d_a, device)?;
522    let mut l_host = vec![0.0f64; n * n];
523    for i in 0..n {
524        for j in 0..n {
525            l_host[i * n + j] = l_col[j * n + i];
526        }
527    }
528
529    // cuSOLVER only writes the lower triangle; zero the upper triangle explicitly.
530    for i in 0..n {
531        for j in (i + 1)..n {
532            l_host[i * n + j] = 0.0;
533        }
534    }
535
536    Ok(l_host)
537}
538
539// ---------------------------------------------------------------------------
540// Solve: A * X = B   (via LU factorization: getrf + getrs)
541// ---------------------------------------------------------------------------
542
543/// Solve A * X = B for X where A is n-by-n and B is n-by-nrhs (row-major f32).
544///
545/// Uses LU factorization (Sgetrf) followed by triangular solve (Sgetrs).
546///
547/// Returns X as flat row-major `Vec<f32>` with shape [n, nrhs] (or [n] if nrhs==1).
548#[cfg(feature = "cuda")]
549pub fn gpu_solve_f32(
550    a_data: &[f32],
551    b_data: &[f32],
552    n: usize,
553    nrhs: usize,
554    device: &GpuDevice,
555) -> GpuResult<Vec<f32>> {
556    use cudarc::cusolver::sys as csys;
557
558    if n == 0 {
559        return Ok(vec![]);
560    }
561
562    let stream = device.stream();
563    let dn = cusolver_safe::DnHandle::new(stream.clone())?;
564
565    // Convert A to column-major.
566    let a_col = transpose_f32(a_data, n, n);
567    let mut d_a = upload_f32(&a_col, device)?;
568
569    // Convert B to column-major (n-by-nrhs).
570    let b_col = transpose_f32(b_data, n, nrhs);
571    let mut d_b = upload_f32(&b_col, device)?;
572
573    let mut d_ipiv = alloc_zeros_i32(n, device)?;
574    let mut d_info = alloc_zeros_i32(1, device)?;
575
576    // Query workspace for getrf.
577    let mut lwork: i32 = 0;
578    // SAFETY: dn.cu() is valid, d_a contains n*n f32 elements.
579    unsafe {
580        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
581        csys::cusolverDnSgetrf_bufferSize(
582            dn.cu(),
583            n as i32,
584            n as i32,
585            a_ptr as *mut f32,
586            n as i32,
587            &mut lwork,
588        )
589        .result()?;
590    }
591
592    let mut d_work = crate::transfer::alloc_zeros_f32(lwork.max(1) as usize, device)?;
593
594    // LU factorization: A = P * L * U.
595    // SAFETY: All device pointers are valid allocations of the required sizes.
596    unsafe {
597        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
598        let (work_ptr, _work_sync) = d_work.inner_mut().device_ptr_mut(&stream);
599        let (ipiv_ptr, _ipiv_sync) = d_ipiv.inner_mut().device_ptr_mut(&stream);
600        let (info_ptr, _info_sync) = d_info.inner_mut().device_ptr_mut(&stream);
601
602        csys::cusolverDnSgetrf(
603            dn.cu(),
604            n as i32,
605            n as i32,
606            a_ptr as *mut f32,
607            n as i32,
608            work_ptr as *mut f32,
609            ipiv_ptr as *mut i32,
610            info_ptr as *mut i32,
611        )
612        .result()?;
613    }
614
615    stream.synchronize()?;
616
617    let info_val = read_dev_info(&d_info, device)?;
618    if info_val != 0 {
619        return Err(GpuError::ShapeMismatch {
620            op: "gpu_solve_f32: LU factorization failed (singular matrix)",
621            expected: vec![0],
622            got: vec![info_val as usize],
623        });
624    }
625
626    // Triangular solve: L * U * X = P * B.
627    // Reset devInfo for getrs.
628    let mut d_info2 = alloc_zeros_i32(1, device)?;
629
630    // SAFETY: d_a now contains the LU factors, d_ipiv the pivot indices,
631    // d_b will be overwritten with the solution X. All are properly sized.
632    unsafe {
633        let (a_ptr, _a_sync) = d_a.inner().device_ptr(&stream);
634        let (ipiv_ptr, _ipiv_sync) = d_ipiv.inner().device_ptr(&stream);
635        let (b_ptr, _b_sync) = d_b.inner_mut().device_ptr_mut(&stream);
636        let (info_ptr, _info_sync) = d_info2.inner_mut().device_ptr_mut(&stream);
637
638        csys::cusolverDnSgetrs(
639            dn.cu(),
640            csys::cublasOperation_t::CUBLAS_OP_N, // no transpose
641            n as i32,
642            nrhs as i32,
643            a_ptr as *const f32,
644            n as i32,
645            ipiv_ptr as *const i32,
646            b_ptr as *mut f32,
647            n as i32,  // ldb = n
648            info_ptr as *mut i32,
649        )
650        .result()?;
651    }
652
653    stream.synchronize()?;
654
655    let info_val2 = read_dev_info(&d_info2, device)?;
656    if info_val2 != 0 {
657        return Err(GpuError::ShapeMismatch {
658            op: "gpu_solve_f32: triangular solve failed",
659            expected: vec![0],
660            got: vec![info_val2 as usize],
661        });
662    }
663
664    // Download solution (column-major) and convert to row-major.
665    let x_col = download_f32(&d_b, device)?;
666    let mut x_host = vec![0.0f32; n * nrhs];
667    for i in 0..n {
668        for j in 0..nrhs {
669            x_host[i * nrhs + j] = x_col[j * n + i];
670        }
671    }
672
673    Ok(x_host)
674}
675
676/// Solve A * X = B for X where A is n-by-n and B is n-by-nrhs (row-major f64).
677///
678/// Uses LU factorization (Dgetrf) followed by triangular solve (Dgetrs).
679///
680/// Returns X as flat row-major `Vec<f64>` with shape [n, nrhs] (or [n] if nrhs==1).
681#[cfg(feature = "cuda")]
682pub fn gpu_solve_f64(
683    a_data: &[f64],
684    b_data: &[f64],
685    n: usize,
686    nrhs: usize,
687    device: &GpuDevice,
688) -> GpuResult<Vec<f64>> {
689    use cudarc::cusolver::sys as csys;
690
691    if n == 0 {
692        return Ok(vec![]);
693    }
694
695    let stream = device.stream();
696    let dn = cusolver_safe::DnHandle::new(stream.clone())?;
697
698    // Convert A to column-major.
699    let a_col = transpose_f64(a_data, n, n);
700    let mut d_a = upload_f64(&a_col, device)?;
701
702    // Convert B to column-major (n-by-nrhs).
703    let b_col = transpose_f64(b_data, n, nrhs);
704    let mut d_b = upload_f64(&b_col, device)?;
705
706    let mut d_ipiv = alloc_zeros_i32(n, device)?;
707    let mut d_info = alloc_zeros_i32(1, device)?;
708
709    // Query workspace for getrf.
710    let mut lwork: i32 = 0;
711    // SAFETY: dn.cu() is valid, d_a contains n*n f64 elements.
712    unsafe {
713        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
714        csys::cusolverDnDgetrf_bufferSize(
715            dn.cu(),
716            n as i32,
717            n as i32,
718            a_ptr as *mut f64,
719            n as i32,
720            &mut lwork,
721        )
722        .result()?;
723    }
724
725    let mut d_work = crate::transfer::alloc_zeros_f64(lwork.max(1) as usize, device)?;
726
727    // LU factorization: A = P * L * U.
728    // SAFETY: All device pointers are valid allocations of the required sizes.
729    unsafe {
730        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
731        let (work_ptr, _work_sync) = d_work.inner_mut().device_ptr_mut(&stream);
732        let (ipiv_ptr, _ipiv_sync) = d_ipiv.inner_mut().device_ptr_mut(&stream);
733        let (info_ptr, _info_sync) = d_info.inner_mut().device_ptr_mut(&stream);
734
735        csys::cusolverDnDgetrf(
736            dn.cu(),
737            n as i32,
738            n as i32,
739            a_ptr as *mut f64,
740            n as i32,
741            work_ptr as *mut f64,
742            ipiv_ptr as *mut i32,
743            info_ptr as *mut i32,
744        )
745        .result()?;
746    }
747
748    stream.synchronize()?;
749
750    let info_val = read_dev_info(&d_info, device)?;
751    if info_val != 0 {
752        return Err(GpuError::ShapeMismatch {
753            op: "gpu_solve_f64: LU factorization failed (singular matrix)",
754            expected: vec![0],
755            got: vec![info_val as usize],
756        });
757    }
758
759    // Triangular solve: L * U * X = P * B.
760    // Reset devInfo for getrs.
761    let mut d_info2 = alloc_zeros_i32(1, device)?;
762
763    // SAFETY: d_a now contains the LU factors, d_ipiv the pivot indices,
764    // d_b will be overwritten with the solution X. All are properly sized.
765    unsafe {
766        let (a_ptr, _a_sync) = d_a.inner().device_ptr(&stream);
767        let (ipiv_ptr, _ipiv_sync) = d_ipiv.inner().device_ptr(&stream);
768        let (b_ptr, _b_sync) = d_b.inner_mut().device_ptr_mut(&stream);
769        let (info_ptr, _info_sync) = d_info2.inner_mut().device_ptr_mut(&stream);
770
771        csys::cusolverDnDgetrs(
772            dn.cu(),
773            csys::cublasOperation_t::CUBLAS_OP_N, // no transpose
774            n as i32,
775            nrhs as i32,
776            a_ptr as *const f64,
777            n as i32,
778            ipiv_ptr as *const i32,
779            b_ptr as *mut f64,
780            n as i32,  // ldb = n
781            info_ptr as *mut i32,
782        )
783        .result()?;
784    }
785
786    stream.synchronize()?;
787
788    let info_val2 = read_dev_info(&d_info2, device)?;
789    if info_val2 != 0 {
790        return Err(GpuError::ShapeMismatch {
791            op: "gpu_solve_f64: triangular solve failed",
792            expected: vec![0],
793            got: vec![info_val2 as usize],
794        });
795    }
796
797    // Download solution (column-major) and convert to row-major.
798    let x_col = download_f64(&d_b, device)?;
799    let mut x_host = vec![0.0f64; n * nrhs];
800    for i in 0..n {
801        for j in 0..nrhs {
802            x_host[i * nrhs + j] = x_col[j * n + i];
803        }
804    }
805
806    Ok(x_host)
807}
808
809// ---------------------------------------------------------------------------
810// QR: A = Q * R   (reduced/thin)
811// ---------------------------------------------------------------------------
812
813/// Compute the reduced QR decomposition of an m-by-n matrix (row-major f32).
814///
815/// Returns `(Q, R)` as flat row-major `Vec<f32>` with shapes:
816/// - Q: [m, k]  where k = min(m, n)
817/// - R: [k, n]
818///
819/// Uses Sgeqrf (Householder QR) followed by Sorgqr (generate Q).
820#[cfg(feature = "cuda")]
821pub fn gpu_qr_f32(
822    data: &[f32],
823    m: usize,
824    n: usize,
825    device: &GpuDevice,
826) -> GpuResult<(Vec<f32>, Vec<f32>)> {
827    use cudarc::cusolver::sys as csys;
828
829    if m == 0 || n == 0 {
830        return Ok((vec![], vec![]));
831    }
832
833    let k = m.min(n);
834    let stream = device.stream();
835    let dn = cusolver_safe::DnHandle::new(stream.clone())?;
836
837    // Transpose to column-major.
838    let col_major = transpose_f32(data, m, n);
839    let mut d_a = upload_f32(&col_major, device)?;
840    let mut d_tau = crate::transfer::alloc_zeros_f32(k, device)?;
841    let mut d_info = alloc_zeros_i32(1, device)?;
842
843    // Query workspace for geqrf.
844    let mut lwork: i32 = 0;
845    // SAFETY: dn.cu() is valid, d_a contains m*n f32 elements.
846    unsafe {
847        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
848        csys::cusolverDnSgeqrf_bufferSize(
849            dn.cu(),
850            m as i32,
851            n as i32,
852            a_ptr as *mut f32,
853            m as i32,
854            &mut lwork,
855        )
856        .result()?;
857    }
858
859    let mut d_work = crate::transfer::alloc_zeros_f32(lwork.max(1) as usize, device)?;
860
861    // Compute QR factorization (Householder form).
862    // SAFETY: All device pointers are valid. d_a is overwritten in-place
863    // with Householder reflectors (lower triangle) and R (upper triangle).
864    unsafe {
865        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
866        let (tau_ptr, _tau_sync) = d_tau.inner_mut().device_ptr_mut(&stream);
867        let (work_ptr, _work_sync) = d_work.inner_mut().device_ptr_mut(&stream);
868        let (info_ptr, _info_sync) = d_info.inner_mut().device_ptr_mut(&stream);
869
870        csys::cusolverDnSgeqrf(
871            dn.cu(),
872            m as i32,
873            n as i32,
874            a_ptr as *mut f32,
875            m as i32,
876            tau_ptr as *mut f32,
877            work_ptr as *mut f32,
878            lwork,
879            info_ptr as *mut i32,
880        )
881        .result()?;
882    }
883
884    stream.synchronize()?;
885
886    let info_val = read_dev_info(&d_info, device)?;
887    if info_val != 0 {
888        return Err(GpuError::ShapeMismatch {
889            op: "gpu_qr_f32: geqrf failed",
890            expected: vec![0],
891            got: vec![info_val as usize],
892        });
893    }
894
895    // Extract R from the upper triangle of d_a (column-major).
896    // R is k-by-n. We read the full m-by-n column-major buffer.
897    let qr_col = download_f32(&d_a, device)?;
898    let mut r_host = vec![0.0f32; k * n];
899    for i in 0..k {
900        for j in 0..n {
901            // In column-major m-by-n: element (i, j) is at index j*m + i.
902            if j >= i {
903                r_host[i * n + j] = qr_col[j * m + i]; // row-major output
904            }
905            // else: R[i,j] = 0 (already initialized)
906        }
907    }
908
909    // Generate explicit Q via Sorgqr.
910    // Sorgqr overwrites d_a in-place: the first k columns become Q (m-by-k, column-major).
911    let mut lwork_orgqr: i32 = 0;
912
913    // SAFETY: dn.cu() is valid, d_a and d_tau contain valid QR factorization data.
914    unsafe {
915        let (a_ptr, _a_sync) = d_a.inner().device_ptr(&stream);
916        let (tau_ptr, _tau_sync) = d_tau.inner().device_ptr(&stream);
917        csys::cusolverDnSorgqr_bufferSize(
918            dn.cu(),
919            m as i32,
920            k as i32,
921            k as i32,
922            a_ptr as *const f32,
923            m as i32,
924            tau_ptr as *const f32,
925            &mut lwork_orgqr,
926        )
927        .result()?;
928    }
929
930    let mut d_work2 = crate::transfer::alloc_zeros_f32(lwork_orgqr.max(1) as usize, device)?;
931    let mut d_info2 = alloc_zeros_i32(1, device)?;
932
933    // SAFETY: d_a contains the Householder reflectors from geqrf, d_tau the
934    // scalar factors. Sorgqr overwrites the first k columns of d_a with Q.
935    unsafe {
936        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
937        let (tau_ptr, _tau_sync) = d_tau.inner().device_ptr(&stream);
938        let (work_ptr, _work_sync) = d_work2.inner_mut().device_ptr_mut(&stream);
939        let (info_ptr, _info_sync) = d_info2.inner_mut().device_ptr_mut(&stream);
940
941        csys::cusolverDnSorgqr(
942            dn.cu(),
943            m as i32,
944            k as i32,
945            k as i32,
946            a_ptr as *mut f32,
947            m as i32,
948            tau_ptr as *const f32,
949            work_ptr as *mut f32,
950            lwork_orgqr,
951            info_ptr as *mut i32,
952        )
953        .result()?;
954    }
955
956    stream.synchronize()?;
957
958    let info_val2 = read_dev_info(&d_info2, device)?;
959    if info_val2 != 0 {
960        return Err(GpuError::ShapeMismatch {
961            op: "gpu_qr_f32: orgqr failed",
962            expected: vec![0],
963            got: vec![info_val2 as usize],
964        });
965    }
966
967    // Download Q (m-by-k column-major from d_a, but d_a has n columns total;
968    // we only need the first k columns).
969    let q_full_col = download_f32(&d_a, device)?;
970    let mut q_host = vec![0.0f32; m * k];
971    for i in 0..m {
972        for j in 0..k {
973            q_host[i * k + j] = q_full_col[j * m + i]; // col-major -> row-major
974        }
975    }
976
977    Ok((q_host, r_host))
978}
979
980/// Compute the reduced QR decomposition of an m-by-n matrix (row-major f64).
981///
982/// Returns `(Q, R)` as flat row-major `Vec<f64>` with shapes:
983/// - Q: [m, k]  where k = min(m, n)
984/// - R: [k, n]
985///
986/// Uses Dgeqrf (Householder QR) followed by Dorgqr (generate Q).
987#[cfg(feature = "cuda")]
988pub fn gpu_qr_f64(
989    data: &[f64],
990    m: usize,
991    n: usize,
992    device: &GpuDevice,
993) -> GpuResult<(Vec<f64>, Vec<f64>)> {
994    use cudarc::cusolver::sys as csys;
995
996    if m == 0 || n == 0 {
997        return Ok((vec![], vec![]));
998    }
999
1000    let k = m.min(n);
1001    let stream = device.stream();
1002    let dn = cusolver_safe::DnHandle::new(stream.clone())?;
1003
1004    // Transpose to column-major.
1005    let col_major = transpose_f64(data, m, n);
1006    let mut d_a = upload_f64(&col_major, device)?;
1007    let mut d_tau = crate::transfer::alloc_zeros_f64(k, device)?;
1008    let mut d_info = alloc_zeros_i32(1, device)?;
1009
1010    // Query workspace for geqrf.
1011    let mut lwork: i32 = 0;
1012    // SAFETY: dn.cu() is valid, d_a contains m*n f64 elements.
1013    unsafe {
1014        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
1015        csys::cusolverDnDgeqrf_bufferSize(
1016            dn.cu(),
1017            m as i32,
1018            n as i32,
1019            a_ptr as *mut f64,
1020            m as i32,
1021            &mut lwork,
1022        )
1023        .result()?;
1024    }
1025
1026    let mut d_work = crate::transfer::alloc_zeros_f64(lwork.max(1) as usize, device)?;
1027
1028    // Compute QR factorization (Householder form).
1029    // SAFETY: All device pointers are valid. d_a is overwritten in-place
1030    // with Householder reflectors (lower triangle) and R (upper triangle).
1031    unsafe {
1032        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
1033        let (tau_ptr, _tau_sync) = d_tau.inner_mut().device_ptr_mut(&stream);
1034        let (work_ptr, _work_sync) = d_work.inner_mut().device_ptr_mut(&stream);
1035        let (info_ptr, _info_sync) = d_info.inner_mut().device_ptr_mut(&stream);
1036
1037        csys::cusolverDnDgeqrf(
1038            dn.cu(),
1039            m as i32,
1040            n as i32,
1041            a_ptr as *mut f64,
1042            m as i32,
1043            tau_ptr as *mut f64,
1044            work_ptr as *mut f64,
1045            lwork,
1046            info_ptr as *mut i32,
1047        )
1048        .result()?;
1049    }
1050
1051    stream.synchronize()?;
1052
1053    let info_val = read_dev_info(&d_info, device)?;
1054    if info_val != 0 {
1055        return Err(GpuError::ShapeMismatch {
1056            op: "gpu_qr_f64: geqrf failed",
1057            expected: vec![0],
1058            got: vec![info_val as usize],
1059        });
1060    }
1061
1062    // Extract R from the upper triangle of d_a (column-major).
1063    // R is k-by-n. We read the full m-by-n column-major buffer.
1064    let qr_col = download_f64(&d_a, device)?;
1065    let mut r_host = vec![0.0f64; k * n];
1066    for i in 0..k {
1067        for j in 0..n {
1068            // In column-major m-by-n: element (i, j) is at index j*m + i.
1069            if j >= i {
1070                r_host[i * n + j] = qr_col[j * m + i]; // row-major output
1071            }
1072            // else: R[i,j] = 0 (already initialized)
1073        }
1074    }
1075
1076    // Generate explicit Q via Dorgqr.
1077    // Dorgqr overwrites d_a in-place: the first k columns become Q (m-by-k, column-major).
1078    let mut lwork_orgqr: i32 = 0;
1079
1080    // SAFETY: dn.cu() is valid, d_a and d_tau contain valid QR factorization data.
1081    unsafe {
1082        let (a_ptr, _a_sync) = d_a.inner().device_ptr(&stream);
1083        let (tau_ptr, _tau_sync) = d_tau.inner().device_ptr(&stream);
1084        csys::cusolverDnDorgqr_bufferSize(
1085            dn.cu(),
1086            m as i32,
1087            k as i32,
1088            k as i32,
1089            a_ptr as *const f64,
1090            m as i32,
1091            tau_ptr as *const f64,
1092            &mut lwork_orgqr,
1093        )
1094        .result()?;
1095    }
1096
1097    let mut d_work2 = crate::transfer::alloc_zeros_f64(lwork_orgqr.max(1) as usize, device)?;
1098    let mut d_info2 = alloc_zeros_i32(1, device)?;
1099
1100    // SAFETY: d_a contains the Householder reflectors from geqrf, d_tau the
1101    // scalar factors. Dorgqr overwrites the first k columns of d_a with Q.
1102    unsafe {
1103        let (a_ptr, _a_sync) = d_a.inner_mut().device_ptr_mut(&stream);
1104        let (tau_ptr, _tau_sync) = d_tau.inner().device_ptr(&stream);
1105        let (work_ptr, _work_sync) = d_work2.inner_mut().device_ptr_mut(&stream);
1106        let (info_ptr, _info_sync) = d_info2.inner_mut().device_ptr_mut(&stream);
1107
1108        csys::cusolverDnDorgqr(
1109            dn.cu(),
1110            m as i32,
1111            k as i32,
1112            k as i32,
1113            a_ptr as *mut f64,
1114            m as i32,
1115            tau_ptr as *const f64,
1116            work_ptr as *mut f64,
1117            lwork_orgqr,
1118            info_ptr as *mut i32,
1119        )
1120        .result()?;
1121    }
1122
1123    stream.synchronize()?;
1124
1125    let info_val2 = read_dev_info(&d_info2, device)?;
1126    if info_val2 != 0 {
1127        return Err(GpuError::ShapeMismatch {
1128            op: "gpu_qr_f64: orgqr failed",
1129            expected: vec![0],
1130            got: vec![info_val2 as usize],
1131        });
1132    }
1133
1134    // Download Q (m-by-k column-major from d_a, but d_a has n columns total;
1135    // we only need the first k columns).
1136    let q_full_col = download_f64(&d_a, device)?;
1137    let mut q_host = vec![0.0f64; m * k];
1138    for i in 0..m {
1139        for j in 0..k {
1140            q_host[i * k + j] = q_full_col[j * m + i]; // col-major -> row-major
1141        }
1142    }
1143
1144    Ok((q_host, r_host))
1145}
1146
1147// ---------------------------------------------------------------------------
1148// Stubs — always return [`GpuError::NoCudaFeature`] when `cuda` is disabled.
1149// ---------------------------------------------------------------------------
1150
1151/// Stub — always returns [`GpuError::NoCudaFeature`].
1152#[cfg(not(feature = "cuda"))]
1153pub fn gpu_svd_f32(
1154    _data: &[f32],
1155    _m: usize,
1156    _n: usize,
1157    _device: &GpuDevice,
1158) -> GpuResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
1159    Err(GpuError::NoCudaFeature)
1160}
1161
1162/// Stub — always returns [`GpuError::NoCudaFeature`].
1163#[cfg(not(feature = "cuda"))]
1164pub fn gpu_svd_f64(
1165    _data: &[f64],
1166    _m: usize,
1167    _n: usize,
1168    _device: &GpuDevice,
1169) -> GpuResult<(Vec<f64>, Vec<f64>, Vec<f64>)> {
1170    Err(GpuError::NoCudaFeature)
1171}
1172
1173/// Stub — always returns [`GpuError::NoCudaFeature`].
1174#[cfg(not(feature = "cuda"))]
1175pub fn gpu_cholesky_f32(
1176    _data: &[f32],
1177    _n: usize,
1178    _device: &GpuDevice,
1179) -> GpuResult<Vec<f32>> {
1180    Err(GpuError::NoCudaFeature)
1181}
1182
1183/// Stub — always returns [`GpuError::NoCudaFeature`].
1184#[cfg(not(feature = "cuda"))]
1185pub fn gpu_cholesky_f64(
1186    _data: &[f64],
1187    _n: usize,
1188    _device: &GpuDevice,
1189) -> GpuResult<Vec<f64>> {
1190    Err(GpuError::NoCudaFeature)
1191}
1192
1193/// Stub — always returns [`GpuError::NoCudaFeature`].
1194#[cfg(not(feature = "cuda"))]
1195pub fn gpu_solve_f32(
1196    _a_data: &[f32],
1197    _b_data: &[f32],
1198    _n: usize,
1199    _nrhs: usize,
1200    _device: &GpuDevice,
1201) -> GpuResult<Vec<f32>> {
1202    Err(GpuError::NoCudaFeature)
1203}
1204
1205/// Stub — always returns [`GpuError::NoCudaFeature`].
1206#[cfg(not(feature = "cuda"))]
1207pub fn gpu_solve_f64(
1208    _a_data: &[f64],
1209    _b_data: &[f64],
1210    _n: usize,
1211    _nrhs: usize,
1212    _device: &GpuDevice,
1213) -> GpuResult<Vec<f64>> {
1214    Err(GpuError::NoCudaFeature)
1215}
1216
1217/// Stub — always returns [`GpuError::NoCudaFeature`].
1218#[cfg(not(feature = "cuda"))]
1219pub fn gpu_qr_f32(
1220    _data: &[f32],
1221    _m: usize,
1222    _n: usize,
1223    _device: &GpuDevice,
1224) -> GpuResult<(Vec<f32>, Vec<f32>)> {
1225    Err(GpuError::NoCudaFeature)
1226}
1227
1228/// Stub — always returns [`GpuError::NoCudaFeature`].
1229#[cfg(not(feature = "cuda"))]
1230pub fn gpu_qr_f64(
1231    _data: &[f64],
1232    _m: usize,
1233    _n: usize,
1234    _device: &GpuDevice,
1235) -> GpuResult<(Vec<f64>, Vec<f64>)> {
1236    Err(GpuError::NoCudaFeature)
1237}
1238
1239// ---------------------------------------------------------------------------
1240// Tests
1241// ---------------------------------------------------------------------------
1242
1243#[cfg(test)]
1244#[cfg(feature = "cuda")]
1245mod tests {
1246    use super::*;
1247
1248    fn device() -> GpuDevice {
1249        GpuDevice::new(0).expect("CUDA device 0")
1250    }
1251
1252    // -- SVD tests --
1253
1254    #[test]
1255    fn svd_reconstructs_3x2() {
1256        let dev = device();
1257        // A = [[1, 2], [3, 4], [5, 6]]
1258        let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1259        let (m, n) = (3, 2);
1260        let (u, s, vt) = gpu_svd_f32(&a, m, n, &dev).unwrap();
1261        let k = m.min(n);
1262
1263        assert_eq!(u.len(), m * k);
1264        assert_eq!(s.len(), k);
1265        assert_eq!(vt.len(), k * n);
1266
1267        // Reconstruct: U * diag(S) * VT
1268        let mut recon = vec![0.0f32; m * n];
1269        for i in 0..m {
1270            for j in 0..n {
1271                let mut acc = 0.0f32;
1272                for p in 0..k {
1273                    acc += u[i * k + p] * s[p] * vt[p * n + j];
1274                }
1275                recon[i * n + j] = acc;
1276            }
1277        }
1278
1279        for i in 0..m * n {
1280            assert!(
1281                (recon[i] - a[i]).abs() < 1e-3,
1282                "SVD reconstruction failed at {i}: {} vs {}",
1283                recon[i],
1284                a[i]
1285            );
1286        }
1287    }
1288
1289    #[test]
1290    fn svd_singular_values_descending() {
1291        let dev = device();
1292        let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1293        let (_, s, _) = gpu_svd_f32(&a, 3, 2, &dev).unwrap();
1294        assert!(s[0] >= s[1], "singular values must be descending");
1295    }
1296
1297    #[test]
1298    fn svd_square_identity() {
1299        let dev = device();
1300        let eye: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
1301        let (_, s, _) = gpu_svd_f32(&eye, 3, 3, &dev).unwrap();
1302        for &sv in &s {
1303            assert!((sv - 1.0).abs() < 1e-4, "identity SVD should have all ones");
1304        }
1305    }
1306
1307    #[test]
1308    fn svd_empty() {
1309        let dev = device();
1310        let (u, s, vt) = gpu_svd_f32(&[], 0, 0, &dev).unwrap();
1311        assert!(u.is_empty());
1312        assert!(s.is_empty());
1313        assert!(vt.is_empty());
1314    }
1315
1316    // -- Cholesky tests --
1317
1318    #[test]
1319    fn cholesky_spd_3x3() {
1320        let dev = device();
1321        // SPD matrix: A = [[6,5,1],[5,12,5],[1,5,6]]
1322        #[rustfmt::skip]
1323        let a: Vec<f32> = vec![
1324            6.0, 5.0, 1.0,
1325            5.0, 12.0, 5.0,
1326            1.0, 5.0, 6.0,
1327        ];
1328        let l = gpu_cholesky_f32(&a, 3, &dev).unwrap();
1329
1330        // Verify lower-triangular.
1331        for i in 0..3 {
1332            for j in (i + 1)..3 {
1333                assert!(
1334                    l[i * 3 + j].abs() < 1e-5,
1335                    "L[{i},{j}] = {} should be 0",
1336                    l[i * 3 + j]
1337                );
1338            }
1339        }
1340
1341        // Reconstruct: L * L^T should equal A.
1342        let mut llt = [0.0f32; 9];
1343        for i in 0..3 {
1344            for j in 0..3 {
1345                let mut acc = 0.0f32;
1346                for p in 0..3 {
1347                    acc += l[i * 3 + p] * l[j * 3 + p];
1348                }
1349                llt[i * 3 + j] = acc;
1350            }
1351        }
1352
1353        for i in 0..9 {
1354            assert!(
1355                (llt[i] - a[i]).abs() < 1e-3,
1356                "L*L^T[{i}] = {} vs A[{i}] = {}",
1357                llt[i],
1358                a[i]
1359            );
1360        }
1361    }
1362
1363    #[test]
1364    fn cholesky_identity() {
1365        let dev = device();
1366        let eye: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
1367        let l = gpu_cholesky_f32(&eye, 3, &dev).unwrap();
1368        // Cholesky of identity is identity.
1369        for i in 0..3 {
1370            for j in 0..3 {
1371                let expected = if i == j { 1.0 } else { 0.0 };
1372                assert!(
1373                    (l[i * 3 + j] - expected).abs() < 1e-5,
1374                    "L[{i},{j}] = {} (expected {})",
1375                    l[i * 3 + j],
1376                    expected
1377                );
1378            }
1379        }
1380    }
1381
1382    #[test]
1383    fn cholesky_empty() {
1384        let dev = device();
1385        let l = gpu_cholesky_f32(&[], 0, &dev).unwrap();
1386        assert!(l.is_empty());
1387    }
1388
1389    // -- Solve tests --
1390
1391    #[test]
1392    fn solve_2x2_simple() {
1393        let dev = device();
1394        // A = [[2, 1], [1, 3]], b = [5, 10]
1395        // Solution: x = [1, 3]
1396        let a: Vec<f32> = vec![2.0, 1.0, 1.0, 3.0];
1397        let b: Vec<f32> = vec![5.0, 10.0];
1398        let x = gpu_solve_f32(&a, &b, 2, 1, &dev).unwrap();
1399        assert!((x[0] - 1.0).abs() < 1e-3, "x[0] = {} (expected 1.0)", x[0]);
1400        assert!((x[1] - 3.0).abs() < 1e-3, "x[1] = {} (expected 3.0)", x[1]);
1401    }
1402
1403    #[test]
1404    fn solve_identity() {
1405        let dev = device();
1406        let eye: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0];
1407        let b: Vec<f32> = vec![7.0, 11.0];
1408        let x = gpu_solve_f32(&eye, &b, 2, 1, &dev).unwrap();
1409        assert!((x[0] - 7.0).abs() < 1e-4);
1410        assert!((x[1] - 11.0).abs() < 1e-4);
1411    }
1412
1413    #[test]
1414    fn solve_multiple_rhs() {
1415        let dev = device();
1416        // A = [[2, 1], [1, 3]], B = [[5, 3], [10, 7]]
1417        // X = A^-1 * B
1418        let a: Vec<f32> = vec![2.0, 1.0, 1.0, 3.0];
1419        let b: Vec<f32> = vec![5.0, 3.0, 10.0, 7.0]; // 2x2 row-major
1420        let x = gpu_solve_f32(&a, &b, 2, 2, &dev).unwrap();
1421        // Verify: A * X should equal B
1422        let mut ax = [0.0f32; 4];
1423        for i in 0..2 {
1424            for j in 0..2 {
1425                ax[i * 2 + j] = a[i * 2] * x[j] + a[i * 2 + 1] * x[2 + j];
1426            }
1427        }
1428        for i in 0..4 {
1429            assert!(
1430                (ax[i] - b[i]).abs() < 1e-3,
1431                "A*X[{i}] = {} vs B[{i}] = {}",
1432                ax[i],
1433                b[i]
1434            );
1435        }
1436    }
1437
1438    #[test]
1439    fn solve_empty() {
1440        let dev = device();
1441        let x = gpu_solve_f32(&[], &[], 0, 0, &dev).unwrap();
1442        assert!(x.is_empty());
1443    }
1444
1445    // -- QR tests --
1446
1447    #[test]
1448    fn qr_reconstructs_3x2() {
1449        let dev = device();
1450        let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1451        let (m, n) = (3, 2);
1452        let (q, r) = gpu_qr_f32(&a, m, n, &dev).unwrap();
1453        let k = m.min(n);
1454
1455        assert_eq!(q.len(), m * k);
1456        assert_eq!(r.len(), k * n);
1457
1458        // Reconstruct: Q * R should equal A.
1459        let mut recon = vec![0.0f32; m * n];
1460        for i in 0..m {
1461            for j in 0..n {
1462                let mut acc = 0.0f32;
1463                for p in 0..k {
1464                    acc += q[i * k + p] * r[p * n + j];
1465                }
1466                recon[i * n + j] = acc;
1467            }
1468        }
1469
1470        for i in 0..m * n {
1471            assert!(
1472                (recon[i] - a[i]).abs() < 1e-3,
1473                "QR reconstruction failed at {i}: {} vs {}",
1474                recon[i],
1475                a[i]
1476            );
1477        }
1478    }
1479
1480    #[test]
1481    fn qr_orthogonal_q() {
1482        let dev = device();
1483        let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1484        let (m, n) = (3, 2);
1485        let (q, _) = gpu_qr_f32(&a, m, n, &dev).unwrap();
1486        let k = m.min(n);
1487
1488        // Q^T * Q should be identity_k.
1489        let mut qtq = vec![0.0f32; k * k];
1490        for i in 0..k {
1491            for j in 0..k {
1492                let mut acc = 0.0f32;
1493                for p in 0..m {
1494                    acc += q[p * k + i] * q[p * k + j];
1495                }
1496                qtq[i * k + j] = acc;
1497            }
1498        }
1499
1500        for i in 0..k {
1501            for j in 0..k {
1502                let expected = if i == j { 1.0 } else { 0.0 };
1503                assert!(
1504                    (qtq[i * k + j] - expected).abs() < 1e-3,
1505                    "Q^T*Q[{i},{j}] = {} (expected {})",
1506                    qtq[i * k + j],
1507                    expected
1508                );
1509            }
1510        }
1511    }
1512
1513    #[test]
1514    fn qr_r_upper_triangular() {
1515        let dev = device();
1516        let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1517        let (_, r) = gpu_qr_f32(&a, 3, 2, &dev).unwrap();
1518        let k = 2;
1519        let n = 2;
1520        // R should be upper-triangular.
1521        for i in 0..k {
1522            for j in 0..i.min(n) {
1523                assert!(
1524                    r[i * n + j].abs() < 1e-4,
1525                    "R[{i},{j}] = {} should be 0",
1526                    r[i * n + j]
1527                );
1528            }
1529        }
1530    }
1531
1532    #[test]
1533    fn qr_square() {
1534        let dev = device();
1535        // 3x3 matrix
1536        let a: Vec<f32> = vec![2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
1537        let (q, r) = gpu_qr_f32(&a, 3, 3, &dev).unwrap();
1538        let k = 3;
1539        let n = 3;
1540
1541        // Reconstruct
1542        let mut recon = [0.0f32; 9];
1543        for i in 0..3 {
1544            for j in 0..3 {
1545                let mut acc = 0.0f32;
1546                for p in 0..k {
1547                    acc += q[i * k + p] * r[p * n + j];
1548                }
1549                recon[i * n + j] = acc;
1550            }
1551        }
1552        for i in 0..9 {
1553            assert!(
1554                (recon[i] - a[i]).abs() < 1e-3,
1555                "QR square reconstruction failed at {i}: {} vs {}",
1556                recon[i],
1557                a[i]
1558            );
1559        }
1560    }
1561
1562    #[test]
1563    fn qr_empty() {
1564        let dev = device();
1565        let (q, r) = gpu_qr_f32(&[], 0, 0, &dev).unwrap();
1566        assert!(q.is_empty());
1567        assert!(r.is_empty());
1568    }
1569}