use crate::metrics::route_formatter::RouteFormatter;
use actix_http::header::HeaderMap;
use actix_service::{Service, Transform};
use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::Error;
use futures::future::{ok, Ready};
use futures::FutureExt as _;
use futures_core::future::LocalBoxFuture;
use opentelemetry::global;
use opentelemetry::propagation::Extractor;
use opentelemetry::trace::{FutureExt, SpanKind, Status, TraceContextExt, Tracer, TracerProvider};
use opentelemetry_semantic_conventions::trace::HTTP_RESPONSE_STATUS_CODE;
use std::borrow::Cow;
use std::rc::Rc;
use std::task::Poll;
#[derive(Default, Debug)]
pub struct RequestTracing {
route_formatter: Option<Rc<dyn RouteFormatter + 'static>>,
}
impl RequestTracing {
pub fn new() -> RequestTracing {
RequestTracing::default()
}
pub fn with_formatter<T: RouteFormatter + 'static>(route_formatter: T) -> Self {
RequestTracing {
route_formatter: Some(Rc::new(route_formatter)),
}
}
}
impl<S, B> Transform<S, ServiceRequest> for RequestTracing
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Transform = RequestTracingMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(RequestTracingMiddleware::new(
global::tracer_provider().versioned_tracer(
"unleash-edge",
Some(env!("CARGO_PKG_VERSION")),
Some("https://opentelemetry.io/schema/1.0.0"),
None,
),
service,
self.route_formatter.clone(),
))
}
}
#[derive(Debug)]
pub struct RequestTracingMiddleware<S> {
tracer: global::BoxedTracer,
service: S,
route_formatter: Option<Rc<dyn RouteFormatter>>,
}
impl<S, B> RequestTracingMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
fn new(
tracer: global::BoxedTracer,
service: S,
route_formatter: Option<Rc<dyn RouteFormatter>>,
) -> Self {
RequestTracingMiddleware {
tracer,
service,
route_formatter,
}
}
}
impl<S, B> Service<ServiceRequest> for RequestTracingMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, mut req: ServiceRequest) -> Self::Future {
let parent_context = global::get_text_map_propagator(|propagator| {
propagator.extract(&RequestHeaderCarrier::new(req.headers_mut()))
});
let mut http_route: Cow<'static, str> = req
.match_pattern()
.map(Into::into)
.unwrap_or_else(|| "default".into());
if let Some(formatter) = &self.route_formatter {
http_route = formatter.format(&http_route).into();
}
let mut builder = self.tracer.span_builder(http_route.clone());
builder.span_kind = Some(SpanKind::Server);
builder.attributes = Some(
crate::metrics::actix_web_metrics::trace_attributes_from_request(&req, &http_route),
);
let span = self.tracer.build_with_context(builder, &parent_context);
let cx = parent_context.with_span(span);
#[cfg(feature = "sync-middleware")]
let attachment = cx.clone().attach();
let fut = self
.service
.call(req)
.with_context(cx.clone())
.map(move |res| match res {
Ok(ok_res) => {
let span = cx.span();
span.set_attribute(
HTTP_RESPONSE_STATUS_CODE.i64(ok_res.status().as_u16() as i64),
);
if ok_res.status().is_server_error() {
span.set_status(Status::error(
ok_res
.status()
.canonical_reason()
.map(ToString::to_string)
.unwrap_or_default(),
));
};
span.end();
Ok(ok_res)
}
Err(err) => {
let span = cx.span();
span.set_status(Status::error(format!("{:?}", err)));
span.end();
Err(err)
}
});
#[cfg(feature = "sync-middleware")]
drop(attachment);
Box::pin(fut)
}
}
struct RequestHeaderCarrier<'a> {
headers: &'a HeaderMap,
}
impl<'a> RequestHeaderCarrier<'a> {
fn new(headers: &'a HeaderMap) -> Self {
RequestHeaderCarrier { headers }
}
}
impl<'a> Extractor for RequestHeaderCarrier<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.headers.get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.headers.keys().map(|header| header.as_str()).collect()
}
}