use rustapi_core::{
middleware::BoxedNext, middleware::MiddlewareLayer, Request, Response, ResponseBody,
};
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
#[derive(Clone)]
pub struct TimeoutLayer {
timeout: Duration,
}
impl TimeoutLayer {
pub fn new(timeout: Duration) -> Self {
Self { timeout }
}
pub fn from_secs(secs: u64) -> Self {
Self::new(Duration::from_secs(secs))
}
pub fn from_millis(millis: u64) -> Self {
Self::new(Duration::from_millis(millis))
}
}
impl MiddlewareLayer for TimeoutLayer {
fn call(
&self,
req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
let timeout = self.timeout;
Box::pin(async move {
match tokio::time::timeout(timeout, next(req)).await {
Ok(response) => response,
Err(_) => {
http::Response::builder()
.status(408)
.header("Content-Type", "application/json")
.body(ResponseBody::Full(http_body_util::Full::new(bytes::Bytes::from(
serde_json::json!({
"error": {
"type": "request_timeout",
"message": format!("Request exceeded timeout of {}ms", timeout.as_millis())
}
})
.to_string(),
))))
.unwrap()
}
}
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use rustapi_core::middleware::MiddlewareLayer;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
#[tokio::test]
async fn timeout_fires_on_slow_request() {
let timeout_layer = TimeoutLayer::from_millis(100);
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
sleep(Duration::from_millis(200)).await;
http::Response::builder()
.status(200)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("OK"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let response = timeout_layer.call(req, next).await;
assert_eq!(response.status(), 408);
}
#[tokio::test]
async fn timeout_allows_fast_request() {
let timeout_layer = TimeoutLayer::from_millis(200);
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
sleep(Duration::from_millis(50)).await;
http::Response::builder()
.status(200)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("OK"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let response = timeout_layer.call(req, next).await;
assert_eq!(response.status(), 200);
}
}