async_priority_limiter/
limiter.rs

1pub mod builder;
2
3use crate::{
4    BoxFuture,
5    blocks::Blocks,
6    ingress::Ingress,
7    intervals::Intervals,
8    task::Task,
9    traits::{Key, Priority, TaskResult},
10    worker::Worker,
11};
12
13use std::{collections::BinaryHeap, sync::Arc, time::Duration};
14use tokio::{
15    sync::{Mutex, RwLock, oneshot},
16    time::Instant,
17};
18
19#[derive(Debug)]
20pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
21    tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
22    ingress: Ingress<K, P, T>,
23    workers: Mutex<Vec<Worker>>,
24    blocks: Arc<RwLock<Blocks<K>>>,
25    intervals: Arc<RwLock<Intervals<K>>>,
26}
27
28impl<K: Key, P: Priority, T: TaskResult> AsRef<Limiter<K, P, T>> for Limiter<K, P, T> {
29    fn as_ref(&self) -> &Limiter<K, P, T> {
30        self
31    }
32}
33
34impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
35    pub fn new<K: Key>(concurrent_tasks: usize) -> Limiter<K, P, T> {
36        Limiter::new_with(concurrent_tasks, Default::default(), Default::default())
37    }
38
39    pub fn new_with<K: Key>(
40        concurrent_tasks: usize,
41        blocks: Blocks<K>,
42        intervals: Intervals<K>,
43    ) -> Limiter<K, P, T> {
44        let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
45        let blocks: Arc<RwLock<Blocks<K>>> = Arc::new(RwLock::new(blocks));
46        let intervals: Arc<RwLock<Intervals<K>>> = Arc::new(RwLock::new(intervals));
47        let ingress = Ingress::spawn(tasks.clone());
48        let workers = Mutex::new(
49            (0..concurrent_tasks)
50                .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
51                .collect(),
52        );
53
54        Limiter {
55            tasks,
56            blocks,
57            intervals,
58            ingress,
59            workers,
60        }
61    }
62}
63
64impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
65    pub async fn get_default_block_duration(&self) -> Option<Duration> {
66        self.blocks.write().await.get_default()
67    }
68
69    pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
70        self.blocks.write().await.get_by_key(key)
71    }
72
73    pub async fn set_default_block_until_at_least(&self, instant: Instant) {
74        self.blocks.write().await.set_default_at_least(instant);
75    }
76
77    pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
78        self.blocks.write().await.set_at_least_by_key(instant, key);
79    }
80
81    pub async fn set_default_block_until(&self, instant: Option<Instant>) {
82        self.blocks.write().await.set_default(instant);
83    }
84
85    pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
86        self.blocks.write().await.set_by_key(instant, key);
87    }
88
89    pub async fn set_default_interval_at_least(&self, interval: Duration) {
90        self.intervals.write().await.set_default_at_least(interval);
91    }
92
93    pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
94        self.intervals
95            .write()
96            .await
97            .set_at_least_by_key(interval, key);
98    }
99
100    pub async fn set_default_interval(&self, interval: Option<Duration>) {
101        self.intervals.write().await.set_default(interval);
102    }
103
104    pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
105        self.intervals.write().await.set_by_key(interval, key);
106    }
107
108    pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
109        let mut guard = self.workers.lock().await;
110        let len = guard.len();
111
112        match len.cmp(&concurrent_tasks) {
113            std::cmp::Ordering::Less => {
114                for _ in len..concurrent_tasks {
115                    guard.push(self.ingress.spawn_worker(
116                        self.tasks.clone(),
117                        self.blocks.clone(),
118                        self.intervals.clone(),
119                    ));
120                }
121            }
122            std::cmp::Ordering::Equal => {}
123            std::cmp::Ordering::Greater => {
124                guard.drain(concurrent_tasks..);
125            }
126        }
127    }
128
129    pub async fn queue<J: Future<Output = T> + Send + 'static>(
130        &self,
131        job: J,
132        priority: P,
133    ) -> BoxFuture<T> {
134        let (reply_sender, reply_receiver) = oneshot::channel();
135
136        self.ingress
137            .send(Task::new(job, priority, reply_sender))
138            .await;
139
140        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
141    }
142
143    pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
144        &self,
145        job: J,
146        priority: P,
147        key: K,
148    ) -> BoxFuture<T> {
149        let (reply_sender, reply_receiver) = oneshot::channel();
150
151        self.ingress
152            .send(Task::new_with_key(job, priority, reply_sender, key))
153            .await;
154
155        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use crate::limiter::builder::LimiterBuilder;
162
163    use super::*;
164    use futures::future::join_all;
165    use std::sync::Arc;
166    use tokio::sync::Mutex;
167
168    #[tokio::test]
169    async fn it_should_work() {
170        use Prio::*;
171
172        #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
173        enum Prio {
174            Low,
175            Mid,
176            High,
177        }
178
179        // let limiter = Limiter::new::<String>(0);
180
181        let limiter = LimiterBuilder::new::<String>(0).build();
182
183        let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
184
185        let futures = [
186            limiter
187                .queue(
188                    {
189                        let results = acc.clone();
190                        async move {
191                            results.lock().await.push(High);
192                            1
193                        }
194                    },
195                    High,
196                )
197                .await,
198            limiter
199                .queue(
200                    {
201                        let results = acc.clone();
202                        async move {
203                            results.lock().await.push(Mid);
204                            2
205                        }
206                    },
207                    Mid,
208                )
209                .await,
210            limiter
211                .queue(
212                    {
213                        let results = acc.clone();
214                        async move {
215                            results.lock().await.push(Low);
216                            3
217                        }
218                    },
219                    Low,
220                )
221                .await,
222            limiter
223                .queue(
224                    {
225                        let results = acc.clone();
226                        async move {
227                            results.lock().await.push(Low);
228                            4
229                        }
230                    },
231                    Low,
232                )
233                .await,
234            limiter
235                .queue(
236                    {
237                        let results = acc.clone();
238                        async move {
239                            results.lock().await.push(Mid);
240                            5
241                        }
242                    },
243                    Mid,
244                )
245                .await,
246            limiter
247                .queue(
248                    {
249                        let results = acc.clone();
250                        async move {
251                            results.lock().await.push(High);
252                            6
253                        }
254                    },
255                    High,
256                )
257                .await,
258            limiter
259                .queue(
260                    {
261                        let results = acc.clone();
262                        async move {
263                            results.lock().await.push(Mid);
264                            7
265                        }
266                    },
267                    Mid,
268                )
269                .await,
270            limiter
271                .queue(
272                    {
273                        let results = acc.clone();
274                        async move {
275                            results.lock().await.push(Low);
276                            8
277                        }
278                    },
279                    Low,
280                )
281                .await,
282            limiter
283                .queue(
284                    {
285                        let results = acc.clone();
286                        async move {
287                            results.lock().await.push(High);
288                            9
289                        }
290                    },
291                    High,
292                )
293                .await,
294        ];
295
296        limiter
297            .set_default_interval(Some(Duration::from_millis(100)))
298            .await;
299
300        limiter.set_concurrent_tasks(2).await;
301
302        let order = join_all(futures).await;
303        let acc = acc.lock().await.clone();
304
305        assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
306
307        assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
308    }
309}