1#![warn(missing_debug_implementations)]
16
17use core::ffi::{c_int, c_void};
18use std::marker::PhantomData;
19
20use baracuda_cusolver_sys::{
21 cublasFillMode_t, cublasOperation_t, cuComplex, cuDoubleComplex, cusolver,
22 cusolverDnHandle_t, cusolverEigMode_t, cusolverStatus_t,
23};
24use baracuda_driver::{DeviceBuffer, Stream};
25use baracuda_types::{Complex32, Complex64, DeviceRepr};
26
27pub use baracuda_cusolver_sys::{
28 cublasFillMode_t as Fill, cusolverEigMode_t as EigMode,
29};
30
31pub type Error = baracuda_core::Error<cusolverStatus_t>;
33pub type Result<T, E = Error> = core::result::Result<T, E>;
35
36#[inline]
37fn check(status: cusolverStatus_t) -> Result<()> {
38 Error::check(status)
39}
40
41fn alloc_fail<E>(_e: E) -> Error {
43 Error::Status {
44 status: cusolverStatus_t::ALLOC_FAILED,
45 }
46}
47
48pub struct DnHandle {
52 handle: cusolverDnHandle_t,
53}
54
55unsafe impl Send for DnHandle {}
56
57impl core::fmt::Debug for DnHandle {
58 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59 f.debug_struct("cusolver::DnHandle")
60 .field("handle", &self.handle)
61 .finish()
62 }
63}
64
65impl DnHandle {
66 pub fn new() -> Result<Self> {
67 let c = cusolver()?;
68 let cu = c.cusolver_dn_create()?;
69 let mut h: cusolverDnHandle_t = core::ptr::null_mut();
70 check(unsafe { cu(&mut h) })?;
71 Ok(Self { handle: h })
72 }
73
74 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
75 let c = cusolver()?;
76 let cu = c.cusolver_dn_set_stream()?;
77 check(unsafe { cu(self.handle, stream.as_raw() as _) })
78 }
79
80 pub fn version() -> Result<i32> {
81 let c = cusolver()?;
82 let cu = c.cusolver_get_version()?;
83 let mut v: c_int = 0;
84 check(unsafe { cu(&mut v) })?;
85 Ok(v)
86 }
87
88 #[inline]
89 pub fn as_raw(&self) -> cusolverDnHandle_t {
90 self.handle
91 }
92}
93
94impl Drop for DnHandle {
95 fn drop(&mut self) {
96 if let Ok(c) = cusolver() {
97 if let Ok(cu) = c.cusolver_dn_destroy() {
98 let _ = unsafe { cu(self.handle) };
99 }
100 }
101 }
102}
103
104#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
106pub enum Op {
107 #[default]
108 N,
109 T,
110 C,
111}
112
113impl Op {
114 fn raw(self) -> cublasOperation_t {
115 match self {
116 Op::N => cublasOperation_t::N,
117 Op::T => cublasOperation_t::T,
118 Op::C => cublasOperation_t::C,
119 }
120 }
121}
122
123pub trait SolverScalar: DeviceRepr + Copy + 'static + sealed::Sealed {
127 type Real: DeviceRepr + Copy + 'static;
130
131 #[doc(hidden)]
133 unsafe fn getrf_buf(
134 h: cusolverDnHandle_t,
135 m: c_int,
136 n: c_int,
137 a: *mut Self,
138 lda: c_int,
139 lwork: *mut c_int,
140 ) -> cusolverStatus_t;
141
142 #[doc(hidden)]
144 #[allow(clippy::too_many_arguments)]
145 unsafe fn getrf(
146 h: cusolverDnHandle_t,
147 m: c_int,
148 n: c_int,
149 a: *mut Self,
150 lda: c_int,
151 workspace: *mut Self,
152 ipiv: *mut c_int,
153 info: *mut c_int,
154 ) -> cusolverStatus_t;
155
156 #[doc(hidden)]
158 #[allow(clippy::too_many_arguments)]
159 unsafe fn getrs(
160 h: cusolverDnHandle_t,
161 trans: cublasOperation_t,
162 n: c_int,
163 nrhs: c_int,
164 a: *const Self,
165 lda: c_int,
166 ipiv: *const c_int,
167 b: *mut Self,
168 ldb: c_int,
169 info: *mut c_int,
170 ) -> cusolverStatus_t;
171
172 #[doc(hidden)]
174 unsafe fn geqrf_buf(
175 h: cusolverDnHandle_t,
176 m: c_int,
177 n: c_int,
178 a: *mut Self,
179 lda: c_int,
180 lwork: *mut c_int,
181 ) -> cusolverStatus_t;
182
183 #[doc(hidden)]
185 #[allow(clippy::too_many_arguments)]
186 unsafe fn geqrf(
187 h: cusolverDnHandle_t,
188 m: c_int,
189 n: c_int,
190 a: *mut Self,
191 lda: c_int,
192 tau: *mut Self,
193 workspace: *mut Self,
194 lwork: c_int,
195 info: *mut c_int,
196 ) -> cusolverStatus_t;
197
198 #[doc(hidden)]
200 unsafe fn potrf_buf(
201 h: cusolverDnHandle_t,
202 uplo: cublasFillMode_t,
203 n: c_int,
204 a: *mut Self,
205 lda: c_int,
206 lwork: *mut c_int,
207 ) -> cusolverStatus_t;
208
209 #[doc(hidden)]
211 #[allow(clippy::too_many_arguments)]
212 unsafe fn potrf(
213 h: cusolverDnHandle_t,
214 uplo: cublasFillMode_t,
215 n: c_int,
216 a: *mut Self,
217 lda: c_int,
218 workspace: *mut Self,
219 lwork: c_int,
220 info: *mut c_int,
221 ) -> cusolverStatus_t;
222
223 #[doc(hidden)]
225 #[allow(clippy::too_many_arguments)]
226 unsafe fn potrs(
227 h: cusolverDnHandle_t,
228 uplo: cublasFillMode_t,
229 n: c_int,
230 nrhs: c_int,
231 a: *const Self,
232 lda: c_int,
233 b: *mut Self,
234 ldb: c_int,
235 info: *mut c_int,
236 ) -> cusolverStatus_t;
237
238 #[doc(hidden)]
240 unsafe fn gesvd_buf(
241 h: cusolverDnHandle_t,
242 m: c_int,
243 n: c_int,
244 lwork: *mut c_int,
245 ) -> cusolverStatus_t;
246
247 #[doc(hidden)]
249 #[allow(clippy::too_many_arguments)]
250 unsafe fn gesvd(
251 h: cusolverDnHandle_t,
252 jobu: u8,
253 jobvt: u8,
254 m: c_int,
255 n: c_int,
256 a: *mut Self,
257 lda: c_int,
258 s: *mut Self::Real,
259 u: *mut Self,
260 ldu: c_int,
261 vt: *mut Self,
262 ldvt: c_int,
263 work: *mut Self,
264 lwork: c_int,
265 rwork: *mut Self::Real,
266 info: *mut c_int,
267 ) -> cusolverStatus_t;
268
269 #[doc(hidden)]
271 #[allow(clippy::too_many_arguments)]
272 unsafe fn syevd_buf(
273 h: cusolverDnHandle_t,
274 jobz: cusolverEigMode_t,
275 uplo: cublasFillMode_t,
276 n: c_int,
277 a: *const Self,
278 lda: c_int,
279 w: *const Self::Real,
280 lwork: *mut c_int,
281 ) -> cusolverStatus_t;
282
283 #[doc(hidden)]
285 #[allow(clippy::too_many_arguments)]
286 unsafe fn syevd(
287 h: cusolverDnHandle_t,
288 jobz: cusolverEigMode_t,
289 uplo: cublasFillMode_t,
290 n: c_int,
291 a: *mut Self,
292 lda: c_int,
293 w: *mut Self::Real,
294 work: *mut Self,
295 lwork: c_int,
296 info: *mut c_int,
297 ) -> cusolverStatus_t;
298}
299
300mod sealed {
301 use baracuda_types::{Complex32, Complex64};
302 pub trait Sealed {}
303 impl Sealed for f32 {}
304 impl Sealed for f64 {}
305 impl Sealed for Complex32 {}
306 impl Sealed for Complex64 {}
307}
308
309macro_rules! real_impl {
310 ($t:ty, $getrf_buf:ident, $getrf:ident, $getrs:ident,
311 $geqrf_buf:ident, $geqrf:ident,
312 $potrf_buf:ident, $potrf:ident, $potrs:ident,
313 $gesvd_buf:ident, $gesvd:ident,
314 $syevd_buf:ident, $syevd:ident) => {
315 impl SolverScalar for $t {
316 type Real = $t;
317
318 unsafe fn getrf_buf(
319 h: cusolverDnHandle_t,
320 m: c_int,
321 n: c_int,
322 a: *mut $t,
323 lda: c_int,
324 lwork: *mut c_int,
325 ) -> cusolverStatus_t {
326 match cusolver().and_then(|c| c.$getrf_buf()) {
327 Ok(f) => f(h, m, n, a, lda, lwork),
328 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
329 }
330 }
331 unsafe fn getrf(
332 h: cusolverDnHandle_t,
333 m: c_int,
334 n: c_int,
335 a: *mut $t,
336 lda: c_int,
337 work: *mut $t,
338 ipiv: *mut c_int,
339 info: *mut c_int,
340 ) -> cusolverStatus_t {
341 match cusolver().and_then(|c| c.$getrf()) {
342 Ok(f) => f(h, m, n, a, lda, work, ipiv, info),
343 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
344 }
345 }
346 unsafe fn getrs(
347 h: cusolverDnHandle_t,
348 trans: cublasOperation_t,
349 n: c_int,
350 nrhs: c_int,
351 a: *const $t,
352 lda: c_int,
353 ipiv: *const c_int,
354 b: *mut $t,
355 ldb: c_int,
356 info: *mut c_int,
357 ) -> cusolverStatus_t {
358 match cusolver().and_then(|c| c.$getrs()) {
359 Ok(f) => f(h, trans, n, nrhs, a, lda, ipiv, b, ldb, info),
360 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
361 }
362 }
363 unsafe fn geqrf_buf(
364 h: cusolverDnHandle_t,
365 m: c_int,
366 n: c_int,
367 a: *mut $t,
368 lda: c_int,
369 lwork: *mut c_int,
370 ) -> cusolverStatus_t {
371 match cusolver().and_then(|c| c.$geqrf_buf()) {
372 Ok(f) => f(h, m, n, a, lda, lwork),
373 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
374 }
375 }
376 unsafe fn geqrf(
377 h: cusolverDnHandle_t,
378 m: c_int,
379 n: c_int,
380 a: *mut $t,
381 lda: c_int,
382 tau: *mut $t,
383 work: *mut $t,
384 lwork: c_int,
385 info: *mut c_int,
386 ) -> cusolverStatus_t {
387 match cusolver().and_then(|c| c.$geqrf()) {
388 Ok(f) => f(h, m, n, a, lda, tau, work, lwork, info),
389 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
390 }
391 }
392 unsafe fn potrf_buf(
393 h: cusolverDnHandle_t,
394 uplo: cublasFillMode_t,
395 n: c_int,
396 a: *mut $t,
397 lda: c_int,
398 lwork: *mut c_int,
399 ) -> cusolverStatus_t {
400 match cusolver().and_then(|c| c.$potrf_buf()) {
401 Ok(f) => f(h, uplo, n, a, lda, lwork),
402 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
403 }
404 }
405 unsafe fn potrf(
406 h: cusolverDnHandle_t,
407 uplo: cublasFillMode_t,
408 n: c_int,
409 a: *mut $t,
410 lda: c_int,
411 work: *mut $t,
412 lwork: c_int,
413 info: *mut c_int,
414 ) -> cusolverStatus_t {
415 match cusolver().and_then(|c| c.$potrf()) {
416 Ok(f) => f(h, uplo, n, a, lda, work, lwork, info),
417 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
418 }
419 }
420 unsafe fn potrs(
421 h: cusolverDnHandle_t,
422 uplo: cublasFillMode_t,
423 n: c_int,
424 nrhs: c_int,
425 a: *const $t,
426 lda: c_int,
427 b: *mut $t,
428 ldb: c_int,
429 info: *mut c_int,
430 ) -> cusolverStatus_t {
431 match cusolver().and_then(|c| c.$potrs()) {
432 Ok(f) => f(h, uplo, n, nrhs, a, lda, b, ldb, info),
433 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
434 }
435 }
436 unsafe fn gesvd_buf(
437 h: cusolverDnHandle_t,
438 m: c_int,
439 n: c_int,
440 lwork: *mut c_int,
441 ) -> cusolverStatus_t {
442 match cusolver().and_then(|c| c.$gesvd_buf()) {
443 Ok(f) => f(h, m, n, lwork),
444 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
445 }
446 }
447 unsafe fn gesvd(
448 h: cusolverDnHandle_t,
449 jobu: u8,
450 jobvt: u8,
451 m: c_int,
452 n: c_int,
453 a: *mut $t,
454 lda: c_int,
455 s: *mut $t,
456 u: *mut $t,
457 ldu: c_int,
458 vt: *mut $t,
459 ldvt: c_int,
460 work: *mut $t,
461 lwork: c_int,
462 rwork: *mut $t,
463 info: *mut c_int,
464 ) -> cusolverStatus_t {
465 match cusolver().and_then(|c| c.$gesvd()) {
466 Ok(f) => f(
467 h, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork, info,
468 ),
469 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
470 }
471 }
472 unsafe fn syevd_buf(
473 h: cusolverDnHandle_t,
474 jobz: cusolverEigMode_t,
475 uplo: cublasFillMode_t,
476 n: c_int,
477 a: *const $t,
478 lda: c_int,
479 w: *const $t,
480 lwork: *mut c_int,
481 ) -> cusolverStatus_t {
482 match cusolver().and_then(|c| c.$syevd_buf()) {
483 Ok(f) => f(h, jobz, uplo, n, a, lda, w, lwork),
484 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
485 }
486 }
487 unsafe fn syevd(
488 h: cusolverDnHandle_t,
489 jobz: cusolverEigMode_t,
490 uplo: cublasFillMode_t,
491 n: c_int,
492 a: *mut $t,
493 lda: c_int,
494 w: *mut $t,
495 work: *mut $t,
496 lwork: c_int,
497 info: *mut c_int,
498 ) -> cusolverStatus_t {
499 match cusolver().and_then(|c| c.$syevd()) {
500 Ok(f) => f(h, jobz, uplo, n, a, lda, w, work, lwork, info),
501 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
502 }
503 }
504 }
505 };
506}
507
508macro_rules! complex_impl {
509 ($t:ty, $real:ty, $raw:ty,
510 $getrf_buf:ident, $getrf:ident, $getrs:ident,
511 $geqrf_buf:ident, $geqrf:ident,
512 $potrf_buf:ident, $potrf:ident, $potrs:ident,
513 $gesvd_buf:ident, $gesvd:ident,
514 $heevd_buf:ident, $heevd:ident) => {
515 impl SolverScalar for $t {
516 type Real = $real;
517
518 unsafe fn getrf_buf(
519 h: cusolverDnHandle_t,
520 m: c_int,
521 n: c_int,
522 a: *mut $t,
523 lda: c_int,
524 lwork: *mut c_int,
525 ) -> cusolverStatus_t {
526 match cusolver().and_then(|c| c.$getrf_buf()) {
527 Ok(f) => f(h, m, n, a as *mut $raw, lda, lwork),
528 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
529 }
530 }
531 unsafe fn getrf(
532 h: cusolverDnHandle_t,
533 m: c_int,
534 n: c_int,
535 a: *mut $t,
536 lda: c_int,
537 work: *mut $t,
538 ipiv: *mut c_int,
539 info: *mut c_int,
540 ) -> cusolverStatus_t {
541 match cusolver().and_then(|c| c.$getrf()) {
542 Ok(f) => f(
543 h,
544 m,
545 n,
546 a as *mut $raw,
547 lda,
548 work as *mut $raw,
549 ipiv,
550 info,
551 ),
552 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
553 }
554 }
555 unsafe fn getrs(
556 h: cusolverDnHandle_t,
557 trans: cublasOperation_t,
558 n: c_int,
559 nrhs: c_int,
560 a: *const $t,
561 lda: c_int,
562 ipiv: *const c_int,
563 b: *mut $t,
564 ldb: c_int,
565 info: *mut c_int,
566 ) -> cusolverStatus_t {
567 match cusolver().and_then(|c| c.$getrs()) {
568 Ok(f) => f(
569 h,
570 trans,
571 n,
572 nrhs,
573 a as *const $raw,
574 lda,
575 ipiv,
576 b as *mut $raw,
577 ldb,
578 info,
579 ),
580 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
581 }
582 }
583 unsafe fn geqrf_buf(
584 h: cusolverDnHandle_t,
585 m: c_int,
586 n: c_int,
587 a: *mut $t,
588 lda: c_int,
589 lwork: *mut c_int,
590 ) -> cusolverStatus_t {
591 match cusolver().and_then(|c| c.$geqrf_buf()) {
592 Ok(f) => f(h, m, n, a as *mut $raw, lda, lwork),
593 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
594 }
595 }
596 unsafe fn geqrf(
597 h: cusolverDnHandle_t,
598 m: c_int,
599 n: c_int,
600 a: *mut $t,
601 lda: c_int,
602 tau: *mut $t,
603 work: *mut $t,
604 lwork: c_int,
605 info: *mut c_int,
606 ) -> cusolverStatus_t {
607 match cusolver().and_then(|c| c.$geqrf()) {
608 Ok(f) => f(
609 h,
610 m,
611 n,
612 a as *mut $raw,
613 lda,
614 tau as *mut $raw,
615 work as *mut $raw,
616 lwork,
617 info,
618 ),
619 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
620 }
621 }
622 unsafe fn potrf_buf(
623 h: cusolverDnHandle_t,
624 uplo: cublasFillMode_t,
625 n: c_int,
626 a: *mut $t,
627 lda: c_int,
628 lwork: *mut c_int,
629 ) -> cusolverStatus_t {
630 match cusolver().and_then(|c| c.$potrf_buf()) {
631 Ok(f) => f(h, uplo, n, a as *mut $raw, lda, lwork),
632 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
633 }
634 }
635 unsafe fn potrf(
636 h: cusolverDnHandle_t,
637 uplo: cublasFillMode_t,
638 n: c_int,
639 a: *mut $t,
640 lda: c_int,
641 work: *mut $t,
642 lwork: c_int,
643 info: *mut c_int,
644 ) -> cusolverStatus_t {
645 match cusolver().and_then(|c| c.$potrf()) {
646 Ok(f) => f(
647 h,
648 uplo,
649 n,
650 a as *mut $raw,
651 lda,
652 work as *mut $raw,
653 lwork,
654 info,
655 ),
656 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
657 }
658 }
659 unsafe fn potrs(
660 h: cusolverDnHandle_t,
661 uplo: cublasFillMode_t,
662 n: c_int,
663 nrhs: c_int,
664 a: *const $t,
665 lda: c_int,
666 b: *mut $t,
667 ldb: c_int,
668 info: *mut c_int,
669 ) -> cusolverStatus_t {
670 match cusolver().and_then(|c| c.$potrs()) {
671 Ok(f) => f(
672 h,
673 uplo,
674 n,
675 nrhs,
676 a as *const $raw,
677 lda,
678 b as *mut $raw,
679 ldb,
680 info,
681 ),
682 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
683 }
684 }
685 unsafe fn gesvd_buf(
686 h: cusolverDnHandle_t,
687 m: c_int,
688 n: c_int,
689 lwork: *mut c_int,
690 ) -> cusolverStatus_t {
691 match cusolver().and_then(|c| c.$gesvd_buf()) {
692 Ok(f) => f(h, m, n, lwork),
693 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
694 }
695 }
696 unsafe fn gesvd(
697 h: cusolverDnHandle_t,
698 jobu: u8,
699 jobvt: u8,
700 m: c_int,
701 n: c_int,
702 a: *mut $t,
703 lda: c_int,
704 s: *mut $real,
705 u: *mut $t,
706 ldu: c_int,
707 vt: *mut $t,
708 ldvt: c_int,
709 work: *mut $t,
710 lwork: c_int,
711 rwork: *mut $real,
712 info: *mut c_int,
713 ) -> cusolverStatus_t {
714 match cusolver().and_then(|c| c.$gesvd()) {
715 Ok(f) => f(
716 h,
717 jobu,
718 jobvt,
719 m,
720 n,
721 a as *mut $raw,
722 lda,
723 s,
724 u as *mut $raw,
725 ldu,
726 vt as *mut $raw,
727 ldvt,
728 work as *mut $raw,
729 lwork,
730 rwork,
731 info,
732 ),
733 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
734 }
735 }
736 unsafe fn syevd_buf(
737 h: cusolverDnHandle_t,
738 jobz: cusolverEigMode_t,
739 uplo: cublasFillMode_t,
740 n: c_int,
741 a: *const $t,
742 lda: c_int,
743 w: *const $real,
744 lwork: *mut c_int,
745 ) -> cusolverStatus_t {
746 match cusolver().and_then(|c| c.$heevd_buf()) {
747 Ok(f) => f(h, jobz, uplo, n, a as *const $raw, lda, w, lwork),
748 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
749 }
750 }
751 unsafe fn syevd(
752 h: cusolverDnHandle_t,
753 jobz: cusolverEigMode_t,
754 uplo: cublasFillMode_t,
755 n: c_int,
756 a: *mut $t,
757 lda: c_int,
758 w: *mut $real,
759 work: *mut $t,
760 lwork: c_int,
761 info: *mut c_int,
762 ) -> cusolverStatus_t {
763 match cusolver().and_then(|c| c.$heevd()) {
764 Ok(f) => f(
765 h,
766 jobz,
767 uplo,
768 n,
769 a as *mut $raw,
770 lda,
771 w,
772 work as *mut $raw,
773 lwork,
774 info,
775 ),
776 Err(_) => cusolverStatus_t::NOT_INITIALIZED,
777 }
778 }
779 }
780 };
781}
782
783real_impl!(
784 f32,
785 cusolver_dn_sgetrf_buffer_size,
786 cusolver_dn_sgetrf,
787 cusolver_dn_sgetrs,
788 cusolver_dn_sgeqrf_buffer_size,
789 cusolver_dn_sgeqrf,
790 cusolver_dn_spotrf_buffer_size,
791 cusolver_dn_spotrf,
792 cusolver_dn_spotrs,
793 cusolver_dn_sgesvd_buffer_size,
794 cusolver_dn_sgesvd,
795 cusolver_dn_ssyevd_buffer_size,
796 cusolver_dn_ssyevd
797);
798
799real_impl!(
800 f64,
801 cusolver_dn_dgetrf_buffer_size,
802 cusolver_dn_dgetrf,
803 cusolver_dn_dgetrs,
804 cusolver_dn_dgeqrf_buffer_size,
805 cusolver_dn_dgeqrf,
806 cusolver_dn_dpotrf_buffer_size,
807 cusolver_dn_dpotrf,
808 cusolver_dn_dpotrs,
809 cusolver_dn_dgesvd_buffer_size,
810 cusolver_dn_dgesvd,
811 cusolver_dn_dsyevd_buffer_size,
812 cusolver_dn_dsyevd
813);
814
815complex_impl!(
816 Complex32,
817 f32,
818 cuComplex,
819 cusolver_dn_cgetrf_buffer_size,
820 cusolver_dn_cgetrf,
821 cusolver_dn_cgetrs,
822 cusolver_dn_cgeqrf_buffer_size,
823 cusolver_dn_cgeqrf,
824 cusolver_dn_cpotrf_buffer_size,
825 cusolver_dn_cpotrf,
826 cusolver_dn_cpotrs,
827 cusolver_dn_cgesvd_buffer_size,
828 cusolver_dn_cgesvd,
829 cusolver_dn_cheevd_buffer_size,
830 cusolver_dn_cheevd
831);
832
833complex_impl!(
834 Complex64,
835 f64,
836 cuDoubleComplex,
837 cusolver_dn_zgetrf_buffer_size,
838 cusolver_dn_zgetrf,
839 cusolver_dn_zgetrs,
840 cusolver_dn_zgeqrf_buffer_size,
841 cusolver_dn_zgeqrf,
842 cusolver_dn_zpotrf_buffer_size,
843 cusolver_dn_zpotrf,
844 cusolver_dn_zpotrs,
845 cusolver_dn_zgesvd_buffer_size,
846 cusolver_dn_zgesvd,
847 cusolver_dn_zheevd_buffer_size,
848 cusolver_dn_zheevd
849);
850
851#[allow(clippy::too_many_arguments)]
855pub fn getrf<T: SolverScalar>(
856 handle: &DnHandle,
857 m: i32,
858 n: i32,
859 a: &mut DeviceBuffer<T>,
860 lda: i32,
861 ipiv: &mut DeviceBuffer<i32>,
862 info: &mut DeviceBuffer<i32>,
863) -> Result<()> {
864 let mut lwork: c_int = 0;
865 check(unsafe { T::getrf_buf(handle.handle, m, n, a.as_raw().0 as *mut T, lda, &mut lwork) })?;
866 let workspace =
867 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
868 check(unsafe {
869 T::getrf(
870 handle.handle,
871 m,
872 n,
873 a.as_raw().0 as *mut T,
874 lda,
875 workspace.as_raw().0 as *mut T,
876 ipiv.as_raw().0 as *mut c_int,
877 info.as_raw().0 as *mut c_int,
878 )
879 })
880}
881
882#[allow(clippy::too_many_arguments)]
884pub fn getrs<T: SolverScalar>(
885 handle: &DnHandle,
886 trans: Op,
887 n: i32,
888 nrhs: i32,
889 a: &DeviceBuffer<T>,
890 lda: i32,
891 ipiv: &DeviceBuffer<i32>,
892 b: &mut DeviceBuffer<T>,
893 ldb: i32,
894 info: &mut DeviceBuffer<i32>,
895) -> Result<()> {
896 check(unsafe {
897 T::getrs(
898 handle.handle,
899 trans.raw(),
900 n,
901 nrhs,
902 a.as_raw().0 as *const T,
903 lda,
904 ipiv.as_raw().0 as *const c_int,
905 b.as_raw().0 as *mut T,
906 ldb,
907 info.as_raw().0 as *mut c_int,
908 )
909 })
910}
911
912#[allow(clippy::too_many_arguments)]
915pub fn geqrf<T: SolverScalar>(
916 handle: &DnHandle,
917 m: i32,
918 n: i32,
919 a: &mut DeviceBuffer<T>,
920 lda: i32,
921 tau: &mut DeviceBuffer<T>,
922 info: &mut DeviceBuffer<i32>,
923) -> Result<()> {
924 let mut lwork: c_int = 0;
925 check(unsafe { T::geqrf_buf(handle.handle, m, n, a.as_raw().0 as *mut T, lda, &mut lwork) })?;
926 let workspace =
927 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
928 check(unsafe {
929 T::geqrf(
930 handle.handle,
931 m,
932 n,
933 a.as_raw().0 as *mut T,
934 lda,
935 tau.as_raw().0 as *mut T,
936 workspace.as_raw().0 as *mut T,
937 lwork,
938 info.as_raw().0 as *mut c_int,
939 )
940 })
941}
942
943pub fn potrf<T: SolverScalar>(
945 handle: &DnHandle,
946 uplo: Fill,
947 n: i32,
948 a: &mut DeviceBuffer<T>,
949 lda: i32,
950 info: &mut DeviceBuffer<i32>,
951) -> Result<()> {
952 let mut lwork: c_int = 0;
953 check(unsafe { T::potrf_buf(handle.handle, uplo, n, a.as_raw().0 as *mut T, lda, &mut lwork) })?;
954 let workspace =
955 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
956 check(unsafe {
957 T::potrf(
958 handle.handle,
959 uplo,
960 n,
961 a.as_raw().0 as *mut T,
962 lda,
963 workspace.as_raw().0 as *mut T,
964 lwork,
965 info.as_raw().0 as *mut c_int,
966 )
967 })
968}
969
970#[allow(clippy::too_many_arguments)]
972pub fn potrs<T: SolverScalar>(
973 handle: &DnHandle,
974 uplo: Fill,
975 n: i32,
976 nrhs: i32,
977 a: &DeviceBuffer<T>,
978 lda: i32,
979 b: &mut DeviceBuffer<T>,
980 ldb: i32,
981 info: &mut DeviceBuffer<i32>,
982) -> Result<()> {
983 check(unsafe {
984 T::potrs(
985 handle.handle,
986 uplo,
987 n,
988 nrhs,
989 a.as_raw().0 as *const T,
990 lda,
991 b.as_raw().0 as *mut T,
992 ldb,
993 info.as_raw().0 as *mut c_int,
994 )
995 })
996}
997
998#[allow(clippy::too_many_arguments)]
1004pub fn gesvd<T: SolverScalar>(
1005 handle: &DnHandle,
1006 jobu: u8,
1007 jobvt: u8,
1008 m: i32,
1009 n: i32,
1010 a: &mut DeviceBuffer<T>,
1011 lda: i32,
1012 s: &mut DeviceBuffer<T::Real>,
1013 u: &mut DeviceBuffer<T>,
1014 ldu: i32,
1015 vt: &mut DeviceBuffer<T>,
1016 ldvt: i32,
1017 rwork: &mut DeviceBuffer<T::Real>,
1018 info: &mut DeviceBuffer<i32>,
1019) -> Result<()> {
1020 let mut lwork: c_int = 0;
1021 check(unsafe { T::gesvd_buf(handle.handle, m, n, &mut lwork) })?;
1022 let workspace =
1023 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1024 check(unsafe {
1025 T::gesvd(
1026 handle.handle,
1027 jobu,
1028 jobvt,
1029 m,
1030 n,
1031 a.as_raw().0 as *mut T,
1032 lda,
1033 s.as_raw().0 as *mut T::Real,
1034 u.as_raw().0 as *mut T,
1035 ldu,
1036 vt.as_raw().0 as *mut T,
1037 ldvt,
1038 workspace.as_raw().0 as *mut T,
1039 lwork,
1040 rwork.as_raw().0 as *mut T::Real,
1041 info.as_raw().0 as *mut c_int,
1042 )
1043 })
1044}
1045
1046#[allow(clippy::too_many_arguments)]
1048pub fn syevd<T: SolverScalar>(
1049 handle: &DnHandle,
1050 jobz: EigMode,
1051 uplo: Fill,
1052 n: i32,
1053 a: &mut DeviceBuffer<T>,
1054 lda: i32,
1055 w: &mut DeviceBuffer<T::Real>,
1056 info: &mut DeviceBuffer<i32>,
1057) -> Result<()> {
1058 let mut lwork: c_int = 0;
1059 check(unsafe {
1060 T::syevd_buf(
1061 handle.handle,
1062 jobz,
1063 uplo,
1064 n,
1065 a.as_raw().0 as *const T,
1066 lda,
1067 w.as_raw().0 as *const T::Real,
1068 &mut lwork,
1069 )
1070 })?;
1071 let workspace =
1072 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1073 check(unsafe {
1074 T::syevd(
1075 handle.handle,
1076 jobz,
1077 uplo,
1078 n,
1079 a.as_raw().0 as *mut T,
1080 lda,
1081 w.as_raw().0 as *mut T::Real,
1082 workspace.as_raw().0 as *mut T,
1083 lwork,
1084 info.as_raw().0 as *mut c_int,
1085 )
1086 })
1087}
1088
1089pub use baracuda_cusolver_sys::{gesvdjInfo_t as GesvdjInfoRaw, syevjInfo_t as SyevjInfoRaw};
1092
1093#[derive(Debug)]
1095pub struct SyevjInfo {
1096 raw: SyevjInfoRaw,
1097}
1098
1099impl SyevjInfo {
1100 pub fn new() -> Result<Self> {
1101 let c = cusolver()?;
1102 let cu = c.cusolver_dn_create_syevj_info()?;
1103 let mut raw: SyevjInfoRaw = core::ptr::null_mut();
1104 check(unsafe { cu(&mut raw) })?;
1105 Ok(Self { raw })
1106 }
1107
1108 pub fn set_tolerance(&self, tol: f64) -> Result<()> {
1109 let c = cusolver()?;
1110 let cu = c.cusolver_dn_xsyevj_set_tolerance()?;
1111 check(unsafe { cu(self.raw, tol) })
1112 }
1113
1114 pub fn set_max_sweeps(&self, n: i32) -> Result<()> {
1115 let c = cusolver()?;
1116 let cu = c.cusolver_dn_xsyevj_set_max_sweeps()?;
1117 check(unsafe { cu(self.raw, n) })
1118 }
1119
1120 pub fn as_raw(&self) -> SyevjInfoRaw {
1121 self.raw
1122 }
1123}
1124
1125impl Drop for SyevjInfo {
1126 fn drop(&mut self) {
1127 if let Ok(c) = cusolver() {
1128 if let Ok(cu) = c.cusolver_dn_destroy_syevj_info() {
1129 let _ = unsafe { cu(self.raw) };
1130 }
1131 }
1132 }
1133}
1134
1135#[derive(Debug)]
1137pub struct GesvdjInfo {
1138 raw: GesvdjInfoRaw,
1139}
1140
1141impl GesvdjInfo {
1142 pub fn new() -> Result<Self> {
1143 let c = cusolver()?;
1144 let cu = c.cusolver_dn_create_gesvdj_info()?;
1145 let mut raw: GesvdjInfoRaw = core::ptr::null_mut();
1146 check(unsafe { cu(&mut raw) })?;
1147 Ok(Self { raw })
1148 }
1149
1150 pub fn as_raw(&self) -> GesvdjInfoRaw {
1151 self.raw
1152 }
1153}
1154
1155impl Drop for GesvdjInfo {
1156 fn drop(&mut self) {
1157 if let Ok(c) = cusolver() {
1158 if let Ok(cu) = c.cusolver_dn_destroy_gesvdj_info() {
1159 let _ = unsafe { cu(self.raw) };
1160 }
1161 }
1162 }
1163}
1164
1165#[allow(clippy::too_many_arguments)]
1168pub fn syevj<T: SolverScalar>(
1169 handle: &DnHandle,
1170 jobz: EigMode,
1171 uplo: Fill,
1172 n: i32,
1173 a: &mut DeviceBuffer<T>,
1174 lda: i32,
1175 w: &mut DeviceBuffer<T::Real>,
1176 info: &mut DeviceBuffer<i32>,
1177 params: &SyevjInfo,
1178) -> Result<()> {
1179 use baracuda_cusolver_sys::{
1180 cuComplex, cuDoubleComplex,
1181 };
1182 use core::mem;
1183
1184 let mut lwork: c_int = 0;
1185
1186 macro_rules! dispatch_real {
1189 ($t:ty, $bufsize:ident, $solve:ident) => {{
1190 let c = cusolver()?;
1191 check(unsafe {
1192 (c.$bufsize()?)(
1193 handle.as_raw(),
1194 jobz,
1195 uplo,
1196 n,
1197 a.as_raw().0 as *const $t,
1198 lda,
1199 w.as_raw().0 as *const $t,
1200 &mut lwork,
1201 params.raw,
1202 )
1203 })?;
1204 let workspace =
1205 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1206 check(unsafe {
1207 (c.$solve()?)(
1208 handle.as_raw(),
1209 jobz,
1210 uplo,
1211 n,
1212 a.as_raw().0 as *mut $t,
1213 lda,
1214 w.as_raw().0 as *mut $t,
1215 workspace.as_raw().0 as *mut $t,
1216 lwork,
1217 info.as_raw().0 as *mut c_int,
1218 params.raw,
1219 )
1220 })
1221 }};
1222 }
1223 macro_rules! dispatch_complex {
1224 ($t:ty, $real:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1225 let c = cusolver()?;
1226 check(unsafe {
1227 (c.$bufsize()?)(
1228 handle.as_raw(),
1229 jobz,
1230 uplo,
1231 n,
1232 a.as_raw().0 as *const $raw,
1233 lda,
1234 w.as_raw().0 as *const $real,
1235 &mut lwork,
1236 params.raw,
1237 )
1238 })?;
1239 let workspace =
1240 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1241 check(unsafe {
1242 (c.$solve()?)(
1243 handle.as_raw(),
1244 jobz,
1245 uplo,
1246 n,
1247 a.as_raw().0 as *mut $raw,
1248 lda,
1249 w.as_raw().0 as *mut $real,
1250 workspace.as_raw().0 as *mut $raw,
1251 lwork,
1252 info.as_raw().0 as *mut c_int,
1253 params.raw,
1254 )
1255 })
1256 }};
1257 }
1258
1259 if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1260 dispatch_real!(f32, cusolver_dn_ssyevj_buffer_size, cusolver_dn_ssyevj)
1261 } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1262 dispatch_real!(f64, cusolver_dn_dsyevj_buffer_size, cusolver_dn_dsyevj)
1263 } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1264 dispatch_complex!(
1265 Complex32,
1266 f32,
1267 cuComplex,
1268 cusolver_dn_cheevj_buffer_size,
1269 cusolver_dn_cheevj
1270 )
1271 } else {
1272 dispatch_complex!(
1273 Complex64,
1274 f64,
1275 cuDoubleComplex,
1276 cusolver_dn_zheevj_buffer_size,
1277 cusolver_dn_zheevj
1278 )
1279 }
1280}
1281
1282#[allow(clippy::too_many_arguments)]
1284pub fn gesvdj<T: SolverScalar>(
1285 handle: &DnHandle,
1286 jobz: EigMode,
1287 econ: bool,
1288 m: i32,
1289 n: i32,
1290 a: &mut DeviceBuffer<T>,
1291 lda: i32,
1292 s: &mut DeviceBuffer<T::Real>,
1293 u: &mut DeviceBuffer<T>,
1294 ldu: i32,
1295 v: &mut DeviceBuffer<T>,
1296 ldv: i32,
1297 info: &mut DeviceBuffer<i32>,
1298 params: &GesvdjInfo,
1299) -> Result<()> {
1300 use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1301 use core::mem;
1302
1303 let mut lwork: c_int = 0;
1304 let econ_i = if econ { 1 } else { 0 };
1305
1306 macro_rules! dispatch_real {
1307 ($t:ty, $bufsize:ident, $solve:ident) => {{
1308 let c = cusolver()?;
1309 check(unsafe {
1310 (c.$bufsize()?)(
1311 handle.as_raw(),
1312 jobz,
1313 econ_i,
1314 m,
1315 n,
1316 a.as_raw().0 as *const $t,
1317 lda,
1318 s.as_raw().0 as *const $t,
1319 u.as_raw().0 as *const $t,
1320 ldu,
1321 v.as_raw().0 as *const $t,
1322 ldv,
1323 &mut lwork,
1324 params.raw,
1325 )
1326 })?;
1327 let workspace =
1328 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1329 check(unsafe {
1330 (c.$solve()?)(
1331 handle.as_raw(),
1332 jobz,
1333 econ_i,
1334 m,
1335 n,
1336 a.as_raw().0 as *mut $t,
1337 lda,
1338 s.as_raw().0 as *mut $t,
1339 u.as_raw().0 as *mut $t,
1340 ldu,
1341 v.as_raw().0 as *mut $t,
1342 ldv,
1343 workspace.as_raw().0 as *mut $t,
1344 lwork,
1345 info.as_raw().0 as *mut c_int,
1346 params.raw,
1347 )
1348 })
1349 }};
1350 }
1351 macro_rules! dispatch_complex {
1352 ($t:ty, $real:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1353 let c = cusolver()?;
1354 check(unsafe {
1355 (c.$bufsize()?)(
1356 handle.as_raw(),
1357 jobz,
1358 econ_i,
1359 m,
1360 n,
1361 a.as_raw().0 as *const $raw,
1362 lda,
1363 s.as_raw().0 as *const $real,
1364 u.as_raw().0 as *const $raw,
1365 ldu,
1366 v.as_raw().0 as *const $raw,
1367 ldv,
1368 &mut lwork,
1369 params.raw,
1370 )
1371 })?;
1372 let workspace =
1373 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1374 check(unsafe {
1375 (c.$solve()?)(
1376 handle.as_raw(),
1377 jobz,
1378 econ_i,
1379 m,
1380 n,
1381 a.as_raw().0 as *mut $raw,
1382 lda,
1383 s.as_raw().0 as *mut $real,
1384 u.as_raw().0 as *mut $raw,
1385 ldu,
1386 v.as_raw().0 as *mut $raw,
1387 ldv,
1388 workspace.as_raw().0 as *mut $raw,
1389 lwork,
1390 info.as_raw().0 as *mut c_int,
1391 params.raw,
1392 )
1393 })
1394 }};
1395 }
1396
1397 if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1398 dispatch_real!(f32, cusolver_dn_sgesvdj_buffer_size, cusolver_dn_sgesvdj)
1399 } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1400 dispatch_real!(f64, cusolver_dn_dgesvdj_buffer_size, cusolver_dn_dgesvdj)
1401 } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1402 dispatch_complex!(
1403 Complex32,
1404 f32,
1405 cuComplex,
1406 cusolver_dn_cgesvdj_buffer_size,
1407 cusolver_dn_cgesvdj
1408 )
1409 } else {
1410 dispatch_complex!(
1411 Complex64,
1412 f64,
1413 cuDoubleComplex,
1414 cusolver_dn_zgesvdj_buffer_size,
1415 cusolver_dn_zgesvdj
1416 )
1417 }
1418}
1419
1420#[allow(clippy::too_many_arguments)]
1425pub fn orgqr<T: SolverScalar>(
1426 handle: &DnHandle,
1427 m: i32,
1428 n: i32,
1429 k: i32,
1430 a: &mut DeviceBuffer<T>,
1431 lda: i32,
1432 tau: &DeviceBuffer<T>,
1433 info: &mut DeviceBuffer<i32>,
1434) -> Result<()> {
1435 use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1436 use core::mem;
1437
1438 let mut lwork: c_int = 0;
1439 macro_rules! dispatch {
1440 ($t:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1441 let c = cusolver()?;
1442 check(unsafe {
1443 (c.$bufsize()?)(
1444 handle.as_raw(),
1445 m,
1446 n,
1447 k,
1448 a.as_raw().0 as *const $raw,
1449 lda,
1450 tau.as_raw().0 as *const $raw,
1451 &mut lwork,
1452 )
1453 })?;
1454 let workspace =
1455 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1456 check(unsafe {
1457 (c.$solve()?)(
1458 handle.as_raw(),
1459 m,
1460 n,
1461 k,
1462 a.as_raw().0 as *mut $raw,
1463 lda,
1464 tau.as_raw().0 as *const $raw,
1465 workspace.as_raw().0 as *mut $raw,
1466 lwork,
1467 info.as_raw().0 as *mut c_int,
1468 )
1469 })
1470 }};
1471 }
1472
1473 if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1474 dispatch!(f32, f32, cusolver_dn_sorgqr_buffer_size, cusolver_dn_sorgqr)
1475 } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1476 dispatch!(f64, f64, cusolver_dn_dorgqr_buffer_size, cusolver_dn_dorgqr)
1477 } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1478 dispatch!(
1479 Complex32,
1480 cuComplex,
1481 cusolver_dn_cungqr_buffer_size,
1482 cusolver_dn_cungqr
1483 )
1484 } else {
1485 dispatch!(
1486 Complex64,
1487 cuDoubleComplex,
1488 cusolver_dn_zungqr_buffer_size,
1489 cusolver_dn_zungqr
1490 )
1491 }
1492}
1493
1494#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1496pub enum Side {
1497 Left,
1498 Right,
1499}
1500
1501impl Side {
1502 fn raw(self) -> core::ffi::c_int {
1503 match self {
1504 Side::Left => 0,
1505 Side::Right => 1,
1506 }
1507 }
1508}
1509
1510#[allow(clippy::too_many_arguments)]
1513pub fn ormqr<T: SolverScalar>(
1514 handle: &DnHandle,
1515 side: Side,
1516 trans: Op,
1517 m: i32,
1518 n: i32,
1519 k: i32,
1520 a: &DeviceBuffer<T>,
1521 lda: i32,
1522 tau: &DeviceBuffer<T>,
1523 c_mat: &mut DeviceBuffer<T>,
1524 ldc: i32,
1525 info: &mut DeviceBuffer<i32>,
1526) -> Result<()> {
1527 use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1528 use core::mem;
1529
1530 let mut lwork: c_int = 0;
1531 let side_i = side.raw();
1532 macro_rules! dispatch {
1533 ($t:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1534 let ca = cusolver()?;
1535 check(unsafe {
1536 (ca.$bufsize()?)(
1537 handle.as_raw(),
1538 side_i,
1539 trans.raw(),
1540 m,
1541 n,
1542 k,
1543 a.as_raw().0 as *const $raw,
1544 lda,
1545 tau.as_raw().0 as *const $raw,
1546 c_mat.as_raw().0 as *const $raw,
1547 ldc,
1548 &mut lwork,
1549 )
1550 })?;
1551 let workspace =
1552 DeviceBuffer::<T>::new(c_mat.context(), lwork as usize).map_err(alloc_fail)?;
1553 check(unsafe {
1554 (ca.$solve()?)(
1555 handle.as_raw(),
1556 side_i,
1557 trans.raw(),
1558 m,
1559 n,
1560 k,
1561 a.as_raw().0 as *const $raw,
1562 lda,
1563 tau.as_raw().0 as *const $raw,
1564 c_mat.as_raw().0 as *mut $raw,
1565 ldc,
1566 workspace.as_raw().0 as *mut $raw,
1567 lwork,
1568 info.as_raw().0 as *mut c_int,
1569 )
1570 })
1571 }};
1572 }
1573
1574 if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1575 dispatch!(f32, f32, cusolver_dn_sormqr_buffer_size, cusolver_dn_sormqr)
1576 } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1577 dispatch!(f64, f64, cusolver_dn_dormqr_buffer_size, cusolver_dn_dormqr)
1578 } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1579 dispatch!(
1580 Complex32,
1581 cuComplex,
1582 cusolver_dn_cunmqr_buffer_size,
1583 cusolver_dn_cunmqr
1584 )
1585 } else {
1586 dispatch!(
1587 Complex64,
1588 cuDoubleComplex,
1589 cusolver_dn_zunmqr_buffer_size,
1590 cusolver_dn_zunmqr
1591 )
1592 }
1593}
1594
1595#[allow(clippy::too_many_arguments)]
1602pub fn gels<T: SolverScalar>(
1603 handle: &DnHandle,
1604 m: i32,
1605 n: i32,
1606 nrhs: i32,
1607 a: &mut DeviceBuffer<T>,
1608 lda: i32,
1609 b: &mut DeviceBuffer<T>,
1610 ldb: i32,
1611 x: &mut DeviceBuffer<T>,
1612 ldx: i32,
1613 info: &mut DeviceBuffer<i32>,
1614) -> Result<i32> {
1615 use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1616 use core::mem;
1617
1618 let mut bytes: usize = 0;
1619
1620 macro_rules! dispatch {
1621 ($t:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1622 let cs = cusolver()?;
1623 check(unsafe {
1624 (cs.$bufsize()?)(
1625 handle.as_raw(),
1626 m,
1627 n,
1628 nrhs,
1629 a.as_raw().0 as *mut $raw,
1630 lda,
1631 b.as_raw().0 as *mut $raw,
1632 ldb,
1633 x.as_raw().0 as *mut $raw,
1634 ldx,
1635 core::ptr::null_mut(),
1636 &mut bytes,
1637 )
1638 })?;
1639 let units = bytes.div_ceil(mem::size_of::<T>());
1641 let workspace =
1642 DeviceBuffer::<T>::new(a.context(), units).map_err(alloc_fail)?;
1643 let mut iter: c_int = 0;
1644 check(unsafe {
1645 (cs.$solve()?)(
1646 handle.as_raw(),
1647 m,
1648 n,
1649 nrhs,
1650 a.as_raw().0 as *mut $raw,
1651 lda,
1652 b.as_raw().0 as *mut $raw,
1653 ldb,
1654 x.as_raw().0 as *mut $raw,
1655 ldx,
1656 workspace.as_raw().0 as *mut c_void,
1657 bytes,
1658 &mut iter,
1659 info.as_raw().0 as *mut c_int,
1660 )
1661 })?;
1662 Ok(iter)
1663 }};
1664 }
1665
1666 if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1667 dispatch!(f32, f32, cusolver_dn_ssgels_buffer_size, cusolver_dn_ssgels)
1668 } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1669 dispatch!(f64, f64, cusolver_dn_ddgels_buffer_size, cusolver_dn_ddgels)
1670 } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1671 dispatch!(
1672 Complex32,
1673 cuComplex,
1674 cusolver_dn_ccgels_buffer_size,
1675 cusolver_dn_ccgels
1676 )
1677 } else {
1678 dispatch!(
1679 Complex64,
1680 cuDoubleComplex,
1681 cusolver_dn_zzgels_buffer_size,
1682 cusolver_dn_zzgels
1683 )
1684 }
1685}
1686
1687pub fn potri<T: SolverScalar>(
1693 handle: &DnHandle,
1694 uplo: Fill,
1695 n: i32,
1696 a: &mut DeviceBuffer<T>,
1697 lda: i32,
1698 info: &mut DeviceBuffer<i32>,
1699) -> Result<()> {
1700 use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1701 use core::mem;
1702
1703 let mut lwork: c_int = 0;
1704 macro_rules! dispatch {
1705 ($t:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1706 let cs = cusolver()?;
1707 check(unsafe {
1708 (cs.$bufsize()?)(
1709 handle.as_raw(),
1710 uplo,
1711 n,
1712 a.as_raw().0 as *mut $raw,
1713 lda,
1714 &mut lwork,
1715 )
1716 })?;
1717 let workspace =
1718 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1719 check(unsafe {
1720 (cs.$solve()?)(
1721 handle.as_raw(),
1722 uplo,
1723 n,
1724 a.as_raw().0 as *mut $raw,
1725 lda,
1726 workspace.as_raw().0 as *mut $raw,
1727 lwork,
1728 info.as_raw().0 as *mut c_int,
1729 )
1730 })
1731 }};
1732 }
1733
1734 if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1735 dispatch!(f32, f32, cusolver_dn_spotri_buffer_size, cusolver_dn_spotri)
1736 } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1737 dispatch!(f64, f64, cusolver_dn_dpotri_buffer_size, cusolver_dn_dpotri)
1738 } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1739 dispatch!(
1740 Complex32,
1741 cuComplex,
1742 cusolver_dn_cpotri_buffer_size,
1743 cusolver_dn_cpotri
1744 )
1745 } else {
1746 dispatch!(
1747 Complex64,
1748 cuDoubleComplex,
1749 cusolver_dn_zpotri_buffer_size,
1750 cusolver_dn_zpotri
1751 )
1752 }
1753}
1754
1755#[allow(clippy::too_many_arguments)]
1761pub fn syevj_batched<T: SolverScalar>(
1762 handle: &DnHandle,
1763 jobz: EigMode,
1764 uplo: Fill,
1765 n: i32,
1766 a: &mut DeviceBuffer<T>,
1767 lda: i32,
1768 w: &mut DeviceBuffer<T::Real>,
1769 info: &mut DeviceBuffer<i32>,
1770 params: &SyevjInfo,
1771 batch_size: i32,
1772) -> Result<()> {
1773 use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1774 use core::mem;
1775
1776 let mut lwork: c_int = 0;
1777 macro_rules! dispatch_real {
1778 ($t:ty, $bufsize:ident, $solve:ident) => {{
1779 let c = cusolver()?;
1780 check(unsafe {
1781 (c.$bufsize()?)(
1782 handle.as_raw(),
1783 jobz,
1784 uplo,
1785 n,
1786 a.as_raw().0 as *const $t,
1787 lda,
1788 w.as_raw().0 as *const $t,
1789 &mut lwork,
1790 params.as_raw(),
1791 batch_size,
1792 )
1793 })?;
1794 let workspace =
1795 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1796 check(unsafe {
1797 (c.$solve()?)(
1798 handle.as_raw(),
1799 jobz,
1800 uplo,
1801 n,
1802 a.as_raw().0 as *mut $t,
1803 lda,
1804 w.as_raw().0 as *mut $t,
1805 workspace.as_raw().0 as *mut $t,
1806 lwork,
1807 info.as_raw().0 as *mut c_int,
1808 params.as_raw(),
1809 batch_size,
1810 )
1811 })
1812 }};
1813 }
1814 macro_rules! dispatch_complex {
1815 ($t:ty, $real:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1816 let c = cusolver()?;
1817 check(unsafe {
1818 (c.$bufsize()?)(
1819 handle.as_raw(),
1820 jobz,
1821 uplo,
1822 n,
1823 a.as_raw().0 as *const $raw,
1824 lda,
1825 w.as_raw().0 as *const $real,
1826 &mut lwork,
1827 params.as_raw(),
1828 batch_size,
1829 )
1830 })?;
1831 let workspace =
1832 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1833 check(unsafe {
1834 (c.$solve()?)(
1835 handle.as_raw(),
1836 jobz,
1837 uplo,
1838 n,
1839 a.as_raw().0 as *mut $raw,
1840 lda,
1841 w.as_raw().0 as *mut $real,
1842 workspace.as_raw().0 as *mut $raw,
1843 lwork,
1844 info.as_raw().0 as *mut c_int,
1845 params.as_raw(),
1846 batch_size,
1847 )
1848 })
1849 }};
1850 }
1851
1852 if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1853 dispatch_real!(
1854 f32,
1855 cusolver_dn_ssyevj_batched_buffer_size,
1856 cusolver_dn_ssyevj_batched
1857 )
1858 } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
1859 dispatch_real!(
1860 f64,
1861 cusolver_dn_dsyevj_batched_buffer_size,
1862 cusolver_dn_dsyevj_batched
1863 )
1864 } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
1865 dispatch_complex!(
1866 Complex32,
1867 f32,
1868 cuComplex,
1869 cusolver_dn_cheevj_batched_buffer_size,
1870 cusolver_dn_cheevj_batched
1871 )
1872 } else {
1873 dispatch_complex!(
1874 Complex64,
1875 f64,
1876 cuDoubleComplex,
1877 cusolver_dn_zheevj_batched_buffer_size,
1878 cusolver_dn_zheevj_batched
1879 )
1880 }
1881}
1882
1883#[allow(clippy::too_many_arguments)]
1885pub fn gesvdj_batched<T: SolverScalar>(
1886 handle: &DnHandle,
1887 jobz: EigMode,
1888 m: i32,
1889 n: i32,
1890 a: &mut DeviceBuffer<T>,
1891 lda: i32,
1892 s: &mut DeviceBuffer<T::Real>,
1893 u: &mut DeviceBuffer<T>,
1894 ldu: i32,
1895 v: &mut DeviceBuffer<T>,
1896 ldv: i32,
1897 info: &mut DeviceBuffer<i32>,
1898 params: &GesvdjInfo,
1899 batch_size: i32,
1900) -> Result<()> {
1901 use baracuda_cusolver_sys::{cuComplex, cuDoubleComplex};
1902 use core::mem;
1903
1904 let mut lwork: c_int = 0;
1905 macro_rules! dispatch_real {
1906 ($t:ty, $bufsize:ident, $solve:ident) => {{
1907 let c = cusolver()?;
1908 check(unsafe {
1909 (c.$bufsize()?)(
1910 handle.as_raw(),
1911 jobz,
1912 m,
1913 n,
1914 a.as_raw().0 as *const $t,
1915 lda,
1916 s.as_raw().0 as *const $t,
1917 u.as_raw().0 as *const $t,
1918 ldu,
1919 v.as_raw().0 as *const $t,
1920 ldv,
1921 &mut lwork,
1922 params.as_raw(),
1923 batch_size,
1924 )
1925 })?;
1926 let workspace =
1927 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1928 check(unsafe {
1929 (c.$solve()?)(
1930 handle.as_raw(),
1931 jobz,
1932 m,
1933 n,
1934 a.as_raw().0 as *mut $t,
1935 lda,
1936 s.as_raw().0 as *mut $t,
1937 u.as_raw().0 as *mut $t,
1938 ldu,
1939 v.as_raw().0 as *mut $t,
1940 ldv,
1941 workspace.as_raw().0 as *mut $t,
1942 lwork,
1943 info.as_raw().0 as *mut c_int,
1944 params.as_raw(),
1945 batch_size,
1946 )
1947 })
1948 }};
1949 }
1950 macro_rules! dispatch_complex {
1951 ($t:ty, $real:ty, $raw:ty, $bufsize:ident, $solve:ident) => {{
1952 let c = cusolver()?;
1953 check(unsafe {
1954 (c.$bufsize()?)(
1955 handle.as_raw(),
1956 jobz,
1957 m,
1958 n,
1959 a.as_raw().0 as *const $raw,
1960 lda,
1961 s.as_raw().0 as *const $real,
1962 u.as_raw().0 as *const $raw,
1963 ldu,
1964 v.as_raw().0 as *const $raw,
1965 ldv,
1966 &mut lwork,
1967 params.as_raw(),
1968 batch_size,
1969 )
1970 })?;
1971 let workspace =
1972 DeviceBuffer::<T>::new(a.context(), lwork as usize).map_err(alloc_fail)?;
1973 check(unsafe {
1974 (c.$solve()?)(
1975 handle.as_raw(),
1976 jobz,
1977 m,
1978 n,
1979 a.as_raw().0 as *mut $raw,
1980 lda,
1981 s.as_raw().0 as *mut $real,
1982 u.as_raw().0 as *mut $raw,
1983 ldu,
1984 v.as_raw().0 as *mut $raw,
1985 ldv,
1986 workspace.as_raw().0 as *mut $raw,
1987 lwork,
1988 info.as_raw().0 as *mut c_int,
1989 params.as_raw(),
1990 batch_size,
1991 )
1992 })
1993 }};
1994 }
1995
1996 if mem::size_of::<T>() == mem::size_of::<f32>() && mem::size_of::<T::Real>() == 4 {
1997 dispatch_real!(
1998 f32,
1999 cusolver_dn_sgesvdj_batched_buffer_size,
2000 cusolver_dn_sgesvdj_batched
2001 )
2002 } else if mem::size_of::<T>() == mem::size_of::<f64>() && mem::size_of::<T::Real>() == 8 {
2003 dispatch_real!(
2004 f64,
2005 cusolver_dn_dgesvdj_batched_buffer_size,
2006 cusolver_dn_dgesvdj_batched
2007 )
2008 } else if mem::size_of::<T>() == mem::size_of::<Complex32>() {
2009 dispatch_complex!(
2010 Complex32,
2011 f32,
2012 cuComplex,
2013 cusolver_dn_cgesvdj_batched_buffer_size,
2014 cusolver_dn_cgesvdj_batched
2015 )
2016 } else {
2017 dispatch_complex!(
2018 Complex64,
2019 f64,
2020 cuDoubleComplex,
2021 cusolver_dn_zgesvdj_batched_buffer_size,
2022 cusolver_dn_zgesvdj_batched
2023 )
2024 }
2025}
2026
2027pub mod mg {
2030 use core::ffi::{c_int, c_void};
2035
2036 use baracuda_cusolver_sys::{
2037 cudaDataType, cudaLibMgGrid_t, cudaLibMgMatrixDesc_t, cusolver_mg, cusolverMgHandle_t,
2038 };
2039
2040 use super::{alloc_fail, check, EigMode, Fill, Result};
2041
2042 #[derive(Debug)]
2044 pub struct Handle {
2045 raw: cusolverMgHandle_t,
2046 }
2047
2048 impl Handle {
2049 pub fn new() -> Result<Self> {
2050 let mg = cusolver_mg()?;
2051 let cu = mg.cusolver_mg_create()?;
2052 let mut h: cusolverMgHandle_t = core::ptr::null_mut();
2053 check(unsafe { cu(&mut h) })?;
2054 Ok(Self { raw: h })
2055 }
2056
2057 pub fn device_select(&self, devices: &[i32]) -> Result<()> {
2060 let mg = cusolver_mg()?;
2061 let cu = mg.cusolver_mg_device_select()?;
2062 check(unsafe { cu(self.raw, devices.len() as c_int, devices.as_ptr()) })
2063 }
2064
2065 pub fn as_raw(&self) -> cusolverMgHandle_t {
2066 self.raw
2067 }
2068 }
2069
2070 impl Drop for Handle {
2071 fn drop(&mut self) {
2072 if let Ok(mg) = cusolver_mg() {
2073 if let Ok(cu) = mg.cusolver_mg_destroy() {
2074 let _ = unsafe { cu(self.raw) };
2075 }
2076 }
2077 }
2078 }
2079
2080 #[derive(Debug)]
2082 pub struct DeviceGrid {
2083 raw: cudaLibMgGrid_t,
2084 }
2085
2086 impl DeviceGrid {
2087 pub fn new(num_row_devices: i32, num_col_devices: i32, devices: &[i32], mapping: i32) -> Result<Self> {
2089 let mg = cusolver_mg()?;
2090 let cu = mg.cusolver_mg_create_device_grid()?;
2091 let mut raw: cudaLibMgGrid_t = core::ptr::null_mut();
2092 check(unsafe {
2093 cu(
2094 &mut raw,
2095 num_row_devices,
2096 num_col_devices,
2097 devices.as_ptr(),
2098 mapping,
2099 )
2100 })?;
2101 Ok(Self { raw })
2102 }
2103
2104 pub fn as_raw(&self) -> cudaLibMgGrid_t {
2105 self.raw
2106 }
2107 }
2108
2109 impl Drop for DeviceGrid {
2110 fn drop(&mut self) {
2111 if let Ok(mg) = cusolver_mg() {
2112 if let Ok(cu) = mg.cusolver_mg_destroy_grid() {
2113 let _ = unsafe { cu(self.raw) };
2114 }
2115 }
2116 }
2117 }
2118
2119 #[derive(Debug)]
2121 pub struct MatrixDesc {
2122 raw: cudaLibMgMatrixDesc_t,
2123 }
2124
2125 impl MatrixDesc {
2126 pub fn new(
2127 num_rows: i64,
2128 num_cols: i64,
2129 row_block_size: i64,
2130 col_block_size: i64,
2131 data_type: cudaDataType,
2132 grid: &DeviceGrid,
2133 ) -> Result<Self> {
2134 let mg = cusolver_mg()?;
2135 let cu = mg.cusolver_mg_create_matrix_desc()?;
2136 let mut raw: cudaLibMgMatrixDesc_t = core::ptr::null_mut();
2137 check(unsafe {
2138 cu(
2139 &mut raw,
2140 num_rows,
2141 num_cols,
2142 row_block_size,
2143 col_block_size,
2144 data_type,
2145 grid.as_raw(),
2146 )
2147 })?;
2148 Ok(Self { raw })
2149 }
2150
2151 pub fn as_raw(&self) -> cudaLibMgMatrixDesc_t {
2152 self.raw
2153 }
2154 }
2155
2156 impl Drop for MatrixDesc {
2157 fn drop(&mut self) {
2158 if let Ok(mg) = cusolver_mg() {
2159 if let Ok(cu) = mg.cusolver_mg_destroy_matrix_desc() {
2160 let _ = unsafe { cu(self.raw) };
2161 }
2162 }
2163 }
2164 }
2165
2166 #[allow(clippy::too_many_arguments)]
2172 pub unsafe fn getrf_buffer_size(
2173 handle: &Handle,
2174 m: i32,
2175 n: i32,
2176 array_d_a: *mut *mut c_void,
2177 ia: i32,
2178 ja: i32,
2179 desc_a: &MatrixDesc,
2180 array_d_ipiv: *mut *mut c_int,
2181 compute_type: cudaDataType,
2182 ) -> Result<i64> {
2183 let mg = cusolver_mg()?;
2184 let cu = mg.cusolver_mg_getrf_buffer_size()?;
2185 let mut lwork: i64 = 0;
2186 check(cu(
2187 handle.as_raw(),
2188 m,
2189 n,
2190 array_d_a,
2191 ia,
2192 ja,
2193 desc_a.as_raw(),
2194 array_d_ipiv,
2195 compute_type,
2196 &mut lwork,
2197 ))?;
2198 Ok(lwork)
2199 }
2200
2201 #[allow(clippy::too_many_arguments)]
2204 pub unsafe fn getrf(
2205 handle: &Handle,
2206 m: i32,
2207 n: i32,
2208 array_d_a: *mut *mut c_void,
2209 ia: i32,
2210 ja: i32,
2211 desc_a: &MatrixDesc,
2212 array_d_ipiv: *mut *mut c_int,
2213 compute_type: cudaDataType,
2214 array_d_work: *mut *mut c_void,
2215 lwork: i64,
2216 info: &mut [c_int],
2217 ) -> Result<()> {
2218 let mg = cusolver_mg()?;
2219 let cu = mg.cusolver_mg_getrf()?;
2220 let _ = alloc_fail::<()>; check(cu(
2222 handle.as_raw(),
2223 m,
2224 n,
2225 array_d_a,
2226 ia,
2227 ja,
2228 desc_a.as_raw(),
2229 array_d_ipiv,
2230 compute_type,
2231 array_d_work,
2232 lwork,
2233 info.as_mut_ptr(),
2234 ))
2235 }
2236
2237 #[allow(clippy::too_many_arguments)]
2242 pub unsafe fn potrf_buffer_size(
2243 handle: &Handle,
2244 uplo: Fill,
2245 n: i32,
2246 array_d_a: *mut *mut c_void,
2247 ia: i32,
2248 ja: i32,
2249 desc_a: &MatrixDesc,
2250 compute_type: cudaDataType,
2251 ) -> Result<i64> {
2252 let mg = cusolver_mg()?;
2253 let cu = mg.cusolver_mg_potrf_buffer_size()?;
2254 let mut lwork: i64 = 0;
2255 check(cu(
2256 handle.as_raw(),
2257 uplo,
2258 n,
2259 array_d_a,
2260 ia,
2261 ja,
2262 desc_a.as_raw(),
2263 compute_type,
2264 &mut lwork,
2265 ))?;
2266 Ok(lwork)
2267 }
2268
2269 #[allow(clippy::too_many_arguments)]
2272 pub unsafe fn potrf(
2273 handle: &Handle,
2274 uplo: Fill,
2275 n: i32,
2276 array_d_a: *mut *mut c_void,
2277 ia: i32,
2278 ja: i32,
2279 desc_a: &MatrixDesc,
2280 compute_type: cudaDataType,
2281 array_d_work: *mut *mut c_void,
2282 lwork: i64,
2283 info: &mut [c_int],
2284 ) -> Result<()> {
2285 let mg = cusolver_mg()?;
2286 let cu = mg.cusolver_mg_potrf()?;
2287 check(cu(
2288 handle.as_raw(),
2289 uplo,
2290 n,
2291 array_d_a,
2292 ia,
2293 ja,
2294 desc_a.as_raw(),
2295 compute_type,
2296 array_d_work,
2297 lwork,
2298 info.as_mut_ptr(),
2299 ))
2300 }
2301
2302 #[allow(clippy::too_many_arguments)]
2307 pub unsafe fn syevd_buffer_size(
2308 handle: &Handle,
2309 jobz: EigMode,
2310 uplo: Fill,
2311 n: i32,
2312 array_d_a: *mut *mut c_void,
2313 ia: i32,
2314 ja: i32,
2315 desc_a: &MatrixDesc,
2316 w: *mut c_void,
2317 data_type_w: cudaDataType,
2318 compute_type: cudaDataType,
2319 ) -> Result<i64> {
2320 let mg = cusolver_mg()?;
2321 let cu = mg.cusolver_mg_syevd_buffer_size()?;
2322 let mut lwork: i64 = 0;
2323 check(cu(
2324 handle.as_raw(),
2325 jobz,
2326 uplo,
2327 n,
2328 array_d_a,
2329 ia,
2330 ja,
2331 desc_a.as_raw(),
2332 w,
2333 data_type_w,
2334 compute_type,
2335 &mut lwork,
2336 ))?;
2337 Ok(lwork)
2338 }
2339
2340 #[allow(clippy::too_many_arguments)]
2343 pub unsafe fn syevd(
2344 handle: &Handle,
2345 jobz: EigMode,
2346 uplo: Fill,
2347 n: i32,
2348 array_d_a: *mut *mut c_void,
2349 ia: i32,
2350 ja: i32,
2351 desc_a: &MatrixDesc,
2352 w: *mut c_void,
2353 data_type_w: cudaDataType,
2354 compute_type: cudaDataType,
2355 array_d_work: *mut *mut c_void,
2356 lwork: i64,
2357 info: &mut [c_int],
2358 ) -> Result<()> {
2359 let mg = cusolver_mg()?;
2360 let cu = mg.cusolver_mg_syevd()?;
2361 check(cu(
2362 handle.as_raw(),
2363 jobz,
2364 uplo,
2365 n,
2366 array_d_a,
2367 ia,
2368 ja,
2369 desc_a.as_raw(),
2370 w,
2371 data_type_w,
2372 compute_type,
2373 array_d_work,
2374 lwork,
2375 info.as_mut_ptr(),
2376 ))
2377 }
2378}
2379
2380pub fn sgetrf(
2384 handle: &DnHandle,
2385 m: i32,
2386 n: i32,
2387 a: &mut DeviceBuffer<f32>,
2388 lda: i32,
2389 ipiv: &mut DeviceBuffer<i32>,
2390 info: &mut DeviceBuffer<i32>,
2391) -> Result<()> {
2392 getrf::<f32>(handle, m, n, a, lda, ipiv, info)
2393}
2394
2395#[allow(clippy::too_many_arguments)]
2397pub fn sgetrs(
2398 handle: &DnHandle,
2399 trans: Op,
2400 n: i32,
2401 nrhs: i32,
2402 a: &DeviceBuffer<f32>,
2403 lda: i32,
2404 ipiv: &DeviceBuffer<i32>,
2405 b: &mut DeviceBuffer<f32>,
2406 ldb: i32,
2407 info: &mut DeviceBuffer<i32>,
2408) -> Result<()> {
2409 getrs::<f32>(handle, trans, n, nrhs, a, lda, ipiv, b, ldb, info)
2410}
2411
2412pub mod xapi {
2415 use super::*;
2421 use baracuda_cusolver_sys::{cudaDataType, cusolverDnParams_t};
2422
2423 #[derive(Debug)]
2424 pub struct Params {
2425 raw: cusolverDnParams_t,
2426 }
2427
2428 impl Params {
2429 pub fn new() -> Result<Self> {
2430 let c = cusolver()?;
2431 let cu = c.cusolver_dn_create_params()?;
2432 let mut p: cusolverDnParams_t = core::ptr::null_mut();
2433 check(unsafe { cu(&mut p) })?;
2434 Ok(Self { raw: p })
2435 }
2436
2437 pub fn as_raw(&self) -> cusolverDnParams_t {
2438 self.raw
2439 }
2440 }
2441
2442 impl Drop for Params {
2443 fn drop(&mut self) {
2444 if let Ok(c) = cusolver() {
2445 if let Ok(cu) = c.cusolver_dn_destroy_params() {
2446 let _ = unsafe { cu(self.raw) };
2447 }
2448 }
2449 }
2450 }
2451
2452 #[allow(clippy::too_many_arguments)]
2455 pub fn xgetrf_buffer_size(
2456 handle: &DnHandle,
2457 params: &Params,
2458 m: i64,
2459 n: i64,
2460 data_type_a: cudaDataType,
2461 a: *const c_void,
2462 lda: i64,
2463 compute_type: cudaDataType,
2464 ) -> Result<(usize, usize)> {
2465 let c = cusolver()?;
2466 let cu = c.cusolver_dn_xgetrf_buffer_size()?;
2467 let (mut dev, mut host) = (0usize, 0usize);
2468 check(unsafe {
2469 cu(
2470 handle.as_raw(),
2471 params.raw,
2472 m,
2473 n,
2474 data_type_a,
2475 a,
2476 lda,
2477 compute_type,
2478 &mut dev,
2479 &mut host,
2480 )
2481 })?;
2482 Ok((dev, host))
2483 }
2484
2485 #[allow(clippy::too_many_arguments)]
2486 pub unsafe fn xgetrf(
2487 handle: &DnHandle,
2488 params: &Params,
2489 m: i64,
2490 n: i64,
2491 data_type_a: cudaDataType,
2492 a: *mut c_void,
2493 lda: i64,
2494 ipiv: *mut i64,
2495 compute_type: cudaDataType,
2496 device_buf: *mut c_void,
2497 device_bytes: usize,
2498 host_buf: *mut c_void,
2499 host_bytes: usize,
2500 info: *mut c_int,
2501 ) -> Result<()> {
2502 let c = cusolver()?;
2503 let cu = c.cusolver_dn_xgetrf()?;
2504 check(cu(
2505 handle.as_raw(),
2506 params.raw,
2507 m,
2508 n,
2509 data_type_a,
2510 a,
2511 lda,
2512 ipiv,
2513 compute_type,
2514 device_buf,
2515 device_bytes,
2516 host_buf,
2517 host_bytes,
2518 info,
2519 ))
2520 }
2521}
2522
2523pub mod sparse {
2526 use super::*;
2529 use baracuda_cusolver_sys::cusolverSpHandle_t;
2530 use core::ffi::c_int;
2531
2532 #[derive(Debug)]
2533 pub struct SpHandle {
2534 raw: cusolverSpHandle_t,
2535 _not_send: PhantomData<*mut ()>,
2536 }
2537
2538 impl SpHandle {
2539 pub fn new() -> Result<Self> {
2540 let c = cusolver()?;
2541 let cu = c.cusolver_sp_create()?;
2542 let mut h: cusolverSpHandle_t = core::ptr::null_mut();
2543 check(unsafe { cu(&mut h) })?;
2544 Ok(Self {
2545 raw: h,
2546 _not_send: PhantomData,
2547 })
2548 }
2549
2550 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
2551 let c = cusolver()?;
2552 let cu = c.cusolver_sp_set_stream()?;
2553 check(unsafe { cu(self.raw, stream.as_raw() as _) })
2554 }
2555
2556 pub fn as_raw(&self) -> cusolverSpHandle_t {
2557 self.raw
2558 }
2559 }
2560
2561 impl Drop for SpHandle {
2562 fn drop(&mut self) {
2563 if let Ok(c) = cusolver() {
2564 if let Ok(cu) = c.cusolver_sp_destroy() {
2565 let _ = unsafe { cu(self.raw) };
2566 }
2567 }
2568 }
2569 }
2570
2571 #[allow(clippy::too_many_arguments)]
2578 pub unsafe fn scsrlsvchol(
2579 handle: &SpHandle,
2580 m: i32,
2581 nnz: i32,
2582 descr_a: *mut c_void,
2583 csr_val: *const f32,
2584 csr_row_ptr: *const c_int,
2585 csr_col_ind: *const c_int,
2586 b: *const f32,
2587 tol: f32,
2588 reorder: i32,
2589 x: *mut f32,
2590 singularity: *mut c_int,
2591 ) -> Result<()> {
2592 let c = cusolver()?;
2593 let cu = c.cusolver_sp_scsrlsvchol()?;
2594 check(cu(
2595 handle.raw,
2596 m,
2597 nnz,
2598 descr_a,
2599 csr_val,
2600 csr_row_ptr,
2601 csr_col_ind,
2602 b,
2603 tol,
2604 reorder,
2605 x,
2606 singularity,
2607 ))
2608 }
2609
2610 #[allow(clippy::too_many_arguments)]
2615 pub unsafe fn scsrlsvqr(
2616 handle: &SpHandle,
2617 m: i32,
2618 nnz: i32,
2619 descr_a: *mut c_void,
2620 csr_val: *const f32,
2621 csr_row_ptr: *const c_int,
2622 csr_col_ind: *const c_int,
2623 b: *const f32,
2624 tol: f32,
2625 reorder: i32,
2626 x: *mut f32,
2627 singularity: *mut c_int,
2628 ) -> Result<()> {
2629 let c = cusolver()?;
2630 let cu = c.cusolver_sp_scsrlsvqr()?;
2631 check(cu(
2632 handle.raw,
2633 m,
2634 nnz,
2635 descr_a,
2636 csr_val,
2637 csr_row_ptr,
2638 csr_col_ind,
2639 b,
2640 tol,
2641 reorder,
2642 x,
2643 singularity,
2644 ))
2645 }
2646}
2647
2648pub mod refactor {
2651 use super::*;
2655 use baracuda_cusolver_sys::cusolverRfHandle_t;
2656
2657 #[derive(Debug)]
2658 pub struct RfHandle {
2659 raw: cusolverRfHandle_t,
2660 _not_send: PhantomData<*mut ()>,
2661 }
2662
2663 impl RfHandle {
2664 pub fn new() -> Result<Self> {
2665 let c = cusolver()?;
2666 let cu = c.cusolver_rf_create()?;
2667 let mut h: cusolverRfHandle_t = core::ptr::null_mut();
2668 check(unsafe { cu(&mut h) })?;
2669 Ok(Self {
2670 raw: h,
2671 _not_send: PhantomData,
2672 })
2673 }
2674
2675 pub fn as_raw(&self) -> cusolverRfHandle_t {
2676 self.raw
2677 }
2678
2679 pub fn analyze(&self) -> Result<()> {
2680 let c = cusolver()?;
2681 let cu = c.cusolver_rf_analyze()?;
2682 check(unsafe { cu(self.raw) })
2683 }
2684
2685 pub fn refactor(&self) -> Result<()> {
2686 let c = cusolver()?;
2687 let cu = c.cusolver_rf_refactor()?;
2688 check(unsafe { cu(self.raw) })
2689 }
2690 }
2691
2692 impl Drop for RfHandle {
2693 fn drop(&mut self) {
2694 if let Ok(c) = cusolver() {
2695 if let Ok(cu) = c.cusolver_rf_destroy() {
2696 let _ = unsafe { cu(self.raw) };
2697 }
2698 }
2699 }
2700 }
2701}