twilight_http/response/
future.rs

1use super::{BytesFuture, Response};
2use crate::{
3    api_error::ApiError,
4    client::connector::Connector,
5    error::{Error, ErrorType},
6};
7use http::{HeaderMap, HeaderValue, Request, StatusCode, header};
8use http_body_util::Full;
9use hyper::body::Bytes;
10use hyper_util::client::legacy::{Client as HyperClient, ResponseFuture as HyperResponseFuture};
11use std::{
12    future::{Future, Ready, ready},
13    marker::PhantomData,
14    pin::Pin,
15    sync::{
16        Arc,
17        atomic::{AtomicBool, Ordering},
18    },
19    task::{Context, Poll, ready},
20    time::{Duration, Instant},
21};
22use tokio::time::{self, Timeout};
23use twilight_http_ratelimiting::{Endpoint, Permit, PermitFuture, RateLimitHeaders, RateLimiter};
24
25/// Parse ratelimit headers from a map of headers.
26///
27/// # Errors
28///
29/// Errors if a required header is missing or if a header value is of an
30/// invalid type.
31fn parse_ratelimit_headers(
32    headers: &HeaderMap,
33) -> Result<Option<RateLimitHeaders>, Box<dyn std::error::Error>> {
34    match headers
35        .get(RateLimitHeaders::SCOPE)
36        .map(HeaderValue::as_bytes)
37    {
38        Some(b"global") => {
39            tracing::info!("globally rate limited");
40
41            Ok(None)
42        }
43        Some(b"shared") => {
44            let bucket = headers
45                .get(RateLimitHeaders::BUCKET)
46                .ok_or("missing bucket header")?
47                .as_bytes()
48                .to_vec();
49            let retry_after = headers
50                .get(header::RETRY_AFTER)
51                .ok_or("missing retry-after header")?
52                .to_str()?
53                .parse()?;
54
55            Ok(Some(RateLimitHeaders::shared(bucket, retry_after)))
56        }
57        Some(b"user") => {
58            let bucket = headers
59                .get(RateLimitHeaders::BUCKET)
60                .ok_or("missing bucket header")?
61                .as_bytes()
62                .to_vec();
63            let limit = headers
64                .get(RateLimitHeaders::LIMIT)
65                .ok_or("missing limit header")?
66                .to_str()?
67                .parse()?;
68            let remaining = headers
69                .get(RateLimitHeaders::REMAINING)
70                .ok_or("missing remaining header")?
71                .to_str()?
72                .parse()?;
73            let reset_after = headers
74                .get(RateLimitHeaders::RESET_AFTER)
75                .ok_or("missing reset-after header")?
76                .to_str()?
77                .parse()?;
78
79            Ok(Some(RateLimitHeaders {
80                bucket,
81                limit,
82                remaining,
83                reset_at: Instant::now() + Duration::from_secs_f32(reset_after),
84            }))
85        }
86        _ => Ok(None),
87    }
88}
89
90/// Sub-futures of [`ResponseFuture`].
91enum ResponseStageFuture {
92    /// Future that completes with an error response body.
93    Error {
94        /// Inner response body future.
95        fut: BytesFuture,
96        /// Erroneous response status code.
97        status: StatusCode,
98    },
99    /// Future that completes when a rate limit permit is ready.
100    RateLimitPermit(PermitFuture),
101    /// Future that completes with a response or timeout.
102    Response {
103        /// Inner timed response future.
104        fut: Pin<Box<Timeout<HyperResponseFuture>>>,
105        /// Optional rate limit permit.
106        permit: Option<Permit>,
107    },
108}
109
110/// [`PermitFuture`] generator.
111struct PermitFutureGenerator {
112    /// Rate limiter to acquire permits from.
113    rate_limiter: RateLimiter,
114    /// Rate limiter endpoint to acquire permits for.
115    endpoint: Endpoint,
116}
117
118impl PermitFutureGenerator {
119    /// Generates a permit future.
120    fn generate(&self) -> PermitFuture {
121        self.rate_limiter.acquire(self.endpoint.clone())
122    }
123}
124
125/// [`Timeout<HyperResponseFuture>`] generator.
126struct TimedResponseFutureGenerator {
127    /// HTTP client to send requests from.
128    client: HyperClient<Connector, Full<Bytes>>,
129    /// HTTP request to send.
130    request: Request<Full<Bytes>>,
131    /// Duration after which the request times out.
132    timeout: Duration,
133}
134
135impl TimedResponseFutureGenerator {
136    /// Generates a timeout response future.
137    fn generate(&self) -> Pin<Box<Timeout<HyperResponseFuture>>> {
138        Box::pin(time::timeout(
139            self.timeout,
140            self.client.request(self.request.clone()),
141        ))
142    }
143}
144
145/// Future that completes when a [`Response`] is received.
146///
147/// # Rate limits
148///
149/// Requests that exceed a rate limit are automatically and immediately retried
150/// until they succeed or fail with another error. If configured without a
151/// [`RateLimiter`], care must be taken that an external service intercepts and
152/// delays these retry requests.
153///
154/// # Canceling a response future pre-flight
155///
156/// Response futures can be canceled pre-flight via
157/// [`ResponseFuture::set_pre_flight`]. This allows you to cancel requests that
158/// are no longer necessary once they have been cleared by the ratelimit queue,
159/// which may be necessary in scenarios where requests are being spammed. Refer
160/// to its documentation for more information.
161///
162/// # Errors
163///
164/// Returns an [`ErrorType::Parsing`] error type if the request failed and the
165/// error in the response body could not be deserialized.
166///
167/// Returns an [`ErrorType::RequestCanceled`] error type if the request was
168/// canceled by the user.
169///
170/// Returns an [`ErrorType::RequestError`] error type if creating the request
171/// failed.
172///
173/// Returns an [`ErrorType::RequestTimedOut`] error type if the request timed
174/// out. The timeout value is configured via [`ClientBuilder::timeout`].
175///
176/// Returns an [`ErrorType::Response`] error type if the request failed.
177///
178/// [`ClientBuilder::timeout`]: crate::client::ClientBuilder::timeout
179/// [`ErrorType::Json`]: crate::error::ErrorType::Json
180/// [`ErrorType::Parsing`]: crate::error::ErrorType::Parsing
181/// [`ErrorType::RequestCanceled`]: crate::error::ErrorType::RequestCanceled
182/// [`ErrorType::RequestError`]: crate::error::ErrorType::RequestError
183/// [`ErrorType::RequestTimedOut`]: crate::error::ErrorType::RequestTimedOut
184/// [`ErrorType::Response`]: crate::error::ErrorType::Response
185/// [`Response`]: super::Response
186#[must_use = "futures do nothing unless you `.await` or poll them"]
187pub struct ResponseFuture<T>(Result<Inner<T>, Ready<Error>>);
188
189impl<T> ResponseFuture<T> {
190    pub(crate) fn new(
191        client: HyperClient<Connector, Full<Bytes>>,
192        invalid_token: Option<Arc<AtomicBool>>,
193        request: Request<Full<Bytes>>,
194        span: tracing::Span,
195        timeout: Duration,
196        rate_limiter: Option<RateLimiter>,
197        endpoint: Endpoint,
198    ) -> Self {
199        let permit_generator = rate_limiter.map(|rate_limiter| PermitFutureGenerator {
200            rate_limiter,
201            endpoint,
202        });
203        let response_generator = TimedResponseFutureGenerator {
204            client,
205            request,
206            timeout,
207        };
208        let stage = permit_generator.as_ref().map_or_else(
209            || ResponseStageFuture::Response {
210                fut: response_generator.generate(),
211                permit: None,
212            },
213            |generator| ResponseStageFuture::RateLimitPermit(generator.generate()),
214        );
215        Self(Ok(Inner {
216            invalid_token,
217            permit_generator,
218            phantom: PhantomData,
219            pre_flight_check: None,
220            response_generator,
221            span,
222            stage,
223        }))
224    }
225
226    /// Set a function to call after clearing the ratelimiter but prior to
227    /// sending the request to determine if the request is still valid.
228    ///
229    /// This function will be a no-op if the request has failed, has already
230    /// passed the ratelimiter, or if there is no ratelimiter configured.
231    ///
232    /// Returns whether the pre flight function was set.
233    ///
234    /// # Examples
235    ///
236    /// Delete a message, but immediately before sending the request check if
237    /// the request should still be sent:
238    ///
239    /// ```no_run
240    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
241    /// use std::{
242    ///     collections::HashSet,
243    ///     env,
244    ///     future::IntoFuture,
245    ///     sync::{Arc, Mutex},
246    /// };
247    /// use twilight_http::{Client, error::ErrorType};
248    /// use twilight_model::id::Id;
249    ///
250    /// let channel_id = Id::new(1);
251    /// let message_id = Id::new(2);
252    ///
253    /// let channels_ignored = {
254    ///     let mut map = HashSet::new();
255    ///     map.insert(channel_id);
256    ///
257    ///     Arc::new(Mutex::new(map))
258    /// };
259    ///
260    /// let client = Client::new(env::var("DISCORD_TOKEN")?);
261    /// let mut req = client.delete_message(channel_id, message_id).into_future();
262    ///
263    /// let channels_ignored_clone = channels_ignored.clone();
264    /// req.set_pre_flight(move || {
265    ///     // imagine you have some logic here to external state that checks
266    ///     // whether the request should still be performed
267    ///     let channels_ignored = channels_ignored_clone.lock().expect("channels poisoned");
268    ///
269    ///     !channels_ignored.contains(&channel_id)
270    /// });
271    ///
272    /// // the pre-flight check will cancel the request
273    /// assert!(matches!(
274    ///     req.await.unwrap_err().kind(),
275    ///     ErrorType::RequestCanceled,
276    /// ));
277    /// # Ok(()) }
278    /// ```
279    pub fn set_pre_flight<P>(&mut self, predicate: P) -> bool
280    where
281        P: Fn() -> bool + Send + 'static,
282    {
283        if let Ok(inner) = &mut self.0
284            && inner.permit_generator.is_some()
285            && inner.pre_flight_check.is_none()
286        {
287            inner.pre_flight_check = Some(Box::new(predicate));
288
289            true
290        } else {
291            false
292        }
293    }
294
295    /// Creates a future that is immediately ready with an error.
296    pub(crate) fn error(source: Error) -> Self {
297        Self(Err(ready(source)))
298    }
299}
300
301impl<T: Unpin> Future for ResponseFuture<T> {
302    type Output = Result<Response<T>, Error>;
303
304    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
305        let inner = match &mut self.0 {
306            Ok(inner) => inner,
307            Err(err) => return Pin::new(err).poll(cx).map(Err),
308        };
309
310        let _entered = inner.span.enter();
311
312        loop {
313            match &mut inner.stage {
314                ResponseStageFuture::Error { fut, status } => {
315                    let body = ready!(Pin::new(fut).poll(cx)).map_err(|source| Error {
316                        kind: ErrorType::RequestError,
317                        source: Some(Box::new(source)),
318                    })?;
319
320                    return Poll::Ready(Err(match crate::json::from_bytes::<ApiError>(&body) {
321                        Ok(error) => Error {
322                            kind: ErrorType::Response {
323                                body,
324                                error,
325                                status: super::StatusCode::new(status.as_u16()),
326                            },
327                            source: None,
328                        },
329                        Err(source) => Error {
330                            kind: ErrorType::Parsing { body },
331                            source: Some(Box::new(source)),
332                        },
333                    }));
334                }
335                ResponseStageFuture::RateLimitPermit(fut) => {
336                    let permit = ready!(Pin::new(fut).poll(cx));
337                    if inner
338                        .pre_flight_check
339                        .as_ref()
340                        .is_some_and(|check| !check())
341                    {
342                        return Poll::Ready(Err(Error {
343                            kind: ErrorType::RequestCanceled,
344                            source: None,
345                        }));
346                    }
347
348                    inner.stage = ResponseStageFuture::Response {
349                        fut: inner.response_generator.generate(),
350                        permit: Some(permit),
351                    };
352                }
353                ResponseStageFuture::Response { fut, permit } => {
354                    let response = ready!(Pin::new(fut).poll(cx))
355                        .map_err(|source| Error {
356                            kind: ErrorType::RequestTimedOut,
357                            source: Some(Box::new(source)),
358                        })?
359                        .map_err(|source| Error {
360                            kind: ErrorType::RequestError,
361                            source: Some(Box::new(source)),
362                        })?;
363
364                    if response.status() == StatusCode::UNAUTHORIZED
365                        && let Some(invalid) = &inner.invalid_token
366                    {
367                        invalid.store(true, Ordering::Relaxed);
368                    }
369
370                    if let Some(permit) = permit.take() {
371                        match parse_ratelimit_headers(response.headers()) {
372                            Ok(v) => permit.complete(v),
373                            Err(source) => {
374                                tracing::warn!("header parsing failed: {source}; {response:?}");
375
376                                permit.complete(None);
377                            }
378                        }
379                    }
380
381                    if response.status().is_success() {
382                        #[cfg(feature = "decompression")]
383                        let mut response = response;
384                        // Inaccurate since end-users can only access the decompressed body.
385                        #[cfg(feature = "decompression")]
386                        response.headers_mut().remove(header::CONTENT_LENGTH);
387
388                        return Poll::Ready(Ok(Response::new(response)));
389                    } else if response.status() == StatusCode::TOO_MANY_REQUESTS {
390                        inner.stage = match &inner.permit_generator {
391                            Some(generator) => {
392                                ResponseStageFuture::RateLimitPermit(generator.generate())
393                            }
394                            None => ResponseStageFuture::Response {
395                                fut: inner.response_generator.generate(),
396                                permit: None,
397                            },
398                        };
399                    } else {
400                        inner.stage = ResponseStageFuture::Error {
401                            status: response.status(),
402                            fut: Response::<()>::new(response).bytes(),
403                        };
404                    }
405                }
406            }
407        }
408    }
409}
410
411/// Internal response future fields.
412struct Inner<T> {
413    /// Whether the client's token is invalidated.
414    invalid_token: Option<Arc<AtomicBool>>,
415    /// Optional [`PermitFuture`] generator, if registered.
416    permit_generator: Option<PermitFutureGenerator>,
417    phantom: PhantomData<T>,
418    /// Predicate to check after completing [`ResponseStageFuture::RateLimitPermit`].
419    pre_flight_check: Option<Box<dyn Fn() -> bool + Send + 'static>>,
420    /// [`Timeout<HyperResponseFuture>`] generator.
421    response_generator: TimedResponseFutureGenerator,
422    /// This future's span.
423    span: tracing::Span,
424    /// This future's current stage.
425    stage: ResponseStageFuture,
426}