pub mod basic_auth;
pub mod cors;
pub mod gzip;
pub mod logger;
use anyhow::Error;
use futures::Future;
use hyper::Body;
use std::convert::TryFrom;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
use super::handler::RequestHandler;
use crate::config::Config;
use self::basic_auth::make_basic_auth_middleware;
use self::cors::make_cors_middleware;
use self::gzip::make_gzip_compression_middleware;
use self::logger::make_logger_middleware;
pub type Request<T> = Arc<Mutex<http::Request<T>>>;
pub type Response<T> = Arc<Mutex<http::Response<T>>>;
pub type Result = std::result::Result<(), http::Response<Body>>;
pub type MiddlewareBefore =
Box<dyn Fn(Request<Body>) -> Pin<Box<dyn Future<Output = Result> + Send + Sync>> + Send + Sync>;
pub type MiddlewareAfter = Box<
dyn Fn(Request<Body>, Response<Body>) -> Pin<Box<dyn Future<Output = Result> + Send + Sync>>
+ Send
+ Sync,
>;
#[derive(Default)]
pub struct Middleware {
before: Vec<MiddlewareBefore>,
after: Vec<MiddlewareAfter>,
}
impl Middleware {
#[allow(dead_code)]
pub fn before(&mut self, middleware: MiddlewareBefore) {
self.before.push(middleware);
}
pub fn after(&mut self, middleware: MiddlewareAfter) {
self.after.push(middleware);
}
pub async fn handle(
&self,
request: http::Request<Body>,
handler: Arc<dyn RequestHandler + Send + Sync>,
) -> http::Response<Body> {
let request = Arc::new(Mutex::new(request));
for fx in self.before.iter() {
if let Err(err) = fx(Arc::clone(&request)).await {
return err;
}
}
let response = handler.handle(Arc::clone(&request)).await;
for fx in self.after.iter() {
if let Err(err) = fx(Arc::clone(&request), Arc::clone(&response)).await {
return err;
}
}
Arc::try_unwrap(response)
.expect("There's one or more reference/s being hold by a middleware chain.")
.into_inner()
}
}
impl TryFrom<Arc<Config>> for Middleware {
type Error = Error;
fn try_from(config: Arc<Config>) -> std::result::Result<Self, Self::Error> {
let mut middleware = Middleware::default();
if let Some(basic_auth_config) = config.basic_auth.clone() {
let basic_auth_middleware = make_basic_auth_middleware(basic_auth_config);
middleware.before(basic_auth_middleware);
}
if let Some(cors_config) = config.cors.clone() {
let cors_middleware = make_cors_middleware(cors_config);
middleware.after(cors_middleware);
}
if let Some(compression_config) = config.compression.clone() {
if compression_config.gzip {
middleware.after(make_gzip_compression_middleware());
}
}
if let Some(should_log) = config.logger {
if should_log {
middleware.after(make_logger_middleware());
}
}
Ok(middleware)
}
}