rama_http/layer/timeout/
service.rs

1use super::TimeoutBody;
2use crate::{Request, Response, StatusCode};
3use rama_core::{Context, Layer, Service};
4use rama_utils::macros::define_inner_service_accessors;
5use std::fmt;
6use std::time::Duration;
7
8/// Layer that applies the [`Timeout`] middleware which apply a timeout to requests.
9///
10/// See the [module docs](super) for an example.
11#[derive(Debug, Clone)]
12pub struct TimeoutLayer {
13    timeout: Duration,
14}
15
16impl TimeoutLayer {
17    /// Creates a new [`TimeoutLayer`].
18    pub const fn new(timeout: Duration) -> Self {
19        TimeoutLayer { timeout }
20    }
21}
22
23impl<S> Layer<S> for TimeoutLayer {
24    type Service = Timeout<S>;
25
26    fn layer(&self, inner: S) -> Self::Service {
27        Timeout::new(inner, self.timeout)
28    }
29}
30
31/// Middleware which apply a timeout to requests.
32///
33/// If the request does not complete within the specified timeout it will be aborted and a `408
34/// Request Timeout` response will be sent.
35///
36/// See the [module docs](super) for an example.
37pub struct Timeout<S> {
38    inner: S,
39    timeout: Duration,
40}
41
42impl<S> Timeout<S> {
43    /// Creates a new [`Timeout`].
44    pub const fn new(inner: S, timeout: Duration) -> Self {
45        Self { inner, timeout }
46    }
47
48    define_inner_service_accessors!();
49}
50
51impl<S: fmt::Debug> fmt::Debug for Timeout<S> {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("Timeout")
54            .field("inner", &self.inner)
55            .field("timeout", &self.timeout)
56            .finish()
57    }
58}
59
60impl<S: Clone> Clone for Timeout<S> {
61    fn clone(&self) -> Self {
62        Timeout {
63            inner: self.inner.clone(),
64            timeout: self.timeout,
65        }
66    }
67}
68
69impl<S: Copy> Copy for Timeout<S> {}
70
71impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for Timeout<S>
72where
73    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
74    ReqBody: Send + 'static,
75    ResBody: Default + Send + 'static,
76    State: Clone + Send + Sync + 'static,
77{
78    type Response = S::Response;
79    type Error = S::Error;
80
81    async fn serve(
82        &self,
83        ctx: Context<State>,
84        req: Request<ReqBody>,
85    ) -> Result<Self::Response, Self::Error> {
86        tokio::select! {
87            res = self.inner.serve(ctx, req) => res,
88            _ = tokio::time::sleep(self.timeout) => {
89                let mut res = Response::new(ResBody::default());
90                *res.status_mut() = StatusCode::REQUEST_TIMEOUT;
91                Ok(res)
92            }
93        }
94    }
95}
96
97/// Applies a [`TimeoutBody`] to the request body.
98#[derive(Clone, Debug)]
99pub struct RequestBodyTimeoutLayer {
100    timeout: Duration,
101}
102
103impl RequestBodyTimeoutLayer {
104    /// Creates a new [`RequestBodyTimeoutLayer`].
105    pub fn new(timeout: Duration) -> Self {
106        Self { timeout }
107    }
108}
109
110impl<S> Layer<S> for RequestBodyTimeoutLayer {
111    type Service = RequestBodyTimeout<S>;
112
113    fn layer(&self, inner: S) -> Self::Service {
114        RequestBodyTimeout::new(inner, self.timeout)
115    }
116}
117
118/// Applies a [`TimeoutBody`] to the request body.
119#[derive(Clone, Debug)]
120pub struct RequestBodyTimeout<S> {
121    inner: S,
122    timeout: Duration,
123}
124
125impl<S> RequestBodyTimeout<S> {
126    /// Creates a new [`RequestBodyTimeout`].
127    pub fn new(service: S, timeout: Duration) -> Self {
128        Self {
129            inner: service,
130            timeout,
131        }
132    }
133
134    /// Returns a new [`Layer`] that wraps services with a [`RequestBodyTimeoutLayer`] middleware.
135    ///
136    /// [`Layer`]: tower_layer::Layer
137    pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer {
138        RequestBodyTimeoutLayer::new(timeout)
139    }
140
141    define_inner_service_accessors!();
142}
143
144impl<S, State, ReqBody> Service<State, Request<ReqBody>> for RequestBodyTimeout<S>
145where
146    S: Service<State, Request<TimeoutBody<ReqBody>>>,
147    ReqBody: Send + 'static,
148    State: Clone + Send + Sync + 'static,
149{
150    type Response = S::Response;
151    type Error = S::Error;
152
153    async fn serve(
154        &self,
155        ctx: Context<State>,
156        req: Request<ReqBody>,
157    ) -> Result<Self::Response, Self::Error> {
158        let req = req.map(|body| TimeoutBody::new(self.timeout, body));
159        self.inner.serve(ctx, req).await
160    }
161}
162
163/// Applies a [`TimeoutBody`] to the response body.
164#[derive(Clone)]
165pub struct ResponseBodyTimeoutLayer {
166    timeout: Duration,
167}
168
169impl ResponseBodyTimeoutLayer {
170    /// Creates a new [`ResponseBodyTimeoutLayer`].
171    pub fn new(timeout: Duration) -> Self {
172        Self { timeout }
173    }
174}
175
176impl<S> Layer<S> for ResponseBodyTimeoutLayer {
177    type Service = ResponseBodyTimeout<S>;
178
179    fn layer(&self, inner: S) -> Self::Service {
180        ResponseBodyTimeout::new(inner, self.timeout)
181    }
182}
183
184/// Applies a [`TimeoutBody`] to the response body.
185#[derive(Clone)]
186pub struct ResponseBodyTimeout<S> {
187    inner: S,
188    timeout: Duration,
189}
190
191impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for ResponseBodyTimeout<S>
192where
193    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
194    ReqBody: Send + 'static,
195    ResBody: Default + Send + 'static,
196    State: Clone + Send + Sync + 'static,
197{
198    type Response = Response<TimeoutBody<ResBody>>;
199    type Error = S::Error;
200
201    async fn serve(
202        &self,
203        ctx: Context<State>,
204        req: Request<ReqBody>,
205    ) -> Result<Self::Response, Self::Error> {
206        let res = self.inner.serve(ctx, req).await?;
207        let res = res.map(|body| TimeoutBody::new(self.timeout, body));
208        Ok(res)
209    }
210}
211
212impl<S> ResponseBodyTimeout<S> {
213    /// Creates a new [`ResponseBodyTimeout`].
214    pub fn new(service: S, timeout: Duration) -> Self {
215        Self {
216            inner: service,
217            timeout,
218        }
219    }
220
221    /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyTimeoutLayer`] middleware.
222    ///
223    /// [`Layer`]: tower_layer::Layer
224    pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer {
225        ResponseBodyTimeoutLayer::new(timeout)
226    }
227
228    define_inner_service_accessors!();
229}