use core::time::Duration;
use tokio::time::Instant;
use crate::{
body::ResponseBody,
context::WebContext,
error::{GrpcError, GrpcStatus},
http::{
WebResponse,
const_header_name::GRPC_TIMEOUT,
const_header_value::GRPC,
header::{CONTENT_TYPE, HeaderValue},
},
service::{Service, ready::ReadyService},
};
pub struct GrpcTimeout;
impl<S, E> Service<Result<S, E>> for GrpcTimeout {
type Response = GrpcTimeoutService<S>;
type Error = E;
async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
res.map(|service| GrpcTimeoutService { service })
}
}
pub struct GrpcTimeoutService<S> {
service: S,
}
impl<'r, S, C, B> Service<WebContext<'r, C, B>> for GrpcTimeoutService<S>
where
S: for<'r2> Service<WebContext<'r2, C, B>, Response = WebResponse, Error = crate::error::Error>,
{
type Response = WebResponse;
type Error = crate::error::Error;
async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
let timeout = ctx.req().headers().get(GRPC_TIMEOUT).and_then(parse_grpc_timeout);
match timeout {
Some(duration) => {
let deadline = Instant::now() + duration;
match tokio::time::timeout_at(deadline, self.service.call(ctx)).await {
Ok(result) => result,
Err(_elapsed) => {
let err = GrpcError::new(GrpcStatus::DeadlineExceeded, "deadline exceeded");
let mut res = WebResponse::new(ResponseBody::empty());
res.headers_mut().insert(CONTENT_TYPE, GRPC);
res.headers_mut().extend(err.trailers());
Ok(res)
}
}
}
None => self.service.call(ctx).await,
}
}
}
impl<S> ReadyService for GrpcTimeoutService<S>
where
S: ReadyService,
{
type Ready = S::Ready;
#[inline]
async fn ready(&self) -> Self::Ready {
self.service.ready().await
}
}
fn parse_grpc_timeout(value: &HeaderValue) -> Option<Duration> {
let bytes = value.as_bytes();
if bytes.len() < 2 {
return None;
}
let (digits, unit) = bytes.split_at(bytes.len() - 1);
if digits.is_empty() || digits.len() > 8 {
return None;
}
let mut val: u64 = 0;
for &b in digits {
if !b.is_ascii_digit() {
return None;
}
val = val * 10 + (b - b'0') as u64;
}
match unit[0] {
b'H' => Some(Duration::from_secs(val * 3600)),
b'M' => Some(Duration::from_secs(val * 60)),
b'S' => Some(Duration::from_secs(val)),
b'm' => Some(Duration::from_millis(val)),
b'u' => Some(Duration::from_micros(val)),
b'n' => Some(Duration::from_nanos(val)),
_ => None,
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parse_timeout_values() {
assert_eq!(
parse_grpc_timeout(&HeaderValue::from_static("1H")),
Some(Duration::from_secs(3600))
);
assert_eq!(
parse_grpc_timeout(&HeaderValue::from_static("5M")),
Some(Duration::from_secs(300))
);
assert_eq!(
parse_grpc_timeout(&HeaderValue::from_static("10S")),
Some(Duration::from_secs(10))
);
assert_eq!(
parse_grpc_timeout(&HeaderValue::from_static("100m")),
Some(Duration::from_millis(100))
);
assert_eq!(
parse_grpc_timeout(&HeaderValue::from_static("5000u")),
Some(Duration::from_micros(5000))
);
assert_eq!(
parse_grpc_timeout(&HeaderValue::from_static("999n")),
Some(Duration::from_nanos(999))
);
}
#[test]
fn parse_timeout_invalid() {
assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("H")), None); assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("5")), None); assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("5x")), None); assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("abc")), None); assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("123456789S")), None); }
#[test]
fn parse_timeout_max_digits() {
assert_eq!(
parse_grpc_timeout(&HeaderValue::from_static("99999999S")),
Some(Duration::from_secs(99999999))
);
}
}