use std::ffi::c_void;
use std::ptr;
use super::context::{get_driver, CudaContext};
use super::graph::{CaptureMode, CudaGraph, CudaGraphExec};
use super::module::CudaModule;
use super::sys::{
CUevent, CUfunction, CUstream, CudaDriver, CUDA_ERROR_NOT_READY, CU_EVENT_DISABLE_TIMING,
CU_STREAM_NON_BLOCKING,
};
use super::types::LaunchConfig;
use crate::GpuError;
pub struct CudaStream {
stream: CUstream,
}
unsafe impl Send for CudaStream {}
unsafe impl Sync for CudaStream {}
impl CudaStream {
pub fn new(_ctx: &CudaContext) -> Result<Self, GpuError> {
let driver = get_driver()?;
let mut stream: CUstream = ptr::null_mut();
let result = unsafe { (driver.cuStreamCreate)(&mut stream, CU_STREAM_NON_BLOCKING) };
CudaDriver::check(result).map_err(|e| GpuError::StreamCreate(e.to_string()))?;
Ok(Self { stream })
}
#[must_use]
pub fn raw(&self) -> CUstream {
self.stream
}
pub fn synchronize(&self) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuStreamSynchronize)(self.stream) };
CudaDriver::check(result).map_err(|e| GpuError::StreamSync(e.to_string()))
}
pub fn memcpy_dtod_sync(
&self,
dst_ptr: u64,
src_ptr: u64,
size_bytes: usize,
) -> Result<(), GpuError> {
if size_bytes == 0 {
return Ok(());
}
let driver = get_driver()?;
let result = unsafe { (driver.cuMemcpyDtoD)(dst_ptr, src_ptr, size_bytes) };
CudaDriver::check(result).map_err(|e| GpuError::Transfer(format!("D2D copy failed: {e}")))
}
pub unsafe fn launch_kernel(
&self,
module: &mut CudaModule,
func_name: &str,
config: &LaunchConfig,
args: &mut [*mut c_void],
) -> Result<(), GpuError> {
let driver = get_driver()?;
let func = module.get_function(func_name)?;
unsafe { self.launch_function(driver, func, config, args) }
}
pub unsafe fn launch_function(
&self,
driver: &CudaDriver,
func: CUfunction,
config: &LaunchConfig,
args: &mut [*mut c_void],
) -> Result<(), GpuError> {
let result = unsafe {
(driver.cuLaunchKernel)(
func,
config.grid.0,
config.grid.1,
config.grid.2,
config.block.0,
config.block.1,
config.block.2,
config.shared_mem,
self.stream,
args.as_mut_ptr(),
ptr::null_mut(), )
};
CudaDriver::check(result).map_err(|e| GpuError::KernelLaunch(e.to_string()))?;
Ok(())
}
pub fn begin_capture(&self, mode: CaptureMode) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuStreamBeginCapture)(self.stream, mode.to_cuda_mode()) };
CudaDriver::check(result).map_err(|e| GpuError::GraphCapture(e.to_string()))
}
pub fn end_capture(&self) -> Result<CudaGraph, GpuError> {
let driver = get_driver()?;
let mut graph = ptr::null_mut();
let result = unsafe { (driver.cuStreamEndCapture)(self.stream, &mut graph) };
CudaDriver::check(result).map_err(|e| GpuError::GraphCapture(e.to_string()))?;
Ok(CudaGraph::from_raw(graph))
}
pub fn launch_graph(&self, exec: &CudaGraphExec) -> Result<(), GpuError> {
exec.launch(self.stream)
}
pub fn record_event(&self, event: &CudaEvent) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuEventRecord)(event.event, self.stream) };
CudaDriver::check(result).map_err(|e| GpuError::StreamSync(format!("event record: {e}")))
}
pub fn wait_event(&self, event: &CudaEvent) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuStreamWaitEvent)(self.stream, event.event, 0) };
CudaDriver::check(result)
.map_err(|e| GpuError::StreamSync(format!("stream wait event: {e}")))
}
}
pub struct CudaEvent {
event: CUevent,
}
unsafe impl Send for CudaEvent {}
unsafe impl Sync for CudaEvent {}
impl CudaEvent {
pub fn new() -> Result<Self, GpuError> {
let driver = get_driver()?;
let mut event: CUevent = ptr::null_mut();
let result = unsafe { (driver.cuEventCreate)(&mut event, CU_EVENT_DISABLE_TIMING) };
CudaDriver::check(result)
.map_err(|e| GpuError::StreamCreate(format!("event create: {e}")))?;
Ok(Self { event })
}
pub fn is_complete(&self) -> Result<bool, GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuEventQuery)(self.event) };
if result == CUDA_ERROR_NOT_READY {
return Ok(false);
}
CudaDriver::check(result).map_err(|e| GpuError::StreamSync(format!("event query: {e}")))?;
Ok(true)
}
pub fn synchronize(&self) -> Result<(), GpuError> {
let driver = get_driver()?;
let result = unsafe { (driver.cuEventSynchronize)(self.event) };
CudaDriver::check(result).map_err(|e| GpuError::StreamSync(format!("event sync: {e}")))
}
}
impl Drop for CudaEvent {
fn drop(&mut self) {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuEventDestroy)(self.event);
}
}
}
}
impl Drop for CudaStream {
fn drop(&mut self) {
if let Ok(driver) = get_driver() {
unsafe {
let _ = (driver.cuStreamDestroy)(self.stream);
}
}
}
}
pub const DEFAULT_STREAM: CUstream = ptr::null_mut();
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_stream_is_null() {
assert!(DEFAULT_STREAM.is_null());
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_stream_requires_cuda_feature() {
assert!(true);
}
}