rs-zero 0.2.3

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use http::{HeaderMap, HeaderValue, header::InvalidHeaderValue};

/// Standard correlation header used by rs-zero REST and RPC adapters.
pub const REQUEST_ID_HEADER: &str = "x-request-id";
/// W3C Trace Context header.
pub const TRACEPARENT_HEADER: &str = "traceparent";

/// Request extension value set by REST middleware for downstream clients.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CurrentRequestId(pub String);

/// Extracts a request id from HTTP headers.
pub fn request_id_from_headers(headers: &HeaderMap) -> Option<String> {
    headers
        .get(REQUEST_ID_HEADER)
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(ToOwned::to_owned)
}

/// Extracts a valid W3C traceparent header from HTTP headers.
pub fn traceparent_from_headers(headers: &HeaderMap) -> Option<String> {
    headers
        .get(TRACEPARENT_HEADER)
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .filter(|value| is_valid_traceparent(value))
        .map(ToOwned::to_owned)
}

/// Inserts a W3C traceparent header into HTTP headers.
pub fn insert_traceparent_header(
    headers: &mut HeaderMap,
    traceparent: &str,
) -> Result<(), InvalidHeaderValue> {
    headers.insert(TRACEPARENT_HEADER, HeaderValue::from_str(traceparent)?);
    Ok(())
}

/// Extracts a request id from tonic metadata.
#[cfg(feature = "rpc")]
pub fn request_id_from_metadata(metadata: &tonic::metadata::MetadataMap) -> Option<String> {
    metadata
        .get(REQUEST_ID_HEADER)
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(ToOwned::to_owned)
}

/// Extracts a valid W3C traceparent value from tonic metadata.
#[cfg(feature = "rpc")]
pub fn traceparent_from_metadata(metadata: &tonic::metadata::MetadataMap) -> Option<String> {
    metadata
        .get(TRACEPARENT_HEADER)
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .filter(|value| is_valid_traceparent(value))
        .map(ToOwned::to_owned)
}

/// Inserts a traceparent value into tonic metadata.
#[cfg(feature = "rpc")]
pub fn insert_traceparent_metadata(
    metadata: &mut tonic::metadata::MetadataMap,
    traceparent: &str,
) -> Result<(), tonic::metadata::errors::InvalidMetadataValue> {
    metadata.insert(TRACEPARENT_HEADER, traceparent.parse()?);
    Ok(())
}

/// Returns the current OpenTelemetry trace id when an OTLP context is active.
pub fn current_trace_id() -> Option<String> {
    #[cfg(feature = "otlp")]
    {
        use opentelemetry::trace::TraceContextExt;
        use tracing_opentelemetry::OpenTelemetrySpanExt;

        let context = tracing::Span::current().context();
        let span = context.span();
        let span_context = span.span_context();
        if span_context.is_valid() {
            return Some(span_context.trace_id().to_string());
        }
    }

    None
}

/// Returns the current OpenTelemetry span id when an OTLP context is active.
pub fn current_span_id() -> Option<String> {
    #[cfg(feature = "otlp")]
    {
        use opentelemetry::trace::TraceContextExt;
        use tracing_opentelemetry::OpenTelemetrySpanExt;

        let context = tracing::Span::current().context();
        let span = context.span();
        let span_context = span.span_context();
        if span_context.is_valid() {
            return Some(span_context.span_id().to_string());
        }
    }

    None
}

/// Returns a W3C traceparent value for the current span when OTLP context is active.
pub fn current_traceparent() -> Option<String> {
    let trace_id = current_trace_id()?;
    let span_id = current_span_id()?;
    Some(format!("00-{trace_id}-{span_id}-01"))
}

/// Extracts an OpenTelemetry parent context from HTTP headers.
#[cfg(feature = "otlp")]
pub fn opentelemetry_context_from_headers(headers: &HeaderMap) -> Option<opentelemetry::Context> {
    use opentelemetry::{global, trace::TraceContextExt};

    traceparent_from_headers(headers)?;
    let extractor = HeaderMapExtractor { headers };
    let context = global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
    context.span().span_context().is_valid().then_some(context)
}

