use bytes::Bytes;
use http::{HeaderMap, HeaderValue, Method, StatusCode};
use http_body_util::Full;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use oxihttp_core::OxiHttpError;
#[derive(Debug, Clone)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<Method>,
pub allowed_headers: Vec<String>,
pub exposed_headers: Vec<String>,
pub allow_credentials: bool,
pub max_age: Option<u64>,
}
impl CorsConfig {
pub fn permissive() -> Self {
Self {
allowed_origins: vec!["*".to_string()],
allowed_methods: vec![
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::PATCH,
Method::HEAD,
Method::OPTIONS,
],
allowed_headers: vec!["*".to_string()],
exposed_headers: Vec::new(),
allow_credentials: false,
max_age: Some(86400),
}
}
pub fn with_origins(origins: Vec<String>) -> Self {
Self {
allowed_origins: origins,
..Self::permissive()
}
}
pub fn apply_headers(&self, headers: &mut HeaderMap, origin: Option<&str>) {
let origin_value = if self.allowed_origins.contains(&"*".to_string()) {
"*"
} else if let Some(o) = origin {
if self.allowed_origins.iter().any(|a| a == o) {
o
} else {
return; }
} else {
return;
};
if let Ok(val) = HeaderValue::from_str(origin_value) {
headers.insert(http::header::ACCESS_CONTROL_ALLOW_ORIGIN, val);
}
if self.allow_credentials {
headers.insert(
http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if !self.allowed_methods.is_empty() {
let methods: String = self
.allowed_methods
.iter()
.map(|m| m.as_str())
.collect::<Vec<_>>()
.join(", ");
if let Ok(val) = HeaderValue::from_str(&methods) {
headers.insert(http::header::ACCESS_CONTROL_ALLOW_METHODS, val);
}
}
if !self.allowed_headers.is_empty() {
let hdrs = self.allowed_headers.join(", ");
if let Ok(val) = HeaderValue::from_str(&hdrs) {
headers.insert(http::header::ACCESS_CONTROL_ALLOW_HEADERS, val);
}
}
if !self.exposed_headers.is_empty() {
let hdrs = self.exposed_headers.join(", ");
if let Ok(val) = HeaderValue::from_str(&hdrs) {
headers.insert(http::header::ACCESS_CONTROL_EXPOSE_HEADERS, val);
}
}
if let Some(max_age) = self.max_age {
if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
headers.insert(http::header::ACCESS_CONTROL_MAX_AGE, val);
}
}
}
pub fn preflight_response(
&self,
origin: Option<&str>,
) -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
let mut resp = hyper::Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Full::new(Bytes::new()))
.map_err(|e| OxiHttpError::Http(Arc::new(e)))?;
self.apply_headers(resp.headers_mut(), origin);
Ok(resp)
}
}
impl Default for CorsConfig {
fn default() -> Self {
Self::permissive()
}
}
#[derive(Debug, Clone, Copy)]
pub struct BodyLimitConfig {
pub max_bytes: u64,
}
impl BodyLimitConfig {
pub fn new(max_bytes: u64) -> Self {
Self { max_bytes }
}
pub fn check_content_length(&self, content_length: Option<u64>) -> Result<(), OxiHttpError> {
if let Some(len) = content_length {
if len > self.max_bytes {
return Err(OxiHttpError::Body(format!(
"request body too large: {} bytes exceeds limit of {} bytes",
len, self.max_bytes
)));
}
}
Ok(())
}
}
#[derive(Clone)]
pub struct RateLimiter {
inner: Arc<Mutex<RateLimiterInner>>,
}
struct RateLimiterInner {
buckets: HashMap<String, TokenBucket>,
max_tokens: u32,
refill_rate: f64,
}
struct TokenBucket {
tokens: f64,
last_refill: Instant,
}
impl RateLimiter {
pub fn new(max_tokens: u32, refill_rate: f64) -> Self {
Self {
inner: Arc::new(Mutex::new(RateLimiterInner {
buckets: HashMap::new(),
max_tokens,
refill_rate,
})),
}
}
pub async fn check(&self, key: &str) -> bool {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let max_tokens = inner.max_tokens;
let refill_rate = inner.refill_rate;
let bucket = inner.buckets.entry(key.to_string()).or_insert(TokenBucket {
tokens: max_tokens as f64,
last_refill: now,
});
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * refill_rate).min(max_tokens as f64);
bucket.last_refill = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
true
} else {
false
}
}
pub fn too_many_requests() -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
hyper::Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Full::new(Bytes::from("Too Many Requests")))
.map_err(|e| OxiHttpError::Http(Arc::new(e)))
}
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter").finish()
}
}
#[derive(Debug, Clone, Copy)]
pub struct TimeoutConfig {
pub duration: Duration,
}
impl TimeoutConfig {
pub fn new(duration: Duration) -> Self {
Self { duration }
}
pub fn timeout_response() -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
hyper::Response::builder()
.status(StatusCode::REQUEST_TIMEOUT)
.body(Full::new(Bytes::from("Request Timeout")))
.map_err(|e| OxiHttpError::Http(Arc::new(e)))
}
}
#[derive(Clone)]
pub struct MiddlewarePipeline {
pub cors: Option<CorsConfig>,
pub body_limit: Option<BodyLimitConfig>,
pub rate_limiter: Option<RateLimiter>,
pub timeout: Option<TimeoutConfig>,
allowed_methods: HashSet<Method>,
}
impl MiddlewarePipeline {
pub fn new() -> Self {
Self {
cors: None,
body_limit: None,
rate_limiter: None,
timeout: None,
allowed_methods: HashSet::new(),
}
}
pub fn with_cors(mut self, config: CorsConfig) -> Self {
self.allowed_methods = config.allowed_methods.iter().cloned().collect();
self.cors = Some(config);
self
}
pub fn with_body_limit(mut self, max_bytes: u64) -> Self {
self.body_limit = Some(BodyLimitConfig::new(max_bytes));
self
}
pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
self.rate_limiter = Some(limiter);
self
}
pub fn with_timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(TimeoutConfig::new(duration));
self
}
pub async fn pre_handle(
&self,
req: &hyper::Request<hyper::body::Incoming>,
) -> Option<Result<hyper::Response<Full<Bytes>>, OxiHttpError>> {
if req.method() == Method::OPTIONS {
if let Some(ref cors) = self.cors {
let origin = req
.headers()
.get(http::header::ORIGIN)
.and_then(|v| v.to_str().ok());
return Some(cors.preflight_response(origin));
}
}
if let Some(ref limiter) = self.rate_limiter {
let key = req
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
if !limiter.check(&key).await {
return Some(RateLimiter::too_many_requests());
}
}
if let Some(ref body_limit) = self.body_limit {
let content_length = req
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if let Err(e) = body_limit.check_content_length(content_length) {
return Some(
hyper::Response::builder()
.status(StatusCode::PAYLOAD_TOO_LARGE)
.body(Full::new(Bytes::from(e.to_string())))
.map_err(|e| OxiHttpError::Http(Arc::new(e))),
);
}
}
None
}
pub fn post_handle(&self, resp: &mut hyper::Response<Full<Bytes>>, origin: Option<&str>) {
if let Some(ref cors) = self.cors {
cors.apply_headers(resp.headers_mut(), origin);
}
}
}
impl Default for MiddlewarePipeline {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for MiddlewarePipeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MiddlewarePipeline")
.field("cors", &self.cors.is_some())
.field("body_limit", &self.body_limit)
.field("rate_limiter", &self.rate_limiter.is_some())
.field("timeout", &self.timeout)
.finish()
}
}