1use std::sync::Arc;
17
18use baracuda_cuda_sys::runtime::{
19 cudaGraphExec_t, cudaGraphNode_t, cudaGraph_t, runtime, types::cudaStreamCaptureStatus,
20};
21
22use crate::error::{check, Result};
23use crate::stream::Stream;
24
25#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
27pub enum CaptureMode {
28 Global,
29 #[default]
30 ThreadLocal,
31 Relaxed,
32}
33
34impl CaptureMode {
35 #[inline]
36 fn raw(self) -> i32 {
37 match self {
38 CaptureMode::Global => 0,
39 CaptureMode::ThreadLocal => 1,
40 CaptureMode::Relaxed => 2,
41 }
42 }
43}
44
45impl Stream {
46 pub fn begin_capture(&self, mode: CaptureMode) -> Result<()> {
48 let r = runtime()?;
49 let cu = r.cuda_stream_begin_capture()?;
50 check(unsafe { cu(self.as_raw(), mode.raw()) })
51 }
52
53 pub fn end_capture(&self) -> Result<Graph> {
55 let r = runtime()?;
56 let cu = r.cuda_stream_end_capture()?;
57 let mut graph: cudaGraph_t = core::ptr::null_mut();
58 check(unsafe { cu(self.as_raw(), &mut graph) })?;
59 Ok(Graph {
60 inner: Arc::new(GraphInner { handle: graph }),
61 })
62 }
63
64 pub fn capture<F>(&self, mode: CaptureMode, f: F) -> Result<Graph>
76 where
77 F: FnOnce(&Stream) -> Result<()>,
78 {
79 struct CaptureGuard<'a> {
85 stream: &'a Stream,
86 armed: bool,
87 }
88 impl Drop for CaptureGuard<'_> {
89 fn drop(&mut self) {
90 if self.armed {
91 let _ = self.stream.end_capture();
92 }
93 }
94 }
95
96 self.begin_capture(mode)?;
97 let mut guard = CaptureGuard { stream: self, armed: true };
98 let inner_result = f(self);
99 guard.armed = false;
102 let end_result = self.end_capture();
103 match (inner_result, end_result) {
104 (Ok(()), Ok(graph)) => Ok(graph),
105 (Err(e), _) => Err(e),
106 (Ok(()), Err(e)) => Err(e),
107 }
108 }
109
110 pub fn is_capturing(&self) -> Result<bool> {
112 let r = runtime()?;
113 let cu = r.cuda_stream_is_capturing()?;
114 let mut status: core::ffi::c_int = 0;
115 check(unsafe { cu(self.as_raw(), &mut status) })?;
116 Ok(status == cudaStreamCaptureStatus::ACTIVE)
117 }
118}
119
120#[derive(Clone)]
122pub struct Graph {
123 inner: Arc<GraphInner>,
124}
125
126struct GraphInner {
127 handle: cudaGraph_t,
128}
129
130unsafe impl Send for GraphInner {}
131unsafe impl Sync for GraphInner {}
132
133impl core::fmt::Debug for GraphInner {
134 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
135 f.debug_struct("Graph")
136 .field("handle", &self.handle)
137 .finish_non_exhaustive()
138 }
139}
140
141impl core::fmt::Debug for Graph {
142 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
143 self.inner.fmt(f)
144 }
145}
146
147impl Graph {
148 pub fn new() -> Result<Self> {
150 let r = runtime()?;
151 let cu = r.cuda_graph_create()?;
152 let mut graph: cudaGraph_t = core::ptr::null_mut();
153 check(unsafe { cu(&mut graph, 0) })?;
154 Ok(Self {
155 inner: Arc::new(GraphInner { handle: graph }),
156 })
157 }
158
159 pub fn instantiate(&self) -> Result<GraphExec> {
161 let r = runtime()?;
162 let cu = r.cuda_graph_instantiate()?;
163 let mut exec: cudaGraphExec_t = core::ptr::null_mut();
164 check(unsafe { cu(&mut exec, self.inner.handle, 0) })?;
165 Ok(GraphExec {
166 inner: Arc::new(GraphExecInner { handle: exec }),
167 })
168 }
169
170 pub fn node_count(&self) -> Result<usize> {
172 let r = runtime()?;
173 let cu = r.cuda_graph_get_nodes()?;
174 let mut count: usize = 0;
175 check(unsafe { cu(self.inner.handle, core::ptr::null_mut(), &mut count) })?;
176 Ok(count)
177 }
178
179 #[inline]
180 pub fn as_raw(&self) -> cudaGraph_t {
181 self.inner.handle
182 }
183
184 pub fn add_empty_node(&self, dependencies: &[GraphNode]) -> Result<GraphNode> {
186 let r = runtime()?;
187 let cu = r.cuda_graph_add_empty_node()?;
188 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
189 let (dp, dl) = deps_raw(&deps);
190 let mut node: cudaGraphNode_t = core::ptr::null_mut();
191 check(unsafe { cu(&mut node, self.inner.handle, dp, dl) })?;
192 Ok(GraphNode { raw: node })
193 }
194
195 pub unsafe fn add_kernel_node(
202 &self,
203 dependencies: &[GraphNode],
204 kernel: &crate::Kernel,
205 grid: crate::Dim3,
206 block: crate::Dim3,
207 shared_mem_bytes: u32,
208 args: &mut [*mut core::ffi::c_void],
209 ) -> Result<GraphNode> { unsafe {
210 use baracuda_cuda_sys::runtime::types::{cudaKernelNodeParams, dim3};
211 let r = runtime()?;
212 let cu = r.cuda_graph_add_kernel_node()?;
213 let params = cudaKernelNodeParams {
214 func: kernel.as_launch_ptr() as *mut core::ffi::c_void,
215 grid_dim: dim3::new(grid.x, grid.y, grid.z),
216 block_dim: dim3::new(block.x, block.y, block.z),
217 shared_mem_bytes,
218 kernel_params: if args.is_empty() {
219 core::ptr::null_mut()
220 } else {
221 args.as_mut_ptr()
222 },
223 extra: core::ptr::null_mut(),
224 };
225 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
226 let (dp, dl) = deps_raw(&deps);
227 let mut node: cudaGraphNode_t = core::ptr::null_mut();
228 check(cu(&mut node, self.inner.handle, dp, dl, ¶ms))?;
229 Ok(GraphNode { raw: node })
230 }}
231
232 pub fn add_memset_u32_node(
234 &self,
235 dependencies: &[GraphNode],
236 dst: *mut core::ffi::c_void,
237 value: u32,
238 count: usize,
239 ) -> Result<GraphNode> {
240 use baracuda_cuda_sys::runtime::types::cudaMemsetParams;
241 let r = runtime()?;
242 let cu = r.cuda_graph_add_memset_node()?;
243 let params = cudaMemsetParams {
244 dst,
245 pitch: 0,
246 value,
247 element_size: 4,
248 width: count,
249 height: 1,
250 };
251 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
252 let (dp, dl) = deps_raw(&deps);
253 let mut node: cudaGraphNode_t = core::ptr::null_mut();
254 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, ¶ms) })?;
255 Ok(GraphNode { raw: node })
256 }
257
258 pub unsafe fn add_host_node(
266 &self,
267 dependencies: &[GraphNode],
268 fn_: unsafe extern "C" fn(*mut core::ffi::c_void),
269 user_data: *mut core::ffi::c_void,
270 ) -> Result<GraphNode> { unsafe {
271 use baracuda_cuda_sys::runtime::types::cudaHostNodeParams;
272 let r = runtime()?;
273 let cu = r.cuda_graph_add_host_node()?;
274 let params = cudaHostNodeParams {
275 fn_: Some(fn_),
276 user_data,
277 };
278 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
279 let (dp, dl) = deps_raw(&deps);
280 let mut node: cudaGraphNode_t = core::ptr::null_mut();
281 check(cu(&mut node, self.inner.handle, dp, dl, ¶ms))?;
282 Ok(GraphNode { raw: node })
283 }}
284
285 pub fn add_child_graph_node(
287 &self,
288 dependencies: &[GraphNode],
289 child: &Graph,
290 ) -> Result<GraphNode> {
291 let r = runtime()?;
292 let cu = r.cuda_graph_add_child_graph_node()?;
293 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
294 let (dp, dl) = deps_raw(&deps);
295 let mut node: cudaGraphNode_t = core::ptr::null_mut();
296 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, child.as_raw()) })?;
297 Ok(GraphNode { raw: node })
298 }
299
300 pub fn add_event_record_node(
302 &self,
303 dependencies: &[GraphNode],
304 event: &crate::Event,
305 ) -> Result<GraphNode> {
306 let r = runtime()?;
307 let cu = r.cuda_graph_add_event_record_node()?;
308 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
309 let (dp, dl) = deps_raw(&deps);
310 let mut node: cudaGraphNode_t = core::ptr::null_mut();
311 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, event.as_raw()) })?;
312 Ok(GraphNode { raw: node })
313 }
314
315 pub fn add_event_wait_node(
317 &self,
318 dependencies: &[GraphNode],
319 event: &crate::Event,
320 ) -> Result<GraphNode> {
321 let r = runtime()?;
322 let cu = r.cuda_graph_add_event_wait_node()?;
323 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
324 let (dp, dl) = deps_raw(&deps);
325 let mut node: cudaGraphNode_t = core::ptr::null_mut();
326 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, event.as_raw()) })?;
327 Ok(GraphNode { raw: node })
328 }
329
330 pub fn add_mem_alloc_node(
333 &self,
334 dependencies: &[GraphNode],
335 device: &crate::Device,
336 bytesize: usize,
337 ) -> Result<(GraphNode, *mut core::ffi::c_void)> {
338 use baracuda_cuda_sys::runtime::types::{
339 cudaMemAllocNodeParams, cudaMemAllocationHandleType, cudaMemAllocationType,
340 cudaMemLocation, cudaMemLocationType, cudaMemPoolProps,
341 };
342 let r = runtime()?;
343 let cu = r.cuda_graph_add_mem_alloc_node()?;
344 let mut params = cudaMemAllocNodeParams {
345 pool_props: cudaMemPoolProps {
346 alloc_type: cudaMemAllocationType::PINNED,
347 handle_types: cudaMemAllocationHandleType::NONE,
348 location: cudaMemLocation {
349 type_: cudaMemLocationType::DEVICE,
350 id: device.ordinal(),
351 },
352 ..Default::default()
353 },
354 access_descs: core::ptr::null(),
355 access_desc_count: 0,
356 bytesize,
357 dptr: core::ptr::null_mut(),
358 };
359 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
360 let (dp, dl) = deps_raw(&deps);
361 let mut node: cudaGraphNode_t = core::ptr::null_mut();
362 check(unsafe { cu(&mut node, self.inner.handle, dp, dl, &mut params) })?;
363 Ok((GraphNode { raw: node }, params.dptr))
364 }
365
366 pub unsafe fn add_mem_free_node(
373 &self,
374 dependencies: &[GraphNode],
375 dptr: *mut core::ffi::c_void,
376 ) -> Result<GraphNode> { unsafe {
377 let r = runtime()?;
378 let cu = r.cuda_graph_add_mem_free_node()?;
379 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
380 let (dp, dl) = deps_raw(&deps);
381 let mut node: cudaGraphNode_t = core::ptr::null_mut();
382 check(cu(&mut node, self.inner.handle, dp, dl, dptr))?;
383 Ok(GraphNode { raw: node })
384 }}
385
386 pub fn conditional_handle_create(&self, default_launch_value: u32, flags: u32) -> Result<u64> {
394 use baracuda_types::{supports, Feature};
395 let installed = crate::init::driver_version()?;
396 if !supports(installed, Feature::GraphConditionalNodes) {
397 return Err(crate::error::Error::FeatureNotSupported {
398 api: "cudaGraphConditionalHandleCreate",
399 since: Feature::GraphConditionalNodes.required_version(),
400 });
401 }
402 let r = runtime()?;
403 let cu = r.cuda_graph_conditional_handle_create()?;
404 let mut handle: u64 = 0;
405 check(unsafe { cu(&mut handle, self.inner.handle, default_launch_value, flags) })?;
406 Ok(handle)
407 }
408
409 pub unsafe fn add_node_raw(
421 &self,
422 dependencies: &[GraphNode],
423 node_params: *mut core::ffi::c_void,
424 ) -> Result<GraphNode> { unsafe {
425 let r = runtime()?;
426 let cu = r.cuda_graph_add_node()?;
427 let deps: Vec<_> = dependencies.iter().map(|n| n.raw).collect();
428 let (dp, dl) = deps_raw(&deps);
429 let mut node: cudaGraphNode_t = core::ptr::null_mut();
430 check(cu(&mut node, self.inner.handle, dp, dl, node_params))?;
431 Ok(GraphNode { raw: node })
432 }}
433
434 pub fn add_dependencies(&self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
436 assert_eq!(from.len(), to.len());
437 if from.is_empty() {
438 return Ok(());
439 }
440 let r = runtime()?;
441 let cu = r.cuda_graph_add_dependencies()?;
442 let f: Vec<_> = from.iter().map(|n| n.raw).collect();
443 let t: Vec<_> = to.iter().map(|n| n.raw).collect();
444 check(unsafe { cu(self.inner.handle, f.as_ptr(), t.as_ptr(), f.len()) })
445 }
446}
447
448fn deps_raw(deps: &[cudaGraphNode_t]) -> (*const cudaGraphNode_t, usize) {
449 if deps.is_empty() {
450 (core::ptr::null(), 0)
451 } else {
452 (deps.as_ptr(), deps.len())
453 }
454}
455
456#[derive(Copy, Clone, Debug)]
459pub struct GraphNode {
460 raw: cudaGraphNode_t,
461}
462
463impl GraphNode {
464 #[inline]
465 pub fn as_raw(&self) -> cudaGraphNode_t {
466 self.raw
467 }
468
469 pub fn node_type(&self) -> Result<i32> {
474 let r = runtime()?;
475 let cu = r.cuda_graph_node_get_type()?;
476 let mut t: core::ffi::c_int = 0;
477 check(unsafe { cu(self.raw, &mut t) })?;
478 Ok(t)
479 }
480
481 pub fn mem_free_ptr(&self) -> Result<*mut core::ffi::c_void> {
483 let r = runtime()?;
484 let cu = r.cuda_graph_mem_free_node_get_params()?;
485 let mut p: *mut core::ffi::c_void = core::ptr::null_mut();
486 check(unsafe { cu(self.raw, &mut p) })?;
487 Ok(p)
488 }
489}
490
491impl Drop for GraphInner {
492 fn drop(&mut self) {
493 if let Ok(r) = runtime() {
494 if let Ok(cu) = r.cuda_graph_destroy() {
495 let _ = unsafe { cu(self.handle) };
496 }
497 }
498 }
499}
500
501#[derive(Clone)]
503pub struct GraphExec {
504 inner: Arc<GraphExecInner>,
505}
506
507struct GraphExecInner {
508 handle: cudaGraphExec_t,
509}
510
511unsafe impl Send for GraphExecInner {}
512unsafe impl Sync for GraphExecInner {}
513
514impl core::fmt::Debug for GraphExecInner {
515 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
516 f.debug_struct("GraphExec")
517 .field("handle", &self.handle)
518 .finish_non_exhaustive()
519 }
520}
521
522impl core::fmt::Debug for GraphExec {
523 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
524 self.inner.fmt(f)
525 }
526}
527
528impl GraphExec {
529 pub fn launch(&self, stream: &Stream) -> Result<()> {
531 let r = runtime()?;
532 let cu = r.cuda_graph_launch()?;
533 check(unsafe { cu(self.inner.handle, stream.as_raw()) })
534 }
535
536 pub fn update(&self, new_template: &Graph) -> Result<UpdateResult> {
540 let r = runtime()?;
541 let cu = r.cuda_graph_exec_update()?;
542 let mut error_node: cudaGraphNode_t = core::ptr::null_mut();
543 let mut result: core::ffi::c_int = 0;
544 let rc = unsafe {
549 cu(
550 self.inner.handle,
551 new_template.as_raw(),
552 &mut error_node,
553 &mut result,
554 )
555 };
556 if rc != baracuda_cuda_sys::runtime::cudaError_t::Success
557 && result == baracuda_cuda_sys::runtime::types::cudaGraphExecUpdateResult::SUCCESS
558 {
559 return Err(crate::error::Error::Status { status: rc });
560 }
561 Ok(UpdateResult {
562 result,
563 error_node: if error_node.is_null() {
564 None
565 } else {
566 Some(GraphNode { raw: error_node })
567 },
568 })
569 }
570
571 #[inline]
572 pub fn as_raw(&self) -> cudaGraphExec_t {
573 self.inner.handle
574 }
575}
576
577#[derive(Clone, Debug)]
581pub struct UpdateResult {
582 pub result: core::ffi::c_int,
583 pub error_node: Option<GraphNode>,
584}
585
586impl UpdateResult {
587 pub fn is_success(&self) -> bool {
588 self.result == baracuda_cuda_sys::runtime::types::cudaGraphExecUpdateResult::SUCCESS
589 }
590}
591
592impl Drop for GraphExecInner {
593 fn drop(&mut self) {
594 if let Ok(r) = runtime() {
595 if let Ok(cu) = r.cuda_graph_exec_destroy() {
596 let _ = unsafe { cu(self.handle) };
597 }
598 }
599 }
600}