ndrs 0.5.0

A tensor library with GPU support
Documentation
use super::device::{get_device, get_device_context};
use super::event::Event;
use anyhow::{Context, Result, bail};
use cudarc::driver::CudaStream;
use cudarc::driver::sys::CUevent_flags;
use std::sync::Arc;

#[derive(Clone)]
pub struct Stream {
    pub(crate) stream: Arc<CudaStream>,
    pub(crate) device_id: usize,
}

impl Stream {
    pub fn new(device_id: Option<usize>) -> Result<Self> {
        let dev_id = device_id.unwrap_or_else(|| get_device());
        let ctx = get_device_context(dev_id).context("Failed to get device context")?;
        let stream = ctx
            .create_stream()
            .context("Failed to create CUDA stream")?;
        Ok(Stream {
            stream,
            device_id: dev_id,
        })
    }

    pub fn synchronize(&self) -> Result<()> {
        self.stream
            .synchronize()
            .context("Failed to synchronize CUDA stream")
    }

    pub fn wait_event(&self, event: &Event) -> Result<()> {
        self.stream
            .wait(&event.event)
            .context("Failed to wait on CUDA event")
    }

    pub fn record(&self) -> Result<Event> {
        let flags = Some(CUevent_flags::CU_EVENT_DEFAULT);
        let event = self
            .stream
            .record_event(flags)
            .context("Failed to record CUDA event")?;
        Ok(Event { event })
    }

    pub fn join(&self, other: &Stream) -> Result<()> {
        self.stream
            .join(&other.stream)
            .context("Failed to join CUDA streams")
    }

    pub fn as_ptr(&self) -> *mut std::ffi::c_void {
        self.stream.cu_stream() as *mut _
    }

    pub fn inner(&self) -> &Arc<CudaStream> {
        &self.stream
    }
}

pub fn default_stream() -> Result<Stream> {
    let dev_id = get_device();
    let ctx = get_device_context(dev_id).context("Failed to get device context")?;
    let stream = ctx.default_stream();
    Ok(Stream {
        stream,
        device_id: dev_id,
    })
}

thread_local! {
    static CURRENT_STREAM: std::cell::RefCell<Option<Stream>> = const { std::cell::RefCell::new(None) };
}

pub fn set_stream(stream: Stream) -> Result<()> {
    let current_dev = get_device();
    if stream.device_id != current_dev {
        bail!(
            "Stream device {} does not match current CUDA device {}",
            stream.device_id,
            current_dev
        );
    }
    CURRENT_STREAM.with(|s| *s.borrow_mut() = Some(stream));
    Ok(())
}

pub fn get_stream() -> Result<Stream> {
    CURRENT_STREAM.with(|s| {
        if let Some(stream) = s.borrow().clone() {
            Ok(stream)
        } else {
            default_stream()
        }
    })
}