Skip to main content

ferrotorch_gpu/
graph.rs

1//! CUDA graph capture and replay infrastructure.
2//!
3//! A CUDA graph records a sequence of GPU operations (kernel launches, memcpys)
4//! and replays them as a single driver submission. This eliminates per-kernel
5//! launch overhead (~70μs on WSL2, ~5μs on native Linux per call) by collapsing
6//! hundreds of launches into one.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use ferrotorch_gpu::graph::{DeviceScalar, begin_capture, end_capture};
12//!
13//! // Pre-allocate all buffers BEFORE capture
14//! let mut out = alloc_zeros_f32(768, &device)?;
15//!
16//! // Parameters that change between replays go in DeviceScalar
17//! let mut pos = DeviceScalar::new(device.stream(), 0u32)?;
18//!
19//! // Capture
20//! begin_capture(device.stream())?;
21//! gpu_add_into(&a, &b, &mut out, &device)?;  // recorded, not executed
22//! let graph = end_capture(device.stream())?;
23//!
24//! // Replay loop
25//! for i in 0..100 {
26//!     pos.update(i as u32)?;  // memcpy before replay
27//!     graph.launch()?;         // replay all captured ops
28//! }
29//! ```
30
31#[cfg(feature = "cuda")]
32use std::sync::Arc;
33#[cfg(feature = "cuda")]
34use std::sync::atomic::{AtomicU64, Ordering};
35
36#[cfg(feature = "cuda")]
37use cudarc::driver::{CudaSlice, CudaStream, DeviceRepr, ValidAsZeroBits};
38
39use crate::error::{GpuError, GpuResult};
40
41// ---------------------------------------------------------------------------
42// CaptureMode — typed wrapper over cudarc's CUstreamCaptureMode
43// ---------------------------------------------------------------------------
44
45/// Selects how CUDA graph capture serializes interactions with other
46/// threads. Mirrors `cudaStreamCaptureMode`.
47///
48/// - `Global` — any CUDA API call from any thread that touches the
49///   capturing stream (or any thread that is also capturing) will
50///   invalidate capture. Safest for debugging; matches PyTorch's
51///   default.
52/// - `ThreadLocal` — only calls from the capturing thread can
53///   invalidate capture. Other threads may freely use unrelated
54///   streams. This is what ferrotorch-gpu has always used.
55/// - `Relaxed` — the driver does not track cross-thread interactions
56///   at all. Fastest, but the caller is fully responsible for making
57///   sure no other thread interferes.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59pub enum CaptureMode {
60    /// Global serialization (`CU_STREAM_CAPTURE_MODE_GLOBAL`).
61    Global,
62    /// Thread-local serialization (`CU_STREAM_CAPTURE_MODE_THREAD_LOCAL`).
63    /// This is the default in PyTorch's `cuda.graph` context.
64    ThreadLocal,
65    /// Relaxed — no cross-thread serialization
66    /// (`CU_STREAM_CAPTURE_MODE_RELAXED`).
67    Relaxed,
68}
69
70impl Default for CaptureMode {
71    fn default() -> Self {
72        // ThreadLocal is what ferrotorch-gpu has historically used and
73        // matches PyTorch's `torch.cuda.graph(stream=..., capture_error_mode="thread_local")`.
74        Self::ThreadLocal
75    }
76}
77
78#[cfg(feature = "cuda")]
79impl CaptureMode {
80    /// Convert to the raw cudarc enum.
81    #[inline]
82    pub fn to_cuda(self) -> cudarc::driver::sys::CUstreamCaptureMode {
83        use cudarc::driver::sys::CUstreamCaptureMode::*;
84        match self {
85            Self::Global => CU_STREAM_CAPTURE_MODE_GLOBAL,
86            Self::ThreadLocal => CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
87            Self::Relaxed => CU_STREAM_CAPTURE_MODE_RELAXED,
88        }
89    }
90}
91
92// ---------------------------------------------------------------------------
93// CaptureStatus — typed wrapper over cudarc's CUstreamCaptureStatus
94// ---------------------------------------------------------------------------
95
96/// The capture state of a CUDA stream. Matches `cudaStreamCaptureStatus`.
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
98pub enum CaptureStatus {
99    /// The stream is not currently capturing any graph.
100    None,
101    /// The stream is actively capturing a graph.
102    Active,
103    /// Capture was invalidated (e.g., by a forbidden API call or a
104    /// cross-stream dependency). The caller must call `end_capture`
105    /// to discard the broken graph before doing anything else on the
106    /// stream.
107    Invalidated,
108}
109
110#[cfg(feature = "cuda")]
111impl CaptureStatus {
112    fn from_cuda(raw: cudarc::driver::sys::CUstreamCaptureStatus) -> Self {
113        use cudarc::driver::sys::CUstreamCaptureStatus::*;
114        match raw {
115            CU_STREAM_CAPTURE_STATUS_NONE => Self::None,
116            CU_STREAM_CAPTURE_STATUS_ACTIVE => Self::Active,
117            CU_STREAM_CAPTURE_STATUS_INVALIDATED => Self::Invalidated,
118        }
119    }
120}
121
122impl CaptureStatus {
123    /// Returns `true` if this stream is actively capturing a graph.
124    #[inline]
125    pub fn is_capturing(&self) -> bool {
126        matches!(self, Self::Active)
127    }
128
129    /// Returns `true` if capture was invalidated and must be ended.
130    #[inline]
131    pub fn is_invalidated(&self) -> bool {
132        matches!(self, Self::Invalidated)
133    }
134}
135
136// ---------------------------------------------------------------------------
137// DeviceScalar — a single value in GPU memory, updatable before graph replay
138// ---------------------------------------------------------------------------
139
140/// A single scalar value stored in GPU device memory.
141///
142/// Used for CUDA graph capture: the graph records the device pointer (fixed
143/// address), and the caller updates the value via [`update`](DeviceScalar::update)
144/// before each [`CapturedGraph::launch`]. The update is a 4-or-8 byte
145/// `cuMemcpyHtoDAsync` — effectively zero cost.
146#[cfg(feature = "cuda")]
147pub struct DeviceScalar<T: DeviceRepr + ValidAsZeroBits + Copy> {
148    buf: CudaSlice<T>,
149    stream: Arc<CudaStream>,
150}
151
152#[cfg(feature = "cuda")]
153impl<T: DeviceRepr + ValidAsZeroBits + Copy> DeviceScalar<T> {
154    /// Allocate a device scalar with the given initial value.
155    pub fn new(stream: &Arc<CudaStream>, initial: T) -> GpuResult<Self> {
156        let buf = stream.clone_htod(&[initial])?;
157        Ok(Self {
158            buf,
159            stream: Arc::clone(stream),
160        })
161    }
162
163    /// Update the device value. This is an async H→D memcpy of `size_of::<T>()`
164    /// bytes. Must be called on the same stream as the graph to ensure ordering.
165    pub fn update(&mut self, value: T) -> GpuResult<()> {
166        self.stream.memcpy_htod(&[value], &mut self.buf)?;
167        Ok(())
168    }
169
170    /// Borrow the underlying `CudaSlice` for use as a kernel parameter.
171    /// The graph captures this pointer address; updating the value later
172    /// changes what the kernel reads without re-capturing.
173    #[inline]
174    pub fn inner(&self) -> &CudaSlice<T> {
175        &self.buf
176    }
177}
178
179// ---------------------------------------------------------------------------
180// CapturedGraph — a replayable CUDA graph
181// ---------------------------------------------------------------------------
182
183/// A captured and instantiated CUDA graph that can be replayed with
184/// [`launch`](CapturedGraph::launch).
185///
186/// Created via [`begin_capture`] + GPU ops + [`end_capture`].
187/// The graph holds references to all device memory used during capture.
188/// Those buffers must remain allocated for the lifetime of the graph.
189///
190/// **Allocator pool integration (CL-278).** When created via
191/// [`end_capture_with_pool`], the graph holds a strong reference to
192/// the [`CapturePool`] that recorded its allocations. The pool keeps
193/// every registered buffer alive until the last `CapturedGraph`
194/// referencing it is dropped, which guarantees the device pointers
195/// recorded in the graph remain valid across replays. Without the
196/// pool, callers must manually keep buffers alive (the original
197/// [`end_capture`] API).
198#[cfg(feature = "cuda")]
199pub struct CapturedGraph {
200    graph: cudarc::driver::CudaGraph,
201    /// Optional reference to the pool that owns the graph's
202    /// allocations. Some(pool) when constructed via
203    /// [`end_capture_with_pool`]. Dropping the graph drops this
204    /// Arc, which (if it's the last reference) drops every buffer
205    /// the pool holds. CL-278.
206    pool: Option<Arc<CapturePool>>,
207    /// Monotonic counter bumped by every successful [`launch`]. Lets
208    /// callers assert that a specific replay happened after some
209    /// other work completed, useful for graph-aware profilers and
210    /// integration tests. CL-454.
211    replay_count: AtomicU64,
212    /// True after the first successful [`upload`] so subsequent
213    /// uploads become cheap no-ops. CL-454.
214    uploaded: std::sync::atomic::AtomicBool,
215}
216
217#[cfg(feature = "cuda")]
218impl CapturedGraph {
219    /// Replay all operations captured in this graph.
220    ///
221    /// Before calling this, update any [`DeviceScalar`] values and perform
222    /// any pre-launch memcpys (e.g., position embeddings). All updates must
223    /// be on the same stream the graph was captured on.
224    ///
225    /// Bumps [`num_replays`](Self::num_replays) on success.
226    pub fn launch(&self) -> GpuResult<()> {
227        self.graph.launch()?;
228        self.replay_count.fetch_add(1, Ordering::Relaxed);
229        Ok(())
230    }
231
232    /// Pre-upload the graph's executable resources to the device.
233    ///
234    /// The first [`launch`](Self::launch) of a freshly instantiated graph
235    /// pays a one-time cost for the driver to copy the exec into GPU
236    /// memory. Calling `upload` up front shifts that cost out of the
237    /// hot replay loop. Subsequent uploads are a no-op. CL-454.
238    pub fn upload(&self) -> GpuResult<()> {
239        if self.uploaded.load(Ordering::Acquire) {
240            return Ok(());
241        }
242        self.graph.upload()?;
243        self.uploaded.store(true, Ordering::Release);
244        Ok(())
245    }
246
247    /// Number of successful replays issued on this graph. CL-454.
248    #[inline]
249    pub fn num_replays(&self) -> u64 {
250        self.replay_count.load(Ordering::Relaxed)
251    }
252
253    /// Returns `true` if [`upload`](Self::upload) has been called on
254    /// this graph. CL-454.
255    #[inline]
256    pub fn is_uploaded(&self) -> bool {
257        self.uploaded.load(Ordering::Acquire)
258    }
259
260    /// Number of buffers held alive by this graph's allocator pool.
261    /// Returns 0 if the graph was created without a pool. CL-278.
262    pub fn pool_buffer_count(&self) -> usize {
263        self.pool
264            .as_ref()
265            .map(|p| p.buffer_count())
266            .unwrap_or(0)
267    }
268
269    /// True if this graph holds a CapturePool reference. CL-278.
270    pub fn has_pool(&self) -> bool {
271        self.pool.is_some()
272    }
273
274    /// Return the [`Arc<CapturePool>`] this graph is using, if any.
275    /// Allows sharing the same pool between multiple graphs so they
276    /// all keep the same buffers alive. CL-454.
277    pub fn pool(&self) -> Option<&Arc<CapturePool>> {
278        self.pool.as_ref()
279    }
280}
281
282// ---------------------------------------------------------------------------
283// Capture API
284// ---------------------------------------------------------------------------
285
286/// Begin CUDA graph capture on the given stream.
287///
288/// All GPU operations (kernel launches, cuBLAS calls, memcpys) issued on this
289/// stream after this call are **recorded but not executed**. Call
290/// [`end_capture`] to finalize and instantiate the graph.
291///
292/// # Requirements
293///
294/// - All output buffers must be pre-allocated before capture begins.
295/// - No `alloc_zeros` / `cpu_to_gpu` during capture (use `_into` variants).
296/// - No CPU↔GPU synchronization during capture.
297/// - Event tracking should be disabled during capture to avoid interference
298///   (call `ctx.disable_event_tracking()` before, re-enable after).
299#[cfg(feature = "cuda")]
300pub fn begin_capture(stream: &Arc<CudaStream>) -> GpuResult<()> {
301    begin_capture_with_mode(stream, CaptureMode::default())
302}
303
304/// Begin CUDA graph capture with an explicit [`CaptureMode`]. CL-454.
305///
306/// Prefer [`begin_capture`] for the default (`ThreadLocal`) mode. Use
307/// this form when you need `Global` (debugging / strict serialization)
308/// or `Relaxed` (max throughput, single-thread ownership).
309#[cfg(feature = "cuda")]
310pub fn begin_capture_with_mode(
311    stream: &Arc<CudaStream>,
312    mode: CaptureMode,
313) -> GpuResult<()> {
314    stream.begin_capture(mode.to_cuda())?;
315    Ok(())
316}
317
318/// Query the capture status of a CUDA stream. CL-454.
319///
320/// This is the ferrotorch-gpu equivalent of PyTorch's
321/// `torch.cuda.is_current_stream_capturing`. Callers can use this to
322/// skip capture-invalid APIs (allocator calls, H↔D copies) when a
323/// graph is being recorded.
324#[cfg(feature = "cuda")]
325pub fn capture_status(stream: &Arc<CudaStream>) -> GpuResult<CaptureStatus> {
326    let raw = stream.capture_status()?;
327    Ok(CaptureStatus::from_cuda(raw))
328}
329
330/// Shorthand for `capture_status(stream)?.is_capturing()`. CL-454.
331#[cfg(feature = "cuda")]
332pub fn is_stream_capturing(stream: &Arc<CudaStream>) -> GpuResult<bool> {
333    Ok(capture_status(stream)?.is_capturing())
334}
335
336/// End CUDA graph capture, instantiate, and return the replayable graph.
337///
338/// Returns `Err` if capture was not active or if instantiation fails.
339///
340/// The returned graph has no [`CapturePool`] attached. The caller is
341/// responsible for keeping the buffers used by the captured kernels
342/// alive for the graph's lifetime. Use [`end_capture_with_pool`]
343/// for the lifetime-managed variant.
344#[cfg(feature = "cuda")]
345pub fn end_capture(stream: &Arc<CudaStream>) -> GpuResult<CapturedGraph> {
346    let flags = cudarc::driver::sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
347    let graph = stream
348        .end_capture(flags)?
349        .ok_or(GpuError::PtxCompileFailed {
350            kernel: "CUDA graph capture returned null",
351        })?;
352    Ok(CapturedGraph {
353        graph,
354        pool: None,
355        replay_count: AtomicU64::new(0),
356        uploaded: std::sync::atomic::AtomicBool::new(false),
357    })
358}
359
360/// End CUDA graph capture and attach a [`CapturePool`] reference to
361/// the resulting [`CapturedGraph`]. CL-278.
362///
363/// The pool's tracked buffers are kept alive for the lifetime of the
364/// returned graph: dropping the graph drops its `Arc<CapturePool>`,
365/// which (if it's the last reference) drops every buffer the pool
366/// recorded. This guarantees that the device pointers recorded in
367/// the captured graph remain valid across replays.
368///
369/// Use this in concert with [`CapturePool::record_buffer`]: allocate
370/// every buffer used during capture before calling `begin_capture`,
371/// register each one with the pool, run the kernels under capture,
372/// then call `end_capture_with_pool(stream, pool)` to seal the
373/// lifetime relationship.
374#[cfg(feature = "cuda")]
375pub fn end_capture_with_pool(
376    stream: &Arc<CudaStream>,
377    pool: Arc<CapturePool>,
378) -> GpuResult<CapturedGraph> {
379    let mut graph = end_capture(stream)?;
380    graph.pool = Some(pool);
381    Ok(graph)
382}
383
384// ---------------------------------------------------------------------------
385// GraphCaptureGuard — RAII wrapper that ends capture on drop
386// ---------------------------------------------------------------------------
387
388/// RAII guard that runs CUDA graph capture in a scoped block.
389///
390/// Call [`GraphCaptureGuard::begin`] (or [`begin_with_mode`] /
391/// [`begin_with_pool`]) to start capture; calling [`finish`] returns
392/// the instantiated graph. If the guard is dropped without calling
393/// `finish` (for example because a kernel returned an error
394/// mid-capture), its `Drop` impl best-effort-ends capture and
395/// discards the resulting graph so the stream returns to a usable
396/// state. CL-454.
397///
398/// This mirrors PyTorch's `with torch.cuda.graph(g): ...` context
399/// manager semantics in Rust's RAII idiom.
400///
401/// # Example
402///
403/// ```ignore
404/// use ferrotorch_gpu::graph::GraphCaptureGuard;
405///
406/// let mut guard = GraphCaptureGuard::begin(device.stream())?;
407/// run_kernels()?; // any kernel launched on device.stream() is recorded
408/// let graph = guard.finish()?;
409/// graph.upload()?;
410/// for _ in 0..1000 { graph.launch()?; }
411/// ```
412#[cfg(feature = "cuda")]
413pub struct GraphCaptureGuard {
414    stream: Arc<CudaStream>,
415    /// Optional pool to attach when `finish` is called.
416    pool: Option<Arc<CapturePool>>,
417    /// Becomes `false` after [`finish`] consumes the guard, so `Drop`
418    /// knows capture is already ended.
419    active: bool,
420}
421
422#[cfg(feature = "cuda")]
423impl GraphCaptureGuard {
424    /// Begin graph capture on `stream` in the default
425    /// [`CaptureMode::ThreadLocal`] mode. CL-454.
426    pub fn begin(stream: &Arc<CudaStream>) -> GpuResult<Self> {
427        Self::begin_with_mode(stream, CaptureMode::default())
428    }
429
430    /// Begin graph capture with an explicit [`CaptureMode`]. CL-454.
431    pub fn begin_with_mode(
432        stream: &Arc<CudaStream>,
433        mode: CaptureMode,
434    ) -> GpuResult<Self> {
435        begin_capture_with_mode(stream, mode)?;
436        Ok(Self {
437            stream: Arc::clone(stream),
438            pool: None,
439            active: true,
440        })
441    }
442
443    /// Begin graph capture bound to a [`CapturePool`]. The pool is
444    /// attached to the resulting graph by [`finish`]. CL-454.
445    pub fn begin_with_pool(
446        stream: &Arc<CudaStream>,
447        pool: Arc<CapturePool>,
448    ) -> GpuResult<Self> {
449        begin_capture_with_pool(&pool, stream)?;
450        Ok(Self {
451            stream: Arc::clone(stream),
452            pool: Some(pool),
453            active: true,
454        })
455    }
456
457    /// Finish capture and return the instantiated [`CapturedGraph`].
458    ///
459    /// Consumes the guard so `Drop` becomes a no-op. If a pool was
460    /// attached at construction, the resulting graph is produced via
461    /// [`end_capture_with_pool`] and holds the pool Arc for the
462    /// lifetime of the graph.
463    pub fn finish(mut self) -> GpuResult<CapturedGraph> {
464        self.active = false;
465        if let Some(pool) = self.pool.take() {
466            end_capture_with_pool(&self.stream, pool)
467        } else {
468            end_capture(&self.stream)
469        }
470    }
471
472    /// Report whether the stream this guard is bound to is still
473    /// actively capturing. An unexpected `Invalidated` or `None`
474    /// usually means a forbidden API call (alloc, sync, host copy)
475    /// happened under capture.
476    pub fn status(&self) -> GpuResult<CaptureStatus> {
477        capture_status(&self.stream)
478    }
479}
480
481#[cfg(feature = "cuda")]
482impl Drop for GraphCaptureGuard {
483    fn drop(&mut self) {
484        if !self.active {
485            return;
486        }
487        // Best-effort: discard the in-flight capture so the stream
488        // becomes usable again. We ignore errors because we're in
489        // Drop — the CapturedGraph result is immediately dropped.
490        let _ = end_capture(&self.stream);
491    }
492}
493
494// ---------------------------------------------------------------------------
495// Graph pool handle registry — share a CapturePool across multiple graphs
496// ---------------------------------------------------------------------------
497
498/// Opaque handle for a pool registered with the process-wide graph
499/// pool registry. Used to share the same buffer-lifetime pool across
500/// multiple captured graphs without passing `Arc<CapturePool>` around
501/// by hand. CL-454.
502///
503/// Mirrors PyTorch's `torch.cuda.graph_pool_handle()`.
504#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
505pub struct GraphPoolHandle(pub u64);
506
507#[cfg(feature = "cuda")]
508static NEXT_POOL_HANDLE: AtomicU64 = AtomicU64::new(1);
509
510#[cfg(feature = "cuda")]
511static POOL_REGISTRY: std::sync::OnceLock<
512    std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>>,
513> = std::sync::OnceLock::new();
514
515#[cfg(feature = "cuda")]
516fn pool_registry() -> &'static std::sync::Mutex<std::collections::HashMap<u64, Arc<CapturePool>>> {
517    POOL_REGISTRY.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()))
518}
519
520/// Allocate a fresh [`GraphPoolHandle`] and register a new
521/// [`CapturePool`] under it in the process-wide registry. CL-454.
522///
523/// The handle can later be passed to [`capture_pool_for_handle`] to
524/// retrieve the same `Arc<CapturePool>` from any thread, which lets
525/// two independently captured graphs share the same buffer-keeping
526/// pool.
527#[cfg(feature = "cuda")]
528pub fn graph_pool_handle() -> GraphPoolHandle {
529    let id = NEXT_POOL_HANDLE.fetch_add(1, Ordering::Relaxed);
530    let pool = Arc::new(CapturePool::new());
531    let mut reg = pool_registry()
532        .lock()
533        .unwrap_or_else(|p| p.into_inner());
534    reg.insert(id, pool);
535    GraphPoolHandle(id)
536}
537
538/// Look up the [`CapturePool`] registered under `handle` and return
539/// a strong `Arc` to it. Returns `None` if the handle was never
540/// allocated or has been released via [`release_graph_pool_handle`].
541/// CL-454.
542#[cfg(feature = "cuda")]
543pub fn capture_pool_for_handle(handle: GraphPoolHandle) -> Option<Arc<CapturePool>> {
544    let reg = pool_registry()
545        .lock()
546        .unwrap_or_else(|p| p.into_inner());
547    reg.get(&handle.0).cloned()
548}
549
550/// Drop the registry's strong reference to the pool behind `handle`.
551/// Any [`CapturedGraph`] that holds its own Arc (for example via
552/// [`end_capture_with_pool`]) keeps the pool alive until that graph
553/// is dropped too. CL-454.
554#[cfg(feature = "cuda")]
555pub fn release_graph_pool_handle(handle: GraphPoolHandle) {
556    let mut reg = pool_registry()
557        .lock()
558        .unwrap_or_else(|p| p.into_inner());
559    reg.remove(&handle.0);
560}
561
562// ---------------------------------------------------------------------------
563// make_graphed_callable — scoped capture over a closure
564// ---------------------------------------------------------------------------
565
566/// Capture the operations performed by `f` into a CUDA graph and
567/// return the replayable graph. CL-454.
568///
569/// This is the ferrotorch-gpu equivalent of PyTorch's
570/// `torch.cuda.make_graphed_callables` for the simple single-callable
571/// case: the caller supplies a closure that runs all the GPU work to
572/// capture, and the returned [`CapturedGraph`] can be replayed over
573/// and over. The closure runs exactly once during capture, so all
574/// per-call work (allocations, dtype decisions) that isn't valid
575/// under capture must happen outside.
576///
577/// If the closure returns an error, capture is discarded and the
578/// error is propagated.
579#[cfg(feature = "cuda")]
580pub fn make_graphed_callable<F>(
581    stream: &Arc<CudaStream>,
582    mode: CaptureMode,
583    f: F,
584) -> GpuResult<CapturedGraph>
585where
586    F: FnOnce() -> GpuResult<()>,
587{
588    let guard = GraphCaptureGuard::begin_with_mode(stream, mode)?;
589    match f() {
590        Ok(()) => guard.finish(),
591        Err(e) => {
592            // Guard drop ends capture and discards the graph.
593            drop(guard);
594            Err(e)
595        }
596    }
597}
598
599// ---------------------------------------------------------------------------
600// CapturePool — memory pool for graph capture
601// ---------------------------------------------------------------------------
602
603/// A dedicated memory pool for CUDA graph capture.
604///
605/// Two responsibilities:
606///
607/// 1. **Sealed flag** — gates [`begin_capture_with_pool`] so the
608///    caller can express "no more allocations after this point"
609///    semantically. Sealed pools cannot satisfy new allocations
610///    during capture.
611///
612/// 2. **Buffer lifetime tracking (CL-278)** — registered buffers
613///    are kept alive by the pool itself, so they outlive any
614///    [`CapturedGraph`] that holds an `Arc<CapturePool>`. Dropping
615///    the graph drops the Arc, and dropping the last Arc drops
616///    every registered buffer in registration order.
617///
618/// # Usage
619///
620/// ```ignore
621/// use std::sync::Arc;
622/// let pool = Arc::new(CapturePool::new());
623///
624/// // Allocate every buffer the captured kernels will read or
625/// // write, and register each one with the pool so it stays alive
626/// // for the graph's lifetime.
627/// let mut buf_a = alloc_zeros_f32(1024, &device)?;
628/// let mut buf_b = alloc_zeros_f32(1024, &device)?;
629/// pool.record_buffer(buf_a.try_clone()?);
630/// pool.record_buffer(buf_b.try_clone()?);
631///
632/// pool.seal();
633/// begin_capture_with_pool(&pool, stream)?;
634/// // ... launch kernels using buf_a and buf_b ...
635/// let graph = end_capture_with_pool(stream, Arc::clone(&pool))?;
636/// // Dropping `pool` here is safe — the graph holds its own Arc.
637/// ```
638#[cfg(feature = "cuda")]
639pub struct CapturePool {
640    sealed: std::sync::atomic::AtomicBool,
641    /// Registered buffers (type-erased) kept alive for the graph's
642    /// lifetime. Each entry is a Box<dyn Any + Send + Sync> wrapping
643    /// the buffer's drop guard. CL-278.
644    buffers: std::sync::Mutex<Vec<Box<dyn std::any::Any + Send + Sync + 'static>>>,
645}
646
647#[cfg(feature = "cuda")]
648impl CapturePool {
649    /// Create a new, unsealed capture pool.
650    pub fn new() -> Self {
651        Self {
652            sealed: std::sync::atomic::AtomicBool::new(false),
653            buffers: std::sync::Mutex::new(Vec::new()),
654        }
655    }
656
657    /// Seal the pool, preventing any further allocations.
658    pub fn seal(&self) {
659        self.sealed
660            .store(true, std::sync::atomic::Ordering::Release);
661    }
662
663    /// Unseal the pool, allowing allocations again.
664    pub fn unseal(&self) {
665        self.sealed
666            .store(false, std::sync::atomic::Ordering::Release);
667    }
668
669    /// Check whether the pool is sealed.
670    pub fn is_capture_pool_sealed(&self) -> bool {
671        self.sealed.load(std::sync::atomic::Ordering::Acquire)
672    }
673
674    /// Register a buffer with the pool so it stays alive for the
675    /// lifetime of any [`CapturedGraph`] that holds this pool.
676    /// CL-278.
677    ///
678    /// `buffer` can be any type that owns GPU memory (typically
679    /// `CudaBuffer<f32>`, `CudaBuffer<f64>`, or `Arc<CudaBuffer<T>>`).
680    /// The pool stores it in a type-erased `Box<dyn Any + Send +
681    /// Sync>` and drops it (in registration order) when the pool
682    /// itself is dropped.
683    ///
684    /// Returns the index of the registered buffer for diagnostic
685    /// purposes.
686    pub fn record_buffer<B>(&self, buffer: B) -> usize
687    where
688        B: Send + Sync + 'static,
689    {
690        let mut guard = self
691            .buffers
692            .lock()
693            .unwrap_or_else(|p| p.into_inner());
694        let idx = guard.len();
695        guard.push(Box::new(buffer));
696        idx
697    }
698
699    /// Number of buffers currently registered with the pool. CL-278.
700    pub fn buffer_count(&self) -> usize {
701        self.buffers
702            .lock()
703            .map(|g| g.len())
704            .unwrap_or(0)
705    }
706
707    /// Drop every registered buffer immediately, in registration
708    /// order. The pool itself remains usable; new buffers can still
709    /// be registered after this call. CL-278.
710    ///
711    /// Use this when reusing a pool across multiple capture cycles.
712    /// Calling clear while a [`CapturedGraph`] still holds an Arc
713    /// to this pool is safe — the graph's strong reference keeps
714    /// the pool struct alive, but the buffer slots are reset.
715    pub fn clear_buffers(&self) {
716        let mut guard = self
717            .buffers
718            .lock()
719            .unwrap_or_else(|p| p.into_inner());
720        guard.clear();
721    }
722}
723
724#[cfg(feature = "cuda")]
725impl Default for CapturePool {
726    fn default() -> Self {
727        Self::new()
728    }
729}
730
731/// Begin CUDA graph capture with a capture pool.
732///
733/// Like [`begin_capture`], but checks that the capture pool is not sealed
734/// before starting capture. A sealed pool cannot satisfy allocations
735/// during graph capture, which would cause CUDA errors.
736///
737/// # Errors
738///
739/// Returns [`GpuError::InvalidArgument`](GpuError) if the pool is sealed.
740/// Returns a CUDA driver error if `begin_capture` fails.
741#[cfg(feature = "cuda")]
742pub fn begin_capture_with_pool(pool: &CapturePool, stream: &Arc<CudaStream>) -> GpuResult<()> {
743    if pool.is_capture_pool_sealed() {
744        return Err(GpuError::InvalidState {
745            message: "cannot begin graph capture: capture pool is sealed".into(),
746        });
747    }
748    begin_capture(stream)
749}
750
751/// Stub CapturePool when cuda feature is disabled. Provides the
752/// same surface API as the cuda-enabled type so callers compile on
753/// both feature configurations.
754#[cfg(not(feature = "cuda"))]
755pub struct CapturePool;
756
757#[cfg(not(feature = "cuda"))]
758impl CapturePool {
759    /// Create an empty CapturePool. Without the cuda feature the
760    /// pool has no internal state to initialize.
761    pub fn new() -> Self {
762        Self
763    }
764
765    /// No-op without the cuda feature: there is no real CUDA pool
766    /// to seal because no real allocations can happen.
767    pub fn seal(&self) {
768        // Without the cuda feature there is no allocator state to
769        // mutate; the CapturePool exists only so callers can write
770        // feature-portable code.
771    }
772
773    /// No-op without the cuda feature: there is no real CUDA pool
774    /// to unseal because no real allocations can happen.
775    pub fn unseal(&self) {
776        // Without the cuda feature there is no allocator state to
777        // mutate; the CapturePool exists only so callers can write
778        // feature-portable code.
779    }
780
781    /// Always returns `false` without the cuda feature since there
782    /// is no real pool that could be in either state.
783    pub fn is_capture_pool_sealed(&self) -> bool {
784        false
785    }
786
787    /// Always returns 0 without the cuda feature since no real
788    /// allocations can be tracked. CL-278.
789    pub fn buffer_count(&self) -> usize {
790        0
791    }
792}
793
794#[cfg(not(feature = "cuda"))]
795impl Default for CapturePool {
796    fn default() -> Self {
797        Self::new()
798    }
799}
800
801/// Stub begin_capture_with_pool when cuda feature is disabled.
802#[cfg(not(feature = "cuda"))]
803pub fn begin_capture_with_pool<T>(_pool: &CapturePool, _stream: &T) -> GpuResult<()> {
804    Err(GpuError::NoCudaFeature)
805}
806
807// ---------------------------------------------------------------------------
808// Stubs when cuda feature is disabled
809// ---------------------------------------------------------------------------
810
811/// Stub DeviceScalar.
812#[cfg(not(feature = "cuda"))]
813pub struct DeviceScalar<T: Copy> {
814    _phantom: std::marker::PhantomData<T>,
815}
816
817/// Stub CapturedGraph.
818#[cfg(not(feature = "cuda"))]
819pub struct CapturedGraph;
820
821#[cfg(not(feature = "cuda"))]
822impl CapturedGraph {
823    pub fn launch(&self) -> GpuResult<()> {
824        Err(GpuError::NoCudaFeature)
825    }
826
827    /// Stub upload for CL-454.
828    pub fn upload(&self) -> GpuResult<()> {
829        Err(GpuError::NoCudaFeature)
830    }
831
832    /// Stub num_replays — always 0 without the cuda feature. CL-454.
833    pub fn num_replays(&self) -> u64 {
834        0
835    }
836
837    /// Stub is_uploaded — always false without the cuda feature. CL-454.
838    pub fn is_uploaded(&self) -> bool {
839        false
840    }
841}
842
843#[cfg(not(feature = "cuda"))]
844pub fn begin_capture<T>(_stream: &T) -> GpuResult<()> {
845    Err(GpuError::NoCudaFeature)
846}
847
848/// Stub `begin_capture_with_mode` when the cuda feature is not enabled.
849/// CL-454.
850#[cfg(not(feature = "cuda"))]
851pub fn begin_capture_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<()> {
852    Err(GpuError::NoCudaFeature)
853}
854
855/// Stub `capture_status` when the cuda feature is not enabled. CL-454.
856#[cfg(not(feature = "cuda"))]
857pub fn capture_status<T>(_stream: &T) -> GpuResult<CaptureStatus> {
858    Err(GpuError::NoCudaFeature)
859}
860
861/// Stub `is_stream_capturing` when the cuda feature is not enabled.
862/// CL-454.
863#[cfg(not(feature = "cuda"))]
864pub fn is_stream_capturing<T>(_stream: &T) -> GpuResult<bool> {
865    Err(GpuError::NoCudaFeature)
866}
867
868#[cfg(not(feature = "cuda"))]
869pub fn end_capture<T>(_stream: &T) -> GpuResult<CapturedGraph> {
870    Err(GpuError::NoCudaFeature)
871}
872
873/// Stub `end_capture_with_pool` when the cuda feature is not enabled.
874/// CL-278.
875#[cfg(not(feature = "cuda"))]
876pub fn end_capture_with_pool<T>(
877    _stream: &T,
878    _pool: std::sync::Arc<CapturePool>,
879) -> GpuResult<CapturedGraph> {
880    Err(GpuError::NoCudaFeature)
881}
882
883/// Stub `GraphCaptureGuard` when the cuda feature is not enabled. CL-454.
884#[cfg(not(feature = "cuda"))]
885pub struct GraphCaptureGuard {
886    _never: core::convert::Infallible,
887}
888
889#[cfg(not(feature = "cuda"))]
890impl GraphCaptureGuard {
891    pub fn begin<T>(_stream: &T) -> GpuResult<Self> {
892        Err(GpuError::NoCudaFeature)
893    }
894
895    pub fn begin_with_mode<T>(_stream: &T, _mode: CaptureMode) -> GpuResult<Self> {
896        Err(GpuError::NoCudaFeature)
897    }
898
899    pub fn begin_with_pool<T>(
900        _stream: &T,
901        _pool: std::sync::Arc<CapturePool>,
902    ) -> GpuResult<Self> {
903        Err(GpuError::NoCudaFeature)
904    }
905
906    pub fn finish(self) -> GpuResult<CapturedGraph> {
907        match self._never {}
908    }
909
910    pub fn status(&self) -> GpuResult<CaptureStatus> {
911        match self._never {}
912    }
913}
914
915/// Stub `graph_pool_handle` when the cuda feature is not enabled. CL-454.
916#[cfg(not(feature = "cuda"))]
917pub fn graph_pool_handle() -> GraphPoolHandle {
918    GraphPoolHandle(0)
919}
920
921/// Stub `capture_pool_for_handle` when the cuda feature is not enabled.
922/// CL-454.
923#[cfg(not(feature = "cuda"))]
924pub fn capture_pool_for_handle(_handle: GraphPoolHandle) -> Option<std::sync::Arc<CapturePool>> {
925    None
926}
927
928/// Stub `release_graph_pool_handle` when the cuda feature is not enabled.
929/// CL-454.
930#[cfg(not(feature = "cuda"))]
931pub fn release_graph_pool_handle(_handle: GraphPoolHandle) {
932    // nothing to release
933}
934
935/// Stub `make_graphed_callable` when the cuda feature is not enabled.
936/// CL-454.
937#[cfg(not(feature = "cuda"))]
938pub fn make_graphed_callable<T, F>(
939    _stream: &T,
940    _mode: CaptureMode,
941    _f: F,
942) -> GpuResult<CapturedGraph>
943where
944    F: FnOnce() -> GpuResult<()>,
945{
946    Err(GpuError::NoCudaFeature)
947}
948
949// ---------------------------------------------------------------------------
950// Tests — CL-278 capture pool buffer tracking
951// ---------------------------------------------------------------------------
952
953#[cfg(all(test, feature = "cuda"))]
954mod tests {
955    use super::*;
956
957    #[test]
958    fn capture_pool_buffer_count_starts_at_zero() {
959        let pool = CapturePool::new();
960        assert_eq!(pool.buffer_count(), 0);
961    }
962
963    #[test]
964    fn capture_pool_record_buffer_increments_count() {
965        let pool = CapturePool::new();
966        let buf_a: Vec<f32> = vec![0.0; 10];
967        let idx = pool.record_buffer(buf_a);
968        assert_eq!(idx, 0);
969        assert_eq!(pool.buffer_count(), 1);
970
971        let buf_b: Vec<f64> = vec![0.0; 5];
972        let idx = pool.record_buffer(buf_b);
973        assert_eq!(idx, 1);
974        assert_eq!(pool.buffer_count(), 2);
975    }
976
977    #[test]
978    fn capture_pool_clear_buffers_resets_count_but_keeps_pool() {
979        let pool = CapturePool::new();
980        pool.record_buffer(vec![0u8; 16]);
981        pool.record_buffer(vec![0u8; 32]);
982        assert_eq!(pool.buffer_count(), 2);
983        pool.clear_buffers();
984        assert_eq!(pool.buffer_count(), 0);
985        // Pool is still usable.
986        pool.record_buffer(vec![0u8; 8]);
987        assert_eq!(pool.buffer_count(), 1);
988    }
989
990    #[test]
991    fn capture_pool_drop_releases_registered_buffers() {
992        // Use Arc to detect when the inner buffer is dropped.
993        let buf = Arc::new(vec![1.0f32, 2.0, 3.0]);
994        let pool = CapturePool::new();
995        pool.record_buffer(Arc::clone(&buf));
996        assert_eq!(Arc::strong_count(&buf), 2);
997        drop(pool);
998        // Pool dropped → recorded Arc dropped → strong count back to 1.
999        assert_eq!(Arc::strong_count(&buf), 1);
1000    }
1001
1002    #[test]
1003    fn capture_pool_records_heterogeneous_types() {
1004        let pool = CapturePool::new();
1005        pool.record_buffer(vec![0.0f32; 4]);
1006        pool.record_buffer(vec![0.0f64; 4]);
1007        pool.record_buffer(vec![0u8; 4]);
1008        pool.record_buffer(Arc::new(42i32));
1009        assert_eq!(pool.buffer_count(), 4);
1010    }
1011
1012    #[test]
1013    fn capture_pool_seal_unseal() {
1014        let pool = CapturePool::new();
1015        assert!(!pool.is_capture_pool_sealed());
1016        pool.seal();
1017        assert!(pool.is_capture_pool_sealed());
1018        pool.unseal();
1019        assert!(!pool.is_capture_pool_sealed());
1020    }
1021
1022    // -----------------------------------------------------------------------
1023    // CL-454 — CaptureMode / CaptureStatus / graph pool handle tests.
1024    //
1025    // These tests exercise the typed wrappers and the process-wide
1026    // pool-handle registry without requiring a real CUDA device.
1027    // Tests that actually touch a device live under the
1028    // `feature = "cuda-live"` gate (run them with
1029    //   cargo test -p ferrotorch-gpu --features cuda,cuda-live
1030    // on a machine with a functioning CUDA driver).
1031    // -----------------------------------------------------------------------
1032
1033    #[test]
1034    fn capture_mode_default_is_thread_local() {
1035        assert_eq!(CaptureMode::default(), CaptureMode::ThreadLocal);
1036    }
1037
1038    #[test]
1039    fn capture_mode_to_cuda_round_trip() {
1040        use cudarc::driver::sys::CUstreamCaptureMode::*;
1041        assert_eq!(CaptureMode::Global.to_cuda(), CU_STREAM_CAPTURE_MODE_GLOBAL);
1042        assert_eq!(
1043            CaptureMode::ThreadLocal.to_cuda(),
1044            CU_STREAM_CAPTURE_MODE_THREAD_LOCAL
1045        );
1046        assert_eq!(
1047            CaptureMode::Relaxed.to_cuda(),
1048            CU_STREAM_CAPTURE_MODE_RELAXED
1049        );
1050    }
1051
1052    #[test]
1053    fn capture_status_is_capturing_only_when_active() {
1054        assert!(!CaptureStatus::None.is_capturing());
1055        assert!(CaptureStatus::Active.is_capturing());
1056        assert!(!CaptureStatus::Invalidated.is_capturing());
1057    }
1058
1059    #[test]
1060    fn capture_status_is_invalidated_only_when_broken() {
1061        assert!(!CaptureStatus::None.is_invalidated());
1062        assert!(!CaptureStatus::Active.is_invalidated());
1063        assert!(CaptureStatus::Invalidated.is_invalidated());
1064    }
1065
1066    #[test]
1067    fn capture_status_from_cuda_maps_all_variants() {
1068        use cudarc::driver::sys::CUstreamCaptureStatus::*;
1069        assert_eq!(
1070            CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_NONE),
1071            CaptureStatus::None
1072        );
1073        assert_eq!(
1074            CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_ACTIVE),
1075            CaptureStatus::Active
1076        );
1077        assert_eq!(
1078            CaptureStatus::from_cuda(CU_STREAM_CAPTURE_STATUS_INVALIDATED),
1079            CaptureStatus::Invalidated
1080        );
1081    }
1082
1083    #[test]
1084    fn graph_pool_handle_allocates_unique_ids() {
1085        let h1 = graph_pool_handle();
1086        let h2 = graph_pool_handle();
1087        assert_ne!(h1, h2, "each call should return a fresh id");
1088        // Both handles should map back to a real pool.
1089        assert!(capture_pool_for_handle(h1).is_some());
1090        assert!(capture_pool_for_handle(h2).is_some());
1091        release_graph_pool_handle(h1);
1092        release_graph_pool_handle(h2);
1093    }
1094
1095    #[test]
1096    fn graph_pool_handle_shares_single_pool_across_lookups() {
1097        let h = graph_pool_handle();
1098        let a = capture_pool_for_handle(h).expect("handle registered");
1099        let b = capture_pool_for_handle(h).expect("handle still registered");
1100        assert!(
1101            Arc::ptr_eq(&a, &b),
1102            "both lookups should return the same pool Arc"
1103        );
1104
1105        // Register a buffer through one lookup; the other should see it.
1106        a.record_buffer(vec![1.0f32, 2.0]);
1107        assert_eq!(b.buffer_count(), 1);
1108
1109        release_graph_pool_handle(h);
1110        // After release the registry no longer hands out the pool, but
1111        // existing Arcs keep it alive.
1112        assert!(capture_pool_for_handle(h).is_none());
1113        // The existing Arc still has its buffer.
1114        assert_eq!(a.buffer_count(), 1);
1115    }
1116
1117    #[test]
1118    fn graph_pool_handle_release_is_idempotent() {
1119        let h = graph_pool_handle();
1120        assert!(capture_pool_for_handle(h).is_some());
1121        release_graph_pool_handle(h);
1122        release_graph_pool_handle(h); // second call is fine
1123        assert!(capture_pool_for_handle(h).is_none());
1124    }
1125
1126    #[test]
1127    fn graph_pool_handle_unknown_id_returns_none() {
1128        // A fresh handle ID that was never registered.
1129        let fake = GraphPoolHandle(u64::MAX);
1130        assert!(capture_pool_for_handle(fake).is_none());
1131    }
1132}
1133
1134// ---------------------------------------------------------------------------
1135// CL-454 — tests that don't need cudarc type info.
1136// ---------------------------------------------------------------------------
1137
1138#[cfg(all(test, not(feature = "cuda")))]
1139mod no_cuda_tests {
1140    use super::*;
1141
1142    #[test]
1143    fn capture_mode_and_status_exist_without_cuda_feature() {
1144        // The feature-portable types compile without the cuda feature
1145        // so callers can write cfg-free code.
1146        let _ = CaptureMode::default();
1147        assert!(!CaptureStatus::None.is_capturing());
1148        assert!(CaptureStatus::Active.is_capturing());
1149        assert!(CaptureStatus::Invalidated.is_invalidated());
1150    }
1151
1152    #[test]
1153    fn graph_pool_handle_without_cuda_returns_sentinel() {
1154        let h = graph_pool_handle();
1155        assert_eq!(h.0, 0, "stub handle is always zero without cuda feature");
1156        assert!(capture_pool_for_handle(h).is_none());
1157        release_graph_pool_handle(h); // no-op
1158    }
1159
1160    #[test]
1161    fn captured_graph_stub_num_replays_and_is_uploaded_are_zero() {
1162        let g = CapturedGraph;
1163        assert_eq!(g.num_replays(), 0);
1164        assert!(!g.is_uploaded());
1165    }
1166}