Skip to main content

baracuda_cusparse_sys/
lib.rs

1//! Raw FFI + dynamic loader for NVIDIA cuSPARSE (generic API subset).
2
3#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals)]
4#![warn(missing_debug_implementations)]
5
6use core::ffi::{c_int, c_void};
7use std::sync::OnceLock;
8
9use baracuda_core::{platform, Library, LoaderError};
10use baracuda_cuda_sys::runtime::cudaStream_t;
11use baracuda_types::CudaStatus;
12
13// ---- handles --------------------------------------------------------------
14
15pub type cusparseHandle_t = *mut c_void;
16pub type cusparseSpMatDescr_t = *mut c_void;
17pub type cusparseDnMatDescr_t = *mut c_void;
18pub type cusparseDnVecDescr_t = *mut c_void;
19pub type cusparseSpGEMMDescr_t = *mut c_void;
20pub type cusparseSpSVDescr_t = *mut c_void;
21pub type cusparseSpSMDescr_t = *mut c_void;
22pub type cusparseMatDescr_t = *mut c_void;
23
24// ---- enums ----------------------------------------------------------------
25
26#[repr(i32)]
27#[derive(Copy, Clone, Debug, Eq, PartialEq)]
28pub enum cusparseOperation_t {
29    N = 0,
30    T = 1,
31    C = 2,
32}
33
34#[repr(i32)]
35#[derive(Copy, Clone, Debug, Eq, PartialEq)]
36pub enum cusparseIndexType_t {
37    I16U = 1,
38    I32I = 2,
39    I64I = 3,
40}
41
42#[repr(i32)]
43#[derive(Copy, Clone, Debug, Eq, PartialEq)]
44pub enum cusparseIndexBase_t {
45    Zero = 0,
46    One = 1,
47}
48
49#[repr(i32)]
50#[derive(Copy, Clone, Debug, Eq, PartialEq)]
51pub enum cusparseOrder_t {
52    Row = 1,
53    Col = 2,
54}
55
56#[repr(i32)]
57#[derive(Copy, Clone, Debug, Eq, PartialEq)]
58pub enum cusparseSpMVAlg_t {
59    /// Driver-chosen default.
60    Default = 0,
61    /// CSR algorithm 1 (deterministic).
62    CsrAlg1 = 2,
63    /// CSR algorithm 2 (higher throughput).
64    CsrAlg2 = 3,
65    CooAlg1 = 1,
66    CooAlg2 = 4,
67}
68
69#[repr(i32)]
70#[derive(Copy, Clone, Debug, Eq, PartialEq)]
71pub enum cusparseSpMMAlg_t {
72    Default = 0,
73    CooAlg1 = 1,
74    CsrAlg1 = 2,
75    CooAlg2 = 3,
76    CooAlg3 = 4,
77    CsrAlg2 = 5,
78    CsrAlg3 = 6,
79    Bsr = 7,
80    CsrAlg4 = 8,
81}
82
83#[repr(i32)]
84#[derive(Copy, Clone, Debug, Eq, PartialEq)]
85pub enum cusparseSpGEMMAlg_t {
86    Default = 0,
87    Alg1 = 1,
88    Alg2 = 2,
89    Alg3 = 3,
90    CsrMemoryDefault = 4,
91}
92
93#[repr(i32)]
94#[derive(Copy, Clone, Debug, Eq, PartialEq)]
95pub enum cusparseSpSVAlg_t {
96    Default = 0,
97}
98
99#[repr(i32)]
100#[derive(Copy, Clone, Debug, Eq, PartialEq)]
101pub enum cusparseSpSMAlg_t {
102    Default = 0,
103}
104
105#[repr(i32)]
106#[derive(Copy, Clone, Debug, Eq, PartialEq)]
107pub enum cusparseSDDMMAlg_t {
108    Default = 0,
109}
110
111#[repr(i32)]
112#[derive(Copy, Clone, Debug, Eq, PartialEq)]
113pub enum cusparseCsr2CscAlg_t {
114    Alg1 = 1,
115    Alg2 = 2,
116}
117
118impl cusparseCsr2CscAlg_t {
119    pub const DEFAULT: Self = Self::Alg1;
120}
121
122#[repr(i32)]
123#[derive(Copy, Clone, Debug, Eq, PartialEq)]
124pub enum cusparseFillMode_t {
125    Lower = 0,
126    Upper = 1,
127}
128
129#[repr(i32)]
130#[derive(Copy, Clone, Debug, Eq, PartialEq)]
131pub enum cusparseDiagType_t {
132    NonUnit = 0,
133    Unit = 1,
134}
135
136#[repr(i32)]
137#[derive(Copy, Clone, Debug, Eq, PartialEq)]
138pub enum cusparseSpMatAttribute_t {
139    FillMode = 0,
140    DiagType = 1,
141}
142
143/// `cudaDataType` values used by cuSPARSE / cuSOLVER's generic APIs. Only
144/// the subset we actually use at v0.1.
145#[repr(i32)]
146#[derive(Copy, Clone, Debug, Eq, PartialEq)]
147pub enum cudaDataType {
148    R_32F = 0,
149    R_64F = 1,
150    R_16F = 2,
151    C_32F = 4,
152    C_64F = 5,
153    R_16BF = 14,
154}
155
156// ---- status ---------------------------------------------------------------
157
158#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
159#[repr(transparent)]
160pub struct cusparseStatus_t(pub i32);
161
162impl cusparseStatus_t {
163    pub const SUCCESS: Self = Self(0);
164    pub const NOT_INITIALIZED: Self = Self(1);
165    pub const ALLOC_FAILED: Self = Self(2);
166    pub const INVALID_VALUE: Self = Self(3);
167    pub const ARCH_MISMATCH: Self = Self(4);
168    pub const MAPPING_ERROR: Self = Self(5);
169    pub const EXECUTION_FAILED: Self = Self(6);
170    pub const INTERNAL_ERROR: Self = Self(7);
171    pub const MATRIX_TYPE_NOT_SUPPORTED: Self = Self(8);
172    pub const ZERO_PIVOT: Self = Self(9);
173    pub const NOT_SUPPORTED: Self = Self(10);
174    pub const INSUFFICIENT_RESOURCES: Self = Self(11);
175
176    pub const fn is_success(self) -> bool {
177        self.0 == 0
178    }
179}
180
181impl CudaStatus for cusparseStatus_t {
182    fn code(self) -> i32 {
183        self.0
184    }
185    fn name(self) -> &'static str {
186        match self.0 {
187            0 => "CUSPARSE_STATUS_SUCCESS",
188            1 => "CUSPARSE_STATUS_NOT_INITIALIZED",
189            2 => "CUSPARSE_STATUS_ALLOC_FAILED",
190            3 => "CUSPARSE_STATUS_INVALID_VALUE",
191            4 => "CUSPARSE_STATUS_ARCH_MISMATCH",
192            6 => "CUSPARSE_STATUS_EXECUTION_FAILED",
193            7 => "CUSPARSE_STATUS_INTERNAL_ERROR",
194            10 => "CUSPARSE_STATUS_NOT_SUPPORTED",
195            _ => "CUSPARSE_STATUS_UNRECOGNIZED",
196        }
197    }
198    fn description(self) -> &'static str {
199        match self.0 {
200            0 => "success",
201            1 => "cuSPARSE handle not initialized",
202            2 => "allocation failed",
203            3 => "invalid argument",
204            6 => "GPU execution failed",
205            10 => "operation not supported",
206            _ => "unrecognized cuSPARSE status code",
207        }
208    }
209    fn is_success(self) -> bool {
210        cusparseStatus_t::is_success(self)
211    }
212    fn library(self) -> &'static str {
213        "cusparse"
214    }
215}
216
217// ---- function-pointer types ----------------------------------------------
218
219pub type PFN_cusparseCreate =
220    unsafe extern "C" fn(handle: *mut cusparseHandle_t) -> cusparseStatus_t;
221pub type PFN_cusparseDestroy = unsafe extern "C" fn(handle: cusparseHandle_t) -> cusparseStatus_t;
222pub type PFN_cusparseSetStream =
223    unsafe extern "C" fn(handle: cusparseHandle_t, stream: cudaStream_t) -> cusparseStatus_t;
224pub type PFN_cusparseGetVersion =
225    unsafe extern "C" fn(handle: cusparseHandle_t, version: *mut c_int) -> cusparseStatus_t;
226
227pub type PFN_cusparseCreateCsr = unsafe extern "C" fn(
228    sp_mat: *mut cusparseSpMatDescr_t,
229    rows: i64,
230    cols: i64,
231    nnz: i64,
232    csr_row_offsets: *mut c_void,
233    csr_col_ind: *mut c_void,
234    csr_values: *mut c_void,
235    csr_row_offsets_type: cusparseIndexType_t,
236    csr_col_ind_type: cusparseIndexType_t,
237    idx_base: cusparseIndexBase_t,
238    value_type: cudaDataType,
239) -> cusparseStatus_t;
240pub type PFN_cusparseDestroySpMat =
241    unsafe extern "C" fn(descr: cusparseSpMatDescr_t) -> cusparseStatus_t;
242
243pub type PFN_cusparseCreateDnVec = unsafe extern "C" fn(
244    descr: *mut cusparseDnVecDescr_t,
245    size: i64,
246    values: *mut c_void,
247    value_type: cudaDataType,
248) -> cusparseStatus_t;
249pub type PFN_cusparseDestroyDnVec =
250    unsafe extern "C" fn(descr: cusparseDnVecDescr_t) -> cusparseStatus_t;
251
252pub type PFN_cusparseSpMV_bufferSize = unsafe extern "C" fn(
253    handle: cusparseHandle_t,
254    op: cusparseOperation_t,
255    alpha: *const c_void,
256    mat_a: cusparseSpMatDescr_t,
257    vec_x: cusparseDnVecDescr_t,
258    beta: *const c_void,
259    vec_y: cusparseDnVecDescr_t,
260    compute_type: cudaDataType,
261    alg: cusparseSpMVAlg_t,
262    buffer_size: *mut usize,
263) -> cusparseStatus_t;
264
265pub type PFN_cusparseSpMV = unsafe extern "C" fn(
266    handle: cusparseHandle_t,
267    op: cusparseOperation_t,
268    alpha: *const c_void,
269    mat_a: cusparseSpMatDescr_t,
270    vec_x: cusparseDnVecDescr_t,
271    beta: *const c_void,
272    vec_y: cusparseDnVecDescr_t,
273    compute_type: cudaDataType,
274    alg: cusparseSpMVAlg_t,
275    external_buffer: *mut c_void,
276) -> cusparseStatus_t;
277
278// ---- CSC / COO / BSR / Dense descriptors ---------------------------------
279
280pub type PFN_cusparseCreateCsc = unsafe extern "C" fn(
281    sp_mat: *mut cusparseSpMatDescr_t,
282    rows: i64,
283    cols: i64,
284    nnz: i64,
285    csc_col_offsets: *mut c_void,
286    csc_row_ind: *mut c_void,
287    csc_values: *mut c_void,
288    csc_col_offsets_type: cusparseIndexType_t,
289    csc_row_ind_type: cusparseIndexType_t,
290    idx_base: cusparseIndexBase_t,
291    value_type: cudaDataType,
292) -> cusparseStatus_t;
293
294pub type PFN_cusparseCreateCoo = unsafe extern "C" fn(
295    sp_mat: *mut cusparseSpMatDescr_t,
296    rows: i64,
297    cols: i64,
298    nnz: i64,
299    coo_row_ind: *mut c_void,
300    coo_col_ind: *mut c_void,
301    coo_values: *mut c_void,
302    coo_idx_type: cusparseIndexType_t,
303    idx_base: cusparseIndexBase_t,
304    value_type: cudaDataType,
305) -> cusparseStatus_t;
306
307pub type PFN_cusparseCreateBsr = unsafe extern "C" fn(
308    sp_mat: *mut cusparseSpMatDescr_t,
309    brows: i64,
310    bcols: i64,
311    bnnz: i64,
312    row_block_dim: i64,
313    col_block_dim: i64,
314    bsr_row_offsets: *mut c_void,
315    bsr_col_ind: *mut c_void,
316    bsr_values: *mut c_void,
317    bsr_row_offsets_type: cusparseIndexType_t,
318    bsr_col_ind_type: cusparseIndexType_t,
319    idx_base: cusparseIndexBase_t,
320    value_type: cudaDataType,
321    order: cusparseOrder_t,
322) -> cusparseStatus_t;
323
324pub type PFN_cusparseCreateDnMat = unsafe extern "C" fn(
325    descr: *mut cusparseDnMatDescr_t,
326    rows: i64,
327    cols: i64,
328    ld: i64,
329    values: *mut c_void,
330    value_type: cudaDataType,
331    order: cusparseOrder_t,
332) -> cusparseStatus_t;
333
334pub type PFN_cusparseDestroyDnMat =
335    unsafe extern "C" fn(descr: cusparseDnMatDescr_t) -> cusparseStatus_t;
336
337pub type PFN_cusparseSpMatGetSize = unsafe extern "C" fn(
338    sp_mat: cusparseSpMatDescr_t,
339    rows: *mut i64,
340    cols: *mut i64,
341    nnz: *mut i64,
342) -> cusparseStatus_t;
343
344pub type PFN_cusparseSpMatSetAttribute = unsafe extern "C" fn(
345    sp_mat: cusparseSpMatDescr_t,
346    attribute: cusparseSpMatAttribute_t,
347    data: *const c_void,
348    data_size: usize,
349) -> cusparseStatus_t;
350
351pub type PFN_cusparseCsrSetPointers = unsafe extern "C" fn(
352    sp_mat: cusparseSpMatDescr_t,
353    csr_row_offsets: *mut c_void,
354    csr_col_ind: *mut c_void,
355    csr_values: *mut c_void,
356) -> cusparseStatus_t;
357
358pub type PFN_cusparseCscSetPointers = unsafe extern "C" fn(
359    sp_mat: cusparseSpMatDescr_t,
360    csc_col_offsets: *mut c_void,
361    csc_row_ind: *mut c_void,
362    csc_values: *mut c_void,
363) -> cusparseStatus_t;
364
365pub type PFN_cusparseCooSetPointers = unsafe extern "C" fn(
366    sp_mat: cusparseSpMatDescr_t,
367    coo_row_ind: *mut c_void,
368    coo_col_ind: *mut c_void,
369    coo_values: *mut c_void,
370) -> cusparseStatus_t;
371
372// ---- SpMM (sparse × dense = dense) ---------------------------------------
373
374pub type PFN_cusparseSpMM_bufferSize = unsafe extern "C" fn(
375    handle: cusparseHandle_t,
376    op_a: cusparseOperation_t,
377    op_b: cusparseOperation_t,
378    alpha: *const c_void,
379    mat_a: cusparseSpMatDescr_t,
380    mat_b: cusparseDnMatDescr_t,
381    beta: *const c_void,
382    mat_c: cusparseDnMatDescr_t,
383    compute_type: cudaDataType,
384    alg: cusparseSpMMAlg_t,
385    buffer_size: *mut usize,
386) -> cusparseStatus_t;
387
388pub type PFN_cusparseSpMM_preprocess = unsafe extern "C" fn(
389    handle: cusparseHandle_t,
390    op_a: cusparseOperation_t,
391    op_b: cusparseOperation_t,
392    alpha: *const c_void,
393    mat_a: cusparseSpMatDescr_t,
394    mat_b: cusparseDnMatDescr_t,
395    beta: *const c_void,
396    mat_c: cusparseDnMatDescr_t,
397    compute_type: cudaDataType,
398    alg: cusparseSpMMAlg_t,
399    external_buffer: *mut c_void,
400) -> cusparseStatus_t;
401
402pub type PFN_cusparseSpMM = unsafe extern "C" fn(
403    handle: cusparseHandle_t,
404    op_a: cusparseOperation_t,
405    op_b: cusparseOperation_t,
406    alpha: *const c_void,
407    mat_a: cusparseSpMatDescr_t,
408    mat_b: cusparseDnMatDescr_t,
409    beta: *const c_void,
410    mat_c: cusparseDnMatDescr_t,
411    compute_type: cudaDataType,
412    alg: cusparseSpMMAlg_t,
413    external_buffer: *mut c_void,
414) -> cusparseStatus_t;
415
416// ---- SpGEMM (sparse × sparse = sparse) -----------------------------------
417
418pub type PFN_cusparseSpGEMM_createDescr =
419    unsafe extern "C" fn(descr: *mut cusparseSpGEMMDescr_t) -> cusparseStatus_t;
420pub type PFN_cusparseSpGEMM_destroyDescr =
421    unsafe extern "C" fn(descr: cusparseSpGEMMDescr_t) -> cusparseStatus_t;
422
423pub type PFN_cusparseSpGEMM_workEstimation = unsafe extern "C" fn(
424    handle: cusparseHandle_t,
425    op_a: cusparseOperation_t,
426    op_b: cusparseOperation_t,
427    alpha: *const c_void,
428    mat_a: cusparseSpMatDescr_t,
429    mat_b: cusparseSpMatDescr_t,
430    beta: *const c_void,
431    mat_c: cusparseSpMatDescr_t,
432    compute_type: cudaDataType,
433    alg: cusparseSpGEMMAlg_t,
434    descr: cusparseSpGEMMDescr_t,
435    buffer_size1: *mut usize,
436    external_buffer1: *mut c_void,
437) -> cusparseStatus_t;
438
439pub type PFN_cusparseSpGEMM_compute = unsafe extern "C" fn(
440    handle: cusparseHandle_t,
441    op_a: cusparseOperation_t,
442    op_b: cusparseOperation_t,
443    alpha: *const c_void,
444    mat_a: cusparseSpMatDescr_t,
445    mat_b: cusparseSpMatDescr_t,
446    beta: *const c_void,
447    mat_c: cusparseSpMatDescr_t,
448    compute_type: cudaDataType,
449    alg: cusparseSpGEMMAlg_t,
450    descr: cusparseSpGEMMDescr_t,
451    buffer_size2: *mut usize,
452    external_buffer2: *mut c_void,
453) -> cusparseStatus_t;
454
455pub type PFN_cusparseSpGEMM_copy = unsafe extern "C" fn(
456    handle: cusparseHandle_t,
457    op_a: cusparseOperation_t,
458    op_b: cusparseOperation_t,
459    alpha: *const c_void,
460    mat_a: cusparseSpMatDescr_t,
461    mat_b: cusparseSpMatDescr_t,
462    beta: *const c_void,
463    mat_c: cusparseSpMatDescr_t,
464    compute_type: cudaDataType,
465    alg: cusparseSpGEMMAlg_t,
466    descr: cusparseSpGEMMDescr_t,
467) -> cusparseStatus_t;
468
469// ---- SpSV (sparse triangular solve, vector) ------------------------------
470
471pub type PFN_cusparseSpSV_createDescr =
472    unsafe extern "C" fn(descr: *mut cusparseSpSVDescr_t) -> cusparseStatus_t;
473pub type PFN_cusparseSpSV_destroyDescr =
474    unsafe extern "C" fn(descr: cusparseSpSVDescr_t) -> cusparseStatus_t;
475
476pub type PFN_cusparseSpSV_bufferSize = unsafe extern "C" fn(
477    handle: cusparseHandle_t,
478    op_a: cusparseOperation_t,
479    alpha: *const c_void,
480    mat_a: cusparseSpMatDescr_t,
481    vec_x: cusparseDnVecDescr_t,
482    vec_y: cusparseDnVecDescr_t,
483    compute_type: cudaDataType,
484    alg: cusparseSpSVAlg_t,
485    descr: cusparseSpSVDescr_t,
486    buffer_size: *mut usize,
487) -> cusparseStatus_t;
488
489pub type PFN_cusparseSpSV_analysis = unsafe extern "C" fn(
490    handle: cusparseHandle_t,
491    op_a: cusparseOperation_t,
492    alpha: *const c_void,
493    mat_a: cusparseSpMatDescr_t,
494    vec_x: cusparseDnVecDescr_t,
495    vec_y: cusparseDnVecDescr_t,
496    compute_type: cudaDataType,
497    alg: cusparseSpSVAlg_t,
498    descr: cusparseSpSVDescr_t,
499    external_buffer: *mut c_void,
500) -> cusparseStatus_t;
501
502pub type PFN_cusparseSpSV_solve = unsafe extern "C" fn(
503    handle: cusparseHandle_t,
504    op_a: cusparseOperation_t,
505    alpha: *const c_void,
506    mat_a: cusparseSpMatDescr_t,
507    vec_x: cusparseDnVecDescr_t,
508    vec_y: cusparseDnVecDescr_t,
509    compute_type: cudaDataType,
510    alg: cusparseSpSVAlg_t,
511    descr: cusparseSpSVDescr_t,
512) -> cusparseStatus_t;
513
514// ---- SpSM (sparse triangular solve, matrix) ------------------------------
515
516pub type PFN_cusparseSpSM_createDescr =
517    unsafe extern "C" fn(descr: *mut cusparseSpSMDescr_t) -> cusparseStatus_t;
518pub type PFN_cusparseSpSM_destroyDescr =
519    unsafe extern "C" fn(descr: cusparseSpSMDescr_t) -> cusparseStatus_t;
520
521pub type PFN_cusparseSpSM_bufferSize = unsafe extern "C" fn(
522    handle: cusparseHandle_t,
523    op_a: cusparseOperation_t,
524    op_b: cusparseOperation_t,
525    alpha: *const c_void,
526    mat_a: cusparseSpMatDescr_t,
527    mat_b: cusparseDnMatDescr_t,
528    mat_c: cusparseDnMatDescr_t,
529    compute_type: cudaDataType,
530    alg: cusparseSpSMAlg_t,
531    descr: cusparseSpSMDescr_t,
532    buffer_size: *mut usize,
533) -> cusparseStatus_t;
534
535pub type PFN_cusparseSpSM_analysis = unsafe extern "C" fn(
536    handle: cusparseHandle_t,
537    op_a: cusparseOperation_t,
538    op_b: cusparseOperation_t,
539    alpha: *const c_void,
540    mat_a: cusparseSpMatDescr_t,
541    mat_b: cusparseDnMatDescr_t,
542    mat_c: cusparseDnMatDescr_t,
543    compute_type: cudaDataType,
544    alg: cusparseSpSMAlg_t,
545    descr: cusparseSpSMDescr_t,
546    external_buffer: *mut c_void,
547) -> cusparseStatus_t;
548
549pub type PFN_cusparseSpSM_solve = unsafe extern "C" fn(
550    handle: cusparseHandle_t,
551    op_a: cusparseOperation_t,
552    op_b: cusparseOperation_t,
553    alpha: *const c_void,
554    mat_a: cusparseSpMatDescr_t,
555    mat_b: cusparseDnMatDescr_t,
556    mat_c: cusparseDnMatDescr_t,
557    compute_type: cudaDataType,
558    alg: cusparseSpSMAlg_t,
559    descr: cusparseSpSMDescr_t,
560) -> cusparseStatus_t;
561
562// ---- SDDMM (sampled dense-dense matmul) ----------------------------------
563
564pub type PFN_cusparseSDDMM_bufferSize = unsafe extern "C" fn(
565    handle: cusparseHandle_t,
566    op_a: cusparseOperation_t,
567    op_b: cusparseOperation_t,
568    alpha: *const c_void,
569    mat_a: cusparseDnMatDescr_t,
570    mat_b: cusparseDnMatDescr_t,
571    beta: *const c_void,
572    mat_c: cusparseSpMatDescr_t,
573    compute_type: cudaDataType,
574    alg: cusparseSDDMMAlg_t,
575    buffer_size: *mut usize,
576) -> cusparseStatus_t;
577
578pub type PFN_cusparseSDDMM_preprocess = unsafe extern "C" fn(
579    handle: cusparseHandle_t,
580    op_a: cusparseOperation_t,
581    op_b: cusparseOperation_t,
582    alpha: *const c_void,
583    mat_a: cusparseDnMatDescr_t,
584    mat_b: cusparseDnMatDescr_t,
585    beta: *const c_void,
586    mat_c: cusparseSpMatDescr_t,
587    compute_type: cudaDataType,
588    alg: cusparseSDDMMAlg_t,
589    external_buffer: *mut c_void,
590) -> cusparseStatus_t;
591
592pub type PFN_cusparseSDDMM = unsafe extern "C" fn(
593    handle: cusparseHandle_t,
594    op_a: cusparseOperation_t,
595    op_b: cusparseOperation_t,
596    alpha: *const c_void,
597    mat_a: cusparseDnMatDescr_t,
598    mat_b: cusparseDnMatDescr_t,
599    beta: *const c_void,
600    mat_c: cusparseSpMatDescr_t,
601    compute_type: cudaDataType,
602    alg: cusparseSDDMMAlg_t,
603    external_buffer: *mut c_void,
604) -> cusparseStatus_t;
605
606// ---- CSR ↔ CSC conversion -------------------------------------------------
607
608pub type PFN_cusparseCsr2cscEx2_bufferSize = unsafe extern "C" fn(
609    handle: cusparseHandle_t,
610    m: c_int,
611    n: c_int,
612    nnz: c_int,
613    csr_val: *const c_void,
614    csr_row_ptr: *const c_int,
615    csr_col_ind: *const c_int,
616    csc_val: *mut c_void,
617    csc_col_ptr: *mut c_int,
618    csc_row_ind: *mut c_int,
619    value_type: cudaDataType,
620    copy_values: c_int,
621    idx_base: cusparseIndexBase_t,
622    alg: cusparseCsr2CscAlg_t,
623    buffer_size: *mut usize,
624) -> cusparseStatus_t;
625
626pub type PFN_cusparseCsr2cscEx2 = unsafe extern "C" fn(
627    handle: cusparseHandle_t,
628    m: c_int,
629    n: c_int,
630    nnz: c_int,
631    csr_val: *const c_void,
632    csr_row_ptr: *const c_int,
633    csr_col_ind: *const c_int,
634    csc_val: *mut c_void,
635    csc_col_ptr: *mut c_int,
636    csc_row_ind: *mut c_int,
637    value_type: cudaDataType,
638    copy_values: c_int,
639    idx_base: cusparseIndexBase_t,
640    alg: cusparseCsr2CscAlg_t,
641    buffer: *mut c_void,
642) -> cusparseStatus_t;
643
644// ---- Sparse↔Dense conversion ---------------------------------------------
645
646pub type PFN_cusparseSparseToDense_bufferSize = unsafe extern "C" fn(
647    handle: cusparseHandle_t,
648    mat_a: cusparseSpMatDescr_t,
649    mat_b: cusparseDnMatDescr_t,
650    alg: c_int,
651    buffer_size: *mut usize,
652) -> cusparseStatus_t;
653
654pub type PFN_cusparseSparseToDense = unsafe extern "C" fn(
655    handle: cusparseHandle_t,
656    mat_a: cusparseSpMatDescr_t,
657    mat_b: cusparseDnMatDescr_t,
658    alg: c_int,
659    external_buffer: *mut c_void,
660) -> cusparseStatus_t;
661
662pub type PFN_cusparseDenseToSparse_bufferSize = unsafe extern "C" fn(
663    handle: cusparseHandle_t,
664    mat_a: cusparseDnMatDescr_t,
665    mat_b: cusparseSpMatDescr_t,
666    alg: c_int,
667    buffer_size: *mut usize,
668) -> cusparseStatus_t;
669
670pub type PFN_cusparseDenseToSparse_analysis = unsafe extern "C" fn(
671    handle: cusparseHandle_t,
672    mat_a: cusparseDnMatDescr_t,
673    mat_b: cusparseSpMatDescr_t,
674    alg: c_int,
675    external_buffer: *mut c_void,
676) -> cusparseStatus_t;
677
678pub type PFN_cusparseDenseToSparse_convert = unsafe extern "C" fn(
679    handle: cusparseHandle_t,
680    mat_a: cusparseDnMatDescr_t,
681    mat_b: cusparseSpMatDescr_t,
682    alg: c_int,
683    external_buffer: *mut c_void,
684) -> cusparseStatus_t;
685
686// ---- Axpby / Gather / Scatter / Rot (sparse BLAS L1) --------------------
687
688pub type PFN_cusparseAxpby = unsafe extern "C" fn(
689    handle: cusparseHandle_t,
690    alpha: *const c_void,
691    vec_x: cusparseDnVecDescr_t,
692    beta: *const c_void,
693    vec_y: cusparseDnVecDescr_t,
694) -> cusparseStatus_t;
695
696pub type PFN_cusparseGather = unsafe extern "C" fn(
697    handle: cusparseHandle_t,
698    vec_y: cusparseDnVecDescr_t,
699    vec_x: cusparseDnVecDescr_t,
700) -> cusparseStatus_t;
701
702pub type PFN_cusparseScatter = unsafe extern "C" fn(
703    handle: cusparseHandle_t,
704    vec_x: cusparseDnVecDescr_t,
705    vec_y: cusparseDnVecDescr_t,
706) -> cusparseStatus_t;
707
708pub type PFN_cusparseRot = unsafe extern "C" fn(
709    handle: cusparseHandle_t,
710    c: *const c_void,
711    s: *const c_void,
712    vec_x: cusparseDnVecDescr_t,
713    vec_y: cusparseDnVecDescr_t,
714) -> cusparseStatus_t;
715
716// ---- loader --------------------------------------------------------------
717
718fn cusparse_candidates() -> Vec<String> {
719    platform::versioned_library_candidates("cusparse", &["13", "12", "11"])
720}
721
722macro_rules! cusparse_fns {
723    ($($name:ident as $sym:literal : $pfn:ty);* $(;)?) => {
724        pub struct Cusparse {
725            lib: Library,
726            $($name: OnceLock<$pfn>,)*
727        }
728        impl core::fmt::Debug for Cusparse {
729            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
730                f.debug_struct("Cusparse").field("lib", &self.lib).finish_non_exhaustive()
731            }
732        }
733        impl Cusparse {
734            $(
735                pub fn $name(&self) -> Result<$pfn, LoaderError> {
736                    if let Some(&p) = self.$name.get() { return Ok(p); }
737                    let raw: *mut () = unsafe { self.lib.raw_symbol($sym)? };
738                    let p: $pfn = unsafe { core::mem::transmute_copy::<*mut (), $pfn>(&raw) };
739                    let _ = self.$name.set(p);
740                    Ok(p)
741                }
742            )*
743            fn empty(lib: Library) -> Self {
744                Self { lib, $($name: OnceLock::new(),)* }
745            }
746        }
747    };
748}
749
750cusparse_fns! {
751    cusparse_create as "cusparseCreate": PFN_cusparseCreate;
752    cusparse_destroy as "cusparseDestroy": PFN_cusparseDestroy;
753    cusparse_set_stream as "cusparseSetStream": PFN_cusparseSetStream;
754    cusparse_get_version as "cusparseGetVersion": PFN_cusparseGetVersion;
755    // Sparse-matrix descriptors
756    cusparse_create_csr as "cusparseCreateCsr": PFN_cusparseCreateCsr;
757    cusparse_create_csc as "cusparseCreateCsc": PFN_cusparseCreateCsc;
758    cusparse_create_coo as "cusparseCreateCoo": PFN_cusparseCreateCoo;
759    cusparse_create_bsr as "cusparseCreateBsr": PFN_cusparseCreateBsr;
760    cusparse_destroy_sp_mat as "cusparseDestroySpMat": PFN_cusparseDestroySpMat;
761    cusparse_sp_mat_get_size as "cusparseSpMatGetSize": PFN_cusparseSpMatGetSize;
762    cusparse_sp_mat_set_attribute as "cusparseSpMatSetAttribute": PFN_cusparseSpMatSetAttribute;
763    cusparse_csr_set_pointers as "cusparseCsrSetPointers": PFN_cusparseCsrSetPointers;
764    cusparse_csc_set_pointers as "cusparseCscSetPointers": PFN_cusparseCscSetPointers;
765    cusparse_coo_set_pointers as "cusparseCooSetPointers": PFN_cusparseCooSetPointers;
766    // Dense descriptors
767    cusparse_create_dn_vec as "cusparseCreateDnVec": PFN_cusparseCreateDnVec;
768    cusparse_destroy_dn_vec as "cusparseDestroyDnVec": PFN_cusparseDestroyDnVec;
769    cusparse_create_dn_mat as "cusparseCreateDnMat": PFN_cusparseCreateDnMat;
770    cusparse_destroy_dn_mat as "cusparseDestroyDnMat": PFN_cusparseDestroyDnMat;
771    // SpMV
772    cusparse_spmv_buffer_size as "cusparseSpMV_bufferSize": PFN_cusparseSpMV_bufferSize;
773    cusparse_spmv as "cusparseSpMV": PFN_cusparseSpMV;
774    // SpMM
775    cusparse_spmm_buffer_size as "cusparseSpMM_bufferSize": PFN_cusparseSpMM_bufferSize;
776    cusparse_spmm_preprocess as "cusparseSpMM_preprocess": PFN_cusparseSpMM_preprocess;
777    cusparse_spmm as "cusparseSpMM": PFN_cusparseSpMM;
778    // SpGEMM
779    cusparse_spgemm_create_descr as "cusparseSpGEMM_createDescr": PFN_cusparseSpGEMM_createDescr;
780    cusparse_spgemm_destroy_descr as "cusparseSpGEMM_destroyDescr": PFN_cusparseSpGEMM_destroyDescr;
781    cusparse_spgemm_work_estimation as "cusparseSpGEMM_workEstimation": PFN_cusparseSpGEMM_workEstimation;
782    cusparse_spgemm_compute as "cusparseSpGEMM_compute": PFN_cusparseSpGEMM_compute;
783    cusparse_spgemm_copy as "cusparseSpGEMM_copy": PFN_cusparseSpGEMM_copy;
784    // SpSV
785    cusparse_spsv_create_descr as "cusparseSpSV_createDescr": PFN_cusparseSpSV_createDescr;
786    cusparse_spsv_destroy_descr as "cusparseSpSV_destroyDescr": PFN_cusparseSpSV_destroyDescr;
787    cusparse_spsv_buffer_size as "cusparseSpSV_bufferSize": PFN_cusparseSpSV_bufferSize;
788    cusparse_spsv_analysis as "cusparseSpSV_analysis": PFN_cusparseSpSV_analysis;
789    cusparse_spsv_solve as "cusparseSpSV_solve": PFN_cusparseSpSV_solve;
790    // SpSM
791    cusparse_spsm_create_descr as "cusparseSpSM_createDescr": PFN_cusparseSpSM_createDescr;
792    cusparse_spsm_destroy_descr as "cusparseSpSM_destroyDescr": PFN_cusparseSpSM_destroyDescr;
793    cusparse_spsm_buffer_size as "cusparseSpSM_bufferSize": PFN_cusparseSpSM_bufferSize;
794    cusparse_spsm_analysis as "cusparseSpSM_analysis": PFN_cusparseSpSM_analysis;
795    cusparse_spsm_solve as "cusparseSpSM_solve": PFN_cusparseSpSM_solve;
796    // SDDMM
797    cusparse_sddmm_buffer_size as "cusparseSDDMM_bufferSize": PFN_cusparseSDDMM_bufferSize;
798    cusparse_sddmm_preprocess as "cusparseSDDMM_preprocess": PFN_cusparseSDDMM_preprocess;
799    cusparse_sddmm as "cusparseSDDMM": PFN_cusparseSDDMM;
800    // CSR ↔ CSC
801    cusparse_csr2csc_ex2_buffer_size as "cusparseCsr2cscEx2_bufferSize": PFN_cusparseCsr2cscEx2_bufferSize;
802    cusparse_csr2csc_ex2 as "cusparseCsr2cscEx2": PFN_cusparseCsr2cscEx2;
803    // Sparse ↔ Dense
804    cusparse_sparse_to_dense_buffer_size as "cusparseSparseToDense_bufferSize": PFN_cusparseSparseToDense_bufferSize;
805    cusparse_sparse_to_dense as "cusparseSparseToDense": PFN_cusparseSparseToDense;
806    cusparse_dense_to_sparse_buffer_size as "cusparseDenseToSparse_bufferSize": PFN_cusparseDenseToSparse_bufferSize;
807    cusparse_dense_to_sparse_analysis as "cusparseDenseToSparse_analysis": PFN_cusparseDenseToSparse_analysis;
808    cusparse_dense_to_sparse_convert as "cusparseDenseToSparse_convert": PFN_cusparseDenseToSparse_convert;
809    // Sparse BLAS L1
810    cusparse_axpby as "cusparseAxpby": PFN_cusparseAxpby;
811    cusparse_gather as "cusparseGather": PFN_cusparseGather;
812    cusparse_scatter as "cusparseScatter": PFN_cusparseScatter;
813    cusparse_rot as "cusparseRot": PFN_cusparseRot;
814}
815
816pub fn cusparse() -> Result<&'static Cusparse, LoaderError> {
817    static CUSPARSE: OnceLock<Cusparse> = OnceLock::new();
818    if let Some(c) = CUSPARSE.get() {
819        return Ok(c);
820    }
821    let candidates: Vec<&'static str> = cusparse_candidates()
822        .into_iter()
823        .map(|s| Box::leak(s.into_boxed_str()) as &'static str)
824        .collect();
825    let candidates_leaked: &'static [&'static str] = Box::leak(candidates.into_boxed_slice());
826    let lib = Library::open("cusparse", candidates_leaked)?;
827    let c = Cusparse::empty(lib);
828    let _ = CUSPARSE.set(c);
829    Ok(CUSPARSE.get().expect("OnceLock set or lost race"))
830}