use cubecl_common::backtrace::BackTrace;
use cubecl_core::server::ServerError;
use cudarc::driver::sys::{CUevent_flags, CUevent_st, CUevent_wait_flags, CUstream_st};
#[derive(Debug)]
pub struct Fence {
event: *mut CUevent_st,
}
unsafe impl Send for Fence {}
#[allow(unused)]
impl Fence {
pub fn new(stream: *mut CUstream_st) -> Self {
unsafe {
let event =
cudarc::driver::result::event::create(CUevent_flags::CU_EVENT_DEFAULT).unwrap();
cudarc::driver::result::event::record(event, stream).unwrap();
Self { event }
}
}
pub fn wait_sync(self) -> Result<(), ServerError> {
unsafe {
cudarc::driver::result::event::synchronize(self.event).map_err(|err| {
ServerError::Generic {
reason: format!("{err:?}"),
backtrace: BackTrace::capture(),
}
})?;
cudarc::driver::result::event::destroy(self.event).map_err(|err| {
ServerError::Generic {
reason: format!("{err:?}"),
backtrace: BackTrace::capture(),
}
})?;
}
Ok(())
}
pub fn wait_async(self, stream: *mut CUstream_st) {
unsafe {
cudarc::driver::result::stream::wait_event(
stream,
self.event,
CUevent_wait_flags::CU_EVENT_WAIT_DEFAULT,
)
.unwrap();
cudarc::driver::result::event::destroy(self.event).unwrap();
}
}
}