Skip to main content

baracuda_cusolver/
lib.rs

1//! Safe Rust wrappers for NVIDIA cuSOLVER.
2//!
3//! Covers the dense API (`Dn`) for all four BLAS scalar types:
4//! - LU factorization: `getrf` + `getrs`
5//! - QR factorization: `geqrf`
6//! - Cholesky: `potrf` + `potrs`
7//! - SVD: `gesvd`
8//! - Symmetric / Hermitian eigendecomposition: `syevd` / `heevd`
9//!
10//! The generic 64-bit X… API (`xgetrf`, `xgeqrf`, `xpotrf`) gives
11//! type-erased data pointers and is exposed under [`xapi`]. The sparse API
12//! (`cusolverSp*`) is under [`sparse`]. The refactor API (`cusolverRf*`) is
13//! under [`refactor`].
14
15#![warn(missing_debug_implementations)]
16
17use core::ffi::{c_int, c_void};
18use std::marker::PhantomData;
19
20use baracuda_cusolver_sys::{
21    cublasFillMode_t, cublasOperation_t, cuComplex, cuDoubleComplex, cusolver,
22    cusolverDnHandle_t, cusolverEigMode_t, cusolverStatus_t,
23};
24use baracuda_driver::{DeviceBuffer, Stream};
25use baracuda_types::{Complex32, Complex64, DeviceRepr};
26
27pub use baracuda_cusolver_sys::{
28    cublasFillMode_t as Fill, cusolverEigMode_t as EigMode,
29};
30
31/// Error type for cuSOLVER operations.
32pub type Error = baracuda_core::Error<cusolverStatus_t>;
33/// Result alias.
34pub type Result<T, E = Error> = core::result::Result<T, E>;
35
36#[inline]
37fn check(status: cusolverStatus_t) -> Result<()> {
38    Error::check(status)
39}
40
41/// Convert a driver allocation failure into a cuSOLVER ALLOC_FAILED.
42fn alloc_fail<E>(_e: E) -> Error {
43    Error::Status {
44        status: cusolverStatus_t::ALLOC_FAILED,
45    }
46}
47
48// ---- Handle -------------------------------------------------------------
49
50/// Dense cuSOLVER handle.
51pub struct DnHandle {
52    handle: cusolverDnHandle_t,
53}
54
55unsafe impl Send for DnHandle {}
56
57impl core::fmt::Debug for DnHandle {
58    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59        f.debug_struct("cusolver::DnHandle")
60            .field("handle", &self.handle)
61            .finish()
62    }
63}
64
65impl DnHandle {
66    pub fn new() -> Result<Self> {
67        let c = cusolver()?;
68        let cu = c.cusolver_dn_create()?;
69        let mut h: cusolverDnHandle_t = core::ptr::null_mut();
70        check(unsafe { cu(&mut h) })?;
71        Ok(Self { handle: h })
72    }
73
74    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
75        let c = cusolver()?;
76        let cu = c.cusolver_dn_set_stream()?;
77        check(unsafe { cu(self.handle, stream.as_raw() as _) })
78    }
79
80    pub fn version() -> Result<i32> {
81        let c = cusolver()?;
82        let cu = c.cusolver_get_version()?;
83        let mut v: c_int = 0;
84        check(unsafe { cu(&mut v) })?;
85        Ok(v)
86    }
87
88    #[inline]
89    pub fn as_raw(&self) -> cusolverDnHandle_t {
90        self.handle
91    }
92}
93
94impl Drop for DnHandle {
95    fn drop(&mut self) {
96        if let Ok(c) = cusolver() {
97            if let Ok(cu) = c.cusolver_dn_destroy() {
98                let _ = unsafe { cu(self.handle) };
99            }
100        }
101    }
102}
103
104/// Transposition selector for solve-step calls.
105#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
106pub enum Op {
107    #[default]
108    N,
109    T,
110    C,
111}
112
113impl Op {
114    fn raw(self) -> cublasOperation_t {
115        match self {
116            Op::N => cublasOperation_t::N,
117            Op::T => cublasOperation_t::T,
118            Op::C => cublasOperation_t::C,
119        }
120    }
121}
122
123// ---- Trait framework ----------------------------------------------------
124
125/// Scalars supported by cuSOLVER's Dn S/D/C/Z API.
126pub trait SolverScalar: DeviceRepr + Copy + 'static + sealed::Sealed {
127    /// Real-valued associate type for ops that mix scalar and norm
128    /// (f32 → f32, f64 → f64, Complex32 → f32, Complex64 → f64).
129    type Real: DeviceRepr + Copy + 'static;
130
131    /// LU buffer size.
132    #[doc(hidden)]
133    unsafe fn getrf_buf(
134        h: cusolverDnHandle_t,
135        m: c_int,
136        n: c_int,
137        a: *mut Self,
138        lda: c_int,
139        lwork: *mut c_int,
140    ) -> cusolverStatus_t;
141
142    /// LU factorization.
143    #[doc(hidden)]
144    #[allow(clippy::too_many_arguments)]
145    unsafe fn getrf(
146        h: cusolverDnHandle_t,
147        m: c_int,
148        n: c_int,
149        a: *mut Self,
150        lda: c_int,
151        workspace: *mut Self,
152        ipiv: *mut c_int,
153        info: *mut c_int,
154    ) -> cusolverStatus_t;
155
156    /// LU solve.
157    #[doc(hidden)]
158    #[allow(clippy::too_many_arguments)]
159    unsafe fn getrs(
160        h: cusolverDnHandle_t,
161        trans: cublasOperation_t,
162        n: c_int,
163        nrhs: c_int,
164        a: *const Self,
165        lda: c_int,
166        ipiv: *const c_int,
167        b: *mut Self,
168        ldb: c_int,
169        info: *mut c_int,
170    ) -> cusolverStatus_t;
171
172    /// QR factorization buffer size.
173    #[doc(hidden)]
174    unsafe fn geqrf_buf(
175        h: cusolverDnHandle_t,
176        m: c_int,
177        n: c_int,
178        a: *mut Self,
179        lda: c_int,
180        lwork: *mut c_int,
181    ) -> cusolverStatus_t;
182
183    /// QR factorization.
184    #[doc(hidden)]
185    #[allow(clippy::too_many_arguments)]
186    unsafe fn geqrf(
187        h: cusolverDnHandle_t,
188        m: c_int,
189        n: c_int,
190        a: *mut Self,
191        lda: c_int,
192        tau: *mut Self,
193        workspace: *mut Self,
194        lwork: c_int,
195        info: *mut c_int,
196    ) -> cusolverStatus_t;
197
198    /// Cholesky buffer size.
199    #[doc(hidden)]
200    unsafe fn potrf_buf(
201        h: cusolverDnHandle_t,
202        uplo: cublasFillMode_t,
203        n: c_int,
204        a: *mut Self,
205        lda: c_int,
206        lwork: *mut c_int,
207    ) -> cusolverStatus_t;
208
209    /// Cholesky factorization.
210    #[doc(hidden)]
211    #[allow(clippy::too_many_arguments)]
212    unsafe fn potrf(
213        h: cusolverDnHandle_t,
214        uplo: cublasFillMode_t,
215        n: c_int,
216        a: *mut Self,
217        lda: c_int,
218        workspace: *mut Self,
219        lwork: c_int,
220        info: *mut c_int,
221    ) -> cusolverStatus_t;
222
223    /// Cholesky solve.
224    #[doc(hidden)]
225    #[allow(clippy::too_many_arguments)]
226    unsafe fn potrs(
227        h: cusolverDnHandle_t,
228        uplo: cublasFillMode_t,
229        n: c_int,
230        nrhs: c_int,
231        a: *const Self,
232        lda: c_int,
233        b: *mut Self,
234        ldb: c_int,
235        info: *mut c_int,
236    ) -> cusolverStatus_t;
237
238    /// SVD buffer size.
239    #[doc(hidden)]
240    unsafe fn gesvd_buf(
241        h: cusolverDnHandle_t,
242        m: c_int,
243        n: c_int,
244        lwork: *mut c_int,
245    ) -> cusolverStatus_t;
246
247    /// SVD (generic, taking real-valued S + rwork for complex variants).
248    #[doc(hidden)]
249    #[allow(clippy::too_many_arguments)]
250    unsafe fn gesvd(
251        h: cusolverDnHandle_t,
252        jobu: u8,
253        jobvt: u8,
254        m: c_int,
255        n: c_int,
256        a: *mut Self,
257        lda: c_int,
258        s: *mut Self::Real,
259        u: *mut Self,
260        ldu: c_int,
261        vt: *mut Self,
262        ldvt: c_int,
263        work: *mut Self,
264        lwork: c_int,
265        rwork: *mut Self::Real,
266        info: *mut c_int,
267    ) -> cusolverStatus_t;
268
269    /// syevd / heevd buffer size.
270    #[doc(hidden)]
271    #[allow(clippy::too_many_arguments)]
272    unsafe fn syevd_buf(
273        h: cusolverDnHandle_t,
274        jobz: cusolverEigMode_t,
275        uplo: cublasFillMode_t,
276        n: c_int,
277        a: *const Self,
278        lda: c_int,
279        w: *const Self::Real,
280        lwork: *mut c_int,
281    ) -> cusolverStatus_t;
282
283    /// syevd / heevd.
284    #[doc(hidden)]
285    #[allow(clippy::too_many_arguments)]
286    unsafe fn syevd(
287        h: cusolverDnHandle_t,
288        jobz: cusolverEigMode_t,
289        uplo: cublasFillMode_t,
290        n: c_int,
291        a: *mut Self,
292        lda: c_int,
293        w: *mut Self::Real,
294        work: *mut Self,
295        lwork: c_int,
296        info: *mut c_int,
297    ) -> cusolverStatus_t;
298}
299
300mod sealed {
301    use baracuda_types::{Complex32, Complex64};
302    pub trait Sealed {}
303    impl Sealed for f32 {}
304    impl Sealed for f64 {}
305    impl Sealed for Complex32 {}
306    impl Sealed for Complex64 {}
307}
308
309macro_rules! real_impl {
310    ($t:ty, $getrf_buf:ident, $getrf:ident, $getrs:ident,
311           $geqrf_buf:ident, $geqrf:ident,
312           $potrf_buf:ident, $potrf:ident, $potrs:ident,
313           $gesvd_buf:ident, $gesvd:ident,
314           $syevd_buf:ident, $syevd:ident) => {
315        impl SolverScalar for $t {
316            type Real = $t;
317
318            unsafe fn getrf_buf(
319                h: cusolverDnHandle_t,
320                m: c_int,
321                n: c_int,
322                a: *mut $t,
323                lda: c_int,
324                lwork: *mut c_int,
325            ) -> cusolverStatus_t {
326                match cusolver().and_then(|c| c.$getrf_buf()) {
327                    Ok(f) => f(h, m, n, a, lda, lwork),
328                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
329                }
330            }
331            unsafe fn getrf(
332                h: cusolverDnHandle_t,
333                m: c_int,
334                n: c_int,
335                a: *mut $t,
336                lda: c_int,
337                work: *mut $t,
338                ipiv: *mut c_int,
339                info: *mut c_int,
340            ) -> cusolverStatus_t {
341                match cusolver().and_then(|c| c.$getrf()) {
342                    Ok(f) => f(h, m, n, a, lda, work, ipiv, info),
343                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
344                }
345            }
346            unsafe fn getrs(
347                h: cusolverDnHandle_t,
348                trans: cublasOperation_t,
349                n: c_int,
350                nrhs: c_int,
351                a: *const $t,
352                lda: c_int,
353                ipiv: *const c_int,
354                b: *mut $t,
355                ldb: c_int,
356                info: *mut c_int,
357            ) -> cusolverStatus_t {
358                match cusolver().and_then(|c| c.$getrs()) {
359                    Ok(f) => f(h, trans, n, nrhs, a, lda, ipiv, b, ldb, info),
360                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
361                }
362            }
363            unsafe fn geqrf_buf(
364                h: cusolverDnHandle_t,
365                m: c_int,
366                n: c_int,
367                a: *mut $t,
368                lda: c_int,
369                lwork: *mut c_int,
370            ) -> cusolverStatus_t {
371                match cusolver().and_then(|c| c.$geqrf_buf()) {
372                    Ok(f) => f(h, m, n, a, lda, lwork),
373                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
374                }
375            }
376            unsafe fn geqrf(
377                h: cusolverDnHandle_t,
378                m: c_int,
379                n: c_int,
380                a: *mut $t,
381                lda: c_int,
382                tau: *mut $t,
383                work: *mut $t,
384                lwork: c_int,
385                info: *mut c_int,
386            ) -> cusolverStatus_t {
387                match cusolver().and_then(|c| c.$geqrf()) {
388                    Ok(f) => f(h, m, n, a, lda, tau, work, lwork, info),
389                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
390                }
391            }
392            unsafe fn potrf_buf(
393                h: cusolverDnHandle_t,
394                uplo: cublasFillMode_t,
395                n: c_int,
396                a: *mut $t,
397                lda: c_int,
398                lwork: *mut c_int,
399            ) -> cusolverStatus_t {
400                match cusolver().and_then(|c| c.$potrf_buf()) {
401                    Ok(f) => f(h, uplo, n, a, lda, lwork),
402                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
403                }
404            }
405            unsafe fn potrf(
406                h: cusolverDnHandle_t,
407                uplo: cublasFillMode_t,
408                n: c_int,
409                a: *mut $t,
410                lda: c_int,
411                work: *mut $t,
412                lwork: c_int,
413                info: *mut c_int,
414            ) -> cusolverStatus_t {
415                match cusolver().and_then(|c| c.$potrf()) {
416                    Ok(f) => f(h, uplo, n, a, lda, work, lwork, info),
417                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
418                }
419            }
420            unsafe fn potrs(
421                h: cusolverDnHandle_t,
422                uplo: cublasFillMode_t,
423                n: c_int,
424                nrhs: c_int,
425                a: *const $t,
426                lda: c_int,
427                b: *mut $t,
428                ldb: c_int,
429                info: *mut c_int,
430            ) -> cusolverStatus_t {
431                match cusolver().and_then(|c| c.$potrs()) {
432                    Ok(f) => f(h, uplo, n, nrhs, a, lda, b, ldb, info),
433                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
434                }
435            }
436            unsafe fn gesvd_buf(
437                h: cusolverDnHandle_t,
438                m: c_int,
439                n: c_int,
440                lwork: *mut c_int,
441            ) -> cusolverStatus_t {
442                match cusolver().and_then(|c| c.$gesvd_buf()) {
443                    Ok(f) => f(h, m, n, lwork),
444                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
445                }
446            }
447            unsafe fn gesvd(
448                h: cusolverDnHandle_t,
449                jobu: u8,
450                jobvt: u8,
451                m: c_int,
452                n: c_int,
453                a: *mut $t,
454                lda: c_int,
455                s: *mut $t,
456                u: *mut $t,
457                ldu: c_int,
458                vt: *mut $t,
459                ldvt: c_int,
460                work: *mut $t,
461                lwork: c_int,
462                rwork: *mut $t,
463                info: *mut c_int,
464            ) -> cusolverStatus_t {
465                match cusolver().and_then(|c| c.$gesvd()) {
466                    Ok(f) => f(
467                        h, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork, info,
468                    ),
469                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
470                }
471            }
472            unsafe fn syevd_buf(
473                h: cusolverDnHandle_t,
474                jobz: cusolverEigMode_t,
475                uplo: cublasFillMode_t,
476                n: c_int,
477                a: *const $t,
478                lda: c_int,
479                w: *const $t,
480                lwork: *mut c_int,
481            ) -> cusolverStatus_t {
482                match cusolver().and_then(|c| c.$syevd_buf()) {
483                    Ok(f) => f(h, jobz, uplo, n, a, lda, w, lwork),
484                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
485                }
486            }
487            unsafe fn syevd(
488                h: cusolverDnHandle_t,
489                jobz: cusolverEigMode_t,
490                uplo: cublasFillMode_t,
491                n: c_int,
492                a: *mut $t,
493                lda: c_int,
494                w: *mut $t,
495                work: *mut $t,
496                lwork: c_int,
497                info: *mut c_int,
498            ) -> cusolverStatus_t {
499                match cusolver().and_then(|c| c.$syevd()) {
500                    Ok(f) => f(h, jobz, uplo, n, a, lda, w, work, lwork, info),
501                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
502                }
503            }
504        }
505    };
506}
507
508macro_rules! complex_impl {
509    ($t:ty, $real:ty, $raw:ty,
510     $getrf_buf:ident, $getrf:ident, $getrs:ident,
511     $geqrf_buf:ident, $geqrf:ident,
512     $potrf_buf:ident, $potrf:ident, $potrs:ident,
513     $gesvd_buf:ident, $gesvd:ident,
514     $heevd_buf:ident, $heevd:ident) => {
515        impl SolverScalar for $t {
516            type Real = $real;
517
518            unsafe fn getrf_buf(
519                h: cusolverDnHandle_t,
520                m: c_int,
521                n: c_int,
522                a: *mut $t,
523                lda: c_int,
524                lwork: *mut c_int,
525            ) -> cusolverStatus_t {
526                match cusolver().and_then(|c| c.$getrf_buf()) {
527                    Ok(f) => f(h, m, n, a as *mut $raw, lda, lwork),
528                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
529                }
530            }
531            unsafe fn getrf(
532                h: cusolverDnHandle_t,
533                m: c_int,
534                n: c_int,
535                a: *mut $t,
536                lda: c_int,
537                work: *mut $t,
538                ipiv: *mut c_int,
539                info: *mut c_int,
540            ) -> cusolverStatus_t {
541                match cusolver().and_then(|c| c.$getrf()) {
542                    Ok(f) => f(
543                        h,
544                        m,
545                        n,
546                        a as *mut $raw,
547                        lda,
548                        work as *mut $raw,
549                        ipiv,
550                        info,
551                    ),
552                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
553                }
554            }
555            unsafe fn getrs(
556                h: cusolverDnHandle_t,
557                trans: cublasOperation_t,
558                n: c_int,
559                nrhs: c_int,
560                a: *const $t,
561                lda: c_int,
562                ipiv: *const c_int,
563                b: *mut $t,
564                ldb: c_int,
565                info: *mut c_int,
566            ) -> cusolverStatus_t {
567                match cusolver().and_then(|c| c.$getrs()) {
568                    Ok(f) => f(
569                        h,
570                        trans,
571                        n,
572                        nrhs,
573                        a as *const $raw,
574                        lda,
575                        ipiv,
576                        b as *mut $raw,
577                        ldb,
578                        info,
579                    ),
580                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
581                }
582            }
583            unsafe fn geqrf_buf(
584                h: cusolverDnHandle_t,
585                m: c_int,
586                n: c_int,
587                a: *mut $t,
588                lda: c_int,
589                lwork: *mut c_int,
590            ) -> cusolverStatus_t {
591                match cusolver().and_then(|c| c.$geqrf_buf()) {
592                    Ok(f) => f(h, m, n, a as *mut $raw, lda, lwork),
593                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
594                }
595            }
596            unsafe fn geqrf(
597                h: cusolverDnHandle_t,
598                m: c_int,
599                n: c_int,
600                a: *mut $t,
601                lda: c_int,
602                tau: *mut $t,
603                work: *mut $t,
604                lwork: c_int,
605                info: *mut c_int,
606            ) -> cusolverStatus_t {
607                match cusolver().and_then(|c| c.$geqrf()) {
608                    Ok(f) => f(
609                        h,
610                        m,
611                        n,
612                        a as *mut $raw,
613                        lda,
614                        tau as *mut $raw,
615                        work as *mut $raw,
616                        lwork,
617                        info,
618                    ),
619                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
620                }
621            }
622            unsafe fn potrf_buf(
623                h: cusolverDnHandle_t,
624                uplo: cublasFillMode_t,
625                n: c_int,
626                a: *mut $t,
627                lda: c_int,
628                lwork: *mut c_int,
629            ) -> cusolverStatus_t {
630                match cusolver().and_then(|c| c.$potrf_buf()) {
631                    Ok(f) => f(h, uplo, n, a as *mut $raw, lda, lwork),
632                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
633                }
634            }
635            unsafe fn potrf(
636                h: cusolverDnHandle_t,
637                uplo: cublasFillMode_t,
638                n: c_int,
639                a: *mut $t,
640                lda: c_int,
641                work: *mut $t,
642                lwork: c_int,
643                info: *mut c_int,
644            ) -> cusolverStatus_t {
645                match cusolver().and_then(|c| c.$potrf()) {
646                    Ok(f) => f(
647                        h,
648                        uplo,
649                        n,
650                        a as *mut $raw,
651                        lda,
652                        work as *mut $raw,
653                        lwork,
654                        info,
655                    ),
656                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
657                }
658            }
659            unsafe fn potrs(
660                h: cusolverDnHandle_t,
661                uplo: cublasFillMode_t,
662                n: c_int,
663                nrhs: c_int,
664                a: *const $t,
665                lda: c_int,
666                b: *mut $t,
667                ldb: c_int,
668                info: *mut c_int,
669            ) -> cusolverStatus_t {
670                match cusolver().and_then(|c| c.$potrs()) {
671                    Ok(f) => f(
672                        h,
673                        uplo,
674                        n,
675                        nrhs,
676                        a as *const $raw,
677                        lda,
678                        b as *mut $raw,
679                        ldb,
680                        info,
681                    ),
682                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
683                }
684            }
685            unsafe fn gesvd_buf(
686                h: cusolverDnHandle_t,
687                m: c_int,
688                n: c_int,
689                lwork: *mut c_int,
690            ) -> cusolverStatus_t {
691                match cusolver().and_then(|c| c.$gesvd_buf()) {
692                    Ok(f) => f(h, m, n, lwork),
693                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
694                }
695            }
696            unsafe fn gesvd(
697                h: cusolverDnHandle_t,
698                jobu: u8,
699                jobvt: u8,
700                m: c_int,
701                n: c_int,
702                a: *mut $t,
703                lda: c_int,
704                s: *mut $real,
705                u: *mut $t,
706                ldu: c_int,
707                vt: *mut $t,
708                ldvt: c_int,
709                work: *mut $t,
710                lwork: c_int,
711                rwork: *mut $real,
712                info: *mut c_int,
713            ) -> cusolverStatus_t {
714                match cusolver().and_then(|c| c.$gesvd()) {
715                    Ok(f) => f(
716                        h,
717                        jobu,
718                        jobvt,
719                        m,
720                        n,
721                        a as *mut $raw,
722                        lda,
723                        s,
724                        u as *mut $raw,
725                        ldu,
726                        vt as *mut $raw,
727                        ldvt,
728                        work as *mut $raw,
729                        lwork,
730                        rwork,
731                        info,
732                    ),
733                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
734                }
735            }
736            unsafe fn syevd_buf(
737                h: cusolverDnHandle_t,
738                jobz: cusolverEigMode_t,
739                uplo: cublasFillMode_t,
740                n: c_int,
741                a: *const $t,
742                lda: c_int,
743                w: *const $real,
744                lwork: *mut c_int,
745            ) -> cusolverStatus_t {
746                match cusolver().and_then(|c| c.$heevd_buf()) {
747                    Ok(f) => f(h, jobz, uplo, n, a as *const $raw, lda, w, lwork),
748                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
749                }
750            }
751            unsafe fn syevd(
752                h: cusolverDnHandle_t,
753                jobz: cusolverEigMode_t,
754                uplo: cublasFillMode_t,
755                n: c_int,
756                a: *mut $t,
757                lda: c_int,
758                w: *mut $real,
759                work: *mut $t,
760                lwork: c_int,
761                info: *mut c_int,
762            ) -> cusolverStatus_t {
763                match cusolver().and_then(|c| c.$heevd()) {
764                    Ok(f) => f(
765                        h,
766                        jobz,
767                        uplo,
768                        n,
769                        a as *mut $raw,
770                        lda,
771                        w,
772                        work as *mut $raw,
773                        lwork,
774                        info,
775                    ),
776                    Err(_) => cusolverStatus_t::NOT_INITIALIZED,
777                }
778            }
779        }
780    };
781}
782
783real_impl!(
784    f32,
785    cusolver_dn_sgetrf_buffer_size,
786    cusolver_dn_sgetrf,
787    cusolver_dn_sgetrs,
788    cusolver_dn_sgeqrf_buffer_size,
789    cusolver_dn_sgeqrf,
790    cusolver_dn_spotrf_buffer_size,
791    cusolver_dn_spotrf,
792    cusolver_dn_spotrs,
793    cusolver_dn_sgesvd_buffer_size,
794    cusolver_dn_sgesvd,
795    cusolver_dn_ssyevd_buffer_size,
796    cusolver_dn_ssyevd
797);
798
799real_impl!(
800    f64,
801    cusolver_dn_dgetrf_buffer_size,
802    cusolver_dn_dgetrf,
803    cusolver_dn_dgetrs,
804    cusolver_dn_dgeqrf_buffer_size,
805    cusolver_dn_dgeqrf,
806    cusolver_dn_dpotrf_buffer_size,
807    cusolver_dn_dpotrf,
808    cusolver_dn_dpotrs,
809    cusolver_dn_dgesvd_buffer_size,
810    cusolver_dn_dgesvd,
811    cusolver_dn_dsyevd_buffer_size,
812    cusolver_dn_dsyevd
813);
814
815complex_impl!(
816    Complex32,
817    f32,
818    cuComplex,
819    cusolver_dn_cgetrf_buffer_size,
820    cusolver_dn_cgetrf,
821    cusolver_dn_cgetrs,
822    cusolver_dn_cgeqrf_buffer_size,
823    cusolver_dn_cgeqrf,
824    cusolver_dn_cpotrf_buffer_size,
825    cusolver_dn_cpotrf,
826    cusolver_dn_cpotrs,
827    cusolver_dn_cgesvd_buffer_size,
828    cusolver_dn_cgesvd,
829    cusolver_dn_cheevd_buffer_size,
830    cusolver_dn_cheevd
831);
832
833complex_impl!(
834    Complex64,
835    f64,
836    cuDoubleComplex,
837    cusolver_dn_zgetrf_buffer_size,
838    cusolver_dn_zgetrf,
839    cusolver_dn_zgetrs,
840    cusolver_dn_zgeqrf_buffer_size,
841    cusolver_dn_zgeqrf,
842    cusolver_dn_zpotrf_buffer_size,
843    cusolver_dn_zpotrf,
844    cusolver_dn_zpotrs,
845    cusolver_dn_zgesvd_buffer_size,
846    cusolver_dn_zgesvd,
847    cusolver_dn_zheevd_buffer_size,
848    cusolver_dn_zheevd
849);
850
851// ---- Public API ---------------------------------------------------------
852
853/// In-place LU factorization of a column-major matrix. Overwrites `a`.
854#[allow(clippy::too_many_arguments)]
855pub fn getrf<T: SolverScalar>(
856    handle: &DnHandle,
857    m: i32,
858    n: i32,
859    a: &mut DeviceBuffer<T>,
860    lda: i32,
861    ipiv: &mut DeviceBuffer<i32>,
862    info: &mut DeviceBuffer<i32>,
863) -> Result<()> {
864    let mut lwork: c_int = 0;
865    check(unsafe { T::getrf_buf(handle.handle, m, n, a.as_raw().0 as *mut T, lda, &mut lwork) })?;
866    let workspace =
867        DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
868    check(unsafe {
869        T::getrf(
870            handle.handle,
871            m,
872            n,
873            a.as_raw().0 as *mut T,
874            lda,
875            workspace.as_raw().0 as *mut T,
876            ipiv.as_raw().0 as *mut c_int,
877            info.as_raw().0 as *mut c_int,
878        )
879    })
880}
881
882/// Solve `op(A) * X = B` using the LU factorization from [`getrf`].
883#[allow(clippy::too_many_arguments)]
884pub fn getrs<T: SolverScalar>(
885    handle: &DnHandle,
886    trans: Op,
887    n: i32,
888    nrhs: i32,
889    a: &DeviceBuffer<T>,
890    lda: i32,
891    ipiv: &DeviceBuffer<i32>,
892    b: &mut DeviceBuffer<T>,
893    ldb: i32,
894    info: &mut DeviceBuffer<i32>,
895) -> Result<()> {
896    check(unsafe {
897        T::getrs(
898            handle.handle,
899            trans.raw(),
900            n,
901            nrhs,
902            a.as_raw().0 as *const T,
903            lda,
904            ipiv.as_raw().0 as *const c_int,
905            b.as_raw().0 as *mut T,
906            ldb,
907            info.as_raw().0 as *mut c_int,
908        )
909    })
910}
911
912/// QR factorization: `A = Q * R`. Overwrites `a` (upper triangle = R,
913/// lower = Householder reflectors); `tau` receives reflector scalars.
914#[allow(clippy::too_many_arguments)]
915pub fn geqrf<T: SolverScalar>(
916    handle: &DnHandle,
917    m: i32,
918    n: i32,
919    a: &mut DeviceBuffer<T>,
920    lda: i32,
921    tau: &mut DeviceBuffer<T>,
922    info: &mut DeviceBuffer<i32>,
923) -> Result<()> {
924    let mut lwork: c_int = 0;
925    check(unsafe { T::geqrf_buf(handle.handle, m, n, a.as_raw().0 as *mut T, lda, &mut lwork) })?;
926    let workspace =
927        DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
928    check(unsafe {
929        T::geqrf(
930            handle.handle,
931            m,
932            n,
933            a.as_raw().0 as *mut T,
934            lda,
935            tau.as_raw().0 as *mut T,
936            workspace.as_raw().0 as *mut T,
937            lwork,
938            info.as_raw().0 as *mut c_int,
939        )
940    })
941}
942
943/// Cholesky factorization: `A = L * Lᵀ` (or `Uᵀ * U`). Overwrites `a`.
944pub fn potrf<T: SolverScalar>(
945    handle: &DnHandle,
946    uplo: Fill,
947    n: i32,
948    a: &mut DeviceBuffer<T>,
949    lda: i32,
950    info: &mut DeviceBuffer<i32>,
951) -> Result<()> {
952    let mut lwork: c_int = 0;
953    check(unsafe { T::potrf_buf(handle.handle, uplo, n, a.as_raw().0 as *mut T, lda, &mut lwork) })?;
954    let workspace =
955        DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
956    check(unsafe {
957        T::potrf(
958            handle.handle,
959            uplo,
960            n,
961            a.as_raw().0 as *mut T,
962            lda,
963            workspace.as_raw().0 as *mut T,
964            lwork,
965            info.as_raw().0 as *mut c_int,
966        )
967    })
968}
969
970/// Solve `A * X = B` using the Cholesky factorization from [`potrf`].
971#[allow(clippy::too_many_arguments)]
972pub fn potrs<T: SolverScalar>(
973    handle: &DnHandle,
974    uplo: Fill,
975    n: i32,
976    nrhs: i32,
977    a: &DeviceBuffer<T>,
978    lda: i32,
979    b: &mut DeviceBuffer<T>,
980    ldb: i32,
981    info: &mut DeviceBuffer<i32>,
982) -> Result<()> {
983    check(unsafe {
984        T::potrs(
985            handle.handle,
986            uplo,
987            n,
988            nrhs,
989            a.as_raw().0 as *const T,
990            lda,
991            b.as_raw().0 as *mut T,
992            ldb,
993            info.as_raw().0 as *mut c_int,
994        )
995    })
996}
997
998/// Full SVD: `A = U * Σ * Vᵀ`. `jobu`/`jobvt` are LAPACK-style single-byte
999/// selectors (b'A' = all, b'S' = economy, b'N' = none, b'O' = overwrite A).
1000///
1001/// `rwork` must be provided for complex element types; pass an empty buffer
1002/// for real types (pointer is still non-null; cuSOLVER ignores it).
1003#[allow(clippy::too_many_arguments)]
1004pub fn gesvd<T: SolverScalar>(
1005    handle: &DnHandle,
1006    jobu: u8,
1007    jobvt: u8,
1008    m: i32,
1009    n: i32,
1010    a: &mut DeviceBuffer<T>,
1011    lda: i32,
1012    s: &mut DeviceBuffer<T::Real>,
1013    u: &mut DeviceBuffer<T>,
1014    ldu: i32,
1015    vt: &mut DeviceBuffer<T>,
1016    ldvt: i32,
1017    rwork: &mut DeviceBuffer<T::Real>,
1018    info: &mut DeviceBuffer<i32>,
1019) -> Result<()> {
1020    let mut lwork: c_int = 0;
1021    check(unsafe { T::gesvd_buf(handle.handle, m, n, &mut lwork) })?;
1022    let workspace =
1023        DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1024    check(unsafe {
1025        T::gesvd(
1026            handle.handle,
1027            jobu,
1028            jobvt,
1029            m,
1030            n,
1031            a.as_raw().0 as *mut T,
1032            lda,
1033            s.as_raw().0 as *mut T::Real,
1034            u.as_raw().0 as *mut T,
1035            ldu,
1036            vt.as_raw().0 as *mut T,
1037            ldvt,
1038            workspace.as_raw().0 as *mut T,
1039            lwork,
1040            rwork.as_raw().0 as *mut T::Real,
1041            info.as_raw().0 as *mut c_int,
1042        )
1043    })
1044}
1045
1046/// Symmetric / Hermitian eigenvalue decomposition: `A = Q * diag(w) * Qᵀ`.
1047#[allow(clippy::too_many_arguments)]
1048pub fn syevd<T: SolverScalar>(
1049    handle: &DnHandle,
1050    jobz: EigMode,
1051    uplo: Fill,
1052    n: i32,
1053    a: &mut DeviceBuffer<T>,
1054    lda: i32,
1055    w: &mut DeviceBuffer<T::Real>,
1056    info: &mut DeviceBuffer<i32>,
1057) -> Result<()> {
1058    let mut lwork: c_int = 0;
1059    check(unsafe {
1060        T::syevd_buf(
1061            handle.handle,
1062            jobz,
1063            uplo,
1064            n,
1065            a.as_raw().0 as *const T,
1066            lda,
1067            w.as_raw().0 as *const T::Real,
1068            &mut lwork,
1069        )
1070    })?;
1071    let workspace =
1072        DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1073    check(unsafe {
1074        T::syevd(
1075            handle.handle,
1076            jobz,
1077            uplo,
1078            n,
1079            a.as_raw().0 as *mut T,
1080            lda,
1081            w.as_raw().0 as *mut T::Real,
1082            workspace.as_raw().0 as *mut T,
1083            lwork,
1084            info.as_raw().0 as *mut c_int,
1085        )
1086    })
1087}
1088
1089// ---- Jacobi-based solvers (syevj / gesvdj) -----------------------------
1090
1091pub use baracuda_cusolver_sys::{gesvdjInfo_t as GesvdjInfoRaw, syevjInfo_t as SyevjInfoRaw};
1092
1093/// Jacobi-eigen tuning handle (tolerance + max sweeps).
1094#[derive(Debug)]
1095pub struct SyevjInfo {
1096    raw: SyevjInfoRaw,
1097}
1098
1099impl SyevjInfo {
1100    pub fn new() -> Result<Self> {
1101        let c = cusolver()?;
1102        let cu = c.cusolver_dn_create_syevj_info()?;
1103        let mut raw: SyevjInfoRaw = core::ptr::null_mut();
1104        check(unsafe { cu(&mut raw) })?;
1105        Ok(Self { raw })
1106    }
1107
1108    pub fn set_tolerance(&self, tol: f64) -> Result<()> {
1109        let c = cusolver()?;
1110        let cu = c.cusolver_dn_xsyevj_set_tolerance()?;
1111        check(unsafe { cu(self.raw, tol) })
1112    }
1113
1114    pub fn set_max_sweeps(&self, n: i32) -> Result<()> {
1115        let c = cusolver()?;
1116        let cu = c.cusolver_dn_xsyevj_set_max_sweeps()?;
1117        check(unsafe { cu(self.raw, n) })
1118    }
1119
1120    pub fn as_raw(&self) -> SyevjInfoRaw {
1121        self.raw
1122    }
1123}
1124
1125impl Drop for SyevjInfo {
1126    fn drop(&mut self) {
1127        if let Ok(c) = cusolver() {
1128            if let Ok(cu) = c.cusolver_dn_destroy_syevj_info() {
1129                let _ = unsafe { cu(self.raw) };
1130            }
1131        }
1132    }
1133}
1134
1135/// Jacobi-SVD tuning handle.
1136#[derive(Debug)]
1137pub struct GesvdjInfo {
1138    raw: GesvdjInfoRaw,
1139}
1140
1141impl GesvdjInfo {
1142    pub fn new() -> Result<Self> {
1143        let c = cusolver()?;
1144        let cu = c.cusolver_dn_create_gesvdj_info()?;
1145        let mut raw: GesvdjInfoRaw = core::ptr::null_mut();
1146        check(unsafe { cu(&mut raw) })?;
1147        Ok(Self { raw })
1148    }
1149
1150    pub fn as_raw(&self) -> GesvdjInfoRaw {
1151        self.raw
1152    }
1153}
1154
1155impl Drop for GesvdjInfo {
1156    fn drop(&mut self) {
1157        if let Ok(c) = cusolver() {
1158            if let Ok(cu) = c.cusolver_dn_destroy_gesvdj_info() {
1159                let _ = unsafe { cu(self.raw) };
1160            }
1161        }
1162    }
1163}
1164
1165/// Jacobi symmetric/Hermitian eigendecomposition (smaller matrices than
1166/// [`syevd`], faster convergence on well-conditioned problems).
1167#[allow(clippy::too_many_arguments)]
1168pub fn syevj<T: SolverScalar>(
1169    handle: &DnHandle,
1170    jobz: EigMode,
1171    uplo: Fill,
1172    n: i32,
1173    a: &mut DeviceBuffer<T>,
1174    lda: i32,
1175    w: &mut DeviceBuffer<T::Real>,
1176    info: &mut DeviceBuffer<i32>,
1177    params: &SyevjInfo,
1178) -> Result<()> {
1179    use baracuda_cusolver_sys::{
1180        cuComplex, cuDoubleComplex,
1181    };
1182    use core::mem;
1183
1184    let mut lwork: c_int = 0;
1185
1186    // Dispatch is simpler done via a type check, since syevj doesn't share
1187    // the generic trait shape (extra params: SyevjInfo).
1188    macro_rules! dispatch_real {
1189        ($t:ty, $bufsize:ident, $solve:ident) => {{
1190            let c = cusolver()?;
1191            check(unsafe {
1192                (c.$bufsize()?)(
1193                    handle.as_raw(),
1194                    jobz,
1195                    uplo,
1196                    n,
1197                    a.as_raw().0 as *const $t,
1198                    lda,
1199                    w.as_raw().0 as *const $t,
1200                    &mut lwork,
1201                    params.raw,
1202                )
1203            })?;
1204            let workspace =
1205                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1206            check(unsafe {
1207                (c.$solve()?)(
1208                    handle.as_raw(),
1209                    jobz,
1210                    uplo,
1211                    n,
1212                    a.as_raw().0 as *mut $t,
1213                    lda,
1214                    w.as_raw().0 as *mut $t,
1215                    workspace.as_raw().0 as *mut $t,
1216                    lwork,
1217                    info.as_raw().0 as *mut c_int,
1218                    params.raw,
1219                )
1220            })
1221        }};
1222    }
1223    macro_rules! dispatch_complex {
1224        ($t:ty, $real:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1225            let c = cusolver()?;
1226            check(unsafe {
1227                (c.$bufsize()?)(
1228                    handle.as_raw(),
1229                    jobz,
1230                    uplo,
1231                    n,
1232                    a.as_raw().0 as *const $raw,
1233                    lda,
1234                    w.as_raw().0 as *const $real,
1235                    &mut lwork,
1236                    params.raw,
1237                )
1238            })?;
1239            let workspace =
1240                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1241            check(unsafe {
1242                (c.$solve()?)(
1243                    handle.as_raw(),
1244                    jobz,
1245                    uplo,
1246                    n,
1247                    a.as_raw().0 as *mut $raw,
1248                    lda,
1249                    w.as_raw().0 as *mut $real,
1250                    workspace.as_raw().0 as *mut $raw,
1251                    lwork,
1252                    info.as_raw().0 as *mut c_int,
1253                    params.raw,
1254                )
1255            })
1256        }};
1257    }
1258
1259    if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1260        dispatch_real!(f32, cusolver_dn_ssyevj_buffer_size, cusolver_dn_ssyevj)
1261    } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1262        dispatch_real!(f64, cusolver_dn_dsyevj_buffer_size, cusolver_dn_dsyevj)
1263    } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1264        dispatch_complex!(
1265            Complex32,
1266            f32,
1267            cuComplex,
1268            cusolver_dn_cheevj_buffer_size,
1269            cusolver_dn_cheevj
1270        )
1271    } else {
1272        dispatch_complex!(
1273            Complex64,
1274            f64,
1275            cuDoubleComplex,
1276            cusolver_dn_zheevj_buffer_size,
1277            cusolver_dn_zheevj
1278        )
1279    }
1280}
1281
1282/// Jacobi SVD: `A = U * diag(s) * Vᴴ`. `econ` selects thin-SVD when set.
1283#[allow(clippy::too_many_arguments)]
1284pub fn gesvdj<T: SolverScalar>(
1285    handle: &DnHandle,
1286    jobz: EigMode,
1287    econ: bool,
1288    m: i32,
1289    n: i32,
1290    a: &mut DeviceBuffer<T>,
1291    lda: i32,
1292    s: &mut DeviceBuffer<T::Real>,
1293    u: &mut DeviceBuffer<T>,
1294    ldu: i32,
1295    v: &mut DeviceBuffer<T>,
1296    ldv: i32,
1297    info: &mut DeviceBuffer<i32>,
1298    params: &GesvdjInfo,
1299) -> Result<()> {
1300    use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1301    use core::mem;
1302
1303    let mut lwork: c_int = 0;
1304    let econ_i = if econ { 1 } else { 0 };
1305
1306    macro_rules! dispatch_real {
1307        ($t:ty, $bufsize:ident, $solve:ident) => {{
1308            let c = cusolver()?;
1309            check(unsafe {
1310                (c.$bufsize()?)(
1311                    handle.as_raw(),
1312                    jobz,
1313                    econ_i,
1314                    m,
1315                    n,
1316                    a.as_raw().0 as *const $t,
1317                    lda,
1318                    s.as_raw().0 as *const $t,
1319                    u.as_raw().0 as *const $t,
1320                    ldu,
1321                    v.as_raw().0 as *const $t,
1322                    ldv,
1323                    &mut lwork,
1324                    params.raw,
1325                )
1326            })?;
1327            let workspace =
1328                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1329            check(unsafe {
1330                (c.$solve()?)(
1331                    handle.as_raw(),
1332                    jobz,
1333                    econ_i,
1334                    m,
1335                    n,
1336                    a.as_raw().0 as *mut $t,
1337                    lda,
1338                    s.as_raw().0 as *mut $t,
1339                    u.as_raw().0 as *mut $t,
1340                    ldu,
1341                    v.as_raw().0 as *mut $t,
1342                    ldv,
1343                    workspace.as_raw().0 as *mut $t,
1344                    lwork,
1345                    info.as_raw().0 as *mut c_int,
1346                    params.raw,
1347                )
1348            })
1349        }};
1350    }
1351    macro_rules! dispatch_complex {
1352        ($t:ty, $real:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1353            let c = cusolver()?;
1354            check(unsafe {
1355                (c.$bufsize()?)(
1356                    handle.as_raw(),
1357                    jobz,
1358                    econ_i,
1359                    m,
1360                    n,
1361                    a.as_raw().0 as *const $raw,
1362                    lda,
1363                    s.as_raw().0 as *const $real,
1364                    u.as_raw().0 as *const $raw,
1365                    ldu,
1366                    v.as_raw().0 as *const $raw,
1367                    ldv,
1368                    &mut lwork,
1369                    params.raw,
1370                )
1371            })?;
1372            let workspace =
1373                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1374            check(unsafe {
1375                (c.$solve()?)(
1376                    handle.as_raw(),
1377                    jobz,
1378                    econ_i,
1379                    m,
1380                    n,
1381                    a.as_raw().0 as *mut $raw,
1382                    lda,
1383                    s.as_raw().0 as *mut $real,
1384                    u.as_raw().0 as *mut $raw,
1385                    ldu,
1386                    v.as_raw().0 as *mut $raw,
1387                    ldv,
1388                    workspace.as_raw().0 as *mut $raw,
1389                    lwork,
1390                    info.as_raw().0 as *mut c_int,
1391                    params.raw,
1392                )
1393            })
1394        }};
1395    }
1396
1397    if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1398        dispatch_real!(f32, cusolver_dn_sgesvdj_buffer_size, cusolver_dn_sgesvdj)
1399    } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1400        dispatch_real!(f64, cusolver_dn_dgesvdj_buffer_size, cusolver_dn_dgesvdj)
1401    } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1402        dispatch_complex!(
1403            Complex32,
1404            f32,
1405            cuComplex,
1406            cusolver_dn_cgesvdj_buffer_size,
1407            cusolver_dn_cgesvdj
1408        )
1409    } else {
1410        dispatch_complex!(
1411            Complex64,
1412            f64,
1413            cuDoubleComplex,
1414            cusolver_dn_zgesvdj_buffer_size,
1415            cusolver_dn_zgesvdj
1416        )
1417    }
1418}
1419
1420// ---- Generate / apply Q from QR (orgqr / ormqr) -------------------------
1421
1422/// Generate the orthogonal matrix `Q` from the factorization produced by
1423/// [`geqrf`]. After this, `a` holds the first `n` columns of `Q`.
1424#[allow(clippy::too_many_arguments)]
1425pub fn orgqr<T: SolverScalar>(
1426    handle: &DnHandle,
1427    m: i32,
1428    n: i32,
1429    k: i32,
1430    a: &mut DeviceBuffer<T>,
1431    lda: i32,
1432    tau: &DeviceBuffer<T>,
1433    info: &mut DeviceBuffer<i32>,
1434) -> Result<()> {
1435    use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1436    use core::mem;
1437
1438    let mut lwork: c_int = 0;
1439    macro_rules! dispatch {
1440        ($t:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1441            let c = cusolver()?;
1442            check(unsafe {
1443                (c.$bufsize()?)(
1444                    handle.as_raw(),
1445                    m,
1446                    n,
1447                    k,
1448                    a.as_raw().0 as *const $raw,
1449                    lda,
1450                    tau.as_raw().0 as *const $raw,
1451                    &mut lwork,
1452                )
1453            })?;
1454            let workspace =
1455                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1456            check(unsafe {
1457                (c.$solve()?)(
1458                    handle.as_raw(),
1459                    m,
1460                    n,
1461                    k,
1462                    a.as_raw().0 as *mut $raw,
1463                    lda,
1464                    tau.as_raw().0 as *const $raw,
1465                    workspace.as_raw().0 as *mut $raw,
1466                    lwork,
1467                    info.as_raw().0 as *mut c_int,
1468                )
1469            })
1470        }};
1471    }
1472
1473    if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1474        dispatch!(f32, f32, cusolver_dn_sorgqr_buffer_size, cusolver_dn_sorgqr)
1475    } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1476        dispatch!(f64, f64, cusolver_dn_dorgqr_buffer_size, cusolver_dn_dorgqr)
1477    } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1478        dispatch!(
1479            Complex32,
1480            cuComplex,
1481            cusolver_dn_cungqr_buffer_size,
1482            cusolver_dn_cungqr
1483        )
1484    } else {
1485        dispatch!(
1486            Complex64,
1487            cuDoubleComplex,
1488            cusolver_dn_zungqr_buffer_size,
1489            cusolver_dn_zungqr
1490        )
1491    }
1492}
1493
1494/// Side argument for [`ormqr`].
1495#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1496pub enum Side {
1497    Left,
1498    Right,
1499}
1500
1501impl Side {
1502    fn raw(self) -> core::ffi::c_int {
1503        match self {
1504            Side::Left => 0,
1505            Side::Right => 1,
1506        }
1507    }
1508}
1509
1510/// Apply `op(Q)` to `C`: `C = op(Q) * C` (Left) or `C = C * op(Q)` (Right),
1511/// where `Q` is packed in `a`+`tau` from [`geqrf`].
1512#[allow(clippy::too_many_arguments)]
1513pub fn ormqr<T: SolverScalar>(
1514    handle: &DnHandle,
1515    side: Side,
1516    trans: Op,
1517    m: i32,
1518    n: i32,
1519    k: i32,
1520    a: &DeviceBuffer<T>,
1521    lda: i32,
1522    tau: &DeviceBuffer<T>,
1523    c_mat: &mut DeviceBuffer<T>,
1524    ldc: i32,
1525    info: &mut DeviceBuffer<i32>,
1526) -> Result<()> {
1527    use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1528    use core::mem;
1529
1530    let mut lwork: c_int = 0;
1531    let side_i = side.raw();
1532    macro_rules! dispatch {
1533        ($t:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1534            let ca = cusolver()?;
1535            check(unsafe {
1536                (ca.$bufsize()?)(
1537                    handle.as_raw(),
1538                    side_i,
1539                    trans.raw(),
1540                    m,
1541                    n,
1542                    k,
1543                    a.as_raw().0 as *const $raw,
1544                    lda,
1545                    tau.as_raw().0 as *const $raw,
1546                    c_mat.as_raw().0 as *const $raw,
1547                    ldc,
1548                    &mut lwork,
1549                )
1550            })?;
1551            let workspace =
1552                DeviceBuffer::<T>::new(c_mat.context(), lwork as usize).map_err(alloc_fail)?;
1553            check(unsafe {
1554                (ca.$solve()?)(
1555                    handle.as_raw(),
1556                    side_i,
1557                    trans.raw(),
1558                    m,
1559                    n,
1560                    k,
1561                    a.as_raw().0 as *const $raw,
1562                    lda,
1563                    tau.as_raw().0 as *const $raw,
1564                    c_mat.as_raw().0 as *mut $raw,
1565                    ldc,
1566                    workspace.as_raw().0 as *mut $raw,
1567                    lwork,
1568                    info.as_raw().0 as *mut c_int,
1569                )
1570            })
1571        }};
1572    }
1573
1574    if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1575        dispatch!(f32, f32, cusolver_dn_sormqr_buffer_size, cusolver_dn_sormqr)
1576    } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1577        dispatch!(f64, f64, cusolver_dn_dormqr_buffer_size, cusolver_dn_dormqr)
1578    } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1579        dispatch!(
1580            Complex32,
1581            cuComplex,
1582            cusolver_dn_cunmqr_buffer_size,
1583            cusolver_dn_cunmqr
1584        )
1585    } else {
1586        dispatch!(
1587            Complex64,
1588            cuDoubleComplex,
1589            cusolver_dn_zunmqr_buffer_size,
1590            cusolver_dn_zunmqr
1591        )
1592    }
1593}
1594
1595// ---- gels: iterative-refinement least-squares solve ---------------------
1596
1597/// Solve `A * X = B` in the least-squares sense (iterative-refinement).
1598/// `A` is `m × n`, `B` is `m × nrhs`, `X` is `n × nrhs`. `A` and `B` may be
1599/// overwritten. Returns `iter`: number of refinement iterations used (-1 =
1600/// fallback to full precision).
1601#[allow(clippy::too_many_arguments)]
1602pub fn gels<T: SolverScalar>(
1603    handle: &DnHandle,
1604    m: i32,
1605    n: i32,
1606    nrhs: i32,
1607    a: &mut DeviceBuffer<T>,
1608    lda: i32,
1609    b: &mut DeviceBuffer<T>,
1610    ldb: i32,
1611    x: &mut DeviceBuffer<T>,
1612    ldx: i32,
1613    info: &mut DeviceBuffer<i32>,
1614) -> Result<i32> {
1615    use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1616    use core::mem;
1617
1618    let mut bytes: usize = 0;
1619
1620    macro_rules! dispatch {
1621        ($t:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1622            let cs = cusolver()?;
1623            check(unsafe {
1624                (cs.$bufsize()?)(
1625                    handle.as_raw(),
1626                    m,
1627                    n,
1628                    nrhs,
1629                    a.as_raw().0 as *mut $raw,
1630                    lda,
1631                    b.as_raw().0 as *mut $raw,
1632                    ldb,
1633                    x.as_raw().0 as *mut $raw,
1634                    ldx,
1635                    core::ptr::null_mut(),
1636                    &mut bytes,
1637                )
1638            })?;
1639            // Allocate `bytes` worth of u8 workspace (rounding up to T units).
1640            let units = bytes.div_ceil(mem::size_of::<T>());
1641            let workspace =
1642                DeviceBuffer::<T>::new(a.context(), units).map_err(alloc_fail)?;
1643            let mut iter: c_int = 0;
1644            check(unsafe {
1645                (cs.$solve()?)(
1646                    handle.as_raw(),
1647                    m,
1648                    n,
1649                    nrhs,
1650                    a.as_raw().0 as *mut $raw,
1651                    lda,
1652                    b.as_raw().0 as *mut $raw,
1653                    ldb,
1654                    x.as_raw().0 as *mut $raw,
1655                    ldx,
1656                    workspace.as_raw().0 as *mut c_void,
1657                    bytes,
1658                    &mut iter,
1659                    info.as_raw().0 as *mut c_int,
1660                )
1661            })?;
1662            Ok(iter)
1663        }};
1664    }
1665
1666    if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1667        dispatch!(f32, f32, cusolver_dn_ssgels_buffer_size, cusolver_dn_ssgels)
1668    } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1669        dispatch!(f64, f64, cusolver_dn_ddgels_buffer_size, cusolver_dn_ddgels)
1670    } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1671        dispatch!(
1672            Complex32,
1673            cuComplex,
1674            cusolver_dn_ccgels_buffer_size,
1675            cusolver_dn_ccgels
1676        )
1677    } else {
1678        dispatch!(
1679            Complex64,
1680            cuDoubleComplex,
1681            cusolver_dn_zzgels_buffer_size,
1682            cusolver_dn_zzgels
1683        )
1684    }
1685}
1686
1687// ---- potri: inverse from Cholesky factor --------------------------------
1688
1689/// Compute `A = (Lᵀ * L)⁻¹` or `A = (U * Uᵀ)⁻¹` given the Cholesky factor
1690/// already stored in the triangle selected by `uplo`. `a` must hold the
1691/// output of [`potrf`] in-place.
1692pub fn potri<T: SolverScalar>(
1693    handle: &DnHandle,
1694    uplo: Fill,
1695    n: i32,
1696    a: &mut DeviceBuffer<T>,
1697    lda: i32,
1698    info: &mut DeviceBuffer<i32>,
1699) -> Result<()> {
1700    use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1701    use core::mem;
1702
1703    let mut lwork: c_int = 0;
1704    macro_rules! dispatch {
1705        ($t:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1706            let cs = cusolver()?;
1707            check(unsafe {
1708                (cs.$bufsize()?)(
1709                    handle.as_raw(),
1710                    uplo,
1711                    n,
1712                    a.as_raw().0 as *mut $raw,
1713                    lda,
1714                    &mut lwork,
1715                )
1716            })?;
1717            let workspace =
1718                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1719            check(unsafe {
1720                (cs.$solve()?)(
1721                    handle.as_raw(),
1722                    uplo,
1723                    n,
1724                    a.as_raw().0 as *mut $raw,
1725                    lda,
1726                    workspace.as_raw().0 as *mut $raw,
1727                    lwork,
1728                    info.as_raw().0 as *mut c_int,
1729                )
1730            })
1731        }};
1732    }
1733
1734    if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1735        dispatch!(f32, f32, cusolver_dn_spotri_buffer_size, cusolver_dn_spotri)
1736    } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1737        dispatch!(f64, f64, cusolver_dn_dpotri_buffer_size, cusolver_dn_dpotri)
1738    } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1739        dispatch!(
1740            Complex32,
1741            cuComplex,
1742            cusolver_dn_cpotri_buffer_size,
1743            cusolver_dn_cpotri
1744        )
1745    } else {
1746        dispatch!(
1747            Complex64,
1748            cuDoubleComplex,
1749            cusolver_dn_zpotri_buffer_size,
1750            cusolver_dn_zpotri
1751        )
1752    }
1753}
1754
1755// ---- Batched Jacobi eigen / SVD -----------------------------------------
1756
1757/// Batched Jacobi symmetric/Hermitian eigendecomposition. Every matrix in
1758/// the batch is `n × n` and stride `n × n`. `w` holds `n * batch_size`
1759/// eigenvalues, strided by `n`.
1760#[allow(clippy::too_many_arguments)]
1761pub fn syevj_batched<T: SolverScalar>(
1762    handle: &DnHandle,
1763    jobz: EigMode,
1764    uplo: Fill,
1765    n: i32,
1766    a: &mut DeviceBuffer<T>,
1767    lda: i32,
1768    w: &mut DeviceBuffer<T::Real>,
1769    info: &mut DeviceBuffer<i32>,
1770    params: &SyevjInfo,
1771    batch_size: i32,
1772) -> Result<()> {
1773    use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1774    use core::mem;
1775
1776    let mut lwork: c_int = 0;
1777    macro_rules! dispatch_real {
1778        ($t:ty, $bufsize:ident, $solve:ident) => {{
1779            let c = cusolver()?;
1780            check(unsafe {
1781                (c.$bufsize()?)(
1782                    handle.as_raw(),
1783                    jobz,
1784                    uplo,
1785                    n,
1786                    a.as_raw().0 as *const $t,
1787                    lda,
1788                    w.as_raw().0 as *const $t,
1789                    &mut lwork,
1790                    params.as_raw(),
1791                    batch_size,
1792                )
1793            })?;
1794            let workspace =
1795                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1796            check(unsafe {
1797                (c.$solve()?)(
1798                    handle.as_raw(),
1799                    jobz,
1800                    uplo,
1801                    n,
1802                    a.as_raw().0 as *mut $t,
1803                    lda,
1804                    w.as_raw().0 as *mut $t,
1805                    workspace.as_raw().0 as *mut $t,
1806                    lwork,
1807                    info.as_raw().0 as *mut c_int,
1808                    params.as_raw(),
1809                    batch_size,
1810                )
1811            })
1812        }};
1813    }
1814    macro_rules! dispatch_complex {
1815        ($t:ty, $real:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1816            let c = cusolver()?;
1817            check(unsafe {
1818                (c.$bufsize()?)(
1819                    handle.as_raw(),
1820                    jobz,
1821                    uplo,
1822                    n,
1823                    a.as_raw().0 as *const $raw,
1824                    lda,
1825                    w.as_raw().0 as *const $real,
1826                    &mut lwork,
1827                    params.as_raw(),
1828                    batch_size,
1829                )
1830            })?;
1831            let workspace =
1832                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1833            check(unsafe {
1834                (c.$solve()?)(
1835                    handle.as_raw(),
1836                    jobz,
1837                    uplo,
1838                    n,
1839                    a.as_raw().0 as *mut $raw,
1840                    lda,
1841                    w.as_raw().0 as *mut $real,
1842                    workspace.as_raw().0 as *mut $raw,
1843                    lwork,
1844                    info.as_raw().0 as *mut c_int,
1845                    params.as_raw(),
1846                    batch_size,
1847                )
1848            })
1849        }};
1850    }
1851
1852    if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1853        dispatch_real!(
1854            f32,
1855            cusolver_dn_ssyevj_batched_buffer_size,
1856            cusolver_dn_ssyevj_batched
1857        )
1858    } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1859        dispatch_real!(
1860            f64,
1861            cusolver_dn_dsyevj_batched_buffer_size,
1862            cusolver_dn_dsyevj_batched
1863        )
1864    } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1865        dispatch_complex!(
1866            Complex32,
1867            f32,
1868            cuComplex,
1869            cusolver_dn_cheevj_batched_buffer_size,
1870            cusolver_dn_cheevj_batched
1871        )
1872    } else {
1873        dispatch_complex!(
1874            Complex64,
1875            f64,
1876            cuDoubleComplex,
1877            cusolver_dn_zheevj_batched_buffer_size,
1878            cusolver_dn_zheevj_batched
1879        )
1880    }
1881}
1882
1883/// Batched Jacobi SVD: batch of `m × n` matrices with stride `m×n`.
1884#[allow(clippy::too_many_arguments)]
1885pub fn gesvdj_batched<T: SolverScalar>(
1886    handle: &DnHandle,
1887    jobz: EigMode,
1888    m: i32,
1889    n: i32,
1890    a: &mut DeviceBuffer<T>,
1891    lda: i32,
1892    s: &mut DeviceBuffer<T::Real>,
1893    u: &mut DeviceBuffer<T>,
1894    ldu: i32,
1895    v: &mut DeviceBuffer<T>,
1896    ldv: i32,
1897    info: &mut DeviceBuffer<i32>,
1898    params: &GesvdjInfo,
1899    batch_size: i32,
1900) -> Result<()> {
1901    use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1902    use core::mem;
1903
1904    let mut lwork: c_int = 0;
1905    macro_rules! dispatch_real {
1906        ($t:ty, $bufsize:ident, $solve:ident) => {{
1907            let c = cusolver()?;
1908            check(unsafe {
1909                (c.$bufsize()?)(
1910                    handle.as_raw(),
1911                    jobz,
1912                    m,
1913                    n,
1914                    a.as_raw().0 as *const $t,
1915                    lda,
1916                    s.as_raw().0 as *const $t,
1917                    u.as_raw().0 as *const $t,
1918                    ldu,
1919                    v.as_raw().0 as *const $t,
1920                    ldv,
1921                    &mut lwork,
1922                    params.as_raw(),
1923                    batch_size,
1924                )
1925            })?;
1926            let workspace =
1927                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1928            check(unsafe {
1929                (c.$solve()?)(
1930                    handle.as_raw(),
1931                    jobz,
1932                    m,
1933                    n,
1934                    a.as_raw().0 as *mut $t,
1935                    lda,
1936                    s.as_raw().0 as *mut $t,
1937                    u.as_raw().0 as *mut $t,
1938                    ldu,
1939                    v.as_raw().0 as *mut $t,
1940                    ldv,
1941                    workspace.as_raw().0 as *mut $t,
1942                    lwork,
1943                    info.as_raw().0 as *mut c_int,
1944                    params.as_raw(),
1945                    batch_size,
1946                )
1947            })
1948        }};
1949    }
1950    macro_rules! dispatch_complex {
1951        ($t:ty, $real:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1952            let c = cusolver()?;
1953            check(unsafe {
1954                (c.$bufsize()?)(
1955                    handle.as_raw(),
1956                    jobz,
1957                    m,
1958                    n,
1959                    a.as_raw().0 as *const $raw,
1960                    lda,
1961                    s.as_raw().0 as *const $real,
1962                    u.as_raw().0 as *const $raw,
1963                    ldu,
1964                    v.as_raw().0 as *const $raw,
1965                    ldv,
1966                    &mut lwork,
1967                    params.as_raw(),
1968                    batch_size,
1969                )
1970            })?;
1971            let workspace =
1972                DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1973            check(unsafe {
1974                (c.$solve()?)(
1975                    handle.as_raw(),
1976                    jobz,
1977                    m,
1978                    n,
1979                    a.as_raw().0 as *mut $raw,
1980                    lda,
1981                    s.as_raw().0 as *mut $real,
1982                    u.as_raw().0 as *mut $raw,
1983                    ldu,
1984                    v.as_raw().0 as *mut $raw,
1985                    ldv,
1986                    workspace.as_raw().0 as *mut $raw,
1987                    lwork,
1988                    info.as_raw().0 as *mut c_int,
1989                    params.as_raw(),
1990                    batch_size,
1991                )
1992            })
1993        }};
1994    }
1995
1996    if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1997        dispatch_real!(
1998            f32,
1999            cusolver_dn_sgesvdj_batched_buffer_size,
2000            cusolver_dn_sgesvdj_batched
2001        )
2002    } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
2003        dispatch_real!(
2004            f64,
2005            cusolver_dn_dgesvdj_batched_buffer_size,
2006            cusolver_dn_dgesvdj_batched
2007        )
2008    } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
2009        dispatch_complex!(
2010            Complex32,
2011            f32,
2012            cuComplex,
2013            cusolver_dn_cgesvdj_batched_buffer_size,
2014            cusolver_dn_cgesvdj_batched
2015        )
2016    } else {
2017        dispatch_complex!(
2018            Complex64,
2019            f64,
2020            cuDoubleComplex,
2021            cusolver_dn_zgesvdj_batched_buffer_size,
2022            cusolver_dn_zgesvdj_batched
2023        )
2024    }
2025}
2026
2027// ---- cuSOLVERMg: multi-GPU dense solvers --------------------------------
2028
2029pub mod mg {
2030    //! Multi-GPU dense solvers via `libcusolverMg`. Shares dimensions with
2031    //! the single-GPU API but takes arrays of device pointers (one per
2032    //! physical GPU after [`Handle::device_select`]).
2033
2034    use core::ffi::{c_int, c_void};
2035
2036    use baracuda_cusolver_sys::{
2037        cudaDataType, cudaLibMgGrid_t, cudaLibMgMatrixDesc_t, cusolver_mg, cusolverMgHandle_t,
2038    };
2039
2040    use super::{alloc_fail, check, EigMode, Fill, Result};
2041
2042    /// Multi-GPU cuSOLVER handle.
2043    #[derive(Debug)]
2044    pub struct Handle {
2045        raw: cusolverMgHandle_t,
2046    }
2047
2048    impl Handle {
2049        pub fn new() -> Result<Self> {
2050            let mg = cusolver_mg()?;
2051            let cu = mg.cusolver_mg_create()?;
2052            let mut h: cusolverMgHandle_t = core::ptr::null_mut();
2053            check(unsafe { cu(&mut h) })?;
2054            Ok(Self { raw: h })
2055        }
2056
2057        /// Assign a set of physical CUDA devices to this handle. Future
2058        /// factorizations will stripe across them.
2059        pub fn device_select(&self, devices: &[i32]) -> Result<()> {
2060            let mg = cusolver_mg()?;
2061            let cu = mg.cusolver_mg_device_select()?;
2062            check(unsafe { cu(self.raw, devices.len() as c_int, devices.as_ptr()) })
2063        }
2064
2065        pub fn as_raw(&self) -> cusolverMgHandle_t {
2066            self.raw
2067        }
2068    }
2069
2070    impl Drop for Handle {
2071        fn drop(&mut self) {
2072            if let Ok(mg) = cusolver_mg() {
2073                if let Ok(cu) = mg.cusolver_mg_destroy() {
2074                    let _ = unsafe { cu(self.raw) };
2075                }
2076            }
2077        }
2078    }
2079
2080    /// A device grid — assigns distribution roles to physical devices.
2081    #[derive(Debug)]
2082    pub struct DeviceGrid {
2083        raw: cudaLibMgGrid_t,
2084    }
2085
2086    impl DeviceGrid {
2087        /// `mapping` is typically `CUDALIBMG_GRID_MAPPING_COL_MAJOR (1)`.
2088        pub fn new(num_row_devices: i32, num_col_devices: i32, devices: &[i32], mapping: i32) -> Result<Self> {
2089            let mg = cusolver_mg()?;
2090            let cu = mg.cusolver_mg_create_device_grid()?;
2091            let mut raw: cudaLibMgGrid_t = core::ptr::null_mut();
2092            check(unsafe {
2093                cu(
2094                    &mut raw,
2095                    num_row_devices,
2096                    num_col_devices,
2097                    devices.as_ptr(),
2098                    mapping,
2099                )
2100            })?;
2101            Ok(Self { raw })
2102        }
2103
2104        pub fn as_raw(&self) -> cudaLibMgGrid_t {
2105            self.raw
2106        }
2107    }
2108
2109    impl Drop for DeviceGrid {
2110        fn drop(&mut self) {
2111            if let Ok(mg) = cusolver_mg() {
2112                if let Ok(cu) = mg.cusolver_mg_destroy_grid() {
2113                    let _ = unsafe { cu(self.raw) };
2114                }
2115            }
2116        }
2117    }
2118
2119    /// Matrix-distribution descriptor.
2120    #[derive(Debug)]
2121    pub struct MatrixDesc {
2122        raw: cudaLibMgMatrixDesc_t,
2123    }
2124
2125    impl MatrixDesc {
2126        pub fn new(
2127            num_rows: i64,
2128            num_cols: i64,
2129            row_block_size: i64,
2130            col_block_size: i64,
2131            data_type: cudaDataType,
2132            grid: &DeviceGrid,
2133        ) -> Result<Self> {
2134            let mg = cusolver_mg()?;
2135            let cu = mg.cusolver_mg_create_matrix_desc()?;
2136            let mut raw: cudaLibMgMatrixDesc_t = core::ptr::null_mut();
2137            check(unsafe {
2138                cu(
2139                    &mut raw,
2140                    num_rows,
2141                    num_cols,
2142                    row_block_size,
2143                    col_block_size,
2144                    data_type,
2145                    grid.as_raw(),
2146                )
2147            })?;
2148            Ok(Self { raw })
2149        }
2150
2151        pub fn as_raw(&self) -> cudaLibMgMatrixDesc_t {
2152            self.raw
2153        }
2154    }
2155
2156    impl Drop for MatrixDesc {
2157        fn drop(&mut self) {
2158            if let Ok(mg) = cusolver_mg() {
2159                if let Ok(cu) = mg.cusolver_mg_destroy_matrix_desc() {
2160                    let _ = unsafe { cu(self.raw) };
2161                }
2162            }
2163        }
2164    }
2165
2166    /// Multi-GPU LU buffer-size query.
2167    ///
2168    /// # Safety
2169    /// `array_d_a`, `array_d_ipiv` must be host arrays of device pointers
2170    /// matching the selected devices.
2171    #[allow(clippy::too_many_arguments)]
2172    pub unsafe fn getrf_buffer_size(
2173        handle: &Handle,
2174        m: i32,
2175        n: i32,
2176        array_d_a: *mut *mut c_void,
2177        ia: i32,
2178        ja: i32,
2179        desc_a: &MatrixDesc,
2180        array_d_ipiv: *mut *mut c_int,
2181        compute_type: cudaDataType,
2182    ) -> Result<i64> {
2183        let mg = cusolver_mg()?;
2184        let cu = mg.cusolver_mg_getrf_buffer_size()?;
2185        let mut lwork: i64 = 0;
2186        check(cu(
2187            handle.as_raw(),
2188            m,
2189            n,
2190            array_d_a,
2191            ia,
2192            ja,
2193            desc_a.as_raw(),
2194            array_d_ipiv,
2195            compute_type,
2196            &mut lwork,
2197        ))?;
2198        Ok(lwork)
2199    }
2200
2201    /// # Safety
2202    /// Same pointer-array requirements as [`getrf_buffer_size`].
2203    #[allow(clippy::too_many_arguments)]
2204    pub unsafe fn getrf(
2205        handle: &Handle,
2206        m: i32,
2207        n: i32,
2208        array_d_a: *mut *mut c_void,
2209        ia: i32,
2210        ja: i32,
2211        desc_a: &MatrixDesc,
2212        array_d_ipiv: *mut *mut c_int,
2213        compute_type: cudaDataType,
2214        array_d_work: *mut *mut c_void,
2215        lwork: i64,
2216        info: &mut [c_int],
2217    ) -> Result<()> {
2218        let mg = cusolver_mg()?;
2219        let cu = mg.cusolver_mg_getrf()?;
2220        let _ = alloc_fail::<()>; // silence unused-import in release builds
2221        check(cu(
2222            handle.as_raw(),
2223            m,
2224            n,
2225            array_d_a,
2226            ia,
2227            ja,
2228            desc_a.as_raw(),
2229            array_d_ipiv,
2230            compute_type,
2231            array_d_work,
2232            lwork,
2233            info.as_mut_ptr(),
2234        ))
2235    }
2236
2237    /// Multi-GPU Cholesky buffer-size.
2238    ///
2239    /// # Safety
2240    /// Same as [`getrf_buffer_size`].
2241    #[allow(clippy::too_many_arguments)]
2242    pub unsafe fn potrf_buffer_size(
2243        handle: &Handle,
2244        uplo: Fill,
2245        n: i32,
2246        array_d_a: *mut *mut c_void,
2247        ia: i32,
2248        ja: i32,
2249        desc_a: &MatrixDesc,
2250        compute_type: cudaDataType,
2251    ) -> Result<i64> {
2252        let mg = cusolver_mg()?;
2253        let cu = mg.cusolver_mg_potrf_buffer_size()?;
2254        let mut lwork: i64 = 0;
2255        check(cu(
2256            handle.as_raw(),
2257            uplo,
2258            n,
2259            array_d_a,
2260            ia,
2261            ja,
2262            desc_a.as_raw(),
2263            compute_type,
2264            &mut lwork,
2265        ))?;
2266        Ok(lwork)
2267    }
2268
2269    /// # Safety
2270    /// Same as [`getrf_buffer_size`].
2271    #[allow(clippy::too_many_arguments)]
2272    pub unsafe fn potrf(
2273        handle: &Handle,
2274        uplo: Fill,
2275        n: i32,
2276        array_d_a: *mut *mut c_void,
2277        ia: i32,
2278        ja: i32,
2279        desc_a: &MatrixDesc,
2280        compute_type: cudaDataType,
2281        array_d_work: *mut *mut c_void,
2282        lwork: i64,
2283        info: &mut [c_int],
2284    ) -> Result<()> {
2285        let mg = cusolver_mg()?;
2286        let cu = mg.cusolver_mg_potrf()?;
2287        check(cu(
2288            handle.as_raw(),
2289            uplo,
2290            n,
2291            array_d_a,
2292            ia,
2293            ja,
2294            desc_a.as_raw(),
2295            compute_type,
2296            array_d_work,
2297            lwork,
2298            info.as_mut_ptr(),
2299        ))
2300    }
2301
2302    /// Multi-GPU symmetric eigendecomposition buffer-size.
2303    ///
2304    /// # Safety
2305    /// Same as [`getrf_buffer_size`].
2306    #[allow(clippy::too_many_arguments)]
2307    pub unsafe fn syevd_buffer_size(
2308        handle: &Handle,
2309        jobz: EigMode,
2310        uplo: Fill,
2311        n: i32,
2312        array_d_a: *mut *mut c_void,
2313        ia: i32,
2314        ja: i32,
2315        desc_a: &MatrixDesc,
2316        w: *mut c_void,
2317        data_type_w: cudaDataType,
2318        compute_type: cudaDataType,
2319    ) -> Result<i64> {
2320        let mg = cusolver_mg()?;
2321        let cu = mg.cusolver_mg_syevd_buffer_size()?;
2322        let mut lwork: i64 = 0;
2323        check(cu(
2324            handle.as_raw(),
2325            jobz,
2326            uplo,
2327            n,
2328            array_d_a,
2329            ia,
2330            ja,
2331            desc_a.as_raw(),
2332            w,
2333            data_type_w,
2334            compute_type,
2335            &mut lwork,
2336        ))?;
2337        Ok(lwork)
2338    }
2339
2340    /// # Safety
2341    /// Same as [`getrf_buffer_size`].
2342    #[allow(clippy::too_many_arguments)]
2343    pub unsafe fn syevd(
2344        handle: &Handle,
2345        jobz: EigMode,
2346        uplo: Fill,
2347        n: i32,
2348        array_d_a: *mut *mut c_void,
2349        ia: i32,
2350        ja: i32,
2351        desc_a: &MatrixDesc,
2352        w: *mut c_void,
2353        data_type_w: cudaDataType,
2354        compute_type: cudaDataType,
2355        array_d_work: *mut *mut c_void,
2356        lwork: i64,
2357        info: &mut [c_int],
2358    ) -> Result<()> {
2359        let mg = cusolver_mg()?;
2360        let cu = mg.cusolver_mg_syevd()?;
2361        check(cu(
2362            handle.as_raw(),
2363            jobz,
2364            uplo,
2365            n,
2366            array_d_a,
2367            ia,
2368            ja,
2369            desc_a.as_raw(),
2370            w,
2371            data_type_w,
2372            compute_type,
2373            array_d_work,
2374            lwork,
2375            info.as_mut_ptr(),
2376        ))
2377    }
2378}
2379
2380// ---- Back-compat: single-precision shortcuts -----------------------------
2381
2382/// Shortcut for [`getrf`] on `f32`.
2383pub fn sgetrf(
2384    handle: &DnHandle,
2385    m: i32,
2386    n: i32,
2387    a: &mut DeviceBuffer<f32>,
2388    lda: i32,
2389    ipiv: &mut DeviceBuffer<i32>,
2390    info: &mut DeviceBuffer<i32>,
2391) -> Result<()> {
2392    getrf::<f32>(handle, m, n, a, lda, ipiv, info)
2393}
2394
2395/// Shortcut for [`getrs`] on `f32`.
2396#[allow(clippy::too_many_arguments)]
2397pub fn sgetrs(
2398    handle: &DnHandle,
2399    trans: Op,
2400    n: i32,
2401    nrhs: i32,
2402    a: &DeviceBuffer<f32>,
2403    lda: i32,
2404    ipiv: &DeviceBuffer<i32>,
2405    b: &mut DeviceBuffer<f32>,
2406    ldb: i32,
2407    info: &mut DeviceBuffer<i32>,
2408) -> Result<()> {
2409    getrs::<f32>(handle, trans, n, nrhs, a, lda, ipiv, b, ldb, info)
2410}
2411
2412// ---- Generic X... (64-bit-size, type-erased) ----------------------------
2413
2414pub mod xapi {
2415    //! The generic 64-bit cuSOLVER API (`cusolverDnX*`). Matrix dimensions
2416    //! are `i64`; element types are passed at call-time as
2417    //! [`cudaDataType`]. Workspace sizes are split between on-device and
2418    //! on-host buffers.
2419
2420    use super::*;
2421    use baracuda_cusolver_sys::{cudaDataType, cusolverDnParams_t};
2422
2423    #[derive(Debug)]
2424    pub struct Params {
2425        raw: cusolverDnParams_t,
2426    }
2427
2428    impl Params {
2429        pub fn new() -> Result<Self> {
2430            let c = cusolver()?;
2431            let cu = c.cusolver_dn_create_params()?;
2432            let mut p: cusolverDnParams_t = core::ptr::null_mut();
2433            check(unsafe { cu(&mut p) })?;
2434            Ok(Self { raw: p })
2435        }
2436
2437        pub fn as_raw(&self) -> cusolverDnParams_t {
2438            self.raw
2439        }
2440    }
2441
2442    impl Drop for Params {
2443        fn drop(&mut self) {
2444            if let Ok(c) = cusolver() {
2445                if let Ok(cu) = c.cusolver_dn_destroy_params() {
2446                    let _ = unsafe { cu(self.raw) };
2447                }
2448            }
2449        }
2450    }
2451
2452    /// Buffer-size query for generic LU factorization. Returns
2453    /// `(workspace_bytes_on_device, workspace_bytes_on_host)`.
2454    #[allow(clippy::too_many_arguments)]
2455    pub fn xgetrf_buffer_size(
2456        handle: &DnHandle,
2457        params: &Params,
2458        m: i64,
2459        n: i64,
2460        data_type_a: cudaDataType,
2461        a: *const c_void,
2462        lda: i64,
2463        compute_type: cudaDataType,
2464    ) -> Result<(usize, usize)> {
2465        let c = cusolver()?;
2466        let cu = c.cusolver_dn_xgetrf_buffer_size()?;
2467        let (mut dev, mut host) = (0usize, 0usize);
2468        check(unsafe {
2469            cu(
2470                handle.as_raw(),
2471                params.raw,
2472                m,
2473                n,
2474                data_type_a,
2475                a,
2476                lda,
2477                compute_type,
2478                &mut dev,
2479                &mut host,
2480            )
2481        })?;
2482        Ok((dev, host))
2483    }
2484
2485    #[allow(clippy::too_many_arguments)]
2486    pub unsafe fn xgetrf(
2487        handle: &DnHandle,
2488        params: &Params,
2489        m: i64,
2490        n: i64,
2491        data_type_a: cudaDataType,
2492        a: *mut c_void,
2493        lda: i64,
2494        ipiv: *mut i64,
2495        compute_type: cudaDataType,
2496        device_buf: *mut c_void,
2497        device_bytes: usize,
2498        host_buf: *mut c_void,
2499        host_bytes: usize,
2500        info: *mut c_int,
2501    ) -> Result<()> {
2502        let c = cusolver()?;
2503        let cu = c.cusolver_dn_xgetrf()?;
2504        check(cu(
2505            handle.as_raw(),
2506            params.raw,
2507            m,
2508            n,
2509            data_type_a,
2510            a,
2511            lda,
2512            ipiv,
2513            compute_type,
2514            device_buf,
2515            device_bytes,
2516            host_buf,
2517            host_bytes,
2518            info,
2519        ))
2520    }
2521}
2522
2523// ---- Sparse --------------------------------------------------------------
2524
2525pub mod sparse {
2526    //! `cusolverSp*` — solve sparse linear systems via Cholesky or QR.
2527
2528    use super::*;
2529    use baracuda_cusolver_sys::cusolverSpHandle_t;
2530    use core::ffi::c_int;
2531
2532    #[derive(Debug)]
2533    pub struct SpHandle {
2534        raw: cusolverSpHandle_t,
2535        _not_send: PhantomData<*mut ()>,
2536    }
2537
2538    impl SpHandle {
2539        pub fn new() -> Result<Self> {
2540            let c = cusolver()?;
2541            let cu = c.cusolver_sp_create()?;
2542            let mut h: cusolverSpHandle_t = core::ptr::null_mut();
2543            check(unsafe { cu(&mut h) })?;
2544            Ok(Self {
2545                raw: h,
2546                _not_send: PhantomData,
2547            })
2548        }
2549
2550        pub fn set_stream(&self, stream: &Stream) -> Result<()> {
2551            let c = cusolver()?;
2552            let cu = c.cusolver_sp_set_stream()?;
2553            check(unsafe { cu(self.raw, stream.as_raw() as _) })
2554        }
2555
2556        pub fn as_raw(&self) -> cusolverSpHandle_t {
2557            self.raw
2558        }
2559    }
2560
2561    impl Drop for SpHandle {
2562        fn drop(&mut self) {
2563            if let Ok(c) = cusolver() {
2564                if let Ok(cu) = c.cusolver_sp_destroy() {
2565                    let _ = unsafe { cu(self.raw) };
2566                }
2567            }
2568        }
2569    }
2570
2571    /// Sparse Cholesky solve: `A * x = b` for SPD `A`.
2572    ///
2573    /// # Safety
2574    /// `descr_a`, CSR arrays, b and x must live on-device (b + x on-device,
2575    /// CSR arrays + descriptor on-device) and satisfy cuSOLVER sparse
2576    /// format requirements.
2577    #[allow(clippy::too_many_arguments)]
2578    pub unsafe fn scsrlsvchol(
2579        handle: &SpHandle,
2580        m: i32,
2581        nnz: i32,
2582        descr_a: *mut c_void,
2583        csr_val: *const f32,
2584        csr_row_ptr: *const c_int,
2585        csr_col_ind: *const c_int,
2586        b: *const f32,
2587        tol: f32,
2588        reorder: i32,
2589        x: *mut f32,
2590        singularity: *mut c_int,
2591    ) -> Result<()> {
2592        let c = cusolver()?;
2593        let cu = c.cusolver_sp_scsrlsvchol()?;
2594        check(cu(
2595            handle.raw,
2596            m,
2597            nnz,
2598            descr_a,
2599            csr_val,
2600            csr_row_ptr,
2601            csr_col_ind,
2602            b,
2603            tol,
2604            reorder,
2605            x,
2606            singularity,
2607        ))
2608    }
2609
2610    /// Sparse QR solve (least-squares, handles non-SPD systems).
2611    ///
2612    /// # Safety
2613    /// Same as [`scsrlsvchol`].
2614    #[allow(clippy::too_many_arguments)]
2615    pub unsafe fn scsrlsvqr(
2616        handle: &SpHandle,
2617        m: i32,
2618        nnz: i32,
2619        descr_a: *mut c_void,
2620        csr_val: *const f32,
2621        csr_row_ptr: *const c_int,
2622        csr_col_ind: *const c_int,
2623        b: *const f32,
2624        tol: f32,
2625        reorder: i32,
2626        x: *mut f32,
2627        singularity: *mut c_int,
2628    ) -> Result<()> {
2629        let c = cusolver()?;
2630        let cu = c.cusolver_sp_scsrlsvqr()?;
2631        check(cu(
2632            handle.raw,
2633            m,
2634            nnz,
2635            descr_a,
2636            csr_val,
2637            csr_row_ptr,
2638            csr_col_ind,
2639            b,
2640            tol,
2641            reorder,
2642            x,
2643            singularity,
2644        ))
2645    }
2646}
2647
2648// ---- Refactor ------------------------------------------------------------
2649
2650pub mod refactor {
2651    //! `cusolverRf*` — fast re-factorization given a sparsity pattern, for
2652    //! solving many systems that differ only in numeric values.
2653
2654    use super::*;
2655    use baracuda_cusolver_sys::cusolverRfHandle_t;
2656
2657    #[derive(Debug)]
2658    pub struct RfHandle {
2659        raw: cusolverRfHandle_t,
2660        _not_send: PhantomData<*mut ()>,
2661    }
2662
2663    impl RfHandle {
2664        pub fn new() -> Result<Self> {
2665            let c = cusolver()?;
2666            let cu = c.cusolver_rf_create()?;
2667            let mut h: cusolverRfHandle_t = core::ptr::null_mut();
2668            check(unsafe { cu(&mut h) })?;
2669            Ok(Self {
2670                raw: h,
2671                _not_send: PhantomData,
2672            })
2673        }
2674
2675        pub fn as_raw(&self) -> cusolverRfHandle_t {
2676            self.raw
2677        }
2678
2679        pub fn analyze(&self) -> Result<()> {
2680            let c = cusolver()?;
2681            let cu = c.cusolver_rf_analyze()?;
2682            check(unsafe { cu(self.raw) })
2683        }
2684
2685        pub fn refactor(&self) -> Result<()> {
2686            let c = cusolver()?;
2687            let cu = c.cusolver_rf_refactor()?;
2688            check(unsafe { cu(self.raw) })
2689        }
2690    }
2691
2692    impl Drop for RfHandle {
2693        fn drop(&mut self) {
2694            if let Ok(c) = cusolver() {
2695                if let Ok(cu) = c.cusolver_rf_destroy() {
2696                    let _ = unsafe { cu(self.raw) };
2697                }
2698            }
2699        }
2700    }
2701}