use std::{
collections::{HashMap, VecDeque},
sync::Arc,
time::{Duration, Instant},
};
use reqwest::{header, Client, Method, Response};
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::RwLock;
use tracing::{debug, error, warn};
use crate::error::{DiscordError, Result};
const API_BASE: &str = "https://discord.com/api/v10";
const DEFAULT_RETRY_COUNT: u32 = 5;
const MAX_RETRY_AFTER_SECONDS: u64 = 30;
const INVALID_REQUEST_WINDOW: Duration = Duration::from_secs(10 * 60);
const INVALID_REQUEST_WARN_THRESHOLD: usize = 7000;
const INVALID_REQUEST_ERROR_THRESHOLD: usize = 9500;
#[derive(Clone, Default)]
pub struct InvalidRequestTracker {
inner: Arc<RwLock<VecDeque<Instant>>>,
}
impl InvalidRequestTracker {
pub async fn record(&self) -> usize {
let now = Instant::now();
let mut guard = self.inner.write().await;
while let Some(&front) = guard.front() {
if now.duration_since(front) > INVALID_REQUEST_WINDOW {
guard.pop_front();
} else {
break;
}
}
guard.push_back(now);
let count = guard.len();
if count >= INVALID_REQUEST_ERROR_THRESHOLD {
tracing::error!(
invalid_request_count = count,
"invalid-request budget critical: Cloudflare 1015 IP ban imminent (>= {})",
INVALID_REQUEST_ERROR_THRESHOLD
);
} else if count >= INVALID_REQUEST_WARN_THRESHOLD {
tracing::warn!(
invalid_request_count = count,
"invalid-request budget elevated (>= {})",
INVALID_REQUEST_WARN_THRESHOLD
);
}
count
}
pub async fn count(&self) -> usize {
let now = Instant::now();
let mut guard = self.inner.write().await;
while let Some(&front) = guard.front() {
if now.duration_since(front) > INVALID_REQUEST_WINDOW {
guard.pop_front();
} else {
break;
}
}
guard.len()
}
}
#[derive(Debug, Clone)]
pub struct RatelimitInfo {
pub timeout: std::time::Duration,
pub limit: u64,
pub method: reqwest::Method,
pub path: String,
pub global: bool,
}
struct RateLimitInfo {
retry_after: f64,
bucket: Option<String>,
global: bool,
scope: Option<String>,
}
#[derive(Debug, Clone)]
pub struct RatelimitingBucket {
pub remaining: u64,
pub limit: u64,
pub reset_at: f64,
pub reset_after: Option<f64>,
}
#[derive(Clone, Default)]
pub struct Ratelimit {
pub buckets: std::sync::Arc<dashmap::DashMap<String, RatelimitingBucket>>,
pub callback: Option<std::sync::Arc<dyn Fn(RatelimitInfo) + Send + Sync>>,
pub global: std::sync::Arc<tokio::sync::Mutex<()>>,
}
impl Ratelimit {
pub fn get_route_key(method: &Method, endpoint: &str) -> String {
let path = if let Some(stripped) = endpoint.split("/api/v10/").nth(1) { stripped } else { endpoint.trim_start_matches('/') };
let path = path.split('?').next().unwrap_or(path);
let parts: Vec<&str> = path.split('/').collect();
let mut route = String::from(method.as_str());
if parts.is_empty() {
return route;
}
let mut iter = parts.iter().peekable();
while let Some(&part) = iter.next() {
route.push('/');
match part {
"channels" | "guilds" | "webhooks" => {
route.push_str(part);
if let Some(&id) = iter.next() {
route.push('/');
route.push_str(id);
}
}
_ => {
if part.chars().all(|c| c.is_ascii_digit()) && part.len() > 10 {
route.push_str("{id}");
} else {
route.push_str(part);
}
}
}
}
route
}
pub async fn pre_hook(&self, method: &Method, endpoint: &str, route_key: &str) {
let wait_info = {
if let Some(bucket) = self.buckets.get(route_key) {
if bucket.remaining == 0 {
let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs_f64();
if bucket.reset_at > now {
Some((bucket.reset_at - now, bucket.limit))
} else {
None
}
} else {
None
}
} else {
None
}
};
if let Some((secs, limit)) = wait_info {
if secs > 0.0 {
tracing::debug!("Preemptive rate limit hit for {}, waiting {:.3}s", route_key, secs);
if let Some(cb) = &self.callback {
cb(RatelimitInfo {
timeout: std::time::Duration::from_secs_f64(secs),
limit,
method: method.clone(),
path: endpoint.to_string(),
global: false,
});
}
tokio::time::sleep(std::time::Duration::from_secs_f64(secs)).await;
}
}
}
pub fn post_hook(&self, route_key: &str, headers: &reqwest::header::HeaderMap) {
let remaining = headers.get("x-ratelimit-remaining").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<u64>().ok());
let limit = headers.get("x-ratelimit-limit").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<u64>().ok());
let reset_at = headers.get("x-ratelimit-reset").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<f64>().ok());
let reset_after = headers.get("x-ratelimit-reset-after").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<f64>().ok());
if let (Some(remaining), Some(limit), Some(reset_at)) = (remaining, limit, reset_at) {
self.buckets.insert(route_key.to_string(), RatelimitingBucket { remaining, limit, reset_at, reset_after });
}
}
pub fn record_429(&self, route_key: &str, headers: &reqwest::header::HeaderMap, fallback_retry_after: f64) {
let remaining = headers.get("x-ratelimit-remaining").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<u64>().ok()).unwrap_or(0);
let limit = headers.get("x-ratelimit-limit").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<u64>().ok()).unwrap_or_else(|| self.buckets.get(route_key).map(|b| b.limit).unwrap_or(0));
let reset_after = headers.get("x-ratelimit-reset-after").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<f64>().ok());
let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs_f64();
let reset_at = headers.get("x-ratelimit-reset").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<f64>().ok()).unwrap_or_else(|| now + reset_after.unwrap_or(fallback_retry_after));
self.buckets.insert(route_key.to_string(), RatelimitingBucket { remaining, limit, reset_at, reset_after });
}
}
#[derive(Clone)]
pub struct DiscordHttpClient {
client: Client,
token: String,
custom_headers: HashMap<String, String>,
ratelimit: Ratelimit,
ratelimiter_disabled: bool,
super_properties_b64: Option<String>,
discord_locale: Option<String>,
discord_timezone: Option<String>,
invalid_requests: InvalidRequestTracker,
}
impl DiscordHttpClient {
pub fn new(token: impl Into<String>, proxy: Option<String>, ratelimiter_disabled: bool) -> Self {
let mut builder = Client::builder().timeout(Duration::from_secs(30)).pool_max_idle_per_host(10).pool_idle_timeout(Duration::from_secs(90)).tcp_keepalive(Duration::from_secs(60)).tcp_nodelay(true).user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36").http1_only();
if let Some(proxy_url) = proxy {
if let Ok(p) = reqwest::Proxy::all(&proxy_url) {
builder = builder.proxy(p);
}
}
let client = builder.build().unwrap_or_else(|e| panic!("Failed to create HTTP client: {}", e));
Self {
client,
token: token.into(),
custom_headers: HashMap::new(),
ratelimit: Ratelimit::default(),
ratelimiter_disabled,
super_properties_b64: None,
discord_locale: Some("en-US".to_string()),
discord_timezone: None,
invalid_requests: InvalidRequestTracker::default(),
}
}
pub fn with_headers(headers: HashMap<String, String>, proxy: Option<String>, ratelimiter_disabled: bool) -> Option<Self> {
let token = headers.iter().find(|(k, _)| k.to_lowercase() == "authorization").map(|(_, v)| v.clone())?;
let mut builder = Client::builder().timeout(Duration::from_secs(30)).pool_max_idle_per_host(10).pool_idle_timeout(Duration::from_secs(90)).tcp_keepalive(Duration::from_secs(60)).tcp_nodelay(true).user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36").http1_only();
if let Some(proxy_url) = proxy {
if let Ok(p) = reqwest::Proxy::all(&proxy_url) {
builder = builder.proxy(p);
}
}
let client = builder.build().unwrap_or_else(|e| panic!("Failed to create HTTP client: {}", e));
Some(Self {
client,
token,
custom_headers: headers,
ratelimit: Ratelimit::default(),
ratelimiter_disabled,
super_properties_b64: None,
discord_locale: Some("en-US".to_string()),
discord_timezone: None,
invalid_requests: InvalidRequestTracker::default(),
})
}
pub fn set_super_properties_b64(&mut self, value: Option<String>) {
self.super_properties_b64 = value;
}
pub fn set_discord_locale(&mut self, locale: Option<String>) {
self.discord_locale = locale;
}
pub fn set_discord_timezone(&mut self, tz: Option<String>) {
self.discord_timezone = tz;
}
pub async fn invalid_request_count(&self) -> usize {
self.invalid_requests.count().await
}
pub fn token(&self) -> &str {
&self.token
}
pub fn set_ratelimit_callback(&mut self, callback: std::sync::Arc<dyn Fn(RatelimitInfo) + Send + Sync>) {
self.ratelimit.callback = Some(callback);
}
pub async fn get<T: DeserializeOwned>(&self, route: crate::route::Route<'_>) -> Result<T> {
self.request(Method::GET, route, None::<()>).await
}
pub async fn post<T: DeserializeOwned, B: Serialize>(&self, route: crate::route::Route<'_>, body: B) -> Result<T> {
self.request(Method::POST, route, Some(body)).await
}
pub async fn patch<T: DeserializeOwned, B: Serialize>(&self, route: crate::route::Route<'_>, body: B) -> Result<T> {
self.request(Method::PATCH, route, Some(body)).await
}
pub async fn put<T: DeserializeOwned, B: Serialize>(&self, route: crate::route::Route<'_>, body: B) -> Result<T> {
self.request(Method::PUT, route, Some(body)).await
}
pub async fn delete(&self, route: crate::route::Route<'_>) -> Result<()> {
self.request_no_response(Method::DELETE, route, None::<()>).await
}
pub async fn delete_with_response<T: DeserializeOwned>(&self, route: crate::route::Route<'_>) -> Result<T> {
self.request(Method::DELETE, route, None::<()>).await
}
pub async fn post_empty(&self, route: crate::route::Route<'_>) -> Result<()> {
self.request_no_response(Method::POST, route, None::<()>).await
}
pub async fn post_no_response<B: Serialize>(&self, route: crate::route::Route<'_>, body: B) -> Result<()> {
self.request_no_response(Method::POST, route, Some(body)).await
}
pub async fn patch_no_response<B: Serialize>(&self, route: crate::route::Route<'_>, body: B) -> Result<()> {
self.request_no_response(Method::PATCH, route, Some(body)).await
}
pub async fn post_with_referer<T: DeserializeOwned, B: Serialize>(&self, route: crate::route::Route<'_>, body: B, referer: &str) -> Result<T> {
self.request_with_referer(Method::POST, route, Some(body), Some(referer)).await
}
pub async fn post_multipart<T: DeserializeOwned>(&self, route: crate::route::Route<'_>, payload_json: serde_json::Value, attachments: Vec<crate::types::CreateAttachment>) -> Result<T> {
use reqwest::multipart::{Form, Part};
let path_cow = route.path();
let endpoint = path_cow.as_ref();
let url = if endpoint.starts_with("http") { endpoint.to_string() } else { format!("{}/{}", API_BASE, endpoint.trim_start_matches('/')) };
let route_key = Ratelimit::get_route_key(&Method::POST, endpoint);
for attempt in 0..DEFAULT_RETRY_COUNT {
if !self.ratelimiter_disabled {
drop(self.ratelimit.global.lock().await);
self.ratelimit.pre_hook(&Method::POST, endpoint, &route_key).await;
}
let mut form = Form::new().part("payload_json", Part::text(payload_json.to_string()).mime_str("application/json").unwrap_or_else(|_| Part::text(payload_json.to_string())));
for (i, att) in attachments.iter().enumerate() {
let part = Part::bytes(att.data.clone()).file_name(att.filename.clone()).mime_str(&att.mime_type).unwrap_or_else(|_| Part::bytes(att.data.clone()).file_name(att.filename.clone()));
form = form.part(format!("files[{}]", i), part);
}
let mut request = self.client.post(&url).multipart(form);
request = self.apply_common_headers(request, None);
for (key, value) in &self.custom_headers {
if key.to_lowercase() != "authorization" {
request = request.header(key.as_str(), value.as_str());
}
}
debug!("Multipart request attempt {}/{}: POST {}", attempt + 1, DEFAULT_RETRY_COUNT, url);
match request.send().await {
Ok(response) => {
let status = response.status();
if !self.ratelimiter_disabled {
self.ratelimit.post_hook(&route_key, response.headers());
}
if status.is_success() {
let text = response.text().await?;
return serde_json::from_str(&text).map_err(|e| {
error!(error = %e, body = %text, "Failed to parse multipart response");
DiscordError::Json(e)
});
}
if status.as_u16() == 429 {
self.invalid_requests.record().await;
let info = Self::extract_rate_limit_info(response.headers());
if !self.ratelimiter_disabled {
self.ratelimit.record_429(&route_key, response.headers(), info.retry_after);
}
let wait_time = if info.retry_after < 2.0 { 2.0 } else { info.retry_after };
warn!("Rate limited on multipart, waiting {:.2}s", wait_time);
tokio::time::sleep(Duration::from_secs_f64(wait_time)).await;
continue;
}
if status.as_u16() == 401 || status.as_u16() == 403 {
self.invalid_requests.record().await;
}
let error_body = response.text().await.unwrap_or_default();
return Err(if status.is_server_error() { DiscordError::ServiceError { status: status.as_u16(), body: error_body } } else { DiscordError::UnexpectedStatusCode { status: status.as_u16(), body: error_body } });
}
Err(e) => {
if attempt < DEFAULT_RETRY_COUNT - 1 {
tokio::time::sleep(Duration::from_secs(2)).await;
continue;
}
return Err(DiscordError::Http(e));
}
}
}
Err(DiscordError::MaxRetriesExceeded)
}
pub async fn post_raw_multipart<T: DeserializeOwned>(&self, path: String, form: reqwest::multipart::Form) -> Result<T> {
let url = if path.starts_with("http") { path.clone() } else { format!("{}/{}", API_BASE, path.trim_start_matches('/')) };
let route_key = Ratelimit::get_route_key(&Method::POST, &path);
if !self.ratelimiter_disabled {
drop(self.ratelimit.global.lock().await);
self.ratelimit.pre_hook(&Method::POST, &path, &route_key).await;
}
let mut request = self.client.post(&url).multipart(form);
request = self.apply_common_headers(request, None);
for (key, value) in &self.custom_headers {
if key.to_lowercase() != "authorization" {
request = request.header(key.as_str(), value.as_str());
}
}
debug!("Raw multipart POST: {}", url);
let response = request.send().await?;
let status = response.status();
if !self.ratelimiter_disabled {
self.ratelimit.post_hook(&route_key, response.headers());
}
if status.is_success() {
let text = response.text().await?;
return serde_json::from_str(&text).map_err(DiscordError::Json);
}
let code = status.as_u16();
if code == 401 || code == 403 || code == 429 {
self.invalid_requests.record().await;
if code == 429 && !self.ratelimiter_disabled {
let info = Self::extract_rate_limit_info(response.headers());
self.ratelimit.record_429(&route_key, response.headers(), info.retry_after);
}
}
let error_body = response.text().await.unwrap_or_default();
Err(if status.is_server_error() { DiscordError::ServiceError { status: status.as_u16(), body: error_body } } else { DiscordError::UnexpectedStatusCode { status: status.as_u16(), body: error_body } })
}
async fn request<T: DeserializeOwned, B: Serialize>(&self, method: Method, route: crate::route::Route<'_>, body: Option<B>) -> Result<T> {
self.request_with_referer(method, route, body, None).await
}
async fn request_with_referer<T: DeserializeOwned, B: Serialize>(&self, method: Method, route: crate::route::Route<'_>, body: Option<B>, referer: Option<&str>) -> Result<T> {
let response = self.do_request_with_referer(method, route, body, referer, None).await?;
let text = response.text().await?;
serde_json::from_str(&text).map_err(|e| {
error!(error = %e, body = %text, "Failed to parse response");
DiscordError::Json(e)
})
}
async fn request_no_response<B: Serialize>(&self, method: Method, route: crate::route::Route<'_>, body: Option<B>) -> Result<()> {
self.do_request_with_referer(method, route, body, None, None).await?;
Ok(())
}
pub async fn request_optional<T: DeserializeOwned, B: Serialize>(&self, method: Method, route: crate::route::Route<'_>, body: Option<B>) -> Result<Option<T>> {
let response = self.do_request_with_referer(method, route, body, None, None).await?;
if response.status() == reqwest::StatusCode::NO_CONTENT {
return Ok(None);
}
let text = response.text().await?;
if text.is_empty() {
return Ok(None);
}
serde_json::from_str(&text).map(Some).map_err(|e| {
error!(error = %e, body = %text, "Failed to parse response");
DiscordError::Json(e)
})
}
pub async fn put_optional<T: DeserializeOwned, B: Serialize>(&self, route: crate::route::Route<'_>, body: B) -> Result<Option<T>> {
self.request_optional(Method::PUT, route, Some(body)).await
}
pub async fn get_bytes(&self, route: crate::route::Route<'_>) -> Result<Vec<u8>> {
let response = self.do_request_with_referer(Method::GET, route, None::<()>, None, None).await?;
let bytes = response.bytes().await?;
Ok(bytes.to_vec())
}
pub async fn request_with_reason<T: DeserializeOwned, B: Serialize>(&self, method: Method, route: crate::route::Route<'_>, body: Option<B>, reason: Option<&str>) -> Result<T> {
let response = self.do_request_with_referer(method, route, body, None, reason).await?;
let text = response.text().await?;
serde_json::from_str(&text).map_err(|e| {
error!(error = %e, body = %text, "Failed to parse response");
DiscordError::Json(e)
})
}
pub async fn request_with_reason_no_response<B: Serialize>(&self, method: Method, route: crate::route::Route<'_>, body: Option<B>, reason: Option<&str>) -> Result<()> {
self.do_request_with_referer(method, route, body, None, reason).await?;
Ok(())
}
fn auth_header_value(&self) -> &str {
if self.token.starts_with("Bot ") || self.token.starts_with("Bearer ") {
&self.token
} else {
&self.token
}
}
fn apply_common_headers(&self, mut request: reqwest::RequestBuilder, reason: Option<&str>) -> reqwest::RequestBuilder {
let auth = self.auth_header_value();
request = request.header(header::AUTHORIZATION, auth);
if let Some(sp) = self.super_properties_b64.as_deref() {
request = request.header("X-Super-Properties", sp);
}
if let Some(loc) = self.discord_locale.as_deref() {
request = request.header("X-Discord-Locale", loc);
}
if let Some(tz) = self.discord_timezone.as_deref() {
request = request.header("X-Discord-Timezone", tz);
}
if let Some(raw_reason) = reason {
let encoded = urlencoding::encode(raw_reason).into_owned();
request = request.header("X-Audit-Log-Reason", encoded);
}
request
}
async fn do_request_with_referer<B: Serialize>(&self, method: Method, route: crate::route::Route<'_>, body: Option<B>, referer: Option<&str>, reason: Option<&str>) -> Result<Response> {
let path_cow = route.path();
let endpoint = path_cow.as_ref();
let url = if endpoint.starts_with("http") { endpoint.to_string() } else { format!("{}/{}", API_BASE, endpoint.trim_start_matches('/')) };
let route_key = Ratelimit::get_route_key(&method, endpoint);
for attempt in 0..DEFAULT_RETRY_COUNT {
if !self.ratelimiter_disabled {
drop(self.ratelimit.global.lock().await);
self.ratelimit.pre_hook(&method, endpoint, &route_key).await;
}
let mut request = self.client.request(method.clone(), &url);
request = self.apply_common_headers(request, reason);
if let Some(ref_url) = referer {
request = request.header(header::REFERER, ref_url);
}
for (key, value) in &self.custom_headers {
if key.to_lowercase() != "authorization" {
request = request.header(key.as_str(), value.as_str());
}
}
if let Some(ref b) = body {
request = request.json(b);
}
debug!("Request attempt {}/{}: {} {}", attempt + 1, DEFAULT_RETRY_COUNT, method, url);
let result = request.send().await;
match result {
Ok(response) => {
let status = response.status();
if !self.ratelimiter_disabled {
self.ratelimit.post_hook(&route_key, response.headers());
}
if status.is_success() {
return Ok(response);
}
if status.as_u16() == 429 {
self.invalid_requests.record().await;
let info = Self::extract_rate_limit_info(response.headers());
if !self.ratelimiter_disabled {
self.ratelimit.record_429(&route_key, response.headers(), info.retry_after);
}
if self.ratelimiter_disabled {
return Err(DiscordError::RateLimited { retry_after: info.retry_after, bucket: info.bucket, global: info.global, scope: info.scope });
}
if info.retry_after > MAX_RETRY_AFTER_SECONDS as f64 {
return Err(DiscordError::RateLimited { retry_after: info.retry_after, bucket: info.bucket, global: info.global, scope: info.scope });
}
let wait_time = if info.retry_after < 2.0 { 2.0 } else { info.retry_after };
warn!("Rate limited (global: {}), waiting {:.2} seconds", info.global, wait_time);
let _global_guard = if info.global { Some(self.ratelimit.global.lock().await) } else { None };
if let Some(cb) = &self.ratelimit.callback {
let limit = response.headers().get("x-ratelimit-limit").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<u64>().ok()).unwrap_or(0);
cb(RatelimitInfo {
timeout: Duration::from_secs_f64(wait_time),
limit,
method: method.clone(),
path: endpoint.to_string(),
global: info.global,
});
}
tokio::time::sleep(Duration::from_secs_f64(wait_time)).await;
continue;
}
if status.as_u16() == 401 || status.as_u16() == 403 {
self.invalid_requests.record().await;
}
let error_body = match response.text().await {
Ok(text) => text,
Err(e) => {
return Err(DiscordError::Http(e));
}
};
warn!(
discord.status = status.as_u16(),
discord.method = %method,
discord.url = %url,
discord.body_len = error_body.len(),
discord.body = %error_body.chars().take(500).collect::<String>(),
"Discord API non-success response"
);
if error_body.contains("verify your account") {
return Err(DiscordError::VerificationRequired);
}
if error_body.contains("captcha_key") {
let service = serde_json::from_str::<serde_json::Value>(&error_body).ok().and_then(|v| v["captcha_service"].as_str().map(|s| s.to_string())).unwrap_or_else(|| "unknown".to_string());
return Err(DiscordError::CaptchaRequired { service });
}
if status.as_u16() == 401 {
if error_body.contains("401: Unauthorized") || error_body.contains("token") {
warn!(discord.body = %error_body, "Mapped 401 -> InvalidToken");
return Err(DiscordError::InvalidToken);
}
warn!(discord.body = %error_body, "Mapped 401 -> AuthenticationFailed");
return Err(DiscordError::AuthenticationFailed);
}
if status.as_u16() == 403 {
let permission = serde_json::from_str::<serde_json::Value>(&error_body).ok().and_then(|v| v["message"].as_str().map(|s| s.to_string())).unwrap_or_else(|| "unknown".to_string());
return Err(DiscordError::PermissionDenied { permission });
}
if status.as_u16() == 404 {
let resource_type = Self::extract_resource_type(&url);
let id = Self::extract_resource_id(&url);
return Err(DiscordError::NotFound { resource_type, id });
}
if status.as_u16() == 400 {
return Err(DiscordError::InvalidRequest(error_body));
}
error!(status = %status, error_body = %error_body, "HTTP error");
if status.is_server_error() && attempt < DEFAULT_RETRY_COUNT - 1 {
tokio::time::sleep(Duration::from_secs(2)).await;
continue;
}
if status.is_server_error() {
return Err(DiscordError::ServiceError { status: status.as_u16(), body: error_body });
}
return Err(DiscordError::UnexpectedStatusCode { status: status.as_u16(), body: error_body });
}
Err(e) => {
error!(error = %e, "Request error");
if attempt < DEFAULT_RETRY_COUNT - 1 {
tokio::time::sleep(Duration::from_secs(2)).await;
continue;
}
return Err(DiscordError::Http(e));
}
}
}
Err(DiscordError::MaxRetriesExceeded)
}
fn extract_rate_limit_info(headers: &header::HeaderMap) -> RateLimitInfo {
let retry_after = headers.get("retry-after").and_then(|h| h.to_str().ok()).and_then(|s| s.parse::<f64>().ok()).unwrap_or(5.0);
let bucket = headers.get("x-ratelimit-bucket").and_then(|h| h.to_str().ok()).map(|s| s.to_string());
let global = headers.get("x-ratelimit-global").and_then(|h| h.to_str().ok()).map(|s| s == "true").unwrap_or(false);
let scope = headers.get("x-ratelimit-scope").and_then(|h| h.to_str().ok()).map(|s| s.to_string());
RateLimitInfo { retry_after, bucket, global, scope }
}
fn extract_resource_type(url: &str) -> String {
let parts: Vec<&str> = url.split('/').collect();
for part in &parts {
match *part {
"channels" => return "channel".to_string(),
"guilds" => return "guild".to_string(),
"users" => return "user".to_string(),
"messages" => return "message".to_string(),
"members" => return "member".to_string(),
"roles" => return "role".to_string(),
"invites" => return "invite".to_string(),
"webhooks" => return "webhook".to_string(),
"emojis" => return "emoji".to_string(),
_ => continue,
}
}
"resource".to_string()
}
fn extract_resource_id(url: &str) -> String {
let parts: Vec<&str> = url.split('/').collect();
for part in parts.iter().rev() {
if part.chars().all(|c| c.is_ascii_digit()) && part.len() > 10 {
return (*part).to_string();
}
}
"unknown".to_string()
}
}
#[cfg(test)]
mod tests {
use reqwest::header::{HeaderMap, HeaderValue};
use super::*;
#[test]
fn test_extract_rate_limit_info() {
let mut headers = HeaderMap::new();
headers.insert("retry-after", HeaderValue::from_static("12.5"));
headers.insert("x-ratelimit-bucket", HeaderValue::from_static("test-bucket"));
headers.insert("x-ratelimit-global", HeaderValue::from_static("true"));
headers.insert("x-ratelimit-scope", HeaderValue::from_static("shared"));
let info = DiscordHttpClient::extract_rate_limit_info(&headers);
assert_eq!(info.retry_after, 12.5);
assert_eq!(info.bucket.unwrap(), "test-bucket");
assert!(info.global);
assert_eq!(info.scope.unwrap(), "shared");
}
#[test]
fn test_extract_rate_limit_info_defaults() {
let headers = HeaderMap::new();
let info = DiscordHttpClient::extract_rate_limit_info(&headers);
assert_eq!(info.retry_after, 5.0);
assert!(info.bucket.is_none());
assert!(!info.global);
assert!(info.scope.is_none());
}
#[tokio::test]
async fn invalid_request_tracker_counts_and_prunes() {
let tracker = InvalidRequestTracker::default();
assert_eq!(tracker.count().await, 0);
let n1 = tracker.record().await;
let n2 = tracker.record().await;
assert_eq!(n1, 1);
assert_eq!(n2, 2);
assert_eq!(tracker.count().await, 2);
}
#[test]
fn record_429_seeds_bucket_with_reset() {
let rl = Ratelimit::default();
let mut headers = HeaderMap::new();
headers.insert("x-ratelimit-limit", HeaderValue::from_static("5"));
headers.insert("x-ratelimit-reset-after", HeaderValue::from_static("3.5"));
rl.record_429("GET /test", &headers, 1.0);
let bucket = rl.buckets.get("GET /test").unwrap();
assert_eq!(bucket.remaining, 0);
assert_eq!(bucket.limit, 5);
assert!(bucket.reset_after.unwrap() > 3.0);
}
}