Skip to main content

axon/
request_middleware.rs

1//! Request Middleware — automatic request ID, timing, and logging for AxonServer.
2//!
3//! Provides an axum middleware layer that intercepts every request to:
4//!   - Generate a unique sequential request ID (`X-Request-Id` header)
5//!   - Time the request duration (start → response)
6//!   - Auto-record to the RequestLogger (method, path, status, latency, client key)
7//!   - Tag slow requests above a configurable threshold
8//!
9//! This replaces manual `request_logger.record()` calls in individual handlers
10//! and provides consistent observability across all endpoints.
11//!
12//! Configuration:
13//!   - `MiddlewareConfig` — enabled flag, slow request threshold
14//!   - `GET /v1/middleware` — view current middleware configuration
15//!   - `PUT /v1/middleware` — update middleware settings at runtime
16
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::{Arc, Mutex};
19use std::time::Instant;
20
21use axum::body::Body;
22use axum::extract::Request;
23use axum::http::HeaderValue;
24use axum::middleware::Next;
25use axum::response::Response;
26use serde::{Deserialize, Serialize};
27
28// ── Request ID generator ────────────────────────────────────────────────
29
30/// Atomic sequential request ID generator.
31///
32/// Produces unique IDs of the form `axr-{counter}` for each request.
33/// Counter is monotonically increasing and never resets during server lifetime.
34pub struct RequestIdGenerator {
35    counter: AtomicU64,
36    prefix: String,
37}
38
39impl RequestIdGenerator {
40    /// Create a new generator with the default prefix "axr".
41    pub fn new() -> Self {
42        RequestIdGenerator {
43            counter: AtomicU64::new(0),
44            prefix: "axr".to_string(),
45        }
46    }
47
48    /// Create a generator with a custom prefix.
49    pub fn with_prefix(prefix: &str) -> Self {
50        RequestIdGenerator {
51            counter: AtomicU64::new(0),
52            prefix: prefix.to_string(),
53        }
54    }
55
56    /// Generate the next request ID.
57    pub fn next_id(&self) -> String {
58        let n = self.counter.fetch_add(1, Ordering::Relaxed);
59        format!("{}-{}", self.prefix, n)
60    }
61
62    /// Current counter value (number of IDs generated).
63    pub fn count(&self) -> u64 {
64        self.counter.load(Ordering::Relaxed)
65    }
66}
67
68impl Default for RequestIdGenerator {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74// ── Middleware configuration ─────────────────────────────────────────────
75
76/// Configuration for the request middleware layer.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct MiddlewareConfig {
79    /// Whether the middleware is enabled.
80    pub enabled: bool,
81    /// Slow request threshold in milliseconds. Requests exceeding this
82    /// are tagged in the log. 0 = disabled.
83    pub slow_threshold_ms: u64,
84    /// Whether to inject X-Request-Id response header.
85    pub inject_request_id: bool,
86    /// Whether to inject X-Response-Time header (latency in ms).
87    pub inject_response_time: bool,
88}
89
90impl Default for MiddlewareConfig {
91    fn default() -> Self {
92        MiddlewareConfig {
93            enabled: true,
94            slow_threshold_ms: 5000,
95            inject_request_id: true,
96            inject_response_time: true,
97        }
98    }
99}
100
101impl MiddlewareConfig {
102    /// Disabled middleware — passes through without recording.
103    pub fn disabled() -> Self {
104        MiddlewareConfig {
105            enabled: false,
106            slow_threshold_ms: 0,
107            inject_request_id: false,
108            inject_response_time: false,
109        }
110    }
111}
112
113// ── Middleware update ────────────────────────────────────────────────────
114
115/// Partial update for middleware configuration.
116#[derive(Debug, Clone, Deserialize)]
117pub struct MiddlewareUpdate {
118    pub enabled: Option<bool>,
119    pub slow_threshold_ms: Option<u64>,
120    pub inject_request_id: Option<bool>,
121    pub inject_response_time: Option<bool>,
122}
123
124/// Apply a partial update to a MiddlewareConfig. Returns list of changed fields.
125pub fn apply_update(config: &mut MiddlewareConfig, update: &MiddlewareUpdate) -> Vec<String> {
126    let mut changes = Vec::new();
127
128    if let Some(enabled) = update.enabled {
129        if enabled != config.enabled {
130            config.enabled = enabled;
131            changes.push("enabled".to_string());
132        }
133    }
134    if let Some(threshold) = update.slow_threshold_ms {
135        if threshold != config.slow_threshold_ms {
136            config.slow_threshold_ms = threshold;
137            changes.push("slow_threshold_ms".to_string());
138        }
139    }
140    if let Some(inject_id) = update.inject_request_id {
141        if inject_id != config.inject_request_id {
142            config.inject_request_id = inject_id;
143            changes.push("inject_request_id".to_string());
144        }
145    }
146    if let Some(inject_time) = update.inject_response_time {
147        if inject_time != config.inject_response_time {
148            config.inject_response_time = inject_time;
149            changes.push("inject_response_time".to_string());
150        }
151    }
152
153    changes
154}
155
156// ── Request metadata ────────────────────────────────────────────────────
157
158/// Metadata captured for a single request by the middleware.
159#[derive(Debug, Clone, Serialize)]
160pub struct RequestMeta {
161    /// Unique request ID (e.g., "axr-42").
162    pub request_id: String,
163    /// HTTP method.
164    pub method: String,
165    /// Request path.
166    pub path: String,
167    /// Response status code.
168    pub status: u16,
169    /// Latency in microseconds.
170    pub latency_us: u64,
171    /// Latency in milliseconds (convenience).
172    pub latency_ms: u64,
173    /// Client identifier.
174    pub client_key: String,
175    /// Whether this was flagged as a slow request.
176    pub slow: bool,
177}
178
179// ── Middleware state ─────────────────────────────────────────────────────
180
181/// Shared state for the request middleware, held in an Arc for cloning.
182pub struct MiddlewareState<S> {
183    pub id_generator: RequestIdGenerator,
184    pub config: Arc<Mutex<MiddlewareConfig>>,
185    pub server_state: Arc<Mutex<S>>,
186}
187
188// ── Helper: extract client key from headers ─────────────────────────────
189
190fn client_key_from_headers(headers: &axum::http::HeaderMap) -> String {
191    headers
192        .get("authorization")
193        .and_then(|v| v.to_str().ok())
194        .map(|v| v.to_string())
195        .unwrap_or_else(|| "anonymous".to_string())
196}
197
198// ── Core middleware function ─────────────────────────────────────────────
199
200/// The request middleware function for use with `axum::middleware::from_fn`.
201///
202/// Extracts method/path/client, generates request ID, times the request,
203/// records to the RequestLogger, and injects response headers.
204///
205/// This is designed to be used with `axum::middleware::from_fn` in the
206/// router setup. The ServerState access is done via the shared state
207/// that axum provides.
208pub async fn request_middleware_fn(
209    state: axum::extract::State<Arc<Mutex<crate::axon_server::ServerState>>>,
210    request: Request<Body>,
211    next: Next,
212) -> Response {
213    let start = Instant::now();
214
215    // Extract request info before passing to handler
216    let method = request.method().to_string();
217    let path = request.uri().path().to_string();
218    let client_key = client_key_from_headers(request.headers());
219
220    // Read config and generate ID
221    let (enabled, slow_threshold_ms, inject_id, inject_time, request_id) = {
222        let s = state.lock().unwrap();
223        let cfg = &s.middleware_config;
224        let id = s.request_id_gen.next_id();
225        (cfg.enabled, cfg.slow_threshold_ms, cfg.inject_request_id, cfg.inject_response_time, id)
226    };
227
228    // Call the actual handler
229    let mut response = next.run(request).await;
230
231    if !enabled {
232        return response;
233    }
234
235    // Compute latency
236    let elapsed = start.elapsed();
237    let _latency_us = elapsed.as_micros() as u64;
238    let latency_ms = elapsed.as_millis() as u64;
239    let status = response.status().as_u16();
240    let _slow = slow_threshold_ms > 0 && latency_ms >= slow_threshold_ms;
241
242    // Record to request logger
243    {
244        let mut s = state.lock().unwrap();
245        s.request_logger.record(&method, &path, status, elapsed, &client_key);
246    }
247
248    // Inject response headers
249    if inject_id {
250        if let Ok(val) = HeaderValue::from_str(&request_id) {
251            response.headers_mut().insert("x-request-id", val);
252        }
253    }
254    if inject_time {
255        if let Ok(val) = HeaderValue::from_str(&format!("{}ms", latency_ms)) {
256            response.headers_mut().insert("x-response-time", val);
257        }
258    }
259
260    response
261}
262
263// ── Stats ───────────────────────────────────────────────────────────────
264
265/// Middleware statistics snapshot.
266#[derive(Debug, Clone, Serialize)]
267pub struct MiddlewareStats {
268    /// Total requests processed by the middleware.
269    pub total_requests: u64,
270    /// Current configuration.
271    pub config: MiddlewareConfig,
272}
273
274// ── Tests ────────────────────────────────────────────────────────────────
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn request_id_generator_sequential() {
282        let gen = RequestIdGenerator::new();
283        assert_eq!(gen.next_id(), "axr-0");
284        assert_eq!(gen.next_id(), "axr-1");
285        assert_eq!(gen.next_id(), "axr-2");
286        assert_eq!(gen.count(), 3);
287    }
288
289    #[test]
290    fn request_id_generator_custom_prefix() {
291        let gen = RequestIdGenerator::with_prefix("req");
292        assert_eq!(gen.next_id(), "req-0");
293        assert_eq!(gen.next_id(), "req-1");
294    }
295
296    #[test]
297    fn request_id_generator_default() {
298        let gen = RequestIdGenerator::default();
299        assert_eq!(gen.next_id(), "axr-0");
300    }
301
302    #[test]
303    fn default_config() {
304        let cfg = MiddlewareConfig::default();
305        assert!(cfg.enabled);
306        assert_eq!(cfg.slow_threshold_ms, 5000);
307        assert!(cfg.inject_request_id);
308        assert!(cfg.inject_response_time);
309    }
310
311    #[test]
312    fn disabled_config() {
313        let cfg = MiddlewareConfig::disabled();
314        assert!(!cfg.enabled);
315        assert_eq!(cfg.slow_threshold_ms, 0);
316        assert!(!cfg.inject_request_id);
317        assert!(!cfg.inject_response_time);
318    }
319
320    #[test]
321    fn config_serializable() {
322        let cfg = MiddlewareConfig::default();
323        let json = serde_json::to_value(&cfg).unwrap();
324        assert_eq!(json["enabled"], true);
325        assert_eq!(json["slow_threshold_ms"], 5000);
326        assert_eq!(json["inject_request_id"], true);
327        assert_eq!(json["inject_response_time"], true);
328    }
329
330    #[test]
331    fn config_deserializable() {
332        let json = serde_json::json!({
333            "enabled": false,
334            "slow_threshold_ms": 1000,
335            "inject_request_id": false,
336            "inject_response_time": true,
337        });
338        let cfg: MiddlewareConfig = serde_json::from_value(json).unwrap();
339        assert!(!cfg.enabled);
340        assert_eq!(cfg.slow_threshold_ms, 1000);
341        assert!(!cfg.inject_request_id);
342        assert!(cfg.inject_response_time);
343    }
344
345    #[test]
346    fn apply_update_changes_tracked() {
347        let mut cfg = MiddlewareConfig::default();
348        let update = MiddlewareUpdate {
349            enabled: None,
350            slow_threshold_ms: Some(2000),
351            inject_request_id: Some(false),
352            inject_response_time: None,
353        };
354        let changes = apply_update(&mut cfg, &update);
355        assert_eq!(changes.len(), 2);
356        assert!(changes.contains(&"slow_threshold_ms".to_string()));
357        assert!(changes.contains(&"inject_request_id".to_string()));
358        assert_eq!(cfg.slow_threshold_ms, 2000);
359        assert!(!cfg.inject_request_id);
360    }
361
362    #[test]
363    fn apply_update_no_op_when_same() {
364        let mut cfg = MiddlewareConfig::default();
365        let update = MiddlewareUpdate {
366            enabled: Some(true),
367            slow_threshold_ms: Some(5000),
368            inject_request_id: Some(true),
369            inject_response_time: Some(true),
370        };
371        let changes = apply_update(&mut cfg, &update);
372        assert!(changes.is_empty());
373    }
374
375    #[test]
376    fn apply_update_all_fields() {
377        let mut cfg = MiddlewareConfig::default();
378        let update = MiddlewareUpdate {
379            enabled: Some(false),
380            slow_threshold_ms: Some(100),
381            inject_request_id: Some(false),
382            inject_response_time: Some(false),
383        };
384        let changes = apply_update(&mut cfg, &update);
385        assert_eq!(changes.len(), 4);
386        assert!(!cfg.enabled);
387        assert_eq!(cfg.slow_threshold_ms, 100);
388        assert!(!cfg.inject_request_id);
389        assert!(!cfg.inject_response_time);
390    }
391
392    #[test]
393    fn request_meta_serializable() {
394        let meta = RequestMeta {
395            request_id: "axr-42".to_string(),
396            method: "POST".to_string(),
397            path: "/v1/deploy".to_string(),
398            status: 200,
399            latency_us: 1500,
400            latency_ms: 1,
401            client_key: "token_abc".to_string(),
402            slow: false,
403        };
404        let json = serde_json::to_value(&meta).unwrap();
405        assert_eq!(json["request_id"], "axr-42");
406        assert_eq!(json["method"], "POST");
407        assert_eq!(json["path"], "/v1/deploy");
408        assert_eq!(json["status"], 200);
409        assert_eq!(json["latency_us"], 1500);
410        assert_eq!(json["slow"], false);
411    }
412
413    #[test]
414    fn request_meta_slow_flag() {
415        let meta = RequestMeta {
416            request_id: "axr-99".to_string(),
417            method: "GET".to_string(),
418            path: "/v1/health".to_string(),
419            status: 200,
420            latency_us: 6_000_000,
421            latency_ms: 6000,
422            client_key: "anonymous".to_string(),
423            slow: true,
424        };
425        let json = serde_json::to_value(&meta).unwrap();
426        assert_eq!(json["slow"], true);
427        assert_eq!(json["latency_ms"], 6000);
428    }
429
430    #[test]
431    fn middleware_stats_serializable() {
432        let stats = MiddlewareStats {
433            total_requests: 42,
434            config: MiddlewareConfig::default(),
435        };
436        let json = serde_json::to_value(&stats).unwrap();
437        assert_eq!(json["total_requests"], 42);
438        assert_eq!(json["config"]["enabled"], true);
439    }
440
441    #[test]
442    fn client_key_extraction() {
443        let mut headers = axum::http::HeaderMap::new();
444        assert_eq!(client_key_from_headers(&headers), "anonymous");
445
446        headers.insert("authorization", HeaderValue::from_static("Bearer token123"));
447        assert_eq!(client_key_from_headers(&headers), "Bearer token123");
448    }
449}