1use std::sync::Arc;
17
18use baracuda_cuda_sys::runtime::{
19 cudaGraphExec_t, cudaGraphNode_t, cudaGraph_t, runtime, types::cudaStreamCaptureStatus,
20};
21
22use crate::error::{check, Result};
23use crate::stream::Stream;
24
25#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
27pub enum CaptureMode {
28 Global,
29 #[default]
30 ThreadLocal,
31 Relaxed,
32}
33
34impl CaptureMode {
35 #[inline]
36 fn raw(self) -> i32 {
37 match self {
38 CaptureMode::Global => 0,
39 CaptureMode::ThreadLocal => 1,
40 CaptureMode::Relaxed => 2,
41 }
42 }
43}
44
45impl Stream {
46 pub fn begin_capture(&self, mode: CaptureMode) -> Result<()> {
48 let r = runtime()?;
49 let cu = r.cuda_stream_begin_capture()?;
50 check(unsafe { cu(self.as_raw(), mode.raw()) })
51 }
52
53 pub fn end_capture(&self) -> Result<Graph> {
55 let r = runtime()?;
56 let cu = r.cuda_stream_end_capture()?;
57 let mut graph: cudaGraph_t = core::ptr::null_mut();
58 check(unsafe { cu(self.as_raw(), &mut graph) })?;
59 Ok(Graph {
60 inner: Arc::new(GraphInner { handle: graph }),
61 })
62 }
63
64 pub fn capture<F>(&self, mode: CaptureMode, f: F) -> Result<Graph>
67 where
68 F: FnOnce(&Stream) -> Result<()>,
69 {
70 self.begin_capture(mode)?;
71 let inner_result = f(self);
72 let end_result = self.end_capture();
73 match (inner_result, end_result) {
74 (Ok(()), Ok(graph)) => Ok(graph),
75 (Err(e), _) => Err(e),
76 (Ok(()), Err(e)) => Err(e),
77 }
78 }
79
80 pub fn is_capturing(&self) -> Result<bool> {
82 let r = runtime()?;
83 let cu = r.cuda_stream_is_capturing()?;
84 let mut status: core::ffi::c_int = 0;
85 check(unsafe { cu(self.as_raw(), &mut status) })?;
86 Ok(status == cudaStreamCaptureStatus::ACTIVE)
87 }
88}
89
90#[derive(Clone)]
92pub struct Graph {
93 inner: Arc<GraphInner>,
94}
95
96struct GraphInner {
97 handle: cudaGraph_t,
98}
99
100unsafe impl Send for GraphInner {}
101unsafe impl Sync for GraphInner {}
102
103impl core::fmt::Debug for GraphInner {
104 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
105 f.debug_struct("Graph")
106 .field("handle", &self.handle)
107 .finish_non_exhaustive()
108 }
109}
110
111impl core::fmt::Debug for Graph {
112 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
113 self.inner.fmt(f)
114 }
115}
116
117impl Graph {
118 pub fn new() -> Result<Self> {
120 let r = runtime()?;
121 let cu = r.cuda_graph_create()?;
122 let mut graph: cudaGraph_t = core::ptr::null_mut();
123 check(unsafe { cu(&mut graph, 0) })?;
124 Ok(Self {
125 inner: Arc::new(GraphInner { handle: graph }),
126 })
127 }
128
129 pub fn instantiate(&self) -> Result<GraphExec> {
131 let r = runtime()?;
132 let cu = r.cuda_graph_instantiate()?;
133 let mut exec: cudaGraphExec_t = core::ptr::null_mut();
134 check(unsafe { cu(&mut exec, self.inner.handle, 0) })?;
135 Ok(GraphExec {
136 inner: Arc::new(GraphExecInner { handle: exec }),
137 })
138 }
139
140 pub fn node_count(&self) -> Result<usize> {
142 let r = runtime()?;
143 let cu = r.cuda_graph_get_nodes()?;
144 let mut count: usize = 0;
145 check(unsafe { cu(self.inner.handle, core::ptr::null_mut(), &mut count) })?;
146 Ok(count)
147 }
148
149 #[inline]
150 pub fn as_raw(&self) -> cudaGraph_t {
151 self.inner.handle
152 }
153
154 pub fn add_empty_node(&self, dependencies: &[GraphNode]) -> Result<GraphNode> {
156 let r = runtime()?;
157 let cu = r.cuda_graph_add_empty_node()?;
158 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
159 let (dp, dl) = deps_raw(&deps);
160 let mut node: cudaGraphNode_t = core::ptr::null_mut();
161 check(unsafe { cu(&mut node, self.inner.handle, dp, dl) })?;
162 Ok(GraphNode { raw: node })
163 }
164
165 pub unsafe fn add_kernel_node(
172 &self,
173 dependencies: &[GraphNode],
174 kernel: &crate::Kernel,
175 grid: crate::Dim3,
176 block: crate::Dim3,
177 shared_mem_bytes: u32,
178 args: &mut [*mut core::ffi::c_void],
179 ) -> Result<GraphNode> {
180 use baracuda_cuda_sys::runtime::types::{cudaKernelNodeParams, dim3};
181 let r = runtime()?;
182 let cu = r.cuda_graph_add_kernel_node()?;
183 let params = cudaKernelNodeParams {
184 func: kernel.as_launch_ptr() as *mut core::ffi::c_void,
185 grid_dim: dim3::new(grid.x, grid.y, grid.z),
186 block_dim: dim3::new(block.x, block.y, block.z),
187 shared_mem_bytes,
188 kernel_params: if args.is_empty() {
189 core::ptr::null_mut()
190 } else {
191 args.as_mut_ptr()
192 },
193 extra: core::ptr::null_mut(),
194 };
195 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
196 let (dp, dl) = deps_raw(&deps);
197 let mut node: cudaGraphNode_t = core::ptr::null_mut();
198 check(cu(&mut node, self.inner.handle, dp, dl, ¶ms))?;
199 Ok(GraphNode { raw: node })
200 }
201
202 pub fn add_memset_u32_node(
204 &self,
205 dependencies: &[GraphNode],
206 dst: *mut core::ffi::c_void,
207 value: u32,
208 count: usize,
209 ) -> Result<GraphNode> {
210 use baracuda_cuda_sys::runtime::types::cudaMemsetParams;
211 let r = runtime()?;
212 let cu = r.cuda_graph_add_memset_node()?;
213 let params = cudaMemsetParams {
214 dst,
215 pitch: 0,
216 value,
217 element_size: 4,
218 width: count,
219 height: 1,
220 };
221 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
222 let (dp, dl) = deps_raw(&deps);
223 let mut node: cudaGraphNode_t = core::ptr::null_mut();
224 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, ¶ms) })?;
225 Ok(GraphNode { raw: node })
226 }
227
228 pub unsafe fn add_host_node(
236 &self,
237 dependencies: &[GraphNode],
238 fn_: unsafe extern "C" fn(*mut core::ffi::c_void),
239 user_data: *mut core::ffi::c_void,
240 ) -> Result<GraphNode> {
241 use baracuda_cuda_sys::runtime::types::cudaHostNodeParams;
242 let r = runtime()?;
243 let cu = r.cuda_graph_add_host_node()?;
244 let params = cudaHostNodeParams {
245 fn_: Some(fn_),
246 user_data,
247 };
248 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
249 let (dp, dl) = deps_raw(&deps);
250 let mut node: cudaGraphNode_t = core::ptr::null_mut();
251 check(cu(&mut node, self.inner.handle, dp, dl, ¶ms))?;
252 Ok(GraphNode { raw: node })
253 }
254
255 pub fn add_child_graph_node(
257 &self,
258 dependencies: &[GraphNode],
259 child: &Graph,
260 ) -> Result<GraphNode> {
261 let r = runtime()?;
262 let cu = r.cuda_graph_add_child_graph_node()?;
263 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
264 let (dp, dl) = deps_raw(&deps);
265 let mut node: cudaGraphNode_t = core::ptr::null_mut();
266 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, child.as_raw()) })?;
267 Ok(GraphNode { raw: node })
268 }
269
270 pub fn add_event_record_node(
272 &self,
273 dependencies: &[GraphNode],
274 event: &crate::Event,
275 ) -> Result<GraphNode> {
276 let r = runtime()?;
277 let cu = r.cuda_graph_add_event_record_node()?;
278 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
279 let (dp, dl) = deps_raw(&deps);
280 let mut node: cudaGraphNode_t = core::ptr::null_mut();
281 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, event.as_raw()) })?;
282 Ok(GraphNode { raw: node })
283 }
284
285 pub fn add_event_wait_node(
287 &self,
288 dependencies: &[GraphNode],
289 event: &crate::Event,
290 ) -> Result<GraphNode> {
291 let r = runtime()?;
292 let cu = r.cuda_graph_add_event_wait_node()?;
293 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
294 let (dp, dl) = deps_raw(&deps);
295 let mut node: cudaGraphNode_t = core::ptr::null_mut();
296 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, event.as_raw()) })?;
297 Ok(GraphNode { raw: node })
298 }
299
300 pub fn add_mem_alloc_node(
303 &self,
304 dependencies: &[GraphNode],
305 device: &crate::Device,
306 bytesize: usize,
307 ) -> Result<(GraphNode, *mut core::ffi::c_void)> {
308 use baracuda_cuda_sys::runtime::types::{
309 cudaMemAllocNodeParams, cudaMemAllocationHandleType, cudaMemAllocationType,
310 cudaMemLocation, cudaMemLocationType, cudaMemPoolProps,
311 };
312 let r = runtime()?;
313 let cu = r.cuda_graph_add_mem_alloc_node()?;
314 let mut params = cudaMemAllocNodeParams {
315 pool_props: cudaMemPoolProps {
316 alloc_type: cudaMemAllocationType::PINNED,
317 handle_types: cudaMemAllocationHandleType::NONE,
318 location: cudaMemLocation {
319 type_: cudaMemLocationType::DEVICE,
320 id: device.ordinal(),
321 },
322 ..Default::default()
323 },
324 access_descs: core::ptr::null(),
325 access_desc_count: 0,
326 bytesize,
327 dptr: core::ptr::null_mut(),
328 };
329 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
330 let (dp, dl) = deps_raw(&deps);
331 let mut node: cudaGraphNode_t = core::ptr::null_mut();
332 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, &mut params) })?;
333 Ok((GraphNode { raw: node }, params.dptr))
334 }
335
336 pub unsafe fn add_mem_free_node(
343 &self,
344 dependencies: &[GraphNode],
345 dptr: *mut core::ffi::c_void,
346 ) -> Result<GraphNode> {
347 let r = runtime()?;
348 let cu = r.cuda_graph_add_mem_free_node()?;
349 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
350 let (dp, dl) = deps_raw(&deps);
351 let mut node: cudaGraphNode_t = core::ptr::null_mut();
352 check(cu(&mut node, self.inner.handle, dp, dl, dptr))?;
353 Ok(GraphNode { raw: node })
354 }
355
356 pub fn conditional_handle_create(&self, default_launch_value: u32, flags: u32) -> Result<u64> {
364 use baracuda_types::{supports, Feature};
365 let installed = crate::init::driver_version()?;
366 if !supports(installed, Feature::GraphConditionalNodes) {
367 return Err(crate::error::Error::FeatureNotSupported {
368 api: "cudaGraphConditionalHandleCreate",
369 since: Feature::GraphConditionalNodes.required_version(),
370 });
371 }
372 let r = runtime()?;
373 let cu = r.cuda_graph_conditional_handle_create()?;
374 let mut handle: u64 = 0;
375 check(unsafe { cu(&mut handle, self.inner.handle, default_launch_value, flags) })?;
376 Ok(handle)
377 }
378
379 pub unsafe fn add_node_raw(
391 &self,
392 dependencies: &[GraphNode],
393 node_params: *mut core::ffi::c_void,
394 ) -> Result<GraphNode> {
395 let r = runtime()?;
396 let cu = r.cuda_graph_add_node()?;
397 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
398 let (dp, dl) = deps_raw(&deps);
399 let mut node: cudaGraphNode_t = core::ptr::null_mut();
400 check(cu(&mut node, self.inner.handle, dp, dl, node_params))?;
401 Ok(GraphNode { raw: node })
402 }
403
404 pub fn add_dependencies(&self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
406 assert_eq!(from.len(), to.len());
407 if from.is_empty() {
408 return Ok(());
409 }
410 let r = runtime()?;
411 let cu = r.cuda_graph_add_dependencies()?;
412 let f: Vec<_> = from.iter().map(|n| n.raw).collect();
413 let t: Vec<_> = to.iter().map(|n| n.raw).collect();
414 check(unsafe { cu(self.inner.handle, f.as_ptr(), t.as_ptr(), f.len()) })
415 }
416}
417
418fn deps_raw(deps: &[cudaGraphNode_t]) -> (*const cudaGraphNode_t, usize) {
419 if deps.is_empty() {
420 (core::ptr::null(), 0)
421 } else {
422 (deps.as_ptr(), deps.len())
423 }
424}
425
426#[derive(Copy, Clone, Debug)]
429pub struct GraphNode {
430 raw: cudaGraphNode_t,
431}
432
433impl GraphNode {
434 #[inline]
435 pub fn as_raw(&self) -> cudaGraphNode_t {
436 self.raw
437 }
438
439 pub fn node_type(&self) -> Result<i32> {
444 let r = runtime()?;
445 let cu = r.cuda_graph_node_get_type()?;
446 let mut t: core::ffi::c_int = 0;
447 check(unsafe { cu(self.raw, &mut t) })?;
448 Ok(t)
449 }
450
451 pub fn mem_free_ptr(&self) -> Result<*mut core::ffi::c_void> {
453 let r = runtime()?;
454 let cu = r.cuda_graph_mem_free_node_get_params()?;
455 let mut p: *mut core::ffi::c_void = core::ptr::null_mut();
456 check(unsafe { cu(self.raw, &mut p) })?;
457 Ok(p)
458 }
459}
460
461impl Drop for GraphInner {
462 fn drop(&mut self) {
463 if let Ok(r) = runtime() {
464 if let Ok(cu) = r.cuda_graph_destroy() {
465 let _ = unsafe { cu(self.handle) };
466 }
467 }
468 }
469}
470
471#[derive(Clone)]
473pub struct GraphExec {
474 inner: Arc<GraphExecInner>,
475}
476
477struct GraphExecInner {
478 handle: cudaGraphExec_t,
479}
480
481unsafe impl Send for GraphExecInner {}
482unsafe impl Sync for GraphExecInner {}
483
484impl core::fmt::Debug for GraphExecInner {
485 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
486 f.debug_struct("GraphExec")
487 .field("handle", &self.handle)
488 .finish_non_exhaustive()
489 }
490}
491
492impl core::fmt::Debug for GraphExec {
493 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
494 self.inner.fmt(f)
495 }
496}
497
498impl GraphExec {
499 pub fn launch(&self, stream: &Stream) -> Result<()> {
501 let r = runtime()?;
502 let cu = r.cuda_graph_launch()?;
503 check(unsafe { cu(self.inner.handle, stream.as_raw()) })
504 }
505
506 pub fn update(&self, new_template: &Graph) -> Result<UpdateResult> {
510 let r = runtime()?;
511 let cu = r.cuda_graph_exec_update()?;
512 let mut error_node: cudaGraphNode_t = core::ptr::null_mut();
513 let mut result: core::ffi::c_int = 0;
514 let rc = unsafe {
519 cu(
520 self.inner.handle,
521 new_template.as_raw(),
522 &mut error_node,
523 &mut result,
524 )
525 };
526 if rc != baracuda_cuda_sys::runtime::cudaError_t::Success
527 && result == baracuda_cuda_sys::runtime::types::cudaGraphExecUpdateResult::SUCCESS
528 {
529 return Err(crate::error::Error::Status { status: rc });
530 }
531 Ok(UpdateResult {
532 result,
533 error_node: if error_node.is_null() {
534 None
535 } else {
536 Some(GraphNode { raw: error_node })
537 },
538 })
539 }
540
541 #[inline]
542 pub fn as_raw(&self) -> cudaGraphExec_t {
543 self.inner.handle
544 }
545}
546
547#[derive(Clone, Debug)]
551pub struct UpdateResult {
552 pub result: core::ffi::c_int,
553 pub error_node: Option<GraphNode>,
554}
555
556impl UpdateResult {
557 pub fn is_success(&self) -> bool {
558 self.result == baracuda_cuda_sys::runtime::types::cudaGraphExecUpdateResult::SUCCESS
559 }
560}
561
562impl Drop for GraphExecInner {
563 fn drop(&mut self) {
564 if let Ok(r) = runtime() {
565 if let Ok(cu) = r.cuda_graph_exec_destroy() {
566 let _ = unsafe { cu(self.handle) };
567 }
568 }
569 }
570}