dd_trace_layer/
lib.rs

1use std::fmt::Display;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use std::time::SystemTime;
5
6use futures::{Future, FutureExt};
7use http::{header, Request, Response};
8use http_body::Body;
9use opentelemetry::trace::{FutureExt as OtelFutureExt, Status, TraceContextExt, Tracer};
10use opentelemetry::{global, Context as OtelContext, Key};
11use opentelemetry_datadog::new_pipeline;
12use opentelemetry_semantic_conventions::trace::{
13    HTTP_CLIENT_IP, HTTP_FLAVOR, HTTP_HOST, HTTP_METHOD, HTTP_SCHEME, HTTP_STATUS_CODE, HTTP_URL,
14    HTTP_USER_AGENT,
15};
16use tower::{Layer, Service};
17
18pub use opentelemetry_datadog::ApiVersion;
19
20/// Initialize the Datadog exporter
21pub fn init(service_name: &str, endpoint: &str, version: ApiVersion) {
22    let _tracer = new_pipeline()
23        .with_service_name(service_name)
24        .with_version(version)
25        .with_agent_endpoint(endpoint)
26        .install_batch(opentelemetry::runtime::Tokio)
27        .expect("failed to initialize tracing pipeline");
28}
29
30#[derive(Clone, Debug)]
31pub struct DDTraceLayer {
32    operation: String,
33}
34
35impl DDTraceLayer {
36    pub fn new(operation: String) -> DDTraceLayer {
37        DDTraceLayer { operation }
38    }
39}
40
41impl<S> Layer<S> for DDTraceLayer {
42    type Service = DDTrace<S>;
43
44    fn layer(&self, inner: S) -> Self::Service {
45        DDTrace::new(inner, &self.operation[..])
46    }
47}
48
49#[derive(Clone, Debug)]
50pub struct DDTrace<S> {
51    inner: S,
52    operation: String,
53}
54
55impl<S> DDTrace<S> {
56    pub fn new(inner: S, operation: &str) -> Self {
57        DDTrace {
58            inner,
59            operation: operation.to_string(),
60        }
61    }
62}
63
64impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for DDTrace<S>
65where
66    S: Service<Request<ReqBody>, Response = Response<ResBody>> + 'static,
67    S::Error: Display + 'static,
68    S::Future: Send,
69    ReqBody: 'static,
70    ResBody: Body + 'static,
71{
72    type Response = S::Response;
73    type Error = S::Error;
74    #[allow(clippy::type_complexity)]
75    type Future =
76        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
77
78    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
79        self.inner.poll_ready(cx)
80    }
81
82    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
83        let method = req.method().to_string();
84        let path = req.uri().path().to_owned();
85        let url = req.uri().to_owned().to_string();
86        let version = format!("{:?}", req.version());
87        let user_agent = req
88            .headers()
89            .get(header::USER_AGENT)
90            .map_or("", |v| v.to_str().unwrap_or(""))
91            .to_string();
92        let host = req
93            .headers()
94            .get(header::HOST)
95            .map_or("", |v| v.to_str().unwrap_or(""))
96            .to_string();
97        let scheme = req
98            .uri()
99            .scheme()
100            .map_or_else(|| "http".to_string(), |v| v.to_string());
101        let client_ip = parse_x_forwarded_for(req.headers())
102            .unwrap_or("")
103            .to_string();
104
105        let operation = self.operation.clone();
106        let start_time = SystemTime::now();
107
108        let tracer = global::tracer(operation);
109        let span = tracer
110            .span_builder(path)
111            .with_attributes(vec![
112                HTTP_URL.string(url),
113                HTTP_METHOD.string(method),
114                HTTP_FLAVOR.string(version),
115                HTTP_USER_AGENT.string(user_agent),
116                HTTP_HOST.string(host),
117                HTTP_SCHEME.string(scheme),
118                HTTP_CLIENT_IP.string(client_ip),
119            ])
120            .with_start_time(start_time)
121            .start(&tracer);
122
123        let cx = OtelContext::current_with_span(span);
124        let fut = self
125            .inner
126            .call(req)
127            .with_context(cx.clone())
128            .map(move |res| match res {
129                Ok(ok_res) => {
130                    let span = cx.span();
131                    span.set_attribute(HTTP_STATUS_CODE.i64(ok_res.status().as_u16().into()));
132                    if ok_res.status().is_server_error() {
133                        span.set_status(Status::error(
134                            ok_res
135                                .status()
136                                .canonical_reason()
137                                .map(|s| s.to_string())
138                                .unwrap_or_default(),
139                        ));
140                    }
141                    span.end();
142                    Ok(ok_res)
143                }
144                Err(err_res) => {
145                    let span = cx.span();
146                    span.set_attribute(HTTP_STATUS_CODE.i64(500));
147                    span.set_attribute(Key::new("error.msg").string(err_res.to_string()));
148                    span.set_status(Status::error(err_res.to_string()));
149                    span.end();
150                    Err(err_res)
151                }
152            });
153        Box::pin(fut)
154    }
155}
156
157fn parse_x_forwarded_for(headers: &header::HeaderMap) -> Option<&str> {
158    let v = headers.get("X-Forwarded-For")?;
159    let v = v.to_str().ok()?;
160    let mut ips = v.split(',');
161    Some(ips.next()?.trim())
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_parse_x_forwarded_for() {
170        let mut map = header::HeaderMap::new();
171        map.insert(
172            "X-Forwarded-For",
173            "203.0.113.195, 203.0.113.194, 203.0.113.193"
174                .parse()
175                .unwrap(),
176        );
177
178        assert_eq!(parse_x_forwarded_for(&map), Some("203.0.113.195"));
179    }
180}