rs-zero 0.2.4

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
#[cfg(feature = "rpc")]
use tonic::metadata::MetadataMap;

#[cfg(all(feature = "observability", feature = "rpc"))]
use crate::observability::insert_traceparent_metadata;
#[cfg(feature = "observability")]
use crate::observability::{
    CorrelationContext, TRACEPARENT_HEADER, insert_traceparent_header, span_id_from_traceparent,
    trace_id_from_traceparent,
};
#[cfg(feature = "rpc")]
use crate::rpc::{REQUEST_ID_METADATA, RpcRequestId};

/// Low-cardinality request context shared by REST/RPC Tower layers.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RequestContext {
    service: String,
    transport: &'static str,
    route: String,
    method: String,
    request_id: Option<String>,
    traceparent: Option<String>,
    trace_id: Option<String>,
    span_id: Option<String>,
}

impl RequestContext {
    /// Creates a context from explicit low-cardinality parts.
    pub fn new(
        service: impl Into<String>,
        transport: &'static str,
        route: impl Into<String>,
        method: impl Into<String>,
    ) -> Self {
        Self {
            service: service.into(),
            transport,
            route: route.into(),
            method: method.into(),
            request_id: None,
            traceparent: None,
            trace_id: None,
            span_id: None,
        }
    }

    /// Builds an HTTP request context from headers and a route pattern.
    #[cfg(feature = "observability")]
    pub fn from_http_headers(
        service: Option<&str>,
        method: impl Into<String>,
        route: Option<&str>,
        headers: &http::HeaderMap,
    ) -> Self {
        let correlation = CorrelationContext::from_http_headers(service, method, route, headers);
        Self::from_correlation(correlation)
    }

    /// Builds a gRPC request context from tonic metadata and a method pattern.
    #[cfg(all(feature = "observability", feature = "rpc"))]
    pub fn from_tonic_metadata(
        service: impl Into<String>,
        method: impl Into<String>,
        metadata: &MetadataMap,
    ) -> Self {
        let correlation = CorrelationContext::from_rpc_metadata(service, method, metadata);
        Self::from_correlation(correlation)
    }

    /// Sets a request id.
    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
        let request_id = request_id.into();
        if !request_id.trim().is_empty() {
            self.request_id = Some(request_id);
        }
        self
    }

    /// Sets a W3C traceparent and derives trace/span ids when valid.
    pub fn with_traceparent(mut self, traceparent: impl Into<String>) -> Self {
        let traceparent = traceparent.into();
        #[cfg(feature = "observability")]
        {
            self.trace_id = trace_id_from_traceparent(&traceparent).map(ToOwned::to_owned);
            self.span_id = span_id_from_traceparent(&traceparent).map(ToOwned::to_owned);
        }
        self.traceparent = Some(traceparent);
        self
    }

    /// Returns the service name.
    pub fn service(&self) -> &str {
        &self.service
    }

    /// Returns the transport name.
    pub fn transport(&self) -> &'static str {
        self.transport
    }

    /// Returns the route or RPC method pattern.
    pub fn route(&self) -> &str {
        &self.route
    }

    /// Returns the HTTP or RPC method.
    pub fn method(&self) -> &str {
        &self.method
    }

    /// Returns the request id.
    pub fn request_id(&self) -> Option<&str> {
        self.request_id.as_deref()
    }

    /// Returns the traceparent.
    pub fn traceparent(&self) -> Option<&str> {
        self.traceparent.as_deref()
    }

    /// Returns the trace id.
    pub fn trace_id(&self) -> Option<&str> {
        self.trace_id.as_deref()
    }

    /// Returns the span id.
    pub fn span_id(&self) -> Option<&str> {
        self.span_id.as_deref()
    }

    /// Inserts context values into HTTP headers when they are missing.
    #[cfg(feature = "observability")]
    pub fn inject_http_headers(
        &self,
        headers: &mut http::HeaderMap,
    ) -> Result<(), http::header::InvalidHeaderValue> {
        if let Some(request_id) = self.request_id()
            && !headers.contains_key(crate::observability::REQUEST_ID_HEADER)
        {
            headers.insert(
                crate::observability::REQUEST_ID_HEADER,
                http::HeaderValue::from_str(request_id)?,
            );
        }
        if let Some(traceparent) = self.traceparent()
            && !headers.contains_key(TRACEPARENT_HEADER)
        {
            insert_traceparent_header(headers, traceparent)?;
        }
        Ok(())
    }

    /// Inserts context values into tonic metadata when they are missing.
    #[cfg(feature = "rpc")]
    pub fn inject_tonic_metadata(
        &self,
        metadata: &mut MetadataMap,
    ) -> Result<(), tonic::metadata::errors::InvalidMetadataValue> {
        if let Some(request_id) = self.request_id()
            && !metadata.contains_key(REQUEST_ID_METADATA)
        {
            metadata.insert(REQUEST_ID_METADATA, request_id.parse()?);
        }
        #[cfg(feature = "observability")]
        if let Some(traceparent) = self.traceparent()
            && !metadata.contains_key(TRACEPARENT_HEADER)
        {
            insert_traceparent_metadata(metadata, traceparent)?;
        }
        Ok(())
    }

    /// Inserts context values into request extensions for downstream layers.
    #[cfg(feature = "rpc")]
    pub fn insert_tonic_extensions<T>(&self, request: &mut tonic::Request<T>) {
        if let Some(request_id) = self.request_id() {
            request
                .extensions_mut()
                .insert(RpcRequestId(request_id.to_string()));
        }
        #[cfg(feature = "observability")]
        if let Some(request_id) = self.request_id() {
            request
                .extensions_mut()
                .insert(crate::observability::CurrentRequestId(
                    request_id.to_string(),
                ));
        }
    }

    #[cfg(feature = "observability")]
    fn from_correlation(correlation: CorrelationContext) -> Self {
        let mut context = Self::new(
            correlation.service().to_string(),
            correlation.transport(),
            correlation.route().to_string(),
            correlation.method().to_string(),
        );
        if let Some(request_id) = correlation.request_id() {
            context.request_id = Some(request_id.to_string());
        }
        if let Some(traceparent) = correlation.traceparent() {
            context.traceparent = Some(traceparent.to_string());
        }
        if let Some(trace_id) = correlation.trace_id() {
            context.trace_id = Some(trace_id.to_string());
        }
        if let Some(span_id) = correlation.span_id() {
            context.span_id = Some(span_id.to_string());
        }
        context
    }
}

/// Returns the current task-local request id when one is available.
pub fn current_request_id() -> Option<String> {
    #[cfg(feature = "rpc")]
    {
        crate::rpc::RPC_REQUEST_ID_SCOPE
            .try_with(|value| value.to_string())
            .ok()
    }

    #[cfg(not(feature = "rpc"))]
    {
        None
    }
}

/// Runs a future with request id available to outgoing RPC layers.
pub async fn scope_request_id<T>(
    request_id: impl Into<String>,
    future: impl std::future::Future<Output = T>,
) -> T {
    #[cfg(feature = "rpc")]
    {
        crate::rpc::with_rpc_request_id(request_id, future).await
    }

    #[cfg(not(feature = "rpc"))]
    {
        let _ = request_id.into();
        future.await
    }
}