/// Extracts an OpenTelemetry parent context from one W3C traceparent value.
#[cfg(feature = "otlp")]
pub fn opentelemetry_context_from_traceparent(traceparent: &str) -> Option<opentelemetry::Context> {
    use opentelemetry::{global, trace::TraceContextExt};

    if !is_valid_traceparent(traceparent) {
        return None;
    }

    let extractor = TraceParentExtractor { traceparent };
    let context = global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
    context.span().span_context().is_valid().then_some(context)
}

/// Sets an OpenTelemetry parent context on a tracing span from HTTP headers.
#[cfg(feature = "otlp")]
pub fn set_span_parent_from_headers(span: &tracing::Span, headers: &HeaderMap) -> bool {
    use tracing_opentelemetry::OpenTelemetrySpanExt;

    let Some(context) = opentelemetry_context_from_headers(headers) else {
        return false;
    };

    span.set_parent(context).is_ok()
}

/// Extracts an OpenTelemetry parent context from tonic metadata.
#[cfg(all(feature = "otlp", feature = "rpc"))]
pub fn opentelemetry_context_from_metadata(
    metadata: &tonic::metadata::MetadataMap,
) -> Option<opentelemetry::Context> {
    use opentelemetry::{global, trace::TraceContextExt};

    traceparent_from_metadata(metadata)?;
    let extractor = MetadataMapExtractor { metadata };
    let context = global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
    context.span().span_context().is_valid().then_some(context)
}

/// Sets an OpenTelemetry parent context on a tracing span from tonic metadata.
#[cfg(all(feature = "otlp", feature = "rpc"))]
pub fn set_span_parent_from_metadata(
    span: &tracing::Span,
    metadata: &tonic::metadata::MetadataMap,
) -> bool {
    use tracing_opentelemetry::OpenTelemetrySpanExt;

    let Some(context) = opentelemetry_context_from_metadata(metadata) else {
        return false;
    };

    span.set_parent(context).is_ok()
}

/// Injects the current OpenTelemetry context into tonic metadata.
#[cfg(all(feature = "otlp", feature = "rpc"))]
pub fn inject_current_context_metadata(
    metadata: &mut tonic::metadata::MetadataMap,
) -> Result<bool, tonic::metadata::errors::InvalidMetadataValue> {
    use opentelemetry::{global, trace::TraceContextExt};
    use tracing_opentelemetry::OpenTelemetrySpanExt;

    let context = tracing::Span::current().context();
    if !context.span().span_context().is_valid() {
        return Ok(false);
    }

    let mut injector = MetadataMapInjector {
        metadata,
        invalid_value: None,
    };
    global::get_text_map_propagator(|propagator| {
        propagator.inject_context(&context, &mut injector);
    });
    if let Some(error) = injector.invalid_value {
        return Err(error);
    }
    Ok(injector.metadata.contains_key(TRACEPARENT_HEADER))
}

/// Extracts the trace id part from a valid traceparent value.
pub fn trace_id_from_traceparent(traceparent: &str) -> Option<&str> {
    if !is_valid_traceparent(traceparent) {
        return None;
    }
    traceparent.split('-').nth(1)
}

/// Extracts the span id part from a valid traceparent value.
pub fn span_id_from_traceparent(traceparent: &str) -> Option<&str> {
    if !is_valid_traceparent(traceparent) {
        return None;
    }
    traceparent.split('-').nth(2)
}

fn is_valid_traceparent(value: &str) -> bool {
    let mut parts = value.split('-');
    let Some(version) = parts.next() else {
        return false;
    };
    let Some(trace_id) = parts.next() else {
        return false;
    };
    let Some(span_id) = parts.next() else {
        return false;
    };
    let Some(flags) = parts.next() else {
        return false;
    };
    parts.next().is_none()
        && version.len() == 2
        && trace_id.len() == 32
        && span_id.len() == 16
        && flags.len() == 2
        && trace_id != "00000000000000000000000000000000"
        && span_id != "0000000000000000"
        && version.chars().all(|value| value.is_ascii_hexdigit())
        && trace_id.chars().all(|value| value.is_ascii_hexdigit())
        && span_id.chars().all(|value| value.is_ascii_hexdigit())
        && flags.chars().all(|value| value.is_ascii_hexdigit())
}

