1pub fn blas_backend_status() -> super::CudaBackendStatus {
18 super::cuda_backend_status()
19}
20
21#[cfg(target_os = "linux")]
22mod cuda_impl {
23 use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, Axis};
24
25 use crate::driver::{
26 array_from_row_major, from_col_major, to_col_major, to_i32, to_row_major,
27 };
28
29 use super::super::device_runtime::GpuRuntime;
30 use cudarc::cublas::sys::{
31 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
32 };
33 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig, StridedBatchedConfig};
34 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
35 use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
36 use std::sync::Arc;
37
38 #[inline]
45 pub(crate) fn stream_and_blas_for(ordinal: usize) -> Option<(Arc<CudaStream>, CudaBlas)> {
46 let stream = super::super::device_runtime::cuda_context_for(ordinal)?
47 .new_stream()
48 .ok()?;
49 let blas = CudaBlas::new(stream.clone()).ok()?;
50 Some((stream, blas))
51 }
52
53 #[inline]
54 fn stream_and_blas(runtime: &GpuRuntime) -> Option<(Arc<CudaStream>, CudaBlas)> {
55 stream_and_blas_for(runtime.device.ordinal)
56 }
57
58 #[inline]
59 fn vector_values(v: ArrayView1<'_, f64>) -> Vec<f64> {
60 v.iter().copied().collect()
61 }
62
63 #[inline]
64 fn to_col_major_batch(batch: ArrayView3<'_, f64>) -> Vec<f64> {
65 let (batch_len, rows, cols) = batch.dim();
66 let mut out = Vec::with_capacity(batch_len.saturating_mul(rows).saturating_mul(cols));
67 for matrix in batch.axis_iter(Axis(0)) {
68 out.extend(to_col_major(&matrix).iter().copied());
69 }
70 out
71 }
72
73 #[inline]
74 fn from_col_major_batch(
75 data: &[f64],
76 batch: usize,
77 rows: usize,
78 cols: usize,
79 ) -> Option<Array3<f64>> {
80 if data.len() != batch.checked_mul(rows)?.checked_mul(cols)? {
81 return None;
82 }
83 let mut out = Array3::<f64>::zeros((batch, rows, cols));
84 let matrix_len = rows.checked_mul(cols)?;
85 for batch_idx in 0..batch {
86 let base = batch_idx.checked_mul(matrix_len)?;
87 for col in 0..cols {
88 for row in 0..rows {
89 out[[batch_idx, row, col]] = data[base + col * rows + row];
90 }
91 }
92 }
93 Some(out)
94 }
95
96 #[inline]
97 fn row_scale_device(
98 blas: &CudaBlas,
99 stream: &Arc<CudaStream>,
100 matrix_dev: &CudaSlice<f64>,
101 weights_dev: &CudaSlice<f64>,
102 scaled_dev: &mut CudaSlice<f64>,
103 rows: usize,
104 cols: usize,
105 ) -> Option<()> {
106 let rows_i = to_i32(rows)?;
107 let cols_i = to_i32(cols)?;
108 let handle = *blas.handle();
109 let (matrix_ptr, _matrix_record) = matrix_dev.device_ptr(stream);
110 let (weights_ptr, _weights_record) = weights_dev.device_ptr(stream);
111 let (scaled_ptr, _scaled_record) = scaled_dev.device_ptr_mut(stream);
112 let status = unsafe {
116 cudarc::cublas::sys::cublasDdgmm(
117 handle,
118 cublasSideMode_t::CUBLAS_SIDE_LEFT,
119 rows_i,
120 cols_i,
121 matrix_ptr as *const f64,
122 rows_i,
123 weights_ptr as *const f64,
124 1,
125 scaled_ptr as *mut f64,
126 rows_i,
127 )
128 };
129 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
130 Some(())
131 } else {
132 None
133 }
134 }
135
136 #[inline]
137 fn weighted_crossprod(
138 runtime: &GpuRuntime,
139 left: ArrayView2<'_, f64>,
140 weights: ArrayView1<'_, f64>,
141 right: ArrayView2<'_, f64>,
142 ) -> Option<Array2<f64>> {
143 weighted_crossprod_for(runtime.device.ordinal, left, weights, right)
144 }
145
146 #[inline]
147 fn weighted_crossprod_for(
148 ordinal: usize,
149 left: ArrayView2<'_, f64>,
150 weights: ArrayView1<'_, f64>,
151 right: ArrayView2<'_, f64>,
152 ) -> Option<Array2<f64>> {
153 let (rows, left_cols) = left.dim();
154 let (right_rows, right_cols) = right.dim();
155 if rows == 0
156 || left_cols == 0
157 || right_cols == 0
158 || rows != right_rows
159 || rows != weights.len()
160 {
161 return None;
162 }
163
164 let (stream, blas) = stream_and_blas_for(ordinal)?;
165 let same_operand = std::ptr::eq(left.as_ptr(), right.as_ptr())
172 && left.dim() == right.dim()
173 && left.strides() == right.strides();
174 let left_col = to_col_major(&left);
175 let weights_host = vector_values(weights);
176 let left_dev = stream.clone_htod(&*left_col).ok()?;
177 let right_dev = if same_operand {
181 None
182 } else {
183 let right_col = to_col_major(&right);
184 Some(stream.clone_htod(&*right_col).ok()?)
185 };
186 let weights_dev = stream.clone_htod(&weights_host).ok()?;
187 let mut weighted_right_dev = stream
188 .alloc_zeros::<f64>(rows.checked_mul(right_cols)?)
189 .ok()?;
190 row_scale_device(
191 &blas,
192 &stream,
193 right_dev.as_ref().unwrap_or(&left_dev),
194 &weights_dev,
195 &mut weighted_right_dev,
196 rows,
197 right_cols,
198 )?;
199
200 let mut out_dev = stream
201 .alloc_zeros::<f64>(left_cols.checked_mul(right_cols)?)
202 .ok()?;
203 let cfg = GemmConfig::<f64> {
204 transa: cublasOperation_t::CUBLAS_OP_T,
205 transb: cublasOperation_t::CUBLAS_OP_N,
206 m: to_i32(left_cols)?,
207 n: to_i32(right_cols)?,
208 k: to_i32(rows)?,
209 alpha: 1.0,
210 lda: to_i32(rows)?,
211 ldb: to_i32(rows)?,
212 beta: 0.0,
213 ldc: to_i32(left_cols)?,
214 };
215 unsafe { blas.gemm(cfg, &left_dev, &weighted_right_dev, &mut out_dev) }.ok()?;
218 let out_col = stream.clone_dtoh(&out_dev).ok()?;
219 from_col_major(&out_col, left_cols, right_cols)
220 }
221
222 pub(crate) struct ResidentWeightedGram {
237 stream: Arc<CudaStream>,
238 blas: CudaBlas,
239 x_dev: CudaSlice<f64>,
240 rows: usize,
241 cols: usize,
242 }
243
244 impl ResidentWeightedGram {
245 pub(crate) fn new(ordinal: usize, x: ArrayView2<'_, f64>) -> Option<Self> {
249 let (rows, cols) = x.dim();
250 if rows == 0 || cols == 0 {
251 return None;
252 }
253 let (stream, blas) = stream_and_blas_for(ordinal)?;
254 let x_col = to_col_major(&x);
255 let x_dev = stream.clone_htod(&*x_col).ok()?;
256 Some(Self {
257 stream,
258 blas,
259 x_dev,
260 rows,
261 cols,
262 })
263 }
264
265 #[inline]
266 pub(crate) fn dims(&self) -> (usize, usize) {
267 (self.rows, self.cols)
268 }
269
270 pub(crate) fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
275 if w.len() != self.rows {
276 return None;
277 }
278 let weights_host = vector_values(w);
279 let weights_dev = self.stream.clone_htod(&weights_host).ok()?;
280 let mut weighted_dev = self
281 .stream
282 .alloc_zeros::<f64>(self.rows.checked_mul(self.cols)?)
283 .ok()?;
284 row_scale_device(
285 &self.blas,
286 &self.stream,
287 &self.x_dev,
288 &weights_dev,
289 &mut weighted_dev,
290 self.rows,
291 self.cols,
292 )?;
293 let mut out_dev = self
294 .stream
295 .alloc_zeros::<f64>(self.cols.checked_mul(self.cols)?)
296 .ok()?;
297 let cfg = GemmConfig::<f64> {
298 transa: cublasOperation_t::CUBLAS_OP_T,
299 transb: cublasOperation_t::CUBLAS_OP_N,
300 m: to_i32(self.cols)?,
301 n: to_i32(self.cols)?,
302 k: to_i32(self.rows)?,
303 alpha: 1.0,
304 lda: to_i32(self.rows)?,
305 ldb: to_i32(self.rows)?,
306 beta: 0.0,
307 ldc: to_i32(self.cols)?,
308 };
309 unsafe {
312 self.blas
313 .gemm(cfg, &self.x_dev, &weighted_dev, &mut out_dev)
314 }
315 .ok()?;
316 let out_col = self.stream.clone_dtoh(&out_dev).ok()?;
317 from_col_major(&out_col, self.cols, self.cols)
318 }
319
320 pub(crate) fn solve_psd_normal_equations(
339 &self,
340 w: ArrayView1<'_, f64>,
341 rhs: ArrayView1<'_, f64>,
342 ridge: f64,
343 ) -> Option<Array1<f64>> {
344 if w.len() != self.rows || rhs.len() != self.cols {
345 return None;
346 }
347 let p = self.cols;
348
349 let weights_dev = self.stream.clone_htod(&vector_values(w)).ok()?;
351 let mut weighted_dev = self
352 .stream
353 .alloc_zeros::<f64>(self.rows.checked_mul(p)?)
354 .ok()?;
355 row_scale_device(
356 &self.blas,
357 &self.stream,
358 &self.x_dev,
359 &weights_dev,
360 &mut weighted_dev,
361 self.rows,
362 p,
363 )?;
364
365 let mut ridge_init = vec![0.0_f64; p.checked_mul(p)?];
372 for i in 0..p {
373 ridge_init[i * p + i] = ridge;
374 }
375 let mut g_dev = self.stream.clone_htod(&ridge_init).ok()?;
376 let cfg = GemmConfig::<f64> {
377 transa: cublasOperation_t::CUBLAS_OP_T,
378 transb: cublasOperation_t::CUBLAS_OP_N,
379 m: to_i32(p)?,
380 n: to_i32(p)?,
381 k: to_i32(self.rows)?,
382 alpha: 1.0,
383 lda: to_i32(self.rows)?,
384 ldb: to_i32(self.rows)?,
385 beta: 1.0,
387 ldc: to_i32(p)?,
388 };
389 unsafe { self.blas.gemm(cfg, &self.x_dev, &weighted_dev, &mut g_dev) }.ok()?;
392
393 let solver = DnHandle::new(self.stream.clone()).ok()?;
395 let info = potrf_single_dev(&solver, &self.stream, p, &mut g_dev)?;
396 if info != 0 {
397 return None;
399 }
400
401 let mut rhs_dev = self.stream.clone_htod(&vector_values(rhs)).ok()?;
403 trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, false)?; trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, true)?; let beta_host = self.stream.clone_dtoh(&rhs_dev).ok()?;
408 Some(Array1::from_vec(beta_host))
409 }
410 }
411
412 fn potrf_single_dev(
416 solver: &DnHandle,
417 stream: &Arc<CudaStream>,
418 p: usize,
419 matrix: &mut CudaSlice<f64>,
420 ) -> Option<i32> {
421 let p_i = to_i32(p)?;
422 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
423 let mut lwork = 0_i32;
424 {
425 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
426 let status = unsafe {
428 cusolver_sys::cusolverDnDpotrf_bufferSize(
429 solver.cu(),
430 uplo,
431 p_i,
432 mat_ptr as *mut f64,
433 p_i,
434 &mut lwork,
435 )
436 };
437 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
438 return None;
439 }
440 }
441 let mut workspace = stream.alloc_zeros::<f64>(lwork.max(1) as usize).ok()?;
442 let mut info_dev = stream.alloc_zeros::<i32>(1).ok()?;
443 {
444 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
445 let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
446 let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
447 let status = unsafe {
449 cusolver_sys::cusolverDnDpotrf(
450 solver.cu(),
451 uplo,
452 p_i,
453 mat_ptr as *mut f64,
454 p_i,
455 work_ptr as *mut f64,
456 lwork,
457 info_ptr as *mut i32,
458 )
459 };
460 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
461 return None;
462 }
463 }
464 let info_host = stream.clone_dtoh(&info_dev).ok()?;
465 info_host.first().copied()
466 }
467
468 fn trsm_single_vec(
472 blas: &CudaBlas,
473 stream: &Arc<CudaStream>,
474 p: usize,
475 l: &CudaSlice<f64>,
476 rhs: &mut CudaSlice<f64>,
477 transposed: bool,
478 ) -> Option<()> {
479 let alpha = 1.0_f64;
480 let p_i = to_i32(p)?;
481 let handle = *blas.handle();
482 let (l_ptr, _l_rec) = l.device_ptr(stream);
483 let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
484 let status = unsafe {
486 cudarc::cublas::sys::cublasDtrsm_v2(
487 handle,
488 cublasSideMode_t::CUBLAS_SIDE_LEFT,
489 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
490 if transposed {
491 cublasOperation_t::CUBLAS_OP_T
492 } else {
493 cublasOperation_t::CUBLAS_OP_N
494 },
495 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
496 p_i,
497 1,
498 &alpha,
499 l_ptr as *const f64,
500 p_i,
501 rhs_ptr as *mut f64,
502 p_i,
503 )
504 };
505 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
506 Some(())
507 } else {
508 None
509 }
510 }
511
512 #[inline]
513 fn assign_block(
514 out: &mut Array2<f64>,
515 row_offset: usize,
516 col_offset: usize,
517 block: &Array2<f64>,
518 ) {
519 let (rows, cols) = block.dim();
520 for col in 0..cols {
521 for row in 0..rows {
522 out[[row_offset + row, col_offset + col]] = block[[row, col]];
523 }
524 }
525 }
526
527 #[inline]
528 fn mirror_upper_to_lower(out: &mut Array2<f64>) {
529 let n = out.nrows();
530 for row in 0..n {
531 for col in 0..row {
532 out[[row, col]] = out[[col, row]];
533 }
534 }
535 }
536
537 #[inline]
538 pub(crate) fn gemm_cuda(
539 runtime: &GpuRuntime,
540 a: ArrayView2<'_, f64>,
541 b: ArrayView2<'_, f64>,
542 trans_a: bool,
543 trans_b: bool,
544 ) -> Option<Array2<f64>> {
545 gemm_on_ordinal_cuda(runtime.device.ordinal, a, b, trans_a, trans_b)
546 }
547
548 #[inline]
554 pub(crate) fn gemm_on_ordinal_cuda(
555 ordinal: usize,
556 a: ArrayView2<'_, f64>,
557 b: ArrayView2<'_, f64>,
558 trans_a: bool,
559 trans_b: bool,
560 ) -> Option<Array2<f64>> {
561 let (a_rows, a_cols) = a.dim();
562 let (b_rows, b_cols) = b.dim();
563 let (m, k_a) = if trans_a {
564 (a_cols, a_rows)
565 } else {
566 (a_rows, a_cols)
567 };
568 let (k_b, n) = if trans_b {
569 (b_cols, b_rows)
570 } else {
571 (b_rows, b_cols)
572 };
573 if m == 0 || n == 0 || k_a == 0 || k_a != k_b {
574 return None;
575 }
576 let (stream, blas) = stream_and_blas_for(ordinal)?;
577 let b_rm = to_row_major(&b);
599 let a_rm = to_row_major(&a);
600 let x_dev = stream.clone_htod(&*b_rm).ok()?;
601 let y_dev = stream.clone_htod(&*a_rm).ok()?;
602 let mut out_dev = stream.alloc_zeros::<f64>(m.checked_mul(n)?).ok()?;
603 let cfg = GemmConfig::<f64> {
604 transa: if trans_b {
605 cublasOperation_t::CUBLAS_OP_T
606 } else {
607 cublasOperation_t::CUBLAS_OP_N
608 },
609 transb: if trans_a {
610 cublasOperation_t::CUBLAS_OP_T
611 } else {
612 cublasOperation_t::CUBLAS_OP_N
613 },
614 m: to_i32(n)?,
615 n: to_i32(m)?,
616 k: to_i32(k_a)?,
617 alpha: 1.0,
618 lda: to_i32(b_cols)?,
621 ldb: to_i32(a_cols)?,
622 beta: 0.0,
623 ldc: to_i32(n)?,
624 };
625 unsafe { blas.gemm(cfg, &x_dev, &y_dev, &mut out_dev) }.ok()?;
628 let out_rm = stream.clone_dtoh(&out_dev).ok()?;
630 array_from_row_major(out_rm, m, n)
631 }
632
633 #[inline]
638 pub(crate) fn gemm_broadcast_b_batched_cuda(
639 ordinal: usize,
640 a: ArrayView3<'_, f64>,
641 b: ArrayView2<'_, f64>,
642 ) -> Option<Array3<f64>> {
643 let (batch, m, k) = a.dim();
644 let (b_rows, n) = b.dim();
645 if batch == 0 || m == 0 || n == 0 || k == 0 || b_rows != k {
646 return None;
647 }
648 let (stream, blas) = stream_and_blas_for(ordinal)?;
649 let a_col = to_col_major_batch(a);
650 let b_col = to_col_major(&b);
651 let a_dev = stream.clone_htod(&a_col).ok()?;
652 let b_dev = stream.clone_htod(&*b_col).ok()?;
653 let mut out_dev = stream
654 .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
655 .ok()?;
656 let cfg = StridedBatchedConfig::<f64> {
657 gemm: GemmConfig::<f64> {
658 transa: cublasOperation_t::CUBLAS_OP_N,
659 transb: cublasOperation_t::CUBLAS_OP_N,
660 m: to_i32(m)?,
661 n: to_i32(n)?,
662 k: to_i32(k)?,
663 alpha: 1.0,
664 lda: to_i32(m)?,
665 ldb: to_i32(k)?,
666 beta: 0.0,
667 ldc: to_i32(m)?,
668 },
669 batch_size: to_i32(batch)?,
670 stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
671 stride_b: 0,
672 stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
673 };
674 unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
678 let out_col = stream.clone_dtoh(&out_dev).ok()?;
679 from_col_major_batch(&out_col, batch, m, n)
680 }
681
682 #[inline]
686 pub(crate) fn gemm_abt_strided_batched_cuda(
687 ordinal: usize,
688 a: ArrayView3<'_, f64>,
689 b: ArrayView3<'_, f64>,
690 ) -> Option<Array3<f64>> {
691 let (batch, m, k) = a.dim();
692 let (batch_b, n, k_b) = b.dim();
693 if batch == 0 || m == 0 || n == 0 || k == 0 || batch != batch_b || k != k_b {
694 return None;
695 }
696 let (stream, blas) = stream_and_blas_for(ordinal)?;
697 let a_col = to_col_major_batch(a);
698 let b_col = to_col_major_batch(b);
699 let a_dev = stream.clone_htod(&a_col).ok()?;
700 let b_dev = stream.clone_htod(&b_col).ok()?;
701 let mut out_dev = stream
702 .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
703 .ok()?;
704 let cfg = StridedBatchedConfig::<f64> {
705 gemm: GemmConfig::<f64> {
706 transa: cublasOperation_t::CUBLAS_OP_N,
707 transb: cublasOperation_t::CUBLAS_OP_T,
708 m: to_i32(m)?,
709 n: to_i32(n)?,
710 k: to_i32(k)?,
711 alpha: 1.0,
712 lda: to_i32(m)?,
713 ldb: to_i32(n)?,
714 beta: 0.0,
715 ldc: to_i32(m)?,
716 },
717 batch_size: to_i32(batch)?,
718 stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
719 stride_b: i64::try_from(n.checked_mul(k)?).ok()?,
720 stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
721 };
722 unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
725 let out_col = stream.clone_dtoh(&out_dev).ok()?;
726 from_col_major_batch(&out_col, batch, m, n)
727 }
728
729 #[inline]
730 pub(crate) fn gemv_cuda(
731 runtime: &GpuRuntime,
732 a: ArrayView2<'_, f64>,
733 v: ArrayView1<'_, f64>,
734 trans_a: bool,
735 ) -> Option<Array1<f64>> {
736 let (rows, cols) = a.dim();
737 let out_len = if trans_a { cols } else { rows };
738 let needed = if trans_a { rows } else { cols };
739 if out_len == 0 || needed == 0 || v.len() != needed {
740 return None;
741 }
742 let (stream, blas) = stream_and_blas(runtime)?;
743 let a_col = to_col_major(&a);
744 let a_dev = stream.clone_htod(&*a_col).ok()?;
745 let v_host = vector_values(v);
746 let v_dev = stream.clone_htod(&v_host).ok()?;
747 let mut out_dev = stream.alloc_zeros::<f64>(out_len).ok()?;
748 let cfg = GemvConfig::<f64> {
749 trans: if trans_a {
750 cublasOperation_t::CUBLAS_OP_T
751 } else {
752 cublasOperation_t::CUBLAS_OP_N
753 },
754 m: to_i32(rows)?,
755 n: to_i32(cols)?,
756 alpha: 1.0,
757 lda: to_i32(rows)?,
758 incx: 1,
759 beta: 0.0,
760 incy: 1,
761 };
762 unsafe { blas.gemv(cfg, &a_dev, &v_dev, &mut out_dev) }.ok()?;
764 Some(Array1::from_vec(stream.clone_dtoh(&out_dev).ok()?))
765 }
766
767 #[inline]
768 pub fn xt_diag_x_cuda(
769 runtime: &GpuRuntime,
770 x: ArrayView2<'_, f64>,
771 w: ArrayView1<'_, f64>,
772 ) -> Option<Array2<f64>> {
773 let (rows, cols) = x.dim();
774 if rows == 0 || cols == 0 || rows != w.len() {
775 return None;
776 }
777 weighted_crossprod(runtime, x, w, x)
778 }
779
780 #[inline]
781 pub(crate) fn xt_diag_x_on_ordinal_cuda(
782 ordinal: usize,
783 x: ArrayView2<'_, f64>,
784 w: ArrayView1<'_, f64>,
785 ) -> Option<Array2<f64>> {
786 let (rows, cols) = x.dim();
787 if rows == 0 || cols == 0 || rows != w.len() {
788 return None;
789 }
790 weighted_crossprod_for(ordinal, x, w, x)
791 }
792
793 #[inline]
794 pub fn xt_diag_y_cuda(
795 runtime: &GpuRuntime,
796 x: ArrayView2<'_, f64>,
797 w: ArrayView1<'_, f64>,
798 y: ArrayView2<'_, f64>,
799 ) -> Option<Array2<f64>> {
800 weighted_crossprod(runtime, x, w, y)
801 }
802
803 #[inline]
804 pub(crate) fn joint_hessian_2x2_cuda(
805 runtime: &GpuRuntime,
806 x_a: ArrayView2<'_, f64>,
807 x_b: ArrayView2<'_, f64>,
808 w_aa: ArrayView1<'_, f64>,
809 w_ab: ArrayView1<'_, f64>,
810 w_bb: ArrayView1<'_, f64>,
811 ) -> Option<Array2<f64>> {
812 let (rows, pa) = x_a.dim();
813 let (rows_b, pb) = x_b.dim();
814 let total = pa.checked_add(pb)?;
815 if rows == 0
816 || total == 0
817 || rows != rows_b
818 || rows != w_aa.len()
819 || rows != w_ab.len()
820 || rows != w_bb.len()
821 {
822 return None;
823 }
824
825 let mut out = Array2::<f64>::zeros((total, total));
826 if pa > 0 {
827 let aa = weighted_crossprod(runtime, x_a, w_aa, x_a)?;
828 assign_block(&mut out, 0, 0, &aa);
829 }
830 if pa > 0 && pb > 0 {
831 let ab = weighted_crossprod(runtime, x_a, w_ab, x_b)?;
832 assign_block(&mut out, 0, pa, &ab);
833 }
834 if pb > 0 {
835 let bb = weighted_crossprod(runtime, x_b, w_bb, x_b)?;
836 assign_block(&mut out, pa, pa, &bb);
837 }
838 mirror_upper_to_lower(&mut out);
839 Some(out)
840 }
841
842 #[inline]
843 pub(crate) fn trsm_cuda(
844 runtime: &GpuRuntime,
845 triangular: ArrayView2<'_, f64>,
846 rhs: ArrayView2<'_, f64>,
847 upper: bool,
848 ) -> Option<Array2<f64>> {
849 let (n, n2) = triangular.dim();
850 if n == 0 || n != n2 || rhs.nrows() != n {
851 return None;
852 }
853 let nrhs = rhs.ncols();
854 let (stream, blas) = stream_and_blas(runtime)?;
855 let tri_col = to_col_major(&triangular);
856 let rhs_col = to_col_major(&rhs);
857 let tri_dev = stream.clone_htod(&*tri_col).ok()?;
858 let mut rhs_dev = stream.clone_htod(&*rhs_col).ok()?;
859 let alpha = 1.0_f64;
860 let handle = *blas.handle();
861 {
862 let (tri_ptr, _tri_record) = tri_dev.device_ptr(&stream);
863 let (rhs_ptr, _rhs_record) = rhs_dev.device_ptr_mut(&stream);
864 let status = unsafe {
867 cudarc::cublas::sys::cublasDtrsm_v2(
868 handle,
869 cublasSideMode_t::CUBLAS_SIDE_LEFT,
870 if upper {
871 cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
872 } else {
873 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
874 },
875 cublasOperation_t::CUBLAS_OP_N,
876 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
877 to_i32(n)?,
878 to_i32(nrhs)?,
879 &alpha,
880 tri_ptr as *const f64,
881 to_i32(n)?,
882 rhs_ptr as *mut f64,
883 to_i32(n)?,
884 )
885 };
886 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
887 return None;
888 }
889 };
890 let out_col = stream.clone_dtoh(&rhs_dev).ok()?;
891 from_col_major(&out_col, n, nrhs)
892 }
893}
894
895#[cfg(target_os = "linux")]
896pub(crate) use cuda_impl::{
897 ResidentWeightedGram, gemm_abt_strided_batched_cuda, gemm_broadcast_b_batched_cuda, gemm_cuda,
898 gemm_on_ordinal_cuda, gemv_cuda, joint_hessian_2x2_cuda, trsm_cuda, xt_diag_x_on_ordinal_cuda,
899};
900#[cfg(target_os = "linux")]
907pub use cuda_impl::{xt_diag_x_cuda, xt_diag_y_cuda};