use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tonic::{Request, Response, Status};
use tower::{Layer, Service};
use crate::grpc::interceptors::RequestIdExtension;
#[derive(Clone)]
pub struct LoggingLayer;
impl<S> Layer<S> for LoggingLayer {
type Service = LoggingService<S>;
fn layer(&self, inner: S) -> Self::Service {
LoggingService { inner }
}
}
#[derive(Clone)]
pub struct LoggingService<S> {
inner: S,
}
impl<S, ReqBody> Service<Request<ReqBody>> for LoggingService<S>
where
S: Service<Request<ReqBody>, Response = Response<tonic::body::Body>, Error = Status>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let mut inner = self.inner.clone();
Box::pin(async move {
let start = Instant::now();
tracing::debug!("gRPC request started");
let result = inner.call(req).await;
let duration = start.elapsed();
match &result {
Ok(_) => {
tracing::info!(duration_ms = duration.as_millis(), "gRPC request completed");
}
Err(status) => {
tracing::warn!(
duration_ms = duration.as_millis(),
status_code = ?status.code(),
"gRPC request failed"
);
}
}
result
})
}
}
#[derive(Clone)]
pub struct GrpcTracingLayer;
impl<S> Layer<S> for GrpcTracingLayer {
type Service = GrpcTracingService<S>;
fn layer(&self, inner: S) -> Self::Service {
GrpcTracingService { inner }
}
}
#[derive(Clone)]
pub struct GrpcTracingService<S> {
inner: S,
}
impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTracingService<S>
where
S: Service<Request<ReqBody>, Response = Response<tonic::body::Body>, Error = Status>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let mut inner = self.inner.clone();
let request_id = req
.extensions()
.get::<RequestIdExtension>()
.map(|ext| ext.0.clone())
.or_else(|| {
req.metadata()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "unknown".to_string());
let method = req
.metadata()
.get(":path")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
Box::pin(async move {
let start = Instant::now();
let span = tracing::info_span!(
"grpc_request",
otel.kind = "server",
rpc.system = "grpc",
rpc.service = %extract_service_name(&method),
rpc.method = %extract_method_name(&method),
request_id = %request_id,
);
let _guard = span.enter();
tracing::debug!(method = %method, "gRPC request started");
let result = inner.call(req).await;
let duration = start.elapsed();
match &result {
Ok(response) => {
let status = response
.metadata()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.unwrap_or("0");
tracing::info!(
duration_ms = duration.as_millis(),
grpc.status_code = status,
"gRPC request completed"
);
}
Err(status) => {
tracing::warn!(
duration_ms = duration.as_millis(),
grpc.status_code = ?status.code(),
error.message = %status.message(),
"gRPC request failed"
);
}
}
result
})
}
}
fn extract_service_name(path: &str) -> &str {
path.trim_start_matches('/')
.split('/')
.next()
.and_then(|s| s.rsplit('.').next())
.unwrap_or("unknown")
}
fn extract_method_name(path: &str) -> &str {
path.trim_start_matches('/')
.split('/')
.nth(1)
.unwrap_or("unknown")
}
#[cfg(feature = "governor")]
#[derive(Clone)]
pub struct GrpcRateLimitLayer {
enabled: bool,
}
#[cfg(feature = "governor")]
impl GrpcRateLimitLayer {
pub fn new(config: crate::config::LocalRateLimitConfig) -> Self {
Self {
enabled: config.enabled,
}
}
}
#[cfg(feature = "governor")]
impl<S> Layer<S> for GrpcRateLimitLayer {
type Service = GrpcRateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
GrpcRateLimitService {
inner,
enabled: self.enabled,
}
}
}
#[cfg(feature = "governor")]
#[derive(Clone)]
pub struct GrpcRateLimitService<S> {
inner: S,
enabled: bool,
}
#[cfg(feature = "governor")]
impl<S, ReqBody> Service<Request<ReqBody>> for GrpcRateLimitService<S>
where
S: Service<Request<ReqBody>, Response = Response<tonic::body::Body>, Error = Status>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
if self.enabled {
tracing::trace!("Rate limiting enabled for gRPC request");
}
let mut inner = self.inner.clone();
Box::pin(async move { inner.call(req).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_service_name() {
assert_eq!(
extract_service_name("/example.v1.Greeter/SayHello"),
"Greeter"
);
assert_eq!(
extract_service_name("/mypackage.UserService/GetUser"),
"UserService"
);
assert_eq!(extract_service_name("/Service/Method"), "Service");
assert_eq!(extract_service_name("invalid"), "invalid");
assert_eq!(extract_service_name(""), "");
}
#[test]
fn test_extract_method_name() {
assert_eq!(
extract_method_name("/example.v1.Greeter/SayHello"),
"SayHello"
);
assert_eq!(
extract_method_name("/mypackage.UserService/GetUser"),
"GetUser"
);
assert_eq!(extract_method_name("/Service/Method"), "Method");
assert_eq!(extract_method_name("invalid"), "unknown");
}
}