flodl 0.3.0

floDl — a flow-graph deep learning framework built on libtorch
//! CUDA stream for async GPU operations.
//!
//! Streams represent ordered queues of GPU work. Operations on different
//! streams can execute concurrently. The default stream serializes all work.
//!
//! Common use: run GPU-to-CPU tensor copies on a non-default stream so
//! they overlap with training on the default stream.
//!
//! CUDA only. Returns error on CPU builds.
//!
//! # Usage
//!
//! ```ignore
//! let copy_stream = CudaStream::new(Device::CUDA(0), false)?;
//! {
//!     let _guard = StreamGuard::new(&copy_stream);
//!     // All CUDA ops here run on copy_stream instead of the default stream.
//!     let cpu_copy = gpu_tensor.to_device_async(Device::CPU)?;
//! }
//! // Default stream restored automatically.
//! ```

use std::ffi::c_void;
use std::ptr;

use flodl_sys as ffi;

use crate::tensor::{check_err, Device, Result, TensorError};
use super::cuda_event::CudaEvent;

/// A CUDA stream obtained from the libtorch stream pool.
///
/// RAII: the stream is returned to the pool on drop.
pub struct CudaStream {
    ptr: *mut c_void,
    device_index: i32,
}

// cudaStream_t is a device-global handle safe to reference from any thread.
unsafe impl Send for CudaStream {}

impl CudaStream {
    /// Create a new CUDA stream from the pool on the given device.
    ///
    /// `high_priority`: if true, uses a high-priority stream that preempts
    /// normal-priority work at SM boundaries.
    pub fn new(device: Device, high_priority: bool) -> Result<Self> {
        let device_index = match device {
            Device::CUDA(idx) => idx as i32,
            Device::CPU => {
                return Err(TensorError::new(
                    "CudaStream requires a CUDA device",
                ))
            }
        };
        let mut ptr: *mut c_void = ptr::null_mut();
        let err = unsafe {
            ffi::flodl_cuda_stream_new(device_index, high_priority as i32, &mut ptr)
        };
        check_err(err)?;
        Ok(CudaStream { ptr, device_index })
    }

    /// Block the CPU thread until all work on this stream completes.
    pub fn synchronize(&self) -> Result<()> {
        let err = unsafe { ffi::flodl_cuda_stream_synchronize(self.ptr) };
        check_err(err)
    }

    /// Make this stream wait for a recorded event before executing any
    /// further work. Does not block the CPU.
    pub fn wait_event(&self, event: &CudaEvent) -> Result<()> {
        let err = unsafe {
            ffi::flodl_cuda_stream_wait_event(self.ptr, event.as_ptr())
        };
        check_err(err)
    }

    /// Non-blocking check: has all work on this stream completed?
    pub fn is_complete(&self) -> bool {
        unsafe { ffi::flodl_cuda_stream_query(self.ptr) != 0 }
    }

    /// The device this stream belongs to.
    pub fn device(&self) -> Device {
        Device::CUDA(self.device_index as u8)
    }

    /// Raw pointer for cross-module use (e.g., event.record_on).
    pub(crate) fn as_ptr(&self) -> *mut c_void {
        self.ptr
    }
}

impl Drop for CudaStream {
    fn drop(&mut self) {
        if !self.ptr.is_null() {
            unsafe { ffi::flodl_cuda_stream_delete(self.ptr) };
            self.ptr = ptr::null_mut();
        }
    }
}

/// RAII guard that sets a stream as the current CUDA stream and
/// restores the default stream on drop.
///
/// Analogous to [`crate::autograd::NoGradGuard`] but for CUDA streams.
///
/// ```ignore
/// let stream = CudaStream::new(Device::CUDA(0), false)?;
/// {
///     let _guard = StreamGuard::new(&stream);
///     // All CUDA ops execute on `stream`.
/// }
/// // Default stream restored.
/// ```
pub struct StreamGuard {
    device_index: i32,
}

impl StreamGuard {
    /// Set `stream` as the current CUDA stream. The default stream
    /// is restored when this guard is dropped.
    pub fn new(stream: &CudaStream) -> Self {
        unsafe { ffi::flodl_cuda_stream_set_current(stream.ptr) };
        StreamGuard {
            device_index: stream.device_index,
        }
    }
}

