Skip to main content

rs_zero/rpc/
interceptor.rs

1use tonic::{Request, Status, service::Interceptor};
2use uuid::Uuid;
3
4/// Metadata key propagated by rs-zero RPC clients.
5pub const REQUEST_ID_METADATA: &str = "x-request-id";
6
7/// Request extension key used by [`request_id_interceptor`] to propagate an
8/// existing request id from the current task into tonic metadata.
9#[derive(Debug, Clone)]
10pub struct RpcRequestId(pub String);
11
12/// Runs a future with an RPC request id propagation context.
13pub async fn with_rpc_request_id<T>(
14    request_id: impl Into<String>,
15    future: impl std::future::Future<Output = T>,
16) -> T {
17    let request_id = request_id.into();
18    crate::rpc::RPC_REQUEST_ID_SCOPE
19        .scope(std::sync::Arc::new(request_id), future)
20        .await
21}
22
23/// Builds an interceptor that adds a request id when missing.
24pub fn request_id_interceptor() -> impl Interceptor {
25    |mut request: Request<()>| -> Result<Request<()>, Status> {
26        if !request.metadata().contains_key(REQUEST_ID_METADATA) {
27            let request_id = request
28                .extensions()
29                .get::<RpcRequestId>()
30                .map(|value| value.0.clone())
31                .or_else(|| {
32                    #[cfg(feature = "observability")]
33                    {
34                        request
35                            .extensions()
36                            .get::<crate::observability::CurrentRequestId>()
37                            .map(|value| value.0.clone())
38                    }
39                    #[cfg(not(feature = "observability"))]
40                    {
41                        None
42                    }
43                })
44                .or_else(crate::layer::context::current_request_id)
45                .unwrap_or_else(|| Uuid::new_v4().to_string());
46            let value = request_id
47                .parse()
48                .map_err(|_| Status::internal("invalid request id metadata"))?;
49            request.metadata_mut().insert(REQUEST_ID_METADATA, value);
50        }
51
52        Ok(request)
53    }
54}
55
56/// Builds an interceptor that injects the current traceparent when available.
57#[cfg(feature = "observability")]
58pub fn trace_context_interceptor() -> impl Interceptor {
59    |mut request: Request<()>| -> Result<Request<()>, Status> {
60        if !request
61            .metadata()
62            .contains_key(crate::observability::TRACEPARENT_HEADER)
63        {
64            #[cfg(feature = "otlp")]
65            {
66                crate::observability::inject_current_context_metadata(request.metadata_mut())
67                    .map_err(|_| Status::internal("invalid traceparent metadata"))?;
68            }
69
70            #[cfg(not(feature = "otlp"))]
71            if let Some(traceparent) = crate::observability::current_traceparent() {
72                crate::observability::insert_traceparent_metadata(
73                    request.metadata_mut(),
74                    &traceparent,
75                )
76                .map_err(|_| Status::internal("invalid traceparent metadata"))?;
77            }
78        }
79
80        Ok(request)
81    }
82}
83
84/// Builds an interceptor that adds `grpc-timeout` metadata when missing.
85pub fn deadline_interceptor(timeout: std::time::Duration) -> impl Interceptor {
86    move |mut request: Request<()>| -> Result<Request<()>, Status> {
87        if !request.metadata().contains_key("grpc-timeout") {
88            crate::rpc::deadline::insert_grpc_timeout(&mut request, timeout)
89                .map_err(|_| Status::internal("invalid grpc-timeout metadata"))?;
90        }
91        Ok(request)
92    }
93}
94
95/// Builds a stable resilience key for RPC adapters.
96pub fn rpc_resilience_key(service: &str, method: &str) -> String {
97    format!("{service}:{method}")
98}
99
100/// Maps a resilience rejection into a tonic unavailable status.
101pub fn resilience_rejection_status(reason: impl std::fmt::Display) -> Status {
102    Status::unavailable(reason.to_string())
103}
104
105#[cfg(test)]
106mod tests {
107    use super::{
108        REQUEST_ID_METADATA, RpcRequestId, deadline_interceptor, request_id_interceptor,
109        resilience_rejection_status, rpc_resilience_key, with_rpc_request_id,
110    };
111    #[cfg(feature = "observability")]
112    use crate::observability::CurrentRequestId;
113    use tonic::{Request, service::Interceptor};
114
115    #[test]
116    fn interceptor_sets_request_id() {
117        let mut interceptor = request_id_interceptor();
118        let request = interceptor.call(Request::new(())).expect("request");
119
120        assert!(request.metadata().contains_key(REQUEST_ID_METADATA));
121    }
122
123    #[cfg(feature = "observability")]
124    #[test]
125    fn interceptor_uses_observability_current_request_id() {
126        let mut interceptor = request_id_interceptor();
127        let mut request = Request::new(());
128        request
129            .extensions_mut()
130            .insert(CurrentRequestId("req-current-1".to_string()));
131
132        let request = interceptor.call(request).expect("request");
133
134        assert_eq!(
135            request
136                .metadata()
137                .get(REQUEST_ID_METADATA)
138                .expect("request id"),
139            "req-current-1"
140        );
141    }
142
143    #[tokio::test]
144    async fn interceptor_uses_scoped_request_id() {
145        let mut interceptor = request_id_interceptor();
146        let request = with_rpc_request_id("req-scoped-1", async {
147            interceptor.call(Request::new(())).expect("request")
148        })
149        .await;
150
151        assert_eq!(
152            request
153                .metadata()
154                .get(REQUEST_ID_METADATA)
155                .expect("request id"),
156            "req-scoped-1"
157        );
158    }
159
160    #[test]
161    fn interceptor_uses_request_extension_id() {
162        let mut interceptor = request_id_interceptor();
163        let mut request = Request::new(());
164        request
165            .extensions_mut()
166            .insert(RpcRequestId("req-extension-1".to_string()));
167
168        let request = interceptor.call(request).expect("request");
169
170        assert_eq!(
171            request
172                .metadata()
173                .get(REQUEST_ID_METADATA)
174                .expect("request id"),
175            "req-extension-1"
176        );
177    }
178
179    #[test]
180    fn interceptor_sets_grpc_timeout() {
181        let mut interceptor = deadline_interceptor(std::time::Duration::from_millis(30));
182        let request = interceptor.call(Request::new(())).expect("request");
183
184        assert!(request.metadata().contains_key("grpc-timeout"));
185    }
186
187    #[test]
188    fn rpc_resilience_helpers_are_stable() {
189        assert_eq!(rpc_resilience_key("hello", "Say"), "hello:Say");
190        assert_eq!(
191            resilience_rejection_status("open").code(),
192            tonic::Code::Unavailable
193        );
194    }
195}