use bytes::Bytes;
use http::{Request, Response};
use http_body::Body as HttpBody;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower_service::Service;
use crate::body::DebugBody;
use crate::layer::DebugConfig;
#[derive(Debug, Clone)]
pub struct DebugService<S> {
inner: S,
config: DebugConfig,
}
impl<S> DebugService<S> {
pub fn new(inner: S, config: DebugConfig) -> Self {
Self { inner, config }
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for DebugService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: std::fmt::Display + Send + 'static,
ReqBody: HttpBody<Data = Bytes> + Send + 'static,
ReqBody::Error: std::fmt::Display,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: std::fmt::Display + Send,
{
type Response = Response<DebugBody<ResBody>>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let config = self.config.clone();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let start = tokio::time::Instant::now();
let method = req.uri().path().to_string();
let http_method = req.method().clone();
tracing::info!(
method = %method,
http_method = %http_method,
"→ gRPC request"
);
if config.log_headers {
let headers = req.headers();
let content_type = headers
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
let authority = req
.uri()
.authority()
.map(|a| a.to_string())
.unwrap_or_default();
let user_agent = headers
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
let grpc_timeout = headers.get("grpc-timeout").and_then(|v| v.to_str().ok());
tracing::debug!(
method = %method,
content_type = content_type,
authority = %authority,
user_agent = user_agent,
grpc_timeout = ?grpc_timeout,
"→ gRPC request headers"
);
let custom_metadata: Vec<_> = headers
.iter()
.filter(|(name, _)| {
let n = name.as_str();
!n.starts_with(':')
&& n != "content-type"
&& n != "user-agent"
&& n != "te"
&& n != "grpc-timeout"
&& n != "grpc-encoding"
&& n != "grpc-accept-encoding"
})
.map(|(name, value)| {
format!("{}={}", name.as_str(), value.to_str().unwrap_or("<binary>"))
})
.collect();
if !custom_metadata.is_empty() {
tracing::debug!(
method = %method,
metadata = ?custom_metadata,
"→ gRPC custom metadata"
);
}
}
let response = inner.call(req).await;
let elapsed = start.elapsed();
match response {
Ok(resp) => {
let status = resp.status();
let grpc_status = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.map(String::from);
if let Some(ref gs) = grpc_status {
if gs != "0" {
let grpc_message = resp
.headers()
.get("grpc-message")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
tracing::warn!(
method = %method,
http_status = %status,
grpc_status = %gs,
grpc_message = grpc_message,
elapsed_ms = elapsed.as_millis() as u64,
"← gRPC response (error)"
);
} else {
tracing::info!(
method = %method,
http_status = %status,
grpc_status = %gs,
elapsed_ms = elapsed.as_millis() as u64,
"← gRPC response"
);
}
} else {
tracing::info!(
method = %method,
http_status = %status,
elapsed_ms = elapsed.as_millis() as u64,
"← gRPC response (status in trailers)"
);
}
if config.log_headers {
tracing::debug!(
method = %method,
headers = ?resp.headers(),
"← gRPC response headers"
);
}
let (parts, body) = resp.into_parts();
let debug_body = DebugBody::new(
body,
method,
config.log_response_frames,
config.max_body_bytes,
);
Ok(Response::from_parts(parts, debug_body))
}
Err(e) => {
tracing::error!(
method = %method,
elapsed_ms = elapsed.as_millis() as u64,
error = %e,
"← gRPC call failed"
);
Err(e)
}
}
})
}
}