use std::ffi::c_void;
use std::ptr;
use super::context::{get_driver, CudaContext};
use super::graph::{CaptureMode, CudaGraph, CudaGraphExec};
use super::module::CudaModule;
use super::sys::{CUfunction, CUstream, CudaDriver, CU_STREAM_NON_BLOCKING};
use super::types::LaunchConfig;
use crate::GpuError;
pub struct CudaStream {
stream: CUstream,
}
unsafe impl Send for CudaStream {}
unsafe impl Sync for CudaStream {}
impl CudaStream {
pub fn new(_ctx: &CudaContext) -> Result<Self, GpuError> {
let driver = get_driver()?;
let mut stream: CUstream = ptr::null_mut();
let result = unsafe { (driver.cuStreamCreate)(&mut stream, CU_STREAM_NON_BLOCKING) };
CudaDriver::check(result).map_err(|e| GpuError::StreamCreate(e.to_string()))?;
Ok(Self { stream })
}
#[must_use]
pub fn raw(&self) -> CUstream {
self.stream
}
pub fn synchronize(&self) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuStreamSynchronize)(self.stream) };
CudaDriver::check(result).map_err(|e| GpuError::StreamSync(e.to_string()))
}
pub unsafe fn launch_kernel(
&self,
module: &mut CudaModule,
func_name: &str,
config: &LaunchConfig,
args: &mut [*mut c_void],
) -> Result<(), GpuError> {
let driver = get_driver()?;
let func = module.get_function(func_name)?;
unsafe { self.launch_function(driver, func, config, args) }
}
pub unsafe fn launch_function(
&self,
driver: &CudaDriver,
func: CUfunction,
config: &LaunchConfig,
args: &mut [*mut c_void],
) -> Result<(), GpuError> {
let result = unsafe {
(driver.cuLaunchKernel)(
func,
config.grid.0,
config.grid.1,
config.grid.2,
config.block.0,
config.block.1,
config.block.2,
config.shared_mem,
self.stream,
args.as_mut_ptr(),
ptr::null_mut(), )
};
CudaDriver::check(result).map_err(|e| GpuError::KernelLaunch(e.to_string()))
}
pub fn begin_capture(&self, mode: CaptureMode) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuStreamBeginCapture)(self.stream, mode.to_cuda_mode()) };
CudaDriver::check(result).map_err(|e| GpuError::GraphCapture(e.to_string()))
}
pub fn end_capture(&self) -> Result<CudaGraph, GpuError> {
let driver = get_driver()?;
let mut graph = ptr::null_mut();
let result = unsafe { (driver.cuStreamEndCapture)(self.stream, &mut graph) };
CudaDriver::check(result).map_err(|e| GpuError::GraphCapture(e.to_string()))?;
Ok(CudaGraph::from_raw(graph))
}
pub fn launch_graph(&self, exec: &CudaGraphExec) -> Result<(), GpuError> {
exec.launch(self.stream)
}
}
impl Drop for CudaStream {
fn drop(&mut self) {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuStreamDestroy)(self.stream);
}
}
}
}
pub const DEFAULT_STREAM: CUstream = ptr::null_mut();
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_stream_is_null() {
assert!(DEFAULT_STREAM.is_null());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_stream_requires_cuda_feature() {
assert!(true);
}
}