1#![warn(missing_debug_implementations)]
13
14use core::ffi::c_void;
15use std::marker::PhantomData;
16
17use baracuda_cusparse_sys::{
18 cudaDataType, cusparse, cusparseDiagType_t, cusparseDnMatDescr_t, cusparseDnVecDescr_t,
19 cusparseFillMode_t, cusparseHandle_t, cusparseIndexBase_t, cusparseIndexType_t,
20 cusparseOperation_t, cusparseOrder_t, cusparseSpGEMMDescr_t, cusparseSpMatAttribute_t,
21 cusparseSpMatDescr_t, cusparseSpSMDescr_t, cusparseSpSVDescr_t, cusparseStatus_t,
22};
23use baracuda_driver::{DeviceBuffer, Stream};
24use baracuda_types::{Complex32, Complex64};
25
26pub use baracuda_cusparse_sys::{
27 cusparseCsr2CscAlg_t as Csr2CscAlg, cusparseIndexBase_t as IndexBase,
28 cusparseSDDMMAlg_t as SDDMMAlg, cusparseSpGEMMAlg_t as SpGEMMAlg,
29 cusparseSpMMAlg_t as SpMMAlg, cusparseSpMVAlg_t as SpMVAlg, cusparseSpSMAlg_t as SpSMAlg,
30 cusparseSpSVAlg_t as SpSVAlg,
31};
32
33pub type Error = baracuda_core::Error<cusparseStatus_t>;
35pub type Result<T, E = Error> = core::result::Result<T, E>;
37
38#[inline]
39fn check(status: cusparseStatus_t) -> Result<()> {
40 Error::check(status)
41}
42
43pub trait SparseScalar: sealed::Sealed + Copy + 'static {
47 fn data_type() -> cudaDataType;
49}
50
51impl SparseScalar for f32 {
52 fn data_type() -> cudaDataType {
53 cudaDataType::R_32F
54 }
55}
56impl SparseScalar for f64 {
57 fn data_type() -> cudaDataType {
58 cudaDataType::R_64F
59 }
60}
61impl SparseScalar for Complex32 {
62 fn data_type() -> cudaDataType {
63 cudaDataType::C_32F
64 }
65}
66impl SparseScalar for Complex64 {
67 fn data_type() -> cudaDataType {
68 cudaDataType::C_64F
69 }
70}
71
72mod sealed {
73 use baracuda_types::{Complex32, Complex64};
74 pub trait Sealed {}
75 impl Sealed for f32 {}
76 impl Sealed for f64 {}
77 impl Sealed for Complex32 {}
78 impl Sealed for Complex64 {}
79}
80
81#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
85pub enum Op {
86 #[default]
87 N,
88 T,
89 C,
90}
91
92impl Op {
93 fn raw(self) -> cusparseOperation_t {
94 match self {
95 Op::N => cusparseOperation_t::N,
96 Op::T => cusparseOperation_t::T,
97 Op::C => cusparseOperation_t::C,
98 }
99 }
100}
101
102#[derive(Copy, Clone, Debug, Eq, PartialEq)]
104pub enum Order {
105 Row,
106 Col,
107}
108
109impl Order {
110 fn raw(self) -> cusparseOrder_t {
111 match self {
112 Order::Row => cusparseOrder_t::Row,
113 Order::Col => cusparseOrder_t::Col,
114 }
115 }
116}
117
118#[derive(Copy, Clone, Debug, Eq, PartialEq)]
119pub enum Fill {
120 Lower,
121 Upper,
122}
123
124impl Fill {
125 fn raw(self) -> cusparseFillMode_t {
126 match self {
127 Fill::Lower => cusparseFillMode_t::Lower,
128 Fill::Upper => cusparseFillMode_t::Upper,
129 }
130 }
131}
132
133#[derive(Copy, Clone, Debug, Eq, PartialEq)]
134pub enum Diag {
135 NonUnit,
136 Unit,
137}
138
139impl Diag {
140 fn raw(self) -> cusparseDiagType_t {
141 match self {
142 Diag::NonUnit => cusparseDiagType_t::NonUnit,
143 Diag::Unit => cusparseDiagType_t::Unit,
144 }
145 }
146}
147
148pub struct Handle {
152 handle: cusparseHandle_t,
153}
154
155unsafe impl Send for Handle {}
156
157impl core::fmt::Debug for Handle {
158 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
159 f.debug_struct("cusparse::Handle")
160 .field("handle", &self.handle)
161 .finish()
162 }
163}
164
165impl Handle {
166 pub fn new() -> Result<Self> {
167 let c = cusparse()?;
168 let cu = c.cusparse_create()?;
169 let mut h: cusparseHandle_t = core::ptr::null_mut();
170 check(unsafe { cu(&mut h) })?;
171 Ok(Self { handle: h })
172 }
173
174 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
175 let c = cusparse()?;
176 let cu = c.cusparse_set_stream()?;
177 check(unsafe { cu(self.handle, stream.as_raw() as _) })
178 }
179
180 pub fn version(&self) -> Result<i32> {
181 let c = cusparse()?;
182 let cu = c.cusparse_get_version()?;
183 let mut v: core::ffi::c_int = 0;
184 check(unsafe { cu(self.handle, &mut v) })?;
185 Ok(v)
186 }
187
188 #[inline]
189 pub fn as_raw(&self) -> cusparseHandle_t {
190 self.handle
191 }
192}
193
194impl Drop for Handle {
195 fn drop(&mut self) {
196 if let Ok(c) = cusparse() {
197 if let Ok(cu) = c.cusparse_destroy() {
198 let _ = unsafe { cu(self.handle) };
199 }
200 }
201 }
202}
203
204pub struct SpMat<'buf, T> {
210 descr: cusparseSpMatDescr_t,
211 _markers: PhantomData<&'buf mut T>,
212}
213
214unsafe impl<T> Send for SpMat<'_, T> {}
215
216impl<T> core::fmt::Debug for SpMat<'_, T> {
217 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
218 f.debug_struct("SpMat")
219 .field("descr", &self.descr)
220 .finish_non_exhaustive()
221 }
222}
223
224impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> SpMat<'buf, T> {
225 pub fn csr(
230 rows: i64,
231 cols: i64,
232 nnz: i64,
233 row_offsets: &'buf mut DeviceBuffer<i32>,
234 col_indices: &'buf mut DeviceBuffer<i32>,
235 values: &'buf mut DeviceBuffer<T>,
236 ) -> Result<Self> {
237 let c = cusparse()?;
238 let cu = c.cusparse_create_csr()?;
239 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
240 check(unsafe {
241 cu(
242 &mut descr,
243 rows,
244 cols,
245 nnz,
246 row_offsets.as_raw().0 as *mut c_void,
247 col_indices.as_raw().0 as *mut c_void,
248 values.as_raw().0 as *mut c_void,
249 cusparseIndexType_t::I32I,
250 cusparseIndexType_t::I32I,
251 cusparseIndexBase_t::Zero,
252 T::data_type(),
253 )
254 })?;
255 Ok(Self {
256 descr,
257 _markers: PhantomData,
258 })
259 }
260
261 pub fn csc(
263 rows: i64,
264 cols: i64,
265 nnz: i64,
266 col_offsets: &'buf mut DeviceBuffer<i32>,
267 row_indices: &'buf mut DeviceBuffer<i32>,
268 values: &'buf mut DeviceBuffer<T>,
269 ) -> Result<Self> {
270 let c = cusparse()?;
271 let cu = c.cusparse_create_csc()?;
272 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
273 check(unsafe {
274 cu(
275 &mut descr,
276 rows,
277 cols,
278 nnz,
279 col_offsets.as_raw().0 as *mut c_void,
280 row_indices.as_raw().0 as *mut c_void,
281 values.as_raw().0 as *mut c_void,
282 cusparseIndexType_t::I32I,
283 cusparseIndexType_t::I32I,
284 cusparseIndexBase_t::Zero,
285 T::data_type(),
286 )
287 })?;
288 Ok(Self {
289 descr,
290 _markers: PhantomData,
291 })
292 }
293
294 #[allow(clippy::too_many_arguments)]
296 pub fn bsr(
297 brows: i64,
298 bcols: i64,
299 bnnz: i64,
300 row_block_dim: i64,
301 col_block_dim: i64,
302 order: Order,
303 row_offsets: &'buf mut DeviceBuffer<i32>,
304 col_indices: &'buf mut DeviceBuffer<i32>,
305 values: &'buf mut DeviceBuffer<T>,
306 ) -> Result<Self> {
307 let c = cusparse()?;
308 let cu = c.cusparse_create_bsr()?;
309 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
310 check(unsafe {
311 cu(
312 &mut descr,
313 brows,
314 bcols,
315 bnnz,
316 row_block_dim,
317 col_block_dim,
318 row_offsets.as_raw().0 as *mut c_void,
319 col_indices.as_raw().0 as *mut c_void,
320 values.as_raw().0 as *mut c_void,
321 cusparseIndexType_t::I32I,
322 cusparseIndexType_t::I32I,
323 cusparseIndexBase_t::Zero,
324 T::data_type(),
325 order.raw(),
326 )
327 })?;
328 Ok(Self {
329 descr,
330 _markers: PhantomData,
331 })
332 }
333
334 pub fn coo(
336 rows: i64,
337 cols: i64,
338 nnz: i64,
339 row_indices: &'buf mut DeviceBuffer<i32>,
340 col_indices: &'buf mut DeviceBuffer<i32>,
341 values: &'buf mut DeviceBuffer<T>,
342 ) -> Result<Self> {
343 let c = cusparse()?;
344 let cu = c.cusparse_create_coo()?;
345 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
346 check(unsafe {
347 cu(
348 &mut descr,
349 rows,
350 cols,
351 nnz,
352 row_indices.as_raw().0 as *mut c_void,
353 col_indices.as_raw().0 as *mut c_void,
354 values.as_raw().0 as *mut c_void,
355 cusparseIndexType_t::I32I,
356 cusparseIndexBase_t::Zero,
357 T::data_type(),
358 )
359 })?;
360 Ok(Self {
361 descr,
362 _markers: PhantomData,
363 })
364 }
365}
366
367impl<T> SpMat<'_, T> {
368 pub fn shape(&self) -> Result<(i64, i64, i64)> {
370 let c = cusparse()?;
371 let cu = c.cusparse_sp_mat_get_size()?;
372 let (mut r, mut col, mut nz) = (0i64, 0i64, 0i64);
373 check(unsafe { cu(self.descr, &mut r, &mut col, &mut nz) })?;
374 Ok((r, col, nz))
375 }
376
377 pub unsafe fn set_csr_pointers(
388 &self,
389 row_offsets: *mut c_void,
390 col_indices: *mut c_void,
391 values: *mut c_void,
392 ) -> Result<()> { unsafe {
393 let c = cusparse()?;
394 let cu = c.cusparse_csr_set_pointers()?;
395 check(cu(self.descr, row_offsets, col_indices, values))
396 }}
397
398 pub unsafe fn set_csc_pointers(
404 &self,
405 col_offsets: *mut c_void,
406 row_indices: *mut c_void,
407 values: *mut c_void,
408 ) -> Result<()> { unsafe {
409 let c = cusparse()?;
410 let cu = c.cusparse_csc_set_pointers()?;
411 check(cu(self.descr, col_offsets, row_indices, values))
412 }}
413
414 pub unsafe fn set_coo_pointers(
420 &self,
421 row_indices: *mut c_void,
422 col_indices: *mut c_void,
423 values: *mut c_void,
424 ) -> Result<()> { unsafe {
425 let c = cusparse()?;
426 let cu = c.cusparse_coo_set_pointers()?;
427 check(cu(self.descr, row_indices, col_indices, values))
428 }}
429
430 pub fn set_fill(&self, fill: Fill) -> Result<()> {
432 let c = cusparse()?;
433 let cu = c.cusparse_sp_mat_set_attribute()?;
434 let raw = fill.raw();
435 check(unsafe {
436 cu(
437 self.descr,
438 cusparseSpMatAttribute_t::FillMode,
439 &raw as *const _ as *const c_void,
440 core::mem::size_of::<cusparseFillMode_t>(),
441 )
442 })
443 }
444
445 pub fn set_diag(&self, diag: Diag) -> Result<()> {
447 let c = cusparse()?;
448 let cu = c.cusparse_sp_mat_set_attribute()?;
449 let raw = diag.raw();
450 check(unsafe {
451 cu(
452 self.descr,
453 cusparseSpMatAttribute_t::DiagType,
454 &raw as *const _ as *const c_void,
455 core::mem::size_of::<cusparseDiagType_t>(),
456 )
457 })
458 }
459
460 #[inline]
461 pub fn as_raw(&self) -> cusparseSpMatDescr_t {
462 self.descr
463 }
464}
465
466impl<T> Drop for SpMat<'_, T> {
467 fn drop(&mut self) {
468 if let Ok(c) = cusparse() {
469 if let Ok(cu) = c.cusparse_destroy_sp_mat() {
470 let _ = unsafe { cu(self.descr) };
471 }
472 }
473 }
474}
475
476pub struct DnVec<'buf, T> {
479 descr: cusparseDnVecDescr_t,
480 _marker: PhantomData<&'buf mut T>,
481}
482
483unsafe impl<T> Send for DnVec<'_, T> {}
484
485impl<T> core::fmt::Debug for DnVec<'_, T> {
486 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
487 f.debug_struct("DnVec")
488 .field("descr", &self.descr)
489 .finish_non_exhaustive()
490 }
491}
492
493impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnVec<'buf, T> {
494 pub fn new(values: &'buf mut DeviceBuffer<T>) -> Result<Self> {
495 let c = cusparse()?;
496 let cu = c.cusparse_create_dn_vec()?;
497 let mut descr: cusparseDnVecDescr_t = core::ptr::null_mut();
498 check(unsafe {
499 cu(
500 &mut descr,
501 values.len() as i64,
502 values.as_raw().0 as *mut c_void,
503 T::data_type(),
504 )
505 })?;
506 Ok(Self {
507 descr,
508 _marker: PhantomData,
509 })
510 }
511}
512
513impl<T> DnVec<'_, T> {
514 #[inline]
515 pub fn as_raw(&self) -> cusparseDnVecDescr_t {
516 self.descr
517 }
518}
519
520impl<T> Drop for DnVec<'_, T> {
521 fn drop(&mut self) {
522 if let Ok(c) = cusparse() {
523 if let Ok(cu) = c.cusparse_destroy_dn_vec() {
524 let _ = unsafe { cu(self.descr) };
525 }
526 }
527 }
528}
529
530pub struct DnMat<'buf, T> {
531 descr: cusparseDnMatDescr_t,
532 _marker: PhantomData<&'buf mut T>,
533}
534
535unsafe impl<T> Send for DnMat<'_, T> {}
536
537impl<T> core::fmt::Debug for DnMat<'_, T> {
538 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
539 f.debug_struct("DnMat")
540 .field("descr", &self.descr)
541 .finish_non_exhaustive()
542 }
543}
544
545impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnMat<'buf, T> {
546 pub fn new(
547 rows: i64,
548 cols: i64,
549 ld: i64,
550 order: Order,
551 values: &'buf mut DeviceBuffer<T>,
552 ) -> Result<Self> {
553 let c = cusparse()?;
554 let cu = c.cusparse_create_dn_mat()?;
555 let mut descr: cusparseDnMatDescr_t = core::ptr::null_mut();
556 check(unsafe {
557 cu(
558 &mut descr,
559 rows,
560 cols,
561 ld,
562 values.as_raw().0 as *mut c_void,
563 T::data_type(),
564 order.raw(),
565 )
566 })?;
567 Ok(Self {
568 descr,
569 _marker: PhantomData,
570 })
571 }
572}
573
574impl<T> DnMat<'_, T> {
575 #[inline]
576 pub fn as_raw(&self) -> cusparseDnMatDescr_t {
577 self.descr
578 }
579}
580
581impl<T> Drop for DnMat<'_, T> {
582 fn drop(&mut self) {
583 if let Ok(c) = cusparse() {
584 if let Ok(cu) = c.cusparse_destroy_dn_mat() {
585 let _ = unsafe { cu(self.descr) };
586 }
587 }
588 }
589}
590
591#[allow(clippy::too_many_arguments)]
595pub fn spmv_buffer_size<T: SparseScalar>(
596 handle: &Handle,
597 op: Op,
598 alpha: &T,
599 a: &SpMat<'_, T>,
600 x: &DnVec<'_, T>,
601 beta: &T,
602 y: &DnVec<'_, T>,
603 alg: SpMVAlg,
604) -> Result<usize> {
605 let c = cusparse()?;
606 let cu = c.cusparse_spmv_buffer_size()?;
607 let mut size: usize = 0;
608 check(unsafe {
609 cu(
610 handle.as_raw(),
611 op.raw(),
612 alpha as *const T as *const c_void,
613 a.descr,
614 x.descr,
615 beta as *const T as *const c_void,
616 y.descr,
617 T::data_type(),
618 alg,
619 &mut size,
620 )
621 })?;
622 Ok(size)
623}
624
625#[allow(clippy::too_many_arguments)]
627pub fn spmv<T: SparseScalar>(
628 handle: &Handle,
629 op: Op,
630 alpha: &T,
631 a: &SpMat<'_, T>,
632 x: &DnVec<'_, T>,
633 beta: &T,
634 y: &mut DnVec<'_, T>,
635 alg: SpMVAlg,
636 workspace: &mut DeviceBuffer<u8>,
637) -> Result<()> {
638 let c = cusparse()?;
639 let cu = c.cusparse_spmv()?;
640 check(unsafe {
641 cu(
642 handle.as_raw(),
643 op.raw(),
644 alpha as *const T as *const c_void,
645 a.descr,
646 x.descr,
647 beta as *const T as *const c_void,
648 y.descr,
649 T::data_type(),
650 alg,
651 workspace.as_raw().0 as *mut c_void,
652 )
653 })
654}
655
656#[allow(clippy::too_many_arguments)]
660pub fn spmm_buffer_size<T: SparseScalar>(
661 handle: &Handle,
662 op_a: Op,
663 op_b: Op,
664 alpha: &T,
665 a: &SpMat<'_, T>,
666 b: &DnMat<'_, T>,
667 beta: &T,
668 c: &DnMat<'_, T>,
669 alg: SpMMAlg,
670) -> Result<usize> {
671 let c_api = cusparse()?;
672 let cu = c_api.cusparse_spmm_buffer_size()?;
673 let mut size = 0usize;
674 check(unsafe {
675 cu(
676 handle.as_raw(),
677 op_a.raw(),
678 op_b.raw(),
679 alpha as *const T as *const c_void,
680 a.descr,
681 b.descr,
682 beta as *const T as *const c_void,
683 c.descr,
684 T::data_type(),
685 alg,
686 &mut size,
687 )
688 })?;
689 Ok(size)
690}
691
692#[allow(clippy::too_many_arguments)]
693pub fn spmm<T: SparseScalar>(
694 handle: &Handle,
695 op_a: Op,
696 op_b: Op,
697 alpha: &T,
698 a: &SpMat<'_, T>,
699 b: &DnMat<'_, T>,
700 beta: &T,
701 c: &mut DnMat<'_, T>,
702 alg: SpMMAlg,
703 workspace: &mut DeviceBuffer<u8>,
704) -> Result<()> {
705 let c_api = cusparse()?;
706 let cu = c_api.cusparse_spmm()?;
707 check(unsafe {
708 cu(
709 handle.as_raw(),
710 op_a.raw(),
711 op_b.raw(),
712 alpha as *const T as *const c_void,
713 a.descr,
714 b.descr,
715 beta as *const T as *const c_void,
716 c.descr,
717 T::data_type(),
718 alg,
719 workspace.as_raw().0 as *mut c_void,
720 )
721 })
722}
723
724#[allow(clippy::too_many_arguments)]
729pub fn spmm_preprocess<T: SparseScalar>(
730 handle: &Handle,
731 op_a: Op,
732 op_b: Op,
733 alpha: &T,
734 a: &SpMat<'_, T>,
735 b: &DnMat<'_, T>,
736 beta: &T,
737 c: &mut DnMat<'_, T>,
738 alg: SpMMAlg,
739 workspace: &mut DeviceBuffer<u8>,
740) -> Result<()> {
741 let c_api = cusparse()?;
742 let cu = c_api.cusparse_spmm_preprocess()?;
743 check(unsafe {
744 cu(
745 handle.as_raw(),
746 op_a.raw(),
747 op_b.raw(),
748 alpha as *const T as *const c_void,
749 a.descr,
750 b.descr,
751 beta as *const T as *const c_void,
752 c.descr,
753 T::data_type(),
754 alg,
755 workspace.as_raw().0 as *mut c_void,
756 )
757 })
758}
759
760#[derive(Debug)]
764pub struct SpGEMMPlan {
765 raw: cusparseSpGEMMDescr_t,
766}
767
768impl SpGEMMPlan {
769 pub fn new() -> Result<Self> {
770 let c = cusparse()?;
771 let cu = c.cusparse_spgemm_create_descr()?;
772 let mut d: cusparseSpGEMMDescr_t = core::ptr::null_mut();
773 check(unsafe { cu(&mut d) })?;
774 Ok(Self { raw: d })
775 }
776}
777
778impl Drop for SpGEMMPlan {
779 fn drop(&mut self) {
780 if let Ok(c) = cusparse() {
781 if let Ok(cu) = c.cusparse_spgemm_destroy_descr() {
782 let _ = unsafe { cu(self.raw) };
783 }
784 }
785 }
786}
787
788#[allow(clippy::too_many_arguments)]
797pub unsafe fn spgemm_work_estimation<T: SparseScalar>(
798 handle: &Handle,
799 op_a: Op,
800 op_b: Op,
801 alpha: &T,
802 a: &SpMat<'_, T>,
803 b: &SpMat<'_, T>,
804 beta: &T,
805 c: &mut SpMat<'_, T>,
806 alg: SpGEMMAlg,
807 plan: &SpGEMMPlan,
808 size1: &mut usize,
809 buffer1: *mut c_void,
810) -> Result<()> { unsafe {
811 let c_api = cusparse()?;
812 let cu = c_api.cusparse_spgemm_work_estimation()?;
813 check(cu(
814 handle.as_raw(),
815 op_a.raw(),
816 op_b.raw(),
817 alpha as *const T as *const c_void,
818 a.descr,
819 b.descr,
820 beta as *const T as *const c_void,
821 c.descr,
822 T::data_type(),
823 alg,
824 plan.raw,
825 size1,
826 buffer1,
827 ))
828}}
829
830#[allow(clippy::too_many_arguments)]
838pub unsafe fn spgemm_compute<T: SparseScalar>(
839 handle: &Handle,
840 op_a: Op,
841 op_b: Op,
842 alpha: &T,
843 a: &SpMat<'_, T>,
844 b: &SpMat<'_, T>,
845 beta: &T,
846 c: &mut SpMat<'_, T>,
847 alg: SpGEMMAlg,
848 plan: &SpGEMMPlan,
849 size2: &mut usize,
850 buffer2: *mut c_void,
851) -> Result<()> { unsafe {
852 let c_api = cusparse()?;
853 let cu = c_api.cusparse_spgemm_compute()?;
854 check(cu(
855 handle.as_raw(),
856 op_a.raw(),
857 op_b.raw(),
858 alpha as *const T as *const c_void,
859 a.descr,
860 b.descr,
861 beta as *const T as *const c_void,
862 c.descr,
863 T::data_type(),
864 alg,
865 plan.raw,
866 size2,
867 buffer2,
868 ))
869}}
870
871#[allow(clippy::too_many_arguments)]
873pub fn spgemm_copy<T: SparseScalar>(
874 handle: &Handle,
875 op_a: Op,
876 op_b: Op,
877 alpha: &T,
878 a: &SpMat<'_, T>,
879 b: &SpMat<'_, T>,
880 beta: &T,
881 c: &mut SpMat<'_, T>,
882 alg: SpGEMMAlg,
883 plan: &SpGEMMPlan,
884) -> Result<()> {
885 let c_api = cusparse()?;
886 let cu = c_api.cusparse_spgemm_copy()?;
887 check(unsafe {
888 cu(
889 handle.as_raw(),
890 op_a.raw(),
891 op_b.raw(),
892 alpha as *const T as *const c_void,
893 a.descr,
894 b.descr,
895 beta as *const T as *const c_void,
896 c.descr,
897 T::data_type(),
898 alg,
899 plan.raw,
900 )
901 })
902}
903
904#[derive(Debug)]
907pub struct SpSVPlan {
908 raw: cusparseSpSVDescr_t,
909}
910
911impl SpSVPlan {
912 pub fn new() -> Result<Self> {
913 let c = cusparse()?;
914 let cu = c.cusparse_spsv_create_descr()?;
915 let mut d: cusparseSpSVDescr_t = core::ptr::null_mut();
916 check(unsafe { cu(&mut d) })?;
917 Ok(Self { raw: d })
918 }
919}
920
921impl Drop for SpSVPlan {
922 fn drop(&mut self) {
923 if let Ok(c) = cusparse() {
924 if let Ok(cu) = c.cusparse_spsv_destroy_descr() {
925 let _ = unsafe { cu(self.raw) };
926 }
927 }
928 }
929}
930
931#[allow(clippy::too_many_arguments)]
932pub fn spsv_buffer_size<T: SparseScalar>(
933 handle: &Handle,
934 op: Op,
935 alpha: &T,
936 a: &SpMat<'_, T>,
937 x: &DnVec<'_, T>,
938 y: &DnVec<'_, T>,
939 alg: SpSVAlg,
940 plan: &SpSVPlan,
941) -> Result<usize> {
942 let c = cusparse()?;
943 let cu = c.cusparse_spsv_buffer_size()?;
944 let mut size = 0usize;
945 check(unsafe {
946 cu(
947 handle.as_raw(),
948 op.raw(),
949 alpha as *const T as *const c_void,
950 a.descr,
951 x.descr,
952 y.descr,
953 T::data_type(),
954 alg,
955 plan.raw,
956 &mut size,
957 )
958 })?;
959 Ok(size)
960}
961
962#[allow(clippy::too_many_arguments)]
963pub fn spsv_analysis<T: SparseScalar>(
964 handle: &Handle,
965 op: Op,
966 alpha: &T,
967 a: &SpMat<'_, T>,
968 x: &DnVec<'_, T>,
969 y: &DnVec<'_, T>,
970 alg: SpSVAlg,
971 plan: &SpSVPlan,
972 workspace: &mut DeviceBuffer<u8>,
973) -> Result<()> {
974 let c = cusparse()?;
975 let cu = c.cusparse_spsv_analysis()?;
976 check(unsafe {
977 cu(
978 handle.as_raw(),
979 op.raw(),
980 alpha as *const T as *const c_void,
981 a.descr,
982 x.descr,
983 y.descr,
984 T::data_type(),
985 alg,
986 plan.raw,
987 workspace.as_raw().0 as *mut c_void,
988 )
989 })
990}
991
992#[allow(clippy::too_many_arguments)]
993pub fn spsv_solve<T: SparseScalar>(
994 handle: &Handle,
995 op: Op,
996 alpha: &T,
997 a: &SpMat<'_, T>,
998 x: &DnVec<'_, T>,
999 y: &mut DnVec<'_, T>,
1000 alg: SpSVAlg,
1001 plan: &SpSVPlan,
1002) -> Result<()> {
1003 let c = cusparse()?;
1004 let cu = c.cusparse_spsv_solve()?;
1005 check(unsafe {
1006 cu(
1007 handle.as_raw(),
1008 op.raw(),
1009 alpha as *const T as *const c_void,
1010 a.descr,
1011 x.descr,
1012 y.descr,
1013 T::data_type(),
1014 alg,
1015 plan.raw,
1016 )
1017 })
1018}
1019
1020#[derive(Debug)]
1021pub struct SpSMPlan {
1022 raw: cusparseSpSMDescr_t,
1023}
1024
1025impl SpSMPlan {
1026 pub fn new() -> Result<Self> {
1027 let c = cusparse()?;
1028 let cu = c.cusparse_spsm_create_descr()?;
1029 let mut d: cusparseSpSMDescr_t = core::ptr::null_mut();
1030 check(unsafe { cu(&mut d) })?;
1031 Ok(Self { raw: d })
1032 }
1033}
1034
1035impl Drop for SpSMPlan {
1036 fn drop(&mut self) {
1037 if let Ok(c) = cusparse() {
1038 if let Ok(cu) = c.cusparse_spsm_destroy_descr() {
1039 let _ = unsafe { cu(self.raw) };
1040 }
1041 }
1042 }
1043}
1044
1045#[allow(clippy::too_many_arguments)]
1046pub fn spsm_buffer_size<T: SparseScalar>(
1047 handle: &Handle,
1048 op_a: Op,
1049 op_b: Op,
1050 alpha: &T,
1051 a: &SpMat<'_, T>,
1052 b: &DnMat<'_, T>,
1053 c: &DnMat<'_, T>,
1054 alg: SpSMAlg,
1055 plan: &SpSMPlan,
1056) -> Result<usize> {
1057 let c_api = cusparse()?;
1058 let cu = c_api.cusparse_spsm_buffer_size()?;
1059 let mut size = 0usize;
1060 check(unsafe {
1061 cu(
1062 handle.as_raw(),
1063 op_a.raw(),
1064 op_b.raw(),
1065 alpha as *const T as *const c_void,
1066 a.descr,
1067 b.descr,
1068 c.descr,
1069 T::data_type(),
1070 alg,
1071 plan.raw,
1072 &mut size,
1073 )
1074 })?;
1075 Ok(size)
1076}
1077
1078#[allow(clippy::too_many_arguments)]
1079pub fn spsm_analysis<T: SparseScalar>(
1080 handle: &Handle,
1081 op_a: Op,
1082 op_b: Op,
1083 alpha: &T,
1084 a: &SpMat<'_, T>,
1085 b: &DnMat<'_, T>,
1086 c: &DnMat<'_, T>,
1087 alg: SpSMAlg,
1088 plan: &SpSMPlan,
1089 workspace: &mut DeviceBuffer<u8>,
1090) -> Result<()> {
1091 let c_api = cusparse()?;
1092 let cu = c_api.cusparse_spsm_analysis()?;
1093 check(unsafe {
1094 cu(
1095 handle.as_raw(),
1096 op_a.raw(),
1097 op_b.raw(),
1098 alpha as *const T as *const c_void,
1099 a.descr,
1100 b.descr,
1101 c.descr,
1102 T::data_type(),
1103 alg,
1104 plan.raw,
1105 workspace.as_raw().0 as *mut c_void,
1106 )
1107 })
1108}
1109
1110#[allow(clippy::too_many_arguments)]
1111pub fn spsm_solve<T: SparseScalar>(
1112 handle: &Handle,
1113 op_a: Op,
1114 op_b: Op,
1115 alpha: &T,
1116 a: &SpMat<'_, T>,
1117 b: &DnMat<'_, T>,
1118 c: &mut DnMat<'_, T>,
1119 alg: SpSMAlg,
1120 plan: &SpSMPlan,
1121) -> Result<()> {
1122 let c_api = cusparse()?;
1123 let cu = c_api.cusparse_spsm_solve()?;
1124 check(unsafe {
1125 cu(
1126 handle.as_raw(),
1127 op_a.raw(),
1128 op_b.raw(),
1129 alpha as *const T as *const c_void,
1130 a.descr,
1131 b.descr,
1132 c.descr,
1133 T::data_type(),
1134 alg,
1135 plan.raw,
1136 )
1137 })
1138}
1139
1140#[allow(clippy::too_many_arguments)]
1143pub fn sddmm_buffer_size<T: SparseScalar>(
1144 handle: &Handle,
1145 op_a: Op,
1146 op_b: Op,
1147 alpha: &T,
1148 a: &DnMat<'_, T>,
1149 b: &DnMat<'_, T>,
1150 beta: &T,
1151 c: &SpMat<'_, T>,
1152 alg: SDDMMAlg,
1153) -> Result<usize> {
1154 let c_api = cusparse()?;
1155 let cu = c_api.cusparse_sddmm_buffer_size()?;
1156 let mut size = 0usize;
1157 check(unsafe {
1158 cu(
1159 handle.as_raw(),
1160 op_a.raw(),
1161 op_b.raw(),
1162 alpha as *const T as *const c_void,
1163 a.descr,
1164 b.descr,
1165 beta as *const T as *const c_void,
1166 c.descr,
1167 T::data_type(),
1168 alg,
1169 &mut size,
1170 )
1171 })?;
1172 Ok(size)
1173}
1174
1175#[allow(clippy::too_many_arguments)]
1176pub fn sddmm<T: SparseScalar>(
1177 handle: &Handle,
1178 op_a: Op,
1179 op_b: Op,
1180 alpha: &T,
1181 a: &DnMat<'_, T>,
1182 b: &DnMat<'_, T>,
1183 beta: &T,
1184 c: &mut SpMat<'_, T>,
1185 alg: SDDMMAlg,
1186 workspace: &mut DeviceBuffer<u8>,
1187) -> Result<()> {
1188 let c_api = cusparse()?;
1189 let cu = c_api.cusparse_sddmm()?;
1190 check(unsafe {
1191 cu(
1192 handle.as_raw(),
1193 op_a.raw(),
1194 op_b.raw(),
1195 alpha as *const T as *const c_void,
1196 a.descr,
1197 b.descr,
1198 beta as *const T as *const c_void,
1199 c.descr,
1200 T::data_type(),
1201 alg,
1202 workspace.as_raw().0 as *mut c_void,
1203 )
1204 })
1205}
1206
1207#[allow(clippy::too_many_arguments)]
1210pub fn sddmm_preprocess<T: SparseScalar>(
1211 handle: &Handle,
1212 op_a: Op,
1213 op_b: Op,
1214 alpha: &T,
1215 a: &DnMat<'_, T>,
1216 b: &DnMat<'_, T>,
1217 beta: &T,
1218 c: &mut SpMat<'_, T>,
1219 alg: SDDMMAlg,
1220 workspace: &mut DeviceBuffer<u8>,
1221) -> Result<()> {
1222 let c_api = cusparse()?;
1223 let cu = c_api.cusparse_sddmm_preprocess()?;
1224 check(unsafe {
1225 cu(
1226 handle.as_raw(),
1227 op_a.raw(),
1228 op_b.raw(),
1229 alpha as *const T as *const c_void,
1230 a.descr,
1231 b.descr,
1232 beta as *const T as *const c_void,
1233 c.descr,
1234 T::data_type(),
1235 alg,
1236 workspace.as_raw().0 as *mut c_void,
1237 )
1238 })
1239}
1240
1241pub fn sparse_to_dense_buffer_size<T: SparseScalar>(
1244 handle: &Handle,
1245 sp: &SpMat<'_, T>,
1246 dn: &DnMat<'_, T>,
1247) -> Result<usize> {
1248 let c = cusparse()?;
1249 let cu = c.cusparse_sparse_to_dense_buffer_size()?;
1250 let mut size = 0usize;
1251 check(unsafe { cu(handle.as_raw(), sp.descr, dn.descr, 0, &mut size) })?;
1252 Ok(size)
1253}
1254
1255pub fn sparse_to_dense<T: SparseScalar>(
1256 handle: &Handle,
1257 sp: &SpMat<'_, T>,
1258 dn: &mut DnMat<'_, T>,
1259 workspace: &mut DeviceBuffer<u8>,
1260) -> Result<()> {
1261 let c = cusparse()?;
1262 let cu = c.cusparse_sparse_to_dense()?;
1263 check(unsafe {
1264 cu(
1265 handle.as_raw(),
1266 sp.descr,
1267 dn.descr,
1268 0,
1269 workspace.as_raw().0 as *mut c_void,
1270 )
1271 })
1272}
1273
1274pub fn dense_to_sparse_buffer_size<T: SparseScalar>(
1275 handle: &Handle,
1276 dn: &DnMat<'_, T>,
1277 sp: &SpMat<'_, T>,
1278) -> Result<usize> {
1279 let c = cusparse()?;
1280 let cu = c.cusparse_dense_to_sparse_buffer_size()?;
1281 let mut size = 0usize;
1282 check(unsafe { cu(handle.as_raw(), dn.descr, sp.descr, 0, &mut size) })?;
1283 Ok(size)
1284}
1285
1286pub fn dense_to_sparse_analysis<T: SparseScalar>(
1287 handle: &Handle,
1288 dn: &DnMat<'_, T>,
1289 sp: &SpMat<'_, T>,
1290 workspace: &mut DeviceBuffer<u8>,
1291) -> Result<()> {
1292 let c = cusparse()?;
1293 let cu = c.cusparse_dense_to_sparse_analysis()?;
1294 check(unsafe {
1295 cu(
1296 handle.as_raw(),
1297 dn.descr,
1298 sp.descr,
1299 0,
1300 workspace.as_raw().0 as *mut c_void,
1301 )
1302 })
1303}
1304
1305pub fn dense_to_sparse_convert<T: SparseScalar>(
1306 handle: &Handle,
1307 dn: &DnMat<'_, T>,
1308 sp: &mut SpMat<'_, T>,
1309 workspace: &mut DeviceBuffer<u8>,
1310) -> Result<()> {
1311 let c = cusparse()?;
1312 let cu = c.cusparse_dense_to_sparse_convert()?;
1313 check(unsafe {
1314 cu(
1315 handle.as_raw(),
1316 dn.descr,
1317 sp.descr,
1318 0,
1319 workspace.as_raw().0 as *mut c_void,
1320 )
1321 })
1322}
1323
1324#[allow(clippy::too_many_arguments)]
1326pub fn csr2csc_ex2_buffer_size<T: SparseScalar + baracuda_types::DeviceRepr>(
1327 handle: &Handle,
1328 m: i32,
1329 n: i32,
1330 nnz: i32,
1331 csr_val: &DeviceBuffer<T>,
1332 csr_row_ptr: &DeviceBuffer<i32>,
1333 csr_col_ind: &DeviceBuffer<i32>,
1334 csc_val: &mut DeviceBuffer<T>,
1335 csc_col_ptr: &mut DeviceBuffer<i32>,
1336 csc_row_ind: &mut DeviceBuffer<i32>,
1337 copy_values: bool,
1338 idx_base: IndexBase,
1339 alg: Csr2CscAlg,
1340) -> Result<usize> {
1341 let c = cusparse()?;
1342 let cu = c.cusparse_csr2csc_ex2_buffer_size()?;
1343 let mut size = 0usize;
1344 check(unsafe {
1345 cu(
1346 handle.as_raw(),
1347 m,
1348 n,
1349 nnz,
1350 csr_val.as_raw().0 as *const c_void,
1351 csr_row_ptr.as_raw().0 as *const i32,
1352 csr_col_ind.as_raw().0 as *const i32,
1353 csc_val.as_raw().0 as *mut c_void,
1354 csc_col_ptr.as_raw().0 as *mut i32,
1355 csc_row_ind.as_raw().0 as *mut i32,
1356 T::data_type(),
1357 copy_values as i32,
1358 idx_base,
1359 alg,
1360 &mut size,
1361 )
1362 })?;
1363 Ok(size)
1364}
1365
1366#[allow(clippy::too_many_arguments)]
1369pub fn csr2csc_ex2<T: SparseScalar + baracuda_types::DeviceRepr>(
1370 handle: &Handle,
1371 m: i32,
1372 n: i32,
1373 nnz: i32,
1374 csr_val: &DeviceBuffer<T>,
1375 csr_row_ptr: &DeviceBuffer<i32>,
1376 csr_col_ind: &DeviceBuffer<i32>,
1377 csc_val: &mut DeviceBuffer<T>,
1378 csc_col_ptr: &mut DeviceBuffer<i32>,
1379 csc_row_ind: &mut DeviceBuffer<i32>,
1380 copy_values: bool,
1381 idx_base: IndexBase,
1382 alg: Csr2CscAlg,
1383 workspace: &mut DeviceBuffer<u8>,
1384) -> Result<()> {
1385 let c = cusparse()?;
1386 let cu = c.cusparse_csr2csc_ex2()?;
1387 check(unsafe {
1388 cu(
1389 handle.as_raw(),
1390 m,
1391 n,
1392 nnz,
1393 csr_val.as_raw().0 as *const c_void,
1394 csr_row_ptr.as_raw().0 as *const i32,
1395 csr_col_ind.as_raw().0 as *const i32,
1396 csc_val.as_raw().0 as *mut c_void,
1397 csc_col_ptr.as_raw().0 as *mut i32,
1398 csc_row_ind.as_raw().0 as *mut i32,
1399 T::data_type(),
1400 copy_values as i32,
1401 idx_base,
1402 alg,
1403 workspace.as_raw().0 as *mut c_void,
1404 )
1405 })
1406}
1407
1408pub fn axpby<T: SparseScalar>(
1411 handle: &Handle,
1412 alpha: &T,
1413 x: &DnVec<'_, T>,
1414 beta: &T,
1415 y: &mut DnVec<'_, T>,
1416) -> Result<()> {
1417 let c = cusparse()?;
1418 let cu = c.cusparse_axpby()?;
1419 check(unsafe {
1420 cu(
1421 handle.as_raw(),
1422 alpha as *const T as *const c_void,
1423 x.descr,
1424 beta as *const T as *const c_void,
1425 y.descr,
1426 )
1427 })
1428}
1429
1430pub fn gather<T: SparseScalar>(
1431 handle: &Handle,
1432 y: &DnVec<'_, T>,
1433 x: &mut DnVec<'_, T>,
1434) -> Result<()> {
1435 let c = cusparse()?;
1436 let cu = c.cusparse_gather()?;
1437 check(unsafe { cu(handle.as_raw(), y.descr, x.descr) })
1438}
1439
1440pub fn scatter<T: SparseScalar>(
1441 handle: &Handle,
1442 x: &DnVec<'_, T>,
1443 y: &mut DnVec<'_, T>,
1444) -> Result<()> {
1445 let c = cusparse()?;
1446 let cu = c.cusparse_scatter()?;
1447 check(unsafe { cu(handle.as_raw(), x.descr, y.descr) })
1448}
1449
1450pub fn rot<T: SparseScalar>(
1451 handle: &Handle,
1452 c_cos: &T,
1453 s_sin: &T,
1454 x: &mut DnVec<'_, T>,
1455 y: &mut DnVec<'_, T>,
1456) -> Result<()> {
1457 let c_api = cusparse()?;
1458 let cu = c_api.cusparse_rot()?;
1459 check(unsafe {
1460 cu(
1461 handle.as_raw(),
1462 c_cos as *const T as *const c_void,
1463 s_sin as *const T as *const c_void,
1464 x.descr,
1465 y.descr,
1466 )
1467 })
1468}
1469
1470pub type CsrMatrix<'buf> = SpMat<'buf, f32>;
1474pub type DenseVector<'buf, T> = DnVec<'buf, T>;