use std::any::{Any, TypeId};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context, Props};
use bitflags::bitflags;
use parking_lot::RwLock;
use tokio::sync::oneshot;
use tracing::{debug, warn};
use crate::dtype::CudaDtype;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::BlasMsg;
use super::alloc_dispatch::{
AllocDispatch, AllocReq, CopyFromHostDispatch, CopyFromHostReq, CopyToHostDispatch,
CopyToHostReq,
};
use super::alloc_msg::{DeviceLoad, HostBuf};
use super::context_actor::{ContextActor, ContextMsg};
use super::state::DeviceState;
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EnabledLibraries: u32 {
const BLAS = 1 << 0;
const CUDNN = 1 << 1;
const CUFFT = 1 << 2;
const CURAND = 1 << 3;
const CUSOLVER = 1 << 4;
const CUBLASLT = 1 << 5;
const NVRTC = 1 << 6;
const CUTENSOR = 1 << 7;
const CUSPARSE = 1 << 8;
const NCCL = 1 << 9;
const CUTLASS = 1 << 10;
const TENSORRT = 1 << 11;
const FLASHATTN = 1 << 12;
const CUB_THRUST = 1 << 13;
const TELEMETRY = 1 << 14;
const ALL = Self::BLAS.bits()
| Self::CUDNN.bits()
| Self::CUFFT.bits()
| Self::CURAND.bits()
| Self::CUSOLVER.bits()
| Self::CUBLASLT.bits()
| Self::NVRTC.bits()
| Self::CUTENSOR.bits()
| Self::CUSPARSE.bits()
| Self::NCCL.bits()
| Self::CUTLASS.bits()
| Self::TENSORRT.bits()
| Self::FLASHATTN.bits()
| Self::CUB_THRUST.bits()
| Self::TELEMETRY.bits();
}
}
impl Default for EnabledLibraries {
fn default() -> Self {
Self::BLAS
}
}
#[derive(Debug, Clone)]
pub struct DeviceConfig {
pub device_id: u32,
pub mock_mode: bool,
pub pending_queue_capacity: usize,
pub enabled_libraries: EnabledLibraries,
}
impl DeviceConfig {
pub fn new(device_id: u32) -> Self {
Self {
device_id,
mock_mode: false,
pending_queue_capacity: 1024,
enabled_libraries: EnabledLibraries::default(),
}
}
pub fn mock(device_id: u32) -> Self {
Self {
device_id,
mock_mode: true,
pending_queue_capacity: 1024,
enabled_libraries: EnabledLibraries::default(),
}
}
pub fn with_libraries(mut self, libs: EnabledLibraries) -> Self {
self.enabled_libraries = libs;
self
}
}
pub enum DeviceMsg {
Alloc(Box<dyn AllocDispatch>),
CopyToHost(Box<dyn CopyToHostDispatch>),
CopyFromHost(Box<dyn CopyFromHostDispatch>),
#[deprecated(note = "use DeviceMsg::alloc::<f32>(len, reply)")]
Allocate {
len: usize,
reply: oneshot::Sender<Result<GpuRef<f32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::alloc::<f32>(len, reply)")]
AllocateF32 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<f32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::alloc::<f64>(len, reply)")]
AllocateF64 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<f64>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::alloc::<i8>(len, reply)")]
AllocateI8 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<i8>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::alloc::<i32>(len, reply)")]
AllocateI32 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<i32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::alloc::<i64>(len, reply)")]
AllocateI64 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<i64>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::alloc::<u8>(len, reply)")]
AllocateU8 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<u8>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::alloc::<u32>(len, reply)")]
AllocateU32 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<u32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::alloc::<u64>(len, reply)")]
AllocateU64 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<u64>, GpuError>>,
},
#[cfg(feature = "f16")]
#[deprecated(note = "use DeviceMsg::alloc::<half::f16>(len, reply)")]
AllocateF16 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<half::f16>, GpuError>>,
},
#[cfg(feature = "f16")]
#[deprecated(note = "use DeviceMsg::alloc::<half::bf16>(len, reply)")]
AllocateBf16 {
len: usize,
reply: oneshot::Sender<Result<GpuRef<half::bf16>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_to_host::<f32>(src, dst, reply)")]
CopyToHostF32 {
src: GpuRef<f32>,
dst: HostBuf<f32>,
reply: oneshot::Sender<Result<HostBuf<f32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_from_host::<f32>(src, dst, reply)")]
CopyFromHostF32 {
src: HostBuf<f32>,
dst: GpuRef<f32>,
reply: oneshot::Sender<Result<HostBuf<f32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_to_host::<f64>(src, dst, reply)")]
CopyToHostF64 {
src: GpuRef<f64>,
dst: HostBuf<f64>,
reply: oneshot::Sender<Result<HostBuf<f64>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_from_host::<f64>(src, dst, reply)")]
CopyFromHostF64 {
src: HostBuf<f64>,
dst: GpuRef<f64>,
reply: oneshot::Sender<Result<HostBuf<f64>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_to_host::<i32>(src, dst, reply)")]
CopyToHostI32 {
src: GpuRef<i32>,
dst: HostBuf<i32>,
reply: oneshot::Sender<Result<HostBuf<i32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_from_host::<i32>(src, dst, reply)")]
CopyFromHostI32 {
src: HostBuf<i32>,
dst: GpuRef<i32>,
reply: oneshot::Sender<Result<HostBuf<i32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_to_host::<u32>(src, dst, reply)")]
CopyToHostU32 {
src: GpuRef<u32>,
dst: HostBuf<u32>,
reply: oneshot::Sender<Result<HostBuf<u32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_from_host::<u32>(src, dst, reply)")]
CopyFromHostU32 {
src: HostBuf<u32>,
dst: GpuRef<u32>,
reply: oneshot::Sender<Result<HostBuf<u32>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_to_host::<u8>(src, dst, reply)")]
CopyToHostU8 {
src: GpuRef<u8>,
dst: HostBuf<u8>,
reply: oneshot::Sender<Result<HostBuf<u8>, GpuError>>,
},
#[deprecated(note = "use DeviceMsg::copy_from_host::<u8>(src, dst, reply)")]
CopyFromHostU8 {
src: HostBuf<u8>,
dst: GpuRef<u8>,
reply: oneshot::Sender<Result<HostBuf<u8>, GpuError>>,
},
Sgemm(Box<SgemmRequest>),
SnapshotContext {
reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaContext>>>,
},
SnapshotChildren {
reply: oneshot::Sender<Option<KernelChildren>>,
},
WatchGeneration {
reply: oneshot::Sender<tokio::sync::watch::Receiver<u64>>,
},
Stats { reply: oneshot::Sender<DeviceLoad> },
ContextReady { children: KernelChildren },
ContextLost,
}
#[derive(Clone)]
pub struct KernelChildren {
pub blas: ActorRef<BlasMsg>,
#[cfg(feature = "cudnn")]
pub cudnn: Option<ActorRef<crate::kernel::CudnnMsg>>,
#[cfg(feature = "cufft")]
pub fft: Option<ActorRef<crate::kernel::FftMsg>>,
#[cfg(feature = "curand")]
pub rng: Option<ActorRef<crate::kernel::RngMsg>>,
extras: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
}
impl KernelChildren {
pub fn new(blas: ActorRef<BlasMsg>) -> Self {
Self {
blas,
#[cfg(feature = "cudnn")]
cudnn: None,
#[cfg(feature = "cufft")]
fft: None,
#[cfg(feature = "curand")]
rng: None,
extras: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register_extra<T: Any + Send + Sync>(&self, value: T) {
let mut g = self.extras.write();
g.insert(TypeId::of::<T>(), Arc::new(value));
}
pub fn extra<T: Any + Send + Sync + Clone>(&self) -> Option<T> {
let g = self.extras.read();
g.get(&TypeId::of::<T>())
.and_then(|v| v.clone().downcast::<T>().ok())
.map(|arc| (*arc).clone())
}
pub fn extras_len(&self) -> usize {
self.extras.read().len()
}
}
impl DeviceMsg {
pub fn alloc<T: CudaDtype>(
len: usize,
reply: oneshot::Sender<Result<GpuRef<T>, GpuError>>,
) -> Self {
DeviceMsg::Alloc(Box::new(AllocReq::<T> { len, reply }))
}
pub fn copy_to_host<T: CudaDtype>(
src: GpuRef<T>,
dst: HostBuf<T>,
reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
) -> Self {
DeviceMsg::CopyToHost(Box::new(CopyToHostReq::<T> { src, dst, reply }))
}
pub fn copy_from_host<T: CudaDtype>(
src: HostBuf<T>,
dst: GpuRef<T>,
reply: oneshot::Sender<Result<HostBuf<T>, GpuError>>,
) -> Self {
DeviceMsg::CopyFromHost(Box::new(CopyFromHostReq::<T> { src, dst, reply }))
}
}
pub struct SgemmRequest {
pub a: GpuRef<f32>,
pub b: GpuRef<f32>,
pub c: GpuRef<f32>,
pub m: i32,
pub n: i32,
pub k: i32,
pub alpha: f32,
pub beta: f32,
pub reply: oneshot::Sender<Result<(), GpuError>>,
}
pub enum WorkRequest {
Boxed(Box<dyn FnOnce(&ActorRef<ContextMsg>, &ActorRef<BlasMsg>) + Send>),
Sgemm(Box<SgemmRequest>),
SnapshotContext {
reply: oneshot::Sender<Option<Arc<cudarc::driver::CudaContext>>>,
},
}
pub struct DeviceActor {
config: DeviceConfig,
state: Arc<DeviceState>,
context_ref: Option<ActorRef<ContextMsg>>,
children: Option<KernelChildren>,
pending: VecDeque<WorkRequest>,
}
impl DeviceActor {
pub fn new(config: DeviceConfig) -> Self {
let state = Arc::new(DeviceState::new(config.device_id));
Self {
config,
state,
context_ref: None,
children: None,
pending: VecDeque::new(),
}
}
pub fn props(config: DeviceConfig) -> Props<Self> {
let cfg = config.clone();
Props::create(move || DeviceActor::new(cfg.clone()))
}
pub fn state(&self) -> &Arc<DeviceState> {
&self.state
}
fn enqueue_pending(&mut self, work: WorkRequest) {
if self.pending.len() >= self.config.pending_queue_capacity {
warn!(
device_id = self.config.device_id,
cap = self.config.pending_queue_capacity,
"dropping work — pending queue full"
);
match work {
WorkRequest::Sgemm(req) => {
let _ = req.reply.send(Err(GpuError::Unrecoverable(
"device pending queue full".into(),
)));
}
WorkRequest::SnapshotContext { reply } => {
let _ = reply.send(None);
}
WorkRequest::Boxed(_) => { }
}
return;
}
self.pending.push_back(work);
}
fn drain_pending(&mut self) {
let Some(children) = self.children.clone() else {
return;
};
let Some(ctx) = self.context_ref.clone() else {
return;
};
while let Some(work) = self.pending.pop_front() {
match work {
WorkRequest::Boxed(f) => f(&ctx, &children.blas),
WorkRequest::Sgemm(req) => {
children.blas.tell(BlasMsg::Sgemm(req));
}
WorkRequest::SnapshotContext { reply } => {
let _ = reply.send(self.state.current_context());
}
}
}
}
}
#[async_trait]
impl Actor for DeviceActor {
type Msg = DeviceMsg;
async fn pre_start(&mut self, ctx: &mut Context<Self>) {
debug!(device_id = self.config.device_id, "DeviceActor pre_start");
let parent_ref = ctx.self_ref().clone();
let props = ContextActor::props(self.state.clone(), self.config.clone(), parent_ref);
match ctx.spawn::<ContextActor>(props, "ctx") {
Ok(r) => {
self.context_ref = Some(r);
}
Err(e) => {
panic!("Unrecoverable: failed to spawn ContextActor: {e}");
}
}
}
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: DeviceMsg) {
#[allow(deprecated)]
let msg = match msg {
DeviceMsg::Allocate { len, reply } | DeviceMsg::AllocateF32 { len, reply } => {
DeviceMsg::alloc::<f32>(len, reply)
}
DeviceMsg::AllocateF64 { len, reply } => DeviceMsg::alloc::<f64>(len, reply),
DeviceMsg::AllocateI8 { len, reply } => DeviceMsg::alloc::<i8>(len, reply),
DeviceMsg::AllocateI32 { len, reply } => DeviceMsg::alloc::<i32>(len, reply),
DeviceMsg::AllocateI64 { len, reply } => DeviceMsg::alloc::<i64>(len, reply),
DeviceMsg::AllocateU8 { len, reply } => DeviceMsg::alloc::<u8>(len, reply),
DeviceMsg::AllocateU32 { len, reply } => DeviceMsg::alloc::<u32>(len, reply),
DeviceMsg::AllocateU64 { len, reply } => DeviceMsg::alloc::<u64>(len, reply),
#[cfg(feature = "f16")]
DeviceMsg::AllocateF16 { len, reply } => DeviceMsg::alloc::<half::f16>(len, reply),
#[cfg(feature = "f16")]
DeviceMsg::AllocateBf16 { len, reply } => DeviceMsg::alloc::<half::bf16>(len, reply),
DeviceMsg::CopyToHostF32 { src, dst, reply } => {
DeviceMsg::copy_to_host::<f32>(src, dst, reply)
}
DeviceMsg::CopyToHostF64 { src, dst, reply } => {
DeviceMsg::copy_to_host::<f64>(src, dst, reply)
}
DeviceMsg::CopyToHostI32 { src, dst, reply } => {
DeviceMsg::copy_to_host::<i32>(src, dst, reply)
}
DeviceMsg::CopyToHostU32 { src, dst, reply } => {
DeviceMsg::copy_to_host::<u32>(src, dst, reply)
}
DeviceMsg::CopyToHostU8 { src, dst, reply } => {
DeviceMsg::copy_to_host::<u8>(src, dst, reply)
}
DeviceMsg::CopyFromHostF32 { src, dst, reply } => {
DeviceMsg::copy_from_host::<f32>(src, dst, reply)
}
DeviceMsg::CopyFromHostF64 { src, dst, reply } => {
DeviceMsg::copy_from_host::<f64>(src, dst, reply)
}
DeviceMsg::CopyFromHostI32 { src, dst, reply } => {
DeviceMsg::copy_from_host::<i32>(src, dst, reply)
}
DeviceMsg::CopyFromHostU32 { src, dst, reply } => {
DeviceMsg::copy_from_host::<u32>(src, dst, reply)
}
DeviceMsg::CopyFromHostU8 { src, dst, reply } => {
DeviceMsg::copy_from_host::<u8>(src, dst, reply)
}
other => other,
};
let ready = self.context_ref.is_some() && self.children.is_some();
match msg {
DeviceMsg::Alloc(boxed) => {
if ready {
self.context_ref
.as_ref()
.unwrap()
.tell(ContextMsg::Alloc(boxed));
} else {
self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
c.tell(ContextMsg::Alloc(boxed))
})));
}
}
DeviceMsg::CopyToHost(boxed) => {
if ready {
self.context_ref
.as_ref()
.unwrap()
.tell(ContextMsg::CopyToHost(boxed));
} else {
self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
c.tell(ContextMsg::CopyToHost(boxed))
})));
}
}
DeviceMsg::CopyFromHost(boxed) => {
if ready {
self.context_ref
.as_ref()
.unwrap()
.tell(ContextMsg::CopyFromHost(boxed));
} else {
self.enqueue_pending(WorkRequest::Boxed(Box::new(move |c, _b| {
c.tell(ContextMsg::CopyFromHost(boxed))
})));
}
}
#[allow(deprecated)]
DeviceMsg::Allocate { .. }
| DeviceMsg::AllocateF32 { .. }
| DeviceMsg::AllocateF64 { .. }
| DeviceMsg::AllocateI8 { .. }
| DeviceMsg::AllocateI32 { .. }
| DeviceMsg::AllocateI64 { .. }
| DeviceMsg::AllocateU8 { .. }
| DeviceMsg::AllocateU32 { .. }
| DeviceMsg::AllocateU64 { .. }
| DeviceMsg::CopyToHostF32 { .. }
| DeviceMsg::CopyFromHostF32 { .. }
| DeviceMsg::CopyToHostF64 { .. }
| DeviceMsg::CopyFromHostF64 { .. }
| DeviceMsg::CopyToHostI32 { .. }
| DeviceMsg::CopyFromHostI32 { .. }
| DeviceMsg::CopyToHostU32 { .. }
| DeviceMsg::CopyFromHostU32 { .. }
| DeviceMsg::CopyToHostU8 { .. }
| DeviceMsg::CopyFromHostU8 { .. } => unreachable!(
"Phase 0.4 translation collapses all legacy alloc/copy variants \
into DeviceMsg::Alloc / CopyToHost / CopyFromHost"
),
#[cfg(feature = "f16")]
#[allow(deprecated)]
DeviceMsg::AllocateF16 { .. } | DeviceMsg::AllocateBf16 { .. } => {
unreachable!(
"Phase 0.4 translation collapses all legacy alloc/copy variants \
into DeviceMsg::Alloc"
)
}
DeviceMsg::Sgemm(req) => match &self.children {
Some(c) => c.blas.tell(BlasMsg::Sgemm(req)),
None => self.enqueue_pending(WorkRequest::Sgemm(req)),
},
DeviceMsg::SnapshotContext { reply } => {
let _ = reply.send(self.state.current_context());
}
DeviceMsg::SnapshotChildren { reply } => {
let _ = reply.send(self.children.clone());
}
DeviceMsg::WatchGeneration { reply } => {
let _ = reply.send(self.state.generation_watch());
}
DeviceMsg::Stats { reply } => {
let _ = reply.send(self.snapshot_load());
}
DeviceMsg::ContextReady { children } => {
debug!(device_id = self.config.device_id, "context ready");
self.children = Some(children);
self.drain_pending();
}
DeviceMsg::ContextLost => {
debug!(device_id = self.config.device_id, "context lost");
self.children = None;
}
}
}
async fn post_stop(&mut self, _ctx: &mut Context<Self>) {
debug!(device_id = self.config.device_id, "DeviceActor post_stop");
self.state.begin_shutdown();
while let Some(work) = self.pending.pop_front() {
match work {
WorkRequest::Boxed(_) => { }
WorkRequest::Sgemm(req) => {
let _ = req
.reply
.send(Err(GpuError::GpuRefStale("device shutting down")));
}
WorkRequest::SnapshotContext { reply } => {
let _ = reply.send(None);
}
}
}
}
}
impl DeviceActor {
fn snapshot_load(&self) -> DeviceLoad {
DeviceLoad {
free_bytes: 0,
total_bytes: 0,
active_streams: 0,
queue_depth: self.pending.len() as u32,
compute_cap: (0, 0),
}
}
}
#[cfg(test)]
#[allow(deprecated)] mod tests {
use super::*;
use crate::dtype::DType;
use atomr_config::Config;
use atomr_core::actor::ActorSystem;
use std::time::Duration;
#[test]
fn enabled_libraries_bit_values_are_stable() {
assert_eq!(EnabledLibraries::BLAS.bits(), 1 << 0);
assert_eq!(EnabledLibraries::CUDNN.bits(), 1 << 1);
assert_eq!(EnabledLibraries::CUFFT.bits(), 1 << 2);
assert_eq!(EnabledLibraries::CURAND.bits(), 1 << 3);
assert_eq!(EnabledLibraries::CUSOLVER.bits(), 1 << 4);
assert_eq!(EnabledLibraries::CUBLASLT.bits(), 1 << 5);
assert_eq!(EnabledLibraries::NVRTC.bits(), 1 << 6);
assert_eq!(EnabledLibraries::CUTENSOR.bits(), 1 << 7);
assert_eq!(EnabledLibraries::CUSPARSE.bits(), 1 << 8);
assert_eq!(EnabledLibraries::NCCL.bits(), 1 << 9);
assert_eq!(EnabledLibraries::CUTLASS.bits(), 1 << 10);
assert_eq!(EnabledLibraries::TENSORRT.bits(), 1 << 11);
assert_eq!(EnabledLibraries::FLASHATTN.bits(), 1 << 12);
assert_eq!(EnabledLibraries::CUB_THRUST.bits(), 1 << 13);
assert_eq!(EnabledLibraries::TELEMETRY.bits(), 1 << 14);
}
#[test]
fn enabled_libraries_round_trip_via_bits() {
let original = EnabledLibraries::BLAS
| EnabledLibraries::CUTENSOR
| EnabledLibraries::FLASHATTN
| EnabledLibraries::TELEMETRY;
let bits = original.bits();
let restored =
EnabledLibraries::from_bits(bits).expect("known bits round-trip through from_bits");
assert_eq!(original, restored);
assert!(restored.contains(EnabledLibraries::FLASHATTN));
assert!(!restored.contains(EnabledLibraries::CUDNN));
}
#[test]
fn enabled_libraries_all_contains_every_phase_0_8_bit() {
let all = EnabledLibraries::ALL;
for bit in [
EnabledLibraries::BLAS,
EnabledLibraries::CUDNN,
EnabledLibraries::CUFFT,
EnabledLibraries::CURAND,
EnabledLibraries::CUSOLVER,
EnabledLibraries::CUBLASLT,
EnabledLibraries::NVRTC,
EnabledLibraries::CUTENSOR,
EnabledLibraries::CUSPARSE,
EnabledLibraries::NCCL,
EnabledLibraries::CUTLASS,
EnabledLibraries::TENSORRT,
EnabledLibraries::FLASHATTN,
EnabledLibraries::CUB_THRUST,
EnabledLibraries::TELEMETRY,
] {
assert!(all.contains(bit), "ALL missing {bit:?}");
}
}
#[test]
fn kernel_children_extras_register_and_retrieve_by_type() {
let extras: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> =
Arc::new(RwLock::new(HashMap::new()));
fn register<T: Any + Send + Sync>(
map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
v: T,
) {
map.write().insert(TypeId::of::<T>(), Arc::new(v));
}
fn lookup<T: Any + Send + Sync + Clone>(
map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
) -> Option<T> {
map.read()
.get(&TypeId::of::<T>())
.and_then(|v| v.clone().downcast::<T>().ok())
.map(|arc| (*arc).clone())
}
#[derive(Clone, PartialEq, Eq, Debug)]
struct CutlassRef(u32);
#[derive(Clone, PartialEq, Eq, Debug)]
struct TensorRtRef(&'static str);
register(&extras, CutlassRef(7));
register(&extras, TensorRtRef("trt"));
assert_eq!(lookup::<CutlassRef>(&extras), Some(CutlassRef(7)));
assert_eq!(lookup::<TensorRtRef>(&extras), Some(TensorRtRef("trt")));
#[derive(Clone)]
struct Unknown;
assert!(lookup::<Unknown>(&extras).is_none());
assert_eq!(extras.read().len(), 2);
register(&extras, CutlassRef(99));
assert_eq!(lookup::<CutlassRef>(&extras), Some(CutlassRef(99)));
assert_eq!(extras.read().len(), 2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn kernel_children_extras_via_snapshot() {
let sys = ActorSystem::create("kc_extras", Config::empty())
.await
.unwrap();
let dev = sys
.actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev0")
.unwrap();
let mut snap: Option<KernelChildren> = None;
for _ in 0..50 {
let (tx, rx) = oneshot::channel();
dev.tell(DeviceMsg::SnapshotChildren { reply: tx });
if let Ok(Some(c)) = rx.await {
snap = Some(c);
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
let children = snap.expect("KernelChildren snapshot should arrive in mock mode");
assert_eq!(children.extras_len(), 0);
#[derive(Clone, Debug, PartialEq, Eq)]
struct FakeCutlassRef(u64);
children.register_extra(FakeCutlassRef(42));
assert_eq!(children.extras_len(), 1);
assert_eq!(children.extra::<FakeCutlassRef>(), Some(FakeCutlassRef(42)));
let cloned = children.clone();
assert_eq!(cloned.extras_len(), 1);
assert_eq!(cloned.extra::<FakeCutlassRef>(), Some(FakeCutlassRef(42)));
sys.terminate().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pending_work_drains_on_context_ready() {
let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
let dev = sys
.actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev0")
.unwrap();
let (tx, rx) = oneshot::channel();
dev.tell(DeviceMsg::Allocate { len: 16, reply: tx });
let res = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.expect("alloc reply should arrive within timeout")
.expect("oneshot dropped");
assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
sys.terminate().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn alloc_dispatch_via_typed_constructor() {
let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
let dev = sys
.actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev1")
.unwrap();
let (tx, rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
dev.tell(DeviceMsg::alloc::<f32>(64, tx));
let res = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.expect("alloc reply within timeout")
.expect("oneshot dropped");
assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
sys.terminate().await;
}
#[test]
fn alloc_dispatch_dtype_kind_correct() {
let (tx, _rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<f32> { len: 4, reply: tx });
assert_eq!(boxed.dtype(), DType::F32);
assert_eq!(boxed.len(), 4);
let (tx, _rx) = oneshot::channel::<Result<GpuRef<i32>, GpuError>>();
let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<i32> { len: 7, reply: tx });
assert_eq!(boxed.dtype(), DType::I32);
let (tx, _rx) = oneshot::channel::<Result<GpuRef<u8>, GpuError>>();
let boxed: Box<dyn AllocDispatch> = Box::new(AllocReq::<u8> { len: 1, reply: tx });
assert_eq!(boxed.dtype(), DType::U8);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn deprecated_allocate_f32_still_works() {
let sys = ActorSystem::create("test", Config::empty()).await.unwrap();
let dev = sys
.actor_of(DeviceActor::props(DeviceConfig::mock(0)), "dev2")
.unwrap();
let (tx, rx) = oneshot::channel::<Result<GpuRef<f32>, GpuError>>();
dev.tell(DeviceMsg::AllocateF32 { len: 8, reply: tx });
let res = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.expect("alloc reply within timeout")
.expect("oneshot dropped");
assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
sys.terminate().await;
}
#[test]
fn copy_to_host_typed() {
struct Stub<T: CudaDtype>(std::marker::PhantomData<T>);
impl<T: CudaDtype> CopyToHostDispatch for Stub<T> {
fn dtype(&self) -> DType {
T::KIND
}
fn run(
self: Box<Self>,
_stream: Arc<cudarc::driver::CudaStream>,
_completion: Arc<dyn crate::completion::CompletionStrategy>,
) {
}
}
let boxed: Box<dyn CopyToHostDispatch> = Box::new(Stub::<f32>(std::marker::PhantomData));
assert_eq!(boxed.dtype(), DType::F32);
let boxed: Box<dyn CopyToHostDispatch> = Box::new(Stub::<i32>(std::marker::PhantomData));
assert_eq!(boxed.dtype(), DType::I32);
let msg = DeviceMsg::CopyToHost(Box::new(Stub::<u32>(std::marker::PhantomData)));
match msg {
DeviceMsg::CopyToHost(b) => assert_eq!(b.dtype(), DType::U32),
_ => panic!("expected CopyToHost variant"),
}
}
#[test]
fn copy_from_host_typed() {
struct Stub<T: CudaDtype>(std::marker::PhantomData<T>);
impl<T: CudaDtype> CopyFromHostDispatch for Stub<T> {
fn dtype(&self) -> DType {
T::KIND
}
fn run(
self: Box<Self>,
_stream: Arc<cudarc::driver::CudaStream>,
_completion: Arc<dyn crate::completion::CompletionStrategy>,
) {
}
}
let boxed: Box<dyn CopyFromHostDispatch> = Box::new(Stub::<u8>(std::marker::PhantomData));
assert_eq!(boxed.dtype(), DType::U8);
let msg = DeviceMsg::CopyFromHost(Box::new(Stub::<f64>(std::marker::PhantomData)));
match msg {
DeviceMsg::CopyFromHost(b) => assert_eq!(b.dtype(), DType::F64),
_ => panic!("expected CopyFromHost variant"),
}
}
}