use reqwest_rate_limit::{ResponseMiddleware, governor::Quota};
use std::num::NonZeroU32;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::time::sleep;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum RateLimitState {
PrimaryRateLimit,
SecondaryRateLimit { retries: u32 },
}
#[derive(Clone)]
pub struct RateLimitResponseMiddleware {
state: Arc<Mutex<Option<RateLimitState>>>,
max_secondary_retries: u32,
}
#[derive(Debug)]
pub enum RateLimitError {
Transport(reqwest::Error),
PrimaryRateLimit {
x_ratelimit_reset: u64,
},
SecondaryRateLimitRetryAfter {
retry_after: u64,
},
SecondaryRateLimitReset {
x_ratelimit_reset: u64,
},
SecondaryRateLimitWait {
retry_after: u64,
},
SecondaryRateLimitExponentialBackoff {
retry_after: u64,
attempt: u32,
},
SecondaryRateLimitExhausted {
retries: u32,
},
}
impl RateLimitResponseMiddleware {
const MIN_SECONDARY_WAIT_SECS: u64 = 60;
const MAX_SECONDARY_BACKOFF_SECS: u64 = 60 * 60;
const DEFAULT_MAX_SECONDARY_RETRIES: u32 = 5;
pub fn new(max_secondary_retries: u32) -> Self {
Self {
state: Arc::new(Mutex::new(None)),
max_secondary_retries,
}
}
fn parse_header_u64(headers: &http::HeaderMap, name: &str) -> Option<u64> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
}
fn secondary_backoff_seconds(attempt: u32) -> u64 {
let exp = attempt.saturating_sub(1).min(16);
let backoff = Self::MIN_SECONDARY_WAIT_SECS.saturating_mul(1u64 << exp);
backoff.min(Self::MAX_SECONDARY_BACKOFF_SECS)
}
fn sleep_until_epoch_seconds(epoch_seconds: u64) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs();
if epoch_seconds > now {
let wait = Duration::from_secs(epoch_seconds - now);
tokio::runtime::Handle::current().block_on(async {
sleep(wait).await;
});
}
}
fn apply(&self, error: Option<&RateLimitError>) {
let mut state = self.state.lock().unwrap();
let Some(error) = error else {
*state = None;
return;
};
let new_state = match error {
RateLimitError::PrimaryRateLimit { .. } => Some(RateLimitState::PrimaryRateLimit),
RateLimitError::SecondaryRateLimitRetryAfter { .. }
| RateLimitError::SecondaryRateLimitReset { .. }
| RateLimitError::SecondaryRateLimitWait { .. }
| RateLimitError::SecondaryRateLimitExponentialBackoff { .. }
| RateLimitError::SecondaryRateLimitExhausted { .. } => {
let retries = match *state {
Some(RateLimitState::SecondaryRateLimit { retries }) => {
retries.saturating_add(1)
}
_ => 0,
};
Some(RateLimitState::SecondaryRateLimit { retries })
}
RateLimitError::Transport(_) => *state,
};
*state = new_state;
}
}
impl Default for RateLimitResponseMiddleware {
fn default() -> Self {
Self::new(Self::DEFAULT_MAX_SECONDARY_RETRIES)
}
}
impl From<reqwest::Error> for RateLimitError {
fn from(value: reqwest::Error) -> Self {
Self::Transport(value)
}
}
impl ResponseMiddleware for RateLimitResponseMiddleware {
type Error = RateLimitError;
fn on_response(
&self,
response: reqwest::Result<reqwest::Response>,
) -> Result<reqwest::Response, Self::Error> {
let response = response.map_err(RateLimitError::Transport)?;
let status = response.status();
if status != reqwest::StatusCode::FORBIDDEN
&& status != reqwest::StatusCode::TOO_MANY_REQUESTS
{
self.apply(None);
return Ok(response);
}
let headers = response.headers();
let remaining = Self::parse_header_u64(headers, "x-ratelimit-remaining");
let reset = Self::parse_header_u64(headers, "x-ratelimit-reset");
let retry_after = Self::parse_header_u64(headers, "retry-after");
let mut state = self.state.lock().unwrap();
let prev_secondary_retries = match *state {
Some(RateLimitState::SecondaryRateLimit { retries }) => Some(retries),
_ => None,
};
let is_secondary = if prev_secondary_retries.is_some() || retry_after.is_some() {
true
} else {
remaining != Some(0)
};
let error = if !is_secondary {
RateLimitError::PrimaryRateLimit {
x_ratelimit_reset: reset.unwrap_or(0),
}
} else {
let attempt = prev_secondary_retries.unwrap_or(0).saturating_add(1);
if attempt > self.max_secondary_retries {
RateLimitError::SecondaryRateLimitExhausted { retries: attempt }
} else if attempt > 1 {
let retry_after = Self::secondary_backoff_seconds(attempt);
RateLimitError::SecondaryRateLimitExponentialBackoff {
retry_after,
attempt,
}
} else if let Some(retry_after) = retry_after {
RateLimitError::SecondaryRateLimitRetryAfter { retry_after }
} else if remaining == Some(0) {
RateLimitError::SecondaryRateLimitReset {
x_ratelimit_reset: reset.unwrap_or(0),
}
} else {
RateLimitError::SecondaryRateLimitWait {
retry_after: Self::MIN_SECONDARY_WAIT_SECS,
}
}
};
*state = Some(match error {
RateLimitError::PrimaryRateLimit { .. } => RateLimitState::PrimaryRateLimit,
RateLimitError::SecondaryRateLimitRetryAfter { .. }
| RateLimitError::SecondaryRateLimitReset { .. }
| RateLimitError::SecondaryRateLimitWait { .. }
| RateLimitError::SecondaryRateLimitExponentialBackoff { .. }
| RateLimitError::SecondaryRateLimitExhausted { .. } => {
let retries = prev_secondary_retries.unwrap_or(0).saturating_add(1);
RateLimitState::SecondaryRateLimit { retries }
}
RateLimitError::Transport(_) => return Err(error),
});
match error {
RateLimitError::PrimaryRateLimit { x_ratelimit_reset }
| RateLimitError::SecondaryRateLimitReset { x_ratelimit_reset } => {
Self::sleep_until_epoch_seconds(x_ratelimit_reset);
}
_ => {}
}
Err(error)
}
}
#[tokio::main]
async fn main() {
let middleware = RateLimitResponseMiddleware::default();
{
let rate_limiter =
governor::RateLimiter::direct(Quota::per_hour(NonZeroU32::new(5_000).unwrap()));
let reqwest_client = reqwest::Client::builder()
.user_agent("reqwest-rate-limit-example")
.build()
.unwrap();
let request = reqwest_client.get("https://api.github.com/rate_limit");
let req = reqwest_rate_limit::send_with_rate_limiter_and_middleware(
request,
&rate_limiter,
&middleware,
);
let _res = req.await.unwrap();
}
{
let rate_limiter =
governor::RateLimiter::direct(Quota::per_hour(NonZeroU32::new(5_000).unwrap()));
let client = reqwest_rate_limit::Client::builder()
.user_agent("reqwest-rate-limit-example")
.response_middleware(middleware)
.rate_limiter(Arc::new(rate_limiter))
.build()
.unwrap();
let req = client.get("https://github.com").send();
let _res = req.await.unwrap();
}
}