use std::sync::Arc;
use cudarc::driver::CudaStream;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::dtype::{CudaDtype, DType};
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use super::alloc_msg::HostBuf;
use super::state::DeviceState;
pub trait AllocDispatch: Send + 'static {
fn dtype(&self) -> DType;
fn len(&self) -> usize;
fn run(
self: Box<Self>,
stream: Option<&Arc<CudaStream>>,
state: &Arc<DeviceState>,
mock_mode: bool,
);
}
pub struct AllocReq<T: CudaDtype> {
pub len: usize,
pub reply: oneshot::Sender<Result<GpuRef<T>, GpuError>>,
}
impl<T: CudaDtype> AllocDispatch for AllocReq<T> {
fn dtype(&self) -> DType {
T::KIND
}
fn len(&self) -> usize {
self.len
}
fn run(
self: Box<Self>,
stream: Option<&Arc<CudaStream>>,
state: &Arc<DeviceState>,
mock_mode: bool,
) {
let AllocReq { len, reply } = *self;
if mock_mode {
let _ = reply.send(Err(GpuError::Unrecoverable(
"alloc not supported in mock mode".into(),
)));
return;
}
let Some(stream) = stream else {
let _ = reply.send(Err(GpuError::GpuRefStale("context not ready")));
return;
};
match stream.alloc_zeros::<T>(len) {
Ok(slice) => {
let _ = reply.send(Ok(GpuRef::<T>::new(Arc::new(slice), state)));
}
Err(e) => {
let _ = reply.send(Err(GpuError::OutOfMemory(format!("alloc {len}: {e}"))));
}
}
}
}
pub trait CopyToHostDispatch: Send + 'static {
fn dtype(&self) -> DType;
fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>);
}
pub struct CopyToHostReq<T: CudaDtype> {
pub src: GpuRef<T>,
pub dst: HostBuf<T>,
pub reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
}
impl<T: CudaDtype> CopyToHostDispatch for CopyToHostReq<T> {
fn dtype(&self) -> DType {
T::KIND
}
fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>) {
let CopyToHostReq { src, dst, reply } = *self;
super::context_actor::run_copy_to_host(src, dst, stream, completion, reply);
}
}
pub trait CopyFromHostDispatch: Send + 'static {
fn dtype(&self) -> DType;
fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>);
}
pub struct CopyFromHostReq<T: CudaDtype> {
pub src: HostBuf<T>,
pub dst: GpuRef<T>,
pub reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
}
impl<T: CudaDtype> CopyFromHostDispatch for CopyFromHostReq<T> {
fn dtype(&self) -> DType {
T::KIND
}
fn run(self: Box<Self>, stream: Arc<CudaStream>, completion: Arc<dyn CompletionStrategy>) {
let CopyFromHostReq { src, dst, reply } = *self;
super::context_actor::run_copy_from_host(src, dst, stream, completion, reply);
}
}