actix_web_opentelemetry/middleware/
trace.rs

1use std::{borrow::Cow, rc::Rc, task::Poll};
2
3use actix_web::{
4    dev::{Service, ServiceRequest, ServiceResponse, Transform},
5    http::header::HeaderMap,
6    Error,
7};
8use futures_util::future::{ok, FutureExt as _, LocalBoxFuture, Ready};
9use opentelemetry::{
10    global::{self},
11    propagation::Extractor,
12    trace::{FutureExt as OtelFutureExt, SpanKind, Status, TraceContextExt, Tracer},
13    KeyValue,
14};
15use opentelemetry_semantic_conventions::trace::HTTP_RESPONSE_STATUS_CODE;
16
17use super::{get_scope, route_formatter::RouteFormatter};
18use crate::util::trace_attributes_from_request;
19
20/// Request tracing middleware.
21///
22/// # Examples:
23///
24/// ```no_run
25/// use actix_web::{web, App, HttpServer};
26/// use actix_web_opentelemetry::RequestTracing;
27/// use opentelemetry::global;
28/// use opentelemetry_sdk::trace::SdkTracerProvider;
29///
30/// async fn index() -> &'static str {
31///     "Hello world!"
32/// }
33///
34/// #[actix_web::main]
35/// async fn main() -> std::io::Result<()> {
36///     // Install an OpenTelemetry trace pipeline.
37///     // Swap for https://docs.rs/opentelemetry-jaeger or other compatible
38///     // exporter to send trace information to your collector.
39///     let exporter = opentelemetry_stdout::SpanExporter::default();
40///
41///     // Configure your tracer provider with your exporter(s)
42///     let provider = SdkTracerProvider::builder()
43///         .with_simple_exporter(exporter)
44///         .build();
45///     global::set_tracer_provider(provider);
46///
47///     HttpServer::new(|| {
48///         App::new()
49///             .wrap(RequestTracing::new())
50///             .service(web::resource("/").to(index))
51///     })
52///     .bind("127.0.0.1:8080")?
53///     .run()
54///     .await
55/// }
56///```
57#[derive(Default, Debug)]
58pub struct RequestTracing {
59    route_formatter: Option<Rc<dyn RouteFormatter + 'static>>,
60}
61
62impl RequestTracing {
63    /// Actix web middleware to trace each request in an OpenTelemetry span.
64    pub fn new() -> RequestTracing {
65        RequestTracing::default()
66    }
67
68    /// Actix web middleware to trace each request in an OpenTelemetry span with
69    /// formatted routes.
70    ///
71    /// # Examples
72    ///
73    /// ```no_run
74    /// use actix_web::{web, App, HttpServer};
75    /// use actix_web_opentelemetry::{RouteFormatter, RequestTracing};
76    ///
77    /// # #[actix_web::main]
78    /// # async fn main() -> std::io::Result<()> {
79    ///
80    ///
81    /// #[derive(Debug)]
82    /// struct MyLowercaseFormatter;
83    ///
84    /// impl RouteFormatter for MyLowercaseFormatter {
85    ///     fn format(&self, path: &str) -> String {
86    ///         path.to_lowercase()
87    ///     }
88    /// }
89    ///
90    /// // report /users/{id} as /users/:id
91    /// HttpServer::new(move || {
92    ///     App::new()
93    ///         .wrap(RequestTracing::with_formatter(MyLowercaseFormatter))
94    ///         .service(web::resource("/users/{id}").to(|| async { "ok" }))
95    /// })
96    /// .bind("127.0.0.1:8080")?
97    /// .run()
98    /// .await
99    /// # }
100    /// ```
101    pub fn with_formatter<T: RouteFormatter + 'static>(route_formatter: T) -> Self {
102        RequestTracing {
103            route_formatter: Some(Rc::new(route_formatter)),
104        }
105    }
106}
107
108impl<S, B> Transform<S, ServiceRequest> for RequestTracing
109where
110    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
111    S::Future: 'static,
112    B: 'static,
113{
114    type Response = ServiceResponse<B>;
115    type Error = Error;
116    type Transform = RequestTracingMiddleware<S>;
117    type InitError = ();
118    type Future = Ready<Result<Self::Transform, Self::InitError>>;
119
120    fn new_transform(&self, service: S) -> Self::Future {
121        ok(RequestTracingMiddleware::new(
122            global::tracer_with_scope(get_scope()),
123            service,
124            self.route_formatter.clone(),
125        ))
126    }
127}
128
129/// Request tracing middleware
130#[derive(Debug)]
131pub struct RequestTracingMiddleware<S> {
132    tracer: global::BoxedTracer,
133    service: S,
134    route_formatter: Option<Rc<dyn RouteFormatter>>,
135}
136
137impl<S, B> RequestTracingMiddleware<S>
138where
139    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
140    S::Future: 'static,
141    B: 'static,
142{
143    fn new(
144        tracer: global::BoxedTracer,
145        service: S,
146        route_formatter: Option<Rc<dyn RouteFormatter>>,
147    ) -> Self {
148        RequestTracingMiddleware {
149            tracer,
150            service,
151            route_formatter,
152        }
153    }
154}
155
156impl<S, B> Service<ServiceRequest> for RequestTracingMiddleware<S>
157where
158    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
159    S::Future: 'static,
160    B: 'static,
161{
162    type Response = ServiceResponse<B>;
163    type Error = Error;
164    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
165
166    fn poll_ready(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
167        self.service.poll_ready(cx)
168    }
169
170    fn call(&self, mut req: ServiceRequest) -> Self::Future {
171        let parent_context = global::get_text_map_propagator(|propagator| {
172            propagator.extract(&RequestHeaderCarrier::new(req.headers_mut()))
173        });
174        let mut http_route: Cow<'static, str> = req
175            .match_pattern()
176            .map(Into::into)
177            .unwrap_or_else(|| "default".into());
178        if let Some(formatter) = &self.route_formatter {
179            http_route = formatter.format(&http_route).into();
180        }
181
182        let mut builder = self.tracer.span_builder(http_route.clone());
183        builder.span_kind = Some(SpanKind::Server);
184        builder.attributes = Some(trace_attributes_from_request(&req, &http_route));
185
186        let span = self.tracer.build_with_context(builder, &parent_context);
187        let cx = parent_context.with_span(span);
188
189        #[cfg(feature = "sync-middleware")]
190        let attachment = cx.clone().attach();
191
192        let fut = self
193            .service
194            .call(req)
195            .with_context(cx.clone())
196            .map(move |res| match res {
197                Ok(ok_res) => {
198                    let span = cx.span();
199                    span.set_attribute(KeyValue::new(
200                        HTTP_RESPONSE_STATUS_CODE,
201                        ok_res.status().as_u16() as i64,
202                    ));
203                    if ok_res.status().is_server_error() {
204                        span.set_status(Status::error(
205                            ok_res
206                                .status()
207                                .canonical_reason()
208                                .map(ToString::to_string)
209                                .unwrap_or_default(),
210                        ));
211                    };
212                    span.end();
213                    Ok(ok_res)
214                }
215                Err(err) => {
216                    let span = cx.span();
217                    span.set_status(Status::error(format!("{:?}", err)));
218                    span.end();
219                    Err(err)
220                }
221            });
222
223        #[cfg(feature = "sync-middleware")]
224        drop(attachment);
225
226        Box::pin(fut)
227    }
228}
229
230struct RequestHeaderCarrier<'a> {
231    headers: &'a HeaderMap,
232}
233
234impl<'a> RequestHeaderCarrier<'a> {
235    fn new(headers: &'a HeaderMap) -> Self {
236        RequestHeaderCarrier { headers }
237    }
238}
239
240impl Extractor for RequestHeaderCarrier<'_> {
241    fn get(&self, key: &str) -> Option<&str> {
242        self.headers.get(key).and_then(|v| v.to_str().ok())
243    }
244
245    fn keys(&self) -> Vec<&str> {
246        self.headers.keys().map(|header| header.as_str()).collect()
247    }
248}