use crate::error::{device_assert, device_error, DeviceError};
use crate::scheduling_policies::{SchedulingPolicy, StreamPoolRoundRobin};
use cuda_core::{CudaContext, CudaFunction, CudaModule, CudaStream};
use std::cell::Cell;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
pub const DEFAULT_DEVICE_ID: usize = 0;
pub const DEFAULT_NUM_DEVICES: usize = 1;
pub const DEFAULT_ROUND_ROBIN_STREAM_POOL_SIZE: usize = 4;
pub trait FunctionKey: Hash {
fn get_hash_string(&self) -> String {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
let hash_value: u64 = hasher.finish();
format!("{:x}", hash_value)
}
}
#[derive(Debug, Clone)]
pub enum ValidParamType {
Scalar(ScalarParamType),
Pointer(PointerParamType),
Tensor(TensorParamType),
}
#[derive(Debug, Clone)]
pub struct ScalarParamType {
pub element_type: String,
}
#[derive(Debug, Clone)]
pub struct PointerParamType {
pub mutable: bool,
pub element_type: String,
}
#[derive(Debug, Clone)]
pub struct TensorParamType {
pub element_type: String,
pub shape: Vec<i32>,
}
#[derive(Debug, Clone)]
pub struct Validator {
pub params: Vec<ValidParamType>,
}
type DeviceFunctions = HashMap<String, (Arc<CudaModule>, Arc<CudaFunction>)>;
type DeviceFunctionValidators = HashMap<String, Arc<Validator>>;
pub struct AsyncDeviceContext {
#[expect(dead_code, reason = "will be used when multi-device is implemented")]
device_id: usize,
context: Arc<CudaContext>,
deallocator_stream: Arc<CudaStream>,
policy: Arc<dyn SchedulingPolicy>,
functions: DeviceFunctions,
validators: DeviceFunctionValidators,
}
pub struct AsyncDeviceContexts {
default_device: Cell<usize>,
devices: Cell<Option<HashMap<usize, AsyncDeviceContext>>>,
}
thread_local!(static DEVICE_CONTEXTS: AsyncDeviceContexts = const {
AsyncDeviceContexts {
default_device: Cell::new(DEFAULT_DEVICE_ID),
devices: Cell::new(None),
}
});
pub fn get_default_device() -> usize {
DEVICE_CONTEXTS.with(|ctx| ctx.default_device.get())
}
pub fn init_device_contexts(
default_device_id: usize,
num_devices: usize,
) -> Result<(), DeviceError> {
DEVICE_CONTEXTS.with(|ctx| {
device_assert(
default_device_id,
ctx.devices.replace(None).is_none(),
"Context already initialized.",
)
})?;
let devices = HashMap::with_capacity(num_devices);
DEVICE_CONTEXTS.with(|ctx| {
ctx.default_device.set(default_device_id);
ctx.devices.set(Some(devices));
});
Ok(())
}
pub fn init_device_contexts_default() -> Result<(), DeviceError> {
let default_device = get_default_device();
init_device_contexts(default_device, DEFAULT_NUM_DEVICES)
}
pub fn new_device_context(
device_id: usize,
policy: Arc<dyn SchedulingPolicy>,
) -> Result<AsyncDeviceContext, DeviceError> {
let context = CudaContext::new(device_id)?;
let deallocator_stream = context.new_stream()?;
Ok(AsyncDeviceContext {
device_id,
context,
deallocator_stream,
policy,
functions: HashMap::new(),
validators: HashMap::new(),
})
}
pub fn init_device(
hashmap: &mut HashMap<usize, AsyncDeviceContext>,
device_id: usize,
policy: Arc<dyn SchedulingPolicy>,
) -> Result<(), DeviceError> {
let device_context = new_device_context(device_id, policy)?;
let pred = hashmap.insert(device_id, device_context).is_none();
device_assert(device_id, pred, "Device is already initialized.")
}
pub fn init_with_default_policy(
hashmap: &mut HashMap<usize, AsyncDeviceContext>,
device_id: usize,
) -> Result<(), DeviceError> {
let context = CudaContext::new(device_id)?;
let policy = StreamPoolRoundRobin::new(&context, DEFAULT_ROUND_ROBIN_STREAM_POOL_SIZE)?;
let deallocator_stream = context.new_stream()?;
let device_context = AsyncDeviceContext {
device_id,
context,
deallocator_stream,
policy: Arc::new(policy),
functions: HashMap::new(),
validators: HashMap::new(),
};
let pred = hashmap.insert(device_id, device_context).is_none();
device_assert(device_id, pred, "Device is already initialized.")
}
pub fn with_global_device_context<F, R>(device_id: usize, f: F) -> Result<R, DeviceError>
where
F: FnOnce(&AsyncDeviceContext) -> R,
{
DEVICE_CONTEXTS.with(|ctx| {
let mut hashmap = match ctx.devices.take() {
Some(hashmap) => hashmap,
None => {
init_device_contexts_default()?;
ctx.devices
.take()
.ok_or(device_error(device_id, "Failed to initialize context"))?
}
};
if !hashmap.contains_key(&device_id) {
init_with_default_policy(&mut hashmap, device_id)?;
}
let device_context = hashmap
.get(&device_id)
.ok_or(device_error(device_id, "Failed to get context"))?;
let r = f(device_context);
ctx.devices.replace(Some(hashmap));
Ok(r)
})
}
pub fn with_global_device_context_mut<F, R>(device_id: usize, f: F) -> Result<R, DeviceError>
where
F: FnOnce(&mut AsyncDeviceContext) -> R,
{
DEVICE_CONTEXTS.with(|ctx| {
let mut hashmap = match ctx.devices.take() {
Some(hashmap) => hashmap,
None => {
init_device_contexts_default()?;
ctx.devices
.take()
.ok_or(device_error(device_id, "Failed to initialize context"))?
}
};
if !hashmap.contains_key(&device_id) {
init_with_default_policy(&mut hashmap, device_id)?;
}
let device_context = hashmap
.get_mut(&device_id)
.ok_or(device_error(device_id, "Failed to get context"))?;
let r = f(device_context);
ctx.devices.replace(Some(hashmap));
Ok(r)
})
}
pub fn with_device_policy<F, R>(device_id: usize, f: F) -> Result<R, DeviceError>
where
F: FnOnce(&Arc<dyn SchedulingPolicy>) -> R,
{
with_global_device_context(device_id, |device_context| f(&device_context.policy))
}
pub fn global_policy(device_id: usize) -> Result<Arc<dyn SchedulingPolicy>, DeviceError> {
with_global_device_context(device_id, |device_context| device_context.policy.clone())
}
pub unsafe fn with_deallocator_stream<F, R>(device_id: usize, f: F) -> Result<R, DeviceError>
where
F: FnOnce(&Arc<CudaStream>) -> R,
{
with_global_device_context(device_id, |device_context| {
f(&device_context.deallocator_stream)
})
}
pub fn with_cuda_context<F, R>(device_id: usize, f: F) -> Result<R, DeviceError>
where
F: FnOnce(&Arc<CudaContext>) -> R,
{
with_global_device_context(device_id, |device_context| f(&device_context.context))
}
pub fn set_default_device(default_device_id: usize) {
DEVICE_CONTEXTS.with(|ctx| {
ctx.default_device.set(default_device_id);
})
}
pub fn with_default_device_policy<F, R>(f: F) -> Result<R, DeviceError>
where
F: FnOnce(&Arc<dyn SchedulingPolicy>) -> R,
{
let default_device = get_default_device();
with_global_device_context(default_device, |device_context| f(&device_context.policy))
}
pub fn load_module_from_file(
filename: &str,
device_id: usize,
) -> Result<Arc<CudaModule>, DeviceError> {
with_cuda_context(device_id, |cuda_ctx| {
let module = cuda_ctx.load_module_from_file(filename)?;
Ok(module)
})?
}
pub fn load_module_from_ptx(
ptx_src: &str,
device_id: usize,
) -> Result<Arc<CudaModule>, DeviceError> {
with_cuda_context(device_id, |cuda_ctx| {
let module = cuda_ctx.load_module_from_ptx_src(ptx_src)?;
Ok(module)
})?
}
pub fn insert_cuda_function(
device_id: usize,
func_key: &impl FunctionKey,
value: (Arc<CudaModule>, Arc<CudaFunction>),
) -> Result<(), DeviceError> {
with_global_device_context_mut(device_id, |device_context| {
let key = func_key.get_hash_string();
let res = device_context.functions.insert(key.clone(), value);
device_assert(device_id, res.is_none(), "Unexpected cache key collision.")
})?
}
pub fn contains_cuda_function(device_id: usize, func_key: &impl FunctionKey) -> bool {
with_global_device_context(device_id, |device_context| {
let key = func_key.get_hash_string();
device_context.functions.contains_key(&key)
})
.is_ok_and(|pred| pred)
}
pub fn get_cuda_function(
device_id: usize,
func_key: &impl FunctionKey,
) -> Result<Arc<CudaFunction>, DeviceError> {
with_global_device_context(device_id, |device_context| {
let key = func_key.get_hash_string();
let entry = device_context
.functions
.get(&key)
.ok_or(device_error(device_id, "Failed to get cuda function."))?;
Ok(entry.1.clone())
})?
}
pub fn insert_function_validator(
device_id: usize,
func_key: &impl FunctionKey,
value: Arc<Validator>,
) -> Result<(), DeviceError> {
with_global_device_context_mut(device_id, |device_context| {
let key = func_key.get_hash_string();
let res = device_context.validators.insert(key.clone(), value);
device_assert(device_id, res.is_none(), "Unexpected cache key collision.")
})?
}
pub fn get_function_validator(
device_id: usize,
func_key: &impl FunctionKey,
) -> Result<Arc<Validator>, DeviceError> {
with_global_device_context(device_id, |device_context| {
let key = func_key.get_hash_string();
let entry = device_context
.validators
.get(&key)
.ok_or(device_error(device_id, "Failed to get function validator."))?;
Ok(entry.clone())
})?
}