foxy/logging/
middleware.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5//! HTTP middleware for request/response logging with trace context.
6
7use crate::logging::config::LoggingConfig;
8use crate::logging::structured::{RequestInfo, generate_trace_id};
9use futures_util::ready;
10use hyper::{Request, Response};
11use slog_scope;
12use std::future::Future;
13use std::net::SocketAddr;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::Duration;
18
19/// Middleware for request/response logging with trace context
20#[derive(Debug, Clone)]
21pub struct LoggingMiddleware {
22    config: Arc<LoggingConfig>,
23}
24
25impl LoggingMiddleware {
26    /// Create a new logging middleware
27    #[must_use]
28    pub fn new(config: LoggingConfig) -> Self {
29        Self {
30            config: Arc::new(config),
31        }
32    }
33
34    /// Get a reference to the logging configuration
35    pub fn config(&self) -> &LoggingConfig {
36        &self.config
37    }
38
39    /// Process a request and add trace context
40    pub fn process<B>(
41        &self,
42        req: Request<B>,
43        remote_addr: Option<SocketAddr>,
44    ) -> (Request<B>, RequestInfo) {
45        let method = req.method().to_string();
46        let path = req.uri().path().to_string();
47        let remote_addr_str =
48            remote_addr.map_or_else(|| "unknown".to_string(), |addr| addr.to_string());
49
50        let user_agent = req
51            .headers()
52            .get(hyper::header::USER_AGENT)
53            .and_then(|h| h.to_str().ok())
54            .unwrap_or("unknown")
55            .to_string();
56
57        // Check for existing trace ID in headers if propagation is enabled
58        let trace_id = if self.config.propagate_trace_id {
59            req.headers()
60                .get(&self.config.trace_id_header)
61                .and_then(|h| h.to_str().ok())
62                .filter(|s| !s.is_empty())
63                .map_or_else(generate_trace_id, std::string::ToString::to_string)
64        } else {
65            generate_trace_id()
66        };
67
68        let request_info = RequestInfo {
69            trace_id,
70            method,
71            path,
72            remote_addr: remote_addr_str,
73            user_agent,
74            start_time_ms: std::time::SystemTime::now()
75                .duration_since(std::time::UNIX_EPOCH)
76                .unwrap_or_default()
77                .as_millis(),
78        };
79
80        // Log the incoming request with trace context
81        if self.config.structured {
82            let logger = slog_scope::logger();
83            slog::info!(logger, "Request received";
84                "trace_id" => &request_info.trace_id,
85                "method" => &request_info.method,
86                "path" => &request_info.path,
87                "remote_addr" => &request_info.remote_addr,
88                "user_agent" => &request_info.user_agent
89            );
90        } else {
91            log::info!(
92                "Request received: {} {} from {} (trace_id: {})",
93                request_info.method,
94                request_info.path,
95                request_info.remote_addr,
96                request_info.trace_id
97            );
98        }
99
100        (req, request_info)
101    }
102
103    /// Log the response with timing information
104    pub fn log_response<B>(
105        &self,
106        response: &Response<B>,
107        request_info: &RequestInfo,
108        upstream_duration: Option<Duration>,
109    ) {
110        let status = response.status().as_u16();
111        let elapsed_ms = request_info.elapsed_ms();
112        let upstream_ms = upstream_duration.map_or(0, |d| d.as_millis());
113        let internal_ms = elapsed_ms.saturating_sub(upstream_ms);
114
115        if self.config.structured {
116            let logger = slog_scope::logger();
117            slog::info!(logger, "Response completed";
118                "trace_id" => &request_info.trace_id,
119                "method" => &request_info.method,
120                "path" => &request_info.path,
121                "status" => status,
122                "elapsed_ms" => elapsed_ms,
123                "upstream_ms" => upstream_ms,
124                "internal_ms" => internal_ms
125            );
126        } else {
127            log::info!(
128                "[timing] {} {} -> {} | total={}ms upstream={}ms internal={}ms (trace_id: {})",
129                request_info.method,
130                request_info.path,
131                status,
132                elapsed_ms,
133                upstream_ms,
134                internal_ms,
135                request_info.trace_id
136            );
137        }
138    }
139}
140
141/// Future that wraps a response future and adds trace ID header
142pub struct TracedResponseFuture<F> {
143    inner: F,
144    trace_id: String,
145    trace_header: String,
146    include_trace_id: bool,
147}
148
149impl<F, B, E> Future for TracedResponseFuture<F>
150where
151    F: Future<Output = Result<Response<B>, E>> + Unpin,
152{
153    type Output = Result<Response<B>, E>;
154
155    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
156        let result = ready!(Pin::new(&mut self.inner).poll(cx));
157
158        Poll::Ready(match result {
159            Ok(mut response) => {
160                // Add trace ID header to response if enabled
161                if self.include_trace_id {
162                    let header_name =
163                        hyper::header::HeaderName::from_bytes(self.trace_header.as_bytes())
164                            .unwrap_or_else(|_| {
165                                hyper::header::HeaderName::from_static("x-trace-id")
166                            });
167
168                    response.headers_mut().insert(
169                        header_name,
170                        hyper::header::HeaderValue::from_str(&self.trace_id).unwrap_or_else(|_| {
171                            hyper::header::HeaderValue::from_static("invalid-trace-id")
172                        }),
173                    );
174                }
175                Ok(response)
176            }
177            Err(e) => Err(e),
178        })
179    }
180}
181
182/// Extension trait for response futures to add trace context
183pub trait ResponseFutureExt: Sized {
184    /// Add trace ID header to the response
185    fn with_trace_id(
186        self,
187        trace_id: String,
188        trace_header: String,
189        include_trace_id: bool,
190    ) -> TracedResponseFuture<Self>;
191}
192
193impl<F, B, E> ResponseFutureExt for F
194where
195    F: Future<Output = Result<Response<B>, E>> + Unpin,
196{
197    fn with_trace_id(
198        self,
199        trace_id: String,
200        trace_header: String,
201        include_trace_id: bool,
202    ) -> TracedResponseFuture<Self> {
203        TracedResponseFuture {
204            inner: self,
205            trace_id,
206            trace_header,
207            include_trace_id,
208        }
209    }
210}