Skip to main content

baracuda_cusparse/
lib.rs

1//! Safe Rust wrappers for NVIDIA cuSPARSE.
2//!
3//! Covers the modern generic-API surface: `Handle`, `SpMat` (CSR/CSC/COO/BSR),
4//! `DnMat`, `DnVec`, and the family of op algorithms — SpMV, SpMM, SpGEMM,
5//! SpSV, SpSM, SDDMM — plus CSR↔CSC and sparse↔dense conversions, and the
6//! sparse BLAS-1 helpers (`axpby`, `gather`, `scatter`, `rot`).
7//!
8//! All matrix/vector descriptors borrow the underlying [`DeviceBuffer`]s and
9//! tie their lifetime to them so the buffers can't be freed while cuSPARSE
10//! is still holding references.
11
12#![warn(missing_debug_implementations)]
13
14use core::ffi::c_void;
15use std::marker::PhantomData;
16
17use baracuda_cusparse_sys::{
18    cudaDataType, cusparse, cusparseDiagType_t, cusparseDnMatDescr_t, cusparseDnVecDescr_t,
19    cusparseFillMode_t, cusparseHandle_t, cusparseIndexBase_t, cusparseIndexType_t,
20    cusparseOperation_t, cusparseOrder_t, cusparseSpGEMMDescr_t, cusparseSpMatAttribute_t,
21    cusparseSpMatDescr_t, cusparseSpSMDescr_t, cusparseSpSVDescr_t, cusparseStatus_t,
22};
23use baracuda_driver::{DeviceBuffer, Stream};
24use baracuda_types::{Complex32, Complex64};
25
26pub use baracuda_cusparse_sys::{
27    cusparseCsr2CscAlg_t as Csr2CscAlg, cusparseSDDMMAlg_t as SDDMMAlg,
28    cusparseSpGEMMAlg_t as SpGEMMAlg, cusparseSpMMAlg_t as SpMMAlg, cusparseSpMVAlg_t as SpMVAlg,
29    cusparseSpSMAlg_t as SpSMAlg, cusparseSpSVAlg_t as SpSVAlg,
30};
31
32/// Error type for cuSPARSE operations.
33pub type Error = baracuda_core::Error<cusparseStatus_t>;
34/// Result alias.
35pub type Result<T, E = Error> = core::result::Result<T, E>;
36
37#[inline]
38fn check(status: cusparseStatus_t) -> Result<()> {
39    Error::check(status)
40}
41
42// ---- scalar <-> cudaDataType --------------------------------------------
43
44/// Types that cuSPARSE's generic API accepts as element / compute type.
45pub trait SparseScalar: sealed::Sealed + Copy + 'static {
46    /// cuSPARSE / cuBLAS element-type tag.
47    fn data_type() -> cudaDataType;
48}
49
50impl SparseScalar for f32 {
51    fn data_type() -> cudaDataType {
52        cudaDataType::R_32F
53    }
54}
55impl SparseScalar for f64 {
56    fn data_type() -> cudaDataType {
57        cudaDataType::R_64F
58    }
59}
60impl SparseScalar for Complex32 {
61    fn data_type() -> cudaDataType {
62        cudaDataType::C_32F
63    }
64}
65impl SparseScalar for Complex64 {
66    fn data_type() -> cudaDataType {
67        cudaDataType::C_64F
68    }
69}
70
71mod sealed {
72    use baracuda_types::{Complex32, Complex64};
73    pub trait Sealed {}
74    impl Sealed for f32 {}
75    impl Sealed for f64 {}
76    impl Sealed for Complex32 {}
77    impl Sealed for Complex64 {}
78}
79
80// ---- Op / Order / Fill / Diag wrappers ----------------------------------
81
82/// Transpose / conjugate-transpose selector.
83#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
84pub enum Op {
85    #[default]
86    N,
87    T,
88    C,
89}
90
91impl Op {
92    fn raw(self) -> cusparseOperation_t {
93        match self {
94            Op::N => cusparseOperation_t::N,
95            Op::T => cusparseOperation_t::T,
96            Op::C => cusparseOperation_t::C,
97        }
98    }
99}
100
101/// Dense-matrix storage order.
102#[derive(Copy, Clone, Debug, Eq, PartialEq)]
103pub enum Order {
104    Row,
105    Col,
106}
107
108impl Order {
109    fn raw(self) -> cusparseOrder_t {
110        match self {
111            Order::Row => cusparseOrder_t::Row,
112            Order::Col => cusparseOrder_t::Col,
113        }
114    }
115}
116
117#[derive(Copy, Clone, Debug, Eq, PartialEq)]
118pub enum Fill {
119    Lower,
120    Upper,
121}
122
123impl Fill {
124    fn raw(self) -> cusparseFillMode_t {
125        match self {
126            Fill::Lower => cusparseFillMode_t::Lower,
127            Fill::Upper => cusparseFillMode_t::Upper,
128        }
129    }
130}
131
132#[derive(Copy, Clone, Debug, Eq, PartialEq)]
133pub enum Diag {
134    NonUnit,
135    Unit,
136}
137
138impl Diag {
139    fn raw(self) -> cusparseDiagType_t {
140        match self {
141            Diag::NonUnit => cusparseDiagType_t::NonUnit,
142            Diag::Unit => cusparseDiagType_t::Unit,
143        }
144    }
145}
146
147// ---- Handle -------------------------------------------------------------
148
149/// Owned cuSPARSE handle.
150pub struct Handle {
151    handle: cusparseHandle_t,
152}
153
154unsafe impl Send for Handle {}
155
156impl core::fmt::Debug for Handle {
157    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
158        f.debug_struct("cusparse::Handle")
159            .field("handle", &self.handle)
160            .finish()
161    }
162}
163
164impl Handle {
165    pub fn new() -> Result<Self> {
166        let c = cusparse()?;
167        let cu = c.cusparse_create()?;
168        let mut h: cusparseHandle_t = core::ptr::null_mut();
169        check(unsafe { cu(&mut h) })?;
170        Ok(Self { handle: h })
171    }
172
173    pub fn set_stream(&self, stream: &Stream) -> Result<()> {
174        let c = cusparse()?;
175        let cu = c.cusparse_set_stream()?;
176        check(unsafe { cu(self.handle, stream.as_raw() as _) })
177    }
178
179    pub fn version(&self) -> Result<i32> {
180        let c = cusparse()?;
181        let cu = c.cusparse_get_version()?;
182        let mut v: core::ffi::c_int = 0;
183        check(unsafe { cu(self.handle, &mut v) })?;
184        Ok(v)
185    }
186
187    #[inline]
188    pub fn as_raw(&self) -> cusparseHandle_t {
189        self.handle
190    }
191}
192
193impl Drop for Handle {
194    fn drop(&mut self) {
195        if let Ok(c) = cusparse() {
196            if let Ok(cu) = c.cusparse_destroy() {
197                let _ = unsafe { cu(self.handle) };
198            }
199        }
200    }
201}
202
203// ---- Sparse matrix descriptor -------------------------------------------
204
205/// A sparse-matrix descriptor (CSR / CSC / COO / BSR). The descriptor keeps
206/// pointers to externally-owned device buffers; the lifetime parameter ties
207/// those buffers to the descriptor.
208pub struct SpMat<'buf, T> {
209    descr: cusparseSpMatDescr_t,
210    _markers: PhantomData<&'buf mut T>,
211}
212
213unsafe impl<T> Send for SpMat<'_, T> {}
214
215impl<T> core::fmt::Debug for SpMat<'_, T> {
216    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
217        f.debug_struct("SpMat")
218            .field("descr", &self.descr)
219            .finish_non_exhaustive()
220    }
221}
222
223impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> SpMat<'buf, T> {
224    /// Build a CSR (compressed sparse row) descriptor.
225    ///
226    /// `row_offsets.len()` must equal `rows + 1`; `col_indices.len()` and
227    /// `values.len()` must equal `nnz`.
228    pub fn csr(
229        rows: i64,
230        cols: i64,
231        nnz: i64,
232        row_offsets: &'buf mut DeviceBuffer<i32>,
233        col_indices: &'buf mut DeviceBuffer<i32>,
234        values: &'buf mut DeviceBuffer<T>,
235    ) -> Result<Self> {
236        let c = cusparse()?;
237        let cu = c.cusparse_create_csr()?;
238        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
239        check(unsafe {
240            cu(
241                &mut descr,
242                rows,
243                cols,
244                nnz,
245                row_offsets.as_raw().0 as *mut c_void,
246                col_indices.as_raw().0 as *mut c_void,
247                values.as_raw().0 as *mut c_void,
248                cusparseIndexType_t::I32I,
249                cusparseIndexType_t::I32I,
250                cusparseIndexBase_t::Zero,
251                T::data_type(),
252            )
253        })?;
254        Ok(Self {
255            descr,
256            _markers: PhantomData,
257        })
258    }
259
260    /// Build a CSC (compressed sparse column) descriptor.
261    pub fn csc(
262        rows: i64,
263        cols: i64,
264        nnz: i64,
265        col_offsets: &'buf mut DeviceBuffer<i32>,
266        row_indices: &'buf mut DeviceBuffer<i32>,
267        values: &'buf mut DeviceBuffer<T>,
268    ) -> Result<Self> {
269        let c = cusparse()?;
270        let cu = c.cusparse_create_csc()?;
271        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
272        check(unsafe {
273            cu(
274                &mut descr,
275                rows,
276                cols,
277                nnz,
278                col_offsets.as_raw().0 as *mut c_void,
279                row_indices.as_raw().0 as *mut c_void,
280                values.as_raw().0 as *mut c_void,
281                cusparseIndexType_t::I32I,
282                cusparseIndexType_t::I32I,
283                cusparseIndexBase_t::Zero,
284                T::data_type(),
285            )
286        })?;
287        Ok(Self {
288            descr,
289            _markers: PhantomData,
290        })
291    }
292
293    /// Build a BSR (block-sparse-row) descriptor.
294    #[allow(clippy::too_many_arguments)]
295    pub fn bsr(
296        brows: i64,
297        bcols: i64,
298        bnnz: i64,
299        row_block_dim: i64,
300        col_block_dim: i64,
301        order: Order,
302        row_offsets: &'buf mut DeviceBuffer<i32>,
303        col_indices: &'buf mut DeviceBuffer<i32>,
304        values: &'buf mut DeviceBuffer<T>,
305    ) -> Result<Self> {
306        let c = cusparse()?;
307        let cu = c.cusparse_create_bsr()?;
308        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
309        check(unsafe {
310            cu(
311                &mut descr,
312                brows,
313                bcols,
314                bnnz,
315                row_block_dim,
316                col_block_dim,
317                row_offsets.as_raw().0 as *mut c_void,
318                col_indices.as_raw().0 as *mut c_void,
319                values.as_raw().0 as *mut c_void,
320                cusparseIndexType_t::I32I,
321                cusparseIndexType_t::I32I,
322                cusparseIndexBase_t::Zero,
323                T::data_type(),
324                order.raw(),
325            )
326        })?;
327        Ok(Self {
328            descr,
329            _markers: PhantomData,
330        })
331    }
332
333    /// Build a COO (coordinate) descriptor.
334    pub fn coo(
335        rows: i64,
336        cols: i64,
337        nnz: i64,
338        row_indices: &'buf mut DeviceBuffer<i32>,
339        col_indices: &'buf mut DeviceBuffer<i32>,
340        values: &'buf mut DeviceBuffer<T>,
341    ) -> Result<Self> {
342        let c = cusparse()?;
343        let cu = c.cusparse_create_coo()?;
344        let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
345        check(unsafe {
346            cu(
347                &mut descr,
348                rows,
349                cols,
350                nnz,
351                row_indices.as_raw().0 as *mut c_void,
352                col_indices.as_raw().0 as *mut c_void,
353                values.as_raw().0 as *mut c_void,
354                cusparseIndexType_t::I32I,
355                cusparseIndexBase_t::Zero,
356                T::data_type(),
357            )
358        })?;
359        Ok(Self {
360            descr,
361            _markers: PhantomData,
362        })
363    }
364}
365
366impl<T> SpMat<'_, T> {
367    /// Sparse matrix dimensions: `(rows, cols, nnz)`.
368    pub fn shape(&self) -> Result<(i64, i64, i64)> {
369        let c = cusparse()?;
370        let cu = c.cusparse_sp_mat_get_size()?;
371        let (mut r, mut col, mut nz) = (0i64, 0i64, 0i64);
372        check(unsafe { cu(self.descr, &mut r, &mut col, &mut nz) })?;
373        Ok((r, col, nz))
374    }
375
376    /// Set the fill-triangle attribute (for triangular solves).
377    pub fn set_fill(&self, fill: Fill) -> Result<()> {
378        let c = cusparse()?;
379        let cu = c.cusparse_sp_mat_set_attribute()?;
380        let raw = fill.raw();
381        check(unsafe {
382            cu(
383                self.descr,
384                cusparseSpMatAttribute_t::FillMode,
385                &raw as *const _ as *const c_void,
386                core::mem::size_of::<cusparseFillMode_t>(),
387            )
388        })
389    }
390
391    /// Set the diagonal-type attribute (unit vs non-unit, for triangular solves).
392    pub fn set_diag(&self, diag: Diag) -> Result<()> {
393        let c = cusparse()?;
394        let cu = c.cusparse_sp_mat_set_attribute()?;
395        let raw = diag.raw();
396        check(unsafe {
397            cu(
398                self.descr,
399                cusparseSpMatAttribute_t::DiagType,
400                &raw as *const _ as *const c_void,
401                core::mem::size_of::<cusparseDiagType_t>(),
402            )
403        })
404    }
405
406    #[inline]
407    pub fn as_raw(&self) -> cusparseSpMatDescr_t {
408        self.descr
409    }
410}
411
412impl<T> Drop for SpMat<'_, T> {
413    fn drop(&mut self) {
414        if let Ok(c) = cusparse() {
415            if let Ok(cu) = c.cusparse_destroy_sp_mat() {
416                let _ = unsafe { cu(self.descr) };
417            }
418        }
419    }
420}
421
422// ---- Dense vector / matrix -----------------------------------------------
423
424pub struct DnVec<'buf, T> {
425    descr: cusparseDnVecDescr_t,
426    _marker: PhantomData<&'buf mut T>,
427}
428
429unsafe impl<T> Send for DnVec<'_, T> {}
430
431impl<T> core::fmt::Debug for DnVec<'_, T> {
432    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
433        f.debug_struct("DnVec")
434            .field("descr", &self.descr)
435            .finish_non_exhaustive()
436    }
437}
438
439impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnVec<'buf, T> {
440    pub fn new(values: &'buf mut DeviceBuffer<T>) -> Result<Self> {
441        let c = cusparse()?;
442        let cu = c.cusparse_create_dn_vec()?;
443        let mut descr: cusparseDnVecDescr_t = core::ptr::null_mut();
444        check(unsafe {
445            cu(
446                &mut descr,
447                values.len() as i64,
448                values.as_raw().0 as *mut c_void,
449                T::data_type(),
450            )
451        })?;
452        Ok(Self {
453            descr,
454            _marker: PhantomData,
455        })
456    }
457}
458
459impl<T> DnVec<'_, T> {
460    #[inline]
461    pub fn as_raw(&self) -> cusparseDnVecDescr_t {
462        self.descr
463    }
464}
465
466impl<T> Drop for DnVec<'_, T> {
467    fn drop(&mut self) {
468        if let Ok(c) = cusparse() {
469            if let Ok(cu) = c.cusparse_destroy_dn_vec() {
470                let _ = unsafe { cu(self.descr) };
471            }
472        }
473    }
474}
475
476pub struct DnMat<'buf, T> {
477    descr: cusparseDnMatDescr_t,
478    _marker: PhantomData<&'buf mut T>,
479}
480
481unsafe impl<T> Send for DnMat<'_, T> {}
482
483impl<T> core::fmt::Debug for DnMat<'_, T> {
484    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
485        f.debug_struct("DnMat")
486            .field("descr", &self.descr)
487            .finish_non_exhaustive()
488    }
489}
490
491impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnMat<'buf, T> {
492    pub fn new(
493        rows: i64,
494        cols: i64,
495        ld: i64,
496        order: Order,
497        values: &'buf mut DeviceBuffer<T>,
498    ) -> Result<Self> {
499        let c = cusparse()?;
500        let cu = c.cusparse_create_dn_mat()?;
501        let mut descr: cusparseDnMatDescr_t = core::ptr::null_mut();
502        check(unsafe {
503            cu(
504                &mut descr,
505                rows,
506                cols,
507                ld,
508                values.as_raw().0 as *mut c_void,
509                T::data_type(),
510                order.raw(),
511            )
512        })?;
513        Ok(Self {
514            descr,
515            _marker: PhantomData,
516        })
517    }
518}
519
520impl<T> DnMat<'_, T> {
521    #[inline]
522    pub fn as_raw(&self) -> cusparseDnMatDescr_t {
523        self.descr
524    }
525}
526
527impl<T> Drop for DnMat<'_, T> {
528    fn drop(&mut self) {
529        if let Ok(c) = cusparse() {
530            if let Ok(cu) = c.cusparse_destroy_dn_mat() {
531                let _ = unsafe { cu(self.descr) };
532            }
533        }
534    }
535}
536
537// ---- SpMV ---------------------------------------------------------------
538
539/// Query buffer-size for `y = alpha * op(A) * x + beta * y`.
540#[allow(clippy::too_many_arguments)]
541pub fn spmv_buffer_size<T: SparseScalar>(
542    handle: &Handle,
543    op: Op,
544    alpha: &T,
545    a: &SpMat<'_, T>,
546    x: &DnVec<'_, T>,
547    beta: &T,
548    y: &DnVec<'_, T>,
549    alg: SpMVAlg,
550) -> Result<usize> {
551    let c = cusparse()?;
552    let cu = c.cusparse_spmv_buffer_size()?;
553    let mut size: usize = 0;
554    check(unsafe {
555        cu(
556            handle.as_raw(),
557            op.raw(),
558            alpha as *const T as *const c_void,
559            a.descr,
560            x.descr,
561            beta as *const T as *const c_void,
562            y.descr,
563            T::data_type(),
564            alg,
565            &mut size,
566        )
567    })?;
568    Ok(size)
569}
570
571/// `y = alpha * op(A) * x + beta * y`.
572#[allow(clippy::too_many_arguments)]
573pub fn spmv<T: SparseScalar>(
574    handle: &Handle,
575    op: Op,
576    alpha: &T,
577    a: &SpMat<'_, T>,
578    x: &DnVec<'_, T>,
579    beta: &T,
580    y: &mut DnVec<'_, T>,
581    alg: SpMVAlg,
582    workspace: &mut DeviceBuffer<u8>,
583) -> Result<()> {
584    let c = cusparse()?;
585    let cu = c.cusparse_spmv()?;
586    check(unsafe {
587        cu(
588            handle.as_raw(),
589            op.raw(),
590            alpha as *const T as *const c_void,
591            a.descr,
592            x.descr,
593            beta as *const T as *const c_void,
594            y.descr,
595            T::data_type(),
596            alg,
597            workspace.as_raw().0 as *mut c_void,
598        )
599    })
600}
601
602// ---- SpMM ---------------------------------------------------------------
603
604/// Query buffer-size for `C = alpha * op(A) * op(B) + beta * C`, `A` sparse.
605#[allow(clippy::too_many_arguments)]
606pub fn spmm_buffer_size<T: SparseScalar>(
607    handle: &Handle,
608    op_a: Op,
609    op_b: Op,
610    alpha: &T,
611    a: &SpMat<'_, T>,
612    b: &DnMat<'_, T>,
613    beta: &T,
614    c: &DnMat<'_, T>,
615    alg: SpMMAlg,
616) -> Result<usize> {
617    let c_api = cusparse()?;
618    let cu = c_api.cusparse_spmm_buffer_size()?;
619    let mut size = 0usize;
620    check(unsafe {
621        cu(
622            handle.as_raw(),
623            op_a.raw(),
624            op_b.raw(),
625            alpha as *const T as *const c_void,
626            a.descr,
627            b.descr,
628            beta as *const T as *const c_void,
629            c.descr,
630            T::data_type(),
631            alg,
632            &mut size,
633        )
634    })?;
635    Ok(size)
636}
637
638#[allow(clippy::too_many_arguments)]
639pub fn spmm<T: SparseScalar>(
640    handle: &Handle,
641    op_a: Op,
642    op_b: Op,
643    alpha: &T,
644    a: &SpMat<'_, T>,
645    b: &DnMat<'_, T>,
646    beta: &T,
647    c: &mut DnMat<'_, T>,
648    alg: SpMMAlg,
649    workspace: &mut DeviceBuffer<u8>,
650) -> Result<()> {
651    let c_api = cusparse()?;
652    let cu = c_api.cusparse_spmm()?;
653    check(unsafe {
654        cu(
655            handle.as_raw(),
656            op_a.raw(),
657            op_b.raw(),
658            alpha as *const T as *const c_void,
659            a.descr,
660            b.descr,
661            beta as *const T as *const c_void,
662            c.descr,
663            T::data_type(),
664            alg,
665            workspace.as_raw().0 as *mut c_void,
666        )
667    })
668}
669
670// ---- SpGEMM -------------------------------------------------------------
671
672/// Per-plan handle for a 3-phase SpGEMM computation.
673#[derive(Debug)]
674pub struct SpGEMMPlan {
675    raw: cusparseSpGEMMDescr_t,
676}
677
678impl SpGEMMPlan {
679    pub fn new() -> Result<Self> {
680        let c = cusparse()?;
681        let cu = c.cusparse_spgemm_create_descr()?;
682        let mut d: cusparseSpGEMMDescr_t = core::ptr::null_mut();
683        check(unsafe { cu(&mut d) })?;
684        Ok(Self { raw: d })
685    }
686}
687
688impl Drop for SpGEMMPlan {
689    fn drop(&mut self) {
690        if let Ok(c) = cusparse() {
691            if let Ok(cu) = c.cusparse_spgemm_destroy_descr() {
692                let _ = unsafe { cu(self.raw) };
693            }
694        }
695    }
696}
697
698/// Phase 1: work-estimation. The caller provides `buffer1` whose size is
699/// returned in `size1`; pass `null` the first time, then allocate & re-call.
700#[allow(clippy::too_many_arguments)]
701pub unsafe fn spgemm_work_estimation<T: SparseScalar>(
702    handle: &Handle,
703    op_a: Op,
704    op_b: Op,
705    alpha: &T,
706    a: &SpMat<'_, T>,
707    b: &SpMat<'_, T>,
708    beta: &T,
709    c: &mut SpMat<'_, T>,
710    alg: SpGEMMAlg,
711    plan: &SpGEMMPlan,
712    size1: &mut usize,
713    buffer1: *mut c_void,
714) -> Result<()> {
715    let c_api = cusparse()?;
716    let cu = c_api.cusparse_spgemm_work_estimation()?;
717    check(cu(
718        handle.as_raw(),
719        op_a.raw(),
720        op_b.raw(),
721        alpha as *const T as *const c_void,
722        a.descr,
723        b.descr,
724        beta as *const T as *const c_void,
725        c.descr,
726        T::data_type(),
727        alg,
728        plan.raw,
729        size1,
730        buffer1,
731    ))
732}
733
734/// Phase 2: compute. Same two-step pattern for `buffer2`.
735#[allow(clippy::too_many_arguments)]
736pub unsafe fn spgemm_compute<T: SparseScalar>(
737    handle: &Handle,
738    op_a: Op,
739    op_b: Op,
740    alpha: &T,
741    a: &SpMat<'_, T>,
742    b: &SpMat<'_, T>,
743    beta: &T,
744    c: &mut SpMat<'_, T>,
745    alg: SpGEMMAlg,
746    plan: &SpGEMMPlan,
747    size2: &mut usize,
748    buffer2: *mut c_void,
749) -> Result<()> {
750    let c_api = cusparse()?;
751    let cu = c_api.cusparse_spgemm_compute()?;
752    check(cu(
753        handle.as_raw(),
754        op_a.raw(),
755        op_b.raw(),
756        alpha as *const T as *const c_void,
757        a.descr,
758        b.descr,
759        beta as *const T as *const c_void,
760        c.descr,
761        T::data_type(),
762        alg,
763        plan.raw,
764        size2,
765        buffer2,
766    ))
767}
768
769/// Phase 3: write output arrays into the pre-populated output `SpMat`.
770#[allow(clippy::too_many_arguments)]
771pub fn spgemm_copy<T: SparseScalar>(
772    handle: &Handle,
773    op_a: Op,
774    op_b: Op,
775    alpha: &T,
776    a: &SpMat<'_, T>,
777    b: &SpMat<'_, T>,
778    beta: &T,
779    c: &mut SpMat<'_, T>,
780    alg: SpGEMMAlg,
781    plan: &SpGEMMPlan,
782) -> Result<()> {
783    let c_api = cusparse()?;
784    let cu = c_api.cusparse_spgemm_copy()?;
785    check(unsafe {
786        cu(
787            handle.as_raw(),
788            op_a.raw(),
789            op_b.raw(),
790            alpha as *const T as *const c_void,
791            a.descr,
792            b.descr,
793            beta as *const T as *const c_void,
794            c.descr,
795            T::data_type(),
796            alg,
797            plan.raw,
798        )
799    })
800}
801
802// ---- SpSV / SpSM --------------------------------------------------------
803
804#[derive(Debug)]
805pub struct SpSVPlan {
806    raw: cusparseSpSVDescr_t,
807}
808
809impl SpSVPlan {
810    pub fn new() -> Result<Self> {
811        let c = cusparse()?;
812        let cu = c.cusparse_spsv_create_descr()?;
813        let mut d: cusparseSpSVDescr_t = core::ptr::null_mut();
814        check(unsafe { cu(&mut d) })?;
815        Ok(Self { raw: d })
816    }
817}
818
819impl Drop for SpSVPlan {
820    fn drop(&mut self) {
821        if let Ok(c) = cusparse() {
822            if let Ok(cu) = c.cusparse_spsv_destroy_descr() {
823                let _ = unsafe { cu(self.raw) };
824            }
825        }
826    }
827}
828
829#[allow(clippy::too_many_arguments)]
830pub fn spsv_buffer_size<T: SparseScalar>(
831    handle: &Handle,
832    op: Op,
833    alpha: &T,
834    a: &SpMat<'_, T>,
835    x: &DnVec<'_, T>,
836    y: &DnVec<'_, T>,
837    alg: SpSVAlg,
838    plan: &SpSVPlan,
839) -> Result<usize> {
840    let c = cusparse()?;
841    let cu = c.cusparse_spsv_buffer_size()?;
842    let mut size = 0usize;
843    check(unsafe {
844        cu(
845            handle.as_raw(),
846            op.raw(),
847            alpha as *const T as *const c_void,
848            a.descr,
849            x.descr,
850            y.descr,
851            T::data_type(),
852            alg,
853            plan.raw,
854            &mut size,
855        )
856    })?;
857    Ok(size)
858}
859
860#[allow(clippy::too_many_arguments)]
861pub fn spsv_analysis<T: SparseScalar>(
862    handle: &Handle,
863    op: Op,
864    alpha: &T,
865    a: &SpMat<'_, T>,
866    x: &DnVec<'_, T>,
867    y: &DnVec<'_, T>,
868    alg: SpSVAlg,
869    plan: &SpSVPlan,
870    workspace: &mut DeviceBuffer<u8>,
871) -> Result<()> {
872    let c = cusparse()?;
873    let cu = c.cusparse_spsv_analysis()?;
874    check(unsafe {
875        cu(
876            handle.as_raw(),
877            op.raw(),
878            alpha as *const T as *const c_void,
879            a.descr,
880            x.descr,
881            y.descr,
882            T::data_type(),
883            alg,
884            plan.raw,
885            workspace.as_raw().0 as *mut c_void,
886        )
887    })
888}
889
890#[allow(clippy::too_many_arguments)]
891pub fn spsv_solve<T: SparseScalar>(
892    handle: &Handle,
893    op: Op,
894    alpha: &T,
895    a: &SpMat<'_, T>,
896    x: &DnVec<'_, T>,
897    y: &mut DnVec<'_, T>,
898    alg: SpSVAlg,
899    plan: &SpSVPlan,
900) -> Result<()> {
901    let c = cusparse()?;
902    let cu = c.cusparse_spsv_solve()?;
903    check(unsafe {
904        cu(
905            handle.as_raw(),
906            op.raw(),
907            alpha as *const T as *const c_void,
908            a.descr,
909            x.descr,
910            y.descr,
911            T::data_type(),
912            alg,
913            plan.raw,
914        )
915    })
916}
917
918#[derive(Debug)]
919pub struct SpSMPlan {
920    raw: cusparseSpSMDescr_t,
921}
922
923impl SpSMPlan {
924    pub fn new() -> Result<Self> {
925        let c = cusparse()?;
926        let cu = c.cusparse_spsm_create_descr()?;
927        let mut d: cusparseSpSMDescr_t = core::ptr::null_mut();
928        check(unsafe { cu(&mut d) })?;
929        Ok(Self { raw: d })
930    }
931}
932
933impl Drop for SpSMPlan {
934    fn drop(&mut self) {
935        if let Ok(c) = cusparse() {
936            if let Ok(cu) = c.cusparse_spsm_destroy_descr() {
937                let _ = unsafe { cu(self.raw) };
938            }
939        }
940    }
941}
942
943#[allow(clippy::too_many_arguments)]
944pub fn spsm_buffer_size<T: SparseScalar>(
945    handle: &Handle,
946    op_a: Op,
947    op_b: Op,
948    alpha: &T,
949    a: &SpMat<'_, T>,
950    b: &DnMat<'_, T>,
951    c: &DnMat<'_, T>,
952    alg: SpSMAlg,
953    plan: &SpSMPlan,
954) -> Result<usize> {
955    let c_api = cusparse()?;
956    let cu = c_api.cusparse_spsm_buffer_size()?;
957    let mut size = 0usize;
958    check(unsafe {
959        cu(
960            handle.as_raw(),
961            op_a.raw(),
962            op_b.raw(),
963            alpha as *const T as *const c_void,
964            a.descr,
965            b.descr,
966            c.descr,
967            T::data_type(),
968            alg,
969            plan.raw,
970            &mut size,
971        )
972    })?;
973    Ok(size)
974}
975
976#[allow(clippy::too_many_arguments)]
977pub fn spsm_analysis<T: SparseScalar>(
978    handle: &Handle,
979    op_a: Op,
980    op_b: Op,
981    alpha: &T,
982    a: &SpMat<'_, T>,
983    b: &DnMat<'_, T>,
984    c: &DnMat<'_, T>,
985    alg: SpSMAlg,
986    plan: &SpSMPlan,
987    workspace: &mut DeviceBuffer<u8>,
988) -> Result<()> {
989    let c_api = cusparse()?;
990    let cu = c_api.cusparse_spsm_analysis()?;
991    check(unsafe {
992        cu(
993            handle.as_raw(),
994            op_a.raw(),
995            op_b.raw(),
996            alpha as *const T as *const c_void,
997            a.descr,
998            b.descr,
999            c.descr,
1000            T::data_type(),
1001            alg,
1002            plan.raw,
1003            workspace.as_raw().0 as *mut c_void,
1004        )
1005    })
1006}
1007
1008#[allow(clippy::too_many_arguments)]
1009pub fn spsm_solve<T: SparseScalar>(
1010    handle: &Handle,
1011    op_a: Op,
1012    op_b: Op,
1013    alpha: &T,
1014    a: &SpMat<'_, T>,
1015    b: &DnMat<'_, T>,
1016    c: &mut DnMat<'_, T>,
1017    alg: SpSMAlg,
1018    plan: &SpSMPlan,
1019) -> Result<()> {
1020    let c_api = cusparse()?;
1021    let cu = c_api.cusparse_spsm_solve()?;
1022    check(unsafe {
1023        cu(
1024            handle.as_raw(),
1025            op_a.raw(),
1026            op_b.raw(),
1027            alpha as *const T as *const c_void,
1028            a.descr,
1029            b.descr,
1030            c.descr,
1031            T::data_type(),
1032            alg,
1033            plan.raw,
1034        )
1035    })
1036}
1037
1038// ---- SDDMM -------------------------------------------------------------
1039
1040#[allow(clippy::too_many_arguments)]
1041pub fn sddmm_buffer_size<T: SparseScalar>(
1042    handle: &Handle,
1043    op_a: Op,
1044    op_b: Op,
1045    alpha: &T,
1046    a: &DnMat<'_, T>,
1047    b: &DnMat<'_, T>,
1048    beta: &T,
1049    c: &SpMat<'_, T>,
1050    alg: SDDMMAlg,
1051) -> Result<usize> {
1052    let c_api = cusparse()?;
1053    let cu = c_api.cusparse_sddmm_buffer_size()?;
1054    let mut size = 0usize;
1055    check(unsafe {
1056        cu(
1057            handle.as_raw(),
1058            op_a.raw(),
1059            op_b.raw(),
1060            alpha as *const T as *const c_void,
1061            a.descr,
1062            b.descr,
1063            beta as *const T as *const c_void,
1064            c.descr,
1065            T::data_type(),
1066            alg,
1067            &mut size,
1068        )
1069    })?;
1070    Ok(size)
1071}
1072
1073#[allow(clippy::too_many_arguments)]
1074pub fn sddmm<T: SparseScalar>(
1075    handle: &Handle,
1076    op_a: Op,
1077    op_b: Op,
1078    alpha: &T,
1079    a: &DnMat<'_, T>,
1080    b: &DnMat<'_, T>,
1081    beta: &T,
1082    c: &mut SpMat<'_, T>,
1083    alg: SDDMMAlg,
1084    workspace: &mut DeviceBuffer<u8>,
1085) -> Result<()> {
1086    let c_api = cusparse()?;
1087    let cu = c_api.cusparse_sddmm()?;
1088    check(unsafe {
1089        cu(
1090            handle.as_raw(),
1091            op_a.raw(),
1092            op_b.raw(),
1093            alpha as *const T as *const c_void,
1094            a.descr,
1095            b.descr,
1096            beta as *const T as *const c_void,
1097            c.descr,
1098            T::data_type(),
1099            alg,
1100            workspace.as_raw().0 as *mut c_void,
1101        )
1102    })
1103}
1104
1105// ---- Sparse / dense conversions ----------------------------------------
1106
1107pub fn sparse_to_dense_buffer_size<T: SparseScalar>(
1108    handle: &Handle,
1109    sp: &SpMat<'_, T>,
1110    dn: &DnMat<'_, T>,
1111) -> Result<usize> {
1112    let c = cusparse()?;
1113    let cu = c.cusparse_sparse_to_dense_buffer_size()?;
1114    let mut size = 0usize;
1115    check(unsafe { cu(handle.as_raw(), sp.descr, dn.descr, 0, &mut size) })?;
1116    Ok(size)
1117}
1118
1119pub fn sparse_to_dense<T: SparseScalar>(
1120    handle: &Handle,
1121    sp: &SpMat<'_, T>,
1122    dn: &mut DnMat<'_, T>,
1123    workspace: &mut DeviceBuffer<u8>,
1124) -> Result<()> {
1125    let c = cusparse()?;
1126    let cu = c.cusparse_sparse_to_dense()?;
1127    check(unsafe {
1128        cu(
1129            handle.as_raw(),
1130            sp.descr,
1131            dn.descr,
1132            0,
1133            workspace.as_raw().0 as *mut c_void,
1134        )
1135    })
1136}
1137
1138pub fn dense_to_sparse_buffer_size<T: SparseScalar>(
1139    handle: &Handle,
1140    dn: &DnMat<'_, T>,
1141    sp: &SpMat<'_, T>,
1142) -> Result<usize> {
1143    let c = cusparse()?;
1144    let cu = c.cusparse_dense_to_sparse_buffer_size()?;
1145    let mut size = 0usize;
1146    check(unsafe { cu(handle.as_raw(), dn.descr, sp.descr, 0, &mut size) })?;
1147    Ok(size)
1148}
1149
1150pub fn dense_to_sparse_analysis<T: SparseScalar>(
1151    handle: &Handle,
1152    dn: &DnMat<'_, T>,
1153    sp: &SpMat<'_, T>,
1154    workspace: &mut DeviceBuffer<u8>,
1155) -> Result<()> {
1156    let c = cusparse()?;
1157    let cu = c.cusparse_dense_to_sparse_analysis()?;
1158    check(unsafe {
1159        cu(
1160            handle.as_raw(),
1161            dn.descr,
1162            sp.descr,
1163            0,
1164            workspace.as_raw().0 as *mut c_void,
1165        )
1166    })
1167}
1168
1169pub fn dense_to_sparse_convert<T: SparseScalar>(
1170    handle: &Handle,
1171    dn: &DnMat<'_, T>,
1172    sp: &mut SpMat<'_, T>,
1173    workspace: &mut DeviceBuffer<u8>,
1174) -> Result<()> {
1175    let c = cusparse()?;
1176    let cu = c.cusparse_dense_to_sparse_convert()?;
1177    check(unsafe {
1178        cu(
1179            handle.as_raw(),
1180            dn.descr,
1181            sp.descr,
1182            0,
1183            workspace.as_raw().0 as *mut c_void,
1184        )
1185    })
1186}
1187
1188// ---- Sparse BLAS-1 helpers ---------------------------------------------
1189
1190pub fn axpby<T: SparseScalar>(
1191    handle: &Handle,
1192    alpha: &T,
1193    x: &DnVec<'_, T>,
1194    beta: &T,
1195    y: &mut DnVec<'_, T>,
1196) -> Result<()> {
1197    let c = cusparse()?;
1198    let cu = c.cusparse_axpby()?;
1199    check(unsafe {
1200        cu(
1201            handle.as_raw(),
1202            alpha as *const T as *const c_void,
1203            x.descr,
1204            beta as *const T as *const c_void,
1205            y.descr,
1206        )
1207    })
1208}
1209
1210pub fn gather<T: SparseScalar>(
1211    handle: &Handle,
1212    y: &DnVec<'_, T>,
1213    x: &mut DnVec<'_, T>,
1214) -> Result<()> {
1215    let c = cusparse()?;
1216    let cu = c.cusparse_gather()?;
1217    check(unsafe { cu(handle.as_raw(), y.descr, x.descr) })
1218}
1219
1220pub fn scatter<T: SparseScalar>(
1221    handle: &Handle,
1222    x: &DnVec<'_, T>,
1223    y: &mut DnVec<'_, T>,
1224) -> Result<()> {
1225    let c = cusparse()?;
1226    let cu = c.cusparse_scatter()?;
1227    check(unsafe { cu(handle.as_raw(), x.descr, y.descr) })
1228}
1229
1230pub fn rot<T: SparseScalar>(
1231    handle: &Handle,
1232    c_cos: &T,
1233    s_sin: &T,
1234    x: &mut DnVec<'_, T>,
1235    y: &mut DnVec<'_, T>,
1236) -> Result<()> {
1237    let c_api = cusparse()?;
1238    let cu = c_api.cusparse_rot()?;
1239    check(unsafe {
1240        cu(
1241            handle.as_raw(),
1242            c_cos as *const T as *const c_void,
1243            s_sin as *const T as *const c_void,
1244            x.descr,
1245            y.descr,
1246        )
1247    })
1248}
1249
1250// ---- Back-compat re-exports for existing users --------------------------
1251
1252/// Legacy alias kept for callers from v0.1 — prefer [`SpMat::csr`].
1253pub type CsrMatrix<'buf> = SpMat<'buf, f32>;
1254/// Legacy alias — prefer [`DnVec`].
1255pub type DenseVector<'buf, T> = DnVec<'buf, T>;