use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use http::{Request, Response};
use opentelemetry::global;
use opentelemetry_http::HeaderExtractor;
use pin_project::pin_project;
use tower::{Layer, Service};
use tracing::{debug_span, instrument::Instrumented, Instrument, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
#[derive(Clone)]
pub struct ExtractPropagationLayer;
impl<S> Layer<S> for ExtractPropagationLayer {
type Service = ExtractPropagation<S>;
fn layer(&self, inner: S) -> Self::Service {
ExtractPropagation { inner }
}
}
#[derive(Clone)]
pub struct ExtractPropagation<S> {
inner: S,
}
#[pin_project]
pub struct ExtractPropagationFuture<F> {
#[pin]
response_future: F,
}
impl<F, Body, Error> Future for ExtractPropagationFuture<F>
where
F: Future<Output = Result<Response<Body>, Error>>,
{
type Output = Result<Response<Body>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.response_future.poll(cx) {
Poll::Ready(result) => match result {
Ok(response) => {
Span::current().record("http.status_code", response.status().as_u16());
Poll::Ready(Ok(response))
}
other => Poll::Ready(other),
},
Poll::Pending => Poll::Pending,
}
}
}
impl<S, Body, ResponseBody> Service<Request<Body>> for ExtractPropagation<S>
where
S: Service<Request<Body>, Response = Response<ResponseBody>> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = ExtractPropagationFuture<Instrumented<S::Future>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let span = debug_span!(
"request",
http.uri = %req.uri(),
http.method = %req.method(),
http.status_code = tracing::field::Empty,
);
let parent_context = global::get_text_map_propagator(|propagator| {
propagator.extract(&HeaderExtractor(req.headers()))
});
span.set_parent(parent_context);
let response_future = self.inner.call(req).instrument(span);
ExtractPropagationFuture { response_future }
}
}