singleton_task/
context.rs1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicU32, Ordering},
5 },
6 task::{Poll, Waker},
7 thread,
8};
9
10use log::trace;
11use tokio::select;
12use tokio_util::sync::CancellationToken;
13
14use crate::{TError, TaskError, rt};
15
16#[derive(Clone)]
17pub struct Context<E: TError> {
18 id: u32,
19 inner: Arc<Mutex<ContextInner<E>>>,
20 cancel: CancellationToken,
21}
22
23impl<E: TError> Context<E> {
24 pub fn id(&self) -> u32 {
25 self.id
26 }
27
28 pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
29 self.inner.lock().unwrap().set_state(state)
30 }
31
32 pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
33 FutureTaskState::new(self.clone(), state)
34 }
35
36 pub fn stop(&self) -> FutureTaskState<E> {
37 self.stop_with_result(Some(TaskError::Cancelled))
38 }
39
40 pub fn is_active(&self) -> bool {
41 !self.cancel.is_cancelled()
42 }
43
44 pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
45 let fur = self.wait_for(State::Stopped);
46 let mut g = self.inner.lock().unwrap();
47 if g.state >= State::Stopping {
48 return fur;
49 }
50 let _ = g.set_state(State::Stopping);
51 g.error = res;
52 g.wake_all();
53 drop(g);
54 self.cancel.cancel();
55 fur
56 }
57
58 pub fn spawn<F>(&self, fut: F)
59 where
60 F: Future + Send + 'static,
61 {
62 let mut g = self.inner.lock().unwrap();
63 g.spawn(self, fut);
64 }
65
66 pub(crate) fn work_done(&self) {
67 let mut g = self.inner.lock().unwrap();
68 g.work_count -= 1;
69 trace!("[{:>6}] work count {}", self.id, g.work_count);
70 if g.work_count == 0 {
71 let _ = g.set_state(State::Stopped);
72 }
73 }
74}
75
76impl<E: TError> Default for Context<E> {
77 fn default() -> Self {
78 static TASK_ID: AtomicU32 = AtomicU32::new(1);
79 let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
80
81 Self {
82 id,
83 inner: Arc::new(Mutex::new(ContextInner {
84 id,
85 work_count: 1,
86 ..Default::default()
87 })),
88 cancel: CancellationToken::new(),
89 }
90 }
91}
92
93struct ContextInner<E: TError> {
94 error: Option<TaskError<E>>,
95 state: State,
96 wakers: Vec<Waker>,
97 work_count: u32,
98 id: u32,
99}
100
101impl<E: TError> ContextInner<E> {
102 fn wake_all(&mut self) {
103 for waker in self.wakers.iter() {
104 waker.wake_by_ref();
105 }
106 self.wakers.clear();
107 }
108
109 fn set_state(&mut self, state: State) -> Result<(), &'static str> {
110 if state < self.state {
111 return Err("state is not allowed");
112 }
113 trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
114 self.state = state;
115 self.wake_all();
116 Ok(())
117 }
118
119 fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
120 where
121 F: Future + Send + 'static,
122 {
123 let ctx = ctx.clone();
124 if matches!(self.state, State::Stopping | State::Stopped) {
125 return;
126 }
127
128 self.work_count += 1;
129 trace!("[{:>6}] work count {}", ctx.id, self.work_count);
130 thread::spawn(move || {
131 rt().block_on(async move {
132 select! {
133 _ = fur =>{}
134 _ = ctx.cancel.cancelled() => {}
135 _ = ctx.wait_for(State::Stopping) => {}
136 }
137 ctx.work_done();
138 });
139 });
140 }
141}
142
143impl<E: TError> Default for ContextInner<E> {
144 fn default() -> Self {
145 Self {
146 id: 0,
147 error: None,
148 state: State::default(),
149 wakers: Default::default(),
150 work_count: 0,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
156pub enum State {
157 Idle,
158 Preparing,
159 Running,
160 Stopping,
161 Stopped,
162}
163
164impl Default for State {
165 fn default() -> Self {
166 Self::Idle
167 }
168}
169
170pub struct FutureTaskState<E: TError> {
171 ctx: Context<E>,
172 want: State,
173}
174impl<E: TError> FutureTaskState<E> {
175 fn new(ctx: Context<E>, want: State) -> Self {
176 Self { ctx, want }
177 }
178}
179
180impl<E: TError> Future for FutureTaskState<E> {
181 type Output = Result<(), TaskError<E>>;
182
183 fn poll(
184 self: std::pin::Pin<&mut Self>,
185 cx: &mut std::task::Context<'_>,
186 ) -> std::task::Poll<Self::Output> {
187 let mut g = self.ctx.inner.lock().unwrap();
188 if g.state >= self.want {
189 Poll::Ready(match g.error.clone() {
190 Some(e) => Err(e),
191 None => Ok(()),
192 })
193 } else {
194 g.wakers.push(cx.waker().clone());
195 Poll::Pending
196 }
197 }
198}