mas_tower/
trace_context.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use http::Request;
16use opentelemetry::propagation::Injector;
17use opentelemetry_http::HeaderInjector;
18use tower::{Layer, Service};
19use tracing::Span;
20use tracing_opentelemetry::OpenTelemetrySpanExt;
21
22/// A trait to get an [`Injector`] from a request.
23trait AsInjector {
24    type Injector<'a>: Injector
25    where
26        Self: 'a;
27
28    fn as_injector(&mut self) -> Self::Injector<'_>;
29}
30
31impl<B> AsInjector for Request<B> {
32    type Injector<'a> = HeaderInjector<'a> where Self: 'a;
33
34    fn as_injector(&mut self) -> Self::Injector<'_> {
35        HeaderInjector(self.headers_mut())
36    }
37}
38
39/// A [`Layer`] that adds a trace context to the request.
40#[derive(Debug, Clone, Copy, Default)]
41pub struct TraceContextLayer {
42    _private: (),
43}
44
45impl TraceContextLayer {
46    /// Create a new [`TraceContextLayer`].
47    #[must_use]
48    pub fn new() -> Self {
49        Self::default()
50    }
51}
52
53impl<S> Layer<S> for TraceContextLayer {
54    type Service = TraceContextService<S>;
55
56    fn layer(&self, inner: S) -> Self::Service {
57        TraceContextService::new(inner)
58    }
59}
60
61/// A [`Service`] that adds a trace context to the request.
62#[derive(Debug, Clone)]
63pub struct TraceContextService<S> {
64    inner: S,
65}
66
67impl<S> TraceContextService<S> {
68    /// Create a new [`TraceContextService`].
69    pub fn new(inner: S) -> Self {
70        Self { inner }
71    }
72}
73
74impl<S, R> Service<R> for TraceContextService<S>
75where
76    S: Service<R>,
77    R: AsInjector,
78{
79    type Response = S::Response;
80    type Error = S::Error;
81    type Future = S::Future;
82
83    fn poll_ready(
84        &mut self,
85        cx: &mut std::task::Context<'_>,
86    ) -> std::task::Poll<Result<(), Self::Error>> {
87        self.inner.poll_ready(cx)
88    }
89
90    fn call(&mut self, mut req: R) -> Self::Future {
91        // Get the `opentelemetry` context out of the `tracing` span.
92        let context = Span::current().context();
93
94        // Inject the trace context into the request. The block is there to ensure that
95        // the injector is dropped before calling the inner service, to avoid borrowing
96        // issues.
97        {
98            let mut injector = req.as_injector();
99            opentelemetry::global::get_text_map_propagator(|propagator| {
100                propagator.inject_context(&context, &mut injector);
101            });
102        }
103
104        self.inner.call(req)
105    }
106}