use crate::{context::Context, error::Result, response::Response};
use http_body_util::Full;
use hyper::{body::Bytes, Response as HyperResponse};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub type Next<'a> = Box<
dyn FnOnce(Context) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + 'a>> + Send + 'a,
>;
pub type BoxedMiddleware = Arc<
dyn for<'a> Fn(Context, Next<'a>) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + 'a>>
+ Send
+ Sync,
>;
pub trait IntoMiddleware {
fn into_middleware(self) -> BoxedMiddleware;
}
impl<F> IntoMiddleware for F
where
F: for<'a> Fn(Context, Next<'a>) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + 'a>>
+ Send
+ Sync
+ 'static,
{
fn into_middleware(self) -> BoxedMiddleware {
Arc::new(self)
}
}
pub struct MiddlewareChain {
middleware: Vec<BoxedMiddleware>,
}
impl MiddlewareChain {
pub fn new() -> Self {
Self {
middleware: Vec::new(),
}
}
pub fn push(&mut self, middleware: BoxedMiddleware) {
self.middleware.push(middleware);
}
pub async fn execute<F, Fut>(self, ctx: Context, handler: F) -> Result<Response>
where
F: FnOnce(Context) -> Fut + Send + 'static,
Fut: Future<Output = Result<Response>> + Send + 'static,
{
self.execute_at(ctx, 0, Box::new(handler)).await
}
fn execute_at<F, Fut>(
self,
ctx: Context,
index: usize,
handler: Box<F>,
) -> Pin<Box<dyn Future<Output = Result<Response>> + Send>>
where
F: FnOnce(Context) -> Fut + Send + 'static,
Fut: Future<Output = Result<Response>> + Send + 'static,
{
Box::pin(async move {
if index >= self.middleware.len() {
return handler(ctx).await;
}
let current_middleware = self.middleware[index].clone();
let next_index = index + 1;
let next: Next = Box::new(move |ctx| self.execute_at(ctx, next_index, handler));
current_middleware(ctx, next).await
})
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
pub mod builtin {
use super::*;
use std::time::Instant;
use tracing::{error, info};
pub fn logger() -> BoxedMiddleware {
Arc::new(|ctx, next| {
Box::pin(async move {
let method = ctx.req.method().clone();
let path = ctx.req.path().to_string();
let start = Instant::now();
info!("--> {} {}", method, path);
let result = next(ctx).await;
let duration = start.elapsed();
match &result {
Ok(response) => {
info!(
"<-- {} {} {} ({:?})",
method,
path,
response.status().as_u16(),
duration
);
}
Err(err) => {
error!("<-- {} {} ERROR: {} ({:?})", method, path, err, duration);
}
}
result
})
})
}
pub struct Cors {
allow_origin: String,
allow_methods: Vec<String>,
allow_headers: Vec<String>,
}
impl Cors {
pub fn new() -> Self {
Self {
allow_origin: "*".to_string(),
allow_methods: vec!["GET".to_string(), "POST".to_string()],
allow_headers: vec!["Content-Type".to_string()],
}
}
pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
self.allow_origin = origin.into();
self
}
pub fn allow_methods(mut self, methods: Vec<impl Into<String>>) -> Self {
self.allow_methods = methods.into_iter().map(|m| m.into()).collect();
self
}
pub fn allow_headers(mut self, headers: Vec<impl Into<String>>) -> Self {
self.allow_headers = headers.into_iter().map(|h| h.into()).collect();
self
}
pub fn build(self) -> BoxedMiddleware {
let origin = self.allow_origin;
let methods = self.allow_methods.join(", ");
let headers = self.allow_headers.join(", ");
Arc::new(move |ctx, next| {
let origin = origin.clone();
let methods = methods.clone();
let headers = headers.clone();
Box::pin(async move {
if ctx.req.method() == "OPTIONS" {
let response = HyperResponse::builder()
.status(204)
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Methods", methods)
.header("Access-Control-Allow-Headers", headers)
.body(Full::new(Bytes::new()))
.unwrap();
return Ok(response);
}
ctx.header("Access-Control-Allow-Origin", origin).await;
ctx.header("Access-Control-Allow-Methods", methods).await;
ctx.header("Access-Control-Allow-Headers", headers).await;
next(ctx).await
})
})
}
}
impl Default for Cors {
fn default() -> Self {
Self::new()
}
}
pub fn cors() -> BoxedMiddleware {
Cors::new().build()
}
pub fn powered_by() -> BoxedMiddleware {
Arc::new(|ctx, next| {
Box::pin(async move {
ctx.header("X-Powered-By", "Ultimo").await;
next(ctx).await
})
})
}
pub fn server_headers(name: impl Into<String>, include_version: bool) -> BoxedMiddleware {
let name = name.into();
let version = if include_version {
format!("{}/{}", name, env!("CARGO_PKG_VERSION"))
} else {
name.clone()
};
Arc::new(move |ctx, next| {
let powered_by = version.clone();
Box::pin(async move {
ctx.header("X-Powered-By", powered_by).await;
next(ctx).await
})
})
}
#[derive(Debug, Clone)]
pub struct SecurityHeaders {
hsts: Option<String>,
csp: Option<String>,
frame_options: Option<String>,
content_type_options: bool,
referrer_policy: Option<String>,
permissions_policy: Option<String>,
}
impl Default for SecurityHeaders {
fn default() -> Self {
Self {
hsts: Some("max-age=31536000; includeSubDomains".to_string()),
csp: None,
frame_options: Some("DENY".to_string()),
content_type_options: true,
referrer_policy: Some("strict-origin-when-cross-origin".to_string()),
permissions_policy: Some("geolocation=(), microphone=(), camera=()".to_string()),
}
}
}
impl SecurityHeaders {
pub fn new() -> Self {
Self::default()
}
pub fn hsts(mut self, value: impl Into<String>) -> Self {
self.hsts = Some(value.into());
self
}
pub fn no_hsts(mut self) -> Self {
self.hsts = None;
self
}
pub fn csp(mut self, value: impl Into<String>) -> Self {
self.csp = Some(value.into());
self
}
pub fn frame_options(mut self, value: impl Into<String>) -> Self {
self.frame_options = Some(value.into());
self
}
pub fn referrer_policy(mut self, value: impl Into<String>) -> Self {
self.referrer_policy = Some(value.into());
self
}
pub fn permissions_policy(mut self, value: impl Into<String>) -> Self {
self.permissions_policy = Some(value.into());
self
}
pub fn no_content_type_options(mut self) -> Self {
self.content_type_options = false;
self
}
fn pairs(&self) -> Vec<(&'static str, String)> {
let mut out = Vec::new();
if let Some(v) = &self.hsts {
out.push(("strict-transport-security", v.clone()));
}
if let Some(v) = &self.csp {
out.push(("content-security-policy", v.clone()));
}
if let Some(v) = &self.frame_options {
out.push(("x-frame-options", v.clone()));
}
if self.content_type_options {
out.push(("x-content-type-options", "nosniff".to_string()));
}
if let Some(v) = &self.referrer_policy {
out.push(("referrer-policy", v.clone()));
}
if let Some(v) = &self.permissions_policy {
out.push(("permissions-policy", v.clone()));
}
out
}
pub fn build(self) -> BoxedMiddleware {
let pairs = Arc::new(self.pairs());
Arc::new(move |ctx, next| {
let pairs = pairs.clone();
Box::pin(async move {
let mut response = next(ctx).await?;
let headers = response.headers_mut();
for (name, value) in pairs.iter() {
let header_name = hyper::header::HeaderName::from_static(name);
if !headers.contains_key(&header_name) {
if let Ok(hv) = hyper::header::HeaderValue::from_str(value) {
headers.insert(header_name, hv);
}
}
}
Ok(response)
})
})
}
}
pub fn security_headers() -> BoxedMiddleware {
SecurityHeaders::new().build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_middleware_chain_creation() {
let chain = MiddlewareChain::new();
assert_eq!(chain.middleware.len(), 0);
}
#[test]
fn test_middleware_chain_push() {
let mut chain = MiddlewareChain::new();
let middleware: BoxedMiddleware =
Arc::new(|ctx, next| Box::pin(async move { next(ctx).await }));
chain.push(middleware.clone());
assert_eq!(chain.middleware.len(), 1);
chain.push(middleware);
assert_eq!(chain.middleware.len(), 2);
}
#[test]
fn test_cors_builder_creation() {
let _cors = builtin::Cors::default();
let _cors2 = builtin::Cors::new();
}
#[test]
fn test_cors_builder_chaining() {
let cors = builtin::Cors::new()
.allow_origin("https://example.com")
.allow_methods(vec!["GET", "POST"])
.allow_headers(vec!["Authorization"]);
let _middleware = cors.build();
}
#[test]
fn test_cors_convenience_function() {
let _cors = builtin::cors();
}
#[test]
fn test_logger_convenience_function() {
let _logger = builtin::logger();
}
#[test]
fn test_powered_by_convenience_function() {
let _powered_by = builtin::powered_by();
}
#[test]
fn test_middleware_chain_default() {
let chain1 = MiddlewareChain::default();
let chain2 = MiddlewareChain::new();
assert_eq!(chain1.middleware.len(), chain2.middleware.len());
}
#[test]
fn test_boxed_middleware_creation() {
let _middleware: BoxedMiddleware = Arc::new(|ctx, next| {
Box::pin(async move {
let result = next(ctx).await;
result
})
});
}
#[test]
fn test_middleware_passthrough() {
let _passthrough: BoxedMiddleware =
Arc::new(|ctx, next| Box::pin(async move { next(ctx).await }));
}
#[test]
fn test_cors_multiple_methods() {
let cors =
builtin::Cors::new().allow_methods(vec!["GET", "POST", "PUT", "PATCH", "DELETE"]);
let _middleware = cors.build();
}
#[test]
fn test_cors_multiple_headers() {
let cors = builtin::Cors::new().allow_headers(vec![
"Content-Type",
"Authorization",
"X-Custom-Header",
]);
let _middleware = cors.build();
}
#[test]
fn test_cors_custom_origin() {
let cors = builtin::Cors::new().allow_origin("https://app.example.com");
let _middleware = cors.build();
}
#[test]
fn test_cors_builder_defaults() {
let cors = builtin::Cors::default();
let _middleware = cors.build();
}
#[test]
fn test_server_headers_builder() {
let _middleware = builtin::server_headers("CustomServer", false);
let _middleware_with_version = builtin::server_headers("Ultimo", true);
}
#[test]
fn test_cors_origin_string_conversion() {
let cors1 = builtin::Cors::new().allow_origin("https://example.com");
let cors2 = builtin::Cors::new().allow_origin(String::from("https://test.com"));
let _m1 = cors1.build();
let _m2 = cors2.build();
}
#[test]
fn test_cors_methods_string_conversion() {
let cors = builtin::Cors::new().allow_methods(vec!["GET", "POST"]);
let _middleware = cors.build();
}
#[test]
fn test_cors_headers_string_conversion() {
let cors = builtin::Cors::new().allow_headers(vec!["Content-Type"]);
let _middleware = cors.build();
}
#[test]
fn test_middleware_arc_clone() {
let middleware: BoxedMiddleware =
Arc::new(|ctx, next| Box::pin(async move { next(ctx).await }));
let cloned = middleware.clone();
assert_eq!(Arc::strong_count(&middleware), Arc::strong_count(&cloned));
}
#[test]
fn test_middleware_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<BoxedMiddleware>();
assert_sync::<BoxedMiddleware>();
}
#[test]
fn test_middleware_chain_is_send() {
fn assert_send<T: Send>() {}
assert_send::<MiddlewareChain>();
}
#[test]
fn test_cors_struct_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<builtin::Cors>();
assert_sync::<builtin::Cors>();
}
}