1#![warn(missing_docs)]
29#![allow(clippy::style)]
30
31mod grpc;
32mod headers;
33
34use std::net::IpAddr;
35use core::{cmp, fmt, ptr, task};
36use core::pin::Pin;
37use core::future::Future;
38
39pub use tracing;
40
41pub const REQUEST_ID: http::HeaderName = http::HeaderName::from_static("x-request-id");
43pub type MakeSpan = fn() -> tracing::Span;
45pub type ExtractClientIp = fn(&http::request::Parts) -> Option<IpAddr>;
47
48#[inline]
49fn default_client_ip(_: &http::request::Parts) -> Option<IpAddr> {
50 None
51}
52
53#[derive(Copy, Clone, PartialEq, Eq)]
54pub enum Protocol {
56 Http,
60 Grpc,
62}
63
64impl Protocol {
65 #[inline(always)]
66 pub fn from_content_type(typ: &[u8]) -> Self {
68 if typ.starts_with(b"application/grpc") {
69 Self::Grpc
70 } else {
71 Self::Http
72 }
73 }
74
75 #[inline(always)]
76 pub const fn as_str(&self) -> &'static str {
78 match self {
79 Self::Grpc => "grpc",
80 Self::Http => "http"
81 }
82 }
83}
84
85impl fmt::Debug for Protocol {
86 #[inline(always)]
87 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
88 fmt::Debug::fmt(self.as_str(), fmt)
89 }
90}
91
92impl fmt::Display for Protocol {
93 #[inline(always)]
94 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
95 fmt::Display::fmt(self.as_str(), fmt)
96 }
97}
98
99type RequestIdBuffer = [u8; 64];
100
101#[derive(Clone)]
102pub struct RequestId {
106 buffer: RequestIdBuffer,
107 len: u8,
108}
109
110impl RequestId {
111 fn from_bytes(bytes: &[u8]) -> Self {
112 let mut buffer: RequestIdBuffer = [0; 64];
113
114 let len = cmp::min(buffer.len(), bytes.len());
115
116 unsafe {
117 ptr::copy_nonoverlapping(bytes.as_ptr(), buffer.as_mut_ptr(), len)
118 };
119
120 Self {
121 buffer,
122 len: len as _,
123 }
124 }
125
126 fn from_uuid(uuid: uuid::Uuid) -> Self {
127 let mut buffer: RequestIdBuffer = [0; 64];
128 let uuid = uuid.as_hyphenated();
129 let len = uuid.encode_lower(&mut buffer).len();
130
131 Self {
132 buffer,
133 len: len as _,
134 }
135 }
136
137 #[inline]
138 pub const fn as_bytes(&self) -> &[u8] {
140 unsafe {
141 core::slice::from_raw_parts(self.buffer.as_ptr(), self.len as _)
142 }
143 }
144
145 #[inline(always)]
146 pub const fn as_str(&self) -> Option<&str> {
148 match core::str::from_utf8(self.as_bytes()) {
149 Ok(header) => Some(header),
150 Err(_) => None,
151 }
152 }
153}
154
155impl fmt::Debug for RequestId {
156 #[inline(always)]
157 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
158 match self.as_str() {
159 Some(id) => fmt::Debug::fmt(id, fmt),
160 None => fmt::Debug::fmt(self.as_bytes(), fmt),
161 }
162 }
163}
164
165impl fmt::Display for RequestId {
166 #[inline(always)]
167 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
168 match self.as_str() {
169 Some(id) => fmt::Display::fmt(id, fmt),
170 None => fmt::Display::fmt("<non-utf8>", fmt),
171 }
172 }
173}
174
175#[macro_export]
176macro_rules! make_request_spanner {
211 ($fn:ident($name:literal, $level:expr)) => {
212 $crate::make_request_spanner!($fn($name, $level,));
213 };
214 ($fn:ident($name:literal, $level:expr, $($fields:tt)*)) => {
215 #[track_caller]
216 pub fn $fn() -> $crate::tracing::Span {
217 use $crate::tracing::field;
218
219 $crate::tracing::span!(
220 $level,
221 $name,
222 http.request.method = field::Empty,
224 url.path = field::Empty,
225 url.query = field::Empty,
226 url.scheme = field::Empty,
227 http.request_id = field::Empty,
228 user_agent.original = field::Empty,
229 http.headers = field::Empty,
230 network.protocol.name = field::Empty,
231 network.protocol.version = field::Empty,
232 client.address = field::Empty,
234 http.response.status_code = field::Empty,
236 error.message = field::Empty,
237 $(
238 $fields
239 )*
240 )
241 }
242 };
243}
244
245#[derive(Clone, Debug)]
246pub struct RequestInfo {
250 pub protocol: Protocol,
252 pub request_id: RequestId,
254 pub client_ip: Option<IpAddr>,
256}
257
258pub struct RequestSpan {
262 pub span: tracing::Span,
264 pub info: RequestInfo,
266}
267
268impl RequestSpan {
269 pub fn new(span: tracing::Span, extract_client_ip: ExtractClientIp, parts: &http::request::Parts) -> Self {
271 let _entered = span.enter();
272
273 let client_ip = (extract_client_ip)(parts);
274 let protocol = parts.headers
275 .get(http::header::CONTENT_TYPE)
276 .map_or(Protocol::Http, |content_type| Protocol::from_content_type(content_type.as_bytes()));
277
278 let request_id = if let Some(request_id) = parts.headers.get(REQUEST_ID) {
279 RequestId::from_bytes(request_id.as_bytes())
280 } else {
281 RequestId::from_uuid(uuid::Uuid::new_v4())
282 };
283
284 if let Some(user_agent) = parts.headers.get(http::header::USER_AGENT).and_then(|header| header.to_str().ok()) {
285 span.record("user_agent.original", user_agent);
286 }
287 span.record("http.request.method", parts.method.as_str());
288 span.record("url.path", parts.uri.path());
289 if let Some(query) = parts.uri.query() {
290 span.record("url.query", query);
291 }
292 if let Some(scheme) = parts.uri.scheme() {
293 span.record("url.scheme", scheme.as_str());
294 }
295 if let Some(request_id) = request_id.as_str() {
296 span.record("http.request_id", &request_id);
297 } else {
298 span.record("http.request_id", request_id.as_bytes());
299 }
300 if let Some(client_ip) = client_ip {
301 span.record("client.address", tracing::field::display(client_ip));
302 }
303 span.record("network.protocol.name", protocol.as_str());
304 if let Protocol::Http = protocol {
305 match parts.version {
306 http::Version::HTTP_09 => span.record("network.protocol.version", 0.9),
307 http::Version::HTTP_10 => span.record("network.protocol.version", 1.0),
308 http::Version::HTTP_11 => span.record("network.protocol.version", 1.1),
309 http::Version::HTTP_2 => span.record("network.protocol.version", 2),
310 http::Version::HTTP_3 => span.record("network.protocol.version", 3),
311 _ => span.record("network.protocol.version", 0),
313 };
314 }
315
316 drop(_entered);
317
318 Self {
319 span,
320 info: RequestInfo {
321 protocol,
322 request_id,
323 client_ip
324 }
325 }
326 }
327}
328
329#[derive(Clone)]
330pub struct HttpRequestLayer {
332 make_span: MakeSpan,
333 inspect_headers: &'static [&'static http::HeaderName],
334 extract_client_ip: ExtractClientIp,
335}
336
337impl HttpRequestLayer {
338 #[inline]
339 pub fn new(make_span: MakeSpan) -> Self {
341 Self {
342 make_span,
343 inspect_headers: &[],
344 extract_client_ip: default_client_ip
345 }
346 }
347
348 #[inline]
349 pub fn with_inspect_headers(mut self, inspect_headers: &'static [&'static http::HeaderName]) -> Self {
353 self.inspect_headers = inspect_headers;
354 self
355 }
356
357 pub fn with_extract_client_ip(mut self, extract_client_ip: ExtractClientIp) -> Self {
361 self.extract_client_ip = extract_client_ip;
362 self
363 }
364}
365
366impl<S> tower_layer::Layer<S> for HttpRequestLayer {
367 type Service = HttpRequestService<S>;
368 #[inline(always)]
369 fn layer(&self, inner: S) -> Self::Service {
370 HttpRequestService {
371 layer: self.clone(),
372 inner,
373 }
374 }
375}
376
377pub struct HttpRequestService<S> {
379 layer: HttpRequestLayer,
380 inner: S
381}
382
383impl<ReqBody, ResBody, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>> tower_service::Service<http::Request<ReqBody>> for HttpRequestService<S> where S::Error: std::error::Error {
384 type Response = S::Response;
385 type Error = S::Error;
386 type Future = ResponseFut<S::Future>;
387
388 #[inline(always)]
389 fn poll_ready(&mut self, ctx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
390 self.inner.poll_ready(ctx)
391 }
392
393 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
394 let (parts, body) = req.into_parts();
395 let RequestSpan { span, info } = RequestSpan::new((self.layer.make_span)(), self.layer.extract_client_ip, &parts);
396
397 let _entered = span.enter();
398 if !self.layer.inspect_headers.is_empty() {
399 span.record("http.headers", tracing::field::debug(headers::InspectHeaders {
400 header_list: self.layer.inspect_headers,
401 headers: &parts.headers
402 }));
403 }
404 let request_id = info.request_id.clone();
405 let protocol = info.protocol;
406 let mut req = http::Request::from_parts(parts, body);
407 req.extensions_mut().insert(info);
408 let inner = self.inner.call(req);
409
410 drop(_entered);
411 ResponseFut {
412 inner,
413 span,
414 protocol,
415 request_id
416 }
417 }
418}
419
420pub struct ResponseFut<F> {
422 inner: F,
423 span: tracing::Span,
424 protocol: Protocol,
425 request_id: RequestId,
426}
427
428impl<ResBody, E: std::error::Error, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for ResponseFut<F> {
429 type Output = F::Output;
430
431 fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
432 let (fut, span, protocol, request_id) = unsafe {
433 let this = self.get_unchecked_mut();
434 (
435 Pin::new_unchecked(&mut this.inner),
436 &this.span,
437 this.protocol,
438 &this.request_id,
439 )
440 };
441 let _entered = span.enter();
442 match Future::poll(fut, ctx) {
443 task::Poll::Ready(Ok(mut resp)) => {
444 if let Ok(request_id) = http::HeaderValue::from_bytes(request_id.as_bytes()) {
445 resp.headers_mut().insert(REQUEST_ID, request_id);
446 }
447 let status = match protocol {
448 Protocol::Http => resp.status().as_u16(),
449 Protocol::Grpc => match resp.headers().get("grpc-status") {
450 Some(status) => grpc::parse_grpc_status(status.as_bytes()),
451 None => 2,
452 }
453 };
454 span.record("http.response.status_code", status);
455
456 task::Poll::Ready(Ok(resp))
457 }
458 task::Poll::Ready(Err(error)) => {
459 let status = match protocol {
460 Protocol::Http => 500u16,
461 Protocol::Grpc => 13,
462 };
463 span.record("http.response.status_code", status);
464 span.record("error.type", core::any::type_name::<E>());
465 span.record("error.message", tracing::field::display(&error));
466 task::Poll::Ready(Err(error))
467 },
468 task::Poll::Pending => task::Poll::Pending
469 }
470 }
471}