1use super::{
2 DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest,
3 DefaultOnResponse, GrpcMakeClassifier, HttpMakeClassifier, MakeSpan, OnBodyChunk, OnEos,
4 OnFailure, OnRequest, OnResponse, ResponseBody,
5};
6use crate::dep::http_body::Body as HttpBody;
7use crate::layer::classify::{
8 ClassifiedResponse, ClassifyResponse, GrpcErrorsAsFailures, MakeClassifier,
9 ServerErrorsAsFailures, SharedClassifier,
10};
11use crate::{Request, Response};
12use rama_core::{Context, Service};
13use rama_utils::macros::define_inner_service_accessors;
14use std::{fmt, time::Instant};
15
16pub struct Trace<
23 S,
24 M,
25 MakeSpan = DefaultMakeSpan,
26 OnRequest = DefaultOnRequest,
27 OnResponse = DefaultOnResponse,
28 OnBodyChunk = DefaultOnBodyChunk,
29 OnEos = DefaultOnEos,
30 OnFailure = DefaultOnFailure,
31> {
32 pub(crate) inner: S,
33 pub(crate) make_classifier: M,
34 pub(crate) make_span: MakeSpan,
35 pub(crate) on_request: OnRequest,
36 pub(crate) on_response: OnResponse,
37 pub(crate) on_body_chunk: OnBodyChunk,
38 pub(crate) on_eos: OnEos,
39 pub(crate) on_failure: OnFailure,
40}
41
42impl<
43 S: fmt::Debug,
44 M: fmt::Debug,
45 MakeSpan: fmt::Debug,
46 OnRequest: fmt::Debug,
47 OnResponse: fmt::Debug,
48 OnBodyChunk: fmt::Debug,
49 OnEos: fmt::Debug,
50 OnFailure: fmt::Debug,
51 > fmt::Debug for Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
52{
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 f.debug_struct("TraceLayer")
55 .field("inner", &self.inner)
56 .field("make_classifier", &self.make_classifier)
57 .field("make_span", &self.make_span)
58 .field("on_request", &self.on_request)
59 .field("on_response", &self.on_response)
60 .field("on_body_chunk", &self.on_body_chunk)
61 .field("on_eos", &self.on_eos)
62 .field("on_failure", &self.on_failure)
63 .finish()
64 }
65}
66
67impl<
68 S: Clone,
69 M: Clone,
70 MakeSpan: Clone,
71 OnRequest: Clone,
72 OnResponse: Clone,
73 OnBodyChunk: Clone,
74 OnEos: Clone,
75 OnFailure: Clone,
76 > Clone for Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
77{
78 fn clone(&self) -> Self {
79 Self {
80 inner: self.inner.clone(),
81 make_classifier: self.make_classifier.clone(),
82 make_span: self.make_span.clone(),
83 on_request: self.on_request.clone(),
84 on_response: self.on_response.clone(),
85 on_body_chunk: self.on_body_chunk.clone(),
86 on_eos: self.on_eos.clone(),
87 on_failure: self.on_failure.clone(),
88 }
89 }
90}
91
92impl<S, M> Trace<S, M> {
93 pub fn new(inner: S, make_classifier: M) -> Self
95 where
96 M: MakeClassifier,
97 {
98 Self {
99 inner,
100 make_classifier,
101 make_span: DefaultMakeSpan::new(),
102 on_request: DefaultOnRequest::default(),
103 on_response: DefaultOnResponse::default(),
104 on_body_chunk: DefaultOnBodyChunk::default(),
105 on_eos: DefaultOnEos::default(),
106 on_failure: DefaultOnFailure::default(),
107 }
108 }
109}
110
111impl<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
112 Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
113{
114 define_inner_service_accessors!();
115
116 pub fn on_request<NewOnRequest>(
122 self,
123 new_on_request: NewOnRequest,
124 ) -> Trace<S, M, MakeSpan, NewOnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> {
125 Trace {
126 on_request: new_on_request,
127 inner: self.inner,
128 on_failure: self.on_failure,
129 on_eos: self.on_eos,
130 on_body_chunk: self.on_body_chunk,
131 make_span: self.make_span,
132 on_response: self.on_response,
133 make_classifier: self.make_classifier,
134 }
135 }
136
137 pub fn on_response<NewOnResponse>(
143 self,
144 new_on_response: NewOnResponse,
145 ) -> Trace<S, M, MakeSpan, OnRequest, NewOnResponse, OnBodyChunk, OnEos, OnFailure> {
146 Trace {
147 on_response: new_on_response,
148 inner: self.inner,
149 on_request: self.on_request,
150 on_failure: self.on_failure,
151 on_body_chunk: self.on_body_chunk,
152 on_eos: self.on_eos,
153 make_span: self.make_span,
154 make_classifier: self.make_classifier,
155 }
156 }
157
158 pub fn on_body_chunk<NewOnBodyChunk>(
164 self,
165 new_on_body_chunk: NewOnBodyChunk,
166 ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, NewOnBodyChunk, OnEos, OnFailure> {
167 Trace {
168 on_body_chunk: new_on_body_chunk,
169 on_eos: self.on_eos,
170 make_span: self.make_span,
171 inner: self.inner,
172 on_failure: self.on_failure,
173 on_request: self.on_request,
174 on_response: self.on_response,
175 make_classifier: self.make_classifier,
176 }
177 }
178
179 pub fn on_eos<NewOnEos>(
185 self,
186 new_on_eos: NewOnEos,
187 ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, NewOnEos, OnFailure> {
188 Trace {
189 on_eos: new_on_eos,
190 make_span: self.make_span,
191 inner: self.inner,
192 on_failure: self.on_failure,
193 on_request: self.on_request,
194 on_body_chunk: self.on_body_chunk,
195 on_response: self.on_response,
196 make_classifier: self.make_classifier,
197 }
198 }
199
200 pub fn on_failure<NewOnFailure>(
206 self,
207 new_on_failure: NewOnFailure,
208 ) -> Trace<S, M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, NewOnFailure> {
209 Trace {
210 on_failure: new_on_failure,
211 inner: self.inner,
212 make_span: self.make_span,
213 on_body_chunk: self.on_body_chunk,
214 on_request: self.on_request,
215 on_eos: self.on_eos,
216 on_response: self.on_response,
217 make_classifier: self.make_classifier,
218 }
219 }
220
221 pub fn make_span_with<NewMakeSpan>(
228 self,
229 new_make_span: NewMakeSpan,
230 ) -> Trace<S, M, NewMakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure> {
231 Trace {
232 make_span: new_make_span,
233 inner: self.inner,
234 on_failure: self.on_failure,
235 on_request: self.on_request,
236 on_body_chunk: self.on_body_chunk,
237 on_response: self.on_response,
238 on_eos: self.on_eos,
239 make_classifier: self.make_classifier,
240 }
241 }
242}
243
244impl<S>
245 Trace<
246 S,
247 HttpMakeClassifier,
248 DefaultMakeSpan,
249 DefaultOnRequest,
250 DefaultOnResponse,
251 DefaultOnBodyChunk,
252 DefaultOnEos,
253 DefaultOnFailure,
254 >
255{
256 pub fn new_for_http(inner: S) -> Self {
259 Self {
260 inner,
261 make_classifier: SharedClassifier::new(ServerErrorsAsFailures::default()),
262 make_span: DefaultMakeSpan::new(),
263 on_request: DefaultOnRequest::default(),
264 on_response: DefaultOnResponse::default(),
265 on_body_chunk: DefaultOnBodyChunk::default(),
266 on_eos: DefaultOnEos::default(),
267 on_failure: DefaultOnFailure::default(),
268 }
269 }
270}
271
272impl<S>
273 Trace<
274 S,
275 GrpcMakeClassifier,
276 DefaultMakeSpan,
277 DefaultOnRequest,
278 DefaultOnResponse,
279 DefaultOnBodyChunk,
280 DefaultOnEos,
281 DefaultOnFailure,
282 >
283{
284 pub fn new_for_grpc(inner: S) -> Self {
287 Self {
288 inner,
289 make_classifier: SharedClassifier::new(GrpcErrorsAsFailures::default()),
290 make_span: DefaultMakeSpan::new(),
291 on_request: DefaultOnRequest::default(),
292 on_response: DefaultOnResponse::default(),
293 on_body_chunk: DefaultOnBodyChunk::default(),
294 on_eos: DefaultOnEos::default(),
295 on_failure: DefaultOnFailure::default(),
296 }
297 }
298}
299
300impl<
301 S,
302 State,
303 ReqBody,
304 ResBody,
305 M,
306 OnRequestT,
307 OnResponseT,
308 OnFailureT,
309 OnBodyChunkT,
310 OnEosT,
311 MakeSpanT,
312 > Service<State, Request<ReqBody>>
313 for Trace<S, M, MakeSpanT, OnRequestT, OnResponseT, OnBodyChunkT, OnEosT, OnFailureT>
314where
315 S: Service<State, Request<ReqBody>, Response = Response<ResBody>, Error: fmt::Display>,
316 State: Clone + Send + Sync + 'static,
317 ReqBody: HttpBody + Send + 'static,
318 ResBody: HttpBody<Error: fmt::Display> + Send + Sync + 'static,
319 M: MakeClassifier<Classifier: Clone>,
320 MakeSpanT: MakeSpan<ReqBody>,
321 OnRequestT: OnRequest<ReqBody>,
322 OnResponseT: OnResponse<ResBody> + Clone,
323 OnBodyChunkT: OnBodyChunk<ResBody::Data> + Clone,
324 OnEosT: OnEos + Clone,
325 OnFailureT: OnFailure<M::FailureClass> + Clone,
326{
327 type Response =
328 Response<ResponseBody<ResBody, M::ClassifyEos, OnBodyChunkT, OnEosT, OnFailureT>>;
329 type Error = S::Error;
330
331 async fn serve(
332 &self,
333 ctx: Context<State>,
334 req: Request<ReqBody>,
335 ) -> Result<Self::Response, Self::Error> {
336 let start = Instant::now();
337
338 let span = self.make_span.make_span(&req);
339
340 let classifier = self.make_classifier.make_classifier(&req);
341
342 let result = {
343 let _guard = span.enter();
344 self.on_request.on_request(&req, &span);
345 self.inner.serve(ctx, req)
346 }
347 .await;
348 let latency = start.elapsed();
349
350 match result {
351 Ok(res) => {
352 let classification = classifier.classify_response(&res);
353
354 self.on_response.clone().on_response(&res, latency, &span);
355
356 match classification {
357 ClassifiedResponse::Ready(classification) => {
358 if let Err(failure_class) = classification {
359 self.on_failure.on_failure(failure_class, latency, &span);
360 }
361
362 let span = span.clone();
363 let res = res.map(|body| ResponseBody {
364 inner: body,
365 classify_eos: None,
366 on_eos: None,
367 on_body_chunk: self.on_body_chunk.clone(),
368 on_failure: Some(self.on_failure.clone()),
369 start,
370 span,
371 });
372
373 Ok(res)
374 }
375 ClassifiedResponse::RequiresEos(classify_eos) => {
376 let span = span.clone();
377 let res = res.map(|body| ResponseBody {
378 inner: body,
379 classify_eos: Some(classify_eos),
380 on_eos: Some((self.on_eos.clone(), Instant::now())),
381 on_body_chunk: self.on_body_chunk.clone(),
382 on_failure: Some(self.on_failure.clone()),
383 start,
384 span,
385 });
386
387 Ok(res)
388 }
389 }
390 }
391 Err(err) => {
392 let failure_class: <M as MakeClassifier>::FailureClass =
393 classifier.classify_error(&err);
394 self.on_failure.on_failure(failure_class, latency, &span);
395
396 Err(err)
397 }
398 }
399 }
400}