ferrotorch_gpu/graph.rs
1//! CUDA graph capture and replay infrastructure.
2//!
3//! A CUDA graph records a sequence of GPU operations (kernel launches, memcpys)
4//! and replays them as a single driver submission. This eliminates per-kernel
5//! launch overhead (~70μs on WSL2, ~5μs on native Linux per call) by collapsing
6//! hundreds of launches into one.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use ferrotorch_gpu::graph::{DeviceScalar, begin_capture, end_capture};
12//!
13//! // Pre-allocate all buffers BEFORE capture
14//! let mut out = alloc_zeros_f32(768, &device)?;
15//!
16//! // Parameters that change between replays go in DeviceScalar
17//! let mut pos = DeviceScalar::new(device.stream(), 0u32)?;
18//!
19//! // Capture
20//! begin_capture(device.stream())?;
21//! gpu_add_into(&a, &b, &mut out, &device)?; // recorded, not executed
22//! let graph = end_capture(device.stream())?;
23//!
24//! // Replay loop
25//! for i in 0..100 {
26//! pos.update(i as u32)?; // memcpy before replay
27//! graph.launch()?; // replay all captured ops
28//! }
29//! ```
30
31#[cfg(feature = "cuda")]
32use std::sync::Arc;
33#[cfg(feature = "cuda")]
34use std::sync::atomic::{AtomicU64, Ordering};
35
36#[cfg(feature = "cuda")]
37use cudarc::driver::{CudaSlice, CudaStream, DeviceRepr, ValidAsZeroBits};
38
39use crate::error::{GpuError, GpuResult};
40
41// ---------------------------------------------------------------------------
42// CaptureMode — typed wrapper over cudarc's CUstreamCaptureMode
43// ---------------------------------------------------------------------------
44
45/// Selects how CUDA graph capture serializes interactions with other
46/// threads. Mirrors `cudaStreamCaptureMode`.
47///
48/// - `Global` — any CUDA API call from any thread that touches the
49/// capturing stream (or any thread that is also capturing) will
50/// invalidate capture. Safest for debugging; matches PyTorch's
51/// default.
52/// - `ThreadLocal` — only calls from the capturing thread can
53/// invalidate capture. Other threads may freely use unrelated
54/// streams. This is what ferrotorch-gpu has always used.
55/// - `Relaxed` — the driver does not track cross-thread interactions
56/// at all. Fastest, but the caller is fully responsible for making
57/// sure no other thread interferes.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59#[derive(Default)]
60pub enum CaptureMode {
61 /// Global serialization (`CU_STREAM_CAPTURE_MODE_GLOBAL`).
62 Global,
63 /// Thread-local serialization (`CU_STREAM_CAPTURE_MODE_THREAD_LOCAL`).
64 /// This is the default in PyTorch's `cuda.graph` context.
65 #[default]
66 ThreadLocal,
67 /// Relaxed — no cross-thread serialization
68 /// (`CU_STREAM_CAPTURE_MODE_RELAXED`).
69 Relaxed,
70}
71
72
73#[cfg(feature = "cuda")]
74impl CaptureMode {
75 /// Convert to the raw cudarc enum.
76 #[inline]
77 pub fn to_cuda(self) -> cudarc::driver::sys::CUstreamCaptureMode {
78 use cudarc::driver::sys::CUstreamCaptureMode::*;
79 match self {
80 Self::Global => CU_STREAM_CAPTURE_MODE_GLOBAL,
81 Self::ThreadLocal => CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
82 Self::Relaxed => CU_STREAM_CAPTURE_MODE_RELAXED,
83 }
84 }
85}
86
87// ---------------------------------------------------------------------------
88// CaptureStatus — typed wrapper over cudarc's CUstreamCaptureStatus
89// ---------------------------------------------------------------------------
90
91/// The capture state of a CUDA stream. Matches `cudaStreamCaptureStatus`.
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
93pub enum CaptureStatus {
94 /// The stream is not currently capturing any graph.
95 None,
96 /// The stream is actively capturing a graph.
97 Active,
98 /// Capture was invalidated (e.g., by a forbidden API call or a
99 /// cross-stream dependency). The caller must call `end_capture`
100 /// to discard the broken graph before doing anything else on the
101 /// stream.
102 Invalidated,
103}
104
105#[cfg(feature = "cuda")]
106impl CaptureStatus {
107 fn from_cuda(raw: cudarc::driver::sys::CUstreamCaptureStatus) -> Self {
108 use cudarc::driver::sys::CUstreamCaptureStatus::*;
109 match raw {
110 CU_STREAM_CAPTURE_STATUS_NONE => Self::None,
111 CU_STREAM_CAPTURE_STATUS_ACTIVE => Self::Active,
112 CU_STREAM_CAPTURE_STATUS_INVALIDATED => Self::Invalidated,
113 }
114 }
115}
116
117impl CaptureStatus {
118 /// Returns `true` if this stream is actively capturing a graph.
119 #[inline]
120 pub fn is_capturing(&self) -> bool {
121 matches!(self, Self::Active)
122 }
123
124 /// Returns `true` if capture was invalidated and must be ended.
125 #[inline]
126 pub fn is_invalidated(&self) -> bool {
127 matches!(self, Self::Invalidated)
128 }
129}
130
131// ---------------------------------------------------------------------------
132// DeviceScalar — a single value in GPU memory, updatable before graph replay
133// ---------------------------------------------------------------------------
134
135/// A single scalar value stored in GPU device memory.
136///
137/// Used for CUDA graph capture: the graph records the device pointer (fixed
138/// address), and the caller updates the value via [`update`](DeviceScalar::update)
139/// before each [`CapturedGraph::launch`]. The update is a 4-or-8 byte
140/// `cuMemcpyHtoDAsync` — effectively zero cost.
141#[cfg(feature = "cuda")]
142pub struct DeviceScalar<T: DeviceRepr + ValidAsZeroBits + Copy> {
143 buf: CudaSlice<T>,
144 stream: Arc<CudaStream>,
145}
146
147#[cfg(feature = "cuda")]
148impl<T: DeviceRepr + ValidAsZeroBits + Copy> DeviceScalar<T> {
149 /// Allocate a device scalar with the given initial value.
150 pub fn new(stream: &Arc<CudaStream>, initial: T) -> GpuResult<Self> {
151 let buf = stream.clone_htod(&[initial])?;
152 Ok(Self {
153 buf,
154 stream: Arc::clone(stream),
155 })
156 }
157
158 /// Update the device value. This is an async H→D memcpy of `size_of::<T>()`
159 /// bytes. Must be called on the same stream as the graph to ensure ordering.
160 pub fn update(&mut self, value: T) -> GpuResult<()> {
161 self.stream.memcpy_htod(&[value], &mut self.buf)?;
162 Ok(())
163 }
164
165 /// Borrow the underlying `CudaSlice` for use as a kernel parameter.
166 /// The graph captures this pointer address; updating the value later
167 /// changes what the kernel reads without re-capturing.
168 #[inline]
169 pub fn inner(&self) -> &CudaSlice<T> {
170 &self.buf
171 }
172}
173
174// ---------------------------------------------------------------------------
175// CapturedGraph — a replayable CUDA graph
176// ---------------------------------------------------------------------------
177
178/// A captured and instantiated CUDA graph that can be replayed with
179/// [`launch`](CapturedGraph::launch).
180///
181/// Created via [`begin_capture`] + GPU ops + [`end_capture`].
182/// The graph holds references to all device memory used during capture.
183/// Those buffers must remain allocated for the lifetime of the graph.
184///
185/// **Allocator pool integration (CL-278).** When created via
186/// [`end_capture_with_pool`], the graph holds a strong reference to
187/// the [`CapturePool`] that recorded its allocations. The pool keeps
188/// every registered buffer alive until the last `CapturedGraph`
189/// referencing it is dropped, which guarantees the device pointers
190/// recorded in the graph remain valid across replays. Without the
191/// pool, callers must manually keep buffers alive (the original
192/// [`end_capture`] API).
193#[cfg(feature = "cuda")]
194pub struct CapturedGraph {
195 graph: cudarc::driver::CudaGraph,
196 /// Optional reference to the pool that owns the graph's
197 /// allocations. Some(pool) when constructed via
198 /// [`end_capture_with_pool`]. Dropping the graph drops this
199 /// Arc, which (if it's the last reference) drops every buffer
200 /// the pool holds. CL-278.
201 pool: Option<Arc<CapturePool>>,
202 /// Monotonic counter bumped by every successful [`launch`]. Lets
203 /// callers assert that a specific replay happened after some
204 /// other work completed, useful for graph-aware profilers and
205 /// integration tests. CL-454.
206 replay_count: AtomicU64,
207 /// True after the first successful [`upload`] so subsequent
208 /// uploads become cheap no-ops. CL-454.
209 uploaded: std::sync::atomic::AtomicBool,
210}
211
212#[cfg(feature = "cuda")]
213impl CapturedGraph {
214 /// Replay all operations captured in this graph.
215 ///
216 /// Before calling this, update any [`DeviceScalar`] values and perform
217 /// any pre-launch memcpys (e.g., position embeddings). All updates must
218 /// be on the same stream the graph was captured on.
219 ///
220 /// Bumps [`num_replays`](Self::num_replays) on success.
221 pub fn launch(&self) -> GpuResult<()> {
222 self.graph.launch()?;
223 self.replay_count.fetch_add(1, Ordering::Relaxed);
224 Ok(())
225 }
226
227 /// Pre-upload the graph's executable resources to the device.
228 ///
229 /// The first [`launch`](Self::launch) of a freshly instantiated graph
230 /// pays a one-time cost for the driver to copy the exec into GPU
231 /// memory. Calling `upload` up front shifts that cost out of the
232 /// hot replay loop. Subsequent uploads are a no-op. CL-454.
233 pub fn upload(&self) -> GpuResult<()> {
234 if self.uploaded.load(Ordering::Acquire) {
235 return Ok(());
236 }
237 self.graph.upload()?;
238 self.uploaded.store(true, Ordering::Release);
239 Ok(())
240 }
241
242 /// Number of successful replays issued on this graph. CL-454.
243 #[inline]
244 pub fn num_replays(&self) -> u64 {
245 self.replay_count.load(Ordering::Relaxed)
246 }
247
248 /// Returns `true` if [`upload`](Self::upload) has been called on
249 /// this graph. CL-454.
250 #[inline]
251 pub fn is_uploaded(&self) -> bool {
252 self.uploaded.load(Ordering::Acquire)
253 }
254
255 /// Number of buffers held alive by this graph's allocator pool.
256 /// Returns 0 if the graph was created without a pool. CL-278.
257 pub fn pool_buffer_count(&self) -> usize {
258 self.pool
259 .as_ref()
260 .map(|p| p.buffer_count())
261 .unwrap_or(0)
262 }
263
264 /// True if this graph holds a CapturePool reference. CL-278.
265 pub fn has_pool(&self) -> bool {
266 self.pool.is_some()
267 }
268
269 /// Return the [`Arc<CapturePool>`] this graph is using, if any.
270 /// Allows sharing the same pool between multiple graphs so they
271 /// all keep the same buffers alive. CL-454.
272 pub fn pool(&self) -> Option<&Arc<CapturePool>> {
273 self.pool.as_ref()
274 }
275}
276
277// ---------------------------------------------------------------------------
278// Capture API
279// ---------------------------------------------------------------------------
280
281/// Begin CUDA graph capture on the given stream.
282///
283/// All GPU operations (kernel launches, cuBLAS calls, memcpys) issued on this
284/// stream after this call are **recorded but not executed**. Call
285/// [`end_capture`] to finalize and instantiate the graph.
286///
287/// # Requirements
288///
289/// - All output buffers must be pre-allocated before capture begins.
290/// - No `alloc_zeros` / `cpu_to_gpu` during capture (use `_into` variants).
291/// - No CPU↔GPU synchronization during capture.
292/// - Event tracking should be disabled during capture to avoid interference
293/// (call `ctx.disable_event_tracking()` before, re-enable after).
294#[cfg(feature = "cuda")]
295pub fn begin_capture(stream: &Arc<CudaStream>) -> GpuResult<()> {
296 begin_capture_with_mode(stream, CaptureMode::default())
297}
298
299/// Begin CUDA graph capture with an explicit [`CaptureMode`]. CL-454.
300///
301/// Prefer [`begin_capture`] for the default (`ThreadLocal`) mode. Use
302/// this form when you need `Global` (debugging / strict serialization)
303/// or `Relaxed` (max throughput, single-thread ownership).
304#[cfg(feature = "cuda")]
305pub fn begin_capture_with_mode(
306 stream: &Arc<CudaStream>,
307 mode: CaptureMode,
308) -> GpuResult<()> {
309 stream.begin_capture(mode.to_cuda())?;
310 Ok(())
311}
312
313/// Query the capture status of a CUDA stream. CL-454.
314///
315/// This is the ferrotorch-gpu equivalent of PyTorch's
316/// `torch.cuda.is_current_stream_capturing`. Callers can use this to
317/// skip capture-invalid APIs (allocator calls, H↔D copies) when a
318/// graph is being recorded.
319#[cfg(feature = "cuda")]
320pub fn capture_status(stream: &Arc<CudaStream>) -> GpuResult<CaptureStatus> {
321 let raw = stream.capture_status()?;
322 Ok(CaptureStatus::from_cuda(raw))
323}
324
325/// Shorthand for `capture_status(stream)?.is_capturing()`. CL-454.
326#[cfg(feature = "cuda")]
327pub fn is_stream_capturing(stream: &Arc<CudaStream>) -> GpuResult<bool> {
328 Ok(capture_status(stream)?.is_capturing())
329}
330
331/// End CUDA graph capture, instantiate, and return the replayable graph.
332///
333/// Returns `Err` if capture was not active or if instantiation fails.
334///
335/// The returned graph has no [`CapturePool`] attached. The caller is
336/// responsible for keeping the buffers used by the captured kernels
337/// alive for the graph's lifetime. Use [`end_capture_with_pool`]
338/// for the lifetime-managed variant.
339#[cfg(feature = "cuda")]
340pub fn end_capture(stream: &Arc<CudaStream>) -> GpuResult<CapturedGraph> {
341 let flags = cudarc::driver::sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
342 let graph = stream
343 .end_capture(flags)?
344 .ok_or(GpuError::PtxCompileFailed {
345 kernel: "CUDA graph capture returned null",
346 })?;
347 Ok(CapturedGraph {
348 graph,
349 pool: None,
350 replay_count: AtomicU64::new(0),
351 uploaded: std::sync::atomic::AtomicBool::new(false),
352 })
353}
354
355/// End CUDA graph capture and attach a [`CapturePool`] reference to
356/// the resulting [`CapturedGraph`]. CL-278.
357///
358/// The pool's tracked buffers are kept alive for the lifetime of the
359/// returned graph: dropping the graph drops its `Arc<CapturePool>`,
360/// which (if it's the last reference) drops every buffer the pool
361/// recorded. This guarantees that the device pointers recorded in
362/// the captured graph remain valid across replays.
363///
364/// Use this in concert with [`CapturePool::record_buffer`]: allocate
365/// every buffer used during capture before calling `begin_capture`,
366/// register each one with the pool, run the kernels under capture,
367/// then call `end_capture_with_pool(stream, pool)` to seal the
368/// lifetime relationship.
369#[cfg(feature = "cuda")]
370pub fn end_capture_with_pool(
371 stream: &Arc<CudaStream>,
372 pool: Arc<CapturePool>,
373) -> GpuResult<CapturedGraph> {
374 let mut graph = end_capture(stream)?;
375 graph.pool = Some(pool);
376 Ok(graph)
377}
378
379// ---------------------------------------------------------------------------
380// GraphCaptureGuard — RAII wrapper that ends capture on drop
381// ---------------------------------------------------------------------------
382
383/// RAII guard that runs CUDA graph capture in a scoped block.
384///
385/// Call [`GraphCaptureGuard::begin`] (or [`begin_with_mode`] /
386/// [`begin_with_pool`]) to start capture; calling [`finish`] returns
387/// the instantiated graph. If the guard is dropped without calling
388/// `finish` (for example because a kernel returned an error
389/// mid-capture), its `Drop` impl best-effort-ends capture and
390/// discards the resulting graph so the stream returns to a usable
391/// state. CL-454.
392///
393/// This mirrors PyTorch's `with torch.cuda.graph(g): ...` context
394/// manager semantics in Rust's RAII idiom.
395///
396/// # Example
397///
398/// ```ignore
399/// use ferrotorch_gpu::graph::GraphCaptureGuard;
400///
401/// let mut guard = GraphCaptureGuard::begin(device.stream())?;
402/// run_kernels()?; // any kernel launched on device.stream() is recorded
403/// let graph = guard.finish()?;
404/// graph.upload()?;
405/// for _ in 0..1000 { graph.launch()?; }
406/// ```
407#[cfg(feature = "cuda")]
408pub struct GraphCaptureGuard {
409 stream: Arc<CudaStream>,
410 /// Optional pool to attach when `finish` is called.
411 pool: Option<Arc<CapturePool>>,
412 /// Becomes `false` after [`finish`] consumes the guard, so `Drop`
413 /// knows capture is already ended.
414 active: bool,
415}
416
417#[cfg(feature = "cuda")]
418impl GraphCaptureGuard {
419 /// Begin graph capture on `stream` in the default
420 /// [`CaptureMode::ThreadLocal`] mode. CL-454.
421 pub fn begin(stream: &Arc<CudaStream>) -> GpuResult<Self> {
422 Self::begin_with_mode(stream, CaptureMode::default())
423 }
424
425 /// Begin graph capture with an explicit [`CaptureMode`]. CL-454.
426 pub fn begin_with_mode(
427 stream: &Arc<CudaStream>,
428 mode: CaptureMode,
429 ) -> GpuResult<Self> {
430 begin_capture_with_mode(stream, mode)?;
431 Ok(Self {
432 stream: Arc::clone(stream),
433 pool: None,
434 active: true,
435 })
436 }
437
438 /// Begin graph capture bound to a [`CapturePool`]. The pool is
439 /// attached to the resulting graph by [`finish`]. CL-454.
440 pub fn begin_with_pool(
441 stream: &Arc<CudaStream>,
442 pool: Arc<CapturePool>,
443 ) -> GpuResult<Self> {
444 begin_capture_with_pool(&pool, stream)?;
445 Ok(Self {
446 stream: Arc::clone(stream),
447 pool: Some(pool),
448 active: true,
449 })
450 }
451
452 /// Finish capture and return the instantiated [`CapturedGraph`].
453 ///
454 /// Consumes the guard so `Drop` becomes a no-op. If a pool was
455 /// attached at construction, the resulting graph is produced via
456 /// [`end_capture_with_pool`] and holds the pool Arc for the
457 /// lifetime of the graph.
458 pub fn finish(mut self) -> GpuResult<CapturedGraph> {
459 self.active = false;
460 if let Some(pool) = self.pool.take() {
461 end_capture_with_pool(&self.stream, pool)
462 } else {
463 end_capture(&self.stream)
464 }
465 }
466
467 /// Report whether the stream this guard is bound to is still
468 /// actively capturing. An unexpected `Invalidated` or `None`
469 /// usually means a forbidden API call (alloc, sync, host copy)
470 /// happened under capture.
471 pub fn status(&self) -> GpuResult<CaptureStatus> {
472 capture_status(&self.stream)
473 }
474}
475
476#[cfg(feature = "cuda")]
477impl Drop for GraphCaptureGuard {
478 fn drop(&mut self) {
479 if !self.active {
480 return;
481 }
482 // Best-effort: discard the in-flight capture so the stream
483 // becomes usable again. We ignore errors because we're in
484 // Drop — the CapturedGraph result is immediately dropped.
485 let _ = end_capture(&self.stream);
486 }
487}
488
489// ---------------------------------------------------------------------------
490// Graph pool handle registry — share a CapturePool across multiple graphs
491// ---------------------------------------------------------------------------
492
493/// Opaque handle for a pool registered with the process-wide graph
494/// pool registry. Used to share the same buffer-lifetime pool across
495/// multiple captured graphs without passing `Arc<CapturePool>` around
496/// by hand. CL-454.
497///
498/// Mirrors PyTorch's `torch.cuda.graph_pool_handle()`.
499#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
500pub struct GraphPoolHandle(pub u64);
501
502#[cfg(feature = "cuda")]
503static NEXT_POOL_HANDLE: AtomicU64 = AtomicU64::new(1);
504
505#[cfg(feature = "cuda")]
506static POOL_REGISTRY: std::sync::OnceLock<
507 std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>>,
508> = std::sync::OnceLock::new();
509
510#[cfg(feature = "cuda")]
511fn pool_registry() -> &'static std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>> {
512 POOL_REGISTRY.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
513}
514
515/// Allocate a fresh [`GraphPoolHandle`] and register a new
516/// [`CapturePool`] under it in the process-wide registry. CL-454.
517///
518/// The handle can later be passed to [`capture_pool_for_handle`] to
519/// retrieve the same `Arc<CapturePool>` from any thread, which lets
520/// two independently captured graphs share the same buffer-keeping
521/// pool.
522#[cfg(feature = "cuda")]
523pub fn graph_pool_handle() -> GraphPoolHandle {
524 let id = NEXT_POOL_HANDLE.fetch_add(1, Ordering::Relaxed);
525 let pool = Arc::new(CapturePool::new());
526 let mut reg = pool_registry()
527 .lock()
528 .unwrap_or_else(|p| p.into_inner());
529 reg.insert(id, pool);
530 GraphPoolHandle(id)
531}
532
533/// Look up the [`CapturePool`] registered under `handle` and return
534/// a strong `Arc` to it. Returns `None` if the handle was never
535/// allocated or has been released via [`release_graph_pool_handle`].
536/// CL-454.
537#[cfg(feature = "cuda")]
538pub fn capture_pool_for_handle(handle: GraphPoolHandle) -> Option<Arc<CapturePool>> {
539 let reg = pool_registry()
540 .lock()
541 .unwrap_or_else(|p| p.into_inner());
542 reg.get(&handle.0).cloned()
543}
544
545/// Drop the registry's strong reference to the pool behind `handle`.
546/// Any [`CapturedGraph`] that holds its own Arc (for example via
547/// [`end_capture_with_pool`]) keeps the pool alive until that graph
548/// is dropped too. CL-454.
549#[cfg(feature = "cuda")]
550pub fn release_graph_pool_handle(handle: GraphPoolHandle) {
551 let mut reg = pool_registry()
552 .lock()
553 .unwrap_or_else(|p| p.into_inner());
554 reg.remove(&handle.0);
555}
556
557// ---------------------------------------------------------------------------
558// make_graphed_callable — scoped capture over a closure
559// ---------------------------------------------------------------------------
560
561/// Capture the operations performed by `f` into a CUDA graph and
562/// return the replayable graph. CL-454.
563///
564/// This is the ferrotorch-gpu equivalent of PyTorch's
565/// `torch.cuda.make_graphed_callables` for the simple single-callable
566/// case: the caller supplies a closure that runs all the GPU work to
567/// capture, and the returned [`CapturedGraph`] can be replayed over
568/// and over. The closure runs exactly once during capture, so all
569/// per-call work (allocations, dtype decisions) that isn't valid
570/// under capture must happen outside.
571///
572/// If the closure returns an error, capture is discarded and the
573/// error is propagated.
574#[cfg(feature = "cuda")]
575pub fn make_graphed_callable<F>(
576 stream: &Arc<CudaStream>,
577 mode: CaptureMode,
578 f: F,
579) -> GpuResult<CapturedGraph>
580where
581 F: FnOnce() -> GpuResult<()>,
582{
583 let guard = GraphCaptureGuard::begin_with_mode(stream, mode)?;
584 match f() {
585 Ok(()) => guard.finish(),
586 Err(e) => {
587 // Guard drop ends capture and discards the graph.
588 drop(guard);
589 Err(e)
590 }
591 }
592}
593
594// ---------------------------------------------------------------------------
595// CapturePool — memory pool for graph capture
596// ---------------------------------------------------------------------------
597
598/// A dedicated memory pool for CUDA graph capture.
599///
600/// Two responsibilities:
601///
602/// 1. **Sealed flag** — gates [`begin_capture_with_pool`] so the
603/// caller can express "no more allocations after this point"
604/// semantically. Sealed pools cannot satisfy new allocations
605/// during capture.
606///
607/// 2. **Buffer lifetime tracking (CL-278)** — registered buffers
608/// are kept alive by the pool itself, so they outlive any
609/// [`CapturedGraph`] that holds an `Arc<CapturePool>`. Dropping
610/// the graph drops the Arc, and dropping the last Arc drops
611/// every registered buffer in registration order.
612///
613/// # Usage
614///
615/// ```ignore
616/// use std::sync::Arc;
617/// let pool = Arc::new(CapturePool::new());
618///
619/// // Allocate every buffer the captured kernels will read or
620/// // write, and register each one with the pool so it stays alive
621/// // for the graph's lifetime.
622/// let mut buf_a = alloc_zeros_f32(1024, &device)?;
623/// let mut buf_b = alloc_zeros_f32(1024, &device)?;
624/// pool.record_buffer(buf_a.try_clone()?);
625/// pool.record_buffer(buf_b.try_clone()?);
626///
627/// pool.seal();
628/// begin_capture_with_pool(&pool, stream)?;
629/// // ... launch kernels using buf_a and buf_b ...
630/// let graph = end_capture_with_pool(stream, Arc::clone(&pool))?;
631/// // Dropping `pool` here is safe — the graph holds its own Arc.
632/// ```
633#[cfg(feature = "cuda")]
634pub struct CapturePool {
635 sealed: std::sync::atomic::AtomicBool,
636 /// Registered buffers (type-erased) kept alive for the graph's
637 /// lifetime. Each entry is a Box<dyn Any + Send + Sync> wrapping
638 /// the buffer's drop guard. CL-278.
639 buffers: std::sync::Mutex<Vec<Box<dyn std::any::Any + Send + Sync + 'static>>>,
640}
641
642#[cfg(feature = "cuda")]
643impl CapturePool {
644 /// Create a new, unsealed capture pool.
645 pub fn new() -> Self {
646 Self {
647 sealed: std::sync::atomic::AtomicBool::new(false),
648 buffers: std::sync::Mutex::new(Vec::new()),
649 }
650 }
651
652 /// Seal the pool, preventing any further allocations.
653 pub fn seal(&self) {
654 self.sealed
655 .store(true, std::sync::atomic::Ordering::Release);
656 }
657
658 /// Unseal the pool, allowing allocations again.
659 pub fn unseal(&self) {
660 self.sealed
661 .store(false, std::sync::atomic::Ordering::Release);
662 }
663
664 /// Check whether the pool is sealed.
665 pub fn is_capture_pool_sealed(&self) -> bool {
666 self.sealed.load(std::sync::atomic::Ordering::Acquire)
667 }
668
669 /// Register a buffer with the pool so it stays alive for the
670 /// lifetime of any [`CapturedGraph`] that holds this pool.
671 /// CL-278.
672 ///
673 /// `buffer` can be any type that owns GPU memory (typically
674 /// `CudaBuffer<f32>`, `CudaBuffer<f64>`, or `Arc<CudaBuffer<T>>`).
675 /// The pool stores it in a type-erased `Box<dyn Any + Send +
676 /// Sync>` and drops it (in registration order) when the pool
677 /// itself is dropped.
678 ///
679 /// Returns the index of the registered buffer for diagnostic
680 /// purposes.
681 pub fn record_buffer<B>(&self, buffer: B) -> usize
682 where
683 B: Send + Sync + 'static,
684 {
685 let mut guard = self
686 .buffers
687 .lock()
688 .unwrap_or_else(|p| p.into_inner());
689 let idx = guard.len();
690 guard.push(Box::new(buffer));
691 idx
692 }
693
694 /// Number of buffers currently registered with the pool. CL-278.
695 pub fn buffer_count(&self) -> usize {
696 self.buffers
697 .lock()
698 .map(|g| g.len())
699 .unwrap_or(0)
700 }
701
702 /// Drop every registered buffer immediately, in registration
703 /// order. The pool itself remains usable; new buffers can still
704 /// be registered after this call. CL-278.
705 ///
706 /// Use this when reusing a pool across multiple capture cycles.
707 /// Calling clear while a [`CapturedGraph`] still holds an Arc
708 /// to this pool is safe — the graph's strong reference keeps
709 /// the pool struct alive, but the buffer slots are reset.
710 pub fn clear_buffers(&self) {
711 let mut guard = self
712 .buffers
713 .lock()
714 .unwrap_or_else(|p| p.into_inner());
715 guard.clear();
716 }
717}
718
719#[cfg(feature = "cuda")]
720impl Default for CapturePool {
721 fn default() -> Self {
722 Self::new()
723 }
724}
725
726/// Begin CUDA graph capture with a capture pool.
727///
728/// Like [`begin_capture`], but checks that the capture pool is not sealed
729/// before starting capture. A sealed pool cannot satisfy allocations
730/// during graph capture, which would cause CUDA errors.
731///
732/// # Errors
733///
734/// Returns [`GpuError::InvalidArgument`](GpuError) if the pool is sealed.
735/// Returns a CUDA driver error if `begin_capture` fails.
736#[cfg(feature = "cuda")]
737pub fn begin_capture_with_pool(pool: &CapturePool, stream: &Arc<CudaStream>) -> GpuResult<()> {
738 if pool.is_capture_pool_sealed() {
739 return Err(GpuError::InvalidState {
740 message: "cannot begin graph capture: capture pool is sealed".into(),
741 });
742 }
743 begin_capture(stream)
744}
745
746/// Stub CapturePool when cuda feature is disabled. Provides the
747/// same surface API as the cuda-enabled type so callers compile on
748/// both feature configurations.
749#[cfg(not(feature = "cuda"))]
750pub struct CapturePool;
751
752#[cfg(not(feature = "cuda"))]
753impl CapturePool {
754 /// Create an empty CapturePool. Without the cuda feature the
755 /// pool has no internal state to initialize.
756 pub fn new() -> Self {
757 Self
758 }
759
760 /// No-op without the cuda feature: there is no real CUDA pool
761 /// to seal because no real allocations can happen.
762 pub fn seal(&self) {
763 // Without the cuda feature there is no allocator state to
764 // mutate; the CapturePool exists only so callers can write
765 // feature-portable code.
766 }
767
768 /// No-op without the cuda feature: there is no real CUDA pool
769 /// to unseal because no real allocations can happen.
770 pub fn unseal(&self) {
771 // Without the cuda feature there is no allocator state to
772 // mutate; the CapturePool exists only so callers can write
773 // feature-portable code.
774 }
775
776 /// Always returns `false` without the cuda feature since there
777 /// is no real pool that could be in either state.
778 pub fn is_capture_pool_sealed(&self) -> bool {
779 false
780 }
781
782 /// Always returns 0 without the cuda feature since no real
783 /// allocations can be tracked. CL-278.
784 pub fn buffer_count(&self) -> usize {
785 0
786 }
787}
788
789#[cfg(not(feature = "cuda"))]
790impl Default for CapturePool {
791 fn default() -> Self {
792 Self::new()
793 }
794}
795
796/// Stub begin_capture_with_pool when cuda feature is disabled.
797#[cfg(not(feature = "cuda"))]
798pub fn begin_capture_with_pool<T>(_pool: &CapturePool, _stream: &T) -> GpuResult<()> {
799 Err(GpuError::NoCudaFeature)
800}
801
802// ---------------------------------------------------------------------------
803// Stubs when cuda feature is disabled
804// ---------------------------------------------------------------------------
805
806/// Stub DeviceScalar.
807#[cfg(not(feature = "cuda"))]
808pub struct DeviceScalar<T: Copy> {
809 _phantom: std::marker::PhantomData<T>,
810}
811
812/// Stub CapturedGraph.
813#[cfg(not(feature = "cuda"))]
814pub struct CapturedGraph;
815
816#[cfg(not(feature = "cuda"))]
817impl CapturedGraph {
818 pub fn launch(&self) -> GpuResult<()> {
819 Err(GpuError::NoCudaFeature)
820 }
821
822 /// Stub upload for CL-454.
823 pub fn upload(&self) -> GpuResult<()> {
824 Err(GpuError::NoCudaFeature)
825 }
826
827 /// Stub num_replays — always 0 without the cuda feature. CL-454.
828 pub fn num_replays(&self) -> u64 {
829 0
830 }
831
832 /// Stub is_uploaded — always false without the cuda feature. CL-454.
833 pub fn is_uploaded(&self) -> bool {
834 false
835 }
836}
837
838#[cfg(not(feature = "cuda"))]
839pub fn begin_capture<T>(_stream: &T) -> GpuResult<()> {
840 Err(GpuError::NoCudaFeature)
841}
842
843/// Stub `begin_capture_with_mode` when the cuda feature is not enabled.
844/// CL-454.
845#[cfg(not(feature = "cuda"))]
846pub fn begin_capture_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<()> {
847 Err(GpuError::NoCudaFeature)
848}
849
850/// Stub `capture_status` when the cuda feature is not enabled. CL-454.
851#[cfg(not(feature = "cuda"))]
852pub fn capture_status<T>(_stream: &T) -> GpuResult<CaptureStatus> {
853 Err(GpuError::NoCudaFeature)
854}
855
856/// Stub `is_stream_capturing` when the cuda feature is not enabled.
857/// CL-454.
858#[cfg(not(feature = "cuda"))]
859pub fn is_stream_capturing<T>(_stream: &T) -> GpuResult<bool> {
860 Err(GpuError::NoCudaFeature)
861}
862
863#[cfg(not(feature = "cuda"))]
864pub fn end_capture<T>(_stream: &T) -> GpuResult<CapturedGraph> {
865 Err(GpuError::NoCudaFeature)
866}
867
868/// Stub `end_capture_with_pool` when the cuda feature is not enabled.
869/// CL-278.
870#[cfg(not(feature = "cuda"))]
871pub fn end_capture_with_pool<T>(
872 _stream: &T,
873 _pool: std::sync::Arc<CapturePool>,
874) -> GpuResult<CapturedGraph> {
875 Err(GpuError::NoCudaFeature)
876}
877
878/// Stub `GraphCaptureGuard` when the cuda feature is not enabled. CL-454.
879#[cfg(not(feature = "cuda"))]
880pub struct GraphCaptureGuard {
881 _never: core::convert::Infallible,
882}
883
884#[cfg(not(feature = "cuda"))]
885impl GraphCaptureGuard {
886 pub fn begin<T>(_stream: &T) -> GpuResult<Self> {
887 Err(GpuError::NoCudaFeature)
888 }
889
890 pub fn begin_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<Self> {
891 Err(GpuError::NoCudaFeature)
892 }
893
894 pub fn begin_with_pool<T>(
895 _stream: &T,
896 _pool: std::sync::Arc<CapturePool>,
897 ) -> GpuResult<Self> {
898 Err(GpuError::NoCudaFeature)
899 }
900
901 pub fn finish(self) -> GpuResult<CapturedGraph> {
902 match self._never {}
903 }
904
905 pub fn status(&self) -> GpuResult<CaptureStatus> {
906 match self._never {}
907 }
908}
909
910/// Stub `graph_pool_handle` when the cuda feature is not enabled. CL-454.
911#[cfg(not(feature = "cuda"))]
912pub fn graph_pool_handle() -> GraphPoolHandle {
913 GraphPoolHandle(0)
914}
915
916/// Stub `capture_pool_for_handle` when the cuda feature is not enabled.
917/// CL-454.
918#[cfg(not(feature = "cuda"))]
919pub fn capture_pool_for_handle(_handle: GraphPoolHandle) -> Option<std::sync::Arc<CapturePool>> {
920 None
921}
922
923/// Stub `release_graph_pool_handle` when the cuda feature is not enabled.
924/// CL-454.
925#[cfg(not(feature = "cuda"))]
926pub fn release_graph_pool_handle(_handle: GraphPoolHandle) {
927 // nothing to release
928}
929
930/// Stub `make_graphed_callable` when the cuda feature is not enabled.
931/// CL-454.
932#[cfg(not(feature = "cuda"))]
933pub fn make_graphed_callable<T, F>(
934 _stream: &T,
935 _mode: CaptureMode,
936 _f: F,
937) -> GpuResult<CapturedGraph>
938where
939 F: FnOnce() -> GpuResult<()>,
940{
941 Err(GpuError::NoCudaFeature)
942}
943
944// ---------------------------------------------------------------------------
945// Tests — CL-278 capture pool buffer tracking
946// ---------------------------------------------------------------------------
947
948#[cfg(all(test, feature = "cuda"))]
949mod tests {
950 use super::*;
951
952 #[test]
953 fn capture_pool_buffer_count_starts_at_zero() {
954 let pool = CapturePool::new();
955 assert_eq!(pool.buffer_count(), 0);
956 }
957
958 #[test]
959 fn capture_pool_record_buffer_increments_count() {
960 let pool = CapturePool::new();
961 let buf_a: Vec<f32> = vec![0.0; 10];
962 let idx = pool.record_buffer(buf_a);
963 assert_eq!(idx, 0);
964 assert_eq!(pool.buffer_count(), 1);
965
966 let buf_b: Vec<f64> = vec![0.0; 5];
967 let idx = pool.record_buffer(buf_b);
968 assert_eq!(idx, 1);
969 assert_eq!(pool.buffer_count(), 2);
970 }
971
972 #[test]
973 fn capture_pool_clear_buffers_resets_count_but_keeps_pool() {
974 let pool = CapturePool::new();
975 pool.record_buffer(vec![0u8; 16]);
976 pool.record_buffer(vec![0u8; 32]);
977 assert_eq!(pool.buffer_count(), 2);
978 pool.clear_buffers();
979 assert_eq!(pool.buffer_count(), 0);
980 // Pool is still usable.
981 pool.record_buffer(vec![0u8; 8]);
982 assert_eq!(pool.buffer_count(), 1);
983 }
984
985 #[test]
986 fn capture_pool_drop_releases_registered_buffers() {
987 // Use Arc to detect when the inner buffer is dropped.
988 let buf = Arc::new(vec![1.0f32, 2.0, 3.0]);
989 let pool = CapturePool::new();
990 pool.record_buffer(Arc::clone(&buf));
991 assert_eq!(Arc::strong_count(&buf), 2);
992 drop(pool);
993 // Pool dropped → recorded Arc dropped → strong count back to 1.
994 assert_eq!(Arc::strong_count(&buf), 1);
995 }
996
997 #[test]
998 fn capture_pool_records_heterogeneous_types() {
999 let pool = CapturePool::new();
1000 pool.record_buffer(vec![0.0f32; 4]);
1001 pool.record_buffer(vec![0.0f64; 4]);
1002 pool.record_buffer(vec![0u8; 4]);
1003 pool.record_buffer(Arc::new(42i32));
1004 assert_eq!(pool.buffer_count(), 4);
1005 }
1006
1007 #[test]
1008 fn capture_pool_seal_unseal() {
1009 let pool = CapturePool::new();
1010 assert!(!pool.is_capture_pool_sealed());
1011 pool.seal();
1012 assert!(pool.is_capture_pool_sealed());
1013 pool.unseal();
1014 assert!(!pool.is_capture_pool_sealed());
1015 }
1016
1017 // -----------------------------------------------------------------------
1018 // CL-454 — CaptureMode / CaptureStatus / graph pool handle tests.
1019 //
1020 // These tests exercise the typed wrappers and the process-wide
1021 // pool-handle registry without requiring a real CUDA device.
1022 // Tests that actually touch a device live under the
1023 // `feature = "cuda-live"` gate (run them with
1024 // cargo test -p ferrotorch-gpu --features cuda,cuda-live
1025 // on a machine with a functioning CUDA driver).
1026 // -----------------------------------------------------------------------
1027
1028 #[test]
1029 fn capture_mode_default_is_thread_local() {
1030 assert_eq!(CaptureMode::default(), CaptureMode::ThreadLocal);
1031 }
1032
1033 #[test]
1034 fn capture_mode_to_cuda_round_trip() {
1035 use cudarc::driver::sys::CUstreamCaptureMode::*;
1036 assert_eq!(CaptureMode::Global.to_cuda(), CU_STREAM_CAPTURE_MODE_GLOBAL);
1037 assert_eq!(
1038 CaptureMode::ThreadLocal.to_cuda(),
1039 CU_STREAM_CAPTURE_MODE_THREAD_LOCAL
1040 );
1041 assert_eq!(
1042 CaptureMode::Relaxed.to_cuda(),
1043 CU_STREAM_CAPTURE_MODE_RELAXED
1044 );
1045 }
1046
1047 #[test]
1048 fn capture_status_is_capturing_only_when_active() {
1049 assert!(!CaptureStatus::None.is_capturing());
1050 assert!(CaptureStatus::Active.is_capturing());
1051 assert!(!CaptureStatus::Invalidated.is_capturing());
1052 }
1053
1054 #[test]
1055 fn capture_status_is_invalidated_only_when_broken() {
1056 assert!(!CaptureStatus::None.is_invalidated());
1057 assert!(!CaptureStatus::Active.is_invalidated());
1058 assert!(CaptureStatus::Invalidated.is_invalidated());
1059 }
1060
1061 #[test]
1062 fn capture_status_from_cuda_maps_all_variants() {
1063 use cudarc::driver::sys::CUstreamCaptureStatus::*;
1064 assert_eq!(
1065 CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_NONE),
1066 CaptureStatus::None
1067 );
1068 assert_eq!(
1069 CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_ACTIVE),
1070 CaptureStatus::Active
1071 );
1072 assert_eq!(
1073 CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_INVALIDATED),
1074 CaptureStatus::Invalidated
1075 );
1076 }
1077
1078 #[test]
1079 fn graph_pool_handle_allocates_unique_ids() {
1080 let h1 = graph_pool_handle();
1081 let h2 = graph_pool_handle();
1082 assert_ne!(h1, h2, "each call should return a fresh id");
1083 // Both handles should map back to a real pool.
1084 assert!(capture_pool_for_handle(h1).is_some());
1085 assert!(capture_pool_for_handle(h2).is_some());
1086 release_graph_pool_handle(h1);
1087 release_graph_pool_handle(h2);
1088 }
1089
1090 #[test]
1091 fn graph_pool_handle_shares_single_pool_across_lookups() {
1092 let h = graph_pool_handle();
1093 let a = capture_pool_for_handle(h).expect("handle registered");
1094 let b = capture_pool_for_handle(h).expect("handle still registered");
1095 assert!(
1096 Arc::ptr_eq(&a, &b),
1097 "both lookups should return the same pool Arc"
1098 );
1099
1100 // Register a buffer through one lookup; the other should see it.
1101 a.record_buffer(vec![1.0f32, 2.0]);
1102 assert_eq!(b.buffer_count(), 1);
1103
1104 release_graph_pool_handle(h);
1105 // After release the registry no longer hands out the pool, but
1106 // existing Arcs keep it alive.
1107 assert!(capture_pool_for_handle(h).is_none());
1108 // The existing Arc still has its buffer.
1109 assert_eq!(a.buffer_count(), 1);
1110 }
1111
1112 #[test]
1113 fn graph_pool_handle_release_is_idempotent() {
1114 let h = graph_pool_handle();
1115 assert!(capture_pool_for_handle(h).is_some());
1116 release_graph_pool_handle(h);
1117 release_graph_pool_handle(h); // second call is fine
1118 assert!(capture_pool_for_handle(h).is_none());
1119 }
1120
1121 #[test]
1122 fn graph_pool_handle_unknown_id_returns_none() {
1123 // A fresh handle ID that was never registered.
1124 let fake = GraphPoolHandle(u64::MAX);
1125 assert!(capture_pool_for_handle(fake).is_none());
1126 }
1127}
1128
1129// ---------------------------------------------------------------------------
1130// CL-454 — tests that don't need cudarc type info.
1131// ---------------------------------------------------------------------------
1132
1133#[cfg(all(test, not(feature = "cuda")))]
1134mod no_cuda_tests {
1135 use super::*;
1136
1137 #[test]
1138 fn capture_mode_and_status_exist_without_cuda_feature() {
1139 // The feature-portable types compile without the cuda feature
1140 // so callers can write cfg-free code.
1141 let _ = CaptureMode::default();
1142 assert!(!CaptureStatus::None.is_capturing());
1143 assert!(CaptureStatus::Active.is_capturing());
1144 assert!(CaptureStatus::Invalidated.is_invalidated());
1145 }
1146
1147 #[test]
1148 fn graph_pool_handle_without_cuda_returns_sentinel() {
1149 let h = graph_pool_handle();
1150 assert_eq!(h.0, 0, "stub handle is always zero without cuda feature");
1151 assert!(capture_pool_for_handle(h).is_none());
1152 release_graph_pool_handle(h); // no-op
1153 }
1154
1155 #[test]
1156 fn captured_graph_stub_num_replays_and_is_uploaded_are_zero() {
1157 let g = CapturedGraph;
1158 assert_eq!(g.num_replays(), 0);
1159 assert!(!g.is_uploaded());
1160 }
1161}