async_priority_limiter/
limiter.rs

1use crate::{
2    BoxFuture,
3    auto_traits::{Key, Priority, TaskResult},
4    blocks::Blocks,
5    ingress::Ingress,
6    intervals::Intervals,
7    task::Task,
8    worker::Worker,
9};
10
11use std::{collections::BinaryHeap, sync::Arc, time::Duration};
12use tokio::{
13    sync::{Mutex, RwLock, oneshot},
14    time::Instant,
15};
16
17pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
18    tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
19    ingress: Ingress<K, P, T>,
20    workers: Mutex<Vec<Worker>>,
21    blocks: Arc<RwLock<Blocks<K>>>,
22    intervals: Arc<RwLock<Intervals<K>>>,
23}
24
25impl<P: Priority, T: TaskResult> Default for Limiter<String, P, T> {
26    fn default() -> Self {
27        Self::new(1)
28    }
29}
30
31impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
32    pub fn new<K: Key>(concurrent: usize) -> Limiter<K, P, T> {
33        let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
34        let blocks: Arc<RwLock<Blocks<K>>> = Default::default();
35        let intervals: Arc<RwLock<Intervals<K>>> = Default::default();
36        let ingress = Ingress::spawn(tasks.clone());
37        let workers = Mutex::new(
38            (0..concurrent)
39                .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
40                .collect(),
41        );
42
43        Limiter {
44            tasks,
45            blocks,
46            intervals,
47            ingress,
48            workers,
49        }
50    }
51}
52
53impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
54    pub async fn get_default_block_duration(&self) -> Option<Duration> {
55        self.blocks.write().await.get_default()
56    }
57
58    pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
59        self.blocks.write().await.get_by_key(key)
60    }
61
62    pub async fn set_default_block_until_at_least(&self, instant: Instant) {
63        self.blocks.write().await.set_default_at_least(instant);
64    }
65
66    pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
67        self.blocks.write().await.set_at_least_by_key(instant, key);
68    }
69
70    pub async fn set_default_block_until(&self, instant: Option<Instant>) {
71        self.blocks.write().await.set_default(instant);
72    }
73
74    pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
75        self.blocks.write().await.set_by_key(instant, key);
76    }
77
78    pub async fn set_default_interval_at_least(&self, interval: Duration) {
79        self.intervals.write().await.set_default_at_least(interval);
80    }
81
82    pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
83        self.intervals
84            .write()
85            .await
86            .set_at_least_by_key(interval, key);
87    }
88
89    pub async fn set_default_interval(&self, interval: Option<Duration>) {
90        self.intervals.write().await.set_default(interval);
91    }
92
93    pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
94        self.intervals.write().await.set_by_key(interval, key);
95    }
96
97    pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
98        let mut guard = self.workers.lock().await;
99        let len = guard.len();
100
101        match len.cmp(&concurrent_tasks) {
102            std::cmp::Ordering::Less => {
103                for _ in len..concurrent_tasks {
104                    guard.push(self.ingress.spawn_worker(
105                        self.tasks.clone(),
106                        self.blocks.clone(),
107                        self.intervals.clone(),
108                    ));
109                }
110            }
111            std::cmp::Ordering::Equal => {}
112            std::cmp::Ordering::Greater => {
113                guard.drain(concurrent_tasks..);
114            }
115        }
116    }
117
118    pub async fn queue<J: Future<Output = T> + Send + 'static>(
119        &self,
120        job: J,
121        priority: P,
122    ) -> BoxFuture<T> {
123        let (reply_sender, reply_receiver) = oneshot::channel();
124
125        self.ingress
126            .send(Task::new(job, priority, reply_sender))
127            .await;
128
129        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
130    }
131
132    pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
133        &self,
134        job: J,
135        priority: P,
136        key: K,
137    ) -> BoxFuture<T> {
138        let (reply_sender, reply_receiver) = oneshot::channel();
139
140        self.ingress
141            .send(Task::new_with_key(job, priority, reply_sender, key))
142            .await;
143
144        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use futures::future::join_all;
152    use std::sync::Arc;
153    use tokio::sync::Mutex;
154
155    #[tokio::test]
156    async fn it_should_work() {
157        use Prio::*;
158
159        #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
160        enum Prio {
161            Low,
162            Mid,
163            High,
164        }
165
166        let limiter = Limiter::new::<String>(0);
167
168        let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
169
170        let futures = [
171            limiter
172                .queue(
173                    {
174                        let results = acc.clone();
175                        async move {
176                            results.lock().await.push(High);
177                            1
178                        }
179                    },
180                    High,
181                )
182                .await,
183            limiter
184                .queue(
185                    {
186                        let results = acc.clone();
187                        async move {
188                            results.lock().await.push(Mid);
189                            2
190                        }
191                    },
192                    Mid,
193                )
194                .await,
195            limiter
196                .queue(
197                    {
198                        let results = acc.clone();
199                        async move {
200                            results.lock().await.push(Low);
201                            3
202                        }
203                    },
204                    Low,
205                )
206                .await,
207            limiter
208                .queue(
209                    {
210                        let results = acc.clone();
211                        async move {
212                            results.lock().await.push(Low);
213                            4
214                        }
215                    },
216                    Low,
217                )
218                .await,
219            limiter
220                .queue(
221                    {
222                        let results = acc.clone();
223                        async move {
224                            results.lock().await.push(Mid);
225                            5
226                        }
227                    },
228                    Mid,
229                )
230                .await,
231            limiter
232                .queue(
233                    {
234                        let results = acc.clone();
235                        async move {
236                            results.lock().await.push(High);
237                            6
238                        }
239                    },
240                    High,
241                )
242                .await,
243            limiter
244                .queue(
245                    {
246                        let results = acc.clone();
247                        async move {
248                            results.lock().await.push(Mid);
249                            7
250                        }
251                    },
252                    Mid,
253                )
254                .await,
255            limiter
256                .queue(
257                    {
258                        let results = acc.clone();
259                        async move {
260                            results.lock().await.push(Low);
261                            8
262                        }
263                    },
264                    Low,
265                )
266                .await,
267            limiter
268                .queue(
269                    {
270                        let results = acc.clone();
271                        async move {
272                            results.lock().await.push(High);
273                            9
274                        }
275                    },
276                    High,
277                )
278                .await,
279        ];
280
281        limiter
282            .set_default_interval(Some(Duration::from_millis(100)))
283            .await;
284
285        limiter.set_concurrent_tasks(2).await;
286
287        let order = join_all(futures).await;
288        let acc = acc.lock().await.clone();
289
290        assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
291
292        assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
293    }
294}