pub mod arrow_schur;
pub mod arrow_schur_nvrtc;
pub mod blas;
pub mod bms_flex;
pub mod bms_flex_row;
pub mod common;
pub mod cpu_traits;
pub mod cubic_bspline_moments;
pub mod cubic_cell;
pub mod device;
pub mod driver;
#[macro_use]
pub mod error;
pub mod identifiability_compile;
#[cfg(target_os = "linux")]
pub mod identifiability_compile_kernel;
pub mod linalg;
pub mod memory;
pub mod numerics_device;
pub mod pirls_row;
pub mod policy;
pub mod polya_gamma;
pub mod profile;
pub mod reml_trace;
pub mod row_hessian_ops;
pub mod runtime;
pub mod sigma_cubature;
pub mod solver;
pub mod sphere;
pub mod survival_flex;
pub mod survival_flex_prep;
pub use cpu_traits::{ExecutionTarget, MatrixLocation};
pub use device::GpuDeviceInfo;
pub use error::GpuError;
pub use memory::{DeviceBuffer, DeviceCsrMatrix, DeviceMatrix, DeviceVector};
pub use policy::{GpuDispatchPolicy, MixedPrecisionPolicy};
pub use profile::{KernelStat, KernelStatsSnapshot};
pub use runtime::GpuRuntime;
use serde::{Deserialize, Serialize};
use std::sync::OnceLock;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum GpuPolicy {
#[default]
Auto,
Off,
Force,
}
impl GpuPolicy {
pub fn parse(raw: &str) -> Option<Self> {
match raw.trim().to_ascii_lowercase().as_str() {
"auto" | "" => Some(Self::Auto),
"off" | "false" | "0" | "cpu" => Some(Self::Off),
"force" | "on" | "true" | "1" | "gpu" => Some(Self::Force),
_ => None,
}
}
#[inline]
pub const fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::Off => "off",
Self::Force => "force",
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum GpuKernel {
DenseMatvec,
DenseTransposeMatvec,
DenseXtWX,
CandidateScreen,
DenseSolve,
MatrixFreePcg,
SparseAssembly,
SpatialKernelOperator,
MarginalSlopeRows,
RemlTrace,
FinalInference,
}
impl GpuKernel {
pub const fn as_str(self) -> &'static str {
match self {
Self::DenseMatvec => "dense-matvec",
Self::DenseTransposeMatvec => "dense-transpose-matvec",
Self::DenseXtWX => "dense-xtwx",
Self::CandidateScreen => "candidate-screen",
Self::DenseSolve => "dense-solve",
Self::MatrixFreePcg => "matrix-free-pcg",
Self::SparseAssembly => "sparse-assembly",
Self::SpatialKernelOperator => "spatial-kernel-operator",
Self::MarginalSlopeRows => "marginal-slope-rows",
Self::RemlTrace => "reml-trace",
Self::FinalInference => "final-inference",
}
}
}
#[derive(Clone, Debug)]
pub struct GpuDecision {
pub policy: GpuPolicy,
pub kernel: GpuKernel,
pub use_gpu: bool,
pub reason: &'static str,
}
static POLICY: OnceLock<GpuPolicy> = OnceLock::new();
#[inline]
pub fn global_policy() -> GpuPolicy {
match POLICY.get() {
Some(p) => *p,
None => GpuPolicy::Auto,
}
}
pub fn configure_global_policy(policy: GpuPolicy) {
POLICY.set(policy).ok();
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum GpuEligibility {
BackendNotCompiled,
WorkloadBelowThreshold,
Eligible,
}
impl GpuEligibility {
#[inline]
pub const fn from_flags(supported: bool, large_enough: bool) -> Self {
if !supported {
Self::BackendNotCompiled
} else if !large_enough {
Self::WorkloadBelowThreshold
} else {
Self::Eligible
}
}
}
pub fn decide(kernel: GpuKernel, eligibility: GpuEligibility) -> GpuDecision {
let policy = global_policy();
let runtime_available = runtime::GpuRuntime::is_available();
let (use_gpu, reason) = match (policy, eligibility) {
(GpuPolicy::Off, _) => (false, "cpu-gpu-policy-off"),
(GpuPolicy::Auto, GpuEligibility::BackendNotCompiled) => {
(false, "cpu-gpu-backend-not-compiled")
}
(GpuPolicy::Auto, _) if !runtime_available => (false, "cpu-gpu-runtime-unavailable"),
(GpuPolicy::Auto, GpuEligibility::WorkloadBelowThreshold) => {
(false, "cpu-workload-below-gpu-threshold")
}
(GpuPolicy::Auto, GpuEligibility::Eligible) => (true, "gpu-auto-supported"),
(GpuPolicy::Force, GpuEligibility::BackendNotCompiled) => {
(false, "cpu-gpu-force-unsupported")
}
(GpuPolicy::Force, _) if !runtime_available => (false, "cpu-gpu-force-runtime-unavailable"),
(GpuPolicy::Force, GpuEligibility::WorkloadBelowThreshold)
| (GpuPolicy::Force, GpuEligibility::Eligible) => (true, "gpu-force-supported"),
};
GpuDecision {
policy,
kernel,
use_gpu,
reason,
}
}
impl GpuDecision {
pub fn require_supported(&self) -> Result<(), String> {
if self.policy == GpuPolicy::Force && !self.use_gpu {
return Err(format!(
"gpu=force requested kernel '{}' but no supported device backend is available ({})",
self.kernel.as_str(),
self.reason
));
}
Ok(())
}
pub fn log(self) {
log::debug!(
"[GPU backend] kernel={} policy={} selected={} reason={}",
self.kernel.as_str(),
self.policy.as_str(),
self.use_gpu,
self.reason
);
}
}
pub fn log_backend_inventory_once() {
static LOGGED: OnceLock<()> = OnceLock::new();
LOGGED.get_or_init(|| {
let compiled_backends = if cfg!(target_os = "linux") {
"cuda-dynamic"
} else {
"none"
};
log::debug!(
"[GPU backend] policy={} compiled_backends={} kernels=dense-matvec,dense-transpose-matvec,dense-xtwx,candidate-screen,dense-solve,matrix-free-pcg,sparse-assembly,spatial-kernel-operator,marginal-slope-rows,reml-trace,final-inference",
global_policy().as_str(),
compiled_backends
);
});
}
#[inline]
pub fn try_fast_ab(
a: ndarray::ArrayView2<'_, f64>,
b: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array2<f64>> {
linalg::try_fast_ab(a, b)
}
#[inline]
pub fn try_fast_av(
a: ndarray::ArrayView2<'_, f64>,
v: ndarray::ArrayView1<'_, f64>,
) -> Option<ndarray::Array1<f64>> {
linalg::try_fast_av(a, v)
}
#[inline]
pub fn try_fast_atv(
a: ndarray::ArrayView2<'_, f64>,
v: ndarray::ArrayView1<'_, f64>,
) -> Option<ndarray::Array1<f64>> {
linalg::try_fast_atv(a, v)
}
#[inline]
pub fn try_fast_ab_broadcast_b_batched(
a: ndarray::ArrayView3<'_, f64>,
b: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array3<f64>> {
linalg::try_fast_ab_broadcast_b_batched(a, b)
}
#[inline]
pub fn try_fast_abt_strided_batched(
a: ndarray::ArrayView3<'_, f64>,
b: ndarray::ArrayView3<'_, f64>,
) -> Option<ndarray::Array3<f64>> {
linalg::try_fast_abt_strided_batched(a, b)
}
#[inline]
pub fn try_cholesky_lower_inplace(a: &mut ndarray::Array2<f64>) -> Option<()> {
linalg::try_cholesky_lower_inplace(a)
}
#[inline]
pub fn try_cholesky_batched_lower_inplace(matrices: &mut [ndarray::Array2<f64>]) -> Option<()> {
linalg::try_cholesky_batched_lower_inplace(matrices)
}
#[inline]
pub fn try_solve_lower_triangular_matrix(
lower: ndarray::ArrayView2<'_, f64>,
rhs: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array2<f64>> {
linalg::try_solve_lower_triangular_matrix(lower, rhs)
}
#[inline]
pub fn try_solve_upper_triangular_matrix(
upper: ndarray::ArrayView2<'_, f64>,
rhs: ndarray::ArrayView2<'_, f64>,
) -> Option<ndarray::Array2<f64>> {
linalg::try_solve_upper_triangular_matrix(upper, rhs)
}
#[cfg(test)]
mod policy_tests {
use super::*;
#[test]
fn parses_user_gpu_policy_aliases() {
assert_eq!(GpuPolicy::parse("auto"), Some(GpuPolicy::Auto));
assert_eq!(GpuPolicy::parse("cpu"), Some(GpuPolicy::Off));
assert_eq!(GpuPolicy::parse("force"), Some(GpuPolicy::Force));
assert_eq!(GpuPolicy::parse("wat"), None);
}
#[test]
fn pirls_loop_admission_requires_runtime_size_and_known_family() {
use crate::gpu::policy::{PirlsLoopAdmission, PirlsLoopCurvatureKind, PirlsLoopFamilyKind};
let pol = GpuDispatchPolicy::default();
let base = PirlsLoopAdmission {
n: 80_000,
p: 44,
family: Some(PirlsLoopFamilyKind::BernoulliLogit),
curvature: PirlsLoopCurvatureKind::Fisher,
gpu_available: true,
};
assert!(pol.should_use_gpu_pirls_loop(base));
assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission {
gpu_available: false,
..base
}));
assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission { n: 1_000, ..base }));
assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission { p: 8, ..base }));
assert!(!pol.should_use_gpu_pirls_loop(PirlsLoopAdmission {
family: None,
..base
}));
}
#[test]
fn force_policy_reports_unsupported_kernel() {
let decision = GpuDecision {
policy: GpuPolicy::Force,
kernel: GpuKernel::DenseXtWX,
use_gpu: false,
reason: "gpu-force-unsupported",
};
let err = decision.require_supported().unwrap_err();
assert!(err.contains("dense-xtwx"));
assert!(err.contains("gpu=force"));
}
}