hyperdriver/service/
timeout.rs

1//! Middleware which applies a timeout to requests.
2//!
3//! Supports a custom error type.
4
5use std::time::Duration;
6
7/// Layer to apply a timeout to requests, with a custom error type.
8pub struct TimeoutLayer<E> {
9    error: Box<fn() -> E>,
10    timeout: Duration,
11}
12
13impl<E> TimeoutLayer<E> {
14    /// Create a new `TimeoutLayer` with the provided error function and timeout.
15    pub fn new(error: fn() -> E, timeout: Duration) -> Self {
16        Self {
17            error: Box::new(error),
18            timeout,
19        }
20    }
21}
22
23impl<E> Clone for TimeoutLayer<E> {
24    fn clone(&self) -> Self {
25        Self {
26            error: self.error.clone(),
27            timeout: self.timeout,
28        }
29    }
30}
31
32impl<E> std::fmt::Debug for TimeoutLayer<E> {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("TimeoutLayer")
35            .field("timeout", &self.timeout)
36            .finish()
37    }
38}
39
40impl<S, E> tower::layer::Layer<S> for TimeoutLayer<E> {
41    type Service = Timeout<S, E>;
42
43    fn layer(&self, inner: S) -> Self::Service {
44        Timeout::new(inner, self.timeout, self.error.clone())
45    }
46}
47
48/// Applies a timeout to requests, with a custom error type.
49pub struct Timeout<S, E> {
50    inner: S,
51    timeout: Duration,
52    error: Box<fn() -> E>,
53}
54
55impl<S, E> Timeout<S, E> {
56    /// Create a new `Timeout` with the provided inner service, timeout, and error function.
57    pub fn new(inner: S, timeout: Duration, error: Box<fn() -> E>) -> Self {
58        Self {
59            inner,
60            timeout,
61            error,
62        }
63    }
64}
65
66impl<S, E> Clone for Timeout<S, E>
67where
68    S: Clone,
69{
70    fn clone(&self) -> Self {
71        Self {
72            inner: self.inner.clone(),
73            timeout: self.timeout,
74            error: self.error.clone(),
75        }
76    }
77}
78
79impl<S, E> std::fmt::Debug for Timeout<S, E>
80where
81    S: std::fmt::Debug,
82{
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct("Timeout")
85            .field("inner", &self.inner)
86            .field("timeout", &self.timeout)
87            .finish()
88    }
89}
90
91impl<S, E, Req> tower::Service<Req> for Timeout<S, E>
92where
93    S: tower::Service<Req, Error = E>,
94{
95    type Response = S::Response;
96    type Error = E;
97    type Future = self::future::TimeoutFuture<S::Future, S::Response, E>;
98
99    fn poll_ready(
100        &mut self,
101        cx: &mut std::task::Context<'_>,
102    ) -> std::task::Poll<Result<(), Self::Error>> {
103        self.inner.poll_ready(cx).map_err(Into::into)
104    }
105
106    fn call(&mut self, req: Req) -> Self::Future {
107        self::future::TimeoutFuture::new(self.inner.call(req), self.error.clone(), self.timeout)
108    }
109}
110
111mod future {
112
113    use std::{future::Future, marker::PhantomData, task::Poll};
114
115    use pin_project::pin_project;
116
117    #[derive(Debug)]
118    #[pin_project]
119    pub struct TimeoutFuture<F, R, E> {
120        #[pin]
121        inner: F,
122        error: Box<fn() -> E>,
123        response: PhantomData<fn() -> R>,
124
125        #[pin]
126        timeout: tokio::time::Sleep,
127    }
128
129    impl<F, R, E> TimeoutFuture<F, R, E> {
130        pub fn new(inner: F, error: Box<fn() -> E>, timeout: std::time::Duration) -> Self {
131            Self {
132                inner,
133                error,
134                response: PhantomData,
135                timeout: tokio::time::sleep(timeout),
136            }
137        }
138    }
139
140    impl<F, R, E> Future for TimeoutFuture<F, R, E>
141    where
142        F: Future<Output = Result<R, E>>,
143    {
144        type Output = Result<R, E>;
145
146        fn poll(
147            self: std::pin::Pin<&mut Self>,
148            cx: &mut std::task::Context<'_>,
149        ) -> Poll<Self::Output> {
150            let this = self.project();
151
152            match this.inner.poll(cx) {
153                Poll::Ready(response) => return Poll::Ready(response),
154                Poll::Pending => {}
155            }
156
157            match this.timeout.poll(cx) {
158                Poll::Ready(()) => Poll::Ready(Err((this.error)())),
159                Poll::Pending => Poll::Pending,
160            }
161        }
162    }
163}