Skip to main content

oxibonsai_runtime/
middleware.rs

1//! Request middleware: context injection, logging, CORS, and idempotency caching.
2//!
3//! This module provides building blocks for production-grade HTTP middleware:
4//!
5//! - [`RequestContext`] — per-request metadata injected at the entry point
6//! - [`RequestIdGen`] — atomic, monotonically increasing request ID generator
7//! - [`RequestLogger`] — structured request/response logging with optional body capture
8//! - [`CorsConfig`] — configurable CORS policy with header generation helpers
9//! - [`IdempotencyCache`] — idempotency-key cache for safe request deduplication
10//!
11//! # Example
12//!
13//! ```
14//! use oxibonsai_runtime::middleware::{RequestContext, RequestLogger, CorsConfig};
15//!
16//! let ctx = RequestContext::new("/v1/chat/completions", "POST", "10.0.0.1");
17//! let logger = RequestLogger::new();
18//! logger.log_request(&ctx);
19//! logger.log_response(&ctx, 200, 512);
20//!
21//! let cors = CorsConfig::default();
22//! assert!(cors.is_origin_allowed("*"));
23//! ```
24
25use std::collections::HashMap;
26use std::sync::{
27    atomic::{AtomicU64, Ordering},
28    Mutex,
29};
30use std::time::{Duration, Instant};
31
32// ─── RequestContext ──────────────────────────────────────────────────────────
33
34/// Per-request context injected by middleware at the entry point.
35///
36/// Carries metadata needed for logging, tracing, and metrics throughout
37/// the request lifetime.
38#[derive(Debug, Clone)]
39pub struct RequestContext {
40    /// Unique request identifier (e.g. `"oxibonsai-1714000000000-1"`).
41    pub request_id: String,
42    /// Caller identity — typically an IP address or API key prefix.
43    pub client_id: String,
44    /// Wall-clock instant when the request was received.
45    pub started_at: Instant,
46    /// Request path (e.g. `"/v1/chat/completions"`).
47    pub path: String,
48    /// HTTP method in upper-case (e.g. `"POST"`).
49    pub method: String,
50}
51
52impl RequestContext {
53    /// Create a new context with an auto-generated request ID.
54    pub fn new(path: &str, method: &str, client_id: &str) -> Self {
55        // Generate a lightweight ID without an external generator so the type
56        // is self-contained; callers can supply a [`RequestIdGen`] for prod use.
57        let ts_ms = std::time::SystemTime::now()
58            .duration_since(std::time::UNIX_EPOCH)
59            .unwrap_or_default()
60            .as_millis();
61        let request_id = format!("req-{ts_ms}");
62        Self {
63            request_id,
64            client_id: client_id.to_owned(),
65            started_at: Instant::now(),
66            path: path.to_owned(),
67            method: method.to_uppercase(),
68        }
69    }
70
71    /// Create a context with an explicit request ID (used with [`RequestIdGen`]).
72    pub fn with_id(request_id: String, path: &str, method: &str, client_id: &str) -> Self {
73        Self {
74            request_id,
75            client_id: client_id.to_owned(),
76            started_at: Instant::now(),
77            path: path.to_owned(),
78            method: method.to_uppercase(),
79        }
80    }
81
82    /// Elapsed time since the request was received, in milliseconds.
83    pub fn elapsed_ms(&self) -> u64 {
84        self.started_at.elapsed().as_millis() as u64
85    }
86
87    /// Elapsed time since the request was received as a [`Duration`].
88    pub fn elapsed(&self) -> Duration {
89        self.started_at.elapsed()
90    }
91}
92
93// ─── RequestIdGen ────────────────────────────────────────────────────────────
94
95/// Atomic, monotonically increasing request ID generator.
96///
97/// IDs have the form `"{prefix}-{timestamp_ms}-{counter}"`, e.g.
98/// `"oxibonsai-1714000000000-42"`. The combination of a millisecond
99/// timestamp and a per-process counter makes collisions practically
100/// impossible across restarts.
101pub struct RequestIdGen {
102    counter: AtomicU64,
103    prefix: String,
104}
105
106impl RequestIdGen {
107    /// Create a new generator with the given prefix string.
108    pub fn new(prefix: &str) -> Self {
109        Self {
110            counter: AtomicU64::new(0),
111            prefix: prefix.to_owned(),
112        }
113    }
114
115    /// Generate the next unique request ID.
116    ///
117    /// Format: `"{prefix}-{timestamp_ms}-{counter}"`
118    pub fn next(&self) -> String {
119        let ts_ms = std::time::SystemTime::now()
120            .duration_since(std::time::UNIX_EPOCH)
121            .unwrap_or_default()
122            .as_millis();
123        let counter = self.counter.fetch_add(1, Ordering::Relaxed);
124        format!("{}-{ts_ms}-{counter}", self.prefix)
125    }
126}
127
128// ─── RequestLogger ───────────────────────────────────────────────────────────
129
130/// Middleware that logs each request and its corresponding response.
131///
132/// Emits structured log lines via [`tracing`]. Body logging is opt-in and
133/// truncated at `max_body_log_bytes` to avoid flooding logs with large payloads.
134pub struct RequestLogger {
135    /// Whether to include body content in log output.
136    pub log_bodies: bool,
137    /// Maximum number of body bytes to include in a log line.
138    pub max_body_log_bytes: usize,
139}
140
141impl RequestLogger {
142    /// Create a logger that does not log bodies.
143    pub fn new() -> Self {
144        Self {
145            log_bodies: false,
146            max_body_log_bytes: 0,
147        }
148    }
149
150    /// Create a logger that includes up to `max_bytes` of body content.
151    pub fn with_body_logging(max_bytes: usize) -> Self {
152        Self {
153            log_bodies: true,
154            max_body_log_bytes: max_bytes,
155        }
156    }
157
158    /// Log an incoming request.
159    pub fn log_request(&self, ctx: &RequestContext) {
160        let line = Self::format_request_line(ctx);
161        tracing::info!(target: "oxibonsai::middleware", "{line}");
162    }
163
164    /// Log an outgoing response.
165    pub fn log_response(&self, ctx: &RequestContext, status: u16, body_bytes: usize) {
166        let elapsed_ms = ctx.elapsed_ms();
167        let line = Self::format_response_line(ctx, status, elapsed_ms);
168        if self.log_bodies && body_bytes > 0 {
169            tracing::info!(
170                target: "oxibonsai::middleware",
171                "{line} body_bytes={body_bytes}"
172            );
173        } else {
174            tracing::info!(target: "oxibonsai::middleware", "{line}");
175        }
176    }
177
178    /// Format an incoming-request log line.
179    ///
180    /// Output: `"[{request_id}] {method} {path} from {client_id}"`
181    pub fn format_request_line(ctx: &RequestContext) -> String {
182        format!(
183            "[{}] {} {} from {}",
184            ctx.request_id, ctx.method, ctx.path, ctx.client_id
185        )
186    }
187
188    /// Format an outgoing-response log line.
189    ///
190    /// Output: `"[{request_id}] {status} in {elapsed_ms}ms"`
191    pub fn format_response_line(ctx: &RequestContext, status: u16, elapsed_ms: u64) -> String {
192        format!("[{}] {} in {}ms", ctx.request_id, status, elapsed_ms)
193    }
194}
195
196impl Default for RequestLogger {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202// ─── CorsConfig ──────────────────────────────────────────────────────────────
203
204/// Cross-Origin Resource Sharing (CORS) policy configuration.
205///
206/// Used to generate `Access-Control-*` headers for preflight and main requests.
207#[derive(Debug, Clone)]
208pub struct CorsConfig {
209    /// Allowed origins. Use `["*"]` to permit all origins.
210    pub allowed_origins: Vec<String>,
211    /// Allowed HTTP methods.
212    pub allowed_methods: Vec<String>,
213    /// Allowed request headers.
214    pub allowed_headers: Vec<String>,
215    /// `Access-Control-Max-Age` in seconds (how long browsers may cache the preflight).
216    pub max_age_secs: u64,
217    /// Whether to allow credentials (cookies, auth headers).
218    pub allow_credentials: bool,
219}
220
221impl Default for CorsConfig {
222    fn default() -> Self {
223        Self {
224            allowed_origins: vec!["*".to_string()],
225            allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
226            allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
227            max_age_secs: 3600,
228            allow_credentials: false,
229        }
230    }
231}
232
233impl CorsConfig {
234    /// Returns `true` if the given `origin` is permitted by this policy.
235    ///
236    /// An entry of `"*"` in `allowed_origins` permits all origins.
237    pub fn is_origin_allowed(&self, origin: &str) -> bool {
238        self.allowed_origins.iter().any(|o| o == "*" || o == origin)
239    }
240
241    /// Generate `Access-Control-*` response headers as `(name, value)` pairs.
242    ///
243    /// Returns headers suitable for both preflight (`OPTIONS`) and actual responses.
244    pub fn access_control_headers(&self) -> Vec<(String, String)> {
245        let mut headers = Vec::with_capacity(5);
246
247        let origin_value = if self.allowed_origins.iter().any(|o| o == "*") {
248            "*".to_owned()
249        } else {
250            self.allowed_origins.join(", ")
251        };
252        headers.push(("Access-Control-Allow-Origin".to_owned(), origin_value));
253
254        headers.push((
255            "Access-Control-Allow-Methods".to_owned(),
256            self.allowed_methods.join(", "),
257        ));
258
259        headers.push((
260            "Access-Control-Allow-Headers".to_owned(),
261            self.allowed_headers.join(", "),
262        ));
263
264        headers.push((
265            "Access-Control-Max-Age".to_owned(),
266            self.max_age_secs.to_string(),
267        ));
268
269        if self.allow_credentials {
270            headers.push((
271                "Access-Control-Allow-Credentials".to_owned(),
272                "true".to_owned(),
273            ));
274        }
275
276        headers
277    }
278}
279
280// ─── IdempotencyCache ────────────────────────────────────────────────────────
281
282/// Cached entry for a previously processed idempotent request.
283struct CachedResponse {
284    status: u16,
285    body: Vec<u8>,
286    created_at: Instant,
287}
288
289/// Request deduplication cache keyed on client-supplied idempotency keys.
290///
291/// When a client sends the same idempotency key twice, the second request
292/// receives the cached response without re-executing the operation.
293/// Entries expire after `ttl` and are lazily evicted.
294pub struct IdempotencyCache {
295    cache: Mutex<HashMap<String, CachedResponse>>,
296    max_entries: usize,
297    ttl: Duration,
298}
299
300impl IdempotencyCache {
301    /// Create a new cache with the given capacity and TTL.
302    pub fn new(max_entries: usize, ttl: Duration) -> Self {
303        Self {
304            cache: Mutex::new(HashMap::new()),
305            max_entries,
306            ttl,
307        }
308    }
309
310    /// Look up a previously cached response by idempotency key.
311    ///
312    /// Returns `(status_code, body)` if a fresh entry exists; `None` otherwise.
313    pub fn get(&self, key: &str) -> Option<(u16, Vec<u8>)> {
314        let cache = self.cache.lock().expect("idempotency cache mutex poisoned");
315        if let Some(entry) = cache.get(key) {
316            if entry.created_at.elapsed() < self.ttl {
317                return Some((entry.status, entry.body.clone()));
318            }
319        }
320        None
321    }
322
323    /// Store a response under the given idempotency key.
324    ///
325    /// If the cache is full, expired entries are evicted first. If still
326    /// full after eviction, the insert is silently dropped to prevent
327    /// unbounded memory growth.
328    pub fn insert(&self, key: &str, status: u16, body: Vec<u8>) {
329        let mut cache = self.cache.lock().expect("idempotency cache mutex poisoned");
330
331        // Evict expired entries when approaching capacity.
332        if cache.len() >= self.max_entries {
333            let ttl = self.ttl;
334            cache.retain(|_, v| v.created_at.elapsed() < ttl);
335        }
336
337        // After eviction, only insert if we still have room.
338        if cache.len() < self.max_entries {
339            cache.insert(
340                key.to_owned(),
341                CachedResponse {
342                    status,
343                    body,
344                    created_at: Instant::now(),
345                },
346            );
347        }
348    }
349
350    /// Remove all expired entries from the cache.
351    pub fn evict_expired(&self) {
352        let ttl = self.ttl;
353        let mut cache = self.cache.lock().expect("idempotency cache mutex poisoned");
354        cache.retain(|_, v| v.created_at.elapsed() < ttl);
355    }
356
357    /// Return the number of entries currently in the cache (including stale ones).
358    pub fn len(&self) -> usize {
359        self.cache
360            .lock()
361            .expect("idempotency cache mutex poisoned")
362            .len()
363    }
364
365    /// Returns `true` if the cache contains no entries.
366    pub fn is_empty(&self) -> bool {
367        self.len() == 0
368    }
369}
370
371// ─── Tests ───────────────────────────────────────────────────────────────────
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use std::thread;
377
378    #[test]
379    fn test_request_context_elapsed() {
380        let ctx = RequestContext::new("/health", "GET", "10.0.0.1");
381        // Elapsed should be very small immediately after creation.
382        assert!(
383            ctx.elapsed_ms() < 500,
384            "elapsed should be <500ms at creation"
385        );
386        assert!(ctx.elapsed() < Duration::from_millis(500));
387    }
388
389    #[test]
390    fn test_request_id_gen_unique() {
391        let gen = RequestIdGen::new("test");
392        let ids: Vec<String> = (0..100).map(|_| gen.next()).collect();
393        let unique: std::collections::HashSet<&String> = ids.iter().collect();
394        assert_eq!(unique.len(), ids.len(), "all generated IDs must be unique");
395    }
396
397    #[test]
398    fn test_request_id_gen_prefix() {
399        let gen = RequestIdGen::new("oxibonsai");
400        let id = gen.next();
401        assert!(
402            id.starts_with("oxibonsai-"),
403            "ID should start with prefix; got {id}"
404        );
405    }
406
407    #[test]
408    fn test_request_logger_format_request_line() {
409        let mut ctx = RequestContext::new("/v1/chat/completions", "post", "1.2.3.4");
410        ctx.request_id = "req-42".to_owned();
411        let line = RequestLogger::format_request_line(&ctx);
412        assert_eq!(line, "[req-42] POST /v1/chat/completions from 1.2.3.4");
413    }
414
415    #[test]
416    fn test_request_logger_format_response_line() {
417        let mut ctx = RequestContext::new("/health", "GET", "127.0.0.1");
418        ctx.request_id = "req-99".to_owned();
419        let line = RequestLogger::format_response_line(&ctx, 200, 15);
420        assert_eq!(line, "[req-99] 200 in 15ms");
421    }
422
423    #[test]
424    fn test_cors_config_default_allows_all() {
425        let cors = CorsConfig::default();
426        assert!(cors.is_origin_allowed("https://example.com"));
427        assert!(cors.is_origin_allowed("null"));
428        assert!(cors.is_origin_allowed("*"));
429    }
430
431    #[test]
432    fn test_cors_config_specific_origin() {
433        let cors = CorsConfig {
434            allowed_origins: vec!["https://app.example.com".to_string()],
435            ..Default::default()
436        };
437        assert!(cors.is_origin_allowed("https://app.example.com"));
438        assert!(!cors.is_origin_allowed("https://evil.example.com"));
439    }
440
441    #[test]
442    fn test_cors_access_control_headers() {
443        let cors = CorsConfig::default();
444        let headers = cors.access_control_headers();
445
446        // Should contain Access-Control-Allow-Origin
447        let has_origin = headers
448            .iter()
449            .any(|(k, v)| k == "Access-Control-Allow-Origin" && v == "*");
450        assert!(has_origin, "should have wildcard Allow-Origin header");
451
452        // Should contain methods
453        let has_methods = headers
454            .iter()
455            .any(|(k, _)| k == "Access-Control-Allow-Methods");
456        assert!(has_methods);
457
458        // allow_credentials is false by default, so no credentials header
459        let has_creds = headers
460            .iter()
461            .any(|(k, _)| k == "Access-Control-Allow-Credentials");
462        assert!(
463            !has_creds,
464            "should not include credentials header by default"
465        );
466    }
467
468    #[test]
469    fn test_idempotency_cache_insert_and_get() {
470        let cache = IdempotencyCache::new(100, Duration::from_secs(60));
471        cache.insert("key-1", 200, b"hello".to_vec());
472        let result = cache.get("key-1");
473        assert_eq!(result, Some((200, b"hello".to_vec())));
474    }
475
476    #[test]
477    fn test_idempotency_cache_miss() {
478        let cache = IdempotencyCache::new(100, Duration::from_secs(60));
479        assert!(cache.get("nonexistent-key").is_none());
480    }
481
482    #[test]
483    fn test_idempotency_cache_evicts_expired() {
484        // TTL of 10ms so entries expire quickly in tests.
485        let cache = IdempotencyCache::new(100, Duration::from_millis(10));
486        cache.insert("exp-key", 200, vec![]);
487        assert_eq!(cache.len(), 1);
488
489        thread::sleep(Duration::from_millis(20));
490        cache.evict_expired();
491        assert_eq!(cache.len(), 0, "expired entry should have been evicted");
492    }
493
494    #[test]
495    fn test_idempotency_cache_expired_returns_none() {
496        let cache = IdempotencyCache::new(100, Duration::from_millis(10));
497        cache.insert("ttl-key", 201, b"data".to_vec());
498        thread::sleep(Duration::from_millis(20));
499        // get() should not return stale entries.
500        assert!(
501            cache.get("ttl-key").is_none(),
502            "stale cache entry must not be returned"
503        );
504    }
505}