#[cfg(feature = "otlp")]
struct HeaderMapExtractor<'a> {
    headers: &'a HeaderMap,
}

#[cfg(feature = "otlp")]
impl opentelemetry::propagation::Extractor for HeaderMapExtractor<'_> {
    fn get(&self, key: &str) -> Option<&str> {
        self.headers.get(key).and_then(|value| value.to_str().ok())
    }

    fn keys(&self) -> Vec<&str> {
        self.headers.keys().map(|key| key.as_str()).collect()
    }
}

#[cfg(feature = "otlp")]
struct TraceParentExtractor<'a> {
    traceparent: &'a str,
}

#[cfg(feature = "otlp")]
impl opentelemetry::propagation::Extractor for TraceParentExtractor<'_> {
    fn get(&self, key: &str) -> Option<&str> {
        key.eq_ignore_ascii_case(TRACEPARENT_HEADER)
            .then_some(self.traceparent)
    }

    fn keys(&self) -> Vec<&str> {
        vec![TRACEPARENT_HEADER]
    }
}

#[cfg(all(feature = "otlp", feature = "rpc"))]
struct MetadataMapExtractor<'a> {
    metadata: &'a tonic::metadata::MetadataMap,
}

#[cfg(all(feature = "otlp", feature = "rpc"))]
impl opentelemetry::propagation::Extractor for MetadataMapExtractor<'_> {
    fn get(&self, key: &str) -> Option<&str> {
        self.metadata.get(key).and_then(|value| value.to_str().ok())
    }

    fn keys(&self) -> Vec<&str> {
        self.metadata
            .keys()
            .filter_map(|key| match key {
                tonic::metadata::KeyRef::Ascii(key) => Some(key.as_str()),
                tonic::metadata::KeyRef::Binary(_) => None,
            })
            .collect()
    }
}

#[cfg(all(feature = "otlp", feature = "rpc"))]
struct MetadataMapInjector<'a> {
    metadata: &'a mut tonic::metadata::MetadataMap,
    invalid_value: Option<tonic::metadata::errors::InvalidMetadataValue>,
}

#[cfg(all(feature = "otlp", feature = "rpc"))]
impl opentelemetry::propagation::Injector for MetadataMapInjector<'_> {
    fn set(&mut self, key: &str, value: String) {
        let Ok(key) = key.parse::<tonic::metadata::MetadataKey<tonic::metadata::Ascii>>() else {
            return;
        };
        match tonic::metadata::MetadataValue::try_from(value.as_str()) {
            Ok(value) => {
                self.metadata.insert(key, value);
            }
            Err(error) => {
                self.invalid_value = Some(error);
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use http::HeaderMap;

    use super::{
        REQUEST_ID_HEADER, TRACEPARENT_HEADER, current_trace_id, request_id_from_headers,
        trace_id_from_traceparent, traceparent_from_headers,
    };

    #[test]
    fn extracts_request_id_from_headers() {
        let mut headers = HeaderMap::new();
        headers.insert(REQUEST_ID_HEADER, "req-1".parse().expect("header"));

        assert_eq!(request_id_from_headers(&headers).as_deref(), Some("req-1"));
    }

    #[test]
    fn trace_id_is_not_forged_without_active_context() {
        assert!(current_trace_id().is_none());
    }

    #[test]
    fn extracts_valid_traceparent_from_headers() {
        let mut headers = HeaderMap::new();
        headers.insert(
            TRACEPARENT_HEADER,
            "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
                .parse()
                .expect("traceparent"),
        );

        let value = traceparent_from_headers(&headers).expect("traceparent");
        assert_eq!(
            trace_id_from_traceparent(&value),
            Some("4bf92f3577b34da6a3ce929d0e0e4736")
        );
    }
}