use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tower_layer::Layer;
use tower_service::Service;
use tracing::{info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogVerbosity {
Off,
Brief,
Detailed,
}
#[derive(Debug, Clone)]
pub struct LoggingLayer {
pub verbosity: LogVerbosity,
pub slow_threshold_ms: u64,
}
impl LoggingLayer {
pub fn new(verbosity: LogVerbosity) -> Self {
Self {
verbosity,
slow_threshold_ms: 100,
}
}
pub fn with_slow_threshold(mut self, ms: u64) -> Self {
self.slow_threshold_ms = ms;
self
}
}
impl<S> Layer<S> for LoggingLayer {
type Service = LoggingService<S>;
fn layer(&self, inner: S) -> Self::Service {
LoggingService {
inner,
verbosity: self.verbosity,
slow_threshold_ms: self.slow_threshold_ms,
}
}
}
#[derive(Clone)]
pub struct LoggingService<S> {
inner: S,
verbosity: LogVerbosity,
slow_threshold_ms: u64,
}
impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for LoggingService<S>
where
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: std::fmt::Display + Send + 'static,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Response = http::Response<ResBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
if self.verbosity == LogVerbosity::Off {
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
return Box::pin(inner.call(req));
}
let method = req.uri().path().to_owned();
let verbosity = self.verbosity;
let slow_threshold_ms = self.slow_threshold_ms;
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
let start = Instant::now();
Box::pin(async move {
let result = inner.call(req).await;
let latency_ms = start.elapsed().as_millis() as u64;
let is_error = result.is_err();
let status_code = result
.as_ref()
.ok()
.map(|r| r.status().as_u16())
.unwrap_or(0);
let should_log = match verbosity {
LogVerbosity::Off => false,
LogVerbosity::Brief => is_error || latency_ms > slow_threshold_ms,
LogVerbosity::Detailed => true,
};
if should_log {
if is_error {
warn!(
method = %method,
latency_ms = latency_ms,
status_code = status_code,
"gRPC request error"
);
} else {
info!(
method = %method,
latency_ms = latency_ms,
status_code = status_code,
"gRPC request completed"
);
}
}
result
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::task::{Context, Poll};
use tower_service::Service as _;
use tracing_test::traced_test;
fn make_req(path: &str) -> http::Request<String> {
http::Request::builder()
.uri(path)
.body(String::new())
.expect("request builder should not fail")
}
#[derive(Clone)]
struct OkService;
impl Service<http::Request<String>> for OkService {
type Response = http::Response<String>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: http::Request<String>) -> Self::Future {
Box::pin(async { Ok(http::Response::new(String::new())) })
}
}
#[derive(Clone)]
struct ErrService;
impl Service<http::Request<String>> for ErrService {
type Response = http::Response<String>;
type Error = String;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: http::Request<String>) -> Self::Future {
Box::pin(async { Err("simulated error".to_owned()) })
}
}
#[derive(Clone)]
struct CountingService {
count: Arc<AtomicU32>,
}
impl CountingService {
fn new() -> (Self, Arc<AtomicU32>) {
let count = Arc::new(AtomicU32::new(0));
(
Self {
count: count.clone(),
},
count,
)
}
}
impl Service<http::Request<String>> for CountingService {
type Response = http::Response<String>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: http::Request<String>) -> Self::Future {
self.count.fetch_add(1, Ordering::Relaxed);
Box::pin(async { Ok(http::Response::new(String::new())) })
}
}
#[traced_test]
#[tokio::test]
async fn test_logging_layer_off_emits_nothing() {
let layer = LoggingLayer::new(LogVerbosity::Off);
let (counting, count) = CountingService::new();
let mut svc = layer.layer(counting);
svc.call(make_req("/pkg.Svc/Method"))
.await
.expect("should succeed");
assert_eq!(
count.load(Ordering::Relaxed),
1,
"inner service called once"
);
assert!(
!logs_contain("gRPC request"),
"Off verbosity must not emit any gRPC log events"
);
}
#[traced_test]
#[tokio::test]
async fn test_logging_layer_brief_skips_fast_success() {
let layer = LoggingLayer::new(LogVerbosity::Brief).with_slow_threshold(10_000);
let mut svc = layer.layer(OkService);
let result = svc.call(make_req("/fast/Method")).await;
assert!(result.is_ok(), "should succeed");
assert!(
!logs_contain("gRPC request"),
"Brief verbosity must not emit for fast success"
);
}
#[traced_test]
#[tokio::test]
async fn test_logging_layer_brief_emits_on_error() {
let layer = LoggingLayer::new(LogVerbosity::Brief);
let mut svc = layer.layer(ErrService);
let result = svc.call(make_req("/fail/Method")).await;
assert!(result.is_err(), "error should propagate");
assert!(
logs_contain("gRPC request error"),
"Brief verbosity must emit a warn on error"
);
}
#[traced_test]
#[tokio::test]
async fn test_logging_layer_detailed_emits_always() {
let layer = LoggingLayer::new(LogVerbosity::Detailed);
let mut svc = layer.layer(OkService);
let result = svc.call(make_req("/always/Method")).await;
assert!(result.is_ok(), "should succeed with Detailed verbosity");
assert!(
logs_contain("gRPC request completed"),
"Detailed verbosity must emit an info for every request"
);
}
#[tokio::test]
async fn test_logging_layer_records_method_and_latency() {
let layer = LoggingLayer::new(LogVerbosity::Detailed);
let mut svc = layer.layer(OkService);
let res = svc
.call(make_req("/amaters.AqlService/ExecuteQuery"))
.await
.expect("should succeed");
assert_eq!(
res.status(),
http::StatusCode::OK,
"status should be 200 OK"
);
}
#[test]
fn test_logging_layer_builder_defaults() {
let layer = LoggingLayer::new(LogVerbosity::Brief);
assert_eq!(layer.verbosity, LogVerbosity::Brief);
assert_eq!(layer.slow_threshold_ms, 100);
}
#[test]
fn test_logging_layer_with_slow_threshold_overrides() {
let layer = LoggingLayer::new(LogVerbosity::Brief).with_slow_threshold(500);
assert_eq!(layer.slow_threshold_ms, 500);
}
}