use http::{Request, Response};
use pin_project_lite::pin_project;
use std::{
error::Error,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{Layer, Service};
use tracing::Span;
use tracing_opentelemetry_instrumentation_sdk::http as otel_http;
pub type Filter = fn(&str) -> bool;
pub type AsStr<T> = fn(&T) -> &str;
#[derive(Debug, Clone)]
pub struct OtelAxumLayer<P> {
matched_path_as_str: AsStr<P>,
filter: Option<Filter>,
inject_context: bool,
}
impl<P> OtelAxumLayer<P> {
pub fn new(matched_path_as_str: AsStr<P>) -> Self {
OtelAxumLayer {
matched_path_as_str,
filter: None,
inject_context: false,
}
}
pub fn filter(self, filter: Filter) -> Self {
OtelAxumLayer {
filter: Some(filter),
..self
}
}
pub fn inject_context(self, inject_context: bool) -> Self {
OtelAxumLayer {
inject_context,
..self
}
}
}
impl<S, P> Layer<S> for OtelAxumLayer<P> {
type Service = OtelAxumService<S, P>;
fn layer(&self, inner: S) -> Self::Service {
OtelAxumService {
inner,
matched_path_as_str: self.matched_path_as_str,
filter: self.filter,
inject_context: self.inject_context,
}
}
}
#[derive(Debug, Clone)]
pub struct OtelAxumService<S, P> {
inner: S,
matched_path_as_str: AsStr<P>,
filter: Option<Filter>,
inject_context: bool,
}
impl<S, B, B2, P> Service<Request<B>> for OtelAxumService<S, P>
where
S: Service<Request<B>, Response = Response<B2>> + Clone + Send + 'static,
S::Error: Error + 'static, S::Future: Send + 'static,
B: Send + 'static,
P: Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
use tracing_opentelemetry::OpenTelemetrySpanExt;
let span = if self.filter.is_none_or(|f| f(req.uri().path())) {
let span = otel_http::http_server::make_span_from_request(&req);
let matched_path = req.extensions().get::<P>();
let route = matched_path.map_or("", self.matched_path_as_str);
let method = req.method();
span.record("http.route", route);
span.record("otel.name", format!("{method} {route}").trim());
if let Err(err) = span.set_parent(otel_http::extract_context(req.headers())) {
tracing::warn!(?err, "span context cannot be set");
};
span
} else {
tracing::Span::none()
};
let future = {
let _ = span.enter();
self.inner.call(req)
};
ResponseFuture {
inner: future,
inject_context: self.inject_context,
span,
}
}
}
pin_project! {
pub struct ResponseFuture<F> {
#[pin]
pub(crate) inner: F,
pub(crate) inject_context: bool,
pub(crate) span: Span,
}
}
impl<Fut, ResBody, E> Future for ResponseFuture<Fut>
where
Fut: Future<Output = Result<Response<ResBody>, E>>,
E: std::error::Error + 'static,
{
type Output = Result<Response<ResBody>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let _guard = this.span.enter();
let mut result = futures_util::ready!(this.inner.poll(cx));
otel_http::http_server::update_span_from_response_or_error(this.span, &result);
if *this.inject_context
&& let Ok(response) = result.as_mut()
{
otel_http::inject_context(
&tracing_opentelemetry_instrumentation_sdk::find_current_context(),
response.headers_mut(),
);
}
Poll::Ready(result)
}
}