Skip to main content

oxihttp_server/
middleware.rs

1//! Middleware types for the OxiHTTP server.
2//!
3//! Provides CORS, body size limits, rate limiting, timeouts, and logging
4//! as composable middleware layers.
5
6use bytes::Bytes;
7use http::{HeaderMap, HeaderValue, Method, StatusCode};
8use http_body_util::Full;
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::Mutex;
13
14use oxihttp_core::OxiHttpError;
15
16// ---------------------------------------------------------------------------
17// CORS Middleware
18// ---------------------------------------------------------------------------
19
20/// Configuration for Cross-Origin Resource Sharing (CORS).
21#[derive(Debug, Clone)]
22pub struct CorsConfig {
23    /// Allowed origins. Use `["*"]` to allow all.
24    pub allowed_origins: Vec<String>,
25    /// Allowed HTTP methods.
26    pub allowed_methods: Vec<Method>,
27    /// Allowed request headers.
28    pub allowed_headers: Vec<String>,
29    /// Headers exposed to the client.
30    pub exposed_headers: Vec<String>,
31    /// Whether to allow credentials (cookies, auth headers).
32    pub allow_credentials: bool,
33    /// Max age for preflight cache (in seconds).
34    pub max_age: Option<u64>,
35}
36
37impl CorsConfig {
38    /// Create a permissive CORS config (allow all origins, common methods).
39    pub fn permissive() -> Self {
40        Self {
41            allowed_origins: vec!["*".to_string()],
42            allowed_methods: vec![
43                Method::GET,
44                Method::POST,
45                Method::PUT,
46                Method::DELETE,
47                Method::PATCH,
48                Method::HEAD,
49                Method::OPTIONS,
50            ],
51            allowed_headers: vec!["*".to_string()],
52            exposed_headers: Vec::new(),
53            allow_credentials: false,
54            max_age: Some(86400),
55        }
56    }
57
58    /// Create a CORS config that allows specific origins.
59    pub fn with_origins(origins: Vec<String>) -> Self {
60        Self {
61            allowed_origins: origins,
62            ..Self::permissive()
63        }
64    }
65
66    /// Apply CORS headers to a response.
67    pub fn apply_headers(&self, headers: &mut HeaderMap, origin: Option<&str>) {
68        let origin_value = if self.allowed_origins.contains(&"*".to_string()) {
69            "*"
70        } else if let Some(o) = origin {
71            if self.allowed_origins.iter().any(|a| a == o) {
72                o
73            } else {
74                return; // Origin not allowed, don't set headers
75            }
76        } else {
77            return;
78        };
79
80        if let Ok(val) = HeaderValue::from_str(origin_value) {
81            headers.insert(http::header::ACCESS_CONTROL_ALLOW_ORIGIN, val);
82        }
83
84        if self.allow_credentials {
85            headers.insert(
86                http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
87                HeaderValue::from_static("true"),
88            );
89        }
90
91        if !self.allowed_methods.is_empty() {
92            let methods: String = self
93                .allowed_methods
94                .iter()
95                .map(|m| m.as_str())
96                .collect::<Vec<_>>()
97                .join(", ");
98            if let Ok(val) = HeaderValue::from_str(&methods) {
99                headers.insert(http::header::ACCESS_CONTROL_ALLOW_METHODS, val);
100            }
101        }
102
103        if !self.allowed_headers.is_empty() {
104            let hdrs = self.allowed_headers.join(", ");
105            if let Ok(val) = HeaderValue::from_str(&hdrs) {
106                headers.insert(http::header::ACCESS_CONTROL_ALLOW_HEADERS, val);
107            }
108        }
109
110        if !self.exposed_headers.is_empty() {
111            let hdrs = self.exposed_headers.join(", ");
112            if let Ok(val) = HeaderValue::from_str(&hdrs) {
113                headers.insert(http::header::ACCESS_CONTROL_EXPOSE_HEADERS, val);
114            }
115        }
116
117        if let Some(max_age) = self.max_age {
118            if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
119                headers.insert(http::header::ACCESS_CONTROL_MAX_AGE, val);
120            }
121        }
122    }
123
124    /// Handle a preflight (OPTIONS) request, returning a 204 No Content response
125    /// with appropriate CORS headers.
126    pub fn preflight_response(
127        &self,
128        origin: Option<&str>,
129    ) -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
130        let mut resp = hyper::Response::builder()
131            .status(StatusCode::NO_CONTENT)
132            .body(Full::new(Bytes::new()))
133            .map_err(|e| OxiHttpError::Http(Arc::new(e)))?;
134        self.apply_headers(resp.headers_mut(), origin);
135        Ok(resp)
136    }
137}
138
139impl Default for CorsConfig {
140    fn default() -> Self {
141        Self::permissive()
142    }
143}
144
145// ---------------------------------------------------------------------------
146// Body Size Limit
147// ---------------------------------------------------------------------------
148
149/// Configuration for request body size limits.
150#[derive(Debug, Clone, Copy)]
151pub struct BodyLimitConfig {
152    /// Maximum body size in bytes.
153    pub max_bytes: u64,
154}
155
156impl BodyLimitConfig {
157    /// Create a body limit config with the given maximum size.
158    pub fn new(max_bytes: u64) -> Self {
159        Self { max_bytes }
160    }
161
162    /// Check if a content-length exceeds the limit.
163    /// Returns `Ok(())` if within limits, `Err` with a 413 status otherwise.
164    pub fn check_content_length(&self, content_length: Option<u64>) -> Result<(), OxiHttpError> {
165        if let Some(len) = content_length {
166            if len > self.max_bytes {
167                return Err(OxiHttpError::Body(format!(
168                    "request body too large: {} bytes exceeds limit of {} bytes",
169                    len, self.max_bytes
170                )));
171            }
172        }
173        Ok(())
174    }
175}
176
177// ---------------------------------------------------------------------------
178// Rate Limiting (Token Bucket)
179// ---------------------------------------------------------------------------
180
181/// Rate limiter using the token bucket algorithm.
182#[derive(Clone)]
183pub struct RateLimiter {
184    inner: Arc<Mutex<RateLimiterInner>>,
185}
186
187struct RateLimiterInner {
188    /// Buckets keyed by IP address or route identifier.
189    buckets: HashMap<String, TokenBucket>,
190    /// Maximum tokens per bucket.
191    max_tokens: u32,
192    /// Token refill rate (tokens per second).
193    refill_rate: f64,
194}
195
196struct TokenBucket {
197    tokens: f64,
198    last_refill: Instant,
199}
200
201impl RateLimiter {
202    /// Create a new rate limiter.
203    ///
204    /// - `max_tokens`: Maximum number of tokens (burst capacity).
205    /// - `refill_rate`: Tokens added per second.
206    pub fn new(max_tokens: u32, refill_rate: f64) -> Self {
207        Self {
208            inner: Arc::new(Mutex::new(RateLimiterInner {
209                buckets: HashMap::new(),
210                max_tokens,
211                refill_rate,
212            })),
213        }
214    }
215
216    /// Check if a request from the given key is allowed.
217    ///
218    /// Returns `true` if the request is allowed (token consumed),
219    /// `false` if rate-limited (429 should be returned).
220    pub async fn check(&self, key: &str) -> bool {
221        let mut inner = self.inner.lock().await;
222        let now = Instant::now();
223        let max_tokens = inner.max_tokens;
224        let refill_rate = inner.refill_rate;
225
226        let bucket = inner.buckets.entry(key.to_string()).or_insert(TokenBucket {
227            tokens: max_tokens as f64,
228            last_refill: now,
229        });
230
231        // Refill tokens based on elapsed time
232        let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
233        bucket.tokens = (bucket.tokens + elapsed * refill_rate).min(max_tokens as f64);
234        bucket.last_refill = now;
235
236        if bucket.tokens >= 1.0 {
237            bucket.tokens -= 1.0;
238            true
239        } else {
240            false
241        }
242    }
243
244    /// Build a 429 Too Many Requests response.
245    pub fn too_many_requests() -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
246        hyper::Response::builder()
247            .status(StatusCode::TOO_MANY_REQUESTS)
248            .body(Full::new(Bytes::from("Too Many Requests")))
249            .map_err(|e| OxiHttpError::Http(Arc::new(e)))
250    }
251}
252
253impl std::fmt::Debug for RateLimiter {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        f.debug_struct("RateLimiter").finish()
256    }
257}
258
259// ---------------------------------------------------------------------------
260// Request Timeout
261// ---------------------------------------------------------------------------
262
263/// Configuration for request processing timeout.
264#[derive(Debug, Clone, Copy)]
265pub struct TimeoutConfig {
266    /// Maximum time to process a request.
267    pub duration: Duration,
268}
269
270impl TimeoutConfig {
271    /// Create a timeout config.
272    pub fn new(duration: Duration) -> Self {
273        Self { duration }
274    }
275
276    /// Build a 408 Request Timeout response.
277    pub fn timeout_response() -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
278        hyper::Response::builder()
279            .status(StatusCode::REQUEST_TIMEOUT)
280            .body(Full::new(Bytes::from("Request Timeout")))
281            .map_err(|e| OxiHttpError::Http(Arc::new(e)))
282    }
283}
284
285// ---------------------------------------------------------------------------
286// Middleware Pipeline
287// ---------------------------------------------------------------------------
288
289/// The middleware pipeline configuration for a server.
290#[derive(Clone)]
291pub struct MiddlewarePipeline {
292    /// CORS configuration (applied to all responses).
293    pub cors: Option<CorsConfig>,
294    /// Body size limit (checked before handler).
295    pub body_limit: Option<BodyLimitConfig>,
296    /// Rate limiter (checked before handler).
297    pub rate_limiter: Option<RateLimiter>,
298    /// Request timeout.
299    pub timeout: Option<TimeoutConfig>,
300    /// Allowed methods for CORS preflight (derived from CORS config).
301    allowed_methods: HashSet<Method>,
302}
303
304impl MiddlewarePipeline {
305    /// Create an empty middleware pipeline.
306    pub fn new() -> Self {
307        Self {
308            cors: None,
309            body_limit: None,
310            rate_limiter: None,
311            timeout: None,
312            allowed_methods: HashSet::new(),
313        }
314    }
315
316    /// Add CORS middleware.
317    pub fn with_cors(mut self, config: CorsConfig) -> Self {
318        self.allowed_methods = config.allowed_methods.iter().cloned().collect();
319        self.cors = Some(config);
320        self
321    }
322
323    /// Add body size limit middleware.
324    pub fn with_body_limit(mut self, max_bytes: u64) -> Self {
325        self.body_limit = Some(BodyLimitConfig::new(max_bytes));
326        self
327    }
328
329    /// Add rate limiting middleware.
330    pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
331        self.rate_limiter = Some(limiter);
332        self
333    }
334
335    /// Add request timeout middleware.
336    pub fn with_timeout(mut self, duration: Duration) -> Self {
337        self.timeout = Some(TimeoutConfig::new(duration));
338        self
339    }
340
341    /// Run pre-handler middleware checks.
342    ///
343    /// Returns `Some(response)` if middleware short-circuits (e.g. CORS preflight,
344    /// rate limit exceeded, body too large). Returns `None` if the request should
345    /// proceed to the handler.
346    pub async fn pre_handle(
347        &self,
348        req: &hyper::Request<hyper::body::Incoming>,
349    ) -> Option<Result<hyper::Response<Full<Bytes>>, OxiHttpError>> {
350        // CORS preflight
351        if req.method() == Method::OPTIONS {
352            if let Some(ref cors) = self.cors {
353                let origin = req
354                    .headers()
355                    .get(http::header::ORIGIN)
356                    .and_then(|v| v.to_str().ok());
357                return Some(cors.preflight_response(origin));
358            }
359        }
360
361        // Rate limiting
362        if let Some(ref limiter) = self.rate_limiter {
363            let key = req
364                .headers()
365                .get("x-forwarded-for")
366                .and_then(|v| v.to_str().ok())
367                .unwrap_or("unknown")
368                .to_string();
369            if !limiter.check(&key).await {
370                return Some(RateLimiter::too_many_requests());
371            }
372        }
373
374        // Body size limit
375        if let Some(ref body_limit) = self.body_limit {
376            let content_length = req
377                .headers()
378                .get(http::header::CONTENT_LENGTH)
379                .and_then(|v| v.to_str().ok())
380                .and_then(|s| s.parse::<u64>().ok());
381            if let Err(e) = body_limit.check_content_length(content_length) {
382                return Some(
383                    hyper::Response::builder()
384                        .status(StatusCode::PAYLOAD_TOO_LARGE)
385                        .body(Full::new(Bytes::from(e.to_string())))
386                        .map_err(|e| OxiHttpError::Http(Arc::new(e))),
387                );
388            }
389        }
390
391        None
392    }
393
394    /// Apply post-handler middleware (e.g. CORS headers) to a response.
395    pub fn post_handle(&self, resp: &mut hyper::Response<Full<Bytes>>, origin: Option<&str>) {
396        if let Some(ref cors) = self.cors {
397            cors.apply_headers(resp.headers_mut(), origin);
398        }
399    }
400}
401
402impl Default for MiddlewarePipeline {
403    fn default() -> Self {
404        Self::new()
405    }
406}
407
408impl std::fmt::Debug for MiddlewarePipeline {
409    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410        f.debug_struct("MiddlewarePipeline")
411            .field("cors", &self.cors.is_some())
412            .field("body_limit", &self.body_limit)
413            .field("rate_limiter", &self.rate_limiter.is_some())
414            .field("timeout", &self.timeout)
415            .finish()
416    }
417}