Skip to main content

ferrotorch_gpu/
graph.rs

1//! CUDA graph capture and replay infrastructure.
2//!
3//! A CUDA graph records a sequence of GPU operations (kernel launches, memcpys)
4//! and replays them as a single driver submission. This eliminates per-kernel
5//! launch overhead (~70μs on WSL2, ~5μs on native Linux per call) by collapsing
6//! hundreds of launches into one.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! use ferrotorch_gpu::graph::{DeviceScalar, begin_capture, end_capture};
12//!
13//! // Pre-allocate all buffers BEFORE capture
14//! let mut out = alloc_zeros_f32(768, &device)?;
15//!
16//! // Parameters that change between replays go in DeviceScalar
17//! let mut pos = DeviceScalar::new(device.stream(), 0u32)?;
18//!
19//! // Capture
20//! begin_capture(device.stream())?;
21//! gpu_add_into(&a, &b, &mut out, &device)?;  // recorded, not executed
22//! let graph = end_capture(device.stream())?;
23//!
24//! // Replay loop
25//! for i in 0..100 {
26//!     pos.update(i as u32)?;  // memcpy before replay
27//!     graph.launch()?;         // replay all captured ops
28//! }
29//! ```
30
31#[cfg(feature = "cuda")]
32use std::sync::Arc;
33
34#[cfg(feature = "cuda")]
35use cudarc::driver::{CudaSlice, CudaStream, DeviceRepr, ValidAsZeroBits};
36
37use crate::error::{GpuError, GpuResult};
38
39// ---------------------------------------------------------------------------
40// DeviceScalar — a single value in GPU memory, updatable before graph replay
41// ---------------------------------------------------------------------------
42
43/// A single scalar value stored in GPU device memory.
44///
45/// Used for CUDA graph capture: the graph records the device pointer (fixed
46/// address), and the caller updates the value via [`update`](DeviceScalar::update)
47/// before each [`CapturedGraph::launch`]. The update is a 4-or-8 byte
48/// `cuMemcpyHtoDAsync` — effectively zero cost.
49#[cfg(feature = "cuda")]
50pub struct DeviceScalar<T: DeviceRepr + ValidAsZeroBits + Copy> {
51    buf: CudaSlice<T>,
52    stream: Arc<CudaStream>,
53}
54
55#[cfg(feature = "cuda")]
56impl<T: DeviceRepr + ValidAsZeroBits + Copy> DeviceScalar<T> {
57    /// Allocate a device scalar with the given initial value.
58    pub fn new(stream: &Arc<CudaStream>, initial: T) -> GpuResult<Self> {
59        let buf = stream.clone_htod(&[initial])?;
60        Ok(Self {
61            buf,
62            stream: Arc::clone(stream),
63        })
64    }
65
66    /// Update the device value. This is an async H→D memcpy of `size_of::<T>()`
67    /// bytes. Must be called on the same stream as the graph to ensure ordering.
68    pub fn update(&mut self, value: T) -> GpuResult<()> {
69        self.stream.memcpy_htod(&[value], &mut self.buf)?;
70        Ok(())
71    }
72
73    /// Borrow the underlying `CudaSlice` for use as a kernel parameter.
74    /// The graph captures this pointer address; updating the value later
75    /// changes what the kernel reads without re-capturing.
76    #[inline]
77    pub fn inner(&self) -> &CudaSlice<T> {
78        &self.buf
79    }
80}
81
82// ---------------------------------------------------------------------------
83// CapturedGraph — a replayable CUDA graph
84// ---------------------------------------------------------------------------
85
86/// A captured and instantiated CUDA graph that can be replayed with
87/// [`launch`](CapturedGraph::launch).
88///
89/// Created via [`begin_capture`] + GPU ops + [`end_capture`].
90/// The graph holds references to all device memory used during capture.
91/// Those buffers must remain allocated for the lifetime of the graph.
92#[cfg(feature = "cuda")]
93pub struct CapturedGraph {
94    graph: cudarc::driver::CudaGraph,
95}
96
97#[cfg(feature = "cuda")]
98impl CapturedGraph {
99    /// Replay all operations captured in this graph.
100    ///
101    /// Before calling this, update any [`DeviceScalar`] values and perform
102    /// any pre-launch memcpys (e.g., position embeddings). All updates must
103    /// be on the same stream the graph was captured on.
104    pub fn launch(&self) -> GpuResult<()> {
105        self.graph.launch()?;
106        Ok(())
107    }
108}
109
110// ---------------------------------------------------------------------------
111// Capture API
112// ---------------------------------------------------------------------------
113
114/// Begin CUDA graph capture on the given stream.
115///
116/// All GPU operations (kernel launches, cuBLAS calls, memcpys) issued on this
117/// stream after this call are **recorded but not executed**. Call
118/// [`end_capture`] to finalize and instantiate the graph.
119///
120/// # Requirements
121///
122/// - All output buffers must be pre-allocated before capture begins.
123/// - No `alloc_zeros` / `cpu_to_gpu` during capture (use `_into` variants).
124/// - No CPU↔GPU synchronization during capture.
125/// - Event tracking should be disabled during capture to avoid interference
126///   (call `ctx.disable_event_tracking()` before, re-enable after).
127#[cfg(feature = "cuda")]
128pub fn begin_capture(stream: &Arc<CudaStream>) -> GpuResult<()> {
129    stream.begin_capture(
130        cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
131    )?;
132    Ok(())
133}
134
135/// End CUDA graph capture, instantiate, and return the replayable graph.
136///
137/// Returns `Err` if capture was not active or if instantiation fails.
138#[cfg(feature = "cuda")]
139pub fn end_capture(stream: &Arc<CudaStream>) -> GpuResult<CapturedGraph> {
140    let flags = cudarc::driver::sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
141    let graph = stream
142        .end_capture(flags)?
143        .ok_or(GpuError::PtxCompileFailed {
144            kernel: "CUDA graph capture returned null",
145        })?;
146    Ok(CapturedGraph { graph })
147}
148
149// ---------------------------------------------------------------------------
150// CapturePool — memory pool for graph capture
151// ---------------------------------------------------------------------------
152
153/// A dedicated memory pool for CUDA graph capture.
154///
155/// During graph capture, allocations must come from a pool that is not
156/// sealed. Once sealed, the pool rejects new allocations — this is used
157/// to ensure that all buffers are pre-allocated before capture begins.
158///
159/// # Usage
160///
161/// ```ignore
162/// let pool = CapturePool::new();
163/// // ... allocate buffers from pool ...
164/// pool.seal();  // no more allocations allowed
165///
166/// // begin_capture_with_pool(&pool, stream) would fail here because
167/// // the pool is sealed — you can't allocate during capture from a
168/// // sealed pool. Un-seal it first or use a fresh pool.
169/// ```
170#[cfg(feature = "cuda")]
171pub struct CapturePool {
172    sealed: std::sync::atomic::AtomicBool,
173}
174
175#[cfg(feature = "cuda")]
176impl CapturePool {
177    /// Create a new, unsealed capture pool.
178    pub fn new() -> Self {
179        Self {
180            sealed: std::sync::atomic::AtomicBool::new(false),
181        }
182    }
183
184    /// Seal the pool, preventing any further allocations.
185    pub fn seal(&self) {
186        self.sealed
187            .store(true, std::sync::atomic::Ordering::Release);
188    }
189
190    /// Unseal the pool, allowing allocations again.
191    pub fn unseal(&self) {
192        self.sealed
193            .store(false, std::sync::atomic::Ordering::Release);
194    }
195
196    /// Check whether the pool is sealed.
197    pub fn is_capture_pool_sealed(&self) -> bool {
198        self.sealed.load(std::sync::atomic::Ordering::Acquire)
199    }
200}
201
202#[cfg(feature = "cuda")]
203impl Default for CapturePool {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209/// Begin CUDA graph capture with a capture pool.
210///
211/// Like [`begin_capture`], but checks that the capture pool is not sealed
212/// before starting capture. A sealed pool cannot satisfy allocations
213/// during graph capture, which would cause CUDA errors.
214///
215/// # Errors
216///
217/// Returns [`GpuError::InvalidArgument`](GpuError) if the pool is sealed.
218/// Returns a CUDA driver error if `begin_capture` fails.
219#[cfg(feature = "cuda")]
220pub fn begin_capture_with_pool(pool: &CapturePool, stream: &Arc<CudaStream>) -> GpuResult<()> {
221    if pool.is_capture_pool_sealed() {
222        return Err(GpuError::InvalidState {
223            message: "cannot begin graph capture: capture pool is sealed".into(),
224        });
225    }
226    begin_capture(stream)
227}
228
229/// Stub CapturePool when cuda feature is disabled.
230#[cfg(not(feature = "cuda"))]
231pub struct CapturePool;
232
233#[cfg(not(feature = "cuda"))]
234impl CapturePool {
235    pub fn new() -> Self {
236        Self
237    }
238
239    pub fn seal(&self) {}
240
241    pub fn unseal(&self) {}
242
243    pub fn is_capture_pool_sealed(&self) -> bool {
244        false
245    }
246}
247
248#[cfg(not(feature = "cuda"))]
249impl Default for CapturePool {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255/// Stub begin_capture_with_pool when cuda feature is disabled.
256#[cfg(not(feature = "cuda"))]
257pub fn begin_capture_with_pool<T>(_pool: &CapturePool, _stream: &T) -> GpuResult<()> {
258    Err(GpuError::NoCudaFeature)
259}
260
261// ---------------------------------------------------------------------------
262// Stubs when cuda feature is disabled
263// ---------------------------------------------------------------------------
264
265/// Stub DeviceScalar.
266#[cfg(not(feature = "cuda"))]
267pub struct DeviceScalar<T: Copy> {
268    _phantom: std::marker::PhantomData<T>,
269}
270
271/// Stub CapturedGraph.
272#[cfg(not(feature = "cuda"))]
273pub struct CapturedGraph;
274
275#[cfg(not(feature = "cuda"))]
276impl CapturedGraph {
277    pub fn launch(&self) -> GpuResult<()> {
278        Err(GpuError::NoCudaFeature)
279    }
280}
281
282#[cfg(not(feature = "cuda"))]
283pub fn begin_capture<T>(_stream: &T) -> GpuResult<()> {
284    Err(GpuError::NoCudaFeature)
285}
286
287#[cfg(not(feature = "cuda"))]
288pub fn end_capture<T>(_stream: &T) -> GpuResult<CapturedGraph> {
289    Err(GpuError::NoCudaFeature)
290}