1#![warn(missing_docs)]
33#![allow(clippy::style)]
34
35mod grpc;
36mod headers;
37#[cfg(feature = "opentelemetry")]
38pub mod opentelemetry;
39
40use std::net::IpAddr;
41use core::{cmp, fmt, ptr, task};
42use core::pin::Pin;
43use core::future::Future;
44
45pub use tracing;
46
47pub const REQUEST_ID: http::HeaderName = http::HeaderName::from_static("x-request-id");
49pub type MakeSpan = fn() -> tracing::Span;
51pub type ExtractClientIp = fn(&http::request::Parts) -> Option<IpAddr>;
53
54#[inline]
55fn default_client_ip(_: &http::request::Parts) -> Option<IpAddr> {
56 None
57}
58
59#[derive(Copy, Clone, PartialEq, Eq)]
60pub enum Protocol {
62 Http,
66 Grpc,
68}
69
70impl Protocol {
71 #[inline(always)]
72 pub fn from_content_type(typ: &[u8]) -> Self {
74 if typ.starts_with(b"application/grpc") {
75 Self::Grpc
76 } else {
77 Self::Http
78 }
79 }
80
81 #[inline(always)]
82 pub const fn as_str(&self) -> &'static str {
84 match self {
85 Self::Grpc => "grpc",
86 Self::Http => "http"
87 }
88 }
89}
90
91impl fmt::Debug for Protocol {
92 #[inline(always)]
93 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
94 fmt::Debug::fmt(self.as_str(), fmt)
95 }
96}
97
98impl fmt::Display for Protocol {
99 #[inline(always)]
100 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
101 fmt::Display::fmt(self.as_str(), fmt)
102 }
103}
104
105type RequestIdBuffer = [u8; 64];
106
107#[derive(Clone)]
108pub struct RequestId {
112 buffer: RequestIdBuffer,
113 len: u8,
114}
115
116impl RequestId {
117 fn from_bytes(bytes: &[u8]) -> Self {
118 let mut buffer: RequestIdBuffer = [0; 64];
119
120 let len = cmp::min(buffer.len(), bytes.len());
121
122 unsafe {
123 ptr::copy_nonoverlapping(bytes.as_ptr(), buffer.as_mut_ptr(), len)
124 };
125
126 Self {
127 buffer,
128 len: len as _,
129 }
130 }
131
132 fn from_uuid(uuid: uuid::Uuid) -> Self {
133 let mut buffer: RequestIdBuffer = [0; 64];
134 let uuid = uuid.as_hyphenated();
135 let len = uuid.encode_lower(&mut buffer).len();
136
137 Self {
138 buffer,
139 len: len as _,
140 }
141 }
142
143 #[inline]
144 pub const fn as_bytes(&self) -> &[u8] {
146 unsafe {
147 core::slice::from_raw_parts(self.buffer.as_ptr(), self.len as _)
148 }
149 }
150
151 #[inline(always)]
152 pub const fn as_str(&self) -> Option<&str> {
154 match core::str::from_utf8(self.as_bytes()) {
155 Ok(header) => Some(header),
156 Err(_) => None,
157 }
158 }
159}
160
161impl fmt::Debug for RequestId {
162 #[inline(always)]
163 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
164 match self.as_str() {
165 Some(id) => fmt::Debug::fmt(id, fmt),
166 None => fmt::Debug::fmt(self.as_bytes(), fmt),
167 }
168 }
169}
170
171impl fmt::Display for RequestId {
172 #[inline(always)]
173 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
174 match self.as_str() {
175 Some(id) => fmt::Display::fmt(id, fmt),
176 None => fmt::Display::fmt("<non-utf8>", fmt),
177 }
178 }
179}
180
181#[macro_export]
182macro_rules! make_request_spanner {
217 ($fn:ident($name:literal, $level:expr)) => {
218 $crate::make_request_spanner!($fn($name, $level,));
219 };
220 ($fn:ident($name:literal, $level:expr, $($fields:tt)*)) => {
221 #[track_caller]
222 pub fn $fn() -> $crate::tracing::Span {
223 use $crate::tracing::field;
224
225 $crate::tracing::span!(
226 $level,
227 $name,
228 span.kind = "server",
230 http.request.method = field::Empty,
232 url.path = field::Empty,
233 url.query = field::Empty,
234 url.scheme = field::Empty,
235 http.request_id = field::Empty,
236 user_agent.original = field::Empty,
237 http.headers = field::Empty,
238 network.protocol.name = field::Empty,
239 network.protocol.version = field::Empty,
240 client.address = field::Empty,
242 http.response.status_code = field::Empty,
244 error.type = field::Empty,
245 error.message = field::Empty,
246 $(
247 $fields
248 )*
249 )
250 }
251 };
252}
253
254#[derive(Clone, Debug)]
255pub struct RequestInfo {
259 pub protocol: Protocol,
261 pub request_id: RequestId,
263 pub client_ip: Option<IpAddr>,
265}
266
267pub struct RequestSpan {
271 pub span: tracing::Span,
273 pub info: RequestInfo,
275}
276
277impl RequestSpan {
278 pub fn new(span: tracing::Span, extract_client_ip: ExtractClientIp, parts: &http::request::Parts) -> Self {
280 let _entered = span.enter();
281
282 let client_ip = (extract_client_ip)(parts);
283 let protocol = parts.headers
284 .get(http::header::CONTENT_TYPE)
285 .map_or(Protocol::Http, |content_type| Protocol::from_content_type(content_type.as_bytes()));
286
287 let request_id = if let Some(request_id) = parts.headers.get(REQUEST_ID) {
288 RequestId::from_bytes(request_id.as_bytes())
289 } else {
290 RequestId::from_uuid(uuid::Uuid::new_v4())
291 };
292
293 if let Some(user_agent) = parts.headers.get(http::header::USER_AGENT).and_then(|header| header.to_str().ok()) {
294 span.record("user_agent.original", user_agent);
295 }
296 span.record("http.request.method", parts.method.as_str());
297 span.record("url.path", parts.uri.path());
298 if let Some(query) = parts.uri.query() {
299 span.record("url.query", query);
300 }
301 if let Some(scheme) = parts.uri.scheme() {
302 span.record("url.scheme", scheme.as_str());
303 }
304 if let Some(request_id) = request_id.as_str() {
305 span.record("http.request_id", &request_id);
306 } else {
307 span.record("http.request_id", request_id.as_bytes());
308 }
309 if let Some(client_ip) = client_ip {
310 span.record("client.address", tracing::field::display(client_ip));
311 }
312 span.record("network.protocol.name", protocol.as_str());
313 if let Protocol::Http = protocol {
314 match parts.version {
315 http::Version::HTTP_09 => span.record("network.protocol.version", 0.9),
316 http::Version::HTTP_10 => span.record("network.protocol.version", 1.0),
317 http::Version::HTTP_11 => span.record("network.protocol.version", 1.1),
318 http::Version::HTTP_2 => span.record("network.protocol.version", 2),
319 http::Version::HTTP_3 => span.record("network.protocol.version", 3),
320 _ => span.record("network.protocol.version", 0),
322 };
323 }
324
325 drop(_entered);
326
327 Self {
328 span,
329 info: RequestInfo {
330 protocol,
331 request_id,
332 client_ip
333 }
334 }
335 }
336}
337
338#[derive(Clone)]
339pub struct HttpRequestLayer {
341 make_span: MakeSpan,
342 inspect_headers: &'static [&'static http::HeaderName],
343 extract_client_ip: ExtractClientIp,
344}
345
346impl HttpRequestLayer {
347 #[inline]
348 pub fn new(make_span: MakeSpan) -> Self {
350 Self {
351 make_span,
352 inspect_headers: &[],
353 extract_client_ip: default_client_ip
354 }
355 }
356
357 #[inline]
358 pub fn with_inspect_headers(mut self, inspect_headers: &'static [&'static http::HeaderName]) -> Self {
362 self.inspect_headers = inspect_headers;
363 self
364 }
365
366 pub fn with_extract_client_ip(mut self, extract_client_ip: ExtractClientIp) -> Self {
370 self.extract_client_ip = extract_client_ip;
371 self
372 }
373}
374
375impl<S> tower_layer::Layer<S> for HttpRequestLayer {
376 type Service = HttpRequestService<S>;
377 #[inline(always)]
378 fn layer(&self, inner: S) -> Self::Service {
379 HttpRequestService {
380 layer: self.clone(),
381 inner,
382 }
383 }
384}
385
386pub struct HttpRequestService<S> {
388 layer: HttpRequestLayer,
389 inner: S
390}
391
392impl<ReqBody, ResBody, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>> tower_service::Service<http::Request<ReqBody>> for HttpRequestService<S> where S::Error: std::error::Error {
393 type Response = S::Response;
394 type Error = S::Error;
395 type Future = ResponseFut<S::Future>;
396
397 #[inline(always)]
398 fn poll_ready(&mut self, ctx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
399 self.inner.poll_ready(ctx)
400 }
401
402 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
403 let (parts, body) = req.into_parts();
404 let RequestSpan { span, info } = RequestSpan::new((self.layer.make_span)(), self.layer.extract_client_ip, &parts);
405
406 let mut req = http::Request::from_parts(parts, body);
407 #[cfg(feature = "opentelemetry")]
408 opentelemetry::on_request(&span, &req);
409
410 let _entered = span.enter();
411 if !self.layer.inspect_headers.is_empty() {
412 span.record("http.headers", tracing::field::debug(headers::InspectHeaders {
413 header_list: self.layer.inspect_headers,
414 headers: req.headers()
415 }));
416 }
417 let request_id = info.request_id.clone();
418 let protocol = info.protocol;
419 req.extensions_mut().insert(info);
420
421 let inner = self.inner.call(req);
422
423 drop(_entered);
424 ResponseFut {
425 inner,
426 span,
427 protocol,
428 request_id
429 }
430 }
431}
432
433pub struct ResponseFut<F> {
435 inner: F,
436 span: tracing::Span,
437 protocol: Protocol,
438 request_id: RequestId,
439}
440
441impl<ResBody, E: std::error::Error, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for ResponseFut<F> {
442 type Output = F::Output;
443
444 fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
445 let (fut, span, protocol, request_id) = unsafe {
446 let this = self.get_unchecked_mut();
447 (
448 Pin::new_unchecked(&mut this.inner),
449 &this.span,
450 this.protocol,
451 &this.request_id,
452 )
453 };
454 let _entered = span.enter();
455 match Future::poll(fut, ctx) {
456 task::Poll::Ready(Ok(mut resp)) => {
457 if let Ok(request_id) = http::HeaderValue::from_bytes(request_id.as_bytes()) {
458 resp.headers_mut().insert(REQUEST_ID, request_id);
459 }
460 let status = match protocol {
461 Protocol::Http => resp.status().as_u16(),
462 Protocol::Grpc => match resp.headers().get("grpc-status") {
463 Some(status) => grpc::parse_grpc_status(status.as_bytes()),
464 None => 2,
465 }
466 };
467 span.record("http.response.status_code", status);
468
469 #[cfg(feature = "opentelemetry")]
470 opentelemetry::on_response_ok(&span, &mut resp);
471
472 task::Poll::Ready(Ok(resp))
473 }
474 task::Poll::Ready(Err(error)) => {
475 let status = match protocol {
476 Protocol::Http => 500u16,
477 Protocol::Grpc => 13,
478 };
479 span.record("http.response.status_code", status);
480 span.record("error.type", core::any::type_name::<E>());
481 span.record("error.message", tracing::field::display(&error));
482
483 #[cfg(feature = "opentelemetry")]
484 opentelemetry::on_response_error(&span, &error);
485
486 task::Poll::Ready(Err(error))
487 },
488 task::Poll::Pending => task::Poll::Pending
489 }
490 }
491}