1#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
4#![warn(missing_debug_implementations)]
5
6use core::ffi::{c_int, c_void};
7use std::sync::OnceLock;
8
9use baracuda_core::{platform, Library, LoaderError};
10use baracuda_cuda_sys::runtime::cudaStream_t;
11use baracuda_types::CudaStatus;
12
13pub type cusolverDnHandle_t = *mut c_void;
16pub type cusolverSpHandle_t = *mut c_void;
17pub type cusolverRfHandle_t = *mut c_void;
18pub type cusolverDnParams_t = *mut c_void;
19pub type cusolverDnIRSParams_t = *mut c_void;
20pub type cusolverDnIRSInfos_t = *mut c_void;
21pub type syevjInfo_t = *mut c_void;
22pub type gesvdjInfo_t = *mut c_void;
23
24#[repr(i32)]
28#[derive(Copy, Clone, Debug, Eq, PartialEq)]
29pub enum cublasOperation_t {
30 N = 0,
31 T = 1,
32 C = 2,
33}
34
35#[repr(i32)]
36#[derive(Copy, Clone, Debug, Eq, PartialEq)]
37pub enum cublasFillMode_t {
38 Lower = 0,
39 Upper = 1,
40 Full = 2,
41}
42
43#[repr(i32)]
44#[derive(Copy, Clone, Debug, Eq, PartialEq)]
45pub enum cublasSideMode_t {
46 Left = 0,
47 Right = 1,
48}
49
50#[repr(i32)]
51#[derive(Copy, Clone, Debug, Eq, PartialEq)]
52pub enum cublasDiagType_t {
53 NonUnit = 0,
54 Unit = 1,
55}
56
57#[repr(i32)]
58#[derive(Copy, Clone, Debug, Eq, PartialEq)]
59pub enum cusolverEigType_t {
60 Type1 = 1,
61 Type2 = 2,
62 Type3 = 3,
63}
64
65#[repr(i32)]
66#[derive(Copy, Clone, Debug, Eq, PartialEq)]
67pub enum cusolverEigMode_t {
68 NoVector = 0,
69 Vector = 1,
70}
71
72#[repr(i32)]
73#[derive(Copy, Clone, Debug, Eq, PartialEq)]
74pub enum cusolverEigRange_t {
75 All = 1001,
76 I = 1002,
77 V = 1003,
78}
79
80#[repr(i32)]
81#[derive(Copy, Clone, Debug, Eq, PartialEq)]
82pub enum cudaDataType {
83 R_32F = 0,
84 R_64F = 1,
85 R_16F = 2,
86 C_32F = 4,
87 C_64F = 5,
88 R_16BF = 14,
89}
90
91#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
94#[repr(transparent)]
95pub struct cusolverStatus_t(pub i32);
96
97impl cusolverStatus_t {
98 pub const SUCCESS: Self = Self(0);
99 pub const NOT_INITIALIZED: Self = Self(1);
100 pub const ALLOC_FAILED: Self = Self(2);
101 pub const INVALID_VALUE: Self = Self(3);
102 pub const ARCH_MISMATCH: Self = Self(4);
103 pub const EXECUTION_FAILED: Self = Self(6);
104 pub const INTERNAL_ERROR: Self = Self(7);
105 pub const NOT_SUPPORTED: Self = Self(9);
106 pub const ZERO_PIVOT: Self = Self(10);
107
108 pub const fn is_success(self) -> bool {
109 self.0 == 0
110 }
111}
112
113impl CudaStatus for cusolverStatus_t {
114 fn code(self) -> i32 {
115 self.0
116 }
117 fn name(self) -> &'static str {
118 match self.0 {
119 0 => "CUSOLVER_STATUS_SUCCESS",
120 1 => "CUSOLVER_STATUS_NOT_INITIALIZED",
121 2 => "CUSOLVER_STATUS_ALLOC_FAILED",
122 3 => "CUSOLVER_STATUS_INVALID_VALUE",
123 6 => "CUSOLVER_STATUS_EXECUTION_FAILED",
124 7 => "CUSOLVER_STATUS_INTERNAL_ERROR",
125 9 => "CUSOLVER_STATUS_NOT_SUPPORTED",
126 10 => "CUSOLVER_STATUS_ZERO_PIVOT",
127 _ => "CUSOLVER_STATUS_UNRECOGNIZED",
128 }
129 }
130 fn description(self) -> &'static str {
131 match self.0 {
132 0 => "success",
133 1 => "cuSOLVER not initialized",
134 6 => "execution failed on device",
135 10 => "factorization produced a zero pivot",
136 _ => "unrecognized cuSOLVER status code",
137 }
138 }
139 fn is_success(self) -> bool {
140 cusolverStatus_t::is_success(self)
141 }
142 fn library(self) -> &'static str {
143 "cusolver"
144 }
145}
146
147#[repr(C)]
150#[derive(Copy, Clone, Debug)]
151pub struct cuComplex {
152 pub x: f32,
153 pub y: f32,
154}
155
156#[repr(C)]
157#[derive(Copy, Clone, Debug)]
158pub struct cuDoubleComplex {
159 pub x: f64,
160 pub y: f64,
161}
162
163macro_rules! dn_getrf_bufsize {
167 ($name:ident, $t:ty) => {
168 pub type $name = unsafe extern "C" fn(
169 handle: cusolverDnHandle_t,
170 m: c_int,
171 n: c_int,
172 a: *mut $t,
173 lda: c_int,
174 lwork: *mut c_int,
175 ) -> cusolverStatus_t;
176 };
177}
178
179macro_rules! dn_getrf {
180 ($name:ident, $t:ty) => {
181 pub type $name = unsafe extern "C" fn(
182 handle: cusolverDnHandle_t,
183 m: c_int,
184 n: c_int,
185 a: *mut $t,
186 lda: c_int,
187 workspace: *mut $t,
188 ipiv: *mut c_int,
189 info: *mut c_int,
190 ) -> cusolverStatus_t;
191 };
192}
193
194macro_rules! dn_getrs {
195 ($name:ident, $t:ty) => {
196 pub type $name = unsafe extern "C" fn(
197 handle: cusolverDnHandle_t,
198 trans: cublasOperation_t,
199 n: c_int,
200 nrhs: c_int,
201 a: *const $t,
202 lda: c_int,
203 ipiv: *const c_int,
204 b: *mut $t,
205 ldb: c_int,
206 info: *mut c_int,
207 ) -> cusolverStatus_t;
208 };
209}
210
211macro_rules! dn_geqrf_bufsize {
212 ($name:ident, $t:ty) => {
213 pub type $name = unsafe extern "C" fn(
214 handle: cusolverDnHandle_t,
215 m: c_int,
216 n: c_int,
217 a: *mut $t,
218 lda: c_int,
219 lwork: *mut c_int,
220 ) -> cusolverStatus_t;
221 };
222}
223
224macro_rules! dn_geqrf {
225 ($name:ident, $t:ty) => {
226 pub type $name = unsafe extern "C" fn(
227 handle: cusolverDnHandle_t,
228 m: c_int,
229 n: c_int,
230 a: *mut $t,
231 lda: c_int,
232 tau: *mut $t,
233 workspace: *mut $t,
234 lwork: c_int,
235 info: *mut c_int,
236 ) -> cusolverStatus_t;
237 };
238}
239
240macro_rules! dn_potrf_bufsize {
241 ($name:ident, $t:ty) => {
242 pub type $name = unsafe extern "C" fn(
243 handle: cusolverDnHandle_t,
244 uplo: cublasFillMode_t,
245 n: c_int,
246 a: *mut $t,
247 lda: c_int,
248 lwork: *mut c_int,
249 ) -> cusolverStatus_t;
250 };
251}
252
253macro_rules! dn_potrf {
254 ($name:ident, $t:ty) => {
255 pub type $name = unsafe extern "C" fn(
256 handle: cusolverDnHandle_t,
257 uplo: cublasFillMode_t,
258 n: c_int,
259 a: *mut $t,
260 lda: c_int,
261 workspace: *mut $t,
262 lwork: c_int,
263 info: *mut c_int,
264 ) -> cusolverStatus_t;
265 };
266}
267
268macro_rules! dn_potrs {
269 ($name:ident, $t:ty) => {
270 pub type $name = unsafe extern "C" fn(
271 handle: cusolverDnHandle_t,
272 uplo: cublasFillMode_t,
273 n: c_int,
274 nrhs: c_int,
275 a: *const $t,
276 lda: c_int,
277 b: *mut $t,
278 ldb: c_int,
279 info: *mut c_int,
280 ) -> cusolverStatus_t;
281 };
282}
283
284macro_rules! dn_gesvd_bufsize {
285 ($name:ident) => {
286 pub type $name = unsafe extern "C" fn(
287 handle: cusolverDnHandle_t,
288 m: c_int,
289 n: c_int,
290 lwork: *mut c_int,
291 ) -> cusolverStatus_t;
292 };
293}
294
295macro_rules! dn_gesvd_real {
296 ($name:ident, $t:ty) => {
297 pub type $name = unsafe extern "C" fn(
298 handle: cusolverDnHandle_t,
299 jobu: u8,
300 jobvt: u8,
301 m: c_int,
302 n: c_int,
303 a: *mut $t,
304 lda: c_int,
305 s: *mut $t,
306 u: *mut $t,
307 ldu: c_int,
308 vt: *mut $t,
309 ldvt: c_int,
310 work: *mut $t,
311 lwork: c_int,
312 rwork: *mut $t,
313 info: *mut c_int,
314 ) -> cusolverStatus_t;
315 };
316}
317
318macro_rules! dn_gesvd_complex {
319 ($name:ident, $t:ty, $real:ty) => {
320 pub type $name = unsafe extern "C" fn(
321 handle: cusolverDnHandle_t,
322 jobu: u8,
323 jobvt: u8,
324 m: c_int,
325 n: c_int,
326 a: *mut $t,
327 lda: c_int,
328 s: *mut $real,
329 u: *mut $t,
330 ldu: c_int,
331 vt: *mut $t,
332 ldvt: c_int,
333 work: *mut $t,
334 lwork: c_int,
335 rwork: *mut $real,
336 info: *mut c_int,
337 ) -> cusolverStatus_t;
338 };
339}
340
341macro_rules! dn_syevd_bufsize {
342 ($name:ident, $t:ty, $real:ty) => {
343 pub type $name = unsafe extern "C" fn(
344 handle: cusolverDnHandle_t,
345 jobz: cusolverEigMode_t,
346 uplo: cublasFillMode_t,
347 n: c_int,
348 a: *const $t,
349 lda: c_int,
350 w: *const $real,
351 lwork: *mut c_int,
352 ) -> cusolverStatus_t;
353 };
354}
355
356macro_rules! dn_syevd {
357 ($name:ident, $t:ty, $real:ty) => {
358 pub type $name = unsafe extern "C" fn(
359 handle: cusolverDnHandle_t,
360 jobz: cusolverEigMode_t,
361 uplo: cublasFillMode_t,
362 n: c_int,
363 a: *mut $t,
364 lda: c_int,
365 w: *mut $real,
366 work: *mut $t,
367 lwork: c_int,
368 info: *mut c_int,
369 ) -> cusolverStatus_t;
370 };
371}
372
373pub type PFN_cusolverDnCreate =
376 unsafe extern "C" fn(handle: *mut cusolverDnHandle_t) -> cusolverStatus_t;
377pub type PFN_cusolverDnDestroy =
378 unsafe extern "C" fn(handle: cusolverDnHandle_t) -> cusolverStatus_t;
379pub type PFN_cusolverDnSetStream =
380 unsafe extern "C" fn(handle: cusolverDnHandle_t, stream: cudaStream_t) -> cusolverStatus_t;
381pub type PFN_cusolverDnGetStream =
382 unsafe extern "C" fn(handle: cusolverDnHandle_t, stream: *mut cudaStream_t) -> cusolverStatus_t;
383
384pub type PFN_cusolverGetVersion = unsafe extern "C" fn(version: *mut c_int) -> cusolverStatus_t;
385
386dn_getrf_bufsize!(PFN_cusolverDnSgetrf_bufferSize, f32);
389dn_getrf_bufsize!(PFN_cusolverDnDgetrf_bufferSize, f64);
390dn_getrf_bufsize!(PFN_cusolverDnCgetrf_bufferSize, cuComplex);
391dn_getrf_bufsize!(PFN_cusolverDnZgetrf_bufferSize, cuDoubleComplex);
392
393dn_getrf!(PFN_cusolverDnSgetrf, f32);
394dn_getrf!(PFN_cusolverDnDgetrf, f64);
395dn_getrf!(PFN_cusolverDnCgetrf, cuComplex);
396dn_getrf!(PFN_cusolverDnZgetrf, cuDoubleComplex);
397
398dn_getrs!(PFN_cusolverDnSgetrs, f32);
399dn_getrs!(PFN_cusolverDnDgetrs, f64);
400dn_getrs!(PFN_cusolverDnCgetrs, cuComplex);
401dn_getrs!(PFN_cusolverDnZgetrs, cuDoubleComplex);
402
403dn_geqrf_bufsize!(PFN_cusolverDnSgeqrf_bufferSize, f32);
406dn_geqrf_bufsize!(PFN_cusolverDnDgeqrf_bufferSize, f64);
407dn_geqrf_bufsize!(PFN_cusolverDnCgeqrf_bufferSize, cuComplex);
408dn_geqrf_bufsize!(PFN_cusolverDnZgeqrf_bufferSize, cuDoubleComplex);
409
410dn_geqrf!(PFN_cusolverDnSgeqrf, f32);
411dn_geqrf!(PFN_cusolverDnDgeqrf, f64);
412dn_geqrf!(PFN_cusolverDnCgeqrf, cuComplex);
413dn_geqrf!(PFN_cusolverDnZgeqrf, cuDoubleComplex);
414
415dn_potrf_bufsize!(PFN_cusolverDnSpotrf_bufferSize, f32);
418dn_potrf_bufsize!(PFN_cusolverDnDpotrf_bufferSize, f64);
419dn_potrf_bufsize!(PFN_cusolverDnCpotrf_bufferSize, cuComplex);
420dn_potrf_bufsize!(PFN_cusolverDnZpotrf_bufferSize, cuDoubleComplex);
421
422dn_potrf!(PFN_cusolverDnSpotrf, f32);
423dn_potrf!(PFN_cusolverDnDpotrf, f64);
424dn_potrf!(PFN_cusolverDnCpotrf, cuComplex);
425dn_potrf!(PFN_cusolverDnZpotrf, cuDoubleComplex);
426
427dn_potrs!(PFN_cusolverDnSpotrs, f32);
428dn_potrs!(PFN_cusolverDnDpotrs, f64);
429dn_potrs!(PFN_cusolverDnCpotrs, cuComplex);
430dn_potrs!(PFN_cusolverDnZpotrs, cuDoubleComplex);
431
432dn_gesvd_bufsize!(PFN_cusolverDnSgesvd_bufferSize);
435dn_gesvd_bufsize!(PFN_cusolverDnDgesvd_bufferSize);
436dn_gesvd_bufsize!(PFN_cusolverDnCgesvd_bufferSize);
437dn_gesvd_bufsize!(PFN_cusolverDnZgesvd_bufferSize);
438
439dn_gesvd_real!(PFN_cusolverDnSgesvd, f32);
440dn_gesvd_real!(PFN_cusolverDnDgesvd, f64);
441dn_gesvd_complex!(PFN_cusolverDnCgesvd, cuComplex, f32);
442dn_gesvd_complex!(PFN_cusolverDnZgesvd, cuDoubleComplex, f64);
443
444dn_syevd_bufsize!(PFN_cusolverDnSsyevd_bufferSize, f32, f32);
447dn_syevd_bufsize!(PFN_cusolverDnDsyevd_bufferSize, f64, f64);
448dn_syevd_bufsize!(PFN_cusolverDnCheevd_bufferSize, cuComplex, f32);
449dn_syevd_bufsize!(PFN_cusolverDnZheevd_bufferSize, cuDoubleComplex, f64);
450
451dn_syevd!(PFN_cusolverDnSsyevd, f32, f32);
452dn_syevd!(PFN_cusolverDnDsyevd, f64, f64);
453dn_syevd!(PFN_cusolverDnCheevd, cuComplex, f32);
454dn_syevd!(PFN_cusolverDnZheevd, cuDoubleComplex, f64);
455
456pub type PFN_cusolverDnCreateParams =
459 unsafe extern "C" fn(params: *mut cusolverDnParams_t) -> cusolverStatus_t;
460pub type PFN_cusolverDnDestroyParams =
461 unsafe extern "C" fn(params: cusolverDnParams_t) -> cusolverStatus_t;
462
463pub type PFN_cusolverDnXgetrf_bufferSize = unsafe extern "C" fn(
464 handle: cusolverDnHandle_t,
465 params: cusolverDnParams_t,
466 m: i64,
467 n: i64,
468 data_type_a: cudaDataType,
469 a: *const c_void,
470 lda: i64,
471 compute_type: cudaDataType,
472 workspace_in_bytes_on_device: *mut usize,
473 workspace_in_bytes_on_host: *mut usize,
474) -> cusolverStatus_t;
475
476pub type PFN_cusolverDnXgetrf = unsafe extern "C" fn(
477 handle: cusolverDnHandle_t,
478 params: cusolverDnParams_t,
479 m: i64,
480 n: i64,
481 data_type_a: cudaDataType,
482 a: *mut c_void,
483 lda: i64,
484 ipiv: *mut i64,
485 compute_type: cudaDataType,
486 bufferondevice: *mut c_void,
487 workspace_in_bytes_on_device: usize,
488 bufferonhost: *mut c_void,
489 workspace_in_bytes_on_host: usize,
490 info: *mut c_int,
491) -> cusolverStatus_t;
492
493pub type PFN_cusolverDnXgetrs = unsafe extern "C" fn(
494 handle: cusolverDnHandle_t,
495 params: cusolverDnParams_t,
496 trans: cublasOperation_t,
497 n: i64,
498 nrhs: i64,
499 data_type_a: cudaDataType,
500 a: *const c_void,
501 lda: i64,
502 ipiv: *const i64,
503 data_type_b: cudaDataType,
504 b: *mut c_void,
505 ldb: i64,
506 info: *mut c_int,
507) -> cusolverStatus_t;
508
509pub type PFN_cusolverDnXgeqrf_bufferSize = unsafe extern "C" fn(
510 handle: cusolverDnHandle_t,
511 params: cusolverDnParams_t,
512 m: i64,
513 n: i64,
514 data_type_a: cudaDataType,
515 a: *const c_void,
516 lda: i64,
517 data_type_tau: cudaDataType,
518 tau: *const c_void,
519 compute_type: cudaDataType,
520 workspace_in_bytes_on_device: *mut usize,
521 workspace_in_bytes_on_host: *mut usize,
522) -> cusolverStatus_t;
523
524pub type PFN_cusolverDnXgeqrf = unsafe extern "C" fn(
525 handle: cusolverDnHandle_t,
526 params: cusolverDnParams_t,
527 m: i64,
528 n: i64,
529 data_type_a: cudaDataType,
530 a: *mut c_void,
531 lda: i64,
532 data_type_tau: cudaDataType,
533 tau: *mut c_void,
534 compute_type: cudaDataType,
535 bufferondevice: *mut c_void,
536 workspace_in_bytes_on_device: usize,
537 bufferonhost: *mut c_void,
538 workspace_in_bytes_on_host: usize,
539 info: *mut c_int,
540) -> cusolverStatus_t;
541
542pub type PFN_cusolverDnXpotrf_bufferSize = unsafe extern "C" fn(
543 handle: cusolverDnHandle_t,
544 params: cusolverDnParams_t,
545 uplo: cublasFillMode_t,
546 n: i64,
547 data_type_a: cudaDataType,
548 a: *const c_void,
549 lda: i64,
550 compute_type: cudaDataType,
551 workspace_in_bytes_on_device: *mut usize,
552 workspace_in_bytes_on_host: *mut usize,
553) -> cusolverStatus_t;
554
555pub type PFN_cusolverDnXpotrf = unsafe extern "C" fn(
556 handle: cusolverDnHandle_t,
557 params: cusolverDnParams_t,
558 uplo: cublasFillMode_t,
559 n: i64,
560 data_type_a: cudaDataType,
561 a: *mut c_void,
562 lda: i64,
563 compute_type: cudaDataType,
564 bufferondevice: *mut c_void,
565 workspace_in_bytes_on_device: usize,
566 bufferonhost: *mut c_void,
567 workspace_in_bytes_on_host: usize,
568 info: *mut c_int,
569) -> cusolverStatus_t;
570
571pub type PFN_cusolverDnXpotrs = unsafe extern "C" fn(
572 handle: cusolverDnHandle_t,
573 params: cusolverDnParams_t,
574 uplo: cublasFillMode_t,
575 n: i64,
576 nrhs: i64,
577 data_type_a: cudaDataType,
578 a: *const c_void,
579 lda: i64,
580 data_type_b: cudaDataType,
581 b: *mut c_void,
582 ldb: i64,
583 info: *mut c_int,
584) -> cusolverStatus_t;
585
586pub type PFN_cusolverDnXsyevd_bufferSize = unsafe extern "C" fn(
587 handle: cusolverDnHandle_t,
588 params: cusolverDnParams_t,
589 jobz: cusolverEigMode_t,
590 uplo: cublasFillMode_t,
591 n: i64,
592 data_type_a: cudaDataType,
593 a: *const c_void,
594 lda: i64,
595 data_type_w: cudaDataType,
596 w: *const c_void,
597 compute_type: cudaDataType,
598 device_bytes: *mut usize,
599 host_bytes: *mut usize,
600) -> cusolverStatus_t;
601
602pub type PFN_cusolverDnXsyevd = unsafe extern "C" fn(
603 handle: cusolverDnHandle_t,
604 params: cusolverDnParams_t,
605 jobz: cusolverEigMode_t,
606 uplo: cublasFillMode_t,
607 n: i64,
608 data_type_a: cudaDataType,
609 a: *mut c_void,
610 lda: i64,
611 data_type_w: cudaDataType,
612 w: *mut c_void,
613 compute_type: cudaDataType,
614 bufferondevice: *mut c_void,
615 device_bytes: usize,
616 bufferonhost: *mut c_void,
617 host_bytes: usize,
618 info: *mut c_int,
619) -> cusolverStatus_t;
620
621pub type PFN_cusolverDnCreateSyevjInfo =
626 unsafe extern "C" fn(info: *mut syevjInfo_t) -> cusolverStatus_t;
627pub type PFN_cusolverDnDestroySyevjInfo =
628 unsafe extern "C" fn(info: syevjInfo_t) -> cusolverStatus_t;
629pub type PFN_cusolverDnXsyevjSetTolerance =
630 unsafe extern "C" fn(info: syevjInfo_t, tolerance: f64) -> cusolverStatus_t;
631pub type PFN_cusolverDnXsyevjSetMaxSweeps =
632 unsafe extern "C" fn(info: syevjInfo_t, max_sweeps: c_int) -> cusolverStatus_t;
633
634macro_rules! dn_syevj_bufsize {
635 ($name:ident, $t:ty, $real:ty) => {
636 pub type $name = unsafe extern "C" fn(
637 handle: cusolverDnHandle_t,
638 jobz: cusolverEigMode_t,
639 uplo: cublasFillMode_t,
640 n: c_int,
641 a: *const $t,
642 lda: c_int,
643 w: *const $real,
644 lwork: *mut c_int,
645 params: syevjInfo_t,
646 ) -> cusolverStatus_t;
647 };
648}
649dn_syevj_bufsize!(PFN_cusolverDnSsyevj_bufferSize, f32, f32);
650dn_syevj_bufsize!(PFN_cusolverDnDsyevj_bufferSize, f64, f64);
651dn_syevj_bufsize!(PFN_cusolverDnCheevj_bufferSize, cuComplex, f32);
652dn_syevj_bufsize!(PFN_cusolverDnZheevj_bufferSize, cuDoubleComplex, f64);
653
654macro_rules! dn_syevj {
655 ($name:ident, $t:ty, $real:ty) => {
656 pub type $name = unsafe extern "C" fn(
657 handle: cusolverDnHandle_t,
658 jobz: cusolverEigMode_t,
659 uplo: cublasFillMode_t,
660 n: c_int,
661 a: *mut $t,
662 lda: c_int,
663 w: *mut $real,
664 work: *mut $t,
665 lwork: c_int,
666 info: *mut c_int,
667 params: syevjInfo_t,
668 ) -> cusolverStatus_t;
669 };
670}
671dn_syevj!(PFN_cusolverDnSsyevj, f32, f32);
672dn_syevj!(PFN_cusolverDnDsyevj, f64, f64);
673dn_syevj!(PFN_cusolverDnCheevj, cuComplex, f32);
674dn_syevj!(PFN_cusolverDnZheevj, cuDoubleComplex, f64);
675
676pub type PFN_cusolverDnCreateGesvdjInfo =
677 unsafe extern "C" fn(info: *mut gesvdjInfo_t) -> cusolverStatus_t;
678pub type PFN_cusolverDnDestroyGesvdjInfo =
679 unsafe extern "C" fn(info: gesvdjInfo_t) -> cusolverStatus_t;
680
681macro_rules! dn_gesvdj_bufsize {
682 ($name:ident, $t:ty, $real:ty) => {
683 pub type $name = unsafe extern "C" fn(
684 handle: cusolverDnHandle_t,
685 jobz: cusolverEigMode_t,
686 econ: c_int,
687 m: c_int,
688 n: c_int,
689 a: *const $t,
690 lda: c_int,
691 s: *const $real,
692 u: *const $t,
693 ldu: c_int,
694 v: *const $t,
695 ldv: c_int,
696 lwork: *mut c_int,
697 params: gesvdjInfo_t,
698 ) -> cusolverStatus_t;
699 };
700}
701dn_gesvdj_bufsize!(PFN_cusolverDnSgesvdj_bufferSize, f32, f32);
702dn_gesvdj_bufsize!(PFN_cusolverDnDgesvdj_bufferSize, f64, f64);
703dn_gesvdj_bufsize!(PFN_cusolverDnCgesvdj_bufferSize, cuComplex, f32);
704dn_gesvdj_bufsize!(PFN_cusolverDnZgesvdj_bufferSize, cuDoubleComplex, f64);
705
706macro_rules! dn_gesvdj {
707 ($name:ident, $t:ty, $real:ty) => {
708 pub type $name = unsafe extern "C" fn(
709 handle: cusolverDnHandle_t,
710 jobz: cusolverEigMode_t,
711 econ: c_int,
712 m: c_int,
713 n: c_int,
714 a: *mut $t,
715 lda: c_int,
716 s: *mut $real,
717 u: *mut $t,
718 ldu: c_int,
719 v: *mut $t,
720 ldv: c_int,
721 work: *mut $t,
722 lwork: c_int,
723 info: *mut c_int,
724 params: gesvdjInfo_t,
725 ) -> cusolverStatus_t;
726 };
727}
728dn_gesvdj!(PFN_cusolverDnSgesvdj, f32, f32);
729dn_gesvdj!(PFN_cusolverDnDgesvdj, f64, f64);
730dn_gesvdj!(PFN_cusolverDnCgesvdj, cuComplex, f32);
731dn_gesvdj!(PFN_cusolverDnZgesvdj, cuDoubleComplex, f64);
732
733macro_rules! dn_orgqr_bufsize {
738 ($name:ident, $t:ty) => {
739 pub type $name = unsafe extern "C" fn(
740 handle: cusolverDnHandle_t,
741 m: c_int,
742 n: c_int,
743 k: c_int,
744 a: *const $t,
745 lda: c_int,
746 tau: *const $t,
747 lwork: *mut c_int,
748 ) -> cusolverStatus_t;
749 };
750}
751dn_orgqr_bufsize!(PFN_cusolverDnSorgqr_bufferSize, f32);
752dn_orgqr_bufsize!(PFN_cusolverDnDorgqr_bufferSize, f64);
753dn_orgqr_bufsize!(PFN_cusolverDnCungqr_bufferSize, cuComplex);
754dn_orgqr_bufsize!(PFN_cusolverDnZungqr_bufferSize, cuDoubleComplex);
755
756macro_rules! dn_orgqr {
757 ($name:ident, $t:ty) => {
758 pub type $name = unsafe extern "C" fn(
759 handle: cusolverDnHandle_t,
760 m: c_int,
761 n: c_int,
762 k: c_int,
763 a: *mut $t,
764 lda: c_int,
765 tau: *const $t,
766 work: *mut $t,
767 lwork: c_int,
768 info: *mut c_int,
769 ) -> cusolverStatus_t;
770 };
771}
772dn_orgqr!(PFN_cusolverDnSorgqr, f32);
773dn_orgqr!(PFN_cusolverDnDorgqr, f64);
774dn_orgqr!(PFN_cusolverDnCungqr, cuComplex);
775dn_orgqr!(PFN_cusolverDnZungqr, cuDoubleComplex);
776
777macro_rules! dn_ormqr_bufsize {
778 ($name:ident, $t:ty) => {
779 pub type $name = unsafe extern "C" fn(
780 handle: cusolverDnHandle_t,
781 side: c_int,
782 trans: cublasOperation_t,
783 m: c_int,
784 n: c_int,
785 k: c_int,
786 a: *const $t,
787 lda: c_int,
788 tau: *const $t,
789 c: *const $t,
790 ldc: c_int,
791 lwork: *mut c_int,
792 ) -> cusolverStatus_t;
793 };
794}
795dn_ormqr_bufsize!(PFN_cusolverDnSormqr_bufferSize, f32);
796dn_ormqr_bufsize!(PFN_cusolverDnDormqr_bufferSize, f64);
797dn_ormqr_bufsize!(PFN_cusolverDnCunmqr_bufferSize, cuComplex);
798dn_ormqr_bufsize!(PFN_cusolverDnZunmqr_bufferSize, cuDoubleComplex);
799
800macro_rules! dn_ormqr {
801 ($name:ident, $t:ty) => {
802 pub type $name = unsafe extern "C" fn(
803 handle: cusolverDnHandle_t,
804 side: c_int,
805 trans: cublasOperation_t,
806 m: c_int,
807 n: c_int,
808 k: c_int,
809 a: *const $t,
810 lda: c_int,
811 tau: *const $t,
812 c: *mut $t,
813 ldc: c_int,
814 work: *mut $t,
815 lwork: c_int,
816 info: *mut c_int,
817 ) -> cusolverStatus_t;
818 };
819}
820dn_ormqr!(PFN_cusolverDnSormqr, f32);
821dn_ormqr!(PFN_cusolverDnDormqr, f64);
822dn_ormqr!(PFN_cusolverDnCunmqr, cuComplex);
823dn_ormqr!(PFN_cusolverDnZunmqr, cuDoubleComplex);
824
825pub type PFN_cusolverSpCreate =
828 unsafe extern "C" fn(handle: *mut cusolverSpHandle_t) -> cusolverStatus_t;
829pub type PFN_cusolverSpDestroy =
830 unsafe extern "C" fn(handle: cusolverSpHandle_t) -> cusolverStatus_t;
831pub type PFN_cusolverSpSetStream =
832 unsafe extern "C" fn(handle: cusolverSpHandle_t, stream: cudaStream_t) -> cusolverStatus_t;
833
834pub type PFN_cusolverSpScsrlsvchol = unsafe extern "C" fn(
835 handle: cusolverSpHandle_t,
836 m: c_int,
837 nnz: c_int,
838 descr_a: *mut c_void,
839 csr_val: *const f32,
840 csr_row_ptr: *const c_int,
841 csr_col_ind: *const c_int,
842 b: *const f32,
843 tol: f32,
844 reorder: c_int,
845 x: *mut f32,
846 singularity: *mut c_int,
847) -> cusolverStatus_t;
848
849pub type PFN_cusolverSpDcsrlsvchol = unsafe extern "C" fn(
850 handle: cusolverSpHandle_t,
851 m: c_int,
852 nnz: c_int,
853 descr_a: *mut c_void,
854 csr_val: *const f64,
855 csr_row_ptr: *const c_int,
856 csr_col_ind: *const c_int,
857 b: *const f64,
858 tol: f64,
859 reorder: c_int,
860 x: *mut f64,
861 singularity: *mut c_int,
862) -> cusolverStatus_t;
863
864pub type PFN_cusolverSpScsrlsvqr = unsafe extern "C" fn(
865 handle: cusolverSpHandle_t,
866 m: c_int,
867 nnz: c_int,
868 descr_a: *mut c_void,
869 csr_val: *const f32,
870 csr_row_ptr: *const c_int,
871 csr_col_ind: *const c_int,
872 b: *const f32,
873 tol: f32,
874 reorder: c_int,
875 x: *mut f32,
876 singularity: *mut c_int,
877) -> cusolverStatus_t;
878
879pub type PFN_cusolverSpDcsrlsvqr = unsafe extern "C" fn(
880 handle: cusolverSpHandle_t,
881 m: c_int,
882 nnz: c_int,
883 descr_a: *mut c_void,
884 csr_val: *const f64,
885 csr_row_ptr: *const c_int,
886 csr_col_ind: *const c_int,
887 b: *const f64,
888 tol: f64,
889 reorder: c_int,
890 x: *mut f64,
891 singularity: *mut c_int,
892) -> cusolverStatus_t;
893
894pub type PFN_cusolverRfCreate =
897 unsafe extern "C" fn(handle: *mut cusolverRfHandle_t) -> cusolverStatus_t;
898pub type PFN_cusolverRfDestroy =
899 unsafe extern "C" fn(handle: cusolverRfHandle_t) -> cusolverStatus_t;
900pub type PFN_cusolverRfSetupDevice = unsafe extern "C" fn(
901 n: c_int,
902 nnz_a: c_int,
903 h_csr_row_ptr_a: *mut c_int,
904 h_csr_col_ind_a: *mut c_int,
905 h_csr_val_a: *mut f64,
906 nnz_l: c_int,
907 h_csr_row_ptr_l: *mut c_int,
908 h_csr_col_ind_l: *mut c_int,
909 h_csr_val_l: *mut f64,
910 nnz_u: c_int,
911 h_csr_row_ptr_u: *mut c_int,
912 h_csr_col_ind_u: *mut c_int,
913 h_csr_val_u: *mut f64,
914 p: *mut c_int,
915 q: *mut c_int,
916 handle: cusolverRfHandle_t,
917) -> cusolverStatus_t;
918pub type PFN_cusolverRfAnalyze =
919 unsafe extern "C" fn(handle: cusolverRfHandle_t) -> cusolverStatus_t;
920pub type PFN_cusolverRfRefactor =
921 unsafe extern "C" fn(handle: cusolverRfHandle_t) -> cusolverStatus_t;
922pub type PFN_cusolverRfSolve = unsafe extern "C" fn(
923 handle: cusolverRfHandle_t,
924 p: *mut c_int,
925 q: *mut c_int,
926 nrhs: c_int,
927 temp: *mut f64,
928 ld_temp: c_int,
929 xf: *mut f64,
930 ld_xf: c_int,
931) -> cusolverStatus_t;
932
933macro_rules! dn_gels_bufsize {
938 ($name:ident, $t:ty) => {
939 pub type $name = unsafe extern "C" fn(
940 handle: cusolverDnHandle_t,
941 m: c_int,
942 n: c_int,
943 nrhs: c_int,
944 d_a: *mut $t,
945 lda: c_int,
946 d_b: *mut $t,
947 ldb: c_int,
948 d_x: *mut $t,
949 ldx: c_int,
950 d_work: *mut c_void,
951 lwork_bytes: *mut usize,
952 ) -> cusolverStatus_t;
953 };
954}
955dn_gels_bufsize!(PFN_cusolverDnSSgels_bufferSize, f32);
956dn_gels_bufsize!(PFN_cusolverDnDDgels_bufferSize, f64);
957dn_gels_bufsize!(PFN_cusolverDnCCgels_bufferSize, cuComplex);
958dn_gels_bufsize!(PFN_cusolverDnZZgels_bufferSize, cuDoubleComplex);
959
960macro_rules! dn_gels {
961 ($name:ident, $t:ty) => {
962 pub type $name = unsafe extern "C" fn(
963 handle: cusolverDnHandle_t,
964 m: c_int,
965 n: c_int,
966 nrhs: c_int,
967 d_a: *mut $t,
968 lda: c_int,
969 d_b: *mut $t,
970 ldb: c_int,
971 d_x: *mut $t,
972 ldx: c_int,
973 d_work: *mut c_void,
974 lwork_bytes: usize,
975 iter: *mut c_int,
976 d_info: *mut c_int,
977 ) -> cusolverStatus_t;
978 };
979}
980dn_gels!(PFN_cusolverDnSSgels, f32);
981dn_gels!(PFN_cusolverDnDDgels, f64);
982dn_gels!(PFN_cusolverDnCCgels, cuComplex);
983dn_gels!(PFN_cusolverDnZZgels, cuDoubleComplex);
984
985macro_rules! dn_potri_bufsize {
990 ($name:ident, $t:ty) => {
991 pub type $name = unsafe extern "C" fn(
992 handle: cusolverDnHandle_t,
993 uplo: cublasFillMode_t,
994 n: c_int,
995 a: *mut $t,
996 lda: c_int,
997 lwork: *mut c_int,
998 ) -> cusolverStatus_t;
999 };
1000}
1001dn_potri_bufsize!(PFN_cusolverDnSpotri_bufferSize, f32);
1002dn_potri_bufsize!(PFN_cusolverDnDpotri_bufferSize, f64);
1003dn_potri_bufsize!(PFN_cusolverDnCpotri_bufferSize, cuComplex);
1004dn_potri_bufsize!(PFN_cusolverDnZpotri_bufferSize, cuDoubleComplex);
1005
1006macro_rules! dn_potri {
1007 ($name:ident, $t:ty) => {
1008 pub type $name = unsafe extern "C" fn(
1009 handle: cusolverDnHandle_t,
1010 uplo: cublasFillMode_t,
1011 n: c_int,
1012 a: *mut $t,
1013 lda: c_int,
1014 work: *mut $t,
1015 lwork: c_int,
1016 info: *mut c_int,
1017 ) -> cusolverStatus_t;
1018 };
1019}
1020dn_potri!(PFN_cusolverDnSpotri, f32);
1021dn_potri!(PFN_cusolverDnDpotri, f64);
1022dn_potri!(PFN_cusolverDnCpotri, cuComplex);
1023dn_potri!(PFN_cusolverDnZpotri, cuDoubleComplex);
1024
1025macro_rules! dn_syevj_batched_bufsize {
1030 ($name:ident, $t:ty, $real:ty) => {
1031 pub type $name = unsafe extern "C" fn(
1032 handle: cusolverDnHandle_t,
1033 jobz: cusolverEigMode_t,
1034 uplo: cublasFillMode_t,
1035 n: c_int,
1036 a: *const $t,
1037 lda: c_int,
1038 w: *const $real,
1039 lwork: *mut c_int,
1040 params: syevjInfo_t,
1041 batch_size: c_int,
1042 ) -> cusolverStatus_t;
1043 };
1044}
1045dn_syevj_batched_bufsize!(PFN_cusolverDnSsyevjBatched_bufferSize, f32, f32);
1046dn_syevj_batched_bufsize!(PFN_cusolverDnDsyevjBatched_bufferSize, f64, f64);
1047dn_syevj_batched_bufsize!(PFN_cusolverDnCheevjBatched_bufferSize, cuComplex, f32);
1048dn_syevj_batched_bufsize!(PFN_cusolverDnZheevjBatched_bufferSize, cuDoubleComplex, f64);
1049
1050macro_rules! dn_syevj_batched {
1051 ($name:ident, $t:ty, $real:ty) => {
1052 pub type $name = unsafe extern "C" fn(
1053 handle: cusolverDnHandle_t,
1054 jobz: cusolverEigMode_t,
1055 uplo: cublasFillMode_t,
1056 n: c_int,
1057 a: *mut $t,
1058 lda: c_int,
1059 w: *mut $real,
1060 work: *mut $t,
1061 lwork: c_int,
1062 info: *mut c_int,
1063 params: syevjInfo_t,
1064 batch_size: c_int,
1065 ) -> cusolverStatus_t;
1066 };
1067}
1068dn_syevj_batched!(PFN_cusolverDnSsyevjBatched, f32, f32);
1069dn_syevj_batched!(PFN_cusolverDnDsyevjBatched, f64, f64);
1070dn_syevj_batched!(PFN_cusolverDnCheevjBatched, cuComplex, f32);
1071dn_syevj_batched!(PFN_cusolverDnZheevjBatched, cuDoubleComplex, f64);
1072
1073macro_rules! dn_gesvdj_batched_bufsize {
1074 ($name:ident, $t:ty, $real:ty) => {
1075 pub type $name = unsafe extern "C" fn(
1076 handle: cusolverDnHandle_t,
1077 jobz: cusolverEigMode_t,
1078 m: c_int,
1079 n: c_int,
1080 a: *const $t,
1081 lda: c_int,
1082 s: *const $real,
1083 u: *const $t,
1084 ldu: c_int,
1085 v: *const $t,
1086 ldv: c_int,
1087 lwork: *mut c_int,
1088 params: gesvdjInfo_t,
1089 batch_size: c_int,
1090 ) -> cusolverStatus_t;
1091 };
1092}
1093dn_gesvdj_batched_bufsize!(PFN_cusolverDnSgesvdjBatched_bufferSize, f32, f32);
1094dn_gesvdj_batched_bufsize!(PFN_cusolverDnDgesvdjBatched_bufferSize, f64, f64);
1095dn_gesvdj_batched_bufsize!(PFN_cusolverDnCgesvdjBatched_bufferSize, cuComplex, f32);
1096dn_gesvdj_batched_bufsize!(PFN_cusolverDnZgesvdjBatched_bufferSize, cuDoubleComplex, f64);
1097
1098macro_rules! dn_gesvdj_batched {
1099 ($name:ident, $t:ty, $real:ty) => {
1100 pub type $name = unsafe extern "C" fn(
1101 handle: cusolverDnHandle_t,
1102 jobz: cusolverEigMode_t,
1103 m: c_int,
1104 n: c_int,
1105 a: *mut $t,
1106 lda: c_int,
1107 s: *mut $real,
1108 u: *mut $t,
1109 ldu: c_int,
1110 v: *mut $t,
1111 ldv: c_int,
1112 work: *mut $t,
1113 lwork: c_int,
1114 info: *mut c_int,
1115 params: gesvdjInfo_t,
1116 batch_size: c_int,
1117 ) -> cusolverStatus_t;
1118 };
1119}
1120dn_gesvdj_batched!(PFN_cusolverDnSgesvdjBatched, f32, f32);
1121dn_gesvdj_batched!(PFN_cusolverDnDgesvdjBatched, f64, f64);
1122dn_gesvdj_batched!(PFN_cusolverDnCgesvdjBatched, cuComplex, f32);
1123dn_gesvdj_batched!(PFN_cusolverDnZgesvdjBatched, cuDoubleComplex, f64);
1124
1125pub type cusolverMgHandle_t = *mut c_void;
1130pub type cudaLibMgMatrixDesc_t = *mut c_void;
1131pub type cudaLibMgGrid_t = *mut c_void;
1132
1133pub type PFN_cusolverMgCreate =
1134 unsafe extern "C" fn(handle: *mut cusolverMgHandle_t) -> cusolverStatus_t;
1135pub type PFN_cusolverMgDestroy =
1136 unsafe extern "C" fn(handle: cusolverMgHandle_t) -> cusolverStatus_t;
1137pub type PFN_cusolverMgDeviceSelect = unsafe extern "C" fn(
1138 handle: cusolverMgHandle_t,
1139 n_devices: c_int,
1140 device_id: *const c_int,
1141) -> cusolverStatus_t;
1142
1143pub type PFN_cusolverMgCreateDeviceGrid = unsafe extern "C" fn(
1144 grid: *mut cudaLibMgGrid_t,
1145 num_row_devices: i32,
1146 num_col_devices: i32,
1147 device_id: *const i32,
1148 mapping: i32,
1149) -> cusolverStatus_t;
1150
1151pub type PFN_cusolverMgDestroyGrid =
1152 unsafe extern "C" fn(grid: cudaLibMgGrid_t) -> cusolverStatus_t;
1153
1154pub type PFN_cusolverMgCreateMatrixDesc = unsafe extern "C" fn(
1155 desc: *mut cudaLibMgMatrixDesc_t,
1156 num_rows: i64,
1157 num_cols: i64,
1158 row_block_size: i64,
1159 col_block_size: i64,
1160 data_type: cudaDataType,
1161 grid: cudaLibMgGrid_t,
1162) -> cusolverStatus_t;
1163
1164pub type PFN_cusolverMgDestroyMatrixDesc =
1165 unsafe extern "C" fn(desc: cudaLibMgMatrixDesc_t) -> cusolverStatus_t;
1166
1167pub type PFN_cusolverMgGetrf_bufferSize = unsafe extern "C" fn(
1168 handle: cusolverMgHandle_t,
1169 m: c_int,
1170 n: c_int,
1171 array_d_a: *mut *mut c_void,
1172 ia: c_int,
1173 ja: c_int,
1174 desc_a: cudaLibMgMatrixDesc_t,
1175 array_d_ipiv: *mut *mut c_int,
1176 compute_type: cudaDataType,
1177 lwork: *mut i64,
1178) -> cusolverStatus_t;
1179
1180pub type PFN_cusolverMgGetrf = unsafe extern "C" fn(
1181 handle: cusolverMgHandle_t,
1182 m: c_int,
1183 n: c_int,
1184 array_d_a: *mut *mut c_void,
1185 ia: c_int,
1186 ja: c_int,
1187 desc_a: cudaLibMgMatrixDesc_t,
1188 array_d_ipiv: *mut *mut c_int,
1189 compute_type: cudaDataType,
1190 array_d_work: *mut *mut c_void,
1191 lwork: i64,
1192 info: *mut c_int,
1193) -> cusolverStatus_t;
1194
1195pub type PFN_cusolverMgPotrf_bufferSize = unsafe extern "C" fn(
1196 handle: cusolverMgHandle_t,
1197 uplo: cublasFillMode_t,
1198 n: c_int,
1199 array_d_a: *mut *mut c_void,
1200 ia: c_int,
1201 ja: c_int,
1202 desc_a: cudaLibMgMatrixDesc_t,
1203 compute_type: cudaDataType,
1204 lwork: *mut i64,
1205) -> cusolverStatus_t;
1206
1207pub type PFN_cusolverMgPotrf = unsafe extern "C" fn(
1208 handle: cusolverMgHandle_t,
1209 uplo: cublasFillMode_t,
1210 n: c_int,
1211 array_d_a: *mut *mut c_void,
1212 ia: c_int,
1213 ja: c_int,
1214 desc_a: cudaLibMgMatrixDesc_t,
1215 compute_type: cudaDataType,
1216 array_d_work: *mut *mut c_void,
1217 lwork: i64,
1218 info: *mut c_int,
1219) -> cusolverStatus_t;
1220
1221pub type PFN_cusolverMgSyevd_bufferSize = unsafe extern "C" fn(
1222 handle: cusolverMgHandle_t,
1223 jobz: cusolverEigMode_t,
1224 uplo: cublasFillMode_t,
1225 n: c_int,
1226 array_d_a: *mut *mut c_void,
1227 ia: c_int,
1228 ja: c_int,
1229 desc_a: cudaLibMgMatrixDesc_t,
1230 w: *mut c_void,
1231 data_type_w: cudaDataType,
1232 compute_type: cudaDataType,
1233 lwork: *mut i64,
1234) -> cusolverStatus_t;
1235
1236pub type PFN_cusolverMgSyevd = unsafe extern "C" fn(
1237 handle: cusolverMgHandle_t,
1238 jobz: cusolverEigMode_t,
1239 uplo: cublasFillMode_t,
1240 n: c_int,
1241 array_d_a: *mut *mut c_void,
1242 ia: c_int,
1243 ja: c_int,
1244 desc_a: cudaLibMgMatrixDesc_t,
1245 w: *mut c_void,
1246 data_type_w: cudaDataType,
1247 compute_type: cudaDataType,
1248 array_d_work: *mut *mut c_void,
1249 lwork: i64,
1250 info: *mut c_int,
1251) -> cusolverStatus_t;
1252
1253fn cusolver_candidates() -> Vec<String> {
1256 platform::versioned_library_candidates("cusolver", &["13", "12", "11"])
1257}
1258
1259macro_rules! cusolver_fns {
1260 ($($name:ident as $sym:literal : $pfn:ty);* $(;)?) => {
1261 pub struct Cusolver {
1262 lib: Library,
1263 $($name: OnceLock<$pfn>,)*
1264 }
1265 impl core::fmt::Debug for Cusolver {
1266 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1267 f.debug_struct("Cusolver").field("lib", &self.lib).finish_non_exhaustive()
1268 }
1269 }
1270 impl Cusolver {
1271 $(
1272 pub fn $name(&self) -> Result<$pfn, LoaderError> {
1273 if let Some(&p) = self.$name.get() { return Ok(p); }
1274 let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
1275 let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
1276 let _ = self.$name.set(p);
1277 Ok(p)
1278 }
1279 )*
1280 fn empty(lib: Library) -> Self {
1281 Self { lib, $($name: OnceLock::new(),)* }
1282 }
1283 }
1284 };
1285}
1286
1287cusolver_fns! {
1288 cusolver_dn_create as "cusolverDnCreate": PFN_cusolverDnCreate;
1290 cusolver_dn_destroy as "cusolverDnDestroy": PFN_cusolverDnDestroy;
1291 cusolver_dn_set_stream as "cusolverDnSetStream": PFN_cusolverDnSetStream;
1292 cusolver_dn_get_stream as "cusolverDnGetStream": PFN_cusolverDnGetStream;
1293 cusolver_get_version as "cusolverGetVersion": PFN_cusolverGetVersion;
1294 cusolver_dn_sgetrf_buffer_size as "cusolverDnSgetrf_bufferSize": PFN_cusolverDnSgetrf_bufferSize;
1296 cusolver_dn_dgetrf_buffer_size as "cusolverDnDgetrf_bufferSize": PFN_cusolverDnDgetrf_bufferSize;
1297 cusolver_dn_cgetrf_buffer_size as "cusolverDnCgetrf_bufferSize": PFN_cusolverDnCgetrf_bufferSize;
1298 cusolver_dn_zgetrf_buffer_size as "cusolverDnZgetrf_bufferSize": PFN_cusolverDnZgetrf_bufferSize;
1299 cusolver_dn_sgetrf as "cusolverDnSgetrf": PFN_cusolverDnSgetrf;
1300 cusolver_dn_dgetrf as "cusolverDnDgetrf": PFN_cusolverDnDgetrf;
1301 cusolver_dn_cgetrf as "cusolverDnCgetrf": PFN_cusolverDnCgetrf;
1302 cusolver_dn_zgetrf as "cusolverDnZgetrf": PFN_cusolverDnZgetrf;
1303 cusolver_dn_sgetrs as "cusolverDnSgetrs": PFN_cusolverDnSgetrs;
1304 cusolver_dn_dgetrs as "cusolverDnDgetrs": PFN_cusolverDnDgetrs;
1305 cusolver_dn_cgetrs as "cusolverDnCgetrs": PFN_cusolverDnCgetrs;
1306 cusolver_dn_zgetrs as "cusolverDnZgetrs": PFN_cusolverDnZgetrs;
1307 cusolver_dn_sgeqrf_buffer_size as "cusolverDnSgeqrf_bufferSize": PFN_cusolverDnSgeqrf_bufferSize;
1309 cusolver_dn_dgeqrf_buffer_size as "cusolverDnDgeqrf_bufferSize": PFN_cusolverDnDgeqrf_bufferSize;
1310 cusolver_dn_cgeqrf_buffer_size as "cusolverDnCgeqrf_bufferSize": PFN_cusolverDnCgeqrf_bufferSize;
1311 cusolver_dn_zgeqrf_buffer_size as "cusolverDnZgeqrf_bufferSize": PFN_cusolverDnZgeqrf_bufferSize;
1312 cusolver_dn_sgeqrf as "cusolverDnSgeqrf": PFN_cusolverDnSgeqrf;
1313 cusolver_dn_dgeqrf as "cusolverDnDgeqrf": PFN_cusolverDnDgeqrf;
1314 cusolver_dn_cgeqrf as "cusolverDnCgeqrf": PFN_cusolverDnCgeqrf;
1315 cusolver_dn_zgeqrf as "cusolverDnZgeqrf": PFN_cusolverDnZgeqrf;
1316 cusolver_dn_spotrf_buffer_size as "cusolverDnSpotrf_bufferSize": PFN_cusolverDnSpotrf_bufferSize;
1318 cusolver_dn_dpotrf_buffer_size as "cusolverDnDpotrf_bufferSize": PFN_cusolverDnDpotrf_bufferSize;
1319 cusolver_dn_cpotrf_buffer_size as "cusolverDnCpotrf_bufferSize": PFN_cusolverDnCpotrf_bufferSize;
1320 cusolver_dn_zpotrf_buffer_size as "cusolverDnZpotrf_bufferSize": PFN_cusolverDnZpotrf_bufferSize;
1321 cusolver_dn_spotrf as "cusolverDnSpotrf": PFN_cusolverDnSpotrf;
1322 cusolver_dn_dpotrf as "cusolverDnDpotrf": PFN_cusolverDnDpotrf;
1323 cusolver_dn_cpotrf as "cusolverDnCpotrf": PFN_cusolverDnCpotrf;
1324 cusolver_dn_zpotrf as "cusolverDnZpotrf": PFN_cusolverDnZpotrf;
1325 cusolver_dn_spotrs as "cusolverDnSpotrs": PFN_cusolverDnSpotrs;
1326 cusolver_dn_dpotrs as "cusolverDnDpotrs": PFN_cusolverDnDpotrs;
1327 cusolver_dn_cpotrs as "cusolverDnCpotrs": PFN_cusolverDnCpotrs;
1328 cusolver_dn_zpotrs as "cusolverDnZpotrs": PFN_cusolverDnZpotrs;
1329 cusolver_dn_sgesvd_buffer_size as "cusolverDnSgesvd_bufferSize": PFN_cusolverDnSgesvd_bufferSize;
1331 cusolver_dn_dgesvd_buffer_size as "cusolverDnDgesvd_bufferSize": PFN_cusolverDnDgesvd_bufferSize;
1332 cusolver_dn_cgesvd_buffer_size as "cusolverDnCgesvd_bufferSize": PFN_cusolverDnCgesvd_bufferSize;
1333 cusolver_dn_zgesvd_buffer_size as "cusolverDnZgesvd_bufferSize": PFN_cusolverDnZgesvd_bufferSize;
1334 cusolver_dn_sgesvd as "cusolverDnSgesvd": PFN_cusolverDnSgesvd;
1335 cusolver_dn_dgesvd as "cusolverDnDgesvd": PFN_cusolverDnDgesvd;
1336 cusolver_dn_cgesvd as "cusolverDnCgesvd": PFN_cusolverDnCgesvd;
1337 cusolver_dn_zgesvd as "cusolverDnZgesvd": PFN_cusolverDnZgesvd;
1338 cusolver_dn_ssyevd_buffer_size as "cusolverDnSsyevd_bufferSize": PFN_cusolverDnSsyevd_bufferSize;
1340 cusolver_dn_dsyevd_buffer_size as "cusolverDnDsyevd_bufferSize": PFN_cusolverDnDsyevd_bufferSize;
1341 cusolver_dn_cheevd_buffer_size as "cusolverDnCheevd_bufferSize": PFN_cusolverDnCheevd_bufferSize;
1342 cusolver_dn_zheevd_buffer_size as "cusolverDnZheevd_bufferSize": PFN_cusolverDnZheevd_bufferSize;
1343 cusolver_dn_ssyevd as "cusolverDnSsyevd": PFN_cusolverDnSsyevd;
1344 cusolver_dn_dsyevd as "cusolverDnDsyevd": PFN_cusolverDnDsyevd;
1345 cusolver_dn_cheevd as "cusolverDnCheevd": PFN_cusolverDnCheevd;
1346 cusolver_dn_zheevd as "cusolverDnZheevd": PFN_cusolverDnZheevd;
1347 cusolver_dn_create_params as "cusolverDnCreateParams": PFN_cusolverDnCreateParams;
1349 cusolver_dn_destroy_params as "cusolverDnDestroyParams": PFN_cusolverDnDestroyParams;
1350 cusolver_dn_xgetrf_buffer_size as "cusolverDnXgetrf_bufferSize": PFN_cusolverDnXgetrf_bufferSize;
1351 cusolver_dn_xgetrf as "cusolverDnXgetrf": PFN_cusolverDnXgetrf;
1352 cusolver_dn_xgetrs as "cusolverDnXgetrs": PFN_cusolverDnXgetrs;
1353 cusolver_dn_xgeqrf_buffer_size as "cusolverDnXgeqrf_bufferSize": PFN_cusolverDnXgeqrf_bufferSize;
1354 cusolver_dn_xgeqrf as "cusolverDnXgeqrf": PFN_cusolverDnXgeqrf;
1355 cusolver_dn_xpotrf_buffer_size as "cusolverDnXpotrf_bufferSize": PFN_cusolverDnXpotrf_bufferSize;
1356 cusolver_dn_xpotrf as "cusolverDnXpotrf": PFN_cusolverDnXpotrf;
1357 cusolver_dn_xpotrs as "cusolverDnXpotrs": PFN_cusolverDnXpotrs;
1358 cusolver_dn_xsyevd_buffer_size as "cusolverDnXsyevd_bufferSize": PFN_cusolverDnXsyevd_bufferSize;
1359 cusolver_dn_xsyevd as "cusolverDnXsyevd": PFN_cusolverDnXsyevd;
1360 cusolver_dn_create_syevj_info as "cusolverDnCreateSyevjInfo": PFN_cusolverDnCreateSyevjInfo;
1362 cusolver_dn_destroy_syevj_info as "cusolverDnDestroySyevjInfo": PFN_cusolverDnDestroySyevjInfo;
1363 cusolver_dn_xsyevj_set_tolerance as "cusolverDnXsyevjSetTolerance": PFN_cusolverDnXsyevjSetTolerance;
1364 cusolver_dn_xsyevj_set_max_sweeps as "cusolverDnXsyevjSetMaxSweeps": PFN_cusolverDnXsyevjSetMaxSweeps;
1365 cusolver_dn_ssyevj_buffer_size as "cusolverDnSsyevj_bufferSize": PFN_cusolverDnSsyevj_bufferSize;
1366 cusolver_dn_dsyevj_buffer_size as "cusolverDnDsyevj_bufferSize": PFN_cusolverDnDsyevj_bufferSize;
1367 cusolver_dn_cheevj_buffer_size as "cusolverDnCheevj_bufferSize": PFN_cusolverDnCheevj_bufferSize;
1368 cusolver_dn_zheevj_buffer_size as "cusolverDnZheevj_bufferSize": PFN_cusolverDnZheevj_bufferSize;
1369 cusolver_dn_ssyevj as "cusolverDnSsyevj": PFN_cusolverDnSsyevj;
1370 cusolver_dn_dsyevj as "cusolverDnDsyevj": PFN_cusolverDnDsyevj;
1371 cusolver_dn_cheevj as "cusolverDnCheevj": PFN_cusolverDnCheevj;
1372 cusolver_dn_zheevj as "cusolverDnZheevj": PFN_cusolverDnZheevj;
1373 cusolver_dn_create_gesvdj_info as "cusolverDnCreateGesvdjInfo": PFN_cusolverDnCreateGesvdjInfo;
1375 cusolver_dn_destroy_gesvdj_info as "cusolverDnDestroyGesvdjInfo": PFN_cusolverDnDestroyGesvdjInfo;
1376 cusolver_dn_sgesvdj_buffer_size as "cusolverDnSgesvdj_bufferSize": PFN_cusolverDnSgesvdj_bufferSize;
1377 cusolver_dn_dgesvdj_buffer_size as "cusolverDnDgesvdj_bufferSize": PFN_cusolverDnDgesvdj_bufferSize;
1378 cusolver_dn_cgesvdj_buffer_size as "cusolverDnCgesvdj_bufferSize": PFN_cusolverDnCgesvdj_bufferSize;
1379 cusolver_dn_zgesvdj_buffer_size as "cusolverDnZgesvdj_bufferSize": PFN_cusolverDnZgesvdj_bufferSize;
1380 cusolver_dn_sgesvdj as "cusolverDnSgesvdj": PFN_cusolverDnSgesvdj;
1381 cusolver_dn_dgesvdj as "cusolverDnDgesvdj": PFN_cusolverDnDgesvdj;
1382 cusolver_dn_cgesvdj as "cusolverDnCgesvdj": PFN_cusolverDnCgesvdj;
1383 cusolver_dn_zgesvdj as "cusolverDnZgesvdj": PFN_cusolverDnZgesvdj;
1384 cusolver_dn_sorgqr_buffer_size as "cusolverDnSorgqr_bufferSize": PFN_cusolverDnSorgqr_bufferSize;
1386 cusolver_dn_dorgqr_buffer_size as "cusolverDnDorgqr_bufferSize": PFN_cusolverDnDorgqr_bufferSize;
1387 cusolver_dn_cungqr_buffer_size as "cusolverDnCungqr_bufferSize": PFN_cusolverDnCungqr_bufferSize;
1388 cusolver_dn_zungqr_buffer_size as "cusolverDnZungqr_bufferSize": PFN_cusolverDnZungqr_bufferSize;
1389 cusolver_dn_sorgqr as "cusolverDnSorgqr": PFN_cusolverDnSorgqr;
1390 cusolver_dn_dorgqr as "cusolverDnDorgqr": PFN_cusolverDnDorgqr;
1391 cusolver_dn_cungqr as "cusolverDnCungqr": PFN_cusolverDnCungqr;
1392 cusolver_dn_zungqr as "cusolverDnZungqr": PFN_cusolverDnZungqr;
1393 cusolver_dn_sormqr_buffer_size as "cusolverDnSormqr_bufferSize": PFN_cusolverDnSormqr_bufferSize;
1394 cusolver_dn_dormqr_buffer_size as "cusolverDnDormqr_bufferSize": PFN_cusolverDnDormqr_bufferSize;
1395 cusolver_dn_cunmqr_buffer_size as "cusolverDnCunmqr_bufferSize": PFN_cusolverDnCunmqr_bufferSize;
1396 cusolver_dn_zunmqr_buffer_size as "cusolverDnZunmqr_bufferSize": PFN_cusolverDnZunmqr_bufferSize;
1397 cusolver_dn_sormqr as "cusolverDnSormqr": PFN_cusolverDnSormqr;
1398 cusolver_dn_dormqr as "cusolverDnDormqr": PFN_cusolverDnDormqr;
1399 cusolver_dn_cunmqr as "cusolverDnCunmqr": PFN_cusolverDnCunmqr;
1400 cusolver_dn_zunmqr as "cusolverDnZunmqr": PFN_cusolverDnZunmqr;
1401 cusolver_sp_create as "cusolverSpCreate": PFN_cusolverSpCreate;
1403 cusolver_sp_destroy as "cusolverSpDestroy": PFN_cusolverSpDestroy;
1404 cusolver_sp_set_stream as "cusolverSpSetStream": PFN_cusolverSpSetStream;
1405 cusolver_sp_scsrlsvchol as "cusolverSpScsrlsvchol": PFN_cusolverSpScsrlsvchol;
1406 cusolver_sp_dcsrlsvchol as "cusolverSpDcsrlsvchol": PFN_cusolverSpDcsrlsvchol;
1407 cusolver_sp_scsrlsvqr as "cusolverSpScsrlsvqr": PFN_cusolverSpScsrlsvqr;
1408 cusolver_sp_dcsrlsvqr as "cusolverSpDcsrlsvqr": PFN_cusolverSpDcsrlsvqr;
1409 cusolver_rf_create as "cusolverRfCreate": PFN_cusolverRfCreate;
1411 cusolver_rf_destroy as "cusolverRfDestroy": PFN_cusolverRfDestroy;
1412 cusolver_rf_setup_device as "cusolverRfSetupDevice": PFN_cusolverRfSetupDevice;
1413 cusolver_rf_analyze as "cusolverRfAnalyze": PFN_cusolverRfAnalyze;
1414 cusolver_rf_refactor as "cusolverRfRefactor": PFN_cusolverRfRefactor;
1415 cusolver_rf_solve as "cusolverRfSolve": PFN_cusolverRfSolve;
1416 cusolver_dn_ssgels_buffer_size as "cusolverDnSSgels_bufferSize": PFN_cusolverDnSSgels_bufferSize;
1418 cusolver_dn_ddgels_buffer_size as "cusolverDnDDgels_bufferSize": PFN_cusolverDnDDgels_bufferSize;
1419 cusolver_dn_ccgels_buffer_size as "cusolverDnCCgels_bufferSize": PFN_cusolverDnCCgels_bufferSize;
1420 cusolver_dn_zzgels_buffer_size as "cusolverDnZZgels_bufferSize": PFN_cusolverDnZZgels_bufferSize;
1421 cusolver_dn_ssgels as "cusolverDnSSgels": PFN_cusolverDnSSgels;
1422 cusolver_dn_ddgels as "cusolverDnDDgels": PFN_cusolverDnDDgels;
1423 cusolver_dn_ccgels as "cusolverDnCCgels": PFN_cusolverDnCCgels;
1424 cusolver_dn_zzgels as "cusolverDnZZgels": PFN_cusolverDnZZgels;
1425 cusolver_dn_spotri_buffer_size as "cusolverDnSpotri_bufferSize": PFN_cusolverDnSpotri_bufferSize;
1427 cusolver_dn_dpotri_buffer_size as "cusolverDnDpotri_bufferSize": PFN_cusolverDnDpotri_bufferSize;
1428 cusolver_dn_cpotri_buffer_size as "cusolverDnCpotri_bufferSize": PFN_cusolverDnCpotri_bufferSize;
1429 cusolver_dn_zpotri_buffer_size as "cusolverDnZpotri_bufferSize": PFN_cusolverDnZpotri_bufferSize;
1430 cusolver_dn_spotri as "cusolverDnSpotri": PFN_cusolverDnSpotri;
1431 cusolver_dn_dpotri as "cusolverDnDpotri": PFN_cusolverDnDpotri;
1432 cusolver_dn_cpotri as "cusolverDnCpotri": PFN_cusolverDnCpotri;
1433 cusolver_dn_zpotri as "cusolverDnZpotri": PFN_cusolverDnZpotri;
1434 cusolver_dn_ssyevj_batched_buffer_size as "cusolverDnSsyevjBatched_bufferSize": PFN_cusolverDnSsyevjBatched_bufferSize;
1436 cusolver_dn_dsyevj_batched_buffer_size as "cusolverDnDsyevjBatched_bufferSize": PFN_cusolverDnDsyevjBatched_bufferSize;
1437 cusolver_dn_cheevj_batched_buffer_size as "cusolverDnCheevjBatched_bufferSize": PFN_cusolverDnCheevjBatched_bufferSize;
1438 cusolver_dn_zheevj_batched_buffer_size as "cusolverDnZheevjBatched_bufferSize": PFN_cusolverDnZheevjBatched_bufferSize;
1439 cusolver_dn_ssyevj_batched as "cusolverDnSsyevjBatched": PFN_cusolverDnSsyevjBatched;
1440 cusolver_dn_dsyevj_batched as "cusolverDnDsyevjBatched": PFN_cusolverDnDsyevjBatched;
1441 cusolver_dn_cheevj_batched as "cusolverDnCheevjBatched": PFN_cusolverDnCheevjBatched;
1442 cusolver_dn_zheevj_batched as "cusolverDnZheevjBatched": PFN_cusolverDnZheevjBatched;
1443 cusolver_dn_sgesvdj_batched_buffer_size as "cusolverDnSgesvdjBatched_bufferSize": PFN_cusolverDnSgesvdjBatched_bufferSize;
1445 cusolver_dn_dgesvdj_batched_buffer_size as "cusolverDnDgesvdjBatched_bufferSize": PFN_cusolverDnDgesvdjBatched_bufferSize;
1446 cusolver_dn_cgesvdj_batched_buffer_size as "cusolverDnCgesvdjBatched_bufferSize": PFN_cusolverDnCgesvdjBatched_bufferSize;
1447 cusolver_dn_zgesvdj_batched_buffer_size as "cusolverDnZgesvdjBatched_bufferSize": PFN_cusolverDnZgesvdjBatched_bufferSize;
1448 cusolver_dn_sgesvdj_batched as "cusolverDnSgesvdjBatched": PFN_cusolverDnSgesvdjBatched;
1449 cusolver_dn_dgesvdj_batched as "cusolverDnDgesvdjBatched": PFN_cusolverDnDgesvdjBatched;
1450 cusolver_dn_cgesvdj_batched as "cusolverDnCgesvdjBatched": PFN_cusolverDnCgesvdjBatched;
1451 cusolver_dn_zgesvdj_batched as "cusolverDnZgesvdjBatched": PFN_cusolverDnZgesvdjBatched;
1452}
1453
1454pub fn cusolver() -> Result<&'static Cusolver, LoaderError> {
1455 static CUSOLVER: OnceLock<Cusolver> = OnceLock::new();
1456 if let Some(c) = CUSOLVER.get() {
1457 return Ok(c);
1458 }
1459 let candidates: Vec<&'static str> = cusolver_candidates()
1460 .into_iter()
1461 .map(|s| Box::leak(s.into_boxed_str()) as &'static str)
1462 .collect();
1463 let candidates_leaked: &'static [&'static str] = Box::leak(candidates.into_boxed_slice());
1464 let lib = Library::open("cusolver", candidates_leaked)?;
1465 let c = Cusolver::empty(lib);
1466 let _ = CUSOLVER.set(c);
1467 Ok(CUSOLVER.get().expect("OnceLock set or lost race"))
1468}
1469
1470fn cusolver_mg_candidates() -> Vec<String> {
1475 platform::versioned_library_candidates("cusolverMg", &["13", "12", "11"])
1476}
1477
1478macro_rules! cusolver_mg_fns {
1479 ($($name:ident as $sym:literal : $pfn:ty);* $(;)?) => {
1480 pub struct CusolverMg {
1481 lib: Library,
1482 $($name: OnceLock<$pfn>,)*
1483 }
1484 impl core::fmt::Debug for CusolverMg {
1485 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1486 f.debug_struct("CusolverMg").field("lib", &self.lib).finish_non_exhaustive()
1487 }
1488 }
1489 impl CusolverMg {
1490 $(
1491 pub fn $name(&self) -> Result<$pfn, LoaderError> {
1492 if let Some(&p) = self.$name.get() { return Ok(p); }
1493 let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
1494 let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
1495 let _ = self.$name.set(p);
1496 Ok(p)
1497 }
1498 )*
1499 fn empty(lib: Library) -> Self {
1500 Self { lib, $($name: OnceLock::new(),)* }
1501 }
1502 }
1503 };
1504}
1505
1506cusolver_mg_fns! {
1507 cusolver_mg_create as "cusolverMgCreate": PFN_cusolverMgCreate;
1508 cusolver_mg_destroy as "cusolverMgDestroy": PFN_cusolverMgDestroy;
1509 cusolver_mg_device_select as "cusolverMgDeviceSelect": PFN_cusolverMgDeviceSelect;
1510 cusolver_mg_create_device_grid as "cusolverMgCreateDeviceGrid": PFN_cusolverMgCreateDeviceGrid;
1511 cusolver_mg_destroy_grid as "cusolverMgDestroyGrid": PFN_cusolverMgDestroyGrid;
1512 cusolver_mg_create_matrix_desc as "cusolverMgCreateMatrixDesc": PFN_cusolverMgCreateMatrixDesc;
1513 cusolver_mg_destroy_matrix_desc as "cusolverMgDestroyMatrixDesc": PFN_cusolverMgDestroyMatrixDesc;
1514 cusolver_mg_getrf_buffer_size as "cusolverMgGetrf_bufferSize": PFN_cusolverMgGetrf_bufferSize;
1515 cusolver_mg_getrf as "cusolverMgGetrf": PFN_cusolverMgGetrf;
1516 cusolver_mg_potrf_buffer_size as "cusolverMgPotrf_bufferSize": PFN_cusolverMgPotrf_bufferSize;
1517 cusolver_mg_potrf as "cusolverMgPotrf": PFN_cusolverMgPotrf;
1518 cusolver_mg_syevd_buffer_size as "cusolverMgSyevd_bufferSize": PFN_cusolverMgSyevd_bufferSize;
1519 cusolver_mg_syevd as "cusolverMgSyevd": PFN_cusolverMgSyevd;
1520}
1521
1522pub fn cusolver_mg() -> Result<&'static CusolverMg, LoaderError> {
1523 static MG: OnceLock<CusolverMg> = OnceLock::new();
1524 if let Some(c) = MG.get() {
1525 return Ok(c);
1526 }
1527 let candidates: Vec<&'static str> = cusolver_mg_candidates()
1528 .into_iter()
1529 .map(|s| Box::leak(s.into_boxed_str()) as &'static str)
1530 .collect();
1531 let leaked: &'static [&'static str] = Box::leak(candidates.into_boxed_slice());
1532 let lib = Library::open("cusolverMg", leaked)?;
1533 let _ = MG.set(CusolverMg::empty(lib));
1534 Ok(MG.get().expect("OnceLock set or lost race"))
1535}