use crate::cuda_sys::{self, cudaStream_t, CUDA_SUCCESS};
use std::sync::OnceLock;
pub struct CudaStreamWrapper {
stream: cudaStream_t,
}
unsafe impl Send for CudaStreamWrapper {}
unsafe impl Sync for CudaStreamWrapper {}
impl CudaStreamWrapper {
pub fn new_default() -> Self {
CudaStreamWrapper {
stream: std::ptr::null_mut(),
}
}
pub fn new() -> Result<Self, String> {
let mut stream: cudaStream_t = std::ptr::null_mut();
let err = unsafe { cuda_sys::cudaStreamCreate(&mut stream) };
if err != CUDA_SUCCESS {
return Err(format!("cudaStreamCreate failed: {}", err));
}
Ok(CudaStreamWrapper { stream })
}
pub fn raw(&self) -> cudaStream_t {
self.stream
}
pub fn synchronize(&self) -> Result<(), String> {
if self.stream.is_null() {
cuda_check!(cuda_sys::cudaDeviceSynchronize())
} else {
cuda_check!(cuda_sys::cudaStreamSynchronize(self.stream))
}
}
}
impl Drop for CudaStreamWrapper {
fn drop(&mut self) {
if !self.stream.is_null() {
unsafe {
cuda_sys::cudaStreamDestroy(self.stream);
}
}
}
}
static GLOBAL_STREAM: OnceLock<CudaStreamWrapper> = OnceLock::new();
pub fn get_stream() -> &'static CudaStreamWrapper {
GLOBAL_STREAM.get_or_init(|| CudaStreamWrapper::new_default())
}
pub fn sync_stream() {
if let Err(e) = get_stream().synchronize() {
eprintln!("CUDA sync_stream error: {}", e);
}
}