1#[macro_use]
18pub mod gpu_error;
19pub mod backend_probe;
20pub mod blas;
21#[cfg(target_os = "linux")]
22pub mod calibration;
23pub mod cpu_traits;
24pub mod device;
25pub mod device_cache;
26pub mod device_runtime;
27pub mod dictionary_score;
28pub mod driver;
29pub mod encode_throughput;
30pub mod linalg_dispatch;
31pub mod memory;
32pub mod numerics_device;
33pub mod numerics_host;
34pub mod policy;
35pub mod pool;
36pub mod profile;
37pub mod solver;
38
39pub mod kernels;
41
42pub use cpu_traits::MatrixLocation;
43pub use device::GpuDeviceInfo;
44pub use device_runtime::GpuRuntime;
45pub use dictionary_score::{
46 DEFAULT_DICTIONARY_SCORE_MIN_ELEMS, DEFAULT_DICTIONARY_SCORE_TILE_ELEMS,
47 DictionaryScoreRoutePlan,
48};
49pub use gpu_error::GpuError;
50pub use memory::{DeviceBuffer, DeviceCsrMatrix, DeviceMatrix, DeviceVector};
51pub use policy::{GpuDispatchPolicy, GpuMixedPrecisionPolicy};
52pub use pool::{balanced_partition, scatter_batched};
53pub use profile::{GpuExecutionTelemetry, KernelStat, KernelStatsSnapshot};
54
55use serde::{Deserialize, Serialize};
68use std::fmt;
69use std::sync::OnceLock;
70
71#[derive(Clone, Copy, Debug, Eq, PartialEq)]
72pub enum CudaBackendStatus {
73 CudaUnavailable,
74 CudaReady,
75}
76
77#[inline]
78pub(crate) fn cuda_backend_status() -> CudaBackendStatus {
79 if device_runtime::GpuRuntime::global().is_some() {
80 CudaBackendStatus::CudaReady
81 } else {
82 CudaBackendStatus::CudaUnavailable
83 }
84}
85
86#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
88#[serde(rename_all = "kebab-case")]
89pub enum GpuPolicy {
90 #[default]
92 Auto,
93 Off,
95 Force,
97}
98
99impl GpuPolicy {
100 pub fn parse(raw: &str) -> Option<Self> {
101 match raw.trim().to_ascii_lowercase().as_str() {
102 "auto" => Some(Self::Auto),
103 "off" => Some(Self::Off),
104 "force" => Some(Self::Force),
105 _ => None,
106 }
107 }
108
109 #[inline]
110 pub const fn as_str(self) -> &'static str {
111 match self {
112 Self::Auto => "auto",
113 Self::Off => "off",
114 Self::Force => "force",
115 }
116 }
117
118 #[inline]
120 pub const fn is_force(self) -> bool {
121 matches!(self, Self::Force)
122 }
123}
124
125impl fmt::Display for GpuPolicy {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.write_str(self.as_str())
128 }
129}
130
131#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
144#[serde(rename_all = "kebab-case")]
145pub enum GpuMode {
146 #[default]
148 Auto,
149 Required,
151 Off,
153}
154
155impl GpuMode {
156 #[inline]
158 pub const fn as_str(self) -> &'static str {
159 match self {
160 Self::Auto => "auto",
161 Self::Required => "required",
162 Self::Off => "off",
163 }
164 }
165}
166
167impl fmt::Display for GpuMode {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 f.write_str(self.as_str())
170 }
171}
172
173static GPU_MODE: OnceLock<GpuMode> = OnceLock::new();
174
175pub fn set_gpu_mode(mode: GpuMode) {
178 GPU_MODE.set(mode).ok();
179}
180
181#[inline]
185pub fn gpu_mode() -> GpuMode {
186 match GPU_MODE.get() {
187 Some(m) => *m,
188 None => GpuMode::Auto,
189 }
190}
191
192#[derive(Clone, Copy, Debug, Eq, PartialEq)]
193pub enum GpuKernel {
194 DenseMatvec,
195 DenseTransposeMatvec,
196 DenseXtWX,
197 CandidateScreen,
198 DenseSolve,
199 MatrixFreePcg,
200 SparseAssembly,
201 SpatialKernelOperator,
202 MarginalSlopeRows,
203 RemlTrace,
204 FinalInference,
205}
206
207impl GpuKernel {
208 pub const fn as_str(self) -> &'static str {
209 match self {
210 Self::DenseMatvec => "dense-matvec",
211 Self::DenseTransposeMatvec => "dense-transpose-matvec",
212 Self::DenseXtWX => "dense-xtwx",
213 Self::CandidateScreen => "candidate-screen",
214 Self::DenseSolve => "dense-solve",
215 Self::MatrixFreePcg => "matrix-free-pcg",
216 Self::SparseAssembly => "sparse-assembly",
217 Self::SpatialKernelOperator => "spatial-kernel-operator",
218 Self::MarginalSlopeRows => "marginal-slope-rows",
219 Self::RemlTrace => "reml-trace",
220 Self::FinalInference => "final-inference",
221 }
222 }
223}
224
225#[derive(Clone, Debug)]
227pub struct GpuDecision {
228 pub policy: GpuPolicy,
229 pub kernel: GpuKernel,
230 pub use_gpu: bool,
231 pub reason: &'static str,
232}
233
234static POLICY: OnceLock<GpuPolicy> = OnceLock::new();
235
236#[inline]
237pub fn global_policy() -> GpuPolicy {
238 match POLICY.get() {
245 Some(p) => *p,
246 None => GpuPolicy::Auto,
247 }
248}
249
250pub fn configure_global_policy(policy: GpuPolicy) {
257 POLICY.set(policy).ok();
259}
260
261#[inline]
268pub fn cuda_selected() -> bool {
269 match global_policy() {
270 GpuPolicy::Auto => device_runtime::GpuRuntime::is_available(),
271 GpuPolicy::Off => false,
272 GpuPolicy::Force => true,
273 }
274}
275
276#[derive(Clone, Copy, Debug, Eq, PartialEq)]
284pub enum GpuEligibility {
285 BackendNotCompiled,
287 WorkloadBelowThreshold,
290 Eligible,
293}
294
295impl GpuEligibility {
296 #[inline]
300 pub const fn from_flags(supported: bool, large_enough: bool) -> Self {
301 if !supported {
302 Self::BackendNotCompiled
303 } else if !large_enough {
304 Self::WorkloadBelowThreshold
305 } else {
306 Self::Eligible
307 }
308 }
309}
310
311pub fn decide(kernel: GpuKernel, eligibility: GpuEligibility) -> GpuDecision {
315 let policy = global_policy();
316 let runtime_available = device_runtime::GpuRuntime::is_available();
322 let (use_gpu, reason) = match (policy, eligibility) {
323 (GpuPolicy::Off, _) => (false, "cpu-gpu-policy-off"),
324 (GpuPolicy::Auto, GpuEligibility::BackendNotCompiled) => {
325 (false, "cpu-gpu-backend-not-compiled")
326 }
327 (GpuPolicy::Auto, _) if !runtime_available => (false, "cpu-gpu-runtime-unavailable"),
328 (GpuPolicy::Auto, GpuEligibility::WorkloadBelowThreshold) => {
329 (false, "cpu-workload-below-gpu-threshold")
330 }
331 (GpuPolicy::Auto, GpuEligibility::Eligible) => (true, "gpu-auto-supported"),
332 (GpuPolicy::Force, GpuEligibility::BackendNotCompiled) => {
333 (false, "cpu-gpu-force-unsupported")
334 }
335 (GpuPolicy::Force, _) if !runtime_available => (false, "cpu-gpu-force-runtime-unavailable"),
336 (GpuPolicy::Force, GpuEligibility::WorkloadBelowThreshold)
339 | (GpuPolicy::Force, GpuEligibility::Eligible) => (true, "gpu-force-supported"),
340 };
341 GpuDecision {
342 policy,
343 kernel,
344 use_gpu,
345 reason,
346 }
347}
348
349impl GpuDecision {
350 pub fn require_supported(&self) -> Result<(), String> {
351 if self.policy == GpuPolicy::Force && !self.use_gpu {
352 return Err(format!(
353 "gpu=force requested kernel '{}' but no supported device backend is available ({})",
354 self.kernel.as_str(),
355 self.reason
356 ));
357 }
358 Ok(())
359 }
360
361 pub fn log(self) {
362 log::debug!(
363 "[GPU backend] kernel={} policy={} selected={} reason={}",
364 self.kernel.as_str(),
365 self.policy.as_str(),
366 self.use_gpu,
367 self.reason
368 );
369 }
370}
371
372pub fn log_backend_inventory_once() {
376 static LOGGED: OnceLock<()> = OnceLock::new();
377 LOGGED.get_or_init(|| {
378 let compiled_backends = if cfg!(target_os = "linux") {
379 "cuda-dynamic"
380 } else {
381 "none"
382 };
383 log::debug!(
384 "[GPU backend] policy={} compiled_backends={} kernels=dense-matvec,dense-transpose-matvec,dense-xtwx,candidate-screen,dense-solve,matrix-free-pcg,sparse-assembly,spatial-kernel-operator,marginal-slope-rows,reml-trace,final-inference",
385 global_policy().as_str(),
386 compiled_backends
387 );
388 });
389}
390
391#[inline]
392pub fn try_fast_ab(
393 a: ndarray::ArrayView2<'_, f64>,
394 b: ndarray::ArrayView2<'_, f64>,
395) -> Option<ndarray::Array2<f64>> {
396 linalg_dispatch::try_fast_ab(a, b)
397}
398#[inline]
399pub fn try_fast_atb_on_ordinal(
400 ordinal: usize,
401 a: ndarray::ArrayView2<'_, f64>,
402 b: ndarray::ArrayView2<'_, f64>,
403) -> Option<ndarray::Array2<f64>> {
404 linalg_dispatch::try_fast_atb_on_ordinal(ordinal, a, b)
405}
406#[inline]
407pub fn try_fast_av(
408 a: ndarray::ArrayView2<'_, f64>,
409 v: ndarray::ArrayView1<'_, f64>,
410) -> Option<ndarray::Array1<f64>> {
411 linalg_dispatch::try_fast_av(a, v)
412}
413#[inline]
414pub fn try_fast_atv(
415 a: ndarray::ArrayView2<'_, f64>,
416 v: ndarray::ArrayView1<'_, f64>,
417) -> Option<ndarray::Array1<f64>> {
418 linalg_dispatch::try_fast_atv(a, v)
419}
420#[inline]
421pub fn try_fast_ab_broadcast_b_batched(
422 a: ndarray::ArrayView3<'_, f64>,
423 b: ndarray::ArrayView2<'_, f64>,
424) -> Option<ndarray::Array3<f64>> {
425 linalg_dispatch::try_fast_ab_broadcast_b_batched(a, b)
426}
427#[inline]
428pub fn try_fast_abt_strided_batched(
429 a: ndarray::ArrayView3<'_, f64>,
430 b: ndarray::ArrayView3<'_, f64>,
431) -> Option<ndarray::Array3<f64>> {
432 linalg_dispatch::try_fast_abt_strided_batched(a, b)
433}
434#[inline]
435pub fn try_cholesky_lower_inplace(a: &mut ndarray::Array2<f64>) -> Option<()> {
436 linalg_dispatch::try_cholesky_lower_inplace(a)
437}
438#[inline]
439pub fn try_cholesky_batched_lower_inplace(matrices: &mut [ndarray::Array2<f64>]) -> Option<()> {
440 linalg_dispatch::try_cholesky_batched_lower_inplace(matrices)
441}
442#[inline]
443pub fn try_solve_lower_triangular_matrix(
444 lower: ndarray::ArrayView2<'_, f64>,
445 rhs: ndarray::ArrayView2<'_, f64>,
446) -> Option<ndarray::Array2<f64>> {
447 linalg_dispatch::try_solve_lower_triangular_matrix(lower, rhs)
448}
449#[inline]
450pub fn try_solve_upper_triangular_matrix(
451 upper: ndarray::ArrayView2<'_, f64>,
452 rhs: ndarray::ArrayView2<'_, f64>,
453) -> Option<ndarray::Array2<f64>> {
454 linalg_dispatch::try_solve_upper_triangular_matrix(upper, rhs)
455}
456#[cfg(test)]
457mod policy_tests {
458 use super::*;
459
460 #[test]
461 fn parses_canonical_user_gpu_policy_values() {
462 assert_eq!(GpuPolicy::parse("auto"), Some(GpuPolicy::Auto));
463 assert_eq!(GpuPolicy::parse("off"), Some(GpuPolicy::Off));
464 assert_eq!(GpuPolicy::parse("force"), Some(GpuPolicy::Force));
465 assert_eq!(GpuPolicy::parse("cpu"), None);
466 assert_eq!(GpuPolicy::parse(""), None);
467 assert_eq!(GpuPolicy::parse("wat"), None);
468 }
469
470 #[test]
471 fn execution_path_defaults_to_cpu() {
472 use gam_problem::ExecutionPath;
473 assert_eq!(ExecutionPath::default(), ExecutionPath::Cpu);
478 assert!(!ExecutionPath::Cpu.used_device());
479 assert!(ExecutionPath::GpuResidentFull.used_device());
480 }
481
482 #[test]
483 fn gpu_mode_required_fails_closed_when_device_absent() {
484 use crate::device_runtime::GpuRuntime;
485 assert!(matches!(
487 GpuRuntime::global_or_fail(GpuMode::Off),
488 Err(GpuError::DriverLibraryUnavailable { .. })
489 ));
490
491 if GpuRuntime::is_available() {
492 assert!(GpuRuntime::global_or_fail(GpuMode::Required).is_ok());
494 assert!(GpuRuntime::global_or_fail(GpuMode::Auto).is_ok());
495 } else {
496 let required = GpuRuntime::global_or_fail(GpuMode::Required);
501 assert!(
502 matches!(required, Err(GpuError::DriverLibraryUnavailable { .. })),
503 "GpuMode::Required must fail closed when the device is absent, got {required:?}"
504 );
505 assert!(GpuRuntime::global_or_fail(GpuMode::Auto).is_err());
506 }
507 }
508
509 #[test]
510 fn pirls_loop_admission_requires_runtime_size_and_known_family() {
511 use crate::policy::{PirlsLoopAdmission, PirlsLoopCurvatureKind, PirlsLoopFamilyKind};
512 let pol = GpuDispatchPolicy::default();
513 let base = PirlsLoopAdmission {
514 n: 80_000,
515 p: 44,
516 family: Some(PirlsLoopFamilyKind::BernoulliLogit),
517 curvature: PirlsLoopCurvatureKind::Fisher,
518 gpu_available: true,
519 };
520 assert!(pol.should_use_gpu_pirls_loop(base));
521 assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission {
523 gpu_available: false,
524 ..base
525 }));
526 assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission { n: 1_000, ..base }));
528 assert!(pol.should_use_gpu_pirls_loop(PirlsLoopAdmission {
530 n: 2_000,
531 p: 2_048,
532 ..base
533 }));
534 assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission { p: 8, ..base }));
536 assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission {
538 family: None,
539 ..base
540 }));
541 }
542
543 #[test]
544 fn force_policy_reports_unsupported_kernel() {
545 let decision = GpuDecision {
546 policy: GpuPolicy::Force,
547 kernel: GpuKernel::DenseXtWX,
548 use_gpu: false,
549 reason: "gpu-force-unsupported",
550 };
551 let err = decision.require_supported().unwrap_err();
552 assert!(err.contains("dense-xtwx"));
553 assert!(err.contains("gpu=force"));
554 }
555}