1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
use cpp::cpp;
use crate::ffi::ptr::DevicePtr;
use crate::ffi::result;
type Result<T> = std::result::Result<T, crate::error::Error>;
/// Synchronous implementation of [`crate::Stream`].
///
/// Refer to [`crate::Stream`] for documentation.
pub struct Stream {
internal: DevicePtr,
}
/// Implements [`Send`] for [`Stream`].
///
/// # Safety
///
/// This property is inherited from the CUDA API, which is thread-safe.
unsafe impl Send for Stream {}
/// Implements [`Sync`] for [`Stream`].
///
/// # Safety
///
/// This property is inherited from the CUDA API, which is thread-safe.
unsafe impl Sync for Stream {}
impl Stream {
pub fn null() -> Self {
Self {
// SAFETY: This is safe because a null pointer for stream indicates the default
// stream in CUDA and all functions accept this.
internal: unsafe { DevicePtr::null() },
}
}
pub fn new() -> Result<Self> {
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let ptr_ptr = std::ptr::addr_of_mut!(ptr);
let ret = cpp!(unsafe [
ptr_ptr as "void**"
] -> i32 as "std::int32_t" {
return cudaStreamCreate((cudaStream_t*) ptr_ptr);
});
result!(
ret,
Stream {
internal: ptr.into()
}
)
}
pub fn synchronize(&self) -> Result<()> {
let ptr = self.internal.as_ptr();
let ret = cpp!(unsafe [
ptr as "void*"
] -> i32 as "std::int32_t" {
return cudaStreamSynchronize((cudaStream_t) ptr);
});
result!(ret)
}
pub fn add_callback(&self, f: impl FnOnce() + Send) -> Result<()> {
let ptr = self.internal.as_ptr();
let f_boxed = Box::new(f) as Box<dyn FnOnce()>;
let f_boxed2 = Box::new(f_boxed);
let f_boxed2_ptr = Box::into_raw(f_boxed2);
let user_data = f_boxed2_ptr as *mut std::ffi::c_void;
let ret = cpp!(unsafe [
ptr as "void*",
user_data as "void*"
] -> i32 as "std::int32_t" {
return cudaStreamAddCallback(
(cudaStream_t) ptr,
cuda_ffi_Callback,
user_data,
0
);
});
result!(ret)
}
/// Get readonly reference to internal [`DevicePtr`].
#[inline(always)]
pub fn as_internal(&self) -> &DevicePtr {
&self.internal
}
/// Get readonly reference to internal [`DevicePtr`].
#[inline(always)]
pub fn as_mut_internal(&mut self) -> &mut DevicePtr {
&mut self.internal
}
}
impl Drop for Stream {
fn drop(&mut self) {
if self.internal.is_null() {
return;
}
// SAFETY: This will cause `self` to hold a null pointer. It is safe here because we don't
// use the object after this.
let mut internal = unsafe { self.internal.take() };
let ptr = internal.as_mut_ptr();
// SAFETY: We must synchronize the stream before destroying it to make sure we are not
// dropping a stream that still has operations pending.
let _ret = cpp!(unsafe [
ptr as "void*"
] -> i32 as "std::int32_t" {
return cudaStreamSynchronize((cudaStream_t) ptr);
});
let _ret = cpp!(unsafe [
ptr as "void*"
] -> i32 as "std::int32_t" {
return cudaStreamDestroy((cudaStream_t) ptr);
});
}
}
cpp! {{
/// Holds the C++ code that makes up the native part required to get our CUDA callback to work
/// over the FFI.
///
/// # Arguments
///
/// * `stream` - The CUDA stream on which the callback was scheduled.
/// * `status` - The CUDA status value (this could represent an error from an earlier async CUDA
/// call).
/// * `user_data` - The user data pointer provided when adding the callback.
///
/// # Example
///
/// It can be used like so:
///
/// ```cpp
/// return cudaStreamAddCallback(
/// stream,
/// cuda_ffi_Callback,
/// user_data,
/// 0
/// );
/// ```
static void cuda_ffi_Callback(
__attribute__((unused)) cudaStream_t stream,
cudaError_t status,
void* user_data
) {
rust!(cuda_ffi_Callback_internal [
status : i32 as "std::int32_t",
user_data : *mut std::ffi::c_void as "void*"
] {
// SAFETY: We boxed the closure ourselves and did `Box::into_raw`, which allows us to
// reinstate the box here and use it accordingly. It will be dropped here after use.
unsafe {
let user_data = std::mem::transmute(user_data);
let function = Box::<Box<dyn FnOnce()>>::from_raw(user_data);
function()
}
});
}
}}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
assert!(Stream::new().is_ok());
}
#[test]
fn test_synchronize() {
let stream = Stream::new().unwrap();
assert!(stream.synchronize().is_ok());
}
#[test]
fn test_synchronize_null_stream() {
let stream = Stream::null();
assert!(stream.synchronize().is_ok());
}
}