async_cuda_core/runtime/
future.rs

1use std::pin::Pin;
2use std::sync::{Arc, Condvar, Mutex};
3use std::task::{Context, Poll, Waker};
4
5use crate::error::Error;
6use crate::runtime::thread_local::RUNTIME_THREAD_LOCAL;
7use crate::runtime::work::Work;
8use crate::stream::Stream;
9
10type Result<T> = std::result::Result<T, Error>;
11
12/// Represents a closure that can be executed in the runtime.
13pub type Closure<'closure> = Box<dyn FnOnce() + Send + 'closure>;
14
15/// Future for CUDA operations.
16///
17/// Note that this future abstracts over two different asynchronousprimitives: dedicated-thread
18/// semantics, and stream asynchrony.
19///
20/// # Dedicated-thread semantics
21///
22/// In this crate, all operations that use CUDA internally are off-loaded to a dedicated thread (the
23/// runtime). This improves CUDA's ability to parallelize without being interrupted by the OS
24/// scheduler or being affected by starvation when under load.
25///
26/// # Stream asynchrony
27///
28/// CUDA has internal asynchrony as well. Lots of CUDA operations are asynchronous with respect to
29/// the host with regards to the stream they are bound to.
30///
31/// It is important to understand that most of the operations in this crate do *NOT* actually wait
32/// for the CUDA asynchronous operation to complete. Instead, the operation is started and then the
33/// future becomes ready. This means that if the caller must still synchronize the underlying CUDA
34/// stream.
35///
36/// # Usage
37///
38/// To create a [`Future`], move the closure into with `Future::new`:
39///
40/// ```
41/// # use async_cuda_core::runtime::Future;
42/// # tokio_test::block_on(async {
43/// let future = Future::new(move || {
44///     ()
45/// });
46/// let return_value = future.await;
47/// assert_eq!(return_value, ());
48/// # })
49/// ```
50pub struct Future<'closure, T> {
51    shared: Arc<Mutex<Shared<'closure, T>>>,
52    completed: Arc<Condvar>,
53    _phantom: std::marker::PhantomData<&'closure ()>,
54}
55
56impl<'closure, T> Future<'closure, T> {
57    /// Wrap the provided function in this future. It will be sent to the runtime thread and
58    /// executed there. The future resolves once the call on the runtime completes.
59    ///
60    /// # Arguments
61    ///
62    /// * `call` - Closure that contains relevant function call.
63    ///
64    /// # Example
65    ///
66    /// ```
67    /// # use async_cuda_core::runtime::Future;
68    /// # tokio_test::block_on(async {
69    /// let return_value = Future::new(|| ()).await;
70    /// assert_eq!(return_value, ());
71    /// })
72    /// ```
73    #[inline]
74    pub fn new<F>(call: F) -> Self
75    where
76        F: FnOnce() -> T + Send + 'closure,
77        T: Send + 'closure,
78    {
79        let shared = Arc::new(Mutex::new(Shared::new()));
80        let completed = Arc::new(Condvar::new());
81        let closure = Box::new({
82            let shared = shared.clone();
83            let completed = completed.clone();
84            move || {
85                let return_value = call();
86                let mut shared = shared.lock().unwrap();
87                match shared.state {
88                    State::Running => {
89                        shared.complete(return_value);
90                        // If the future was cancelled before the function finished, the drop
91                        // function is now waiting for us to finish. Notify it here.
92                        completed.notify_all();
93                        // If the future is still active, then this will wake the executor and
94                        // cause it to poll the future again. Since we changed the state to
95                        // `State::Completed`, the future will return a result.
96                        if let Some(waker) = shared.waker.take() {
97                            waker.wake();
98                        }
99                    }
100                    _ => {
101                        panic!("unexpected state");
102                    }
103                }
104            }
105        });
106
107        shared.lock().unwrap().initialize(closure);
108
109        Self {
110            shared,
111            completed,
112            _phantom: Default::default(),
113        }
114    }
115}
116
117impl<'closure, T> std::future::Future for Future<'closure, T> {
118    type Output = T;
119
120    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
121        let mut shared = self.shared.lock().unwrap();
122        match shared.state {
123            State::New => Poll::Pending,
124            // This is the first time that the future is polled. We take the function out and
125            // enqueue it on the runtime, then change the state from `Initialized` to `Running`.
126            State::Initialized => {
127                shared.running(cx.waker().clone());
128                let closure: Box<dyn FnOnce() + Send + 'closure> =
129                    shared.closure.take().expect("initialized without function");
130                let closure: Box<dyn FnOnce() + Send + 'static> = unsafe {
131                    // SAFETY: This is safe because in `drop` we make sure to wait for the runtime
132                    // thread closure to complete if it still exists. This ensures that the closure
133                    // cannot outlive this future object. Because of this, we can simply erase the
134                    // `'closure` lifetime bound here and pretend it is `'static`.
135                    std::mem::transmute(closure)
136                };
137                RUNTIME_THREAD_LOCAL.with(|runtime| {
138                    runtime.enqueue(Work::new(closure)).expect("runtime broken");
139                });
140
141                Poll::Pending
142            }
143            // The future is still running.
144            State::Running => Poll::Pending,
145            // The future has completed and a return value is available. We take out the return
146            // value and change the state from `Completed` to `Done`.
147            State::Completed => {
148                shared.done();
149                Poll::Ready(shared.return_value.take().unwrap())
150            }
151            // It is illegal to poll a future after it has become ready before.
152            State::Done => {
153                panic!("future polled after completion");
154            }
155        }
156    }
157}
158
159impl<'closure, T> Drop for Future<'closure, T> {
160    fn drop(&mut self) {
161        let mut shared = self.shared.lock().unwrap();
162        // SAFETY:
163        //
164        // Only if the state is `State::Running` there is a chance that the closure is currently
165        // used and active. In that case we must wait for it to finish because we promised that the
166        // closure outlives the future.
167        //
168        // Note that no race conditions can occur here because we currently have the lock on the
169        // state and it is only released when waiting for the condition variable later. And even
170        // after, the state is guaranteed only to change from the runtime thread i.e. the only
171        // allowed state change is `State::Running` -> `State::Completed`.
172        if let State::Running = shared.state {
173            // SAFETY: This is where we wait for the closure to finish on the runtime thread. Since
174            // the only allowed state change at this point it `State::Running` ->
175            // `State::Completed`, we only need to check for that one.
176            while !matches!(shared.state, State::Completed) {
177                shared = self.completed.wait(shared).unwrap();
178            }
179        }
180    }
181}
182
183/// Future for the stream synchronization operation.
184///
185/// Unlike the generic [`Future`] provided by this crate, this variant only becomes ready after all
186/// operations on the given stream have completed.
187///
188/// # Usage
189///
190/// ```ignore
191/// let null_stream = Stream::null();
192/// let result = SynchronizeFuture::new(&null_stream).await;
193/// ```
194pub struct SynchronizeFuture<'closure>(Future<'closure, Result<()>>);
195
196impl<'closure> SynchronizeFuture<'closure> {
197    /// Create future that becomes ready only when all currently scheduled work on the given stream
198    /// has completed.
199    ///
200    /// # Arguments
201    ///
202    /// * `stream` - Reference to stream to synchronize.
203    ///
204    /// # Example
205    ///
206    /// ```ignore
207    /// let stream = Stream::new();
208    /// SynchronizeFuture::new(&stream).await.unwrap();
209    /// ```
210    #[inline]
211    pub(crate) fn new(stream: &'closure Stream) -> Self {
212        let shared = Arc::new(Mutex::new(Shared::new()));
213        let completed = Arc::new(Condvar::new());
214
215        // Create a closure that will be sent to the runtime thread and then executed in the
216        // dedicated thread.
217        let closure = Box::new({
218            let shared = shared.clone();
219            let completed = completed.clone();
220            move || {
221                let callback = {
222                    let shared = shared.clone();
223                    let completed = completed.clone();
224                    // Create a closure that will be executed after all work on the current CUDA
225                    // stream has completed. This closure will wake the future and make it ready.
226                    move || Self::complete(shared, completed, Ok(()))
227                };
228                if let Err(err) = stream.inner().add_callback(callback) {
229                    // If for some reason CUDA can't add the callback, we must still ready the
230                    // future or it will never complete.
231                    Self::complete(shared, completed, Err(err));
232                }
233            }
234        });
235
236        shared.lock().unwrap().initialize(closure);
237
238        Self(Future {
239            shared,
240            completed,
241            _phantom: Default::default(),
242        })
243    }
244
245    /// Set the future's shared state to reflect that the function has completed with the given
246    /// return value.
247    ///
248    /// # Arguments
249    ///
250    /// * `shared` - Closure's shared state.
251    /// * `return_value` - Closure's return value.
252    #[inline]
253    fn complete(
254        shared: Arc<Mutex<Shared<Result<()>>>>,
255        completed: Arc<Condvar>,
256        return_value: Result<()>,
257    ) {
258        if let Ok(mut shared) = shared.lock() {
259            match shared.state {
260                State::Running => {
261                    shared.complete(return_value);
262                    // If the future was cancelled before the function finished, the drop
263                    // function is now waiting for us to finish. Notify it here.
264                    completed.notify_all();
265                    // If the future is still active, then this will wake the executor and
266                    // cause it to poll the future again. Since we changed the state to
267                    // `State::Completed`, the future will return a result.
268                    if let Some(waker) = shared.waker.take() {
269                        waker.wake();
270                    }
271                }
272                _ => {
273                    panic!("unexpected state");
274                }
275            }
276        }
277    }
278}
279
280impl<'closure> std::future::Future for SynchronizeFuture<'closure> {
281    type Output = Result<()>;
282
283    #[inline]
284    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
285        Pin::new(&mut self.0).poll(cx)
286    }
287}
288
289/// Share state between the future and the closure that is sent over to the runtime.
290struct Shared<'closure, T> {
291    /// Current future state.
292    state: State,
293    /// Closure to execute on runtime.
294    closure: Option<Closure<'closure>>,
295    /// Waker that can be used to wake the future.
296    waker: Option<Waker>,
297    /// Return value of future.
298    return_value: Option<T>,
299}
300
301#[derive(Debug, Copy, Clone, PartialEq)]
302enum State {
303    /// Future has been created but not yet been polled.
304    New,
305    /// Future has been assigned a closure and has internal state. It has not yet been polled.
306    Initialized,
307    /// Future has been polled and is scheduled. It is running and the waker will wake it up at some
308    /// point.
309    Running,
310    /// Future has completed and has a result.
311    Completed,
312    /// Future is done and result has been taken out.
313    Done,
314}
315
316impl<'closure, T> Shared<'closure, T> {
317    /// Create new [`Future`] shared state.
318    fn new() -> Self {
319        Shared {
320            state: State::New,
321            closure: None,
322            waker: None,
323            return_value: None,
324        }
325    }
326
327    /// Initialize state and move function closure into shared state.
328    ///
329    /// # Arguments
330    ///
331    /// * `closure` - The function closure. Stil unscheduled at this point.
332    ///
333    /// # Safety
334    ///
335    /// This state change may only be performed from the thread that holds the future.
336    #[inline]
337    fn initialize(&mut self, closure: Closure<'closure>) {
338        self.closure = Some(closure);
339        self.state = State::Initialized;
340    }
341
342    /// Set running state and store waker.
343    ///
344    /// # Arguments
345    ///
346    /// * `waker` - Waker that can be used by runtime to wake future.
347    ///
348    /// # Safety
349    ///
350    /// This state change may only be performed from the thread that holds the future.
351    #[inline]
352    fn running(&mut self, waker: Waker) {
353        self.waker = Some(waker);
354        self.state = State::Running;
355    }
356
357    /// Complete state and set return value.
358    ///
359    /// # Arguments
360    ///
361    /// * `return_value` - Function closure return value.
362    ///
363    /// # Safety
364    ///
365    /// This state change may only be performed from the runtime thread.
366    #[inline]
367    fn complete(&mut self, return_value: T) {
368        self.return_value = Some(return_value);
369        self.state = State::Completed;
370    }
371
372    /// Set done state.
373    ///
374    /// # Safety
375    ///
376    /// This state change may only be performed from the runtime thread.
377    #[inline]
378    fn done(&mut self) {
379        self.state = State::Done;
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use std::sync::atomic::{AtomicBool, Ordering};
386    use std::sync::Arc;
387
388    use super::*;
389
390    #[tokio::test]
391    async fn test_future() {
392        assert!(Future::new(|| true).await);
393    }
394
395    #[tokio::test]
396    async fn test_future_order() {
397        let first_future_completed = Arc::new(AtomicBool::new(false));
398        Future::new({
399            let first_future_completed = first_future_completed.clone();
400            move || {
401                first_future_completed.store(true, Ordering::Relaxed);
402            }
403        })
404        .await;
405        assert!(
406            Future::new({
407                let first_future_completed = first_future_completed.clone();
408                move || first_future_completed.load(Ordering::Relaxed)
409            })
410            .await
411        );
412    }
413
414    #[tokio::test]
415    async fn test_future_order_simple() {
416        let mut first_future_completed = false;
417        Future::new(|| first_future_completed = true).await;
418        assert!(Future::new(|| first_future_completed).await);
419    }
420
421    #[tokio::test]
422    async fn test_future_outlives_closure() {
423        let mut count_completed = 0;
424        let mut count_cancelled = 0;
425        for _ in 0..1_000 {
426            let mut start_of_closure = false;
427            let mut end_of_closure = false;
428            let future = Future::new(|| {
429                start_of_closure = true;
430                std::thread::sleep(std::time::Duration::from_millis(1));
431                end_of_closure = true;
432            });
433            let future_with_small_delay = async {
434                tokio::time::sleep(std::time::Duration::from_millis(1)).await;
435                future.await
436            };
437            let _ =
438                tokio::time::timeout(std::time::Duration::from_nanos(0), future_with_small_delay)
439                    .await;
440            assert!((start_of_closure && end_of_closure) || (!start_of_closure && !end_of_closure));
441            if end_of_closure {
442                count_completed += 1;
443            } else {
444                count_cancelled += 1;
445            }
446        }
447        println!("num completed: {count_completed}");
448        println!("num cancelled: {count_cancelled}");
449    }
450
451    #[tokio::test]
452    async fn test_future_outlives_closure_manual() {
453        let mut start_of_closure = false;
454        let mut end_of_closure = false;
455        let future = Future::new(|| {
456            start_of_closure = true;
457            std::thread::sleep(std::time::Duration::from_nanos(1000));
458            end_of_closure = true;
459        });
460        let future_with_small_delay = async {
461            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
462            future.await
463        };
464        let _ = tokio::time::timeout(std::time::Duration::ZERO, future_with_small_delay).await;
465        assert!((!start_of_closure && !end_of_closure))
466    }
467
468    #[tokio::test]
469    async fn test_future_does_not_run_if_cancelled_before_polling() {
470        let mut start_of_closure = false;
471        let mut end_of_closure = false;
472        let future = Future::new(|| {
473            start_of_closure = true;
474            std::thread::sleep(std::time::Duration::from_nanos(1000));
475            end_of_closure = true;
476        });
477        drop(future);
478        assert!((!start_of_closure && !end_of_closure))
479    }
480
481    #[tokio::test]
482    async fn test_synchronization_future() {
483        let stream = crate::Stream::new().await.unwrap();
484        assert!(SynchronizeFuture::new(&stream).await.is_ok());
485    }
486}