use crate::{contexted_call, contexted_new, device::*, error::*};
use cuda::*;
use std::sync::Arc;
pub struct Stream {
pub(crate) stream: CUstream,
ctx: Arc<Context>,
}
impl Drop for Stream {
fn drop(&mut self) {
if let Err(e) = unsafe { contexted_call!(self, cuStreamDestroy_v2, self.stream) } {
log::error!("Failed to delete CUDA stream: {:?}", e);
}
}
}
impl Contexted for Stream {
fn get_context(&self) -> Arc<Context> {
self.ctx.clone()
}
}
impl Stream {
pub fn new(ctx: Arc<Context>) -> Self {
let stream = unsafe {
contexted_new!(
&ctx,
cuStreamCreate,
CUstream_flags::CU_STREAM_NON_BLOCKING as u32
)
}
.expect("Failed to create CUDA stream");
Stream { ctx, stream }
}
pub fn query(&self) -> bool {
match unsafe { contexted_call!(self, cuStreamQuery, self.stream) } {
Ok(_) => true,
Err(AccelError::AsyncOperationNotReady) => false,
Err(e) => panic!("Unknown error is happened while cuStreamQuery: {:?}", e),
}
}
pub fn sync(&self) -> Result<()> {
unsafe { contexted_call!(self, cuStreamSynchronize, self.stream) }?;
Ok(())
}
pub fn wait_event(&mut self, event: &Event) {
unsafe { contexted_call!(self, cuStreamWaitEvent, self.stream, event.event, 0) }
.expect("Failed to register an CUDA event waiting on CUDA stream");
}
}
pub struct Event {
event: CUevent,
ctx: Arc<Context>,
}
impl Drop for Event {
fn drop(&mut self) {
if let Err(e) = unsafe { contexted_call!(self, cuEventDestroy_v2, self.event) } {
log::error!("Failed to delete CUDA event: {:?}", e);
}
}
}
impl Contexted for Event {
fn get_context(&self) -> Arc<Context> {
self.ctx.clone()
}
}
impl Event {
pub fn new(ctx: Arc<Context>) -> Self {
let event = unsafe {
contexted_new!(
&ctx,
cuEventCreate,
CUevent_flags_enum::CU_EVENT_BLOCKING_SYNC as u32
)
}
.expect("Failed to create CUDA event");
Event { ctx, event }
}
pub fn record(&mut self, stream: &mut Stream) {
unsafe { contexted_call!(self, cuEventRecord, self.event, stream.stream) }
.expect("Failed to set event record");
}
pub fn query(&self) -> bool {
match unsafe { contexted_call!(self, cuEventQuery, self.event) } {
Ok(_) => true,
Err(AccelError::AsyncOperationNotReady) => false,
Err(e) => panic!("Unknown error occurs while cuEventQuery: {:?}", e),
}
}
pub fn sync(&self) -> Result<()> {
unsafe { contexted_call!(self, cuEventSynchronize, self.event) }?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() -> Result<()> {
let device = Device::nth(0)?;
let ctx = device.create_context();
let _st = Stream::new(ctx);
Ok(())
}
#[test]
fn trivial_sync() -> Result<()> {
let device = Device::nth(0)?;
let ctx = device.create_context();
let mut stream = Stream::new(ctx.clone());
let mut event = Event::new(ctx);
event.record(&mut stream);
event.sync()?;
stream.sync()?;
Ok(())
}
}