rs_zero/rpc/
interceptor.rs1use tonic::{Request, Status, service::Interceptor};
2use uuid::Uuid;
3
4pub const REQUEST_ID_METADATA: &str = "x-request-id";
6
7#[derive(Debug, Clone)]
10pub struct RpcRequestId(pub String);
11
12pub async fn with_rpc_request_id<T>(
14 request_id: impl Into<String>,
15 future: impl std::future::Future<Output = T>,
16) -> T {
17 let request_id = request_id.into();
18 crate::rpc::RPC_REQUEST_ID_SCOPE
19 .scope(std::sync::Arc::new(request_id), future)
20 .await
21}
22
23pub fn request_id_interceptor() -> impl Interceptor {
25 |mut request: Request<()>| -> Result<Request<()>, Status> {
26 if !request.metadata().contains_key(REQUEST_ID_METADATA) {
27 let request_id = request
28 .extensions()
29 .get::<RpcRequestId>()
30 .map(|value| value.0.clone())
31 .or_else(|| {
32 #[cfg(feature = "observability")]
33 {
34 request
35 .extensions()
36 .get::<crate::observability::CurrentRequestId>()
37 .map(|value| value.0.clone())
38 }
39 #[cfg(not(feature = "observability"))]
40 {
41 None
42 }
43 })
44 .or_else(|| {
45 crate::rpc::RPC_REQUEST_ID_SCOPE
46 .try_with(|value| value.to_string())
47 .ok()
48 })
49 .unwrap_or_else(|| Uuid::new_v4().to_string());
50 let value = request_id
51 .parse()
52 .map_err(|_| Status::internal("invalid request id metadata"))?;
53 request.metadata_mut().insert(REQUEST_ID_METADATA, value);
54 }
55
56 Ok(request)
57 }
58}
59
60#[cfg(feature = "observability")]
62pub fn trace_context_interceptor() -> impl Interceptor {
63 |mut request: Request<()>| -> Result<Request<()>, Status> {
64 if !request
65 .metadata()
66 .contains_key(crate::observability::TRACEPARENT_HEADER)
67 {
68 #[cfg(feature = "otlp")]
69 {
70 crate::observability::inject_current_context_metadata(request.metadata_mut())
71 .map_err(|_| Status::internal("invalid traceparent metadata"))?;
72 }
73
74 #[cfg(not(feature = "otlp"))]
75 if let Some(traceparent) = crate::observability::current_traceparent() {
76 crate::observability::insert_traceparent_metadata(
77 request.metadata_mut(),
78 &traceparent,
79 )
80 .map_err(|_| Status::internal("invalid traceparent metadata"))?;
81 }
82 }
83
84 Ok(request)
85 }
86}
87
88pub fn deadline_interceptor(timeout: std::time::Duration) -> impl Interceptor {
90 move |mut request: Request<()>| -> Result<Request<()>, Status> {
91 if !request.metadata().contains_key("grpc-timeout") {
92 crate::rpc::deadline::insert_grpc_timeout(&mut request, timeout)
93 .map_err(|_| Status::internal("invalid grpc-timeout metadata"))?;
94 }
95 Ok(request)
96 }
97}
98
99pub fn rpc_resilience_key(service: &str, method: &str) -> String {
101 format!("{service}:{method}")
102}
103
104pub fn resilience_rejection_status(reason: impl std::fmt::Display) -> Status {
106 Status::unavailable(reason.to_string())
107}
108
109#[cfg(test)]
110mod tests {
111 use super::{
112 REQUEST_ID_METADATA, RpcRequestId, deadline_interceptor, request_id_interceptor,
113 resilience_rejection_status, rpc_resilience_key, with_rpc_request_id,
114 };
115 #[cfg(feature = "observability")]
116 use crate::observability::CurrentRequestId;
117 use tonic::{Request, service::Interceptor};
118
119 #[test]
120 fn interceptor_sets_request_id() {
121 let mut interceptor = request_id_interceptor();
122 let request = interceptor.call(Request::new(())).expect("request");
123
124 assert!(request.metadata().contains_key(REQUEST_ID_METADATA));
125 }
126
127 #[cfg(feature = "observability")]
128 #[test]
129 fn interceptor_uses_observability_current_request_id() {
130 let mut interceptor = request_id_interceptor();
131 let mut request = Request::new(());
132 request
133 .extensions_mut()
134 .insert(CurrentRequestId("req-current-1".to_string()));
135
136 let request = interceptor.call(request).expect("request");
137
138 assert_eq!(
139 request
140 .metadata()
141 .get(REQUEST_ID_METADATA)
142 .expect("request id"),
143 "req-current-1"
144 );
145 }
146
147 #[tokio::test]
148 async fn interceptor_uses_scoped_request_id() {
149 let mut interceptor = request_id_interceptor();
150 let request = with_rpc_request_id("req-scoped-1", async {
151 interceptor.call(Request::new(())).expect("request")
152 })
153 .await;
154
155 assert_eq!(
156 request
157 .metadata()
158 .get(REQUEST_ID_METADATA)
159 .expect("request id"),
160 "req-scoped-1"
161 );
162 }
163
164 #[test]
165 fn interceptor_uses_request_extension_id() {
166 let mut interceptor = request_id_interceptor();
167 let mut request = Request::new(());
168 request
169 .extensions_mut()
170 .insert(RpcRequestId("req-extension-1".to_string()));
171
172 let request = interceptor.call(request).expect("request");
173
174 assert_eq!(
175 request
176 .metadata()
177 .get(REQUEST_ID_METADATA)
178 .expect("request id"),
179 "req-extension-1"
180 );
181 }
182
183 #[test]
184 fn interceptor_sets_grpc_timeout() {
185 let mut interceptor = deadline_interceptor(std::time::Duration::from_millis(30));
186 let request = interceptor.call(Request::new(())).expect("request");
187
188 assert!(request.metadata().contains_key("grpc-timeout"));
189 }
190
191 #[test]
192 fn rpc_resilience_helpers_are_stable() {
193 assert_eq!(rpc_resilience_key("hello", "Say"), "hello:Say");
194 assert_eq!(
195 resilience_rejection_status("open").code(),
196 tonic::Code::Unavailable
197 );
198 }
199}