bmart/
workers.rs

1// TODO logs
2use crate::Error;
3use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::{mpsc, Notify};
8use tokio::task;
9use tokio::time::{sleep_until, Instant};
10
11const ERR_DUPLICATE_WORKER_ID: &str = "Duplicate worker ID";
12const ERR_WORKER_NOT_FOUND: &str = "Worker not found";
13
14#[derive(Debug)]
15pub struct Scheduler {
16    interval: Duration,
17    trigger: Arc<Notify>,
18}
19
20impl Scheduler {
21    pub fn new(trigger: Arc<Notify>, interval: Duration) -> Self {
22        Self { interval, trigger }
23    }
24    pub async fn run(&mut self) {
25        let mut t = Instant::now();
26        loop {
27            t += self.interval;
28            sleep_until(t).await;
29            self.trigger.notify_waiters();
30        }
31    }
32    pub async fn run_instant(&mut self) {
33        let mut t = Instant::now();
34        loop {
35            self.trigger.notify_waiters();
36            t += self.interval;
37            sleep_until(t).await;
38        }
39    }
40}
41
42pub struct WorkerFactory {
43    schedulers: BTreeMap<String, task::JoinHandle<()>>,
44}
45
46impl Default for WorkerFactory {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl WorkerFactory {
53    #[must_use]
54    pub fn new() -> Self {
55        Self {
56            schedulers: BTreeMap::new(),
57        }
58    }
59
60    /// # Errors
61    ///
62    /// Will return `Err` if the worker already exists
63    pub fn create_scheduler(
64        &mut self,
65        worker_id: &str,
66        trigger: Arc<Notify>,
67        interval: Duration,
68        instant: bool,
69    ) -> Result<(), Error> {
70        self._create_scheduler(worker_id, trigger, interval, false, instant)
71    }
72
73    /// # Errors
74    ///
75    /// Will return `Err` if failed to recreate the worker
76    pub fn recreate_scheduler(
77        &mut self,
78        worker_id: &str,
79        trigger: Arc<Notify>,
80        interval: Duration,
81        instant: bool,
82    ) -> Result<(), Error> {
83        self._create_scheduler(worker_id, trigger, interval, true, instant)
84    }
85
86    fn _create_scheduler(
87        &mut self,
88        worker_id: &str,
89        trigger: Arc<Notify>,
90        interval: Duration,
91        recreate: bool,
92        instant: bool,
93    ) -> Result<(), Error> {
94        if self.schedulers.contains_key(worker_id) {
95            if recreate {
96                let _r = self.destroy_scheduler(worker_id);
97            } else {
98                return Err(Error::duplicate(ERR_DUPLICATE_WORKER_ID));
99            }
100        }
101        let mut scheduler = Scheduler::new(trigger, interval);
102        let fut = if instant {
103            tokio::spawn(async move {
104                scheduler.run_instant().await;
105            })
106        } else {
107            tokio::spawn(async move {
108                scheduler.run().await;
109            })
110        };
111        self.schedulers.insert(worker_id.to_owned(), fut);
112        Ok(())
113    }
114
115    /// # Errors
116    ///
117    /// Will return `Err` if the worker does not exist
118    pub fn destroy_scheduler(&mut self, worker_id: &str) -> Result<(), Error> {
119        self.schedulers.remove(worker_id).map_or(
120            Err(Error::not_found(ERR_WORKER_NOT_FOUND)),
121            |fut| {
122                fut.abort();
123                Ok(())
124            },
125        )
126    }
127}
128
129pub struct TaskWorker<F, Fut, T>
130where
131    F: FnMut(T) -> Fut,
132    Fut: std::future::Future<Output = ()>,
133    T: Sync + fmt::Debug,
134{
135    func: F,
136    rx: mpsc::Receiver<T>,
137}
138
139impl<F, Fut, T> TaskWorker<F, Fut, T>
140where
141    F: FnMut(T) -> Fut,
142    Fut: std::future::Future<Output = ()>,
143    T: Sync + fmt::Debug,
144{
145    pub fn new(func: F, buf: usize) -> (Self, mpsc::Sender<T>) {
146        let (tx, rx) = mpsc::channel(buf);
147        (Self { func, rx }, tx)
148    }
149
150    pub async fn run(&mut self) {
151        while let Some(v) = self.rx.recv().await {
152            (self.func)(v).await;
153        }
154    }
155}