use crate::{context::Context, error::Result, response::Response};
use http_body_util::Full;
use hyper::{body::Bytes, Response as HyperResponse};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub type Next<'a> = Box<
dyn FnOnce(Context) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + 'a>> + Send + 'a,
>;
pub type BoxedMiddleware = Arc<
dyn for<'a> Fn(Context, Next<'a>) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + 'a>>
+ Send
+ Sync,
>;
pub trait IntoMiddleware {
fn into_middleware(self) -> BoxedMiddleware;
}
impl<F> IntoMiddleware for F
where
F: for<'a> Fn(Context, Next<'a>) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + 'a>>
+ Send
+ Sync
+ 'static,
{
fn into_middleware(self) -> BoxedMiddleware {
Arc::new(self)
}
}
pub struct MiddlewareChain {
middleware: Vec<BoxedMiddleware>,
}
impl MiddlewareChain {
pub fn new() -> Self {
Self {
middleware: Vec::new(),
}
}
pub fn push(&mut self, middleware: BoxedMiddleware) {
self.middleware.push(middleware);
}
pub async fn execute<F, Fut>(self, ctx: Context, handler: F) -> Result<Response>
where
F: FnOnce(Context) -> Fut + Send + 'static,
Fut: Future<Output = Result<Response>> + Send + 'static,
{
self.execute_at(ctx, 0, Box::new(handler)).await
}
fn execute_at<F, Fut>(
self,
ctx: Context,
index: usize,
handler: Box<F>,
) -> Pin<Box<dyn Future<Output = Result<Response>> + Send>>
where
F: FnOnce(Context) -> Fut + Send + 'static,
Fut: Future<Output = Result<Response>> + Send + 'static,
{
Box::pin(async move {
if index >= self.middleware.len() {
return handler(ctx).await;
}
let current_middleware = self.middleware[index].clone();
let next_index = index + 1;
let next: Next = Box::new(move |ctx| self.execute_at(ctx, next_index, handler));
current_middleware(ctx, next).await
})
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
pub mod builtin {
use super::*;
use std::time::Instant;
use tracing::{error, info};
pub fn logger() -> BoxedMiddleware {
Arc::new(|ctx, next| {
Box::pin(async move {
let method = ctx.req.method().clone();
let path = ctx.req.path().to_string();
let start = Instant::now();
info!("--> {} {}", method, path);
let result = next(ctx).await;
let duration = start.elapsed();
match &result {
Ok(response) => {
info!(
"<-- {} {} {} ({:?})",
method,
path,
response.status().as_u16(),
duration
);
}
Err(err) => {
error!("<-- {} {} ERROR: {} ({:?})", method, path, err, duration);
}
}
result
})
})
}
pub struct Cors {
allow_origin: String,
allow_methods: Vec<String>,
allow_headers: Vec<String>,
}
impl Cors {
pub fn new() -> Self {
Self {
allow_origin: "*".to_string(),
allow_methods: vec!["GET".to_string(), "POST".to_string()],
allow_headers: vec!["Content-Type".to_string()],
}
}
pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
self.allow_origin = origin.into();
self
}
pub fn allow_methods(mut self, methods: Vec<impl Into<String>>) -> Self {
self.allow_methods = methods.into_iter().map(|m| m.into()).collect();
self
}
pub fn allow_headers(mut self, headers: Vec<impl Into<String>>) -> Self {
self.allow_headers = headers.into_iter().map(|h| h.into()).collect();
self
}
pub fn build(self) -> BoxedMiddleware {
let origin = self.allow_origin;
let methods = self.allow_methods.join(", ");
let headers = self.allow_headers.join(", ");
Arc::new(move |ctx, next| {
let origin = origin.clone();
let methods = methods.clone();
let headers = headers.clone();
Box::pin(async move {
if ctx.req.method() == "OPTIONS" {
let response = HyperResponse::builder()
.status(204)
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Methods", methods)
.header("Access-Control-Allow-Headers", headers)
.body(Full::new(Bytes::new()))
.unwrap();
return Ok(response);
}
ctx.header("Access-Control-Allow-Origin", origin).await;
ctx.header("Access-Control-Allow-Methods", methods).await;
ctx.header("Access-Control-Allow-Headers", headers).await;
next(ctx).await
})
})
}
}
impl Default for Cors {
fn default() -> Self {
Self::new()
}
}
pub fn cors() -> BoxedMiddleware {
Cors::new().build()
}
pub fn powered_by() -> BoxedMiddleware {
Arc::new(|ctx, next| {
Box::pin(async move {
ctx.header("X-Powered-By", "Ultimo").await;
next(ctx).await
})
})
}
pub fn server_headers(name: impl Into<String>, include_version: bool) -> BoxedMiddleware {
let name = name.into();
let version = if include_version {
format!("{}/{}", name, env!("CARGO_PKG_VERSION"))
} else {
name.clone()
};
Arc::new(move |ctx, next| {
let powered_by = version.clone();
Box::pin(async move {
ctx.header("X-Powered-By", powered_by).await;
next(ctx).await
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_middleware_chain_creation() {
let chain = MiddlewareChain::new();
assert_eq!(chain.middleware.len(), 0);
}
#[test]
fn test_middleware_chain_push() {
let mut chain = MiddlewareChain::new();
let middleware: BoxedMiddleware =
Arc::new(|ctx, next| Box::pin(async move { next(ctx).await }));
chain.push(middleware.clone());
assert_eq!(chain.middleware.len(), 1);
chain.push(middleware);
assert_eq!(chain.middleware.len(), 2);
}
#[test]
fn test_cors_builder_creation() {
let _cors = builtin::Cors::default();
let _cors2 = builtin::Cors::new();
}
#[test]
fn test_cors_builder_chaining() {
let cors = builtin::Cors::new()
.allow_origin("https://example.com")
.allow_methods(vec!["GET", "POST"])
.allow_headers(vec!["Authorization"]);
let _middleware = cors.build();
}
#[test]
fn test_cors_convenience_function() {
let _cors = builtin::cors();
}
#[test]
fn test_logger_convenience_function() {
let _logger = builtin::logger();
}
#[test]
fn test_powered_by_convenience_function() {
let _powered_by = builtin::powered_by();
}
#[test]
fn test_middleware_chain_default() {
let chain1 = MiddlewareChain::default();
let chain2 = MiddlewareChain::new();
assert_eq!(chain1.middleware.len(), chain2.middleware.len());
}
#[test]
fn test_boxed_middleware_creation() {
let _middleware: BoxedMiddleware = Arc::new(|ctx, next| {
Box::pin(async move {
let result = next(ctx).await;
result
})
});
}
#[test]
fn test_middleware_passthrough() {
let _passthrough: BoxedMiddleware =
Arc::new(|ctx, next| Box::pin(async move { next(ctx).await }));
}
#[test]
fn test_cors_multiple_methods() {
let cors =
builtin::Cors::new().allow_methods(vec!["GET", "POST", "PUT", "PATCH", "DELETE"]);
let _middleware = cors.build();
}
#[test]
fn test_cors_multiple_headers() {
let cors = builtin::Cors::new().allow_headers(vec![
"Content-Type",
"Authorization",
"X-Custom-Header",
]);
let _middleware = cors.build();
}
#[test]
fn test_cors_custom_origin() {
let cors = builtin::Cors::new().allow_origin("https://app.example.com");
let _middleware = cors.build();
}
#[test]
fn test_cors_builder_defaults() {
let cors = builtin::Cors::default();
let _middleware = cors.build();
}
#[test]
fn test_server_headers_builder() {
let _middleware = builtin::server_headers("CustomServer", false);
let _middleware_with_version = builtin::server_headers("Ultimo", true);
}
#[test]
fn test_cors_origin_string_conversion() {
let cors1 = builtin::Cors::new().allow_origin("https://example.com");
let cors2 = builtin::Cors::new().allow_origin(String::from("https://test.com"));
let _m1 = cors1.build();
let _m2 = cors2.build();
}
#[test]
fn test_cors_methods_string_conversion() {
let cors = builtin::Cors::new().allow_methods(vec!["GET", "POST"]);
let _middleware = cors.build();
}
#[test]
fn test_cors_headers_string_conversion() {
let cors = builtin::Cors::new().allow_headers(vec!["Content-Type"]);
let _middleware = cors.build();
}
#[test]
fn test_middleware_arc_clone() {
let middleware: BoxedMiddleware =
Arc::new(|ctx, next| Box::pin(async move { next(ctx).await }));
let cloned = middleware.clone();
assert_eq!(Arc::strong_count(&middleware), Arc::strong_count(&cloned));
}
#[test]
fn test_middleware_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<BoxedMiddleware>();
assert_sync::<BoxedMiddleware>();
}
#[test]
fn test_middleware_chain_is_send() {
fn assert_send<T: Send>() {}
assert_send::<MiddlewareChain>();
}
#[test]
fn test_cors_struct_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<builtin::Cors>();
assert_sync::<builtin::Cors>();
}
}