1use super::spinlock::SpinLock;
2
3use std::error::Error;
4use std::fmt::{Debug, Display, Formatter};
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll, Wake, Waker};
9
10pub trait TaskType: Send + Clone + 'static {}
11impl<T> TaskType for T where T: Send + Clone + 'static {}
12
13#[derive(Debug, Clone, PartialEq)]
14pub enum TaskError {
15 Unhealthy,
16}
17
18impl Display for TaskError {
19 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
20 match self {
21 TaskError::Unhealthy => write!(f, "worker executor panicked"),
22 }
23 }
24}
25
26impl Error for TaskError {}
27
28pub(super) struct TaskInner<T>
29where
30 T: TaskType,
31{
32 pub(super) ready: bool,
33 pub(super) result: Option<Result<T, TaskError>>,
34 pub(super) future: Option<Pin<Box<dyn Future<Output = T> + Send>>>,
35}
36
37impl<T> Wake for Task<T>
38where
39 T: TaskType,
40{
41 fn wake(self: Arc<Self>) {
42 (self.resume)(self.clone());
43 }
44}
45
46pub struct Task<T>
47where
48 T: TaskType,
49{
50 pub(super) inner: SpinLock<TaskInner<T>>,
51 pub(super) resume: Box<dyn Fn(Arc<Task<T>>) + Send + Sync + 'static>,
52}
53
54#[derive(Debug, PartialEq)]
55pub enum PollError {
56 NotReady,
57 Error(TaskError),
58}
59
60impl<T> Task<T>
61where
62 T: TaskType,
63{
64 pub fn new<F>(resumer: F, future: Pin<Box<dyn Future<Output = T> + Send>>) -> Self
65 where
66 F: Fn(Arc<Task<T>>) + Send + Sync + 'static,
67 {
68 Task {
69 inner: SpinLock::new(TaskInner {
70 ready: false,
71 result: None,
72 future: Some(future),
73 }),
74 resume: Box::new(resumer),
75 }
76 }
77
78 pub fn poll(&self) -> Result<T, PollError> {
79 let guard = self.inner.lock().unwrap();
80 if guard.ready {
81 guard.result.clone().unwrap().map_err(PollError::Error)
82 } else {
83 Err(PollError::NotReady)
84 }
85 }
86
87 pub fn wait(&self) -> Result<T, TaskError> {
88 loop {
90 std::thread::yield_now();
91 let poll = self.poll();
92 if let Err(error) = poll {
93 match error {
94 PollError::NotReady => (),
95 PollError::Error(error) => return Err(error),
96 }
97 } else {
98 return Ok(poll.unwrap());
99 }
100 }
101 }
102
103 pub(super) fn progress(self: Arc<Self>) -> bool {
104 if let Some(mut future) = self.inner.lock().unwrap().future.take() {
105 let waker = Waker::from(self.clone());
106 let context = &mut Context::from_waker(&waker);
107 let poll = future.as_mut().poll(context);
108 let mut guard = self.inner.lock().unwrap();
109 if let Poll::Ready(value) = poll {
110 guard.ready = true;
111 guard.result = Some(Ok(value));
112 true
113 } else {
114 guard.future = Some(future);
115 false
116 }
117 } else {
118 false
119 }
120 }
121}