use std::{ptr, sync::Arc};
use singe_cuda_sys::{driver, runtime};
use crate::{
context::Context,
error::{Error, Result},
stream::{BorrowedStream, Stream, StreamBinding},
try_cuda,
};
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EventFlags: u32 {
const DEFAULT = driver::CUevent_flags::CU_EVENT_DEFAULT as _;
const BLOCKING_SYNC = driver::CUevent_flags::CU_EVENT_BLOCKING_SYNC as _;
const DISABLE_TIMING = driver::CUevent_flags::CU_EVENT_DISABLE_TIMING as _;
const INTERPROCESS = driver::CUevent_flags::CU_EVENT_INTERPROCESS as _;
}
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EventRecordFlags: u32 {
const DEFAULT = runtime::cudaEventRecordDefault;
const EXTERNAL = runtime::cudaEventRecordExternal;
}
}
#[derive(Debug)]
pub struct Event {
handle: runtime::cudaEvent_t,
ctx: Arc<Context>,
}
impl Event {
pub fn record(&self, stream: &Stream, flags: EventRecordFlags) -> Result<()> {
if stream.context() != self.context() {
return Err(runtime::cudaError_t::CUDA_ERROR_INVALID_CONTEXT.into());
}
self.record_raw(unsafe { stream.as_raw() }, flags)
}
pub fn record_borrowed(&self, stream: &BorrowedStream, flags: EventRecordFlags) -> Result<()> {
if stream.context() != self.context() {
return Err(runtime::cudaError_t::CUDA_ERROR_INVALID_CONTEXT.into());
}
self.record_raw(stream.as_raw(), flags)
}
pub fn record_on(&self, stream: &StreamBinding, flags: EventRecordFlags) -> Result<()> {
if stream.context() != self.context() {
return Err(runtime::cudaError_t::CUDA_ERROR_INVALID_CONTEXT.into());
}
self.record_raw(stream.as_raw(), flags)
}
fn record_raw(&self, stream: runtime::cudaStream_t, flags: EventRecordFlags) -> Result<()> {
self.ctx.bind()?;
unsafe {
try_cuda!(runtime::cudaEventRecordWithFlags(
self.as_raw(),
stream,
flags.bits(),
))?;
}
Ok(())
}
pub fn query(&self) -> Result<bool> {
let error = unsafe { runtime::cudaEventQuery(self.as_raw()) };
match error {
runtime::cudaError_t::CUDA_SUCCESS => Ok(true),
runtime::cudaError_t::CUDA_ERROR_NOT_READY => Ok(false),
_ => Err(error.into()),
}
}
pub fn synchronize(&self) -> Result<()> {
self.ctx.bind()?;
unsafe {
try_cuda!(runtime::cudaEventSynchronize(self.as_raw()))?;
}
Ok(())
}
pub fn elapsed_time_since(&self, start: &Event) -> Result<f32> {
if self.context() != start.context() {
return Err(runtime::cudaError_t::CUDA_ERROR_INVALID_CONTEXT.into());
}
self.ctx.bind()?;
let mut milliseconds = 0.0f32;
unsafe {
try_cuda!(runtime::cudaEventElapsedTime(
&raw mut milliseconds,
start.as_raw(),
self.as_raw(),
))?;
}
Ok(milliseconds)
}
pub fn context(&self) -> &Context {
&self.ctx
}
pub const unsafe fn as_raw(&self) -> runtime::cudaEvent_t {
self.handle
}
}
unsafe impl Send for Event {}
unsafe impl Sync for Event {}
impl Drop for Event {
fn drop(&mut self) {
if let Err(err) = self.ctx.bind() {
#[cfg(debug_assertions)]
eprintln!("failed to bind context before destroying stream: {err}");
}
unsafe {
if let Err(err) = try_cuda!(runtime::cudaEventDestroy(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy CUDA event: {err}");
}
}
}
}
impl Context {
pub fn create_event(self: &Arc<Self>) -> Result<Event> {
self.create_event_with_flags(EventFlags::DEFAULT)
}
pub fn create_event_with_flags(self: &Arc<Self>, flags: EventFlags) -> Result<Event> {
self.bind()?;
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaEventCreateWithFlags(
&raw mut handle,
flags.bits()
))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Event {
handle,
ctx: Arc::clone(self),
})
}
}