1use crate::{
2 BoxFuture,
3 auto_traits::{Key, Priority, TaskResult},
4 blocks::Blocks,
5 ingress::Ingress,
6 intervals::Intervals,
7 task::Task,
8 worker::Worker,
9};
10
11use std::{collections::BinaryHeap, sync::Arc, time::Duration};
12use tokio::{
13 sync::{Mutex, RwLock, oneshot},
14 time::Instant,
15};
16
17pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
18 tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
19 ingress: Ingress<K, P, T>,
20 workers: Mutex<Vec<Worker>>,
21 blocks: Arc<RwLock<Blocks<K>>>,
22 intervals: Arc<RwLock<Intervals<K>>>,
23}
24
25impl<P: Priority, T: TaskResult> Default for Limiter<String, P, T> {
26 fn default() -> Self {
27 Self::new(1)
28 }
29}
30
31impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
32 pub fn new<K: Key>(concurrent: usize) -> Limiter<K, P, T> {
33 let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
34 let blocks: Arc<RwLock<Blocks<K>>> = Default::default();
35 let intervals: Arc<RwLock<Intervals<K>>> = Default::default();
36 let ingress = Ingress::spawn(tasks.clone());
37 let workers = Mutex::new(
38 (0..concurrent)
39 .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
40 .collect(),
41 );
42
43 Limiter {
44 tasks,
45 blocks,
46 intervals,
47 ingress,
48 workers,
49 }
50 }
51}
52
53impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
54 pub async fn get_default_block_duration(&self) -> Option<Duration> {
55 self.blocks.write().await.get_default()
56 }
57
58 pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
59 self.blocks.write().await.get_by_key(key)
60 }
61
62 pub async fn set_default_block_until_at_least(&self, instant: Instant) {
63 self.blocks.write().await.set_default_at_least(instant);
64 }
65
66 pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
67 self.blocks.write().await.set_at_least_by_key(instant, key);
68 }
69
70 pub async fn set_default_block_until(&self, instant: Option<Instant>) {
71 self.blocks.write().await.set_default(instant);
72 }
73
74 pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
75 self.blocks.write().await.set_by_key(instant, key);
76 }
77
78 pub async fn set_default_interval_at_least(&self, interval: Duration) {
79 self.intervals.write().await.set_default_at_least(interval);
80 }
81
82 pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
83 self.intervals
84 .write()
85 .await
86 .set_at_least_by_key(interval, key);
87 }
88
89 pub async fn set_default_interval(&self, interval: Option<Duration>) {
90 self.intervals.write().await.set_default(interval);
91 }
92
93 pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
94 self.intervals.write().await.set_by_key(interval, key);
95 }
96
97 pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
98 let mut guard = self.workers.lock().await;
99 let len = guard.len();
100
101 match len.cmp(&concurrent_tasks) {
102 std::cmp::Ordering::Less => {
103 for _ in len..concurrent_tasks {
104 guard.push(self.ingress.spawn_worker(
105 self.tasks.clone(),
106 self.blocks.clone(),
107 self.intervals.clone(),
108 ));
109 }
110 }
111 std::cmp::Ordering::Equal => {}
112 std::cmp::Ordering::Greater => {
113 guard.drain(concurrent_tasks..);
114 }
115 }
116 }
117
118 pub async fn queue<J: Future<Output = T> + Send + 'static>(
119 &self,
120 job: J,
121 priority: P,
122 ) -> BoxFuture<T> {
123 let (reply_sender, reply_receiver) = oneshot::channel();
124
125 self.ingress
126 .send(Task::new(job, priority, reply_sender))
127 .await;
128
129 Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
130 }
131
132 pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
133 &self,
134 job: J,
135 priority: P,
136 key: K,
137 ) -> BoxFuture<T> {
138 let (reply_sender, reply_receiver) = oneshot::channel();
139
140 self.ingress
141 .send(Task::new_with_key(job, priority, reply_sender, key))
142 .await;
143
144 Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use futures::future::join_all;
152 use std::sync::Arc;
153 use tokio::sync::Mutex;
154
155 #[tokio::test]
156 async fn it_should_work() {
157 use Prio::*;
158
159 #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
160 enum Prio {
161 Low,
162 Mid,
163 High,
164 }
165
166 let limiter = Limiter::new::<String>(0);
167
168 let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
169
170 let futures = [
171 limiter
172 .queue(
173 {
174 let results = acc.clone();
175 async move {
176 results.lock().await.push(High);
177 1
178 }
179 },
180 High,
181 )
182 .await,
183 limiter
184 .queue(
185 {
186 let results = acc.clone();
187 async move {
188 results.lock().await.push(Mid);
189 2
190 }
191 },
192 Mid,
193 )
194 .await,
195 limiter
196 .queue(
197 {
198 let results = acc.clone();
199 async move {
200 results.lock().await.push(Low);
201 3
202 }
203 },
204 Low,
205 )
206 .await,
207 limiter
208 .queue(
209 {
210 let results = acc.clone();
211 async move {
212 results.lock().await.push(Low);
213 4
214 }
215 },
216 Low,
217 )
218 .await,
219 limiter
220 .queue(
221 {
222 let results = acc.clone();
223 async move {
224 results.lock().await.push(Mid);
225 5
226 }
227 },
228 Mid,
229 )
230 .await,
231 limiter
232 .queue(
233 {
234 let results = acc.clone();
235 async move {
236 results.lock().await.push(High);
237 6
238 }
239 },
240 High,
241 )
242 .await,
243 limiter
244 .queue(
245 {
246 let results = acc.clone();
247 async move {
248 results.lock().await.push(Mid);
249 7
250 }
251 },
252 Mid,
253 )
254 .await,
255 limiter
256 .queue(
257 {
258 let results = acc.clone();
259 async move {
260 results.lock().await.push(Low);
261 8
262 }
263 },
264 Low,
265 )
266 .await,
267 limiter
268 .queue(
269 {
270 let results = acc.clone();
271 async move {
272 results.lock().await.push(High);
273 9
274 }
275 },
276 High,
277 )
278 .await,
279 ];
280
281 limiter
282 .set_default_interval(Some(Duration::from_millis(100)))
283 .await;
284
285 limiter.set_concurrent_tasks(2).await;
286
287 let order = join_all(futures).await;
288 let acc = acc.lock().await.clone();
289
290 assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
291
292 assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
293 }
294}