graphql_starter/
timeout.rs1use std::{
5 future::Future,
6 pin::Pin,
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use axum::response::{IntoResponse, Response};
12use error_info::ErrorInfo;
13use http::Request;
14use pin_project_lite::pin_project;
15use tokio::time::Sleep;
16use tower::{Layer, Service};
17
18use crate::error::{ApiError, Error};
19
20#[derive(Debug, Clone, Copy)]
24pub struct TimeoutLayer<T: ErrorInfo + Send + Sync + Copy + 'static> {
25 timeout: Duration,
26 response: T,
27}
28
29impl<T> TimeoutLayer<T>
30where
31 T: ErrorInfo + Send + Sync + Copy + 'static,
32{
33 pub fn new(timeout: Duration, response: T) -> Self {
35 TimeoutLayer { timeout, response }
36 }
37}
38
39impl<T, S> Layer<S> for TimeoutLayer<T>
40where
41 T: ErrorInfo + Send + Sync + Copy + 'static,
42{
43 type Service = Timeout<S, T>;
44
45 fn layer(&self, inner: S) -> Self::Service {
46 Timeout::new(inner, self.timeout, self.response)
47 }
48}
49
50#[derive(Debug, Clone, Copy)]
57pub struct Timeout<S, T> {
58 inner: S,
59 timeout: Duration,
60 response: T,
61}
62
63impl<S, T> Timeout<S, T>
64where
65 T: ErrorInfo + Send + Sync + Copy + 'static,
66{
67 pub fn new(inner: S, timeout: Duration, response: T) -> Self {
69 Self {
70 inner,
71 timeout,
72 response,
73 }
74 }
75}
76
77impl<S, T, ReqBody> Service<Request<ReqBody>> for Timeout<S, T>
78where
79 S: Service<Request<ReqBody>, Response = Response>,
80 T: ErrorInfo + Send + Sync + Copy + 'static,
81{
82 type Error = S::Error;
83 type Future = ResponseFuture<S::Future, T>;
84 type Response = S::Response;
85
86 #[inline]
87 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88 self.inner.poll_ready(cx)
89 }
90
91 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
92 let sleep = tokio::time::sleep(self.timeout);
93 ResponseFuture {
94 inner: self.inner.call(req),
95 sleep,
96 response: self.response,
97 }
98 }
99}
100
101pin_project! {
102 pub struct ResponseFuture<F,T> {
104 #[pin]
105 inner: F,
106 #[pin]
107 sleep: Sleep,
108 #[pin]
109 response: T,
110 }
111}
112
113impl<F, T, E> Future for ResponseFuture<F, T>
114where
115 F: Future<Output = Result<Response, E>>,
116 T: ErrorInfo + Send + Sync + Copy + 'static,
117{
118 type Output = Result<Response, E>;
119
120 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
121 let this = self.project();
122
123 if this.sleep.poll(cx).is_ready() {
124 let err = ApiError::from_err(Error::new(*this.response));
125 return Poll::Ready(Ok(err.into_response()));
126 }
127
128 this.inner.poll(cx)
129 }
130}