async_cuda_core/runtime/future.rs
1use std::pin::Pin;
2use std::sync::{Arc, Condvar, Mutex};
3use std::task::{Context, Poll, Waker};
4
5use crate::error::Error;
6use crate::runtime::thread_local::RUNTIME_THREAD_LOCAL;
7use crate::runtime::work::Work;
8use crate::stream::Stream;
9
10type Result<T> = std::result::Result<T, Error>;
11
12/// Represents a closure that can be executed in the runtime.
13pub type Closure<'closure> = Box<dyn FnOnce() + Send + 'closure>;
14
15/// Future for CUDA operations.
16///
17/// Note that this future abstracts over two different asynchronousprimitives: dedicated-thread
18/// semantics, and stream asynchrony.
19///
20/// # Dedicated-thread semantics
21///
22/// In this crate, all operations that use CUDA internally are off-loaded to a dedicated thread (the
23/// runtime). This improves CUDA's ability to parallelize without being interrupted by the OS
24/// scheduler or being affected by starvation when under load.
25///
26/// # Stream asynchrony
27///
28/// CUDA has internal asynchrony as well. Lots of CUDA operations are asynchronous with respect to
29/// the host with regards to the stream they are bound to.
30///
31/// It is important to understand that most of the operations in this crate do *NOT* actually wait
32/// for the CUDA asynchronous operation to complete. Instead, the operation is started and then the
33/// future becomes ready. This means that if the caller must still synchronize the underlying CUDA
34/// stream.
35///
36/// # Usage
37///
38/// To create a [`Future`], move the closure into with `Future::new`:
39///
40/// ```
41/// # use async_cuda_core::runtime::Future;
42/// # tokio_test::block_on(async {
43/// let future = Future::new(move || {
44/// ()
45/// });
46/// let return_value = future.await;
47/// assert_eq!(return_value, ());
48/// # })
49/// ```
50pub struct Future<'closure, T> {
51 shared: Arc<Mutex<Shared<'closure, T>>>,
52 completed: Arc<Condvar>,
53 _phantom: std::marker::PhantomData<&'closure ()>,
54}
55
56impl<'closure, T> Future<'closure, T> {
57 /// Wrap the provided function in this future. It will be sent to the runtime thread and
58 /// executed there. The future resolves once the call on the runtime completes.
59 ///
60 /// # Arguments
61 ///
62 /// * `call` - Closure that contains relevant function call.
63 ///
64 /// # Example
65 ///
66 /// ```
67 /// # use async_cuda_core::runtime::Future;
68 /// # tokio_test::block_on(async {
69 /// let return_value = Future::new(|| ()).await;
70 /// assert_eq!(return_value, ());
71 /// })
72 /// ```
73 #[inline]
74 pub fn new<F>(call: F) -> Self
75 where
76 F: FnOnce() -> T + Send + 'closure,
77 T: Send + 'closure,
78 {
79 let shared = Arc::new(Mutex::new(Shared::new()));
80 let completed = Arc::new(Condvar::new());
81 let closure = Box::new({
82 let shared = shared.clone();
83 let completed = completed.clone();
84 move || {
85 let return_value = call();
86 let mut shared = shared.lock().unwrap();
87 match shared.state {
88 State::Running => {
89 shared.complete(return_value);
90 // If the future was cancelled before the function finished, the drop
91 // function is now waiting for us to finish. Notify it here.
92 completed.notify_all();
93 // If the future is still active, then this will wake the executor and
94 // cause it to poll the future again. Since we changed the state to
95 // `State::Completed`, the future will return a result.
96 if let Some(waker) = shared.waker.take() {
97 waker.wake();
98 }
99 }
100 _ => {
101 panic!("unexpected state");
102 }
103 }
104 }
105 });
106
107 shared.lock().unwrap().initialize(closure);
108
109 Self {
110 shared,
111 completed,
112 _phantom: Default::default(),
113 }
114 }
115}
116
117impl<'closure, T> std::future::Future for Future<'closure, T> {
118 type Output = T;
119
120 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
121 let mut shared = self.shared.lock().unwrap();
122 match shared.state {
123 State::New => Poll::Pending,
124 // This is the first time that the future is polled. We take the function out and
125 // enqueue it on the runtime, then change the state from `Initialized` to `Running`.
126 State::Initialized => {
127 shared.running(cx.waker().clone());
128 let closure: Box<dyn FnOnce() + Send + 'closure> =
129 shared.closure.take().expect("initialized without function");
130 let closure: Box<dyn FnOnce() + Send + 'static> = unsafe {
131 // SAFETY: This is safe because in `drop` we make sure to wait for the runtime
132 // thread closure to complete if it still exists. This ensures that the closure
133 // cannot outlive this future object. Because of this, we can simply erase the
134 // `'closure` lifetime bound here and pretend it is `'static`.
135 std::mem::transmute(closure)
136 };
137 RUNTIME_THREAD_LOCAL.with(|runtime| {
138 runtime.enqueue(Work::new(closure)).expect("runtime broken");
139 });
140
141 Poll::Pending
142 }
143 // The future is still running.
144 State::Running => Poll::Pending,
145 // The future has completed and a return value is available. We take out the return
146 // value and change the state from `Completed` to `Done`.
147 State::Completed => {
148 shared.done();
149 Poll::Ready(shared.return_value.take().unwrap())
150 }
151 // It is illegal to poll a future after it has become ready before.
152 State::Done => {
153 panic!("future polled after completion");
154 }
155 }
156 }
157}
158
159impl<'closure, T> Drop for Future<'closure, T> {
160 fn drop(&mut self) {
161 let mut shared = self.shared.lock().unwrap();
162 // SAFETY:
163 //
164 // Only if the state is `State::Running` there is a chance that the closure is currently
165 // used and active. In that case we must wait for it to finish because we promised that the
166 // closure outlives the future.
167 //
168 // Note that no race conditions can occur here because we currently have the lock on the
169 // state and it is only released when waiting for the condition variable later. And even
170 // after, the state is guaranteed only to change from the runtime thread i.e. the only
171 // allowed state change is `State::Running` -> `State::Completed`.
172 if let State::Running = shared.state {
173 // SAFETY: This is where we wait for the closure to finish on the runtime thread. Since
174 // the only allowed state change at this point it `State::Running` ->
175 // `State::Completed`, we only need to check for that one.
176 while !matches!(shared.state, State::Completed) {
177 shared = self.completed.wait(shared).unwrap();
178 }
179 }
180 }
181}
182
183/// Future for the stream synchronization operation.
184///
185/// Unlike the generic [`Future`] provided by this crate, this variant only becomes ready after all
186/// operations on the given stream have completed.
187///
188/// # Usage
189///
190/// ```ignore
191/// let null_stream = Stream::null();
192/// let result = SynchronizeFuture::new(&null_stream).await;
193/// ```
194pub struct SynchronizeFuture<'closure>(Future<'closure, Result<()>>);
195
196impl<'closure> SynchronizeFuture<'closure> {
197 /// Create future that becomes ready only when all currently scheduled work on the given stream
198 /// has completed.
199 ///
200 /// # Arguments
201 ///
202 /// * `stream` - Reference to stream to synchronize.
203 ///
204 /// # Example
205 ///
206 /// ```ignore
207 /// let stream = Stream::new();
208 /// SynchronizeFuture::new(&stream).await.unwrap();
209 /// ```
210 #[inline]
211 pub(crate) fn new(stream: &'closure Stream) -> Self {
212 let shared = Arc::new(Mutex::new(Shared::new()));
213 let completed = Arc::new(Condvar::new());
214
215 // Create a closure that will be sent to the runtime thread and then executed in the
216 // dedicated thread.
217 let closure = Box::new({
218 let shared = shared.clone();
219 let completed = completed.clone();
220 move || {
221 let callback = {
222 let shared = shared.clone();
223 let completed = completed.clone();
224 // Create a closure that will be executed after all work on the current CUDA
225 // stream has completed. This closure will wake the future and make it ready.
226 move || Self::complete(shared, completed, Ok(()))
227 };
228 if let Err(err) = stream.inner().add_callback(callback) {
229 // If for some reason CUDA can't add the callback, we must still ready the
230 // future or it will never complete.
231 Self::complete(shared, completed, Err(err));
232 }
233 }
234 });
235
236 shared.lock().unwrap().initialize(closure);
237
238 Self(Future {
239 shared,
240 completed,
241 _phantom: Default::default(),
242 })
243 }
244
245 /// Set the future's shared state to reflect that the function has completed with the given
246 /// return value.
247 ///
248 /// # Arguments
249 ///
250 /// * `shared` - Closure's shared state.
251 /// * `return_value` - Closure's return value.
252 #[inline]
253 fn complete(
254 shared: Arc<Mutex<Shared<Result<()>>>>,
255 completed: Arc<Condvar>,
256 return_value: Result<()>,
257 ) {
258 if let Ok(mut shared) = shared.lock() {
259 match shared.state {
260 State::Running => {
261 shared.complete(return_value);
262 // If the future was cancelled before the function finished, the drop
263 // function is now waiting for us to finish. Notify it here.
264 completed.notify_all();
265 // If the future is still active, then this will wake the executor and
266 // cause it to poll the future again. Since we changed the state to
267 // `State::Completed`, the future will return a result.
268 if let Some(waker) = shared.waker.take() {
269 waker.wake();
270 }
271 }
272 _ => {
273 panic!("unexpected state");
274 }
275 }
276 }
277 }
278}
279
280impl<'closure> std::future::Future for SynchronizeFuture<'closure> {
281 type Output = Result<()>;
282
283 #[inline]
284 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
285 Pin::new(&mut self.0).poll(cx)
286 }
287}
288
289/// Share state between the future and the closure that is sent over to the runtime.
290struct Shared<'closure, T> {
291 /// Current future state.
292 state: State,
293 /// Closure to execute on runtime.
294 closure: Option<Closure<'closure>>,
295 /// Waker that can be used to wake the future.
296 waker: Option<Waker>,
297 /// Return value of future.
298 return_value: Option<T>,
299}
300
301#[derive(Debug, Copy, Clone, PartialEq)]
302enum State {
303 /// Future has been created but not yet been polled.
304 New,
305 /// Future has been assigned a closure and has internal state. It has not yet been polled.
306 Initialized,
307 /// Future has been polled and is scheduled. It is running and the waker will wake it up at some
308 /// point.
309 Running,
310 /// Future has completed and has a result.
311 Completed,
312 /// Future is done and result has been taken out.
313 Done,
314}
315
316impl<'closure, T> Shared<'closure, T> {
317 /// Create new [`Future`] shared state.
318 fn new() -> Self {
319 Shared {
320 state: State::New,
321 closure: None,
322 waker: None,
323 return_value: None,
324 }
325 }
326
327 /// Initialize state and move function closure into shared state.
328 ///
329 /// # Arguments
330 ///
331 /// * `closure` - The function closure. Stil unscheduled at this point.
332 ///
333 /// # Safety
334 ///
335 /// This state change may only be performed from the thread that holds the future.
336 #[inline]
337 fn initialize(&mut self, closure: Closure<'closure>) {
338 self.closure = Some(closure);
339 self.state = State::Initialized;
340 }
341
342 /// Set running state and store waker.
343 ///
344 /// # Arguments
345 ///
346 /// * `waker` - Waker that can be used by runtime to wake future.
347 ///
348 /// # Safety
349 ///
350 /// This state change may only be performed from the thread that holds the future.
351 #[inline]
352 fn running(&mut self, waker: Waker) {
353 self.waker = Some(waker);
354 self.state = State::Running;
355 }
356
357 /// Complete state and set return value.
358 ///
359 /// # Arguments
360 ///
361 /// * `return_value` - Function closure return value.
362 ///
363 /// # Safety
364 ///
365 /// This state change may only be performed from the runtime thread.
366 #[inline]
367 fn complete(&mut self, return_value: T) {
368 self.return_value = Some(return_value);
369 self.state = State::Completed;
370 }
371
372 /// Set done state.
373 ///
374 /// # Safety
375 ///
376 /// This state change may only be performed from the runtime thread.
377 #[inline]
378 fn done(&mut self) {
379 self.state = State::Done;
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use std::sync::atomic::{AtomicBool, Ordering};
386 use std::sync::Arc;
387
388 use super::*;
389
390 #[tokio::test]
391 async fn test_future() {
392 assert!(Future::new(|| true).await);
393 }
394
395 #[tokio::test]
396 async fn test_future_order() {
397 let first_future_completed = Arc::new(AtomicBool::new(false));
398 Future::new({
399 let first_future_completed = first_future_completed.clone();
400 move || {
401 first_future_completed.store(true, Ordering::Relaxed);
402 }
403 })
404 .await;
405 assert!(
406 Future::new({
407 let first_future_completed = first_future_completed.clone();
408 move || first_future_completed.load(Ordering::Relaxed)
409 })
410 .await
411 );
412 }
413
414 #[tokio::test]
415 async fn test_future_order_simple() {
416 let mut first_future_completed = false;
417 Future::new(|| first_future_completed = true).await;
418 assert!(Future::new(|| first_future_completed).await);
419 }
420
421 #[tokio::test]
422 async fn test_future_outlives_closure() {
423 let mut count_completed = 0;
424 let mut count_cancelled = 0;
425 for _ in 0..1_000 {
426 let mut start_of_closure = false;
427 let mut end_of_closure = false;
428 let future = Future::new(|| {
429 start_of_closure = true;
430 std::thread::sleep(std::time::Duration::from_millis(1));
431 end_of_closure = true;
432 });
433 let future_with_small_delay = async {
434 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
435 future.await
436 };
437 let _ =
438 tokio::time::timeout(std::time::Duration::from_nanos(0), future_with_small_delay)
439 .await;
440 assert!((start_of_closure && end_of_closure) || (!start_of_closure && !end_of_closure));
441 if end_of_closure {
442 count_completed += 1;
443 } else {
444 count_cancelled += 1;
445 }
446 }
447 println!("num completed: {count_completed}");
448 println!("num cancelled: {count_cancelled}");
449 }
450
451 #[tokio::test]
452 async fn test_future_outlives_closure_manual() {
453 let mut start_of_closure = false;
454 let mut end_of_closure = false;
455 let future = Future::new(|| {
456 start_of_closure = true;
457 std::thread::sleep(std::time::Duration::from_nanos(1000));
458 end_of_closure = true;
459 });
460 let future_with_small_delay = async {
461 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
462 future.await
463 };
464 let _ = tokio::time::timeout(std::time::Duration::ZERO, future_with_small_delay).await;
465 assert!((!start_of_closure && !end_of_closure))
466 }
467
468 #[tokio::test]
469 async fn test_future_does_not_run_if_cancelled_before_polling() {
470 let mut start_of_closure = false;
471 let mut end_of_closure = false;
472 let future = Future::new(|| {
473 start_of_closure = true;
474 std::thread::sleep(std::time::Duration::from_nanos(1000));
475 end_of_closure = true;
476 });
477 drop(future);
478 assert!((!start_of_closure && !end_of_closure))
479 }
480
481 #[tokio::test]
482 async fn test_synchronization_future() {
483 let stream = crate::Stream::new().await.unwrap();
484 assert!(SynchronizeFuture::new(&stream).await.is_ok());
485 }
486}