1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
//! This is a modified Timeout service copy/pasted from the tower codebase.
//! This Timeout is also checking if we do not timeout on the `poll_ready` and not only on the `call` part
//! Middleware that applies a timeout to requests.
//!
//! If the response does not complete within the specified timeout, the response
//! will be aborted.

pub(crate) mod error;
pub(crate) mod future;
mod layer;

use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;

use futures::Future;
use tokio::time::Sleep;
use tower::Service;

use self::future::ResponseFuture;
pub(crate) use self::layer::TimeoutLayer;
pub(crate) use crate::plugins::traffic_shaping::timeout::error::Elapsed;

/// Applies a timeout to requests.
#[derive(Debug)]
pub(crate) struct Timeout<T> {
    inner: T,
    timeout: Duration,
    sleep: Option<Pin<Box<Sleep>>>,
}

// ===== impl Timeout =====

impl<T> Timeout<T> {
    /// Creates a new [`Timeout`]
    pub(crate) fn new(inner: T, timeout: Duration) -> Self {
        Timeout {
            inner,
            timeout,
            sleep: None,
        }
    }
}

impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
    S::Error: Into<tower::BoxError>,
{
    type Response = S::Response;
    type Error = tower::BoxError;
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        if self.sleep.is_none() {
            self.sleep = Some(Box::pin(tokio::time::sleep(self.timeout)));
        }
        match self.inner.poll_ready(cx) {
            Poll::Pending => {}
            Poll::Ready(r) => return Poll::Ready(r.map_err(Into::into)),
        };

        // Checking if we don't timeout on `poll_ready`
        if Pin::new(
            &mut self
                .sleep
                .as_mut()
                .expect("we can unwrap because we set it just before"),
        )
        .poll(cx)
        .is_ready()
        {
            tracing::trace!("timeout exceeded.");
            self.sleep = None;

            return Poll::Ready(Err(Elapsed::new().into()));
        }

        Poll::Pending
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let response = self.inner.call(request);

        ResponseFuture::new(
            response,
            self.sleep
                .take()
                .expect("poll_ready must been called before"),
        )
    }
}