use super::{Response, StatusCode};
use crate::{
api_error::ApiError,
error::{Error, ErrorType},
};
use hyper::{client::ResponseFuture as HyperResponseFuture, StatusCode as HyperStatusCode};
use std::{
future::Future,
marker::PhantomData,
mem,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
use tokio::time::{self, Timeout};
use twilight_http_ratelimiting::{ticket::TicketSender, RatelimitHeaders, WaitForTicketFuture};
type Output<T> = Result<Response<T>, Error>;
enum InnerPoll<T> {
Advance(ResponseFutureStage),
Pending(ResponseFutureStage),
Ready(Output<T>),
}
struct Chunking {
future: Pin<Box<dyn Future<Output = Result<Vec<u8>, Error>> + Send + Sync + 'static>>,
status: HyperStatusCode,
}
impl Chunking {
fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
let bytes = match Pin::new(&mut self.future).poll(cx) {
Poll::Ready(Ok(bytes)) => bytes,
Poll::Ready(Err(source)) => return InnerPoll::Ready(Err(source)),
Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::Chunking(self)),
};
let error = match crate::json::from_bytes::<ApiError>(&bytes) {
Ok(error) => error,
Err(source) => {
return InnerPoll::Ready(Err(Error {
kind: ErrorType::Parsing { body: bytes },
source: Some(Box::new(source)),
}));
}
};
InnerPoll::Ready(Err(Error {
kind: ErrorType::Response {
body: bytes,
error,
status: StatusCode::new(self.status.as_u16()),
},
source: None,
}))
}
}
struct Failed {
source: Error,
}
impl Failed {
fn poll<T>(self, _: &mut Context<'_>) -> InnerPoll<T> {
InnerPoll::Ready(Err(self.source))
}
}
struct InFlight {
future: Pin<Box<Timeout<HyperResponseFuture>>>,
invalid_token: Option<Arc<AtomicBool>>,
tx: Option<TicketSender>,
}
impl InFlight {
fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
let resp = match Pin::new(&mut self.future).poll(cx) {
Poll::Ready(Ok(Ok(resp))) => resp,
Poll::Ready(Ok(Err(source))) => {
return InnerPoll::Ready(Err(Error {
kind: ErrorType::RequestError,
source: Some(Box::new(source)),
}))
}
Poll::Ready(Err(source)) => {
return InnerPoll::Ready(Err(Error {
kind: ErrorType::RequestTimedOut,
source: Some(Box::new(source)),
}))
}
Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::InFlight(self)),
};
if resp.status() == HyperStatusCode::UNAUTHORIZED {
if let Some(invalid_token) = self.invalid_token {
invalid_token.store(true, Ordering::Relaxed);
}
}
if let Some(tx) = self.tx {
let headers = resp
.headers()
.iter()
.map(|(key, value)| (key.as_str(), value.as_bytes()));
match RatelimitHeaders::from_pairs(headers) {
Ok(v) => {
let _res = tx.headers(Some(v));
}
Err(source) => {
tracing::warn!("header parsing failed: {source:?}; {resp:?}");
let _res = tx.headers(None);
}
}
}
let status = resp.status();
if status.is_success() {
#[cfg(feature = "decompression")]
let mut resp = resp;
#[cfg(feature = "decompression")]
resp.headers_mut().remove(hyper::header::CONTENT_LENGTH);
return InnerPoll::Ready(Ok(Response::new(resp)));
}
match status {
HyperStatusCode::TOO_MANY_REQUESTS => {
tracing::warn!("429 response: {resp:?}");
}
HyperStatusCode::SERVICE_UNAVAILABLE => {
return InnerPoll::Ready(Err(Error {
kind: ErrorType::ServiceUnavailable { response: resp },
source: None,
}));
}
_ => {}
}
let fut = async {
Response::<()>::new(resp)
.bytes()
.await
.map_err(|source| Error {
kind: ErrorType::ChunkingResponse,
source: Some(Box::new(source)),
})
};
InnerPoll::Advance(ResponseFutureStage::Chunking(Chunking {
future: Box::pin(fut),
status,
}))
}
}
struct RatelimitQueue {
invalid_token: Option<Arc<AtomicBool>>,
response_future: HyperResponseFuture,
timeout: Duration,
pre_flight_check: Option<Box<dyn FnOnce() -> bool + Send + 'static>>,
wait_for_sender: WaitForTicketFuture,
}
impl RatelimitQueue {
fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
let tx = match Pin::new(&mut self.wait_for_sender).poll(cx) {
Poll::Ready(Ok(tx)) => tx,
Poll::Ready(Err(source)) => {
return InnerPoll::Ready(Err(Error {
kind: ErrorType::RatelimiterTicket,
source: Some(source),
}))
}
Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::RatelimitQueue(self)),
};
if let Some(pre_flight_check) = self.pre_flight_check {
if !pre_flight_check() {
return InnerPoll::Ready(Err(Error {
kind: ErrorType::RequestCanceled,
source: None,
}));
}
}
InnerPoll::Advance(ResponseFutureStage::InFlight(InFlight {
future: Box::pin(time::timeout(self.timeout, self.response_future)),
invalid_token: self.invalid_token,
tx: Some(tx),
}))
}
}
enum ResponseFutureStage {
Chunking(Chunking),
Completed,
Failed(Failed),
InFlight(InFlight),
RatelimitQueue(RatelimitQueue),
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ResponseFuture<T> {
phantom: PhantomData<T>,
stage: ResponseFutureStage,
}
impl<T> ResponseFuture<T> {
pub(crate) const fn new(
future: Pin<Box<Timeout<HyperResponseFuture>>>,
invalid_token: Option<Arc<AtomicBool>>,
) -> Self {
Self {
phantom: PhantomData,
stage: ResponseFutureStage::InFlight(InFlight {
future,
invalid_token,
tx: None,
}),
}
}
pub fn set_pre_flight(
&mut self,
pre_flight: Box<dyn FnOnce() -> bool + Send + 'static>,
) -> bool {
if let ResponseFutureStage::RatelimitQueue(queue) = &mut self.stage {
queue.pre_flight_check = Some(pre_flight);
true
} else {
false
}
}
pub(crate) const fn error(source: Error) -> Self {
Self {
phantom: PhantomData,
stage: ResponseFutureStage::Failed(Failed { source }),
}
}
pub(crate) fn ratelimit(
invalid_token: Option<Arc<AtomicBool>>,
response_future: HyperResponseFuture,
timeout: Duration,
wait_for_sender: WaitForTicketFuture,
) -> Self {
Self {
phantom: PhantomData,
stage: ResponseFutureStage::RatelimitQueue(RatelimitQueue {
invalid_token,
response_future,
timeout,
pre_flight_check: None,
wait_for_sender,
}),
}
}
}
impl<T: Unpin> Future for ResponseFuture<T> {
type Output = Output<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let stage = mem::replace(&mut self.stage, ResponseFutureStage::Completed);
let result = match stage {
ResponseFutureStage::Chunking(chunking) => chunking.poll(cx),
ResponseFutureStage::Completed => panic!("future already completed"),
ResponseFutureStage::Failed(failed) => failed.poll(cx),
ResponseFutureStage::InFlight(in_flight) => in_flight.poll(cx),
ResponseFutureStage::RatelimitQueue(queue) => queue.poll(cx),
};
match result {
InnerPoll::Advance(stage) => {
self.stage = stage;
}
InnerPoll::Pending(stage) => {
self.stage = stage;
return Poll::Pending;
}
InnerPoll::Ready(output) => {
self.stage = ResponseFutureStage::Completed;
return Poll::Ready(output);
}
}
}
}
}