ferrotorch_gpu/stream.rs
1//! CUDA stream pool with thread-local current stream and event wrappers.
2//!
3//! Provides multi-stream concurrency for overlapping compute and data transfers:
4//!
5//! - [`CudaEventWrapper`] — safe wrapper around cudarc's `CudaEvent` with record/sync/query.
6//! - [`StreamPool`] — per-device pool of CUDA streams, created lazily, round-robin dispatch.
7//! - [`get_current_stream`] / [`set_current_stream`] — thread-local "active" stream per device.
8//! - [`StreamGuard`] — RAII guard that sets the current stream and restores the previous on drop.
9//!
10//! # Design
11//!
12//! Each device gets [`STREAMS_PER_DEVICE`] non-blocking streams created via
13//! [`CudaContext::new_stream`]. The pool is initialized lazily on first access
14//! using [`OnceLock`]. Streams are distributed round-robin via an atomic counter.
15//!
16//! The thread-local current stream allows callers to override which stream a
17//! device operation targets without threading a stream parameter through every
18//! function. [`StreamGuard`] makes this ergonomic and exception-safe.
19
20#[cfg(feature = "cuda")]
21use std::cell::RefCell;
22#[cfg(feature = "cuda")]
23use std::collections::HashMap;
24#[cfg(feature = "cuda")]
25use std::sync::atomic::{AtomicUsize, Ordering};
26#[cfg(feature = "cuda")]
27use std::sync::{Arc, OnceLock};
28
29#[cfg(feature = "cuda")]
30use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
31
32use crate::error::{GpuError, GpuResult};
33
34// ---------------------------------------------------------------------------
35// Constants
36// ---------------------------------------------------------------------------
37
38/// Number of streams created per device in the pool.
39#[cfg(feature = "cuda")]
40const STREAMS_PER_DEVICE: usize = 8;
41
42/// Maximum supported device ordinal. Guards against unbounded allocation
43/// if a caller passes a bogus ordinal.
44#[cfg(feature = "cuda")]
45const MAX_DEVICES: usize = 64;
46
47// ---------------------------------------------------------------------------
48// CudaEventWrapper — safe wrapper around cudarc's CudaEvent
49// ---------------------------------------------------------------------------
50
51/// Safe wrapper around a cudarc [`CudaEvent`].
52///
53/// Records a point in a stream's execution timeline and allows the host or
54/// other streams to wait until that point is reached.
55///
56/// All methods return [`GpuResult`] rather than panicking on CUDA errors.
57#[cfg(feature = "cuda")]
58pub struct CudaEventWrapper {
59 inner: CudaEvent,
60}
61
62#[cfg(feature = "cuda")]
63impl CudaEventWrapper {
64 /// Create a new event associated with the given device's context.
65 ///
66 /// The event is created with `CU_EVENT_DISABLE_TIMING` (the cudarc default
67 /// when `None` is passed for flags). Use [`new_with_timing`] if you need
68 /// elapsed-time queries.
69 pub fn new(ctx: &Arc<CudaContext>) -> GpuResult<Self> {
70 let inner = ctx.new_event(None)?;
71 Ok(Self { inner })
72 }
73
74 /// Create a new event with timing enabled.
75 ///
76 /// Required if you want to call [`elapsed_ms`](CudaEvent::elapsed_ms).
77 /// Timing events are slightly more expensive than non-timing events.
78 pub fn new_with_timing(ctx: &Arc<CudaContext>) -> GpuResult<Self> {
79 let flags = cudarc::driver::sys::CUevent_flags::CU_EVENT_DEFAULT;
80 let inner = ctx.new_event(Some(flags))?;
81 Ok(Self { inner })
82 }
83
84 /// Record the current point in `stream`'s execution into this event.
85 ///
86 /// After recording, [`synchronize`](Self::synchronize) will block until all
87 /// work submitted to `stream` before this call has completed.
88 ///
89 /// # Errors
90 ///
91 /// Returns `Err` if the stream belongs to a different CUDA context than
92 /// the event, or if the CUDA driver reports an error.
93 pub fn record(&self, stream: &CudaStream) -> GpuResult<()> {
94 self.inner.record(stream)?;
95 Ok(())
96 }
97
98 /// Block the calling CPU thread until all work recorded in this event
99 /// has completed on the GPU.
100 ///
101 /// # Errors
102 ///
103 /// Returns `Err` if the CUDA driver reports an error (e.g., a previous
104 /// async kernel launch failed).
105 pub fn synchronize(&self) -> GpuResult<()> {
106 self.inner.synchronize()?;
107 Ok(())
108 }
109
110 /// Query whether all work recorded in this event has completed.
111 ///
112 /// Returns `Ok(true)` if complete, `Ok(false)` if still in progress.
113 /// This is a non-blocking check.
114 pub fn query(&self) -> GpuResult<bool> {
115 Ok(self.inner.is_complete())
116 }
117
118 /// Make `stream` wait for all work recorded in this event to complete
119 /// before executing any subsequent operations.
120 ///
121 /// This is a GPU-side wait — it does not block the CPU.
122 ///
123 /// # Errors
124 ///
125 /// Returns `Err` if the stream and event belong to different CUDA contexts.
126 pub fn wait_on(&self, stream: &CudaStream) -> GpuResult<()> {
127 stream.wait(&self.inner)?;
128 Ok(())
129 }
130
131 /// Borrow the underlying cudarc [`CudaEvent`].
132 #[inline]
133 pub fn inner(&self) -> &CudaEvent {
134 &self.inner
135 }
136}
137
138#[cfg(feature = "cuda")]
139impl std::fmt::Debug for CudaEventWrapper {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("CudaEventWrapper").finish_non_exhaustive()
142 }
143}
144
145// ---------------------------------------------------------------------------
146// StreamPool — per-device pool of CUDA streams
147// ---------------------------------------------------------------------------
148
149/// Per-device pool of CUDA streams for concurrent kernel execution.
150///
151/// Streams are created lazily on first access for a given device ordinal.
152/// [`get_stream`](StreamPool::get_stream) distributes streams round-robin
153/// across the pool, ensuring balanced utilization.
154///
155/// The pool holds [`STREAMS_PER_DEVICE`] streams per device (currently 8).
156#[cfg(feature = "cuda")]
157struct DeviceStreams {
158 streams: Vec<Arc<CudaStream>>,
159 counter: AtomicUsize,
160}
161
162/// Global stream pool. Each entry is lazily initialized via `OnceLock`.
163///
164/// We use a fixed-size array of `OnceLock` rather than a `HashMap` to avoid
165/// locking on the hot path. The index is the device ordinal.
166#[cfg(feature = "cuda")]
167static STREAM_POOL: OnceLock<Vec<OnceLock<DeviceStreams>>> = OnceLock::new();
168
169/// Initialize the pool structure (array of `OnceLock` slots). Called once.
170#[cfg(feature = "cuda")]
171fn pool_slots() -> &'static Vec<OnceLock<DeviceStreams>> {
172 STREAM_POOL.get_or_init(|| {
173 (0..MAX_DEVICES).map(|_| OnceLock::new()).collect()
174 })
175}
176
177/// Public interface for the CUDA stream pool.
178pub struct StreamPool;
179
180#[cfg(feature = "cuda")]
181impl StreamPool {
182 /// Get a stream for the given device, round-robin across the pool.
183 ///
184 /// On first call for a device ordinal, lazily creates [`STREAMS_PER_DEVICE`]
185 /// non-blocking streams from the device's CUDA context.
186 ///
187 /// # Arguments
188 ///
189 /// * `ctx` — The CUDA context for the target device. Must match the
190 /// ordinal (callers are responsible for passing the correct context).
191 /// * `device_ordinal` — The GPU device index (0-based).
192 ///
193 /// # Errors
194 ///
195 /// - Returns [`GpuError::InvalidDevice`] if `device_ordinal >= MAX_DEVICES`.
196 /// - Returns a CUDA driver error if stream creation fails.
197 pub fn get_stream(
198 ctx: &Arc<CudaContext>,
199 device_ordinal: usize,
200 ) -> GpuResult<Arc<CudaStream>> {
201 if device_ordinal >= MAX_DEVICES {
202 return Err(GpuError::InvalidDevice {
203 ordinal: device_ordinal,
204 count: MAX_DEVICES,
205 });
206 }
207
208 let slots = pool_slots();
209 let device_streams = slots[device_ordinal].get_or_init(|| {
210 // We create the streams eagerly within this device's OnceLock init.
211 // If any stream creation fails, we store what we got (at least 1).
212 let mut streams = Vec::with_capacity(STREAMS_PER_DEVICE);
213 for _ in 0..STREAMS_PER_DEVICE {
214 match ctx.new_stream() {
215 Ok(s) => streams.push(s),
216 Err(_) => break,
217 }
218 }
219 // If we got zero streams, push a fallback: fork from default stream.
220 if streams.is_empty() {
221 if let Ok(s) = ctx.default_stream().fork() {
222 streams.push(s);
223 }
224 }
225 DeviceStreams {
226 streams,
227 counter: AtomicUsize::new(0),
228 }
229 });
230
231 if device_streams.streams.is_empty() {
232 return Err(GpuError::Driver(cudarc::driver::DriverError(
233 cudarc::driver::sys::cudaError_enum::CUDA_ERROR_OUT_OF_MEMORY,
234 )));
235 }
236
237 let idx = device_streams.counter.fetch_add(1, Ordering::Relaxed)
238 % device_streams.streams.len();
239 Ok(Arc::clone(&device_streams.streams[idx]))
240 }
241
242 /// Return the number of streams currently in the pool for a device.
243 /// Returns 0 if the device has not been initialized yet.
244 pub fn pool_size(device_ordinal: usize) -> usize {
245 if device_ordinal >= MAX_DEVICES {
246 return 0;
247 }
248 let slots = pool_slots();
249 slots[device_ordinal]
250 .get()
251 .map(|ds| ds.streams.len())
252 .unwrap_or(0)
253 }
254}
255
256// ---------------------------------------------------------------------------
257// Thread-local current stream
258// ---------------------------------------------------------------------------
259
260#[cfg(feature = "cuda")]
261thread_local! {
262 /// Per-thread map from device ordinal to the "current" stream for that device.
263 /// When set, GPU operations on that device should use this stream instead of
264 /// the device's default stream.
265 static CURRENT_STREAMS: RefCell<HashMap<usize, Arc<CudaStream>>> =
266 RefCell::new(HashMap::new());
267}
268
269/// Get the current thread-local stream for the given device.
270///
271/// Returns `None` if no stream has been set for this device on the current
272/// thread. In that case, callers should fall back to the device's default stream.
273#[cfg(feature = "cuda")]
274pub fn get_current_stream(device: usize) -> Option<Arc<CudaStream>> {
275 CURRENT_STREAMS.with(|map| map.borrow().get(&device).cloned())
276}
277
278/// Set the current thread-local stream for the given device.
279///
280/// After this call, [`get_current_stream`] will return `Some(stream)` for
281/// this device on the current thread until it is changed or cleared.
282#[cfg(feature = "cuda")]
283pub fn set_current_stream(device: usize, stream: Arc<CudaStream>) {
284 CURRENT_STREAMS.with(|map| {
285 map.borrow_mut().insert(device, stream);
286 });
287}
288
289/// Clear the current thread-local stream for the given device, reverting
290/// to the device's default stream.
291#[cfg(feature = "cuda")]
292pub fn clear_current_stream(device: usize) {
293 CURRENT_STREAMS.with(|map| {
294 map.borrow_mut().remove(&device);
295 });
296}
297
298/// Get the current stream for a device, falling back to the device's default
299/// stream if none has been set on this thread.
300///
301/// This is the primary entry point for operations that need "the stream to use."
302#[cfg(feature = "cuda")]
303pub fn current_stream_or_default(device: &crate::device::GpuDevice) -> Arc<CudaStream> {
304 get_current_stream(device.ordinal())
305 .unwrap_or_else(|| Arc::clone(device.default_stream()))
306}
307
308// ---------------------------------------------------------------------------
309// StreamGuard — RAII guard for thread-local current stream
310// ---------------------------------------------------------------------------
311
312/// RAII guard that sets the thread-local current stream on construction and
313/// restores the previous stream (or clears it) on drop.
314///
315/// # Example
316///
317/// ```ignore
318/// use ferrotorch_gpu::stream::{StreamGuard, StreamPool};
319///
320/// let stream = StreamPool::get_stream(&ctx, 0)?;
321/// {
322/// let _guard = StreamGuard::new(0, stream);
323/// // All operations on device 0 in this scope use `stream`.
324/// // ...
325/// }
326/// // Previous stream (or default) is restored here.
327/// ```
328#[cfg(feature = "cuda")]
329pub struct StreamGuard {
330 device: usize,
331 previous: Option<Arc<CudaStream>>,
332}
333
334#[cfg(feature = "cuda")]
335impl StreamGuard {
336 /// Set `stream` as the current stream for `device` on this thread.
337 ///
338 /// The previous current stream (if any) is saved and will be restored
339 /// when this guard is dropped.
340 pub fn new(device: usize, stream: Arc<CudaStream>) -> Self {
341 let previous = get_current_stream(device);
342 set_current_stream(device, stream);
343 Self { device, previous }
344 }
345}
346
347#[cfg(feature = "cuda")]
348impl Drop for StreamGuard {
349 fn drop(&mut self) {
350 match self.previous.take() {
351 Some(prev) => set_current_stream(self.device, prev),
352 None => clear_current_stream(self.device),
353 }
354 }
355}
356
357#[cfg(feature = "cuda")]
358impl std::fmt::Debug for StreamGuard {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.debug_struct("StreamGuard")
361 .field("device", &self.device)
362 .field("has_previous", &self.previous.is_some())
363 .finish()
364 }
365}
366
367// ---------------------------------------------------------------------------
368// Stubs when `cuda` feature is disabled
369// ---------------------------------------------------------------------------
370
371/// Stub `CudaEventWrapper` when the `cuda` feature is not enabled.
372#[cfg(not(feature = "cuda"))]
373#[derive(Debug)]
374pub struct CudaEventWrapper;
375
376#[cfg(not(feature = "cuda"))]
377impl StreamPool {
378 /// Always returns an error — compile with `features = ["cuda"]`.
379 pub fn get_stream(_device_ordinal: usize) -> GpuResult<()> {
380 Err(GpuError::NoCudaFeature)
381 }
382
383 /// Returns 0 — no streams without CUDA.
384 pub fn pool_size(_device_ordinal: usize) -> usize {
385 0
386 }
387}
388
389/// Stub `StreamGuard` when the `cuda` feature is not enabled.
390#[cfg(not(feature = "cuda"))]
391#[derive(Debug)]
392pub struct StreamGuard;
393
394/// Stub — returns `None` without CUDA.
395#[cfg(not(feature = "cuda"))]
396pub fn get_current_stream(_device: usize) -> Option<()> {
397 None
398}
399
400/// Stub — no-op without CUDA.
401#[cfg(not(feature = "cuda"))]
402pub fn set_current_stream(_device: usize, _stream: ()) {}
403
404/// Stub — no-op without CUDA.
405#[cfg(not(feature = "cuda"))]
406pub fn clear_current_stream(_device: usize) {}
407
408// ---------------------------------------------------------------------------
409// Tests
410// ---------------------------------------------------------------------------
411
412#[cfg(all(test, feature = "cuda"))]
413mod tests {
414 use super::*;
415 use cudarc::driver::CudaContext;
416
417 /// Helper: create a context for device 0. Skips the test if no GPU.
418 fn test_ctx() -> Option<Arc<CudaContext>> {
419 CudaContext::new(0).ok()
420 }
421
422 #[test]
423 fn event_record_sync() {
424 let Some(ctx) = test_ctx() else { return };
425 let stream = ctx.default_stream();
426
427 let event = CudaEventWrapper::new(&ctx)
428 .expect("event creation should succeed");
429
430 // Record on the default stream (which has no pending work).
431 event.record(&stream).expect("record should succeed");
432
433 // Synchronize should complete immediately (no work queued).
434 event.synchronize().expect("synchronize should succeed");
435
436 // Query should return true — all work is done.
437 assert!(
438 event.query().expect("query should succeed"),
439 "event should be complete after synchronize"
440 );
441 }
442
443 #[test]
444 fn event_query_before_record() {
445 let Some(ctx) = test_ctx() else { return };
446
447 let event = CudaEventWrapper::new(&ctx)
448 .expect("event creation should succeed");
449
450 // A freshly created event with no work recorded. Per CUDA semantics,
451 // cuEventQuery on an event that has never been recorded returns
452 // CUDA_SUCCESS (it is considered "complete"). cudarc's is_complete()
453 // wraps this.
454 let complete = event.query().expect("query should not error");
455 // The event has no recorded work, so it reports complete.
456 assert!(complete, "unrecorded event should report complete");
457 }
458
459 #[test]
460 fn stream_pool_round_robin() {
461 let Some(ctx) = test_ctx() else { return };
462 // Use a high ordinal unlikely to collide with other tests.
463 let dev = 0;
464
465 let s1 = StreamPool::get_stream(&ctx, dev)
466 .expect("first get_stream should succeed");
467 let s2 = StreamPool::get_stream(&ctx, dev)
468 .expect("second get_stream should succeed");
469
470 // After STREAMS_PER_DEVICE calls, we should wrap around.
471 let pool_size = StreamPool::pool_size(dev);
472 assert!(pool_size > 0, "pool should have streams");
473 assert!(pool_size <= STREAMS_PER_DEVICE, "pool should not exceed configured size");
474
475 // Collect all streams from a full cycle.
476 let mut streams = vec![s1, s2];
477 for _ in 2..pool_size {
478 streams.push(
479 StreamPool::get_stream(&ctx, dev).expect("get_stream should succeed"),
480 );
481 }
482
483 // The next stream should wrap around to the same as the first.
484 let wrap = StreamPool::get_stream(&ctx, dev)
485 .expect("wrapped get_stream should succeed");
486
487 // Because round-robin, `wrap` should be the same Arc as `streams[0]`.
488 // We compare the underlying cu_stream pointers.
489 assert_eq!(
490 Arc::as_ptr(&wrap),
491 Arc::as_ptr(&streams[0]),
492 "round-robin should wrap back to the first stream"
493 );
494 }
495
496 #[test]
497 fn stream_pool_invalid_device() {
498 let Some(ctx) = test_ctx() else { return };
499 let result = StreamPool::get_stream(&ctx, MAX_DEVICES + 1);
500 assert!(result.is_err(), "should reject ordinal >= MAX_DEVICES");
501 }
502
503 #[test]
504 fn stream_guard_restores_previous() {
505 let Some(ctx) = test_ctx() else { return };
506 let dev = 0;
507
508 // Initially, no current stream.
509 assert!(
510 get_current_stream(dev).is_none(),
511 "should start with no current stream"
512 );
513
514 let s1 = ctx.new_stream().expect("new_stream should succeed");
515 let s2 = ctx.new_stream().expect("new_stream should succeed");
516
517 let s1_ptr = Arc::as_ptr(&s1);
518 let s2_ptr = Arc::as_ptr(&s2);
519
520 // Set s1 as current.
521 set_current_stream(dev, Arc::clone(&s1));
522 assert_eq!(
523 Arc::as_ptr(&get_current_stream(dev).unwrap()),
524 s1_ptr,
525 "current stream should be s1"
526 );
527
528 // Create a guard that sets s2.
529 {
530 let _guard = StreamGuard::new(dev, Arc::clone(&s2));
531 assert_eq!(
532 Arc::as_ptr(&get_current_stream(dev).unwrap()),
533 s2_ptr,
534 "current stream should be s2 inside guard"
535 );
536 }
537
538 // After guard drop, s1 should be restored.
539 assert_eq!(
540 Arc::as_ptr(&get_current_stream(dev).unwrap()),
541 s1_ptr,
542 "current stream should be restored to s1 after guard drop"
543 );
544
545 // Clean up.
546 clear_current_stream(dev);
547 assert!(
548 get_current_stream(dev).is_none(),
549 "should be cleared after explicit clear"
550 );
551 }
552
553 #[test]
554 fn stream_guard_clears_when_no_previous() {
555 let Some(ctx) = test_ctx() else { return };
556 let dev = 0;
557
558 // Ensure no current stream.
559 clear_current_stream(dev);
560 assert!(get_current_stream(dev).is_none());
561
562 let s1 = ctx.new_stream().expect("new_stream should succeed");
563
564 {
565 let _guard = StreamGuard::new(dev, Arc::clone(&s1));
566 assert!(
567 get_current_stream(dev).is_some(),
568 "guard should set current stream"
569 );
570 }
571
572 // Guard had no previous — should clear.
573 assert!(
574 get_current_stream(dev).is_none(),
575 "guard with no previous should clear current stream on drop"
576 );
577 }
578
579 #[test]
580 fn current_stream_or_default_fallback() {
581 // We can't easily construct a GpuDevice in tests without a real GPU
582 // context, but we can test the thread-local logic in isolation.
583 let Some(ctx) = test_ctx() else { return };
584 let dev_ordinal = 0;
585
586 // Clear any leftover state.
587 clear_current_stream(dev_ordinal);
588
589 let device = crate::device::GpuDevice::new(dev_ordinal)
590 .expect("GpuDevice::new should succeed");
591 let default_ptr = Arc::as_ptr(device.default_stream());
592
593 // No current stream set — should fall back to device default.
594 let stream = current_stream_or_default(&device);
595 assert_eq!(
596 Arc::as_ptr(&stream),
597 default_ptr,
598 "should fall back to device default stream"
599 );
600
601 // Set a custom stream — should use it instead.
602 let custom = ctx.new_stream().expect("new_stream should succeed");
603 let custom_ptr = Arc::as_ptr(&custom);
604 set_current_stream(dev_ordinal, custom);
605
606 let stream = current_stream_or_default(&device);
607 assert_eq!(
608 Arc::as_ptr(&stream),
609 custom_ptr,
610 "should use thread-local current stream"
611 );
612
613 // Clean up.
614 clear_current_stream(dev_ordinal);
615 }
616
617 #[test]
618 fn event_wait_on_stream() {
619 let Some(ctx) = test_ctx() else { return };
620 let stream1 = ctx.default_stream();
621 let stream2 = ctx.new_stream().expect("new_stream should succeed");
622
623 let event = CudaEventWrapper::new(&ctx)
624 .expect("event creation should succeed");
625
626 // Record on stream1.
627 event.record(&stream1).expect("record should succeed");
628
629 // Make stream2 wait on the event (GPU-side sync).
630 event.wait_on(&stream2).expect("wait_on should succeed");
631
632 // Synchronize stream2 — this implicitly waits for stream1's work too.
633 stream2.synchronize().expect("synchronize should succeed");
634 }
635}