luminal/runtime/join_handle.rs
1//! Join handle implementation
2//!
3//! This module provides the implementation of `JoinHandle`,
4//! which is used to await the completion of async tasks.
5
6#[cfg(feature = "std")]
7use std::{future::Future, pin::Pin, task::{Context, Poll}};
8
9#[cfg(not(feature = "std"))]
10use core::{future::Future, pin::Pin, task::{Context, Poll}};
11
12#[cfg(feature = "std")]
13use crossbeam_channel::{Receiver, TryRecvError};
14
15#[cfg(not(feature = "std"))]
16use heapless::mpmc::MpMcQueue;
17
18#[cfg(not(feature = "std"))]
19use alloc::sync::Arc;
20
21use super::task::TaskId;
22
23/// Handle for awaiting the completion of an asynchronous task
24///
25/// Similar to tokio's JoinHandle, this allows waiting for a task to complete
26/// and retrieving its result. It implements `Future` so it can be awaited
27/// in async contexts.
28///
29/// # Type Parameters
30///
31/// * `T` - The output type of the task
32///
33/// # Examples
34///
35/// ```
36/// use luminal::Runtime;
37///
38/// let rt = Runtime::new().unwrap();
39/// let handle = rt.spawn(async { 42 });
40///
41/// // Await the handle
42/// let result = rt.block_on(async {
43/// handle.await
44/// });
45/// assert_eq!(result, 42);
46/// ```
47#[cfg(feature = "std")]
48pub struct JoinHandle<T> {
49 /// The unique identifier of the task
50 #[allow(dead_code)]
51 pub(crate) id: TaskId,
52
53 /// Channel for receiving the task's result when it completes
54 pub(crate) receiver: Receiver<T>,
55}
56
57/// Handle for awaiting the completion of an asynchronous task (no_std version)
58///
59/// Similar to tokio's JoinHandle, this allows waiting for a task to complete
60/// and retrieving its result. It implements `Future` so it can be awaited
61/// in async contexts.
62///
63/// In no_std environments, this uses a bounded queue with a default capacity of 16.
64#[cfg(not(feature = "std"))]
65pub struct JoinHandle<T> {
66 /// The unique identifier of the task
67 #[allow(dead_code)]
68 pub(crate) id: TaskId,
69
70 /// Channel for receiving the task's result when it completes
71 /// Uses a shared bounded queue for no_std environments
72 pub(crate) receiver: Arc<MpMcQueue<T, 16>>,
73}
74
75#[cfg(feature = "std")]
76impl<T: Send + 'static> Future for JoinHandle<T> {
77 type Output = T;
78
79 /// Poll method implementation for JoinHandle
80 ///
81 /// This checks if the result is ready by attempting to receive
82 /// from the channel. If the result is available, it's returned as Ready.
83 /// If not, it returns Pending to be polled again later.
84 ///
85 /// # Parameters
86 ///
87 /// * `self` - Pinned mutable reference to self
88 /// * `_cx` - Context for the poll (not used in this implementation)
89 ///
90 /// # Returns
91 ///
92 /// `Poll::Ready(T)` when the task completes, or `Poll::Pending` if still running
93 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
94 match self.receiver.try_recv() {
95 Ok(result) => Poll::Ready(result),
96 Err(TryRecvError::Empty) => Poll::Pending,
97 Err(TryRecvError::Disconnected) => {
98 // Channel disconnected - likely due to panic or dropped task
99 // For safety we could abort, but that would break API
100 // Instead log warning and return a default/placeholder
101 eprintln!("WARNING: Task channel disconnected unexpectedly");
102 // Keep pending to allow task completion or timeout
103 Poll::Pending
104 },
105 }
106 }
107}
108
109#[cfg(not(feature = "std"))]
110impl<T: Send + 'static> Future for JoinHandle<T> {
111 type Output = T;
112
113 /// Poll method implementation for JoinHandle (no_std version)
114 ///
115 /// This checks if the result is ready by attempting to receive
116 /// from the heapless queue. If the result is available, it's returned as Ready.
117 /// If not, it returns Pending to be polled again later.
118 ///
119 /// # Parameters
120 ///
121 /// * `self` - Pinned mutable reference to self
122 /// * `_cx` - Context for the poll (not used in this implementation)
123 ///
124 /// # Returns
125 ///
126 /// `Poll::Ready(T)` when the task completes, or `Poll::Pending` if still running
127 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
128 #[cfg(feature = "std")]
129 use std::task::Poll;
130 #[cfg(not(feature = "std"))]
131 use core::task::Poll;
132
133 // Try to dequeue a result from the bounded queue
134 match self.receiver.dequeue() {
135 Some(result) => Poll::Ready(result),
136 None => Poll::Pending,
137 }
138 }
139}