1use std::{borrow::Cow, pin::Pin};
7
8use actix_web::{
9 dev::{Service, ServiceRequest, ServiceResponse, Transform},
10 http::{
11 header::{self, HeaderMap, CONTENT_LENGTH},
12 Version,
13 },
14 Error,
15};
16use futures_util::{
17 future::{ok, Ready},
18 Future, FutureExt as _,
19};
20use highlightio::Highlight;
21use opentelemetry::{
22 global,
23 propagation::Extractor,
24 trace::{FutureExt as _, SpanKind, Status, TraceContextExt, Tracer as _, TracerProvider as _},
25 KeyValue,
26};
27use opentelemetry_semantic_conventions::trace::{
28 CLIENT_ADDRESS, CLIENT_SOCKET_ADDRESS, HTTP_REQUEST_BODY_SIZE, HTTP_REQUEST_METHOD,
29 HTTP_RESPONSE_STATUS_CODE, HTTP_ROUTE, NETWORK_PROTOCOL_VERSION, SERVER_ADDRESS, SERVER_PORT,
30 URL_PATH, URL_QUERY, URL_SCHEME, USER_AGENT_ORIGINAL,
31};
32
33pub mod highlight {
34 pub use highlightio::*;
35}
36
37struct RequestHeaderCarrier<'a> {
38 headers: &'a HeaderMap,
39}
40
41impl<'a> RequestHeaderCarrier<'a> {
42 fn new(headers: &'a HeaderMap) -> Self {
43 RequestHeaderCarrier { headers }
44 }
45}
46
47impl<'a> Extractor for RequestHeaderCarrier<'a> {
48 fn get(&self, key: &str) -> Option<&str> {
49 self.headers.get(key).and_then(|v| v.to_str().ok())
50 }
51
52 fn keys(&self) -> Vec<&str> {
53 self.headers.keys().map(|header| header.as_str()).collect()
54 }
55}
56
57#[derive(Clone)]
58pub struct HighlightActix {
59 highlight: Highlight,
60}
61
62impl HighlightActix {
63 pub fn new(h: &Highlight) -> Self {
64 HighlightActix {
65 highlight: h.clone(),
66 }
67 }
68}
69
70pub struct HighlightMiddleware<S> {
71 tracer: global::BoxedTracer,
72 service: S,
73 inner: HighlightActix,
74}
75
76impl<S, B> Transform<S, ServiceRequest> for HighlightActix
77where
78 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
79 S::Future: 'static,
80{
81 type Response = ServiceResponse<B>;
82 type Error = Error;
83 type Transform = HighlightMiddleware<S>;
84 type InitError = ();
85 type Future = Ready<Result<Self::Transform, Self::InitError>>;
86
87 fn new_transform(&self, service: S) -> Self::Future {
88 ok(HighlightMiddleware {
89 tracer: global::tracer_provider().versioned_tracer(
90 "highlight-actix",
91 Some(env!("CARGO_PKG_VERSION")),
92 Some(opentelemetry_semantic_conventions::SCHEMA_URL),
93 None,
94 ),
95 service,
96 inner: self.clone(),
97 })
98 }
99}
100
101fn req_to_attrs(req: &ServiceRequest, http_route: &str, h: &Highlight) -> Vec<KeyValue> {
102 let mut attributes: Vec<KeyValue> = Vec::with_capacity(16);
103
104 let conn_info = req.connection_info();
105 let remote_addr = conn_info.realip_remote_addr();
106
107 attributes.push(KeyValue::new(HTTP_ROUTE, http_route.to_owned()));
108
109 if let Some(remote_addr) = remote_addr {
110 attributes.push(KeyValue::new(CLIENT_ADDRESS, remote_addr.to_string()));
111 }
112
113 if let Some(peer_addr) = req.peer_addr().map(|socket| socket.ip().to_string()) {
114 if Some(peer_addr.as_str()) != remote_addr {
115 attributes.push(KeyValue::new(CLIENT_SOCKET_ADDRESS, peer_addr));
117 }
118 }
119 let mut host_parts = conn_info.host().split_terminator(':');
120 if let Some(host) = host_parts.next() {
121 attributes.push(KeyValue::new(SERVER_ADDRESS, host.to_string()));
122 }
123 if let Some(port) = host_parts.next().and_then(|port| port.parse::<i64>().ok()) {
124 if port != 80 && port != 443 {
125 attributes.push(KeyValue::new(SERVER_PORT, port));
126 }
127 }
128 if let Some(path_query) = req.uri().path_and_query() {
129 if path_query.path() != "/" {
130 attributes.push(KeyValue::new(URL_PATH, path_query.path().to_string()));
131 }
132 if let Some(query) = path_query.query() {
133 attributes.push(KeyValue::new(URL_QUERY, query.to_string()));
134 }
135 }
136 attributes.push(KeyValue::new(URL_SCHEME, conn_info.scheme().to_owned()));
137
138 attributes.push(KeyValue::new(
139 HTTP_REQUEST_METHOD,
140 req.method().as_str().to_owned(),
141 ));
142 attributes.push(KeyValue::new::<_, String>(
143 NETWORK_PROTOCOL_VERSION,
144 match req.version() {
145 Version::HTTP_09 => "0.9".into(),
146 Version::HTTP_10 => "1.0".into(),
147 Version::HTTP_11 => "1.1".into(),
148 Version::HTTP_2 => "2".into(),
149 Version::HTTP_3 => "3".into(),
150 other => format!("{:?}", other).into(),
151 },
152 ));
153
154 if let Some(size) = req
155 .headers()
156 .get(CONTENT_LENGTH)
157 .and_then(|len| len.to_str().ok().and_then(|s| s.parse::<i64>().ok()))
158 .filter(|&len| len > 0)
159 {
160 attributes.push(KeyValue::new(HTTP_REQUEST_BODY_SIZE, size));
161 }
162
163 if let Some(ua) = req
164 .headers()
165 .get(header::USER_AGENT)
166 .and_then(|s| s.to_str().ok())
167 {
168 attributes.push(KeyValue::new(USER_AGENT_ORIGINAL, ua.to_string()));
169 }
170
171 attributes.push(KeyValue::new("highlight.project_id", h.project_id()));
172
173 if let Some(hr) = req.headers().get("x-highlight-request") {
174 if let Some((session_id, trace_id)) = hr.to_str().ok().and_then(|x| x.split_once("/")) {
175 attributes.push(KeyValue::new("highlight.session_id", session_id.to_owned()));
176 attributes.push(KeyValue::new("highlight.trace_id", trace_id.to_owned()));
177 }
178 }
179
180 attributes
181}
182
183impl<S, B> Service<ServiceRequest> for HighlightMiddleware<S>
184where
185 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
186 S::Future: 'static,
187{
188 type Response = ServiceResponse<B>;
189 type Error = Error;
190 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
191
192 fn poll_ready(
193 &self,
194 ctx: &mut core::task::Context<'_>,
195 ) -> std::task::Poll<Result<(), Self::Error>> {
196 let res = self.service.poll_ready(ctx);
197 res
198 }
199
200 fn call(&self, mut req: ServiceRequest) -> Self::Future {
201 let parent_context = global::get_text_map_propagator(|propagator| {
202 propagator.extract(&RequestHeaderCarrier::new(req.headers_mut()))
203 });
204
205 let http_route: Cow<'static, str> = req
206 .match_pattern()
207 .map(Into::into)
208 .unwrap_or_else(|| "default".into());
209
210 let mut builder = self.tracer.span_builder(http_route.clone());
211 builder.span_kind = Some(SpanKind::Server);
212 builder.attributes = Some(req_to_attrs(&req, &http_route, &self.inner.highlight));
213
214 let span = self.tracer.build_with_context(builder, &parent_context);
215 let cx = parent_context.with_span(span);
216
217 let fut = self
218 .service
219 .call(req)
220 .with_context(cx.clone())
221 .map(move |res| match res {
222 Ok(ok_res) => {
223 let span = cx.span();
224 span.set_attribute(
225 HTTP_RESPONSE_STATUS_CODE.i64(ok_res.status().as_u16() as i64),
226 );
227 if ok_res.status().is_server_error() {
228 if let Some(e) = ok_res.response().error() {
229 span.record_error(&e);
230 }
231
232
233 span.set_status(Status::error(
234 ok_res
235 .status()
236 .canonical_reason()
237 .map(ToString::to_string)
238 .unwrap_or_default(),
239 ));
240 };
241 span.end();
242 Ok(ok_res)
243 }
244 Err(err) => {
245 let span = cx.span();
246 span.record_error(&err);
247 span.set_status(Status::error(format!("{:?}", err)));
248 span.end();
249 Err(err)
250 }
251 });
252
253 Box::pin(fut)
254 }
255}