1pub mod builder;
2
3use crate::{
4 BoxFuture,
5 blocks::Blocks,
6 ingress::Ingress,
7 intervals::Intervals,
8 task::Task,
9 traits::{Key, Priority, TaskResult},
10 worker::Worker,
11};
12
13use std::{collections::BinaryHeap, sync::Arc, time::Duration};
14use tokio::{
15 sync::{Mutex, RwLock, oneshot},
16 time::Instant,
17};
18
19#[derive(Debug)]
20pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
21 tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
22 ingress: Ingress<K, P, T>,
23 workers: Mutex<Vec<Worker>>,
24 blocks: Arc<RwLock<Blocks<K>>>,
25 intervals: Arc<RwLock<Intervals<K>>>,
26}
27
28impl<K: Key, P: Priority, T: TaskResult> AsRef<Limiter<K, P, T>> for Limiter<K, P, T> {
29 fn as_ref(&self) -> &Limiter<K, P, T> {
30 self
31 }
32}
33
34impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
35 pub fn new<K: Key>(concurrent_tasks: usize) -> Limiter<K, P, T> {
36 Limiter::new_with(concurrent_tasks, Default::default(), Default::default())
37 }
38
39 pub fn new_with<K: Key>(
40 concurrent_tasks: usize,
41 blocks: Blocks<K>,
42 intervals: Intervals<K>,
43 ) -> Limiter<K, P, T> {
44 let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
45 let blocks: Arc<RwLock<Blocks<K>>> = Arc::new(RwLock::new(blocks));
46 let intervals: Arc<RwLock<Intervals<K>>> = Arc::new(RwLock::new(intervals));
47 let ingress = Ingress::spawn(tasks.clone());
48 let workers = Mutex::new(
49 (0..concurrent_tasks)
50 .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
51 .collect(),
52 );
53
54 Limiter {
55 tasks,
56 blocks,
57 intervals,
58 ingress,
59 workers,
60 }
61 }
62}
63
64impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
65 pub async fn get_default_block_duration(&self) -> Option<Duration> {
66 self.blocks.write().await.get_default()
67 }
68
69 pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
70 self.blocks.write().await.get_by_key(key)
71 }
72
73 pub async fn set_default_block_until_at_least(&self, instant: Instant) {
74 self.blocks.write().await.set_default_at_least(instant);
75 }
76
77 pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
78 self.blocks.write().await.set_at_least_by_key(instant, key);
79 }
80
81 pub async fn set_default_block_until(&self, instant: Option<Instant>) {
82 self.blocks.write().await.set_default(instant);
83 }
84
85 pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
86 self.blocks.write().await.set_by_key(instant, key);
87 }
88
89 pub async fn set_default_interval_at_least(&self, interval: Duration) {
90 self.intervals.write().await.set_default_at_least(interval);
91 }
92
93 pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
94 self.intervals
95 .write()
96 .await
97 .set_at_least_by_key(interval, key);
98 }
99
100 pub async fn set_default_interval(&self, interval: Option<Duration>) {
101 self.intervals.write().await.set_default(interval);
102 }
103
104 pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
105 self.intervals.write().await.set_by_key(interval, key);
106 }
107
108 pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
109 let mut guard = self.workers.lock().await;
110 let len = guard.len();
111
112 match len.cmp(&concurrent_tasks) {
113 std::cmp::Ordering::Less => {
114 for _ in len..concurrent_tasks {
115 guard.push(self.ingress.spawn_worker(
116 self.tasks.clone(),
117 self.blocks.clone(),
118 self.intervals.clone(),
119 ));
120 }
121 }
122 std::cmp::Ordering::Equal => {}
123 std::cmp::Ordering::Greater => {
124 guard.drain(concurrent_tasks..);
125 }
126 }
127 }
128
129 pub async fn queue<J: Future<Output = T> + Send + 'static>(
130 &self,
131 job: J,
132 priority: P,
133 ) -> BoxFuture<T> {
134 let (reply_sender, reply_receiver) = oneshot::channel();
135
136 self.ingress
137 .send(Task::new(job, priority, reply_sender))
138 .await;
139
140 Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
141 }
142
143 pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
144 &self,
145 job: J,
146 priority: P,
147 key: K,
148 ) -> BoxFuture<T> {
149 let (reply_sender, reply_receiver) = oneshot::channel();
150
151 self.ingress
152 .send(Task::new_with_key(job, priority, reply_sender, key))
153 .await;
154
155 Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use crate::limiter::builder::LimiterBuilder;
162
163 use super::*;
164 use futures::future::join_all;
165 use std::sync::Arc;
166 use tokio::sync::Mutex;
167
168 #[tokio::test]
169 async fn it_should_work() {
170 use Prio::*;
171
172 #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
173 enum Prio {
174 Low,
175 Mid,
176 High,
177 }
178
179 let limiter = LimiterBuilder::new::<String>(0).build();
182
183 let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
184
185 let futures = [
186 limiter
187 .queue(
188 {
189 let results = acc.clone();
190 async move {
191 results.lock().await.push(High);
192 1
193 }
194 },
195 High,
196 )
197 .await,
198 limiter
199 .queue(
200 {
201 let results = acc.clone();
202 async move {
203 results.lock().await.push(Mid);
204 2
205 }
206 },
207 Mid,
208 )
209 .await,
210 limiter
211 .queue(
212 {
213 let results = acc.clone();
214 async move {
215 results.lock().await.push(Low);
216 3
217 }
218 },
219 Low,
220 )
221 .await,
222 limiter
223 .queue(
224 {
225 let results = acc.clone();
226 async move {
227 results.lock().await.push(Low);
228 4
229 }
230 },
231 Low,
232 )
233 .await,
234 limiter
235 .queue(
236 {
237 let results = acc.clone();
238 async move {
239 results.lock().await.push(Mid);
240 5
241 }
242 },
243 Mid,
244 )
245 .await,
246 limiter
247 .queue(
248 {
249 let results = acc.clone();
250 async move {
251 results.lock().await.push(High);
252 6
253 }
254 },
255 High,
256 )
257 .await,
258 limiter
259 .queue(
260 {
261 let results = acc.clone();
262 async move {
263 results.lock().await.push(Mid);
264 7
265 }
266 },
267 Mid,
268 )
269 .await,
270 limiter
271 .queue(
272 {
273 let results = acc.clone();
274 async move {
275 results.lock().await.push(Low);
276 8
277 }
278 },
279 Low,
280 )
281 .await,
282 limiter
283 .queue(
284 {
285 let results = acc.clone();
286 async move {
287 results.lock().await.push(High);
288 9
289 }
290 },
291 High,
292 )
293 .await,
294 ];
295
296 limiter
297 .set_default_interval(Some(Duration::from_millis(100)))
298 .await;
299
300 limiter.set_concurrent_tasks(2).await;
301
302 let order = join_all(futures).await;
303 let acc = acc.lock().await.clone();
304
305 assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
306
307 assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
308 }
309}