twilight_http_ratelimiting/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(
3    clippy::missing_const_for_fn,
4    clippy::missing_docs_in_private_items,
5    clippy::pedantic,
6    missing_docs,
7    unsafe_code
8)]
9#![allow(clippy::module_name_repetitions, clippy::must_use_candidate)]
10
11mod actor;
12
13use std::{
14    future::Future,
15    hash::{Hash as _, Hasher},
16    pin::Pin,
17    task::{Context, Poll},
18    time::{Duration, Instant},
19};
20use tokio::sync::{mpsc, oneshot};
21
22/// Duration from the first globally limited request until the remaining count
23/// resets to the global limit count.
24pub const GLOBAL_LIMIT_PERIOD: Duration = Duration::from_secs(1);
25
26/// User actionable description that the actor panicked.
27const ACTOR_PANIC_MESSAGE: &str =
28    "actor task panicked: report its panic message to the maintainers";
29
30/// HTTP request [method].
31///
32/// [method]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
33#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
34#[non_exhaustive]
35pub enum Method {
36    /// Delete a resource.
37    Delete,
38    /// Retrieve a resource.
39    Get,
40    /// Update a resource.
41    Patch,
42    /// Create a resource.
43    Post,
44    /// Replace a resource.
45    Put,
46}
47
48impl Method {
49    /// Name of the method.
50    pub const fn name(self) -> &'static str {
51        match self {
52            Method::Delete => "DELETE",
53            Method::Get => "GET",
54            Method::Patch => "PATCH",
55            Method::Post => "POST",
56            Method::Put => "PUT",
57        }
58    }
59}
60
61/// Rate limited endpoint.
62///
63/// The rate limiter dynamically supports new or unknown API paths, but is consequently unable to
64/// catch invalid arguments. Invalidly structured endpoints may be permitted at an improper time.
65///
66/// # Example
67///
68/// ```no_run
69/// # let rt = tokio::runtime::Builder::new_current_thread()
70/// #     .enable_time()
71/// #     .build()
72/// #     .unwrap();
73/// # rt.block_on(async {
74/// # let rate_limiter = twilight_http_ratelimiting::RateLimiter::default();
75/// use twilight_http_ratelimiting::{Endpoint, Method};
76///
77/// let url = "https://discord.com/api/v10/guilds/745809834183753828/audit-logs?limit=10";
78/// let endpoint = Endpoint {
79///     method: Method::Get,
80///     path: String::from("guilds/745809834183753828/audit-logs"),
81/// };
82/// let permit = rate_limiter.acquire(endpoint).await;
83/// let headers = unimplemented!("GET {url}");
84/// permit.complete(headers);
85/// # });
86/// ```
87#[derive(Clone, Debug, Eq, Hash, PartialEq)]
88pub struct Endpoint {
89    /// Method of the endpoint.
90    pub method: Method,
91    /// API path of the endpoint.
92    ///
93    /// Should not start with a slash (`/`) or include query parameters (`?key=value`).
94    pub path: String,
95}
96
97impl Endpoint {
98    /// Whether the endpoint is properly structured.
99    pub(crate) fn is_valid(&self) -> bool {
100        !self.path.as_bytes().starts_with(b"/") && !self.path.as_bytes().contains(&b'?')
101    }
102
103    /// Whether the endpoint is an interaction.
104    pub(crate) fn is_interaction(&self) -> bool {
105        self.path.as_bytes().starts_with(b"webhooks")
106            || self.path.as_bytes().starts_with(b"interactions")
107    }
108
109    /// Feeds the top-level resources of this endpoint into the given [`Hasher`].
110    ///
111    /// Top-level resources represent the bucket namespace in which they are unique.
112    ///
113    /// Top-level resources are currently:
114    /// - `channels/<channel_id>`
115    /// - `guilds/<guild_id>`
116    /// - `webhooks/<webhook_id>`
117    /// - `webhooks/<webhook_id>/<webhook_token>`
118    pub(crate) fn hash_resources(&self, state: &mut impl Hasher) {
119        let mut segments = self.path.as_bytes().split(|&s| s == b'/');
120        match segments.next().unwrap_or_default() {
121            b"channels" => {
122                if let Some(s) = segments.next() {
123                    "channels".hash(state);
124                    s.hash(state);
125                }
126            }
127            b"guilds" => {
128                if let Some(s) = segments.next() {
129                    "guilds".hash(state);
130                    s.hash(state);
131                }
132            }
133            b"webhooks" => {
134                if let Some(s) = segments.next() {
135                    "webhooks".hash(state);
136                    s.hash(state);
137                }
138                if let Some(s) = segments.next() {
139                    s.hash(state);
140                }
141            }
142            _ => {}
143        }
144    }
145}
146
147/// Parsed user response rate limit headers.
148///
149/// A `limit` of zero marks the [`Bucket`] as exhausted until `reset_at` elapses.
150///
151/// # Global limits
152///
153/// Please open an issue if the [`RateLimiter`] exceeded the global limit.
154///
155/// # Shared limits
156///
157/// You may preemptively exhaust the bucket until `Reset-After` by completing
158/// the [`Permit`] with [`RateLimitHeaders::shared`], but are not required to
159/// since these limits do not count towards the invalid request limit.
160#[derive(Clone, Debug, Eq, Hash, PartialEq)]
161pub struct RateLimitHeaders {
162    /// Bucket identifier.
163    pub bucket: Vec<u8>,
164    /// Total number of requests until the bucket becomes exhausted.
165    pub limit: u16,
166    /// Number of remaining requests until the bucket becomes exhausted.
167    pub remaining: u16,
168    /// Time at which the bucket resets.
169    pub reset_at: Instant,
170}
171
172impl RateLimitHeaders {
173    /// Lowercased name for the bucket header.
174    pub const BUCKET: &'static str = "x-ratelimit-bucket";
175
176    /// Lowercased name for the limit header.
177    pub const LIMIT: &'static str = "x-ratelimit-limit";
178
179    /// Lowercased name for the remaining header.
180    pub const REMAINING: &'static str = "x-ratelimit-remaining";
181
182    /// Lowercased name for the reset-after header.
183    pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
184
185    /// Lowercased name for the scope header.
186    pub const SCOPE: &'static str = "x-ratelimit-scope";
187
188    /// Emulates a shared resource limit as a user limit by setting `limit` and
189    /// `remaining` to zero.
190    pub fn shared(bucket: Vec<u8>, retry_after: u16) -> Self {
191        Self {
192            bucket,
193            limit: 0,
194            remaining: 0,
195            reset_at: Instant::now() + Duration::from_secs(retry_after.into()),
196        }
197    }
198}
199
200/// Permit to send a Discord HTTP API request to the acquired endpoint.
201#[derive(Debug)]
202#[must_use = "dropping the permit immediately cancels itself"]
203pub struct Permit(oneshot::Sender<Option<RateLimitHeaders>>);
204
205impl Permit {
206    /// Update the [`RateLimiter`] based on the response headers.
207    ///
208    /// Non-completed permits are regarded as cancelled, so only call this
209    /// on receiving a response.
210    #[allow(clippy::missing_panics_doc)]
211    pub fn complete(self, headers: Option<RateLimitHeaders>) {
212        assert!(self.0.send(headers).is_ok(), "{ACTOR_PANIC_MESSAGE}");
213    }
214}
215
216/// Future that completes when a permit is ready.
217#[derive(Debug)]
218#[must_use = "futures do nothing unless you `.await` or poll them"]
219pub struct PermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
220
221impl Future for PermitFuture {
222    type Output = Permit;
223
224    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
225        #[allow(clippy::match_wild_err_arm)]
226        Pin::new(&mut self.0).poll(cx).map(|r| match r {
227            Ok(sender) => Permit(sender),
228            Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
229        })
230    }
231}
232
233/// Future that completes when a permit is ready or cancelled.
234#[derive(Debug)]
235#[must_use = "futures do nothing unless you `.await` or poll them"]
236pub struct MaybePermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
237
238impl Future for MaybePermitFuture {
239    type Output = Option<Permit>;
240
241    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
242        Pin::new(&mut self.0).poll(cx).map(|r| r.ok().map(Permit))
243    }
244}
245
246/// Rate limit information for one or more paths from previous
247/// [`RateLimitHeaders`].
248#[non_exhaustive]
249#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
250pub struct Bucket {
251    /// Total number of permits until the bucket becomes exhausted.
252    pub limit: u16,
253    /// Number of remaining permits until the bucket becomes exhausted.
254    pub remaining: u16,
255    /// Time at which the bucket resets.
256    pub reset_at: Instant,
257}
258
259/// Actor run closure pre-enqueue for early [`MaybePermitFuture`] cancellation.
260type Predicate = Box<dyn FnOnce(Option<Bucket>) -> bool + Send>;
261
262/// Discord HTTP client API rate limiter.
263///
264/// The [`RateLimiter`] runs an associated actor task to concurrently handle permit
265/// requests and responses.
266///
267/// Cloning a [`RateLimiter`] increments just the amount of senders for the actor.
268/// The actor completes when there are no senders and non-completed permits left.
269#[derive(Clone, Debug)]
270pub struct RateLimiter {
271    /// Actor message sender.
272    tx: mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
273}
274
275impl RateLimiter {
276    /// Create a new [`RateLimiter`] with a custom global limit.
277    pub fn new(global_limit: u16) -> Self {
278        let (tx, rx) = mpsc::unbounded_channel();
279        tokio::spawn(actor::runner(global_limit, rx));
280
281        Self { tx }
282    }
283
284    /// Await a single permit for this endpoint.
285    ///
286    /// Permits are queued per endpoint in the order they were requested.
287    #[allow(clippy::missing_panics_doc)]
288    pub fn acquire(&self, endpoint: Endpoint) -> PermitFuture {
289        let (notifier, rx) = oneshot::channel();
290        let message = actor::Message { endpoint, notifier };
291        assert!(
292            self.tx.send((message, None)).is_ok(),
293            "{ACTOR_PANIC_MESSAGE}"
294        );
295
296        PermitFuture(rx)
297    }
298
299    /// Await a single permit for this endpoint, but only if the predicate evaluates
300    /// to `true`.
301    ///
302    /// Permits are queued per endpoint in the order they were requested.
303    ///
304    /// Note that the predicate is asynchronously called in the actor task.
305    ///
306    /// # Example
307    ///
308    /// ```no_run
309    /// # let rt = tokio::runtime::Builder::new_current_thread()
310    /// #     .enable_time()
311    /// #     .build()
312    /// #     .unwrap();
313    /// # rt.block_on(async {
314    /// # let rate_limiter = twilight_http_ratelimiting::RateLimiter::default();
315    /// use twilight_http_ratelimiting::{Endpoint, Method};
316    ///
317    /// let endpoint = Endpoint {
318    ///     method: Method::Get,
319    ///     path: String::from("applications/@me"),
320    /// };
321    /// if let Some(permit) = rate_limiter
322    ///     .acquire_if(endpoint, |b| b.is_none_or(|b| b.remaining > 10))
323    ///     .await
324    /// {
325    ///     let headers = unimplemented!("GET /applications/@me");
326    ///     permit.complete(headers);
327    /// }
328    /// # });
329    /// ```
330    #[allow(clippy::missing_panics_doc)]
331    pub fn acquire_if<P>(&self, endpoint: Endpoint, predicate: P) -> MaybePermitFuture
332    where
333        P: FnOnce(Option<Bucket>) -> bool + Send + 'static,
334    {
335        let (notifier, rx) = oneshot::channel();
336        let message = actor::Message { endpoint, notifier };
337        assert!(
338            self.tx.send((message, Some(Box::new(predicate)))).is_ok(),
339            "{ACTOR_PANIC_MESSAGE}"
340        );
341
342        MaybePermitFuture(rx)
343    }
344
345    /// Retrieve the [`Bucket`] for this endpoint.
346    ///
347    /// The bucket is internally retrieved via [`acquire_if`][Self::acquire_if].
348    #[allow(clippy::missing_panics_doc)]
349    pub async fn bucket(&self, endpoint: Endpoint) -> Option<Bucket> {
350        let (tx, rx) = oneshot::channel();
351        self.acquire_if(endpoint, |bucket| {
352            // Ignore cancellation error.
353            _ = tx.send(bucket);
354            false
355        })
356        .await;
357
358        #[allow(clippy::match_wild_err_arm)]
359        match rx.await {
360            Ok(bucket) => bucket,
361            Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
362        }
363    }
364}
365
366impl Default for RateLimiter {
367    /// Create a new [`RateLimiter`] with Discord's default global limit.
368    ///
369    /// Currently this is `50`.
370    fn default() -> Self {
371        Self::new(50)
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::{
378        Bucket, Endpoint, MaybePermitFuture, Method, Permit, PermitFuture, RateLimitHeaders,
379        RateLimiter,
380    };
381    use static_assertions::assert_impl_all;
382    use std::{
383        fmt::Debug,
384        future::Future,
385        hash::{DefaultHasher, Hash, Hasher as _},
386        time::{Duration, Instant},
387    };
388    use tokio::task;
389
390    assert_impl_all!(Bucket: Clone, Copy, Debug, Eq, Hash, PartialEq, Send, Sync);
391    assert_impl_all!(Endpoint: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
392    assert_impl_all!(MaybePermitFuture: Debug, Future<Output = Option<Permit>>);
393    assert_impl_all!(Method: Clone, Copy, Debug, Eq, PartialEq);
394    assert_impl_all!(Permit: Debug, Send, Sync);
395    assert_impl_all!(PermitFuture: Debug, Future<Output = Permit>);
396    assert_impl_all!(RateLimitHeaders: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
397    assert_impl_all!(RateLimiter: Clone, Debug, Default, Send, Sync);
398
399    const ENDPOINT: fn() -> Endpoint = || Endpoint {
400        method: Method::Get,
401        path: String::from("applications/@me"),
402    };
403
404    #[tokio::test]
405    async fn acquire_if() {
406        let rate_limiter = RateLimiter::default();
407
408        assert!(
409            rate_limiter
410                .acquire_if(ENDPOINT(), |_| false)
411                .await
412                .is_none()
413        );
414        assert!(
415            rate_limiter
416                .acquire_if(ENDPOINT(), |_| true)
417                .await
418                .is_some()
419        );
420    }
421
422    #[tokio::test]
423    async fn bucket() {
424        let rate_limiter = RateLimiter::default();
425
426        let limit = 2;
427        let remaining = 1;
428        let reset_at = Instant::now() + Duration::from_secs(1);
429        let headers = RateLimitHeaders {
430            bucket: vec![1, 2, 3],
431            limit,
432            remaining,
433            reset_at,
434        };
435
436        rate_limiter
437            .acquire(ENDPOINT())
438            .await
439            .complete(Some(headers));
440        task::yield_now().await;
441
442        let bucket = rate_limiter.bucket(ENDPOINT()).await.unwrap();
443        assert_eq!(bucket.limit, limit);
444        assert_eq!(bucket.remaining, remaining);
445        assert!(
446            bucket.reset_at.saturating_duration_since(reset_at) < Duration::from_millis(1)
447                && reset_at.saturating_duration_since(bucket.reset_at) < Duration::from_millis(1)
448        );
449    }
450
451    fn with_hasher(f: impl FnOnce(&mut DefaultHasher)) -> u64 {
452        let mut hasher = DefaultHasher::new();
453        f(&mut hasher);
454        hasher.finish()
455    }
456
457    #[test]
458    fn endpoint() {
459        let invalid = Endpoint {
460            method: Method::Get,
461            path: String::from("/guilds/745809834183753828/audit-logs?limit=10"),
462        };
463        let delete_webhook = Endpoint {
464            method: Method::Delete,
465            path: String::from("webhooks/1"),
466        };
467        let interaction_response = Endpoint {
468            method: Method::Post,
469            path: String::from("interactions/1/abc/callback"),
470        };
471
472        assert!(!invalid.is_valid());
473        assert!(delete_webhook.is_valid());
474        assert!(interaction_response.is_valid());
475
476        assert!(delete_webhook.is_interaction());
477        assert!(interaction_response.is_interaction());
478
479        assert_eq!(
480            with_hasher(|state| invalid.hash_resources(state)),
481            with_hasher(|_| {})
482        );
483        assert_eq!(
484            with_hasher(|state| delete_webhook.hash_resources(state)),
485            with_hasher(|state| {
486                "webhooks".hash(state);
487                b"1".hash(state);
488            })
489        );
490        assert_eq!(
491            with_hasher(|state| interaction_response.hash_resources(state)),
492            with_hasher(|_| {})
493        );
494    }
495
496    #[test]
497    fn method_conversions() {
498        assert_eq!("DELETE", Method::Delete.name());
499        assert_eq!("GET", Method::Get.name());
500        assert_eq!("PATCH", Method::Patch.name());
501        assert_eq!("POST", Method::Post.name());
502        assert_eq!("PUT", Method::Put.name());
503    }
504}