rosu_render/request/
future.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    pin::{pin, Pin},
5    task::{Context, Poll},
6};
7
8use http_body_util::{combinators::Collect, BodyExt};
9use hyper::{body::Incoming, StatusCode};
10use hyper_util::client::legacy::ResponseFuture;
11use leaky_bucket::AcquireOwned;
12use pin_project::pin_project;
13use serde::de::DeserializeOwned;
14
15use crate::ClientError;
16
17use super::requestable::Requestable;
18
19#[pin_project(project = OrdrFutureProj)]
20pub struct OrdrFuture<T> {
21    #[pin]
22    ratelimit: Option<AcquireOwned>,
23    #[pin]
24    state: OrdrFutureState<T>,
25}
26
27impl<T> OrdrFuture<T> {
28    pub(crate) const fn new(fut: Pin<Box<ResponseFuture>>, ratelimit: AcquireOwned) -> Self {
29        Self {
30            ratelimit: Some(ratelimit),
31            state: OrdrFutureState::InFlight(InFlight {
32                fut,
33                phantom: PhantomData,
34            }),
35        }
36    }
37
38    pub(crate) const fn error(source: ClientError) -> Self {
39        Self {
40            ratelimit: None,
41            state: OrdrFutureState::Failed(Some(source)),
42        }
43    }
44
45    fn await_ratelimit(
46        mut ratelimit_opt: Pin<&mut Option<AcquireOwned>>,
47        cx: &mut Context<'_>,
48    ) -> Poll<()> {
49        if let Some(ratelimit) = ratelimit_opt.as_mut().as_pin_mut() {
50            match ratelimit.poll(cx) {
51                Poll::Ready(()) => ratelimit_opt.set(None),
52                Poll::Pending => return Poll::Pending,
53            }
54        }
55
56        Poll::Ready(())
57    }
58}
59
60impl<T: DeserializeOwned + Requestable> Future for OrdrFuture<T> {
61    type Output = Result<T, ClientError>;
62
63    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64        let mut this = self.project();
65        let mut state = this.state.as_mut();
66
67        match state.as_mut().project() {
68            OrdrFutureStateProj::InFlight(in_flight) => {
69                if Self::await_ratelimit(this.ratelimit, cx).is_pending() {
70                    return Poll::Pending;
71                }
72
73                match in_flight.poll(cx) {
74                    Poll::Ready(Ok(chunking)) => {
75                        state.set(OrdrFutureState::Chunking(chunking));
76                        cx.waker().wake_by_ref();
77
78                        Poll::Pending
79                    }
80                    Poll::Ready(Err(err)) => {
81                        state.set(OrdrFutureState::Completed);
82
83                        Poll::Ready(Err(err))
84                    }
85                    Poll::Pending => Poll::Pending,
86                }
87            }
88            OrdrFutureStateProj::Chunking(chunking) => match chunking.poll(cx) {
89                Poll::Ready(res) => {
90                    state.set(OrdrFutureState::Completed);
91
92                    Poll::Ready(res)
93                }
94                Poll::Pending => Poll::Pending,
95            },
96            OrdrFutureStateProj::Failed(failed) => {
97                let err = failed.take().expect("error already taken");
98                state.set(OrdrFutureState::Completed);
99
100                Poll::Ready(Err(err))
101            }
102            OrdrFutureStateProj::Completed => panic!("future already completed"),
103        }
104    }
105}
106
107#[pin_project(project = OrdrFutureStateProj)]
108enum OrdrFutureState<T> {
109    Chunking(#[pin] Chunking<T>),
110    Completed,
111    Failed(Option<ClientError>),
112    InFlight(#[pin] InFlight<T>),
113}
114
115#[pin_project]
116struct Chunking<T> {
117    #[pin]
118    fut: Collect<Incoming>,
119    status: StatusCode,
120    phantom: PhantomData<T>,
121}
122
123impl<T: DeserializeOwned + Requestable> Future for Chunking<T> {
124    type Output = Result<T, ClientError>;
125
126    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        let this = self.project();
128
129        let bytes = match this.fut.poll(cx) {
130            Poll::Ready(Ok(collected)) => collected.to_bytes(),
131            Poll::Ready(Err(source)) => {
132                return Poll::Ready(Err(ClientError::ChunkingResponse { source }))
133            }
134            Poll::Pending => return Poll::Pending,
135        };
136
137        let res = if this.status.is_success() {
138            match serde_json::from_slice(&bytes) {
139                Ok(this) => Ok(this),
140                Err(source) => Err(ClientError::Parsing {
141                    body: bytes.into(),
142                    source,
143                }),
144            }
145        } else {
146            Err(<T as Requestable>::response_error(*this.status, bytes))
147        };
148
149        Poll::Ready(res)
150    }
151}
152
153#[pin_project]
154struct InFlight<T> {
155    #[pin]
156    fut: Pin<Box<ResponseFuture>>,
157    phantom: PhantomData<T>,
158}
159
160impl<T: Requestable> Future for InFlight<T> {
161    type Output = Result<Chunking<T>, ClientError>;
162
163    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        let this = self.project();
165
166        let response = match this.fut.poll(cx) {
167            Poll::Ready(Ok(response)) => response,
168            Poll::Ready(Err(source)) => {
169                return Poll::Ready(Err(ClientError::RequestError { source }))
170            }
171            Poll::Pending => return Poll::Pending,
172        };
173
174        let status = response.status();
175
176        match status {
177            StatusCode::TOO_MANY_REQUESTS => warn!("429 response: {response:?}"),
178            StatusCode::SERVICE_UNAVAILABLE => {
179                return Poll::Ready(Err(ClientError::ServiceUnavailable {
180                    response: Box::new(response),
181                }))
182            }
183            _ => {}
184        }
185
186        Poll::Ready(Ok(Chunking {
187            fut: response.into_body().collect(),
188            status,
189            phantom: PhantomData,
190        }))
191    }
192}