phala_rocket_middleware/
request_tracer.rs1use std::fmt::{Display, Formatter};
2use std::sync::atomic::{AtomicU64, Ordering};
3
4use rocket::fairing::{Fairing, Info, Kind};
5use rocket::request::{FromRequest, Outcome};
6use rocket::{Data, Request, Response};
7
8pub struct RequestTracer {
10 sn: AtomicU64,
11 step: u64,
12}
13
14impl Default for RequestTracer {
15 fn default() -> Self {
16 Self::new(1)
17 }
18}
19
20impl RequestTracer {
21 pub fn new(step: u64) -> Self {
23 Self {
24 sn: AtomicU64::new(0),
25 step,
26 }
27 }
28
29 fn next_id(&self) -> TraceId {
30 TraceId(self.sn.fetch_add(self.step, Ordering::Relaxed))
31 }
32}
33
34#[derive(Debug, Clone, Copy, Default)]
35pub struct TraceId(u64);
36
37impl Display for TraceId {
38 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43impl TraceId {
44 pub fn id(&self) -> u64 {
45 self.0
46 }
47}
48
49#[rocket::async_trait]
50impl Fairing for RequestTracer {
51 fn info(&self) -> Info {
52 Info {
53 name: "Reqeust Tracer",
54 kind: Kind::Request | Kind::Response,
55 }
56 }
57
58 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
59 let trace_id = self.next_id();
60 let _t = request.local_cache(|| trace_id);
61 }
62
63 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
64 let trace = request.local_cache(TraceId::default);
65 response.set_raw_header("X-Request-Id", trace.id().to_string());
66 }
67}
68
69#[rocket::async_trait]
70impl<'r> FromRequest<'r> for TraceId {
71 type Error = ();
72
73 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
74 let trace = request.local_cache(TraceId::default);
75 Outcome::Success(*trace)
76 }
77}