1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//! Middleware for controlling requests timeout.
//!
//! If the request does not complete within the specified timeout it will be aborted and a `503 Service Unavailable`
//! response will be sent.
//!
//! This middleware can be used to deal with slow network attacks.
//!
//! # Example
//!
//! ```no_run
//! use std::time::Duration;
//!
//! use salvo_core::prelude::*;
//! use salvo_extra::timeout::Timeout;
//!
//! #[handler]
//! async fn fast() -> &'static str {
//!     "hello"
//! }
//! #[handler]
//! async fn slow() -> &'static str {
//!     tokio::time::sleep(Duration::from_secs(6)).await;
//!     "hello"
//! }
//!
//! #[tokio::main]
//! async fn main() {
//!     let router = Router::new()
//!         .hoop(Timeout::new(Duration::from_secs(5)))
//!         .push(Router::with_path("slow").get(slow))
//!         .push(Router::with_path("fast").get(fast));
//!
//!     let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
//!     Server::new(acceptor).serve(router).await;
//! }
//! ```

use std::time::Duration;

use salvo_core::http::headers::{Connection, HeaderMapExt};
use salvo_core::http::{Request, Response, StatusError};
use salvo_core::{async_trait, Depot, FlowCtrl, Handler};

/// Middleware for controlling request timeout.
/// 
/// View [module level documentation](index.html) for more details.
pub struct Timeout {
    value: Duration,
    error: Box<dyn Fn() -> StatusError + Send + Sync + 'static>,
}
impl Timeout {
    /// Create a new `Timeout`.
    #[inline]
    pub fn new(value: Duration) -> Self {
        // If a 408 error code is returned, the browser may resend the request multiple times. In most cases,
        // this behavior is undesirable.
        // https://github.com/tower-rs/tower-http/issues/300
        Timeout {
            value,
            error: Box::new(|| StatusError::service_unavailable().brief("Server process the request timeout.")),
        }
    }

    /// Custom error returned when timeout.
    ///
    /// By default, a `503 Service Unavailable` error is returned. You can set this function to other error types,
    /// such as `403 Request Timeout`, but the 403 error code may cause the browser to automatically resend the
    /// request multiple times.
    pub fn error(mut self, error: impl Fn() -> StatusError + Send + Sync + 'static) -> Self {
        self.error = Box::new(error);
        self
    }
}
#[async_trait]
impl Handler for Timeout {
    #[inline]
    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
        tokio::select! {
            _ = ctrl.call_next(req, depot, res) => {},
            _ = tokio::time::sleep(self.value) => {
                res.headers_mut().typed_insert(Connection::close());
                res.render((self.error)());
                ctrl.skip_rest();
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use salvo_core::prelude::*;
    use salvo_core::test::{ResponseExt, TestClient};

    use super::*;

    #[tokio::test]
    async fn test_timeout_handler() {
        #[handler]
        async fn fast() -> &'static str {
            "hello"
        }
        #[handler]
        async fn slow() -> &'static str {
            tokio::time::sleep(Duration::from_secs(6)).await;
            "hello"
        }

        let router = Router::new()
            .hoop(Timeout::new(Duration::from_secs(5)))
            .push(Router::with_path("slow").get(slow))
            .push(Router::with_path("fast").get(fast));
        let service = Service::new(router);

        let content = TestClient::get("http://127.0.0.1:5801/slow")
            .send(&service)
            .await
            .take_string()
            .await
            .unwrap();
        assert!(content.contains("timeout"));

        let content = TestClient::get("http://127.0.0.1:5801/fast")
            .send(&service)
            .await
            .take_string()
            .await
            .unwrap();
        assert!(content.contains("hello"));
    }
}