use cubecl_common::backtrace::BackTrace;
use cubecl_core::server::ExecutionError;
use cubecl_hip_sys::HIP_SUCCESS;
pub struct Fence {
event: cubecl_hip_sys::hipEvent_t,
}
unsafe impl Send for Fence {}
impl Fence {
pub fn new(stream: cubecl_hip_sys::hipStream_t) -> Self {
let mut event: cubecl_hip_sys::hipEvent_t = std::ptr::null_mut();
unsafe {
let status = cubecl_hip_sys::hipEventCreateWithFlags(
&mut event,
cubecl_hip_sys::hipEventDefault,
);
assert_eq!(status, HIP_SUCCESS, "Should create the stream event");
let status = cubecl_hip_sys::hipEventRecord(event, stream);
assert_eq!(status, HIP_SUCCESS, "Should record the stream event");
Self {
event: event as *mut _,
}
}
}
#[allow(unused)]
pub fn wait_async(self, stream: cubecl_hip_sys::hipStream_t) {
unsafe {
let status = cubecl_hip_sys::hipStreamWaitEvent(stream, self.event, 0);
assert_eq!(
status, HIP_SUCCESS,
"Should successfully wait for stream event"
);
let status = cubecl_hip_sys::hipEventDestroy(self.event);
assert_eq!(status, HIP_SUCCESS, "Should destroy the stream eventt");
}
}
pub fn wait_sync(self) -> Result<(), ExecutionError> {
unsafe {
let status = cubecl_hip_sys::hipEventSynchronize(self.event);
if status != HIP_SUCCESS {
return Err(ExecutionError::Generic {
reason: format!("Should successfully wait for stream event: {status}"),
backtrace: BackTrace::capture(),
});
}
let status = cubecl_hip_sys::hipEventDestroy(self.event);
if status != HIP_SUCCESS {
return Err(ExecutionError::Generic {
reason: format!("Should destroy the stream event: {status}"),
backtrace: BackTrace::capture(),
});
}
}
Ok(())
}
}