1pub fn blas_backend_status() -> super::CudaBackendStatus {
18 super::cuda_backend_status()
19}
20
21#[cfg(target_os = "linux")]
22mod cuda_impl {
23 use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, Axis};
24
25 use crate::driver::{from_col_major, to_col_major, to_i32};
26
27 use super::super::device_runtime::GpuRuntime;
28 use cudarc::cublas::sys::{
29 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
30 };
31 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig, StridedBatchedConfig};
32 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
33 use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
34 use std::sync::Arc;
35
36 #[inline]
43 pub(crate) fn stream_and_blas_for(ordinal: usize) -> Option<(Arc<CudaStream>, CudaBlas)> {
44 let stream = super::super::device_runtime::cuda_context_for(ordinal)?
45 .new_stream()
46 .ok()?;
47 let blas = CudaBlas::new(stream.clone()).ok()?;
48 Some((stream, blas))
49 }
50
51 #[inline]
52 fn stream_and_blas(runtime: &GpuRuntime) -> Option<(Arc<CudaStream>, CudaBlas)> {
53 stream_and_blas_for(runtime.device.ordinal)
54 }
55
56 #[inline]
57 fn vector_values(v: ArrayView1<'_, f64>) -> Vec<f64> {
58 v.iter().copied().collect()
59 }
60
61 #[inline]
62 fn to_col_major_batch(batch: ArrayView3<'_, f64>) -> Vec<f64> {
63 let (batch_len, rows, cols) = batch.dim();
64 let mut out = Vec::with_capacity(batch_len.saturating_mul(rows).saturating_mul(cols));
65 for matrix in batch.axis_iter(Axis(0)) {
66 out.extend(to_col_major(&matrix).iter().copied());
67 }
68 out
69 }
70
71 #[inline]
72 fn from_col_major_batch(
73 data: &[f64],
74 batch: usize,
75 rows: usize,
76 cols: usize,
77 ) -> Option<Array3<f64>> {
78 if data.len() != batch.checked_mul(rows)?.checked_mul(cols)? {
79 return None;
80 }
81 let mut out = Array3::<f64>::zeros((batch, rows, cols));
82 let matrix_len = rows.checked_mul(cols)?;
83 for batch_idx in 0..batch {
84 let base = batch_idx.checked_mul(matrix_len)?;
85 for col in 0..cols {
86 for row in 0..rows {
87 out[[batch_idx, row, col]] = data[base + col * rows + row];
88 }
89 }
90 }
91 Some(out)
92 }
93
94 #[inline]
95 fn row_scale_device(
96 blas: &CudaBlas,
97 stream: &Arc<CudaStream>,
98 matrix_dev: &CudaSlice<f64>,
99 weights_dev: &CudaSlice<f64>,
100 scaled_dev: &mut CudaSlice<f64>,
101 rows: usize,
102 cols: usize,
103 ) -> Option<()> {
104 let rows_i = to_i32(rows)?;
105 let cols_i = to_i32(cols)?;
106 let handle = *blas.handle();
107 let (matrix_ptr, _matrix_record) = matrix_dev.device_ptr(stream);
108 let (weights_ptr, _weights_record) = weights_dev.device_ptr(stream);
109 let (scaled_ptr, _scaled_record) = scaled_dev.device_ptr_mut(stream);
110 let status = unsafe {
114 cudarc::cublas::sys::cublasDdgmm(
115 handle,
116 cublasSideMode_t::CUBLAS_SIDE_LEFT,
117 rows_i,
118 cols_i,
119 matrix_ptr as *const f64,
120 rows_i,
121 weights_ptr as *const f64,
122 1,
123 scaled_ptr as *mut f64,
124 rows_i,
125 )
126 };
127 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
128 Some(())
129 } else {
130 None
131 }
132 }
133
134 #[inline]
135 fn weighted_crossprod(
136 runtime: &GpuRuntime,
137 left: ArrayView2<'_, f64>,
138 weights: ArrayView1<'_, f64>,
139 right: ArrayView2<'_, f64>,
140 ) -> Option<Array2<f64>> {
141 weighted_crossprod_for(runtime.device.ordinal, left, weights, right)
142 }
143
144 #[inline]
145 fn weighted_crossprod_for(
146 ordinal: usize,
147 left: ArrayView2<'_, f64>,
148 weights: ArrayView1<'_, f64>,
149 right: ArrayView2<'_, f64>,
150 ) -> Option<Array2<f64>> {
151 let (rows, left_cols) = left.dim();
152 let (right_rows, right_cols) = right.dim();
153 if rows == 0
154 || left_cols == 0
155 || right_cols == 0
156 || rows != right_rows
157 || rows != weights.len()
158 {
159 return None;
160 }
161
162 let (stream, blas) = stream_and_blas_for(ordinal)?;
163 let same_operand = std::ptr::eq(left.as_ptr(), right.as_ptr())
170 && left.dim() == right.dim()
171 && left.strides() == right.strides();
172 let left_col = to_col_major(&left);
173 let weights_host = vector_values(weights);
174 let left_dev = stream.clone_htod(&*left_col).ok()?;
175 let right_dev = if same_operand {
179 None
180 } else {
181 let right_col = to_col_major(&right);
182 Some(stream.clone_htod(&*right_col).ok()?)
183 };
184 let weights_dev = stream.clone_htod(&weights_host).ok()?;
185 let mut weighted_right_dev = stream
186 .alloc_zeros::<f64>(rows.checked_mul(right_cols)?)
187 .ok()?;
188 row_scale_device(
189 &blas,
190 &stream,
191 right_dev.as_ref().unwrap_or(&left_dev),
192 &weights_dev,
193 &mut weighted_right_dev,
194 rows,
195 right_cols,
196 )?;
197
198 let mut out_dev = stream
199 .alloc_zeros::<f64>(left_cols.checked_mul(right_cols)?)
200 .ok()?;
201 let cfg = GemmConfig::<f64> {
202 transa: cublasOperation_t::CUBLAS_OP_T,
203 transb: cublasOperation_t::CUBLAS_OP_N,
204 m: to_i32(left_cols)?,
205 n: to_i32(right_cols)?,
206 k: to_i32(rows)?,
207 alpha: 1.0,
208 lda: to_i32(rows)?,
209 ldb: to_i32(rows)?,
210 beta: 0.0,
211 ldc: to_i32(left_cols)?,
212 };
213 unsafe { blas.gemm(cfg, &left_dev, &weighted_right_dev, &mut out_dev) }.ok()?;
216 let out_col = stream.clone_dtoh(&out_dev).ok()?;
217 from_col_major(&out_col, left_cols, right_cols)
218 }
219
220 pub(crate) struct ResidentWeightedGram {
235 stream: Arc<CudaStream>,
236 blas: CudaBlas,
237 x_dev: CudaSlice<f64>,
238 rows: usize,
239 cols: usize,
240 }
241
242 impl ResidentWeightedGram {
243 pub(crate) fn new(ordinal: usize, x: ArrayView2<'_, f64>) -> Option<Self> {
247 let (rows, cols) = x.dim();
248 if rows == 0 || cols == 0 {
249 return None;
250 }
251 let (stream, blas) = stream_and_blas_for(ordinal)?;
252 let x_col = to_col_major(&x);
253 let x_dev = stream.clone_htod(&*x_col).ok()?;
254 Some(Self {
255 stream,
256 blas,
257 x_dev,
258 rows,
259 cols,
260 })
261 }
262
263 #[inline]
264 pub(crate) fn dims(&self) -> (usize, usize) {
265 (self.rows, self.cols)
266 }
267
268 pub(crate) fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
273 if w.len() != self.rows {
274 return None;
275 }
276 let weights_host = vector_values(w);
277 let weights_dev = self.stream.clone_htod(&weights_host).ok()?;
278 let mut weighted_dev = self
279 .stream
280 .alloc_zeros::<f64>(self.rows.checked_mul(self.cols)?)
281 .ok()?;
282 row_scale_device(
283 &self.blas,
284 &self.stream,
285 &self.x_dev,
286 &weights_dev,
287 &mut weighted_dev,
288 self.rows,
289 self.cols,
290 )?;
291 let mut out_dev = self
292 .stream
293 .alloc_zeros::<f64>(self.cols.checked_mul(self.cols)?)
294 .ok()?;
295 let cfg = GemmConfig::<f64> {
296 transa: cublasOperation_t::CUBLAS_OP_T,
297 transb: cublasOperation_t::CUBLAS_OP_N,
298 m: to_i32(self.cols)?,
299 n: to_i32(self.cols)?,
300 k: to_i32(self.rows)?,
301 alpha: 1.0,
302 lda: to_i32(self.rows)?,
303 ldb: to_i32(self.rows)?,
304 beta: 0.0,
305 ldc: to_i32(self.cols)?,
306 };
307 unsafe {
310 self.blas
311 .gemm(cfg, &self.x_dev, &weighted_dev, &mut out_dev)
312 }
313 .ok()?;
314 let out_col = self.stream.clone_dtoh(&out_dev).ok()?;
315 from_col_major(&out_col, self.cols, self.cols)
316 }
317
318 pub(crate) fn solve_psd_normal_equations(
337 &self,
338 w: ArrayView1<'_, f64>,
339 rhs: ArrayView1<'_, f64>,
340 ridge: f64,
341 ) -> Option<Array1<f64>> {
342 if w.len() != self.rows || rhs.len() != self.cols {
343 return None;
344 }
345 let p = self.cols;
346
347 let weights_dev = self.stream.clone_htod(&vector_values(w)).ok()?;
349 let mut weighted_dev = self
350 .stream
351 .alloc_zeros::<f64>(self.rows.checked_mul(p)?)
352 .ok()?;
353 row_scale_device(
354 &self.blas,
355 &self.stream,
356 &self.x_dev,
357 &weights_dev,
358 &mut weighted_dev,
359 self.rows,
360 p,
361 )?;
362
363 let mut ridge_init = vec![0.0_f64; p.checked_mul(p)?];
370 for i in 0..p {
371 ridge_init[i * p + i] = ridge;
372 }
373 let mut g_dev = self.stream.clone_htod(&ridge_init).ok()?;
374 let cfg = GemmConfig::<f64> {
375 transa: cublasOperation_t::CUBLAS_OP_T,
376 transb: cublasOperation_t::CUBLAS_OP_N,
377 m: to_i32(p)?,
378 n: to_i32(p)?,
379 k: to_i32(self.rows)?,
380 alpha: 1.0,
381 lda: to_i32(self.rows)?,
382 ldb: to_i32(self.rows)?,
383 beta: 1.0,
385 ldc: to_i32(p)?,
386 };
387 unsafe { self.blas.gemm(cfg, &self.x_dev, &weighted_dev, &mut g_dev) }.ok()?;
390
391 let solver = DnHandle::new(self.stream.clone()).ok()?;
393 let info = potrf_single_dev(&solver, &self.stream, p, &mut g_dev)?;
394 if info != 0 {
395 return None;
397 }
398
399 let mut rhs_dev = self.stream.clone_htod(&vector_values(rhs)).ok()?;
401 trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, false)?; trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, true)?; let beta_host = self.stream.clone_dtoh(&rhs_dev).ok()?;
406 Some(Array1::from_vec(beta_host))
407 }
408 }
409
410 fn potrf_single_dev(
414 solver: &DnHandle,
415 stream: &Arc<CudaStream>,
416 p: usize,
417 matrix: &mut CudaSlice<f64>,
418 ) -> Option<i32> {
419 let p_i = to_i32(p)?;
420 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
421 let mut lwork = 0_i32;
422 {
423 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
424 let status = unsafe {
426 cusolver_sys::cusolverDnDpotrf_bufferSize(
427 solver.cu(),
428 uplo,
429 p_i,
430 mat_ptr as *mut f64,
431 p_i,
432 &mut lwork,
433 )
434 };
435 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
436 return None;
437 }
438 }
439 let mut workspace = stream.alloc_zeros::<f64>(lwork.max(1) as usize).ok()?;
440 let mut info_dev = stream.alloc_zeros::<i32>(1).ok()?;
441 {
442 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
443 let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
444 let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
445 let status = unsafe {
447 cusolver_sys::cusolverDnDpotrf(
448 solver.cu(),
449 uplo,
450 p_i,
451 mat_ptr as *mut f64,
452 p_i,
453 work_ptr as *mut f64,
454 lwork,
455 info_ptr as *mut i32,
456 )
457 };
458 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
459 return None;
460 }
461 }
462 let info_host = stream.clone_dtoh(&info_dev).ok()?;
463 info_host.first().copied()
464 }
465
466 fn trsm_single_vec(
470 blas: &CudaBlas,
471 stream: &Arc<CudaStream>,
472 p: usize,
473 l: &CudaSlice<f64>,
474 rhs: &mut CudaSlice<f64>,
475 transposed: bool,
476 ) -> Option<()> {
477 let alpha = 1.0_f64;
478 let p_i = to_i32(p)?;
479 let handle = *blas.handle();
480 let (l_ptr, _l_rec) = l.device_ptr(stream);
481 let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
482 let status = unsafe {
484 cudarc::cublas::sys::cublasDtrsm_v2(
485 handle,
486 cublasSideMode_t::CUBLAS_SIDE_LEFT,
487 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
488 if transposed {
489 cublasOperation_t::CUBLAS_OP_T
490 } else {
491 cublasOperation_t::CUBLAS_OP_N
492 },
493 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
494 p_i,
495 1,
496 &alpha,
497 l_ptr as *const f64,
498 p_i,
499 rhs_ptr as *mut f64,
500 p_i,
501 )
502 };
503 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
504 Some(())
505 } else {
506 None
507 }
508 }
509
510 #[inline]
511 fn assign_block(
512 out: &mut Array2<f64>,
513 row_offset: usize,
514 col_offset: usize,
515 block: &Array2<f64>,
516 ) {
517 let (rows, cols) = block.dim();
518 for col in 0..cols {
519 for row in 0..rows {
520 out[[row_offset + row, col_offset + col]] = block[[row, col]];
521 }
522 }
523 }
524
525 #[inline]
526 fn mirror_upper_to_lower(out: &mut Array2<f64>) {
527 let n = out.nrows();
528 for row in 0..n {
529 for col in 0..row {
530 out[[row, col]] = out[[col, row]];
531 }
532 }
533 }
534
535 #[inline]
536 pub(crate) fn gemm_cuda(
537 runtime: &GpuRuntime,
538 a: ArrayView2<'_, f64>,
539 b: ArrayView2<'_, f64>,
540 trans_a: bool,
541 trans_b: bool,
542 ) -> Option<Array2<f64>> {
543 gemm_on_ordinal_cuda(runtime.device.ordinal, a, b, trans_a, trans_b)
544 }
545
546 #[inline]
552 pub(crate) fn gemm_on_ordinal_cuda(
553 ordinal: usize,
554 a: ArrayView2<'_, f64>,
555 b: ArrayView2<'_, f64>,
556 trans_a: bool,
557 trans_b: bool,
558 ) -> Option<Array2<f64>> {
559 let (a_rows, a_cols) = a.dim();
560 let (b_rows, b_cols) = b.dim();
561 let (m, k_a) = if trans_a {
562 (a_cols, a_rows)
563 } else {
564 (a_rows, a_cols)
565 };
566 let (k_b, n) = if trans_b {
567 (b_cols, b_rows)
568 } else {
569 (b_rows, b_cols)
570 };
571 if m == 0 || n == 0 || k_a == 0 || k_a != k_b {
572 return None;
573 }
574 let (stream, blas) = stream_and_blas_for(ordinal)?;
575 let a_col = to_col_major(&a);
576 let b_col = to_col_major(&b);
577 let a_dev = stream.clone_htod(&*a_col).ok()?;
578 let b_dev = stream.clone_htod(&*b_col).ok()?;
579 let mut out_dev = stream.alloc_zeros::<f64>(m.checked_mul(n)?).ok()?;
580 let cfg = GemmConfig::<f64> {
581 transa: if trans_a {
582 cublasOperation_t::CUBLAS_OP_T
583 } else {
584 cublasOperation_t::CUBLAS_OP_N
585 },
586 transb: if trans_b {
587 cublasOperation_t::CUBLAS_OP_T
588 } else {
589 cublasOperation_t::CUBLAS_OP_N
590 },
591 m: to_i32(m)?,
592 n: to_i32(n)?,
593 k: to_i32(k_a)?,
594 alpha: 1.0,
595 lda: to_i32(a_rows)?,
596 ldb: to_i32(b_rows)?,
597 beta: 0.0,
598 ldc: to_i32(m)?,
599 };
600 unsafe { blas.gemm(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
602 let out_col = stream.clone_dtoh(&out_dev).ok()?;
603 from_col_major(&out_col, m, n)
604 }
605
606 #[inline]
611 pub(crate) fn gemm_broadcast_b_batched_cuda(
612 ordinal: usize,
613 a: ArrayView3<'_, f64>,
614 b: ArrayView2<'_, f64>,
615 ) -> Option<Array3<f64>> {
616 let (batch, m, k) = a.dim();
617 let (b_rows, n) = b.dim();
618 if batch == 0 || m == 0 || n == 0 || k == 0 || b_rows != k {
619 return None;
620 }
621 let (stream, blas) = stream_and_blas_for(ordinal)?;
622 let a_col = to_col_major_batch(a);
623 let b_col = to_col_major(&b);
624 let a_dev = stream.clone_htod(&a_col).ok()?;
625 let b_dev = stream.clone_htod(&*b_col).ok()?;
626 let mut out_dev = stream
627 .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
628 .ok()?;
629 let cfg = StridedBatchedConfig::<f64> {
630 gemm: GemmConfig::<f64> {
631 transa: cublasOperation_t::CUBLAS_OP_N,
632 transb: cublasOperation_t::CUBLAS_OP_N,
633 m: to_i32(m)?,
634 n: to_i32(n)?,
635 k: to_i32(k)?,
636 alpha: 1.0,
637 lda: to_i32(m)?,
638 ldb: to_i32(k)?,
639 beta: 0.0,
640 ldc: to_i32(m)?,
641 },
642 batch_size: to_i32(batch)?,
643 stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
644 stride_b: 0,
645 stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
646 };
647 unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
651 let out_col = stream.clone_dtoh(&out_dev).ok()?;
652 from_col_major_batch(&out_col, batch, m, n)
653 }
654
655 #[inline]
659 pub(crate) fn gemm_abt_strided_batched_cuda(
660 ordinal: usize,
661 a: ArrayView3<'_, f64>,
662 b: ArrayView3<'_, f64>,
663 ) -> Option<Array3<f64>> {
664 let (batch, m, k) = a.dim();
665 let (batch_b, n, k_b) = b.dim();
666 if batch == 0 || m == 0 || n == 0 || k == 0 || batch != batch_b || k != k_b {
667 return None;
668 }
669 let (stream, blas) = stream_and_blas_for(ordinal)?;
670 let a_col = to_col_major_batch(a);
671 let b_col = to_col_major_batch(b);
672 let a_dev = stream.clone_htod(&a_col).ok()?;
673 let b_dev = stream.clone_htod(&b_col).ok()?;
674 let mut out_dev = stream
675 .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
676 .ok()?;
677 let cfg = StridedBatchedConfig::<f64> {
678 gemm: GemmConfig::<f64> {
679 transa: cublasOperation_t::CUBLAS_OP_N,
680 transb: cublasOperation_t::CUBLAS_OP_T,
681 m: to_i32(m)?,
682 n: to_i32(n)?,
683 k: to_i32(k)?,
684 alpha: 1.0,
685 lda: to_i32(m)?,
686 ldb: to_i32(n)?,
687 beta: 0.0,
688 ldc: to_i32(m)?,
689 },
690 batch_size: to_i32(batch)?,
691 stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
692 stride_b: i64::try_from(n.checked_mul(k)?).ok()?,
693 stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
694 };
695 unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
698 let out_col = stream.clone_dtoh(&out_dev).ok()?;
699 from_col_major_batch(&out_col, batch, m, n)
700 }
701
702 #[inline]
703 pub(crate) fn gemv_cuda(
704 runtime: &GpuRuntime,
705 a: ArrayView2<'_, f64>,
706 v: ArrayView1<'_, f64>,
707 trans_a: bool,
708 ) -> Option<Array1<f64>> {
709 let (rows, cols) = a.dim();
710 let out_len = if trans_a { cols } else { rows };
711 let needed = if trans_a { rows } else { cols };
712 if out_len == 0 || needed == 0 || v.len() != needed {
713 return None;
714 }
715 let (stream, blas) = stream_and_blas(runtime)?;
716 let a_col = to_col_major(&a);
717 let a_dev = stream.clone_htod(&*a_col).ok()?;
718 let v_host = vector_values(v);
719 let v_dev = stream.clone_htod(&v_host).ok()?;
720 let mut out_dev = stream.alloc_zeros::<f64>(out_len).ok()?;
721 let cfg = GemvConfig::<f64> {
722 trans: if trans_a {
723 cublasOperation_t::CUBLAS_OP_T
724 } else {
725 cublasOperation_t::CUBLAS_OP_N
726 },
727 m: to_i32(rows)?,
728 n: to_i32(cols)?,
729 alpha: 1.0,
730 lda: to_i32(rows)?,
731 incx: 1,
732 beta: 0.0,
733 incy: 1,
734 };
735 unsafe { blas.gemv(cfg, &a_dev, &v_dev, &mut out_dev) }.ok()?;
737 Some(Array1::from_vec(stream.clone_dtoh(&out_dev).ok()?))
738 }
739
740 #[inline]
741 pub fn xt_diag_x_cuda(
742 runtime: &GpuRuntime,
743 x: ArrayView2<'_, f64>,
744 w: ArrayView1<'_, f64>,
745 ) -> Option<Array2<f64>> {
746 let (rows, cols) = x.dim();
747 if rows == 0 || cols == 0 || rows != w.len() {
748 return None;
749 }
750 weighted_crossprod(runtime, x, w, x)
751 }
752
753 #[inline]
754 pub(crate) fn xt_diag_x_on_ordinal_cuda(
755 ordinal: usize,
756 x: ArrayView2<'_, f64>,
757 w: ArrayView1<'_, f64>,
758 ) -> Option<Array2<f64>> {
759 let (rows, cols) = x.dim();
760 if rows == 0 || cols == 0 || rows != w.len() {
761 return None;
762 }
763 weighted_crossprod_for(ordinal, x, w, x)
764 }
765
766 #[inline]
767 pub fn xt_diag_y_cuda(
768 runtime: &GpuRuntime,
769 x: ArrayView2<'_, f64>,
770 w: ArrayView1<'_, f64>,
771 y: ArrayView2<'_, f64>,
772 ) -> Option<Array2<f64>> {
773 weighted_crossprod(runtime, x, w, y)
774 }
775
776 #[inline]
777 pub(crate) fn joint_hessian_2x2_cuda(
778 runtime: &GpuRuntime,
779 x_a: ArrayView2<'_, f64>,
780 x_b: ArrayView2<'_, f64>,
781 w_aa: ArrayView1<'_, f64>,
782 w_ab: ArrayView1<'_, f64>,
783 w_bb: ArrayView1<'_, f64>,
784 ) -> Option<Array2<f64>> {
785 let (rows, pa) = x_a.dim();
786 let (rows_b, pb) = x_b.dim();
787 let total = pa.checked_add(pb)?;
788 if rows == 0
789 || total == 0
790 || rows != rows_b
791 || rows != w_aa.len()
792 || rows != w_ab.len()
793 || rows != w_bb.len()
794 {
795 return None;
796 }
797
798 let mut out = Array2::<f64>::zeros((total, total));
799 if pa > 0 {
800 let aa = weighted_crossprod(runtime, x_a, w_aa, x_a)?;
801 assign_block(&mut out, 0, 0, &aa);
802 }
803 if pa > 0 && pb > 0 {
804 let ab = weighted_crossprod(runtime, x_a, w_ab, x_b)?;
805 assign_block(&mut out, 0, pa, &ab);
806 }
807 if pb > 0 {
808 let bb = weighted_crossprod(runtime, x_b, w_bb, x_b)?;
809 assign_block(&mut out, pa, pa, &bb);
810 }
811 mirror_upper_to_lower(&mut out);
812 Some(out)
813 }
814
815 #[inline]
816 pub(crate) fn trsm_cuda(
817 runtime: &GpuRuntime,
818 triangular: ArrayView2<'_, f64>,
819 rhs: ArrayView2<'_, f64>,
820 upper: bool,
821 ) -> Option<Array2<f64>> {
822 let (n, n2) = triangular.dim();
823 if n == 0 || n != n2 || rhs.nrows() != n {
824 return None;
825 }
826 let nrhs = rhs.ncols();
827 let (stream, blas) = stream_and_blas(runtime)?;
828 let tri_col = to_col_major(&triangular);
829 let rhs_col = to_col_major(&rhs);
830 let tri_dev = stream.clone_htod(&*tri_col).ok()?;
831 let mut rhs_dev = stream.clone_htod(&*rhs_col).ok()?;
832 let alpha = 1.0_f64;
833 let handle = *blas.handle();
834 {
835 let (tri_ptr, _tri_record) = tri_dev.device_ptr(&stream);
836 let (rhs_ptr, _rhs_record) = rhs_dev.device_ptr_mut(&stream);
837 let status = unsafe {
840 cudarc::cublas::sys::cublasDtrsm_v2(
841 handle,
842 cublasSideMode_t::CUBLAS_SIDE_LEFT,
843 if upper {
844 cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
845 } else {
846 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
847 },
848 cublasOperation_t::CUBLAS_OP_N,
849 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
850 to_i32(n)?,
851 to_i32(nrhs)?,
852 &alpha,
853 tri_ptr as *const f64,
854 to_i32(n)?,
855 rhs_ptr as *mut f64,
856 to_i32(n)?,
857 )
858 };
859 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
860 return None;
861 }
862 };
863 let out_col = stream.clone_dtoh(&rhs_dev).ok()?;
864 from_col_major(&out_col, n, nrhs)
865 }
866}
867
868#[cfg(target_os = "linux")]
869pub(crate) use cuda_impl::{
870 ResidentWeightedGram, gemm_abt_strided_batched_cuda, gemm_broadcast_b_batched_cuda, gemm_cuda,
871 gemm_on_ordinal_cuda, gemv_cuda, joint_hessian_2x2_cuda, trsm_cuda, xt_diag_x_on_ordinal_cuda,
872};
873#[cfg(target_os = "linux")]
880pub use cuda_impl::{xt_diag_x_cuda, xt_diag_y_cuda};