pub trait GrpcRequestExt {
fn get_di_context<T: Clone + Send + Sync + 'static>(&self) -> Option<T>;
}
impl<T> GrpcRequestExt for tonic::Request<T> {
fn get_di_context<C: Clone + Send + Sync + 'static>(&self) -> Option<C> {
self.extensions().get::<C>().cloned()
}
}
pub fn sanitize_di_error(error: &reinhardt_di::DiError) -> tonic::Status {
tracing::error!(
error = %error,
"DI resolution failed"
);
tonic::Status::internal("Internal server error")
}
pub fn sanitize_missing_context() -> tonic::Status {
tracing::error!("DI context not found in request extensions");
tonic::Status::internal("Internal server error")
}
#[cfg(test)]
mod tests {
use super::*;
use reinhardt_di::InjectionContext;
use rstest::rstest;
use std::sync::Arc;
use tonic::Request;
#[rstest]
fn grpc_request_ext_extracts_di_context() {
let singleton_scope = reinhardt_di::SingletonScope::new();
let injection_ctx = Arc::new(InjectionContext::builder(singleton_scope).build());
let mut request = Request::new(());
request.extensions_mut().insert(injection_ctx.clone());
let extracted = request
.get_di_context::<Arc<InjectionContext>>()
.expect("DI context should be extractable from request extensions after insertion");
assert!(Arc::ptr_eq(&injection_ctx, &extracted));
}
#[rstest]
fn grpc_request_ext_returns_none_for_missing_context() {
let request = Request::new(());
let extracted = request.get_di_context::<Arc<InjectionContext>>();
assert!(extracted.is_none());
}
#[rstest]
fn sanitize_di_error_returns_generic_message() {
let error = reinhardt_di::DiError::NotFound("my_app::services::DatabasePool".to_string());
let status = sanitize_di_error(&error);
assert_eq!(status.code(), tonic::Code::Internal);
assert_eq!(status.message(), "Internal server error");
assert!(
!status.message().contains("DatabasePool"),
"Type name should not be exposed in client error"
);
assert!(
!status.message().contains("my_app"),
"Module path should not be exposed in client error"
);
}
#[rstest]
fn sanitize_di_error_hides_type_mismatch_details() {
let error = reinhardt_di::DiError::TypeMismatch {
expected: "my_app::db::PostgresPool".to_string(),
actual: "my_app::db::SqlitePool".to_string(),
};
let status = sanitize_di_error(&error);
assert_eq!(status.code(), tonic::Code::Internal);
assert_eq!(status.message(), "Internal server error");
assert!(!status.message().contains("PostgresPool"));
assert!(!status.message().contains("SqlitePool"));
}
#[rstest]
fn sanitize_di_error_hides_circular_dependency_details() {
let error = reinhardt_di::DiError::CircularDependency(
"my_app::ServiceA -> my_app::ServiceB -> my_app::ServiceA".to_string(),
);
let status = sanitize_di_error(&error);
assert_eq!(status.code(), tonic::Code::Internal);
assert_eq!(status.message(), "Internal server error");
assert!(!status.message().contains("ServiceA"));
assert!(!status.message().contains("ServiceB"));
}
#[rstest]
fn sanitize_missing_context_returns_generic_message() {
let status = sanitize_missing_context();
assert_eq!(status.code(), tonic::Code::Internal);
assert_eq!(status.message(), "Internal server error");
assert!(!status.message().contains("DI"));
assert!(!status.message().contains("context"));
assert!(!status.message().contains("extensions"));
}
}