Skip to main content

baracuda_driver/
user_object.rs

1//! CUDA Graph user objects (CUDA 12.0+).
2//!
3//! Graphs frequently hold references to external resources (allocators,
4//! file handles, RAII guards) that must be kept alive for the graph's
5//! lifetime. A [`UserObject`] is a refcounted handle + destructor that
6//! you can *attach* to a graph via [`Graph::retain_user_object`]; when
7//! the graph releases the last reference, the destructor runs.
8//!
9//! The Rust safe wrapper owns a `Box<dyn FnOnce() + Send>` trampoline so
10//! idiomatic `move`-closures work as destructors.
11
12use core::ffi::c_void;
13
14use baracuda_cuda_sys::{driver, CUuserObject};
15
16use crate::error::{check, Result};
17
18/// A refcounted user object. Drop releases one reference.
19pub struct UserObject {
20    handle: CUuserObject,
21}
22
23unsafe impl Send for UserObject {}
24unsafe impl Sync for UserObject {}
25
26impl core::fmt::Debug for UserObject {
27    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28        f.debug_struct("UserObject")
29            .field("handle", &self.handle)
30            .finish_non_exhaustive()
31    }
32}
33
34unsafe extern "C" fn destroy_trampoline(user_data: *mut c_void) {
35    if user_data.is_null() {
36        return;
37    }
38    // SAFETY: `user_data` was `Box::into_raw`'d by `UserObject::new`.
39    let f: Box<Box<dyn FnOnce() + Send>> =
40        unsafe { Box::from_raw(user_data as *mut Box<dyn FnOnce() + Send>) };
41    (*f)();
42}
43
44impl UserObject {
45    /// Create a user object whose destructor is `destroy`. `initial_refcount`
46    /// is typically 1; the CUDA API requires at least 1.
47    pub fn new<F>(destroy: F, initial_refcount: u32) -> Result<Self>
48    where
49        F: FnOnce() + Send + 'static,
50    {
51        let boxed: Box<Box<dyn FnOnce() + Send>> = Box::new(Box::new(destroy));
52        let raw = Box::into_raw(boxed) as *mut c_void;
53        let d = driver()?;
54        let cu = d.cu_user_object_create()?;
55        let mut object: CUuserObject = core::ptr::null_mut();
56        // CUDA currently requires flags == CU_USER_OBJECT_NO_DESTRUCTOR_SYNC (= 1);
57        // the destructor may run on any CUDA-internal thread.
58        const CU_USER_OBJECT_NO_DESTRUCTOR_SYNC: core::ffi::c_uint = 1;
59        let rc = unsafe {
60            cu(
61                &mut object,
62                raw,
63                Some(destroy_trampoline),
64                initial_refcount,
65                CU_USER_OBJECT_NO_DESTRUCTOR_SYNC,
66            )
67        };
68        if rc != baracuda_cuda_sys::CUresult::SUCCESS {
69            // Reclaim the box so we don't leak the closure.
70            drop(unsafe { Box::from_raw(raw as *mut Box<dyn FnOnce() + Send>) });
71            return Err(crate::error::Error::Status { status: rc });
72        }
73        Ok(Self { handle: object })
74    }
75
76    /// Add `count` references to this user object's refcount.
77    pub fn retain(&self, count: u32) -> Result<()> {
78        let d = driver()?;
79        let cu = d.cu_user_object_retain()?;
80        check(unsafe { cu(self.handle, count) })
81    }
82
83    /// Drop `count` references (runs destructor if this was the last).
84    pub fn release(&self, count: u32) -> Result<()> {
85        let d = driver()?;
86        let cu = d.cu_user_object_release()?;
87        check(unsafe { cu(self.handle, count) })
88    }
89
90    #[inline]
91    pub fn as_raw(&self) -> CUuserObject {
92        self.handle
93    }
94}
95
96impl Drop for UserObject {
97    fn drop(&mut self) {
98        if self.handle.is_null() {
99            return;
100        }
101        if let Ok(d) = driver() {
102            if let Ok(cu) = d.cu_user_object_release() {
103                let _ = unsafe { cu(self.handle, 1) };
104            }
105        }
106    }
107}
108
109// Extend Graph with user-object retention.
110impl crate::Graph {
111    /// Have this graph take `count` references to `object`. When the
112    /// graph is destroyed (or when [`release_user_object`](Self::release_user_object)
113    /// is called), those references are dropped.
114    ///
115    /// `flags` is reserved (pass 0) in CUDA 12.x; CUDA 13 adds
116    /// `CU_GRAPH_USER_OBJECT_MOVE` = 1.
117    pub fn retain_user_object(&self, object: &UserObject, count: u32, flags: u32) -> Result<()> {
118        let d = driver()?;
119        let cu = d.cu_graph_retain_user_object()?;
120        check(unsafe { cu(self.as_raw(), object.as_raw(), count, flags) })
121    }
122
123    pub fn release_user_object(&self, object: &UserObject, count: u32) -> Result<()> {
124        let d = driver()?;
125        let cu = d.cu_graph_release_user_object()?;
126        check(unsafe { cu(self.as_raw(), object.as_raw(), count) })
127    }
128}