1pub mod builder;
2
3use crate::{
4 BoxFuture,
5 blocks::Blocks,
6 ingress::Ingress,
7 intervals::Intervals,
8 task::Task,
9 traits::{Key, Priority, TaskResult},
10 worker::Worker,
11};
12
13use std::{collections::BinaryHeap, sync::Arc, time::Duration};
14use tokio::{
15 sync::{Mutex, RwLock, oneshot},
16 time::Instant,
17};
18
19#[derive(Debug)]
20pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
21 tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
22 ingress: Ingress<K, P, T>,
23 workers: Mutex<Vec<Worker>>,
24 blocks: Arc<RwLock<Blocks<K>>>,
25 intervals: Arc<RwLock<Intervals<K>>>,
26}
27
28impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
29 pub fn new<K: Key>(concurrent_tasks: usize) -> Limiter<K, P, T> {
30 Limiter::new_with(concurrent_tasks, Default::default(), Default::default())
31 }
32
33 pub fn new_with<K: Key>(
34 concurrent_tasks: usize,
35 blocks: Blocks<K>,
36 intervals: Intervals<K>,
37 ) -> Limiter<K, P, T> {
38 let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
39 let blocks: Arc<RwLock<Blocks<K>>> = Arc::new(RwLock::new(blocks));
40 let intervals: Arc<RwLock<Intervals<K>>> = Arc::new(RwLock::new(intervals));
41 let ingress = Ingress::spawn(tasks.clone());
42 let workers = Mutex::new(
43 (0..concurrent_tasks)
44 .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
45 .collect(),
46 );
47
48 Limiter {
49 tasks,
50 blocks,
51 intervals,
52 ingress,
53 workers,
54 }
55 }
56}
57
58impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
59 pub async fn get_default_block_duration(&self) -> Option<Duration> {
60 self.blocks.write().await.get_default()
61 }
62
63 pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
64 self.blocks.write().await.get_by_key(key)
65 }
66
67 pub async fn set_default_block_until_at_least(&self, instant: Instant) {
68 self.blocks.write().await.set_default_at_least(instant);
69 }
70
71 pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
72 self.blocks.write().await.set_at_least_by_key(instant, key);
73 }
74
75 pub async fn set_default_block_until(&self, instant: Option<Instant>) {
76 self.blocks.write().await.set_default(instant);
77 }
78
79 pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
80 self.blocks.write().await.set_by_key(instant, key);
81 }
82
83 pub async fn set_default_interval_at_least(&self, interval: Duration) {
84 self.intervals.write().await.set_default_at_least(interval);
85 }
86
87 pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
88 self.intervals
89 .write()
90 .await
91 .set_at_least_by_key(interval, key);
92 }
93
94 pub async fn set_default_interval(&self, interval: Option<Duration>) {
95 self.intervals.write().await.set_default(interval);
96 }
97
98 pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
99 self.intervals.write().await.set_by_key(interval, key);
100 }
101
102 pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
103 let mut guard = self.workers.lock().await;
104 let len = guard.len();
105
106 match len.cmp(&concurrent_tasks) {
107 std::cmp::Ordering::Less => {
108 for _ in len..concurrent_tasks {
109 guard.push(self.ingress.spawn_worker(
110 self.tasks.clone(),
111 self.blocks.clone(),
112 self.intervals.clone(),
113 ));
114 }
115 }
116 std::cmp::Ordering::Equal => {}
117 std::cmp::Ordering::Greater => {
118 guard.drain(concurrent_tasks..);
119 }
120 }
121 }
122
123 pub async fn queue<J: Future<Output = T> + Send + 'static>(
124 &self,
125 job: J,
126 priority: P,
127 ) -> BoxFuture<T> {
128 let (reply_sender, reply_receiver) = oneshot::channel();
129
130 self.ingress
131 .send(Task::new(job, priority, reply_sender))
132 .await;
133
134 Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
135 }
136
137 pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
138 &self,
139 job: J,
140 priority: P,
141 key: K,
142 ) -> BoxFuture<T> {
143 let (reply_sender, reply_receiver) = oneshot::channel();
144
145 self.ingress
146 .send(Task::new_with_key(job, priority, reply_sender, key))
147 .await;
148
149 Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use crate::limiter::builder::LimiterBuilder;
156
157 use super::*;
158 use futures::future::join_all;
159 use std::sync::Arc;
160 use tokio::sync::Mutex;
161
162 #[tokio::test]
163 async fn it_should_work() {
164 use Prio::*;
165
166 #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
167 enum Prio {
168 Low,
169 Mid,
170 High,
171 }
172
173 let limiter = LimiterBuilder::new::<String>(0).build();
176
177 let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
178
179 let futures = [
180 limiter
181 .queue(
182 {
183 let results = acc.clone();
184 async move {
185 results.lock().await.push(High);
186 1
187 }
188 },
189 High,
190 )
191 .await,
192 limiter
193 .queue(
194 {
195 let results = acc.clone();
196 async move {
197 results.lock().await.push(Mid);
198 2
199 }
200 },
201 Mid,
202 )
203 .await,
204 limiter
205 .queue(
206 {
207 let results = acc.clone();
208 async move {
209 results.lock().await.push(Low);
210 3
211 }
212 },
213 Low,
214 )
215 .await,
216 limiter
217 .queue(
218 {
219 let results = acc.clone();
220 async move {
221 results.lock().await.push(Low);
222 4
223 }
224 },
225 Low,
226 )
227 .await,
228 limiter
229 .queue(
230 {
231 let results = acc.clone();
232 async move {
233 results.lock().await.push(Mid);
234 5
235 }
236 },
237 Mid,
238 )
239 .await,
240 limiter
241 .queue(
242 {
243 let results = acc.clone();
244 async move {
245 results.lock().await.push(High);
246 6
247 }
248 },
249 High,
250 )
251 .await,
252 limiter
253 .queue(
254 {
255 let results = acc.clone();
256 async move {
257 results.lock().await.push(Mid);
258 7
259 }
260 },
261 Mid,
262 )
263 .await,
264 limiter
265 .queue(
266 {
267 let results = acc.clone();
268 async move {
269 results.lock().await.push(Low);
270 8
271 }
272 },
273 Low,
274 )
275 .await,
276 limiter
277 .queue(
278 {
279 let results = acc.clone();
280 async move {
281 results.lock().await.push(High);
282 9
283 }
284 },
285 High,
286 )
287 .await,
288 ];
289
290 limiter
291 .set_default_interval(Some(Duration::from_millis(100)))
292 .await;
293
294 limiter.set_concurrent_tasks(2).await;
295
296 let order = join_all(futures).await;
297 let acc = acc.lock().await.clone();
298
299 assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
300
301 assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
302 }
303}