1pub mod builder;
2
3use crate::{
4 BoxFuture,
5 blocks::Blocks,
6 ingress::Ingress,
7 intervals::Intervals,
8 task::Task,
9 traits::{Key, Priority, TaskResult},
10 worker::Worker,
11};
12
13use std::{collections::BinaryHeap, sync::Arc, time::Duration};
14use tokio::{
15 sync::{Mutex, RwLock, oneshot},
16 time::Instant,
17};
18
19#[derive(Debug)]
20pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
21 tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
22 ingress: Ingress<K, P, T>,
23 workers: Mutex<Vec<Worker>>,
24 blocks: Arc<RwLock<Blocks<K>>>,
25 intervals: Arc<RwLock<Intervals<K>>>,
26}
27
28impl<K: Key, P: Priority, T: TaskResult> AsRef<Limiter<K, P, T>> for Limiter<K, P, T> {
29 fn as_ref(&self) -> &Limiter<K, P, T> {
30 self
31 }
32}
33
34impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
35 pub fn new<K: Key>(concurrent_tasks: usize) -> Limiter<K, P, T> {
36 Limiter::new_with(concurrent_tasks, Default::default(), Default::default())
37 }
38
39 pub fn new_with<K: Key>(
40 concurrent_tasks: usize,
41 blocks: Blocks<K>,
42 intervals: Intervals<K>,
43 ) -> Limiter<K, P, T> {
44 let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
45 let blocks: Arc<RwLock<Blocks<K>>> = Arc::new(RwLock::new(blocks));
46 let intervals: Arc<RwLock<Intervals<K>>> = Arc::new(RwLock::new(intervals));
47 let ingress = Ingress::spawn(tasks.clone());
48 let workers = Mutex::new(
49 (0..concurrent_tasks)
50 .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
51 .collect(),
52 );
53
54 Limiter {
55 tasks,
56 blocks,
57 intervals,
58 ingress,
59 workers,
60 }
61 }
62}
63
64impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
65 pub async fn get_default_block_duration(&self) -> Option<Duration> {
66 self.blocks.read().await.get_default()
67 }
68
69 pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
70 self.blocks.read().await.get_by_key(key)
71 }
72
73 pub async fn set_default_block_until_at_least(&self, instant: Instant) {
74 self.blocks.write().await.set_default_at_least(instant);
75 }
76
77 pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
78 self.blocks.write().await.set_at_least_by_key(instant, key);
79 }
80
81 pub async fn set_default_block_until(&self, instant: Option<Instant>) {
82 self.blocks.write().await.set_default(instant);
83 }
84
85 pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
86 self.blocks.write().await.set_by_key(instant, key);
87 }
88
89 pub async fn set_default_interval_at_least(&self, interval: Duration) {
90 self.intervals.write().await.set_default_at_least(interval);
91 }
92
93 pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
94 self.intervals
95 .write()
96 .await
97 .set_at_least_by_key(interval, key);
98 }
99
100 pub async fn set_default_interval(&self, interval: Option<Duration>) {
101 self.intervals.write().await.set_default(interval);
102 }
103
104 pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
105 self.intervals.write().await.set_by_key(interval, key);
106 }
107
108 pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
109 let mut guard = self.workers.lock().await;
110 let len = guard.len();
111
112 match len.cmp(&concurrent_tasks) {
113 std::cmp::Ordering::Less => {
114 for _ in len..concurrent_tasks {
115 guard.push(self.ingress.spawn_worker(
116 self.tasks.clone(),
117 self.blocks.clone(),
118 self.intervals.clone(),
119 ));
120 }
121 }
122 std::cmp::Ordering::Equal => {}
123 std::cmp::Ordering::Greater => {
124 guard.drain(concurrent_tasks..);
125 }
126 }
127 }
128
129 pub async fn queue<J: Future<Output = T> + Send + 'static>(
130 &self,
131 job: J,
132 priority: P,
133 ) -> BoxFuture<T> {
134 let (reply_sender, reply_receiver) = oneshot::channel();
135
136 self.ingress
137 .send(Task::new(job, priority, reply_sender))
138 .await;
139
140 Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
141 }
142
143 pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
144 &self,
145 job: J,
146 priority: P,
147 key: K,
148 ) -> BoxFuture<T> {
149 let (reply_sender, reply_receiver) = oneshot::channel();
150
151 self.ingress
152 .send(Task::new_with_key(job, priority, reply_sender, key))
153 .await;
154
155 Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use crate::limiter::builder::LimiterBuilder;
162
163 use super::*;
164 use futures::future::join_all;
165 use std::sync::Arc;
166 use tokio::sync::Mutex;
167
168 #[tokio::test]
169 async fn it_should_work() {
170 use Prio::*;
171
172 #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
173 enum Prio {
174 Low,
175 Mid,
176 High,
177 }
178
179 let limiter = LimiterBuilder::new::<String>(0).build();
180
181 let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
182
183 let futures = [
184 limiter
185 .queue(
186 {
187 let results = acc.clone();
188 async move {
189 results.lock().await.push(High);
190 1
191 }
192 },
193 High,
194 )
195 .await,
196 limiter
197 .queue(
198 {
199 let results = acc.clone();
200 async move {
201 results.lock().await.push(Mid);
202 2
203 }
204 },
205 Mid,
206 )
207 .await,
208 limiter
209 .queue(
210 {
211 let results = acc.clone();
212 async move {
213 results.lock().await.push(Low);
214 3
215 }
216 },
217 Low,
218 )
219 .await,
220 limiter
221 .queue(
222 {
223 let results = acc.clone();
224 async move {
225 results.lock().await.push(Low);
226 4
227 }
228 },
229 Low,
230 )
231 .await,
232 limiter
233 .queue(
234 {
235 let results = acc.clone();
236 async move {
237 results.lock().await.push(Mid);
238 5
239 }
240 },
241 Mid,
242 )
243 .await,
244 limiter
245 .queue(
246 {
247 let results = acc.clone();
248 async move {
249 results.lock().await.push(High);
250 6
251 }
252 },
253 High,
254 )
255 .await,
256 limiter
257 .queue(
258 {
259 let results = acc.clone();
260 async move {
261 results.lock().await.push(Mid);
262 7
263 }
264 },
265 Mid,
266 )
267 .await,
268 limiter
269 .queue(
270 {
271 let results = acc.clone();
272 async move {
273 results.lock().await.push(Low);
274 8
275 }
276 },
277 Low,
278 )
279 .await,
280 limiter
281 .queue(
282 {
283 let results = acc.clone();
284 async move {
285 results.lock().await.push(High);
286 9
287 }
288 },
289 High,
290 )
291 .await,
292 ];
293
294 limiter
295 .set_default_interval(Some(Duration::from_millis(100)))
296 .await;
297
298 limiter.set_concurrent_tasks(2).await;
299
300 let order = join_all(futures).await;
301 let acc = acc.lock().await.clone();
302
303 assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
304
305 assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
306 }
307
308 #[tokio::test]
309 #[should_panic(expected = "reply_sender should not drop")]
310 async fn panic_in_task_causes_cascading_panic() {
311 let limiter = Limiter::new::<String>(1);
312
313 let future = limiter
314 .queue(
315 async {
316 panic!("Intentional panic in task");
317 },
318 1,
319 )
320 .await;
321
322 future.await;
324 }
325
326 #[tokio::test]
327 async fn zero_concurrent_tasks_still_queues() {
328 let limiter = Limiter::new::<String>(0);
329 let start = Instant::now();
330
331 let fut = limiter.queue(async { 42 }, 1).await;
332
333 tokio::time::sleep(Duration::from_millis(10)).await;
335
336 limiter.set_concurrent_tasks(1).await;
338
339 let result = fut.await;
340 assert_eq!(result, 42);
341 assert!(start.elapsed() >= Duration::from_millis(10));
342 }
343
344 #[tokio::test]
345 async fn dynamic_worker_scaling() {
346 let limiter = Limiter::new::<String>(1);
347 let counter = Arc::new(Mutex::new(0));
348
349 let futures: Vec<_> = (0..10)
350 .map(|_| {
351 let counter = counter.clone();
352 limiter.queue(
353 async move {
354 *counter.lock().await += 1;
355 tokio::time::sleep(Duration::from_millis(50)).await;
356 },
357 1,
358 )
359 })
360 .collect();
361
362 let futures = join_all(futures).await;
363
364 let start = Instant::now();
366
367 tokio::time::sleep(Duration::from_millis(100)).await;
369 limiter.set_concurrent_tasks(5).await;
370
371 join_all(futures).await;
372 let elapsed = start.elapsed();
373
374 assert_eq!(*counter.lock().await, 10);
375 assert!(elapsed < Duration::from_millis(500));
377 }
378
379 #[tokio::test]
380 async fn read_locks_dont_block_reads() {
381 let limiter: Limiter<String, i32, ()> = Limiter::new(1);
382 limiter
383 .set_default_block_until(Some(Instant::now() + Duration::from_secs(10)))
384 .await;
385
386 let start = Instant::now();
388 let (r1, r2, r3) = tokio::join!(
389 limiter.get_default_block_duration(),
390 limiter.get_default_block_duration(),
391 limiter.get_default_block_duration(),
392 );
393
394 let elapsed = start.elapsed();
395
396 assert!(r1.is_some());
397 assert!(r2.is_some());
398 assert!(r3.is_some());
399 assert!(elapsed < Duration::from_millis(10));
401 }
402}