use std::collections::HashMap;
use std::sync::{
atomic::{AtomicU64, Ordering},
Mutex,
};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct RequestContext {
pub request_id: String,
pub client_id: String,
pub started_at: Instant,
pub path: String,
pub method: String,
}
impl RequestContext {
pub fn new(path: &str, method: &str, client_id: &str) -> Self {
let ts_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let request_id = format!("req-{ts_ms}");
Self {
request_id,
client_id: client_id.to_owned(),
started_at: Instant::now(),
path: path.to_owned(),
method: method.to_uppercase(),
}
}
pub fn with_id(request_id: String, path: &str, method: &str, client_id: &str) -> Self {
Self {
request_id,
client_id: client_id.to_owned(),
started_at: Instant::now(),
path: path.to_owned(),
method: method.to_uppercase(),
}
}
pub fn elapsed_ms(&self) -> u64 {
self.started_at.elapsed().as_millis() as u64
}
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
}
pub struct RequestIdGen {
counter: AtomicU64,
prefix: String,
}
impl RequestIdGen {
pub fn new(prefix: &str) -> Self {
Self {
counter: AtomicU64::new(0),
prefix: prefix.to_owned(),
}
}
pub fn next(&self) -> String {
let ts_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let counter = self.counter.fetch_add(1, Ordering::Relaxed);
format!("{}-{ts_ms}-{counter}", self.prefix)
}
}
pub struct RequestLogger {
pub log_bodies: bool,
pub max_body_log_bytes: usize,
}
impl RequestLogger {
pub fn new() -> Self {
Self {
log_bodies: false,
max_body_log_bytes: 0,
}
}
pub fn with_body_logging(max_bytes: usize) -> Self {
Self {
log_bodies: true,
max_body_log_bytes: max_bytes,
}
}
pub fn log_request(&self, ctx: &RequestContext) {
let line = Self::format_request_line(ctx);
tracing::info!(target: "oxibonsai::middleware", "{line}");
}
pub fn log_response(&self, ctx: &RequestContext, status: u16, body_bytes: usize) {
let elapsed_ms = ctx.elapsed_ms();
let line = Self::format_response_line(ctx, status, elapsed_ms);
if self.log_bodies && body_bytes > 0 {
tracing::info!(
target: "oxibonsai::middleware",
"{line} body_bytes={body_bytes}"
);
} else {
tracing::info!(target: "oxibonsai::middleware", "{line}");
}
}
pub fn format_request_line(ctx: &RequestContext) -> String {
format!(
"[{}] {} {} from {}",
ctx.request_id, ctx.method, ctx.path, ctx.client_id
)
}
pub fn format_response_line(ctx: &RequestContext, status: u16, elapsed_ms: u64) -> String {
format!("[{}] {} in {}ms", ctx.request_id, status, elapsed_ms)
}
}
impl Default for RequestLogger {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
pub max_age_secs: u64,
pub allow_credentials: bool,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allowed_origins: vec!["*".to_string()],
allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
max_age_secs: 3600,
allow_credentials: false,
}
}
}
impl CorsConfig {
pub fn is_origin_allowed(&self, origin: &str) -> bool {
self.allowed_origins.iter().any(|o| o == "*" || o == origin)
}
pub fn access_control_headers(&self) -> Vec<(String, String)> {
let mut headers = Vec::with_capacity(5);
let origin_value = if self.allowed_origins.iter().any(|o| o == "*") {
"*".to_owned()
} else {
self.allowed_origins.join(", ")
};
headers.push(("Access-Control-Allow-Origin".to_owned(), origin_value));
headers.push((
"Access-Control-Allow-Methods".to_owned(),
self.allowed_methods.join(", "),
));
headers.push((
"Access-Control-Allow-Headers".to_owned(),
self.allowed_headers.join(", "),
));
headers.push((
"Access-Control-Max-Age".to_owned(),
self.max_age_secs.to_string(),
));
if self.allow_credentials {
headers.push((
"Access-Control-Allow-Credentials".to_owned(),
"true".to_owned(),
));
}
headers
}
}
struct CachedResponse {
status: u16,
body: Vec<u8>,
created_at: Instant,
}
pub struct IdempotencyCache {
cache: Mutex<HashMap<String, CachedResponse>>,
max_entries: usize,
ttl: Duration,
}
impl IdempotencyCache {
pub fn new(max_entries: usize, ttl: Duration) -> Self {
Self {
cache: Mutex::new(HashMap::new()),
max_entries,
ttl,
}
}
pub fn get(&self, key: &str) -> Option<(u16, Vec<u8>)> {
let cache = self.cache.lock().expect("idempotency cache mutex poisoned");
if let Some(entry) = cache.get(key) {
if entry.created_at.elapsed() < self.ttl {
return Some((entry.status, entry.body.clone()));
}
}
None
}
pub fn insert(&self, key: &str, status: u16, body: Vec<u8>) {
let mut cache = self.cache.lock().expect("idempotency cache mutex poisoned");
if cache.len() >= self.max_entries {
let ttl = self.ttl;
cache.retain(|_, v| v.created_at.elapsed() < ttl);
}
if cache.len() < self.max_entries {
cache.insert(
key.to_owned(),
CachedResponse {
status,
body,
created_at: Instant::now(),
},
);
}
}
pub fn evict_expired(&self) {
let ttl = self.ttl;
let mut cache = self.cache.lock().expect("idempotency cache mutex poisoned");
cache.retain(|_, v| v.created_at.elapsed() < ttl);
}
pub fn len(&self) -> usize {
self.cache
.lock()
.expect("idempotency cache mutex poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_request_context_elapsed() {
let ctx = RequestContext::new("/health", "GET", "10.0.0.1");
assert!(
ctx.elapsed_ms() < 500,
"elapsed should be <500ms at creation"
);
assert!(ctx.elapsed() < Duration::from_millis(500));
}
#[test]
fn test_request_id_gen_unique() {
let gen = RequestIdGen::new("test");
let ids: Vec<String> = (0..100).map(|_| gen.next()).collect();
let unique: std::collections::HashSet<&String> = ids.iter().collect();
assert_eq!(unique.len(), ids.len(), "all generated IDs must be unique");
}
#[test]
fn test_request_id_gen_prefix() {
let gen = RequestIdGen::new("oxibonsai");
let id = gen.next();
assert!(
id.starts_with("oxibonsai-"),
"ID should start with prefix; got {id}"
);
}
#[test]
fn test_request_logger_format_request_line() {
let mut ctx = RequestContext::new("/v1/chat/completions", "post", "1.2.3.4");
ctx.request_id = "req-42".to_owned();
let line = RequestLogger::format_request_line(&ctx);
assert_eq!(line, "[req-42] POST /v1/chat/completions from 1.2.3.4");
}
#[test]
fn test_request_logger_format_response_line() {
let mut ctx = RequestContext::new("/health", "GET", "127.0.0.1");
ctx.request_id = "req-99".to_owned();
let line = RequestLogger::format_response_line(&ctx, 200, 15);
assert_eq!(line, "[req-99] 200 in 15ms");
}
#[test]
fn test_cors_config_default_allows_all() {
let cors = CorsConfig::default();
assert!(cors.is_origin_allowed("https://example.com"));
assert!(cors.is_origin_allowed("null"));
assert!(cors.is_origin_allowed("*"));
}
#[test]
fn test_cors_config_specific_origin() {
let cors = CorsConfig {
allowed_origins: vec!["https://app.example.com".to_string()],
..Default::default()
};
assert!(cors.is_origin_allowed("https://app.example.com"));
assert!(!cors.is_origin_allowed("https://evil.example.com"));
}
#[test]
fn test_cors_access_control_headers() {
let cors = CorsConfig::default();
let headers = cors.access_control_headers();
let has_origin = headers
.iter()
.any(|(k, v)| k == "Access-Control-Allow-Origin" && v == "*");
assert!(has_origin, "should have wildcard Allow-Origin header");
let has_methods = headers
.iter()
.any(|(k, _)| k == "Access-Control-Allow-Methods");
assert!(has_methods);
let has_creds = headers
.iter()
.any(|(k, _)| k == "Access-Control-Allow-Credentials");
assert!(
!has_creds,
"should not include credentials header by default"
);
}
#[test]
fn test_idempotency_cache_insert_and_get() {
let cache = IdempotencyCache::new(100, Duration::from_secs(60));
cache.insert("key-1", 200, b"hello".to_vec());
let result = cache.get("key-1");
assert_eq!(result, Some((200, b"hello".to_vec())));
}
#[test]
fn test_idempotency_cache_miss() {
let cache = IdempotencyCache::new(100, Duration::from_secs(60));
assert!(cache.get("nonexistent-key").is_none());
}
#[test]
fn test_idempotency_cache_evicts_expired() {
let cache = IdempotencyCache::new(100, Duration::from_millis(10));
cache.insert("exp-key", 200, vec![]);
assert_eq!(cache.len(), 1);
thread::sleep(Duration::from_millis(20));
cache.evict_expired();
assert_eq!(cache.len(), 0, "expired entry should have been evicted");
}
#[test]
fn test_idempotency_cache_expired_returns_none() {
let cache = IdempotencyCache::new(100, Duration::from_millis(10));
cache.insert("ttl-key", 201, b"data".to_vec());
thread::sleep(Duration::from_millis(20));
assert!(
cache.get("ttl-key").is_none(),
"stale cache entry must not be returned"
);
}
}