singleton_task/
context.rs

1use std::{
2    sync::{
3        Arc, Mutex,
4        atomic::{AtomicU32, Ordering},
5    },
6    task::{Poll, Waker},
7};
8
9use log::trace;
10use tokio::select;
11use tokio_util::sync::CancellationToken;
12
13use crate::{TError, TaskError};
14
15#[derive(Clone)]
16pub struct Context<E: TError> {
17    id: u32,
18    inner: Arc<Mutex<ContextInner<E>>>,
19    cancel: CancellationToken,
20}
21
22impl<E: TError> Context<E> {
23    pub fn id(&self) -> u32 {
24        self.id
25    }
26
27    pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
28        self.inner.lock().unwrap().set_state(state)
29    }
30
31    pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
32        FutureTaskState::new(self.clone(), state)
33    }
34
35    pub fn stop(&self) -> FutureTaskState<E> {
36        self.stop_with_result(Some(TaskError::Cancelled))
37    }
38
39    pub fn is_active(&self) -> bool {
40        !self.cancel.is_cancelled()
41    }
42
43    pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
44        let fur = self.wait_for(State::Stopped);
45        let mut g = self.inner.lock().unwrap();
46        if g.state >= State::Stopping {
47            return fur;
48        }
49        let _ = g.set_state(State::Stopping);
50        g.error = res;
51        g.wake_all();
52        drop(g);
53        self.cancel.cancel();
54        fur
55    }
56
57    pub fn spawn<F>(&self, fut: F)
58    where
59        F: Future + Send + 'static,
60    {
61        let mut g = self.inner.lock().unwrap();
62        g.spawn(self, fut);
63    }
64
65    pub(crate) fn work_done(&self) {
66        let mut g = self.inner.lock().unwrap();
67        g.work_count -= 1;
68        trace!("[{:>6}] work count {}", self.id, g.work_count);
69        if g.work_count == 1 && g.state == State::Running {
70            let _ = g.set_state(State::Stopping);
71        }
72
73        if g.work_count == 0 {
74            let _ = g.set_state(State::Stopped);
75        }
76    }
77}
78
79impl<E: TError> Default for Context<E> {
80    fn default() -> Self {
81        static TASK_ID: AtomicU32 = AtomicU32::new(1);
82        let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
83
84        Self {
85            id,
86            inner: Arc::new(Mutex::new(ContextInner {
87                id,
88                work_count: 1,
89                ..Default::default()
90            })),
91            cancel: CancellationToken::new(),
92        }
93    }
94}
95
96struct ContextInner<E: TError> {
97    error: Option<TaskError<E>>,
98    state: State,
99    wakers: Vec<Waker>,
100    work_count: u32,
101    id: u32,
102}
103
104impl<E: TError> ContextInner<E> {
105    fn wake_all(&mut self) {
106        for waker in self.wakers.iter() {
107            waker.wake_by_ref();
108        }
109        self.wakers.clear();
110    }
111
112    fn set_state(&mut self, state: State) -> Result<(), &'static str> {
113        if state < self.state {
114            return Err("state is not allowed");
115        }
116        trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
117        self.state = state;
118        self.wake_all();
119        Ok(())
120    }
121
122    fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
123    where
124        F: Future + Send + 'static,
125    {
126        let ctx = ctx.clone();
127        if matches!(self.state, State::Stopping | State::Stopped) {
128            return;
129        }
130
131        self.work_count += 1;
132        trace!("[{:>6}] work count {}", ctx.id, self.work_count);
133
134        tokio::spawn(async move {
135            select! {
136                _ = fur =>{}
137                _ = ctx.cancel.cancelled() => {}
138                _ = ctx.wait_for(State::Stopping) => {}
139            }
140            ctx.work_done();
141        });
142    }
143}
144
145impl<E: TError> Default for ContextInner<E> {
146    fn default() -> Self {
147        Self {
148            id: 0,
149            error: None,
150            state: State::default(),
151            wakers: Default::default(),
152            work_count: 0,
153        }
154    }
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
158pub enum State {
159    Idle,
160    Preparing,
161    Running,
162    Stopping,
163    Stopped,
164}
165
166impl Default for State {
167    fn default() -> Self {
168        Self::Idle
169    }
170}
171
172pub struct FutureTaskState<E: TError> {
173    ctx: Context<E>,
174    want: State,
175}
176impl<E: TError> FutureTaskState<E> {
177    fn new(ctx: Context<E>, want: State) -> Self {
178        Self { ctx, want }
179    }
180}
181
182impl<E: TError> Future for FutureTaskState<E> {
183    type Output = Result<(), TaskError<E>>;
184
185    fn poll(
186        self: std::pin::Pin<&mut Self>,
187        cx: &mut std::task::Context<'_>,
188    ) -> std::task::Poll<Self::Output> {
189        let mut g = self.ctx.inner.lock().unwrap();
190        if g.state >= self.want {
191            Poll::Ready(match g.error.clone() {
192                Some(e) => Err(e),
193                None => Ok(()),
194            })
195        } else {
196            g.wakers.push(cx.waker().clone());
197            Poll::Pending
198        }
199    }
200}