use std::{
convert::Infallible,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use http::{Request, Response};
use http_body::Body;
use tower::{Layer, Service};
use crate::rpc::{RpcResilienceLayer, status_counts_as_failure};
#[derive(Debug, Clone)]
pub struct RpcUnaryResilienceLayer {
resilience: RpcResilienceLayer,
observe: bool,
}
impl RpcUnaryResilienceLayer {
pub fn new(resilience: RpcResilienceLayer) -> Self {
Self {
resilience,
observe: true,
}
}
pub fn resilience_only(resilience: RpcResilienceLayer) -> Self {
Self {
resilience,
observe: false,
}
}
}
impl<S> Layer<S> for RpcUnaryResilienceLayer {
type Service = RpcUnaryResilienceService<S>;
fn layer(&self, inner: S) -> Self::Service {
RpcUnaryResilienceService {
inner,
resilience: self.resilience.clone(),
observe: self.observe,
}
}
}
#[derive(Debug, Clone)]
pub struct RpcUnaryResilienceService<S> {
inner: S,
resilience: RpcResilienceLayer,
observe: bool,
}
impl<S, B> Service<Request<B>> for RpcUnaryResilienceService<S>
where
S: Service<Request<B>, Response = Response<tonic::body::Body>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
B: Body + Send + 'static,
{
type Response = Response<tonic::body::Body>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<B>) -> Self::Future {
let method = request.uri().path().to_string();
let mut inner = self.inner.clone();
let resilience = self.resilience.clone();
let observe = self.observe;
let mut metadata = tonic::metadata::MetadataMap::new();
for (key, value) in request.headers() {
let Some((key, value)) = key
.as_str()
.parse::<tonic::metadata::MetadataKey<tonic::metadata::Ascii>>()
.ok()
.zip(value.to_str().ok().and_then(|value| value.parse().ok()))
else {
continue;
};
metadata.insert(key, value);
}
Box::pin(async move {
let call = || async move {
let response = match inner.call(request).await {
Ok(response) => response,
Err(never) => match never {},
};
if let Some(status) = tonic::Status::from_header_map(response.headers())
&& status_counts_as_failure(&status)
{
return Err(status);
}
Ok(response)
};
#[cfg(feature = "observability")]
let result = if observe {
resilience
.run_unary_with_metadata(&method, &metadata, call)
.await
} else {
resilience.run_unary_inner_public(&method, call).await
};
#[cfg(not(feature = "observability"))]
let result = {
let _ = observe;
let _ = metadata;
resilience.run_unary_inner_public(&method, call).await
};
Ok(match result {
Ok(response) => response,
Err(status) => status.into_http::<tonic::body::Body>(),
})
})
}
}
#[cfg(test)]
mod tests {
use std::{convert::Infallible, task::Poll};
use http::{Request, Response};
use tower::{Layer, Service};
use super::RpcUnaryResilienceLayer;
use crate::rpc::{RpcResilienceConfig, RpcResilienceLayer};
#[derive(Clone)]
struct SlowService;
impl Service<Request<tonic::body::Body>> for SlowService {
type Response = Response<tonic::body::Body>;
type Error = Infallible;
type Future =
std::pin::Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _request: Request<tonic::body::Body>) -> Self::Future {
Box::pin(async move {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
Ok(Response::new(tonic::body::Body::empty()))
})
}
}
#[tokio::test]
async fn unary_layer_maps_timeout_without_manual_helper() {
let layer = RpcUnaryResilienceLayer::new(RpcResilienceLayer::new(
"hello",
RpcResilienceConfig {
request_timeout: std::time::Duration::from_millis(1),
..RpcResilienceConfig::default()
},
));
let mut service = layer.layer(SlowService);
let response = service
.call(Request::new(tonic::body::Body::empty()))
.await
.expect("infallible");
assert_eq!(response.status(), http::StatusCode::OK);
assert_eq!(
response
.headers()
.get("grpc-status")
.and_then(|value| value.to_str().ok()),
Some("4")
);
}
}