1use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::{Arc, Mutex};
19use std::time::Instant;
20
21use axum::body::Body;
22use axum::extract::Request;
23use axum::http::HeaderValue;
24use axum::middleware::Next;
25use axum::response::Response;
26use serde::{Deserialize, Serialize};
27
28pub struct RequestIdGenerator {
35 counter: AtomicU64,
36 prefix: String,
37}
38
39impl RequestIdGenerator {
40 pub fn new() -> Self {
42 RequestIdGenerator {
43 counter: AtomicU64::new(0),
44 prefix: "axr".to_string(),
45 }
46 }
47
48 pub fn with_prefix(prefix: &str) -> Self {
50 RequestIdGenerator {
51 counter: AtomicU64::new(0),
52 prefix: prefix.to_string(),
53 }
54 }
55
56 pub fn next_id(&self) -> String {
58 let n = self.counter.fetch_add(1, Ordering::Relaxed);
59 format!("{}-{}", self.prefix, n)
60 }
61
62 pub fn count(&self) -> u64 {
64 self.counter.load(Ordering::Relaxed)
65 }
66}
67
68impl Default for RequestIdGenerator {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct MiddlewareConfig {
79 pub enabled: bool,
81 pub slow_threshold_ms: u64,
84 pub inject_request_id: bool,
86 pub inject_response_time: bool,
88}
89
90impl Default for MiddlewareConfig {
91 fn default() -> Self {
92 MiddlewareConfig {
93 enabled: true,
94 slow_threshold_ms: 5000,
95 inject_request_id: true,
96 inject_response_time: true,
97 }
98 }
99}
100
101impl MiddlewareConfig {
102 pub fn disabled() -> Self {
104 MiddlewareConfig {
105 enabled: false,
106 slow_threshold_ms: 0,
107 inject_request_id: false,
108 inject_response_time: false,
109 }
110 }
111}
112
113#[derive(Debug, Clone, Deserialize)]
117pub struct MiddlewareUpdate {
118 pub enabled: Option<bool>,
119 pub slow_threshold_ms: Option<u64>,
120 pub inject_request_id: Option<bool>,
121 pub inject_response_time: Option<bool>,
122}
123
124pub fn apply_update(config: &mut MiddlewareConfig, update: &MiddlewareUpdate) -> Vec<String> {
126 let mut changes = Vec::new();
127
128 if let Some(enabled) = update.enabled {
129 if enabled != config.enabled {
130 config.enabled = enabled;
131 changes.push("enabled".to_string());
132 }
133 }
134 if let Some(threshold) = update.slow_threshold_ms {
135 if threshold != config.slow_threshold_ms {
136 config.slow_threshold_ms = threshold;
137 changes.push("slow_threshold_ms".to_string());
138 }
139 }
140 if let Some(inject_id) = update.inject_request_id {
141 if inject_id != config.inject_request_id {
142 config.inject_request_id = inject_id;
143 changes.push("inject_request_id".to_string());
144 }
145 }
146 if let Some(inject_time) = update.inject_response_time {
147 if inject_time != config.inject_response_time {
148 config.inject_response_time = inject_time;
149 changes.push("inject_response_time".to_string());
150 }
151 }
152
153 changes
154}
155
156#[derive(Debug, Clone, Serialize)]
160pub struct RequestMeta {
161 pub request_id: String,
163 pub method: String,
165 pub path: String,
167 pub status: u16,
169 pub latency_us: u64,
171 pub latency_ms: u64,
173 pub client_key: String,
175 pub slow: bool,
177}
178
179pub struct MiddlewareState<S> {
183 pub id_generator: RequestIdGenerator,
184 pub config: Arc<Mutex<MiddlewareConfig>>,
185 pub server_state: Arc<Mutex<S>>,
186}
187
188fn client_key_from_headers(headers: &axum::http::HeaderMap) -> String {
191 headers
192 .get("authorization")
193 .and_then(|v| v.to_str().ok())
194 .map(|v| v.to_string())
195 .unwrap_or_else(|| "anonymous".to_string())
196}
197
198pub async fn request_middleware_fn(
209 state: axum::extract::State<Arc<Mutex<crate::axon_server::ServerState>>>,
210 request: Request<Body>,
211 next: Next,
212) -> Response {
213 let start = Instant::now();
214
215 let method = request.method().to_string();
217 let path = request.uri().path().to_string();
218 let client_key = client_key_from_headers(request.headers());
219
220 let (enabled, slow_threshold_ms, inject_id, inject_time, request_id) = {
222 let s = state.lock().unwrap();
223 let cfg = &s.middleware_config;
224 let id = s.request_id_gen.next_id();
225 (cfg.enabled, cfg.slow_threshold_ms, cfg.inject_request_id, cfg.inject_response_time, id)
226 };
227
228 let mut response = next.run(request).await;
230
231 if !enabled {
232 return response;
233 }
234
235 let elapsed = start.elapsed();
237 let _latency_us = elapsed.as_micros() as u64;
238 let latency_ms = elapsed.as_millis() as u64;
239 let status = response.status().as_u16();
240 let _slow = slow_threshold_ms > 0 && latency_ms >= slow_threshold_ms;
241
242 {
244 let mut s = state.lock().unwrap();
245 s.request_logger.record(&method, &path, status, elapsed, &client_key);
246 }
247
248 if inject_id {
250 if let Ok(val) = HeaderValue::from_str(&request_id) {
251 response.headers_mut().insert("x-request-id", val);
252 }
253 }
254 if inject_time {
255 if let Ok(val) = HeaderValue::from_str(&format!("{}ms", latency_ms)) {
256 response.headers_mut().insert("x-response-time", val);
257 }
258 }
259
260 response
261}
262
263#[derive(Debug, Clone, Serialize)]
267pub struct MiddlewareStats {
268 pub total_requests: u64,
270 pub config: MiddlewareConfig,
272}
273
274#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn request_id_generator_sequential() {
282 let gen = RequestIdGenerator::new();
283 assert_eq!(gen.next_id(), "axr-0");
284 assert_eq!(gen.next_id(), "axr-1");
285 assert_eq!(gen.next_id(), "axr-2");
286 assert_eq!(gen.count(), 3);
287 }
288
289 #[test]
290 fn request_id_generator_custom_prefix() {
291 let gen = RequestIdGenerator::with_prefix("req");
292 assert_eq!(gen.next_id(), "req-0");
293 assert_eq!(gen.next_id(), "req-1");
294 }
295
296 #[test]
297 fn request_id_generator_default() {
298 let gen = RequestIdGenerator::default();
299 assert_eq!(gen.next_id(), "axr-0");
300 }
301
302 #[test]
303 fn default_config() {
304 let cfg = MiddlewareConfig::default();
305 assert!(cfg.enabled);
306 assert_eq!(cfg.slow_threshold_ms, 5000);
307 assert!(cfg.inject_request_id);
308 assert!(cfg.inject_response_time);
309 }
310
311 #[test]
312 fn disabled_config() {
313 let cfg = MiddlewareConfig::disabled();
314 assert!(!cfg.enabled);
315 assert_eq!(cfg.slow_threshold_ms, 0);
316 assert!(!cfg.inject_request_id);
317 assert!(!cfg.inject_response_time);
318 }
319
320 #[test]
321 fn config_serializable() {
322 let cfg = MiddlewareConfig::default();
323 let json = serde_json::to_value(&cfg).unwrap();
324 assert_eq!(json["enabled"], true);
325 assert_eq!(json["slow_threshold_ms"], 5000);
326 assert_eq!(json["inject_request_id"], true);
327 assert_eq!(json["inject_response_time"], true);
328 }
329
330 #[test]
331 fn config_deserializable() {
332 let json = serde_json::json!({
333 "enabled": false,
334 "slow_threshold_ms": 1000,
335 "inject_request_id": false,
336 "inject_response_time": true,
337 });
338 let cfg: MiddlewareConfig = serde_json::from_value(json).unwrap();
339 assert!(!cfg.enabled);
340 assert_eq!(cfg.slow_threshold_ms, 1000);
341 assert!(!cfg.inject_request_id);
342 assert!(cfg.inject_response_time);
343 }
344
345 #[test]
346 fn apply_update_changes_tracked() {
347 let mut cfg = MiddlewareConfig::default();
348 let update = MiddlewareUpdate {
349 enabled: None,
350 slow_threshold_ms: Some(2000),
351 inject_request_id: Some(false),
352 inject_response_time: None,
353 };
354 let changes = apply_update(&mut cfg, &update);
355 assert_eq!(changes.len(), 2);
356 assert!(changes.contains(&"slow_threshold_ms".to_string()));
357 assert!(changes.contains(&"inject_request_id".to_string()));
358 assert_eq!(cfg.slow_threshold_ms, 2000);
359 assert!(!cfg.inject_request_id);
360 }
361
362 #[test]
363 fn apply_update_no_op_when_same() {
364 let mut cfg = MiddlewareConfig::default();
365 let update = MiddlewareUpdate {
366 enabled: Some(true),
367 slow_threshold_ms: Some(5000),
368 inject_request_id: Some(true),
369 inject_response_time: Some(true),
370 };
371 let changes = apply_update(&mut cfg, &update);
372 assert!(changes.is_empty());
373 }
374
375 #[test]
376 fn apply_update_all_fields() {
377 let mut cfg = MiddlewareConfig::default();
378 let update = MiddlewareUpdate {
379 enabled: Some(false),
380 slow_threshold_ms: Some(100),
381 inject_request_id: Some(false),
382 inject_response_time: Some(false),
383 };
384 let changes = apply_update(&mut cfg, &update);
385 assert_eq!(changes.len(), 4);
386 assert!(!cfg.enabled);
387 assert_eq!(cfg.slow_threshold_ms, 100);
388 assert!(!cfg.inject_request_id);
389 assert!(!cfg.inject_response_time);
390 }
391
392 #[test]
393 fn request_meta_serializable() {
394 let meta = RequestMeta {
395 request_id: "axr-42".to_string(),
396 method: "POST".to_string(),
397 path: "/v1/deploy".to_string(),
398 status: 200,
399 latency_us: 1500,
400 latency_ms: 1,
401 client_key: "token_abc".to_string(),
402 slow: false,
403 };
404 let json = serde_json::to_value(&meta).unwrap();
405 assert_eq!(json["request_id"], "axr-42");
406 assert_eq!(json["method"], "POST");
407 assert_eq!(json["path"], "/v1/deploy");
408 assert_eq!(json["status"], 200);
409 assert_eq!(json["latency_us"], 1500);
410 assert_eq!(json["slow"], false);
411 }
412
413 #[test]
414 fn request_meta_slow_flag() {
415 let meta = RequestMeta {
416 request_id: "axr-99".to_string(),
417 method: "GET".to_string(),
418 path: "/v1/health".to_string(),
419 status: 200,
420 latency_us: 6_000_000,
421 latency_ms: 6000,
422 client_key: "anonymous".to_string(),
423 slow: true,
424 };
425 let json = serde_json::to_value(&meta).unwrap();
426 assert_eq!(json["slow"], true);
427 assert_eq!(json["latency_ms"], 6000);
428 }
429
430 #[test]
431 fn middleware_stats_serializable() {
432 let stats = MiddlewareStats {
433 total_requests: 42,
434 config: MiddlewareConfig::default(),
435 };
436 let json = serde_json::to_value(&stats).unwrap();
437 assert_eq!(json["total_requests"], 42);
438 assert_eq!(json["config"]["enabled"], true);
439 }
440
441 #[test]
442 fn client_key_extraction() {
443 let mut headers = axum::http::HeaderMap::new();
444 assert_eq!(client_key_from_headers(&headers), "anonymous");
445
446 headers.insert("authorization", HeaderValue::from_static("Bearer token123"));
447 assert_eq!(client_key_from_headers(&headers), "Bearer token123");
448 }
449}