amaters_net/
logging_layer.rs1use std::future::Future;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use std::time::Instant;
32
33use tower_layer::Layer;
34use tower_service::Service;
35use tracing::{info, warn};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum LogVerbosity {
42 Off,
44 Brief,
47 Detailed,
49}
50
51#[derive(Debug, Clone)]
55pub struct LoggingLayer {
56 pub verbosity: LogVerbosity,
58 pub slow_threshold_ms: u64,
60}
61
62impl LoggingLayer {
63 pub fn new(verbosity: LogVerbosity) -> Self {
66 Self {
67 verbosity,
68 slow_threshold_ms: 100,
69 }
70 }
71
72 pub fn with_slow_threshold(mut self, ms: u64) -> Self {
74 self.slow_threshold_ms = ms;
75 self
76 }
77}
78
79impl<S> Layer<S> for LoggingLayer {
80 type Service = LoggingService<S>;
81
82 fn layer(&self, inner: S) -> Self::Service {
83 LoggingService {
84 inner,
85 verbosity: self.verbosity,
86 slow_threshold_ms: self.slow_threshold_ms,
87 }
88 }
89}
90
91#[derive(Clone)]
95pub struct LoggingService<S> {
96 inner: S,
97 verbosity: LogVerbosity,
98 slow_threshold_ms: u64,
99}
100
101impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for LoggingService<S>
102where
103 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
104 S::Future: Send + 'static,
105 S::Error: std::fmt::Display + Send + 'static,
106 ReqBody: Send + 'static,
107 ResBody: Send + 'static,
108{
109 type Response = http::Response<ResBody>;
110 type Error = S::Error;
111 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
112
113 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114 self.inner.poll_ready(cx)
115 }
116
117 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
118 if self.verbosity == LogVerbosity::Off {
120 let mut inner = self.inner.clone();
121 std::mem::swap(&mut self.inner, &mut inner);
122 return Box::pin(inner.call(req));
123 }
124
125 let method = req.uri().path().to_owned();
126 let verbosity = self.verbosity;
127 let slow_threshold_ms = self.slow_threshold_ms;
128
129 let mut inner = self.inner.clone();
130 std::mem::swap(&mut self.inner, &mut inner);
131
132 let start = Instant::now();
133
134 Box::pin(async move {
135 let result = inner.call(req).await;
136 let latency_ms = start.elapsed().as_millis() as u64;
137
138 let is_error = result.is_err();
139 let status_code = result
140 .as_ref()
141 .ok()
142 .map(|r| r.status().as_u16())
143 .unwrap_or(0);
144
145 let should_log = match verbosity {
146 LogVerbosity::Off => false,
147 LogVerbosity::Brief => is_error || latency_ms > slow_threshold_ms,
148 LogVerbosity::Detailed => true,
149 };
150
151 if should_log {
152 if is_error {
153 warn!(
154 method = %method,
155 latency_ms = latency_ms,
156 status_code = status_code,
157 "gRPC request error"
158 );
159 } else {
160 info!(
161 method = %method,
162 latency_ms = latency_ms,
163 status_code = status_code,
164 "gRPC request completed"
165 );
166 }
167 }
168
169 result
170 })
171 }
172}
173
174#[cfg(test)]
177mod tests {
178 use super::*;
179 use std::convert::Infallible;
180 use std::future::Future;
181 use std::pin::Pin;
182 use std::sync::Arc;
183 use std::sync::atomic::{AtomicU32, Ordering};
184 use std::task::{Context, Poll};
185
186 use tower_service::Service as _;
187 use tracing_test::traced_test;
188
189 fn make_req(path: &str) -> http::Request<String> {
192 http::Request::builder()
193 .uri(path)
194 .body(String::new())
195 .expect("request builder should not fail")
196 }
197
198 #[derive(Clone)]
200 struct OkService;
201
202 impl Service<http::Request<String>> for OkService {
203 type Response = http::Response<String>;
204 type Error = Infallible;
205 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
206
207 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208 Poll::Ready(Ok(()))
209 }
210
211 fn call(&mut self, _req: http::Request<String>) -> Self::Future {
212 Box::pin(async { Ok(http::Response::new(String::new())) })
213 }
214 }
215
216 #[derive(Clone)]
218 struct ErrService;
219
220 impl Service<http::Request<String>> for ErrService {
221 type Response = http::Response<String>;
222 type Error = String;
223 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
224
225 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226 Poll::Ready(Ok(()))
227 }
228
229 fn call(&mut self, _req: http::Request<String>) -> Self::Future {
230 Box::pin(async { Err("simulated error".to_owned()) })
231 }
232 }
233
234 #[derive(Clone)]
236 struct CountingService {
237 count: Arc<AtomicU32>,
238 }
239
240 impl CountingService {
241 fn new() -> (Self, Arc<AtomicU32>) {
242 let count = Arc::new(AtomicU32::new(0));
243 (
244 Self {
245 count: count.clone(),
246 },
247 count,
248 )
249 }
250 }
251
252 impl Service<http::Request<String>> for CountingService {
253 type Response = http::Response<String>;
254 type Error = Infallible;
255 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
256
257 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
258 Poll::Ready(Ok(()))
259 }
260
261 fn call(&mut self, _req: http::Request<String>) -> Self::Future {
262 self.count.fetch_add(1, Ordering::Relaxed);
263 Box::pin(async { Ok(http::Response::new(String::new())) })
264 }
265 }
266
267 #[traced_test]
270 #[tokio::test]
271 async fn test_logging_layer_off_emits_nothing() {
272 let layer = LoggingLayer::new(LogVerbosity::Off);
275 let (counting, count) = CountingService::new();
276 let mut svc = layer.layer(counting);
277
278 svc.call(make_req("/pkg.Svc/Method"))
279 .await
280 .expect("should succeed");
281
282 assert_eq!(
283 count.load(Ordering::Relaxed),
284 1,
285 "inner service called once"
286 );
287 assert!(
288 !logs_contain("gRPC request"),
289 "Off verbosity must not emit any gRPC log events"
290 );
291 }
292
293 #[traced_test]
296 #[tokio::test]
297 async fn test_logging_layer_brief_skips_fast_success() {
298 let layer = LoggingLayer::new(LogVerbosity::Brief).with_slow_threshold(10_000);
301 let mut svc = layer.layer(OkService);
302
303 let result = svc.call(make_req("/fast/Method")).await;
304 assert!(result.is_ok(), "should succeed");
305 assert!(
306 !logs_contain("gRPC request"),
307 "Brief verbosity must not emit for fast success"
308 );
309 }
310
311 #[traced_test]
314 #[tokio::test]
315 async fn test_logging_layer_brief_emits_on_error() {
316 let layer = LoggingLayer::new(LogVerbosity::Brief);
317 let mut svc = layer.layer(ErrService);
318
319 let result = svc.call(make_req("/fail/Method")).await;
320 assert!(result.is_err(), "error should propagate");
322 assert!(
323 logs_contain("gRPC request error"),
324 "Brief verbosity must emit a warn on error"
325 );
326 }
327
328 #[traced_test]
331 #[tokio::test]
332 async fn test_logging_layer_detailed_emits_always() {
333 let layer = LoggingLayer::new(LogVerbosity::Detailed);
335 let mut svc = layer.layer(OkService);
336
337 let result = svc.call(make_req("/always/Method")).await;
338 assert!(result.is_ok(), "should succeed with Detailed verbosity");
339 assert!(
340 logs_contain("gRPC request completed"),
341 "Detailed verbosity must emit an info for every request"
342 );
343 }
344
345 #[tokio::test]
348 async fn test_logging_layer_records_method_and_latency() {
349 let layer = LoggingLayer::new(LogVerbosity::Detailed);
353 let mut svc = layer.layer(OkService);
354
355 let res = svc
356 .call(make_req("/amaters.AqlService/ExecuteQuery"))
357 .await
358 .expect("should succeed");
359
360 assert_eq!(
361 res.status(),
362 http::StatusCode::OK,
363 "status should be 200 OK"
364 );
365 }
366
367 #[test]
370 fn test_logging_layer_builder_defaults() {
371 let layer = LoggingLayer::new(LogVerbosity::Brief);
372 assert_eq!(layer.verbosity, LogVerbosity::Brief);
373 assert_eq!(layer.slow_threshold_ms, 100);
374 }
375
376 #[test]
379 fn test_logging_layer_with_slow_threshold_overrides() {
380 let layer = LoggingLayer::new(LogVerbosity::Brief).with_slow_threshold(500);
381 assert_eq!(layer.slow_threshold_ms, 500);
382 }
383}