use std::pin::Pin;
use std::task::{Context, Poll};
use opentelemetry::trace::FutureExt as _;
use opentelemetry::trace::WithContext;
use opentelemetry_http::{HeaderExtractor, HeaderInjector};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
#[derive(Debug, Clone, Default)]
pub struct HttpServerPropagatorLayer {
_private: (),
}
impl HttpServerPropagatorLayer {
pub fn new() -> Self {
Self { _private: () }
}
}
impl<S> Layer<S> for HttpServerPropagatorLayer {
type Service = HttpServerPropagator<S>;
fn layer(&self, inner: S) -> Self::Service {
HttpServerPropagator::new(inner)
}
}
#[derive(Debug, Clone, Default)]
pub struct HttpClientPropagatorLayer {
_private: (),
}
impl HttpClientPropagatorLayer {
pub fn new() -> Self {
Self { _private: () }
}
}
impl<S> Layer<S> for HttpClientPropagatorLayer {
type Service = HttpClientPropagator<S>;
fn layer(&self, inner: S) -> Self::Service {
HttpClientPropagator::new(inner)
}
}
#[derive(Clone)]
pub struct HttpServerPropagator<S> {
inner: S,
}
impl<S> HttpServerPropagator<S> {
fn new(inner: S) -> Self {
Self { inner }
}
}
impl<B, S> Service<http::Request<B>> for HttpServerPropagator<S>
where
S: Service<http::Request<B>>,
{
type Error = S::Error;
type Response = S::Response;
type Future = WithContext<S::Future>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let cx = opentelemetry::global::get_text_map_propagator(|p| {
p.extract(&HeaderExtractor(req.headers()))
});
self.inner.call(req).with_context(cx)
}
}
#[derive(Clone)]
pub struct HttpClientPropagator<S> {
inner: S,
}
impl<S> HttpClientPropagator<S> {
fn new(inner: S) -> Self {
Self { inner }
}
}
impl<B, S> Service<http::Request<B>> for HttpClientPropagator<S>
where
S: Service<http::Request<B>> + Clone,
{
type Error = S::Error;
type Response = S::Response;
type Future = HttpClientPropagatorFuture<S, B>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let clone = self.inner.clone();
let service = std::mem::replace(&mut self.inner, clone);
HttpClientPropagatorFuture::Pending {
service: Some(service),
request: Some(req),
}
}
}
pin_project! {
#[project = HttpClientPropagatorFutureProj]
pub enum HttpClientPropagatorFuture<S, B>
where
S: Service<http::Request<B>>,
{
Pending {
service: Option<S>,
request: Option<http::Request<B>>,
},
Active {
#[pin]
future: S::Future,
},
}
}
impl<S, B> std::future::Future for HttpClientPropagatorFuture<S, B>
where
S: Service<http::Request<B>>,
{
type Output = Result<S::Response, S::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let HttpClientPropagatorFutureProj::Pending { service, request } =
self.as_mut().project()
{
let mut svc = service.take().expect("polled after completion");
let mut req = request.take().expect("polled after completion");
opentelemetry::global::get_text_map_propagator(|p| {
p.inject(&mut HeaderInjector(req.headers_mut()))
});
self.set(HttpClientPropagatorFuture::Active {
future: svc.call(req),
});
}
match self.project() {
HttpClientPropagatorFutureProj::Active { future } => future.poll(cx),
HttpClientPropagatorFutureProj::Pending { .. } => unreachable!(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tower::ServiceBuilderExt;
use apollo_opentelemetry_test::{TelemetryContext, assert_spans_snapshot};
use http::{Request, Response};
use opentelemetry::InstrumentationScope;
use opentelemetry::trace::SpanBuilder;
use tower::ServiceBuilder;
use tower::ServiceExt as _;
fn test_scope() -> &'static InstrumentationScope {
static SCOPE: std::sync::LazyLock<InstrumentationScope> =
std::sync::LazyLock::new(|| InstrumentationScope::builder("test").build());
&SCOPE
}
#[tokio::test]
async fn test_server_propagator_extracts_parent_context() {
let ctx = TelemetryContext::new();
let (mut service, mut handle) = tower_test::mock::spawn_with(|inner| {
ServiceBuilder::new()
.layer(HttpServerPropagatorLayer::new())
.traced(test_scope(), |_req: &Request<()>| {
SpanBuilder::from_name("server-span")
.with_kind(opentelemetry::trace::SpanKind::Server)
})
.service(inner)
});
let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
let req = Request::builder()
.header("traceparent", traceparent)
.body(())
.unwrap();
assert!(service.poll_ready().is_ready());
let response_fut = service.call(req);
let (_req, send_response) = handle.next_request().await.unwrap();
send_response.send_response(Response::new("ok"));
response_fut.await.unwrap();
assert_spans_snapshot!(ctx, @r#"
- name: server-span
span_kind: Server
has_parent: true
is_sampled: true
"#);
}
#[tokio::test]
async fn test_server_propagator_context_only_active_during_poll() {
use opentelemetry::trace::Tracer;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
let ctx = TelemetryContext::new();
#[derive(Clone)]
struct SpanInCallAndPoll;
impl Service<Request<()>> for SpanInCallAndPoll {
type Response = Response<&'static str>;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<()>) -> Self::Future {
let tracer = opentelemetry::global::tracer("test");
let _span = tracer.start("span-in-call");
Box::pin(async {
let tracer = opentelemetry::global::tracer("test");
let _span = tracer.start("span-in-poll");
Ok(Response::new("ok"))
})
}
}
let mut service = ServiceBuilder::new()
.layer(HttpServerPropagatorLayer::new())
.service(SpanInCallAndPoll);
let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
let req = Request::builder()
.header("traceparent", traceparent)
.body(())
.unwrap();
service.ready().await.unwrap();
service.call(req).await.unwrap();
assert_spans_snapshot!(ctx, @r#"
- name: span-in-call
span_kind: Internal
is_sampled: true
- name: span-in-poll
span_kind: Internal
has_parent: true
is_sampled: true
"#);
}
#[tokio::test]
async fn test_client_propagator_injects_context() {
let ctx = TelemetryContext::new();
let (mut service, mut handle) = tower_test::mock::spawn_with(|inner| {
ServiceBuilder::new()
.layer(HttpClientPropagatorLayer::new())
.service(inner)
});
use opentelemetry::trace::{TraceContextExt, Tracer};
let tracer = opentelemetry::global::tracer("test");
let span = tracer.start("client-span");
let cx = opentelemetry::Context::current_with_span(span);
let _guard = cx.attach();
assert!(service.poll_ready().is_ready());
let response_fut = service.call(Request::builder().body(()).unwrap());
let response_handle = tokio::spawn(response_fut);
let (req, send_response) = handle.next_request().await.unwrap();
let traceparent = req
.headers()
.get("traceparent")
.expect("traceparent header should be injected");
let traceparent = traceparent.to_str().unwrap();
assert!(
traceparent.starts_with("00-"),
"traceparent should be in W3C format: {traceparent}"
);
send_response.send_response(Response::new("ok"));
response_handle.await.unwrap().unwrap();
drop(_guard);
assert_spans_snapshot!(ctx, @r#"
- name: client-span
span_kind: Internal
is_sampled: true
"#);
}
#[tokio::test]
async fn test_server_and_client_propagation_roundtrip() {
let ctx = TelemetryContext::new();
let captured_traceparent = std::sync::Arc::new(std::sync::Mutex::new(None));
let captured_clone = captured_traceparent.clone();
let subrequest_service = tower::service_fn(move |req: Request<()>| {
let traceparent = req
.headers()
.get("traceparent")
.map(|v| v.to_str().unwrap().to_string());
*captured_clone.lock().unwrap() = traceparent;
async { Ok::<_, std::convert::Infallible>(Response::new("subrequest-response")) }
});
let subrequest_service = ServiceBuilder::new()
.layer(HttpClientPropagatorLayer::new())
.service(subrequest_service);
let subrequest_service = std::sync::Arc::new(tokio::sync::Mutex::new(subrequest_service));
let make_subrequest = tower::service_fn(move |_req: Request<()>| {
let svc = subrequest_service.clone();
async move {
let subreq = Request::builder().body(()).unwrap();
let mut guard = svc.lock().await;
guard.ready().await.unwrap();
guard.call(subreq).await.unwrap();
Ok::<_, std::convert::Infallible>(Response::new("ok"))
}
});
let mut service = ServiceBuilder::new()
.layer(HttpServerPropagatorLayer::new())
.traced(test_scope(), |_req: &Request<()>| {
SpanBuilder::from_name("handle-request")
.with_kind(opentelemetry::trace::SpanKind::Server)
})
.service(make_subrequest);
let incoming_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
let req = Request::builder()
.header("traceparent", incoming_traceparent)
.body(())
.unwrap();
service.ready().await.unwrap();
service.call(req).await.unwrap();
let outgoing_traceparent = captured_traceparent.lock().unwrap();
assert!(
outgoing_traceparent.is_some(),
"subrequest should have traceparent header"
);
let outgoing = outgoing_traceparent.as_ref().unwrap();
let incoming_trace_id = incoming_traceparent.split('-').nth(1).unwrap();
let outgoing_trace_id = outgoing.split('-').nth(1).unwrap();
assert_eq!(
incoming_trace_id, outgoing_trace_id,
"trace ID should be propagated through the service"
);
assert_spans_snapshot!(ctx, @r#"
- name: handle-request
span_kind: Server
has_parent: true
is_sampled: true
"#);
}
#[tokio::test]
async fn test_client_propagator_captures_lazy_span_context() {
let _ctx = TelemetryContext::new();
let (mut service, mut handle) = tower_test::mock::spawn_with(|inner| {
ServiceBuilder::new()
.traced(test_scope(), |_req: &Request<()>| {
SpanBuilder::from_name("lazy-client-span")
.with_kind(opentelemetry::trace::SpanKind::Client)
})
.layer(HttpClientPropagatorLayer::new())
.service(inner)
});
assert!(service.poll_ready().is_ready());
let response_fut = service.call(Request::builder().body(()).unwrap());
let response_handle = tokio::spawn(response_fut);
let (req, send_response) = handle.next_request().await.unwrap();
let traceparent = req
.headers()
.get("traceparent")
.expect("traceparent should be injected even with lazy span creation");
let tp = traceparent.to_str().unwrap();
let parent_id = tp.split('-').nth(2).unwrap();
assert_ne!(
parent_id, "0000000000000000",
"parent span ID should not be empty - lazy span should be captured"
);
send_response.send_response(Response::new("ok"));
response_handle.await.unwrap().unwrap();
}
}