Skip to main content

ferrotorch_gpu/
stream.rs

1//! CUDA stream pool with thread-local current stream and event wrappers.
2//!
3//! Provides multi-stream concurrency for overlapping compute and data transfers:
4//!
5//! - [`CudaEventWrapper`] — safe wrapper around cudarc's `CudaEvent` with record/sync/query.
6//! - [`StreamPool`] — per-device pool of CUDA streams, created lazily, round-robin dispatch.
7//! - [`get_current_stream`] / [`set_current_stream`] — thread-local "active" stream per device.
8//! - [`StreamGuard`] — RAII guard that sets the current stream and restores the previous on drop.
9//!
10//! # Design
11//!
12//! Each device gets [`STREAMS_PER_DEVICE`] non-blocking streams created via
13//! [`CudaContext::new_stream`]. The pool is initialized lazily on first access
14//! using [`OnceLock`]. Streams are distributed round-robin via an atomic counter.
15//!
16//! The thread-local current stream allows callers to override which stream a
17//! device operation targets without threading a stream parameter through every
18//! function. [`StreamGuard`] makes this ergonomic and exception-safe.
19
20#[cfg(feature = "cuda")]
21use std::cell::RefCell;
22#[cfg(feature = "cuda")]
23use std::collections::HashMap;
24#[cfg(feature = "cuda")]
25use std::sync::atomic::{AtomicUsize, Ordering};
26#[cfg(feature = "cuda")]
27use std::sync::{Arc, OnceLock};
28
29#[cfg(feature = "cuda")]
30use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
31
32use crate::error::{GpuError, GpuResult};
33
34// ---------------------------------------------------------------------------
35// Constants
36// ---------------------------------------------------------------------------
37
38/// Number of streams created per device in the pool.
39#[cfg(feature = "cuda")]
40const STREAMS_PER_DEVICE: usize = 8;
41
42/// Maximum supported device ordinal. Guards against unbounded allocation
43/// if a caller passes a bogus ordinal.
44#[cfg(feature = "cuda")]
45const MAX_DEVICES: usize = 64;
46
47// ---------------------------------------------------------------------------
48// CudaEventWrapper — safe wrapper around cudarc's CudaEvent
49// ---------------------------------------------------------------------------
50
51/// Safe wrapper around a cudarc [`CudaEvent`].
52///
53/// Records a point in a stream's execution timeline and allows the host or
54/// other streams to wait until that point is reached.
55///
56/// All methods return [`GpuResult`] rather than panicking on CUDA errors.
57#[cfg(feature = "cuda")]
58pub struct CudaEventWrapper {
59    inner: CudaEvent,
60}
61
62#[cfg(feature = "cuda")]
63impl CudaEventWrapper {
64    /// Create a new event associated with the given device's context.
65    ///
66    /// The event is created with `CU_EVENT_DISABLE_TIMING` (the cudarc default
67    /// when `None` is passed for flags). Use [`new_with_timing`] if you need
68    /// elapsed-time queries.
69    pub fn new(ctx: &Arc<CudaContext>) -> GpuResult<Self> {
70        let inner = ctx.new_event(None)?;
71        Ok(Self { inner })
72    }
73
74    /// Create a new event with timing enabled.
75    ///
76    /// Required if you want to call [`elapsed_ms`](CudaEvent::elapsed_ms).
77    /// Timing events are slightly more expensive than non-timing events.
78    pub fn new_with_timing(ctx: &Arc<CudaContext>) -> GpuResult<Self> {
79        let flags = cudarc::driver::sys::CUevent_flags::CU_EVENT_DEFAULT;
80        let inner = ctx.new_event(Some(flags))?;
81        Ok(Self { inner })
82    }
83
84    /// Record the current point in `stream`'s execution into this event.
85    ///
86    /// After recording, [`synchronize`](Self::synchronize) will block until all
87    /// work submitted to `stream` before this call has completed.
88    ///
89    /// # Errors
90    ///
91    /// Returns `Err` if the stream belongs to a different CUDA context than
92    /// the event, or if the CUDA driver reports an error.
93    pub fn record(&self, stream: &CudaStream) -> GpuResult<()> {
94        self.inner.record(stream)?;
95        Ok(())
96    }
97
98    /// Block the calling CPU thread until all work recorded in this event
99    /// has completed on the GPU.
100    ///
101    /// # Errors
102    ///
103    /// Returns `Err` if the CUDA driver reports an error (e.g., a previous
104    /// async kernel launch failed).
105    pub fn synchronize(&self) -> GpuResult<()> {
106        self.inner.synchronize()?;
107        Ok(())
108    }
109
110    /// Query whether all work recorded in this event has completed.
111    ///
112    /// Returns `Ok(true)` if complete, `Ok(false)` if still in progress.
113    /// This is a non-blocking check.
114    pub fn query(&self) -> GpuResult<bool> {
115        Ok(self.inner.is_complete())
116    }
117
118    /// Make `stream` wait for all work recorded in this event to complete
119    /// before executing any subsequent operations.
120    ///
121    /// This is a GPU-side wait — it does not block the CPU.
122    ///
123    /// # Errors
124    ///
125    /// Returns `Err` if the stream and event belong to different CUDA contexts.
126    pub fn wait_on(&self, stream: &CudaStream) -> GpuResult<()> {
127        stream.wait(&self.inner)?;
128        Ok(())
129    }
130
131    /// Borrow the underlying cudarc [`CudaEvent`].
132    #[inline]
133    pub fn inner(&self) -> &CudaEvent {
134        &self.inner
135    }
136}
137
138#[cfg(feature = "cuda")]
139impl std::fmt::Debug for CudaEventWrapper {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("CudaEventWrapper").finish_non_exhaustive()
142    }
143}
144
145// ---------------------------------------------------------------------------
146// StreamPool — per-device pool of CUDA streams
147// ---------------------------------------------------------------------------
148
149/// Per-device pool of CUDA streams for concurrent kernel execution.
150///
151/// Streams are created lazily on first access for a given device ordinal.
152/// [`get_stream`](StreamPool::get_stream) distributes streams round-robin
153/// across the pool, ensuring balanced utilization.
154///
155/// The pool holds [`STREAMS_PER_DEVICE`] streams per device (currently 8).
156#[cfg(feature = "cuda")]
157struct DeviceStreams {
158    streams: Vec<Arc<CudaStream>>,
159    counter: AtomicUsize,
160}
161
162/// Global stream pool. Each entry is lazily initialized via `OnceLock`.
163///
164/// We use a fixed-size array of `OnceLock` rather than a `HashMap` to avoid
165/// locking on the hot path. The index is the device ordinal.
166#[cfg(feature = "cuda")]
167static STREAM_POOL: OnceLock<Vec<OnceLock<DeviceStreams>>> = OnceLock::new();
168
169/// Initialize the pool structure (array of `OnceLock` slots). Called once.
170#[cfg(feature = "cuda")]
171fn pool_slots() -> &'static Vec<OnceLock<DeviceStreams>> {
172    STREAM_POOL.get_or_init(|| {
173        (0..MAX_DEVICES).map(|_| OnceLock::new()).collect()
174    })
175}
176
177/// Public interface for the CUDA stream pool.
178pub struct StreamPool;
179
180#[cfg(feature = "cuda")]
181impl StreamPool {
182    /// Get a stream for the given device, round-robin across the pool.
183    ///
184    /// On first call for a device ordinal, lazily creates [`STREAMS_PER_DEVICE`]
185    /// non-blocking streams from the device's CUDA context.
186    ///
187    /// # Arguments
188    ///
189    /// * `ctx` — The CUDA context for the target device. Must match the
190    ///   ordinal (callers are responsible for passing the correct context).
191    /// * `device_ordinal` — The GPU device index (0-based).
192    ///
193    /// # Errors
194    ///
195    /// - Returns [`GpuError::InvalidDevice`] if `device_ordinal >= MAX_DEVICES`.
196    /// - Returns a CUDA driver error if stream creation fails.
197    pub fn get_stream(
198        ctx: &Arc<CudaContext>,
199        device_ordinal: usize,
200    ) -> GpuResult<Arc<CudaStream>> {
201        if device_ordinal >= MAX_DEVICES {
202            return Err(GpuError::InvalidDevice {
203                ordinal: device_ordinal,
204                count: MAX_DEVICES,
205            });
206        }
207
208        let slots = pool_slots();
209        let device_streams = slots[device_ordinal].get_or_init(|| {
210            // We create the streams eagerly within this device's OnceLock init.
211            // If any stream creation fails, we store what we got (at least 1).
212            let mut streams = Vec::with_capacity(STREAMS_PER_DEVICE);
213            for _ in 0..STREAMS_PER_DEVICE {
214                match ctx.new_stream() {
215                    Ok(s) => streams.push(s),
216                    Err(_) => break,
217                }
218            }
219            // If we got zero streams, push a fallback: fork from default stream.
220            if streams.is_empty() {
221                if let Ok(s) = ctx.default_stream().fork() {
222                    streams.push(s);
223                }
224            }
225            DeviceStreams {
226                streams,
227                counter: AtomicUsize::new(0),
228            }
229        });
230
231        if device_streams.streams.is_empty() {
232            return Err(GpuError::Driver(cudarc::driver::DriverError(
233                cudarc::driver::sys::cudaError_enum::CUDA_ERROR_OUT_OF_MEMORY,
234            )));
235        }
236
237        let idx = device_streams.counter.fetch_add(1, Ordering::Relaxed)
238            % device_streams.streams.len();
239        Ok(Arc::clone(&device_streams.streams[idx]))
240    }
241
242    /// Return the number of streams currently in the pool for a device.
243    /// Returns 0 if the device has not been initialized yet.
244    pub fn pool_size(device_ordinal: usize) -> usize {
245        if device_ordinal >= MAX_DEVICES {
246            return 0;
247        }
248        let slots = pool_slots();
249        slots[device_ordinal]
250            .get()
251            .map(|ds| ds.streams.len())
252            .unwrap_or(0)
253    }
254}
255
256// ---------------------------------------------------------------------------
257// Thread-local current stream
258// ---------------------------------------------------------------------------
259
260#[cfg(feature = "cuda")]
261thread_local! {
262    /// Per-thread map from device ordinal to the "current" stream for that device.
263    /// When set, GPU operations on that device should use this stream instead of
264    /// the device's default stream.
265    static CURRENT_STREAMS: RefCell<HashMap<usize, Arc<CudaStream>>> =
266        RefCell::new(HashMap::new());
267}
268
269/// Get the current thread-local stream for the given device.
270///
271/// Returns `None` if no stream has been set for this device on the current
272/// thread. In that case, callers should fall back to the device's default stream.
273#[cfg(feature = "cuda")]
274pub fn get_current_stream(device: usize) -> Option<Arc<CudaStream>> {
275    CURRENT_STREAMS.with(|map| map.borrow().get(&device).cloned())
276}
277
278/// Set the current thread-local stream for the given device.
279///
280/// After this call, [`get_current_stream`] will return `Some(stream)` for
281/// this device on the current thread until it is changed or cleared.
282#[cfg(feature = "cuda")]
283pub fn set_current_stream(device: usize, stream: Arc<CudaStream>) {
284    CURRENT_STREAMS.with(|map| {
285        map.borrow_mut().insert(device, stream);
286    });
287}
288
289/// Clear the current thread-local stream for the given device, reverting
290/// to the device's default stream.
291#[cfg(feature = "cuda")]
292pub fn clear_current_stream(device: usize) {
293    CURRENT_STREAMS.with(|map| {
294        map.borrow_mut().remove(&device);
295    });
296}
297
298/// Get the current stream for a device, falling back to the device's default
299/// stream if none has been set on this thread.
300///
301/// This is the primary entry point for operations that need "the stream to use."
302#[cfg(feature = "cuda")]
303pub fn current_stream_or_default(device: &crate::device::GpuDevice) -> Arc<CudaStream> {
304    get_current_stream(device.ordinal())
305        .unwrap_or_else(|| Arc::clone(device.default_stream()))
306}
307
308// ---------------------------------------------------------------------------
309// StreamGuard — RAII guard for thread-local current stream
310// ---------------------------------------------------------------------------
311
312/// RAII guard that sets the thread-local current stream on construction and
313/// restores the previous stream (or clears it) on drop.
314///
315/// # Example
316///
317/// ```ignore
318/// use ferrotorch_gpu::stream::{StreamGuard, StreamPool};
319///
320/// let stream = StreamPool::get_stream(&ctx, 0)?;
321/// {
322///     let _guard = StreamGuard::new(0, stream);
323///     // All operations on device 0 in this scope use `stream`.
324///     // ...
325/// }
326/// // Previous stream (or default) is restored here.
327/// ```
328#[cfg(feature = "cuda")]
329pub struct StreamGuard {
330    device: usize,
331    previous: Option<Arc<CudaStream>>,
332}
333
334#[cfg(feature = "cuda")]
335impl StreamGuard {
336    /// Set `stream` as the current stream for `device` on this thread.
337    ///
338    /// The previous current stream (if any) is saved and will be restored
339    /// when this guard is dropped.
340    pub fn new(device: usize, stream: Arc<CudaStream>) -> Self {
341        let previous = get_current_stream(device);
342        set_current_stream(device, stream);
343        Self { device, previous }
344    }
345}
346
347#[cfg(feature = "cuda")]
348impl Drop for StreamGuard {
349    fn drop(&mut self) {
350        match self.previous.take() {
351            Some(prev) => set_current_stream(self.device, prev),
352            None => clear_current_stream(self.device),
353        }
354    }
355}
356
357#[cfg(feature = "cuda")]
358impl std::fmt::Debug for StreamGuard {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        f.debug_struct("StreamGuard")
361            .field("device", &self.device)
362            .field("has_previous", &self.previous.is_some())
363            .finish()
364    }
365}
366
367// ---------------------------------------------------------------------------
368// Stubs when `cuda` feature is disabled
369// ---------------------------------------------------------------------------
370
371/// Stub `CudaEventWrapper` when the `cuda` feature is not enabled.
372#[cfg(not(feature = "cuda"))]
373#[derive(Debug)]
374pub struct CudaEventWrapper;
375
376#[cfg(not(feature = "cuda"))]
377impl StreamPool {
378    /// Always returns an error — compile with `features = ["cuda"]`.
379    pub fn get_stream(_device_ordinal: usize) -> GpuResult<()> {
380        Err(GpuError::NoCudaFeature)
381    }
382
383    /// Returns 0 — no streams without CUDA.
384    pub fn pool_size(_device_ordinal: usize) -> usize {
385        0
386    }
387}
388
389/// Stub `StreamGuard` when the `cuda` feature is not enabled.
390#[cfg(not(feature = "cuda"))]
391#[derive(Debug)]
392pub struct StreamGuard;
393
394/// Stub — returns `None` without CUDA.
395#[cfg(not(feature = "cuda"))]
396pub fn get_current_stream(_device: usize) -> Option<()> {
397    None
398}
399
400/// Stub — no-op without CUDA.
401#[cfg(not(feature = "cuda"))]
402pub fn set_current_stream(_device: usize, _stream: ()) {}
403
404/// Stub — no-op without CUDA.
405#[cfg(not(feature = "cuda"))]
406pub fn clear_current_stream(_device: usize) {}
407
408// ---------------------------------------------------------------------------
409// Tests
410// ---------------------------------------------------------------------------
411
412#[cfg(all(test, feature = "cuda"))]
413mod tests {
414    use super::*;
415    use cudarc::driver::CudaContext;
416
417    /// Helper: create a context for device 0. Skips the test if no GPU.
418    fn test_ctx() -> Option<Arc<CudaContext>> {
419        CudaContext::new(0).ok()
420    }
421
422    #[test]
423    fn event_record_sync() {
424        let Some(ctx) = test_ctx() else { return };
425        let stream = ctx.default_stream();
426
427        let event = CudaEventWrapper::new(&ctx)
428            .expect("event creation should succeed");
429
430        // Record on the default stream (which has no pending work).
431        event.record(&stream).expect("record should succeed");
432
433        // Synchronize should complete immediately (no work queued).
434        event.synchronize().expect("synchronize should succeed");
435
436        // Query should return true — all work is done.
437        assert!(
438            event.query().expect("query should succeed"),
439            "event should be complete after synchronize"
440        );
441    }
442
443    #[test]
444    fn event_query_before_record() {
445        let Some(ctx) = test_ctx() else { return };
446
447        let event = CudaEventWrapper::new(&ctx)
448            .expect("event creation should succeed");
449
450        // A freshly created event with no work recorded. Per CUDA semantics,
451        // cuEventQuery on an event that has never been recorded returns
452        // CUDA_SUCCESS (it is considered "complete"). cudarc's is_complete()
453        // wraps this.
454        let complete = event.query().expect("query should not error");
455        // The event has no recorded work, so it reports complete.
456        assert!(complete, "unrecorded event should report complete");
457    }
458
459    #[test]
460    fn stream_pool_round_robin() {
461        let Some(ctx) = test_ctx() else { return };
462        // Use a high ordinal unlikely to collide with other tests.
463        let dev = 0;
464
465        let s1 = StreamPool::get_stream(&ctx, dev)
466            .expect("first get_stream should succeed");
467        let s2 = StreamPool::get_stream(&ctx, dev)
468            .expect("second get_stream should succeed");
469
470        // After STREAMS_PER_DEVICE calls, we should wrap around.
471        let pool_size = StreamPool::pool_size(dev);
472        assert!(pool_size > 0, "pool should have streams");
473        assert!(pool_size <= STREAMS_PER_DEVICE, "pool should not exceed configured size");
474
475        // Collect all streams from a full cycle.
476        let mut streams = vec![s1, s2];
477        for _ in 2..pool_size {
478            streams.push(
479                StreamPool::get_stream(&ctx, dev).expect("get_stream should succeed"),
480            );
481        }
482
483        // The next stream should wrap around to the same as the first.
484        let wrap = StreamPool::get_stream(&ctx, dev)
485            .expect("wrapped get_stream should succeed");
486
487        // Because round-robin, `wrap` should be the same Arc as `streams[0]`.
488        // We compare the underlying cu_stream pointers.
489        assert_eq!(
490            Arc::as_ptr(&wrap),
491            Arc::as_ptr(&streams[0]),
492            "round-robin should wrap back to the first stream"
493        );
494    }
495
496    #[test]
497    fn stream_pool_invalid_device() {
498        let Some(ctx) = test_ctx() else { return };
499        let result = StreamPool::get_stream(&ctx, MAX_DEVICES + 1);
500        assert!(result.is_err(), "should reject ordinal >= MAX_DEVICES");
501    }
502
503    #[test]
504    fn stream_guard_restores_previous() {
505        let Some(ctx) = test_ctx() else { return };
506        let dev = 0;
507
508        // Initially, no current stream.
509        assert!(
510            get_current_stream(dev).is_none(),
511            "should start with no current stream"
512        );
513
514        let s1 = ctx.new_stream().expect("new_stream should succeed");
515        let s2 = ctx.new_stream().expect("new_stream should succeed");
516
517        let s1_ptr = Arc::as_ptr(&s1);
518        let s2_ptr = Arc::as_ptr(&s2);
519
520        // Set s1 as current.
521        set_current_stream(dev, Arc::clone(&s1));
522        assert_eq!(
523            Arc::as_ptr(&get_current_stream(dev).unwrap()),
524            s1_ptr,
525            "current stream should be s1"
526        );
527
528        // Create a guard that sets s2.
529        {
530            let _guard = StreamGuard::new(dev, Arc::clone(&s2));
531            assert_eq!(
532                Arc::as_ptr(&get_current_stream(dev).unwrap()),
533                s2_ptr,
534                "current stream should be s2 inside guard"
535            );
536        }
537
538        // After guard drop, s1 should be restored.
539        assert_eq!(
540            Arc::as_ptr(&get_current_stream(dev).unwrap()),
541            s1_ptr,
542            "current stream should be restored to s1 after guard drop"
543        );
544
545        // Clean up.
546        clear_current_stream(dev);
547        assert!(
548            get_current_stream(dev).is_none(),
549            "should be cleared after explicit clear"
550        );
551    }
552
553    #[test]
554    fn stream_guard_clears_when_no_previous() {
555        let Some(ctx) = test_ctx() else { return };
556        let dev = 0;
557
558        // Ensure no current stream.
559        clear_current_stream(dev);
560        assert!(get_current_stream(dev).is_none());
561
562        let s1 = ctx.new_stream().expect("new_stream should succeed");
563
564        {
565            let _guard = StreamGuard::new(dev, Arc::clone(&s1));
566            assert!(
567                get_current_stream(dev).is_some(),
568                "guard should set current stream"
569            );
570        }
571
572        // Guard had no previous — should clear.
573        assert!(
574            get_current_stream(dev).is_none(),
575            "guard with no previous should clear current stream on drop"
576        );
577    }
578
579    #[test]
580    fn current_stream_or_default_fallback() {
581        // We can't easily construct a GpuDevice in tests without a real GPU
582        // context, but we can test the thread-local logic in isolation.
583        let Some(ctx) = test_ctx() else { return };
584        let dev_ordinal = 0;
585
586        // Clear any leftover state.
587        clear_current_stream(dev_ordinal);
588
589        let device = crate::device::GpuDevice::new(dev_ordinal)
590            .expect("GpuDevice::new should succeed");
591        let default_ptr = Arc::as_ptr(device.default_stream());
592
593        // No current stream set — should fall back to device default.
594        let stream = current_stream_or_default(&device);
595        assert_eq!(
596            Arc::as_ptr(&stream),
597            default_ptr,
598            "should fall back to device default stream"
599        );
600
601        // Set a custom stream — should use it instead.
602        let custom = ctx.new_stream().expect("new_stream should succeed");
603        let custom_ptr = Arc::as_ptr(&custom);
604        set_current_stream(dev_ordinal, custom);
605
606        let stream = current_stream_or_default(&device);
607        assert_eq!(
608            Arc::as_ptr(&stream),
609            custom_ptr,
610            "should use thread-local current stream"
611        );
612
613        // Clean up.
614        clear_current_stream(dev_ordinal);
615    }
616
617    #[test]
618    fn event_wait_on_stream() {
619        let Some(ctx) = test_ctx() else { return };
620        let stream1 = ctx.default_stream();
621        let stream2 = ctx.new_stream().expect("new_stream should succeed");
622
623        let event = CudaEventWrapper::new(&ctx)
624            .expect("event creation should succeed");
625
626        // Record on stream1.
627        event.record(&stream1).expect("record should succeed");
628
629        // Make stream2 wait on the event (GPU-side sync).
630        event.wait_on(&stream2).expect("wait_on should succeed");
631
632        // Synchronize stream2 — this implicitly waits for stream1's work too.
633        stream2.synchronize().expect("synchronize should succeed");
634    }
635}