rama_http/layer/trace/
service.rs

1use super::{
2    DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest,
3    DefaultOnResponse, GrpcMakeClassifier, HttpMakeClassifier, MakeSpan, OnBodyChunk, OnEos,
4    OnFailure, OnRequest, OnResponse, ResponseBody,
5};
6use crate::dep::http_body::Body as HttpBody;
7use crate::layer::classify::{
8    ClassifiedResponse, ClassifyResponse, GrpcErrorsAsFailures, MakeClassifier,
9    ServerErrorsAsFailures, SharedClassifier,
10};
11use crate::{Request, Response};
12use rama_core::{Context, Service};
13use rama_utils::macros::define_inner_service_accessors;
14use std::{fmt, time::Instant};
15
16/// Middleware that adds high level [tracing] to a [`Service`].
17///
18/// See the [module docs](crate::layer::trace) for an example.
19///
20/// [tracing]: https://crates.io/crates/tracing
21/// [`Service`]: rama_core::Service
22pub struct Trace<
23    S,
24    M,
25    MakeSpan = DefaultMakeSpan,
26    OnRequest = DefaultOnRequest,
27    OnResponse = DefaultOnResponse,
28    OnBodyChunk = DefaultOnBodyChunk,
29    OnEos = DefaultOnEos,
30    OnFailure = DefaultOnFailure,
31> {
32    pub(crate) inner: S,
33    pub(crate) make_classifier: M,
34    pub(crate) make_span: MakeSpan,
35    pub(crate) on_request: OnRequest,
36    pub(crate) on_response: OnResponse,
37    pub(crate) on_body_chunk: OnBodyChunk,
38    pub(crate) on_eos: OnEos,
39    pub(crate) on_failure: OnFailure,
40}
41
42impl<
43        S: fmt::Debug,
44        M: fmt::Debug,
45        MakeSpan: fmt::Debug,
46        OnRequest: fmt::Debug,
47        OnResponse: fmt::Debug,
48        OnBodyChunk: fmt::Debug,
49        OnEos: fmt::Debug,
50        OnFailure: fmt::Debug,
51    > fmt::Debug for Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
52{
53    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54        f.debug_struct("TraceLayer")
55            .field("inner", &self.inner)
56            .field("make_classifier", &self.make_classifier)
57            .field("make_span", &self.make_span)
58            .field("on_request", &self.on_request)
59            .field("on_response", &self.on_response)
60            .field("on_body_chunk", &self.on_body_chunk)
61            .field("on_eos", &self.on_eos)
62            .field("on_failure", &self.on_failure)
63            .finish()
64    }
65}
66
67impl<
68        S: Clone,
69        M: Clone,
70        MakeSpan: Clone,
71        OnRequest: Clone,
72        OnResponse: Clone,
73        OnBodyChunk: Clone,
74        OnEos: Clone,
75        OnFailure: Clone,
76    > Clone for Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
77{
78    fn clone(&self) -> Self {
79        Self {
80            inner: self.inner.clone(),
81            make_classifier: self.make_classifier.clone(),
82            make_span: self.make_span.clone(),
83            on_request: self.on_request.clone(),
84            on_response: self.on_response.clone(),
85            on_body_chunk: self.on_body_chunk.clone(),
86            on_eos: self.on_eos.clone(),
87            on_failure: self.on_failure.clone(),
88        }
89    }
90}
91
92impl<S, M> Trace<S, M> {
93    /// Create a new [`Trace`] using the given [`MakeClassifier`].
94    pub fn new(inner: S, make_classifier: M) -> Self
95    where
96        M: MakeClassifier,
97    {
98        Self {
99            inner,
100            make_classifier,
101            make_span: DefaultMakeSpan::new(),
102            on_request: DefaultOnRequest::default(),
103            on_response: DefaultOnResponse::default(),
104            on_body_chunk: DefaultOnBodyChunk::default(),
105            on_eos: DefaultOnEos::default(),
106            on_failure: DefaultOnFailure::default(),
107        }
108    }
109}
110
111impl<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
112    Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
113{
114    define_inner_service_accessors!();
115
116    /// Customize what to do when a request is received.
117    ///
118    /// `NewOnRequest` is expected to implement [`OnRequest`].
119    ///
120    /// [`OnRequest`]: super::OnRequest
121    pub fn on_request<NewOnRequest>(
122        self,
123        new_on_request: NewOnRequest,
124    ) -> Trace<S, M, MakeSpan, NewOnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> {
125        Trace {
126            on_request: new_on_request,
127            inner: self.inner,
128            on_failure: self.on_failure,
129            on_eos: self.on_eos,
130            on_body_chunk: self.on_body_chunk,
131            make_span: self.make_span,
132            on_response: self.on_response,
133            make_classifier: self.make_classifier,
134        }
135    }
136
137    /// Customize what to do when a response has been produced.
138    ///
139    /// `NewOnResponse` is expected to implement [`OnResponse`].
140    ///
141    /// [`OnResponse`]: super::OnResponse
142    pub fn on_response<NewOnResponse>(
143        self,
144        new_on_response: NewOnResponse,
145    ) -> Trace<S, M, MakeSpan, OnRequest, NewOnResponse, OnBodyChunk, OnEos, OnFailure> {
146        Trace {
147            on_response: new_on_response,
148            inner: self.inner,
149            on_request: self.on_request,
150            on_failure: self.on_failure,
151            on_body_chunk: self.on_body_chunk,
152            on_eos: self.on_eos,
153            make_span: self.make_span,
154            make_classifier: self.make_classifier,
155        }
156    }
157
158    /// Customize what to do when a body chunk has been sent.
159    ///
160    /// `NewOnBodyChunk` is expected to implement [`OnBodyChunk`].
161    ///
162    /// [`OnBodyChunk`]: super::OnBodyChunk
163    pub fn on_body_chunk<NewOnBodyChunk>(
164        self,
165        new_on_body_chunk: NewOnBodyChunk,
166    ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, NewOnBodyChunk, OnEos, OnFailure> {
167        Trace {
168            on_body_chunk: new_on_body_chunk,
169            on_eos: self.on_eos,
170            make_span: self.make_span,
171            inner: self.inner,
172            on_failure: self.on_failure,
173            on_request: self.on_request,
174            on_response: self.on_response,
175            make_classifier: self.make_classifier,
176        }
177    }
178
179    /// Customize what to do when a streaming response has closed.
180    ///
181    /// `NewOnEos` is expected to implement [`OnEos`].
182    ///
183    /// [`OnEos`]: super::OnEos
184    pub fn on_eos<NewOnEos>(
185        self,
186        new_on_eos: NewOnEos,
187    ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, NewOnEos, OnFailure> {
188        Trace {
189            on_eos: new_on_eos,
190            make_span: self.make_span,
191            inner: self.inner,
192            on_failure: self.on_failure,
193            on_request: self.on_request,
194            on_body_chunk: self.on_body_chunk,
195            on_response: self.on_response,
196            make_classifier: self.make_classifier,
197        }
198    }
199
200    /// Customize what to do when a response has been classified as a failure.
201    ///
202    /// `NewOnFailure` is expected to implement [`OnFailure`].
203    ///
204    /// [`OnFailure`]: super::OnFailure
205    pub fn on_failure<NewOnFailure>(
206        self,
207        new_on_failure: NewOnFailure,
208    ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, NewOnFailure> {
209        Trace {
210            on_failure: new_on_failure,
211            inner: self.inner,
212            make_span: self.make_span,
213            on_body_chunk: self.on_body_chunk,
214            on_request: self.on_request,
215            on_eos: self.on_eos,
216            on_response: self.on_response,
217            make_classifier: self.make_classifier,
218        }
219    }
220
221    /// Customize how to make [`Span`]s that all request handling will be wrapped in.
222    ///
223    /// `NewMakeSpan` is expected to implement [`MakeSpan`].
224    ///
225    /// [`MakeSpan`]: super::MakeSpan
226    /// [`Span`]: tracing::Span
227    pub fn make_span_with<NewMakeSpan>(
228        self,
229        new_make_span: NewMakeSpan,
230    ) -> Trace<S, M, NewMakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> {
231        Trace {
232            make_span: new_make_span,
233            inner: self.inner,
234            on_failure: self.on_failure,
235            on_request: self.on_request,
236            on_body_chunk: self.on_body_chunk,
237            on_response: self.on_response,
238            on_eos: self.on_eos,
239            make_classifier: self.make_classifier,
240        }
241    }
242}
243
244impl<S>
245    Trace<
246        S,
247        HttpMakeClassifier,
248        DefaultMakeSpan,
249        DefaultOnRequest,
250        DefaultOnResponse,
251        DefaultOnBodyChunk,
252        DefaultOnEos,
253        DefaultOnFailure,
254    >
255{
256    /// Create a new [`Trace`] using [`ServerErrorsAsFailures`] which supports classifying
257    /// regular HTTP responses based on the status code.
258    pub fn new_for_http(inner: S) -> Self {
259        Self {
260            inner,
261            make_classifier: SharedClassifier::new(ServerErrorsAsFailures::default()),
262            make_span: DefaultMakeSpan::new(),
263            on_request: DefaultOnRequest::default(),
264            on_response: DefaultOnResponse::default(),
265            on_body_chunk: DefaultOnBodyChunk::default(),
266            on_eos: DefaultOnEos::default(),
267            on_failure: DefaultOnFailure::default(),
268        }
269    }
270}
271
272impl<S>
273    Trace<
274        S,
275        GrpcMakeClassifier,
276        DefaultMakeSpan,
277        DefaultOnRequest,
278        DefaultOnResponse,
279        DefaultOnBodyChunk,
280        DefaultOnEos,
281        DefaultOnFailure,
282    >
283{
284    /// Create a new [`Trace`] using [`GrpcErrorsAsFailures`] which supports classifying
285    /// gRPC responses and streams based on the `grpc-status` header.
286    pub fn new_for_grpc(inner: S) -> Self {
287        Self {
288            inner,
289            make_classifier: SharedClassifier::new(GrpcErrorsAsFailures::default()),
290            make_span: DefaultMakeSpan::new(),
291            on_request: DefaultOnRequest::default(),
292            on_response: DefaultOnResponse::default(),
293            on_body_chunk: DefaultOnBodyChunk::default(),
294            on_eos: DefaultOnEos::default(),
295            on_failure: DefaultOnFailure::default(),
296        }
297    }
298}
299
300impl<
301        S,
302        State,
303        ReqBody,
304        ResBody,
305        M,
306        OnRequestT,
307        OnResponseT,
308        OnFailureT,
309        OnBodyChunkT,
310        OnEosT,
311        MakeSpanT,
312    > Service<State, Request<ReqBody>>
313    for Trace<S, M, MakeSpanT, OnRequestT, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT>
314where
315    S: Service<State, Request<ReqBody>, Response = Response<ResBody>, Error: fmt::Display>,
316    State: Clone + Send + Sync + 'static,
317    ReqBody: HttpBody + Send + 'static,
318    ResBody: HttpBody<Error: fmt::Display> + Send + Sync + 'static,
319    M: MakeClassifier<Classifier: Clone>,
320    MakeSpanT: MakeSpan<ReqBody>,
321    OnRequestT: OnRequest<ReqBody>,
322    OnResponseT: OnResponse<ResBody> + Clone,
323    OnBodyChunkT: OnBodyChunk<ResBody::Data> + Clone,
324    OnEosT: OnEos + Clone,
325    OnFailureT: OnFailure<M::FailureClass> + Clone,
326{
327    type Response =
328        Response<ResponseBody<ResBody, M::ClassifyEos, OnBodyChunkT, OnEosT, OnFailureT>>;
329    type Error = S::Error;
330
331    async fn serve(
332        &self,
333        ctx: Context<State>,
334        req: Request<ReqBody>,
335    ) -> Result<Self::Response, Self::Error> {
336        let start = Instant::now();
337
338        let span = self.make_span.make_span(&req);
339
340        let classifier = self.make_classifier.make_classifier(&req);
341
342        let result = {
343            let _guard = span.enter();
344            self.on_request.on_request(&req, &span);
345            self.inner.serve(ctx, req)
346        }
347        .await;
348        let latency = start.elapsed();
349
350        match result {
351            Ok(res) => {
352                let classification = classifier.classify_response(&res);
353
354                self.on_response.clone().on_response(&res, latency, &span);
355
356                match classification {
357                    ClassifiedResponse::Ready(classification) => {
358                        if let Err(failure_class) = classification {
359                            self.on_failure.on_failure(failure_class, latency, &span);
360                        }
361
362                        let span = span.clone();
363                        let res = res.map(|body| ResponseBody {
364                            inner: body,
365                            classify_eos: None,
366                            on_eos: None,
367                            on_body_chunk: self.on_body_chunk.clone(),
368                            on_failure: Some(self.on_failure.clone()),
369                            start,
370                            span,
371                        });
372
373                        Ok(res)
374                    }
375                    ClassifiedResponse::RequiresEos(classify_eos) => {
376                        let span = span.clone();
377                        let res = res.map(|body| ResponseBody {
378                            inner: body,
379                            classify_eos: Some(classify_eos),
380                            on_eos: Some((self.on_eos.clone(), Instant::now())),
381                            on_body_chunk: self.on_body_chunk.clone(),
382                            on_failure: Some(self.on_failure.clone()),
383                            start,
384                            span,
385                        });
386
387                        Ok(res)
388                    }
389                }
390            }
391            Err(err) => {
392                let failure_class: <M as MakeClassifier>::FailureClass =
393                    classifier.classify_error(&err);
394                self.on_failure.on_failure(failure_class, latency, &span);
395
396                Err(err)
397            }
398        }
399    }
400}