phala_rocket_middleware/
request_tracer.rs

1use std::fmt::{Display, Formatter};
2use std::sync::atomic::{AtomicU64, Ordering};
3
4use rocket::fairing::{Fairing, Info, Kind};
5use rocket::request::{FromRequest, Outcome};
6use rocket::{Data, Request, Response};
7
8/// Set a unique trace id for each request.
9pub struct RequestTracer {
10    sn: AtomicU64,
11    step: u64,
12}
13
14impl Default for RequestTracer {
15    fn default() -> Self {
16        Self::new(1)
17    }
18}
19
20impl RequestTracer {
21    /// Create a new RequestTracer with a given step.
22    pub fn new(step: u64) -> Self {
23        Self {
24            sn: AtomicU64::new(0),
25            step,
26        }
27    }
28
29    fn next_id(&self) -> TraceId {
30        TraceId(self.sn.fetch_add(self.step, Ordering::Relaxed))
31    }
32}
33
34#[derive(Debug, Clone, Copy, Default)]
35pub struct TraceId(u64);
36
37impl Display for TraceId {
38    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39        self.0.fmt(f)
40    }
41}
42
43impl TraceId {
44    pub fn id(&self) -> u64 {
45        self.0
46    }
47}
48
49#[rocket::async_trait]
50impl Fairing for RequestTracer {
51    fn info(&self) -> Info {
52        Info {
53            name: "Reqeust Tracer",
54            kind: Kind::Request | Kind::Response,
55        }
56    }
57
58    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
59        let trace_id = self.next_id();
60        let _t = request.local_cache(|| trace_id);
61    }
62
63    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
64        let trace = request.local_cache(TraceId::default);
65        response.set_raw_header("X-Request-Id", trace.id().to_string());
66    }
67}
68
69#[rocket::async_trait]
70impl<'r> FromRequest<'r> for TraceId {
71    type Error = ();
72
73    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
74        let trace = request.local_cache(TraceId::default);
75        Outcome::Success(*trace)
76    }
77}