nio_task/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(unsafe_op_in_unsafe_fn)]
3
4mod abort;
5mod blocking;
6mod error;
7mod id;
8mod join;
9mod raw;
10mod state;
11mod task;
12mod thin_arc;
13mod waker;
14
15use crate::{raw::*, thin_arc::ThinArc};
16
17pub use abort::AbortHandle;
18pub use blocking::BlockingTask;
19pub use error::JoinError;
20pub use id::{TaskId, id};
21pub use join::JoinHandle;
22
23use state::*;
24use std::{
25    cell::UnsafeCell,
26    fmt::{self, Debug},
27    future::Future,
28    marker::PhantomData,
29    mem::ManuallyDrop,
30};
31
32pub trait Scheduler<M = ()>: 'static {
33    fn schedule(&self, task: Task<M>);
34}
35
36impl<F, M> Scheduler<M> for F
37where
38    F: Fn(Task<M>) + 'static,
39{
40    fn schedule(&self, runnable: Task<M>) {
41        self(runnable)
42    }
43}
44
45pub struct Task<M = ()> {
46    raw: Option<RawTask>,
47    _meta: PhantomData<M>,
48}
49
50unsafe impl<M> Send for Task<M> {}
51unsafe impl<M> Sync for Task<M> {}
52
53impl<M> std::panic::UnwindSafe for Task<M> {}
54impl<M> std::panic::RefUnwindSafe for Task<M> {}
55
56impl<M> Drop for Task<M> {
57    fn drop(&mut self) {
58        if let Some(raw) = self.raw.take() {
59            unsafe { raw.drop_task() };
60        }
61    }
62}
63
64pub struct Metadata<M = ()> {
65    raw: RawTask,
66    _meta: PhantomData<M>,
67}
68
69impl<M: Debug> Debug for Metadata<M> {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        self.get().fmt(f)
72    }
73}
74
75impl<M> Metadata<M> {
76    pub fn id(&self) -> TaskId {
77        TaskId::new(&self.raw)
78    }
79
80    pub fn get(&self) -> &M {
81        unsafe { &*self.raw.metadata().cast::<M>() }
82    }
83
84    pub fn get_mut(&mut self) -> &mut M {
85        unsafe { &mut *self.raw.metadata().cast::<M>() }
86    }
87}
88
89#[derive(Debug)]
90pub enum Status<M> {
91    Yielded(Task<M>),
92    Pending,
93    Complete(Metadata<M>),
94}
95
96impl Task {
97    pub fn new<F, S>(future: F, scheduler: S) -> (Task, JoinHandle<F::Output>)
98    where
99        S: Scheduler<()> + Send,
100        F: Future + Send + 'static,
101        F::Output: Send + 'static,
102    {
103        unsafe { Self::new_unchecked((), future, scheduler) }
104    }
105
106    pub fn new_local<F, S>(future: F, scheduler: S) -> (Task, JoinHandle<F::Output>)
107    where
108        S: Scheduler<()> + Send,
109        F: Future + 'static,
110        F::Output: 'static,
111    {
112        Self::new_local_with((), future, scheduler)
113    }
114}
115
116impl<M> Task<M> {
117    pub(crate) fn from_raw(raw: RawTask) -> Self {
118        Self {
119            raw: Some(raw),
120            _meta: PhantomData,
121        }
122    }
123
124    pub fn metadata(&self) -> &M {
125        unsafe { &*self.raw.as_ref().unwrap_unchecked().metadata().cast() }
126    }
127
128    pub fn metadata_mut(&mut self) -> &mut M {
129        unsafe { &mut *self.raw.as_ref().unwrap_unchecked().metadata().cast() }
130    }
131
132    pub unsafe fn new_unchecked<F, S>(
133        meta: M,
134        future: F,
135        scheduler: S,
136    ) -> (Task<M>, JoinHandle<F::Output>)
137    where
138        M: 'static,
139        S: Scheduler<M>,
140        F: Future + 'static,
141    {
142        let (raw, join) = ThinArc::new(Box::new(RawTaskHeader {
143            header: Header::new(),
144            data: task::RawTaskInner {
145                future: UnsafeCell::new(Fut::Future(future)),
146                meta: UnsafeCell::new(meta),
147                scheduler,
148            },
149        }));
150        (Task::from_raw(raw), JoinHandle::new(join))
151    }
152
153    pub fn new_with<F, S>(meta: M, future: F, scheduler: S) -> (Task<M>, JoinHandle<F::Output>)
154    where
155        M: 'static + Send,
156        S: Scheduler<M> + Send,
157        F: Future + Send + 'static,
158        F::Output: Send + 'static,
159    {
160        unsafe { Self::new_unchecked(meta, future, scheduler) }
161    }
162
163    pub fn new_local_with<F, S>(
164        meta: M,
165        future: F,
166        scheduler: S,
167    ) -> (Task<M>, JoinHandle<F::Output>)
168    where
169        M: 'static + Send,
170        S: Scheduler<M> + Send,
171        F: Future + 'static,
172        F::Output: 'static,
173    {
174        use std::{
175            mem::ManuallyDrop,
176            pin::Pin,
177            task::{Context, Poll},
178            thread::{self, ThreadId},
179        };
180
181        #[inline]
182        fn thread_id() -> ThreadId {
183            std::thread_local! {
184                static ID: ThreadId = thread::current().id();
185            }
186            ID.try_with(|id| *id)
187                .unwrap_or_else(|_| thread::current().id())
188        }
189
190        struct Checked<F> {
191            id: ThreadId,
192            inner: ManuallyDrop<F>,
193        }
194
195        impl<F> Drop for Checked<F> {
196            fn drop(&mut self) {
197                assert!(
198                    self.id == thread_id(),
199                    "local task dropped by a thread that didn't spawn it"
200                );
201                unsafe {
202                    ManuallyDrop::drop(&mut self.inner);
203                }
204            }
205        }
206
207        impl<F: Future> Future for Checked<F> {
208            type Output = F::Output;
209            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210                unsafe {
211                    let me = self.get_unchecked_mut();
212                    assert!(
213                        me.id == thread_id(),
214                        "local task polled by a thread that didn't spawn it"
215                    );
216                    Pin::new_unchecked(&mut *me.inner).poll(cx)
217                }
218            }
219        }
220
221        let future = Checked {
222            id: thread_id(),
223            inner: ManuallyDrop::new(future),
224        };
225
226        unsafe { Self::new_unchecked(meta, future, scheduler) }
227    }
228
229    #[inline]
230    pub fn poll(mut self) -> Status<M> {
231        let raw = unsafe { self.raw.take().unwrap_unchecked() };
232        // Don't increase ref-counter
233        let waker = raw.clone_without_ref_inc();
234        // Don't decrease ref-counter
235        let waker = ManuallyDrop::new(raw.waker(waker));
236
237        // SAFETY: `Task` does not implement `Clone` and we have owned access
238        match unsafe { raw.poll(&waker) } {
239            PollStatus::Yield => Status::Yielded(Task::from_raw(raw)),
240            PollStatus::Pending => Status::Pending,
241            PollStatus::Complete => Status::Complete(Metadata {
242                raw,
243                _meta: PhantomData,
244            }),
245        }
246    }
247
248    #[inline]
249    pub fn schedule(mut self) {
250        unsafe {
251            let raw = self.raw.take().unwrap_unchecked();
252            raw.schedule(raw.clone());
253        }
254    }
255
256    #[inline]
257    pub fn id(&self) -> TaskId {
258        TaskId::new(unsafe { self.raw.as_ref().unwrap_unchecked() })
259    }
260}
261
262impl<M: Debug> fmt::Debug for Task<M> {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        f.debug_struct("Task")
265            .field("id", &self.id())
266            .field("state", &self.raw.as_ref().unwrap().header().state.load())
267            .field("metadata", self.metadata())
268            .finish()
269    }
270}