1use crate::runtime::execution::ExecutionState;
9use crate::runtime::task::TaskId;
10use crate::runtime::thread;
11use std::error::Error;
12use std::fmt::{Display, Formatter};
13use std::future::Future;
14use std::pin::Pin;
15use std::result::Result;
16use std::sync::Arc;
17use std::task::{Context, Poll, Waker};
18
19pub mod batch_semaphore;
20
21fn spawn_inner<F>(fut: F) -> JoinHandle<F::Output>
22where
23 F: Future + 'static,
24 F::Output: 'static,
25{
26 let stack_size = ExecutionState::with(|s| s.config.stack_size);
27 let inner = Arc::new(std::sync::Mutex::new(JoinHandleInner::default()));
28 let task_id = ExecutionState::spawn_future(Wrapper::new(fut, inner.clone()), stack_size, None);
29
30 thread::switch();
31
32 JoinHandle { task_id, inner }
33}
34
35pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
37where
38 F: Future + Send + 'static,
39 F::Output: Send + 'static,
40{
41 spawn_inner(fut)
42}
43
44pub fn spawn_local<F>(fut: F) -> JoinHandle<F::Output>
47where
48 F: Future + 'static,
49 F::Output: 'static,
50{
51 spawn_inner(fut)
52}
53
54#[derive(Debug, Clone)]
56pub struct AbortHandle {
57 task_id: TaskId,
58}
59
60impl AbortHandle {
61 pub fn abort(&self) {
63 ExecutionState::try_with(|state| {
64 if !state.is_finished() {
65 let task = state.get_mut(self.task_id);
66 task.abort();
67 }
68 });
69 }
70
71 pub fn is_finished(&self) -> bool {
76 ExecutionState::with(|state| {
77 let task = state.get(self.task_id);
78 task.finished()
79 })
80 }
81}
82
83unsafe impl Send for AbortHandle {}
84unsafe impl Sync for AbortHandle {}
85
86#[derive(Debug)]
88pub struct JoinHandle<T> {
89 task_id: TaskId,
90 inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<T>>>,
91}
92
93#[derive(Debug)]
94struct JoinHandleInner<T> {
95 result: Option<Result<T, JoinError>>,
96 waker: Option<Waker>,
97}
98
99impl<T> Default for JoinHandleInner<T> {
100 fn default() -> Self {
101 JoinHandleInner {
102 result: None,
103 waker: None,
104 }
105 }
106}
107
108impl<T> JoinHandle<T> {
109 pub fn abort(&self) {
111 ExecutionState::try_with(|state| {
112 if !state.is_finished() {
113 let task = state.get_mut(self.task_id);
114 task.abort();
115 }
116 });
117 }
118
119 pub fn is_finished(&self) -> bool {
124 ExecutionState::with(|state| {
125 let task = state.get(self.task_id);
126 task.finished()
127 })
128 }
129
130 pub fn abort_handle(&self) -> AbortHandle {
132 AbortHandle { task_id: self.task_id }
133 }
134}
135
136#[derive(Debug)]
139pub enum JoinError {
140 Cancelled,
142}
143
144impl Display for JoinError {
145 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
146 match self {
147 JoinError::Cancelled => write!(f, "task was cancelled"),
148 }
149 }
150}
151
152impl Error for JoinError {}
153
154impl<T> Drop for JoinHandle<T> {
155 fn drop(&mut self) {
156 self.abort();
157 }
158}
159
160impl<T> Future for JoinHandle<T> {
161 type Output = Result<T, JoinError>;
162
163 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 let mut lock = self.inner.lock().unwrap();
165 if let Some(result) = lock.result.take() {
166 Poll::Ready(result)
167 } else {
168 lock.waker = Some(cx.waker().clone());
169 Poll::Pending
170 }
171 }
172}
173
174struct Wrapper<F: Future> {
179 future: Pin<Box<F>>,
180 inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<F::Output>>>,
181}
182
183impl<F> Wrapper<F>
184where
185 F: Future + 'static,
186 F::Output: 'static,
187{
188 fn new(future: F, inner: std::sync::Arc<std::sync::Mutex<JoinHandleInner<F::Output>>>) -> Self {
189 Self {
190 future: Box::pin(future),
191 inner,
192 }
193 }
194}
195
196impl<F> Future for Wrapper<F>
197where
198 F: Future + 'static,
199 F::Output: 'static,
200{
201 type Output = ();
202
203 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
204 match self.future.as_mut().poll(cx) {
205 Poll::Ready(result) => {
206 if ExecutionState::try_with(|state| state.is_finished()).unwrap_or(true) {
210 return Poll::Ready(());
211 }
212
213 while let Some(local) = ExecutionState::with(|state| state.current_mut().pop_local()) {
219 drop(local);
220 }
221
222 let mut lock = self.inner.lock().unwrap();
223 lock.result = Some(Ok(result));
224
225 if let Some(waker) = lock.waker.take() {
226 waker.wake();
227 }
228
229 Poll::Ready(())
230 }
231 Poll::Pending => Poll::Pending,
232 }
233 }
234}
235
236pub fn block_on<F: Future>(future: F) -> F::Output {
238 let mut future = Box::pin(future);
239 let waker = ExecutionState::with(|state| state.current_mut().waker());
240 let cx = &mut Context::from_waker(&waker);
241
242 loop {
243 match future.as_mut().poll(cx) {
244 Poll::Ready(result) => break result,
245 Poll::Pending => {
246 ExecutionState::with(|state| state.current_mut().sleep_unless_woken());
247 }
248 }
249
250 thread::switch();
251 }
252}
253
254pub async fn yield_now() {
258 struct YieldNow {
260 yielded: bool,
261 }
262
263 impl Future for YieldNow {
264 type Output = ();
265
266 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
267 if self.yielded {
268 return Poll::Ready(());
269 }
270
271 self.yielded = true;
272 cx.waker().wake_by_ref();
273 ExecutionState::request_yield();
274 Poll::Pending
275 }
276 }
277
278 YieldNow { yielded: false }.await
279}