tower_http/trace/mod.rs
1//! Middleware that adds high level [tracing] to a [`Service`].
2//!
3//! # Example
4//!
5//! Adding tracing to your service can be as simple as:
6//!
7//! ```rust
8//! use http::{Request, Response};
9//! use tower::{ServiceBuilder, ServiceExt, Service};
10//! use tower_http::trace::TraceLayer;
11//! use std::convert::Infallible;
12//! use http_body_util::Full;
13//! use bytes::Bytes;
14//!
15//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
16//! Ok(Response::new(Full::default()))
17//! }
18//!
19//! # #[tokio::main]
20//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
21//! // Setup tracing
22//! tracing_subscriber::fmt::init();
23//!
24//! let mut service = ServiceBuilder::new()
25//! .layer(TraceLayer::new_for_http())
26//! .service_fn(handle);
27//!
28//! let request = Request::new(Full::from("foo"));
29//!
30//! let response = service
31//! .ready()
32//! .await?
33//! .call(request)
34//! .await?;
35//! # Ok(())
36//! # }
37//! ```
38//!
39//! If you run this application with `RUST_LOG=tower_http=trace cargo run` you should see logs like:
40//!
41//! ```text
42//! Mar 05 20:50:28.523 DEBUG request{method=GET path="/foo"}: tower_http::trace::on_request: started processing request
43//! Mar 05 20:50:28.524 DEBUG request{method=GET path="/foo"}: tower_http::trace::on_response: finished processing request latency=1 ms status=200
44//! ```
45//!
46//! # Customization
47//!
48//! [`Trace`] comes with good defaults but also supports customizing many aspects of the output.
49//!
50//! The default behaviour supports some customization:
51//!
52//! ```rust
53//! use http::{Request, Response, HeaderMap, StatusCode};
54//! use http_body_util::Full;
55//! use bytes::Bytes;
56//! use tower::ServiceBuilder;
57//! use tracing::Level;
58//! use tower_http::{
59//! LatencyUnit,
60//! trace::{TraceLayer, DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse},
61//! };
62//! use std::time::Duration;
63//! # use tower::{ServiceExt, Service};
64//! # use std::convert::Infallible;
65//!
66//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
67//! # Ok(Response::new(Full::from("foo")))
68//! # }
69//! # #[tokio::main]
70//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
71//! # tracing_subscriber::fmt::init();
72//! #
73//! let service = ServiceBuilder::new()
74//! .layer(
75//! TraceLayer::new_for_http()
76//! .make_span_with(
77//! DefaultMakeSpan::new().include_headers(true)
78//! )
79//! .on_request(
80//! DefaultOnRequest::new().level(Level::INFO)
81//! )
82//! .on_response(
83//! DefaultOnResponse::new()
84//! .level(Level::INFO)
85//! .latency_unit(LatencyUnit::Micros)
86//! )
87//! // and so on for `on_eos`, `on_body_chunk`, and `on_failure`
88//! )
89//! .service_fn(handle);
90//! # let mut service = service;
91//! # let response = service
92//! # .ready()
93//! # .await?
94//! # .call(Request::new(Full::from("foo")))
95//! # .await?;
96//! # Ok(())
97//! # }
98//! ```
99//!
100//! However for maximum control you can provide callbacks:
101//!
102//! ```rust
103//! use http::{Request, Response, HeaderMap, StatusCode};
104//! use http_body_util::Full;
105//! use bytes::Bytes;
106//! use tower::ServiceBuilder;
107//! use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer};
108//! use std::time::Duration;
109//! use tracing::Span;
110//! # use tower::{ServiceExt, Service};
111//! # use std::convert::Infallible;
112//!
113//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
114//! # Ok(Response::new(Full::from("foo")))
115//! # }
116//! # #[tokio::main]
117//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
118//! # tracing_subscriber::fmt::init();
119//! #
120//! let service = ServiceBuilder::new()
121//! .layer(
122//! TraceLayer::new_for_http()
123//! .make_span_with(|request: &Request<Full<Bytes>>| {
124//! tracing::debug_span!("http-request")
125//! })
126//! .on_request(|request: &Request<Full<Bytes>>, _span: &Span| {
127//! tracing::debug!("started {} {}", request.method(), request.uri().path())
128//! })
129//! .on_response(|response: &Response<Full<Bytes>>, latency: Duration, _span: &Span| {
130//! tracing::debug!("response generated in {:?}", latency)
131//! })
132//! .on_body_chunk(|chunk: &Bytes, latency: Duration, _span: &Span| {
133//! tracing::debug!("sending {} bytes", chunk.len())
134//! })
135//! .on_eos(|trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span| {
136//! tracing::debug!("stream closed after {:?}", stream_duration)
137//! })
138//! .on_failure(|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| {
139//! tracing::debug!("something went wrong")
140//! })
141//! )
142//! .service_fn(handle);
143//! # let mut service = service;
144//! # let response = service
145//! # .ready()
146//! # .await?
147//! # .call(Request::new(Full::from("foo")))
148//! # .await?;
149//! # Ok(())
150//! # }
151//! ```
152//!
153//! ## Disabling something
154//!
155//! Setting the behaviour to `()` will be disable that particular step:
156//!
157//! ```rust
158//! use http::StatusCode;
159//! use tower::ServiceBuilder;
160//! use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer};
161//! use std::time::Duration;
162//! use tracing::Span;
163//! # use tower::{ServiceExt, Service};
164//! # use http_body_util::Full;
165//! # use bytes::Bytes;
166//! # use http::{Response, Request};
167//! # use std::convert::Infallible;
168//!
169//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
170//! # Ok(Response::new(Full::from("foo")))
171//! # }
172//! # #[tokio::main]
173//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
174//! # tracing_subscriber::fmt::init();
175//! #
176//! let service = ServiceBuilder::new()
177//! .layer(
178//! // This configuration will only emit events on failures
179//! TraceLayer::new_for_http()
180//! .on_request(())
181//! .on_response(())
182//! .on_body_chunk(())
183//! .on_eos(())
184//! .on_failure(|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| {
185//! tracing::debug!("something went wrong")
186//! })
187//! )
188//! .service_fn(handle);
189//! # let mut service = service;
190//! # let response = service
191//! # .ready()
192//! # .await?
193//! # .call(Request::new(Full::from("foo")))
194//! # .await?;
195//! # Ok(())
196//! # }
197//! ```
198//!
199//! # When the callbacks are called
200//!
201//! ### `on_request`
202//!
203//! The `on_request` callback is called when the request arrives at the
204//! middleware in [`Service::call`] just prior to passing the request to the
205//! inner service.
206//!
207//! ### `on_response`
208//!
209//! The `on_response` callback is called when the inner service's response
210//! future completes with `Ok(response)` regardless if the response is
211//! classified as a success or a failure.
212//!
213//! For example if you're using [`ServerErrorsAsFailures`] as your classifier
214//! and the inner service responds with `500 Internal Server Error` then the
215//! `on_response` callback is still called. `on_failure` would _also_ be called
216//! in this case since the response was classified as a failure.
217//!
218//! ### `on_body_chunk`
219//!
220//! The `on_body_chunk` callback is called when the response body produces a new
221//! chunk, that is when [`Body::poll_frame`] returns a data frame.
222//!
223//! `on_body_chunk` is called even if the chunk is empty.
224//!
225//! ### `on_eos`
226//!
227//! The `on_eos` callback is called when a streaming response body ends, that is
228//! when [`Body::poll_frame`] returns a trailers frame.
229//!
230//! `on_eos` is called even if the trailers produced are `None`.
231//!
232//! ### `on_failure`
233//!
234//! The `on_failure` callback is called when:
235//!
236//! - The inner [`Service`]'s response future resolves to an error.
237//! - A response is classified as a failure.
238//! - [`Body::poll_frame`] returns an error.
239//! - An end-of-stream is classified as a failure.
240//!
241//! # Recording fields on the span
242//!
243//! All callbacks receive a reference to the [tracing] [`Span`], corresponding to this request,
244//! produced by the closure passed to [`TraceLayer::make_span_with`]. It can be used to [record
245//! field values][record] that weren't known when the span was created.
246//!
247//! ```rust
248//! use http::{Request, Response, HeaderMap, StatusCode};
249//! use http_body_util::Full;
250//! use bytes::Bytes;
251//! use tower::ServiceBuilder;
252//! use tower_http::trace::TraceLayer;
253//! use tracing::Span;
254//! use std::time::Duration;
255//! # use std::convert::Infallible;
256//!
257//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
258//! # Ok(Response::new(Full::from("foo")))
259//! # }
260//! # #[tokio::main]
261//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
262//! # tracing_subscriber::fmt::init();
263//! #
264//! let service = ServiceBuilder::new()
265//! .layer(
266//! TraceLayer::new_for_http()
267//! .make_span_with(|request: &Request<Full<Bytes>>| {
268//! tracing::debug_span!(
269//! "http-request",
270//! status_code = tracing::field::Empty,
271//! )
272//! })
273//! .on_response(|response: &Response<Full<Bytes>>, _latency: Duration, span: &Span| {
274//! span.record("status_code", &tracing::field::display(response.status()));
275//!
276//! tracing::debug!("response generated")
277//! })
278//! )
279//! .service_fn(handle);
280//! # Ok(())
281//! # }
282//! ```
283//!
284//! # Providing classifiers
285//!
286//! Tracing requires determining if a response is a success or failure. [`MakeClassifier`] is used
287//! to create a classifier for the incoming request. See the docs for [`MakeClassifier`] and
288//! [`ClassifyResponse`] for more details on classification.
289//!
290//! A [`MakeClassifier`] can be provided when creating a [`TraceLayer`]:
291//!
292//! ```rust
293//! use http::{Request, Response};
294//! use http_body_util::Full;
295//! use bytes::Bytes;
296//! use tower::ServiceBuilder;
297//! use tower_http::{
298//! trace::TraceLayer,
299//! classify::{
300//! MakeClassifier, ClassifyResponse, ClassifiedResponse, NeverClassifyEos,
301//! SharedClassifier,
302//! },
303//! };
304//! use std::convert::Infallible;
305//!
306//! # async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
307//! # Ok(Response::new(Full::from("foo")))
308//! # }
309//! # #[tokio::main]
310//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
311//! # tracing_subscriber::fmt::init();
312//! #
313//! // Our `MakeClassifier` that always crates `MyClassifier` classifiers.
314//! #[derive(Copy, Clone)]
315//! struct MyMakeClassify;
316//!
317//! impl MakeClassifier for MyMakeClassify {
318//! type Classifier = MyClassifier;
319//! type FailureClass = &'static str;
320//! type ClassifyEos = NeverClassifyEos<&'static str>;
321//!
322//! fn make_classifier<B>(&self, req: &Request<B>) -> Self::Classifier {
323//! MyClassifier
324//! }
325//! }
326//!
327//! // A classifier that classifies failures as `"something went wrong..."`.
328//! #[derive(Copy, Clone)]
329//! struct MyClassifier;
330//!
331//! impl ClassifyResponse for MyClassifier {
332//! type FailureClass = &'static str;
333//! type ClassifyEos = NeverClassifyEos<&'static str>;
334//!
335//! fn classify_response<B>(
336//! self,
337//! res: &Response<B>
338//! ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
339//! // Classify based on the status code.
340//! if res.status().is_server_error() {
341//! ClassifiedResponse::Ready(Err("something went wrong..."))
342//! } else {
343//! ClassifiedResponse::Ready(Ok(()))
344//! }
345//! }
346//!
347//! fn classify_error<E>(self, error: &E) -> Self::FailureClass
348//! where
349//! E: std::fmt::Display + 'static,
350//! {
351//! "something went wrong..."
352//! }
353//! }
354//!
355//! let service = ServiceBuilder::new()
356//! // Create a trace layer that uses our classifier.
357//! .layer(TraceLayer::new(MyMakeClassify))
358//! .service_fn(handle);
359//!
360//! // Since `MyClassifier` is `Clone` we can also use `SharedClassifier`
361//! // to avoid having to define a separate `MakeClassifier`.
362//! let service = ServiceBuilder::new()
363//! .layer(TraceLayer::new(SharedClassifier::new(MyClassifier)))
364//! .service_fn(handle);
365//! # Ok(())
366//! # }
367//! ```
368//!
369//! [`TraceLayer`] comes with convenience methods for using common classifiers:
370//!
371//! - [`TraceLayer::new_for_http`] classifies based on the status code. It doesn't consider
372//! streaming responses.
373//! - [`TraceLayer::new_for_grpc`] classifies based on the gRPC protocol and supports streaming
374//! responses.
375//!
376//! [tracing]: https://crates.io/crates/tracing
377//! [`Service`]: tower_service::Service
378//! [`Service::call`]: tower_service::Service::call
379//! [`MakeClassifier`]: crate::classify::MakeClassifier
380//! [`ClassifyResponse`]: crate::classify::ClassifyResponse
381//! [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record
382//! [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with
383//! [`Span`]: tracing::Span
384//! [`ServerErrorsAsFailures`]: crate::classify::ServerErrorsAsFailures
385//! [`Body::poll_frame`]: http_body::Body::poll_frame
386
387use std::{fmt, time::Duration};
388
389use tracing::Level;
390
391pub use self::{
392 body::ResponseBody,
393 future::ResponseFuture,
394 layer::TraceLayer,
395 make_span::{DefaultMakeSpan, MakeSpan},
396 on_body_chunk::{DefaultOnBodyChunk, OnBodyChunk},
397 on_eos::{DefaultOnEos, OnEos},
398 on_failure::{DefaultOnFailure, OnFailure},
399 on_request::{DefaultOnRequest, OnRequest},
400 on_response::{DefaultOnResponse, OnResponse},
401 service::Trace,
402};
403use crate::{
404 classify::{GrpcErrorsAsFailures, ServerErrorsAsFailures, SharedClassifier},
405 LatencyUnit,
406};
407
408/// MakeClassifier for HTTP requests.
409pub type HttpMakeClassifier = SharedClassifier<ServerErrorsAsFailures>;
410
411/// MakeClassifier for gRPC requests.
412pub type GrpcMakeClassifier = SharedClassifier<GrpcErrorsAsFailures>;
413
414macro_rules! event_dynamic_lvl {
415 ( target: $target:expr, parent: $parent:expr, $lvl:expr, $($tt:tt)* ) => {
416 match $lvl {
417 tracing::Level::ERROR => {
418 tracing::event!(target: $target, parent: $parent, tracing::Level::ERROR, $($tt)*);
419 }
420 tracing::Level::WARN => {
421 tracing::event!(target: $target, parent: $parent, tracing::Level::WARN, $($tt)*);
422 }
423 tracing::Level::INFO => {
424 tracing::event!(target: $target, parent: $parent, tracing::Level::INFO, $($tt)*);
425 }
426 tracing::Level::DEBUG => {
427 tracing::event!(target: $target, parent: $parent, tracing::Level::DEBUG, $($tt)*);
428 }
429 tracing::Level::TRACE => {
430 tracing::event!(target: $target, parent: $parent, tracing::Level::TRACE, $($tt)*);
431 }
432 }
433 };
434 ( target: $target:expr, $lvl:expr, $($tt:tt)* ) => {
435 match $lvl {
436 tracing::Level::ERROR => {
437 tracing::event!(target: $target, tracing::Level::ERROR, $($tt)*);
438 }
439 tracing::Level::WARN => {
440 tracing::event!(target: $target, tracing::Level::WARN, $($tt)*);
441 }
442 tracing::Level::INFO => {
443 tracing::event!(target: $target, tracing::Level::INFO, $($tt)*);
444 }
445 tracing::Level::DEBUG => {
446 tracing::event!(target: $target, tracing::Level::DEBUG, $($tt)*);
447 }
448 tracing::Level::TRACE => {
449 tracing::event!(target: $target, tracing::Level::TRACE, $($tt)*);
450 }
451 }
452 };
453 ( parent: $parent:expr, $lvl:expr, $($tt:tt)* ) => {
454 match $lvl {
455 tracing::Level::ERROR => {
456 tracing::event!(parent: $parent, tracing::Level::ERROR, $($tt)*);
457 }
458 tracing::Level::WARN => {
459 tracing::event!(parent: $parent, tracing::Level::WARN, $($tt)*);
460 }
461 tracing::Level::INFO => {
462 tracing::event!(parent: $parent, tracing::Level::INFO, $($tt)*);
463 }
464 tracing::Level::DEBUG => {
465 tracing::event!(parent: $parent, tracing::Level::DEBUG, $($tt)*);
466 }
467 tracing::Level::TRACE => {
468 tracing::event!(parent: $parent, tracing::Level::TRACE, $($tt)*);
469 }
470 }
471 };
472 ( $lvl:expr, $($tt:tt)* ) => {
473 match $lvl {
474 tracing::Level::ERROR => {
475 tracing::event!(tracing::Level::ERROR, $($tt)*);
476 }
477 tracing::Level::WARN => {
478 tracing::event!(tracing::Level::WARN, $($tt)*);
479 }
480 tracing::Level::INFO => {
481 tracing::event!(tracing::Level::INFO, $($tt)*);
482 }
483 tracing::Level::DEBUG => {
484 tracing::event!(tracing::Level::DEBUG, $($tt)*);
485 }
486 tracing::Level::TRACE => {
487 tracing::event!(tracing::Level::TRACE, $($tt)*);
488 }
489 }
490 };
491}
492
493mod body;
494mod future;
495mod layer;
496mod make_span;
497mod on_body_chunk;
498mod on_eos;
499mod on_failure;
500mod on_request;
501mod on_response;
502mod service;
503
504const DEFAULT_MESSAGE_LEVEL: Level = Level::DEBUG;
505const DEFAULT_ERROR_LEVEL: Level = Level::ERROR;
506
507struct Latency {
508 unit: LatencyUnit,
509 duration: Duration,
510}
511
512impl fmt::Display for Latency {
513 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514 match self.unit {
515 LatencyUnit::Seconds => write!(f, "{} s", self.duration.as_secs_f64()),
516 LatencyUnit::Millis => write!(f, "{} ms", self.duration.as_millis()),
517 LatencyUnit::Micros => write!(f, "{} μs", self.duration.as_micros()),
518 LatencyUnit::Nanos => write!(f, "{} ns", self.duration.as_nanos()),
519 }
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::classify::ServerErrorsFailureClass;
527 use crate::test_helpers::Body;
528 use bytes::Bytes;
529 use http::{HeaderMap, Request, Response};
530 use once_cell::sync::Lazy;
531 use std::{
532 sync::atomic::{AtomicU32, Ordering},
533 time::Duration,
534 };
535 use tower::{BoxError, Service, ServiceBuilder, ServiceExt};
536 use tracing::Span;
537
538 #[tokio::test]
539 async fn unary_request() {
540 static ON_REQUEST_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
541 static ON_RESPONSE_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
542 static ON_BODY_CHUNK_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
543 static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
544 static ON_FAILURE: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
545
546 let trace_layer = TraceLayer::new_for_http()
547 .make_span_with(|_req: &Request<Body>| {
548 tracing::info_span!("test-span", foo = tracing::field::Empty)
549 })
550 .on_request(|_req: &Request<Body>, span: &Span| {
551 span.record("foo", 42);
552 ON_REQUEST_COUNT.fetch_add(1, Ordering::SeqCst);
553 })
554 .on_response(|_res: &Response<Body>, _latency: Duration, _span: &Span| {
555 ON_RESPONSE_COUNT.fetch_add(1, Ordering::SeqCst);
556 })
557 .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {
558 ON_BODY_CHUNK_COUNT.fetch_add(1, Ordering::SeqCst);
559 })
560 .on_eos(
561 |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
562 ON_EOS.fetch_add(1, Ordering::SeqCst);
563 },
564 )
565 .on_failure(
566 |_class: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
567 ON_FAILURE.fetch_add(1, Ordering::SeqCst);
568 },
569 );
570
571 let mut svc = ServiceBuilder::new().layer(trace_layer).service_fn(echo);
572
573 let res = svc
574 .ready()
575 .await
576 .unwrap()
577 .call(Request::new(Body::from("foobar")))
578 .await
579 .unwrap();
580
581 assert_eq!(1, ON_REQUEST_COUNT.load(Ordering::SeqCst), "request");
582 assert_eq!(1, ON_RESPONSE_COUNT.load(Ordering::SeqCst), "request");
583 assert_eq!(0, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk");
584 assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos");
585 assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure");
586
587 crate::test_helpers::to_bytes(res.into_body())
588 .await
589 .unwrap();
590 assert_eq!(1, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk");
591 assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos");
592 assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure");
593 }
594
595 #[tokio::test]
596 async fn streaming_response() {
597 static ON_REQUEST_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
598 static ON_RESPONSE_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
599 static ON_BODY_CHUNK_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
600 static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
601 static ON_FAILURE: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
602
603 let trace_layer = TraceLayer::new_for_http()
604 .on_request(|_req: &Request<Body>, _span: &Span| {
605 ON_REQUEST_COUNT.fetch_add(1, Ordering::SeqCst);
606 })
607 .on_response(|_res: &Response<Body>, _latency: Duration, _span: &Span| {
608 ON_RESPONSE_COUNT.fetch_add(1, Ordering::SeqCst);
609 })
610 .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {
611 ON_BODY_CHUNK_COUNT.fetch_add(1, Ordering::SeqCst);
612 })
613 .on_eos(
614 |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
615 ON_EOS.fetch_add(1, Ordering::SeqCst);
616 },
617 )
618 .on_failure(
619 |_class: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
620 ON_FAILURE.fetch_add(1, Ordering::SeqCst);
621 },
622 );
623
624 let mut svc = ServiceBuilder::new()
625 .layer(trace_layer)
626 .service_fn(streaming_body);
627
628 let res = svc
629 .ready()
630 .await
631 .unwrap()
632 .call(Request::new(Body::empty()))
633 .await
634 .unwrap();
635
636 assert_eq!(1, ON_REQUEST_COUNT.load(Ordering::SeqCst), "request");
637 assert_eq!(1, ON_RESPONSE_COUNT.load(Ordering::SeqCst), "request");
638 assert_eq!(0, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk");
639 assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos");
640 assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure");
641
642 crate::test_helpers::to_bytes(res.into_body())
643 .await
644 .unwrap();
645 assert_eq!(3, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk");
646 assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos");
647 assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure");
648 }
649
650 #[tokio::test]
651 async fn classify_eos_on_trailers_success() {
652 static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
653 static ON_FAILURE: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
654
655 let trace_layer = TraceLayer::new(TestClassify::new(false))
656 .on_eos(
657 |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
658 ON_EOS.fetch_add(1, Ordering::SeqCst);
659 },
660 )
661 .on_failure(|_class: &'static str, _latency: Duration, _span: &Span| {
662 ON_FAILURE.fetch_add(1, Ordering::SeqCst);
663 });
664
665 let mut svc = ServiceBuilder::new()
666 .layer(trace_layer)
667 .service_fn(body_with_trailers);
668
669 let res = svc
670 .ready()
671 .await
672 .unwrap()
673 .call(Request::new(Body::empty()))
674 .await
675 .unwrap();
676
677 crate::test_helpers::to_bytes(res.into_body())
678 .await
679 .unwrap();
680 assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos");
681 assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure");
682 }
683
684 #[tokio::test]
685 async fn classify_eos_on_trailers_failure() {
686 static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
687 static ON_FAILURE: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
688
689 let trace_layer = TraceLayer::new(TestClassify::new(true))
690 .on_eos(
691 |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
692 ON_EOS.fetch_add(1, Ordering::SeqCst);
693 },
694 )
695 .on_failure(|_class: &'static str, _latency: Duration, _span: &Span| {
696 ON_FAILURE.fetch_add(1, Ordering::SeqCst);
697 });
698
699 let mut svc = ServiceBuilder::new()
700 .layer(trace_layer)
701 .service_fn(body_with_trailers);
702
703 let res = svc
704 .ready()
705 .await
706 .unwrap()
707 .call(Request::new(Body::empty()))
708 .await
709 .unwrap();
710
711 crate::test_helpers::to_bytes(res.into_body())
712 .await
713 .unwrap();
714 assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos");
715 assert_eq!(1, ON_FAILURE.load(Ordering::SeqCst), "failure");
716 }
717
718 #[tokio::test]
719 async fn classify_eos_on_empty_stream() {
720 static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
721 static ON_FAILURE: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
722
723 let trace_layer = TraceLayer::new(TestClassify::new(true))
724 .on_eos(
725 |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
726 ON_EOS.fetch_add(1, Ordering::SeqCst);
727 },
728 )
729 .on_failure(|_class: &'static str, _latency: Duration, _span: &Span| {
730 ON_FAILURE.fetch_add(1, Ordering::SeqCst);
731 });
732
733 let mut svc = ServiceBuilder::new()
734 .layer(trace_layer)
735 .service_fn(streaming_body);
736
737 let res = svc
738 .ready()
739 .await
740 .unwrap()
741 .call(Request::new(Body::empty()))
742 .await
743 .unwrap();
744
745 crate::test_helpers::to_bytes(res.into_body())
746 .await
747 .unwrap();
748 assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos");
749 assert_eq!(1, ON_FAILURE.load(Ordering::SeqCst), "failure");
750 }
751
752 #[tokio::test]
753 async fn on_eos_fires_for_content_length_body() {
754 // Simulates the scenario where a consumer stops polling after receiving
755 // all bytes (as hyper does when Content-Length is exact). We poll only
756 // the data frame and never poll to None.
757 use http_body_util::BodyExt;
758
759 static ON_BODY_CHUNK_COUNT: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
760 static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
761
762 let trace_layer = TraceLayer::new_for_http()
763 .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {
764 ON_BODY_CHUNK_COUNT.fetch_add(1, Ordering::SeqCst);
765 })
766 .on_eos(
767 |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
768 ON_EOS.fetch_add(1, Ordering::SeqCst);
769 },
770 );
771
772 let mut svc = ServiceBuilder::new().layer(trace_layer).service_fn(echo);
773
774 let res = svc
775 .ready()
776 .await
777 .unwrap()
778 .call(Request::new(Body::from("hello")))
779 .await
780 .unwrap();
781
782 let mut body = res.into_body();
783
784 // Poll only the data frame (simulating a content-length aware consumer)
785 let frame = body.frame().await.unwrap().unwrap();
786 assert!(frame.data_ref().is_some());
787
788 // on_eos should have fired immediately after the data frame since
789 // is_end_stream() is true for Full bodies after yielding their data.
790 assert_eq!(1, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk");
791 assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos");
792 }
793
794 #[tokio::test]
795 async fn on_eos_fires_for_streaming_body_on_none() {
796 // Streaming bodies (no content-length) don't report is_end_stream()
797 // until polled to None. Verify on_eos still fires via the None path.
798 static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
799
800 let trace_layer = TraceLayer::new_for_http().on_eos(
801 |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
802 ON_EOS.fetch_add(1, Ordering::SeqCst);
803 },
804 );
805
806 let mut svc = ServiceBuilder::new()
807 .layer(trace_layer)
808 .service_fn(streaming_body);
809
810 let res = svc
811 .ready()
812 .await
813 .unwrap()
814 .call(Request::new(Body::empty()))
815 .await
816 .unwrap();
817
818 crate::test_helpers::to_bytes(res.into_body())
819 .await
820 .unwrap();
821 assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos");
822 }
823
824 #[tokio::test]
825 async fn on_eos_not_called_twice() {
826 // When is_end_stream() fires on_eos after a data frame, a subsequent
827 // poll returning None must not fire on_eos again.
828 static ON_EOS: Lazy<AtomicU32> = Lazy::new(|| AtomicU32::new(0));
829
830 let trace_layer = TraceLayer::new_for_http().on_eos(
831 |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| {
832 ON_EOS.fetch_add(1, Ordering::SeqCst);
833 },
834 );
835
836 let mut svc = ServiceBuilder::new().layer(trace_layer).service_fn(echo);
837
838 let res = svc
839 .ready()
840 .await
841 .unwrap()
842 .call(Request::new(Body::from("hello")))
843 .await
844 .unwrap();
845
846 // Consume the body fully (polls data frame then None)
847 crate::test_helpers::to_bytes(res.into_body())
848 .await
849 .unwrap();
850
851 // on_eos should fire exactly once, not twice
852 assert_eq!(1, ON_EOS.load(Ordering::SeqCst), "eos");
853 }
854
855 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
856 Ok(Response::new(req.into_body()))
857 }
858
859 async fn streaming_body(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
860 use futures_util::stream::iter;
861
862 let stream = iter(vec![
863 Ok::<_, BoxError>(Bytes::from("one")),
864 Ok::<_, BoxError>(Bytes::from("two")),
865 Ok::<_, BoxError>(Bytes::from("three")),
866 ]);
867
868 let body = Body::from_stream(stream);
869
870 Ok(Response::new(body))
871 }
872
873 async fn body_with_trailers(_req: Request<Body>) -> Result<Response<Body>, BoxError> {
874 let mut trailers = HeaderMap::new();
875 trailers.insert("x-test-trailer", "value".parse().unwrap());
876 let body = Body::new(Body::from(Bytes::from("data")).with_trailers(trailers));
877 Ok(Response::new(body))
878 }
879
880 #[derive(Clone)]
881 struct TestClassify {
882 reject: bool,
883 }
884
885 impl TestClassify {
886 fn new(reject: bool) -> Self {
887 Self { reject }
888 }
889 }
890
891 impl crate::classify::MakeClassifier for TestClassify {
892 type FailureClass = &'static str;
893 type ClassifyEos = TestClassifyEos;
894 type Classifier = TestClassifyResponse;
895
896 fn make_classifier<B>(&self, _req: &Request<B>) -> Self::Classifier {
897 TestClassifyResponse {
898 reject: self.reject,
899 }
900 }
901 }
902
903 #[derive(Clone)]
904 struct TestClassifyResponse {
905 reject: bool,
906 }
907
908 impl crate::classify::ClassifyResponse for TestClassifyResponse {
909 type FailureClass = &'static str;
910 type ClassifyEos = TestClassifyEos;
911
912 fn classify_response<B>(
913 self,
914 _res: &Response<B>,
915 ) -> crate::classify::ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
916 crate::classify::ClassifiedResponse::RequiresEos(TestClassifyEos {
917 reject: self.reject,
918 })
919 }
920
921 fn classify_error<E>(self, _error: &E) -> Self::FailureClass
922 where
923 E: std::fmt::Display + 'static,
924 {
925 "error"
926 }
927 }
928
929 #[derive(Clone)]
930 struct TestClassifyEos {
931 reject: bool,
932 }
933
934 impl crate::classify::ClassifyEos for TestClassifyEos {
935 type FailureClass = &'static str;
936
937 fn classify_eos(self, _trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> {
938 if self.reject {
939 Err("classified as failure")
940 } else {
941 Ok(())
942 }
943 }
944
945 fn classify_error<E>(self, _error: &E) -> Self::FailureClass
946 where
947 E: std::fmt::Display + 'static,
948 {
949 "error"
950 }
951 }
952
953 /// Regression test for https://github.com/tower-rs/tower-http/issues/655
954 ///
955 /// Reproduces the reported bug: when a subscriber's filter disables the
956 /// request span but still enables events, the events appear without any
957 /// span context. This happens because `Span::enter()` on a disabled span
958 /// is a no-op, so events relying on ambient context have no parent.
959 ///
960 /// The fix (using explicit `parent: span`) ensures events always reference
961 /// the request span, even when it's disabled. A subscriber that records
962 /// disabled spans will still see the correct parent relationship.
963 #[test]
964 fn events_have_span_context_when_span_is_disabled() {
965 use std::sync::{Arc, Mutex};
966 use tracing::subscriber::with_default;
967 use tracing_subscriber::{layer::SubscriberExt, registry::LookupSpan, Layer as _};
968
969 /// A filter that disables spans (by rejecting at the span level)
970 /// but allows all events through. This simulates the scenario where
971 /// a per-layer EnvFilter disables the request span's callsite.
972 struct DisableSpansFilter;
973
974 impl<S: tracing::Subscriber> tracing_subscriber::layer::Filter<S> for DisableSpansFilter {
975 fn enabled(
976 &self,
977 meta: &tracing::Metadata<'_>,
978 _cx: &tracing_subscriber::layer::Context<'_, S>,
979 ) -> bool {
980 // Disable spans, keep events
981 !meta.is_span()
982 }
983 }
984
985 /// Records (event_message, has_any_parent) pairs.
986 #[derive(Clone)]
987 struct RecordingLayer {
988 events: Arc<Mutex<Vec<(String, bool)>>>,
989 }
990
991 impl<S> tracing_subscriber::Layer<S> for RecordingLayer
992 where
993 S: tracing::Subscriber + for<'a> LookupSpan<'a>,
994 {
995 fn on_event(
996 &self,
997 event: &tracing::Event<'_>,
998 ctx: tracing_subscriber::layer::Context<'_, S>,
999 ) {
1000 let mut msg = String::new();
1001 event.record(&mut MessageVisitor(&mut msg));
1002
1003 // Check if the event has ANY parent: explicit or contextual
1004 let has_parent = event.parent().is_some() || ctx.event_span(event).is_some();
1005
1006 self.events.lock().unwrap().push((msg, has_parent));
1007 }
1008 }
1009
1010 struct MessageVisitor<'a>(&'a mut String);
1011 impl tracing::field::Visit for MessageVisitor<'_> {
1012 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
1013 if field.name() == "message" {
1014 *self.0 = format!("{:?}", value);
1015 }
1016 }
1017 }
1018
1019 let events = Arc::new(Mutex::new(Vec::new()));
1020 let layer = RecordingLayer {
1021 events: events.clone(),
1022 };
1023 let subscriber = tracing_subscriber::registry().with(layer.with_filter(DisableSpansFilter));
1024
1025 // Use with_default to guarantee cleanup even on panic, avoiding
1026 // cross-test subscriber pollution.
1027 with_default(subscriber, || {
1028 let rt = tokio::runtime::Builder::new_current_thread()
1029 .enable_all()
1030 .build()
1031 .unwrap();
1032
1033 rt.block_on(async {
1034 let mut svc = ServiceBuilder::new()
1035 .layer(TraceLayer::new_for_http())
1036 .service_fn(echo);
1037
1038 let res = svc
1039 .ready()
1040 .await
1041 .unwrap()
1042 .call(Request::new(Body::from("test")))
1043 .await
1044 .unwrap();
1045
1046 crate::test_helpers::to_bytes(res.into_body())
1047 .await
1048 .unwrap();
1049 });
1050 });
1051
1052 let events = events.lock().unwrap();
1053 let request_events: Vec<_> = events
1054 .iter()
1055 .filter(|(msg, _)| {
1056 msg.contains("started processing request")
1057 || msg.contains("finished processing request")
1058 })
1059 .collect();
1060
1061 assert!(
1062 request_events.len() >= 2,
1063 "expected on_request and on_response events to fire"
1064 );
1065
1066 // The bug: without explicit parent, these events have no span context
1067 // at all when the request span is disabled. With the fix, they still
1068 // reference the span (even though it's disabled).
1069 for (msg, has_parent) in &request_events {
1070 assert!(
1071 *has_parent,
1072 "event {:?} has no span context. When the request span is \
1073 disabled by a filter, events must still reference it via \
1074 explicit parent so subscribers can associate them correctly.",
1075 msg
1076 );
1077 }
1078 }
1079}