Skip to main content

ferrotorch_gpu/
graph.rs

1//! CUDA graph capture and replay infrastructure.
2//!
3//! A CUDA graph records a sequence of GPU operations (kernel launches, memcpys)
4//! and replays them as a single driver submission. This eliminates per-kernel
5//! launch overhead (~70μs on WSL2, ~5μs on native Linux per call) by collapsing
6//! hundreds of launches into one.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use ferrotorch_gpu::graph::{DeviceScalar, begin_capture, end_capture};
12//!
13//! // Pre-allocate all buffers BEFORE capture
14//! let mut out = alloc_zeros_f32(768, &device)?;
15//!
16//! // Parameters that change between replays go in DeviceScalar
17//! let mut pos = DeviceScalar::new(device.stream(), 0u32)?;
18//!
19//! // Capture
20//! begin_capture(device.stream())?;
21//! gpu_add_into(&a, &b, &mut out, &device)?;  // recorded, not executed
22//! let graph = end_capture(device.stream())?;
23//!
24//! // Replay loop
25//! for i in 0..100 {
26//!     pos.update(i as u32)?;  // memcpy before replay
27//!     graph.launch()?;         // replay all captured ops
28//! }
29//! ```
30
31#[cfg(feature = "cuda")]
32use std::sync::Arc;
33#[cfg(feature = "cuda")]
34use std::sync::atomic::{AtomicU64, Ordering};
35
36#[cfg(feature = "cuda")]
37use cudarc::driver::{CudaSlice, CudaStream, DeviceRepr, ValidAsZeroBits};
38
39use crate::error::{GpuError, GpuResult};
40
41// ---------------------------------------------------------------------------
42// CaptureMode — typed wrapper over cudarc's CUstreamCaptureMode
43// ---------------------------------------------------------------------------
44
45/// Selects how CUDA graph capture serializes interactions with other
46/// threads. Mirrors `cudaStreamCaptureMode`.
47///
48/// - `Global` — any CUDA API call from any thread that touches the
49///   capturing stream (or any thread that is also capturing) will
50///   invalidate capture. Safest for debugging; matches PyTorch's
51///   default.
52/// - `ThreadLocal` — only calls from the capturing thread can
53///   invalidate capture. Other threads may freely use unrelated
54///   streams. This is what ferrotorch-gpu has always used.
55/// - `Relaxed` — the driver does not track cross-thread interactions
56///   at all. Fastest, but the caller is fully responsible for making
57///   sure no other thread interferes.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59#[derive(Default)]
60pub enum CaptureMode {
61    /// Global serialization (`CU_STREAM_CAPTURE_MODE_GLOBAL`).
62    Global,
63    /// Thread-local serialization (`CU_STREAM_CAPTURE_MODE_THREAD_LOCAL`).
64    /// This is the default in PyTorch's `cuda.graph` context.
65    #[default]
66    ThreadLocal,
67    /// Relaxed — no cross-thread serialization
68    /// (`CU_STREAM_CAPTURE_MODE_RELAXED`).
69    Relaxed,
70}
71
72
73#[cfg(feature = "cuda")]
74impl CaptureMode {
75    /// Convert to the raw cudarc enum.
76    #[inline]
77    pub fn to_cuda(self) -> cudarc::driver::sys::CUstreamCaptureMode {
78        use cudarc::driver::sys::CUstreamCaptureMode::*;
79        match self {
80            Self::Global => CU_STREAM_CAPTURE_MODE_GLOBAL,
81            Self::ThreadLocal => CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
82            Self::Relaxed => CU_STREAM_CAPTURE_MODE_RELAXED,
83        }
84    }
85}
86
87// ---------------------------------------------------------------------------
88// CaptureStatus — typed wrapper over cudarc's CUstreamCaptureStatus
89// ---------------------------------------------------------------------------
90
91/// The capture state of a CUDA stream. Matches `cudaStreamCaptureStatus`.
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
93pub enum CaptureStatus {
94    /// The stream is not currently capturing any graph.
95    None,
96    /// The stream is actively capturing a graph.
97    Active,
98    /// Capture was invalidated (e.g., by a forbidden API call or a
99    /// cross-stream dependency). The caller must call `end_capture`
100    /// to discard the broken graph before doing anything else on the
101    /// stream.
102    Invalidated,
103}
104
105#[cfg(feature = "cuda")]
106impl CaptureStatus {
107    fn from_cuda(raw: cudarc::driver::sys::CUstreamCaptureStatus) -> Self {
108        use cudarc::driver::sys::CUstreamCaptureStatus::*;
109        match raw {
110            CU_STREAM_CAPTURE_STATUS_NONE => Self::None,
111            CU_STREAM_CAPTURE_STATUS_ACTIVE => Self::Active,
112            CU_STREAM_CAPTURE_STATUS_INVALIDATED => Self::Invalidated,
113        }
114    }
115}
116
117impl CaptureStatus {
118    /// Returns `true` if this stream is actively capturing a graph.
119    #[inline]
120    pub fn is_capturing(&self) -> bool {
121        matches!(self, Self::Active)
122    }
123
124    /// Returns `true` if capture was invalidated and must be ended.
125    #[inline]
126    pub fn is_invalidated(&self) -> bool {
127        matches!(self, Self::Invalidated)
128    }
129}
130
131// ---------------------------------------------------------------------------
132// DeviceScalar — a single value in GPU memory, updatable before graph replay
133// ---------------------------------------------------------------------------
134
135/// A single scalar value stored in GPU device memory.
136///
137/// Used for CUDA graph capture: the graph records the device pointer (fixed
138/// address), and the caller updates the value via [`update`](DeviceScalar::update)
139/// before each [`CapturedGraph::launch`]. The update is a 4-or-8 byte
140/// `cuMemcpyHtoDAsync` — effectively zero cost.
141#[cfg(feature = "cuda")]
142pub struct DeviceScalar<T: DeviceRepr + ValidAsZeroBits + Copy> {
143    buf: CudaSlice<T>,
144    stream: Arc<CudaStream>,
145}
146
147#[cfg(feature = "cuda")]
148impl<T: DeviceRepr + ValidAsZeroBits + Copy> DeviceScalar<T> {
149    /// Allocate a device scalar with the given initial value.
150    pub fn new(stream: &Arc<CudaStream>, initial: T) -> GpuResult<Self> {
151        let buf = stream.clone_htod(&[initial])?;
152        Ok(Self {
153            buf,
154            stream: Arc::clone(stream),
155        })
156    }
157
158    /// Update the device value. This is an async H→D memcpy of `size_of::<T>()`
159    /// bytes. Must be called on the same stream as the graph to ensure ordering.
160    pub fn update(&mut self, value: T) -> GpuResult<()> {
161        self.stream.memcpy_htod(&[value], &mut self.buf)?;
162        Ok(())
163    }
164
165    /// Borrow the underlying `CudaSlice` for use as a kernel parameter.
166    /// The graph captures this pointer address; updating the value later
167    /// changes what the kernel reads without re-capturing.
168    #[inline]
169    pub fn inner(&self) -> &CudaSlice<T> {
170        &self.buf
171    }
172}
173
174// ---------------------------------------------------------------------------
175// CapturedGraph — a replayable CUDA graph
176// ---------------------------------------------------------------------------
177
178/// A captured and instantiated CUDA graph that can be replayed with
179/// [`launch`](CapturedGraph::launch).
180///
181/// Created via [`begin_capture`] + GPU ops + [`end_capture`].
182/// The graph holds references to all device memory used during capture.
183/// Those buffers must remain allocated for the lifetime of the graph.
184///
185/// **Allocator pool integration (CL-278).** When created via
186/// [`end_capture_with_pool`], the graph holds a strong reference to
187/// the [`CapturePool`] that recorded its allocations. The pool keeps
188/// every registered buffer alive until the last `CapturedGraph`
189/// referencing it is dropped, which guarantees the device pointers
190/// recorded in the graph remain valid across replays. Without the
191/// pool, callers must manually keep buffers alive (the original
192/// [`end_capture`] API).
193#[cfg(feature = "cuda")]
194pub struct CapturedGraph {
195    graph: cudarc::driver::CudaGraph,
196    /// Optional reference to the pool that owns the graph's
197    /// allocations. Some(pool) when constructed via
198    /// [`end_capture_with_pool`]. Dropping the graph drops this
199    /// Arc, which (if it's the last reference) drops every buffer
200    /// the pool holds. CL-278.
201    pool: Option<Arc<CapturePool>>,
202    /// Monotonic counter bumped by every successful [`launch`]. Lets
203    /// callers assert that a specific replay happened after some
204    /// other work completed, useful for graph-aware profilers and
205    /// integration tests. CL-454.
206    replay_count: AtomicU64,
207    /// True after the first successful [`upload`] so subsequent
208    /// uploads become cheap no-ops. CL-454.
209    uploaded: std::sync::atomic::AtomicBool,
210}
211
212#[cfg(feature = "cuda")]
213impl CapturedGraph {
214    /// Replay all operations captured in this graph.
215    ///
216    /// Before calling this, update any [`DeviceScalar`] values and perform
217    /// any pre-launch memcpys (e.g., position embeddings). All updates must
218    /// be on the same stream the graph was captured on.
219    ///
220    /// Bumps [`num_replays`](Self::num_replays) on success.
221    pub fn launch(&self) -> GpuResult<()> {
222        self.graph.launch()?;
223        self.replay_count.fetch_add(1, Ordering::Relaxed);
224        Ok(())
225    }
226
227    /// Pre-upload the graph's executable resources to the device.
228    ///
229    /// The first [`launch`](Self::launch) of a freshly instantiated graph
230    /// pays a one-time cost for the driver to copy the exec into GPU
231    /// memory. Calling `upload` up front shifts that cost out of the
232    /// hot replay loop. Subsequent uploads are a no-op. CL-454.
233    pub fn upload(&self) -> GpuResult<()> {
234        if self.uploaded.load(Ordering::Acquire) {
235            return Ok(());
236        }
237        self.graph.upload()?;
238        self.uploaded.store(true, Ordering::Release);
239        Ok(())
240    }
241
242    /// Number of successful replays issued on this graph. CL-454.
243    #[inline]
244    pub fn num_replays(&self) -> u64 {
245        self.replay_count.load(Ordering::Relaxed)
246    }
247
248    /// Returns `true` if [`upload`](Self::upload) has been called on
249    /// this graph. CL-454.
250    #[inline]
251    pub fn is_uploaded(&self) -> bool {
252        self.uploaded.load(Ordering::Acquire)
253    }
254
255    /// Number of buffers held alive by this graph's allocator pool.
256    /// Returns 0 if the graph was created without a pool. CL-278.
257    pub fn pool_buffer_count(&self) -> usize {
258        self.pool
259            .as_ref()
260            .map(|p| p.buffer_count())
261            .unwrap_or(0)
262    }
263
264    /// True if this graph holds a CapturePool reference. CL-278.
265    pub fn has_pool(&self) -> bool {
266        self.pool.is_some()
267    }
268
269    /// Return the [`Arc<CapturePool>`] this graph is using, if any.
270    /// Allows sharing the same pool between multiple graphs so they
271    /// all keep the same buffers alive. CL-454.
272    pub fn pool(&self) -> Option<&Arc<CapturePool>> {
273        self.pool.as_ref()
274    }
275}
276
277// ---------------------------------------------------------------------------
278// Capture API
279// ---------------------------------------------------------------------------
280
281/// Begin CUDA graph capture on the given stream.
282///
283/// All GPU operations (kernel launches, cuBLAS calls, memcpys) issued on this
284/// stream after this call are **recorded but not executed**. Call
285/// [`end_capture`] to finalize and instantiate the graph.
286///
287/// # Requirements
288///
289/// - All output buffers must be pre-allocated before capture begins.
290/// - No `alloc_zeros` / `cpu_to_gpu` during capture (use `_into` variants).
291/// - No CPU↔GPU synchronization during capture.
292/// - Event tracking should be disabled during capture to avoid interference
293///   (call `ctx.disable_event_tracking()` before, re-enable after).
294#[cfg(feature = "cuda")]
295pub fn begin_capture(stream: &Arc<CudaStream>) -> GpuResult<()> {
296    begin_capture_with_mode(stream, CaptureMode::default())
297}
298
299/// Begin CUDA graph capture with an explicit [`CaptureMode`]. CL-454.
300///
301/// Prefer [`begin_capture`] for the default (`ThreadLocal`) mode. Use
302/// this form when you need `Global` (debugging / strict serialization)
303/// or `Relaxed` (max throughput, single-thread ownership).
304#[cfg(feature = "cuda")]
305pub fn begin_capture_with_mode(
306    stream: &Arc<CudaStream>,
307    mode: CaptureMode,
308) -> GpuResult<()> {
309    stream.begin_capture(mode.to_cuda())?;
310    Ok(())
311}
312
313/// Query the capture status of a CUDA stream. CL-454.
314///
315/// This is the ferrotorch-gpu equivalent of PyTorch's
316/// `torch.cuda.is_current_stream_capturing`. Callers can use this to
317/// skip capture-invalid APIs (allocator calls, H↔D copies) when a
318/// graph is being recorded.
319#[cfg(feature = "cuda")]
320pub fn capture_status(stream: &Arc<CudaStream>) -> GpuResult<CaptureStatus> {
321    let raw = stream.capture_status()?;
322    Ok(CaptureStatus::from_cuda(raw))
323}
324
325/// Shorthand for `capture_status(stream)?.is_capturing()`. CL-454.
326#[cfg(feature = "cuda")]
327pub fn is_stream_capturing(stream: &Arc<CudaStream>) -> GpuResult<bool> {
328    Ok(capture_status(stream)?.is_capturing())
329}
330
331/// End CUDA graph capture, instantiate, and return the replayable graph.
332///
333/// Returns `Err` if capture was not active or if instantiation fails.
334///
335/// The returned graph has no [`CapturePool`] attached. The caller is
336/// responsible for keeping the buffers used by the captured kernels
337/// alive for the graph's lifetime. Use [`end_capture_with_pool`]
338/// for the lifetime-managed variant.
339#[cfg(feature = "cuda")]
340pub fn end_capture(stream: &Arc<CudaStream>) -> GpuResult<CapturedGraph> {
341    let flags = cudarc::driver::sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
342    let graph = stream
343        .end_capture(flags)?
344        .ok_or(GpuError::PtxCompileFailed {
345            kernel: "CUDA graph capture returned null",
346        })?;
347    Ok(CapturedGraph {
348        graph,
349        pool: None,
350        replay_count: AtomicU64::new(0),
351        uploaded: std::sync::atomic::AtomicBool::new(false),
352    })
353}
354
355/// End CUDA graph capture and attach a [`CapturePool`] reference to
356/// the resulting [`CapturedGraph`]. CL-278.
357///
358/// The pool's tracked buffers are kept alive for the lifetime of the
359/// returned graph: dropping the graph drops its `Arc<CapturePool>`,
360/// which (if it's the last reference) drops every buffer the pool
361/// recorded. This guarantees that the device pointers recorded in
362/// the captured graph remain valid across replays.
363///
364/// Use this in concert with [`CapturePool::record_buffer`]: allocate
365/// every buffer used during capture before calling `begin_capture`,
366/// register each one with the pool, run the kernels under capture,
367/// then call `end_capture_with_pool(stream, pool)` to seal the
368/// lifetime relationship.
369#[cfg(feature = "cuda")]
370pub fn end_capture_with_pool(
371    stream: &Arc<CudaStream>,
372    pool: Arc<CapturePool>,
373) -> GpuResult<CapturedGraph> {
374    let mut graph = end_capture(stream)?;
375    graph.pool = Some(pool);
376    Ok(graph)
377}
378
379// ---------------------------------------------------------------------------
380// GraphCaptureGuard — RAII wrapper that ends capture on drop
381// ---------------------------------------------------------------------------
382
383/// RAII guard that runs CUDA graph capture in a scoped block.
384///
385/// Call [`GraphCaptureGuard::begin`] (or [`begin_with_mode`] /
386/// [`begin_with_pool`]) to start capture; calling [`finish`] returns
387/// the instantiated graph. If the guard is dropped without calling
388/// `finish` (for example because a kernel returned an error
389/// mid-capture), its `Drop` impl best-effort-ends capture and
390/// discards the resulting graph so the stream returns to a usable
391/// state. CL-454.
392///
393/// This mirrors PyTorch's `with torch.cuda.graph(g): ...` context
394/// manager semantics in Rust's RAII idiom.
395///
396/// # Example
397///
398/// ```ignore
399/// use ferrotorch_gpu::graph::GraphCaptureGuard;
400///
401/// let mut guard = GraphCaptureGuard::begin(device.stream())?;
402/// run_kernels()?; // any kernel launched on device.stream() is recorded
403/// let graph = guard.finish()?;
404/// graph.upload()?;
405/// for _ in 0..1000 { graph.launch()?; }
406/// ```
407#[cfg(feature = "cuda")]
408pub struct GraphCaptureGuard {
409    stream: Arc<CudaStream>,
410    /// Optional pool to attach when `finish` is called.
411    pool: Option<Arc<CapturePool>>,
412    /// Becomes `false` after [`finish`] consumes the guard, so `Drop`
413    /// knows capture is already ended.
414    active: bool,
415}
416
417#[cfg(feature = "cuda")]
418impl GraphCaptureGuard {
419    /// Begin graph capture on `stream` in the default
420    /// [`CaptureMode::ThreadLocal`] mode. CL-454.
421    pub fn begin(stream: &Arc<CudaStream>) -> GpuResult<Self> {
422        Self::begin_with_mode(stream, CaptureMode::default())
423    }
424
425    /// Begin graph capture with an explicit [`CaptureMode`]. CL-454.
426    pub fn begin_with_mode(
427        stream: &Arc<CudaStream>,
428        mode: CaptureMode,
429    ) -> GpuResult<Self> {
430        begin_capture_with_mode(stream, mode)?;
431        Ok(Self {
432            stream: Arc::clone(stream),
433            pool: None,
434            active: true,
435        })
436    }
437
438    /// Begin graph capture bound to a [`CapturePool`]. The pool is
439    /// attached to the resulting graph by [`finish`]. CL-454.
440    pub fn begin_with_pool(
441        stream: &Arc<CudaStream>,
442        pool: Arc<CapturePool>,
443    ) -> GpuResult<Self> {
444        begin_capture_with_pool(&pool, stream)?;
445        Ok(Self {
446            stream: Arc::clone(stream),
447            pool: Some(pool),
448            active: true,
449        })
450    }
451
452    /// Finish capture and return the instantiated [`CapturedGraph`].
453    ///
454    /// Consumes the guard so `Drop` becomes a no-op. If a pool was
455    /// attached at construction, the resulting graph is produced via
456    /// [`end_capture_with_pool`] and holds the pool Arc for the
457    /// lifetime of the graph.
458    pub fn finish(mut self) -> GpuResult<CapturedGraph> {
459        self.active = false;
460        if let Some(pool) = self.pool.take() {
461            end_capture_with_pool(&self.stream, pool)
462        } else {
463            end_capture(&self.stream)
464        }
465    }
466
467    /// Report whether the stream this guard is bound to is still
468    /// actively capturing. An unexpected `Invalidated` or `None`
469    /// usually means a forbidden API call (alloc, sync, host copy)
470    /// happened under capture.
471    pub fn status(&self) -> GpuResult<CaptureStatus> {
472        capture_status(&self.stream)
473    }
474}
475
476#[cfg(feature = "cuda")]
477impl Drop for GraphCaptureGuard {
478    fn drop(&mut self) {
479        if !self.active {
480            return;
481        }
482        // Best-effort: discard the in-flight capture so the stream
483        // becomes usable again. We ignore errors because we're in
484        // Drop — the CapturedGraph result is immediately dropped.
485        let _ = end_capture(&self.stream);
486    }
487}
488
489// ---------------------------------------------------------------------------
490// Graph pool handle registry — share a CapturePool across multiple graphs
491// ---------------------------------------------------------------------------
492
493/// Opaque handle for a pool registered with the process-wide graph
494/// pool registry. Used to share the same buffer-lifetime pool across
495/// multiple captured graphs without passing `Arc<CapturePool>` around
496/// by hand. CL-454.
497///
498/// Mirrors PyTorch's `torch.cuda.graph_pool_handle()`.
499#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
500pub struct GraphPoolHandle(pub u64);
501
502#[cfg(feature = "cuda")]
503static NEXT_POOL_HANDLE: AtomicU64 = AtomicU64::new(1);
504
505#[cfg(feature = "cuda")]
506static POOL_REGISTRY: std::sync::OnceLock<
507    std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>>,
508> = std::sync::OnceLock::new();
509
510#[cfg(feature = "cuda")]
511fn pool_registry() -> &'static std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>> {
512    POOL_REGISTRY.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
513}
514
515/// Allocate a fresh [`GraphPoolHandle`] and register a new
516/// [`CapturePool`] under it in the process-wide registry. CL-454.
517///
518/// The handle can later be passed to [`capture_pool_for_handle`] to
519/// retrieve the same `Arc<CapturePool>` from any thread, which lets
520/// two independently captured graphs share the same buffer-keeping
521/// pool.
522#[cfg(feature = "cuda")]
523pub fn graph_pool_handle() -> GraphPoolHandle {
524    let id = NEXT_POOL_HANDLE.fetch_add(1, Ordering::Relaxed);
525    let pool = Arc::new(CapturePool::new());
526    let mut reg = pool_registry()
527        .lock()
528        .unwrap_or_else(|p| p.into_inner());
529    reg.insert(id, pool);
530    GraphPoolHandle(id)
531}
532
533/// Look up the [`CapturePool`] registered under `handle` and return
534/// a strong `Arc` to it. Returns `None` if the handle was never
535/// allocated or has been released via [`release_graph_pool_handle`].
536/// CL-454.
537#[cfg(feature = "cuda")]
538pub fn capture_pool_for_handle(handle: GraphPoolHandle) -> Option<Arc<CapturePool>> {
539    let reg = pool_registry()
540        .lock()
541        .unwrap_or_else(|p| p.into_inner());
542    reg.get(&handle.0).cloned()
543}
544
545/// Drop the registry's strong reference to the pool behind `handle`.
546/// Any [`CapturedGraph`] that holds its own Arc (for example via
547/// [`end_capture_with_pool`]) keeps the pool alive until that graph
548/// is dropped too. CL-454.
549#[cfg(feature = "cuda")]
550pub fn release_graph_pool_handle(handle: GraphPoolHandle) {
551    let mut reg = pool_registry()
552        .lock()
553        .unwrap_or_else(|p| p.into_inner());
554    reg.remove(&handle.0);
555}
556
557// ---------------------------------------------------------------------------
558// make_graphed_callable — scoped capture over a closure
559// ---------------------------------------------------------------------------
560
561/// Capture the operations performed by `f` into a CUDA graph and
562/// return the replayable graph. CL-454.
563///
564/// This is the ferrotorch-gpu equivalent of PyTorch's
565/// `torch.cuda.make_graphed_callables` for the simple single-callable
566/// case: the caller supplies a closure that runs all the GPU work to
567/// capture, and the returned [`CapturedGraph`] can be replayed over
568/// and over. The closure runs exactly once during capture, so all
569/// per-call work (allocations, dtype decisions) that isn't valid
570/// under capture must happen outside.
571///
572/// If the closure returns an error, capture is discarded and the
573/// error is propagated.
574#[cfg(feature = "cuda")]
575pub fn make_graphed_callable<F>(
576    stream: &Arc<CudaStream>,
577    mode: CaptureMode,
578    f: F,
579) -> GpuResult<CapturedGraph>
580where
581    F: FnOnce() -> GpuResult<()>,
582{
583    let guard = GraphCaptureGuard::begin_with_mode(stream, mode)?;
584    match f() {
585        Ok(()) => guard.finish(),
586        Err(e) => {
587            // Guard drop ends capture and discards the graph.
588            drop(guard);
589            Err(e)
590        }
591    }
592}
593
594// ---------------------------------------------------------------------------
595// CapturePool — memory pool for graph capture
596// ---------------------------------------------------------------------------
597
598/// A dedicated memory pool for CUDA graph capture.
599///
600/// Two responsibilities:
601///
602/// 1. **Sealed flag** — gates [`begin_capture_with_pool`] so the
603///    caller can express "no more allocations after this point"
604///    semantically. Sealed pools cannot satisfy new allocations
605///    during capture.
606///
607/// 2. **Buffer lifetime tracking (CL-278)** — registered buffers
608///    are kept alive by the pool itself, so they outlive any
609///    [`CapturedGraph`] that holds an `Arc<CapturePool>`. Dropping
610///    the graph drops the Arc, and dropping the last Arc drops
611///    every registered buffer in registration order.
612///
613/// # Usage
614///
615/// ```ignore
616/// use std::sync::Arc;
617/// let pool = Arc::new(CapturePool::new());
618///
619/// // Allocate every buffer the captured kernels will read or
620/// // write, and register each one with the pool so it stays alive
621/// // for the graph's lifetime.
622/// let mut buf_a = alloc_zeros_f32(1024, &device)?;
623/// let mut buf_b = alloc_zeros_f32(1024, &device)?;
624/// pool.record_buffer(buf_a.try_clone()?);
625/// pool.record_buffer(buf_b.try_clone()?);
626///
627/// pool.seal();
628/// begin_capture_with_pool(&pool, stream)?;
629/// // ... launch kernels using buf_a and buf_b ...
630/// let graph = end_capture_with_pool(stream, Arc::clone(&pool))?;
631/// // Dropping `pool` here is safe — the graph holds its own Arc.
632/// ```
633#[cfg(feature = "cuda")]
634pub struct CapturePool {
635    sealed: std::sync::atomic::AtomicBool,
636    /// Registered buffers (type-erased) kept alive for the graph's
637    /// lifetime. Each entry is a Box<dyn Any + Send + Sync> wrapping
638    /// the buffer's drop guard. CL-278.
639    buffers: std::sync::Mutex<Vec<Box<dyn std::any::Any + Send + Sync + 'static>>>,
640}
641
642#[cfg(feature = "cuda")]
643impl CapturePool {
644    /// Create a new, unsealed capture pool.
645    pub fn new() -> Self {
646        Self {
647            sealed: std::sync::atomic::AtomicBool::new(false),
648            buffers: std::sync::Mutex::new(Vec::new()),
649        }
650    }
651
652    /// Seal the pool, preventing any further allocations.
653    pub fn seal(&self) {
654        self.sealed
655            .store(true, std::sync::atomic::Ordering::Release);
656    }
657
658    /// Unseal the pool, allowing allocations again.
659    pub fn unseal(&self) {
660        self.sealed
661            .store(false, std::sync::atomic::Ordering::Release);
662    }
663
664    /// Check whether the pool is sealed.
665    pub fn is_capture_pool_sealed(&self) -> bool {
666        self.sealed.load(std::sync::atomic::Ordering::Acquire)
667    }
668
669    /// Register a buffer with the pool so it stays alive for the
670    /// lifetime of any [`CapturedGraph`] that holds this pool.
671    /// CL-278.
672    ///
673    /// `buffer` can be any type that owns GPU memory (typically
674    /// `CudaBuffer<f32>`, `CudaBuffer<f64>`, or `Arc<CudaBuffer<T>>`).
675    /// The pool stores it in a type-erased `Box<dyn Any + Send +
676    /// Sync>` and drops it (in registration order) when the pool
677    /// itself is dropped.
678    ///
679    /// Returns the index of the registered buffer for diagnostic
680    /// purposes.
681    pub fn record_buffer<B>(&self, buffer: B) -> usize
682    where
683        B: Send + Sync + 'static,
684    {
685        let mut guard = self
686            .buffers
687            .lock()
688            .unwrap_or_else(|p| p.into_inner());
689        let idx = guard.len();
690        guard.push(Box::new(buffer));
691        idx
692    }
693
694    /// Number of buffers currently registered with the pool. CL-278.
695    pub fn buffer_count(&self) -> usize {
696        self.buffers
697            .lock()
698            .map(|g| g.len())
699            .unwrap_or(0)
700    }
701
702    /// Drop every registered buffer immediately, in registration
703    /// order. The pool itself remains usable; new buffers can still
704    /// be registered after this call. CL-278.
705    ///
706    /// Use this when reusing a pool across multiple capture cycles.
707    /// Calling clear while a [`CapturedGraph`] still holds an Arc
708    /// to this pool is safe — the graph's strong reference keeps
709    /// the pool struct alive, but the buffer slots are reset.
710    pub fn clear_buffers(&self) {
711        let mut guard = self
712            .buffers
713            .lock()
714            .unwrap_or_else(|p| p.into_inner());
715        guard.clear();
716    }
717}
718
719#[cfg(feature = "cuda")]
720impl Default for CapturePool {
721    fn default() -> Self {
722        Self::new()
723    }
724}
725
726/// Begin CUDA graph capture with a capture pool.
727///
728/// Like [`begin_capture`], but checks that the capture pool is not sealed
729/// before starting capture. A sealed pool cannot satisfy allocations
730/// during graph capture, which would cause CUDA errors.
731///
732/// # Errors
733///
734/// Returns [`GpuError::InvalidArgument`](GpuError) if the pool is sealed.
735/// Returns a CUDA driver error if `begin_capture` fails.
736#[cfg(feature = "cuda")]
737pub fn begin_capture_with_pool(pool: &CapturePool, stream: &Arc<CudaStream>) -> GpuResult<()> {
738    if pool.is_capture_pool_sealed() {
739        return Err(GpuError::InvalidState {
740            message: "cannot begin graph capture: capture pool is sealed".into(),
741        });
742    }
743    begin_capture(stream)
744}
745
746/// Stub CapturePool when cuda feature is disabled. Provides the
747/// same surface API as the cuda-enabled type so callers compile on
748/// both feature configurations.
749#[cfg(not(feature = "cuda"))]
750pub struct CapturePool;
751
752#[cfg(not(feature = "cuda"))]
753impl CapturePool {
754    /// Create an empty CapturePool. Without the cuda feature the
755    /// pool has no internal state to initialize.
756    pub fn new() -> Self {
757        Self
758    }
759
760    /// No-op without the cuda feature: there is no real CUDA pool
761    /// to seal because no real allocations can happen.
762    pub fn seal(&self) {
763        // Without the cuda feature there is no allocator state to
764        // mutate; the CapturePool exists only so callers can write
765        // feature-portable code.
766    }
767
768    /// No-op without the cuda feature: there is no real CUDA pool
769    /// to unseal because no real allocations can happen.
770    pub fn unseal(&self) {
771        // Without the cuda feature there is no allocator state to
772        // mutate; the CapturePool exists only so callers can write
773        // feature-portable code.
774    }
775
776    /// Always returns `false` without the cuda feature since there
777    /// is no real pool that could be in either state.
778    pub fn is_capture_pool_sealed(&self) -> bool {
779        false
780    }
781
782    /// Always returns 0 without the cuda feature since no real
783    /// allocations can be tracked. CL-278.
784    pub fn buffer_count(&self) -> usize {
785        0
786    }
787}
788
789#[cfg(not(feature = "cuda"))]
790impl Default for CapturePool {
791    fn default() -> Self {
792        Self::new()
793    }
794}
795
796/// Stub begin_capture_with_pool when cuda feature is disabled.
797#[cfg(not(feature = "cuda"))]
798pub fn begin_capture_with_pool<T>(_pool: &CapturePool, _stream: &T) -> GpuResult<()> {
799    Err(GpuError::NoCudaFeature)
800}
801
802// ---------------------------------------------------------------------------
803// Stubs when cuda feature is disabled
804// ---------------------------------------------------------------------------
805
806/// Stub DeviceScalar.
807#[cfg(not(feature = "cuda"))]
808pub struct DeviceScalar<T: Copy> {
809    _phantom: std::marker::PhantomData<T>,
810}
811
812/// Stub CapturedGraph.
813#[cfg(not(feature = "cuda"))]
814pub struct CapturedGraph;
815
816#[cfg(not(feature = "cuda"))]
817impl CapturedGraph {
818    pub fn launch(&self) -> GpuResult<()> {
819        Err(GpuError::NoCudaFeature)
820    }
821
822    /// Stub upload for CL-454.
823    pub fn upload(&self) -> GpuResult<()> {
824        Err(GpuError::NoCudaFeature)
825    }
826
827    /// Stub num_replays — always 0 without the cuda feature. CL-454.
828    pub fn num_replays(&self) -> u64 {
829        0
830    }
831
832    /// Stub is_uploaded — always false without the cuda feature. CL-454.
833    pub fn is_uploaded(&self) -> bool {
834        false
835    }
836}
837
838#[cfg(not(feature = "cuda"))]
839pub fn begin_capture<T>(_stream: &T) -> GpuResult<()> {
840    Err(GpuError::NoCudaFeature)
841}
842
843/// Stub `begin_capture_with_mode` when the cuda feature is not enabled.
844/// CL-454.
845#[cfg(not(feature = "cuda"))]
846pub fn begin_capture_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<()> {
847    Err(GpuError::NoCudaFeature)
848}
849
850/// Stub `capture_status` when the cuda feature is not enabled. CL-454.
851#[cfg(not(feature = "cuda"))]
852pub fn capture_status<T>(_stream: &T) -> GpuResult<CaptureStatus> {
853    Err(GpuError::NoCudaFeature)
854}
855
856/// Stub `is_stream_capturing` when the cuda feature is not enabled.
857/// CL-454.
858#[cfg(not(feature = "cuda"))]
859pub fn is_stream_capturing<T>(_stream: &T) -> GpuResult<bool> {
860    Err(GpuError::NoCudaFeature)
861}
862
863#[cfg(not(feature = "cuda"))]
864pub fn end_capture<T>(_stream: &T) -> GpuResult<CapturedGraph> {
865    Err(GpuError::NoCudaFeature)
866}
867
868/// Stub `end_capture_with_pool` when the cuda feature is not enabled.
869/// CL-278.
870#[cfg(not(feature = "cuda"))]
871pub fn end_capture_with_pool<T>(
872    _stream: &T,
873    _pool: std::sync::Arc<CapturePool>,
874) -> GpuResult<CapturedGraph> {
875    Err(GpuError::NoCudaFeature)
876}
877
878/// Stub `GraphCaptureGuard` when the cuda feature is not enabled. CL-454.
879#[cfg(not(feature = "cuda"))]
880pub struct GraphCaptureGuard {
881    _never: core::convert::Infallible,
882}
883
884#[cfg(not(feature = "cuda"))]
885impl GraphCaptureGuard {
886    pub fn begin<T>(_stream: &T) -> GpuResult<Self> {
887        Err(GpuError::NoCudaFeature)
888    }
889
890    pub fn begin_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<Self> {
891        Err(GpuError::NoCudaFeature)
892    }
893
894    pub fn begin_with_pool<T>(
895        _stream: &T,
896        _pool: std::sync::Arc<CapturePool>,
897    ) -> GpuResult<Self> {
898        Err(GpuError::NoCudaFeature)
899    }
900
901    pub fn finish(self) -> GpuResult<CapturedGraph> {
902        match self._never {}
903    }
904
905    pub fn status(&self) -> GpuResult<CaptureStatus> {
906        match self._never {}
907    }
908}
909
910/// Stub `graph_pool_handle` when the cuda feature is not enabled. CL-454.
911#[cfg(not(feature = "cuda"))]
912pub fn graph_pool_handle() -> GraphPoolHandle {
913    GraphPoolHandle(0)
914}
915
916/// Stub `capture_pool_for_handle` when the cuda feature is not enabled.
917/// CL-454.
918#[cfg(not(feature = "cuda"))]
919pub fn capture_pool_for_handle(_handle: GraphPoolHandle) -> Option<std::sync::Arc<CapturePool>> {
920    None
921}
922
923/// Stub `release_graph_pool_handle` when the cuda feature is not enabled.
924/// CL-454.
925#[cfg(not(feature = "cuda"))]
926pub fn release_graph_pool_handle(_handle: GraphPoolHandle) {
927    // nothing to release
928}
929
930/// Stub `make_graphed_callable` when the cuda feature is not enabled.
931/// CL-454.
932#[cfg(not(feature = "cuda"))]
933pub fn make_graphed_callable<T, F>(
934    _stream: &T,
935    _mode: CaptureMode,
936    _f: F,
937) -> GpuResult<CapturedGraph>
938where
939    F: FnOnce() -> GpuResult<()>,
940{
941    Err(GpuError::NoCudaFeature)
942}
943
944// ---------------------------------------------------------------------------
945// Tests — CL-278 capture pool buffer tracking
946// ---------------------------------------------------------------------------
947
948#[cfg(all(test, feature = "cuda"))]
949mod tests {
950    use super::*;
951
952    #[test]
953    fn capture_pool_buffer_count_starts_at_zero() {
954        let pool = CapturePool::new();
955        assert_eq!(pool.buffer_count(), 0);
956    }
957
958    #[test]
959    fn capture_pool_record_buffer_increments_count() {
960        let pool = CapturePool::new();
961        let buf_a: Vec<f32> = vec![0.0; 10];
962        let idx = pool.record_buffer(buf_a);
963        assert_eq!(idx, 0);
964        assert_eq!(pool.buffer_count(), 1);
965
966        let buf_b: Vec<f64> = vec![0.0; 5];
967        let idx = pool.record_buffer(buf_b);
968        assert_eq!(idx, 1);
969        assert_eq!(pool.buffer_count(), 2);
970    }
971
972    #[test]
973    fn capture_pool_clear_buffers_resets_count_but_keeps_pool() {
974        let pool = CapturePool::new();
975        pool.record_buffer(vec![0u8; 16]);
976        pool.record_buffer(vec![0u8; 32]);
977        assert_eq!(pool.buffer_count(), 2);
978        pool.clear_buffers();
979        assert_eq!(pool.buffer_count(), 0);
980        // Pool is still usable.
981        pool.record_buffer(vec![0u8; 8]);
982        assert_eq!(pool.buffer_count(), 1);
983    }
984
985    #[test]
986    fn capture_pool_drop_releases_registered_buffers() {
987        // Use Arc to detect when the inner buffer is dropped.
988        let buf = Arc::new(vec![1.0f32, 2.0, 3.0]);
989        let pool = CapturePool::new();
990        pool.record_buffer(Arc::clone(&buf));
991        assert_eq!(Arc::strong_count(&buf), 2);
992        drop(pool);
993        // Pool dropped → recorded Arc dropped → strong count back to 1.
994        assert_eq!(Arc::strong_count(&buf), 1);
995    }
996
997    #[test]
998    fn capture_pool_records_heterogeneous_types() {
999        let pool = CapturePool::new();
1000        pool.record_buffer(vec![0.0f32; 4]);
1001        pool.record_buffer(vec![0.0f64; 4]);
1002        pool.record_buffer(vec![0u8; 4]);
1003        pool.record_buffer(Arc::new(42i32));
1004        assert_eq!(pool.buffer_count(), 4);
1005    }
1006
1007    #[test]
1008    fn capture_pool_seal_unseal() {
1009        let pool = CapturePool::new();
1010        assert!(!pool.is_capture_pool_sealed());
1011        pool.seal();
1012        assert!(pool.is_capture_pool_sealed());
1013        pool.unseal();
1014        assert!(!pool.is_capture_pool_sealed());
1015    }
1016
1017    // -----------------------------------------------------------------------
1018    // CL-454 — CaptureMode / CaptureStatus / graph pool handle tests.
1019    //
1020    // These tests exercise the typed wrappers and the process-wide
1021    // pool-handle registry without requiring a real CUDA device.
1022    // Tests that actually touch a device live under the
1023    // `feature = "cuda-live"` gate (run them with
1024    //   cargo test -p ferrotorch-gpu --features cuda,cuda-live
1025    // on a machine with a functioning CUDA driver).
1026    // -----------------------------------------------------------------------
1027
1028    #[test]
1029    fn capture_mode_default_is_thread_local() {
1030        assert_eq!(CaptureMode::default(), CaptureMode::ThreadLocal);
1031    }
1032
1033    #[test]
1034    fn capture_mode_to_cuda_round_trip() {
1035        use cudarc::driver::sys::CUstreamCaptureMode::*;
1036        assert_eq!(CaptureMode::Global.to_cuda(), CU_STREAM_CAPTURE_MODE_GLOBAL);
1037        assert_eq!(
1038            CaptureMode::ThreadLocal.to_cuda(),
1039            CU_STREAM_CAPTURE_MODE_THREAD_LOCAL
1040        );
1041        assert_eq!(
1042            CaptureMode::Relaxed.to_cuda(),
1043            CU_STREAM_CAPTURE_MODE_RELAXED
1044        );
1045    }
1046
1047    #[test]
1048    fn capture_status_is_capturing_only_when_active() {
1049        assert!(!CaptureStatus::None.is_capturing());
1050        assert!(CaptureStatus::Active.is_capturing());
1051        assert!(!CaptureStatus::Invalidated.is_capturing());
1052    }
1053
1054    #[test]
1055    fn capture_status_is_invalidated_only_when_broken() {
1056        assert!(!CaptureStatus::None.is_invalidated());
1057        assert!(!CaptureStatus::Active.is_invalidated());
1058        assert!(CaptureStatus::Invalidated.is_invalidated());
1059    }
1060
1061    #[test]
1062    fn capture_status_from_cuda_maps_all_variants() {
1063        use cudarc::driver::sys::CUstreamCaptureStatus::*;
1064        assert_eq!(
1065            CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_NONE),
1066            CaptureStatus::None
1067        );
1068        assert_eq!(
1069            CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_ACTIVE),
1070            CaptureStatus::Active
1071        );
1072        assert_eq!(
1073            CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_INVALIDATED),
1074            CaptureStatus::Invalidated
1075        );
1076    }
1077
1078    #[test]
1079    fn graph_pool_handle_allocates_unique_ids() {
1080        let h1 = graph_pool_handle();
1081        let h2 = graph_pool_handle();
1082        assert_ne!(h1, h2, "each call should return a fresh id");
1083        // Both handles should map back to a real pool.
1084        assert!(capture_pool_for_handle(h1).is_some());
1085        assert!(capture_pool_for_handle(h2).is_some());
1086        release_graph_pool_handle(h1);
1087        release_graph_pool_handle(h2);
1088    }
1089
1090    #[test]
1091    fn graph_pool_handle_shares_single_pool_across_lookups() {
1092        let h = graph_pool_handle();
1093        let a = capture_pool_for_handle(h).expect("handle registered");
1094        let b = capture_pool_for_handle(h).expect("handle still registered");
1095        assert!(
1096            Arc::ptr_eq(&a, &b),
1097            "both lookups should return the same pool Arc"
1098        );
1099
1100        // Register a buffer through one lookup; the other should see it.
1101        a.record_buffer(vec![1.0f32, 2.0]);
1102        assert_eq!(b.buffer_count(), 1);
1103
1104        release_graph_pool_handle(h);
1105        // After release the registry no longer hands out the pool, but
1106        // existing Arcs keep it alive.
1107        assert!(capture_pool_for_handle(h).is_none());
1108        // The existing Arc still has its buffer.
1109        assert_eq!(a.buffer_count(), 1);
1110    }
1111
1112    #[test]
1113    fn graph_pool_handle_release_is_idempotent() {
1114        let h = graph_pool_handle();
1115        assert!(capture_pool_for_handle(h).is_some());
1116        release_graph_pool_handle(h);
1117        release_graph_pool_handle(h); // second call is fine
1118        assert!(capture_pool_for_handle(h).is_none());
1119    }
1120
1121    #[test]
1122    fn graph_pool_handle_unknown_id_returns_none() {
1123        // A fresh handle ID that was never registered.
1124        let fake = GraphPoolHandle(u64::MAX);
1125        assert!(capture_pool_for_handle(fake).is_none());
1126    }
1127}
1128
1129// ---------------------------------------------------------------------------
1130// CL-454 — tests that don't need cudarc type info.
1131// ---------------------------------------------------------------------------
1132
1133#[cfg(all(test, not(feature = "cuda")))]
1134mod no_cuda_tests {
1135    use super::*;
1136
1137    #[test]
1138    fn capture_mode_and_status_exist_without_cuda_feature() {
1139        // The feature-portable types compile without the cuda feature
1140        // so callers can write cfg-free code.
1141        let _ = CaptureMode::default();
1142        assert!(!CaptureStatus::None.is_capturing());
1143        assert!(CaptureStatus::Active.is_capturing());
1144        assert!(CaptureStatus::Invalidated.is_invalidated());
1145    }
1146
1147    #[test]
1148    fn graph_pool_handle_without_cuda_returns_sentinel() {
1149        let h = graph_pool_handle();
1150        assert_eq!(h.0, 0, "stub handle is always zero without cuda feature");
1151        assert!(capture_pool_for_handle(h).is_none());
1152        release_graph_pool_handle(h); // no-op
1153    }
1154
1155    #[test]
1156    fn captured_graph_stub_num_replays_and_is_uploaded_are_zero() {
1157        let g = CapturedGraph;
1158        assert_eq!(g.num_replays(), 0);
1159        assert!(!g.is_uploaded());
1160    }
1161}