1use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
23
24use super::device_runtime::GpuRuntime;
25
26pub struct CudaGemmDispatch;
27
28impl gam_linalg::gpu_hook::GpuGemmDispatch for CudaGemmDispatch {
29 fn try_fast_atb(&self, a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
30 try_fast_atb(a, b)
31 }
32
33 fn try_fast_ab(&self, a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
34 try_fast_ab(a, b)
35 }
36
37 fn try_fast_av(&self, a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
38 try_fast_av(a, v)
39 }
40
41 fn try_fast_atv(&self, a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
42 try_fast_atv(a, v)
43 }
44
45 fn try_fast_xt_diag_x(
46 &self,
47 x: ArrayView2<'_, f64>,
48 w: ArrayView1<'_, f64>,
49 ) -> Option<Array2<f64>> {
50 try_fast_xt_diag_x(x, w)
51 }
52
53 fn try_fast_xt_diag_y(
54 &self,
55 x: ArrayView2<'_, f64>,
56 w: ArrayView1<'_, f64>,
57 y: ArrayView2<'_, f64>,
58 ) -> Option<Array2<f64>> {
59 try_fast_xt_diag_y(x, w, y)
60 }
61
62 fn try_fast_joint_hessian_2x2(
63 &self,
64 x_a: ArrayView2<'_, f64>,
65 x_b: ArrayView2<'_, f64>,
66 w_aa: ArrayView1<'_, f64>,
67 w_ab: ArrayView1<'_, f64>,
68 w_bb: ArrayView1<'_, f64>,
69 ) -> Option<Array2<f64>> {
70 try_fast_joint_hessian_2x2(x_a, x_b, w_aa, w_ab, w_bb)
71 }
72
73 fn device_count(&self) -> usize {
74 GpuRuntime::global().map_or(0, |rt| rt.device_count())
75 }
76
77 fn try_fast_ab_broadcast_b_batched(
78 &self,
79 a3: ArrayView3<'_, f64>,
80 b: ArrayView2<'_, f64>,
81 ) -> Option<Array3<f64>> {
82 try_fast_ab_broadcast_b_batched(a3, b)
83 }
84}
85
86#[derive(Clone, Copy, Debug)]
89pub enum DispatchOp {
90 Gemm { m: usize, n: usize, k: usize },
92 BatchedGemm {
94 batch: usize,
95 m: usize,
96 n: usize,
97 k: usize,
98 },
99 Potrf { p: usize, batch: usize },
101 SmallDenseBatchedPotrf { p: usize, batch: usize },
106 Trsm { m: usize, n: usize },
108 Gemv { m: usize, k: usize },
110 XtDiagX { n: usize, p: usize },
112 XtDiagY { n: usize, px: usize, q: usize },
114 JointHessian2x2 { n: usize, pa: usize, pb: usize },
116}
117
118impl DispatchOp {
119 #[inline]
121 pub const fn flops(self) -> u128 {
122 match self {
123 Self::Gemm { m, n, k } => 2u128 * (m as u128) * (n as u128) * (k as u128),
124 Self::BatchedGemm { batch, m, n, k } => {
125 2u128 * (batch as u128) * (m as u128) * (n as u128) * (k as u128)
126 }
127 Self::Gemv { m, k } => 2u128 * (m as u128) * (k as u128),
128 Self::Potrf { p, batch } => (batch as u128) * (p as u128).pow(3) / 3,
129 Self::SmallDenseBatchedPotrf { p, batch } => (batch as u128) * (p as u128).pow(3) / 3,
130 Self::Trsm { m, n } => (m as u128) * (m as u128) * (n as u128),
131 Self::XtDiagX { n, p } => 2u128 * (n as u128) * (p as u128) * (p as u128),
132 Self::XtDiagY { n, px, q } => 2u128 * (n as u128) * (px as u128) * (q as u128),
133 Self::JointHessian2x2 { n, pa, pb } => {
134 let total = (pa as u128) + (pb as u128);
135 2u128 * (n as u128) * total * total
136 }
137 }
138 }
139}
140
141#[inline]
146#[must_use]
147pub fn route_through_gpu(op: DispatchOp) -> Option<&'static GpuRuntime> {
148 let runtime = GpuRuntime::global()?;
149 let policy = &runtime.policy;
150 let admit = match op {
151 DispatchOp::Gemm { m, n, k } => {
152 op.flops() >= (policy.gemm_min_flops as u128) && m.min(n).min(k) > 0
153 }
154 DispatchOp::BatchedGemm { batch, m, n, k } => {
155 op.flops() >= (policy.gemm_min_flops as u128) && batch > 1 && m.min(n).min(k) > 0
156 }
157 DispatchOp::Gemv { m, k } => {
158 op.flops() >= (policy.gemm_min_flops as u128) && m > 0 && k > 0
159 }
160 DispatchOp::Potrf { p, batch } => {
161 p > 0
162 && batch > 0
163 && (p >= policy.potrf_min_p
164 || (batch > 1 && op.flops() >= policy.gemm_min_flops as u128))
165 }
166 DispatchOp::SmallDenseBatchedPotrf { p, batch } => {
167 p > 0
168 && p <= policy.small_dense_batched_potrf_max_p
169 && batch >= policy.small_dense_batched_potrf_min_batch
170 }
171 DispatchOp::Trsm { m, n } => {
172 op.flops() >= (policy.gemm_min_flops as u128) && m > 0 && n > 0
173 }
174 DispatchOp::XtDiagX { n, p } => policy.xtwx_target_is_gpu(n, p, true),
175 DispatchOp::XtDiagY { n, px, q } => policy.xtwy_target_is_gpu(n, px, q, true),
176 DispatchOp::JointHessian2x2 { n, pa, pb } => {
177 n > 0 && (pa + pb) > 0 && op.flops() >= policy.gemm_min_flops as u128
178 }
179 };
180 if admit { Some(runtime) } else { None }
181}
182
183#[cfg(target_os = "linux")]
190const MULTI_GPU_BATCH_FLOOR: usize = 64;
191
192#[cfg(target_os = "linux")]
195#[inline]
196fn should_split_batch(batch: usize) -> bool {
197 GpuRuntime::global().is_some_and(|rt| rt.device_count() > 1) && batch >= MULTI_GPU_BATCH_FLOOR
198}
199
200#[inline]
201#[must_use]
202pub fn try_fast_ab_broadcast_b_batched(
203 a: ArrayView3<'_, f64>,
204 b: ArrayView2<'_, f64>,
205) -> Option<Array3<f64>> {
206 let (batch, m, k) = a.dim();
207 let (bk, n) = b.dim();
208 if k != bk || batch == 0 || m == 0 || n == 0 {
209 return None;
210 }
211 #[cfg(not(target_os = "linux"))]
212 {
213 return None;
214 }
215 #[cfg(target_os = "linux")]
216 {
217 let runtime = route_through_gpu(DispatchOp::BatchedGemm { batch, m, n, k })?;
218 if should_split_batch(batch) {
219 if let Some(out) = scatter_broadcast_b_batched(runtime, a, b, m, n) {
220 return Some(out);
221 }
222 }
225 cuda_backend::gemm_broadcast_b_batched(runtime.device.ordinal, a, b)
226 }
227}
228
229#[cfg(target_os = "linux")]
235fn scatter_broadcast_b_batched(
236 runtime: &GpuRuntime,
237 a: ArrayView3<'_, f64>,
238 b: ArrayView2<'_, f64>,
239 m: usize,
240 n: usize,
241) -> Option<Array3<f64>> {
242 let batch = a.dim().0;
243 let mut items: Vec<(Array2<f64>, Option<Array2<f64>>)> = (0..batch)
246 .map(|i| (a.index_axis(ndarray::Axis(0), i).to_owned(), None))
247 .collect();
248 super::pool::scatter_batched(runtime, &mut items, |ordinal, tile| {
249 let tile_batch = tile.len();
250 if tile_batch == 0 {
251 return Some(());
252 }
253 let k = b.dim().0;
254 let mut a_tile = Array3::<f64>::zeros((tile_batch, m, k));
255 for (idx, (a_i, _)) in tile.iter().enumerate() {
256 a_tile.index_axis_mut(ndarray::Axis(0), idx).assign(a_i);
257 }
258 let out = cuda_backend::gemm_broadcast_b_batched(ordinal, a_tile.view(), b)?;
259 for (idx, (_, slot)) in tile.iter_mut().enumerate() {
260 *slot = Some(out.index_axis(ndarray::Axis(0), idx).to_owned());
261 }
262 Some(())
263 })?;
264 stitch_batched(items, m, n)
265}
266
267#[inline]
268#[must_use]
269pub fn try_fast_abt_strided_batched(
270 a: ArrayView3<'_, f64>,
271 b: ArrayView3<'_, f64>,
272) -> Option<Array3<f64>> {
273 let (batch, m, k) = a.dim();
274 let (batch_b, n, k_b) = b.dim();
275 if batch != batch_b || k != k_b || batch == 0 || m == 0 || n == 0 {
276 return None;
277 }
278 #[cfg(not(target_os = "linux"))]
279 {
280 return None;
281 }
282 #[cfg(target_os = "linux")]
283 {
284 let runtime = route_through_gpu(DispatchOp::BatchedGemm { batch, m, n, k })?;
285 if should_split_batch(batch) {
286 if let Some(out) = scatter_abt_strided_batched(runtime, a, b, m, n) {
287 return Some(out);
288 }
289 }
290 cuda_backend::gemm_abt_strided_batched(runtime.device.ordinal, a, b)
291 }
292}
293
294#[cfg(target_os = "linux")]
299fn scatter_abt_strided_batched(
300 runtime: &GpuRuntime,
301 a: ArrayView3<'_, f64>,
302 b: ArrayView3<'_, f64>,
303 m: usize,
304 n: usize,
305) -> Option<Array3<f64>> {
306 let batch = a.dim().0;
307 let mut items: Vec<(Array2<f64>, Array2<f64>, Option<Array2<f64>>)> = (0..batch)
308 .map(|i| {
309 (
310 a.index_axis(ndarray::Axis(0), i).to_owned(),
311 b.index_axis(ndarray::Axis(0), i).to_owned(),
312 None,
313 )
314 })
315 .collect();
316 super::pool::scatter_batched(runtime, &mut items, |ordinal, tile| {
317 let tile_batch = tile.len();
318 if tile_batch == 0 {
319 return Some(());
320 }
321 let k = tile[0].0.dim().1;
322 let mut a_tile = Array3::<f64>::zeros((tile_batch, m, k));
323 let mut b_tile = Array3::<f64>::zeros((tile_batch, n, k));
324 for (idx, (a_i, b_i, _)) in tile.iter().enumerate() {
325 a_tile.index_axis_mut(ndarray::Axis(0), idx).assign(a_i);
326 b_tile.index_axis_mut(ndarray::Axis(0), idx).assign(b_i);
327 }
328 let out = cuda_backend::gemm_abt_strided_batched(ordinal, a_tile.view(), b_tile.view())?;
329 for (idx, (_, _, slot)) in tile.iter_mut().enumerate() {
330 *slot = Some(out.index_axis(ndarray::Axis(0), idx).to_owned());
331 }
332 Some(())
333 })?;
334 let slots: Vec<((), Option<Array2<f64>>)> =
335 items.into_iter().map(|(_, _, slot)| ((), slot)).collect();
336 stitch_batched(slots, m, n)
337}
338
339#[cfg(target_os = "linux")]
343fn stitch_batched<L>(
344 items: Vec<(L, Option<Array2<f64>>)>,
345 m: usize,
346 n: usize,
347) -> Option<Array3<f64>> {
348 let batch = items.len();
349 let mut out = Array3::<f64>::zeros((batch, m, n));
350 for (idx, (_, slot)) in items.into_iter().enumerate() {
351 let block = slot?;
352 if block.dim() != (m, n) {
353 return None;
354 }
355 out.index_axis_mut(ndarray::Axis(0), idx).assign(&block);
356 }
357 Some(out)
358}
359
360#[inline]
371#[must_use]
372pub fn try_fast_ab(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
373 let (m, k) = a.dim();
374 let (kb, n) = b.dim();
375 if k != kb {
376 return None;
377 }
378 let runtime = route_through_gpu(DispatchOp::Gemm { m, n, k });
384 let used_gpu = runtime.is_some();
385 super::profile::record(super::profile::KernelStat {
386 name: "try_fast_ab",
387 n: m,
388 p: n,
389 k,
390 flops_est: (DispatchOp::Gemm { m, n, k }.flops().min(usize::MAX as u128)) as usize,
391 gpu_ms: if used_gpu { Some(0.0) } else { None },
392 ..Default::default()
393 });
394 #[cfg(not(target_os = "linux"))]
395 {
396 None
397 }
398 #[cfg(target_os = "linux")]
399 {
400 let runtime = runtime?;
401 cuda_backend::gemm(runtime, a, b, false, false)
402 }
403}
404
405#[inline]
406#[must_use]
407pub fn try_fast_atb(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
408 let (n_a, p) = a.dim();
409 let (n_b, q) = b.dim();
410 if n_a != n_b || p == 0 || q == 0 {
411 return None;
412 }
413 #[cfg(not(target_os = "linux"))]
414 {
415 return None;
416 }
417 #[cfg(target_os = "linux")]
418 {
419 let runtime = route_through_gpu(DispatchOp::Gemm { m: p, n: q, k: n_a })?;
420 cuda_backend::gemm(runtime, a, b, true, false)
421 }
422}
423
424#[inline]
432#[must_use]
433pub fn try_fast_atb_on_ordinal(
434 ordinal: usize,
435 a: ArrayView2<'_, f64>,
436 b: ArrayView2<'_, f64>,
437) -> Option<Array2<f64>> {
438 let (n_a, p) = a.dim();
439 let (n_b, q) = b.dim();
440 if n_a != n_b || p == 0 || q == 0 {
441 return None;
442 }
443 #[cfg(not(target_os = "linux"))]
444 {
445 log::trace!(
451 "try_fast_atb_on_ordinal: CUDA unavailable off Linux; declining ordinal {ordinal}"
452 );
453 return None;
454 }
455 #[cfg(target_os = "linux")]
456 {
457 route_through_gpu(DispatchOp::Gemm { m: p, n: q, k: n_a })?;
469 cuda_backend::gemm_on_ordinal(ordinal, a, b, true, false)
470 }
471}
472
473#[inline]
474#[must_use]
475pub fn try_fast_av(a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
476 let (m, k) = a.dim();
477 if k != v.len() || m == 0 || k == 0 {
478 return None;
479 }
480 #[cfg(not(target_os = "linux"))]
481 {
482 return None;
483 }
484 #[cfg(target_os = "linux")]
485 {
486 let runtime = route_through_gpu(DispatchOp::Gemv { m, k })?;
487 cuda_backend::gemv(runtime, a, v, false)
488 }
489}
490
491#[inline]
492#[must_use]
493pub fn try_fast_atv(a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
494 let (n, p) = a.dim();
495 if n != v.len() || n == 0 || p == 0 {
496 return None;
497 }
498 #[cfg(not(target_os = "linux"))]
499 {
500 return None;
501 }
502 #[cfg(target_os = "linux")]
503 {
504 let runtime = route_through_gpu(DispatchOp::Gemv { m: p, k: n })?;
505 cuda_backend::gemv(runtime, a, v, true)
506 }
507}
508
509#[inline]
510#[must_use]
511pub fn try_fast_xt_diag_x(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
512 let (n, p) = x.dim();
513 if n != w.len() || n == 0 || p == 0 {
514 return None;
515 }
516 #[cfg(not(target_os = "linux"))]
517 {
518 return None;
519 }
520 #[cfg(target_os = "linux")]
521 {
522 let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
523 cuda_backend::xt_diag_x(runtime, x, w)
524 }
525}
526
527pub struct ResidentDesignGram {
547 #[cfg(target_os = "linux")]
548 inner: super::blas::ResidentWeightedGram,
549 #[cfg(not(target_os = "linux"))]
550 _never: std::convert::Infallible,
551}
552
553impl ResidentDesignGram {
554 #[must_use]
558 pub fn try_new(x: ArrayView2<'_, f64>) -> Option<Self> {
559 let (n, p) = x.dim();
560 if n == 0 || p == 0 {
561 return None;
562 }
563 #[cfg(not(target_os = "linux"))]
564 {
565 None
566 }
567 #[cfg(target_os = "linux")]
568 {
569 let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
570 let inner = super::blas::ResidentWeightedGram::new(runtime.device.ordinal, x)?;
571 Some(Self { inner })
572 }
573 }
574
575 #[must_use]
578 pub fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
579 #[cfg(not(target_os = "linux"))]
580 {
581 panic!(
587 "ResidentDesignGram cannot be constructed off CUDA (w.len()={})",
588 w.len()
589 )
590 }
591 #[cfg(target_os = "linux")]
592 {
593 self.inner.gram(w)
594 }
595 }
596
597 #[must_use]
609 pub fn solve_normal_equations(
610 &self,
611 w: ArrayView1<'_, f64>,
612 rhs: ArrayView1<'_, f64>,
613 ridge: f64,
614 ) -> Option<Array1<f64>> {
615 #[cfg(not(target_os = "linux"))]
616 {
617 panic!(
619 "ResidentDesignGram cannot be constructed off CUDA (w.len()={}, rhs.len()={}, ridge={ridge})",
620 w.len(),
621 rhs.len()
622 )
623 }
624 #[cfg(target_os = "linux")]
625 {
626 self.inner.solve_psd_normal_equations(w, rhs, ridge)
627 }
628 }
629
630 #[must_use]
632 pub fn dims(&self) -> (usize, usize) {
633 #[cfg(not(target_os = "linux"))]
634 {
635 panic!("ResidentDesignGram cannot be constructed off CUDA")
639 }
640 #[cfg(target_os = "linux")]
641 {
642 self.inner.dims()
643 }
644 }
645}
646
647#[cfg(target_os = "linux")]
653const LEVERAGE_CHUNKS_PER_DEVICE: usize = 4;
654
655#[cfg(target_os = "linux")]
660#[inline]
661fn leverage_chunk_rows(cols: usize, n_rows: usize) -> usize {
662 const TARGET_BYTES: usize = 8 * 1024 * 1024;
663 const MIN_CHUNK_ROWS: usize = 512;
664 let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
665 (TARGET_BYTES / bytes_per_row)
666 .max(MIN_CHUNK_ROWS)
667 .min(n_rows.max(1))
668}
669
670#[inline]
686#[must_use]
687pub fn try_fast_spectral_leverage_diagonal(
688 x: &gam_linalg::matrix::DesignMatrix,
689 g: ArrayView2<'_, f64>,
690) -> Option<Array1<f64>> {
691 let n = x.nrows();
692 let p = x.ncols();
693 let rank = g.ncols();
694 if n == 0 || p == 0 || rank == 0 || g.nrows() != p {
695 return None;
696 }
697 #[cfg(not(target_os = "linux"))]
698 {
699 return None;
700 }
701 #[cfg(target_os = "linux")]
702 {
703 let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
706 let device_count = runtime.device_count().max(1);
707 let byte_chunk = leverage_chunk_rows(p + rank, n);
708 let target_chunks = device_count
709 .saturating_mul(LEVERAGE_CHUNKS_PER_DEVICE)
710 .max(1);
711 let chunk_rows = byte_chunk.min(n.div_ceil(target_chunks).max(1)).max(1);
712
713 let mut tiles: Vec<(std::ops::Range<usize>, Option<Array1<f64>>)> = Vec::new();
716 let mut start = 0usize;
717 while start < n {
718 let end = (start + chunk_rows).min(n);
719 tiles.push((start..end, None));
720 start = end;
721 }
722
723 super::pool::scatter_batched(runtime, &mut tiles, |ordinal, tile| {
724 for (range, slot) in tile.iter_mut() {
725 let rows = x.try_row_chunk(range.clone()).ok()?;
726 let xg = cuda_backend::gemm_on_ordinal(ordinal, rows.view(), g, false, false)?;
727 let mut out = Array1::<f64>::zeros(range.end - range.start);
728 for (local, row) in xg.outer_iter().enumerate() {
729 out[local] = row.iter().map(|&v| v * v).sum();
730 }
731 *slot = Some(out);
732 }
733 Some(())
734 })?;
735
736 let mut h = Array1::<f64>::zeros(n);
737 for (range, slot) in tiles {
738 let vals = slot?;
739 if vals.len() != range.end - range.start {
740 return None;
741 }
742 h.slice_mut(ndarray::s![range]).assign(&vals);
743 }
744 Some(h)
745 }
746}
747
748#[inline]
749#[must_use]
750pub fn try_fast_xt_diag_y(
751 x: ArrayView2<'_, f64>,
752 w: ArrayView1<'_, f64>,
753 y: ArrayView2<'_, f64>,
754) -> Option<Array2<f64>> {
755 let (n, px) = x.dim();
756 let (n_y, q) = y.dim();
757 if n != n_y || n != w.len() || n == 0 || px == 0 || q == 0 {
758 return None;
759 }
760 #[cfg(not(target_os = "linux"))]
761 {
762 return None;
763 }
764 #[cfg(target_os = "linux")]
765 {
766 let runtime = route_through_gpu(DispatchOp::XtDiagY { n, px, q })?;
767 cuda_backend::xt_diag_y(runtime, x, w, y)
768 }
769}
770
771#[inline]
772#[must_use]
773pub fn try_fast_joint_hessian_2x2(
774 x_a: ArrayView2<'_, f64>,
775 x_b: ArrayView2<'_, f64>,
776 w_aa: ArrayView1<'_, f64>,
777 w_ab: ArrayView1<'_, f64>,
778 w_bb: ArrayView1<'_, f64>,
779) -> Option<Array2<f64>> {
780 let (n, pa) = x_a.dim();
781 let (n_b, pb) = x_b.dim();
782 if n != n_b || n != w_aa.len() || n != w_ab.len() || n != w_bb.len() || pa + pb == 0 {
783 return None;
784 }
785 #[cfg(not(target_os = "linux"))]
786 {
787 return None;
788 }
789 #[cfg(target_os = "linux")]
790 {
791 let runtime = route_through_gpu(DispatchOp::JointHessian2x2 { n, pa, pb })?;
792 cuda_backend::joint_hessian_2x2(runtime, x_a, x_b, w_aa, w_ab, w_bb)
793 }
794}
795
796#[inline]
797#[must_use]
798pub fn try_cholesky_lower_inplace(a: &mut Array2<f64>) -> Option<()> {
799 let p = a.nrows();
800 if p != a.ncols() {
801 return None;
802 }
803 #[cfg(not(target_os = "linux"))]
804 {
805 return None;
806 }
807 #[cfg(target_os = "linux")]
808 {
809 let runtime = route_through_gpu(DispatchOp::Potrf { p, batch: 1 })?;
810 let lower = cuda_backend::cholesky_lower(runtime, a.view())?;
811 *a = lower;
812 Some(())
813 }
814}
815
816#[inline]
817#[must_use]
818pub fn try_cholesky_batched_lower_inplace(matrices: &mut [Array2<f64>]) -> Option<()> {
819 let first = matrices.first()?;
820 let p = first.nrows();
821 if p == 0 || first.ncols() != p || matrices.iter().any(|matrix| matrix.dim() != (p, p)) {
822 return None;
823 }
824 #[cfg(not(target_os = "linux"))]
825 {
826 return None;
827 }
828 #[cfg(target_os = "linux")]
829 {
830 let batch = matrices.len();
831 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p, batch })
832 .or_else(|| route_through_gpu(DispatchOp::Potrf { p, batch }))?;
833 if should_split_batch(batch) {
834 let split = super::pool::scatter_batched(runtime, matrices, |ordinal, tile| {
840 cuda_backend::cholesky_batched_lower(ordinal, tile)
841 });
842 if split.is_some() {
843 return Some(());
844 }
845 }
846 cuda_backend::cholesky_batched_lower(runtime.device.ordinal, matrices)
847 }
848}
849
850#[inline]
851#[must_use]
852pub fn try_solve_lower_triangular_matrix(
853 lower: ArrayView2<'_, f64>,
854 rhs: ArrayView2<'_, f64>,
855) -> Option<Array2<f64>> {
856 let (m, n) = rhs.dim();
857 if m == 0 || n == 0 || lower.nrows() != m {
858 return None;
859 }
860 #[cfg(not(target_os = "linux"))]
861 {
862 return None;
863 }
864 #[cfg(target_os = "linux")]
865 {
866 let runtime = route_through_gpu(DispatchOp::Trsm { m, n })?;
867 cuda_backend::trsm(runtime, lower, rhs, false)
868 }
869}
870
871#[inline]
872#[must_use]
873pub fn try_solve_upper_triangular_matrix(
874 upper: ArrayView2<'_, f64>,
875 rhs: ArrayView2<'_, f64>,
876) -> Option<Array2<f64>> {
877 let (m, n) = rhs.dim();
878 if m == 0 || n == 0 || upper.nrows() != m {
879 return None;
880 }
881 #[cfg(not(target_os = "linux"))]
882 {
883 return None;
884 }
885 #[cfg(target_os = "linux")]
886 {
887 let runtime = route_through_gpu(DispatchOp::Trsm { m, n })?;
888 cuda_backend::trsm(runtime, upper, rhs, true)
889 }
890}
891
892#[cfg(test)]
893mod tests {
894 use super::{DispatchOp, route_through_gpu};
895 use crate::device_runtime::GpuRuntime;
896
897 #[test]
898 fn sae_shape_dispatch_ops_route_when_cuda_runtime_is_present() {
899 let Some(runtime) = GpuRuntime::global() else {
900 eprintln!("[sae dispatch gate] no CUDA runtime - skipping branch-admission check");
901 return;
902 };
903
904 let n = 2_000usize;
905 let p = 2_048usize;
906 let m = 12usize;
907 let k = 8usize;
908 let dense_reduction_ops = [
909 DispatchOp::XtDiagX { n, p },
910 DispatchOp::XtDiagY { n, px: p, q: m * k },
911 DispatchOp::JointHessian2x2 {
912 n,
913 pa: p,
914 pb: m * k,
915 },
916 DispatchOp::Gemm {
917 m: p,
918 n: p,
919 k: n * m,
920 },
921 ];
922
923 for op in dense_reduction_ops {
924 assert!(
925 op.flops() >= runtime.policy.gemm_min_flops as u128,
926 "SAE dispatch fixture must clear the runtime GEMM work floor: op={op:?}, flops={}, floor={}",
927 op.flops(),
928 runtime.policy.gemm_min_flops
929 );
930 assert!(
931 route_through_gpu(op).is_some(),
932 "SAE dispatch fixture should route to GPU when CUDA is present: {op:?}"
933 );
934 }
935
936 let batched_potrf = DispatchOp::SmallDenseBatchedPotrf { p: m, batch: n };
937 assert!(
938 route_through_gpu(batched_potrf).is_some(),
939 "uniform SAE row blocks should reach the small-dense batched POTRF gate"
940 );
941 }
942}
943
944#[cfg(target_os = "linux")]
950mod cuda_backend {
951 use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
961
962 use super::super::device_runtime::GpuRuntime;
963 use crate::driver::{from_col_major, to_col_major, to_i32};
964 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
965 use cudarc::driver::{DevicePtrMut, sys as driver_sys};
966
967 #[inline]
968 pub(super) fn gemm(
969 runtime: &GpuRuntime,
970 a: ArrayView2<'_, f64>,
971 b: ArrayView2<'_, f64>,
972 trans_a: bool,
973 trans_b: bool,
974 ) -> Option<Array2<f64>> {
975 super::super::blas::gemm_cuda(runtime, a, b, trans_a, trans_b)
976 }
977
978 #[inline]
979 pub(super) fn gemm_on_ordinal(
980 ordinal: usize,
981 a: ArrayView2<'_, f64>,
982 b: ArrayView2<'_, f64>,
983 trans_a: bool,
984 trans_b: bool,
985 ) -> Option<Array2<f64>> {
986 super::super::blas::gemm_on_ordinal_cuda(ordinal, a, b, trans_a, trans_b)
987 }
988
989 #[inline]
990 pub(super) fn gemv(
991 runtime: &GpuRuntime,
992 a: ArrayView2<'_, f64>,
993 v: ArrayView1<'_, f64>,
994 trans_a: bool,
995 ) -> Option<Array1<f64>> {
996 super::super::blas::gemv_cuda(runtime, a, v, trans_a)
997 }
998
999 #[inline]
1000 pub(super) fn gemm_broadcast_b_batched(
1001 ordinal: usize,
1002 a: ArrayView3<'_, f64>,
1003 b: ArrayView2<'_, f64>,
1004 ) -> Option<Array3<f64>> {
1005 super::super::blas::gemm_broadcast_b_batched_cuda(ordinal, a, b)
1006 }
1007
1008 #[inline]
1009 pub(super) fn gemm_abt_strided_batched(
1010 ordinal: usize,
1011 a: ArrayView3<'_, f64>,
1012 b: ArrayView3<'_, f64>,
1013 ) -> Option<Array3<f64>> {
1014 super::super::blas::gemm_abt_strided_batched_cuda(ordinal, a, b)
1015 }
1016
1017 #[inline]
1018 pub(super) fn xt_diag_x(
1019 runtime: &GpuRuntime,
1020 x: ArrayView2<'_, f64>,
1021 w: ArrayView1<'_, f64>,
1022 ) -> Option<Array2<f64>> {
1023 super::super::blas::xt_diag_x_cuda(runtime, x, w)
1024 }
1025
1026 #[inline]
1027 pub(super) fn xt_diag_y(
1028 runtime: &GpuRuntime,
1029 x: ArrayView2<'_, f64>,
1030 w: ArrayView1<'_, f64>,
1031 y: ArrayView2<'_, f64>,
1032 ) -> Option<Array2<f64>> {
1033 super::super::blas::xt_diag_y_cuda(runtime, x, w, y)
1034 }
1035
1036 #[inline]
1037 pub(super) fn joint_hessian_2x2(
1038 runtime: &GpuRuntime,
1039 x_a: ArrayView2<'_, f64>,
1040 x_b: ArrayView2<'_, f64>,
1041 w_aa: ArrayView1<'_, f64>,
1042 w_ab: ArrayView1<'_, f64>,
1043 w_bb: ArrayView1<'_, f64>,
1044 ) -> Option<Array2<f64>> {
1045 super::super::blas::joint_hessian_2x2_cuda(runtime, x_a, x_b, w_aa, w_ab, w_bb)
1046 }
1047
1048 #[inline]
1049 pub(super) fn trsm(
1050 runtime: &GpuRuntime,
1051 triangular: ArrayView2<'_, f64>,
1052 rhs: ArrayView2<'_, f64>,
1053 upper: bool,
1054 ) -> Option<Array2<f64>> {
1055 super::super::blas::trsm_cuda(runtime, triangular, rhs, upper)
1056 }
1057
1058 #[inline]
1059 pub(super) fn cholesky_lower(
1060 runtime: &GpuRuntime,
1061 a: ArrayView2<'_, f64>,
1062 ) -> Option<Array2<f64>> {
1063 let (p, p2) = a.dim();
1064 if p == 0 || p != p2 {
1065 return None;
1066 }
1067 let stream = super::super::device_runtime::cuda_context_for(runtime.device.ordinal)?
1068 .new_stream()
1069 .ok()?;
1070 let solver = DnHandle::new(stream.clone()).ok()?;
1071 let a_col = to_col_major(&a);
1072 let mut a_dev = stream.clone_htod(&*a_col).ok()?;
1073 potrf_lower_in_place(&solver, &stream, p, &mut a_dev)?;
1074 let factor_col = stream.clone_dtoh(&a_dev).ok()?;
1075 let mut lower = from_col_major(&factor_col, p, p)?;
1076 for row in 0..p {
1077 for col in (row + 1)..p {
1078 lower[[row, col]] = 0.0;
1079 }
1080 }
1081 Some(lower)
1082 }
1083
1084 #[inline]
1088 pub(super) fn cholesky_batched_lower(
1089 ordinal: usize,
1090 matrices: &mut [Array2<f64>],
1091 ) -> Option<()> {
1092 let first = matrices.first()?;
1093 let p = first.nrows();
1094 if p == 0 || first.ncols() != p || matrices.iter().any(|matrix| matrix.dim() != (p, p)) {
1095 return None;
1096 }
1097
1098 let stream = super::super::device_runtime::cuda_context_for(ordinal)?
1099 .new_stream()
1100 .ok()?;
1101 let solver = DnHandle::new(stream.clone()).ok()?;
1102 let matrix_len = p.checked_mul(p)?;
1103 let mut batch_col = Vec::with_capacity(matrices.len().checked_mul(matrix_len)?);
1104 for matrix in matrices.iter() {
1105 batch_col.extend(to_col_major(&matrix.view()).iter().copied());
1106 }
1107 let mut matrices_dev = stream.clone_htod(&batch_col).ok()?;
1108 let matrix_ptrs = {
1109 let (base_ptr, _matrix_record) = matrices_dev.device_ptr_mut(&stream);
1110 let bytes_per_matrix = driver_sys::CUdeviceptr::try_from(
1111 matrix_len.checked_mul(std::mem::size_of::<f64>())?,
1112 )
1113 .ok()?;
1114 let mut matrix_ptrs = Vec::with_capacity(matrices.len());
1115 for idx in 0..matrices.len() {
1116 let offset = driver_sys::CUdeviceptr::try_from(idx).ok()? * bytes_per_matrix;
1117 matrix_ptrs.push(base_ptr + offset);
1118 }
1119 matrix_ptrs
1120 };
1121 let mut matrix_ptrs_dev = stream.clone_htod(&matrix_ptrs).ok()?;
1122 let mut info_dev = stream.alloc_zeros::<i32>(matrices.len()).ok()?;
1123 let p_i = to_i32(p)?;
1124 let batch_i = to_i32(matrices.len())?;
1125 {
1126 let (ptrs_ptr, _ptrs_record) = matrix_ptrs_dev.device_ptr_mut(&stream);
1127 let (info_ptr, _info_record) = info_dev.device_ptr_mut(&stream);
1128 let status = unsafe {
1132 cusolver_sys::cusolverDnDpotrfBatched(
1133 solver.cu(),
1134 cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1135 p_i,
1136 ptrs_ptr as *mut *mut f64,
1137 p_i,
1138 info_ptr as *mut i32,
1139 batch_i,
1140 )
1141 };
1142 check_cusolver(status)?;
1143 }
1144 let info_host = stream.clone_dtoh(&info_dev).ok()?;
1145 if info_host.iter().any(|info| *info != 0) {
1146 return None;
1147 }
1148 let factored_col = stream.clone_dtoh(&matrices_dev).ok()?;
1149 for (idx, matrix) in matrices.iter_mut().enumerate() {
1150 let start = idx.checked_mul(matrix_len)?;
1151 let end = start.checked_add(matrix_len)?;
1152 let mut lower = from_col_major(&factored_col[start..end], p, p)?;
1153 for row in 0..p {
1154 for col in (row + 1)..p {
1155 lower[[row, col]] = 0.0;
1156 }
1157 }
1158 *matrix = lower;
1159 }
1160 Some(())
1161 }
1162
1163 fn potrf_lower_in_place(
1169 solver: &DnHandle,
1170 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1171 p: usize,
1172 a: &mut cudarc::driver::CudaSlice<f64>,
1173 ) -> Option<()> {
1174 crate::solver::potrf_in_place_generic::<f64>(solver, stream, p, a).ok()
1175 }
1176
1177 #[inline]
1178 fn check_cusolver(status: cusolver_sys::cusolverStatus_t) -> Option<()> {
1179 if status == cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1180 Some(())
1181 } else {
1182 None
1183 }
1184 }
1185}