nio_task/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(unsafe_op_in_unsafe_fn)]
3
4mod blocking;
5mod error;
6mod id;
7mod join;
8mod raw;
9mod state;
10mod waker;
11
12use crate::raw::*;
13
14pub use blocking::BlockingTask;
15pub use error::JoinError;
16pub use id::Id;
17pub use join::JoinHandle;
18
19use state::*;
20use std::{
21    cell::UnsafeCell,
22    fmt,
23    future::Future,
24    marker::PhantomData,
25    mem::ManuallyDrop,
26    panic::{AssertUnwindSafe, catch_unwind},
27    pin::Pin,
28    sync::Arc,
29    task::{Context, Poll, Wake, Waker},
30};
31
32pub trait Scheduler<M>: 'static {
33    fn schedule(&self, task: Task<M>);
34}
35
36impl<F, M> Scheduler<M> for F
37where
38    F: Fn(Task<M>) + 'static,
39{
40    fn schedule(&self, runnable: Task<M>) {
41        self(runnable)
42    }
43}
44
45struct RawTaskInner<F: Future, S: Scheduler<M>, M> {
46    header: Header,
47    future: UnsafeCell<Fut<F, F::Output>>,
48    meta: UnsafeCell<M>,
49    scheduler: S,
50}
51
52unsafe impl<F: Future, S: Scheduler<M>, M> Send for RawTaskInner<F, S, M> {}
53unsafe impl<F: Future, S: Scheduler<M>, M> Sync for RawTaskInner<F, S, M> {}
54
55pub struct Task<M = ()> {
56    raw: RawTask,
57    _meta: PhantomData<M>,
58}
59
60unsafe impl<M> Send for Task<M> {}
61unsafe impl<M> Sync for Task<M> {}
62
63impl<M> std::panic::UnwindSafe for Task<M> {}
64impl<M> std::panic::RefUnwindSafe for Task<M> {}
65
66pub struct Metadata<M>(Task<M>);
67
68impl<M> Metadata<M> {
69    pub fn get(&self) -> &M {
70        self.0.metadata()
71    }
72
73    pub fn get_mut(&mut self) -> &mut M {
74        self.0.metadata_mut()
75    }
76}
77
78pub enum Status<M> {
79    Yielded(Task<M>),
80    Pending,
81    Complete(Metadata<M>),
82}
83
84impl Task {
85    pub fn new<F, S>(future: F, scheduler: S) -> (Self, JoinHandle<F::Output>)
86    where
87        F: Future + Send + 'static,
88        F::Output: Send,
89        S: Scheduler<()>,
90    {
91        Self::new_with((), future, scheduler)
92    }
93
94    pub fn new_local<F, S>(future: F, scheduler: S) -> (Self, JoinHandle<F::Output>)
95    where
96        F: Future + 'static,
97        F::Output: 'static,
98        S: Scheduler<()>,
99    {
100        Self::new_local_with((), future, scheduler)
101    }
102}
103
104impl<M> Task<M> {
105    pub fn metadata(&self) -> &M {
106        unsafe { &*self.raw.metadata().cast::<M>() }
107    }
108
109    pub fn metadata_mut(&mut self) -> &mut M {
110        unsafe { &mut *self.raw.metadata().cast::<M>() }
111    }
112
113    pub fn new_with<F, S>(meta: M, future: F, scheduler: S) -> (Self, JoinHandle<F::Output>)
114    where
115        M: 'static + Send,
116        F: Future + Send + 'static,
117        F::Output: Send,
118        S: Scheduler<M>,
119    {
120        let raw = Arc::new(RawTaskInner {
121            header: Header::new(),
122            future: UnsafeCell::new(Fut::Future(future)),
123            meta: UnsafeCell::new(meta),
124            scheduler,
125        });
126        let join_handle = JoinHandle::new(raw.clone());
127        (
128            Self {
129                raw,
130                _meta: PhantomData,
131            },
132            join_handle,
133        )
134    }
135
136    pub fn new_local_with<F, S>(meta: M, future: F, scheduler: S) -> (Self, JoinHandle<F::Output>)
137    where
138        M: 'static + Send,
139        F: Future + 'static,
140        F::Output: 'static,
141        S: Scheduler<M>,
142    {
143        let raw = Arc::new(RawTaskInner {
144            header: Header::new(),
145            future: UnsafeCell::new(Fut::Future(future)),
146            meta: UnsafeCell::new(meta),
147            scheduler,
148        });
149        let join_handle = JoinHandle::new(raw.clone());
150        (
151            Self {
152                raw,
153                _meta: PhantomData,
154            },
155            join_handle,
156        )
157    }
158
159    #[inline]
160    pub fn poll(self) -> Status<M> {
161        // Don't increase ref-counter
162        let raw = unsafe { Arc::from_raw(Arc::as_ptr(&self.raw)) };
163        // Don't decrease ref-counter
164        let waker = ManuallyDrop::new(raw.waker());
165
166        // SAFETY: `Task` does not implement `Clone` and we have owned access
167        match unsafe { self.raw.poll(&waker) } {
168            PollStatus::Yield => Status::Yielded(self),
169            PollStatus::Pending => Status::Pending,
170            PollStatus::Complete => Status::Complete(Metadata(self)),
171        }
172    }
173
174    #[inline]
175    pub fn schedule(self) {
176        unsafe { self.raw.schedule() }
177    }
178
179    #[inline]
180    pub fn id(&self) -> Id {
181        Id::new(&self.raw)
182    }
183}
184
185impl<F, S, M> RawTaskVTable for RawTaskInner<F, S, M>
186where
187    M: 'static,
188    F: Future + 'static,
189    S: Scheduler<M>,
190{
191    #[inline]
192    fn waker(self: Arc<Self>) -> Waker {
193        Waker::from(self)
194    }
195
196    #[inline]
197    fn header(&self) -> &Header {
198        &self.header
199    }
200
201    unsafe fn metadata(&self) -> *mut () {
202        self.meta.get().cast()
203    }
204
205    /// State transitions:
206    ///
207    /// ```markdown
208    /// NOTIFIED -> RUNNING -> ( SLEEP? -> NOTIFIED -> RUNNING )* -> COMPLETE?
209    /// ```
210    unsafe fn poll(&self, waker: &Waker) -> PollStatus {
211        let is_cancelled = self.header.transition_to_running_and_check_if_cancelled();
212
213        let has_output = catch_unwind(AssertUnwindSafe(|| {
214            let poll_result = unsafe {
215                let fut = match &mut *self.future.get() {
216                    Fut::Future(fut) => Pin::new_unchecked(fut),
217                    _ => unreachable!(),
218                };
219                // Polling may panic, but we catch it in outer layer.
220                fut.poll(&mut Context::from_waker(waker))
221            };
222            let result = match poll_result {
223                Poll::Ready(val) => Ok(val),
224                Poll::Pending if is_cancelled => Err(JoinError::cancelled()),
225                Poll::Pending => return false,
226            };
227            // Droping `Fut::Future` may also panic, but we catch it in outer layer
228            unsafe {
229                (*self.future.get()).set_output(result);
230            }
231            true
232        }));
233
234        match has_output {
235            Ok(false) => return self.header.transition_to_sleep(),
236            Ok(true) => {}
237            Err(err) => unsafe { (*self.future.get()).set_output(Err(JoinError::panic(err))) },
238        }
239        if !self
240            .header
241            .transition_to_complete_and_notify_output_if_intrested()
242        {
243            // Receiver is not interested in the output, So we can drop it.
244            // Droping `Fut::Output` may panic
245            let _ = catch_unwind(AssertUnwindSafe(|| unsafe { (*self.future.get()).drop() }));
246        }
247        PollStatus::Complete
248    }
249
250    unsafe fn schedule(self: Arc<Self>) {
251        self.scheduler.schedule(Task {
252            raw: self.clone(),
253            _meta: PhantomData,
254        });
255    }
256
257    unsafe fn abort_task(self: Arc<Self>) {
258        if self.header.transition_to_abort() {
259            self.schedule()
260        }
261    }
262
263    unsafe fn read_output(&self, dst: *mut (), waker: &Waker) {
264        if self.header.can_read_output_or_notify_when_readable(waker) {
265            *(dst as *mut _) = Poll::Ready((*self.future.get()).take_output());
266        }
267    }
268
269    unsafe fn drop_join_handler(&self) {
270        let is_task_complete = self.header.state.unset_waker_and_interested();
271        if is_task_complete {
272            // If the task is complete then waker is droped by the executor.
273            // We just only need to drop the output.
274            let _ = catch_unwind(AssertUnwindSafe(|| unsafe {
275                (*self.future.get()).drop();
276            }));
277        } else {
278            *self.header.join_waker.get() = None;
279        }
280    }
281}
282
283impl<F, S, M> RawTaskInner<F, S, M>
284where
285    M: 'static,
286    F: Future + 'static,
287    S: Scheduler<M>,
288{
289    unsafe fn schedule_by_ref(self: &Arc<Self>) {
290        self.scheduler.schedule(Task {
291            raw: self.clone(),
292            _meta: PhantomData,
293        });
294    }
295}
296
297impl<F, S, M> Wake for RawTaskInner<F, S, M>
298where
299    M: 'static,
300    F: Future + 'static,
301    S: Scheduler<M>,
302{
303    fn wake(self: Arc<Self>) {
304        unsafe {
305            if self.header.transition_to_notified() {
306                self.schedule();
307            }
308        }
309    }
310
311    fn wake_by_ref(self: &Arc<Self>) {
312        unsafe {
313            if self.header.transition_to_notified() {
314                self.schedule_by_ref();
315            }
316        }
317    }
318}
319
320impl<M> fmt::Debug for Task<M> {
321    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322        f.debug_struct("Task")
323            .field("id", &self.id())
324            .field("state", &self.raw.header().state.load())
325            .finish()
326    }
327}