Skip to main content

baracuda_cusparse/
lib.rs

1//! Safe Rust wrappers for NVIDIA cuSPARSE.
2//!
3//! Covers the modern generic-API surface: `Handle`, `SpMat` (CSR/CSC/COO/BSR),
4//! `DnMat`, `DnVec`, and the family of op algorithms — SpMV, SpMM, SpGEMM,
5//! SpSV, SpSM, SDDMM — plus CSR↔CSC and sparse↔dense conversions, and the
6//! sparse BLAS-1 helpers (`axpby`, `gather`, `scatter`, `rot`).
7//!
8//! All matrix/vector descriptors borrow the underlying [`DeviceBuffer`]s and
9//! tie their lifetime to them so the buffers can't be freed while cuSPARSE
10//! is still holding references.
11
12#![warn(missing_debug_implementations)]
13
14use core::ffi::c_void;
15use std::marker::PhantomData;
16
17use baracuda_cusparse_sys::{
18    cudaDataType, cusparse, cusparseDiagType_t, cusparseDnMatDescr_t, cusparseDnVecDescr_t,
19    cusparseFillMode_t, cusparseHandle_t, cusparseIndexBase_t, cusparseIndexType_t,
20    cusparseOperation_t, cusparseOrder_t, cusparseSpGEMMDescr_t, cusparseSpMatAttribute_t,
21    cusparseSpMatDescr_t, cusparseSpSMDescr_t, cusparseSpSVDescr_t, cusparseStatus_t,
22};
23use baracuda_driver::{DeviceBuffer, Stream};
24use baracuda_types::{Complex32, Complex64};
25
26pub use baracuda_cusparse_sys::{
27    cusparseCsr2CscAlg_t as Csr2CscAlg, cusparseIndexBase_t as IndexBase,
28    cusparseSDDMMAlg_t as SDDMMAlg, cusparseSpGEMMAlg_t as SpGEMMAlg,
29    cusparseSpMMAlg_t as SpMMAlg, cusparseSpMVAlg_t as SpMVAlg, cusparseSpSMAlg_t as SpSMAlg,
30    cusparseSpSVAlg_t as SpSVAlg,
31};
32
33/// Error type for cuSPARSE operations.
34pub type Error = baracuda_core::Error<cusparseStatus_t>;
35/// Result alias.
36pub type Result<T, E = Error> = core::result::Result<T, E>;
37
38#[inline]
39fn check(status: cusparseStatus_t) -> Result<()> {
40    Error::check(status)
41}
42
43// ---- scalar <-> cudaDataType --------------------------------------------
44
45/// Types that cuSPARSE's generic API accepts as element / compute type.
46pub trait SparseScalar: sealed::Sealed + Copy + 'static {
47    /// cuSPARSE / cuBLAS element-type tag.
48    fn data_type() -> cudaDataType;
49}
50
51impl SparseScalar for f32 {
52    fn data_type() -> cudaDataType {
53        cudaDataType::R_32F
54    }
55}
56impl SparseScalar for f64 {
57    fn data_type() -> cudaDataType {
58        cudaDataType::R_64F
59    }
60}
61impl SparseScalar for Complex32 {
62    fn data_type() -> cudaDataType {
63        cudaDataType::C_32F
64    }
65}
66impl SparseScalar for Complex64 {
67    fn data_type() -> cudaDataType {
68        cudaDataType::C_64F
69    }
70}
71
72mod sealed {
73    use baracuda_types::{Complex32, Complex64};
74    pub trait Sealed {}
75    impl Sealed for f32 {}
76    impl Sealed for f64 {}
77    impl Sealed for Complex32 {}
78    impl Sealed for Complex64 {}
79}
80
81// ---- Op / Order / Fill / Diag wrappers ----------------------------------
82
83/// Transpose / conjugate-transpose selector.
84#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
85pub enum Op {
86    #[default]
87    N,
88    T,
89    C,
90}
91
92impl Op {
93    fn raw(self) -> cusparseOperation_t {
94        match self {
95            Op::N => cusparseOperation_t::N,
96            Op::T => cusparseOperation_t::T,
97            Op::C => cusparseOperation_t::C,
98        }
99    }
100}
101
102/// Dense-matrix storage order.
103#[derive(Copy, Clone, Debug, Eq, PartialEq)]
104pub enum Order {
105    Row,
106    Col,
107}
108
109impl Order {
110    fn raw(self) -> cusparseOrder_t {
111        match self {
112            Order::Row => cusparseOrder_t::Row,
113            Order::Col => cusparseOrder_t::Col,
114        }
115    }
116}
117
118#[derive(Copy, Clone, Debug, Eq, PartialEq)]
119pub enum Fill {
120    Lower,
121    Upper,
122}
123
124impl Fill {
125    fn raw(self) -> cusparseFillMode_t {
126        match self {
127            Fill::Lower => cusparseFillMode_t::Lower,
128            Fill::Upper => cusparseFillMode_t::Upper,
129        }
130    }
131}
132
133#[derive(Copy, Clone, Debug, Eq, PartialEq)]
134pub enum Diag {
135    NonUnit,
136    Unit,
137}
138
139impl Diag {
140    fn raw(self) -> cusparseDiagType_t {
141        match self {
142            Diag::NonUnit => cusparseDiagType_t::NonUnit,
143            Diag::Unit => cusparseDiagType_t::Unit,
144        }
145    }
146}
147
148// ---- Handle -------------------------------------------------------------
149
150/// Owned cuSPARSE handle.
151pub struct Handle {
152    handle: cusparseHandle_t,
153}
154
155unsafe impl Send for Handle {}
156
157impl core::fmt::Debug for Handle {
158    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
159        f.debug_struct("cusparse::Handle")
160            .field("handle", &self.handle)
161            .finish()
162    }
163}
164
165impl Handle {
166    pub fn new() -> Result<Self> {
167        let c = cusparse()?;
168        let cu = c.cusparse_create()?;
169        let mut h: cusparseHandle_t = core::ptr::null_mut();
170        check(unsafe { cu(&mut h) })?;
171        Ok(Self { handle: h })
172    }
173
174    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
175        let c = cusparse()?;
176        let cu = c.cusparse_set_stream()?;
177        check(unsafe { cu(self.handle, stream.as_raw() as _) })
178    }
179
180    pub fn version(&self) -> Result<i32> {
181        let c = cusparse()?;
182        let cu = c.cusparse_get_version()?;
183        let mut v: core::ffi::c_int = 0;
184        check(unsafe { cu(self.handle, &mut v) })?;
185        Ok(v)
186    }
187
188    #[inline]
189    pub fn as_raw(&self) -> cusparseHandle_t {
190        self.handle
191    }
192}
193
194impl Drop for Handle {
195    fn drop(&mut self) {
196        if let Ok(c) = cusparse() {
197            if let Ok(cu) = c.cusparse_destroy() {
198                let _ = unsafe { cu(self.handle) };
199            }
200        }
201    }
202}
203
204// ---- Sparse matrix descriptor -------------------------------------------
205
206/// A sparse-matrix descriptor (CSR / CSC / COO / BSR). The descriptor keeps
207/// pointers to externally-owned device buffers; the lifetime parameter ties
208/// those buffers to the descriptor.
209pub struct SpMat<'buf, T> {
210    descr: cusparseSpMatDescr_t,
211    _markers: PhantomData<&'buf mut T>,
212}
213
214unsafe impl<T> Send for SpMat<'_, T> {}
215
216impl<T> core::fmt::Debug for SpMat<'_, T> {
217    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
218        f.debug_struct("SpMat")
219            .field("descr", &self.descr)
220            .finish_non_exhaustive()
221    }
222}
223
224impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> SpMat<'buf, T> {
225    /// Build a CSR (compressed sparse row) descriptor.
226    ///
227    /// `row_offsets.len()` must equal `rows + 1`; `col_indices.len()` and
228    /// `values.len()` must equal `nnz`.
229    pub fn csr(
230        rows: i64,
231        cols: i64,
232        nnz: i64,
233        row_offsets: &'buf mut DeviceBuffer<i32>,
234        col_indices: &'buf mut DeviceBuffer<i32>,
235        values: &'buf mut DeviceBuffer<T>,
236    ) -> Result<Self> {
237        let c = cusparse()?;
238        let cu = c.cusparse_create_csr()?;
239        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
240        check(unsafe {
241            cu(
242                &mut descr,
243                rows,
244                cols,
245                nnz,
246                row_offsets.as_raw().0 as *mut c_void,
247                col_indices.as_raw().0 as *mut c_void,
248                values.as_raw().0 as *mut c_void,
249                cusparseIndexType_t::I32I,
250                cusparseIndexType_t::I32I,
251                cusparseIndexBase_t::Zero,
252                T::data_type(),
253            )
254        })?;
255        Ok(Self {
256            descr,
257            _markers: PhantomData,
258        })
259    }
260
261    /// Build a CSC (compressed sparse column) descriptor.
262    pub fn csc(
263        rows: i64,
264        cols: i64,
265        nnz: i64,
266        col_offsets: &'buf mut DeviceBuffer<i32>,
267        row_indices: &'buf mut DeviceBuffer<i32>,
268        values: &'buf mut DeviceBuffer<T>,
269    ) -> Result<Self> {
270        let c = cusparse()?;
271        let cu = c.cusparse_create_csc()?;
272        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
273        check(unsafe {
274            cu(
275                &mut descr,
276                rows,
277                cols,
278                nnz,
279                col_offsets.as_raw().0 as *mut c_void,
280                row_indices.as_raw().0 as *mut c_void,
281                values.as_raw().0 as *mut c_void,
282                cusparseIndexType_t::I32I,
283                cusparseIndexType_t::I32I,
284                cusparseIndexBase_t::Zero,
285                T::data_type(),
286            )
287        })?;
288        Ok(Self {
289            descr,
290            _markers: PhantomData,
291        })
292    }
293
294    /// Build a BSR (block-sparse-row) descriptor.
295    #[allow(clippy::too_many_arguments)]
296    pub fn bsr(
297        brows: i64,
298        bcols: i64,
299        bnnz: i64,
300        row_block_dim: i64,
301        col_block_dim: i64,
302        order: Order,
303        row_offsets: &'buf mut DeviceBuffer<i32>,
304        col_indices: &'buf mut DeviceBuffer<i32>,
305        values: &'buf mut DeviceBuffer<T>,
306    ) -> Result<Self> {
307        let c = cusparse()?;
308        let cu = c.cusparse_create_bsr()?;
309        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
310        check(unsafe {
311            cu(
312                &mut descr,
313                brows,
314                bcols,
315                bnnz,
316                row_block_dim,
317                col_block_dim,
318                row_offsets.as_raw().0 as *mut c_void,
319                col_indices.as_raw().0 as *mut c_void,
320                values.as_raw().0 as *mut c_void,
321                cusparseIndexType_t::I32I,
322                cusparseIndexType_t::I32I,
323                cusparseIndexBase_t::Zero,
324                T::data_type(),
325                order.raw(),
326            )
327        })?;
328        Ok(Self {
329            descr,
330            _markers: PhantomData,
331        })
332    }
333
334    /// Build a COO (coordinate) descriptor.
335    pub fn coo(
336        rows: i64,
337        cols: i64,
338        nnz: i64,
339        row_indices: &'buf mut DeviceBuffer<i32>,
340        col_indices: &'buf mut DeviceBuffer<i32>,
341        values: &'buf mut DeviceBuffer<T>,
342    ) -> Result<Self> {
343        let c = cusparse()?;
344        let cu = c.cusparse_create_coo()?;
345        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
346        check(unsafe {
347            cu(
348                &mut descr,
349                rows,
350                cols,
351                nnz,
352                row_indices.as_raw().0 as *mut c_void,
353                col_indices.as_raw().0 as *mut c_void,
354                values.as_raw().0 as *mut c_void,
355                cusparseIndexType_t::I32I,
356                cusparseIndexBase_t::Zero,
357                T::data_type(),
358            )
359        })?;
360        Ok(Self {
361            descr,
362            _markers: PhantomData,
363        })
364    }
365}
366
367impl<T> SpMat<'_, T> {
368    /// Sparse matrix dimensions: `(rows, cols, nnz)`.
369    pub fn shape(&self) -> Result<(i64, i64, i64)> {
370        let c = cusparse()?;
371        let cu = c.cusparse_sp_mat_get_size()?;
372        let (mut r, mut col, mut nz) = (0i64, 0i64, 0i64);
373        check(unsafe { cu(self.descr, &mut r, &mut col, &mut nz) })?;
374        Ok((r, col, nz))
375    }
376
377    /// Rebind a CSR descriptor's underlying device pointers without
378    /// rebuilding it. Saves descriptor recreation when the same shape
379    /// is reused with new data.
380    ///
381    /// # Safety
382    ///
383    /// All three pointers must be live device allocations matching the
384    /// original `(rows + 1, nnz, nnz)` element counts and the original
385    /// element types. They must stay valid until the next operation
386    /// on this descriptor completes.
387    pub unsafe fn set_csr_pointers(
388        &self,
389        row_offsets: *mut c_void,
390        col_indices: *mut c_void,
391        values: *mut c_void,
392    ) -> Result<()> {
393        let c = cusparse()?;
394        let cu = c.cusparse_csr_set_pointers()?;
395        check(cu(self.descr, row_offsets, col_indices, values))
396    }
397
398    /// Rebind a CSC descriptor's underlying device pointers.
399    ///
400    /// # Safety
401    ///
402    /// See [`Self::set_csr_pointers`].
403    pub unsafe fn set_csc_pointers(
404        &self,
405        col_offsets: *mut c_void,
406        row_indices: *mut c_void,
407        values: *mut c_void,
408    ) -> Result<()> {
409        let c = cusparse()?;
410        let cu = c.cusparse_csc_set_pointers()?;
411        check(cu(self.descr, col_offsets, row_indices, values))
412    }
413
414    /// Rebind a COO descriptor's underlying device pointers.
415    ///
416    /// # Safety
417    ///
418    /// See [`Self::set_csr_pointers`].
419    pub unsafe fn set_coo_pointers(
420        &self,
421        row_indices: *mut c_void,
422        col_indices: *mut c_void,
423        values: *mut c_void,
424    ) -> Result<()> {
425        let c = cusparse()?;
426        let cu = c.cusparse_coo_set_pointers()?;
427        check(cu(self.descr, row_indices, col_indices, values))
428    }
429
430    /// Set the fill-triangle attribute (for triangular solves).
431    pub fn set_fill(&self, fill: Fill) -> Result<()> {
432        let c = cusparse()?;
433        let cu = c.cusparse_sp_mat_set_attribute()?;
434        let raw = fill.raw();
435        check(unsafe {
436            cu(
437                self.descr,
438                cusparseSpMatAttribute_t::FillMode,
439                &raw as *const _ as *const c_void,
440                core::mem::size_of::<cusparseFillMode_t>(),
441            )
442        })
443    }
444
445    /// Set the diagonal-type attribute (unit vs non-unit, for triangular solves).
446    pub fn set_diag(&self, diag: Diag) -> Result<()> {
447        let c = cusparse()?;
448        let cu = c.cusparse_sp_mat_set_attribute()?;
449        let raw = diag.raw();
450        check(unsafe {
451            cu(
452                self.descr,
453                cusparseSpMatAttribute_t::DiagType,
454                &raw as *const _ as *const c_void,
455                core::mem::size_of::<cusparseDiagType_t>(),
456            )
457        })
458    }
459
460    #[inline]
461    pub fn as_raw(&self) -> cusparseSpMatDescr_t {
462        self.descr
463    }
464}
465
466impl<T> Drop for SpMat<'_, T> {
467    fn drop(&mut self) {
468        if let Ok(c) = cusparse() {
469            if let Ok(cu) = c.cusparse_destroy_sp_mat() {
470                let _ = unsafe { cu(self.descr) };
471            }
472        }
473    }
474}
475
476// ---- Dense vector / matrix -----------------------------------------------
477
478pub struct DnVec<'buf, T> {
479    descr: cusparseDnVecDescr_t,
480    _marker: PhantomData<&'buf mut T>,
481}
482
483unsafe impl<T> Send for DnVec<'_, T> {}
484
485impl<T> core::fmt::Debug for DnVec<'_, T> {
486    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
487        f.debug_struct("DnVec")
488            .field("descr", &self.descr)
489            .finish_non_exhaustive()
490    }
491}
492
493impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnVec<'buf, T> {
494    pub fn new(values: &'buf mut DeviceBuffer<T>) -> Result<Self> {
495        let c = cusparse()?;
496        let cu = c.cusparse_create_dn_vec()?;
497        let mut descr: cusparseDnVecDescr_t = core::ptr::null_mut();
498        check(unsafe {
499            cu(
500                &mut descr,
501                values.len() as i64,
502                values.as_raw().0 as *mut c_void,
503                T::data_type(),
504            )
505        })?;
506        Ok(Self {
507            descr,
508            _marker: PhantomData,
509        })
510    }
511}
512
513impl<T> DnVec<'_, T> {
514    #[inline]
515    pub fn as_raw(&self) -> cusparseDnVecDescr_t {
516        self.descr
517    }
518}
519
520impl<T> Drop for DnVec<'_, T> {
521    fn drop(&mut self) {
522        if let Ok(c) = cusparse() {
523            if let Ok(cu) = c.cusparse_destroy_dn_vec() {
524                let _ = unsafe { cu(self.descr) };
525            }
526        }
527    }
528}
529
530pub struct DnMat<'buf, T> {
531    descr: cusparseDnMatDescr_t,
532    _marker: PhantomData<&'buf mut T>,
533}
534
535unsafe impl<T> Send for DnMat<'_, T> {}
536
537impl<T> core::fmt::Debug for DnMat<'_, T> {
538    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
539        f.debug_struct("DnMat")
540            .field("descr", &self.descr)
541            .finish_non_exhaustive()
542    }
543}
544
545impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnMat<'buf, T> {
546    pub fn new(
547        rows: i64,
548        cols: i64,
549        ld: i64,
550        order: Order,
551        values: &'buf mut DeviceBuffer<T>,
552    ) -> Result<Self> {
553        let c = cusparse()?;
554        let cu = c.cusparse_create_dn_mat()?;
555        let mut descr: cusparseDnMatDescr_t = core::ptr::null_mut();
556        check(unsafe {
557            cu(
558                &mut descr,
559                rows,
560                cols,
561                ld,
562                values.as_raw().0 as *mut c_void,
563                T::data_type(),
564                order.raw(),
565            )
566        })?;
567        Ok(Self {
568            descr,
569            _marker: PhantomData,
570        })
571    }
572}
573
574impl<T> DnMat<'_, T> {
575    #[inline]
576    pub fn as_raw(&self) -> cusparseDnMatDescr_t {
577        self.descr
578    }
579}
580
581impl<T> Drop for DnMat<'_, T> {
582    fn drop(&mut self) {
583        if let Ok(c) = cusparse() {
584            if let Ok(cu) = c.cusparse_destroy_dn_mat() {
585                let _ = unsafe { cu(self.descr) };
586            }
587        }
588    }
589}
590
591// ---- SpMV ---------------------------------------------------------------
592
593/// Query buffer-size for `y = alpha * op(A) * x + beta * y`.
594#[allow(clippy::too_many_arguments)]
595pub fn spmv_buffer_size<T: SparseScalar>(
596    handle: &Handle,
597    op: Op,
598    alpha: &T,
599    a: &SpMat<'_, T>,
600    x: &DnVec<'_, T>,
601    beta: &T,
602    y: &DnVec<'_, T>,
603    alg: SpMVAlg,
604) -> Result<usize> {
605    let c = cusparse()?;
606    let cu = c.cusparse_spmv_buffer_size()?;
607    let mut size: usize = 0;
608    check(unsafe {
609        cu(
610            handle.as_raw(),
611            op.raw(),
612            alpha as *const T as *const c_void,
613            a.descr,
614            x.descr,
615            beta as *const T as *const c_void,
616            y.descr,
617            T::data_type(),
618            alg,
619            &mut size,
620        )
621    })?;
622    Ok(size)
623}
624
625/// `y = alpha * op(A) * x + beta * y`.
626#[allow(clippy::too_many_arguments)]
627pub fn spmv<T: SparseScalar>(
628    handle: &Handle,
629    op: Op,
630    alpha: &T,
631    a: &SpMat<'_, T>,
632    x: &DnVec<'_, T>,
633    beta: &T,
634    y: &mut DnVec<'_, T>,
635    alg: SpMVAlg,
636    workspace: &mut DeviceBuffer<u8>,
637) -> Result<()> {
638    let c = cusparse()?;
639    let cu = c.cusparse_spmv()?;
640    check(unsafe {
641        cu(
642            handle.as_raw(),
643            op.raw(),
644            alpha as *const T as *const c_void,
645            a.descr,
646            x.descr,
647            beta as *const T as *const c_void,
648            y.descr,
649            T::data_type(),
650            alg,
651            workspace.as_raw().0 as *mut c_void,
652        )
653    })
654}
655
656// ---- SpMM ---------------------------------------------------------------
657
658/// Query buffer-size for `C = alpha * op(A) * op(B) + beta * C`, `A` sparse.
659#[allow(clippy::too_many_arguments)]
660pub fn spmm_buffer_size<T: SparseScalar>(
661    handle: &Handle,
662    op_a: Op,
663    op_b: Op,
664    alpha: &T,
665    a: &SpMat<'_, T>,
666    b: &DnMat<'_, T>,
667    beta: &T,
668    c: &DnMat<'_, T>,
669    alg: SpMMAlg,
670) -> Result<usize> {
671    let c_api = cusparse()?;
672    let cu = c_api.cusparse_spmm_buffer_size()?;
673    let mut size = 0usize;
674    check(unsafe {
675        cu(
676            handle.as_raw(),
677            op_a.raw(),
678            op_b.raw(),
679            alpha as *const T as *const c_void,
680            a.descr,
681            b.descr,
682            beta as *const T as *const c_void,
683            c.descr,
684            T::data_type(),
685            alg,
686            &mut size,
687        )
688    })?;
689    Ok(size)
690}
691
692#[allow(clippy::too_many_arguments)]
693pub fn spmm<T: SparseScalar>(
694    handle: &Handle,
695    op_a: Op,
696    op_b: Op,
697    alpha: &T,
698    a: &SpMat<'_, T>,
699    b: &DnMat<'_, T>,
700    beta: &T,
701    c: &mut DnMat<'_, T>,
702    alg: SpMMAlg,
703    workspace: &mut DeviceBuffer<u8>,
704) -> Result<()> {
705    let c_api = cusparse()?;
706    let cu = c_api.cusparse_spmm()?;
707    check(unsafe {
708        cu(
709            handle.as_raw(),
710            op_a.raw(),
711            op_b.raw(),
712            alpha as *const T as *const c_void,
713            a.descr,
714            b.descr,
715            beta as *const T as *const c_void,
716            c.descr,
717            T::data_type(),
718            alg,
719            workspace.as_raw().0 as *mut c_void,
720        )
721    })
722}
723
724/// One-time preprocessing before [`spmm`]. Pre-computes algorithm-specific
725/// state into `workspace` so subsequent [`spmm`] calls (with the same A
726/// sparsity pattern + dimensions) are faster. Use this when the same
727/// matrix is multiplied many times.
728#[allow(clippy::too_many_arguments)]
729pub fn spmm_preprocess<T: SparseScalar>(
730    handle: &Handle,
731    op_a: Op,
732    op_b: Op,
733    alpha: &T,
734    a: &SpMat<'_, T>,
735    b: &DnMat<'_, T>,
736    beta: &T,
737    c: &mut DnMat<'_, T>,
738    alg: SpMMAlg,
739    workspace: &mut DeviceBuffer<u8>,
740) -> Result<()> {
741    let c_api = cusparse()?;
742    let cu = c_api.cusparse_spmm_preprocess()?;
743    check(unsafe {
744        cu(
745            handle.as_raw(),
746            op_a.raw(),
747            op_b.raw(),
748            alpha as *const T as *const c_void,
749            a.descr,
750            b.descr,
751            beta as *const T as *const c_void,
752            c.descr,
753            T::data_type(),
754            alg,
755            workspace.as_raw().0 as *mut c_void,
756        )
757    })
758}
759
760// ---- SpGEMM -------------------------------------------------------------
761
762/// Per-plan handle for a 3-phase SpGEMM computation.
763#[derive(Debug)]
764pub struct SpGEMMPlan {
765    raw: cusparseSpGEMMDescr_t,
766}
767
768impl SpGEMMPlan {
769    pub fn new() -> Result<Self> {
770        let c = cusparse()?;
771        let cu = c.cusparse_spgemm_create_descr()?;
772        let mut d: cusparseSpGEMMDescr_t = core::ptr::null_mut();
773        check(unsafe { cu(&mut d) })?;
774        Ok(Self { raw: d })
775    }
776}
777
778impl Drop for SpGEMMPlan {
779    fn drop(&mut self) {
780        if let Ok(c) = cusparse() {
781            if let Ok(cu) = c.cusparse_spgemm_destroy_descr() {
782                let _ = unsafe { cu(self.raw) };
783            }
784        }
785    }
786}
787
788/// Phase 1: work-estimation. The caller provides `buffer1` whose size is
789/// returned in `size1`; pass `null` the first time, then allocate & re-call.
790#[allow(clippy::too_many_arguments)]
791pub unsafe fn spgemm_work_estimation<T: SparseScalar>(
792    handle: &Handle,
793    op_a: Op,
794    op_b: Op,
795    alpha: &T,
796    a: &SpMat<'_, T>,
797    b: &SpMat<'_, T>,
798    beta: &T,
799    c: &mut SpMat<'_, T>,
800    alg: SpGEMMAlg,
801    plan: &SpGEMMPlan,
802    size1: &mut usize,
803    buffer1: *mut c_void,
804) -> Result<()> {
805    let c_api = cusparse()?;
806    let cu = c_api.cusparse_spgemm_work_estimation()?;
807    check(cu(
808        handle.as_raw(),
809        op_a.raw(),
810        op_b.raw(),
811        alpha as *const T as *const c_void,
812        a.descr,
813        b.descr,
814        beta as *const T as *const c_void,
815        c.descr,
816        T::data_type(),
817        alg,
818        plan.raw,
819        size1,
820        buffer1,
821    ))
822}
823
824/// Phase 2: compute. Same two-step pattern for `buffer2`.
825#[allow(clippy::too_many_arguments)]
826pub unsafe fn spgemm_compute<T: SparseScalar>(
827    handle: &Handle,
828    op_a: Op,
829    op_b: Op,
830    alpha: &T,
831    a: &SpMat<'_, T>,
832    b: &SpMat<'_, T>,
833    beta: &T,
834    c: &mut SpMat<'_, T>,
835    alg: SpGEMMAlg,
836    plan: &SpGEMMPlan,
837    size2: &mut usize,
838    buffer2: *mut c_void,
839) -> Result<()> {
840    let c_api = cusparse()?;
841    let cu = c_api.cusparse_spgemm_compute()?;
842    check(cu(
843        handle.as_raw(),
844        op_a.raw(),
845        op_b.raw(),
846        alpha as *const T as *const c_void,
847        a.descr,
848        b.descr,
849        beta as *const T as *const c_void,
850        c.descr,
851        T::data_type(),
852        alg,
853        plan.raw,
854        size2,
855        buffer2,
856    ))
857}
858
859/// Phase 3: write output arrays into the pre-populated output `SpMat`.
860#[allow(clippy::too_many_arguments)]
861pub fn spgemm_copy<T: SparseScalar>(
862    handle: &Handle,
863    op_a: Op,
864    op_b: Op,
865    alpha: &T,
866    a: &SpMat<'_, T>,
867    b: &SpMat<'_, T>,
868    beta: &T,
869    c: &mut SpMat<'_, T>,
870    alg: SpGEMMAlg,
871    plan: &SpGEMMPlan,
872) -> Result<()> {
873    let c_api = cusparse()?;
874    let cu = c_api.cusparse_spgemm_copy()?;
875    check(unsafe {
876        cu(
877            handle.as_raw(),
878            op_a.raw(),
879            op_b.raw(),
880            alpha as *const T as *const c_void,
881            a.descr,
882            b.descr,
883            beta as *const T as *const c_void,
884            c.descr,
885            T::data_type(),
886            alg,
887            plan.raw,
888        )
889    })
890}
891
892// ---- SpSV / SpSM --------------------------------------------------------
893
894#[derive(Debug)]
895pub struct SpSVPlan {
896    raw: cusparseSpSVDescr_t,
897}
898
899impl SpSVPlan {
900    pub fn new() -> Result<Self> {
901        let c = cusparse()?;
902        let cu = c.cusparse_spsv_create_descr()?;
903        let mut d: cusparseSpSVDescr_t = core::ptr::null_mut();
904        check(unsafe { cu(&mut d) })?;
905        Ok(Self { raw: d })
906    }
907}
908
909impl Drop for SpSVPlan {
910    fn drop(&mut self) {
911        if let Ok(c) = cusparse() {
912            if let Ok(cu) = c.cusparse_spsv_destroy_descr() {
913                let _ = unsafe { cu(self.raw) };
914            }
915        }
916    }
917}
918
919#[allow(clippy::too_many_arguments)]
920pub fn spsv_buffer_size<T: SparseScalar>(
921    handle: &Handle,
922    op: Op,
923    alpha: &T,
924    a: &SpMat<'_, T>,
925    x: &DnVec<'_, T>,
926    y: &DnVec<'_, T>,
927    alg: SpSVAlg,
928    plan: &SpSVPlan,
929) -> Result<usize> {
930    let c = cusparse()?;
931    let cu = c.cusparse_spsv_buffer_size()?;
932    let mut size = 0usize;
933    check(unsafe {
934        cu(
935            handle.as_raw(),
936            op.raw(),
937            alpha as *const T as *const c_void,
938            a.descr,
939            x.descr,
940            y.descr,
941            T::data_type(),
942            alg,
943            plan.raw,
944            &mut size,
945        )
946    })?;
947    Ok(size)
948}
949
950#[allow(clippy::too_many_arguments)]
951pub fn spsv_analysis<T: SparseScalar>(
952    handle: &Handle,
953    op: Op,
954    alpha: &T,
955    a: &SpMat<'_, T>,
956    x: &DnVec<'_, T>,
957    y: &DnVec<'_, T>,
958    alg: SpSVAlg,
959    plan: &SpSVPlan,
960    workspace: &mut DeviceBuffer<u8>,
961) -> Result<()> {
962    let c = cusparse()?;
963    let cu = c.cusparse_spsv_analysis()?;
964    check(unsafe {
965        cu(
966            handle.as_raw(),
967            op.raw(),
968            alpha as *const T as *const c_void,
969            a.descr,
970            x.descr,
971            y.descr,
972            T::data_type(),
973            alg,
974            plan.raw,
975            workspace.as_raw().0 as *mut c_void,
976        )
977    })
978}
979
980#[allow(clippy::too_many_arguments)]
981pub fn spsv_solve<T: SparseScalar>(
982    handle: &Handle,
983    op: Op,
984    alpha: &T,
985    a: &SpMat<'_, T>,
986    x: &DnVec<'_, T>,
987    y: &mut DnVec<'_, T>,
988    alg: SpSVAlg,
989    plan: &SpSVPlan,
990) -> Result<()> {
991    let c = cusparse()?;
992    let cu = c.cusparse_spsv_solve()?;
993    check(unsafe {
994        cu(
995            handle.as_raw(),
996            op.raw(),
997            alpha as *const T as *const c_void,
998            a.descr,
999            x.descr,
1000            y.descr,
1001            T::data_type(),
1002            alg,
1003            plan.raw,
1004        )
1005    })
1006}
1007
1008#[derive(Debug)]
1009pub struct SpSMPlan {
1010    raw: cusparseSpSMDescr_t,
1011}
1012
1013impl SpSMPlan {
1014    pub fn new() -> Result<Self> {
1015        let c = cusparse()?;
1016        let cu = c.cusparse_spsm_create_descr()?;
1017        let mut d: cusparseSpSMDescr_t = core::ptr::null_mut();
1018        check(unsafe { cu(&mut d) })?;
1019        Ok(Self { raw: d })
1020    }
1021}
1022
1023impl Drop for SpSMPlan {
1024    fn drop(&mut self) {
1025        if let Ok(c) = cusparse() {
1026            if let Ok(cu) = c.cusparse_spsm_destroy_descr() {
1027                let _ = unsafe { cu(self.raw) };
1028            }
1029        }
1030    }
1031}
1032
1033#[allow(clippy::too_many_arguments)]
1034pub fn spsm_buffer_size<T: SparseScalar>(
1035    handle: &Handle,
1036    op_a: Op,
1037    op_b: Op,
1038    alpha: &T,
1039    a: &SpMat<'_, T>,
1040    b: &DnMat<'_, T>,
1041    c: &DnMat<'_, T>,
1042    alg: SpSMAlg,
1043    plan: &SpSMPlan,
1044) -> Result<usize> {
1045    let c_api = cusparse()?;
1046    let cu = c_api.cusparse_spsm_buffer_size()?;
1047    let mut size = 0usize;
1048    check(unsafe {
1049        cu(
1050            handle.as_raw(),
1051            op_a.raw(),
1052            op_b.raw(),
1053            alpha as *const T as *const c_void,
1054            a.descr,
1055            b.descr,
1056            c.descr,
1057            T::data_type(),
1058            alg,
1059            plan.raw,
1060            &mut size,
1061        )
1062    })?;
1063    Ok(size)
1064}
1065
1066#[allow(clippy::too_many_arguments)]
1067pub fn spsm_analysis<T: SparseScalar>(
1068    handle: &Handle,
1069    op_a: Op,
1070    op_b: Op,
1071    alpha: &T,
1072    a: &SpMat<'_, T>,
1073    b: &DnMat<'_, T>,
1074    c: &DnMat<'_, T>,
1075    alg: SpSMAlg,
1076    plan: &SpSMPlan,
1077    workspace: &mut DeviceBuffer<u8>,
1078) -> Result<()> {
1079    let c_api = cusparse()?;
1080    let cu = c_api.cusparse_spsm_analysis()?;
1081    check(unsafe {
1082        cu(
1083            handle.as_raw(),
1084            op_a.raw(),
1085            op_b.raw(),
1086            alpha as *const T as *const c_void,
1087            a.descr,
1088            b.descr,
1089            c.descr,
1090            T::data_type(),
1091            alg,
1092            plan.raw,
1093            workspace.as_raw().0 as *mut c_void,
1094        )
1095    })
1096}
1097
1098#[allow(clippy::too_many_arguments)]
1099pub fn spsm_solve<T: SparseScalar>(
1100    handle: &Handle,
1101    op_a: Op,
1102    op_b: Op,
1103    alpha: &T,
1104    a: &SpMat<'_, T>,
1105    b: &DnMat<'_, T>,
1106    c: &mut DnMat<'_, T>,
1107    alg: SpSMAlg,
1108    plan: &SpSMPlan,
1109) -> Result<()> {
1110    let c_api = cusparse()?;
1111    let cu = c_api.cusparse_spsm_solve()?;
1112    check(unsafe {
1113        cu(
1114            handle.as_raw(),
1115            op_a.raw(),
1116            op_b.raw(),
1117            alpha as *const T as *const c_void,
1118            a.descr,
1119            b.descr,
1120            c.descr,
1121            T::data_type(),
1122            alg,
1123            plan.raw,
1124        )
1125    })
1126}
1127
1128// ---- SDDMM -------------------------------------------------------------
1129
1130#[allow(clippy::too_many_arguments)]
1131pub fn sddmm_buffer_size<T: SparseScalar>(
1132    handle: &Handle,
1133    op_a: Op,
1134    op_b: Op,
1135    alpha: &T,
1136    a: &DnMat<'_, T>,
1137    b: &DnMat<'_, T>,
1138    beta: &T,
1139    c: &SpMat<'_, T>,
1140    alg: SDDMMAlg,
1141) -> Result<usize> {
1142    let c_api = cusparse()?;
1143    let cu = c_api.cusparse_sddmm_buffer_size()?;
1144    let mut size = 0usize;
1145    check(unsafe {
1146        cu(
1147            handle.as_raw(),
1148            op_a.raw(),
1149            op_b.raw(),
1150            alpha as *const T as *const c_void,
1151            a.descr,
1152            b.descr,
1153            beta as *const T as *const c_void,
1154            c.descr,
1155            T::data_type(),
1156            alg,
1157            &mut size,
1158        )
1159    })?;
1160    Ok(size)
1161}
1162
1163#[allow(clippy::too_many_arguments)]
1164pub fn sddmm<T: SparseScalar>(
1165    handle: &Handle,
1166    op_a: Op,
1167    op_b: Op,
1168    alpha: &T,
1169    a: &DnMat<'_, T>,
1170    b: &DnMat<'_, T>,
1171    beta: &T,
1172    c: &mut SpMat<'_, T>,
1173    alg: SDDMMAlg,
1174    workspace: &mut DeviceBuffer<u8>,
1175) -> Result<()> {
1176    let c_api = cusparse()?;
1177    let cu = c_api.cusparse_sddmm()?;
1178    check(unsafe {
1179        cu(
1180            handle.as_raw(),
1181            op_a.raw(),
1182            op_b.raw(),
1183            alpha as *const T as *const c_void,
1184            a.descr,
1185            b.descr,
1186            beta as *const T as *const c_void,
1187            c.descr,
1188            T::data_type(),
1189            alg,
1190            workspace.as_raw().0 as *mut c_void,
1191        )
1192    })
1193}
1194
1195/// One-time preprocessing before [`sddmm`]. See [`spmm_preprocess`] for
1196/// the rationale.
1197#[allow(clippy::too_many_arguments)]
1198pub fn sddmm_preprocess<T: SparseScalar>(
1199    handle: &Handle,
1200    op_a: Op,
1201    op_b: Op,
1202    alpha: &T,
1203    a: &DnMat<'_, T>,
1204    b: &DnMat<'_, T>,
1205    beta: &T,
1206    c: &mut SpMat<'_, T>,
1207    alg: SDDMMAlg,
1208    workspace: &mut DeviceBuffer<u8>,
1209) -> Result<()> {
1210    let c_api = cusparse()?;
1211    let cu = c_api.cusparse_sddmm_preprocess()?;
1212    check(unsafe {
1213        cu(
1214            handle.as_raw(),
1215            op_a.raw(),
1216            op_b.raw(),
1217            alpha as *const T as *const c_void,
1218            a.descr,
1219            b.descr,
1220            beta as *const T as *const c_void,
1221            c.descr,
1222            T::data_type(),
1223            alg,
1224            workspace.as_raw().0 as *mut c_void,
1225        )
1226    })
1227}
1228
1229// ---- Sparse / dense conversions ----------------------------------------
1230
1231pub fn sparse_to_dense_buffer_size<T: SparseScalar>(
1232    handle: &Handle,
1233    sp: &SpMat<'_, T>,
1234    dn: &DnMat<'_, T>,
1235) -> Result<usize> {
1236    let c = cusparse()?;
1237    let cu = c.cusparse_sparse_to_dense_buffer_size()?;
1238    let mut size = 0usize;
1239    check(unsafe { cu(handle.as_raw(), sp.descr, dn.descr, 0, &mut size) })?;
1240    Ok(size)
1241}
1242
1243pub fn sparse_to_dense<T: SparseScalar>(
1244    handle: &Handle,
1245    sp: &SpMat<'_, T>,
1246    dn: &mut DnMat<'_, T>,
1247    workspace: &mut DeviceBuffer<u8>,
1248) -> Result<()> {
1249    let c = cusparse()?;
1250    let cu = c.cusparse_sparse_to_dense()?;
1251    check(unsafe {
1252        cu(
1253            handle.as_raw(),
1254            sp.descr,
1255            dn.descr,
1256            0,
1257            workspace.as_raw().0 as *mut c_void,
1258        )
1259    })
1260}
1261
1262pub fn dense_to_sparse_buffer_size<T: SparseScalar>(
1263    handle: &Handle,
1264    dn: &DnMat<'_, T>,
1265    sp: &SpMat<'_, T>,
1266) -> Result<usize> {
1267    let c = cusparse()?;
1268    let cu = c.cusparse_dense_to_sparse_buffer_size()?;
1269    let mut size = 0usize;
1270    check(unsafe { cu(handle.as_raw(), dn.descr, sp.descr, 0, &mut size) })?;
1271    Ok(size)
1272}
1273
1274pub fn dense_to_sparse_analysis<T: SparseScalar>(
1275    handle: &Handle,
1276    dn: &DnMat<'_, T>,
1277    sp: &SpMat<'_, T>,
1278    workspace: &mut DeviceBuffer<u8>,
1279) -> Result<()> {
1280    let c = cusparse()?;
1281    let cu = c.cusparse_dense_to_sparse_analysis()?;
1282    check(unsafe {
1283        cu(
1284            handle.as_raw(),
1285            dn.descr,
1286            sp.descr,
1287            0,
1288            workspace.as_raw().0 as *mut c_void,
1289        )
1290    })
1291}
1292
1293pub fn dense_to_sparse_convert<T: SparseScalar>(
1294    handle: &Handle,
1295    dn: &DnMat<'_, T>,
1296    sp: &mut SpMat<'_, T>,
1297    workspace: &mut DeviceBuffer<u8>,
1298) -> Result<()> {
1299    let c = cusparse()?;
1300    let cu = c.cusparse_dense_to_sparse_convert()?;
1301    check(unsafe {
1302        cu(
1303            handle.as_raw(),
1304            dn.descr,
1305            sp.descr,
1306            0,
1307            workspace.as_raw().0 as *mut c_void,
1308        )
1309    })
1310}
1311
1312/// Workspace size in bytes for [`csr2csc_ex2`].
1313#[allow(clippy::too_many_arguments)]
1314pub fn csr2csc_ex2_buffer_size<T: SparseScalar + baracuda_types::DeviceRepr>(
1315    handle: &Handle,
1316    m: i32,
1317    n: i32,
1318    nnz: i32,
1319    csr_val: &DeviceBuffer<T>,
1320    csr_row_ptr: &DeviceBuffer<i32>,
1321    csr_col_ind: &DeviceBuffer<i32>,
1322    csc_val: &mut DeviceBuffer<T>,
1323    csc_col_ptr: &mut DeviceBuffer<i32>,
1324    csc_row_ind: &mut DeviceBuffer<i32>,
1325    copy_values: bool,
1326    idx_base: IndexBase,
1327    alg: Csr2CscAlg,
1328) -> Result<usize> {
1329    let c = cusparse()?;
1330    let cu = c.cusparse_csr2csc_ex2_buffer_size()?;
1331    let mut size = 0usize;
1332    check(unsafe {
1333        cu(
1334            handle.as_raw(),
1335            m,
1336            n,
1337            nnz,
1338            csr_val.as_raw().0 as *const c_void,
1339            csr_row_ptr.as_raw().0 as *const i32,
1340            csr_col_ind.as_raw().0 as *const i32,
1341            csc_val.as_raw().0 as *mut c_void,
1342            csc_col_ptr.as_raw().0 as *mut i32,
1343            csc_row_ind.as_raw().0 as *mut i32,
1344            T::data_type(),
1345            copy_values as i32,
1346            idx_base,
1347            alg,
1348            &mut size,
1349        )
1350    })?;
1351    Ok(size)
1352}
1353
1354/// Convert a CSR matrix to CSC format using the modern Ex2 entry point —
1355/// supports algorithm selection (`alg`) and arbitrary value types.
1356#[allow(clippy::too_many_arguments)]
1357pub fn csr2csc_ex2<T: SparseScalar + baracuda_types::DeviceRepr>(
1358    handle: &Handle,
1359    m: i32,
1360    n: i32,
1361    nnz: i32,
1362    csr_val: &DeviceBuffer<T>,
1363    csr_row_ptr: &DeviceBuffer<i32>,
1364    csr_col_ind: &DeviceBuffer<i32>,
1365    csc_val: &mut DeviceBuffer<T>,
1366    csc_col_ptr: &mut DeviceBuffer<i32>,
1367    csc_row_ind: &mut DeviceBuffer<i32>,
1368    copy_values: bool,
1369    idx_base: IndexBase,
1370    alg: Csr2CscAlg,
1371    workspace: &mut DeviceBuffer<u8>,
1372) -> Result<()> {
1373    let c = cusparse()?;
1374    let cu = c.cusparse_csr2csc_ex2()?;
1375    check(unsafe {
1376        cu(
1377            handle.as_raw(),
1378            m,
1379            n,
1380            nnz,
1381            csr_val.as_raw().0 as *const c_void,
1382            csr_row_ptr.as_raw().0 as *const i32,
1383            csr_col_ind.as_raw().0 as *const i32,
1384            csc_val.as_raw().0 as *mut c_void,
1385            csc_col_ptr.as_raw().0 as *mut i32,
1386            csc_row_ind.as_raw().0 as *mut i32,
1387            T::data_type(),
1388            copy_values as i32,
1389            idx_base,
1390            alg,
1391            workspace.as_raw().0 as *mut c_void,
1392        )
1393    })
1394}
1395
1396// ---- Sparse BLAS-1 helpers ---------------------------------------------
1397
1398pub fn axpby<T: SparseScalar>(
1399    handle: &Handle,
1400    alpha: &T,
1401    x: &DnVec<'_, T>,
1402    beta: &T,
1403    y: &mut DnVec<'_, T>,
1404) -> Result<()> {
1405    let c = cusparse()?;
1406    let cu = c.cusparse_axpby()?;
1407    check(unsafe {
1408        cu(
1409            handle.as_raw(),
1410            alpha as *const T as *const c_void,
1411            x.descr,
1412            beta as *const T as *const c_void,
1413            y.descr,
1414        )
1415    })
1416}
1417
1418pub fn gather<T: SparseScalar>(
1419    handle: &Handle,
1420    y: &DnVec<'_, T>,
1421    x: &mut DnVec<'_, T>,
1422) -> Result<()> {
1423    let c = cusparse()?;
1424    let cu = c.cusparse_gather()?;
1425    check(unsafe { cu(handle.as_raw(), y.descr, x.descr) })
1426}
1427
1428pub fn scatter<T: SparseScalar>(
1429    handle: &Handle,
1430    x: &DnVec<'_, T>,
1431    y: &mut DnVec<'_, T>,
1432) -> Result<()> {
1433    let c = cusparse()?;
1434    let cu = c.cusparse_scatter()?;
1435    check(unsafe { cu(handle.as_raw(), x.descr, y.descr) })
1436}
1437
1438pub fn rot<T: SparseScalar>(
1439    handle: &Handle,
1440    c_cos: &T,
1441    s_sin: &T,
1442    x: &mut DnVec<'_, T>,
1443    y: &mut DnVec<'_, T>,
1444) -> Result<()> {
1445    let c_api = cusparse()?;
1446    let cu = c_api.cusparse_rot()?;
1447    check(unsafe {
1448        cu(
1449            handle.as_raw(),
1450            c_cos as *const T as *const c_void,
1451            s_sin as *const T as *const c_void,
1452            x.descr,
1453            y.descr,
1454        )
1455    })
1456}
1457
1458// ---- Back-compat re-exports for existing users --------------------------
1459
1460/// Legacy alias kept for callers from v0.1 — prefer [`SpMat::csr`].
1461pub type CsrMatrix<'buf> = SpMat<'buf, f32>;
1462/// Legacy alias — prefer [`DnVec`].
1463pub type DenseVector<'buf, T> = DnVec<'buf, T>;