Skip to main content

sp1_gpu_cudart/
stream.rs

1use super::{CudaError, CudaEvent};
2use slop_alloc::mem::{CopyDirection, CopyError, DeviceMemory};
3use slop_alloc::{AllocError, Allocator};
4use sp1_gpu_sys::runtime::{
5    cuda_event_record, cuda_free_async, cuda_launch_host_function, cuda_malloc_async,
6    cuda_mem_copy_device_to_device_async, cuda_mem_copy_device_to_host_async,
7    cuda_mem_copy_host_to_device_async, cuda_mem_set_async, cuda_stream_create,
8    cuda_stream_destroy, cuda_stream_query, cuda_stream_synchronize, cuda_stream_wait_event,
9    CudaStreamHandle, Dim3, KernelPtr, DEFAULT_STREAM,
10};
11use std::{
12    alloc::Layout,
13    ffi::c_void,
14    future::{Future, IntoFuture},
15    ops::Deref,
16    pin::Pin,
17    ptr::{self, NonNull},
18    sync::{Arc, Mutex},
19    task::{Context, Poll, Waker},
20    time::Duration,
21};
22use tokio::time::Interval;
23
24pub(crate) const INTERVAL_MS: u64 = 2000;
25
26#[derive(Debug, PartialEq, Eq, Hash)]
27#[repr(transparent)]
28pub struct CudaStream(pub(crate) CudaStreamHandle);
29
30unsafe impl Send for CudaStream {}
31unsafe impl Sync for CudaStream {}
32
33impl Drop for CudaStream {
34    fn drop(&mut self) {
35        if self.0 != unsafe { DEFAULT_STREAM } {
36            // We unwrap because any cuda error should throw here.
37            CudaError::result_from_ffi(unsafe { cuda_stream_destroy(self.0) }).unwrap();
38        }
39    }
40}
41
42impl CudaStream {
43    #[inline]
44    pub(crate) fn create() -> Result<Self, CudaError> {
45        let mut ptr = CudaStreamHandle(ptr::null_mut());
46        CudaError::result_from_ffi(unsafe {
47            cuda_stream_create(&mut ptr as *mut CudaStreamHandle)
48        })?;
49        Ok(Self(ptr))
50    }
51
52    /// # Safety
53    ///
54    /// TODO
55    #[inline]
56    unsafe fn launch_host_fn(
57        &self,
58        host_fn: Option<unsafe extern "C" fn(*mut c_void)>,
59        data: *const c_void,
60    ) -> Result<(), CudaError> {
61        CudaError::result_from_ffi(unsafe { cuda_launch_host_function(self.0, host_fn, data) })
62    }
63
64    /// # Safety
65    ///
66    /// This function launch is asynchronous when called with the non-default stream. The caller
67    /// must ensure that the data read to and written by the kernel remains valid throughout its
68    /// execution.
69    #[inline]
70    pub unsafe fn launch_kernel(
71        &self,
72        kernel: KernelPtr,
73        grid_dim: impl Into<Dim3>,
74        block_dim: impl Into<Dim3>,
75        args: &[*mut c_void],
76        shared_mem: usize,
77    ) -> Result<(), CudaError> {
78        CudaError::result_from_ffi(sp1_gpu_sys::runtime::cuda_launch_kernel(
79            kernel,
80            grid_dim.into(),
81            block_dim.into(),
82            args.as_ptr() as *mut *mut c_void,
83            shared_mem,
84            self.0,
85        ))
86    }
87
88    #[inline]
89    fn query(&self) -> Result<(), CudaError> {
90        CudaError::result_from_ffi(unsafe { cuda_stream_query(self.0) })
91    }
92
93    #[inline]
94    fn record(&self, event: &CudaEvent) -> Result<(), CudaError> {
95        CudaError::result_from_ffi(unsafe { cuda_event_record(event.0, self.0) })
96    }
97
98    /// # Safety
99    ///
100    /// This function is marked unsafe because it requires the caller to ensure that the event is
101    /// valid and that the stream is valid.
102    #[inline]
103    unsafe fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
104        CudaError::result_from_ffi(cuda_stream_wait_event(self.0, event.0))
105    }
106
107    #[inline]
108    fn synchronize(&self) -> Result<(), CudaError> {
109        CudaError::result_from_ffi(unsafe { cuda_stream_synchronize(self.0) })
110    }
111}
112
113impl Default for CudaStream {
114    fn default() -> Self {
115        Self(unsafe { DEFAULT_STREAM })
116    }
117}
118
119/// State shared between the future and the CUDA callback
120struct CallbackState<S> {
121    // Holding the stream to prevent it from being dropped
122    task: Option<S>,
123    done: bool,
124    result: Result<(), CudaError>,
125    waker: Option<Waker>,
126}
127
128/// A future that completes once the GPU has completed all work queued in `stream` so far.
129///
130/// This future uses a callback to the host to check if the GPU has completed all work. This is
131/// useful for waiting for the GPU to finish work before continuing on the host and avoiding
132/// busy-waiting.
133pub struct StreamCallbackFuture<S> {
134    shared: Arc<Mutex<CallbackState<S>>>,
135    interval: Pin<Box<Interval>>,
136}
137
138// /// A future that completes once the GPU has completed all work queued in `stream` so far.
139// ///
140// /// This future uses a busy-wait loop to check if the GPU has completed all work. This is useful for
141// /// a future waiting on stream completion with minimal overhead.
142// #[repr(transparent)]
143// pub struct StreamSpinFuture {
144//     stream: CudaStream,
145// }
146
147pub trait StreamRef {
148    unsafe fn stream(&self) -> &CudaStream;
149
150    /// # Safety
151    ///
152    /// TODO
153    #[inline]
154    unsafe fn launch_host_fn_uncheked(
155        &self,
156        host_fn: Option<unsafe extern "C" fn(*mut c_void)>,
157        data: *const c_void,
158    ) -> Result<(), CudaError> {
159        self.stream().launch_host_fn(host_fn, data)
160    }
161
162    #[inline]
163    unsafe fn query(&self) -> Result<(), CudaError> {
164        self.stream().query()
165    }
166
167    #[inline]
168    unsafe fn record_unchecked(&self, event: &CudaEvent) -> Result<(), CudaError> {
169        self.stream().record(event)
170    }
171
172    /// # Safety
173    ///
174    /// This function is marked unsafe because it requires the caller to ensure that the event is
175    /// valid and that the stream is valid.
176    #[inline]
177    unsafe fn wait_unchecked(&self, event: &CudaEvent) -> Result<(), CudaError> {
178        self.stream().wait(event)
179    }
180
181    #[inline]
182    unsafe fn stream_synchronize(&self) -> Result<(), CudaError> {
183        self.stream().synchronize()
184    }
185}
186
187impl StreamRef for CudaStream {
188    #[inline]
189    unsafe fn stream(&self) -> &CudaStream {
190        self
191    }
192}
193
194impl<S> StreamRef for Arc<S>
195where
196    S: StreamRef + ?Sized,
197{
198    #[inline]
199    unsafe fn stream(&self) -> &CudaStream {
200        self.as_ref().stream()
201    }
202}
203
204impl<S> StreamCallbackFuture<S> {
205    /// Creates a new future that completes once the GPU has completed
206    /// all work queued in `stream` so far.
207    pub fn new(task: S) -> Self
208    where
209        S: StreamRef,
210    {
211        // 1) Create an Arc<Mutex<...>> for the shared state
212        let shared = Arc::new(Mutex::new(CallbackState {
213            task: None,
214            done: false,
215            result: Ok(()),
216            waker: None,
217        }));
218
219        // 2) Convert Arc to a raw pointer for CUDA, leaking one Arc so  that the context is not
220        //    dropped before the callback is called.
221        let ptr = Arc::into_raw(shared.clone()) as *mut c_void;
222
223        // 3) Enqueue the callback on the given stream
224        //    This means "when the GPU finishes all prior tasks in `stream`,
225        //    call `my_host_callback(ptr)`"
226        let launch_result = unsafe { task.stream().launch_host_fn(Some(waker_callback::<S>), ptr) };
227
228        shared.lock().unwrap().task = Some(task);
229
230        if let Err(e) = launch_result {
231            let mut state = shared.lock().unwrap();
232            state.result = Err(e);
233            state.done = true;
234        }
235
236        let interval = Box::pin(tokio::time::interval(Duration::from_millis(INTERVAL_MS)));
237
238        Self { shared, interval }
239    }
240}
241
242unsafe extern "C" fn waker_callback<S>(user_data: *mut c_void)
243where
244    S: StreamRef,
245{
246    // Convert the raw pointer back to our Arc<Mutex<CallbackState>>
247    let shared = Arc::<Mutex<CallbackState<S>>>::from_raw(user_data as *const _);
248    let mut state = shared.lock().unwrap();
249
250    // Mark GPU done
251    state.done = true;
252
253    // If we have a waker, wake it so poll() is called again
254    if let Some(ref waker) = state.waker {
255        waker.wake_by_ref();
256    }
257}
258
259impl<S> Future for StreamCallbackFuture<S>
260where
261    S: StreamRef,
262{
263    type Output = Result<(), CudaError>;
264
265    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
266        let mut state = self.shared.lock().unwrap();
267
268        // If the stream is done, return the result
269        if state.done {
270            // GPU has reached the callback
271            return Poll::Ready(state.result);
272        }
273
274        //  If not done, check the stream's status
275        match unsafe { state.task.as_ref().unwrap().stream().query() } {
276            Ok(()) => {
277                state.done = true;
278                state.result = Ok(());
279                return Poll::Ready(Ok(()));
280            }
281            Err(CudaError::NotReady) => {
282                // Stream is not done yet, so we need to wait for it.
283            }
284            Err(e) => {
285                // Got an error from the stream, so we need to return it.
286                state.done = true;
287                state.result = Err(e);
288                return Poll::Ready(Err(e));
289            }
290        }
291
292        // Not done yet, store the waker so we can wake it later
293        state.waker = Some(cx.waker().clone());
294        drop(state);
295
296        // Poll the interval to check if we need to wake up again
297        match self.interval.as_mut().poll_tick(cx) {
298            Poll::Ready(_) => {
299                // The time has passed, so we need to schedule another poll
300                cx.waker().wake_by_ref();
301                Poll::Pending
302            }
303            Poll::Pending => {
304                // The time has not passed yet, so we need to wait for it or for the callback.
305                Poll::Pending
306            }
307        }
308    }
309}
310
311impl IntoFuture for CudaStream {
312    type Output = Result<(), CudaError>;
313    type IntoFuture = StreamCallbackFuture<Self>;
314
315    fn into_future(self) -> Self::IntoFuture {
316        StreamCallbackFuture::new(self)
317    }
318}
319
320unsafe impl Allocator for CudaStream {
321    #[inline]
322    unsafe fn allocate(&self, layout: Layout) -> Result<ptr::NonNull<[u8]>, AllocError> {
323        let mut ptr: *mut c_void = ptr::null_mut();
324        unsafe {
325            CudaError::result_from_ffi(cuda_malloc_async(
326                &mut ptr as *mut *mut c_void,
327                layout.size(),
328                self.0,
329            ))
330            .map_err(|_| AllocError)?;
331        };
332        let ptr = ptr as *mut u8;
333        Ok(NonNull::slice_from_raw_parts(NonNull::new_unchecked(ptr), layout.size()))
334    }
335
336    #[inline]
337    unsafe fn deallocate(&self, ptr: NonNull<u8>, _layout: Layout) {
338        unsafe {
339            CudaError::result_from_ffi(cuda_free_async(ptr.as_ptr() as *mut c_void, self.0))
340                .unwrap()
341        }
342    }
343}
344
345impl DeviceMemory for CudaStream {
346    #[inline]
347    unsafe fn copy_nonoverlapping(
348        &self,
349        src: *const u8,
350        dst: *mut u8,
351        size: usize,
352        direction: CopyDirection,
353    ) -> Result<(), CopyError> {
354        let maybe_err = match direction {
355            CopyDirection::HostToDevice => cuda_mem_copy_host_to_device_async(
356                dst as *mut c_void,
357                src as *const c_void,
358                size,
359                self.0,
360            ),
361            CopyDirection::DeviceToHost => cuda_mem_copy_device_to_host_async(
362                dst as *mut c_void,
363                src as *const c_void,
364                size,
365                self.0,
366            ),
367            CopyDirection::DeviceToDevice => cuda_mem_copy_device_to_device_async(
368                dst as *mut c_void,
369                src as *const c_void,
370                size,
371                self.0,
372            ),
373        };
374        CudaError::result_from_ffi(maybe_err).map_err(|_| CopyError)
375    }
376
377    #[inline]
378    unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
379        unsafe {
380            CudaError::result_from_ffi(cuda_mem_set_async(dst as *mut c_void, value, size, self.0))
381                .map_err(|_| CopyError)
382        }
383    }
384}
385
386#[derive(Debug, PartialEq, Eq, Hash)]
387pub struct UnsafeCudaStream(CudaStream);
388
389impl UnsafeCudaStream {
390    #[allow(dead_code)]
391    pub fn create() -> Result<Self, CudaError> {
392        Ok(Self(CudaStream::create()?))
393    }
394}
395
396impl Deref for UnsafeCudaStream {
397    type Target = CudaStream;
398
399    fn deref(&self) -> &Self::Target {
400        &self.0
401    }
402}
403
404impl StreamRef for UnsafeCudaStream {
405    #[inline]
406    unsafe fn stream(&self) -> &CudaStream {
407        &self.0
408    }
409}
410
411impl IntoFuture for UnsafeCudaStream {
412    type Output = Result<(), CudaError>;
413    type IntoFuture = StreamCallbackFuture<Self>;
414
415    fn into_future(self) -> Self::IntoFuture {
416        StreamCallbackFuture::new(self)
417    }
418}
419
420unsafe impl Allocator for UnsafeCudaStream {
421    #[inline]
422    unsafe fn allocate(&self, layout: Layout) -> Result<ptr::NonNull<[u8]>, AllocError> {
423        self.0.allocate(layout)
424    }
425
426    #[inline]
427    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
428        self.0.deallocate(ptr, layout)
429    }
430}
431
432impl DeviceMemory for UnsafeCudaStream {
433    #[inline]
434    unsafe fn copy_nonoverlapping(
435        &self,
436        src: *const u8,
437        dst: *mut u8,
438        size: usize,
439        direction: CopyDirection,
440    ) -> Result<(), CopyError> {
441        self.0.copy_nonoverlapping(src, dst, size, direction)
442    }
443
444    #[inline]
445    unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
446        self.0.write_bytes(dst, value, size)
447    }
448}