async_priority_limiter/
limiter.rs

1pub mod builder;
2
3use crate::{
4    BoxFuture,
5    blocks::Blocks,
6    ingress::Ingress,
7    intervals::Intervals,
8    task::Task,
9    traits::{Key, Priority, TaskResult},
10    worker::Worker,
11};
12
13use std::{collections::BinaryHeap, sync::Arc, time::Duration};
14use tokio::{
15    sync::{Mutex, RwLock, oneshot},
16    time::Instant,
17};
18
19#[derive(Debug)]
20pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
21    tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
22    ingress: Ingress<K, P, T>,
23    workers: Mutex<Vec<Worker>>,
24    blocks: Arc<RwLock<Blocks<K>>>,
25    intervals: Arc<RwLock<Intervals<K>>>,
26}
27
28impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
29    pub fn new<K: Key>(concurrent_tasks: usize) -> Limiter<K, P, T> {
30        Limiter::new_with(concurrent_tasks, Default::default(), Default::default())
31    }
32
33    pub fn new_with<K: Key>(
34        concurrent_tasks: usize,
35        blocks: Blocks<K>,
36        intervals: Intervals<K>,
37    ) -> Limiter<K, P, T> {
38        let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
39        let blocks: Arc<RwLock<Blocks<K>>> = Arc::new(RwLock::new(blocks));
40        let intervals: Arc<RwLock<Intervals<K>>> = Arc::new(RwLock::new(intervals));
41        let ingress = Ingress::spawn(tasks.clone());
42        let workers = Mutex::new(
43            (0..concurrent_tasks)
44                .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
45                .collect(),
46        );
47
48        Limiter {
49            tasks,
50            blocks,
51            intervals,
52            ingress,
53            workers,
54        }
55    }
56}
57
58impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
59    pub async fn get_default_block_duration(&self) -> Option<Duration> {
60        self.blocks.write().await.get_default()
61    }
62
63    pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
64        self.blocks.write().await.get_by_key(key)
65    }
66
67    pub async fn set_default_block_until_at_least(&self, instant: Instant) {
68        self.blocks.write().await.set_default_at_least(instant);
69    }
70
71    pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
72        self.blocks.write().await.set_at_least_by_key(instant, key);
73    }
74
75    pub async fn set_default_block_until(&self, instant: Option<Instant>) {
76        self.blocks.write().await.set_default(instant);
77    }
78
79    pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
80        self.blocks.write().await.set_by_key(instant, key);
81    }
82
83    pub async fn set_default_interval_at_least(&self, interval: Duration) {
84        self.intervals.write().await.set_default_at_least(interval);
85    }
86
87    pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
88        self.intervals
89            .write()
90            .await
91            .set_at_least_by_key(interval, key);
92    }
93
94    pub async fn set_default_interval(&self, interval: Option<Duration>) {
95        self.intervals.write().await.set_default(interval);
96    }
97
98    pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
99        self.intervals.write().await.set_by_key(interval, key);
100    }
101
102    pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
103        let mut guard = self.workers.lock().await;
104        let len = guard.len();
105
106        match len.cmp(&concurrent_tasks) {
107            std::cmp::Ordering::Less => {
108                for _ in len..concurrent_tasks {
109                    guard.push(self.ingress.spawn_worker(
110                        self.tasks.clone(),
111                        self.blocks.clone(),
112                        self.intervals.clone(),
113                    ));
114                }
115            }
116            std::cmp::Ordering::Equal => {}
117            std::cmp::Ordering::Greater => {
118                guard.drain(concurrent_tasks..);
119            }
120        }
121    }
122
123    pub async fn queue<J: Future<Output = T> + Send + 'static>(
124        &self,
125        job: J,
126        priority: P,
127    ) -> BoxFuture<T> {
128        let (reply_sender, reply_receiver) = oneshot::channel();
129
130        self.ingress
131            .send(Task::new(job, priority, reply_sender))
132            .await;
133
134        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
135    }
136
137    pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
138        &self,
139        job: J,
140        priority: P,
141        key: K,
142    ) -> BoxFuture<T> {
143        let (reply_sender, reply_receiver) = oneshot::channel();
144
145        self.ingress
146            .send(Task::new_with_key(job, priority, reply_sender, key))
147            .await;
148
149        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use crate::limiter::builder::LimiterBuilder;
156
157    use super::*;
158    use futures::future::join_all;
159    use std::sync::Arc;
160    use tokio::sync::Mutex;
161
162    #[tokio::test]
163    async fn it_should_work() {
164        use Prio::*;
165
166        #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
167        enum Prio {
168            Low,
169            Mid,
170            High,
171        }
172
173        // let limiter = Limiter::new::<String>(0);
174
175        let limiter = LimiterBuilder::new::<String>(0).build();
176
177        let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
178
179        let futures = [
180            limiter
181                .queue(
182                    {
183                        let results = acc.clone();
184                        async move {
185                            results.lock().await.push(High);
186                            1
187                        }
188                    },
189                    High,
190                )
191                .await,
192            limiter
193                .queue(
194                    {
195                        let results = acc.clone();
196                        async move {
197                            results.lock().await.push(Mid);
198                            2
199                        }
200                    },
201                    Mid,
202                )
203                .await,
204            limiter
205                .queue(
206                    {
207                        let results = acc.clone();
208                        async move {
209                            results.lock().await.push(Low);
210                            3
211                        }
212                    },
213                    Low,
214                )
215                .await,
216            limiter
217                .queue(
218                    {
219                        let results = acc.clone();
220                        async move {
221                            results.lock().await.push(Low);
222                            4
223                        }
224                    },
225                    Low,
226                )
227                .await,
228            limiter
229                .queue(
230                    {
231                        let results = acc.clone();
232                        async move {
233                            results.lock().await.push(Mid);
234                            5
235                        }
236                    },
237                    Mid,
238                )
239                .await,
240            limiter
241                .queue(
242                    {
243                        let results = acc.clone();
244                        async move {
245                            results.lock().await.push(High);
246                            6
247                        }
248                    },
249                    High,
250                )
251                .await,
252            limiter
253                .queue(
254                    {
255                        let results = acc.clone();
256                        async move {
257                            results.lock().await.push(Mid);
258                            7
259                        }
260                    },
261                    Mid,
262                )
263                .await,
264            limiter
265                .queue(
266                    {
267                        let results = acc.clone();
268                        async move {
269                            results.lock().await.push(Low);
270                            8
271                        }
272                    },
273                    Low,
274                )
275                .await,
276            limiter
277                .queue(
278                    {
279                        let results = acc.clone();
280                        async move {
281                            results.lock().await.push(High);
282                            9
283                        }
284                    },
285                    High,
286                )
287                .await,
288        ];
289
290        limiter
291            .set_default_interval(Some(Duration::from_millis(100)))
292            .await;
293
294        limiter.set_concurrent_tasks(2).await;
295
296        let order = join_all(futures).await;
297        let acc = acc.lock().await.clone();
298
299        assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
300
301        assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
302    }
303}