graphql_starter/
timeout.rs

1//! Based on https://github.com/tower-rs/tower-http/blob/main/tower-http/src/timeout/service.rs, but allowing to
2//! customize the response
3
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use axum::response::{IntoResponse, Response};
12use error_info::ErrorInfo;
13use http::Request;
14use pin_project_lite::pin_project;
15use tokio::time::Sleep;
16use tower::{Layer, Service};
17
18use crate::error::{ApiError, Error};
19
20/// Layer that applies the [`Timeout`] middleware which apply a timeout to requests.
21///
22/// See the [module docs](super) for an example.
23#[derive(Debug, Clone, Copy)]
24pub struct TimeoutLayer<T: ErrorInfo + Send + Sync + Copy + 'static> {
25    timeout: Duration,
26    response: T,
27}
28
29impl<T> TimeoutLayer<T>
30where
31    T: ErrorInfo + Send + Sync + Copy + 'static,
32{
33    /// Creates a new [`TimeoutLayer`].
34    pub fn new(timeout: Duration, response: T) -> Self {
35        TimeoutLayer { timeout, response }
36    }
37}
38
39impl<T, S> Layer<S> for TimeoutLayer<T>
40where
41    T: ErrorInfo + Send + Sync + Copy + 'static,
42{
43    type Service = Timeout<S, T>;
44
45    fn layer(&self, inner: S) -> Self::Service {
46        Timeout::new(inner, self.timeout, self.response)
47    }
48}
49
50/// Middleware which apply a timeout to requests.
51///
52/// If the request does not complete within the specified timeout it will be aborted and a `408
53/// Request Timeout` response will be sent.
54///
55/// See the [module docs](super) for an example.
56#[derive(Debug, Clone, Copy)]
57pub struct Timeout<S, T> {
58    inner: S,
59    timeout: Duration,
60    response: T,
61}
62
63impl<S, T> Timeout<S, T>
64where
65    T: ErrorInfo + Send + Sync + Copy + 'static,
66{
67    /// Creates a new [`Timeout`].
68    pub fn new(inner: S, timeout: Duration, response: T) -> Self {
69        Self {
70            inner,
71            timeout,
72            response,
73        }
74    }
75}
76
77impl<S, T, ReqBody> Service<Request<ReqBody>> for Timeout<S, T>
78where
79    S: Service<Request<ReqBody>, Response = Response>,
80    T: ErrorInfo + Send + Sync + Copy + 'static,
81{
82    type Error = S::Error;
83    type Future = ResponseFuture<S::Future, T>;
84    type Response = S::Response;
85
86    #[inline]
87    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88        self.inner.poll_ready(cx)
89    }
90
91    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
92        let sleep = tokio::time::sleep(self.timeout);
93        ResponseFuture {
94            inner: self.inner.call(req),
95            sleep,
96            response: self.response,
97        }
98    }
99}
100
101pin_project! {
102    /// Response future for [`Timeout`].
103    pub struct ResponseFuture<F,T> {
104        #[pin]
105        inner: F,
106        #[pin]
107        sleep: Sleep,
108        #[pin]
109        response: T,
110    }
111}
112
113impl<F, T, E> Future for ResponseFuture<F, T>
114where
115    F: Future<Output = Result<Response, E>>,
116    T: ErrorInfo + Send + Sync + Copy + 'static,
117{
118    type Output = Result<Response, E>;
119
120    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
121        let this = self.project();
122
123        if this.sleep.poll(cx).is_ready() {
124            let err = ApiError::from_err(Error::new(*this.response));
125            return Poll::Ready(Ok(err.into_response()));
126        }
127
128        this.inner.poll(cx)
129    }
130}