singleton_task/
lib.rs

1pub use std::sync::mpsc::{Receiver, SyncSender};
2use std::{
3    error::Error,
4    fmt::Display,
5    sync::{Arc, OnceLock, mpsc::sync_channel},
6    thread,
7};
8
9pub use async_trait::async_trait;
10use tokio::{runtime::Runtime, select};
11
12use context::{FutureTaskState, State};
13use log::{trace, warn};
14
15mod context;
16mod task_chan;
17
18pub use context::Context;
19use task_chan::{TaskReceiver, TaskSender, task_channel};
20
21static RT: OnceLock<Runtime> = OnceLock::new();
22
23pub trait TError: Error + Clone + Send + 'static {}
24
25#[derive(Debug, Clone)]
26pub enum TaskError<E: TError> {
27    Cancelled,
28    Error(E),
29}
30
31impl<E: TError> From<E> for TaskError<E> {
32    fn from(value: E) -> Self {
33        Self::Error(value)
34    }
35}
36
37impl<E: TError> Display for TaskError<E> {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::Cancelled => write!(f, "Cancelled"),
41            Self::Error(e) => write!(f, "{}", e),
42        }
43    }
44}
45
46pub trait TaskBuilder {
47    type Output: Send + 'static;
48    type Error: TError;
49    type Task: Task<Self::Error>;
50
51    fn build(self, tx: SyncSender<Self::Output>) -> Self::Task;
52    fn channel_size(&self) -> usize {
53        10
54    }
55}
56
57#[async_trait]
58pub trait Task<E: TError>: Send + 'static {
59    async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
60        drop(ctx);
61        trace!("on_start");
62        Ok(())
63    }
64    async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
65        drop(ctx);
66        trace!("on_stop");
67        Ok(())
68    }
69}
70
71struct TaskBox<E: TError> {
72    task: Box<dyn Task<E>>,
73    ctx: Context<E>,
74}
75
76struct WaitingTask<E: TError> {
77    task: TaskBox<E>,
78}
79
80#[derive(Clone)]
81pub struct SingletonTask<E: TError> {
82    tx: TaskSender<E>,
83    _drop: Arc<TaskDrop<E>>,
84}
85
86impl<E: TError> SingletonTask<E> {
87    pub fn new() -> Self {
88        let (tx, rx) = task_channel::<E>();
89
90        thread::spawn(move || Self::work_deal_start(rx));
91
92        Self {
93            _drop: Arc::new(TaskDrop { tx: tx.clone() }),
94            tx,
95        }
96    }
97
98    fn work_deal_start(rx: TaskReceiver<E>) {
99        while let Some(next) = rx.recv() {
100            let id = next.task.ctx.id();
101            if let Err(e) = Self::work_start_task(next) {
102                warn!("task [{}] error: {}", id, e);
103            }
104        }
105    }
106
107    fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
108        trace!("run task {}", next.task.ctx.id());
109        let ctx = next.task.ctx.clone();
110        let mut task = next.task.task;
111        match rt().block_on(async {
112            select! {
113                res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
114                res = ctx.wait_for(State::Stopping) => res
115            }
116        }) {
117            Ok(_) => {
118                if ctx.set_state(State::Running).is_err() {
119                    return Err(TaskError::Cancelled);
120                };
121            }
122            Err(e) => {
123                ctx.stop_with_result(Some(e));
124            }
125        }
126
127        rt().block_on(async {
128            let _ = ctx.wait_for(State::Stopping).await;
129            let _ = task.on_stop(ctx.clone()).await;
130            ctx.work_done();
131            let _ = ctx.wait_for(State::Stopped).await;
132        });
133
134        Ok(())
135    }
136
137    pub async fn start<T: TaskBuilder<Error = E>>(
138        &self,
139        task_builder: T,
140    ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
141        let channel_size = task_builder.channel_size();
142        let (tx, rx) = sync_channel::<T::Output>(channel_size);
143        let task = Box::new(task_builder.build(tx));
144        let task_box = TaskBox {
145            task,
146            ctx: Context::default(),
147        };
148        let ctx = task_box.ctx.clone();
149
150        self.tx.send(WaitingTask { task: task_box });
151
152        ctx.wait_for(State::Running).await?;
153
154        Ok(TaskHandle { rx, ctx })
155    }
156}
157
158struct TaskDrop<E: TError> {
159    tx: TaskSender<E>,
160}
161impl<E: TError> Drop for TaskDrop<E> {
162    fn drop(&mut self) {
163        self.tx.stop();
164    }
165}
166
167impl<E: TError> Default for SingletonTask<E> {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173pub struct TaskHandle<T, E: TError> {
174    pub rx: Receiver<T>,
175    pub ctx: Context<E>,
176}
177
178impl<T, E: TError> TaskHandle<T, E> {
179    pub fn stop(self) -> FutureTaskState<E> {
180        self.ctx.stop()
181    }
182    pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
183        self.ctx.wait_for(State::Stopped)
184    }
185
186    pub fn recv(&self) -> Result<T, std::sync::mpsc::RecvError> {
187        self.rx.recv()
188    }
189}
190
191fn rt() -> &'static Runtime {
192    RT.get_or_init(|| {
193        tokio::runtime::Builder::new_current_thread()
194            .enable_all()
195            .build()
196            .unwrap()
197    })
198}