Skip to main content

rs_zero/rpc/
interceptor.rs

1use tonic::{Request, Status, service::Interceptor};
2use uuid::Uuid;
3
4/// Metadata key propagated by rs-zero RPC clients.
5pub const REQUEST_ID_METADATA: &str = "x-request-id";
6
7/// Request extension key used by [`request_id_interceptor`] to propagate an
8/// existing request id from the current task into tonic metadata.
9#[derive(Debug, Clone)]
10pub struct RpcRequestId(pub String);
11
12/// Runs a future with an RPC request id propagation context.
13pub async fn with_rpc_request_id<T>(
14    request_id: impl Into<String>,
15    future: impl std::future::Future<Output = T>,
16) -> T {
17    let request_id = request_id.into();
18    crate::rpc::RPC_REQUEST_ID_SCOPE
19        .scope(std::sync::Arc::new(request_id), future)
20        .await
21}
22
23/// Builds an interceptor that adds a request id when missing.
24pub fn request_id_interceptor() -> impl Interceptor {
25    |mut request: Request<()>| -> Result<Request<()>, Status> {
26        if !request.metadata().contains_key(REQUEST_ID_METADATA) {
27            let request_id = request
28                .extensions()
29                .get::<RpcRequestId>()
30                .map(|value| value.0.clone())
31                .or_else(|| {
32                    #[cfg(feature = "observability")]
33                    {
34                        request
35                            .extensions()
36                            .get::<crate::observability::CurrentRequestId>()
37                            .map(|value| value.0.clone())
38                    }
39                    #[cfg(not(feature = "observability"))]
40                    {
41                        None
42                    }
43                })
44                .or_else(|| {
45                    crate::rpc::RPC_REQUEST_ID_SCOPE
46                        .try_with(|value| value.to_string())
47                        .ok()
48                })
49                .unwrap_or_else(|| Uuid::new_v4().to_string());
50            let value = request_id
51                .parse()
52                .map_err(|_| Status::internal("invalid request id metadata"))?;
53            request.metadata_mut().insert(REQUEST_ID_METADATA, value);
54        }
55
56        Ok(request)
57    }
58}
59
60/// Builds an interceptor that injects the current traceparent when available.
61#[cfg(feature = "observability")]
62pub fn trace_context_interceptor() -> impl Interceptor {
63    |mut request: Request<()>| -> Result<Request<()>, Status> {
64        if !request
65            .metadata()
66            .contains_key(crate::observability::TRACEPARENT_HEADER)
67        {
68            #[cfg(feature = "otlp")]
69            {
70                crate::observability::inject_current_context_metadata(request.metadata_mut())
71                    .map_err(|_| Status::internal("invalid traceparent metadata"))?;
72            }
73
74            #[cfg(not(feature = "otlp"))]
75            if let Some(traceparent) = crate::observability::current_traceparent() {
76                crate::observability::insert_traceparent_metadata(
77                    request.metadata_mut(),
78                    &traceparent,
79                )
80                .map_err(|_| Status::internal("invalid traceparent metadata"))?;
81            }
82        }
83
84        Ok(request)
85    }
86}
87
88/// Builds an interceptor that adds `grpc-timeout` metadata when missing.
89pub fn deadline_interceptor(timeout: std::time::Duration) -> impl Interceptor {
90    move |mut request: Request<()>| -> Result<Request<()>, Status> {
91        if !request.metadata().contains_key("grpc-timeout") {
92            crate::rpc::deadline::insert_grpc_timeout(&mut request, timeout)
93                .map_err(|_| Status::internal("invalid grpc-timeout metadata"))?;
94        }
95        Ok(request)
96    }
97}
98
99/// Builds a stable resilience key for RPC adapters.
100pub fn rpc_resilience_key(service: &str, method: &str) -> String {
101    format!("{service}:{method}")
102}
103
104/// Maps a resilience rejection into a tonic unavailable status.
105pub fn resilience_rejection_status(reason: impl std::fmt::Display) -> Status {
106    Status::unavailable(reason.to_string())
107}
108
109#[cfg(test)]
110mod tests {
111    use super::{
112        REQUEST_ID_METADATA, RpcRequestId, deadline_interceptor, request_id_interceptor,
113        resilience_rejection_status, rpc_resilience_key, with_rpc_request_id,
114    };
115    #[cfg(feature = "observability")]
116    use crate::observability::CurrentRequestId;
117    use tonic::{Request, service::Interceptor};
118
119    #[test]
120    fn interceptor_sets_request_id() {
121        let mut interceptor = request_id_interceptor();
122        let request = interceptor.call(Request::new(())).expect("request");
123
124        assert!(request.metadata().contains_key(REQUEST_ID_METADATA));
125    }
126
127    #[cfg(feature = "observability")]
128    #[test]
129    fn interceptor_uses_observability_current_request_id() {
130        let mut interceptor = request_id_interceptor();
131        let mut request = Request::new(());
132        request
133            .extensions_mut()
134            .insert(CurrentRequestId("req-current-1".to_string()));
135
136        let request = interceptor.call(request).expect("request");
137
138        assert_eq!(
139            request
140                .metadata()
141                .get(REQUEST_ID_METADATA)
142                .expect("request id"),
143            "req-current-1"
144        );
145    }
146
147    #[tokio::test]
148    async fn interceptor_uses_scoped_request_id() {
149        let mut interceptor = request_id_interceptor();
150        let request = with_rpc_request_id("req-scoped-1", async {
151            interceptor.call(Request::new(())).expect("request")
152        })
153        .await;
154
155        assert_eq!(
156            request
157                .metadata()
158                .get(REQUEST_ID_METADATA)
159                .expect("request id"),
160            "req-scoped-1"
161        );
162    }
163
164    #[test]
165    fn interceptor_uses_request_extension_id() {
166        let mut interceptor = request_id_interceptor();
167        let mut request = Request::new(());
168        request
169            .extensions_mut()
170            .insert(RpcRequestId("req-extension-1".to_string()));
171
172        let request = interceptor.call(request).expect("request");
173
174        assert_eq!(
175            request
176                .metadata()
177                .get(REQUEST_ID_METADATA)
178                .expect("request id"),
179            "req-extension-1"
180        );
181    }
182
183    #[test]
184    fn interceptor_sets_grpc_timeout() {
185        let mut interceptor = deadline_interceptor(std::time::Duration::from_millis(30));
186        let request = interceptor.call(Request::new(())).expect("request");
187
188        assert!(request.metadata().contains_key("grpc-timeout"));
189    }
190
191    #[test]
192    fn rpc_resilience_helpers_are_stable() {
193        assert_eq!(rpc_resilience_key("hello", "Say"), "hello:Say");
194        assert_eq!(
195            resilience_rejection_status("open").code(),
196            tonic::Code::Unavailable
197        );
198    }
199}