Skip to main content

oxicuda_launch/
graph_launch.rs

1//! Graph-based kernel launch capture and replay.
2//!
3//! CUDA Graphs allow recording a sequence of operations (kernel launches,
4//! memory copies, etc.) and replaying them as a single unit with reduced
5//! launch overhead. This module provides a lightweight capture facility
6//! that records kernel launch configurations for later replay or
7//! analysis.
8//!
9//! # Example
10//!
11//! ```rust,no_run
12//! # use oxicuda_launch::graph_launch::{GraphLaunchCapture, LaunchRecord};
13//! # use oxicuda_launch::{LaunchParams, Dim3};
14//! let mut capture = GraphLaunchCapture::begin();
15//! // In a real scenario you would record actual kernel launches:
16//! // capture.record_launch(&kernel, &params);
17//! let records = capture.end();
18//! println!("captured {} launches", records.len());
19//! ```
20
21use crate::kernel::Kernel;
22use crate::params::LaunchParams;
23
24// ---------------------------------------------------------------------------
25// LaunchRecord
26// ---------------------------------------------------------------------------
27
28/// A recorded kernel launch operation.
29///
30/// Captures the kernel name and launch parameters at the time of
31/// recording. This is a lightweight snapshot that does not retain
32/// references to the kernel or its module.
33#[derive(Debug, Clone)]
34pub struct LaunchRecord {
35    /// The name of the kernel function that was recorded.
36    kernel_name: String,
37    /// The launch configuration (grid, block, shared memory).
38    params: LaunchParams,
39}
40
41impl LaunchRecord {
42    /// Returns the kernel function name.
43    #[inline]
44    pub fn kernel_name(&self) -> &str {
45        &self.kernel_name
46    }
47
48    /// Returns the recorded launch parameters.
49    #[inline]
50    pub fn params(&self) -> &LaunchParams {
51        &self.params
52    }
53}
54
55impl std::fmt::Display for LaunchRecord {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        write!(
58            f,
59            "LaunchRecord(kernel={}, grid={}x{}x{}, block={}x{}x{})",
60            self.kernel_name,
61            self.params.grid.x,
62            self.params.grid.y,
63            self.params.grid.z,
64            self.params.block.x,
65            self.params.block.y,
66            self.params.block.z,
67        )
68    }
69}
70
71// ---------------------------------------------------------------------------
72// GraphLaunchCapture
73// ---------------------------------------------------------------------------
74
75/// Captures a sequence of kernel launches for graph-based replay.
76///
77/// Create with [`begin`](Self::begin), record launches with
78/// [`record_launch`](Self::record_launch), and finalise with
79/// [`end`](Self::end) to obtain the list of recorded operations.
80///
81/// This is a host-side recording facility. On systems with GPU support,
82/// the recorded operations can be converted into a CUDA graph for
83/// optimised replay via the `oxicuda-driver` graph API.
84#[derive(Debug)]
85pub struct GraphLaunchCapture {
86    /// The sequence of recorded kernel launches.
87    stream_nodes: Vec<LaunchRecord>,
88    /// Whether the capture is currently active.
89    active: bool,
90}
91
92impl GraphLaunchCapture {
93    /// Begins a new graph launch capture session.
94    ///
95    /// Returns a capture object in the active state. Record launches
96    /// using [`record_launch`](Self::record_launch) and finalise
97    /// with [`end`](Self::end).
98    pub fn begin() -> Self {
99        Self {
100            stream_nodes: Vec::new(),
101            active: true,
102        }
103    }
104
105    /// Records a kernel launch into the capture sequence.
106    ///
107    /// The kernel name and launch parameters are snapshot at the time
108    /// of recording. If the capture is not active (i.e., [`end`](Self::end)
109    /// has already been called), this method is a no-op.
110    pub fn record_launch(&mut self, kernel: &Kernel, params: &LaunchParams) {
111        if !self.active {
112            return;
113        }
114        self.stream_nodes.push(LaunchRecord {
115            kernel_name: kernel.name().to_owned(),
116            params: *params,
117        });
118    }
119
120    /// Records a kernel launch directly from a kernel name and params.
121    ///
122    /// This is a lower-level variant of [`record_launch`](Self::record_launch)
123    /// that does not require a [`Kernel`] handle — useful for testing and
124    /// for replaying recorded descriptions.
125    pub fn record_raw(&mut self, kernel_name: impl Into<String>, params: LaunchParams) {
126        if !self.active {
127            return;
128        }
129        self.stream_nodes.push(LaunchRecord {
130            kernel_name: kernel_name.into(),
131            params,
132        });
133    }
134
135    /// Ends the capture session and returns all recorded launches.
136    ///
137    /// After calling this method, the capture is no longer active and
138    /// further calls to [`record_launch`](Self::record_launch) are
139    /// ignored.
140    pub fn end(mut self) -> Vec<LaunchRecord> {
141        self.active = false;
142        std::mem::take(&mut self.stream_nodes)
143    }
144
145    /// Returns the number of launches recorded so far.
146    #[inline]
147    pub fn len(&self) -> usize {
148        self.stream_nodes.len()
149    }
150
151    /// Returns `true` if no launches have been recorded.
152    #[inline]
153    pub fn is_empty(&self) -> bool {
154        self.stream_nodes.is_empty()
155    }
156
157    /// Returns `true` if the capture is currently active.
158    #[inline]
159    pub fn is_active(&self) -> bool {
160        self.active
161    }
162}
163
164// ---------------------------------------------------------------------------
165// Tests
166// ---------------------------------------------------------------------------
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::grid::Dim3;
172
173    #[test]
174    fn capture_begin_is_active() {
175        let capture = GraphLaunchCapture::begin();
176        assert!(capture.is_active());
177        assert!(capture.is_empty());
178        assert_eq!(capture.len(), 0);
179    }
180
181    #[test]
182    fn capture_end_returns_empty_vec() {
183        let capture = GraphLaunchCapture::begin();
184        let records = capture.end();
185        assert!(records.is_empty());
186    }
187
188    #[test]
189    fn launch_record_display() {
190        let record = LaunchRecord {
191            kernel_name: "vector_add".to_owned(),
192            params: LaunchParams::new(Dim3::x(4), Dim3::x(256)),
193        };
194        let s = format!("{record}");
195        assert!(s.contains("vector_add"));
196        assert!(s.contains("4x1x1"));
197        assert!(s.contains("256x1x1"));
198    }
199
200    #[test]
201    fn launch_record_accessors() {
202        let record = LaunchRecord {
203            kernel_name: "my_kernel".to_owned(),
204            params: LaunchParams::new(8u32, 128u32),
205        };
206        assert_eq!(record.kernel_name(), "my_kernel");
207        assert_eq!(record.params().grid.x, 8);
208        assert_eq!(record.params().block.x, 128);
209    }
210
211    #[test]
212    fn capture_debug() {
213        let capture = GraphLaunchCapture::begin();
214        let dbg = format!("{capture:?}");
215        assert!(dbg.contains("GraphLaunchCapture"));
216        assert!(dbg.contains("active: true"));
217    }
218
219    #[test]
220    fn launch_record_clone() {
221        let record = LaunchRecord {
222            kernel_name: "clone_test".to_owned(),
223            params: LaunchParams::new(2u32, 64u32),
224        };
225        let cloned = record.clone();
226        assert_eq!(cloned.kernel_name(), record.kernel_name());
227        assert_eq!(cloned.params().grid.x, record.params().grid.x);
228    }
229
230    // ---------------------------------------------------------------------------
231    // Quality gate tests (CPU-only)
232    // ---------------------------------------------------------------------------
233
234    #[test]
235    fn graph_capture_records_launches() {
236        // GraphLaunchCapture::begin() creates an empty capture; after pushing one
237        // record via record_raw, len() == 1.
238        let mut capture = GraphLaunchCapture::begin();
239        assert_eq!(capture.len(), 0);
240        assert!(capture.is_empty());
241
242        capture.record_raw("vector_add", LaunchParams::new(Dim3::x(4), Dim3::x(256)));
243        assert_eq!(capture.len(), 1);
244        assert!(!capture.is_empty());
245    }
246
247    #[test]
248    fn graph_record_contains_params() {
249        // A LaunchRecord stores grid/block dims accurately.
250        let params = LaunchParams::new(Dim3::new(8, 2, 1), Dim3::new(32, 8, 1));
251        let record = LaunchRecord {
252            kernel_name: "my_kernel".to_owned(),
253            params,
254        };
255        assert_eq!(record.params().grid.x, 8);
256        assert_eq!(record.params().grid.y, 2);
257        assert_eq!(record.params().grid.z, 1);
258        assert_eq!(record.params().block.x, 32);
259        assert_eq!(record.params().block.y, 8);
260        assert_eq!(record.params().block.z, 1);
261    }
262
263    #[test]
264    fn graph_replay_count() {
265        // After recording 3 launches via record_raw, the records vector has length 3.
266        let mut capture = GraphLaunchCapture::begin();
267        let params = LaunchParams::new(Dim3::x(4), Dim3::x(128));
268        capture.record_raw("kernel_a", params);
269        capture.record_raw("kernel_b", params);
270        capture.record_raw("kernel_c", params);
271
272        assert_eq!(capture.len(), 3);
273
274        // end() consumes the capture and returns all records
275        let records = capture.end();
276        assert_eq!(records.len(), 3);
277        assert_eq!(records[0].kernel_name(), "kernel_a");
278        assert_eq!(records[1].kernel_name(), "kernel_b");
279        assert_eq!(records[2].kernel_name(), "kernel_c");
280    }
281}