1use bytes::Bytes;
7use http::{HeaderMap, HeaderValue, Method, StatusCode};
8use http_body_util::Full;
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::Mutex;
13
14use oxihttp_core::OxiHttpError;
15
16#[derive(Debug, Clone)]
22pub struct CorsConfig {
23 pub allowed_origins: Vec<String>,
25 pub allowed_methods: Vec<Method>,
27 pub allowed_headers: Vec<String>,
29 pub exposed_headers: Vec<String>,
31 pub allow_credentials: bool,
33 pub max_age: Option<u64>,
35}
36
37impl CorsConfig {
38 pub fn permissive() -> Self {
40 Self {
41 allowed_origins: vec!["*".to_string()],
42 allowed_methods: vec![
43 Method::GET,
44 Method::POST,
45 Method::PUT,
46 Method::DELETE,
47 Method::PATCH,
48 Method::HEAD,
49 Method::OPTIONS,
50 ],
51 allowed_headers: vec!["*".to_string()],
52 exposed_headers: Vec::new(),
53 allow_credentials: false,
54 max_age: Some(86400),
55 }
56 }
57
58 pub fn with_origins(origins: Vec<String>) -> Self {
60 Self {
61 allowed_origins: origins,
62 ..Self::permissive()
63 }
64 }
65
66 pub fn apply_headers(&self, headers: &mut HeaderMap, origin: Option<&str>) {
68 let origin_value = if self.allowed_origins.contains(&"*".to_string()) {
69 "*"
70 } else if let Some(o) = origin {
71 if self.allowed_origins.iter().any(|a| a == o) {
72 o
73 } else {
74 return; }
76 } else {
77 return;
78 };
79
80 if let Ok(val) = HeaderValue::from_str(origin_value) {
81 headers.insert(http::header::ACCESS_CONTROL_ALLOW_ORIGIN, val);
82 }
83
84 if self.allow_credentials {
85 headers.insert(
86 http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
87 HeaderValue::from_static("true"),
88 );
89 }
90
91 if !self.allowed_methods.is_empty() {
92 let methods: String = self
93 .allowed_methods
94 .iter()
95 .map(|m| m.as_str())
96 .collect::<Vec<_>>()
97 .join(", ");
98 if let Ok(val) = HeaderValue::from_str(&methods) {
99 headers.insert(http::header::ACCESS_CONTROL_ALLOW_METHODS, val);
100 }
101 }
102
103 if !self.allowed_headers.is_empty() {
104 let hdrs = self.allowed_headers.join(", ");
105 if let Ok(val) = HeaderValue::from_str(&hdrs) {
106 headers.insert(http::header::ACCESS_CONTROL_ALLOW_HEADERS, val);
107 }
108 }
109
110 if !self.exposed_headers.is_empty() {
111 let hdrs = self.exposed_headers.join(", ");
112 if let Ok(val) = HeaderValue::from_str(&hdrs) {
113 headers.insert(http::header::ACCESS_CONTROL_EXPOSE_HEADERS, val);
114 }
115 }
116
117 if let Some(max_age) = self.max_age {
118 if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
119 headers.insert(http::header::ACCESS_CONTROL_MAX_AGE, val);
120 }
121 }
122 }
123
124 pub fn preflight_response(
127 &self,
128 origin: Option<&str>,
129 ) -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
130 let mut resp = hyper::Response::builder()
131 .status(StatusCode::NO_CONTENT)
132 .body(Full::new(Bytes::new()))
133 .map_err(|e| OxiHttpError::Http(Arc::new(e)))?;
134 self.apply_headers(resp.headers_mut(), origin);
135 Ok(resp)
136 }
137}
138
139impl Default for CorsConfig {
140 fn default() -> Self {
141 Self::permissive()
142 }
143}
144
145#[derive(Debug, Clone, Copy)]
151pub struct BodyLimitConfig {
152 pub max_bytes: u64,
154}
155
156impl BodyLimitConfig {
157 pub fn new(max_bytes: u64) -> Self {
159 Self { max_bytes }
160 }
161
162 pub fn check_content_length(&self, content_length: Option<u64>) -> Result<(), OxiHttpError> {
165 if let Some(len) = content_length {
166 if len > self.max_bytes {
167 return Err(OxiHttpError::Body(format!(
168 "request body too large: {} bytes exceeds limit of {} bytes",
169 len, self.max_bytes
170 )));
171 }
172 }
173 Ok(())
174 }
175}
176
177#[derive(Clone)]
183pub struct RateLimiter {
184 inner: Arc<Mutex<RateLimiterInner>>,
185}
186
187struct RateLimiterInner {
188 buckets: HashMap<String, TokenBucket>,
190 max_tokens: u32,
192 refill_rate: f64,
194}
195
196struct TokenBucket {
197 tokens: f64,
198 last_refill: Instant,
199}
200
201impl RateLimiter {
202 pub fn new(max_tokens: u32, refill_rate: f64) -> Self {
207 Self {
208 inner: Arc::new(Mutex::new(RateLimiterInner {
209 buckets: HashMap::new(),
210 max_tokens,
211 refill_rate,
212 })),
213 }
214 }
215
216 pub async fn check(&self, key: &str) -> bool {
221 let mut inner = self.inner.lock().await;
222 let now = Instant::now();
223 let max_tokens = inner.max_tokens;
224 let refill_rate = inner.refill_rate;
225
226 let bucket = inner.buckets.entry(key.to_string()).or_insert(TokenBucket {
227 tokens: max_tokens as f64,
228 last_refill: now,
229 });
230
231 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
233 bucket.tokens = (bucket.tokens + elapsed * refill_rate).min(max_tokens as f64);
234 bucket.last_refill = now;
235
236 if bucket.tokens >= 1.0 {
237 bucket.tokens -= 1.0;
238 true
239 } else {
240 false
241 }
242 }
243
244 pub fn too_many_requests() -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
246 hyper::Response::builder()
247 .status(StatusCode::TOO_MANY_REQUESTS)
248 .body(Full::new(Bytes::from("Too Many Requests")))
249 .map_err(|e| OxiHttpError::Http(Arc::new(e)))
250 }
251}
252
253impl std::fmt::Debug for RateLimiter {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 f.debug_struct("RateLimiter").finish()
256 }
257}
258
259#[derive(Debug, Clone, Copy)]
265pub struct TimeoutConfig {
266 pub duration: Duration,
268}
269
270impl TimeoutConfig {
271 pub fn new(duration: Duration) -> Self {
273 Self { duration }
274 }
275
276 pub fn timeout_response() -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
278 hyper::Response::builder()
279 .status(StatusCode::REQUEST_TIMEOUT)
280 .body(Full::new(Bytes::from("Request Timeout")))
281 .map_err(|e| OxiHttpError::Http(Arc::new(e)))
282 }
283}
284
285#[derive(Clone)]
291pub struct MiddlewarePipeline {
292 pub cors: Option<CorsConfig>,
294 pub body_limit: Option<BodyLimitConfig>,
296 pub rate_limiter: Option<RateLimiter>,
298 pub timeout: Option<TimeoutConfig>,
300 allowed_methods: HashSet<Method>,
302}
303
304impl MiddlewarePipeline {
305 pub fn new() -> Self {
307 Self {
308 cors: None,
309 body_limit: None,
310 rate_limiter: None,
311 timeout: None,
312 allowed_methods: HashSet::new(),
313 }
314 }
315
316 pub fn with_cors(mut self, config: CorsConfig) -> Self {
318 self.allowed_methods = config.allowed_methods.iter().cloned().collect();
319 self.cors = Some(config);
320 self
321 }
322
323 pub fn with_body_limit(mut self, max_bytes: u64) -> Self {
325 self.body_limit = Some(BodyLimitConfig::new(max_bytes));
326 self
327 }
328
329 pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
331 self.rate_limiter = Some(limiter);
332 self
333 }
334
335 pub fn with_timeout(mut self, duration: Duration) -> Self {
337 self.timeout = Some(TimeoutConfig::new(duration));
338 self
339 }
340
341 pub async fn pre_handle(
347 &self,
348 req: &hyper::Request<hyper::body::Incoming>,
349 ) -> Option<Result<hyper::Response<Full<Bytes>>, OxiHttpError>> {
350 if req.method() == Method::OPTIONS {
352 if let Some(ref cors) = self.cors {
353 let origin = req
354 .headers()
355 .get(http::header::ORIGIN)
356 .and_then(|v| v.to_str().ok());
357 return Some(cors.preflight_response(origin));
358 }
359 }
360
361 if let Some(ref limiter) = self.rate_limiter {
363 let key = req
364 .headers()
365 .get("x-forwarded-for")
366 .and_then(|v| v.to_str().ok())
367 .unwrap_or("unknown")
368 .to_string();
369 if !limiter.check(&key).await {
370 return Some(RateLimiter::too_many_requests());
371 }
372 }
373
374 if let Some(ref body_limit) = self.body_limit {
376 let content_length = req
377 .headers()
378 .get(http::header::CONTENT_LENGTH)
379 .and_then(|v| v.to_str().ok())
380 .and_then(|s| s.parse::<u64>().ok());
381 if let Err(e) = body_limit.check_content_length(content_length) {
382 return Some(
383 hyper::Response::builder()
384 .status(StatusCode::PAYLOAD_TOO_LARGE)
385 .body(Full::new(Bytes::from(e.to_string())))
386 .map_err(|e| OxiHttpError::Http(Arc::new(e))),
387 );
388 }
389 }
390
391 None
392 }
393
394 pub fn post_handle(&self, resp: &mut hyper::Response<Full<Bytes>>, origin: Option<&str>) {
396 if let Some(ref cors) = self.cors {
397 cors.apply_headers(resp.headers_mut(), origin);
398 }
399 }
400}
401
402impl Default for MiddlewarePipeline {
403 fn default() -> Self {
404 Self::new()
405 }
406}
407
408impl std::fmt::Debug for MiddlewarePipeline {
409 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410 f.debug_struct("MiddlewarePipeline")
411 .field("cors", &self.cors.is_some())
412 .field("body_limit", &self.body_limit)
413 .field("rate_limiter", &self.rate_limiter.is_some())
414 .field("timeout", &self.timeout)
415 .finish()
416 }
417}