oxicuda_launch/graph_launch.rs
1//! Graph-based kernel launch capture and replay.
2//!
3//! CUDA Graphs allow recording a sequence of operations (kernel launches,
4//! memory copies, etc.) and replaying them as a single unit with reduced
5//! launch overhead. This module provides a lightweight capture facility
6//! that records kernel launch configurations for later replay or
7//! analysis.
8//!
9//! # Example
10//!
11//! ```rust,no_run
12//! # use oxicuda_launch::graph_launch::{GraphLaunchCapture, LaunchRecord};
13//! # use oxicuda_launch::{LaunchParams, Dim3};
14//! let mut capture = GraphLaunchCapture::begin();
15//! // In a real scenario you would record actual kernel launches:
16//! // capture.record_launch(&kernel, ¶ms);
17//! let records = capture.end();
18//! println!("captured {} launches", records.len());
19//! ```
20
21use crate::kernel::Kernel;
22use crate::params::LaunchParams;
23
24// ---------------------------------------------------------------------------
25// LaunchRecord
26// ---------------------------------------------------------------------------
27
28/// A recorded kernel launch operation.
29///
30/// Captures the kernel name and launch parameters at the time of
31/// recording. This is a lightweight snapshot that does not retain
32/// references to the kernel or its module.
33#[derive(Debug, Clone)]
34pub struct LaunchRecord {
35 /// The name of the kernel function that was recorded.
36 kernel_name: String,
37 /// The launch configuration (grid, block, shared memory).
38 params: LaunchParams,
39}
40
41impl LaunchRecord {
42 /// Returns the kernel function name.
43 #[inline]
44 pub fn kernel_name(&self) -> &str {
45 &self.kernel_name
46 }
47
48 /// Returns the recorded launch parameters.
49 #[inline]
50 pub fn params(&self) -> &LaunchParams {
51 &self.params
52 }
53}
54
55impl std::fmt::Display for LaunchRecord {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 write!(
58 f,
59 "LaunchRecord(kernel={}, grid={}x{}x{}, block={}x{}x{})",
60 self.kernel_name,
61 self.params.grid.x,
62 self.params.grid.y,
63 self.params.grid.z,
64 self.params.block.x,
65 self.params.block.y,
66 self.params.block.z,
67 )
68 }
69}
70
71// ---------------------------------------------------------------------------
72// GraphLaunchCapture
73// ---------------------------------------------------------------------------
74
75/// Captures a sequence of kernel launches for graph-based replay.
76///
77/// Create with [`begin`](Self::begin), record launches with
78/// [`record_launch`](Self::record_launch), and finalise with
79/// [`end`](Self::end) to obtain the list of recorded operations.
80///
81/// This is a host-side recording facility. On systems with GPU support,
82/// the recorded operations can be converted into a CUDA graph for
83/// optimised replay via the `oxicuda-driver` graph API.
84#[derive(Debug)]
85pub struct GraphLaunchCapture {
86 /// The sequence of recorded kernel launches.
87 stream_nodes: Vec<LaunchRecord>,
88 /// Whether the capture is currently active.
89 active: bool,
90}
91
92impl GraphLaunchCapture {
93 /// Begins a new graph launch capture session.
94 ///
95 /// Returns a capture object in the active state. Record launches
96 /// using [`record_launch`](Self::record_launch) and finalise
97 /// with [`end`](Self::end).
98 pub fn begin() -> Self {
99 Self {
100 stream_nodes: Vec::new(),
101 active: true,
102 }
103 }
104
105 /// Records a kernel launch into the capture sequence.
106 ///
107 /// The kernel name and launch parameters are snapshot at the time
108 /// of recording. If the capture is not active (i.e., [`end`](Self::end)
109 /// has already been called), this method is a no-op.
110 pub fn record_launch(&mut self, kernel: &Kernel, params: &LaunchParams) {
111 if !self.active {
112 return;
113 }
114 self.stream_nodes.push(LaunchRecord {
115 kernel_name: kernel.name().to_owned(),
116 params: *params,
117 });
118 }
119
120 /// Records a kernel launch directly from a kernel name and params.
121 ///
122 /// This is a lower-level variant of [`record_launch`](Self::record_launch)
123 /// that does not require a [`Kernel`] handle — useful for testing and
124 /// for replaying recorded descriptions.
125 pub fn record_raw(&mut self, kernel_name: impl Into<String>, params: LaunchParams) {
126 if !self.active {
127 return;
128 }
129 self.stream_nodes.push(LaunchRecord {
130 kernel_name: kernel_name.into(),
131 params,
132 });
133 }
134
135 /// Ends the capture session and returns all recorded launches.
136 ///
137 /// After calling this method, the capture is no longer active and
138 /// further calls to [`record_launch`](Self::record_launch) are
139 /// ignored.
140 pub fn end(mut self) -> Vec<LaunchRecord> {
141 self.active = false;
142 std::mem::take(&mut self.stream_nodes)
143 }
144
145 /// Returns the number of launches recorded so far.
146 #[inline]
147 pub fn len(&self) -> usize {
148 self.stream_nodes.len()
149 }
150
151 /// Returns `true` if no launches have been recorded.
152 #[inline]
153 pub fn is_empty(&self) -> bool {
154 self.stream_nodes.is_empty()
155 }
156
157 /// Returns `true` if the capture is currently active.
158 #[inline]
159 pub fn is_active(&self) -> bool {
160 self.active
161 }
162}
163
164// ---------------------------------------------------------------------------
165// Tests
166// ---------------------------------------------------------------------------
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::grid::Dim3;
172
173 #[test]
174 fn capture_begin_is_active() {
175 let capture = GraphLaunchCapture::begin();
176 assert!(capture.is_active());
177 assert!(capture.is_empty());
178 assert_eq!(capture.len(), 0);
179 }
180
181 #[test]
182 fn capture_end_returns_empty_vec() {
183 let capture = GraphLaunchCapture::begin();
184 let records = capture.end();
185 assert!(records.is_empty());
186 }
187
188 #[test]
189 fn launch_record_display() {
190 let record = LaunchRecord {
191 kernel_name: "vector_add".to_owned(),
192 params: LaunchParams::new(Dim3::x(4), Dim3::x(256)),
193 };
194 let s = format!("{record}");
195 assert!(s.contains("vector_add"));
196 assert!(s.contains("4x1x1"));
197 assert!(s.contains("256x1x1"));
198 }
199
200 #[test]
201 fn launch_record_accessors() {
202 let record = LaunchRecord {
203 kernel_name: "my_kernel".to_owned(),
204 params: LaunchParams::new(8u32, 128u32),
205 };
206 assert_eq!(record.kernel_name(), "my_kernel");
207 assert_eq!(record.params().grid.x, 8);
208 assert_eq!(record.params().block.x, 128);
209 }
210
211 #[test]
212 fn capture_debug() {
213 let capture = GraphLaunchCapture::begin();
214 let dbg = format!("{capture:?}");
215 assert!(dbg.contains("GraphLaunchCapture"));
216 assert!(dbg.contains("active: true"));
217 }
218
219 #[test]
220 fn launch_record_clone() {
221 let record = LaunchRecord {
222 kernel_name: "clone_test".to_owned(),
223 params: LaunchParams::new(2u32, 64u32),
224 };
225 let cloned = record.clone();
226 assert_eq!(cloned.kernel_name(), record.kernel_name());
227 assert_eq!(cloned.params().grid.x, record.params().grid.x);
228 }
229
230 // ---------------------------------------------------------------------------
231 // Quality gate tests (CPU-only)
232 // ---------------------------------------------------------------------------
233
234 #[test]
235 fn graph_capture_records_launches() {
236 // GraphLaunchCapture::begin() creates an empty capture; after pushing one
237 // record via record_raw, len() == 1.
238 let mut capture = GraphLaunchCapture::begin();
239 assert_eq!(capture.len(), 0);
240 assert!(capture.is_empty());
241
242 capture.record_raw("vector_add", LaunchParams::new(Dim3::x(4), Dim3::x(256)));
243 assert_eq!(capture.len(), 1);
244 assert!(!capture.is_empty());
245 }
246
247 #[test]
248 fn graph_record_contains_params() {
249 // A LaunchRecord stores grid/block dims accurately.
250 let params = LaunchParams::new(Dim3::new(8, 2, 1), Dim3::new(32, 8, 1));
251 let record = LaunchRecord {
252 kernel_name: "my_kernel".to_owned(),
253 params,
254 };
255 assert_eq!(record.params().grid.x, 8);
256 assert_eq!(record.params().grid.y, 2);
257 assert_eq!(record.params().grid.z, 1);
258 assert_eq!(record.params().block.x, 32);
259 assert_eq!(record.params().block.y, 8);
260 assert_eq!(record.params().block.z, 1);
261 }
262
263 #[test]
264 fn graph_replay_count() {
265 // After recording 3 launches via record_raw, the records vector has length 3.
266 let mut capture = GraphLaunchCapture::begin();
267 let params = LaunchParams::new(Dim3::x(4), Dim3::x(128));
268 capture.record_raw("kernel_a", params);
269 capture.record_raw("kernel_b", params);
270 capture.record_raw("kernel_c", params);
271
272 assert_eq!(capture.len(), 3);
273
274 // end() consumes the capture and returns all records
275 let records = capture.end();
276 assert_eq!(records.len(), 3);
277 assert_eq!(records[0].kernel_name(), "kernel_a");
278 assert_eq!(records[1].kernel_name(), "kernel_b");
279 assert_eq!(records[2].kernel_name(), "kernel_c");
280 }
281}