drumbeat/sync/
task.rs

1use super::spinlock::SpinLock;
2
3use std::error::Error;
4use std::fmt::{Debug, Display, Formatter};
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll, Wake, Waker};
9
10pub trait TaskType: Send + Clone + 'static {}
11impl<T> TaskType for T where T: Send + Clone + 'static {}
12
13#[derive(Debug, Clone, PartialEq)]
14pub enum TaskError {
15  Unhealthy,
16}
17
18impl Display for TaskError {
19  fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
20    match self {
21      TaskError::Unhealthy => write!(f, "worker executor panicked"),
22    }
23  }
24}
25
26impl Error for TaskError {}
27
28pub(super) struct TaskInner<T>
29where
30  T: TaskType,
31{
32  pub(super) ready: bool,
33  pub(super) result: Option<Result<T, TaskError>>,
34  pub(super) future: Option<Pin<Box<dyn Future<Output = T> + Send>>>,
35}
36
37impl<T> Wake for Task<T>
38where
39  T: TaskType,
40{
41  fn wake(self: Arc<Self>) {
42    (self.resume)(self.clone());
43  }
44}
45
46pub struct Task<T>
47where
48  T: TaskType,
49{
50  pub(super) inner: SpinLock<TaskInner<T>>,
51  pub(super) resume: Box<dyn Fn(Arc<Task<T>>) + Send + Sync + 'static>,
52}
53
54#[derive(Debug, PartialEq)]
55pub enum PollError {
56  NotReady,
57  Error(TaskError),
58}
59
60impl<T> Task<T>
61where
62  T: TaskType,
63{
64  pub fn new<F>(resumer: F, future: Pin<Box<dyn Future<Output = T> + Send>>) -> Self
65  where
66    F: Fn(Arc<Task<T>>) + Send + Sync + 'static,
67  {
68    Task {
69      inner: SpinLock::new(TaskInner {
70        ready: false,
71        result: None,
72        future: Some(future),
73      }),
74      resume: Box::new(resumer),
75    }
76  }
77
78  pub fn poll(&self) -> Result<T, PollError> {
79    let guard = self.inner.lock().unwrap();
80    if guard.ready {
81      guard.result.clone().unwrap().map_err(PollError::Error)
82    } else {
83      Err(PollError::NotReady)
84    }
85  }
86
87  pub fn wait(&self) -> Result<T, TaskError> {
88    // TODO: Implement non busy wait version
89    loop {
90      std::thread::yield_now();
91      let poll = self.poll();
92      if let Err(error) = poll {
93        match error {
94          PollError::NotReady => (),
95          PollError::Error(error) => return Err(error),
96        }
97      } else {
98        return Ok(poll.unwrap());
99      }
100    }
101  }
102
103  pub(super) fn progress(self: Arc<Self>) -> bool {
104    if let Some(mut future) = self.inner.lock().unwrap().future.take() {
105      let waker = Waker::from(self.clone());
106      let context = &mut Context::from_waker(&waker);
107      let poll = future.as_mut().poll(context);
108      let mut guard = self.inner.lock().unwrap();
109      if let Poll::Ready(value) = poll {
110        guard.ready = true;
111        guard.result = Some(Ok(value));
112        true
113      } else {
114        guard.future = Some(future);
115        false
116      }
117    } else {
118      false
119    }
120  }
121}