Skip to main content

baracuda_runtime/
user_object.rs

1//! Runtime-API graph user objects (CUDA 12.0+).
2//!
3//! Refcounted RAII slot you can attach to a graph via
4//! [`Graph::retain_user_object`]; when the graph releases the last
5//! reference, the destructor runs. Mirrors the Driver-side wrapper.
6
7use core::ffi::c_void;
8
9use baracuda_cuda_sys::runtime::{cudaUserObject_t, runtime};
10
11use crate::error::{check, Result};
12
13/// A refcounted user object.
14pub struct UserObject {
15    handle: cudaUserObject_t,
16}
17
18unsafe impl Send for UserObject {}
19unsafe impl Sync for UserObject {}
20
21impl core::fmt::Debug for UserObject {
22    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
23        f.debug_struct("UserObject")
24            .field("handle", &self.handle)
25            .finish_non_exhaustive()
26    }
27}
28
29unsafe extern "C" fn destroy_trampoline(user_data: *mut c_void) {
30    if user_data.is_null() {
31        return;
32    }
33    let f: Box<Box<dyn FnOnce() + Send>> =
34        unsafe { Box::from_raw(user_data as *mut Box<dyn FnOnce() + Send>) };
35    (*f)();
36}
37
38impl UserObject {
39    /// Create a user object whose destructor is `destroy`.
40    /// `initial_refcount` must be >= 1.
41    pub fn new<F>(destroy: F, initial_refcount: u32) -> Result<Self>
42    where
43        F: FnOnce() + Send + 'static,
44    {
45        let boxed: Box<Box<dyn FnOnce() + Send>> = Box::new(Box::new(destroy));
46        let raw = Box::into_raw(boxed) as *mut c_void;
47        let r = runtime()?;
48        let cu = r.cuda_user_object_create()?;
49        let mut object: cudaUserObject_t = core::ptr::null_mut();
50        // CUDA requires flags == cudaGraphUserObjectMove (1) currently.
51        const CUDA_USER_OBJECT_NO_DESTRUCTOR_SYNC: core::ffi::c_uint = 1;
52        let rc = unsafe {
53            cu(
54                &mut object,
55                raw,
56                Some(destroy_trampoline),
57                initial_refcount,
58                CUDA_USER_OBJECT_NO_DESTRUCTOR_SYNC,
59            )
60        };
61        if rc != baracuda_cuda_sys::runtime::cudaError_t::Success {
62            drop(unsafe { Box::from_raw(raw as *mut Box<dyn FnOnce() + Send>) });
63            return Err(crate::error::Error::Status { status: rc });
64        }
65        Ok(Self { handle: object })
66    }
67
68    pub fn retain(&self, count: u32) -> Result<()> {
69        let r = runtime()?;
70        let cu = r.cuda_user_object_retain()?;
71        check(unsafe { cu(self.handle, count) })
72    }
73
74    pub fn release(&self, count: u32) -> Result<()> {
75        let r = runtime()?;
76        let cu = r.cuda_user_object_release()?;
77        check(unsafe { cu(self.handle, count) })
78    }
79
80    #[inline]
81    pub fn as_raw(&self) -> cudaUserObject_t {
82        self.handle
83    }
84}
85
86impl Drop for UserObject {
87    fn drop(&mut self) {
88        if self.handle.is_null() {
89            return;
90        }
91        if let Ok(r) = runtime() {
92            if let Ok(cu) = r.cuda_user_object_release() {
93                let _ = unsafe { cu(self.handle, 1) };
94            }
95        }
96    }
97}
98
99impl crate::Graph {
100    /// Have this graph retain `count` references to `object`.
101    pub fn retain_user_object(&self, object: &UserObject, count: u32, flags: u32) -> Result<()> {
102        let r = runtime()?;
103        let cu = r.cuda_graph_retain_user_object()?;
104        check(unsafe { cu(self.as_raw(), object.as_raw(), count, flags) })
105    }
106
107    pub fn release_user_object(&self, object: &UserObject, count: u32) -> Result<()> {
108        let r = runtime()?;
109        let cu = r.cuda_graph_release_user_object()?;
110        check(unsafe { cu(self.as_raw(), object.as_raw(), count) })
111    }
112}