#[cfg(target_os = "linux")]
pub(crate) use linux::{CudaBackendContext, probe_cuda_backend};
#[cfg(target_os = "linux")]
mod linux {
use crate::gpu::common::{DeviceArena, PtxModuleCache};
use crate::gpu::device::GpuCapability;
use crate::gpu::error::GpuError;
use crate::gpu::runtime::{GpuRuntime, cuda_context_for};
use crate::gpu_err;
use cudarc::driver::{CudaContext, CudaStream};
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub(crate) struct CudaBackendParts {
pub(crate) ctx: Arc<CudaContext>,
pub(crate) stream: Arc<CudaStream>,
pub(crate) capability: GpuCapability,
}
pub(crate) fn probe_cuda_backend(label: &'static str) -> Result<CudaBackendParts, GpuError> {
let runtime = GpuRuntime::global().ok_or_else(|| GpuError::DriverLibraryUnavailable {
reason: format!("{label} backend: no CUDA runtime available"),
})?;
let ordinal = runtime.selected_device().ordinal;
let ctx = cuda_context_for(ordinal).ok_or_else(|| {
gpu_err!("{label} backend: failed to create CUDA context for device {ordinal}")
})?;
let stream = ctx.default_stream();
let capability = runtime.selected_device().capability.clone();
Ok(CudaBackendParts {
ctx,
stream,
capability,
})
}
pub(crate) struct CudaBackendContext {
pub(crate) ctx: Arc<CudaContext>,
pub(crate) stream: Arc<CudaStream>,
pub(crate) module: PtxModuleCache,
pub(crate) arena: Mutex<DeviceArena>,
}
impl CudaBackendContext {
pub(crate) fn from_parts(parts: CudaBackendParts) -> Self {
CudaBackendContext {
ctx: parts.ctx,
stream: parts.stream,
module: PtxModuleCache::new(),
arena: Mutex::new(DeviceArena::default()),
}
}
}
}
#[cfg(all(test, target_os = "linux"))]
mod tests {
use super::probe_cuda_backend;
use crate::gpu::error::GpuError;
use crate::gpu::runtime::GpuRuntime;
#[test]
fn shared_probe_matches_runtime_device_and_labels_errors() {
match GpuRuntime::global() {
None => {
match probe_cuda_backend("bms_flex") {
Err(GpuError::DriverLibraryUnavailable { reason }) => {
assert_eq!(
reason, "bms_flex backend: no CUDA runtime available",
"shared probe must emit the uniform no-runtime message"
);
}
other => panic!(
"expected DriverLibraryUnavailable on a host without a CUDA runtime, \
got {other:?}"
),
}
}
Some(runtime) => {
let expected_ordinal = runtime.selected_device().ordinal;
let expected_capability = &runtime.selected_device().capability;
for label in [
"bms_flex",
"survival_flex",
"cubic_bspline_moments",
"cubic_cell",
"pirls_row",
"sphere",
] {
let parts = probe_cuda_backend(label)
.unwrap_or_else(|err| panic!("probe for {label} must succeed: {err:?}"));
assert_eq!(
parts.ctx.ordinal(),
expected_ordinal,
"{label}: context must bind the runtime's selected device ordinal"
);
assert_eq!(
&parts.capability, expected_capability,
"{label}: probe capability must match the runtime's selected device"
);
parts
.stream
.synchronize()
.unwrap_or_else(|err| panic!("{label}: default stream must sync: {err:?}"));
}
}
}
}
}