pub mod backend_probe;
pub mod blas;
#[cfg(target_os = "linux")]
pub mod calibration;
pub mod cpu_traits;
pub mod device;
pub mod device_cache;
pub mod driver;
#[macro_use]
pub mod gpu_error;
pub mod device_runtime;
pub mod linalg_dispatch;
pub mod memory;
pub mod numerics_device;
pub mod numerics_host;
pub mod policy;
pub mod pool;
pub mod profile;
pub mod solver;
pub mod kernels;
pub use cpu_traits::{ExecutionTarget, MatrixLocation};
pub use device::GpuDeviceInfo;
pub use device_runtime::GpuRuntime;
pub use gpu_error::GpuError;
pub use memory::{DeviceBuffer, DeviceCsrMatrix, DeviceMatrix, DeviceVector};
pub use policy::{GpuDispatchPolicy, GpuMixedPrecisionPolicy};
pub use pool::{balanced_partition, scatter_batched};
pub use profile::{GpuExecutionTelemetry, KernelStat, KernelStatsSnapshot};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::OnceLock;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum CudaBackendStatus {
CudaUnavailable,
CudaReady,
}
#[inline]
pub(crate) fn cuda_backend_status() -> CudaBackendStatus {
if device_runtime::GpuRuntime::global().is_some() {
CudaBackendStatus::CudaReady
} else {
CudaBackendStatus::CudaUnavailable
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum GpuPolicy {
#[default]
Auto,
Off,
Force,
}
impl GpuPolicy {
pub fn parse(raw: &str) -> Option<Self> {
match raw.trim().to_ascii_lowercase().as_str() {
"auto" => Some(Self::Auto),
"off" => Some(Self::Off),
"force" => Some(Self::Force),
_ => None,
}
}
#[inline]
pub const fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::Off => "off",
Self::Force => "force",
}
}
#[inline]
pub const fn is_force(self) -> bool {
matches!(self, Self::Force)
}
}
impl fmt::Display for GpuPolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum GpuMode {
#[default]
Auto,
Required,
Off,
}
impl GpuMode {
#[inline]
pub const fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::Required => "required",
Self::Off => "off",
}
}
}
impl fmt::Display for GpuMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
static GPU_MODE: OnceLock<GpuMode> = OnceLock::new();
pub fn set_gpu_mode(mode: GpuMode) {
GPU_MODE.set(mode).ok();
}
#[inline]
pub fn gpu_mode() -> GpuMode {
match GPU_MODE.get() {
Some(m) => *m,
None => GpuMode::Auto,
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum GpuKernel {
DenseMatvec,
DenseTransposeMatvec,
DenseXtWX,
CandidateScreen,
DenseSolve,
MatrixFreePcg,
SparseAssembly,
SpatialKernelOperator,
MarginalSlopeRows,
RemlTrace,
FinalInference,
}
impl GpuKernel {
pub const fn as_str(self) -> &'static str {
match self {
Self::DenseMatvec => "dense-matvec",
Self::DenseTransposeMatvec => "dense-transpose-matvec",
Self::DenseXtWX => "dense-xtwx",
Self::CandidateScreen => "candidate-screen",
Self::DenseSolve => "dense-solve",
Self::MatrixFreePcg => "matrix-free-pcg",
Self::SparseAssembly => "sparse-assembly",
Self::SpatialKernelOperator => "spatial-kernel-operator",
Self::MarginalSlopeRows => "marginal-slope-rows",
Self::RemlTrace => "reml-trace",
Self::FinalInference => "final-inference",
}
}
}
#[derive(Clone, Debug)]
pub struct GpuDecision {
pub policy: GpuPolicy,
pub kernel: GpuKernel,
pub use_gpu: bool,
pub reason: &'static str,
}
static POLICY: OnceLock<GpuPolicy> = OnceLock::new();
#[inline]
pub fn global_policy() -> GpuPolicy {
match POLICY.get() {
Some(p) => *p,
None => GpuPolicy::Auto,
}
}
pub fn configure_global_policy(policy: GpuPolicy) {
POLICY.set(policy).ok();
}
#[inline]
pub fn cuda_selected() -> bool {
match global_policy() {
GpuPolicy::Auto => device_runtime::GpuRuntime::is_available(),
GpuPolicy::Off => false,
GpuPolicy::Force => true,
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum GpuEligibility {
BackendNotCompiled,
WorkloadBelowThreshold,
Eligible,
}
impl GpuEligibility {
#[inline]
pub const fn from_flags(supported: bool, large_enough: bool) -> Self {
if !supported {
Self::BackendNotCompiled
} else if !large_enough {
Self::WorkloadBelowThreshold
} else {
Self::Eligible
}
}
}
pub fn decide(kernel: GpuKernel, eligibility: GpuEligibility) -> GpuDecision {
let policy = global_policy();
let runtime_available = device_runtime::GpuRuntime::is_available();
let (use_gpu, reason) = match (policy, eligibility) {
(GpuPolicy::Off, _) => (false, "cpu-gpu-policy-off"),
(GpuPolicy::Auto, GpuEligibility::BackendNotCompiled) => {
(false, "cpu-gpu-backend-not-compiled")
}
(GpuPolicy::Auto, _) if !runtime_available => (false, "cpu-gpu-runtime-unavailable"),
(GpuPolicy::Auto, GpuEligibility::WorkloadBelowThreshold) => {
(false, "cpu-workload-below-gpu-threshold")
}
(GpuPolicy::Auto, GpuEligibility::Eligible) => (true, "gpu-auto-supported"),
(GpuPolicy::Force, GpuEligibility::BackendNotCompiled) => {
(false, "cpu-gpu-force-unsupported")
}
(GpuPolicy::Force, _) if !runtime_available => (false, "cpu-gpu-force-runtime-unavailable"),
(GpuPolicy::Force, GpuEligibility::WorkloadBelowThreshold)
| (GpuPolicy::Force, GpuEligibility::Eligible) => (true, "gpu-force-supported"),
};
GpuDecision {
policy,
kernel,
use_gpu,
reason,
}
}
impl GpuDecision {
pub fn require_supported(&self) -> Result<(), String> {
if self.policy == GpuPolicy::Force && !self.use_gpu {
return Err(format!(
"gpu=force requested kernel '{}' but no supported device backend is available ({})",
self.kernel.as_str(),
self.reason
));
}
Ok(())
}
pub fn log(self) {
log::debug!(
"[GPU backend] kernel={} policy={} selected={} reason={}",
self.kernel.as_str(),
self.policy.as_str(),
self.use_gpu,
self.reason
);
}
}
pub fn log_backend_inventory_once() {
static LOGGED: OnceLock<()> = OnceLock::new();
LOGGED.get_or_init(|| {
let compiled_backends = if cfg!(target_os = "linux") {
"cuda-dynamic"
} else {
"none"
};
log::debug!(
"[GPU backend] policy={} compiled_backends={} kernels=dense-matvec,dense-transpose-matvec,dense-xtwx,candidate-screen,dense-solve,matrix-free-pcg,sparse-assembly,spatial-kernel-operator,marginal-slope-rows,reml-trace,final-inference",
global_policy().as_str(),
compiled_backends
);
});
}
#[inline]
pub fn try_fast_ab(
a: ndarray::ArrayView2<'_, f64>,
b: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array2<f64>> {
linalg_dispatch::try_fast_ab(a, b)
}
#[inline]
pub fn try_fast_atb_on_ordinal(
ordinal: usize,
a: ndarray::ArrayView2<'_, f64>,
b: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array2<f64>> {
linalg_dispatch::try_fast_atb_on_ordinal(ordinal, a, b)
}
#[inline]
pub fn try_fast_av(
a: ndarray::ArrayView2<'_, f64>,
v: ndarray::ArrayView1<'_, f64>,
) -> Option<ndarray::Array1<f64>> {
linalg_dispatch::try_fast_av(a, v)
}
#[inline]
pub fn try_fast_atv(
a: ndarray::ArrayView2<'_, f64>,
v: ndarray::ArrayView1<'_, f64>,
) -> Option<ndarray::Array1<f64>> {
linalg_dispatch::try_fast_atv(a, v)
}
#[inline]
pub fn try_fast_ab_broadcast_b_batched(
a: ndarray::ArrayView3<'_, f64>,
b: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array3<f64>> {
linalg_dispatch::try_fast_ab_broadcast_b_batched(a, b)
}
#[inline]
pub fn try_fast_abt_strided_batched(
a: ndarray::ArrayView3<'_, f64>,
b: ndarray::ArrayView3<'_, f64>,
) -> Option<ndarray::Array3<f64>> {
linalg_dispatch::try_fast_abt_strided_batched(a, b)
}
#[inline]
pub fn try_cholesky_lower_inplace(a: &mut ndarray::Array2<f64>) -> Option<()> {
linalg_dispatch::try_cholesky_lower_inplace(a)
}
#[inline]
pub fn try_cholesky_batched_lower_inplace(matrices: &mut [ndarray::Array2<f64>]) -> Option<()> {
linalg_dispatch::try_cholesky_batched_lower_inplace(matrices)
}
#[inline]
pub fn try_solve_lower_triangular_matrix(
lower: ndarray::ArrayView2<'_, f64>,
rhs: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array2<f64>> {
linalg_dispatch::try_solve_lower_triangular_matrix(lower, rhs)
}
#[inline]
pub fn try_solve_upper_triangular_matrix(
upper: ndarray::ArrayView2<'_, f64>,
rhs: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array2<f64>> {
linalg_dispatch::try_solve_upper_triangular_matrix(upper, rhs)
}
#[cfg(test)]
mod policy_tests {
use super::*;
#[test]
fn parses_canonical_user_gpu_policy_values() {
assert_eq!(GpuPolicy::parse("auto"), Some(GpuPolicy::Auto));
assert_eq!(GpuPolicy::parse("off"), Some(GpuPolicy::Off));
assert_eq!(GpuPolicy::parse("force"), Some(GpuPolicy::Force));
assert_eq!(GpuPolicy::parse("cpu"), None);
assert_eq!(GpuPolicy::parse(""), None);
assert_eq!(GpuPolicy::parse("wat"), None);
}
#[test]
fn execution_path_defaults_to_cpu() {
use crate::model_types::ExecutionPath;
assert_eq!(ExecutionPath::default(), ExecutionPath::Cpu);
assert!(!ExecutionPath::Cpu.used_device());
assert!(ExecutionPath::GpuResidentFull.used_device());
}
#[test]
fn gpu_mode_required_fails_closed_when_device_absent() {
use crate::gpu::device_runtime::GpuRuntime;
assert!(matches!(
GpuRuntime::global_or_fail(GpuMode::Off),
Err(GpuError::DriverLibraryUnavailable { .. })
));
if GpuRuntime::is_available() {
assert!(GpuRuntime::global_or_fail(GpuMode::Required).is_ok());
assert!(GpuRuntime::global_or_fail(GpuMode::Auto).is_ok());
} else {
let required = GpuRuntime::global_or_fail(GpuMode::Required);
assert!(
matches!(required, Err(GpuError::DriverLibraryUnavailable { .. })),
"GpuMode::Required must fail closed when the device is absent, got {required:?}"
);
assert!(GpuRuntime::global_or_fail(GpuMode::Auto).is_err());
}
}
#[test]
fn pirls_loop_admission_requires_runtime_size_and_known_family() {
use crate::gpu::policy::{PirlsLoopAdmission, PirlsLoopCurvatureKind, PirlsLoopFamilyKind};
let pol = GpuDispatchPolicy::default();
let base = PirlsLoopAdmission {
n: 80_000,
p: 44,
family: Some(PirlsLoopFamilyKind::BernoulliLogit),
curvature: PirlsLoopCurvatureKind::Fisher,
gpu_available: true,
};
assert!(pol.should_use_gpu_pirls_loop(base));
assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission {
gpu_available: false,
..base
}));
assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission { n: 1_000, ..base }));
assert!(pol.should_use_gpu_pirls_loop(PirlsLoopAdmission {
n: 2_000,
p: 2_048,
..base
}));
assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission { p: 8, ..base }));
assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission {
family: None,
..base
}));
}
#[test]
fn force_policy_reports_unsupported_kernel() {
let decision = GpuDecision {
policy: GpuPolicy::Force,
kernel: GpuKernel::DenseXtWX,
use_gpu: false,
reason: "gpu-force-unsupported",
};
let err = decision.require_supported().unwrap_err();
assert!(err.contains("dense-xtwx"));
assert!(err.contains("gpu=force"));
}
}