1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
//! Middleware to rate-limit requests built on [`reqwest_middleware`].
//!
//! You're exected to provide your own [`RateLimiter`] implementation.
//!
//! ## Example
//!
//! ```
//! use async_trait::async_trait;
//! use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
//!
//! struct RateLimiter;
//!
//! #[async_trait]
//! impl reqwest_ratelimit::RateLimiter for RateLimiter {
//!     async fn acquire_permit(&self) {
//!         // noop
//!     }
//! }
//!
//! async fn run() {
//!     let client = ClientBuilder::new(reqwest::Client::new())
//!         .with(reqwest_ratelimit::all(RateLimiter))
//!         .build();
//!
//!     client.get("https://crates.io").send().await.unwrap();
//! }
//! ```
use async_trait::async_trait;
use http::Extensions;
use reqwest::{Request, Response};
use reqwest_middleware::{Next, Result};

/// Request rate limiter.
#[async_trait]
pub trait RateLimiter: Send + Sync + 'static {
    /// Acquires a permit to issue the next request.
    async fn acquire_permit(&self);
}

/// Creates a new [`Middleware`] rate-limiting all requests using the provided [`RateLimiter`].
pub fn all<R>(rate_limiter: R) -> Middleware<R> {
    Middleware { rate_limiter }
}

/// Request rate-limiting middleware.
pub struct Middleware<R> {
    rate_limiter: R,
}

#[async_trait]
impl<R: RateLimiter> reqwest_middleware::Middleware for Middleware<R> {
    async fn handle(
        &self,
        req: Request,
        extensions: &'_ mut Extensions,
        next: Next<'_>,
    ) -> Result<Response> {
        self.rate_limiter.acquire_permit().await;
        next.run(req, extensions).await
    }
}