twilight_http_ratelimiting/
lib.rs1#![doc = include_str!("../README.md")]
2#![warn(
3 clippy::missing_const_for_fn,
4 clippy::missing_docs_in_private_items,
5 clippy::pedantic,
6 missing_docs,
7 unsafe_code
8)]
9#![allow(clippy::module_name_repetitions, clippy::must_use_candidate)]
10
11mod actor;
12
13use std::{
14 future::Future,
15 hash::{Hash as _, Hasher},
16 pin::Pin,
17 task::{Context, Poll},
18 time::{Duration, Instant},
19};
20use tokio::sync::{mpsc, oneshot};
21
22pub const GLOBAL_LIMIT_PERIOD: Duration = Duration::from_secs(1);
25
26#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
30#[non_exhaustive]
31pub enum Method {
32 Delete,
34 Get,
36 Patch,
38 Post,
40 Put,
42}
43
44impl Method {
45 pub const fn name(self) -> &'static str {
47 match self {
48 Method::Delete => "DELETE",
49 Method::Get => "GET",
50 Method::Patch => "PATCH",
51 Method::Post => "POST",
52 Method::Put => "PUT",
53 }
54 }
55}
56
57#[derive(Clone, Debug, Eq, Hash, PartialEq)]
84pub struct Endpoint {
85 pub method: Method,
87 pub path: String,
91}
92
93impl Endpoint {
94 pub(crate) fn is_valid(&self) -> bool {
96 !self.path.as_bytes().starts_with(b"/") && !self.path.as_bytes().contains(&b'?')
97 }
98
99 pub(crate) fn is_interaction(&self) -> bool {
101 self.path.as_bytes().starts_with(b"webhooks")
102 || self.path.as_bytes().starts_with(b"interactions")
103 }
104
105 pub(crate) fn hash_resources(&self, state: &mut impl Hasher) {
115 let mut segments = self.path.as_bytes().split(|&s| s == b'/');
116 match segments.next().unwrap_or_default() {
117 b"channels" => {
118 if let Some(s) = segments.next() {
119 "channels".hash(state);
120 s.hash(state);
121 }
122 }
123 b"guilds" => {
124 if let Some(s) = segments.next() {
125 "guilds".hash(state);
126 s.hash(state);
127 }
128 }
129 b"webhooks" => {
130 if let Some(s) = segments.next() {
131 "webhooks".hash(state);
132 s.hash(state);
133 }
134 if let Some(s) = segments.next() {
135 s.hash(state);
136 }
137 }
138 _ => {}
139 }
140 }
141}
142
143#[derive(Clone, Debug, Eq, Hash, PartialEq)]
157pub struct RateLimitHeaders {
158 pub bucket: Vec<u8>,
160 pub limit: u16,
162 pub remaining: u16,
164 pub reset_at: Instant,
166}
167
168impl RateLimitHeaders {
169 pub const BUCKET: &'static str = "x-ratelimit-bucket";
171
172 pub const LIMIT: &'static str = "x-ratelimit-limit";
174
175 pub const REMAINING: &'static str = "x-ratelimit-remaining";
177
178 pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
180
181 pub const SCOPE: &'static str = "x-ratelimit-scope";
183
184 pub fn shared(bucket: Vec<u8>, retry_after: u16) -> Self {
187 Self {
188 bucket,
189 limit: 0,
190 remaining: 0,
191 reset_at: Instant::now() + Duration::from_secs(retry_after.into()),
192 }
193 }
194}
195
196#[derive(Debug)]
198#[must_use = "dropping the permit immediately cancels itself"]
199pub struct Permit(oneshot::Sender<Option<RateLimitHeaders>>);
200
201impl Permit {
202 #[allow(clippy::missing_panics_doc)]
207 pub fn complete(self, headers: Option<RateLimitHeaders>) {
208 self.0.send(headers).expect("actor is alive");
209 }
210}
211
212#[derive(Debug)]
214#[must_use = "futures do nothing unless you `.await` or poll them"]
215pub struct PermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
216
217impl Future for PermitFuture {
218 type Output = Permit;
219
220 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
221 Pin::new(&mut self.0)
222 .poll(cx)
223 .map(|r| Permit(r.expect("actor is alive")))
224 }
225}
226
227#[derive(Debug)]
229#[must_use = "futures do nothing unless you `.await` or poll them"]
230pub struct MaybePermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
231
232impl Future for MaybePermitFuture {
233 type Output = Option<Permit>;
234
235 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
236 Pin::new(&mut self.0).poll(cx).map(|r| r.ok().map(Permit))
237 }
238}
239
240#[non_exhaustive]
243#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
244pub struct Bucket {
245 pub limit: u16,
247 pub remaining: u16,
249 pub reset_at: Instant,
251}
252
253type Predicate = Box<dyn FnOnce(Option<Bucket>) -> bool + Send>;
255
256#[derive(Clone, Debug)]
264pub struct RateLimiter {
265 tx: mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
267}
268
269impl RateLimiter {
270 pub fn new(global_limit: u16) -> Self {
272 let (tx, rx) = mpsc::unbounded_channel();
273 tokio::spawn(actor::runner(global_limit, rx));
274
275 Self { tx }
276 }
277
278 #[allow(clippy::missing_panics_doc)]
282 pub fn acquire(&self, endpoint: Endpoint) -> PermitFuture {
283 let (tx, rx) = oneshot::channel();
284 self.tx
285 .send((
286 actor::Message {
287 endpoint,
288 notifier: tx,
289 },
290 None,
291 ))
292 .expect("actor is alive");
293
294 PermitFuture(rx)
295 }
296
297 #[allow(clippy::missing_panics_doc)]
329 pub fn acquire_if<P>(&self, endpoint: Endpoint, predicate: P) -> MaybePermitFuture
330 where
331 P: FnOnce(Option<Bucket>) -> bool + Send + 'static,
332 {
333 let (tx, rx) = oneshot::channel();
334 self.tx
335 .send((
336 actor::Message {
337 endpoint,
338 notifier: tx,
339 },
340 Some(Box::new(predicate)),
341 ))
342 .expect("actor is alive");
343
344 MaybePermitFuture(rx)
345 }
346
347 #[allow(clippy::missing_panics_doc)]
351 pub async fn bucket(&self, endpoint: Endpoint) -> Option<Bucket> {
352 let (tx, rx) = oneshot::channel();
353 self.acquire_if(endpoint, |bucket| {
354 _ = tx.send(bucket);
355 false
356 })
357 .await;
358
359 rx.await.expect("actor is alive")
360 }
361}
362
363impl Default for RateLimiter {
364 fn default() -> Self {
368 Self::new(50)
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::{
375 Bucket, Endpoint, MaybePermitFuture, Method, Permit, PermitFuture, RateLimitHeaders,
376 RateLimiter,
377 };
378 use static_assertions::assert_impl_all;
379 use std::{
380 fmt::Debug,
381 future::Future,
382 hash::{DefaultHasher, Hash, Hasher as _},
383 time::{Duration, Instant},
384 };
385 use tokio::task;
386
387 assert_impl_all!(Bucket: Clone, Copy, Debug, Eq, Hash, PartialEq, Send, Sync);
388 assert_impl_all!(Endpoint: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
389 assert_impl_all!(MaybePermitFuture: Debug, Future<Output = Option<Permit>>);
390 assert_impl_all!(Method: Clone, Copy, Debug, Eq, PartialEq);
391 assert_impl_all!(Permit: Debug, Send, Sync);
392 assert_impl_all!(PermitFuture: Debug, Future<Output = Permit>);
393 assert_impl_all!(RateLimitHeaders: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
394 assert_impl_all!(RateLimiter: Clone, Debug, Default, Send, Sync);
395
396 const ENDPOINT: fn() -> Endpoint = || Endpoint {
397 method: Method::Get,
398 path: String::from("applications/@me"),
399 };
400
401 #[tokio::test]
402 async fn acquire_if() {
403 let rate_limiter = RateLimiter::default();
404
405 assert!(
406 rate_limiter
407 .acquire_if(ENDPOINT(), |_| false)
408 .await
409 .is_none()
410 );
411 assert!(
412 rate_limiter
413 .acquire_if(ENDPOINT(), |_| true)
414 .await
415 .is_some()
416 );
417 }
418
419 #[tokio::test]
420 async fn bucket() {
421 let rate_limiter = RateLimiter::default();
422
423 let limit = 2;
424 let remaining = 1;
425 let reset_at = Instant::now() + Duration::from_secs(1);
426 let headers = RateLimitHeaders {
427 bucket: vec![1, 2, 3],
428 limit,
429 remaining,
430 reset_at,
431 };
432
433 rate_limiter
434 .acquire(ENDPOINT())
435 .await
436 .complete(Some(headers));
437 task::yield_now().await;
438
439 let bucket = rate_limiter.bucket(ENDPOINT()).await.unwrap();
440 assert_eq!(bucket.limit, limit);
441 assert_eq!(bucket.remaining, remaining);
442 assert!(
443 bucket.reset_at.saturating_duration_since(reset_at) < Duration::from_millis(1)
444 && reset_at.saturating_duration_since(bucket.reset_at) < Duration::from_millis(1)
445 );
446 }
447
448 fn with_hasher(f: impl FnOnce(&mut DefaultHasher)) -> u64 {
449 let mut hasher = DefaultHasher::new();
450 f(&mut hasher);
451 hasher.finish()
452 }
453
454 #[test]
455 fn endpoint() {
456 let invalid = Endpoint {
457 method: Method::Get,
458 path: String::from("/guilds/745809834183753828/audit-logs?limit=10"),
459 };
460 let delete_webhook = Endpoint {
461 method: Method::Delete,
462 path: String::from("webhooks/1"),
463 };
464 let interaction_response = Endpoint {
465 method: Method::Post,
466 path: String::from("interactions/1/abc/callback"),
467 };
468
469 assert!(!invalid.is_valid());
470 assert!(delete_webhook.is_valid());
471 assert!(interaction_response.is_valid());
472
473 assert!(delete_webhook.is_interaction());
474 assert!(interaction_response.is_interaction());
475
476 assert_eq!(
477 with_hasher(|state| invalid.hash_resources(state)),
478 with_hasher(|_| {})
479 );
480 assert_eq!(
481 with_hasher(|state| delete_webhook.hash_resources(state)),
482 with_hasher(|state| {
483 "webhooks".hash(state);
484 b"1".hash(state);
485 })
486 );
487 assert_eq!(
488 with_hasher(|state| interaction_response.hash_resources(state)),
489 with_hasher(|_| {})
490 );
491 }
492
493 #[test]
494 fn method_conversions() {
495 assert_eq!("DELETE", Method::Delete.name());
496 assert_eq!("GET", Method::Get.name());
497 assert_eq!("PATCH", Method::Patch.name());
498 assert_eq!("POST", Method::Post.name());
499 assert_eq!("PUT", Method::Put.name());
500 }
501}