ferrotorch_gpu/graph.rs
1//! CUDA graph capture and replay infrastructure.
2//!
3//! A CUDA graph records a sequence of GPU operations (kernel launches, memcpys)
4//! and replays them as a single driver submission. This eliminates per-kernel
5//! launch overhead (~70μs on WSL2, ~5μs on native Linux per call) by collapsing
6//! hundreds of launches into one.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use ferrotorch_gpu::graph::{DeviceScalar, begin_capture, end_capture};
12//!
13//! // Pre-allocate all buffers BEFORE capture
14//! let mut out = alloc_zeros_f32(768, &device)?;
15//!
16//! // Parameters that change between replays go in DeviceScalar
17//! let mut pos = DeviceScalar::new(device.stream(), 0u32)?;
18//!
19//! // Capture
20//! begin_capture(device.stream())?;
21//! gpu_add_into(&a, &b, &mut out, &device)?; // recorded, not executed
22//! let graph = end_capture(device.stream())?;
23//!
24//! // Replay loop
25//! for i in 0..100 {
26//! pos.update(i as u32)?; // memcpy before replay
27//! graph.launch()?; // replay all captured ops
28//! }
29//! ```
30
31#[cfg(feature = "cuda")]
32use std::sync::Arc;
33#[cfg(feature = "cuda")]
34use std::sync::atomic::{AtomicU64, Ordering};
35
36#[cfg(feature = "cuda")]
37use cudarc::driver::{CudaSlice, CudaStream, DeviceRepr, ValidAsZeroBits};
38
39use crate::error::{GpuError, GpuResult};
40
41// ---------------------------------------------------------------------------
42// CaptureMode — typed wrapper over cudarc's CUstreamCaptureMode
43// ---------------------------------------------------------------------------
44
45/// Selects how CUDA graph capture serializes interactions with other
46/// threads. Mirrors `cudaStreamCaptureMode`.
47///
48/// - `Global` — any CUDA API call from any thread that touches the
49/// capturing stream (or any thread that is also capturing) will
50/// invalidate capture. Safest for debugging; matches PyTorch's
51/// default.
52/// - `ThreadLocal` — only calls from the capturing thread can
53/// invalidate capture. Other threads may freely use unrelated
54/// streams. This is what ferrotorch-gpu has always used.
55/// - `Relaxed` — the driver does not track cross-thread interactions
56/// at all. Fastest, but the caller is fully responsible for making
57/// sure no other thread interferes.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59pub enum CaptureMode {
60 /// Global serialization (`CU_STREAM_CAPTURE_MODE_GLOBAL`).
61 Global,
62 /// Thread-local serialization (`CU_STREAM_CAPTURE_MODE_THREAD_LOCAL`).
63 /// This is the default in PyTorch's `cuda.graph` context.
64 ThreadLocal,
65 /// Relaxed — no cross-thread serialization
66 /// (`CU_STREAM_CAPTURE_MODE_RELAXED`).
67 Relaxed,
68}
69
70impl Default for CaptureMode {
71 fn default() -> Self {
72 // ThreadLocal is what ferrotorch-gpu has historically used and
73 // matches PyTorch's `torch.cuda.graph(stream=..., capture_error_mode="thread_local")`.
74 Self::ThreadLocal
75 }
76}
77
78#[cfg(feature = "cuda")]
79impl CaptureMode {
80 /// Convert to the raw cudarc enum.
81 #[inline]
82 pub fn to_cuda(self) -> cudarc::driver::sys::CUstreamCaptureMode {
83 use cudarc::driver::sys::CUstreamCaptureMode::*;
84 match self {
85 Self::Global => CU_STREAM_CAPTURE_MODE_GLOBAL,
86 Self::ThreadLocal => CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
87 Self::Relaxed => CU_STREAM_CAPTURE_MODE_RELAXED,
88 }
89 }
90}
91
92// ---------------------------------------------------------------------------
93// CaptureStatus — typed wrapper over cudarc's CUstreamCaptureStatus
94// ---------------------------------------------------------------------------
95
96/// The capture state of a CUDA stream. Matches `cudaStreamCaptureStatus`.
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
98pub enum CaptureStatus {
99 /// The stream is not currently capturing any graph.
100 None,
101 /// The stream is actively capturing a graph.
102 Active,
103 /// Capture was invalidated (e.g., by a forbidden API call or a
104 /// cross-stream dependency). The caller must call `end_capture`
105 /// to discard the broken graph before doing anything else on the
106 /// stream.
107 Invalidated,
108}
109
110#[cfg(feature = "cuda")]
111impl CaptureStatus {
112 fn from_cuda(raw: cudarc::driver::sys::CUstreamCaptureStatus) -> Self {
113 use cudarc::driver::sys::CUstreamCaptureStatus::*;
114 match raw {
115 CU_STREAM_CAPTURE_STATUS_NONE => Self::None,
116 CU_STREAM_CAPTURE_STATUS_ACTIVE => Self::Active,
117 CU_STREAM_CAPTURE_STATUS_INVALIDATED => Self::Invalidated,
118 }
119 }
120}
121
122impl CaptureStatus {
123 /// Returns `true` if this stream is actively capturing a graph.
124 #[inline]
125 pub fn is_capturing(&self) -> bool {
126 matches!(self, Self::Active)
127 }
128
129 /// Returns `true` if capture was invalidated and must be ended.
130 #[inline]
131 pub fn is_invalidated(&self) -> bool {
132 matches!(self, Self::Invalidated)
133 }
134}
135
136// ---------------------------------------------------------------------------
137// DeviceScalar — a single value in GPU memory, updatable before graph replay
138// ---------------------------------------------------------------------------
139
140/// A single scalar value stored in GPU device memory.
141///
142/// Used for CUDA graph capture: the graph records the device pointer (fixed
143/// address), and the caller updates the value via [`update`](DeviceScalar::update)
144/// before each [`CapturedGraph::launch`]. The update is a 4-or-8 byte
145/// `cuMemcpyHtoDAsync` — effectively zero cost.
146#[cfg(feature = "cuda")]
147pub struct DeviceScalar<T: DeviceRepr + ValidAsZeroBits + Copy> {
148 buf: CudaSlice<T>,
149 stream: Arc<CudaStream>,
150}
151
152#[cfg(feature = "cuda")]
153impl<T: DeviceRepr + ValidAsZeroBits + Copy> DeviceScalar<T> {
154 /// Allocate a device scalar with the given initial value.
155 pub fn new(stream: &Arc<CudaStream>, initial: T) -> GpuResult<Self> {
156 let buf = stream.clone_htod(&[initial])?;
157 Ok(Self {
158 buf,
159 stream: Arc::clone(stream),
160 })
161 }
162
163 /// Update the device value. This is an async H→D memcpy of `size_of::<T>()`
164 /// bytes. Must be called on the same stream as the graph to ensure ordering.
165 pub fn update(&mut self, value: T) -> GpuResult<()> {
166 self.stream.memcpy_htod(&[value], &mut self.buf)?;
167 Ok(())
168 }
169
170 /// Borrow the underlying `CudaSlice` for use as a kernel parameter.
171 /// The graph captures this pointer address; updating the value later
172 /// changes what the kernel reads without re-capturing.
173 #[inline]
174 pub fn inner(&self) -> &CudaSlice<T> {
175 &self.buf
176 }
177}
178
179// ---------------------------------------------------------------------------
180// CapturedGraph — a replayable CUDA graph
181// ---------------------------------------------------------------------------
182
183/// A captured and instantiated CUDA graph that can be replayed with
184/// [`launch`](CapturedGraph::launch).
185///
186/// Created via [`begin_capture`] + GPU ops + [`end_capture`].
187/// The graph holds references to all device memory used during capture.
188/// Those buffers must remain allocated for the lifetime of the graph.
189///
190/// **Allocator pool integration (CL-278).** When created via
191/// [`end_capture_with_pool`], the graph holds a strong reference to
192/// the [`CapturePool`] that recorded its allocations. The pool keeps
193/// every registered buffer alive until the last `CapturedGraph`
194/// referencing it is dropped, which guarantees the device pointers
195/// recorded in the graph remain valid across replays. Without the
196/// pool, callers must manually keep buffers alive (the original
197/// [`end_capture`] API).
198#[cfg(feature = "cuda")]
199pub struct CapturedGraph {
200 graph: cudarc::driver::CudaGraph,
201 /// Optional reference to the pool that owns the graph's
202 /// allocations. Some(pool) when constructed via
203 /// [`end_capture_with_pool`]. Dropping the graph drops this
204 /// Arc, which (if it's the last reference) drops every buffer
205 /// the pool holds. CL-278.
206 pool: Option<Arc<CapturePool>>,
207 /// Monotonic counter bumped by every successful [`launch`]. Lets
208 /// callers assert that a specific replay happened after some
209 /// other work completed, useful for graph-aware profilers and
210 /// integration tests. CL-454.
211 replay_count: AtomicU64,
212 /// True after the first successful [`upload`] so subsequent
213 /// uploads become cheap no-ops. CL-454.
214 uploaded: std::sync::atomic::AtomicBool,
215}
216
217#[cfg(feature = "cuda")]
218impl CapturedGraph {
219 /// Replay all operations captured in this graph.
220 ///
221 /// Before calling this, update any [`DeviceScalar`] values and perform
222 /// any pre-launch memcpys (e.g., position embeddings). All updates must
223 /// be on the same stream the graph was captured on.
224 ///
225 /// Bumps [`num_replays`](Self::num_replays) on success.
226 pub fn launch(&self) -> GpuResult<()> {
227 self.graph.launch()?;
228 self.replay_count.fetch_add(1, Ordering::Relaxed);
229 Ok(())
230 }
231
232 /// Pre-upload the graph's executable resources to the device.
233 ///
234 /// The first [`launch`](Self::launch) of a freshly instantiated graph
235 /// pays a one-time cost for the driver to copy the exec into GPU
236 /// memory. Calling `upload` up front shifts that cost out of the
237 /// hot replay loop. Subsequent uploads are a no-op. CL-454.
238 pub fn upload(&self) -> GpuResult<()> {
239 if self.uploaded.load(Ordering::Acquire) {
240 return Ok(());
241 }
242 self.graph.upload()?;
243 self.uploaded.store(true, Ordering::Release);
244 Ok(())
245 }
246
247 /// Number of successful replays issued on this graph. CL-454.
248 #[inline]
249 pub fn num_replays(&self) -> u64 {
250 self.replay_count.load(Ordering::Relaxed)
251 }
252
253 /// Returns `true` if [`upload`](Self::upload) has been called on
254 /// this graph. CL-454.
255 #[inline]
256 pub fn is_uploaded(&self) -> bool {
257 self.uploaded.load(Ordering::Acquire)
258 }
259
260 /// Number of buffers held alive by this graph's allocator pool.
261 /// Returns 0 if the graph was created without a pool. CL-278.
262 pub fn pool_buffer_count(&self) -> usize {
263 self.pool
264 .as_ref()
265 .map(|p| p.buffer_count())
266 .unwrap_or(0)
267 }
268
269 /// True if this graph holds a CapturePool reference. CL-278.
270 pub fn has_pool(&self) -> bool {
271 self.pool.is_some()
272 }
273
274 /// Return the [`Arc<CapturePool>`] this graph is using, if any.
275 /// Allows sharing the same pool between multiple graphs so they
276 /// all keep the same buffers alive. CL-454.
277 pub fn pool(&self) -> Option<&Arc<CapturePool>> {
278 self.pool.as_ref()
279 }
280}
281
282// ---------------------------------------------------------------------------
283// Capture API
284// ---------------------------------------------------------------------------
285
286/// Begin CUDA graph capture on the given stream.
287///
288/// All GPU operations (kernel launches, cuBLAS calls, memcpys) issued on this
289/// stream after this call are **recorded but not executed**. Call
290/// [`end_capture`] to finalize and instantiate the graph.
291///
292/// # Requirements
293///
294/// - All output buffers must be pre-allocated before capture begins.
295/// - No `alloc_zeros` / `cpu_to_gpu` during capture (use `_into` variants).
296/// - No CPU↔GPU synchronization during capture.
297/// - Event tracking should be disabled during capture to avoid interference
298/// (call `ctx.disable_event_tracking()` before, re-enable after).
299#[cfg(feature = "cuda")]
300pub fn begin_capture(stream: &Arc<CudaStream>) -> GpuResult<()> {
301 begin_capture_with_mode(stream, CaptureMode::default())
302}
303
304/// Begin CUDA graph capture with an explicit [`CaptureMode`]. CL-454.
305///
306/// Prefer [`begin_capture`] for the default (`ThreadLocal`) mode. Use
307/// this form when you need `Global` (debugging / strict serialization)
308/// or `Relaxed` (max throughput, single-thread ownership).
309#[cfg(feature = "cuda")]
310pub fn begin_capture_with_mode(
311 stream: &Arc<CudaStream>,
312 mode: CaptureMode,
313) -> GpuResult<()> {
314 stream.begin_capture(mode.to_cuda())?;
315 Ok(())
316}
317
318/// Query the capture status of a CUDA stream. CL-454.
319///
320/// This is the ferrotorch-gpu equivalent of PyTorch's
321/// `torch.cuda.is_current_stream_capturing`. Callers can use this to
322/// skip capture-invalid APIs (allocator calls, H↔D copies) when a
323/// graph is being recorded.
324#[cfg(feature = "cuda")]
325pub fn capture_status(stream: &Arc<CudaStream>) -> GpuResult<CaptureStatus> {
326 let raw = stream.capture_status()?;
327 Ok(CaptureStatus::from_cuda(raw))
328}
329
330/// Shorthand for `capture_status(stream)?.is_capturing()`. CL-454.
331#[cfg(feature = "cuda")]
332pub fn is_stream_capturing(stream: &Arc<CudaStream>) -> GpuResult<bool> {
333 Ok(capture_status(stream)?.is_capturing())
334}
335
336/// End CUDA graph capture, instantiate, and return the replayable graph.
337///
338/// Returns `Err` if capture was not active or if instantiation fails.
339///
340/// The returned graph has no [`CapturePool`] attached. The caller is
341/// responsible for keeping the buffers used by the captured kernels
342/// alive for the graph's lifetime. Use [`end_capture_with_pool`]
343/// for the lifetime-managed variant.
344#[cfg(feature = "cuda")]
345pub fn end_capture(stream: &Arc<CudaStream>) -> GpuResult<CapturedGraph> {
346 let flags = cudarc::driver::sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
347 let graph = stream
348 .end_capture(flags)?
349 .ok_or(GpuError::PtxCompileFailed {
350 kernel: "CUDA graph capture returned null",
351 })?;
352 Ok(CapturedGraph {
353 graph,
354 pool: None,
355 replay_count: AtomicU64::new(0),
356 uploaded: std::sync::atomic::AtomicBool::new(false),
357 })
358}
359
360/// End CUDA graph capture and attach a [`CapturePool`] reference to
361/// the resulting [`CapturedGraph`]. CL-278.
362///
363/// The pool's tracked buffers are kept alive for the lifetime of the
364/// returned graph: dropping the graph drops its `Arc<CapturePool>`,
365/// which (if it's the last reference) drops every buffer the pool
366/// recorded. This guarantees that the device pointers recorded in
367/// the captured graph remain valid across replays.
368///
369/// Use this in concert with [`CapturePool::record_buffer`]: allocate
370/// every buffer used during capture before calling `begin_capture`,
371/// register each one with the pool, run the kernels under capture,
372/// then call `end_capture_with_pool(stream, pool)` to seal the
373/// lifetime relationship.
374#[cfg(feature = "cuda")]
375pub fn end_capture_with_pool(
376 stream: &Arc<CudaStream>,
377 pool: Arc<CapturePool>,
378) -> GpuResult<CapturedGraph> {
379 let mut graph = end_capture(stream)?;
380 graph.pool = Some(pool);
381 Ok(graph)
382}
383
384// ---------------------------------------------------------------------------
385// GraphCaptureGuard — RAII wrapper that ends capture on drop
386// ---------------------------------------------------------------------------
387
388/// RAII guard that runs CUDA graph capture in a scoped block.
389///
390/// Call [`GraphCaptureGuard::begin`] (or [`begin_with_mode`] /
391/// [`begin_with_pool`]) to start capture; calling [`finish`] returns
392/// the instantiated graph. If the guard is dropped without calling
393/// `finish` (for example because a kernel returned an error
394/// mid-capture), its `Drop` impl best-effort-ends capture and
395/// discards the resulting graph so the stream returns to a usable
396/// state. CL-454.
397///
398/// This mirrors PyTorch's `with torch.cuda.graph(g): ...` context
399/// manager semantics in Rust's RAII idiom.
400///
401/// # Example
402///
403/// ```ignore
404/// use ferrotorch_gpu::graph::GraphCaptureGuard;
405///
406/// let mut guard = GraphCaptureGuard::begin(device.stream())?;
407/// run_kernels()?; // any kernel launched on device.stream() is recorded
408/// let graph = guard.finish()?;
409/// graph.upload()?;
410/// for _ in 0..1000 { graph.launch()?; }
411/// ```
412#[cfg(feature = "cuda")]
413pub struct GraphCaptureGuard {
414 stream: Arc<CudaStream>,
415 /// Optional pool to attach when `finish` is called.
416 pool: Option<Arc<CapturePool>>,
417 /// Becomes `false` after [`finish`] consumes the guard, so `Drop`
418 /// knows capture is already ended.
419 active: bool,
420}
421
422#[cfg(feature = "cuda")]
423impl GraphCaptureGuard {
424 /// Begin graph capture on `stream` in the default
425 /// [`CaptureMode::ThreadLocal`] mode. CL-454.
426 pub fn begin(stream: &Arc<CudaStream>) -> GpuResult<Self> {
427 Self::begin_with_mode(stream, CaptureMode::default())
428 }
429
430 /// Begin graph capture with an explicit [`CaptureMode`]. CL-454.
431 pub fn begin_with_mode(
432 stream: &Arc<CudaStream>,
433 mode: CaptureMode,
434 ) -> GpuResult<Self> {
435 begin_capture_with_mode(stream, mode)?;
436 Ok(Self {
437 stream: Arc::clone(stream),
438 pool: None,
439 active: true,
440 })
441 }
442
443 /// Begin graph capture bound to a [`CapturePool`]. The pool is
444 /// attached to the resulting graph by [`finish`]. CL-454.
445 pub fn begin_with_pool(
446 stream: &Arc<CudaStream>,
447 pool: Arc<CapturePool>,
448 ) -> GpuResult<Self> {
449 begin_capture_with_pool(&pool, stream)?;
450 Ok(Self {
451 stream: Arc::clone(stream),
452 pool: Some(pool),
453 active: true,
454 })
455 }
456
457 /// Finish capture and return the instantiated [`CapturedGraph`].
458 ///
459 /// Consumes the guard so `Drop` becomes a no-op. If a pool was
460 /// attached at construction, the resulting graph is produced via
461 /// [`end_capture_with_pool`] and holds the pool Arc for the
462 /// lifetime of the graph.
463 pub fn finish(mut self) -> GpuResult<CapturedGraph> {
464 self.active = false;
465 if let Some(pool) = self.pool.take() {
466 end_capture_with_pool(&self.stream, pool)
467 } else {
468 end_capture(&self.stream)
469 }
470 }
471
472 /// Report whether the stream this guard is bound to is still
473 /// actively capturing. An unexpected `Invalidated` or `None`
474 /// usually means a forbidden API call (alloc, sync, host copy)
475 /// happened under capture.
476 pub fn status(&self) -> GpuResult<CaptureStatus> {
477 capture_status(&self.stream)
478 }
479}
480
481#[cfg(feature = "cuda")]
482impl Drop for GraphCaptureGuard {
483 fn drop(&mut self) {
484 if !self.active {
485 return;
486 }
487 // Best-effort: discard the in-flight capture so the stream
488 // becomes usable again. We ignore errors because we're in
489 // Drop — the CapturedGraph result is immediately dropped.
490 let _ = end_capture(&self.stream);
491 }
492}
493
494// ---------------------------------------------------------------------------
495// Graph pool handle registry — share a CapturePool across multiple graphs
496// ---------------------------------------------------------------------------
497
498/// Opaque handle for a pool registered with the process-wide graph
499/// pool registry. Used to share the same buffer-lifetime pool across
500/// multiple captured graphs without passing `Arc<CapturePool>` around
501/// by hand. CL-454.
502///
503/// Mirrors PyTorch's `torch.cuda.graph_pool_handle()`.
504#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
505pub struct GraphPoolHandle(pub u64);
506
507#[cfg(feature = "cuda")]
508static NEXT_POOL_HANDLE: AtomicU64 = AtomicU64::new(1);
509
510#[cfg(feature = "cuda")]
511static POOL_REGISTRY: std::sync::OnceLock<
512 std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>>,
513> = std::sync::OnceLock::new();
514
515#[cfg(feature = "cuda")]
516fn pool_registry() -> &'static std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>> {
517 POOL_REGISTRY.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
518}
519
520/// Allocate a fresh [`GraphPoolHandle`] and register a new
521/// [`CapturePool`] under it in the process-wide registry. CL-454.
522///
523/// The handle can later be passed to [`capture_pool_for_handle`] to
524/// retrieve the same `Arc<CapturePool>` from any thread, which lets
525/// two independently captured graphs share the same buffer-keeping
526/// pool.
527#[cfg(feature = "cuda")]
528pub fn graph_pool_handle() -> GraphPoolHandle {
529 let id = NEXT_POOL_HANDLE.fetch_add(1, Ordering::Relaxed);
530 let pool = Arc::new(CapturePool::new());
531 let mut reg = pool_registry()
532 .lock()
533 .unwrap_or_else(|p| p.into_inner());
534 reg.insert(id, pool);
535 GraphPoolHandle(id)
536}
537
538/// Look up the [`CapturePool`] registered under `handle` and return
539/// a strong `Arc` to it. Returns `None` if the handle was never
540/// allocated or has been released via [`release_graph_pool_handle`].
541/// CL-454.
542#[cfg(feature = "cuda")]
543pub fn capture_pool_for_handle(handle: GraphPoolHandle) -> Option<Arc<CapturePool>> {
544 let reg = pool_registry()
545 .lock()
546 .unwrap_or_else(|p| p.into_inner());
547 reg.get(&handle.0).cloned()
548}
549
550/// Drop the registry's strong reference to the pool behind `handle`.
551/// Any [`CapturedGraph`] that holds its own Arc (for example via
552/// [`end_capture_with_pool`]) keeps the pool alive until that graph
553/// is dropped too. CL-454.
554#[cfg(feature = "cuda")]
555pub fn release_graph_pool_handle(handle: GraphPoolHandle) {
556 let mut reg = pool_registry()
557 .lock()
558 .unwrap_or_else(|p| p.into_inner());
559 reg.remove(&handle.0);
560}
561
562// ---------------------------------------------------------------------------
563// make_graphed_callable — scoped capture over a closure
564// ---------------------------------------------------------------------------
565
566/// Capture the operations performed by `f` into a CUDA graph and
567/// return the replayable graph. CL-454.
568///
569/// This is the ferrotorch-gpu equivalent of PyTorch's
570/// `torch.cuda.make_graphed_callables` for the simple single-callable
571/// case: the caller supplies a closure that runs all the GPU work to
572/// capture, and the returned [`CapturedGraph`] can be replayed over
573/// and over. The closure runs exactly once during capture, so all
574/// per-call work (allocations, dtype decisions) that isn't valid
575/// under capture must happen outside.
576///
577/// If the closure returns an error, capture is discarded and the
578/// error is propagated.
579#[cfg(feature = "cuda")]
580pub fn make_graphed_callable<F>(
581 stream: &Arc<CudaStream>,
582 mode: CaptureMode,
583 f: F,
584) -> GpuResult<CapturedGraph>
585where
586 F: FnOnce() -> GpuResult<()>,
587{
588 let guard = GraphCaptureGuard::begin_with_mode(stream, mode)?;
589 match f() {
590 Ok(()) => guard.finish(),
591 Err(e) => {
592 // Guard drop ends capture and discards the graph.
593 drop(guard);
594 Err(e)
595 }
596 }
597}
598
599// ---------------------------------------------------------------------------
600// CapturePool — memory pool for graph capture
601// ---------------------------------------------------------------------------
602
603/// A dedicated memory pool for CUDA graph capture.
604///
605/// Two responsibilities:
606///
607/// 1. **Sealed flag** — gates [`begin_capture_with_pool`] so the
608/// caller can express "no more allocations after this point"
609/// semantically. Sealed pools cannot satisfy new allocations
610/// during capture.
611///
612/// 2. **Buffer lifetime tracking (CL-278)** — registered buffers
613/// are kept alive by the pool itself, so they outlive any
614/// [`CapturedGraph`] that holds an `Arc<CapturePool>`. Dropping
615/// the graph drops the Arc, and dropping the last Arc drops
616/// every registered buffer in registration order.
617///
618/// # Usage
619///
620/// ```ignore
621/// use std::sync::Arc;
622/// let pool = Arc::new(CapturePool::new());
623///
624/// // Allocate every buffer the captured kernels will read or
625/// // write, and register each one with the pool so it stays alive
626/// // for the graph's lifetime.
627/// let mut buf_a = alloc_zeros_f32(1024, &device)?;
628/// let mut buf_b = alloc_zeros_f32(1024, &device)?;
629/// pool.record_buffer(buf_a.try_clone()?);
630/// pool.record_buffer(buf_b.try_clone()?);
631///
632/// pool.seal();
633/// begin_capture_with_pool(&pool, stream)?;
634/// // ... launch kernels using buf_a and buf_b ...
635/// let graph = end_capture_with_pool(stream, Arc::clone(&pool))?;
636/// // Dropping `pool` here is safe — the graph holds its own Arc.
637/// ```
638#[cfg(feature = "cuda")]
639pub struct CapturePool {
640 sealed: std::sync::atomic::AtomicBool,
641 /// Registered buffers (type-erased) kept alive for the graph's
642 /// lifetime. Each entry is a Box<dyn Any + Send + Sync> wrapping
643 /// the buffer's drop guard. CL-278.
644 buffers: std::sync::Mutex<Vec<Box<dyn std::any::Any + Send + Sync + 'static>>>,
645}
646
647#[cfg(feature = "cuda")]
648impl CapturePool {
649 /// Create a new, unsealed capture pool.
650 pub fn new() -> Self {
651 Self {
652 sealed: std::sync::atomic::AtomicBool::new(false),
653 buffers: std::sync::Mutex::new(Vec::new()),
654 }
655 }
656
657 /// Seal the pool, preventing any further allocations.
658 pub fn seal(&self) {
659 self.sealed
660 .store(true, std::sync::atomic::Ordering::Release);
661 }
662
663 /// Unseal the pool, allowing allocations again.
664 pub fn unseal(&self) {
665 self.sealed
666 .store(false, std::sync::atomic::Ordering::Release);
667 }
668
669 /// Check whether the pool is sealed.
670 pub fn is_capture_pool_sealed(&self) -> bool {
671 self.sealed.load(std::sync::atomic::Ordering::Acquire)
672 }
673
674 /// Register a buffer with the pool so it stays alive for the
675 /// lifetime of any [`CapturedGraph`] that holds this pool.
676 /// CL-278.
677 ///
678 /// `buffer` can be any type that owns GPU memory (typically
679 /// `CudaBuffer<f32>`, `CudaBuffer<f64>`, or `Arc<CudaBuffer<T>>`).
680 /// The pool stores it in a type-erased `Box<dyn Any + Send +
681 /// Sync>` and drops it (in registration order) when the pool
682 /// itself is dropped.
683 ///
684 /// Returns the index of the registered buffer for diagnostic
685 /// purposes.
686 pub fn record_buffer<B>(&self, buffer: B) -> usize
687 where
688 B: Send + Sync + 'static,
689 {
690 let mut guard = self
691 .buffers
692 .lock()
693 .unwrap_or_else(|p| p.into_inner());
694 let idx = guard.len();
695 guard.push(Box::new(buffer));
696 idx
697 }
698
699 /// Number of buffers currently registered with the pool. CL-278.
700 pub fn buffer_count(&self) -> usize {
701 self.buffers
702 .lock()
703 .map(|g| g.len())
704 .unwrap_or(0)
705 }
706
707 /// Drop every registered buffer immediately, in registration
708 /// order. The pool itself remains usable; new buffers can still
709 /// be registered after this call. CL-278.
710 ///
711 /// Use this when reusing a pool across multiple capture cycles.
712 /// Calling clear while a [`CapturedGraph`] still holds an Arc
713 /// to this pool is safe — the graph's strong reference keeps
714 /// the pool struct alive, but the buffer slots are reset.
715 pub fn clear_buffers(&self) {
716 let mut guard = self
717 .buffers
718 .lock()
719 .unwrap_or_else(|p| p.into_inner());
720 guard.clear();
721 }
722}
723
724#[cfg(feature = "cuda")]
725impl Default for CapturePool {
726 fn default() -> Self {
727 Self::new()
728 }
729}
730
731/// Begin CUDA graph capture with a capture pool.
732///
733/// Like [`begin_capture`], but checks that the capture pool is not sealed
734/// before starting capture. A sealed pool cannot satisfy allocations
735/// during graph capture, which would cause CUDA errors.
736///
737/// # Errors
738///
739/// Returns [`GpuError::InvalidArgument`](GpuError) if the pool is sealed.
740/// Returns a CUDA driver error if `begin_capture` fails.
741#[cfg(feature = "cuda")]
742pub fn begin_capture_with_pool(pool: &CapturePool, stream: &Arc<CudaStream>) -> GpuResult<()> {
743 if pool.is_capture_pool_sealed() {
744 return Err(GpuError::InvalidState {
745 message: "cannot begin graph capture: capture pool is sealed".into(),
746 });
747 }
748 begin_capture(stream)
749}
750
751/// Stub CapturePool when cuda feature is disabled. Provides the
752/// same surface API as the cuda-enabled type so callers compile on
753/// both feature configurations.
754#[cfg(not(feature = "cuda"))]
755pub struct CapturePool;
756
757#[cfg(not(feature = "cuda"))]
758impl CapturePool {
759 /// Create an empty CapturePool. Without the cuda feature the
760 /// pool has no internal state to initialize.
761 pub fn new() -> Self {
762 Self
763 }
764
765 /// No-op without the cuda feature: there is no real CUDA pool
766 /// to seal because no real allocations can happen.
767 pub fn seal(&self) {
768 // Without the cuda feature there is no allocator state to
769 // mutate; the CapturePool exists only so callers can write
770 // feature-portable code.
771 }
772
773 /// No-op without the cuda feature: there is no real CUDA pool
774 /// to unseal because no real allocations can happen.
775 pub fn unseal(&self) {
776 // Without the cuda feature there is no allocator state to
777 // mutate; the CapturePool exists only so callers can write
778 // feature-portable code.
779 }
780
781 /// Always returns `false` without the cuda feature since there
782 /// is no real pool that could be in either state.
783 pub fn is_capture_pool_sealed(&self) -> bool {
784 false
785 }
786
787 /// Always returns 0 without the cuda feature since no real
788 /// allocations can be tracked. CL-278.
789 pub fn buffer_count(&self) -> usize {
790 0
791 }
792}
793
794#[cfg(not(feature = "cuda"))]
795impl Default for CapturePool {
796 fn default() -> Self {
797 Self::new()
798 }
799}
800
801/// Stub begin_capture_with_pool when cuda feature is disabled.
802#[cfg(not(feature = "cuda"))]
803pub fn begin_capture_with_pool<T>(_pool: &CapturePool, _stream: &T) -> GpuResult<()> {
804 Err(GpuError::NoCudaFeature)
805}
806
807// ---------------------------------------------------------------------------
808// Stubs when cuda feature is disabled
809// ---------------------------------------------------------------------------
810
811/// Stub DeviceScalar.
812#[cfg(not(feature = "cuda"))]
813pub struct DeviceScalar<T: Copy> {
814 _phantom: std::marker::PhantomData<T>,
815}
816
817/// Stub CapturedGraph.
818#[cfg(not(feature = "cuda"))]
819pub struct CapturedGraph;
820
821#[cfg(not(feature = "cuda"))]
822impl CapturedGraph {
823 pub fn launch(&self) -> GpuResult<()> {
824 Err(GpuError::NoCudaFeature)
825 }
826
827 /// Stub upload for CL-454.
828 pub fn upload(&self) -> GpuResult<()> {
829 Err(GpuError::NoCudaFeature)
830 }
831
832 /// Stub num_replays — always 0 without the cuda feature. CL-454.
833 pub fn num_replays(&self) -> u64 {
834 0
835 }
836
837 /// Stub is_uploaded — always false without the cuda feature. CL-454.
838 pub fn is_uploaded(&self) -> bool {
839 false
840 }
841}
842
843#[cfg(not(feature = "cuda"))]
844pub fn begin_capture<T>(_stream: &T) -> GpuResult<()> {
845 Err(GpuError::NoCudaFeature)
846}
847
848/// Stub `begin_capture_with_mode` when the cuda feature is not enabled.
849/// CL-454.
850#[cfg(not(feature = "cuda"))]
851pub fn begin_capture_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<()> {
852 Err(GpuError::NoCudaFeature)
853}
854
855/// Stub `capture_status` when the cuda feature is not enabled. CL-454.
856#[cfg(not(feature = "cuda"))]
857pub fn capture_status<T>(_stream: &T) -> GpuResult<CaptureStatus> {
858 Err(GpuError::NoCudaFeature)
859}
860
861/// Stub `is_stream_capturing` when the cuda feature is not enabled.
862/// CL-454.
863#[cfg(not(feature = "cuda"))]
864pub fn is_stream_capturing<T>(_stream: &T) -> GpuResult<bool> {
865 Err(GpuError::NoCudaFeature)
866}
867
868#[cfg(not(feature = "cuda"))]
869pub fn end_capture<T>(_stream: &T) -> GpuResult<CapturedGraph> {
870 Err(GpuError::NoCudaFeature)
871}
872
873/// Stub `end_capture_with_pool` when the cuda feature is not enabled.
874/// CL-278.
875#[cfg(not(feature = "cuda"))]
876pub fn end_capture_with_pool<T>(
877 _stream: &T,
878 _pool: std::sync::Arc<CapturePool>,
879) -> GpuResult<CapturedGraph> {
880 Err(GpuError::NoCudaFeature)
881}
882
883/// Stub `GraphCaptureGuard` when the cuda feature is not enabled. CL-454.
884#[cfg(not(feature = "cuda"))]
885pub struct GraphCaptureGuard {
886 _never: core::convert::Infallible,
887}
888
889#[cfg(not(feature = "cuda"))]
890impl GraphCaptureGuard {
891 pub fn begin<T>(_stream: &T) -> GpuResult<Self> {
892 Err(GpuError::NoCudaFeature)
893 }
894
895 pub fn begin_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<Self> {
896 Err(GpuError::NoCudaFeature)
897 }
898
899 pub fn begin_with_pool<T>(
900 _stream: &T,
901 _pool: std::sync::Arc<CapturePool>,
902 ) -> GpuResult<Self> {
903 Err(GpuError::NoCudaFeature)
904 }
905
906 pub fn finish(self) -> GpuResult<CapturedGraph> {
907 match self._never {}
908 }
909
910 pub fn status(&self) -> GpuResult<CaptureStatus> {
911 match self._never {}
912 }
913}
914
915/// Stub `graph_pool_handle` when the cuda feature is not enabled. CL-454.
916#[cfg(not(feature = "cuda"))]
917pub fn graph_pool_handle() -> GraphPoolHandle {
918 GraphPoolHandle(0)
919}
920
921/// Stub `capture_pool_for_handle` when the cuda feature is not enabled.
922/// CL-454.
923#[cfg(not(feature = "cuda"))]
924pub fn capture_pool_for_handle(_handle: GraphPoolHandle) -> Option<std::sync::Arc<CapturePool>> {
925 None
926}
927
928/// Stub `release_graph_pool_handle` when the cuda feature is not enabled.
929/// CL-454.
930#[cfg(not(feature = "cuda"))]
931pub fn release_graph_pool_handle(_handle: GraphPoolHandle) {
932 // nothing to release
933}
934
935/// Stub `make_graphed_callable` when the cuda feature is not enabled.
936/// CL-454.
937#[cfg(not(feature = "cuda"))]
938pub fn make_graphed_callable<T, F>(
939 _stream: &T,
940 _mode: CaptureMode,
941 _f: F,
942) -> GpuResult<CapturedGraph>
943where
944 F: FnOnce() -> GpuResult<()>,
945{
946 Err(GpuError::NoCudaFeature)
947}
948
949// ---------------------------------------------------------------------------
950// Tests — CL-278 capture pool buffer tracking
951// ---------------------------------------------------------------------------
952
953#[cfg(all(test, feature = "cuda"))]
954mod tests {
955 use super::*;
956
957 #[test]
958 fn capture_pool_buffer_count_starts_at_zero() {
959 let pool = CapturePool::new();
960 assert_eq!(pool.buffer_count(), 0);
961 }
962
963 #[test]
964 fn capture_pool_record_buffer_increments_count() {
965 let pool = CapturePool::new();
966 let buf_a: Vec<f32> = vec![0.0; 10];
967 let idx = pool.record_buffer(buf_a);
968 assert_eq!(idx, 0);
969 assert_eq!(pool.buffer_count(), 1);
970
971 let buf_b: Vec<f64> = vec![0.0; 5];
972 let idx = pool.record_buffer(buf_b);
973 assert_eq!(idx, 1);
974 assert_eq!(pool.buffer_count(), 2);
975 }
976
977 #[test]
978 fn capture_pool_clear_buffers_resets_count_but_keeps_pool() {
979 let pool = CapturePool::new();
980 pool.record_buffer(vec![0u8; 16]);
981 pool.record_buffer(vec![0u8; 32]);
982 assert_eq!(pool.buffer_count(), 2);
983 pool.clear_buffers();
984 assert_eq!(pool.buffer_count(), 0);
985 // Pool is still usable.
986 pool.record_buffer(vec![0u8; 8]);
987 assert_eq!(pool.buffer_count(), 1);
988 }
989
990 #[test]
991 fn capture_pool_drop_releases_registered_buffers() {
992 // Use Arc to detect when the inner buffer is dropped.
993 let buf = Arc::new(vec![1.0f32, 2.0, 3.0]);
994 let pool = CapturePool::new();
995 pool.record_buffer(Arc::clone(&buf));
996 assert_eq!(Arc::strong_count(&buf), 2);
997 drop(pool);
998 // Pool dropped → recorded Arc dropped → strong count back to 1.
999 assert_eq!(Arc::strong_count(&buf), 1);
1000 }
1001
1002 #[test]
1003 fn capture_pool_records_heterogeneous_types() {
1004 let pool = CapturePool::new();
1005 pool.record_buffer(vec![0.0f32; 4]);
1006 pool.record_buffer(vec![0.0f64; 4]);
1007 pool.record_buffer(vec![0u8; 4]);
1008 pool.record_buffer(Arc::new(42i32));
1009 assert_eq!(pool.buffer_count(), 4);
1010 }
1011
1012 #[test]
1013 fn capture_pool_seal_unseal() {
1014 let pool = CapturePool::new();
1015 assert!(!pool.is_capture_pool_sealed());
1016 pool.seal();
1017 assert!(pool.is_capture_pool_sealed());
1018 pool.unseal();
1019 assert!(!pool.is_capture_pool_sealed());
1020 }
1021
1022 // -----------------------------------------------------------------------
1023 // CL-454 — CaptureMode / CaptureStatus / graph pool handle tests.
1024 //
1025 // These tests exercise the typed wrappers and the process-wide
1026 // pool-handle registry without requiring a real CUDA device.
1027 // Tests that actually touch a device live under the
1028 // `feature = "cuda-live"` gate (run them with
1029 // cargo test -p ferrotorch-gpu --features cuda,cuda-live
1030 // on a machine with a functioning CUDA driver).
1031 // -----------------------------------------------------------------------
1032
1033 #[test]
1034 fn capture_mode_default_is_thread_local() {
1035 assert_eq!(CaptureMode::default(), CaptureMode::ThreadLocal);
1036 }
1037
1038 #[test]
1039 fn capture_mode_to_cuda_round_trip() {
1040 use cudarc::driver::sys::CUstreamCaptureMode::*;
1041 assert_eq!(CaptureMode::Global.to_cuda(), CU_STREAM_CAPTURE_MODE_GLOBAL);
1042 assert_eq!(
1043 CaptureMode::ThreadLocal.to_cuda(),
1044 CU_STREAM_CAPTURE_MODE_THREAD_LOCAL
1045 );
1046 assert_eq!(
1047 CaptureMode::Relaxed.to_cuda(),
1048 CU_STREAM_CAPTURE_MODE_RELAXED
1049 );
1050 }
1051
1052 #[test]
1053 fn capture_status_is_capturing_only_when_active() {
1054 assert!(!CaptureStatus::None.is_capturing());
1055 assert!(CaptureStatus::Active.is_capturing());
1056 assert!(!CaptureStatus::Invalidated.is_capturing());
1057 }
1058
1059 #[test]
1060 fn capture_status_is_invalidated_only_when_broken() {
1061 assert!(!CaptureStatus::None.is_invalidated());
1062 assert!(!CaptureStatus::Active.is_invalidated());
1063 assert!(CaptureStatus::Invalidated.is_invalidated());
1064 }
1065
1066 #[test]
1067 fn capture_status_from_cuda_maps_all_variants() {
1068 use cudarc::driver::sys::CUstreamCaptureStatus::*;
1069 assert_eq!(
1070 CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_NONE),
1071 CaptureStatus::None
1072 );
1073 assert_eq!(
1074 CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_ACTIVE),
1075 CaptureStatus::Active
1076 );
1077 assert_eq!(
1078 CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_INVALIDATED),
1079 CaptureStatus::Invalidated
1080 );
1081 }
1082
1083 #[test]
1084 fn graph_pool_handle_allocates_unique_ids() {
1085 let h1 = graph_pool_handle();
1086 let h2 = graph_pool_handle();
1087 assert_ne!(h1, h2, "each call should return a fresh id");
1088 // Both handles should map back to a real pool.
1089 assert!(capture_pool_for_handle(h1).is_some());
1090 assert!(capture_pool_for_handle(h2).is_some());
1091 release_graph_pool_handle(h1);
1092 release_graph_pool_handle(h2);
1093 }
1094
1095 #[test]
1096 fn graph_pool_handle_shares_single_pool_across_lookups() {
1097 let h = graph_pool_handle();
1098 let a = capture_pool_for_handle(h).expect("handle registered");
1099 let b = capture_pool_for_handle(h).expect("handle still registered");
1100 assert!(
1101 Arc::ptr_eq(&a, &b),
1102 "both lookups should return the same pool Arc"
1103 );
1104
1105 // Register a buffer through one lookup; the other should see it.
1106 a.record_buffer(vec![1.0f32, 2.0]);
1107 assert_eq!(b.buffer_count(), 1);
1108
1109 release_graph_pool_handle(h);
1110 // After release the registry no longer hands out the pool, but
1111 // existing Arcs keep it alive.
1112 assert!(capture_pool_for_handle(h).is_none());
1113 // The existing Arc still has its buffer.
1114 assert_eq!(a.buffer_count(), 1);
1115 }
1116
1117 #[test]
1118 fn graph_pool_handle_release_is_idempotent() {
1119 let h = graph_pool_handle();
1120 assert!(capture_pool_for_handle(h).is_some());
1121 release_graph_pool_handle(h);
1122 release_graph_pool_handle(h); // second call is fine
1123 assert!(capture_pool_for_handle(h).is_none());
1124 }
1125
1126 #[test]
1127 fn graph_pool_handle_unknown_id_returns_none() {
1128 // A fresh handle ID that was never registered.
1129 let fake = GraphPoolHandle(u64::MAX);
1130 assert!(capture_pool_for_handle(fake).is_none());
1131 }
1132}
1133
1134// ---------------------------------------------------------------------------
1135// CL-454 — tests that don't need cudarc type info.
1136// ---------------------------------------------------------------------------
1137
1138#[cfg(all(test, not(feature = "cuda")))]
1139mod no_cuda_tests {
1140 use super::*;
1141
1142 #[test]
1143 fn capture_mode_and_status_exist_without_cuda_feature() {
1144 // The feature-portable types compile without the cuda feature
1145 // so callers can write cfg-free code.
1146 let _ = CaptureMode::default();
1147 assert!(!CaptureStatus::None.is_capturing());
1148 assert!(CaptureStatus::Active.is_capturing());
1149 assert!(CaptureStatus::Invalidated.is_invalidated());
1150 }
1151
1152 #[test]
1153 fn graph_pool_handle_without_cuda_returns_sentinel() {
1154 let h = graph_pool_handle();
1155 assert_eq!(h.0, 0, "stub handle is always zero without cuda feature");
1156 assert!(capture_pool_for_handle(h).is_none());
1157 release_graph_pool_handle(h); // no-op
1158 }
1159
1160 #[test]
1161 fn captured_graph_stub_num_replays_and_is_uploaded_are_zero() {
1162 let g = CapturedGraph;
1163 assert_eq!(g.num_replays(), 0);
1164 assert!(!g.is_uploaded());
1165 }
1166}