1use std::sync::Arc;
17
18use baracuda_cuda_sys::types::{
19 CUgraphConditionalHandle, CUgraphExecUpdateResultInfo, CUgraphNodeParams, CUgraphNodeType,
20 CUmemAllocationHandleType, CUmemAllocationType, CUmemLocation, CUmemLocationType,
21 CUmemPoolProps, CUDA_CONDITIONAL_NODE_PARAMS, CUDA_HOST_NODE_PARAMS, CUDA_KERNEL_NODE_PARAMS,
22 CUDA_MEMCPY3D, CUDA_MEMSET_NODE_PARAMS, CUDA_MEM_ALLOC_NODE_PARAMS,
23};
24use baracuda_cuda_sys::{driver, CUdeviceptr, CUgraph, CUgraphExec, CUgraphNode};
25
26use crate::context::Context;
27use crate::error::{check, Result};
28use crate::event::Event;
29use crate::launch::Dim3;
30use crate::module::Function;
31use crate::stream::Stream;
32
33#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
35pub enum CaptureMode {
36 Global,
40 #[default]
43 ThreadLocal,
44 Relaxed,
47}
48
49impl CaptureMode {
50 #[inline]
51 fn raw(self) -> u32 {
52 match self {
53 CaptureMode::Global => 0,
54 CaptureMode::ThreadLocal => 1,
55 CaptureMode::Relaxed => 2,
56 }
57 }
58}
59
60impl Stream {
61 pub fn begin_capture(&self, mode: CaptureMode) -> Result<()> {
67 let d = driver()?;
68 let cu = d.cu_stream_begin_capture()?;
69 check(unsafe { cu(self.as_raw(), mode.raw()) })
70 }
71
72 pub fn end_capture(&self) -> Result<Graph> {
74 let d = driver()?;
75 let cu = d.cu_stream_end_capture()?;
76 let mut graph: CUgraph = core::ptr::null_mut();
77 check(unsafe { cu(self.as_raw(), &mut graph) })?;
78 Ok(Graph {
79 inner: Arc::new(GraphInner {
80 handle: graph,
81 context: self.context().clone(),
82 owned: true,
83 }),
84 })
85 }
86
87 pub fn capture<F>(&self, mode: CaptureMode, f: F) -> Result<Graph>
93 where
94 F: FnOnce(&Stream) -> Result<()>,
95 {
96 self.begin_capture(mode)?;
97 let inner_result = f(self);
98 let end_result = self.end_capture();
100 match (inner_result, end_result) {
101 (Ok(()), Ok(graph)) => Ok(graph),
102 (Err(e), _) => Err(e),
103 (Ok(()), Err(e)) => Err(e),
104 }
105 }
106
107 pub fn is_capturing(&self) -> Result<bool> {
109 let d = driver()?;
110 let cu = d.cu_stream_is_capturing()?;
111 let mut status: core::ffi::c_uint = 0;
112 check(unsafe { cu(self.as_raw(), &mut status) })?;
113 Ok(status == 1)
115 }
116}
117
118#[derive(Clone)]
120pub struct Graph {
121 inner: Arc<GraphInner>,
122}
123
124struct GraphInner {
125 handle: CUgraph,
126 context: Context,
127 owned: bool,
130}
131
132unsafe impl Send for GraphInner {}
133unsafe impl Sync for GraphInner {}
134
135impl core::fmt::Debug for GraphInner {
136 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
137 f.debug_struct("Graph")
138 .field("handle", &self.handle)
139 .finish_non_exhaustive()
140 }
141}
142
143impl core::fmt::Debug for Graph {
144 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145 self.inner.fmt(f)
146 }
147}
148
149impl Graph {
150 pub fn new(context: &Context) -> Result<Self> {
155 context.set_current()?;
156 let d = driver()?;
157 let cu = d.cu_graph_create()?;
158 let mut graph: CUgraph = core::ptr::null_mut();
159 check(unsafe { cu(&mut graph, 0) })?;
160 Ok(Self {
161 inner: Arc::new(GraphInner {
162 handle: graph,
163 context: context.clone(),
164 owned: true,
165 }),
166 })
167 }
168
169 pub fn instantiate(&self) -> Result<GraphExec> {
171 self.instantiate_with_flags(0)
172 }
173
174 pub fn instantiate_with_flags(&self, flags: u64) -> Result<GraphExec> {
177 let d = driver()?;
178 let cu = d.cu_graph_instantiate_with_flags()?;
179 let mut exec: CUgraphExec = core::ptr::null_mut();
180 check(unsafe { cu(&mut exec, self.inner.handle, flags) })?;
181 Ok(GraphExec {
182 inner: Arc::new(GraphExecInner {
183 handle: exec,
184 context: self.inner.context.clone(),
185 }),
186 })
187 }
188
189 pub fn node_count(&self) -> Result<usize> {
191 let d = driver()?;
192 let cu = d.cu_graph_get_nodes()?;
193 let mut count: usize = 0;
194 check(unsafe { cu(self.inner.handle, core::ptr::null_mut(), &mut count) })?;
195 Ok(count)
196 }
197
198 #[inline]
200 pub fn as_raw(&self) -> CUgraph {
201 self.inner.handle
202 }
203
204 pub fn add_empty_node(&self, dependencies: &[GraphNode]) -> Result<GraphNode> {
207 let d = driver()?;
208 let cu = d.cu_graph_add_empty_node()?;
209 let mut node: CUgraphNode = core::ptr::null_mut();
210 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
211 let (deps_ptr, deps_len) = deps_raw(&deps);
212 check(unsafe { cu(&mut node, self.inner.handle, deps_ptr, deps_len) })?;
213 Ok(GraphNode { raw: node })
214 }
215
216 pub unsafe fn add_kernel_node(
227 &self,
228 dependencies: &[GraphNode],
229 function: &Function,
230 grid: impl Into<Dim3>,
231 block: impl Into<Dim3>,
232 shared_mem_bytes: u32,
233 args: &mut [*mut core::ffi::c_void],
234 ) -> Result<GraphNode> { unsafe {
235 let d = driver()?;
236 let cu = d.cu_graph_add_kernel_node()?;
237 let grid = grid.into();
238 let block = block.into();
239 let params = CUDA_KERNEL_NODE_PARAMS {
240 func: function.as_raw(),
241 grid_dim_x: grid.x,
242 grid_dim_y: grid.y,
243 grid_dim_z: grid.z,
244 block_dim_x: block.x,
245 block_dim_y: block.y,
246 block_dim_z: block.z,
247 shared_mem_bytes,
248 kernel_params: if args.is_empty() {
249 core::ptr::null_mut()
250 } else {
251 args.as_mut_ptr()
252 },
253 extra: core::ptr::null_mut(),
254 kern: core::ptr::null_mut(),
255 ctx: core::ptr::null_mut(),
256 };
257 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
258 let (deps_ptr, deps_len) = deps_raw(&deps);
259 let mut node: CUgraphNode = core::ptr::null_mut();
260 check(cu(
261 &mut node,
262 self.inner.handle,
263 deps_ptr,
264 deps_len,
265 ¶ms,
266 ))?;
267 Ok(GraphNode { raw: node })
268 }}
269
270 pub fn add_memset_u32_node(
274 &self,
275 dependencies: &[GraphNode],
276 dst: CUdeviceptr,
277 value: u32,
278 count: usize,
279 ) -> Result<GraphNode> {
280 let d = driver()?;
281 let cu = d.cu_graph_add_memset_node()?;
282 let params = CUDA_MEMSET_NODE_PARAMS {
283 dst,
284 pitch: 0,
285 value,
286 element_size: 4,
287 width: count,
288 height: 1,
289 };
290 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
291 let (deps_ptr, deps_len) = deps_raw(&deps);
292 let mut node: CUgraphNode = core::ptr::null_mut();
293 check(unsafe {
294 cu(
295 &mut node,
296 self.inner.handle,
297 deps_ptr,
298 deps_len,
299 ¶ms,
300 self.inner.context.as_raw(),
301 )
302 })?;
303 Ok(GraphNode { raw: node })
304 }
305
306 pub fn clone_graph(&self) -> Result<Self> {
309 let d = driver()?;
310 let cu = d.cu_graph_clone()?;
311 let mut out: CUgraph = core::ptr::null_mut();
312 check(unsafe { cu(&mut out, self.inner.handle) })?;
313 Ok(Self {
314 inner: Arc::new(GraphInner {
315 handle: out,
316 context: self.inner.context.clone(),
317 owned: true,
318 }),
319 })
320 }
321
322 pub fn add_memcpy_node(
324 &self,
325 dependencies: &[GraphNode],
326 params: &CUDA_MEMCPY3D,
327 ) -> Result<GraphNode> {
328 let d = driver()?;
329 let cu = d.cu_graph_add_memcpy_node()?;
330 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
331 let (deps_ptr, deps_len) = deps_raw(&deps);
332 let mut node: CUgraphNode = core::ptr::null_mut();
333 check(unsafe {
334 cu(
335 &mut node,
336 self.inner.handle,
337 deps_ptr,
338 deps_len,
339 params,
340 self.inner.context.as_raw(),
341 )
342 })?;
343 Ok(GraphNode { raw: node })
344 }
345
346 pub unsafe fn add_host_node(
356 &self,
357 dependencies: &[GraphNode],
358 fn_: unsafe extern "C" fn(*mut core::ffi::c_void),
359 user_data: *mut core::ffi::c_void,
360 ) -> Result<GraphNode> { unsafe {
361 let d = driver()?;
362 let cu = d.cu_graph_add_host_node()?;
363 let params = CUDA_HOST_NODE_PARAMS {
364 fn_: Some(fn_),
365 user_data,
366 };
367 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
368 let (deps_ptr, deps_len) = deps_raw(&deps);
369 let mut node: CUgraphNode = core::ptr::null_mut();
370 check(cu(
371 &mut node,
372 self.inner.handle,
373 deps_ptr,
374 deps_len,
375 ¶ms,
376 ))?;
377 Ok(GraphNode { raw: node })
378 }}
379
380 pub fn add_child_graph_node(
383 &self,
384 dependencies: &[GraphNode],
385 child: &Graph,
386 ) -> Result<GraphNode> {
387 let d = driver()?;
388 let cu = d.cu_graph_add_child_graph_node()?;
389 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
390 let (deps_ptr, deps_len) = deps_raw(&deps);
391 let mut node: CUgraphNode = core::ptr::null_mut();
392 check(unsafe {
393 cu(
394 &mut node,
395 self.inner.handle,
396 deps_ptr,
397 deps_len,
398 child.as_raw(),
399 )
400 })?;
401 Ok(GraphNode { raw: node })
402 }
403
404 pub fn add_event_record_node(
406 &self,
407 dependencies: &[GraphNode],
408 event: &Event,
409 ) -> Result<GraphNode> {
410 let d = driver()?;
411 let cu = d.cu_graph_add_event_record_node()?;
412 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
413 let (deps_ptr, deps_len) = deps_raw(&deps);
414 let mut node: CUgraphNode = core::ptr::null_mut();
415 check(unsafe {
416 cu(
417 &mut node,
418 self.inner.handle,
419 deps_ptr,
420 deps_len,
421 event.as_raw(),
422 )
423 })?;
424 Ok(GraphNode { raw: node })
425 }
426
427 pub fn add_event_wait_node(
430 &self,
431 dependencies: &[GraphNode],
432 event: &Event,
433 ) -> Result<GraphNode> {
434 let d = driver()?;
435 let cu = d.cu_graph_add_event_wait_node()?;
436 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
437 let (deps_ptr, deps_len) = deps_raw(&deps);
438 let mut node: CUgraphNode = core::ptr::null_mut();
439 check(unsafe {
440 cu(
441 &mut node,
442 self.inner.handle,
443 deps_ptr,
444 deps_len,
445 event.as_raw(),
446 )
447 })?;
448 Ok(GraphNode { raw: node })
449 }
450
451 pub fn add_mem_alloc_node(
456 &self,
457 dependencies: &[GraphNode],
458 device: &crate::Device,
459 bytesize: usize,
460 ) -> Result<(GraphNode, CUdeviceptr)> {
461 let d = driver()?;
462 let cu = d.cu_graph_add_mem_alloc_node()?;
463 let mut params = CUDA_MEM_ALLOC_NODE_PARAMS {
464 pool_props: CUmemPoolProps {
465 alloc_type: CUmemAllocationType::PINNED,
466 handle_types: CUmemAllocationHandleType::NONE,
467 location: CUmemLocation {
468 type_: CUmemLocationType::DEVICE,
469 id: device.as_raw().0,
470 },
471 ..Default::default()
472 },
473 access_descs: core::ptr::null(),
474 access_desc_count: 0,
475 bytesize,
476 dptr: CUdeviceptr(0),
477 };
478 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
479 let (deps_ptr, deps_len) = deps_raw(&deps);
480 let mut node: CUgraphNode = core::ptr::null_mut();
481 check(unsafe {
482 cu(
483 &mut node,
484 self.inner.handle,
485 deps_ptr,
486 deps_len,
487 &mut params,
488 )
489 })?;
490 Ok((GraphNode { raw: node }, params.dptr))
491 }
492
493 pub fn add_mem_free_node(
497 &self,
498 dependencies: &[GraphNode],
499 dptr: CUdeviceptr,
500 ) -> Result<GraphNode> {
501 let d = driver()?;
502 let cu = d.cu_graph_add_mem_free_node()?;
503 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
504 let (deps_ptr, deps_len) = deps_raw(&deps);
505 let mut node: CUgraphNode = core::ptr::null_mut();
506 check(unsafe { cu(&mut node, self.inner.handle, deps_ptr, deps_len, dptr) })?;
507 Ok(GraphNode { raw: node })
508 }
509
510 pub fn add_batch_mem_op_node(
517 &self,
518 dependencies: &[GraphNode],
519 ops: &mut [baracuda_cuda_sys::types::CUstreamBatchMemOpParams],
520 ) -> Result<GraphNode> {
521 let d = driver()?;
522 let cu = d.cu_graph_add_batch_mem_op_node()?;
523 let params = baracuda_cuda_sys::types::CUDA_BATCH_MEM_OP_NODE_PARAMS {
524 ctx: self.inner.context.as_raw(),
525 count: ops.len() as core::ffi::c_uint,
526 param_array: ops.as_mut_ptr(),
527 flags: 0,
528 };
529 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
530 let (deps_ptr, deps_len) = deps_raw(&deps);
531 let mut node: CUgraphNode = core::ptr::null_mut();
532 check(unsafe { cu(&mut node, self.inner.handle, deps_ptr, deps_len, ¶ms) })?;
533 Ok(GraphNode { raw: node })
534 }
535
536 pub fn add_dependencies(&self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
539 assert_eq!(from.len(), to.len(), "add_dependencies: length mismatch");
540 if from.is_empty() {
541 return Ok(());
542 }
543 let d = driver()?;
544 let cu = d.cu_graph_add_dependencies()?;
545 let f: Vec<CUgraphNode> = from.iter().map(|n| n.raw).collect();
546 let t: Vec<CUgraphNode> = to.iter().map(|n| n.raw).collect();
547 check(unsafe { cu(self.inner.handle, f.as_ptr(), t.as_ptr(), f.len()) })
548 }
549
550 pub fn remove_dependencies(&self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
552 assert_eq!(from.len(), to.len(), "remove_dependencies: length mismatch");
553 if from.is_empty() {
554 return Ok(());
555 }
556 let d = driver()?;
557 let cu = d.cu_graph_remove_dependencies()?;
558 let f: Vec<CUgraphNode> = from.iter().map(|n| n.raw).collect();
559 let t: Vec<CUgraphNode> = to.iter().map(|n| n.raw).collect();
560 check(unsafe { cu(self.inner.handle, f.as_ptr(), t.as_ptr(), f.len()) })
561 }
562
563 pub fn debug_dot_print(&self, path: &str, flags: u32) -> Result<()> {
566 let d = driver()?;
567 let cu = d.cu_graph_debug_dot_print()?;
568 let c_path = std::ffi::CString::new(path).map_err(|_| {
569 crate::error::Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
570 library: "cuda-driver",
571 symbol: "cuGraphDebugDotPrint(path contained a NUL byte)",
572 })
573 })?;
574 check(unsafe { cu(self.inner.handle, c_path.as_ptr(), flags) })
575 }
576
577 pub fn conditional_handle(
582 &self,
583 default_launch_value: u32,
584 flags: u32,
585 ) -> Result<CUgraphConditionalHandle> {
586 let d = driver()?;
587 let cu = d.cu_graph_conditional_handle_create()?;
588 let mut h: CUgraphConditionalHandle = 0;
589 check(unsafe {
590 cu(
591 &mut h,
592 self.inner.handle,
593 self.inner.context.as_raw(),
594 default_launch_value,
595 flags,
596 )
597 })?;
598 Ok(h)
599 }
600
601 pub fn add_conditional_node(
608 &self,
609 dependencies: &[GraphNode],
610 handle: CUgraphConditionalHandle,
611 type_: i32,
612 size: u32,
613 ) -> Result<(GraphNode, Graph)> {
614 let d = driver()?;
615 let cu = d.cu_graph_add_node()?;
616 let mut body: CUgraph = core::ptr::null_mut();
617 let cond = CUDA_CONDITIONAL_NODE_PARAMS {
618 handle,
619 type_,
620 size,
621 body_graph_out: &mut body,
622 ctx: self.inner.context.as_raw(),
623 };
624 let mut params = CUgraphNodeParams {
625 type_: CUgraphNodeType::CONDITIONAL,
626 ..Default::default()
627 };
628 unsafe {
632 let dst = params.payload.as_mut_ptr() as *mut CUDA_CONDITIONAL_NODE_PARAMS;
633 dst.write(cond);
634 }
635 let deps: Vec<CUgraphNode> = dependencies.iter().map(|n| n.raw).collect();
636 let (deps_ptr, deps_len) = deps_raw(&deps);
637 let mut node: CUgraphNode = core::ptr::null_mut();
638 check(unsafe {
639 cu(
640 &mut node,
641 self.inner.handle,
642 deps_ptr,
643 core::ptr::null(),
644 deps_len,
645 &mut params,
646 )
647 })?;
648 let body_graph = Graph {
651 inner: Arc::new(GraphInner {
652 handle: body,
653 context: self.inner.context.clone(),
654 owned: false,
655 }),
656 };
657 Ok((GraphNode { raw: node }, body_graph))
658 }
659
660 pub fn edges(&self) -> Result<(Vec<GraphNode>, Vec<GraphNode>)> {
662 let d = driver()?;
663 let cu = d.cu_graph_get_edges()?;
664 let mut count: usize = 0;
666 check(unsafe {
667 cu(
668 self.inner.handle,
669 core::ptr::null_mut(),
670 core::ptr::null_mut(),
671 &mut count,
672 )
673 })?;
674 let mut from = vec![core::ptr::null_mut(); count];
675 let mut to = vec![core::ptr::null_mut(); count];
676 if count > 0 {
677 check(unsafe {
678 cu(
679 self.inner.handle,
680 from.as_mut_ptr(),
681 to.as_mut_ptr(),
682 &mut count,
683 )
684 })?;
685 }
686 Ok((
687 from.into_iter().map(|raw| GraphNode { raw }).collect(),
688 to.into_iter().map(|raw| GraphNode { raw }).collect(),
689 ))
690 }
691}
692
693fn deps_raw(deps: &[CUgraphNode]) -> (*const CUgraphNode, usize) {
694 if deps.is_empty() {
695 (core::ptr::null(), 0)
696 } else {
697 (deps.as_ptr(), deps.len())
698 }
699}
700
701#[derive(Copy, Clone, Debug)]
707pub struct GraphNode {
708 raw: CUgraphNode,
709}
710
711impl GraphNode {
712 #[inline]
714 pub fn as_raw(&self) -> CUgraphNode {
715 self.raw
716 }
717
718 pub fn node_type(&self) -> Result<core::ffi::c_int> {
721 let d = driver()?;
722 let cu = d.cu_graph_node_get_type()?;
723 let mut t: core::ffi::c_int = 0;
724 check(unsafe { cu(self.raw, &mut t) })?;
725 Ok(t)
726 }
727
728 pub fn dependencies(&self) -> Result<Vec<GraphNode>> {
730 let d = driver()?;
731 let cu = d.cu_graph_node_get_dependencies()?;
732 let mut count: usize = 0;
733 check(unsafe { cu(self.raw, core::ptr::null_mut(), &mut count) })?;
734 let mut out = vec![core::ptr::null_mut(); count];
735 if count > 0 {
736 check(unsafe { cu(self.raw, out.as_mut_ptr(), &mut count) })?;
737 }
738 Ok(out.into_iter().map(|raw| GraphNode { raw }).collect())
739 }
740
741 pub fn dependent_nodes(&self) -> Result<Vec<GraphNode>> {
743 let d = driver()?;
744 let cu = d.cu_graph_node_get_dependent_nodes()?;
745 let mut count: usize = 0;
746 check(unsafe { cu(self.raw, core::ptr::null_mut(), &mut count) })?;
747 let mut out = vec![core::ptr::null_mut(); count];
748 if count > 0 {
749 check(unsafe { cu(self.raw, out.as_mut_ptr(), &mut count) })?;
750 }
751 Ok(out.into_iter().map(|raw| GraphNode { raw }).collect())
752 }
753
754 pub fn kernel_params(&self) -> Result<CUDA_KERNEL_NODE_PARAMS> {
756 let d = driver()?;
757 let cu = d.cu_graph_kernel_node_get_params()?;
758 let mut p = CUDA_KERNEL_NODE_PARAMS::default();
759 check(unsafe { cu(self.raw, &mut p) })?;
760 Ok(p)
761 }
762
763 pub unsafe fn set_kernel_params(&self, params: &CUDA_KERNEL_NODE_PARAMS) -> Result<()> { unsafe {
772 let d = driver()?;
773 let cu = d.cu_graph_kernel_node_set_params()?;
774 check(cu(self.raw, params))
775 }}
776
777 pub unsafe fn set_params(&self, params: &mut CUgraphNodeParams) -> Result<()> { unsafe {
788 let d = driver()?;
789 let cu = d.cu_graph_node_set_params()?;
790 check(cu(self.raw, params))
791 }}
792
793 pub fn memset_params(&self) -> Result<CUDA_MEMSET_NODE_PARAMS> {
795 let d = driver()?;
796 let cu = d.cu_graph_memset_node_get_params()?;
797 let mut p = CUDA_MEMSET_NODE_PARAMS::default();
798 check(unsafe { cu(self.raw, &mut p) })?;
799 Ok(p)
800 }
801
802 pub fn set_memset_params(&self, params: &CUDA_MEMSET_NODE_PARAMS) -> Result<()> {
804 let d = driver()?;
805 let cu = d.cu_graph_memset_node_set_params()?;
806 check(unsafe { cu(self.raw, params) })
807 }
808
809 pub fn mem_free_ptr(&self) -> Result<CUdeviceptr> {
812 let d = driver()?;
813 let cu = d.cu_graph_mem_free_node_get_params()?;
814 let mut p = CUdeviceptr(0);
815 check(unsafe { cu(self.raw, &mut p) })?;
816 Ok(p)
817 }
818
819 pub fn mem_alloc_params(&self) -> Result<CUDA_MEM_ALLOC_NODE_PARAMS> {
822 let d = driver()?;
823 let cu = d.cu_graph_mem_alloc_node_get_params()?;
824 let mut p = CUDA_MEM_ALLOC_NODE_PARAMS::default();
825 check(unsafe { cu(self.raw, &mut p) })?;
826 Ok(p)
827 }
828
829 pub fn memcpy_params(&self) -> Result<CUDA_MEMCPY3D> {
831 let d = driver()?;
832 let cu = d.cu_graph_memcpy_node_get_params()?;
833 let mut p = CUDA_MEMCPY3D::default();
834 check(unsafe { cu(self.raw, &mut p) })?;
835 Ok(p)
836 }
837
838 pub fn set_memcpy_params(&self, params: &CUDA_MEMCPY3D) -> Result<()> {
840 let d = driver()?;
841 let cu = d.cu_graph_memcpy_node_set_params()?;
842 check(unsafe { cu(self.raw, params) })
843 }
844
845 pub unsafe fn destroy(self) -> Result<()> { unsafe {
854 let d = driver()?;
855 let cu = d.cu_graph_destroy_node()?;
856 check(cu(self.raw))
857 }}
858}
859
860pub mod instantiate_flags {
862 pub use baracuda_cuda_sys::types::CUgraphInstantiate_flags::*;
863}
864
865impl Drop for GraphInner {
866 fn drop(&mut self) {
867 if !self.owned || self.handle.is_null() {
868 return;
869 }
870 if let Ok(d) = driver() {
871 if let Ok(cu) = d.cu_graph_destroy() {
872 let _ = unsafe { cu(self.handle) };
873 }
874 }
875 }
876}
877
878#[derive(Clone)]
880pub struct GraphExec {
881 inner: Arc<GraphExecInner>,
882}
883
884struct GraphExecInner {
885 handle: CUgraphExec,
886 #[allow(dead_code)]
887 context: Context,
888}
889
890unsafe impl Send for GraphExecInner {}
891unsafe impl Sync for GraphExecInner {}
892
893impl core::fmt::Debug for GraphExecInner {
894 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
895 f.debug_struct("GraphExec")
896 .field("handle", &self.handle)
897 .finish_non_exhaustive()
898 }
899}
900
901impl core::fmt::Debug for GraphExec {
902 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
903 self.inner.fmt(f)
904 }
905}
906
907impl GraphExec {
908 pub fn launch(&self, stream: &Stream) -> Result<()> {
911 let d = driver()?;
912 let cu = d.cu_graph_launch()?;
913 check(unsafe { cu(self.inner.handle, stream.as_raw()) })
914 }
915
916 #[inline]
918 pub fn as_raw(&self) -> CUgraphExec {
919 self.inner.handle
920 }
921
922 pub fn update(&self, new_template: &Graph) -> Result<UpdateResult> {
933 let d = driver()?;
934 let cu = d.cu_graph_exec_update()?;
935 let mut info = CUgraphExecUpdateResultInfo::default();
936 let rc = unsafe { cu(self.inner.handle, new_template.as_raw(), &mut info) };
940 if rc != baracuda_cuda_sys::CUresult::SUCCESS
941 && info.result == baracuda_cuda_sys::types::CUgraphExecUpdateResult::SUCCESS
942 {
943 return Err(crate::error::Error::Status { status: rc });
944 }
945 Ok(UpdateResult {
946 result: info.result,
947 error_node: if info.error_node.is_null() {
948 None
949 } else {
950 Some(GraphNode {
951 raw: info.error_node,
952 })
953 },
954 error_from_node: if info.error_from_node.is_null() {
955 None
956 } else {
957 Some(GraphNode {
958 raw: info.error_from_node,
959 })
960 },
961 })
962 }
963
964 pub unsafe fn set_kernel_node_params(
972 &self,
973 node: GraphNode,
974 params: &CUDA_KERNEL_NODE_PARAMS,
975 ) -> Result<()> { unsafe {
976 let d = driver()?;
977 let cu = d.cu_graph_exec_kernel_node_set_params()?;
978 check(cu(self.inner.handle, node.raw, params))
979 }}
980
981 pub fn set_memcpy_node_params(&self, node: GraphNode, params: &CUDA_MEMCPY3D) -> Result<()> {
983 let d = driver()?;
984 let cu = d.cu_graph_exec_memcpy_node_set_params()?;
985 check(unsafe {
986 cu(
987 self.inner.handle,
988 node.raw,
989 params,
990 self.inner.context.as_raw(),
991 )
992 })
993 }
994
995 pub fn set_memset_node_params(
997 &self,
998 node: GraphNode,
999 params: &CUDA_MEMSET_NODE_PARAMS,
1000 ) -> Result<()> {
1001 let d = driver()?;
1002 let cu = d.cu_graph_exec_memset_node_set_params()?;
1003 check(unsafe {
1004 cu(
1005 self.inner.handle,
1006 node.raw,
1007 params,
1008 self.inner.context.as_raw(),
1009 )
1010 })
1011 }
1012
1013 pub unsafe fn set_host_node_params(
1020 &self,
1021 node: GraphNode,
1022 fn_: unsafe extern "C" fn(*mut core::ffi::c_void),
1023 user_data: *mut core::ffi::c_void,
1024 ) -> Result<()> { unsafe {
1025 let d = driver()?;
1026 let cu = d.cu_graph_exec_host_node_set_params()?;
1027 let params = CUDA_HOST_NODE_PARAMS {
1028 fn_: Some(fn_),
1029 user_data,
1030 };
1031 check(cu(self.inner.handle, node.raw, ¶ms))
1032 }}
1033}
1034
1035impl Drop for GraphExecInner {
1036 fn drop(&mut self) {
1037 if let Ok(d) = driver() {
1038 if let Ok(cu) = d.cu_graph_exec_destroy() {
1039 let _ = unsafe { cu(self.handle) };
1040 }
1041 }
1042 }
1043}
1044
1045#[derive(Clone, Debug)]
1050pub struct UpdateResult {
1051 pub result: core::ffi::c_int,
1052 pub error_node: Option<GraphNode>,
1054 pub error_from_node: Option<GraphNode>,
1056}
1057
1058impl UpdateResult {
1059 pub fn is_success(&self) -> bool {
1061 self.result == baracuda_cuda_sys::types::CUgraphExecUpdateResult::SUCCESS
1062 }
1063}
1064
1065pub fn device_graph_mem_trim(device: &crate::Device) -> Result<()> {
1069 let d = driver()?;
1070 let cu = d.cu_device_graph_mem_trim()?;
1071 check(unsafe { cu(device.as_raw()) })
1072}
1073
1074pub fn device_graph_mem_attribute(device: &crate::Device, attr: i32) -> Result<u64> {
1077 let d = driver()?;
1078 let cu = d.cu_device_get_graph_mem_attribute()?;
1079 let mut v: u64 = 0;
1080 check(unsafe {
1081 cu(
1082 device.as_raw(),
1083 attr,
1084 &mut v as *mut u64 as *mut core::ffi::c_void,
1085 )
1086 })?;
1087 Ok(v)
1088}
1089
1090pub fn device_set_graph_mem_attribute(device: &crate::Device, attr: i32, value: u64) -> Result<()> {
1092 let d = driver()?;
1093 let cu = d.cu_device_set_graph_mem_attribute()?;
1094 let mut v = value;
1095 check(unsafe {
1096 cu(
1097 device.as_raw(),
1098 attr,
1099 &mut v as *mut u64 as *mut core::ffi::c_void,
1100 )
1101 })
1102}