Skip to main content

sp1_gpu_cudart/
task.rs

1use std::{
2    alloc::Layout,
3    ffi::c_void,
4    future::{Future, IntoFuture},
5    mem::MaybeUninit,
6    ops::Deref,
7    pin::Pin,
8    ptr::{self, NonNull},
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc, OnceLock, Weak,
12    },
13    task::{Context, Poll},
14    time::Duration,
15};
16
17use futures::{future::MapOkOrElse, TryFutureExt};
18use pin_project::pin_project;
19use slop_alloc::{
20    mem::{CopyDirection, CopyError, DeviceMemory},
21    AllocError, Allocator, Backend, Buffer, Slice,
22};
23use slop_futures::queue::{AcquireWorkerError, TryAcquireWorkerError, Worker, WorkerQueue};
24use sp1_gpu_sys::runtime::{
25    cuda_device_get_default_mem_pool, cuda_mem_pool_set_release_threshold, CudaDevice, CudaMemPool,
26    CudaStreamHandle, Dim3, KernelPtr,
27};
28use thiserror::Error;
29use tokio::{sync::oneshot, task::JoinHandle};
30
31use crate::{DeviceCopy, ToDevice};
32
33use super::{
34    stream::{StreamRef, INTERVAL_MS},
35    sync::CudaSend,
36    CudaError, CudaEvent, CudaStream, IntoDevice, StreamCallbackFuture,
37};
38
39const DEFAULT_NUM_TASKS: usize = 64;
40
41static GLOBAL_TASK_POOL: OnceLock<Arc<TaskPool>> = OnceLock::new();
42
43static POOL_ID: AtomicUsize = AtomicUsize::new(0);
44
45pub struct TaskPoolBuilder {
46    device: CudaDevice,
47    mem_release_threshold: u64,
48    capacity: Option<usize>,
49}
50
51pub(crate) fn global_task_pool() -> &'static Arc<TaskPool> {
52    GLOBAL_TASK_POOL.get_or_init(|| Arc::new(TaskPoolBuilder::new().build().unwrap()))
53}
54
55pub struct SpawnHandle<T> {
56    handle: JoinHandle<Result<T, CudaError>>,
57}
58
59impl<T> SpawnHandle<T> {
60    pub fn abort(&self) {
61        self.handle.abort();
62    }
63}
64
65#[derive(Debug, Error)]
66pub enum SpawnError {
67    #[error("join handle panicked with error: {0}")]
68    JoinError(#[from] tokio::task::JoinError),
69    #[error("cuda error: {0}")]
70    CudaError(#[from] CudaError),
71    #[error("failed to acquire a task from the pool")]
72    TaskSpawnError(#[from] TaskSpawnError),
73}
74
75fn map_ok_value<T>(e: Result<T, CudaError>) -> Result<T, SpawnError> {
76    e.map_err(SpawnError::CudaError)
77}
78
79fn map_err_value<T>(e: tokio::task::JoinError) -> Result<T, SpawnError> {
80    Err(SpawnError::JoinError(e))
81}
82
83impl<T> IntoFuture for SpawnHandle<T> {
84    type Output = Result<T, SpawnError>;
85
86    type IntoFuture = MapOkOrElse<
87        JoinHandle<Result<T, CudaError>>,
88        fn(Result<T, CudaError>) -> Result<T, SpawnError>,
89        fn(tokio::task::JoinError) -> Result<T, SpawnError>,
90    >;
91
92    fn into_future(self) -> Self::IntoFuture {
93        self.handle.map_ok_or_else(map_err_value, map_ok_value)
94    }
95}
96
97pub fn spawn<F, Fut>(f: F) -> SpawnHandle<Fut::Output>
98where
99    F: FnOnce(TaskScope) -> Fut + Send + 'static,
100    Fut: Future + Send + 'static,
101    Fut::Output: Send + 'static,
102{
103    let pool = global_task_pool();
104    pool.spawn(f)
105}
106
107/// Run a task on the task pool.
108///
109/// The future returned by this function will wait for the task to finish.
110pub async fn run_in_place<F, Fut, R>(f: F) -> TaskHandle<R>
111where
112    F: FnOnce(TaskScope) -> Fut,
113    Fut: Future<Output = R>,
114{
115    let pool = global_task_pool();
116    pool.run(f).await
117}
118
119/// Run a task on the task pool.
120///
121/// The future returned by this function will wait for the task to finish.
122pub fn run_sync_in_place<F, R>(f: F) -> Result<R, CudaError>
123where
124    F: FnOnce(TaskScope) -> R,
125{
126    let pool = global_task_pool();
127    pool.run_sync(f)
128}
129
130#[derive(Debug, Clone, Error)]
131pub enum TaskPoolBuildError {
132    #[error("failed to create CUDA stream: {0}")]
133    StreamCreationFailed(CudaError),
134
135    #[error("failed to create CUDA event: {0}")]
136    EventCreationFailed(CudaError),
137
138    #[error("failed to push task back into pool")]
139    PushTaskFailed,
140}
141
142#[derive(Debug, Clone, Error)]
143pub enum GlobalTaskPoolBuildError {
144    #[error("failed to build global task pool")]
145    BuildFailed(#[from] TaskPoolBuildError),
146    #[error("global task pool already initialized")]
147    AlreadyInitialized,
148}
149
150impl TaskPoolBuilder {
151    pub fn new() -> Self {
152        Self { capacity: None, device: CudaDevice(0), mem_release_threshold: u64::MAX }
153    }
154
155    pub fn num_tasks(mut self, num_tasks: usize) -> Self {
156        self.capacity = Some(num_tasks);
157        self
158    }
159
160    pub fn device(mut self, device: CudaDevice) -> Self {
161        assert!(device.0 == 0, "only device 0 is supported at the moment");
162        self.device = device;
163        self
164    }
165
166    /// Sets the memory release threshold for the associated device.
167    ///
168    /// # Warning
169    /// This setting will affect the memory release threshold for the entire device, not just the
170    /// current task pool being built.
171    pub fn mem_release_threshold(mut self, threshold: u64) -> Self {
172        self.mem_release_threshold = threshold;
173        self
174    }
175
176    fn allocate_new_id(&self) -> usize {
177        let id = POOL_ID.fetch_add(1, Ordering::Relaxed);
178        if id > usize::MAX / 2 {
179            std::process::abort();
180        }
181        id
182    }
183
184    pub fn build(self) -> Result<TaskPool, TaskPoolBuildError> {
185        let id = self.allocate_new_id();
186        let num_tasks = self.capacity.unwrap_or(DEFAULT_NUM_TASKS);
187
188        // Set the memory release threshold
189        unsafe {
190            let mut mem_pool = CudaMemPool(ptr::null_mut());
191            CudaError::result_from_ffi(cuda_device_get_default_mem_pool(
192                &mut mem_pool,
193                self.device,
194            ))
195            .unwrap();
196            CudaError::result_from_ffi(cuda_mem_pool_set_release_threshold(
197                mem_pool,
198                self.mem_release_threshold,
199            ))
200            .unwrap();
201        };
202
203        let mut tasks = Vec::with_capacity(num_tasks);
204        for (i, _) in (0..num_tasks).enumerate() {
205            let stream = CudaStream::create().map_err(TaskPoolBuildError::StreamCreationFailed)?;
206            let end_event = CudaEvent::create().map_err(TaskPoolBuildError::EventCreationFailed)?;
207            tasks.push(Task { owner_id: id, id: i, stream, end_event });
208        }
209        let inner = Arc::new(WorkerQueue::new(tasks));
210
211        Ok(TaskPool { inner })
212    }
213
214    pub fn build_global(self) -> Result<(), GlobalTaskPoolBuildError> {
215        let pool = self.build()?;
216        GLOBAL_TASK_POOL
217            .set(Arc::new(pool))
218            .map_err(|_| GlobalTaskPoolBuildError::AlreadyInitialized)
219    }
220}
221
222impl Default for TaskPoolBuilder {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228#[derive(Debug, Clone)]
229pub struct TaskPool {
230    inner: Arc<WorkerQueue<Task>>,
231}
232
233struct OwnedTask {
234    inner: Worker<Task>,
235}
236
237impl std::fmt::Debug for OwnedTask {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        write!(f, "OwnedTask {{ inner: {:?} }}", self.inner.deref())
240    }
241}
242
243#[derive(Debug, Error)]
244#[error("failed to acquire a task from the pool")]
245pub enum TaskSpawnError {
246    AcquireError(#[from] AcquireWorkerError),
247}
248
249#[derive(Debug, Error)]
250#[error("failed to acquire a task from the pool")]
251pub enum TrySpawnError {
252    TryAcquireError(#[from] TryAcquireWorkerError),
253}
254
255impl TaskPool {
256    async fn task(inner: Arc<WorkerQueue<Task>>) -> Result<OwnedTask, TaskSpawnError> {
257        let worker = inner.clone().pop().await.map_err(TaskSpawnError::AcquireError)?;
258        Ok(OwnedTask { inner: worker })
259    }
260
261    fn try_task(inner: Arc<WorkerQueue<Task>>) -> Result<OwnedTask, TrySpawnError> {
262        let worker = inner.clone().try_pop().map_err(TrySpawnError::TryAcquireError)?;
263        Ok(OwnedTask { inner: worker })
264    }
265
266    /// Spawn a task on the task pool.
267    ///
268    /// This function will not block the current thread.
269    pub fn spawn<F, Fut>(&self, f: F) -> SpawnHandle<Fut::Output>
270    where
271        F: FnOnce(TaskScope) -> Fut + Send + 'static,
272        Fut: Future + Send + 'static,
273        Fut::Output: Send + 'static,
274    {
275        let queue = self.inner.clone();
276        let handle = tokio::spawn(async move {
277            let task = TaskPool::task(queue).await.expect("failed to acquire a task from the pool");
278            task.run(f).await.await
279        });
280        SpawnHandle { handle }
281    }
282
283    pub fn spawn_blocking<F, R>(&self, f: F) -> SpawnHandle<R>
284    where
285        F: FnOnce(TaskScope) -> R + Send + 'static,
286        R: Send + 'static,
287    {
288        let queue = self.inner.clone();
289        let handle = tokio::task::spawn_blocking(move || {
290            let task = TaskPool::try_task(queue).expect("failed to acquire a task from the pool");
291            let task = Arc::new(task);
292            task.run_sync(f)
293        });
294        SpawnHandle { handle }
295    }
296
297    /// Run a task on the task pool.
298    ///
299    /// The future returned by this function will wait for the task to finish.
300    pub async fn run<F, Fut, R>(&self, f: F) -> TaskHandle<R>
301    where
302        F: FnOnce(TaskScope) -> Fut,
303        Fut: Future<Output = R>,
304    {
305        let queue = self.inner.clone();
306        let task = TaskPool::task(queue).await.expect("failed to acquire a task from the pool");
307        task.run(f).await
308    }
309
310    pub fn run_sync<F, R>(&self, f: F) -> Result<R, CudaError>
311    where
312        F: FnOnce(TaskScope) -> R,
313    {
314        let queue = self.inner.clone();
315        let task = TaskPool::try_task(queue).expect("failed to acquire a task from the pool");
316        let task = Arc::new(task);
317        task.run_sync(f)
318    }
319}
320
321#[derive(Debug)]
322pub struct TaskScope(Weak<OwnedTask>);
323
324impl Clone for TaskScope {
325    fn clone(&self) -> Self {
326        TaskScope(self.0.clone())
327    }
328}
329
330impl Deref for TaskScope {
331    type Target = Task;
332
333    #[inline]
334    fn deref(&self) -> &Self::Target {
335        unsafe { &(*self.0.as_ptr()).inner }
336    }
337}
338
339unsafe impl Backend for TaskScope {}
340
341unsafe extern "C" fn sleep(ptr: *mut c_void) {
342    let time = unsafe { Box::from_raw(ptr as *mut Duration) };
343    std::thread::sleep(*time);
344}
345
346unsafe extern "C" fn sync_host(ptr: *mut c_void) {
347    let tx = unsafe { Box::from_raw(ptr as *mut oneshot::Sender<bool>) };
348    tx.send(true).unwrap();
349}
350
351impl TaskScope {
352    /// Allocates a buffer in this scope on the device.
353    ///
354    /// This call is not blocking. Upon successful completion, it will return a buffer with a memory
355    /// that is guaranteed to be available in the scope of the task but without any absolute
356    /// guarantee relative to the host or any other task.
357    ///
358    /// Other tasks may try to allocate memory concurrently. In order to guarantee enough memory
359    /// for all expected work, the user must ensure some limit on task calls by e.g. using a
360    /// semaphore.
361    #[inline]
362    pub fn alloc<T>(&self, capacity: usize) -> Buffer<T, Self> {
363        Buffer::with_capacity_in(capacity, self.clone())
364    }
365
366    /// Tries to allocate a buffer in this scope on the device.
367    #[inline]
368    pub fn try_alloc<T>(
369        &self,
370        capacity: usize,
371    ) -> Result<Buffer<T, Self>, slop_alloc::TryReserveError> {
372        Buffer::try_with_capacity_in(capacity, self.clone())
373    }
374
375    /// Launches a host function in this task.
376    ///
377    /// # Safety
378    ///
379    /// The function essentially executes an extern call in `C`. The safety assumption of an extern.  
380    /// The user must ensure the pointer is valid and that the data remains valid as this call will
381    /// be asynchronous.
382    #[inline]
383    pub unsafe fn launch_host_fn(
384        &self,
385        host_fn: unsafe extern "C" fn(*mut c_void),
386        data: *mut c_void,
387    ) -> Result<(), CudaError> {
388        self.launch_host_fn_uncheked(Some(host_fn), data)
389    }
390
391    /// Launches a kernel in this task.
392    ///
393    /// # Safety
394    /// The caller must ensure that:
395    /// - The kernel ptr is valid.
396    /// - The arguments are passed correctly across the FFI interface.
397    /// - The data lives whitin the scope of the current task.
398    pub unsafe fn launch_kernel(
399        &self,
400        kernel: KernelPtr,
401        grid_dim: impl Into<Dim3>,
402        block_dim: impl Into<Dim3>,
403        args: &[*mut c_void],
404        shared_mem: usize,
405    ) -> Result<(), CudaError> {
406        self.stream().launch_kernel(kernel, grid_dim, block_dim, args, shared_mem)
407    }
408
409    /// Sends the CUDA task to sleep for **at least** the given duration.
410    ///
411    /// This function will not block the calling host thread. The function does a small allocation
412    /// so the sleep time might be slightly longer than the given duration for very short times.
413    pub fn sleep(&self, time: Duration) {
414        let time_ptr = Box::into_raw(Box::new(time));
415        unsafe {
416            self.launch_host_fn(sleep, time_ptr as *mut c_void).unwrap();
417        }
418    }
419
420    /// Copies data between slices using CudaMemCpyAsync
421    ///
422    /// # Safety
423    /// The caller must ensure that the data is valid and that the data remains valid as this call
424    pub unsafe fn copy<T: DeviceCopy>(
425        &self,
426        dst: &mut Slice<T, Self>,
427        src: &Slice<T, Self>,
428    ) -> Result<(), CopyError> {
429        dst.copy_from_slice(src, self)
430    }
431
432    /// Waits for all work enqueued so far in this task to finish.
433    ///
434    /// This function can be useful in case there is work to be enqueued but for some reason this
435    /// work cannot be done using [Self::launch_host_fn].
436    pub async fn synchronize(&self) -> Result<(), CudaError> {
437        let (tx, mut rx) = oneshot::channel::<bool>();
438        let mut interval = tokio::time::interval(Duration::from_millis(INTERVAL_MS));
439
440        // Launch the host function to signal the main thread that the task is done
441        let tx = Box::new(tx);
442        let tx_ptr = Box::into_raw(tx);
443        unsafe {
444            self.launch_host_fn(sync_host, tx_ptr as *mut c_void)?;
445        }
446
447        // Wait for the host function to signal the main thread that the task is done while
448        // simultaneously polling the stream in the interval to catch any errors.
449        loop {
450            tokio::select! {
451                _ = interval.tick() => {
452                     match unsafe { self.stream().query() } {
453                        Ok(()) => {break;}
454                        Err(CudaError::NotReady) => {}
455                        Err(e) => {
456                            return Err(e);
457                        }
458
459                    }
460                }
461                _ = &mut rx => {
462                    break;
463                }
464            }
465        }
466
467        Ok(())
468    }
469
470    /// Joins this task into another task.
471    ///
472    /// The other task will wait for the current task to finish.
473    #[inline]
474    unsafe fn join(self, parent: &TaskScope) -> Result<(), CudaError> {
475        parent.stream.wait_unchecked(&self.end_event)
476    }
477
478    /// Copies data from the host to the device.
479    #[inline]
480    pub fn into_device<T: IntoDevice>(&self, data: T) -> Result<T::Output, CopyError> {
481        T::into_device_in(data, self)
482    }
483
484    #[inline]
485    pub fn to_device<T: ToDevice>(&self, data: &T) -> Result<T::Output, CopyError> {
486        T::to_device_in(data, self)
487    }
488
489    /// Waits for all work enqueued so far in this task to finish.
490    ///
491    /// This function can be useful in case there is work to be enqueued but for some reason this
492    /// work cannot be done using [Self::launch_host_fn].
493    #[inline]
494    pub fn synchronize_blocking(&self) -> Result<(), CudaError> {
495        // The access to the stream is safe and therefore synchronize is safe.
496        unsafe { self.stream_synchronize() }
497    }
498
499    /// # Safety
500    pub unsafe fn handle(&self) -> CudaStreamHandle {
501        self.stream.0
502    }
503
504    pub fn owner(&self) -> TaskPool {
505        TaskPool { inner: self.0.upgrade().unwrap().inner.owner().clone() }
506    }
507
508    fn owner_queue(&self) -> Arc<WorkerQueue<Task>> {
509        self.0.upgrade().unwrap().inner.owner().clone()
510    }
511
512    /// Spawns a new task from the current task pool.
513    ///
514    /// The task starting point will have a "happens before" relationship with the current task when
515    /// the spawn is called. The handle can be used to wait for the child task to finish.
516    pub fn spawn<F, Fut>(&self, f: F) -> SpawnHandle<Fut::Output>
517    where
518        F: FnOnce(TaskScope) -> Fut + Send + 'static,
519        Fut: Future + Send + 'static,
520        Fut::Output: CudaSend + 'static,
521    {
522        let parent = self.clone();
523        let handle = tokio::spawn(async move { parent.run_in_place(f).await });
524        SpawnHandle { handle }
525    }
526
527    /// Runs a task in place in a new stream.
528    ///
529    /// Awaiting this task will peform the device calls and synchronize the end of this task to
530    /// the parent, but does not do host synchronization.
531    pub async fn run_in_place<F, Fut>(&self, f: F) -> Result<Fut::Output, CudaError>
532    where
533        F: FnOnce(TaskScope) -> Fut,
534        Fut: Future,
535        Fut::Output: CudaSend,
536    {
537        let parent = self.clone();
538        let task = TaskPool::task(parent.owner_queue()).await.unwrap();
539        unsafe {
540            // Use the task's end event to synchronize the parent task.
541            // This is safe because this is the first time this task is being run so we know
542            // there are no other copies that record anything on this event at the same time.
543            parent.stream.record_unchecked(&task.inner.end_event)?;
544            task.inner.stream.wait_unchecked(&task.inner.end_event)?
545        };
546        let handle = task.run(f).await;
547        handle.join(&parent)
548    }
549}
550
551impl StreamRef for TaskScope {
552    #[inline]
553    unsafe fn stream(&self) -> &CudaStream {
554        &self.stream
555    }
556}
557
558#[derive(Debug)]
559pub struct Task {
560    pub(crate) owner_id: usize,
561    pub(crate) id: usize,
562    pub(crate) stream: CudaStream,
563    end_event: CudaEvent,
564}
565
566impl PartialEq for Task {
567    fn eq(&self, other: &Self) -> bool {
568        self.owner_id == other.owner_id && self.id == other.id
569    }
570}
571
572impl Eq for Task {}
573
574impl StreamRef for Task {
575    #[inline]
576    unsafe fn stream(&self) -> &CudaStream {
577        &self.stream
578    }
579}
580
581impl Drop for Task {
582    fn drop(&mut self) {
583        unsafe {
584            self.end_event.query().expect("attempting to drop a task that did not finish");
585            self.stream.query().expect("attempting to drop a task that did not finish");
586        }
587    }
588}
589
590impl IntoFuture for Task {
591    type Output = Result<(), CudaError>;
592    type IntoFuture = StreamCallbackFuture<Self>;
593
594    fn into_future(self) -> Self::IntoFuture {
595        StreamCallbackFuture::new(self)
596    }
597}
598
599unsafe impl Allocator for TaskScope {
600    #[inline]
601    unsafe fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
602        self.stream.allocate(layout)
603    }
604
605    #[inline]
606    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
607        // SAFETY: the safety contract must be upheld by the caller
608        self.stream.deallocate(ptr, layout)
609    }
610}
611
612impl DeviceMemory for TaskScope {
613    #[inline]
614    unsafe fn copy_nonoverlapping(
615        &self,
616        src: *const u8,
617        dst: *mut u8,
618        size: usize,
619        direction: CopyDirection,
620    ) -> Result<(), CopyError> {
621        self.stream.copy_nonoverlapping(src, dst, size, direction)
622    }
623
624    #[inline]
625    unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
626        self.stream.write_bytes(dst, value, size)
627    }
628}
629
630// // Implement CanCopyFrom for TaskScope to copy from CpuBackend
631// impl<T: DeviceCopy> CanCopyFrom<Buffer<T>, slop_alloc::CpuBackend> for TaskScope {
632//     type Output = Buffer<T, TaskScope>;
633
634//     fn copy_into(
635//         &self,
636//         value: Buffer<T>,
637//     ) -> impl std::future::Future<Output = Result<Self::Output, CopyError>> + Send + Sync {
638//         let result = DeviceBuffer::from_host(&value, self).map(|b| b.into_inner());
639//         std::future::ready(result)
640//     }
641// }
642
643// // Implement CanCopyFromRef for TaskScope to copy Point from CpuBackend
644// impl<T: DeviceCopy> CanCopyFromRef<Point<T>, slop_alloc::CpuBackend> for TaskScope {
645//     type Output = Point<T, TaskScope>;
646
647//     fn copy_to(
648//         &self,
649//         value: &Point<T>,
650//     ) -> impl std::future::Future<Output = Result<Self::Output, CopyError>> + Send + Sync {
651//         let result =
652//             DeviceBuffer::from_host(value.values(), self).map(|b| Point::new(b.into_inner()));
653//         std::future::ready(result)
654//     }
655// }
656
657// // Implement CanCopyIntoRef for TaskScope to copy Point to CpuBackend
658// impl<T: DeviceCopy> CanCopyIntoRef<Point<T, TaskScope>, slop_alloc::CpuBackend> for TaskScope {
659//     type Output = Point<T>;
660
661//     fn copy_to_dst(
662//         dst: &slop_alloc::CpuBackend,
663//         value: &Point<T, TaskScope>,
664//     ) -> impl std::future::Future<Output = Result<Self::Output, CopyError>> + Send + Sync {
665//         let _ = dst;
666//         let result =
667//             DeviceBuffer::from_raw(value.values().clone()).to_host().map(|v| Point::new(v.into()));
668//         std::future::ready(result)
669//     }
670// }
671
672impl OwnedTask {
673    fn is_finished(&self) -> Result<bool, CudaError> {
674        self.inner.end_event.query().map(|()| true).or_else(|e| match e {
675            CudaError::NotReady => Ok(false),
676            e => Err(e),
677        })
678    }
679
680    async fn run<F, Fut, R>(self, f: F) -> TaskHandle<R>
681    where
682        F: FnOnce(TaskScope) -> Fut,
683        Fut: Future<Output = R>,
684    {
685        let strong_ptr = Arc::new(self);
686        let scope = TaskScope(Arc::downgrade(&strong_ptr));
687        let value = f(scope.clone()).await;
688        unsafe { scope.stream.record_unchecked(&scope.end_event).unwrap() };
689        TaskHandle { task: strong_ptr, scope, value }
690    }
691
692    fn run_sync<F, R>(self: Arc<Self>, f: F) -> Result<R, CudaError>
693    where
694        F: FnOnce(TaskScope) -> R,
695    {
696        let scope = TaskScope(Arc::downgrade(&self));
697        let output = f(scope.clone());
698        unsafe {
699            scope.stream.record_unchecked(&scope.end_event)?;
700            scope.end_event.synchronize()?;
701        };
702        Ok(output)
703    }
704}
705
706impl StreamRef for OwnedTask {
707    #[inline]
708    unsafe fn stream(&self) -> &CudaStream {
709        self.inner.stream()
710    }
711}
712
713impl IntoFuture for TaskScope {
714    type Output = Result<(), CudaError>;
715    type IntoFuture = StreamCallbackFuture<Self>;
716
717    fn into_future(self) -> Self::IntoFuture {
718        StreamCallbackFuture::new(self)
719    }
720}
721
722pub struct TaskHandle<T> {
723    task: Arc<OwnedTask>,
724    scope: TaskScope,
725    value: T,
726}
727
728impl<T> TaskHandle<T> {
729    pub fn join(self, parent: &TaskScope) -> Result<T, CudaError>
730    where
731        T: CudaSend,
732    {
733        // See [TaskHandle::join] for the explanation of safety. Here this is a bit more complex,
734        // but the eventual panic still applies. This is enough in most cases.
735        unsafe {
736            self.scope.join(parent)?;
737            let value = self.value.send_to_scope(parent);
738            // Return the value to the caller.
739            Ok(value)
740        }
741    }
742
743    pub fn is_finished(&self) -> Result<bool, CudaError> {
744        self.task.is_finished()
745    }
746}
747
748#[pin_project]
749pub struct StreamHandleFuture<T> {
750    #[pin]
751    callback: StreamCallbackFuture<Arc<OwnedTask>>,
752    value: MaybeUninit<T>,
753}
754
755impl<T> Future for StreamHandleFuture<T> {
756    type Output = Result<T, CudaError>;
757
758    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
759        let this = self.project();
760        this.callback.poll(cx).map(|res| {
761            res.map(|_| {
762                let uinit = MaybeUninit::uninit();
763                let ret = std::mem::replace(this.value, uinit);
764                // We assume that JoinHandleFuture is created from a JoinHandle, so the value is
765                // always initialized.
766                unsafe { ret.assume_init() }
767            })
768        })
769    }
770}
771
772impl<T> IntoFuture for TaskHandle<T> {
773    type Output = Result<T, CudaError>;
774    type IntoFuture = StreamHandleFuture<T>;
775
776    #[inline]
777    fn into_future(self) -> Self::IntoFuture {
778        StreamHandleFuture {
779            callback: StreamCallbackFuture::new(self.task),
780            value: MaybeUninit::new(self.value),
781        }
782    }
783}
784
785#[cfg(test)]
786mod tests {
787
788    use crate::TaskPoolBuilder;
789
790    #[tokio::test]
791    async fn test_global_task_pool() {
792        crate::spawn(|_| async {}).await.unwrap();
793    }
794
795    #[tokio::test]
796    async fn test_local_pool() {
797        let num_workers = 10;
798        let num_callers = 100;
799        let pool = TaskPoolBuilder::new().num_tasks(num_workers).build().unwrap();
800
801        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
802        let mut handles = Vec::new();
803        for _ in 0..num_callers {
804            let pool = pool.clone();
805            let tx = tx.clone();
806            let handle = pool.spawn(|_| async move {
807                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
808                tx.send(true).unwrap();
809            });
810
811            handles.push(handle);
812        }
813        drop(tx);
814
815        let mut count = 0;
816        while let Some(flag) = rx.recv().await {
817            assert!(flag);
818            count += 1;
819        }
820
821        for handle in handles {
822            handle.await.unwrap();
823        }
824
825        assert_eq!(count, num_callers);
826    }
827}