Skip to main content

baracuda_cusparse/
lib.rs

1//! Safe Rust wrappers for NVIDIA cuSPARSE.
2//!
3//! Covers the modern generic-API surface: `Handle`, `SpMat` (CSR/CSC/COO/BSR),
4//! `DnMat`, `DnVec`, and the family of op algorithms — SpMV, SpMM, SpGEMM,
5//! SpSV, SpSM, SDDMM — plus CSR↔CSC and sparse↔dense conversions, and the
6//! sparse BLAS-1 helpers (`axpby`, `gather`, `scatter`, `rot`).
7//!
8//! All matrix/vector descriptors borrow the underlying [`DeviceBuffer`]s and
9//! tie their lifetime to them so the buffers can't be freed while cuSPARSE
10//! is still holding references.
11
12#![warn(missing_debug_implementations)]
13
14use core::ffi::c_void;
15use std::marker::PhantomData;
16
17use baracuda_cusparse_sys::{
18    cudaDataType, cusparse, cusparseDiagType_t, cusparseDnMatDescr_t, cusparseDnVecDescr_t,
19    cusparseFillMode_t, cusparseHandle_t, cusparseIndexBase_t, cusparseIndexType_t,
20    cusparseOperation_t, cusparseOrder_t, cusparseSpGEMMDescr_t, cusparseSpMatAttribute_t,
21    cusparseSpMatDescr_t, cusparseSpSMDescr_t, cusparseSpSVDescr_t, cusparseStatus_t,
22};
23use baracuda_driver::{DeviceBuffer, Stream};
24use baracuda_types::{Complex32, Complex64};
25
26pub use baracuda_cusparse_sys::{
27    cusparseCsr2CscAlg_t as Csr2CscAlg, cusparseIndexBase_t as IndexBase,
28    cusparseSDDMMAlg_t as SDDMMAlg, cusparseSpGEMMAlg_t as SpGEMMAlg,
29    cusparseSpMMAlg_t as SpMMAlg, cusparseSpMVAlg_t as SpMVAlg, cusparseSpSMAlg_t as SpSMAlg,
30    cusparseSpSVAlg_t as SpSVAlg,
31};
32
33/// Error type for cuSPARSE operations.
34pub type Error = baracuda_core::Error<cusparseStatus_t>;
35/// Result alias.
36pub type Result<T, E = Error> = core::result::Result<T, E>;
37
38#[inline]
39fn check(status: cusparseStatus_t) -> Result<()> {
40    Error::check(status)
41}
42
43// ---- scalar <-> cudaDataType --------------------------------------------
44
45/// Types that cuSPARSE's generic API accepts as element / compute type.
46pub trait SparseScalar: sealed::Sealed + Copy + 'static {
47    /// cuSPARSE / cuBLAS element-type tag.
48    fn data_type() -> cudaDataType;
49}
50
51impl SparseScalar for f32 {
52    fn data_type() -> cudaDataType {
53        cudaDataType::R_32F
54    }
55}
56impl SparseScalar for f64 {
57    fn data_type() -> cudaDataType {
58        cudaDataType::R_64F
59    }
60}
61impl SparseScalar for Complex32 {
62    fn data_type() -> cudaDataType {
63        cudaDataType::C_32F
64    }
65}
66impl SparseScalar for Complex64 {
67    fn data_type() -> cudaDataType {
68        cudaDataType::C_64F
69    }
70}
71
72mod sealed {
73    use baracuda_types::{Complex32, Complex64};
74    pub trait Sealed {}
75    impl Sealed for f32 {}
76    impl Sealed for f64 {}
77    impl Sealed for Complex32 {}
78    impl Sealed for Complex64 {}
79}
80
81// ---- Op / Order / Fill / Diag wrappers ----------------------------------
82
83/// Transpose / conjugate-transpose selector.
84#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
85pub enum Op {
86    #[default]
87    N,
88    T,
89    C,
90}
91
92impl Op {
93    fn raw(self) -> cusparseOperation_t {
94        match self {
95            Op::N => cusparseOperation_t::N,
96            Op::T => cusparseOperation_t::T,
97            Op::C => cusparseOperation_t::C,
98        }
99    }
100}
101
102/// Dense-matrix storage order.
103#[derive(Copy, Clone, Debug, Eq, PartialEq)]
104pub enum Order {
105    Row,
106    Col,
107}
108
109impl Order {
110    fn raw(self) -> cusparseOrder_t {
111        match self {
112            Order::Row => cusparseOrder_t::Row,
113            Order::Col => cusparseOrder_t::Col,
114        }
115    }
116}
117
118#[derive(Copy, Clone, Debug, Eq, PartialEq)]
119pub enum Fill {
120    Lower,
121    Upper,
122}
123
124impl Fill {
125    fn raw(self) -> cusparseFillMode_t {
126        match self {
127            Fill::Lower => cusparseFillMode_t::Lower,
128            Fill::Upper => cusparseFillMode_t::Upper,
129        }
130    }
131}
132
133#[derive(Copy, Clone, Debug, Eq, PartialEq)]
134pub enum Diag {
135    NonUnit,
136    Unit,
137}
138
139impl Diag {
140    fn raw(self) -> cusparseDiagType_t {
141        match self {
142            Diag::NonUnit => cusparseDiagType_t::NonUnit,
143            Diag::Unit => cusparseDiagType_t::Unit,
144        }
145    }
146}
147
148// ---- Handle -------------------------------------------------------------
149
150/// Owned cuSPARSE handle.
151pub struct Handle {
152    handle: cusparseHandle_t,
153}
154
155unsafe impl Send for Handle {}
156
157impl core::fmt::Debug for Handle {
158    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
159        f.debug_struct("cusparse::Handle")
160            .field("handle", &self.handle)
161            .finish()
162    }
163}
164
165impl Handle {
166    pub fn new() -> Result<Self> {
167        let c = cusparse()?;
168        let cu = c.cusparse_create()?;
169        let mut h: cusparseHandle_t = core::ptr::null_mut();
170        check(unsafe { cu(&mut h) })?;
171        Ok(Self { handle: h })
172    }
173
174    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
175        let c = cusparse()?;
176        let cu = c.cusparse_set_stream()?;
177        check(unsafe { cu(self.handle, stream.as_raw() as _) })
178    }
179
180    pub fn version(&self) -> Result<i32> {
181        let c = cusparse()?;
182        let cu = c.cusparse_get_version()?;
183        let mut v: core::ffi::c_int = 0;
184        check(unsafe { cu(self.handle, &mut v) })?;
185        Ok(v)
186    }
187
188    #[inline]
189    pub fn as_raw(&self) -> cusparseHandle_t {
190        self.handle
191    }
192}
193
194impl Drop for Handle {
195    fn drop(&mut self) {
196        if let Ok(c) = cusparse() {
197            if let Ok(cu) = c.cusparse_destroy() {
198                let _ = unsafe { cu(self.handle) };
199            }
200        }
201    }
202}
203
204// ---- Sparse matrix descriptor -------------------------------------------
205
206/// A sparse-matrix descriptor (CSR / CSC / COO / BSR). The descriptor keeps
207/// pointers to externally-owned device buffers; the lifetime parameter ties
208/// those buffers to the descriptor.
209pub struct SpMat<'buf, T> {
210    descr: cusparseSpMatDescr_t,
211    _markers: PhantomData<&'buf mut T>,
212}
213
214unsafe impl<T> Send for SpMat<'_, T> {}
215
216impl<T> core::fmt::Debug for SpMat<'_, T> {
217    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
218        f.debug_struct("SpMat")
219            .field("descr", &self.descr)
220            .finish_non_exhaustive()
221    }
222}
223
224impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> SpMat<'buf, T> {
225    /// Build a CSR (compressed sparse row) descriptor.
226    ///
227    /// `row_offsets.len()` must equal `rows + 1`; `col_indices.len()` and
228    /// `values.len()` must equal `nnz`.
229    pub fn csr(
230        rows: i64,
231        cols: i64,
232        nnz: i64,
233        row_offsets: &'buf mut DeviceBuffer<i32>,
234        col_indices: &'buf mut DeviceBuffer<i32>,
235        values: &'buf mut DeviceBuffer<T>,
236    ) -> Result<Self> {
237        let c = cusparse()?;
238        let cu = c.cusparse_create_csr()?;
239        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
240        check(unsafe {
241            cu(
242                &mut descr,
243                rows,
244                cols,
245                nnz,
246                row_offsets.as_raw().0 as *mut c_void,
247                col_indices.as_raw().0 as *mut c_void,
248                values.as_raw().0 as *mut c_void,
249                cusparseIndexType_t::I32I,
250                cusparseIndexType_t::I32I,
251                cusparseIndexBase_t::Zero,
252                T::data_type(),
253            )
254        })?;
255        Ok(Self {
256            descr,
257            _markers: PhantomData,
258        })
259    }
260
261    /// Build a CSC (compressed sparse column) descriptor.
262    pub fn csc(
263        rows: i64,
264        cols: i64,
265        nnz: i64,
266        col_offsets: &'buf mut DeviceBuffer<i32>,
267        row_indices: &'buf mut DeviceBuffer<i32>,
268        values: &'buf mut DeviceBuffer<T>,
269    ) -> Result<Self> {
270        let c = cusparse()?;
271        let cu = c.cusparse_create_csc()?;
272        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
273        check(unsafe {
274            cu(
275                &mut descr,
276                rows,
277                cols,
278                nnz,
279                col_offsets.as_raw().0 as *mut c_void,
280                row_indices.as_raw().0 as *mut c_void,
281                values.as_raw().0 as *mut c_void,
282                cusparseIndexType_t::I32I,
283                cusparseIndexType_t::I32I,
284                cusparseIndexBase_t::Zero,
285                T::data_type(),
286            )
287        })?;
288        Ok(Self {
289            descr,
290            _markers: PhantomData,
291        })
292    }
293
294    /// Build a BSR (block-sparse-row) descriptor.
295    #[allow(clippy::too_many_arguments)]
296    pub fn bsr(
297        brows: i64,
298        bcols: i64,
299        bnnz: i64,
300        row_block_dim: i64,
301        col_block_dim: i64,
302        order: Order,
303        row_offsets: &'buf mut DeviceBuffer<i32>,
304        col_indices: &'buf mut DeviceBuffer<i32>,
305        values: &'buf mut DeviceBuffer<T>,
306    ) -> Result<Self> {
307        let c = cusparse()?;
308        let cu = c.cusparse_create_bsr()?;
309        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
310        check(unsafe {
311            cu(
312                &mut descr,
313                brows,
314                bcols,
315                bnnz,
316                row_block_dim,
317                col_block_dim,
318                row_offsets.as_raw().0 as *mut c_void,
319                col_indices.as_raw().0 as *mut c_void,
320                values.as_raw().0 as *mut c_void,
321                cusparseIndexType_t::I32I,
322                cusparseIndexType_t::I32I,
323                cusparseIndexBase_t::Zero,
324                T::data_type(),
325                order.raw(),
326            )
327        })?;
328        Ok(Self {
329            descr,
330            _markers: PhantomData,
331        })
332    }
333
334    /// Build a COO (coordinate) descriptor.
335    pub fn coo(
336        rows: i64,
337        cols: i64,
338        nnz: i64,
339        row_indices: &'buf mut DeviceBuffer<i32>,
340        col_indices: &'buf mut DeviceBuffer<i32>,
341        values: &'buf mut DeviceBuffer<T>,
342    ) -> Result<Self> {
343        let c = cusparse()?;
344        let cu = c.cusparse_create_coo()?;
345        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
346        check(unsafe {
347            cu(
348                &mut descr,
349                rows,
350                cols,
351                nnz,
352                row_indices.as_raw().0 as *mut c_void,
353                col_indices.as_raw().0 as *mut c_void,
354                values.as_raw().0 as *mut c_void,
355                cusparseIndexType_t::I32I,
356                cusparseIndexBase_t::Zero,
357                T::data_type(),
358            )
359        })?;
360        Ok(Self {
361            descr,
362            _markers: PhantomData,
363        })
364    }
365}
366
367impl<T> SpMat<'_, T> {
368    /// Sparse matrix dimensions: `(rows, cols, nnz)`.
369    pub fn shape(&self) -> Result<(i64, i64, i64)> {
370        let c = cusparse()?;
371        let cu = c.cusparse_sp_mat_get_size()?;
372        let (mut r, mut col, mut nz) = (0i64, 0i64, 0i64);
373        check(unsafe { cu(self.descr, &mut r, &mut col, &mut nz) })?;
374        Ok((r, col, nz))
375    }
376
377    /// Rebind a CSR descriptor's underlying device pointers without
378    /// rebuilding it. Saves descriptor recreation when the same shape
379    /// is reused with new data.
380    ///
381    /// # Safety
382    ///
383    /// All three pointers must be live device allocations matching the
384    /// original `(rows + 1, nnz, nnz)` element counts and the original
385    /// element types. They must stay valid until the next operation
386    /// on this descriptor completes.
387    pub unsafe fn set_csr_pointers(
388        &self,
389        row_offsets: *mut c_void,
390        col_indices: *mut c_void,
391        values: *mut c_void,
392    ) -> Result<()> { unsafe {
393        let c = cusparse()?;
394        let cu = c.cusparse_csr_set_pointers()?;
395        check(cu(self.descr, row_offsets, col_indices, values))
396    }}
397
398    /// Rebind a CSC descriptor's underlying device pointers.
399    ///
400    /// # Safety
401    ///
402    /// See [`Self::set_csr_pointers`].
403    pub unsafe fn set_csc_pointers(
404        &self,
405        col_offsets: *mut c_void,
406        row_indices: *mut c_void,
407        values: *mut c_void,
408    ) -> Result<()> { unsafe {
409        let c = cusparse()?;
410        let cu = c.cusparse_csc_set_pointers()?;
411        check(cu(self.descr, col_offsets, row_indices, values))
412    }}
413
414    /// Rebind a COO descriptor's underlying device pointers.
415    ///
416    /// # Safety
417    ///
418    /// See [`Self::set_csr_pointers`].
419    pub unsafe fn set_coo_pointers(
420        &self,
421        row_indices: *mut c_void,
422        col_indices: *mut c_void,
423        values: *mut c_void,
424    ) -> Result<()> { unsafe {
425        let c = cusparse()?;
426        let cu = c.cusparse_coo_set_pointers()?;
427        check(cu(self.descr, row_indices, col_indices, values))
428    }}
429
430    /// Set the fill-triangle attribute (for triangular solves).
431    pub fn set_fill(&self, fill: Fill) -> Result<()> {
432        let c = cusparse()?;
433        let cu = c.cusparse_sp_mat_set_attribute()?;
434        let raw = fill.raw();
435        check(unsafe {
436            cu(
437                self.descr,
438                cusparseSpMatAttribute_t::FillMode,
439                &raw as *const _ as *const c_void,
440                core::mem::size_of::<cusparseFillMode_t>(),
441            )
442        })
443    }
444
445    /// Set the diagonal-type attribute (unit vs non-unit, for triangular solves).
446    pub fn set_diag(&self, diag: Diag) -> Result<()> {
447        let c = cusparse()?;
448        let cu = c.cusparse_sp_mat_set_attribute()?;
449        let raw = diag.raw();
450        check(unsafe {
451            cu(
452                self.descr,
453                cusparseSpMatAttribute_t::DiagType,
454                &raw as *const _ as *const c_void,
455                core::mem::size_of::<cusparseDiagType_t>(),
456            )
457        })
458    }
459
460    #[inline]
461    pub fn as_raw(&self) -> cusparseSpMatDescr_t {
462        self.descr
463    }
464}
465
466impl<T> Drop for SpMat<'_, T> {
467    fn drop(&mut self) {
468        if let Ok(c) = cusparse() {
469            if let Ok(cu) = c.cusparse_destroy_sp_mat() {
470                let _ = unsafe { cu(self.descr) };
471            }
472        }
473    }
474}
475
476// ---- Dense vector / matrix -----------------------------------------------
477
478pub struct DnVec<'buf, T> {
479    descr: cusparseDnVecDescr_t,
480    _marker: PhantomData<&'buf mut T>,
481}
482
483unsafe impl<T> Send for DnVec<'_, T> {}
484
485impl<T> core::fmt::Debug for DnVec<'_, T> {
486    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
487        f.debug_struct("DnVec")
488            .field("descr", &self.descr)
489            .finish_non_exhaustive()
490    }
491}
492
493impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnVec<'buf, T> {
494    pub fn new(values: &'buf mut DeviceBuffer<T>) -> Result<Self> {
495        let c = cusparse()?;
496        let cu = c.cusparse_create_dn_vec()?;
497        let mut descr: cusparseDnVecDescr_t = core::ptr::null_mut();
498        check(unsafe {
499            cu(
500                &mut descr,
501                values.len() as i64,
502                values.as_raw().0 as *mut c_void,
503                T::data_type(),
504            )
505        })?;
506        Ok(Self {
507            descr,
508            _marker: PhantomData,
509        })
510    }
511}
512
513impl<T> DnVec<'_, T> {
514    #[inline]
515    pub fn as_raw(&self) -> cusparseDnVecDescr_t {
516        self.descr
517    }
518}
519
520impl<T> Drop for DnVec<'_, T> {
521    fn drop(&mut self) {
522        if let Ok(c) = cusparse() {
523            if let Ok(cu) = c.cusparse_destroy_dn_vec() {
524                let _ = unsafe { cu(self.descr) };
525            }
526        }
527    }
528}
529
530pub struct DnMat<'buf, T> {
531    descr: cusparseDnMatDescr_t,
532    _marker: PhantomData<&'buf mut T>,
533}
534
535unsafe impl<T> Send for DnMat<'_, T> {}
536
537impl<T> core::fmt::Debug for DnMat<'_, T> {
538    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
539        f.debug_struct("DnMat")
540            .field("descr", &self.descr)
541            .finish_non_exhaustive()
542    }
543}
544
545impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnMat<'buf, T> {
546    pub fn new(
547        rows: i64,
548        cols: i64,
549        ld: i64,
550        order: Order,
551        values: &'buf mut DeviceBuffer<T>,
552    ) -> Result<Self> {
553        let c = cusparse()?;
554        let cu = c.cusparse_create_dn_mat()?;
555        let mut descr: cusparseDnMatDescr_t = core::ptr::null_mut();
556        check(unsafe {
557            cu(
558                &mut descr,
559                rows,
560                cols,
561                ld,
562                values.as_raw().0 as *mut c_void,
563                T::data_type(),
564                order.raw(),
565            )
566        })?;
567        Ok(Self {
568            descr,
569            _marker: PhantomData,
570        })
571    }
572}
573
574impl<T> DnMat<'_, T> {
575    #[inline]
576    pub fn as_raw(&self) -> cusparseDnMatDescr_t {
577        self.descr
578    }
579}
580
581impl<T> Drop for DnMat<'_, T> {
582    fn drop(&mut self) {
583        if let Ok(c) = cusparse() {
584            if let Ok(cu) = c.cusparse_destroy_dn_mat() {
585                let _ = unsafe { cu(self.descr) };
586            }
587        }
588    }
589}
590
591// ---- SpMV ---------------------------------------------------------------
592
593/// Query buffer-size for `y = alpha * op(A) * x + beta * y`.
594#[allow(clippy::too_many_arguments)]
595pub fn spmv_buffer_size<T: SparseScalar>(
596    handle: &Handle,
597    op: Op,
598    alpha: &T,
599    a: &SpMat<'_, T>,
600    x: &DnVec<'_, T>,
601    beta: &T,
602    y: &DnVec<'_, T>,
603    alg: SpMVAlg,
604) -> Result<usize> {
605    let c = cusparse()?;
606    let cu = c.cusparse_spmv_buffer_size()?;
607    let mut size: usize = 0;
608    check(unsafe {
609        cu(
610            handle.as_raw(),
611            op.raw(),
612            alpha as *const T as *const c_void,
613            a.descr,
614            x.descr,
615            beta as *const T as *const c_void,
616            y.descr,
617            T::data_type(),
618            alg,
619            &mut size,
620        )
621    })?;
622    Ok(size)
623}
624
625/// `y = alpha * op(A) * x + beta * y`.
626#[allow(clippy::too_many_arguments)]
627pub fn spmv<T: SparseScalar>(
628    handle: &Handle,
629    op: Op,
630    alpha: &T,
631    a: &SpMat<'_, T>,
632    x: &DnVec<'_, T>,
633    beta: &T,
634    y: &mut DnVec<'_, T>,
635    alg: SpMVAlg,
636    workspace: &mut DeviceBuffer<u8>,
637) -> Result<()> {
638    let c = cusparse()?;
639    let cu = c.cusparse_spmv()?;
640    check(unsafe {
641        cu(
642            handle.as_raw(),
643            op.raw(),
644            alpha as *const T as *const c_void,
645            a.descr,
646            x.descr,
647            beta as *const T as *const c_void,
648            y.descr,
649            T::data_type(),
650            alg,
651            workspace.as_raw().0 as *mut c_void,
652        )
653    })
654}
655
656// ---- SpMM ---------------------------------------------------------------
657
658/// Query buffer-size for `C = alpha * op(A) * op(B) + beta * C`, `A` sparse.
659#[allow(clippy::too_many_arguments)]
660pub fn spmm_buffer_size<T: SparseScalar>(
661    handle: &Handle,
662    op_a: Op,
663    op_b: Op,
664    alpha: &T,
665    a: &SpMat<'_, T>,
666    b: &DnMat<'_, T>,
667    beta: &T,
668    c: &DnMat<'_, T>,
669    alg: SpMMAlg,
670) -> Result<usize> {
671    let c_api = cusparse()?;
672    let cu = c_api.cusparse_spmm_buffer_size()?;
673    let mut size = 0usize;
674    check(unsafe {
675        cu(
676            handle.as_raw(),
677            op_a.raw(),
678            op_b.raw(),
679            alpha as *const T as *const c_void,
680            a.descr,
681            b.descr,
682            beta as *const T as *const c_void,
683            c.descr,
684            T::data_type(),
685            alg,
686            &mut size,
687        )
688    })?;
689    Ok(size)
690}
691
692#[allow(clippy::too_many_arguments)]
693pub fn spmm<T: SparseScalar>(
694    handle: &Handle,
695    op_a: Op,
696    op_b: Op,
697    alpha: &T,
698    a: &SpMat<'_, T>,
699    b: &DnMat<'_, T>,
700    beta: &T,
701    c: &mut DnMat<'_, T>,
702    alg: SpMMAlg,
703    workspace: &mut DeviceBuffer<u8>,
704) -> Result<()> {
705    let c_api = cusparse()?;
706    let cu = c_api.cusparse_spmm()?;
707    check(unsafe {
708        cu(
709            handle.as_raw(),
710            op_a.raw(),
711            op_b.raw(),
712            alpha as *const T as *const c_void,
713            a.descr,
714            b.descr,
715            beta as *const T as *const c_void,
716            c.descr,
717            T::data_type(),
718            alg,
719            workspace.as_raw().0 as *mut c_void,
720        )
721    })
722}
723
724/// One-time preprocessing before [`spmm`]. Pre-computes algorithm-specific
725/// state into `workspace` so subsequent [`spmm`] calls (with the same A
726/// sparsity pattern + dimensions) are faster. Use this when the same
727/// matrix is multiplied many times.
728#[allow(clippy::too_many_arguments)]
729pub fn spmm_preprocess<T: SparseScalar>(
730    handle: &Handle,
731    op_a: Op,
732    op_b: Op,
733    alpha: &T,
734    a: &SpMat<'_, T>,
735    b: &DnMat<'_, T>,
736    beta: &T,
737    c: &mut DnMat<'_, T>,
738    alg: SpMMAlg,
739    workspace: &mut DeviceBuffer<u8>,
740) -> Result<()> {
741    let c_api = cusparse()?;
742    let cu = c_api.cusparse_spmm_preprocess()?;
743    check(unsafe {
744        cu(
745            handle.as_raw(),
746            op_a.raw(),
747            op_b.raw(),
748            alpha as *const T as *const c_void,
749            a.descr,
750            b.descr,
751            beta as *const T as *const c_void,
752            c.descr,
753            T::data_type(),
754            alg,
755            workspace.as_raw().0 as *mut c_void,
756        )
757    })
758}
759
760// ---- SpGEMM -------------------------------------------------------------
761
762/// Per-plan handle for a 3-phase SpGEMM computation.
763#[derive(Debug)]
764pub struct SpGEMMPlan {
765    raw: cusparseSpGEMMDescr_t,
766}
767
768impl SpGEMMPlan {
769    pub fn new() -> Result<Self> {
770        let c = cusparse()?;
771        let cu = c.cusparse_spgemm_create_descr()?;
772        let mut d: cusparseSpGEMMDescr_t = core::ptr::null_mut();
773        check(unsafe { cu(&mut d) })?;
774        Ok(Self { raw: d })
775    }
776}
777
778impl Drop for SpGEMMPlan {
779    fn drop(&mut self) {
780        if let Ok(c) = cusparse() {
781            if let Ok(cu) = c.cusparse_spgemm_destroy_descr() {
782                let _ = unsafe { cu(self.raw) };
783            }
784        }
785    }
786}
787
788/// Phase 1: work-estimation. The caller provides `buffer1` whose size is
789/// returned in `size1`; pass `null` the first time, then allocate & re-call.
790///
791/// # Safety
792///
793/// `buffer1` must be either null (size-query mode) or a valid device
794/// pointer to at least `*size1` bytes of writable scratch memory that
795/// remains live for the duration of the underlying cuSPARSE call.
796#[allow(clippy::too_many_arguments)]
797pub unsafe fn spgemm_work_estimation<T: SparseScalar>(
798    handle: &Handle,
799    op_a: Op,
800    op_b: Op,
801    alpha: &T,
802    a: &SpMat<'_, T>,
803    b: &SpMat<'_, T>,
804    beta: &T,
805    c: &mut SpMat<'_, T>,
806    alg: SpGEMMAlg,
807    plan: &SpGEMMPlan,
808    size1: &mut usize,
809    buffer1: *mut c_void,
810) -> Result<()> { unsafe {
811    let c_api = cusparse()?;
812    let cu = c_api.cusparse_spgemm_work_estimation()?;
813    check(cu(
814        handle.as_raw(),
815        op_a.raw(),
816        op_b.raw(),
817        alpha as *const T as *const c_void,
818        a.descr,
819        b.descr,
820        beta as *const T as *const c_void,
821        c.descr,
822        T::data_type(),
823        alg,
824        plan.raw,
825        size1,
826        buffer1,
827    ))
828}}
829
830/// Phase 2: compute. Same two-step pattern for `buffer2`.
831///
832/// # Safety
833///
834/// `buffer2` must be either null (size-query mode) or a valid device
835/// pointer to at least `*size2` bytes of writable scratch memory that
836/// remains live for the duration of the underlying cuSPARSE call.
837#[allow(clippy::too_many_arguments)]
838pub unsafe fn spgemm_compute<T: SparseScalar>(
839    handle: &Handle,
840    op_a: Op,
841    op_b: Op,
842    alpha: &T,
843    a: &SpMat<'_, T>,
844    b: &SpMat<'_, T>,
845    beta: &T,
846    c: &mut SpMat<'_, T>,
847    alg: SpGEMMAlg,
848    plan: &SpGEMMPlan,
849    size2: &mut usize,
850    buffer2: *mut c_void,
851) -> Result<()> { unsafe {
852    let c_api = cusparse()?;
853    let cu = c_api.cusparse_spgemm_compute()?;
854    check(cu(
855        handle.as_raw(),
856        op_a.raw(),
857        op_b.raw(),
858        alpha as *const T as *const c_void,
859        a.descr,
860        b.descr,
861        beta as *const T as *const c_void,
862        c.descr,
863        T::data_type(),
864        alg,
865        plan.raw,
866        size2,
867        buffer2,
868    ))
869}}
870
871/// Phase 3: write output arrays into the pre-populated output `SpMat`.
872#[allow(clippy::too_many_arguments)]
873pub fn spgemm_copy<T: SparseScalar>(
874    handle: &Handle,
875    op_a: Op,
876    op_b: Op,
877    alpha: &T,
878    a: &SpMat<'_, T>,
879    b: &SpMat<'_, T>,
880    beta: &T,
881    c: &mut SpMat<'_, T>,
882    alg: SpGEMMAlg,
883    plan: &SpGEMMPlan,
884) -> Result<()> {
885    let c_api = cusparse()?;
886    let cu = c_api.cusparse_spgemm_copy()?;
887    check(unsafe {
888        cu(
889            handle.as_raw(),
890            op_a.raw(),
891            op_b.raw(),
892            alpha as *const T as *const c_void,
893            a.descr,
894            b.descr,
895            beta as *const T as *const c_void,
896            c.descr,
897            T::data_type(),
898            alg,
899            plan.raw,
900        )
901    })
902}
903
904// ---- SpSV / SpSM --------------------------------------------------------
905
906#[derive(Debug)]
907pub struct SpSVPlan {
908    raw: cusparseSpSVDescr_t,
909}
910
911impl SpSVPlan {
912    pub fn new() -> Result<Self> {
913        let c = cusparse()?;
914        let cu = c.cusparse_spsv_create_descr()?;
915        let mut d: cusparseSpSVDescr_t = core::ptr::null_mut();
916        check(unsafe { cu(&mut d) })?;
917        Ok(Self { raw: d })
918    }
919}
920
921impl Drop for SpSVPlan {
922    fn drop(&mut self) {
923        if let Ok(c) = cusparse() {
924            if let Ok(cu) = c.cusparse_spsv_destroy_descr() {
925                let _ = unsafe { cu(self.raw) };
926            }
927        }
928    }
929}
930
931#[allow(clippy::too_many_arguments)]
932pub fn spsv_buffer_size<T: SparseScalar>(
933    handle: &Handle,
934    op: Op,
935    alpha: &T,
936    a: &SpMat<'_, T>,
937    x: &DnVec<'_, T>,
938    y: &DnVec<'_, T>,
939    alg: SpSVAlg,
940    plan: &SpSVPlan,
941) -> Result<usize> {
942    let c = cusparse()?;
943    let cu = c.cusparse_spsv_buffer_size()?;
944    let mut size = 0usize;
945    check(unsafe {
946        cu(
947            handle.as_raw(),
948            op.raw(),
949            alpha as *const T as *const c_void,
950            a.descr,
951            x.descr,
952            y.descr,
953            T::data_type(),
954            alg,
955            plan.raw,
956            &mut size,
957        )
958    })?;
959    Ok(size)
960}
961
962#[allow(clippy::too_many_arguments)]
963pub fn spsv_analysis<T: SparseScalar>(
964    handle: &Handle,
965    op: Op,
966    alpha: &T,
967    a: &SpMat<'_, T>,
968    x: &DnVec<'_, T>,
969    y: &DnVec<'_, T>,
970    alg: SpSVAlg,
971    plan: &SpSVPlan,
972    workspace: &mut DeviceBuffer<u8>,
973) -> Result<()> {
974    let c = cusparse()?;
975    let cu = c.cusparse_spsv_analysis()?;
976    check(unsafe {
977        cu(
978            handle.as_raw(),
979            op.raw(),
980            alpha as *const T as *const c_void,
981            a.descr,
982            x.descr,
983            y.descr,
984            T::data_type(),
985            alg,
986            plan.raw,
987            workspace.as_raw().0 as *mut c_void,
988        )
989    })
990}
991
992#[allow(clippy::too_many_arguments)]
993pub fn spsv_solve<T: SparseScalar>(
994    handle: &Handle,
995    op: Op,
996    alpha: &T,
997    a: &SpMat<'_, T>,
998    x: &DnVec<'_, T>,
999    y: &mut DnVec<'_, T>,
1000    alg: SpSVAlg,
1001    plan: &SpSVPlan,
1002) -> Result<()> {
1003    let c = cusparse()?;
1004    let cu = c.cusparse_spsv_solve()?;
1005    check(unsafe {
1006        cu(
1007            handle.as_raw(),
1008            op.raw(),
1009            alpha as *const T as *const c_void,
1010            a.descr,
1011            x.descr,
1012            y.descr,
1013            T::data_type(),
1014            alg,
1015            plan.raw,
1016        )
1017    })
1018}
1019
1020#[derive(Debug)]
1021pub struct SpSMPlan {
1022    raw: cusparseSpSMDescr_t,
1023}
1024
1025impl SpSMPlan {
1026    pub fn new() -> Result<Self> {
1027        let c = cusparse()?;
1028        let cu = c.cusparse_spsm_create_descr()?;
1029        let mut d: cusparseSpSMDescr_t = core::ptr::null_mut();
1030        check(unsafe { cu(&mut d) })?;
1031        Ok(Self { raw: d })
1032    }
1033}
1034
1035impl Drop for SpSMPlan {
1036    fn drop(&mut self) {
1037        if let Ok(c) = cusparse() {
1038            if let Ok(cu) = c.cusparse_spsm_destroy_descr() {
1039                let _ = unsafe { cu(self.raw) };
1040            }
1041        }
1042    }
1043}
1044
1045#[allow(clippy::too_many_arguments)]
1046pub fn spsm_buffer_size<T: SparseScalar>(
1047    handle: &Handle,
1048    op_a: Op,
1049    op_b: Op,
1050    alpha: &T,
1051    a: &SpMat<'_, T>,
1052    b: &DnMat<'_, T>,
1053    c: &DnMat<'_, T>,
1054    alg: SpSMAlg,
1055    plan: &SpSMPlan,
1056) -> Result<usize> {
1057    let c_api = cusparse()?;
1058    let cu = c_api.cusparse_spsm_buffer_size()?;
1059    let mut size = 0usize;
1060    check(unsafe {
1061        cu(
1062            handle.as_raw(),
1063            op_a.raw(),
1064            op_b.raw(),
1065            alpha as *const T as *const c_void,
1066            a.descr,
1067            b.descr,
1068            c.descr,
1069            T::data_type(),
1070            alg,
1071            plan.raw,
1072            &mut size,
1073        )
1074    })?;
1075    Ok(size)
1076}
1077
1078#[allow(clippy::too_many_arguments)]
1079pub fn spsm_analysis<T: SparseScalar>(
1080    handle: &Handle,
1081    op_a: Op,
1082    op_b: Op,
1083    alpha: &T,
1084    a: &SpMat<'_, T>,
1085    b: &DnMat<'_, T>,
1086    c: &DnMat<'_, T>,
1087    alg: SpSMAlg,
1088    plan: &SpSMPlan,
1089    workspace: &mut DeviceBuffer<u8>,
1090) -> Result<()> {
1091    let c_api = cusparse()?;
1092    let cu = c_api.cusparse_spsm_analysis()?;
1093    check(unsafe {
1094        cu(
1095            handle.as_raw(),
1096            op_a.raw(),
1097            op_b.raw(),
1098            alpha as *const T as *const c_void,
1099            a.descr,
1100            b.descr,
1101            c.descr,
1102            T::data_type(),
1103            alg,
1104            plan.raw,
1105            workspace.as_raw().0 as *mut c_void,
1106        )
1107    })
1108}
1109
1110#[allow(clippy::too_many_arguments)]
1111pub fn spsm_solve<T: SparseScalar>(
1112    handle: &Handle,
1113    op_a: Op,
1114    op_b: Op,
1115    alpha: &T,
1116    a: &SpMat<'_, T>,
1117    b: &DnMat<'_, T>,
1118    c: &mut DnMat<'_, T>,
1119    alg: SpSMAlg,
1120    plan: &SpSMPlan,
1121) -> Result<()> {
1122    let c_api = cusparse()?;
1123    let cu = c_api.cusparse_spsm_solve()?;
1124    check(unsafe {
1125        cu(
1126            handle.as_raw(),
1127            op_a.raw(),
1128            op_b.raw(),
1129            alpha as *const T as *const c_void,
1130            a.descr,
1131            b.descr,
1132            c.descr,
1133            T::data_type(),
1134            alg,
1135            plan.raw,
1136        )
1137    })
1138}
1139
1140// ---- SDDMM -------------------------------------------------------------
1141
1142#[allow(clippy::too_many_arguments)]
1143pub fn sddmm_buffer_size<T: SparseScalar>(
1144    handle: &Handle,
1145    op_a: Op,
1146    op_b: Op,
1147    alpha: &T,
1148    a: &DnMat<'_, T>,
1149    b: &DnMat<'_, T>,
1150    beta: &T,
1151    c: &SpMat<'_, T>,
1152    alg: SDDMMAlg,
1153) -> Result<usize> {
1154    let c_api = cusparse()?;
1155    let cu = c_api.cusparse_sddmm_buffer_size()?;
1156    let mut size = 0usize;
1157    check(unsafe {
1158        cu(
1159            handle.as_raw(),
1160            op_a.raw(),
1161            op_b.raw(),
1162            alpha as *const T as *const c_void,
1163            a.descr,
1164            b.descr,
1165            beta as *const T as *const c_void,
1166            c.descr,
1167            T::data_type(),
1168            alg,
1169            &mut size,
1170        )
1171    })?;
1172    Ok(size)
1173}
1174
1175#[allow(clippy::too_many_arguments)]
1176pub fn sddmm<T: SparseScalar>(
1177    handle: &Handle,
1178    op_a: Op,
1179    op_b: Op,
1180    alpha: &T,
1181    a: &DnMat<'_, T>,
1182    b: &DnMat<'_, T>,
1183    beta: &T,
1184    c: &mut SpMat<'_, T>,
1185    alg: SDDMMAlg,
1186    workspace: &mut DeviceBuffer<u8>,
1187) -> Result<()> {
1188    let c_api = cusparse()?;
1189    let cu = c_api.cusparse_sddmm()?;
1190    check(unsafe {
1191        cu(
1192            handle.as_raw(),
1193            op_a.raw(),
1194            op_b.raw(),
1195            alpha as *const T as *const c_void,
1196            a.descr,
1197            b.descr,
1198            beta as *const T as *const c_void,
1199            c.descr,
1200            T::data_type(),
1201            alg,
1202            workspace.as_raw().0 as *mut c_void,
1203        )
1204    })
1205}
1206
1207/// One-time preprocessing before [`sddmm`]. See [`spmm_preprocess`] for
1208/// the rationale.
1209#[allow(clippy::too_many_arguments)]
1210pub fn sddmm_preprocess<T: SparseScalar>(
1211    handle: &Handle,
1212    op_a: Op,
1213    op_b: Op,
1214    alpha: &T,
1215    a: &DnMat<'_, T>,
1216    b: &DnMat<'_, T>,
1217    beta: &T,
1218    c: &mut SpMat<'_, T>,
1219    alg: SDDMMAlg,
1220    workspace: &mut DeviceBuffer<u8>,
1221) -> Result<()> {
1222    let c_api = cusparse()?;
1223    let cu = c_api.cusparse_sddmm_preprocess()?;
1224    check(unsafe {
1225        cu(
1226            handle.as_raw(),
1227            op_a.raw(),
1228            op_b.raw(),
1229            alpha as *const T as *const c_void,
1230            a.descr,
1231            b.descr,
1232            beta as *const T as *const c_void,
1233            c.descr,
1234            T::data_type(),
1235            alg,
1236            workspace.as_raw().0 as *mut c_void,
1237        )
1238    })
1239}
1240
1241// ---- Sparse / dense conversions ----------------------------------------
1242
1243pub fn sparse_to_dense_buffer_size<T: SparseScalar>(
1244    handle: &Handle,
1245    sp: &SpMat<'_, T>,
1246    dn: &DnMat<'_, T>,
1247) -> Result<usize> {
1248    let c = cusparse()?;
1249    let cu = c.cusparse_sparse_to_dense_buffer_size()?;
1250    let mut size = 0usize;
1251    check(unsafe { cu(handle.as_raw(), sp.descr, dn.descr, 0, &mut size) })?;
1252    Ok(size)
1253}
1254
1255pub fn sparse_to_dense<T: SparseScalar>(
1256    handle: &Handle,
1257    sp: &SpMat<'_, T>,
1258    dn: &mut DnMat<'_, T>,
1259    workspace: &mut DeviceBuffer<u8>,
1260) -> Result<()> {
1261    let c = cusparse()?;
1262    let cu = c.cusparse_sparse_to_dense()?;
1263    check(unsafe {
1264        cu(
1265            handle.as_raw(),
1266            sp.descr,
1267            dn.descr,
1268            0,
1269            workspace.as_raw().0 as *mut c_void,
1270        )
1271    })
1272}
1273
1274pub fn dense_to_sparse_buffer_size<T: SparseScalar>(
1275    handle: &Handle,
1276    dn: &DnMat<'_, T>,
1277    sp: &SpMat<'_, T>,
1278) -> Result<usize> {
1279    let c = cusparse()?;
1280    let cu = c.cusparse_dense_to_sparse_buffer_size()?;
1281    let mut size = 0usize;
1282    check(unsafe { cu(handle.as_raw(), dn.descr, sp.descr, 0, &mut size) })?;
1283    Ok(size)
1284}
1285
1286pub fn dense_to_sparse_analysis<T: SparseScalar>(
1287    handle: &Handle,
1288    dn: &DnMat<'_, T>,
1289    sp: &SpMat<'_, T>,
1290    workspace: &mut DeviceBuffer<u8>,
1291) -> Result<()> {
1292    let c = cusparse()?;
1293    let cu = c.cusparse_dense_to_sparse_analysis()?;
1294    check(unsafe {
1295        cu(
1296            handle.as_raw(),
1297            dn.descr,
1298            sp.descr,
1299            0,
1300            workspace.as_raw().0 as *mut c_void,
1301        )
1302    })
1303}
1304
1305pub fn dense_to_sparse_convert<T: SparseScalar>(
1306    handle: &Handle,
1307    dn: &DnMat<'_, T>,
1308    sp: &mut SpMat<'_, T>,
1309    workspace: &mut DeviceBuffer<u8>,
1310) -> Result<()> {
1311    let c = cusparse()?;
1312    let cu = c.cusparse_dense_to_sparse_convert()?;
1313    check(unsafe {
1314        cu(
1315            handle.as_raw(),
1316            dn.descr,
1317            sp.descr,
1318            0,
1319            workspace.as_raw().0 as *mut c_void,
1320        )
1321    })
1322}
1323
1324/// Workspace size in bytes for [`csr2csc_ex2`].
1325#[allow(clippy::too_many_arguments)]
1326pub fn csr2csc_ex2_buffer_size<T: SparseScalar + baracuda_types::DeviceRepr>(
1327    handle: &Handle,
1328    m: i32,
1329    n: i32,
1330    nnz: i32,
1331    csr_val: &DeviceBuffer<T>,
1332    csr_row_ptr: &DeviceBuffer<i32>,
1333    csr_col_ind: &DeviceBuffer<i32>,
1334    csc_val: &mut DeviceBuffer<T>,
1335    csc_col_ptr: &mut DeviceBuffer<i32>,
1336    csc_row_ind: &mut DeviceBuffer<i32>,
1337    copy_values: bool,
1338    idx_base: IndexBase,
1339    alg: Csr2CscAlg,
1340) -> Result<usize> {
1341    let c = cusparse()?;
1342    let cu = c.cusparse_csr2csc_ex2_buffer_size()?;
1343    let mut size = 0usize;
1344    check(unsafe {
1345        cu(
1346            handle.as_raw(),
1347            m,
1348            n,
1349            nnz,
1350            csr_val.as_raw().0 as *const c_void,
1351            csr_row_ptr.as_raw().0 as *const i32,
1352            csr_col_ind.as_raw().0 as *const i32,
1353            csc_val.as_raw().0 as *mut c_void,
1354            csc_col_ptr.as_raw().0 as *mut i32,
1355            csc_row_ind.as_raw().0 as *mut i32,
1356            T::data_type(),
1357            copy_values as i32,
1358            idx_base,
1359            alg,
1360            &mut size,
1361        )
1362    })?;
1363    Ok(size)
1364}
1365
1366/// Convert a CSR matrix to CSC format using the modern Ex2 entry point —
1367/// supports algorithm selection (`alg`) and arbitrary value types.
1368#[allow(clippy::too_many_arguments)]
1369pub fn csr2csc_ex2<T: SparseScalar + baracuda_types::DeviceRepr>(
1370    handle: &Handle,
1371    m: i32,
1372    n: i32,
1373    nnz: i32,
1374    csr_val: &DeviceBuffer<T>,
1375    csr_row_ptr: &DeviceBuffer<i32>,
1376    csr_col_ind: &DeviceBuffer<i32>,
1377    csc_val: &mut DeviceBuffer<T>,
1378    csc_col_ptr: &mut DeviceBuffer<i32>,
1379    csc_row_ind: &mut DeviceBuffer<i32>,
1380    copy_values: bool,
1381    idx_base: IndexBase,
1382    alg: Csr2CscAlg,
1383    workspace: &mut DeviceBuffer<u8>,
1384) -> Result<()> {
1385    let c = cusparse()?;
1386    let cu = c.cusparse_csr2csc_ex2()?;
1387    check(unsafe {
1388        cu(
1389            handle.as_raw(),
1390            m,
1391            n,
1392            nnz,
1393            csr_val.as_raw().0 as *const c_void,
1394            csr_row_ptr.as_raw().0 as *const i32,
1395            csr_col_ind.as_raw().0 as *const i32,
1396            csc_val.as_raw().0 as *mut c_void,
1397            csc_col_ptr.as_raw().0 as *mut i32,
1398            csc_row_ind.as_raw().0 as *mut i32,
1399            T::data_type(),
1400            copy_values as i32,
1401            idx_base,
1402            alg,
1403            workspace.as_raw().0 as *mut c_void,
1404        )
1405    })
1406}
1407
1408// ---- Sparse BLAS-1 helpers ---------------------------------------------
1409
1410pub fn axpby<T: SparseScalar>(
1411    handle: &Handle,
1412    alpha: &T,
1413    x: &DnVec<'_, T>,
1414    beta: &T,
1415    y: &mut DnVec<'_, T>,
1416) -> Result<()> {
1417    let c = cusparse()?;
1418    let cu = c.cusparse_axpby()?;
1419    check(unsafe {
1420        cu(
1421            handle.as_raw(),
1422            alpha as *const T as *const c_void,
1423            x.descr,
1424            beta as *const T as *const c_void,
1425            y.descr,
1426        )
1427    })
1428}
1429
1430pub fn gather<T: SparseScalar>(
1431    handle: &Handle,
1432    y: &DnVec<'_, T>,
1433    x: &mut DnVec<'_, T>,
1434) -> Result<()> {
1435    let c = cusparse()?;
1436    let cu = c.cusparse_gather()?;
1437    check(unsafe { cu(handle.as_raw(), y.descr, x.descr) })
1438}
1439
1440pub fn scatter<T: SparseScalar>(
1441    handle: &Handle,
1442    x: &DnVec<'_, T>,
1443    y: &mut DnVec<'_, T>,
1444) -> Result<()> {
1445    let c = cusparse()?;
1446    let cu = c.cusparse_scatter()?;
1447    check(unsafe { cu(handle.as_raw(), x.descr, y.descr) })
1448}
1449
1450pub fn rot<T: SparseScalar>(
1451    handle: &Handle,
1452    c_cos: &T,
1453    s_sin: &T,
1454    x: &mut DnVec<'_, T>,
1455    y: &mut DnVec<'_, T>,
1456) -> Result<()> {
1457    let c_api = cusparse()?;
1458    let cu = c_api.cusparse_rot()?;
1459    check(unsafe {
1460        cu(
1461            handle.as_raw(),
1462            c_cos as *const T as *const c_void,
1463            s_sin as *const T as *const c_void,
1464            x.descr,
1465            y.descr,
1466        )
1467    })
1468}
1469
1470// ---- Back-compat re-exports for existing users --------------------------
1471
1472/// Legacy alias kept for callers from v0.1 — prefer [`SpMat::csr`].
1473pub type CsrMatrix<'buf> = SpMat<'buf, f32>;
1474/// Legacy alias — prefer [`DnVec`].
1475pub type DenseVector<'buf, T> = DnVec<'buf, T>;