impl Drop for StreamGuard {
    fn drop(&mut self) {
        unsafe { ffi::flodl_cuda_stream_restore_default(self.device_index) };
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use super::super::cuda_event::CudaEventFlags;
    use crate::tensor::{Tensor, test_device, test_opts};

    use std::sync::Mutex;
    static STREAM_LOCK: Mutex<()> = Mutex::new(());

    #[test]
    fn test_cuda_stream_requires_cuda_device() {
        let result = CudaStream::new(Device::CPU, false);
        assert!(result.is_err(), "CudaStream::new(CPU) should fail");
    }

    #[test]
    fn test_cuda_stream_create_synchronize() {
        if !test_device().is_cuda() {
            return;
        }
        let _lock = STREAM_LOCK.lock().unwrap_or_else(|e| e.into_inner());

        let stream = CudaStream::new(test_device(), false).unwrap();
        assert_eq!(stream.device(), test_device());
        stream.synchronize().unwrap();
        assert!(stream.is_complete(), "empty stream should be complete");
    }

    #[test]
    fn test_stream_guard_restores_default() {
        if !test_device().is_cuda() {
            return;
        }
        let _lock = STREAM_LOCK.lock().unwrap_or_else(|e| e.into_inner());
        let opts = test_opts();

        let stream = CudaStream::new(test_device(), false).unwrap();
        {
            let _guard = StreamGuard::new(&stream);
            // Ops run on the non-default stream
            let _a = Tensor::randn(&[32, 32], opts).unwrap();
        }
        // Guard dropped — default stream restored.
        // Verify we can still do GPU ops normally.
        let b = Tensor::ones(&[4], opts).unwrap();
        let c = b.add(&b).unwrap();
        let vals = c.to_f32_vec().unwrap();
        assert!(vals.iter().all(|&v| (v - 2.0).abs() < 1e-5));
    }

    #[test]
    fn test_async_copy_on_stream() {
        if !test_device().is_cuda() {
            return;
        }
        let _lock = STREAM_LOCK.lock().unwrap_or_else(|e| e.into_inner());
        let opts = test_opts();

        // Create a GPU tensor with known values
        let gpu = Tensor::full(&[128], 42.0, opts).unwrap();

        // Create a non-default copy stream
        let copy_stream = CudaStream::new(test_device(), false).unwrap();

        // Record event on default stream to capture when gpu tensor is ready
        let ready = CudaEvent::new(CudaEventFlags::DisableTiming).unwrap();
        ready.record().unwrap();

        // Copy stream waits for the event, then copies
        copy_stream.wait_event(&ready).unwrap();
        let cpu_copy = {
            let _guard = StreamGuard::new(&copy_stream);
            gpu.to_device_async(Device::CPU).unwrap()
        };

        // Record completion on copy stream
        let done = CudaEvent::new(CudaEventFlags::DisableTiming).unwrap();
        done.record_on(&copy_stream).unwrap();
        done.synchronize().unwrap();

        let vals = cpu_copy.to_f32_vec().unwrap();
        assert_eq!(vals.len(), 128);
        assert!(vals.iter().all(|&v| (v - 42.0).abs() < 1e-5),
            "async copy should preserve values");
    }

    #[test]
    fn test_cross_stream_wait_event() {
        if !test_device().is_cuda() {
            return;
        }
        let _lock = STREAM_LOCK.lock().unwrap_or_else(|e| e.into_inner());
        let opts = test_opts();

        let stream_a = CudaStream::new(test_device(), false).unwrap();
        let stream_b = CudaStream::new(test_device(), false).unwrap();

        // On stream A: create a tensor
        let result = {
            let _guard = StreamGuard::new(&stream_a);
            Tensor::full(&[64], 7.0, opts).unwrap()
        };

        // Record event on stream A
        let event = CudaEvent::new(CudaEventFlags::DisableTiming).unwrap();
        event.record_on(&stream_a).unwrap();

        // Stream B waits for stream A, then reads the tensor
        stream_b.wait_event(&event).unwrap();
        let doubled = {
            let _guard = StreamGuard::new(&stream_b);
            result.add(&result).unwrap()
        };

        // Wait for stream B to finish
        stream_b.synchronize().unwrap();

        let vals = doubled.to_f32_vec().unwrap();
        assert!(vals.iter().all(|&v| (v - 14.0).abs() < 1e-5),
            "cross-stream result should be 14.0");
    }
}