use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::error::GpuError;
#[cfg(feature = "cuda-ipc")]
use crate::sys::cuda_driver;
#[cfg(feature = "cuda-ipc")]
use cudarc::driver::sys as driver_sys;
const LIB: &str = "event";
#[derive(Clone)]
pub struct Event {
inner: Arc<EventInner>,
}
struct EventInner {
event: CudaEvent,
}
impl Event {
pub fn from_cuda(event: CudaEvent) -> Self {
Self {
inner: Arc::new(EventInner { event }),
}
}
pub fn cuda_event(&self) -> &CudaEvent {
&self.inner.event
}
#[cfg(feature = "cuda-ipc")]
pub fn cu_event(&self) -> driver_sys::CUevent {
self.inner.event.cu_event()
}
}
#[cfg(feature = "cuda-ipc")]
#[derive(Clone, Copy)]
pub struct IpcEventHandle {
pub(crate) raw: driver_sys::CUipcEventHandle,
}
#[cfg(feature = "cuda-ipc")]
impl IpcEventHandle {
pub fn as_bytes(&self) -> [u8; 64] {
unsafe { std::mem::transmute::<[std::ffi::c_char; 64], [u8; 64]>(self.raw.reserved) }
}
pub fn from_bytes(bytes: [u8; 64]) -> Self {
let raw = driver_sys::CUipcEventHandle_st {
reserved: unsafe { std::mem::transmute::<[u8; 64], [std::ffi::c_char; 64]>(bytes) },
};
Self { raw }
}
}
#[cfg(feature = "cuda-ipc")]
impl std::fmt::Debug for IpcEventHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IpcEventHandle")
.field("bytes_hash", &fxhash(&self.as_bytes()))
.finish()
}
}
#[cfg(feature = "cuda-ipc")]
fn fxhash(bytes: &[u8]) -> u64 {
let mut h: u64 = 0xcbf29ce484222325;
for b in bytes {
h ^= *b as u64;
h = h.wrapping_mul(0x100000001b3);
}
h
}
pub enum EventMsg {
Create {
reply: oneshot::Sender<Result<Event, GpuError>>,
},
Record {
event: Event,
stream: Arc<CudaStream>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
Wait {
event: Event,
stream: Arc<CudaStream>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
Query {
event: Event,
reply: oneshot::Sender<Result<bool, GpuError>>,
},
ElapsedTime {
start: Event,
end: Event,
reply: oneshot::Sender<Result<Duration, GpuError>>,
},
Synchronize {
event: Event,
reply: oneshot::Sender<Result<(), GpuError>>,
},
#[cfg(feature = "cuda-ipc")]
GetIpcHandle {
event: Event,
reply: oneshot::Sender<Result<IpcEventHandle, GpuError>>,
},
#[cfg(feature = "cuda-ipc")]
OpenIpcHandle {
handle: IpcEventHandle,
reply: oneshot::Sender<Result<Event, GpuError>>,
},
}
struct SendCtx(Arc<CudaContext>);
unsafe impl Send for SendCtx {}
unsafe impl Sync for SendCtx {}
#[allow(dead_code)]
enum EventInnerActor {
Real { ctx: Mutex<SendCtx> },
Mock,
}
pub struct EventActor {
inner: EventInnerActor,
}
impl EventActor {
pub fn props(ctx: Arc<CudaContext>) -> Props<Self> {
Props::create(move || EventActor {
inner: EventInnerActor::Real {
ctx: Mutex::new(SendCtx(ctx.clone())),
},
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| EventActor {
inner: EventInnerActor::Mock,
})
}
}
#[async_trait]
impl Actor for EventActor {
type Msg = EventMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: EventMsg) {
match &self.inner {
EventInnerActor::Mock => mock_reply(msg),
EventInnerActor::Real { ctx } => {
let ctx = ctx.lock().0.clone();
handle_real(&ctx, msg);
}
}
}
}
fn mock_reply(msg: EventMsg) {
match msg {
EventMsg::Create { reply } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor in mock mode".into(),
)));
}
EventMsg::Record { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor in mock mode".into(),
)));
}
EventMsg::Wait { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor in mock mode".into(),
)));
}
EventMsg::Query { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor in mock mode".into(),
)));
}
EventMsg::ElapsedTime { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor in mock mode".into(),
)));
}
EventMsg::Synchronize { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor in mock mode".into(),
)));
}
#[cfg(feature = "cuda-ipc")]
EventMsg::GetIpcHandle { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor in mock mode".into(),
)));
}
#[cfg(feature = "cuda-ipc")]
EventMsg::OpenIpcHandle { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor in mock mode".into(),
)));
}
}
}
fn handle_real(ctx: &Arc<CudaContext>, msg: EventMsg) {
match msg {
EventMsg::Create { reply } => {
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
ctx.new_event(None)
.map(Event::from_cuda)
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("new_event: {e}"),
})
}))
.unwrap_or_else(|_| {
Err(GpuError::Unrecoverable(
"EventActor::Create: CUDA driver not loadable".into(),
))
});
let _ = reply.send(r);
}
EventMsg::Record {
event,
stream,
reply,
} => {
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
event
.cuda_event()
.record(&stream)
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("record: {e}"),
})
}))
.unwrap_or_else(|_| {
Err(GpuError::Unrecoverable(
"EventActor::Record: CUDA driver not loadable".into(),
))
});
let _ = reply.send(r);
}
EventMsg::Wait {
event,
stream,
reply,
} => {
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
stream
.wait(event.cuda_event())
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("wait: {e}"),
})
}))
.unwrap_or_else(|_| {
Err(GpuError::Unrecoverable(
"EventActor::Wait: CUDA driver not loadable".into(),
))
});
let _ = reply.send(r);
}
EventMsg::Query { event, reply } => {
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Ok::<_, GpuError>(event.cuda_event().is_complete())
}))
.unwrap_or_else(|_| {
Err(GpuError::Unrecoverable(
"EventActor::Query: CUDA driver not loadable".into(),
))
});
let _ = reply.send(r);
}
EventMsg::ElapsedTime { start, end, reply } => {
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
start
.cuda_event()
.elapsed_ms(end.cuda_event())
.map(|ms| Duration::from_secs_f64(ms as f64 / 1000.0))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("elapsed: {e}"),
})
}))
.unwrap_or_else(|_| {
Err(GpuError::Unrecoverable(
"EventActor::ElapsedTime: CUDA driver not loadable".into(),
))
});
let _ = reply.send(r);
}
EventMsg::Synchronize { event, reply } => {
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
event
.cuda_event()
.synchronize()
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("synchronize: {e}"),
})
}))
.unwrap_or_else(|_| {
Err(GpuError::Unrecoverable(
"EventActor::Synchronize: CUDA driver not loadable".into(),
))
});
let _ = reply.send(r);
}
#[cfg(feature = "cuda-ipc")]
EventMsg::GetIpcHandle { event, reply } => {
let r = cuda_driver::ipc_get_event_handle(event.cu_event())
.map(|raw| IpcEventHandle { raw });
let _ = reply.send(r);
}
#[cfg(feature = "cuda-ipc")]
EventMsg::OpenIpcHandle { handle, reply } => {
let raw_event = match cuda_driver::ipc_open_event_handle(handle.raw) {
Ok(e) => e,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let _ = raw_event;
let _ = reply.send(Err(GpuError::Unrecoverable(
"EventActor::OpenIpcHandle: cudarc 0.19 lacks CudaEvent::from_raw — \
use the IpcEventHandle bytes directly with cuStreamWaitEvent on the \
destination context"
.into(),
)));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use atomr_config::Config;
use atomr_core::actor::ActorSystem;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn event_msg_round_trip() {
let sys = ActorSystem::create("event-msg-test", Config::empty())
.await
.unwrap();
let actor = sys.actor_of(EventActor::mock_props(), "evt").unwrap();
let (tx, rx) = oneshot::channel();
actor.tell(EventMsg::Create { reply: tx });
let r = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
assert!(matches!(r, Err(GpuError::Unrecoverable(_))));
sys.terminate().await;
}
#[cfg(feature = "cuda-ipc")]
#[test]
fn ipc_event_handle_serializes() {
let bytes: [u8; 64] = std::array::from_fn(|i| i as u8);
let h = IpcEventHandle::from_bytes(bytes);
let round = h.as_bytes();
assert_eq!(round, bytes);
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<IpcEventHandle>();
let _clone: IpcEventHandle = h;
}
}