singleton_task/
context.rs

1use std::{
2    sync::{
3        Arc, Mutex,
4        atomic::{AtomicU32, Ordering},
5    },
6    task::{Poll, Waker},
7    thread,
8};
9
10use log::trace;
11use tokio::select;
12
13use crate::{TError, TaskError, rt};
14
15#[derive(Clone)]
16pub struct Context<E: TError> {
17    id: u32,
18    inner: Arc<Mutex<ContextInner<E>>>,
19}
20
21impl<E: TError> Context<E> {
22    pub fn id(&self) -> u32 {
23        self.id
24    }
25
26    pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
27        self.inner.lock().unwrap().set_state(state)
28    }
29
30    pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
31        FutureTaskState::new(self.clone(), state)
32    }
33
34    pub fn stop(&self) -> FutureTaskState<E> {
35        self.stop_with_result(Some(TaskError::Cancelled))
36    }
37
38    pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
39        let fur = self.wait_for(State::Stopped);
40        let mut g = self.inner.lock().unwrap();
41        if g.state >= State::Stopping {
42            return fur;
43        }
44        let _ = g.set_state(State::Stopping);
45        g.error = res;
46        g.wake_all();
47        fur
48    }
49
50    pub fn spawn<F>(&self, fut: F)
51    where
52        F: Future + Send + 'static,
53    {
54        let mut g = self.inner.lock().unwrap();
55        g.spawn(self, fut);
56    }
57
58    pub(crate) fn work_done(&self) {
59        let mut g = self.inner.lock().unwrap();
60        g.work_count -= 1;
61        trace!("[{:>6}] work count {}", self.id, g.work_count);
62        if g.work_count == 0 {
63            let _ = g.set_state(State::Stopped);
64        }
65    }
66}
67
68impl<E: TError> Default for Context<E> {
69    fn default() -> Self {
70        static TASK_ID: AtomicU32 = AtomicU32::new(1);
71        let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
72
73        Self {
74            id,
75            inner: Arc::new(Mutex::new(ContextInner {
76                id,
77                work_count: 1,
78                ..Default::default()
79            })),
80        }
81    }
82}
83
84struct ContextInner<E: TError> {
85    error: Option<TaskError<E>>,
86    state: State,
87    wakers: Vec<Waker>,
88    work_count: u32,
89    id: u32,
90}
91
92impl<E: TError> ContextInner<E> {
93    fn wake_all(&mut self) {
94        for waker in self.wakers.iter() {
95            waker.wake_by_ref();
96        }
97        self.wakers.clear();
98    }
99
100    fn set_state(&mut self, state: State) -> Result<(), &'static str> {
101        if state < self.state {
102            return Err("state is not allowed");
103        }
104        trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
105        self.state = state;
106        self.wake_all();
107        Ok(())
108    }
109
110    fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
111    where
112        F: Future + Send + 'static,
113    {
114        let ctx = ctx.clone();
115        if matches!(self.state, State::Stopping | State::Stopped) {
116            return;
117        }
118
119        self.work_count += 1;
120        trace!("[{:>6}] work count {}", ctx.id, self.work_count);
121        thread::spawn(move || {
122            rt().block_on(async move {
123                select! {
124                    _ = fur =>{}
125                    _ = ctx.wait_for(State::Stopping) => {}
126                }
127                ctx.work_done();
128            });
129        });
130    }
131}
132
133impl<E: TError> Default for ContextInner<E> {
134    fn default() -> Self {
135        Self {
136            id: 0,
137            error: None,
138            state: State::default(),
139            wakers: Default::default(),
140            work_count: 0,
141        }
142    }
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
146pub enum State {
147    Idle,
148    Preparing,
149    Running,
150    Stopping,
151    Stopped,
152}
153
154impl Default for State {
155    fn default() -> Self {
156        Self::Idle
157    }
158}
159
160pub struct FutureTaskState<E: TError> {
161    ctx: Context<E>,
162    want: State,
163}
164impl<E: TError> FutureTaskState<E> {
165    fn new(ctx: Context<E>, want: State) -> Self {
166        Self { ctx, want }
167    }
168}
169
170impl<E: TError> Future for FutureTaskState<E> {
171    type Output = Result<(), TaskError<E>>;
172
173    fn poll(
174        self: std::pin::Pin<&mut Self>,
175        cx: &mut std::task::Context<'_>,
176    ) -> std::task::Poll<Self::Output> {
177        let mut g = self.ctx.inner.lock().unwrap();
178        if g.state >= self.want {
179            Poll::Ready(match g.error.clone() {
180                Some(e) => Err(e),
181                None => Ok(()),
182            })
183        } else {
184            g.wakers.push(cx.waker().clone());
185            Poll::Pending
186        }
187    }
188}