luminal/runtime/
join_handle.rs

1//! Join handle implementation
2//!
3//! This module provides the implementation of `JoinHandle`,
4//! which is used to await the completion of async tasks.
5
6#[cfg(feature = "std")]
7use std::{future::Future, pin::Pin, task::{Context, Poll}};
8
9#[cfg(not(feature = "std"))]
10use core::{future::Future, pin::Pin, task::{Context, Poll}};
11
12#[cfg(feature = "std")]
13use crossbeam_channel::{Receiver, TryRecvError};
14
15#[cfg(not(feature = "std"))]
16use heapless::mpmc::MpMcQueue;
17
18#[cfg(not(feature = "std"))]
19use alloc::sync::Arc;
20
21use super::task::TaskId;
22
23/// Handle for awaiting the completion of an asynchronous task
24///
25/// Similar to tokio's JoinHandle, this allows waiting for a task to complete
26/// and retrieving its result. It implements `Future` so it can be awaited
27/// in async contexts.
28///
29/// # Type Parameters
30///
31/// * `T` - The output type of the task
32///
33/// # Examples
34///
35/// ```
36/// use luminal::Runtime;
37///
38/// let rt = Runtime::new().unwrap();
39/// let handle = rt.spawn(async { 42 });
40///
41/// // Await the handle
42/// let result = rt.block_on(async {
43///     handle.await
44/// });
45/// assert_eq!(result, 42);
46/// ```
47#[cfg(feature = "std")]
48pub struct JoinHandle<T> {
49    /// The unique identifier of the task
50    #[allow(dead_code)]
51    pub(crate) id: TaskId,
52
53    /// Channel for receiving the task's result when it completes
54    pub(crate) receiver: Receiver<T>,
55}
56
57/// Handle for awaiting the completion of an asynchronous task (no_std version)
58///
59/// Similar to tokio's JoinHandle, this allows waiting for a task to complete
60/// and retrieving its result. It implements `Future` so it can be awaited
61/// in async contexts.
62///
63/// In no_std environments, this uses a bounded queue with a default capacity of 16.
64#[cfg(not(feature = "std"))]
65pub struct JoinHandle<T> {
66    /// The unique identifier of the task
67    #[allow(dead_code)]
68    pub(crate) id: TaskId,
69
70    /// Channel for receiving the task's result when it completes
71    /// Uses a shared bounded queue for no_std environments
72    pub(crate) receiver: Arc<MpMcQueue<T, 16>>,
73}
74
75#[cfg(feature = "std")]
76impl<T: Send + 'static> Future for JoinHandle<T> {
77    type Output = T;
78
79    /// Poll method implementation for JoinHandle
80    ///
81    /// This checks if the result is ready by attempting to receive
82    /// from the channel. If the result is available, it's returned as Ready.
83    /// If not, it returns Pending to be polled again later.
84    ///
85    /// # Parameters
86    ///
87    /// * `self` - Pinned mutable reference to self
88    /// * `_cx` - Context for the poll (not used in this implementation)
89    ///
90    /// # Returns
91    ///
92    /// `Poll::Ready(T)` when the task completes, or `Poll::Pending` if still running
93    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
94        match self.receiver.try_recv() {
95            Ok(result) => Poll::Ready(result),
96            Err(TryRecvError::Empty) => Poll::Pending,
97            Err(TryRecvError::Disconnected) => {
98                // Channel disconnected - likely due to panic or dropped task
99                // For safety we could abort, but that would break API
100                // Instead log warning and return a default/placeholder
101                eprintln!("WARNING: Task channel disconnected unexpectedly");
102                // Keep pending to allow task completion or timeout
103                Poll::Pending
104            },
105        }
106    }
107}
108
109#[cfg(not(feature = "std"))]
110impl<T: Send + 'static> Future for JoinHandle<T> {
111    type Output = T;
112
113    /// Poll method implementation for JoinHandle (no_std version)
114    ///
115    /// This checks if the result is ready by attempting to receive
116    /// from the heapless queue. If the result is available, it's returned as Ready.
117    /// If not, it returns Pending to be polled again later.
118    ///
119    /// # Parameters
120    ///
121    /// * `self` - Pinned mutable reference to self
122    /// * `_cx` - Context for the poll (not used in this implementation)
123    ///
124    /// # Returns
125    ///
126    /// `Poll::Ready(T)` when the task completes, or `Poll::Pending` if still running
127    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
128        #[cfg(feature = "std")]
129        use std::task::Poll;
130        #[cfg(not(feature = "std"))]
131        use core::task::Poll;
132
133        // Try to dequeue a result from the bounded queue
134        match self.receiver.dequeue() {
135            Some(result) => Poll::Ready(result),
136            None => Poll::Pending,
137        }
138    }
139}