use crate::error::{CudaError, CudaResult, DropResult, ToResult};
use crate::stream::Stream;
use crate::sys::{
cuEventCreate, cuEventDestroy_v2, cuEventElapsedTime, cuEventQuery, cuEventRecord,
cuEventSynchronize, CUevent,
};
use std::mem;
use std::ptr;
use std::time::Duration;
bitflags::bitflags! {
pub struct EventFlags: u32 {
const DEFAULT = 0x0;
const BLOCKING_SYNC = 0x1;
const DISABLE_TIMING = 0x2;
const INTERPROCESS = 0x4;
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum EventStatus {
Ready,
NotReady,
}
#[derive(Debug)]
pub struct Event(CUevent);
unsafe impl Send for Event {}
unsafe impl Sync for Event {}
impl Event {
pub fn new(flags: EventFlags) -> CudaResult<Self> {
unsafe {
let mut event: CUevent = mem::zeroed();
cuEventCreate(&mut event, flags.bits()).to_result()?;
Ok(Event(event))
}
}
pub fn record(&self, stream: &Stream) -> CudaResult<()> {
unsafe {
cuEventRecord(self.0, stream.as_inner()).to_result()?;
Ok(())
}
}
pub fn query(&self) -> CudaResult<EventStatus> {
let result = unsafe { cuEventQuery(self.0).to_result() };
match result {
Ok(()) => Ok(EventStatus::Ready),
Err(CudaError::NotReady) => Ok(EventStatus::NotReady),
Err(other) => Err(other),
}
}
pub fn synchronize(&self) -> CudaResult<()> {
unsafe {
cuEventSynchronize(self.0).to_result()?;
Ok(())
}
}
pub fn elapsed_time_f32(&self, start: &Self) -> CudaResult<f32> {
unsafe {
let mut millis: f32 = 0.0;
cuEventElapsedTime(&mut millis, start.0, self.0).to_result()?;
Ok(millis)
}
}
pub fn elapsed(&self, start: &Self) -> CudaResult<Duration> {
let time_f32 = self.elapsed_time_f32(start)?;
Ok(Duration::from_nanos((time_f32 * 1e6) as u64))
}
pub(crate) fn as_inner(&self) -> CUevent {
self.0
}
pub fn drop(mut event: Event) -> DropResult<Event> {
if event.0.is_null() {
return Ok(());
}
unsafe {
let inner = mem::replace(&mut event.0, ptr::null_mut());
match cuEventDestroy_v2(inner).to_result() {
Ok(()) => {
mem::forget(event);
Ok(())
}
Err(e) => Err((e, Event(inner))),
}
}
}
}
impl Drop for Event {
fn drop(&mut self) {
unsafe { cuEventDestroy_v2(self.0) };
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::quick_init;
use crate::stream::StreamFlags;
use std::error::Error;
#[test]
fn test_new_with_flags() -> Result<(), Box<dyn Error>> {
let _context = quick_init()?;
let _event = Event::new(EventFlags::BLOCKING_SYNC | EventFlags::DISABLE_TIMING)?;
Ok(())
}
#[test]
fn test_elapsed_time_f32_with_different_streams() -> Result<(), Box<dyn Error>> {
let _context = quick_init()?;
let fst_stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
let fst_event = Event::new(EventFlags::DEFAULT)?;
fst_event.record(&fst_stream)?;
let snd_stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
let snd_event = Event::new(EventFlags::DEFAULT)?;
snd_event.record(&snd_stream)?;
fst_event.synchronize()?;
snd_event.synchronize()?;
let _result = snd_event.elapsed_time_f32(&fst_event)?;
Ok(())
}
#[test]
fn test_elapsed_time_f32_with_disable_timing() -> Result<(), Box<dyn Error>> {
let _context = quick_init()?;
let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
let start_event = Event::new(EventFlags::DISABLE_TIMING)?;
start_event.record(&stream)?;
let stop_event = Event::new(EventFlags::DEFAULT)?;
stop_event.record(&stream)?;
stop_event.synchronize()?;
let result = stop_event.elapsed_time_f32(&start_event);
assert_eq!(result, Err(CudaError::InvalidHandle));
Ok(())
}
}