rs-zero 0.2.7

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{
    convert::Infallible,
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

use http::{Request, Response};
use http_body::Body;
use tower::{Layer, Service};

use crate::rpc::{RpcResilienceLayer, status_counts_as_failure};

/// Tower layer that applies [`RpcResilienceLayer`] to tonic unary services.
#[derive(Debug, Clone)]
pub struct RpcUnaryResilienceLayer {
    resilience: RpcResilienceLayer,
    observe: bool,
}

impl RpcUnaryResilienceLayer {
    /// Creates a unary resilience layer from a shared RPC resilience helper.
    pub fn new(resilience: RpcResilienceLayer) -> Self {
        Self {
            resilience,
            observe: true,
        }
    }

    /// Creates a layer that only applies resilience decisions.
    pub fn resilience_only(resilience: RpcResilienceLayer) -> Self {
        Self {
            resilience,
            observe: false,
        }
    }
}

impl<S> Layer<S> for RpcUnaryResilienceLayer {
    type Service = RpcUnaryResilienceService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        RpcUnaryResilienceService {
            inner,
            resilience: self.resilience.clone(),
            observe: self.observe,
        }
    }
}

/// Tower service produced by [`RpcUnaryResilienceLayer`].
#[derive(Debug, Clone)]
pub struct RpcUnaryResilienceService<S> {
    inner: S,
    resilience: RpcResilienceLayer,
    observe: bool,
}

impl<S, B> Service<Request<B>> for RpcUnaryResilienceService<S>
where
    S: Service<Request<B>, Response = Response<tonic::body::Body>, Error = Infallible>
        + Clone
        + Send
        + 'static,
    S::Future: Send + 'static,
    B: Body + Send + 'static,
{
    type Response = Response<tonic::body::Body>;
    type Error = Infallible;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<B>) -> Self::Future {
        let method = request.uri().path().to_string();
        let mut inner = self.inner.clone();
        let resilience = self.resilience.clone();
        let observe = self.observe;
        let mut metadata = tonic::metadata::MetadataMap::new();
        for (key, value) in request.headers() {
            let Some((key, value)) = key
                .as_str()
                .parse::<tonic::metadata::MetadataKey<tonic::metadata::Ascii>>()
                .ok()
                .zip(value.to_str().ok().and_then(|value| value.parse().ok()))
            else {
                continue;
            };
            metadata.insert(key, value);
        }

        Box::pin(async move {
            let call = || async move {
                let response = match inner.call(request).await {
                    Ok(response) => response,
                    Err(never) => match never {},
                };
                if let Some(status) = tonic::Status::from_header_map(response.headers())
                    && status_counts_as_failure(&status)
                {
                    return Err(status);
                }
                Ok(response)
            };
            #[cfg(feature = "observability")]
            let result = if observe {
                resilience
                    .run_unary_with_metadata(&method, &metadata, call)
                    .await
            } else {
                resilience.run_unary_inner_public(&method, call).await
            };

            #[cfg(not(feature = "observability"))]
            let result = {
                let _ = observe;
                let _ = metadata;
                resilience.run_unary_inner_public(&method, call).await
            };

            Ok(match result {
                Ok(response) => response,
                Err(status) => status.into_http::<tonic::body::Body>(),
            })
        })
    }
}

#[cfg(test)]
mod tests {
    use std::{convert::Infallible, task::Poll};

    use http::{Request, Response};
    use tower::{Layer, Service};

    use super::RpcUnaryResilienceLayer;
    use crate::rpc::{RpcResilienceConfig, RpcResilienceLayer};

    #[derive(Clone)]
    struct SlowService;

    impl Service<Request<tonic::body::Body>> for SlowService {
        type Response = Response<tonic::body::Body>;
        type Error = Infallible;
        type Future =
            std::pin::Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

        fn poll_ready(
            &mut self,
            _cx: &mut std::task::Context<'_>,
        ) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }

        fn call(&mut self, _request: Request<tonic::body::Body>) -> Self::Future {
            Box::pin(async move {
                tokio::time::sleep(std::time::Duration::from_millis(20)).await;
                Ok(Response::new(tonic::body::Body::empty()))
            })
        }
    }

    #[tokio::test]
    async fn unary_layer_maps_timeout_without_manual_helper() {
        let layer = RpcUnaryResilienceLayer::new(RpcResilienceLayer::new(
            "hello",
            RpcResilienceConfig {
                request_timeout: std::time::Duration::from_millis(1),
                ..RpcResilienceConfig::default()
            },
        ));
        let mut service = layer.layer(SlowService);

        let response = service
            .call(Request::new(tonic::body::Body::empty()))
            .await
            .expect("infallible");

        assert_eq!(response.status(), http::StatusCode::OK);
        assert_eq!(
            response
                .headers()
                .get("grpc-status")
                .and_then(|value| value.to_str().ok()),
            Some("4")
        );
    }
}