Skip to main content

async_priority_limiter/
limiter.rs

1pub mod builder;
2
3use crate::{
4    BoxFuture,
5    blocks::Blocks,
6    ingress::Ingress,
7    intervals::Intervals,
8    task::Task,
9    traits::{Key, Priority, TaskResult},
10    worker::Worker,
11};
12
13use std::{collections::BinaryHeap, sync::Arc, time::Duration};
14use tokio::{
15    sync::{Mutex, RwLock, oneshot},
16    time::Instant,
17};
18
19#[derive(Debug)]
20pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
21    tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
22    ingress: Ingress<K, P, T>,
23    workers: Mutex<Vec<Worker>>,
24    blocks: Arc<RwLock<Blocks<K>>>,
25    intervals: Arc<RwLock<Intervals<K>>>,
26}
27
28impl<K: Key, P: Priority, T: TaskResult> AsRef<Limiter<K, P, T>> for Limiter<K, P, T> {
29    fn as_ref(&self) -> &Limiter<K, P, T> {
30        self
31    }
32}
33
34impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
35    pub fn new<K: Key>(concurrent_tasks: usize) -> Limiter<K, P, T> {
36        Limiter::new_with(concurrent_tasks, Default::default(), Default::default())
37    }
38
39    pub fn new_with<K: Key>(
40        concurrent_tasks: usize,
41        blocks: Blocks<K>,
42        intervals: Intervals<K>,
43    ) -> Limiter<K, P, T> {
44        let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
45        let blocks: Arc<RwLock<Blocks<K>>> = Arc::new(RwLock::new(blocks));
46        let intervals: Arc<RwLock<Intervals<K>>> = Arc::new(RwLock::new(intervals));
47        let ingress = Ingress::spawn(tasks.clone());
48        let workers = Mutex::new(
49            (0..concurrent_tasks)
50                .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
51                .collect(),
52        );
53
54        Limiter {
55            tasks,
56            blocks,
57            intervals,
58            ingress,
59            workers,
60        }
61    }
62}
63
64impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
65    pub async fn get_default_block_duration(&self) -> Option<Duration> {
66        self.blocks.read().await.get_default()
67    }
68
69    pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
70        self.blocks.read().await.get_by_key(key)
71    }
72
73    pub async fn set_default_block_until_at_least(&self, instant: Instant) {
74        self.blocks.write().await.set_default_at_least(instant);
75    }
76
77    pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
78        self.blocks.write().await.set_at_least_by_key(instant, key);
79    }
80
81    pub async fn set_default_block_until(&self, instant: Option<Instant>) {
82        self.blocks.write().await.set_default(instant);
83    }
84
85    pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
86        self.blocks.write().await.set_by_key(instant, key);
87    }
88
89    pub async fn set_default_interval_at_least(&self, interval: Duration) {
90        self.intervals.write().await.set_default_at_least(interval);
91    }
92
93    pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
94        self.intervals
95            .write()
96            .await
97            .set_at_least_by_key(interval, key);
98    }
99
100    pub async fn set_default_interval(&self, interval: Option<Duration>) {
101        self.intervals.write().await.set_default(interval);
102    }
103
104    pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
105        self.intervals.write().await.set_by_key(interval, key);
106    }
107
108    pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
109        let mut guard = self.workers.lock().await;
110        let len = guard.len();
111
112        match len.cmp(&concurrent_tasks) {
113            std::cmp::Ordering::Less => {
114                for _ in len..concurrent_tasks {
115                    guard.push(self.ingress.spawn_worker(
116                        self.tasks.clone(),
117                        self.blocks.clone(),
118                        self.intervals.clone(),
119                    ));
120                }
121            }
122            std::cmp::Ordering::Equal => {}
123            std::cmp::Ordering::Greater => {
124                guard.drain(concurrent_tasks..);
125            }
126        }
127    }
128
129    pub async fn queue<J: Future<Output = T> + Send + 'static>(
130        &self,
131        job: J,
132        priority: P,
133    ) -> BoxFuture<T> {
134        let (reply_sender, reply_receiver) = oneshot::channel();
135
136        self.ingress
137            .send(Task::new(job, priority, reply_sender))
138            .await;
139
140        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
141    }
142
143    pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
144        &self,
145        job: J,
146        priority: P,
147        key: K,
148    ) -> BoxFuture<T> {
149        let (reply_sender, reply_receiver) = oneshot::channel();
150
151        self.ingress
152            .send(Task::new_with_key(job, priority, reply_sender, key))
153            .await;
154
155        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use crate::limiter::builder::LimiterBuilder;
162
163    use super::*;
164    use futures::future::join_all;
165    use std::sync::Arc;
166    use tokio::sync::Mutex;
167
168    #[tokio::test]
169    async fn it_should_work() {
170        use Prio::*;
171
172        #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
173        enum Prio {
174            Low,
175            Mid,
176            High,
177        }
178
179        let limiter = LimiterBuilder::new::<String>(0).build();
180
181        let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
182
183        let futures = [
184            limiter
185                .queue(
186                    {
187                        let results = acc.clone();
188                        async move {
189                            results.lock().await.push(High);
190                            1
191                        }
192                    },
193                    High,
194                )
195                .await,
196            limiter
197                .queue(
198                    {
199                        let results = acc.clone();
200                        async move {
201                            results.lock().await.push(Mid);
202                            2
203                        }
204                    },
205                    Mid,
206                )
207                .await,
208            limiter
209                .queue(
210                    {
211                        let results = acc.clone();
212                        async move {
213                            results.lock().await.push(Low);
214                            3
215                        }
216                    },
217                    Low,
218                )
219                .await,
220            limiter
221                .queue(
222                    {
223                        let results = acc.clone();
224                        async move {
225                            results.lock().await.push(Low);
226                            4
227                        }
228                    },
229                    Low,
230                )
231                .await,
232            limiter
233                .queue(
234                    {
235                        let results = acc.clone();
236                        async move {
237                            results.lock().await.push(Mid);
238                            5
239                        }
240                    },
241                    Mid,
242                )
243                .await,
244            limiter
245                .queue(
246                    {
247                        let results = acc.clone();
248                        async move {
249                            results.lock().await.push(High);
250                            6
251                        }
252                    },
253                    High,
254                )
255                .await,
256            limiter
257                .queue(
258                    {
259                        let results = acc.clone();
260                        async move {
261                            results.lock().await.push(Mid);
262                            7
263                        }
264                    },
265                    Mid,
266                )
267                .await,
268            limiter
269                .queue(
270                    {
271                        let results = acc.clone();
272                        async move {
273                            results.lock().await.push(Low);
274                            8
275                        }
276                    },
277                    Low,
278                )
279                .await,
280            limiter
281                .queue(
282                    {
283                        let results = acc.clone();
284                        async move {
285                            results.lock().await.push(High);
286                            9
287                        }
288                    },
289                    High,
290                )
291                .await,
292        ];
293
294        limiter
295            .set_default_interval(Some(Duration::from_millis(100)))
296            .await;
297
298        limiter.set_concurrent_tasks(2).await;
299
300        let order = join_all(futures).await;
301        let acc = acc.lock().await.clone();
302
303        assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
304
305        assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
306    }
307
308    #[tokio::test]
309    #[should_panic(expected = "reply_sender should not drop")]
310    async fn panic_in_task_causes_cascading_panic() {
311        let limiter = Limiter::new::<String>(1);
312
313        let future = limiter
314            .queue(
315                async {
316                    panic!("Intentional panic in task");
317                },
318                1,
319            )
320            .await;
321
322        // This should panic when the worker panics and drops the reply sender
323        future.await;
324    }
325
326    #[tokio::test]
327    async fn zero_concurrent_tasks_still_queues() {
328        let limiter = Limiter::new::<String>(0);
329        let start = Instant::now();
330
331        let fut = limiter.queue(async { 42 }, 1).await;
332
333        // With zero workers, the task should queue but not execute yet
334        tokio::time::sleep(Duration::from_millis(10)).await;
335
336        // Now add a worker
337        limiter.set_concurrent_tasks(1).await;
338
339        let result = fut.await;
340        assert_eq!(result, 42);
341        assert!(start.elapsed() >= Duration::from_millis(10));
342    }
343
344    #[tokio::test]
345    async fn dynamic_worker_scaling() {
346        let limiter = Limiter::new::<String>(1);
347        let counter = Arc::new(Mutex::new(0));
348
349        let futures: Vec<_> = (0..10)
350            .map(|_| {
351                let counter = counter.clone();
352                limiter.queue(
353                    async move {
354                        *counter.lock().await += 1;
355                        tokio::time::sleep(Duration::from_millis(50)).await;
356                    },
357                    1,
358                )
359            })
360            .collect();
361
362        let futures = join_all(futures).await;
363
364        // Start with 1 worker - should take ~500ms
365        let start = Instant::now();
366
367        // Scale up to 5 workers after 100ms
368        tokio::time::sleep(Duration::from_millis(100)).await;
369        limiter.set_concurrent_tasks(5).await;
370
371        join_all(futures).await;
372        let elapsed = start.elapsed();
373
374        assert_eq!(*counter.lock().await, 10);
375        // Should complete faster than 10 * 50ms due to scaling
376        assert!(elapsed < Duration::from_millis(500));
377    }
378
379    #[tokio::test]
380    async fn read_locks_dont_block_reads() {
381        let limiter: Limiter<String, i32, ()> = Limiter::new(1);
382        limiter
383            .set_default_block_until(Some(Instant::now() + Duration::from_secs(10)))
384            .await;
385
386        // Multiple concurrent reads should not block each other
387        let start = Instant::now();
388        let (r1, r2, r3) = tokio::join!(
389            limiter.get_default_block_duration(),
390            limiter.get_default_block_duration(),
391            limiter.get_default_block_duration(),
392        );
393
394        let elapsed = start.elapsed();
395
396        assert!(r1.is_some());
397        assert!(r2.is_some());
398        assert!(r3.is_some());
399        // Should complete nearly instantly, not sequentially blocked
400        assert!(elapsed < Duration::from_millis(10));
401    }
402}