rs-zero 0.2.3

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{
    convert::Infallible,
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

use http::{Request, Response};
use http_body::Body;
use tower::{Layer, Service};

use crate::rpc::{RpcResilienceLayer, status_counts_as_failure};

/// Tower layer that applies [`RpcResilienceLayer`] to tonic unary services.
#[derive(Debug, Clone)]
pub struct RpcUnaryResilienceLayer {
    resilience: RpcResilienceLayer,
}

impl RpcUnaryResilienceLayer {
    /// Creates a unary resilience layer from a shared RPC resilience helper.
    pub fn new(resilience: RpcResilienceLayer) -> Self {
        Self { resilience }
    }
}

impl<S> Layer<S> for RpcUnaryResilienceLayer {
    type Service = RpcUnaryResilienceService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        RpcUnaryResilienceService {
            inner,
            resilience: self.resilience.clone(),
        }
    }
}

/// Tower service produced by [`RpcUnaryResilienceLayer`].
#[derive(Debug, Clone)]
pub struct RpcUnaryResilienceService<S> {
    inner: S,
    resilience: RpcResilienceLayer,
}

impl<S, B> Service<Request<B>> for RpcUnaryResilienceService<S>
where
    S: Service<Request<B>, Response = Response<tonic::body::Body>, Error = Infallible>
        + Clone
        + Send
        + 'static,
    S::Future: Send + 'static,
    B: Body + Send + 'static,
{
    type Response = Response<tonic::body::Body>;
    type Error = Infallible;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<B>) -> Self::Future {
        let method = request.uri().path().to_string();
        let mut inner = self.inner.clone();
        let resilience = self.resilience.clone();

        Box::pin(async move {
            let result = resilience
                .run_unary(&method, || async move {
                    let response = match inner.call(request).await {
                        Ok(response) => response,
                        Err(never) => match never {},
                    };
                    if let Some(status) = tonic::Status::from_header_map(response.headers()) {
                        if status_counts_as_failure(&status) {
                            return Err(status);
                        }
                    }
                    Ok(response)
                })
                .await;

            Ok(match result {
                Ok(response) => response,
                Err(status) => status.into_http::<tonic::body::Body>(),
            })
        })
    }
}

#[cfg(test)]
mod tests {
    use std::{convert::Infallible, task::Poll};

    use http::{Request, Response};
    use tower::{Layer, Service};

    use super::RpcUnaryResilienceLayer;
    use crate::rpc::{RpcResilienceConfig, RpcResilienceLayer};

    #[derive(Clone)]
    struct SlowService;

    impl Service<Request<tonic::body::Body>> for SlowService {
        type Response = Response<tonic::body::Body>;
        type Error = Infallible;
        type Future =
            std::pin::Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

        fn poll_ready(
            &mut self,
            _cx: &mut std::task::Context<'_>,
        ) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }

        fn call(&mut self, _request: Request<tonic::body::Body>) -> Self::Future {
            Box::pin(async move {
                tokio::time::sleep(std::time::Duration::from_millis(20)).await;
                Ok(Response::new(tonic::body::Body::empty()))
            })
        }
    }

    #[tokio::test]
    async fn unary_layer_maps_timeout_without_manual_helper() {
        let layer = RpcUnaryResilienceLayer::new(RpcResilienceLayer::new(
            "hello",
            RpcResilienceConfig {
                request_timeout: std::time::Duration::from_millis(1),
                ..RpcResilienceConfig::default()
            },
        ));
        let mut service = layer.layer(SlowService);

        let response = service
            .call(Request::new(tonic::body::Body::empty()))
            .await
            .expect("infallible");

        assert_eq!(response.status(), http::StatusCode::OK);
        assert_eq!(
            response
                .headers()
                .get("grpc-status")
                .and_then(|value| value.to_str().ok()),
            Some("4")
        );
    }
}