singleton_task/
lib.rs

1pub use std::sync::mpsc::{Receiver, SyncSender};
2use std::{
3    error::Error,
4    fmt::Display,
5    sync::{Arc, OnceLock, mpsc::sync_channel},
6    thread,
7};
8
9use context::{FutureTaskState, State};
10pub use futures::{FutureExt, future::LocalBoxFuture};
11use log::{trace, warn};
12
13mod context;
14mod task_chan;
15
16pub use context::Context;
17use task_chan::{TaskReceiver, TaskSender, task_channel};
18use tokio::{runtime::Runtime, select};
19
20static RT: OnceLock<Runtime> = OnceLock::new();
21
22pub trait TError: Error + Clone + Send + 'static {}
23
24#[derive(Debug, Clone)]
25pub enum TaskError<E: TError> {
26    Cancelled,
27    Error(E),
28}
29
30impl<E: TError> From<E> for TaskError<E> {
31    fn from(value: E) -> Self {
32        Self::Error(value)
33    }
34}
35
36impl<E: TError> Display for TaskError<E> {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Self::Cancelled => write!(f, "Cancelled"),
40            Self::Error(e) => write!(f, "{}", e),
41        }
42    }
43}
44
45pub trait TaskBuilder {
46    type Output: Send + 'static;
47    type Error: TError;
48    type Task: Task<Self::Error>;
49
50    fn build(self, tx: SyncSender<Self::Output>) -> Self::Task;
51    fn channel_size(&self) -> usize {
52        10
53    }
54}
55
56pub trait Task<E: TError>: Send + 'static {
57    fn on_start(&mut self, ctx: Context<E>) -> LocalBoxFuture<'_, Result<(), E>> {
58        drop(ctx);
59        async {
60            trace!("on_start");
61            Ok(())
62        }
63        .boxed_local()
64    }
65    fn on_stop(&mut self, ctx: Context<E>) -> LocalBoxFuture<'_, Result<(), E>> {
66        drop(ctx);
67        async {
68            trace!("on_stop");
69            Ok(())
70        }
71        .boxed_local()
72    }
73}
74
75struct TaskBox<E: TError> {
76    task: Box<dyn Task<E>>,
77    ctx: Context<E>,
78}
79
80struct WaitingTask<E: TError> {
81    task: TaskBox<E>,
82}
83
84#[derive(Clone)]
85pub struct SingletonTask<E: TError> {
86    tx: TaskSender<E>,
87    _drop: Arc<TaskDrop<E>>,
88}
89
90impl<E: TError> SingletonTask<E> {
91    pub fn new() -> Self {
92        let (tx, rx) = task_channel::<E>();
93
94        thread::spawn(move || Self::work_deal_start(rx));
95
96        Self {
97            _drop: Arc::new(TaskDrop { tx: tx.clone() }),
98            tx,
99        }
100    }
101
102    fn work_deal_start(rx: TaskReceiver<E>) {
103        while let Some(next) = rx.recv() {
104            let id = next.task.ctx.id();
105            if let Err(e) = Self::work_start_task(next) {
106                warn!("task [{}] error: {}", id, e);
107            }
108        }
109    }
110
111    fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
112        trace!("run task {}", next.task.ctx.id());
113        let ctx = next.task.ctx.clone();
114        let mut task = next.task.task;
115        match rt().block_on(async {
116            select! {
117                res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
118                res = ctx.wait_for(State::Stopping) => res
119            }
120        }) {
121            Ok(_) => {
122                if ctx.set_state(State::Running).is_err() {
123                    return Err(TaskError::Cancelled);
124                };
125            }
126            Err(e) => {
127                ctx.stop_with_result(Some(e));
128            }
129        }
130
131        rt().block_on(async {
132            let _ = ctx.wait_for(State::Stopping).await;
133            let _ = task.on_stop(ctx.clone()).await;
134            ctx.work_done();
135            let _ = ctx.wait_for(State::Stopped).await;
136        });
137
138        Ok(())
139    }
140
141    pub async fn start<T: TaskBuilder<Error = E>>(
142        &self,
143        task_builder: T,
144    ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
145        let channel_size = task_builder.channel_size();
146        let (tx, rx) = sync_channel::<T::Output>(channel_size);
147        let task = Box::new(task_builder.build(tx));
148        let task_box = TaskBox {
149            task,
150            ctx: Context::default(),
151        };
152        let ctx = task_box.ctx.clone();
153
154        self.tx.send(WaitingTask { task: task_box });
155
156        ctx.wait_for(State::Running).await?;
157
158        Ok(TaskHandle { rx, ctx })
159    }
160}
161
162struct TaskDrop<E: TError> {
163    tx: TaskSender<E>,
164}
165impl<E: TError> Drop for TaskDrop<E> {
166    fn drop(&mut self) {
167        self.tx.stop();
168    }
169}
170
171impl<E: TError> Default for SingletonTask<E> {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177pub struct TaskHandle<T, E: TError> {
178    pub rx: Receiver<T>,
179    pub ctx: Context<E>,
180}
181
182impl<T, E: TError> TaskHandle<T, E> {
183    pub fn stop(self) -> FutureTaskState<E> {
184        self.ctx.stop()
185    }
186    pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
187        self.ctx.wait_for(State::Stopped)
188    }
189
190    pub fn recv(&self) -> Result<T, std::sync::mpsc::RecvError> {
191        self.rx.recv()
192    }
193}
194
195fn rt() -> &'static Runtime {
196    RT.get_or_init(|| {
197        tokio::runtime::Builder::new_current_thread()
198            .enable_all()
199            .build()
200            .unwrap()
201    })
202}
203
204#[cfg(test)]
205mod test {
206    use log::LevelFilter;
207
208    use super::*;
209
210    #[derive(Debug, Clone)]
211    enum Error1 {
212        _A,
213    }
214
215    impl TError for Error1 {}
216    impl Error for Error1 {}
217    impl Display for Error1 {
218        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219            write!(f, "{:?}", self)
220        }
221    }
222
223    struct Task1 {
224        _a: i32,
225    }
226
227    impl Task<Error1> for Task1 {
228        fn on_start(&mut self, _ctx: Context<Error1>) -> LocalBoxFuture<'_, Result<(), Error1>> {
229            async {
230                trace!("on_start 1");
231                Ok(())
232            }
233            .boxed_local()
234        }
235    }
236
237    struct Tasl1Builder {}
238
239    impl TaskBuilder for Tasl1Builder {
240        type Output = u32;
241        type Error = Error1;
242        type Task = Task1;
243
244        fn build(self, _tx: SyncSender<u32>) -> Self::Task {
245            Task1 { _a: 1 }
246        }
247    }
248
249    #[tokio::test]
250    async fn test_task() {
251        env_logger::builder()
252            .is_test(true)
253            .filter_level(LevelFilter::Trace)
254            .init();
255
256        let st = SingletonTask::<Error1>::new();
257        let _rx = st.start(Tasl1Builder {}).await.unwrap();
258    }
259}