1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicU32, Ordering},
5 },
6 task::{Poll, Waker},
7};
8
9use log::trace;
10use tokio::{runtime::Handle, select, task::JoinHandle};
11use tokio_util::sync::CancellationToken;
12
13use crate::{TError, TaskError};
14
15#[derive(Clone)]
16pub struct Context<E: TError> {
17 id: u32,
18 inner: Arc<Mutex<ContextInner<E>>>,
19 cancel: CancellationToken,
20}
21
22impl<E: TError> Context<E> {
23 pub fn id(&self) -> u32 {
24 self.id
25 }
26
27 pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
28 self.inner.lock().unwrap().set_state(state)
29 }
30
31 pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
32 FutureTaskState::new(self.clone(), state)
33 }
34
35 pub fn stop(&self) -> FutureTaskState<E> {
36 self._stop(Some(TaskError::Cancelled))
37 }
38
39 pub fn is_active(&self) -> bool {
40 !self.cancel.is_cancelled()
41 }
42
43 fn _stop(&self, err: Option<TaskError<E>>) -> FutureTaskState<E> {
44 let fur = self.wait_for(State::Stopped);
45 let mut g = self.inner.lock().unwrap();
46 if g.state >= State::Stopping {
47 return fur;
48 }
49 let _ = g.set_state(State::Stopping);
50 g.error = err;
51 g.wake_all();
52 drop(g);
53 self.cancel.cancel();
54 fur
55 }
56
57 pub(crate) fn stop_with_terr(&self, err: TaskError<E>) -> FutureTaskState<E> {
58 self._stop(Some(err))
59 }
60
61 pub fn stop_with_err(&self, err: E) -> FutureTaskState<E> {
62 self._stop(Some(TaskError::Error(err)))
63 }
64
65 pub fn spawn<F>(&self, fut: F) -> JoinHandle<Result<F::Output, TaskError<E>>>
66 where
67 F: Future + Send + 'static,
68 F::Output: Send + 'static,
69 {
70 let mut g = self.inner.lock().unwrap();
71 g.spawn(self, fut)
72 }
73
74 pub fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<Result<R, TaskError<E>>>
75 where
76 F: FnOnce(&Context<E>) -> R + Send + 'static,
77 R: Send + 'static,
78 {
79 let mut g = self.inner.lock().unwrap();
80 g.spawn_blocking(self, f)
81 }
82
83 pub(crate) fn work_done(&self) {
84 let mut g = self.inner.lock().unwrap();
85 g.work_count -= 1;
86 trace!("[{:>6}] work count {}", self.id, g.work_count);
87 if g.work_count == 1 && g.state == State::Running {
88 let _ = g.set_state(State::Stopping);
89 }
90
91 if g.work_count == 0 {
92 let _ = g.set_state(State::Stopped);
93 }
94 }
95}
96
97impl<E: TError> Default for Context<E> {
98 fn default() -> Self {
99 static TASK_ID: AtomicU32 = AtomicU32::new(1);
100 let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
101
102 Self {
103 id,
104 inner: Arc::new(Mutex::new(ContextInner {
105 id,
106 work_count: 1,
107 ..Default::default()
108 })),
109 cancel: CancellationToken::new(),
110 }
111 }
112}
113
114struct ContextInner<E: TError> {
115 error: Option<TaskError<E>>,
116 state: State,
117 wakers: Vec<Waker>,
118 work_count: u32,
119 id: u32,
120}
121
122impl<E: TError> ContextInner<E> {
123 fn wake_all(&mut self) {
124 for waker in self.wakers.iter() {
125 waker.wake_by_ref();
126 }
127 self.wakers.clear();
128 }
129
130 fn set_state(&mut self, state: State) -> Result<(), &'static str> {
131 if state < self.state {
132 return Err("state is not allowed");
133 }
134 trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
135 self.state = state;
136 self.wake_all();
137 Ok(())
138 }
139
140 fn spawn<F>(&mut self, ctx: &Context<E>, fur: F) -> JoinHandle<Result<F::Output, TaskError<E>>>
141 where
142 F: Future + Send + 'static,
143 F::Output: Send + 'static,
144 {
145 let ctx = ctx.clone();
146
147 self.work_count += 1;
148 trace!("[{:>6}] work count {}", ctx.id, self.work_count);
149 let handle = Handle::current();
150
151 handle.spawn(async move {
152 let mut res = Err(TaskError::Cancelled);
153 select! {
154 r = fur =>{
155 trace!("[{:>6}] exit: finish", ctx.id);
156 res = Ok(r);
157 }
158 _ = ctx.cancel.cancelled() => {
159 trace!("[{:>6}] exit: cancel token", ctx.id);
160 }
161 _ = ctx.wait_for(State::Stopping) => {
162 trace!("[{:>6}] exit: stopping", ctx.id);
163 }
164 }
165 ctx.work_done();
166 res
167 })
168 }
169
170 fn spawn_blocking<F, R>(
171 &mut self,
172 ctx: &Context<E>,
173 fur: F,
174 ) -> JoinHandle<Result<R, TaskError<E>>>
175 where
176 F: FnOnce(&Context<E>) -> R + Send + 'static,
177 R: Send + 'static,
178 {
179 let ctx = ctx.clone();
180
181 self.work_count += 1;
182 trace!("[{:>6}] work count {}", ctx.id, self.work_count);
183 let handle = Handle::current();
184
185 handle.spawn_blocking(move || {
186 if !ctx.is_active() {
187 return Err(TaskError::Cancelled);
188 }
189 let r = fur(&ctx);
190 ctx.work_done();
191 Ok(r)
192 })
193 }
194}
195
196impl<E: TError> Default for ContextInner<E> {
197 fn default() -> Self {
198 Self {
199 id: 0,
200 error: None,
201 state: State::default(),
202 wakers: Default::default(),
203 work_count: 0,
204 }
205 }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
209pub enum State {
210 Idle,
211 Preparing,
212 Running,
213 Stopping,
214 Stopped,
215}
216
217impl Default for State {
218 fn default() -> Self {
219 Self::Idle
220 }
221}
222
223pub struct FutureTaskState<E: TError> {
224 ctx: Context<E>,
225 want: State,
226}
227impl<E: TError> FutureTaskState<E> {
228 fn new(ctx: Context<E>, want: State) -> Self {
229 Self { ctx, want }
230 }
231}
232
233impl<E: TError> Future for FutureTaskState<E> {
234 type Output = Result<(), TaskError<E>>;
235
236 fn poll(
237 self: std::pin::Pin<&mut Self>,
238 cx: &mut std::task::Context<'_>,
239 ) -> std::task::Poll<Self::Output> {
240 let mut g = self.ctx.inner.lock().unwrap();
241 if g.state >= self.want {
242 Poll::Ready(match g.error.clone() {
243 Some(e) => Err(e),
244 None => Ok(()),
245 })
246 } else {
247 g.wakers.push(cx.waker().clone());
248 Poll::Pending
249 }
250 }
251}