use std::collections::HashMap;
use std::fmt;
use std::str::{self, FromStr};
use std::sync::Arc;
use std::time::SystemTime;
use reqwest::header::HeaderMap;
use reqwest::{Client, Response, StatusCode};
use secrecy::{ExposeSecret, SecretString};
use tokio::sync::{Mutex, RwLock};
use tokio::time::{sleep, Duration};
use tracing::{debug, instrument};
pub use super::routing::RatelimitingBucket;
use super::{HttpError, LightMethod, Request};
use crate::internal::prelude::*;
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct RatelimitInfo {
pub timeout: std::time::Duration,
pub limit: i64,
pub method: LightMethod,
pub path: String,
pub global: bool,
}
pub struct Ratelimiter {
client: Client,
global: Arc<Mutex<()>>,
routes: Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>>,
token: SecretString,
absolute_ratelimits: bool,
ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
}
impl fmt::Debug for Ratelimiter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Ratelimiter")
.field("client", &self.client)
.field("global", &self.global)
.field("routes", &self.routes)
.field("token", &self.token)
.field("absolute_ratelimits", &self.absolute_ratelimits)
.field("ratelimit_callback", &"Fn(RatelimitInfo)")
.finish()
}
}
impl Ratelimiter {
#[must_use]
pub fn new(client: Client, token: impl Into<String>) -> Self {
Self::new_(client, token.into())
}
fn new_(client: Client, token: String) -> Self {
Self {
client,
global: Arc::default(),
routes: Arc::default(),
token: SecretString::new(token),
ratelimit_callback: Box::new(|_| {}),
absolute_ratelimits: false,
}
}
pub fn set_ratelimit_callback(
&mut self,
ratelimit_callback: Box<dyn Fn(RatelimitInfo) + Send + Sync>,
) {
self.ratelimit_callback = ratelimit_callback;
}
pub fn set_absolute_ratelimits(&mut self, absolute_ratelimits: bool) {
self.absolute_ratelimits = absolute_ratelimits;
}
#[must_use]
pub fn routes(&self) -> Arc<RwLock<HashMap<RatelimitingBucket, Arc<Mutex<Ratelimit>>>>> {
Arc::clone(&self.routes)
}
#[instrument]
pub async fn perform(&self, req: Request<'_>) -> Result<Response> {
loop {
drop(self.global.lock().await);
let ratelimiting_bucket = req.route.ratelimiting_bucket();
let bucket =
Arc::clone(self.routes.write().await.entry(ratelimiting_bucket).or_default());
bucket.lock().await.pre_hook(&req, &self.ratelimit_callback).await;
let request = req.clone().build(&self.client, self.token.expose_secret(), None)?;
let response = self.client.execute(request.build()?).await?;
if ratelimiting_bucket.is_none() {
return Ok(response);
}
let redo = if response.headers().get("x-ratelimit-global").is_some() {
drop(self.global.lock().await);
Ok(
if let Some(retry_after) =
parse_header::<f64>(response.headers(), "retry-after")?
{
debug!(
"Ratelimited on route {:?} for {:?}s",
ratelimiting_bucket, retry_after
);
(self.ratelimit_callback)(RatelimitInfo {
timeout: Duration::from_secs_f64(retry_after),
limit: 50,
method: req.method,
path: req.route.path().to_string(),
global: true,
});
sleep(Duration::from_secs_f64(retry_after)).await;
true
} else {
false
},
)
} else {
bucket
.lock()
.await
.post_hook(&response, &req, &self.ratelimit_callback, self.absolute_ratelimits)
.await
};
if !redo.unwrap_or(true) {
return Ok(response);
}
}
}
}
#[derive(Debug)]
pub struct Ratelimit {
limit: i64,
remaining: i64,
reset: Option<SystemTime>,
reset_after: Option<Duration>,
}
impl Ratelimit {
#[instrument(skip(ratelimit_callback))]
pub async fn pre_hook(
&mut self,
req: &Request<'_>,
ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
) {
if self.limit() == 0 {
return;
}
let Some(reset) = self.reset else {
self.remaining = self.limit;
return;
};
let Ok(delay) = reset.duration_since(SystemTime::now()) else {
if self.remaining() != 0 {
self.remaining -= 1;
}
return;
};
if self.remaining() == 0 {
debug!(
"Pre-emptive ratelimit on route {:?} for {}ms",
req.route.ratelimiting_bucket(),
delay.as_millis(),
);
ratelimit_callback(RatelimitInfo {
timeout: delay,
limit: self.limit,
method: req.method,
path: req.route.path().to_string(),
global: false,
});
sleep(delay).await;
return;
}
self.remaining -= 1;
}
#[instrument(skip(ratelimit_callback))]
pub async fn post_hook(
&mut self,
response: &Response,
req: &Request<'_>,
ratelimit_callback: &(dyn Fn(RatelimitInfo) + Send + Sync),
absolute_ratelimits: bool,
) -> Result<bool> {
if let Some(limit) = parse_header(response.headers(), "x-ratelimit-limit")? {
self.limit = limit;
}
if let Some(remaining) = parse_header(response.headers(), "x-ratelimit-remaining")? {
self.remaining = remaining;
}
if absolute_ratelimits {
if let Some(reset) = parse_header::<f64>(response.headers(), "x-ratelimit-reset")? {
self.reset = Some(std::time::UNIX_EPOCH + Duration::from_secs_f64(reset));
}
}
if let Some(reset_after) =
parse_header::<f64>(response.headers(), "x-ratelimit-reset-after")?
{
if !absolute_ratelimits {
self.reset = Some(SystemTime::now() + Duration::from_secs_f64(reset_after));
}
self.reset_after = Some(Duration::from_secs_f64(reset_after));
}
Ok(if response.status() != StatusCode::TOO_MANY_REQUESTS {
false
} else if let Some(retry_after) = parse_header::<f64>(response.headers(), "retry-after")? {
debug!(
"Ratelimited on route {:?} for {:?}s",
req.route.ratelimiting_bucket(),
retry_after
);
ratelimit_callback(RatelimitInfo {
timeout: Duration::from_secs_f64(retry_after),
limit: self.limit,
method: req.method,
path: req.route.path().to_string(),
global: false,
});
sleep(Duration::from_secs_f64(retry_after)).await;
true
} else {
false
})
}
#[inline]
#[must_use]
pub const fn limit(&self) -> i64 {
self.limit
}
#[inline]
#[must_use]
pub const fn remaining(&self) -> i64 {
self.remaining
}
#[inline]
#[must_use]
pub const fn reset(&self) -> Option<SystemTime> {
self.reset
}
#[inline]
#[must_use]
pub const fn reset_after(&self) -> Option<Duration> {
self.reset_after
}
}
impl Default for Ratelimit {
fn default() -> Self {
Self {
limit: i64::MAX,
remaining: i64::MAX,
reset: None,
reset_after: None,
}
}
}
fn parse_header<T: FromStr>(headers: &HeaderMap, header: &str) -> Result<Option<T>> {
let Some(header) = headers.get(header) else { return Ok(None) };
let unicode =
str::from_utf8(header.as_bytes()).map_err(|_| Error::from(HttpError::RateLimitUtf8))?;
let num = unicode.parse().map_err(|_| Error::from(HttpError::RateLimitI64F64))?;
Ok(Some(num))
}
#[cfg(test)]
mod tests {
use std::error::Error as StdError;
use std::result::Result as StdResult;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use super::parse_header;
use crate::error::Error;
use crate::http::HttpError;
type Result<T> = StdResult<T, Box<dyn StdError>>;
fn headers() -> HeaderMap {
let pairs = &[
(HeaderName::from_static("x-ratelimit-limit"), HeaderValue::from_static("5")),
(HeaderName::from_static("x-ratelimit-remaining"), HeaderValue::from_static("4")),
(
HeaderName::from_static("x-ratelimit-reset"),
HeaderValue::from_static("1560704880.423"),
),
(HeaderName::from_static("x-bad-num"), HeaderValue::from_static("abc")),
(
HeaderName::from_static("x-bad-unicode"),
HeaderValue::from_bytes(&[255, 255, 255, 255]).unwrap(),
),
];
let mut map = HeaderMap::with_capacity(pairs.len());
for (name, val) in pairs {
map.insert(name, val.clone());
}
map
}
#[test]
#[allow(clippy::float_cmp)]
fn test_parse_header_good() -> Result<()> {
let headers = headers();
assert_eq!(parse_header::<i64>(&headers, "x-ratelimit-limit")?.unwrap(), 5);
assert_eq!(parse_header::<i64>(&headers, "x-ratelimit-remaining")?.unwrap(), 4,);
assert_eq!(parse_header::<f64>(&headers, "x-ratelimit-reset")?.unwrap(), 1_560_704_880.423);
Ok(())
}
#[test]
fn test_parse_header_errors() {
let headers = headers();
assert!(matches!(
parse_header::<i64>(&headers, "x-bad-num").unwrap_err(),
Error::Http(HttpError::RateLimitI64F64)
));
assert!(matches!(
parse_header::<i64>(&headers, "x-bad-unicode").unwrap_err(),
Error::Http(HttpError::RateLimitUtf8)
));
}
}