use std::ffi::c_void;
use std::ptr;
use flodl_sys as ffi;
use crate::tensor::{check_err, Device, Result, TensorError};
use super::cuda_event::CudaEvent;
pub struct CudaStream {
ptr: *mut c_void,
device_index: i32,
}
unsafe impl Send for CudaStream {}
impl CudaStream {
pub fn new(device: Device, high_priority: bool) -> Result<Self> {
let device_index = match device {
Device::CUDA(idx) => idx as i32,
Device::CPU => {
return Err(TensorError::new(
"CudaStream requires a CUDA device",
))
}
};
let mut ptr: *mut c_void = ptr::null_mut();
let err = unsafe {
ffi::flodl_cuda_stream_new(device_index, high_priority as i32, &mut ptr)
};
check_err(err)?;
Ok(CudaStream { ptr, device_index })
}
pub fn synchronize(&self) -> Result<()> {
let err = unsafe { ffi::flodl_cuda_stream_synchronize(self.ptr) };
check_err(err)
}
pub fn wait_event(&self, event: &CudaEvent) -> Result<()> {
let err = unsafe {
ffi::flodl_cuda_stream_wait_event(self.ptr, event.as_ptr())
};
check_err(err)
}
pub fn is_complete(&self) -> bool {
unsafe { ffi::flodl_cuda_stream_query(self.ptr) != 0 }
}
pub fn device(&self) -> Device {
Device::CUDA(self.device_index as u8)
}
pub(crate) fn as_ptr(&self) -> *mut c_void {
self.ptr
}
}
impl Drop for CudaStream {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::flodl_cuda_stream_delete(self.ptr) };
self.ptr = ptr::null_mut();
}
}
}
pub struct StreamGuard {
prev: *mut std::ffi::c_void,
device_index: i32,
}
impl StreamGuard {
pub fn new(stream: &CudaStream) -> Self {
let prev = unsafe { ffi::flodl_cuda_stream_get_current(stream.device_index) };
unsafe { ffi::flodl_cuda_stream_set_current(stream.ptr) };
StreamGuard {
prev,
device_index: stream.device_index,
}
}
}
impl Drop for StreamGuard {
fn drop(&mut self) {
if !self.prev.is_null() {
unsafe { ffi::flodl_cuda_stream_set_current(self.prev) };
unsafe { ffi::flodl_cuda_stream_delete(self.prev) };
} else {
unsafe { ffi::flodl_cuda_stream_restore_default(self.device_index) };
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::cuda_event::CudaEventFlags;
use crate::tensor::{Tensor, test_device, test_opts};
use std::sync::Mutex;
static STREAM_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn test_cuda_stream_requires_cuda_device() {
let result = CudaStream::new(Device::CPU, false);
assert!(result.is_err(), "CudaStream::new(CPU) should fail");
}
#[test]
fn test_cuda_stream_create_synchronize() {
if !test_device().is_cuda() {
return;
}
let _lock = STREAM_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let stream = CudaStream::new(test_device(), false).unwrap();
assert_eq!(stream.device(), test_device());
stream.synchronize().unwrap();
assert!(stream.is_complete(), "empty stream should be complete");
}
#[test]
fn test_stream_guard_restores_default() {
if !test_device().is_cuda() {
return;
}
let _lock = STREAM_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let opts = test_opts();
let stream = CudaStream::new(test_device(), false).unwrap();
{
let _guard = StreamGuard::new(&stream);
let _a = Tensor::randn(&[32, 32], opts).unwrap();
}
let b = Tensor::ones(&[4], opts).unwrap();
let c = b.add(&b).unwrap();
let vals = c.to_f32_vec().unwrap();
assert!(vals.iter().all(|&v| (v - 2.0).abs() < 1e-5));
}
#[test]
fn test_async_copy_on_stream() {
if !test_device().is_cuda() {
return;
}
let _lock = STREAM_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let opts = test_opts();
let gpu = Tensor::full(&[128], 42.0, opts).unwrap();
let copy_stream = CudaStream::new(test_device(), false).unwrap();
let ready = CudaEvent::new(CudaEventFlags::DisableTiming).unwrap();
ready.record().unwrap();
copy_stream.wait_event(&ready).unwrap();
let cpu_copy = {
let _guard = StreamGuard::new(©_stream);
gpu.to_device_async(Device::CPU).unwrap()
};
let done = CudaEvent::new(CudaEventFlags::DisableTiming).unwrap();
done.record_on(©_stream).unwrap();
done.synchronize().unwrap();
let vals = cpu_copy.to_f32_vec().unwrap();
assert_eq!(vals.len(), 128);
assert!(vals.iter().all(|&v| (v - 42.0).abs() < 1e-5),
"async copy should preserve values");
}
#[test]
fn test_cross_stream_wait_event() {
if !test_device().is_cuda() {
return;
}
let _lock = STREAM_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let opts = test_opts();
let stream_a = CudaStream::new(test_device(), false).unwrap();
let stream_b = CudaStream::new(test_device(), false).unwrap();
let result = {
let _guard = StreamGuard::new(&stream_a);
Tensor::full(&[64], 7.0, opts).unwrap()
};
let event = CudaEvent::new(CudaEventFlags::DisableTiming).unwrap();
event.record_on(&stream_a).unwrap();
stream_b.wait_event(&event).unwrap();
let doubled = {
let _guard = StreamGuard::new(&stream_b);
result.add(&result).unwrap()
};
stream_b.synchronize().unwrap();
let vals = doubled.to_f32_vec().unwrap();
assert!(vals.iter().all(|&v| (v - 14.0).abs() < 1e-5),
"cross-stream result should be 14.0");
}
}