hyperdriver/service/
timeout.rs1use std::time::Duration;
6
7pub struct TimeoutLayer<E> {
9 error: Box<fn() -> E>,
10 timeout: Duration,
11}
12
13impl<E> TimeoutLayer<E> {
14 pub fn new(error: fn() -> E, timeout: Duration) -> Self {
16 Self {
17 error: Box::new(error),
18 timeout,
19 }
20 }
21}
22
23impl<E> Clone for TimeoutLayer<E> {
24 fn clone(&self) -> Self {
25 Self {
26 error: self.error.clone(),
27 timeout: self.timeout,
28 }
29 }
30}
31
32impl<E> std::fmt::Debug for TimeoutLayer<E> {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("TimeoutLayer")
35 .field("timeout", &self.timeout)
36 .finish()
37 }
38}
39
40impl<S, E> tower::layer::Layer<S> for TimeoutLayer<E> {
41 type Service = Timeout<S, E>;
42
43 fn layer(&self, inner: S) -> Self::Service {
44 Timeout::new(inner, self.timeout, self.error.clone())
45 }
46}
47
48pub struct Timeout<S, E> {
50 inner: S,
51 timeout: Duration,
52 error: Box<fn() -> E>,
53}
54
55impl<S, E> Timeout<S, E> {
56 pub fn new(inner: S, timeout: Duration, error: Box<fn() -> E>) -> Self {
58 Self {
59 inner,
60 timeout,
61 error,
62 }
63 }
64}
65
66impl<S, E> Clone for Timeout<S, E>
67where
68 S: Clone,
69{
70 fn clone(&self) -> Self {
71 Self {
72 inner: self.inner.clone(),
73 timeout: self.timeout,
74 error: self.error.clone(),
75 }
76 }
77}
78
79impl<S, E> std::fmt::Debug for Timeout<S, E>
80where
81 S: std::fmt::Debug,
82{
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("Timeout")
85 .field("inner", &self.inner)
86 .field("timeout", &self.timeout)
87 .finish()
88 }
89}
90
91impl<S, E, Req> tower::Service<Req> for Timeout<S, E>
92where
93 S: tower::Service<Req, Error = E>,
94{
95 type Response = S::Response;
96 type Error = E;
97 type Future = self::future::TimeoutFuture<S::Future, S::Response, E>;
98
99 fn poll_ready(
100 &mut self,
101 cx: &mut std::task::Context<'_>,
102 ) -> std::task::Poll<Result<(), Self::Error>> {
103 self.inner.poll_ready(cx).map_err(Into::into)
104 }
105
106 fn call(&mut self, req: Req) -> Self::Future {
107 self::future::TimeoutFuture::new(self.inner.call(req), self.error.clone(), self.timeout)
108 }
109}
110
111mod future {
112
113 use std::{future::Future, marker::PhantomData, task::Poll};
114
115 use pin_project::pin_project;
116
117 #[derive(Debug)]
118 #[pin_project]
119 pub struct TimeoutFuture<F, R, E> {
120 #[pin]
121 inner: F,
122 error: Box<fn() -> E>,
123 response: PhantomData<fn() -> R>,
124
125 #[pin]
126 timeout: tokio::time::Sleep,
127 }
128
129 impl<F, R, E> TimeoutFuture<F, R, E> {
130 pub fn new(inner: F, error: Box<fn() -> E>, timeout: std::time::Duration) -> Self {
131 Self {
132 inner,
133 error,
134 response: PhantomData,
135 timeout: tokio::time::sleep(timeout),
136 }
137 }
138 }
139
140 impl<F, R, E> Future for TimeoutFuture<F, R, E>
141 where
142 F: Future<Output = Result<R, E>>,
143 {
144 type Output = Result<R, E>;
145
146 fn poll(
147 self: std::pin::Pin<&mut Self>,
148 cx: &mut std::task::Context<'_>,
149 ) -> Poll<Self::Output> {
150 let this = self.project();
151
152 match this.inner.poll(cx) {
153 Poll::Ready(response) => return Poll::Ready(response),
154 Poll::Pending => {}
155 }
156
157 match this.timeout.poll(cx) {
158 Poll::Ready(()) => Poll::Ready(Err((this.error)())),
159 Poll::Pending => Poll::Pending,
160 }
161 }
162 }
163}