1use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
23
24use super::device_runtime::GpuRuntime;
25
26pub struct CudaGemmDispatch;
27
28impl gam_linalg::gpu_hook::GpuGemmDispatch for CudaGemmDispatch {
29 fn try_fast_atb(&self, a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
30 try_fast_atb(a, b)
31 }
32
33 fn try_fast_ab(&self, a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
34 try_fast_ab(a, b)
35 }
36
37 fn try_fast_av(&self, a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
38 try_fast_av(a, v)
39 }
40
41 fn try_fast_atv(&self, a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
42 try_fast_atv(a, v)
43 }
44
45 fn try_fast_xt_diag_x(
46 &self,
47 x: ArrayView2<'_, f64>,
48 w: ArrayView1<'_, f64>,
49 ) -> Option<Array2<f64>> {
50 try_fast_xt_diag_x(x, w)
51 }
52
53 fn try_fast_xt_diag_y(
54 &self,
55 x: ArrayView2<'_, f64>,
56 w: ArrayView1<'_, f64>,
57 y: ArrayView2<'_, f64>,
58 ) -> Option<Array2<f64>> {
59 try_fast_xt_diag_y(x, w, y)
60 }
61
62 fn try_fast_joint_hessian_2x2(
63 &self,
64 x_a: ArrayView2<'_, f64>,
65 x_b: ArrayView2<'_, f64>,
66 w_aa: ArrayView1<'_, f64>,
67 w_ab: ArrayView1<'_, f64>,
68 w_bb: ArrayView1<'_, f64>,
69 ) -> Option<Array2<f64>> {
70 try_fast_joint_hessian_2x2(x_a, x_b, w_aa, w_ab, w_bb)
71 }
72
73 fn device_count(&self) -> usize {
74 GpuRuntime::global().map_or(0, |rt| rt.device_count())
75 }
76
77 fn try_fast_ab_broadcast_b_batched(
78 &self,
79 a3: ArrayView3<'_, f64>,
80 b: ArrayView2<'_, f64>,
81 ) -> Option<Array3<f64>> {
82 try_fast_ab_broadcast_b_batched(a3, b)
83 }
84}
85
86#[derive(Clone, Copy, Debug)]
89pub enum DispatchOp {
90 Gemm { m: usize, n: usize, k: usize },
92 BatchedGemm {
94 batch: usize,
95 m: usize,
96 n: usize,
97 k: usize,
98 },
99 Potrf { p: usize, batch: usize },
101 SmallDenseBatchedPotrf { p: usize, batch: usize },
106 Trsm { m: usize, n: usize },
108 Gemv { m: usize, k: usize },
110 XtDiagX { n: usize, p: usize },
112 XtDiagY { n: usize, px: usize, q: usize },
114 JointHessian2x2 { n: usize, pa: usize, pb: usize },
116}
117
118impl DispatchOp {
119 #[inline]
121 pub const fn flops(self) -> u128 {
122 match self {
123 Self::Gemm { m, n, k } => 2u128 * (m as u128) * (n as u128) * (k as u128),
124 Self::BatchedGemm { batch, m, n, k } => {
125 2u128 * (batch as u128) * (m as u128) * (n as u128) * (k as u128)
126 }
127 Self::Gemv { m, k } => 2u128 * (m as u128) * (k as u128),
128 Self::Potrf { p, batch } => (batch as u128) * (p as u128).pow(3) / 3,
129 Self::SmallDenseBatchedPotrf { p, batch } => (batch as u128) * (p as u128).pow(3) / 3,
130 Self::Trsm { m, n } => (m as u128) * (m as u128) * (n as u128),
131 Self::XtDiagX { n, p } => 2u128 * (n as u128) * (p as u128) * (p as u128),
132 Self::XtDiagY { n, px, q } => 2u128 * (n as u128) * (px as u128) * (q as u128),
133 Self::JointHessian2x2 { n, pa, pb } => {
134 let total = (pa as u128) + (pb as u128);
135 2u128 * (n as u128) * total * total
136 }
137 }
138 }
139}
140
141#[inline]
146#[must_use]
147pub fn route_through_gpu(op: DispatchOp) -> Option<&'static GpuRuntime> {
148 let runtime = GpuRuntime::global()?;
149 let policy = &runtime.policy;
150 let admit = match op {
151 DispatchOp::Gemm { m, n, k } => {
152 op.flops() >= (policy.gemm_min_flops as u128) && m.min(n).min(k) > 0
153 }
154 DispatchOp::BatchedGemm { batch, m, n, k } => {
155 op.flops() >= (policy.gemm_min_flops as u128) && batch > 1 && m.min(n).min(k) > 0
156 }
157 DispatchOp::Gemv { m, k } => {
158 op.flops() >= (policy.gemm_min_flops as u128) && m > 0 && k > 0
159 }
160 DispatchOp::Potrf { p, batch } => {
161 p > 0
162 && batch > 0
163 && (p >= policy.potrf_min_p
164 || (batch > 1 && op.flops() >= policy.gemm_min_flops as u128))
165 }
166 DispatchOp::SmallDenseBatchedPotrf { p, batch } => {
167 p > 0
168 && p <= policy.small_dense_batched_potrf_max_p
169 && batch >= policy.small_dense_batched_potrf_min_batch
170 }
171 DispatchOp::Trsm { m, n } => {
172 op.flops() >= (policy.gemm_min_flops as u128) && m > 0 && n > 0
173 }
174 DispatchOp::XtDiagX { n, p } => policy.xtwx_target_is_gpu(n, p, true),
175 DispatchOp::XtDiagY { n, px, q } => policy.xtwy_target_is_gpu(n, px, q, true),
176 DispatchOp::JointHessian2x2 { n, pa, pb } => {
177 n > 0 && (pa + pb) > 0 && op.flops() >= policy.gemm_min_flops as u128
178 }
179 };
180 if admit { Some(runtime) } else { None }
181}
182
183#[cfg(target_os = "linux")]
190const MULTI_GPU_BATCH_FLOOR: usize = 64;
191
192#[cfg(target_os = "linux")]
195#[inline]
196fn should_split_batch(batch: usize) -> bool {
197 GpuRuntime::global().is_some_and(|rt| rt.device_count() > 1) && batch >= MULTI_GPU_BATCH_FLOOR
198}
199
200#[inline]
201#[must_use]
202pub fn try_fast_ab_broadcast_b_batched(
203 a: ArrayView3<'_, f64>,
204 b: ArrayView2<'_, f64>,
205) -> Option<Array3<f64>> {
206 let (batch, m, k) = a.dim();
207 let (bk, n) = b.dim();
208 if k != bk || batch == 0 || m == 0 || n == 0 {
209 return None;
210 }
211 #[cfg(not(target_os = "linux"))]
212 {
213 return None;
214 }
215 #[cfg(target_os = "linux")]
216 {
217 let runtime = route_through_gpu(DispatchOp::BatchedGemm { batch, m, n, k })?;
218 if should_split_batch(batch) {
219 if let Some(out) = scatter_broadcast_b_batched(runtime, a, b, m, n) {
220 return Some(out);
221 }
222 }
225 cuda_backend::gemm_broadcast_b_batched(runtime.device.ordinal, a, b)
226 }
227}
228
229#[cfg(target_os = "linux")]
235fn scatter_broadcast_b_batched(
236 runtime: &GpuRuntime,
237 a: ArrayView3<'_, f64>,
238 b: ArrayView2<'_, f64>,
239 m: usize,
240 n: usize,
241) -> Option<Array3<f64>> {
242 let batch = a.dim().0;
243 let mut items: Vec<(Array2<f64>, Option<Array2<f64>>)> = (0..batch)
246 .map(|i| (a.index_axis(ndarray::Axis(0), i).to_owned(), None))
247 .collect();
248 super::pool::scatter_batched(runtime, &mut items, |ordinal, tile| {
249 let tile_batch = tile.len();
250 if tile_batch == 0 {
251 return Some(());
252 }
253 let k = b.dim().0;
254 let mut a_tile = Array3::<f64>::zeros((tile_batch, m, k));
255 for (idx, (a_i, _)) in tile.iter().enumerate() {
256 a_tile.index_axis_mut(ndarray::Axis(0), idx).assign(a_i);
257 }
258 let out = cuda_backend::gemm_broadcast_b_batched(ordinal, a_tile.view(), b)?;
259 for (idx, (_, slot)) in tile.iter_mut().enumerate() {
260 *slot = Some(out.index_axis(ndarray::Axis(0), idx).to_owned());
261 }
262 Some(())
263 })?;
264 stitch_batched(items, m, n)
265}
266
267#[inline]
268#[must_use]
269pub fn try_fast_abt_strided_batched(
270 a: ArrayView3<'_, f64>,
271 b: ArrayView3<'_, f64>,
272) -> Option<Array3<f64>> {
273 let (batch, m, k) = a.dim();
274 let (batch_b, n, k_b) = b.dim();
275 if batch != batch_b || k != k_b || batch == 0 || m == 0 || n == 0 {
276 return None;
277 }
278 #[cfg(not(target_os = "linux"))]
279 {
280 return None;
281 }
282 #[cfg(target_os = "linux")]
283 {
284 let runtime = route_through_gpu(DispatchOp::BatchedGemm { batch, m, n, k })?;
285 if should_split_batch(batch) {
286 if let Some(out) = scatter_abt_strided_batched(runtime, a, b, m, n) {
287 return Some(out);
288 }
289 }
290 cuda_backend::gemm_abt_strided_batched(runtime.device.ordinal, a, b)
291 }
292}
293
294#[cfg(target_os = "linux")]
299fn scatter_abt_strided_batched(
300 runtime: &GpuRuntime,
301 a: ArrayView3<'_, f64>,
302 b: ArrayView3<'_, f64>,
303 m: usize,
304 n: usize,
305) -> Option<Array3<f64>> {
306 let batch = a.dim().0;
307 let mut items: Vec<(Array2<f64>, Array2<f64>, Option<Array2<f64>>)> = (0..batch)
308 .map(|i| {
309 (
310 a.index_axis(ndarray::Axis(0), i).to_owned(),
311 b.index_axis(ndarray::Axis(0), i).to_owned(),
312 None,
313 )
314 })
315 .collect();
316 super::pool::scatter_batched(runtime, &mut items, |ordinal, tile| {
317 let tile_batch = tile.len();
318 if tile_batch == 0 {
319 return Some(());
320 }
321 let k = tile[0].0.dim().1;
322 let mut a_tile = Array3::<f64>::zeros((tile_batch, m, k));
323 let mut b_tile = Array3::<f64>::zeros((tile_batch, n, k));
324 for (idx, (a_i, b_i, _)) in tile.iter().enumerate() {
325 a_tile.index_axis_mut(ndarray::Axis(0), idx).assign(a_i);
326 b_tile.index_axis_mut(ndarray::Axis(0), idx).assign(b_i);
327 }
328 let out = cuda_backend::gemm_abt_strided_batched(ordinal, a_tile.view(), b_tile.view())?;
329 for (idx, (_, _, slot)) in tile.iter_mut().enumerate() {
330 *slot = Some(out.index_axis(ndarray::Axis(0), idx).to_owned());
331 }
332 Some(())
333 })?;
334 let slots: Vec<((), Option<Array2<f64>>)> =
335 items.into_iter().map(|(_, _, slot)| ((), slot)).collect();
336 stitch_batched(slots, m, n)
337}
338
339#[cfg(target_os = "linux")]
343fn stitch_batched<L>(
344 items: Vec<(L, Option<Array2<f64>>)>,
345 m: usize,
346 n: usize,
347) -> Option<Array3<f64>> {
348 let batch = items.len();
349 let mut out = Array3::<f64>::zeros((batch, m, n));
350 for (idx, (_, slot)) in items.into_iter().enumerate() {
351 let block = slot?;
352 if block.dim() != (m, n) {
353 return None;
354 }
355 out.index_axis_mut(ndarray::Axis(0), idx).assign(&block);
356 }
357 Some(out)
358}
359
360#[inline]
371#[must_use]
372pub fn try_fast_ab(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
373 let (m, k) = a.dim();
374 let (kb, n) = b.dim();
375 if k != kb {
376 return None;
377 }
378 let runtime = route_through_gpu(DispatchOp::Gemm { m, n, k });
384 let used_gpu = runtime.is_some();
385 super::profile::record(super::profile::KernelStat {
386 name: "try_fast_ab",
387 n: m,
388 p: n,
389 k,
390 flops_est: (DispatchOp::Gemm { m, n, k }.flops().min(usize::MAX as u128)) as usize,
391 gpu_ms: if used_gpu { Some(0.0) } else { None },
392 ..Default::default()
393 });
394 #[cfg(not(target_os = "linux"))]
395 {
396 None
397 }
398 #[cfg(target_os = "linux")]
399 {
400 let runtime = runtime?;
401 cuda_backend::gemm(runtime, a, b, false, false)
402 }
403}
404
405#[inline]
406#[must_use]
407pub fn try_fast_atb(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
408 let (n_a, p) = a.dim();
409 let (n_b, q) = b.dim();
410 if n_a != n_b || p == 0 || q == 0 {
411 return None;
412 }
413 #[cfg(not(target_os = "linux"))]
414 {
415 return None;
416 }
417 #[cfg(target_os = "linux")]
418 {
419 let runtime = route_through_gpu(DispatchOp::Gemm { m: p, n: q, k: n_a })?;
420 cuda_backend::gemm(runtime, a, b, true, false)
421 }
422}
423
424#[inline]
432#[must_use]
433pub fn try_fast_atb_on_ordinal(
434 ordinal: usize,
435 a: ArrayView2<'_, f64>,
436 b: ArrayView2<'_, f64>,
437) -> Option<Array2<f64>> {
438 let (n_a, p) = a.dim();
439 let (n_b, q) = b.dim();
440 if n_a != n_b || p == 0 || q == 0 {
441 return None;
442 }
443 #[cfg(not(target_os = "linux"))]
444 {
445 log::trace!(
451 "try_fast_atb_on_ordinal: CUDA unavailable off Linux; declining ordinal {ordinal}"
452 );
453 return None;
454 }
455 #[cfg(target_os = "linux")]
456 {
457 route_through_gpu(DispatchOp::Gemm { m: p, n: q, k: n_a })?;
469 cuda_backend::gemm_on_ordinal(ordinal, a, b, true, false)
470 }
471}
472
473#[inline]
474#[must_use]
475pub fn try_fast_av(a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
476 let (m, k) = a.dim();
477 if k != v.len() || m == 0 || k == 0 {
478 return None;
479 }
480 #[cfg(not(target_os = "linux"))]
481 {
482 return None;
483 }
484 #[cfg(target_os = "linux")]
485 {
486 let runtime = route_through_gpu(DispatchOp::Gemv { m, k })?;
487 cuda_backend::gemv(runtime, a, v, false)
488 }
489}
490
491#[inline]
492#[must_use]
493pub fn try_fast_atv(a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
494 let (n, p) = a.dim();
495 if n != v.len() || n == 0 || p == 0 {
496 return None;
497 }
498 #[cfg(not(target_os = "linux"))]
499 {
500 return None;
501 }
502 #[cfg(target_os = "linux")]
503 {
504 let runtime = route_through_gpu(DispatchOp::Gemv { m: p, k: n })?;
505 cuda_backend::gemv(runtime, a, v, true)
506 }
507}
508
509#[inline]
510#[must_use]
511pub fn try_fast_xt_diag_x(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
512 let (n, p) = x.dim();
513 if n != w.len() || n == 0 || p == 0 {
514 return None;
515 }
516 #[cfg(not(target_os = "linux"))]
517 {
518 return None;
519 }
520 #[cfg(target_os = "linux")]
521 {
522 let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
523 cuda_backend::xt_diag_x(runtime, x, w)
524 }
525}
526
527pub struct ResidentDesignGram {
547 #[cfg(target_os = "linux")]
548 inner: super::blas::ResidentWeightedGram,
549 #[cfg(not(target_os = "linux"))]
550 _never: std::convert::Infallible,
551}
552
553impl ResidentDesignGram {
554 #[must_use]
558 pub fn try_new(x: ArrayView2<'_, f64>) -> Option<Self> {
559 let (n, p) = x.dim();
560 if n == 0 || p == 0 {
561 return None;
562 }
563 #[cfg(not(target_os = "linux"))]
564 {
565 None
566 }
567 #[cfg(target_os = "linux")]
568 {
569 let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
570 let inner = super::blas::ResidentWeightedGram::new(runtime.device.ordinal, x)?;
571 Some(Self { inner })
572 }
573 }
574
575 #[must_use]
578 pub fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
579 #[cfg(not(target_os = "linux"))]
580 {
581 panic!(
587 "ResidentDesignGram cannot be constructed off CUDA (w.len()={})",
588 w.len()
589 )
590 }
591 #[cfg(target_os = "linux")]
592 {
593 self.inner.gram(w)
594 }
595 }
596
597 #[must_use]
609 pub fn solve_normal_equations(
610 &self,
611 w: ArrayView1<'_, f64>,
612 rhs: ArrayView1<'_, f64>,
613 ridge: f64,
614 ) -> Option<Array1<f64>> {
615 #[cfg(not(target_os = "linux"))]
616 {
617 panic!(
619 "ResidentDesignGram cannot be constructed off CUDA (w.len()={}, rhs.len()={}, ridge={ridge})",
620 w.len(),
621 rhs.len()
622 )
623 }
624 #[cfg(target_os = "linux")]
625 {
626 self.inner.solve_psd_normal_equations(w, rhs, ridge)
627 }
628 }
629
630 #[must_use]
632 pub fn dims(&self) -> (usize, usize) {
633 #[cfg(not(target_os = "linux"))]
634 {
635 panic!("ResidentDesignGram cannot be constructed off CUDA")
639 }
640 #[cfg(target_os = "linux")]
641 {
642 self.inner.dims()
643 }
644 }
645}
646
647#[cfg(target_os = "linux")]
653const LEVERAGE_CHUNKS_PER_DEVICE: usize = 4;
654
655#[cfg(target_os = "linux")]
660#[inline]
661fn leverage_chunk_rows(cols: usize, n_rows: usize) -> usize {
662 const TARGET_BYTES: usize = 8 * 1024 * 1024;
663 const MIN_CHUNK_ROWS: usize = 512;
664 let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
665 (TARGET_BYTES / bytes_per_row)
666 .max(MIN_CHUNK_ROWS)
667 .min(n_rows.max(1))
668}
669
670#[inline]
686#[must_use]
687pub fn try_fast_spectral_leverage_diagonal(
688 x: &gam_linalg::matrix::DesignMatrix,
689 g: ArrayView2<'_, f64>,
690) -> Option<Array1<f64>> {
691 let n = x.nrows();
692 let p = x.ncols();
693 let rank = g.ncols();
694 if n == 0 || p == 0 || rank == 0 || g.nrows() != p {
695 return None;
696 }
697 #[cfg(not(target_os = "linux"))]
698 {
699 return None;
700 }
701 #[cfg(target_os = "linux")]
702 {
703 let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
706 let device_count = runtime.device_count().max(1);
707 let byte_chunk = leverage_chunk_rows(p + rank, n);
708 let target_chunks = device_count
709 .saturating_mul(LEVERAGE_CHUNKS_PER_DEVICE)
710 .max(1);
711 let chunk_rows = byte_chunk.min(n.div_ceil(target_chunks).max(1)).max(1);
712
713 let mut tiles: Vec<(std::ops::Range<usize>, Option<Array1<f64>>)> = Vec::new();
716 let mut start = 0usize;
717 while start < n {
718 let end = (start + chunk_rows).min(n);
719 tiles.push((start..end, None));
720 start = end;
721 }
722
723 super::pool::scatter_batched(runtime, &mut tiles, |ordinal, tile| {
724 for (range, slot) in tile.iter_mut() {
725 let rows = x.try_row_chunk(range.clone()).ok()?;
726 let xg = cuda_backend::gemm_on_ordinal(ordinal, rows.view(), g, false, false)?;
727 let mut out = Array1::<f64>::zeros(range.end - range.start);
728 for (local, row) in xg.outer_iter().enumerate() {
729 out[local] = row.iter().map(|&v| v * v).sum();
730 }
731 *slot = Some(out);
732 }
733 Some(())
734 })?;
735
736 let mut h = Array1::<f64>::zeros(n);
737 for (range, slot) in tiles {
738 let vals = slot?;
739 if vals.len() != range.end - range.start {
740 return None;
741 }
742 h.slice_mut(ndarray::s![range]).assign(&vals);
743 }
744 Some(h)
745 }
746}
747
748#[inline]
749#[must_use]
750pub fn try_fast_xt_diag_y(
751 x: ArrayView2<'_, f64>,
752 w: ArrayView1<'_, f64>,
753 y: ArrayView2<'_, f64>,
754) -> Option<Array2<f64>> {
755 let (n, px) = x.dim();
756 let (n_y, q) = y.dim();
757 if n != n_y || n != w.len() || n == 0 || px == 0 || q == 0 {
758 return None;
759 }
760 #[cfg(not(target_os = "linux"))]
761 {
762 return None;
763 }
764 #[cfg(target_os = "linux")]
765 {
766 let runtime = route_through_gpu(DispatchOp::XtDiagY { n, px, q })?;
767 cuda_backend::xt_diag_y(runtime, x, w, y)
768 }
769}
770
771#[inline]
772#[must_use]
773pub fn try_fast_joint_hessian_2x2(
774 x_a: ArrayView2<'_, f64>,
775 x_b: ArrayView2<'_, f64>,
776 w_aa: ArrayView1<'_, f64>,
777 w_ab: ArrayView1<'_, f64>,
778 w_bb: ArrayView1<'_, f64>,
779) -> Option<Array2<f64>> {
780 let (n, pa) = x_a.dim();
781 let (n_b, pb) = x_b.dim();
782 if n != n_b || n != w_aa.len() || n != w_ab.len() || n != w_bb.len() || pa + pb == 0 {
783 return None;
784 }
785 #[cfg(not(target_os = "linux"))]
786 {
787 return None;
788 }
789 #[cfg(target_os = "linux")]
790 {
791 let runtime = route_through_gpu(DispatchOp::JointHessian2x2 { n, pa, pb })?;
792 cuda_backend::joint_hessian_2x2(runtime, x_a, x_b, w_aa, w_ab, w_bb)
793 }
794}
795
796#[inline]
797#[must_use]
798pub fn try_cholesky_lower_inplace(a: &mut Array2<f64>) -> Option<()> {
799 let p = a.nrows();
800 if p != a.ncols() {
801 return None;
802 }
803 #[cfg(not(target_os = "linux"))]
804 {
805 return None;
806 }
807 #[cfg(target_os = "linux")]
808 {
809 let runtime = route_through_gpu(DispatchOp::Potrf { p, batch: 1 })?;
810 let lower = cuda_backend::cholesky_lower(runtime, a.view())?;
811 *a = lower;
812 Some(())
813 }
814}
815
816#[inline]
817#[must_use]
818pub fn try_cholesky_batched_lower_inplace(matrices: &mut [Array2<f64>]) -> Option<()> {
819 let first = matrices.first()?;
820 let p = first.nrows();
821 if p == 0 || first.ncols() != p || matrices.iter().any(|matrix| matrix.dim() != (p, p)) {
822 return None;
823 }
824 #[cfg(not(target_os = "linux"))]
825 {
826 return None;
827 }
828 #[cfg(target_os = "linux")]
829 {
830 let batch = matrices.len();
831 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p, batch })
832 .or_else(|| route_through_gpu(DispatchOp::Potrf { p, batch }))?;
833 if should_split_batch(batch) {
834 let split = super::pool::scatter_batched(runtime, matrices, |ordinal, tile| {
840 cuda_backend::cholesky_batched_lower(ordinal, tile)
841 });
842 if split.is_some() {
843 return Some(());
844 }
845 }
846 cuda_backend::cholesky_batched_lower(runtime.device.ordinal, matrices)
847 }
848}
849
850#[inline]
851#[must_use]
852pub fn try_solve_lower_triangular_matrix(
853 lower: ArrayView2<'_, f64>,
854 rhs: ArrayView2<'_, f64>,
855) -> Option<Array2<f64>> {
856 let (m, n) = rhs.dim();
857 if m == 0 || n == 0 || lower.nrows() != m {
858 return None;
859 }
860 #[cfg(not(target_os = "linux"))]
861 {
862 return None;
863 }
864 #[cfg(target_os = "linux")]
865 {
866 let runtime = route_through_gpu(DispatchOp::Trsm { m, n })?;
867 cuda_backend::trsm(runtime, lower, rhs, false)
868 }
869}
870
871#[inline]
872#[must_use]
873pub fn try_solve_upper_triangular_matrix(
874 upper: ArrayView2<'_, f64>,
875 rhs: ArrayView2<'_, f64>,
876) -> Option<Array2<f64>> {
877 let (m, n) = rhs.dim();
878 if m == 0 || n == 0 || upper.nrows() != m {
879 return None;
880 }
881 #[cfg(not(target_os = "linux"))]
882 {
883 return None;
884 }
885 #[cfg(target_os = "linux")]
886 {
887 let runtime = route_through_gpu(DispatchOp::Trsm { m, n })?;
888 cuda_backend::trsm(runtime, upper, rhs, true)
889 }
890}
891
892#[cfg(test)]
893mod tests {
894 use super::{DispatchOp, route_through_gpu, try_fast_ab};
895 use crate::device_runtime::GpuRuntime;
896
897 #[test]
898 fn sae_shape_dispatch_ops_route_when_cuda_runtime_is_present() {
899 let Some(runtime) = GpuRuntime::global() else {
900 eprintln!("[sae dispatch gate] no CUDA runtime - skipping branch-admission check");
901 return;
902 };
903
904 let n = 2_000usize;
905 let p = 2_048usize;
906 let m = 12usize;
907 let k = 8usize;
908 let dense_reduction_ops = [
909 DispatchOp::XtDiagX { n, p },
910 DispatchOp::XtDiagY { n, px: p, q: m * k },
911 DispatchOp::JointHessian2x2 {
912 n,
913 pa: p,
914 pb: m * k,
915 },
916 DispatchOp::Gemm {
917 m: p,
918 n: p,
919 k: n * m,
920 },
921 ];
922
923 for op in dense_reduction_ops {
924 assert!(
925 op.flops() >= runtime.policy.gemm_min_flops as u128,
926 "SAE dispatch fixture must clear the runtime GEMM work floor: op={op:?}, flops={}, floor={}",
927 op.flops(),
928 runtime.policy.gemm_min_flops
929 );
930 assert!(
931 route_through_gpu(op).is_some(),
932 "SAE dispatch fixture should route to GPU when CUDA is present: {op:?}"
933 );
934 }
935
936 let batched_potrf = DispatchOp::SmallDenseBatchedPotrf { p: m, batch: n };
937 assert!(
938 route_through_gpu(batched_potrf).is_some(),
939 "uniform SAE row blocks should reach the small-dense batched POTRF gate"
940 );
941 }
942
943 #[test]
950 fn global_runtime_installs_fast_ab_hook_and_matches_cpu() {
951 use ndarray::Array2;
952
953 let Some(_runtime) = GpuRuntime::global() else {
954 eprintln!("[fast_ab hook] no CUDA runtime - skipping engagement check");
955 return;
956 };
957 assert!(
959 gam_linalg::gpu_hook::gpu_dispatch().is_some(),
960 "GpuRuntime::global() returned a device but did not register the \
961 dense-GEMM dispatch hook — fast_ab would silently stay on the CPU"
962 );
963
964 let (m, k, n) = (512usize, 512usize, 512usize);
968 assert!(
969 route_through_gpu(DispatchOp::Gemm { m, n, k }).is_some(),
970 "a 268 MFLOP GEMM must clear the policy floor and route to GPU"
971 );
972
973 let a = Array2::<f64>::from_shape_fn((m, k), |(i, j)| {
975 ((i * 7 + j * 3) % 13) as f64 * 0.01 - 0.06
976 });
977 let b = Array2::<f64>::from_shape_fn((k, n), |(i, j)| {
978 ((i * 5 + j * 11) % 17) as f64 * 0.01 - 0.08
979 });
980
981 let gpu = try_fast_ab(a.view(), b.view())
983 .expect("profitable GEMM must produce a device result once admitted");
984
985 let mut cpu = Array2::<f64>::zeros((m, n));
987 for i in 0..m {
988 for j in 0..n {
989 let mut acc = 0.0f64;
990 for p in 0..k {
991 acc += a[[i, p]] * b[[p, j]];
992 }
993 cpu[[i, j]] = acc;
994 }
995 }
996
997 let mut max_abs = 0.0f64;
998 for i in 0..m {
999 for j in 0..n {
1000 max_abs = max_abs.max((gpu[[i, j]] - cpu[[i, j]]).abs());
1001 }
1002 }
1003 assert!(
1004 max_abs < 1e-9,
1005 "device GEMM disagreed with the CPU oracle: max|Δ| = {max_abs:e}"
1006 );
1007 }
1008
1009 #[cfg(target_os = "linux")]
1016 #[test]
1017 fn transpose_free_gemm_matches_cpu_all_trans_and_shapes() {
1018 use crate::blas::gemm_cuda;
1019 use ndarray::Array2;
1020
1021 let Some(runtime) = GpuRuntime::global() else {
1022 eprintln!("[gemm transpose-free] no CUDA runtime - skipping");
1023 return;
1024 };
1025
1026 let cases = [(6usize, 4usize, 5usize), (17, 23, 9), (200, 31, 7)];
1029 for (m, k, n) in cases {
1030 let mk = Array2::<f64>::from_shape_fn((m, k), |(i, j)| {
1034 ((i * 31 + j * 17) % 19) as f64 * 0.013 - 0.11
1035 });
1036 let km = Array2::<f64>::from_shape_fn((k, m), |(i, j)| {
1037 ((i * 13 + j * 29) % 23) as f64 * 0.011 - 0.07
1038 });
1039 let kn = Array2::<f64>::from_shape_fn((k, n), |(i, j)| {
1040 ((i * 7 + j * 5) % 17) as f64 * 0.017 - 0.09
1041 });
1042 let nk = Array2::<f64>::from_shape_fn((n, k), |(i, j)| {
1043 ((i * 19 + j * 11) % 13) as f64 * 0.015 - 0.05
1044 });
1045
1046 for &trans_a in &[false, true] {
1047 for &trans_b in &[false, true] {
1048 let a = if trans_a { &km } else { &mk };
1049 let b = if trans_b { &nk } else { &kn };
1050
1051 let gpu = gemm_cuda(runtime, a.view(), b.view(), trans_a, trans_b).expect(
1052 "transpose-free device GEMM must produce a result when a device is present",
1053 );
1054 assert_eq!(gpu.dim(), (m, n), "output shape wrong for trans_a={trans_a} trans_b={trans_b} ({m}×{k}×{n})");
1055
1056 let mut cpu = Array2::<f64>::zeros((m, n));
1058 for i in 0..m {
1059 for j in 0..n {
1060 let mut acc = 0.0f64;
1061 for p in 0..k {
1062 let av = if trans_a { a[[p, i]] } else { a[[i, p]] };
1063 let bv = if trans_b { b[[j, p]] } else { b[[p, j]] };
1064 acc += av * bv;
1065 }
1066 cpu[[i, j]] = acc;
1067 }
1068 }
1069
1070 let mut max_abs = 0.0f64;
1071 for i in 0..m {
1072 for j in 0..n {
1073 max_abs = max_abs.max((gpu[[i, j]] - cpu[[i, j]]).abs());
1074 }
1075 }
1076 assert!(
1077 max_abs < 1e-9,
1078 "transpose-free GEMM mismatch (trans_a={trans_a} trans_b={trans_b}, \
1079 {m}×{k}×{n}): max|Δ| = {max_abs:e}"
1080 );
1081 }
1082 }
1083 }
1084 }
1085}
1086
1087#[cfg(target_os = "linux")]
1093mod cuda_backend {
1094 use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
1104
1105 use super::super::device_runtime::GpuRuntime;
1106 use crate::driver::{from_col_major, to_col_major, to_i32};
1107 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1108 use cudarc::driver::{DevicePtrMut, sys as driver_sys};
1109
1110 #[inline]
1111 pub(super) fn gemm(
1112 runtime: &GpuRuntime,
1113 a: ArrayView2<'_, f64>,
1114 b: ArrayView2<'_, f64>,
1115 trans_a: bool,
1116 trans_b: bool,
1117 ) -> Option<Array2<f64>> {
1118 super::super::blas::gemm_cuda(runtime, a, b, trans_a, trans_b)
1119 }
1120
1121 #[inline]
1122 pub(super) fn gemm_on_ordinal(
1123 ordinal: usize,
1124 a: ArrayView2<'_, f64>,
1125 b: ArrayView2<'_, f64>,
1126 trans_a: bool,
1127 trans_b: bool,
1128 ) -> Option<Array2<f64>> {
1129 super::super::blas::gemm_on_ordinal_cuda(ordinal, a, b, trans_a, trans_b)
1130 }
1131
1132 #[inline]
1133 pub(super) fn gemv(
1134 runtime: &GpuRuntime,
1135 a: ArrayView2<'_, f64>,
1136 v: ArrayView1<'_, f64>,
1137 trans_a: bool,
1138 ) -> Option<Array1<f64>> {
1139 super::super::blas::gemv_cuda(runtime, a, v, trans_a)
1140 }
1141
1142 #[inline]
1143 pub(super) fn gemm_broadcast_b_batched(
1144 ordinal: usize,
1145 a: ArrayView3<'_, f64>,
1146 b: ArrayView2<'_, f64>,
1147 ) -> Option<Array3<f64>> {
1148 super::super::blas::gemm_broadcast_b_batched_cuda(ordinal, a, b)
1149 }
1150
1151 #[inline]
1152 pub(super) fn gemm_abt_strided_batched(
1153 ordinal: usize,
1154 a: ArrayView3<'_, f64>,
1155 b: ArrayView3<'_, f64>,
1156 ) -> Option<Array3<f64>> {
1157 super::super::blas::gemm_abt_strided_batched_cuda(ordinal, a, b)
1158 }
1159
1160 #[inline]
1161 pub(super) fn xt_diag_x(
1162 runtime: &GpuRuntime,
1163 x: ArrayView2<'_, f64>,
1164 w: ArrayView1<'_, f64>,
1165 ) -> Option<Array2<f64>> {
1166 super::super::blas::xt_diag_x_cuda(runtime, x, w)
1167 }
1168
1169 #[inline]
1170 pub(super) fn xt_diag_y(
1171 runtime: &GpuRuntime,
1172 x: ArrayView2<'_, f64>,
1173 w: ArrayView1<'_, f64>,
1174 y: ArrayView2<'_, f64>,
1175 ) -> Option<Array2<f64>> {
1176 super::super::blas::xt_diag_y_cuda(runtime, x, w, y)
1177 }
1178
1179 #[inline]
1180 pub(super) fn joint_hessian_2x2(
1181 runtime: &GpuRuntime,
1182 x_a: ArrayView2<'_, f64>,
1183 x_b: ArrayView2<'_, f64>,
1184 w_aa: ArrayView1<'_, f64>,
1185 w_ab: ArrayView1<'_, f64>,
1186 w_bb: ArrayView1<'_, f64>,
1187 ) -> Option<Array2<f64>> {
1188 super::super::blas::joint_hessian_2x2_cuda(runtime, x_a, x_b, w_aa, w_ab, w_bb)
1189 }
1190
1191 #[inline]
1192 pub(super) fn trsm(
1193 runtime: &GpuRuntime,
1194 triangular: ArrayView2<'_, f64>,
1195 rhs: ArrayView2<'_, f64>,
1196 upper: bool,
1197 ) -> Option<Array2<f64>> {
1198 super::super::blas::trsm_cuda(runtime, triangular, rhs, upper)
1199 }
1200
1201 #[inline]
1202 pub(super) fn cholesky_lower(
1203 runtime: &GpuRuntime,
1204 a: ArrayView2<'_, f64>,
1205 ) -> Option<Array2<f64>> {
1206 let (p, p2) = a.dim();
1207 if p == 0 || p != p2 {
1208 return None;
1209 }
1210 let stream = super::super::device_runtime::cuda_context_for(runtime.device.ordinal)?
1211 .new_stream()
1212 .ok()?;
1213 let solver = DnHandle::new(stream.clone()).ok()?;
1214 let a_col = to_col_major(&a);
1215 let mut a_dev = stream.clone_htod(&*a_col).ok()?;
1216 potrf_lower_in_place(&solver, &stream, p, &mut a_dev)?;
1217 let factor_col = stream.clone_dtoh(&a_dev).ok()?;
1218 let mut lower = from_col_major(&factor_col, p, p)?;
1219 for row in 0..p {
1220 for col in (row + 1)..p {
1221 lower[[row, col]] = 0.0;
1222 }
1223 }
1224 Some(lower)
1225 }
1226
1227 #[inline]
1231 pub(super) fn cholesky_batched_lower(
1232 ordinal: usize,
1233 matrices: &mut [Array2<f64>],
1234 ) -> Option<()> {
1235 let first = matrices.first()?;
1236 let p = first.nrows();
1237 if p == 0 || first.ncols() != p || matrices.iter().any(|matrix| matrix.dim() != (p, p)) {
1238 return None;
1239 }
1240
1241 let stream = super::super::device_runtime::cuda_context_for(ordinal)?
1242 .new_stream()
1243 .ok()?;
1244 let solver = DnHandle::new(stream.clone()).ok()?;
1245 let matrix_len = p.checked_mul(p)?;
1246 let mut batch_col = Vec::with_capacity(matrices.len().checked_mul(matrix_len)?);
1247 for matrix in matrices.iter() {
1248 batch_col.extend(to_col_major(&matrix.view()).iter().copied());
1249 }
1250 let mut matrices_dev = stream.clone_htod(&batch_col).ok()?;
1251 let matrix_ptrs = {
1252 let (base_ptr, _matrix_record) = matrices_dev.device_ptr_mut(&stream);
1253 let bytes_per_matrix = driver_sys::CUdeviceptr::try_from(
1254 matrix_len.checked_mul(std::mem::size_of::<f64>())?,
1255 )
1256 .ok()?;
1257 let mut matrix_ptrs = Vec::with_capacity(matrices.len());
1258 for idx in 0..matrices.len() {
1259 let offset = driver_sys::CUdeviceptr::try_from(idx).ok()? * bytes_per_matrix;
1260 matrix_ptrs.push(base_ptr + offset);
1261 }
1262 matrix_ptrs
1263 };
1264 let mut matrix_ptrs_dev = stream.clone_htod(&matrix_ptrs).ok()?;
1265 let mut info_dev = stream.alloc_zeros::<i32>(matrices.len()).ok()?;
1266 let p_i = to_i32(p)?;
1267 let batch_i = to_i32(matrices.len())?;
1268 {
1269 let (ptrs_ptr, _ptrs_record) = matrix_ptrs_dev.device_ptr_mut(&stream);
1270 let (info_ptr, _info_record) = info_dev.device_ptr_mut(&stream);
1271 let status = unsafe {
1275 cusolver_sys::cusolverDnDpotrfBatched(
1276 solver.cu(),
1277 cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1278 p_i,
1279 ptrs_ptr as *mut *mut f64,
1280 p_i,
1281 info_ptr as *mut i32,
1282 batch_i,
1283 )
1284 };
1285 check_cusolver(status)?;
1286 }
1287 let info_host = stream.clone_dtoh(&info_dev).ok()?;
1288 if info_host.iter().any(|info| *info != 0) {
1289 return None;
1290 }
1291 let factored_col = stream.clone_dtoh(&matrices_dev).ok()?;
1292 for (idx, matrix) in matrices.iter_mut().enumerate() {
1293 let start = idx.checked_mul(matrix_len)?;
1294 let end = start.checked_add(matrix_len)?;
1295 let mut lower = from_col_major(&factored_col[start..end], p, p)?;
1296 for row in 0..p {
1297 for col in (row + 1)..p {
1298 lower[[row, col]] = 0.0;
1299 }
1300 }
1301 *matrix = lower;
1302 }
1303 Some(())
1304 }
1305
1306 fn potrf_lower_in_place(
1312 solver: &DnHandle,
1313 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
1314 p: usize,
1315 a: &mut cudarc::driver::CudaSlice<f64>,
1316 ) -> Option<()> {
1317 crate::solver::potrf_in_place_generic::<f64>(solver, stream, p, a).ok()
1318 }
1319
1320 #[inline]
1321 fn check_cusolver(status: cusolver_sys::cusolverStatus_t) -> Option<()> {
1322 if status == cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1323 Some(())
1324 } else {
1325 None
1326 }
1327 }
1328}