1pub fn blas_backend_status() -> super::CudaBackendStatus {
18 super::cuda_backend_status()
19}
20
21#[cfg(target_os = "linux")]
22mod cuda_impl {
23 use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, Axis};
24
25 use crate::driver::{array_from_row_major, from_col_major, to_col_major, to_i32, to_row_major};
26
27 use super::super::device_runtime::GpuRuntime;
28 use cudarc::cublas::sys::{
29 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
30 };
31 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig, StridedBatchedConfig};
32 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
33 use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
34 use std::sync::Arc;
35
36 #[inline]
43 pub(crate) fn stream_and_blas_for(ordinal: usize) -> Option<(Arc<CudaStream>, CudaBlas)> {
44 let stream = super::super::device_runtime::cuda_context_for(ordinal)?
45 .new_stream()
46 .ok()?;
47 let blas = CudaBlas::new(stream.clone()).ok()?;
48 Some((stream, blas))
49 }
50
51 #[inline]
52 fn stream_and_blas(runtime: &GpuRuntime) -> Option<(Arc<CudaStream>, CudaBlas)> {
53 stream_and_blas_for(runtime.device.ordinal)
54 }
55
56 #[inline]
57 fn vector_values(v: ArrayView1<'_, f64>) -> Vec<f64> {
58 v.iter().copied().collect()
59 }
60
61 #[inline]
62 fn to_col_major_batch(batch: ArrayView3<'_, f64>) -> Vec<f64> {
63 let (batch_len, rows, cols) = batch.dim();
64 let mut out = Vec::with_capacity(batch_len.saturating_mul(rows).saturating_mul(cols));
65 for matrix in batch.axis_iter(Axis(0)) {
66 out.extend(to_col_major(&matrix).iter().copied());
67 }
68 out
69 }
70
71 #[inline]
72 fn from_col_major_batch(
73 data: &[f64],
74 batch: usize,
75 rows: usize,
76 cols: usize,
77 ) -> Option<Array3<f64>> {
78 if data.len() != batch.checked_mul(rows)?.checked_mul(cols)? {
79 return None;
80 }
81 let mut out = Array3::<f64>::zeros((batch, rows, cols));
82 let matrix_len = rows.checked_mul(cols)?;
83 for batch_idx in 0..batch {
84 let base = batch_idx.checked_mul(matrix_len)?;
85 for col in 0..cols {
86 for row in 0..rows {
87 out[[batch_idx, row, col]] = data[base + col * rows + row];
88 }
89 }
90 }
91 Some(out)
92 }
93
94 #[inline]
95 fn row_scale_device(
96 blas: &CudaBlas,
97 stream: &Arc<CudaStream>,
98 matrix_dev: &CudaSlice<f64>,
99 weights_dev: &CudaSlice<f64>,
100 scaled_dev: &mut CudaSlice<f64>,
101 rows: usize,
102 cols: usize,
103 ) -> Option<()> {
104 let rows_i = to_i32(rows)?;
105 let cols_i = to_i32(cols)?;
106 let handle = *blas.handle();
107 let (matrix_ptr, _matrix_record) = matrix_dev.device_ptr(stream);
108 let (weights_ptr, _weights_record) = weights_dev.device_ptr(stream);
109 let (scaled_ptr, _scaled_record) = scaled_dev.device_ptr_mut(stream);
110 let status = unsafe {
114 cudarc::cublas::sys::cublasDdgmm(
115 handle,
116 cublasSideMode_t::CUBLAS_SIDE_LEFT,
117 rows_i,
118 cols_i,
119 matrix_ptr as *const f64,
120 rows_i,
121 weights_ptr as *const f64,
122 1,
123 scaled_ptr as *mut f64,
124 rows_i,
125 )
126 };
127 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
128 Some(())
129 } else {
130 None
131 }
132 }
133
134 #[inline]
135 fn weighted_crossprod(
136 runtime: &GpuRuntime,
137 left: ArrayView2<'_, f64>,
138 weights: ArrayView1<'_, f64>,
139 right: ArrayView2<'_, f64>,
140 ) -> Option<Array2<f64>> {
141 weighted_crossprod_for(runtime.device.ordinal, left, weights, right)
142 }
143
144 #[inline]
145 fn weighted_crossprod_for(
146 ordinal: usize,
147 left: ArrayView2<'_, f64>,
148 weights: ArrayView1<'_, f64>,
149 right: ArrayView2<'_, f64>,
150 ) -> Option<Array2<f64>> {
151 let (rows, left_cols) = left.dim();
152 let (right_rows, right_cols) = right.dim();
153 if rows == 0
154 || left_cols == 0
155 || right_cols == 0
156 || rows != right_rows
157 || rows != weights.len()
158 {
159 return None;
160 }
161
162 let (stream, blas) = stream_and_blas_for(ordinal)?;
163 let same_operand = std::ptr::eq(left.as_ptr(), right.as_ptr())
170 && left.dim() == right.dim()
171 && left.strides() == right.strides();
172 let left_col = to_col_major(&left);
173 let weights_host = vector_values(weights);
174 let left_dev = stream.clone_htod(&*left_col).ok()?;
175 let right_dev = if same_operand {
179 None
180 } else {
181 let right_col = to_col_major(&right);
182 Some(stream.clone_htod(&*right_col).ok()?)
183 };
184 let weights_dev = stream.clone_htod(&weights_host).ok()?;
185 let mut weighted_right_dev = stream
186 .alloc_zeros::<f64>(rows.checked_mul(right_cols)?)
187 .ok()?;
188 row_scale_device(
189 &blas,
190 &stream,
191 right_dev.as_ref().unwrap_or(&left_dev),
192 &weights_dev,
193 &mut weighted_right_dev,
194 rows,
195 right_cols,
196 )?;
197
198 let mut out_dev = stream
199 .alloc_zeros::<f64>(left_cols.checked_mul(right_cols)?)
200 .ok()?;
201 let cfg = GemmConfig::<f64> {
202 transa: cublasOperation_t::CUBLAS_OP_T,
203 transb: cublasOperation_t::CUBLAS_OP_N,
204 m: to_i32(left_cols)?,
205 n: to_i32(right_cols)?,
206 k: to_i32(rows)?,
207 alpha: 1.0,
208 lda: to_i32(rows)?,
209 ldb: to_i32(rows)?,
210 beta: 0.0,
211 ldc: to_i32(left_cols)?,
212 };
213 unsafe { blas.gemm(cfg, &left_dev, &weighted_right_dev, &mut out_dev) }.ok()?;
216 let out_col = stream.clone_dtoh(&out_dev).ok()?;
217 from_col_major(&out_col, left_cols, right_cols)
218 }
219
220 pub(crate) struct ResidentWeightedGram {
235 stream: Arc<CudaStream>,
236 blas: CudaBlas,
237 x_dev: CudaSlice<f64>,
238 rows: usize,
239 cols: usize,
240 }
241
242 impl ResidentWeightedGram {
243 pub(crate) fn new(ordinal: usize, x: ArrayView2<'_, f64>) -> Option<Self> {
247 let (rows, cols) = x.dim();
248 if rows == 0 || cols == 0 {
249 return None;
250 }
251 let (stream, blas) = stream_and_blas_for(ordinal)?;
252 let x_col = to_col_major(&x);
253 let x_dev = stream.clone_htod(&*x_col).ok()?;
254 Some(Self {
255 stream,
256 blas,
257 x_dev,
258 rows,
259 cols,
260 })
261 }
262
263 #[inline]
264 pub(crate) fn dims(&self) -> (usize, usize) {
265 (self.rows, self.cols)
266 }
267
268 pub(crate) fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
273 if w.len() != self.rows {
274 return None;
275 }
276 let weights_host = vector_values(w);
277 let weights_dev = self.stream.clone_htod(&weights_host).ok()?;
278 let mut weighted_dev = self
279 .stream
280 .alloc_zeros::<f64>(self.rows.checked_mul(self.cols)?)
281 .ok()?;
282 row_scale_device(
283 &self.blas,
284 &self.stream,
285 &self.x_dev,
286 &weights_dev,
287 &mut weighted_dev,
288 self.rows,
289 self.cols,
290 )?;
291 let mut out_dev = self
292 .stream
293 .alloc_zeros::<f64>(self.cols.checked_mul(self.cols)?)
294 .ok()?;
295 let cfg = GemmConfig::<f64> {
296 transa: cublasOperation_t::CUBLAS_OP_T,
297 transb: cublasOperation_t::CUBLAS_OP_N,
298 m: to_i32(self.cols)?,
299 n: to_i32(self.cols)?,
300 k: to_i32(self.rows)?,
301 alpha: 1.0,
302 lda: to_i32(self.rows)?,
303 ldb: to_i32(self.rows)?,
304 beta: 0.0,
305 ldc: to_i32(self.cols)?,
306 };
307 unsafe {
310 self.blas
311 .gemm(cfg, &self.x_dev, &weighted_dev, &mut out_dev)
312 }
313 .ok()?;
314 let out_col = self.stream.clone_dtoh(&out_dev).ok()?;
315 from_col_major(&out_col, self.cols, self.cols)
316 }
317
318 pub(crate) fn solve_psd_normal_equations(
337 &self,
338 w: ArrayView1<'_, f64>,
339 rhs: ArrayView1<'_, f64>,
340 ridge: f64,
341 ) -> Option<Array1<f64>> {
342 if w.len() != self.rows || rhs.len() != self.cols {
343 return None;
344 }
345 let p = self.cols;
346
347 let weights_dev = self.stream.clone_htod(&vector_values(w)).ok()?;
349 let mut weighted_dev = self
350 .stream
351 .alloc_zeros::<f64>(self.rows.checked_mul(p)?)
352 .ok()?;
353 row_scale_device(
354 &self.blas,
355 &self.stream,
356 &self.x_dev,
357 &weights_dev,
358 &mut weighted_dev,
359 self.rows,
360 p,
361 )?;
362
363 let mut ridge_init = vec![0.0_f64; p.checked_mul(p)?];
370 for i in 0..p {
371 ridge_init[i * p + i] = ridge;
372 }
373 let mut g_dev = self.stream.clone_htod(&ridge_init).ok()?;
374 let cfg = GemmConfig::<f64> {
375 transa: cublasOperation_t::CUBLAS_OP_T,
376 transb: cublasOperation_t::CUBLAS_OP_N,
377 m: to_i32(p)?,
378 n: to_i32(p)?,
379 k: to_i32(self.rows)?,
380 alpha: 1.0,
381 lda: to_i32(self.rows)?,
382 ldb: to_i32(self.rows)?,
383 beta: 1.0,
385 ldc: to_i32(p)?,
386 };
387 unsafe { self.blas.gemm(cfg, &self.x_dev, &weighted_dev, &mut g_dev) }.ok()?;
390
391 let solver = DnHandle::new(self.stream.clone()).ok()?;
393 let info = potrf_single_dev(&solver, &self.stream, p, &mut g_dev)?;
394 if info != 0 {
395 return None;
397 }
398
399 let mut rhs_dev = self.stream.clone_htod(&vector_values(rhs)).ok()?;
401 trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, false)?; trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, true)?; let beta_host = self.stream.clone_dtoh(&rhs_dev).ok()?;
406 Some(Array1::from_vec(beta_host))
407 }
408 }
409
410 fn potrf_single_dev(
414 solver: &DnHandle,
415 stream: &Arc<CudaStream>,
416 p: usize,
417 matrix: &mut CudaSlice<f64>,
418 ) -> Option<i32> {
419 let p_i = to_i32(p)?;
420 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
421 let mut lwork = 0_i32;
422 {
423 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
424 let status = unsafe {
426 cusolver_sys::cusolverDnDpotrf_bufferSize(
427 solver.cu(),
428 uplo,
429 p_i,
430 mat_ptr as *mut f64,
431 p_i,
432 &mut lwork,
433 )
434 };
435 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
436 return None;
437 }
438 }
439 let mut workspace = stream.alloc_zeros::<f64>(lwork.max(1) as usize).ok()?;
440 let mut info_dev = stream.alloc_zeros::<i32>(1).ok()?;
441 {
442 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
443 let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
444 let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
445 let status = unsafe {
447 cusolver_sys::cusolverDnDpotrf(
448 solver.cu(),
449 uplo,
450 p_i,
451 mat_ptr as *mut f64,
452 p_i,
453 work_ptr as *mut f64,
454 lwork,
455 info_ptr as *mut i32,
456 )
457 };
458 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
459 return None;
460 }
461 }
462 let info_host = stream.clone_dtoh(&info_dev).ok()?;
463 info_host.first().copied()
464 }
465
466 fn trsm_single_vec(
470 blas: &CudaBlas,
471 stream: &Arc<CudaStream>,
472 p: usize,
473 l: &CudaSlice<f64>,
474 rhs: &mut CudaSlice<f64>,
475 transposed: bool,
476 ) -> Option<()> {
477 let alpha = 1.0_f64;
478 let p_i = to_i32(p)?;
479 let handle = *blas.handle();
480 let (l_ptr, _l_rec) = l.device_ptr(stream);
481 let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
482 let status = unsafe {
484 cudarc::cublas::sys::cublasDtrsm_v2(
485 handle,
486 cublasSideMode_t::CUBLAS_SIDE_LEFT,
487 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
488 if transposed {
489 cublasOperation_t::CUBLAS_OP_T
490 } else {
491 cublasOperation_t::CUBLAS_OP_N
492 },
493 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
494 p_i,
495 1,
496 &alpha,
497 l_ptr as *const f64,
498 p_i,
499 rhs_ptr as *mut f64,
500 p_i,
501 )
502 };
503 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
504 Some(())
505 } else {
506 None
507 }
508 }
509
510 #[inline]
511 fn assign_block(
512 out: &mut Array2<f64>,
513 row_offset: usize,
514 col_offset: usize,
515 block: &Array2<f64>,
516 ) {
517 let (rows, cols) = block.dim();
518 for col in 0..cols {
519 for row in 0..rows {
520 out[[row_offset + row, col_offset + col]] = block[[row, col]];
521 }
522 }
523 }
524
525 #[inline]
526 fn mirror_upper_to_lower(out: &mut Array2<f64>) {
527 let n = out.nrows();
528 for row in 0..n {
529 for col in 0..row {
530 out[[row, col]] = out[[col, row]];
531 }
532 }
533 }
534
535 #[inline]
536 pub(crate) fn gemm_cuda(
537 runtime: &GpuRuntime,
538 a: ArrayView2<'_, f64>,
539 b: ArrayView2<'_, f64>,
540 trans_a: bool,
541 trans_b: bool,
542 ) -> Option<Array2<f64>> {
543 gemm_on_ordinal_cuda(runtime.device.ordinal, a, b, trans_a, trans_b)
544 }
545
546 #[inline]
552 pub(crate) fn gemm_on_ordinal_cuda(
553 ordinal: usize,
554 a: ArrayView2<'_, f64>,
555 b: ArrayView2<'_, f64>,
556 trans_a: bool,
557 trans_b: bool,
558 ) -> Option<Array2<f64>> {
559 let (a_rows, a_cols) = a.dim();
560 let (b_rows, b_cols) = b.dim();
561 let (m, k_a) = if trans_a {
562 (a_cols, a_rows)
563 } else {
564 (a_rows, a_cols)
565 };
566 let (k_b, n) = if trans_b {
567 (b_cols, b_rows)
568 } else {
569 (b_rows, b_cols)
570 };
571 if m == 0 || n == 0 || k_a == 0 || k_a != k_b {
572 return None;
573 }
574 let (stream, blas) = stream_and_blas_for(ordinal)?;
575 let b_rm = to_row_major(&b);
597 let a_rm = to_row_major(&a);
598 let x_dev = stream.clone_htod(&*b_rm).ok()?;
599 let y_dev = stream.clone_htod(&*a_rm).ok()?;
600 let mut out_dev = stream.alloc_zeros::<f64>(m.checked_mul(n)?).ok()?;
601 let cfg = GemmConfig::<f64> {
602 transa: if trans_b {
603 cublasOperation_t::CUBLAS_OP_T
604 } else {
605 cublasOperation_t::CUBLAS_OP_N
606 },
607 transb: if trans_a {
608 cublasOperation_t::CUBLAS_OP_T
609 } else {
610 cublasOperation_t::CUBLAS_OP_N
611 },
612 m: to_i32(n)?,
613 n: to_i32(m)?,
614 k: to_i32(k_a)?,
615 alpha: 1.0,
616 lda: to_i32(b_cols)?,
619 ldb: to_i32(a_cols)?,
620 beta: 0.0,
621 ldc: to_i32(n)?,
622 };
623 unsafe { blas.gemm(cfg, &x_dev, &y_dev, &mut out_dev) }.ok()?;
626 let out_rm = stream.clone_dtoh(&out_dev).ok()?;
628 array_from_row_major(out_rm, m, n)
629 }
630
631 #[inline]
636 pub(crate) fn gemm_broadcast_b_batched_cuda(
637 ordinal: usize,
638 a: ArrayView3<'_, f64>,
639 b: ArrayView2<'_, f64>,
640 ) -> Option<Array3<f64>> {
641 let (batch, m, k) = a.dim();
642 let (b_rows, n) = b.dim();
643 if batch == 0 || m == 0 || n == 0 || k == 0 || b_rows != k {
644 return None;
645 }
646 let (stream, blas) = stream_and_blas_for(ordinal)?;
647 let a_col = to_col_major_batch(a);
648 let b_col = to_col_major(&b);
649 let a_dev = stream.clone_htod(&a_col).ok()?;
650 let b_dev = stream.clone_htod(&*b_col).ok()?;
651 let mut out_dev = stream
652 .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
653 .ok()?;
654 let cfg = StridedBatchedConfig::<f64> {
655 gemm: GemmConfig::<f64> {
656 transa: cublasOperation_t::CUBLAS_OP_N,
657 transb: cublasOperation_t::CUBLAS_OP_N,
658 m: to_i32(m)?,
659 n: to_i32(n)?,
660 k: to_i32(k)?,
661 alpha: 1.0,
662 lda: to_i32(m)?,
663 ldb: to_i32(k)?,
664 beta: 0.0,
665 ldc: to_i32(m)?,
666 },
667 batch_size: to_i32(batch)?,
668 stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
669 stride_b: 0,
670 stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
671 };
672 unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
676 let out_col = stream.clone_dtoh(&out_dev).ok()?;
677 from_col_major_batch(&out_col, batch, m, n)
678 }
679
680 #[inline]
684 pub(crate) fn gemm_abt_strided_batched_cuda(
685 ordinal: usize,
686 a: ArrayView3<'_, f64>,
687 b: ArrayView3<'_, f64>,
688 ) -> Option<Array3<f64>> {
689 let (batch, m, k) = a.dim();
690 let (batch_b, n, k_b) = b.dim();
691 if batch == 0 || m == 0 || n == 0 || k == 0 || batch != batch_b || k != k_b {
692 return None;
693 }
694 let (stream, blas) = stream_and_blas_for(ordinal)?;
695 let a_col = to_col_major_batch(a);
696 let b_col = to_col_major_batch(b);
697 let a_dev = stream.clone_htod(&a_col).ok()?;
698 let b_dev = stream.clone_htod(&b_col).ok()?;
699 let mut out_dev = stream
700 .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
701 .ok()?;
702 let cfg = StridedBatchedConfig::<f64> {
703 gemm: GemmConfig::<f64> {
704 transa: cublasOperation_t::CUBLAS_OP_N,
705 transb: cublasOperation_t::CUBLAS_OP_T,
706 m: to_i32(m)?,
707 n: to_i32(n)?,
708 k: to_i32(k)?,
709 alpha: 1.0,
710 lda: to_i32(m)?,
711 ldb: to_i32(n)?,
712 beta: 0.0,
713 ldc: to_i32(m)?,
714 },
715 batch_size: to_i32(batch)?,
716 stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
717 stride_b: i64::try_from(n.checked_mul(k)?).ok()?,
718 stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
719 };
720 unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
723 let out_col = stream.clone_dtoh(&out_dev).ok()?;
724 from_col_major_batch(&out_col, batch, m, n)
725 }
726
727 #[inline]
728 pub(crate) fn gemv_cuda(
729 runtime: &GpuRuntime,
730 a: ArrayView2<'_, f64>,
731 v: ArrayView1<'_, f64>,
732 trans_a: bool,
733 ) -> Option<Array1<f64>> {
734 let (rows, cols) = a.dim();
735 let out_len = if trans_a { cols } else { rows };
736 let needed = if trans_a { rows } else { cols };
737 if out_len == 0 || needed == 0 || v.len() != needed {
738 return None;
739 }
740 let (stream, blas) = stream_and_blas(runtime)?;
741 let a_col = to_col_major(&a);
742 let a_dev = stream.clone_htod(&*a_col).ok()?;
743 let v_host = vector_values(v);
744 let v_dev = stream.clone_htod(&v_host).ok()?;
745 let mut out_dev = stream.alloc_zeros::<f64>(out_len).ok()?;
746 let cfg = GemvConfig::<f64> {
747 trans: if trans_a {
748 cublasOperation_t::CUBLAS_OP_T
749 } else {
750 cublasOperation_t::CUBLAS_OP_N
751 },
752 m: to_i32(rows)?,
753 n: to_i32(cols)?,
754 alpha: 1.0,
755 lda: to_i32(rows)?,
756 incx: 1,
757 beta: 0.0,
758 incy: 1,
759 };
760 unsafe { blas.gemv(cfg, &a_dev, &v_dev, &mut out_dev) }.ok()?;
762 Some(Array1::from_vec(stream.clone_dtoh(&out_dev).ok()?))
763 }
764
765 #[inline]
766 pub fn xt_diag_x_cuda(
767 runtime: &GpuRuntime,
768 x: ArrayView2<'_, f64>,
769 w: ArrayView1<'_, f64>,
770 ) -> Option<Array2<f64>> {
771 let (rows, cols) = x.dim();
772 if rows == 0 || cols == 0 || rows != w.len() {
773 return None;
774 }
775 weighted_crossprod(runtime, x, w, x)
776 }
777
778 #[inline]
779 pub(crate) fn xt_diag_x_on_ordinal_cuda(
780 ordinal: usize,
781 x: ArrayView2<'_, f64>,
782 w: ArrayView1<'_, f64>,
783 ) -> Option<Array2<f64>> {
784 let (rows, cols) = x.dim();
785 if rows == 0 || cols == 0 || rows != w.len() {
786 return None;
787 }
788 weighted_crossprod_for(ordinal, x, w, x)
789 }
790
791 #[inline]
792 pub fn xt_diag_y_cuda(
793 runtime: &GpuRuntime,
794 x: ArrayView2<'_, f64>,
795 w: ArrayView1<'_, f64>,
796 y: ArrayView2<'_, f64>,
797 ) -> Option<Array2<f64>> {
798 weighted_crossprod(runtime, x, w, y)
799 }
800
801 #[inline]
802 pub(crate) fn joint_hessian_2x2_cuda(
803 runtime: &GpuRuntime,
804 x_a: ArrayView2<'_, f64>,
805 x_b: ArrayView2<'_, f64>,
806 w_aa: ArrayView1<'_, f64>,
807 w_ab: ArrayView1<'_, f64>,
808 w_bb: ArrayView1<'_, f64>,
809 ) -> Option<Array2<f64>> {
810 let (rows, pa) = x_a.dim();
811 let (rows_b, pb) = x_b.dim();
812 let total = pa.checked_add(pb)?;
813 if rows == 0
814 || total == 0
815 || rows != rows_b
816 || rows != w_aa.len()
817 || rows != w_ab.len()
818 || rows != w_bb.len()
819 {
820 return None;
821 }
822
823 let mut out = Array2::<f64>::zeros((total, total));
824 if pa > 0 {
825 let aa = weighted_crossprod(runtime, x_a, w_aa, x_a)?;
826 assign_block(&mut out, 0, 0, &aa);
827 }
828 if pa > 0 && pb > 0 {
829 let ab = weighted_crossprod(runtime, x_a, w_ab, x_b)?;
830 assign_block(&mut out, 0, pa, &ab);
831 }
832 if pb > 0 {
833 let bb = weighted_crossprod(runtime, x_b, w_bb, x_b)?;
834 assign_block(&mut out, pa, pa, &bb);
835 }
836 mirror_upper_to_lower(&mut out);
837 Some(out)
838 }
839
840 #[inline]
841 pub(crate) fn trsm_cuda(
842 runtime: &GpuRuntime,
843 triangular: ArrayView2<'_, f64>,
844 rhs: ArrayView2<'_, f64>,
845 upper: bool,
846 ) -> Option<Array2<f64>> {
847 let (n, n2) = triangular.dim();
848 if n == 0 || n != n2 || rhs.nrows() != n {
849 return None;
850 }
851 let nrhs = rhs.ncols();
852 let (stream, blas) = stream_and_blas(runtime)?;
853 let tri_col = to_col_major(&triangular);
854 let rhs_col = to_col_major(&rhs);
855 let tri_dev = stream.clone_htod(&*tri_col).ok()?;
856 let mut rhs_dev = stream.clone_htod(&*rhs_col).ok()?;
857 let alpha = 1.0_f64;
858 let handle = *blas.handle();
859 {
860 let (tri_ptr, _tri_record) = tri_dev.device_ptr(&stream);
861 let (rhs_ptr, _rhs_record) = rhs_dev.device_ptr_mut(&stream);
862 let status = unsafe {
865 cudarc::cublas::sys::cublasDtrsm_v2(
866 handle,
867 cublasSideMode_t::CUBLAS_SIDE_LEFT,
868 if upper {
869 cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
870 } else {
871 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
872 },
873 cublasOperation_t::CUBLAS_OP_N,
874 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
875 to_i32(n)?,
876 to_i32(nrhs)?,
877 &alpha,
878 tri_ptr as *const f64,
879 to_i32(n)?,
880 rhs_ptr as *mut f64,
881 to_i32(n)?,
882 )
883 };
884 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
885 return None;
886 }
887 };
888 let out_col = stream.clone_dtoh(&rhs_dev).ok()?;
889 from_col_major(&out_col, n, nrhs)
890 }
891}
892
893#[cfg(target_os = "linux")]
894pub(crate) use cuda_impl::{
895 ResidentWeightedGram, gemm_abt_strided_batched_cuda, gemm_broadcast_b_batched_cuda, gemm_cuda,
896 gemm_on_ordinal_cuda, gemv_cuda, joint_hessian_2x2_cuda, trsm_cuda, xt_diag_x_on_ordinal_cuda,
897};
898#[cfg(target_os = "linux")]
905pub use cuda_impl::{xt_diag_x_cuda, xt_diag_y_cuda};