Skip to main content

rs_zero/layer/
context.rs

1#[cfg(feature = "rpc")]
2use tonic::metadata::MetadataMap;
3
4#[cfg(all(feature = "observability", feature = "rpc"))]
5use crate::observability::insert_traceparent_metadata;
6#[cfg(feature = "observability")]
7use crate::observability::{
8    CorrelationContext, TRACEPARENT_HEADER, insert_traceparent_header, span_id_from_traceparent,
9    trace_id_from_traceparent,
10};
11#[cfg(feature = "rpc")]
12use crate::rpc::{REQUEST_ID_METADATA, RpcRequestId};
13
14/// Low-cardinality request context shared by REST/RPC Tower layers.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct RequestContext {
17    service: String,
18    transport: &'static str,
19    route: String,
20    method: String,
21    request_id: Option<String>,
22    traceparent: Option<String>,
23    trace_id: Option<String>,
24    span_id: Option<String>,
25}
26
27impl RequestContext {
28    /// Creates a context from explicit low-cardinality parts.
29    pub fn new(
30        service: impl Into<String>,
31        transport: &'static str,
32        route: impl Into<String>,
33        method: impl Into<String>,
34    ) -> Self {
35        Self {
36            service: service.into(),
37            transport,
38            route: route.into(),
39            method: method.into(),
40            request_id: None,
41            traceparent: None,
42            trace_id: None,
43            span_id: None,
44        }
45    }
46
47    /// Builds an HTTP request context from headers and a route pattern.
48    #[cfg(feature = "observability")]
49    pub fn from_http_headers(
50        service: Option<&str>,
51        method: impl Into<String>,
52        route: Option<&str>,
53        headers: &http::HeaderMap,
54    ) -> Self {
55        let correlation = CorrelationContext::from_http_headers(service, method, route, headers);
56        Self::from_correlation(correlation)
57    }
58
59    /// Builds a gRPC request context from tonic metadata and a method pattern.
60    #[cfg(all(feature = "observability", feature = "rpc"))]
61    pub fn from_tonic_metadata(
62        service: impl Into<String>,
63        method: impl Into<String>,
64        metadata: &MetadataMap,
65    ) -> Self {
66        let correlation = CorrelationContext::from_rpc_metadata(service, method, metadata);
67        Self::from_correlation(correlation)
68    }
69
70    /// Sets a request id.
71    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
72        let request_id = request_id.into();
73        if !request_id.trim().is_empty() {
74            self.request_id = Some(request_id);
75        }
76        self
77    }
78
79    /// Sets a W3C traceparent and derives trace/span ids when valid.
80    pub fn with_traceparent(mut self, traceparent: impl Into<String>) -> Self {
81        let traceparent = traceparent.into();
82        #[cfg(feature = "observability")]
83        {
84            self.trace_id = trace_id_from_traceparent(&traceparent).map(ToOwned::to_owned);
85            self.span_id = span_id_from_traceparent(&traceparent).map(ToOwned::to_owned);
86        }
87        self.traceparent = Some(traceparent);
88        self
89    }
90
91    /// Returns the service name.
92    pub fn service(&self) -> &str {
93        &self.service
94    }
95
96    /// Returns the transport name.
97    pub fn transport(&self) -> &'static str {
98        self.transport
99    }
100
101    /// Returns the route or RPC method pattern.
102    pub fn route(&self) -> &str {
103        &self.route
104    }
105
106    /// Returns the HTTP or RPC method.
107    pub fn method(&self) -> &str {
108        &self.method
109    }
110
111    /// Returns the request id.
112    pub fn request_id(&self) -> Option<&str> {
113        self.request_id.as_deref()
114    }
115
116    /// Returns the traceparent.
117    pub fn traceparent(&self) -> Option<&str> {
118        self.traceparent.as_deref()
119    }
120
121    /// Returns the trace id.
122    pub fn trace_id(&self) -> Option<&str> {
123        self.trace_id.as_deref()
124    }
125
126    /// Returns the span id.
127    pub fn span_id(&self) -> Option<&str> {
128        self.span_id.as_deref()
129    }
130
131    /// Inserts context values into HTTP headers when they are missing.
132    #[cfg(feature = "observability")]
133    pub fn inject_http_headers(
134        &self,
135        headers: &mut http::HeaderMap,
136    ) -> Result<(), http::header::InvalidHeaderValue> {
137        if let Some(request_id) = self.request_id()
138            && !headers.contains_key(crate::observability::REQUEST_ID_HEADER)
139        {
140            headers.insert(
141                crate::observability::REQUEST_ID_HEADER,
142                http::HeaderValue::from_str(request_id)?,
143            );
144        }
145        if let Some(traceparent) = self.traceparent()
146            && !headers.contains_key(TRACEPARENT_HEADER)
147        {
148            insert_traceparent_header(headers, traceparent)?;
149        }
150        Ok(())
151    }
152
153    /// Inserts context values into tonic metadata when they are missing.
154    #[cfg(feature = "rpc")]
155    pub fn inject_tonic_metadata(
156        &self,
157        metadata: &mut MetadataMap,
158    ) -> Result<(), tonic::metadata::errors::InvalidMetadataValue> {
159        if let Some(request_id) = self.request_id()
160            && !metadata.contains_key(REQUEST_ID_METADATA)
161        {
162            metadata.insert(REQUEST_ID_METADATA, request_id.parse()?);
163        }
164        #[cfg(feature = "observability")]
165        if let Some(traceparent) = self.traceparent()
166            && !metadata.contains_key(TRACEPARENT_HEADER)
167        {
168            insert_traceparent_metadata(metadata, traceparent)?;
169        }
170        Ok(())
171    }
172
173    /// Inserts context values into request extensions for downstream layers.
174    #[cfg(feature = "rpc")]
175    pub fn insert_tonic_extensions<T>(&self, request: &mut tonic::Request<T>) {
176        if let Some(request_id) = self.request_id() {
177            request
178                .extensions_mut()
179                .insert(RpcRequestId(request_id.to_string()));
180        }
181        #[cfg(feature = "observability")]
182        if let Some(request_id) = self.request_id() {
183            request
184                .extensions_mut()
185                .insert(crate::observability::CurrentRequestId(
186                    request_id.to_string(),
187                ));
188        }
189    }
190
191    #[cfg(feature = "observability")]
192    fn from_correlation(correlation: CorrelationContext) -> Self {
193        let mut context = Self::new(
194            correlation.service().to_string(),
195            correlation.transport(),
196            correlation.route().to_string(),
197            correlation.method().to_string(),
198        );
199        if let Some(request_id) = correlation.request_id() {
200            context.request_id = Some(request_id.to_string());
201        }
202        if let Some(traceparent) = correlation.traceparent() {
203            context.traceparent = Some(traceparent.to_string());
204        }
205        if let Some(trace_id) = correlation.trace_id() {
206            context.trace_id = Some(trace_id.to_string());
207        }
208        if let Some(span_id) = correlation.span_id() {
209            context.span_id = Some(span_id.to_string());
210        }
211        context
212    }
213}
214
215/// Returns the current task-local request id when one is available.
216pub fn current_request_id() -> Option<String> {
217    #[cfg(feature = "rpc")]
218    {
219        crate::rpc::RPC_REQUEST_ID_SCOPE
220            .try_with(|value| value.to_string())
221            .ok()
222    }
223
224    #[cfg(not(feature = "rpc"))]
225    {
226        None
227    }
228}
229
230/// Runs a future with request id available to outgoing RPC layers.
231pub async fn scope_request_id<T>(
232    request_id: impl Into<String>,
233    future: impl std::future::Future<Output = T>,
234) -> T {
235    #[cfg(feature = "rpc")]
236    {
237        crate::rpc::with_rpc_request_id(request_id, future).await
238    }
239
240    #[cfg(not(feature = "rpc"))]
241    {
242        let _ = request_id.into();
243        future.await
244    }
245}