1use std::fmt::Display;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use std::time::SystemTime;
5
6use futures::{Future, FutureExt};
7use http::{header, Request, Response};
8use http_body::Body;
9use opentelemetry::trace::{FutureExt as OtelFutureExt, Status, TraceContextExt, Tracer};
10use opentelemetry::{global, Context as OtelContext, Key};
11use opentelemetry_datadog::new_pipeline;
12use opentelemetry_semantic_conventions::trace::{
13 HTTP_CLIENT_IP, HTTP_FLAVOR, HTTP_HOST, HTTP_METHOD, HTTP_SCHEME, HTTP_STATUS_CODE, HTTP_URL,
14 HTTP_USER_AGENT,
15};
16use tower::{Layer, Service};
17
18pub use opentelemetry_datadog::ApiVersion;
19
20pub fn init(service_name: &str, endpoint: &str, version: ApiVersion) {
22 let _tracer = new_pipeline()
23 .with_service_name(service_name)
24 .with_version(version)
25 .with_agent_endpoint(endpoint)
26 .install_batch(opentelemetry::runtime::Tokio)
27 .expect("failed to initialize tracing pipeline");
28}
29
30#[derive(Clone, Debug)]
31pub struct DDTraceLayer {
32 operation: String,
33}
34
35impl DDTraceLayer {
36 pub fn new(operation: String) -> DDTraceLayer {
37 DDTraceLayer { operation }
38 }
39}
40
41impl<S> Layer<S> for DDTraceLayer {
42 type Service = DDTrace<S>;
43
44 fn layer(&self, inner: S) -> Self::Service {
45 DDTrace::new(inner, &self.operation[..])
46 }
47}
48
49#[derive(Clone, Debug)]
50pub struct DDTrace<S> {
51 inner: S,
52 operation: String,
53}
54
55impl<S> DDTrace<S> {
56 pub fn new(inner: S, operation: &str) -> Self {
57 DDTrace {
58 inner,
59 operation: operation.to_string(),
60 }
61 }
62}
63
64impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for DDTrace<S>
65where
66 S: Service<Request<ReqBody>, Response = Response<ResBody>> + 'static,
67 S::Error: Display + 'static,
68 S::Future: Send,
69 ReqBody: 'static,
70 ResBody: Body + 'static,
71{
72 type Response = S::Response;
73 type Error = S::Error;
74 #[allow(clippy::type_complexity)]
75 type Future =
76 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
77
78 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
79 self.inner.poll_ready(cx)
80 }
81
82 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
83 let method = req.method().to_string();
84 let path = req.uri().path().to_owned();
85 let url = req.uri().to_owned().to_string();
86 let version = format!("{:?}", req.version());
87 let user_agent = req
88 .headers()
89 .get(header::USER_AGENT)
90 .map_or("", |v| v.to_str().unwrap_or(""))
91 .to_string();
92 let host = req
93 .headers()
94 .get(header::HOST)
95 .map_or("", |v| v.to_str().unwrap_or(""))
96 .to_string();
97 let scheme = req
98 .uri()
99 .scheme()
100 .map_or_else(|| "http".to_string(), |v| v.to_string());
101 let client_ip = parse_x_forwarded_for(req.headers())
102 .unwrap_or("")
103 .to_string();
104
105 let operation = self.operation.clone();
106 let start_time = SystemTime::now();
107
108 let tracer = global::tracer(operation);
109 let span = tracer
110 .span_builder(path)
111 .with_attributes(vec![
112 HTTP_URL.string(url),
113 HTTP_METHOD.string(method),
114 HTTP_FLAVOR.string(version),
115 HTTP_USER_AGENT.string(user_agent),
116 HTTP_HOST.string(host),
117 HTTP_SCHEME.string(scheme),
118 HTTP_CLIENT_IP.string(client_ip),
119 ])
120 .with_start_time(start_time)
121 .start(&tracer);
122
123 let cx = OtelContext::current_with_span(span);
124 let fut = self
125 .inner
126 .call(req)
127 .with_context(cx.clone())
128 .map(move |res| match res {
129 Ok(ok_res) => {
130 let span = cx.span();
131 span.set_attribute(HTTP_STATUS_CODE.i64(ok_res.status().as_u16().into()));
132 if ok_res.status().is_server_error() {
133 span.set_status(Status::error(
134 ok_res
135 .status()
136 .canonical_reason()
137 .map(|s| s.to_string())
138 .unwrap_or_default(),
139 ));
140 }
141 span.end();
142 Ok(ok_res)
143 }
144 Err(err_res) => {
145 let span = cx.span();
146 span.set_attribute(HTTP_STATUS_CODE.i64(500));
147 span.set_attribute(Key::new("error.msg").string(err_res.to_string()));
148 span.set_status(Status::error(err_res.to_string()));
149 span.end();
150 Err(err_res)
151 }
152 });
153 Box::pin(fut)
154 }
155}
156
157fn parse_x_forwarded_for(headers: &header::HeaderMap) -> Option<&str> {
158 let v = headers.get("X-Forwarded-For")?;
159 let v = v.to_str().ok()?;
160 let mut ips = v.split(',');
161 Some(ips.next()?.trim())
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_parse_x_forwarded_for() {
170 let mut map = header::HeaderMap::new();
171 map.insert(
172 "X-Forwarded-For",
173 "203.0.113.195, 203.0.113.194, 203.0.113.193"
174 .parse()
175 .unwrap(),
176 );
177
178 assert_eq!(parse_x_forwarded_for(&map), Some("203.0.113.195"));
179 }
180}