use super::{BytesFuture, Response};
use crate::{
api_error::ApiError,
client::connector::Connector,
error::{Error, ErrorType},
};
use http::{HeaderMap, HeaderValue, Request, StatusCode, header};
use http_body_util::Full;
use hyper::body::Bytes;
use hyper_util::client::legacy::{Client as HyperClient, ResponseFuture as HyperResponseFuture};
use std::{
future::{Future, Ready, ready},
marker::PhantomData,
pin::Pin,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll, ready},
time::{Duration, Instant},
};
use tokio::time::{self, Timeout};
use twilight_http_ratelimiting::{Endpoint, Permit, PermitFuture, RateLimitHeaders, RateLimiter};
fn parse_ratelimit_headers(
headers: &HeaderMap,
) -> Result<Option<RateLimitHeaders>, Box<dyn std::error::Error>> {
match headers
.get(RateLimitHeaders::SCOPE)
.map(HeaderValue::as_bytes)
{
Some(b"global") => {
tracing::info!("globally rate limited");
Ok(None)
}
Some(b"shared") => {
let bucket = headers
.get(RateLimitHeaders::BUCKET)
.ok_or("missing bucket header")?
.as_bytes()
.to_vec();
let retry_after = headers
.get(header::RETRY_AFTER)
.ok_or("missing retry-after header")?
.to_str()?
.parse()?;
Ok(Some(RateLimitHeaders::shared(bucket, retry_after)))
}
Some(b"user") => {
let bucket = headers
.get(RateLimitHeaders::BUCKET)
.ok_or("missing bucket header")?
.as_bytes()
.to_vec();
let limit = headers
.get(RateLimitHeaders::LIMIT)
.ok_or("missing limit header")?
.to_str()?
.parse()?;
let remaining = headers
.get(RateLimitHeaders::REMAINING)
.ok_or("missing remaining header")?
.to_str()?
.parse()?;
let reset_after = headers
.get(RateLimitHeaders::RESET_AFTER)
.ok_or("missing reset-after header")?
.to_str()?
.parse()?;
Ok(Some(RateLimitHeaders {
bucket,
limit,
remaining,
reset_at: Instant::now() + Duration::from_secs_f32(reset_after),
}))
}
_ => Ok(None),
}
}
enum ResponseStageFuture {
Error {
fut: BytesFuture,
status: StatusCode,
},
RateLimitPermit(PermitFuture),
Response {
fut: Pin<Box<Timeout<HyperResponseFuture>>>,
permit: Option<Permit>,
},
}
struct PermitFutureGenerator {
rate_limiter: RateLimiter,
endpoint: Endpoint,
}
impl PermitFutureGenerator {
fn generate(&self) -> PermitFuture {
self.rate_limiter.acquire(self.endpoint.clone())
}
}
struct TimedResponseFutureGenerator {
client: HyperClient<Connector, Full<Bytes>>,
request: Request<Full<Bytes>>,
timeout: Duration,
}
impl TimedResponseFutureGenerator {
fn generate(&self) -> Pin<Box<Timeout<HyperResponseFuture>>> {
Box::pin(time::timeout(
self.timeout,
self.client.request(self.request.clone()),
))
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ResponseFuture<T>(Result<Inner<T>, Ready<Error>>);
impl<T> ResponseFuture<T> {
pub(crate) fn new(
client: HyperClient<Connector, Full<Bytes>>,
invalid_token: Option<Arc<AtomicBool>>,
request: Request<Full<Bytes>>,
span: tracing::Span,
timeout: Duration,
rate_limiter: Option<RateLimiter>,
endpoint: Endpoint,
) -> Self {
let permit_generator = rate_limiter.map(|rate_limiter| PermitFutureGenerator {
rate_limiter,
endpoint,
});
let response_generator = TimedResponseFutureGenerator {
client,
request,
timeout,
};
let stage = permit_generator.as_ref().map_or_else(
|| ResponseStageFuture::Response {
fut: response_generator.generate(),
permit: None,
},
|generator| ResponseStageFuture::RateLimitPermit(generator.generate()),
);
Self(Ok(Inner {
invalid_token,
permit_generator,
phantom: PhantomData,
pre_flight_check: None,
response_generator,
span,
stage,
}))
}
pub fn set_pre_flight<P>(&mut self, predicate: P) -> bool
where
P: Fn() -> bool + Send + 'static,
{
if let Ok(inner) = &mut self.0
&& inner.permit_generator.is_some()
&& inner.pre_flight_check.is_none()
{
inner.pre_flight_check = Some(Box::new(predicate));
true
} else {
false
}
}
pub(crate) fn error(source: Error) -> Self {
Self(Err(ready(source)))
}
}
impl<T: Unpin> Future for ResponseFuture<T> {
type Output = Result<Response<T>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = match &mut self.0 {
Ok(inner) => inner,
Err(err) => return Pin::new(err).poll(cx).map(Err),
};
let _entered = inner.span.enter();
loop {
match &mut inner.stage {
ResponseStageFuture::Error { fut, status } => {
let body = ready!(Pin::new(fut).poll(cx)).map_err(|source| Error {
kind: ErrorType::RequestError,
source: Some(Box::new(source)),
})?;
return Poll::Ready(Err(match crate::json::from_bytes::<ApiError>(&body) {
Ok(error) => Error {
kind: ErrorType::Response {
body,
error,
status: super::StatusCode::new(status.as_u16()),
},
source: None,
},
Err(source) => Error {
kind: ErrorType::Parsing { body },
source: Some(Box::new(source)),
},
}));
}
ResponseStageFuture::RateLimitPermit(fut) => {
let permit = ready!(Pin::new(fut).poll(cx));
if inner
.pre_flight_check
.as_ref()
.is_some_and(|check| !check())
{
return Poll::Ready(Err(Error {
kind: ErrorType::RequestCanceled,
source: None,
}));
}
inner.stage = ResponseStageFuture::Response {
fut: inner.response_generator.generate(),
permit: Some(permit),
};
}
ResponseStageFuture::Response { fut, permit } => {
let response = ready!(Pin::new(fut).poll(cx))
.map_err(|source| Error {
kind: ErrorType::RequestTimedOut,
source: Some(Box::new(source)),
})?
.map_err(|source| Error {
kind: ErrorType::RequestError,
source: Some(Box::new(source)),
})?;
if response.status() == StatusCode::UNAUTHORIZED
&& let Some(invalid) = &inner.invalid_token
{
invalid.store(true, Ordering::Relaxed);
}
if let Some(permit) = permit.take() {
match parse_ratelimit_headers(response.headers()) {
Ok(v) => permit.complete(v),
Err(source) => {
tracing::warn!("header parsing failed: {source}; {response:?}");
permit.complete(None);
}
}
}
if response.status().is_success() {
#[cfg(feature = "decompression")]
let mut response = response;
#[cfg(feature = "decompression")]
response.headers_mut().remove(header::CONTENT_LENGTH);
return Poll::Ready(Ok(Response::new(response)));
} else if response.status() == StatusCode::TOO_MANY_REQUESTS {
inner.stage = match &inner.permit_generator {
Some(generator) => {
ResponseStageFuture::RateLimitPermit(generator.generate())
}
None => ResponseStageFuture::Response {
fut: inner.response_generator.generate(),
permit: None,
},
};
} else {
inner.stage = ResponseStageFuture::Error {
status: response.status(),
fut: Response::<()>::new(response).bytes(),
};
}
}
}
}
}
}
struct Inner<T> {
invalid_token: Option<Arc<AtomicBool>>,
permit_generator: Option<PermitFutureGenerator>,
phantom: PhantomData<T>,
pre_flight_check: Option<Box<dyn Fn() -> bool + Send + 'static>>,
response_generator: TimedResponseFutureGenerator,
span: tracing::Span,
stage: ResponseStageFuture,
}