#[cfg(target_os = "linux")]
pub use linux::{DeviceArena, PtxModuleCache};
#[cfg(target_os = "linux")]
mod linux {
use super::super::gpu_error::GpuError;
use crate::gpu::gpu_error::GpuResultExt;
use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
use cudarc::nvrtc::{CompileOptions, compile_ptx_with_opts};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
#[derive(Default)]
pub struct DeviceArena {
free: HashMap<usize, Vec<CudaSlice<f64>>>,
}
impl DeviceArena {
#[inline]
pub fn bucket_of(elements: usize) -> usize {
elements.max(1).next_power_of_two()
}
pub fn alloc(
&mut self,
stream: &Arc<CudaStream>,
elements: usize,
label: &'static str,
) -> Result<(usize, CudaSlice<f64>), GpuError> {
let bucket = Self::bucket_of(elements);
if let Some(bucket_vec) = self.free.get_mut(&bucket)
&& let Some(slot) = bucket_vec.pop()
{
return Ok((bucket, slot));
}
let fresh = stream
.alloc_zeros::<f64>(bucket)
.gpu_ctx_with(|err| format!("{label} arena alloc_zeros<{bucket}>: {err}"))?;
Ok((bucket, fresh))
}
pub fn release(&mut self, bucket: usize, slab: CudaSlice<f64>) {
self.free.entry(bucket).or_default().push(slab);
}
}
#[derive(Default)]
pub struct PtxModuleCache {
module: std::sync::OnceLock<Arc<CudaModule>>,
}
impl PtxModuleCache {
pub const fn new() -> Self {
Self {
module: std::sync::OnceLock::new(),
}
}
pub fn get(&self) -> Option<&Arc<CudaModule>> {
self.module.get()
}
pub fn get_or_compile(
&self,
ctx: &Arc<CudaContext>,
label: &'static str,
source: &str,
) -> Result<&Arc<CudaModule>, GpuError> {
if let Some(existing) = self.module.get() {
return Ok(existing);
}
let ptx = compile_ptx_with_opts(source, nvrtc_compile_options())
.gpu_ctx_with(|err| format!("{label} NVRTC compile failed: {err}"))?;
let module = ctx
.load_module(ptx)
.gpu_ctx_with(|err| format!("{label} module load failed: {err}"))?;
self.module.set(module).ok();
Ok(self
.module
.get()
.expect("module slot populated immediately after set"))
}
}
fn nvrtc_compile_options() -> CompileOptions {
let mut opts = CompileOptions::default();
opts.include_paths = nvrtc_include_paths();
opts
}
fn nvrtc_include_paths() -> Vec<String> {
let mut paths = Vec::new();
push_existing_include_path(&mut paths, Path::new("/usr/local/cuda/include"));
push_existing_include_path(&mut paths, Path::new("/usr/include"));
push_existing_include_path(&mut paths, Path::new("/usr/include/x86_64-linux-gnu"));
push_gcc_include_paths(&mut paths, Path::new("/usr/lib/gcc/x86_64-linux-gnu"));
paths
}
fn push_gcc_include_paths(paths: &mut Vec<String>, root: &Path) {
let Ok(entries) = std::fs::read_dir(root) else {
return;
};
for entry in entries.flatten() {
push_existing_include_path(paths, &entry.path().join("include"));
}
}
fn push_existing_include_path(paths: &mut Vec<String>, path: &Path) {
if !path.is_dir() {
return;
}
let display = path.to_string_lossy().into_owned();
if !paths.iter().any(|existing| existing == &display) {
paths.push(display);
}
}
}