use elif_http::middleware::MiddlewarePipeline;
use crate::{
middleware::{cors::CorsMiddleware, csrf::CsrfMiddleware},
config::{CorsConfig, CsrfConfig},
};
#[derive(Debug, Default)]
pub struct SecurityMiddlewareBuilder {
cors_config: Option<CorsConfig>,
csrf_config: Option<CsrfConfig>,
}
impl SecurityMiddlewareBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_cors(mut self, config: CorsConfig) -> Self {
self.cors_config = Some(config);
self
}
pub fn with_cors_permissive(mut self) -> Self {
self.cors_config = Some(CorsConfig::default());
self
}
pub fn with_csrf(mut self, config: CsrfConfig) -> Self {
self.csrf_config = Some(config);
self
}
pub fn with_csrf_default(mut self) -> Self {
self.csrf_config = Some(CsrfConfig::default());
self
}
pub fn build(self) -> MiddlewarePipeline {
let mut pipeline = MiddlewarePipeline::new();
if let Some(cors_config) = self.cors_config {
let cors_middleware = CorsMiddleware::new(cors_config);
pipeline = pipeline.add(cors_middleware);
}
if let Some(csrf_config) = self.csrf_config {
let csrf_middleware = CsrfMiddleware::new(csrf_config);
pipeline = pipeline.add(csrf_middleware);
}
pipeline
}
}
pub fn basic_security_pipeline() -> MiddlewarePipeline {
SecurityMiddlewareBuilder::new()
.with_cors_permissive()
.with_csrf_default()
.build()
}
pub fn strict_security_pipeline(allowed_origins: Vec<String>) -> MiddlewarePipeline {
use std::collections::HashSet;
let cors_config = CorsConfig {
allowed_origins: Some(allowed_origins.into_iter().collect::<HashSet<_>>()),
allow_credentials: true,
max_age: Some(300), ..CorsConfig::default()
};
let csrf_config = CsrfConfig {
secure_cookie: true,
token_lifetime: 3600, ..CsrfConfig::default()
};
SecurityMiddlewareBuilder::new()
.with_cors(cors_config)
.with_csrf(csrf_config)
.build()
}
pub fn development_security_pipeline() -> MiddlewarePipeline {
let cors_config = CorsConfig {
allowed_origins: None, allow_credentials: false,
..CorsConfig::default()
};
let csrf_config = CsrfConfig {
secure_cookie: false, token_lifetime: 7200, ..CsrfConfig::default()
};
SecurityMiddlewareBuilder::new()
.with_cors(cors_config)
.with_csrf(csrf_config)
.build()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{extract::Request, http::Method, body::Body};
#[tokio::test]
async fn test_basic_security_pipeline() {
let pipeline = basic_security_pipeline();
assert_eq!(pipeline.len(), 2);
assert_eq!(pipeline.names(), vec!["CorsMiddleware", "CsrfMiddleware"]);
}
#[tokio::test]
async fn test_security_middleware_builder() {
let cors_config = CorsConfig::default();
let csrf_config = CsrfConfig::default();
let pipeline = SecurityMiddlewareBuilder::new()
.with_cors(cors_config)
.with_csrf(csrf_config)
.build();
assert_eq!(pipeline.len(), 2);
assert!(pipeline.names().contains(&"CorsMiddleware"));
assert!(pipeline.names().contains(&"CsrfMiddleware"));
}
#[tokio::test]
async fn test_cors_only_pipeline() {
let pipeline = SecurityMiddlewareBuilder::new()
.with_cors_permissive()
.build();
assert_eq!(pipeline.len(), 1);
assert_eq!(pipeline.names(), vec!["CorsMiddleware"]);
}
#[tokio::test]
async fn test_csrf_only_pipeline() {
let pipeline = SecurityMiddlewareBuilder::new()
.with_csrf_default()
.build();
assert_eq!(pipeline.len(), 1);
assert_eq!(pipeline.names(), vec!["CsrfMiddleware"]);
}
#[tokio::test]
async fn test_security_pipeline_processing() {
let pipeline = basic_security_pipeline();
let request = Request::builder()
.method(Method::GET)
.uri("/")
.header("Origin", "https://example.com")
.body(Body::empty())
.unwrap();
let result = pipeline.process_request(request).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_strict_security_pipeline() {
let allowed_origins = vec!["https://trusted.com".to_string()];
let pipeline = strict_security_pipeline(allowed_origins);
assert_eq!(pipeline.len(), 2);
let request = Request::builder()
.method(Method::GET)
.uri("/")
.header("Origin", "https://trusted.com")
.body(Body::empty())
.unwrap();
let result = pipeline.process_request(request).await;
assert!(result.is_ok());
let request = Request::builder()
.method(Method::GET)
.uri("/")
.header("Origin", "https://evil.com")
.body(Body::empty())
.unwrap();
let result = pipeline.process_request(request).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_development_security_pipeline() {
let pipeline = development_security_pipeline();
assert_eq!(pipeline.len(), 2);
let request = Request::builder()
.method(Method::GET)
.uri("/")
.header("Origin", "http://localhost:3000")
.body(Body::empty())
.unwrap();
let result = pipeline.process_request(request).await;
assert!(result.is_ok());
}
}