1#![warn(missing_debug_implementations)]
13
14use core::ffi::c_void;
15use std::marker::PhantomData;
16
17use baracuda_cusparse_sys::{
18 cudaDataType, cusparse, cusparseDiagType_t, cusparseDnMatDescr_t, cusparseDnVecDescr_t,
19 cusparseFillMode_t, cusparseHandle_t, cusparseIndexBase_t, cusparseIndexType_t,
20 cusparseOperation_t, cusparseOrder_t, cusparseSpGEMMDescr_t, cusparseSpMatAttribute_t,
21 cusparseSpMatDescr_t, cusparseSpSMDescr_t, cusparseSpSVDescr_t, cusparseStatus_t,
22};
23use baracuda_driver::{DeviceBuffer, Stream};
24use baracuda_types::{Complex32, Complex64};
25
26pub use baracuda_cusparse_sys::{
27 cusparseCsr2CscAlg_t as Csr2CscAlg, cusparseIndexBase_t as IndexBase,
28 cusparseSDDMMAlg_t as SDDMMAlg, cusparseSpGEMMAlg_t as SpGEMMAlg,
29 cusparseSpMMAlg_t as SpMMAlg, cusparseSpMVAlg_t as SpMVAlg, cusparseSpSMAlg_t as SpSMAlg,
30 cusparseSpSVAlg_t as SpSVAlg,
31};
32
33pub type Error = baracuda_core::Error<cusparseStatus_t>;
35pub type Result<T, E = Error> = core::result::Result<T, E>;
37
38#[inline]
39fn check(status: cusparseStatus_t) -> Result<()> {
40 Error::check(status)
41}
42
43pub trait SparseScalar: sealed::Sealed + Copy + 'static {
47 fn data_type() -> cudaDataType;
49}
50
51impl SparseScalar for f32 {
52 fn data_type() -> cudaDataType {
53 cudaDataType::R_32F
54 }
55}
56impl SparseScalar for f64 {
57 fn data_type() -> cudaDataType {
58 cudaDataType::R_64F
59 }
60}
61impl SparseScalar for Complex32 {
62 fn data_type() -> cudaDataType {
63 cudaDataType::C_32F
64 }
65}
66impl SparseScalar for Complex64 {
67 fn data_type() -> cudaDataType {
68 cudaDataType::C_64F
69 }
70}
71
72mod sealed {
73 use baracuda_types::{Complex32, Complex64};
74 pub trait Sealed {}
75 impl Sealed for f32 {}
76 impl Sealed for f64 {}
77 impl Sealed for Complex32 {}
78 impl Sealed for Complex64 {}
79}
80
81#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
85pub enum Op {
86 #[default]
87 N,
88 T,
89 C,
90}
91
92impl Op {
93 fn raw(self) -> cusparseOperation_t {
94 match self {
95 Op::N => cusparseOperation_t::N,
96 Op::T => cusparseOperation_t::T,
97 Op::C => cusparseOperation_t::C,
98 }
99 }
100}
101
102#[derive(Copy, Clone, Debug, Eq, PartialEq)]
104pub enum Order {
105 Row,
106 Col,
107}
108
109impl Order {
110 fn raw(self) -> cusparseOrder_t {
111 match self {
112 Order::Row => cusparseOrder_t::Row,
113 Order::Col => cusparseOrder_t::Col,
114 }
115 }
116}
117
118#[derive(Copy, Clone, Debug, Eq, PartialEq)]
119pub enum Fill {
120 Lower,
121 Upper,
122}
123
124impl Fill {
125 fn raw(self) -> cusparseFillMode_t {
126 match self {
127 Fill::Lower => cusparseFillMode_t::Lower,
128 Fill::Upper => cusparseFillMode_t::Upper,
129 }
130 }
131}
132
133#[derive(Copy, Clone, Debug, Eq, PartialEq)]
134pub enum Diag {
135 NonUnit,
136 Unit,
137}
138
139impl Diag {
140 fn raw(self) -> cusparseDiagType_t {
141 match self {
142 Diag::NonUnit => cusparseDiagType_t::NonUnit,
143 Diag::Unit => cusparseDiagType_t::Unit,
144 }
145 }
146}
147
148pub struct Handle {
152 handle: cusparseHandle_t,
153}
154
155unsafe impl Send for Handle {}
156
157impl core::fmt::Debug for Handle {
158 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
159 f.debug_struct("cusparse::Handle")
160 .field("handle", &self.handle)
161 .finish()
162 }
163}
164
165impl Handle {
166 pub fn new() -> Result<Self> {
167 let c = cusparse()?;
168 let cu = c.cusparse_create()?;
169 let mut h: cusparseHandle_t = core::ptr::null_mut();
170 check(unsafe { cu(&mut h) })?;
171 Ok(Self { handle: h })
172 }
173
174 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
175 let c = cusparse()?;
176 let cu = c.cusparse_set_stream()?;
177 check(unsafe { cu(self.handle, stream.as_raw() as _) })
178 }
179
180 pub fn version(&self) -> Result<i32> {
181 let c = cusparse()?;
182 let cu = c.cusparse_get_version()?;
183 let mut v: core::ffi::c_int = 0;
184 check(unsafe { cu(self.handle, &mut v) })?;
185 Ok(v)
186 }
187
188 #[inline]
189 pub fn as_raw(&self) -> cusparseHandle_t {
190 self.handle
191 }
192}
193
194impl Drop for Handle {
195 fn drop(&mut self) {
196 if let Ok(c) = cusparse() {
197 if let Ok(cu) = c.cusparse_destroy() {
198 let _ = unsafe { cu(self.handle) };
199 }
200 }
201 }
202}
203
204pub struct SpMat<'buf, T> {
210 descr: cusparseSpMatDescr_t,
211 _markers: PhantomData<&'buf mut T>,
212}
213
214unsafe impl<T> Send for SpMat<'_, T> {}
215
216impl<T> core::fmt::Debug for SpMat<'_, T> {
217 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
218 f.debug_struct("SpMat")
219 .field("descr", &self.descr)
220 .finish_non_exhaustive()
221 }
222}
223
224impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> SpMat<'buf, T> {
225 pub fn csr(
230 rows: i64,
231 cols: i64,
232 nnz: i64,
233 row_offsets: &'buf mut DeviceBuffer<i32>,
234 col_indices: &'buf mut DeviceBuffer<i32>,
235 values: &'buf mut DeviceBuffer<T>,
236 ) -> Result<Self> {
237 let c = cusparse()?;
238 let cu = c.cusparse_create_csr()?;
239 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
240 check(unsafe {
241 cu(
242 &mut descr,
243 rows,
244 cols,
245 nnz,
246 row_offsets.as_raw().0 as *mut c_void,
247 col_indices.as_raw().0 as *mut c_void,
248 values.as_raw().0 as *mut c_void,
249 cusparseIndexType_t::I32I,
250 cusparseIndexType_t::I32I,
251 cusparseIndexBase_t::Zero,
252 T::data_type(),
253 )
254 })?;
255 Ok(Self {
256 descr,
257 _markers: PhantomData,
258 })
259 }
260
261 pub fn csc(
263 rows: i64,
264 cols: i64,
265 nnz: i64,
266 col_offsets: &'buf mut DeviceBuffer<i32>,
267 row_indices: &'buf mut DeviceBuffer<i32>,
268 values: &'buf mut DeviceBuffer<T>,
269 ) -> Result<Self> {
270 let c = cusparse()?;
271 let cu = c.cusparse_create_csc()?;
272 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
273 check(unsafe {
274 cu(
275 &mut descr,
276 rows,
277 cols,
278 nnz,
279 col_offsets.as_raw().0 as *mut c_void,
280 row_indices.as_raw().0 as *mut c_void,
281 values.as_raw().0 as *mut c_void,
282 cusparseIndexType_t::I32I,
283 cusparseIndexType_t::I32I,
284 cusparseIndexBase_t::Zero,
285 T::data_type(),
286 )
287 })?;
288 Ok(Self {
289 descr,
290 _markers: PhantomData,
291 })
292 }
293
294 #[allow(clippy::too_many_arguments)]
296 pub fn bsr(
297 brows: i64,
298 bcols: i64,
299 bnnz: i64,
300 row_block_dim: i64,
301 col_block_dim: i64,
302 order: Order,
303 row_offsets: &'buf mut DeviceBuffer<i32>,
304 col_indices: &'buf mut DeviceBuffer<i32>,
305 values: &'buf mut DeviceBuffer<T>,
306 ) -> Result<Self> {
307 let c = cusparse()?;
308 let cu = c.cusparse_create_bsr()?;
309 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
310 check(unsafe {
311 cu(
312 &mut descr,
313 brows,
314 bcols,
315 bnnz,
316 row_block_dim,
317 col_block_dim,
318 row_offsets.as_raw().0 as *mut c_void,
319 col_indices.as_raw().0 as *mut c_void,
320 values.as_raw().0 as *mut c_void,
321 cusparseIndexType_t::I32I,
322 cusparseIndexType_t::I32I,
323 cusparseIndexBase_t::Zero,
324 T::data_type(),
325 order.raw(),
326 )
327 })?;
328 Ok(Self {
329 descr,
330 _markers: PhantomData,
331 })
332 }
333
334 pub fn coo(
336 rows: i64,
337 cols: i64,
338 nnz: i64,
339 row_indices: &'buf mut DeviceBuffer<i32>,
340 col_indices: &'buf mut DeviceBuffer<i32>,
341 values: &'buf mut DeviceBuffer<T>,
342 ) -> Result<Self> {
343 let c = cusparse()?;
344 let cu = c.cusparse_create_coo()?;
345 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
346 check(unsafe {
347 cu(
348 &mut descr,
349 rows,
350 cols,
351 nnz,
352 row_indices.as_raw().0 as *mut c_void,
353 col_indices.as_raw().0 as *mut c_void,
354 values.as_raw().0 as *mut c_void,
355 cusparseIndexType_t::I32I,
356 cusparseIndexBase_t::Zero,
357 T::data_type(),
358 )
359 })?;
360 Ok(Self {
361 descr,
362 _markers: PhantomData,
363 })
364 }
365}
366
367impl<T> SpMat<'_, T> {
368 pub fn shape(&self) -> Result<(i64, i64, i64)> {
370 let c = cusparse()?;
371 let cu = c.cusparse_sp_mat_get_size()?;
372 let (mut r, mut col, mut nz) = (0i64, 0i64, 0i64);
373 check(unsafe { cu(self.descr, &mut r, &mut col, &mut nz) })?;
374 Ok((r, col, nz))
375 }
376
377 pub unsafe fn set_csr_pointers(
388 &self,
389 row_offsets: *mut c_void,
390 col_indices: *mut c_void,
391 values: *mut c_void,
392 ) -> Result<()> {
393 let c = cusparse()?;
394 let cu = c.cusparse_csr_set_pointers()?;
395 check(cu(self.descr, row_offsets, col_indices, values))
396 }
397
398 pub unsafe fn set_csc_pointers(
404 &self,
405 col_offsets: *mut c_void,
406 row_indices: *mut c_void,
407 values: *mut c_void,
408 ) -> Result<()> {
409 let c = cusparse()?;
410 let cu = c.cusparse_csc_set_pointers()?;
411 check(cu(self.descr, col_offsets, row_indices, values))
412 }
413
414 pub unsafe fn set_coo_pointers(
420 &self,
421 row_indices: *mut c_void,
422 col_indices: *mut c_void,
423 values: *mut c_void,
424 ) -> Result<()> {
425 let c = cusparse()?;
426 let cu = c.cusparse_coo_set_pointers()?;
427 check(cu(self.descr, row_indices, col_indices, values))
428 }
429
430 pub fn set_fill(&self, fill: Fill) -> Result<()> {
432 let c = cusparse()?;
433 let cu = c.cusparse_sp_mat_set_attribute()?;
434 let raw = fill.raw();
435 check(unsafe {
436 cu(
437 self.descr,
438 cusparseSpMatAttribute_t::FillMode,
439 &raw as *const _ as *const c_void,
440 core::mem::size_of::<cusparseFillMode_t>(),
441 )
442 })
443 }
444
445 pub fn set_diag(&self, diag: Diag) -> Result<()> {
447 let c = cusparse()?;
448 let cu = c.cusparse_sp_mat_set_attribute()?;
449 let raw = diag.raw();
450 check(unsafe {
451 cu(
452 self.descr,
453 cusparseSpMatAttribute_t::DiagType,
454 &raw as *const _ as *const c_void,
455 core::mem::size_of::<cusparseDiagType_t>(),
456 )
457 })
458 }
459
460 #[inline]
461 pub fn as_raw(&self) -> cusparseSpMatDescr_t {
462 self.descr
463 }
464}
465
466impl<T> Drop for SpMat<'_, T> {
467 fn drop(&mut self) {
468 if let Ok(c) = cusparse() {
469 if let Ok(cu) = c.cusparse_destroy_sp_mat() {
470 let _ = unsafe { cu(self.descr) };
471 }
472 }
473 }
474}
475
476pub struct DnVec<'buf, T> {
479 descr: cusparseDnVecDescr_t,
480 _marker: PhantomData<&'buf mut T>,
481}
482
483unsafe impl<T> Send for DnVec<'_, T> {}
484
485impl<T> core::fmt::Debug for DnVec<'_, T> {
486 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
487 f.debug_struct("DnVec")
488 .field("descr", &self.descr)
489 .finish_non_exhaustive()
490 }
491}
492
493impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnVec<'buf, T> {
494 pub fn new(values: &'buf mut DeviceBuffer<T>) -> Result<Self> {
495 let c = cusparse()?;
496 let cu = c.cusparse_create_dn_vec()?;
497 let mut descr: cusparseDnVecDescr_t = core::ptr::null_mut();
498 check(unsafe {
499 cu(
500 &mut descr,
501 values.len() as i64,
502 values.as_raw().0 as *mut c_void,
503 T::data_type(),
504 )
505 })?;
506 Ok(Self {
507 descr,
508 _marker: PhantomData,
509 })
510 }
511}
512
513impl<T> DnVec<'_, T> {
514 #[inline]
515 pub fn as_raw(&self) -> cusparseDnVecDescr_t {
516 self.descr
517 }
518}
519
520impl<T> Drop for DnVec<'_, T> {
521 fn drop(&mut self) {
522 if let Ok(c) = cusparse() {
523 if let Ok(cu) = c.cusparse_destroy_dn_vec() {
524 let _ = unsafe { cu(self.descr) };
525 }
526 }
527 }
528}
529
530pub struct DnMat<'buf, T> {
531 descr: cusparseDnMatDescr_t,
532 _marker: PhantomData<&'buf mut T>,
533}
534
535unsafe impl<T> Send for DnMat<'_, T> {}
536
537impl<T> core::fmt::Debug for DnMat<'_, T> {
538 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
539 f.debug_struct("DnMat")
540 .field("descr", &self.descr)
541 .finish_non_exhaustive()
542 }
543}
544
545impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnMat<'buf, T> {
546 pub fn new(
547 rows: i64,
548 cols: i64,
549 ld: i64,
550 order: Order,
551 values: &'buf mut DeviceBuffer<T>,
552 ) -> Result<Self> {
553 let c = cusparse()?;
554 let cu = c.cusparse_create_dn_mat()?;
555 let mut descr: cusparseDnMatDescr_t = core::ptr::null_mut();
556 check(unsafe {
557 cu(
558 &mut descr,
559 rows,
560 cols,
561 ld,
562 values.as_raw().0 as *mut c_void,
563 T::data_type(),
564 order.raw(),
565 )
566 })?;
567 Ok(Self {
568 descr,
569 _marker: PhantomData,
570 })
571 }
572}
573
574impl<T> DnMat<'_, T> {
575 #[inline]
576 pub fn as_raw(&self) -> cusparseDnMatDescr_t {
577 self.descr
578 }
579}
580
581impl<T> Drop for DnMat<'_, T> {
582 fn drop(&mut self) {
583 if let Ok(c) = cusparse() {
584 if let Ok(cu) = c.cusparse_destroy_dn_mat() {
585 let _ = unsafe { cu(self.descr) };
586 }
587 }
588 }
589}
590
591#[allow(clippy::too_many_arguments)]
595pub fn spmv_buffer_size<T: SparseScalar>(
596 handle: &Handle,
597 op: Op,
598 alpha: &T,
599 a: &SpMat<'_, T>,
600 x: &DnVec<'_, T>,
601 beta: &T,
602 y: &DnVec<'_, T>,
603 alg: SpMVAlg,
604) -> Result<usize> {
605 let c = cusparse()?;
606 let cu = c.cusparse_spmv_buffer_size()?;
607 let mut size: usize = 0;
608 check(unsafe {
609 cu(
610 handle.as_raw(),
611 op.raw(),
612 alpha as *const T as *const c_void,
613 a.descr,
614 x.descr,
615 beta as *const T as *const c_void,
616 y.descr,
617 T::data_type(),
618 alg,
619 &mut size,
620 )
621 })?;
622 Ok(size)
623}
624
625#[allow(clippy::too_many_arguments)]
627pub fn spmv<T: SparseScalar>(
628 handle: &Handle,
629 op: Op,
630 alpha: &T,
631 a: &SpMat<'_, T>,
632 x: &DnVec<'_, T>,
633 beta: &T,
634 y: &mut DnVec<'_, T>,
635 alg: SpMVAlg,
636 workspace: &mut DeviceBuffer<u8>,
637) -> Result<()> {
638 let c = cusparse()?;
639 let cu = c.cusparse_spmv()?;
640 check(unsafe {
641 cu(
642 handle.as_raw(),
643 op.raw(),
644 alpha as *const T as *const c_void,
645 a.descr,
646 x.descr,
647 beta as *const T as *const c_void,
648 y.descr,
649 T::data_type(),
650 alg,
651 workspace.as_raw().0 as *mut c_void,
652 )
653 })
654}
655
656#[allow(clippy::too_many_arguments)]
660pub fn spmm_buffer_size<T: SparseScalar>(
661 handle: &Handle,
662 op_a: Op,
663 op_b: Op,
664 alpha: &T,
665 a: &SpMat<'_, T>,
666 b: &DnMat<'_, T>,
667 beta: &T,
668 c: &DnMat<'_, T>,
669 alg: SpMMAlg,
670) -> Result<usize> {
671 let c_api = cusparse()?;
672 let cu = c_api.cusparse_spmm_buffer_size()?;
673 let mut size = 0usize;
674 check(unsafe {
675 cu(
676 handle.as_raw(),
677 op_a.raw(),
678 op_b.raw(),
679 alpha as *const T as *const c_void,
680 a.descr,
681 b.descr,
682 beta as *const T as *const c_void,
683 c.descr,
684 T::data_type(),
685 alg,
686 &mut size,
687 )
688 })?;
689 Ok(size)
690}
691
692#[allow(clippy::too_many_arguments)]
693pub fn spmm<T: SparseScalar>(
694 handle: &Handle,
695 op_a: Op,
696 op_b: Op,
697 alpha: &T,
698 a: &SpMat<'_, T>,
699 b: &DnMat<'_, T>,
700 beta: &T,
701 c: &mut DnMat<'_, T>,
702 alg: SpMMAlg,
703 workspace: &mut DeviceBuffer<u8>,
704) -> Result<()> {
705 let c_api = cusparse()?;
706 let cu = c_api.cusparse_spmm()?;
707 check(unsafe {
708 cu(
709 handle.as_raw(),
710 op_a.raw(),
711 op_b.raw(),
712 alpha as *const T as *const c_void,
713 a.descr,
714 b.descr,
715 beta as *const T as *const c_void,
716 c.descr,
717 T::data_type(),
718 alg,
719 workspace.as_raw().0 as *mut c_void,
720 )
721 })
722}
723
724#[allow(clippy::too_many_arguments)]
729pub fn spmm_preprocess<T: SparseScalar>(
730 handle: &Handle,
731 op_a: Op,
732 op_b: Op,
733 alpha: &T,
734 a: &SpMat<'_, T>,
735 b: &DnMat<'_, T>,
736 beta: &T,
737 c: &mut DnMat<'_, T>,
738 alg: SpMMAlg,
739 workspace: &mut DeviceBuffer<u8>,
740) -> Result<()> {
741 let c_api = cusparse()?;
742 let cu = c_api.cusparse_spmm_preprocess()?;
743 check(unsafe {
744 cu(
745 handle.as_raw(),
746 op_a.raw(),
747 op_b.raw(),
748 alpha as *const T as *const c_void,
749 a.descr,
750 b.descr,
751 beta as *const T as *const c_void,
752 c.descr,
753 T::data_type(),
754 alg,
755 workspace.as_raw().0 as *mut c_void,
756 )
757 })
758}
759
760#[derive(Debug)]
764pub struct SpGEMMPlan {
765 raw: cusparseSpGEMMDescr_t,
766}
767
768impl SpGEMMPlan {
769 pub fn new() -> Result<Self> {
770 let c = cusparse()?;
771 let cu = c.cusparse_spgemm_create_descr()?;
772 let mut d: cusparseSpGEMMDescr_t = core::ptr::null_mut();
773 check(unsafe { cu(&mut d) })?;
774 Ok(Self { raw: d })
775 }
776}
777
778impl Drop for SpGEMMPlan {
779 fn drop(&mut self) {
780 if let Ok(c) = cusparse() {
781 if let Ok(cu) = c.cusparse_spgemm_destroy_descr() {
782 let _ = unsafe { cu(self.raw) };
783 }
784 }
785 }
786}
787
788#[allow(clippy::too_many_arguments)]
791pub unsafe fn spgemm_work_estimation<T: SparseScalar>(
792 handle: &Handle,
793 op_a: Op,
794 op_b: Op,
795 alpha: &T,
796 a: &SpMat<'_, T>,
797 b: &SpMat<'_, T>,
798 beta: &T,
799 c: &mut SpMat<'_, T>,
800 alg: SpGEMMAlg,
801 plan: &SpGEMMPlan,
802 size1: &mut usize,
803 buffer1: *mut c_void,
804) -> Result<()> {
805 let c_api = cusparse()?;
806 let cu = c_api.cusparse_spgemm_work_estimation()?;
807 check(cu(
808 handle.as_raw(),
809 op_a.raw(),
810 op_b.raw(),
811 alpha as *const T as *const c_void,
812 a.descr,
813 b.descr,
814 beta as *const T as *const c_void,
815 c.descr,
816 T::data_type(),
817 alg,
818 plan.raw,
819 size1,
820 buffer1,
821 ))
822}
823
824#[allow(clippy::too_many_arguments)]
826pub unsafe fn spgemm_compute<T: SparseScalar>(
827 handle: &Handle,
828 op_a: Op,
829 op_b: Op,
830 alpha: &T,
831 a: &SpMat<'_, T>,
832 b: &SpMat<'_, T>,
833 beta: &T,
834 c: &mut SpMat<'_, T>,
835 alg: SpGEMMAlg,
836 plan: &SpGEMMPlan,
837 size2: &mut usize,
838 buffer2: *mut c_void,
839) -> Result<()> {
840 let c_api = cusparse()?;
841 let cu = c_api.cusparse_spgemm_compute()?;
842 check(cu(
843 handle.as_raw(),
844 op_a.raw(),
845 op_b.raw(),
846 alpha as *const T as *const c_void,
847 a.descr,
848 b.descr,
849 beta as *const T as *const c_void,
850 c.descr,
851 T::data_type(),
852 alg,
853 plan.raw,
854 size2,
855 buffer2,
856 ))
857}
858
859#[allow(clippy::too_many_arguments)]
861pub fn spgemm_copy<T: SparseScalar>(
862 handle: &Handle,
863 op_a: Op,
864 op_b: Op,
865 alpha: &T,
866 a: &SpMat<'_, T>,
867 b: &SpMat<'_, T>,
868 beta: &T,
869 c: &mut SpMat<'_, T>,
870 alg: SpGEMMAlg,
871 plan: &SpGEMMPlan,
872) -> Result<()> {
873 let c_api = cusparse()?;
874 let cu = c_api.cusparse_spgemm_copy()?;
875 check(unsafe {
876 cu(
877 handle.as_raw(),
878 op_a.raw(),
879 op_b.raw(),
880 alpha as *const T as *const c_void,
881 a.descr,
882 b.descr,
883 beta as *const T as *const c_void,
884 c.descr,
885 T::data_type(),
886 alg,
887 plan.raw,
888 )
889 })
890}
891
892#[derive(Debug)]
895pub struct SpSVPlan {
896 raw: cusparseSpSVDescr_t,
897}
898
899impl SpSVPlan {
900 pub fn new() -> Result<Self> {
901 let c = cusparse()?;
902 let cu = c.cusparse_spsv_create_descr()?;
903 let mut d: cusparseSpSVDescr_t = core::ptr::null_mut();
904 check(unsafe { cu(&mut d) })?;
905 Ok(Self { raw: d })
906 }
907}
908
909impl Drop for SpSVPlan {
910 fn drop(&mut self) {
911 if let Ok(c) = cusparse() {
912 if let Ok(cu) = c.cusparse_spsv_destroy_descr() {
913 let _ = unsafe { cu(self.raw) };
914 }
915 }
916 }
917}
918
919#[allow(clippy::too_many_arguments)]
920pub fn spsv_buffer_size<T: SparseScalar>(
921 handle: &Handle,
922 op: Op,
923 alpha: &T,
924 a: &SpMat<'_, T>,
925 x: &DnVec<'_, T>,
926 y: &DnVec<'_, T>,
927 alg: SpSVAlg,
928 plan: &SpSVPlan,
929) -> Result<usize> {
930 let c = cusparse()?;
931 let cu = c.cusparse_spsv_buffer_size()?;
932 let mut size = 0usize;
933 check(unsafe {
934 cu(
935 handle.as_raw(),
936 op.raw(),
937 alpha as *const T as *const c_void,
938 a.descr,
939 x.descr,
940 y.descr,
941 T::data_type(),
942 alg,
943 plan.raw,
944 &mut size,
945 )
946 })?;
947 Ok(size)
948}
949
950#[allow(clippy::too_many_arguments)]
951pub fn spsv_analysis<T: SparseScalar>(
952 handle: &Handle,
953 op: Op,
954 alpha: &T,
955 a: &SpMat<'_, T>,
956 x: &DnVec<'_, T>,
957 y: &DnVec<'_, T>,
958 alg: SpSVAlg,
959 plan: &SpSVPlan,
960 workspace: &mut DeviceBuffer<u8>,
961) -> Result<()> {
962 let c = cusparse()?;
963 let cu = c.cusparse_spsv_analysis()?;
964 check(unsafe {
965 cu(
966 handle.as_raw(),
967 op.raw(),
968 alpha as *const T as *const c_void,
969 a.descr,
970 x.descr,
971 y.descr,
972 T::data_type(),
973 alg,
974 plan.raw,
975 workspace.as_raw().0 as *mut c_void,
976 )
977 })
978}
979
980#[allow(clippy::too_many_arguments)]
981pub fn spsv_solve<T: SparseScalar>(
982 handle: &Handle,
983 op: Op,
984 alpha: &T,
985 a: &SpMat<'_, T>,
986 x: &DnVec<'_, T>,
987 y: &mut DnVec<'_, T>,
988 alg: SpSVAlg,
989 plan: &SpSVPlan,
990) -> Result<()> {
991 let c = cusparse()?;
992 let cu = c.cusparse_spsv_solve()?;
993 check(unsafe {
994 cu(
995 handle.as_raw(),
996 op.raw(),
997 alpha as *const T as *const c_void,
998 a.descr,
999 x.descr,
1000 y.descr,
1001 T::data_type(),
1002 alg,
1003 plan.raw,
1004 )
1005 })
1006}
1007
1008#[derive(Debug)]
1009pub struct SpSMPlan {
1010 raw: cusparseSpSMDescr_t,
1011}
1012
1013impl SpSMPlan {
1014 pub fn new() -> Result<Self> {
1015 let c = cusparse()?;
1016 let cu = c.cusparse_spsm_create_descr()?;
1017 let mut d: cusparseSpSMDescr_t = core::ptr::null_mut();
1018 check(unsafe { cu(&mut d) })?;
1019 Ok(Self { raw: d })
1020 }
1021}
1022
1023impl Drop for SpSMPlan {
1024 fn drop(&mut self) {
1025 if let Ok(c) = cusparse() {
1026 if let Ok(cu) = c.cusparse_spsm_destroy_descr() {
1027 let _ = unsafe { cu(self.raw) };
1028 }
1029 }
1030 }
1031}
1032
1033#[allow(clippy::too_many_arguments)]
1034pub fn spsm_buffer_size<T: SparseScalar>(
1035 handle: &Handle,
1036 op_a: Op,
1037 op_b: Op,
1038 alpha: &T,
1039 a: &SpMat<'_, T>,
1040 b: &DnMat<'_, T>,
1041 c: &DnMat<'_, T>,
1042 alg: SpSMAlg,
1043 plan: &SpSMPlan,
1044) -> Result<usize> {
1045 let c_api = cusparse()?;
1046 let cu = c_api.cusparse_spsm_buffer_size()?;
1047 let mut size = 0usize;
1048 check(unsafe {
1049 cu(
1050 handle.as_raw(),
1051 op_a.raw(),
1052 op_b.raw(),
1053 alpha as *const T as *const c_void,
1054 a.descr,
1055 b.descr,
1056 c.descr,
1057 T::data_type(),
1058 alg,
1059 plan.raw,
1060 &mut size,
1061 )
1062 })?;
1063 Ok(size)
1064}
1065
1066#[allow(clippy::too_many_arguments)]
1067pub fn spsm_analysis<T: SparseScalar>(
1068 handle: &Handle,
1069 op_a: Op,
1070 op_b: Op,
1071 alpha: &T,
1072 a: &SpMat<'_, T>,
1073 b: &DnMat<'_, T>,
1074 c: &DnMat<'_, T>,
1075 alg: SpSMAlg,
1076 plan: &SpSMPlan,
1077 workspace: &mut DeviceBuffer<u8>,
1078) -> Result<()> {
1079 let c_api = cusparse()?;
1080 let cu = c_api.cusparse_spsm_analysis()?;
1081 check(unsafe {
1082 cu(
1083 handle.as_raw(),
1084 op_a.raw(),
1085 op_b.raw(),
1086 alpha as *const T as *const c_void,
1087 a.descr,
1088 b.descr,
1089 c.descr,
1090 T::data_type(),
1091 alg,
1092 plan.raw,
1093 workspace.as_raw().0 as *mut c_void,
1094 )
1095 })
1096}
1097
1098#[allow(clippy::too_many_arguments)]
1099pub fn spsm_solve<T: SparseScalar>(
1100 handle: &Handle,
1101 op_a: Op,
1102 op_b: Op,
1103 alpha: &T,
1104 a: &SpMat<'_, T>,
1105 b: &DnMat<'_, T>,
1106 c: &mut DnMat<'_, T>,
1107 alg: SpSMAlg,
1108 plan: &SpSMPlan,
1109) -> Result<()> {
1110 let c_api = cusparse()?;
1111 let cu = c_api.cusparse_spsm_solve()?;
1112 check(unsafe {
1113 cu(
1114 handle.as_raw(),
1115 op_a.raw(),
1116 op_b.raw(),
1117 alpha as *const T as *const c_void,
1118 a.descr,
1119 b.descr,
1120 c.descr,
1121 T::data_type(),
1122 alg,
1123 plan.raw,
1124 )
1125 })
1126}
1127
1128#[allow(clippy::too_many_arguments)]
1131pub fn sddmm_buffer_size<T: SparseScalar>(
1132 handle: &Handle,
1133 op_a: Op,
1134 op_b: Op,
1135 alpha: &T,
1136 a: &DnMat<'_, T>,
1137 b: &DnMat<'_, T>,
1138 beta: &T,
1139 c: &SpMat<'_, T>,
1140 alg: SDDMMAlg,
1141) -> Result<usize> {
1142 let c_api = cusparse()?;
1143 let cu = c_api.cusparse_sddmm_buffer_size()?;
1144 let mut size = 0usize;
1145 check(unsafe {
1146 cu(
1147 handle.as_raw(),
1148 op_a.raw(),
1149 op_b.raw(),
1150 alpha as *const T as *const c_void,
1151 a.descr,
1152 b.descr,
1153 beta as *const T as *const c_void,
1154 c.descr,
1155 T::data_type(),
1156 alg,
1157 &mut size,
1158 )
1159 })?;
1160 Ok(size)
1161}
1162
1163#[allow(clippy::too_many_arguments)]
1164pub fn sddmm<T: SparseScalar>(
1165 handle: &Handle,
1166 op_a: Op,
1167 op_b: Op,
1168 alpha: &T,
1169 a: &DnMat<'_, T>,
1170 b: &DnMat<'_, T>,
1171 beta: &T,
1172 c: &mut SpMat<'_, T>,
1173 alg: SDDMMAlg,
1174 workspace: &mut DeviceBuffer<u8>,
1175) -> Result<()> {
1176 let c_api = cusparse()?;
1177 let cu = c_api.cusparse_sddmm()?;
1178 check(unsafe {
1179 cu(
1180 handle.as_raw(),
1181 op_a.raw(),
1182 op_b.raw(),
1183 alpha as *const T as *const c_void,
1184 a.descr,
1185 b.descr,
1186 beta as *const T as *const c_void,
1187 c.descr,
1188 T::data_type(),
1189 alg,
1190 workspace.as_raw().0 as *mut c_void,
1191 )
1192 })
1193}
1194
1195#[allow(clippy::too_many_arguments)]
1198pub fn sddmm_preprocess<T: SparseScalar>(
1199 handle: &Handle,
1200 op_a: Op,
1201 op_b: Op,
1202 alpha: &T,
1203 a: &DnMat<'_, T>,
1204 b: &DnMat<'_, T>,
1205 beta: &T,
1206 c: &mut SpMat<'_, T>,
1207 alg: SDDMMAlg,
1208 workspace: &mut DeviceBuffer<u8>,
1209) -> Result<()> {
1210 let c_api = cusparse()?;
1211 let cu = c_api.cusparse_sddmm_preprocess()?;
1212 check(unsafe {
1213 cu(
1214 handle.as_raw(),
1215 op_a.raw(),
1216 op_b.raw(),
1217 alpha as *const T as *const c_void,
1218 a.descr,
1219 b.descr,
1220 beta as *const T as *const c_void,
1221 c.descr,
1222 T::data_type(),
1223 alg,
1224 workspace.as_raw().0 as *mut c_void,
1225 )
1226 })
1227}
1228
1229pub fn sparse_to_dense_buffer_size<T: SparseScalar>(
1232 handle: &Handle,
1233 sp: &SpMat<'_, T>,
1234 dn: &DnMat<'_, T>,
1235) -> Result<usize> {
1236 let c = cusparse()?;
1237 let cu = c.cusparse_sparse_to_dense_buffer_size()?;
1238 let mut size = 0usize;
1239 check(unsafe { cu(handle.as_raw(), sp.descr, dn.descr, 0, &mut size) })?;
1240 Ok(size)
1241}
1242
1243pub fn sparse_to_dense<T: SparseScalar>(
1244 handle: &Handle,
1245 sp: &SpMat<'_, T>,
1246 dn: &mut DnMat<'_, T>,
1247 workspace: &mut DeviceBuffer<u8>,
1248) -> Result<()> {
1249 let c = cusparse()?;
1250 let cu = c.cusparse_sparse_to_dense()?;
1251 check(unsafe {
1252 cu(
1253 handle.as_raw(),
1254 sp.descr,
1255 dn.descr,
1256 0,
1257 workspace.as_raw().0 as *mut c_void,
1258 )
1259 })
1260}
1261
1262pub fn dense_to_sparse_buffer_size<T: SparseScalar>(
1263 handle: &Handle,
1264 dn: &DnMat<'_, T>,
1265 sp: &SpMat<'_, T>,
1266) -> Result<usize> {
1267 let c = cusparse()?;
1268 let cu = c.cusparse_dense_to_sparse_buffer_size()?;
1269 let mut size = 0usize;
1270 check(unsafe { cu(handle.as_raw(), dn.descr, sp.descr, 0, &mut size) })?;
1271 Ok(size)
1272}
1273
1274pub fn dense_to_sparse_analysis<T: SparseScalar>(
1275 handle: &Handle,
1276 dn: &DnMat<'_, T>,
1277 sp: &SpMat<'_, T>,
1278 workspace: &mut DeviceBuffer<u8>,
1279) -> Result<()> {
1280 let c = cusparse()?;
1281 let cu = c.cusparse_dense_to_sparse_analysis()?;
1282 check(unsafe {
1283 cu(
1284 handle.as_raw(),
1285 dn.descr,
1286 sp.descr,
1287 0,
1288 workspace.as_raw().0 as *mut c_void,
1289 )
1290 })
1291}
1292
1293pub fn dense_to_sparse_convert<T: SparseScalar>(
1294 handle: &Handle,
1295 dn: &DnMat<'_, T>,
1296 sp: &mut SpMat<'_, T>,
1297 workspace: &mut DeviceBuffer<u8>,
1298) -> Result<()> {
1299 let c = cusparse()?;
1300 let cu = c.cusparse_dense_to_sparse_convert()?;
1301 check(unsafe {
1302 cu(
1303 handle.as_raw(),
1304 dn.descr,
1305 sp.descr,
1306 0,
1307 workspace.as_raw().0 as *mut c_void,
1308 )
1309 })
1310}
1311
1312#[allow(clippy::too_many_arguments)]
1314pub fn csr2csc_ex2_buffer_size<T: SparseScalar + baracuda_types::DeviceRepr>(
1315 handle: &Handle,
1316 m: i32,
1317 n: i32,
1318 nnz: i32,
1319 csr_val: &DeviceBuffer<T>,
1320 csr_row_ptr: &DeviceBuffer<i32>,
1321 csr_col_ind: &DeviceBuffer<i32>,
1322 csc_val: &mut DeviceBuffer<T>,
1323 csc_col_ptr: &mut DeviceBuffer<i32>,
1324 csc_row_ind: &mut DeviceBuffer<i32>,
1325 copy_values: bool,
1326 idx_base: IndexBase,
1327 alg: Csr2CscAlg,
1328) -> Result<usize> {
1329 let c = cusparse()?;
1330 let cu = c.cusparse_csr2csc_ex2_buffer_size()?;
1331 let mut size = 0usize;
1332 check(unsafe {
1333 cu(
1334 handle.as_raw(),
1335 m,
1336 n,
1337 nnz,
1338 csr_val.as_raw().0 as *const c_void,
1339 csr_row_ptr.as_raw().0 as *const i32,
1340 csr_col_ind.as_raw().0 as *const i32,
1341 csc_val.as_raw().0 as *mut c_void,
1342 csc_col_ptr.as_raw().0 as *mut i32,
1343 csc_row_ind.as_raw().0 as *mut i32,
1344 T::data_type(),
1345 copy_values as i32,
1346 idx_base,
1347 alg,
1348 &mut size,
1349 )
1350 })?;
1351 Ok(size)
1352}
1353
1354#[allow(clippy::too_many_arguments)]
1357pub fn csr2csc_ex2<T: SparseScalar + baracuda_types::DeviceRepr>(
1358 handle: &Handle,
1359 m: i32,
1360 n: i32,
1361 nnz: i32,
1362 csr_val: &DeviceBuffer<T>,
1363 csr_row_ptr: &DeviceBuffer<i32>,
1364 csr_col_ind: &DeviceBuffer<i32>,
1365 csc_val: &mut DeviceBuffer<T>,
1366 csc_col_ptr: &mut DeviceBuffer<i32>,
1367 csc_row_ind: &mut DeviceBuffer<i32>,
1368 copy_values: bool,
1369 idx_base: IndexBase,
1370 alg: Csr2CscAlg,
1371 workspace: &mut DeviceBuffer<u8>,
1372) -> Result<()> {
1373 let c = cusparse()?;
1374 let cu = c.cusparse_csr2csc_ex2()?;
1375 check(unsafe {
1376 cu(
1377 handle.as_raw(),
1378 m,
1379 n,
1380 nnz,
1381 csr_val.as_raw().0 as *const c_void,
1382 csr_row_ptr.as_raw().0 as *const i32,
1383 csr_col_ind.as_raw().0 as *const i32,
1384 csc_val.as_raw().0 as *mut c_void,
1385 csc_col_ptr.as_raw().0 as *mut i32,
1386 csc_row_ind.as_raw().0 as *mut i32,
1387 T::data_type(),
1388 copy_values as i32,
1389 idx_base,
1390 alg,
1391 workspace.as_raw().0 as *mut c_void,
1392 )
1393 })
1394}
1395
1396pub fn axpby<T: SparseScalar>(
1399 handle: &Handle,
1400 alpha: &T,
1401 x: &DnVec<'_, T>,
1402 beta: &T,
1403 y: &mut DnVec<'_, T>,
1404) -> Result<()> {
1405 let c = cusparse()?;
1406 let cu = c.cusparse_axpby()?;
1407 check(unsafe {
1408 cu(
1409 handle.as_raw(),
1410 alpha as *const T as *const c_void,
1411 x.descr,
1412 beta as *const T as *const c_void,
1413 y.descr,
1414 )
1415 })
1416}
1417
1418pub fn gather<T: SparseScalar>(
1419 handle: &Handle,
1420 y: &DnVec<'_, T>,
1421 x: &mut DnVec<'_, T>,
1422) -> Result<()> {
1423 let c = cusparse()?;
1424 let cu = c.cusparse_gather()?;
1425 check(unsafe { cu(handle.as_raw(), y.descr, x.descr) })
1426}
1427
1428pub fn scatter<T: SparseScalar>(
1429 handle: &Handle,
1430 x: &DnVec<'_, T>,
1431 y: &mut DnVec<'_, T>,
1432) -> Result<()> {
1433 let c = cusparse()?;
1434 let cu = c.cusparse_scatter()?;
1435 check(unsafe { cu(handle.as_raw(), x.descr, y.descr) })
1436}
1437
1438pub fn rot<T: SparseScalar>(
1439 handle: &Handle,
1440 c_cos: &T,
1441 s_sin: &T,
1442 x: &mut DnVec<'_, T>,
1443 y: &mut DnVec<'_, T>,
1444) -> Result<()> {
1445 let c_api = cusparse()?;
1446 let cu = c_api.cusparse_rot()?;
1447 check(unsafe {
1448 cu(
1449 handle.as_raw(),
1450 c_cos as *const T as *const c_void,
1451 s_sin as *const T as *const c_void,
1452 x.descr,
1453 y.descr,
1454 )
1455 })
1456}
1457
1458pub type CsrMatrix<'buf> = SpMat<'buf, f32>;
1462pub type DenseVector<'buf, T> = DnVec<'buf, T>;