use tonic::{Request, Status, service::Interceptor};
use uuid::Uuid;
pub const REQUEST_ID_METADATA: &str = "x-request-id";
#[derive(Debug, Clone)]
pub struct RpcRequestId(pub String);
pub async fn with_rpc_request_id<T>(
request_id: impl Into<String>,
future: impl std::future::Future<Output = T>,
) -> T {
let request_id = request_id.into();
crate::rpc::RPC_REQUEST_ID_SCOPE
.scope(std::sync::Arc::new(request_id), future)
.await
}
pub fn request_id_interceptor() -> impl Interceptor {
|mut request: Request<()>| -> Result<Request<()>, Status> {
if !request.metadata().contains_key(REQUEST_ID_METADATA) {
let request_id = request
.extensions()
.get::<RpcRequestId>()
.map(|value| value.0.clone())
.or_else(|| {
#[cfg(feature = "observability")]
{
request
.extensions()
.get::<crate::observability::CurrentRequestId>()
.map(|value| value.0.clone())
}
#[cfg(not(feature = "observability"))]
{
None
}
})
.or_else(crate::layer::context::current_request_id)
.unwrap_or_else(|| Uuid::new_v4().to_string());
let value = request_id
.parse()
.map_err(|_| Status::internal("invalid request id metadata"))?;
request.metadata_mut().insert(REQUEST_ID_METADATA, value);
}
Ok(request)
}
}
#[cfg(feature = "observability")]
pub fn trace_context_interceptor() -> impl Interceptor {
|mut request: Request<()>| -> Result<Request<()>, Status> {
if !request
.metadata()
.contains_key(crate::observability::TRACEPARENT_HEADER)
{
#[cfg(feature = "otlp")]
{
crate::observability::inject_current_context_metadata(request.metadata_mut())
.map_err(|_| Status::internal("invalid traceparent metadata"))?;
}
#[cfg(not(feature = "otlp"))]
if let Some(traceparent) = crate::observability::current_traceparent() {
crate::observability::insert_traceparent_metadata(
request.metadata_mut(),
&traceparent,
)
.map_err(|_| Status::internal("invalid traceparent metadata"))?;
}
}
Ok(request)
}
}
pub fn deadline_interceptor(timeout: std::time::Duration) -> impl Interceptor {
move |mut request: Request<()>| -> Result<Request<()>, Status> {
if !request.metadata().contains_key("grpc-timeout") {
crate::rpc::deadline::insert_grpc_timeout(&mut request, timeout)
.map_err(|_| Status::internal("invalid grpc-timeout metadata"))?;
}
Ok(request)
}
}
pub fn rpc_resilience_key(service: &str, method: &str) -> String {
format!("{service}:{method}")
}
pub fn resilience_rejection_status(reason: impl std::fmt::Display) -> Status {
Status::unavailable(reason.to_string())
}
#[cfg(test)]
mod tests {
use super::{
REQUEST_ID_METADATA, RpcRequestId, deadline_interceptor, request_id_interceptor,
resilience_rejection_status, rpc_resilience_key, with_rpc_request_id,
};
#[cfg(feature = "observability")]
use crate::observability::CurrentRequestId;
use tonic::{Request, service::Interceptor};
#[test]
fn interceptor_sets_request_id() {
let mut interceptor = request_id_interceptor();
let request = interceptor.call(Request::new(())).expect("request");
assert!(request.metadata().contains_key(REQUEST_ID_METADATA));
}
#[cfg(feature = "observability")]
#[test]
fn interceptor_uses_observability_current_request_id() {
let mut interceptor = request_id_interceptor();
let mut request = Request::new(());
request
.extensions_mut()
.insert(CurrentRequestId("req-current-1".to_string()));
let request = interceptor.call(request).expect("request");
assert_eq!(
request
.metadata()
.get(REQUEST_ID_METADATA)
.expect("request id"),
"req-current-1"
);
}
#[tokio::test]
async fn interceptor_uses_scoped_request_id() {
let mut interceptor = request_id_interceptor();
let request = with_rpc_request_id("req-scoped-1", async {
interceptor.call(Request::new(())).expect("request")
})
.await;
assert_eq!(
request
.metadata()
.get(REQUEST_ID_METADATA)
.expect("request id"),
"req-scoped-1"
);
}
#[test]
fn interceptor_uses_request_extension_id() {
let mut interceptor = request_id_interceptor();
let mut request = Request::new(());
request
.extensions_mut()
.insert(RpcRequestId("req-extension-1".to_string()));
let request = interceptor.call(request).expect("request");
assert_eq!(
request
.metadata()
.get(REQUEST_ID_METADATA)
.expect("request id"),
"req-extension-1"
);
}
#[test]
fn interceptor_sets_grpc_timeout() {
let mut interceptor = deadline_interceptor(std::time::Duration::from_millis(30));
let request = interceptor.call(Request::new(())).expect("request");
assert!(request.metadata().contains_key("grpc-timeout"));
}
#[test]
fn rpc_resilience_helpers_are_stable() {
assert_eq!(rpc_resilience_key("hello", "Say"), "hello:Say");
assert_eq!(
resilience_rejection_status("open").code(),
tonic::Code::Unavailable
);
}
}