1use std::collections::HashMap;
26use std::sync::{
27 atomic::{AtomicU64, Ordering},
28 Mutex,
29};
30use std::time::{Duration, Instant};
31
32#[derive(Debug, Clone)]
39pub struct RequestContext {
40 pub request_id: String,
42 pub client_id: String,
44 pub started_at: Instant,
46 pub path: String,
48 pub method: String,
50}
51
52impl RequestContext {
53 pub fn new(path: &str, method: &str, client_id: &str) -> Self {
55 let ts_ms = std::time::SystemTime::now()
58 .duration_since(std::time::UNIX_EPOCH)
59 .unwrap_or_default()
60 .as_millis();
61 let request_id = format!("req-{ts_ms}");
62 Self {
63 request_id,
64 client_id: client_id.to_owned(),
65 started_at: Instant::now(),
66 path: path.to_owned(),
67 method: method.to_uppercase(),
68 }
69 }
70
71 pub fn with_id(request_id: String, path: &str, method: &str, client_id: &str) -> Self {
73 Self {
74 request_id,
75 client_id: client_id.to_owned(),
76 started_at: Instant::now(),
77 path: path.to_owned(),
78 method: method.to_uppercase(),
79 }
80 }
81
82 pub fn elapsed_ms(&self) -> u64 {
84 self.started_at.elapsed().as_millis() as u64
85 }
86
87 pub fn elapsed(&self) -> Duration {
89 self.started_at.elapsed()
90 }
91}
92
93pub struct RequestIdGen {
102 counter: AtomicU64,
103 prefix: String,
104}
105
106impl RequestIdGen {
107 pub fn new(prefix: &str) -> Self {
109 Self {
110 counter: AtomicU64::new(0),
111 prefix: prefix.to_owned(),
112 }
113 }
114
115 pub fn next(&self) -> String {
119 let ts_ms = std::time::SystemTime::now()
120 .duration_since(std::time::UNIX_EPOCH)
121 .unwrap_or_default()
122 .as_millis();
123 let counter = self.counter.fetch_add(1, Ordering::Relaxed);
124 format!("{}-{ts_ms}-{counter}", self.prefix)
125 }
126}
127
128pub struct RequestLogger {
135 pub log_bodies: bool,
137 pub max_body_log_bytes: usize,
139}
140
141impl RequestLogger {
142 pub fn new() -> Self {
144 Self {
145 log_bodies: false,
146 max_body_log_bytes: 0,
147 }
148 }
149
150 pub fn with_body_logging(max_bytes: usize) -> Self {
152 Self {
153 log_bodies: true,
154 max_body_log_bytes: max_bytes,
155 }
156 }
157
158 pub fn log_request(&self, ctx: &RequestContext) {
160 let line = Self::format_request_line(ctx);
161 tracing::info!(target: "oxibonsai::middleware", "{line}");
162 }
163
164 pub fn log_response(&self, ctx: &RequestContext, status: u16, body_bytes: usize) {
166 let elapsed_ms = ctx.elapsed_ms();
167 let line = Self::format_response_line(ctx, status, elapsed_ms);
168 if self.log_bodies && body_bytes > 0 {
169 tracing::info!(
170 target: "oxibonsai::middleware",
171 "{line} body_bytes={body_bytes}"
172 );
173 } else {
174 tracing::info!(target: "oxibonsai::middleware", "{line}");
175 }
176 }
177
178 pub fn format_request_line(ctx: &RequestContext) -> String {
182 format!(
183 "[{}] {} {} from {}",
184 ctx.request_id, ctx.method, ctx.path, ctx.client_id
185 )
186 }
187
188 pub fn format_response_line(ctx: &RequestContext, status: u16, elapsed_ms: u64) -> String {
192 format!("[{}] {} in {}ms", ctx.request_id, status, elapsed_ms)
193 }
194}
195
196impl Default for RequestLogger {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202#[derive(Debug, Clone)]
208pub struct CorsConfig {
209 pub allowed_origins: Vec<String>,
211 pub allowed_methods: Vec<String>,
213 pub allowed_headers: Vec<String>,
215 pub max_age_secs: u64,
217 pub allow_credentials: bool,
219}
220
221impl Default for CorsConfig {
222 fn default() -> Self {
223 Self {
224 allowed_origins: vec!["*".to_string()],
225 allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
226 allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
227 max_age_secs: 3600,
228 allow_credentials: false,
229 }
230 }
231}
232
233impl CorsConfig {
234 pub fn is_origin_allowed(&self, origin: &str) -> bool {
238 self.allowed_origins.iter().any(|o| o == "*" || o == origin)
239 }
240
241 pub fn access_control_headers(&self) -> Vec<(String, String)> {
245 let mut headers = Vec::with_capacity(5);
246
247 let origin_value = if self.allowed_origins.iter().any(|o| o == "*") {
248 "*".to_owned()
249 } else {
250 self.allowed_origins.join(", ")
251 };
252 headers.push(("Access-Control-Allow-Origin".to_owned(), origin_value));
253
254 headers.push((
255 "Access-Control-Allow-Methods".to_owned(),
256 self.allowed_methods.join(", "),
257 ));
258
259 headers.push((
260 "Access-Control-Allow-Headers".to_owned(),
261 self.allowed_headers.join(", "),
262 ));
263
264 headers.push((
265 "Access-Control-Max-Age".to_owned(),
266 self.max_age_secs.to_string(),
267 ));
268
269 if self.allow_credentials {
270 headers.push((
271 "Access-Control-Allow-Credentials".to_owned(),
272 "true".to_owned(),
273 ));
274 }
275
276 headers
277 }
278}
279
280struct CachedResponse {
284 status: u16,
285 body: Vec<u8>,
286 created_at: Instant,
287}
288
289pub struct IdempotencyCache {
295 cache: Mutex<HashMap<String, CachedResponse>>,
296 max_entries: usize,
297 ttl: Duration,
298}
299
300impl IdempotencyCache {
301 pub fn new(max_entries: usize, ttl: Duration) -> Self {
303 Self {
304 cache: Mutex::new(HashMap::new()),
305 max_entries,
306 ttl,
307 }
308 }
309
310 pub fn get(&self, key: &str) -> Option<(u16, Vec<u8>)> {
314 let cache = self.cache.lock().expect("idempotency cache mutex poisoned");
315 if let Some(entry) = cache.get(key) {
316 if entry.created_at.elapsed() < self.ttl {
317 return Some((entry.status, entry.body.clone()));
318 }
319 }
320 None
321 }
322
323 pub fn insert(&self, key: &str, status: u16, body: Vec<u8>) {
329 let mut cache = self.cache.lock().expect("idempotency cache mutex poisoned");
330
331 if cache.len() >= self.max_entries {
333 let ttl = self.ttl;
334 cache.retain(|_, v| v.created_at.elapsed() < ttl);
335 }
336
337 if cache.len() < self.max_entries {
339 cache.insert(
340 key.to_owned(),
341 CachedResponse {
342 status,
343 body,
344 created_at: Instant::now(),
345 },
346 );
347 }
348 }
349
350 pub fn evict_expired(&self) {
352 let ttl = self.ttl;
353 let mut cache = self.cache.lock().expect("idempotency cache mutex poisoned");
354 cache.retain(|_, v| v.created_at.elapsed() < ttl);
355 }
356
357 pub fn len(&self) -> usize {
359 self.cache
360 .lock()
361 .expect("idempotency cache mutex poisoned")
362 .len()
363 }
364
365 pub fn is_empty(&self) -> bool {
367 self.len() == 0
368 }
369}
370
371#[cfg(test)]
374mod tests {
375 use super::*;
376 use std::thread;
377
378 #[test]
379 fn test_request_context_elapsed() {
380 let ctx = RequestContext::new("/health", "GET", "10.0.0.1");
381 assert!(
383 ctx.elapsed_ms() < 500,
384 "elapsed should be <500ms at creation"
385 );
386 assert!(ctx.elapsed() < Duration::from_millis(500));
387 }
388
389 #[test]
390 fn test_request_id_gen_unique() {
391 let gen = RequestIdGen::new("test");
392 let ids: Vec<String> = (0..100).map(|_| gen.next()).collect();
393 let unique: std::collections::HashSet<&String> = ids.iter().collect();
394 assert_eq!(unique.len(), ids.len(), "all generated IDs must be unique");
395 }
396
397 #[test]
398 fn test_request_id_gen_prefix() {
399 let gen = RequestIdGen::new("oxibonsai");
400 let id = gen.next();
401 assert!(
402 id.starts_with("oxibonsai-"),
403 "ID should start with prefix; got {id}"
404 );
405 }
406
407 #[test]
408 fn test_request_logger_format_request_line() {
409 let mut ctx = RequestContext::new("/v1/chat/completions", "post", "1.2.3.4");
410 ctx.request_id = "req-42".to_owned();
411 let line = RequestLogger::format_request_line(&ctx);
412 assert_eq!(line, "[req-42] POST /v1/chat/completions from 1.2.3.4");
413 }
414
415 #[test]
416 fn test_request_logger_format_response_line() {
417 let mut ctx = RequestContext::new("/health", "GET", "127.0.0.1");
418 ctx.request_id = "req-99".to_owned();
419 let line = RequestLogger::format_response_line(&ctx, 200, 15);
420 assert_eq!(line, "[req-99] 200 in 15ms");
421 }
422
423 #[test]
424 fn test_cors_config_default_allows_all() {
425 let cors = CorsConfig::default();
426 assert!(cors.is_origin_allowed("https://example.com"));
427 assert!(cors.is_origin_allowed("null"));
428 assert!(cors.is_origin_allowed("*"));
429 }
430
431 #[test]
432 fn test_cors_config_specific_origin() {
433 let cors = CorsConfig {
434 allowed_origins: vec!["https://app.example.com".to_string()],
435 ..Default::default()
436 };
437 assert!(cors.is_origin_allowed("https://app.example.com"));
438 assert!(!cors.is_origin_allowed("https://evil.example.com"));
439 }
440
441 #[test]
442 fn test_cors_access_control_headers() {
443 let cors = CorsConfig::default();
444 let headers = cors.access_control_headers();
445
446 let has_origin = headers
448 .iter()
449 .any(|(k, v)| k == "Access-Control-Allow-Origin" && v == "*");
450 assert!(has_origin, "should have wildcard Allow-Origin header");
451
452 let has_methods = headers
454 .iter()
455 .any(|(k, _)| k == "Access-Control-Allow-Methods");
456 assert!(has_methods);
457
458 let has_creds = headers
460 .iter()
461 .any(|(k, _)| k == "Access-Control-Allow-Credentials");
462 assert!(
463 !has_creds,
464 "should not include credentials header by default"
465 );
466 }
467
468 #[test]
469 fn test_idempotency_cache_insert_and_get() {
470 let cache = IdempotencyCache::new(100, Duration::from_secs(60));
471 cache.insert("key-1", 200, b"hello".to_vec());
472 let result = cache.get("key-1");
473 assert_eq!(result, Some((200, b"hello".to_vec())));
474 }
475
476 #[test]
477 fn test_idempotency_cache_miss() {
478 let cache = IdempotencyCache::new(100, Duration::from_secs(60));
479 assert!(cache.get("nonexistent-key").is_none());
480 }
481
482 #[test]
483 fn test_idempotency_cache_evicts_expired() {
484 let cache = IdempotencyCache::new(100, Duration::from_millis(10));
486 cache.insert("exp-key", 200, vec![]);
487 assert_eq!(cache.len(), 1);
488
489 thread::sleep(Duration::from_millis(20));
490 cache.evict_expired();
491 assert_eq!(cache.len(), 0, "expired entry should have been evicted");
492 }
493
494 #[test]
495 fn test_idempotency_cache_expired_returns_none() {
496 let cache = IdempotencyCache::new(100, Duration::from_millis(10));
497 cache.insert("ttl-key", 201, b"data".to_vec());
498 thread::sleep(Duration::from_millis(20));
499 assert!(
501 cache.get("ttl-key").is_none(),
502 "stale cache entry must not be returned"
503 );
504 }
505}