use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use opentelemetry::propagation::{Extractor, TextMapPropagator};
use opentelemetry_sdk::propagation::TraceContextPropagator;
use tower::{Layer, Service};
pub fn inject_trace_context(
mut req: tonic::Request<()>,
) -> Result<tonic::Request<()>, tonic::Status> {
if let Some(value) = crate::current_traceparent()
&& let Ok(meta_val) = tonic::metadata::MetadataValue::try_from(value.as_str())
{
req.metadata_mut().insert("traceparent", meta_val);
}
if let Some(value) = crate::current_tracestate()
&& let Ok(meta_val) = tonic::metadata::MetadataValue::try_from(value.as_str())
{
req.metadata_mut().insert("tracestate", meta_val);
}
Ok(req)
}
struct MetadataExtractor<'a>(&'a tonic::metadata::MetadataMap);
impl<'a> Extractor for MetadataExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.0
.keys()
.map(|k| match k {
tonic::metadata::KeyRef::Ascii(key) => key.as_str(),
tonic::metadata::KeyRef::Binary(key) => key.as_str(),
})
.collect()
}
}
pub fn extract_trace_context(req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
let propagator = TraceContextPropagator::new();
let parent_cx = propagator.extract(&MetadataExtractor(req.metadata()));
parent_cx.attach();
Ok(req)
}
#[derive(Clone, Copy, Debug, Default)]
pub struct GrpcDurationLayer;
impl<S> Layer<S> for GrpcDurationLayer {
type Service = GrpcDurationService<S>;
fn layer(&self, inner: S) -> Self::Service {
GrpcDurationService { inner }
}
}
#[derive(Clone, Debug)]
pub struct GrpcDurationService<S> {
inner: S,
}
impl<S, B> Service<http::Request<B>> for GrpcDurationService<S>
where
S: Service<http::Request<B>> + Send + 'static,
S::Future: Send + 'static,
B: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let path = req.uri().path().to_string();
let start = std::time::Instant::now();
let fut = self.inner.call(req);
Box::pin(async move {
let result = fut.await;
let elapsed = start.elapsed().as_secs_f64();
crate::global_metrics().observe_grpc_duration(&path, elapsed);
result
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn inject_trace_context_with_no_span_passes_through() {
let req = tonic::Request::new(());
let result = inject_trace_context(req);
assert!(result.is_ok());
}
#[test]
fn extract_trace_context_with_no_header_passes_through() {
let req = tonic::Request::new(());
let result = extract_trace_context(req);
assert!(result.is_ok());
}
#[test]
fn extract_trace_context_with_valid_header() {
let mut req = tonic::Request::new(());
req.metadata_mut().insert(
"traceparent",
tonic::metadata::MetadataValue::from_static("00-abc123-def456-01"),
);
let result = extract_trace_context(req);
assert!(result.is_ok());
}
#[test]
fn extract_trace_context_with_empty_header() {
let mut req = tonic::Request::new(());
req.metadata_mut().insert(
"traceparent",
tonic::metadata::MetadataValue::from_static(""),
);
let result = extract_trace_context(req);
assert!(result.is_ok());
}
#[test]
fn inject_trace_context_preserves_request() {
let req = tonic::Request::new(());
let result = inject_trace_context(req);
assert!(result.is_ok());
let _req = result.unwrap();
}
#[test]
fn extract_trace_context_preserves_request() {
let req = tonic::Request::new(());
let result = extract_trace_context(req);
assert!(result.is_ok());
let _req = result.unwrap();
}
}