use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use axum::{
body::Body,
http::{header::HeaderValue, Request, Response, StatusCode},
response::IntoResponse,
};
use iso8601_duration::Duration as IsoDuration;
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::{
task::futures::TaskLocalFuture,
time::{sleep_until, Instant, Sleep},
};
use tower::{BoxError, Layer, Service};
use tracing::warn;
use crate::{errors, layers::ext::Deadline};
tokio::task_local! {
pub static CURRENT_DEADLINE: Option<Deadline>;
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum TimeoutError {
#[error("Request timed out")]
TimedOut,
}
impl TimeoutError {
fn http_status(&self) -> StatusCode {
match self {
Self::TimedOut => StatusCode::GATEWAY_TIMEOUT,
}
}
}
impl IntoResponse for TimeoutError {
fn into_response(self) -> Response<Body> {
problemdetails::new(self.http_status())
.with_type(errors::TAG_UXUM_TIMEOUT)
.with_title(self.to_string())
.into_response()
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[non_exhaustive]
pub struct HandlerTimeoutConfig {
#[serde(default = "crate::util::default_true")]
pub use_x_timeout: bool,
#[serde(
default,
skip_serializing_if = "Option::is_none",
alias = "timeout",
with = "humantime_serde"
)]
pub default_timeout: Option<Duration>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "humantime_serde"
)]
pub min_timeout: Option<Duration>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "humantime_serde"
)]
pub max_timeout: Option<Duration>,
}
impl Default for HandlerTimeoutConfig {
fn default() -> Self {
Self {
use_x_timeout: true,
default_timeout: None,
min_timeout: None,
max_timeout: None,
}
}
}
impl HandlerTimeoutConfig {
pub fn is_default(&self) -> bool {
*self == Self::default()
}
pub fn make_layer<S>(&self) -> Option<TimeoutLayer<S>> {
if self.use_x_timeout || self.default_timeout.is_some() {
Some(self.into())
} else {
None
}
}
pub fn get_deadline(&self, timeout_header: Option<&HeaderValue>) -> Option<Instant> {
if self.use_x_timeout && timeout_header.is_some() {
timeout_header.and_then(|h| match h.to_str() {
Ok(s) => match s.parse::<IsoDuration>() {
Ok(d) => d.to_std(),
Err(error) => {
warn!(?error, "unable to parse X-Timeout");
self.default_timeout
}
},
Err(error) => {
warn!(%error, "invalid X-Timeout value");
self.default_timeout
}
})
} else {
self.default_timeout
}
.and_then(|dur| match self.min_timeout {
Some(min) if min > dur => {
warn!(?dur, "duration is shorter than minimum");
None
}
_ => match self.max_timeout {
Some(max) if max < dur => {
warn!(?dur, "duration is longer than maximum");
None
}
_ => Some(Instant::now() + dur),
},
})
}
}
pub struct TimeoutLayer<S> {
config: HandlerTimeoutConfig,
_phantom_service: PhantomData<S>,
}
impl<S> From<&HandlerTimeoutConfig> for TimeoutLayer<S> {
fn from(value: &HandlerTimeoutConfig) -> Self {
Self {
config: value.clone(),
_phantom_service: PhantomData,
}
}
}
impl<S> Layer<S> for TimeoutLayer<S>
where
S: Service<Request<Body>>,
{
type Service = TimeoutService<S>;
fn layer(&self, inner: S) -> Self::Service {
TimeoutService::new(inner, &self.config)
}
}
#[derive(Clone, Debug)]
pub struct TimeoutService<S> {
config: Arc<HandlerTimeoutConfig>,
inner: S,
}
pub(crate) const X_TIMEOUT: &str = "x-timeout";
impl<S> Service<Request<Body>> for TimeoutService<S>
where
S: Service<Request<Body>>,
S::Error: Into<BoxError>,
{
type Response = S::Response;
type Error = BoxError;
type Future = TimeoutFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(res) => Poll::Ready(res.map_err(Into::into)),
}
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let deadline = self.config.get_deadline(req.headers().get(X_TIMEOUT));
let deadline_obj = deadline.map(Deadline::from);
if let Some(d) = deadline_obj {
req.extensions_mut().insert(d);
}
let inner = CURRENT_DEADLINE.scope(deadline_obj, self.inner.call(req));
TimeoutFuture::new(inner, deadline)
}
}
impl<S> TimeoutService<S> {
#[must_use]
pub fn new(inner: S, config: &HandlerTimeoutConfig) -> Self {
Self {
config: Arc::new(config.clone()),
inner,
}
}
}
#[pin_project(project = Type)]
#[derive(Debug)]
pub enum TimeoutFuture<F> {
Bounded {
#[pin]
inner: TaskLocalFuture<Option<Deadline>, F>,
#[pin]
sleep: Sleep,
},
Unbounded {
#[pin]
inner: TaskLocalFuture<Option<Deadline>, F>,
},
}
impl<F, U, E> Future for TimeoutFuture<F>
where
F: Future<Output = Result<U, E>>,
E: Into<BoxError>,
{
type Output = Result<U, BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
Type::Bounded { inner, sleep } => {
match inner.poll(cx) {
Poll::Pending => {}
Poll::Ready(res) => return Poll::Ready(res.map_err(Into::into)),
}
match sleep.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(_) => {
warn!("request timed out");
Poll::Ready(Err(TimeoutError::TimedOut.into()))
}
}
}
Type::Unbounded { inner } => inner.poll(cx).map_err(Into::into),
}
}
}
impl<F> TimeoutFuture<F> {
#[must_use]
pub fn new(inner: TaskLocalFuture<Option<Deadline>, F>, deadline: Option<Instant>) -> Self {
match deadline {
Some(d) => Self::Bounded {
inner,
sleep: sleep_until(d),
},
None => Self::Unbounded { inner },
}
}
}