use std::sync::Arc;
use crate::app::AppInner;
use crate::body::ReqBody;
use crate::error::{Error, Result};
use crate::response::Response;
use crate::router::BoxFuture;
pub mod body_limit;
pub mod compression;
pub mod cors;
pub mod https_redirect;
pub mod proxy_headers;
pub mod request_id;
pub mod security_headers;
pub mod timeout;
pub mod trace;
pub mod trusted_host;
pub use body_limit::BodyLimit;
pub use compression::Compression;
pub use cors::Cors;
pub use https_redirect::HttpsRedirect;
pub use proxy_headers::ProxyHeaders;
pub use request_id::RequestId;
pub use security_headers::SecurityHeaders;
pub use timeout::Timeout;
pub use trace::Trace;
pub use trusted_host::TrustedHost;
pub type Request = http::Request<ReqBody>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DuplicatePolicy {
Allow,
Warn,
Reject,
Replace,
}
pub trait Middleware: Send + Sync + 'static {
fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>>;
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
fn duplicate_policy(&self) -> DuplicatePolicy {
DuplicatePolicy::Allow
}
}
pub struct Next {
state: Arc<NextState>,
index: usize,
}
struct NextState {
app: Arc<AppInner>,
stack: Arc<[Arc<dyn Middleware>]>,
}
impl Next {
pub(crate) fn new(app: Arc<AppInner>, stack: Arc<[Arc<dyn Middleware>]>) -> Self {
Self {
state: Arc::new(NextState { app, stack }),
index: 0,
}
}
pub fn run(self, request: Request) -> BoxFuture<'static, Result<Response>> {
match self.state.stack.get(self.index).cloned() {
Some(middleware) => {
let next = Next {
state: self.state,
index: self.index + 1,
};
middleware.handle(request, next)
}
None => {
let app = self.state.app.clone();
Box::pin(async move { Ok(app.dispatch(request).await) })
}
}
}
}
pub(crate) fn resolve_duplicates(
middleware: Vec<Arc<dyn Middleware>>,
) -> Result<Vec<Arc<dyn Middleware>>> {
let mut resolved: Vec<Arc<dyn Middleware>> = Vec::with_capacity(middleware.len());
for entry in middleware {
let name = entry.name();
let existing = resolved.iter().position(|m| m.name() == name);
match (existing, entry.duplicate_policy()) {
(None, _) | (Some(_), DuplicatePolicy::Allow) => resolved.push(entry),
(Some(_), DuplicatePolicy::Warn) => {
eprintln!(
"tork: middleware `{}` registered more than once",
short_name(name)
);
resolved.push(entry);
}
(Some(index), DuplicatePolicy::Replace) => resolved[index] = entry,
(Some(_), DuplicatePolicy::Reject) => {
let short = short_name(name);
return Err(Error::internal(format!(
"Duplicate middleware detected: {short}\n\
{short} middleware can only be registered once per scope.\n\
Already registered at app level."
))
.with_code("DUPLICATE_MIDDLEWARE"));
}
}
}
Ok(resolved)
}
fn short_name(name: &str) -> &str {
name.rsplit("::").next().unwrap_or(name)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::app::{App, AppInner};
use crate::body::box_body;
use crate::constants::TEXT_PLAIN_UTF8;
use crate::extract::RequestContext;
use crate::response::bytes_response;
use crate::router::{HandlerFn, Route, Router};
use crate::{Method, StatusCode};
use bytes::Bytes;
use http::HeaderValue;
use http_body_util::{BodyExt, Full};
struct Mark(&'static str);
impl Middleware for Mark {
fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
let header = self.0;
Box::pin(async move {
let mut response = next.run(request).await?;
response
.headers_mut()
.append("x-mark", HeaderValue::from_static(header));
Ok(response)
})
}
}
struct ShortCircuit;
impl Middleware for ShortCircuit {
fn handle(&self, _request: Request, _next: Next) -> BoxFuture<'static, Result<Response>> {
Box::pin(async { Err(crate::Error::forbidden("blocked")) })
}
}
fn pong_handler() -> HandlerFn {
std::sync::Arc::new(
|_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
Box::pin(async {
Ok(bytes_response(
StatusCode::OK,
TEXT_PLAIN_UTF8,
Bytes::from_static(b"pong"),
))
})
},
)
}
fn app_with(middlewares: Vec<Box<dyn FnOnce(App) -> App>>) -> std::sync::Arc<AppInner> {
let mut app = App::new().include_router(Router::new().route(Route::new(
Method::GET,
"/",
pong_handler(),
)));
for add in middlewares {
app = add(app);
}
std::sync::Arc::new(app.build().unwrap())
}
fn request() -> Request {
http::Request::builder()
.method(Method::GET)
.uri("/")
.body(box_body(Full::new(Bytes::new())))
.unwrap()
}
async fn body_string(response: Response) -> String {
let bytes = response.into_body().collect().await.unwrap().to_bytes();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[tokio::test]
async fn chain_runs_outermost_first_and_reaches_dispatch() {
let app = app_with(vec![
Box::new(|a: App| a.middleware(Mark("outer"))),
Box::new(|a: App| a.middleware(Mark("inner"))),
]);
let response = app.handle(request()).await;
assert_eq!(response.status(), StatusCode::OK);
let marks: Vec<_> = response
.headers()
.get_all("x-mark")
.iter()
.map(|v| v.to_str().unwrap().to_owned())
.collect();
assert_eq!(marks, vec!["inner", "outer"]);
assert_eq!(body_string(response).await, "pong");
}
#[tokio::test]
async fn middleware_can_short_circuit() {
let app = app_with(vec![Box::new(|a: App| a.middleware(ShortCircuit))]);
let response = app.handle(request()).await;
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
struct Policy {
name: &'static str,
policy: DuplicatePolicy,
}
impl Middleware for Policy {
fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
next.run(request)
}
fn name(&self) -> &'static str {
self.name
}
fn duplicate_policy(&self) -> DuplicatePolicy {
self.policy
}
}
fn policy(name: &'static str, policy: DuplicatePolicy) -> std::sync::Arc<dyn Middleware> {
std::sync::Arc::new(Policy { name, policy })
}
#[test]
fn resolve_duplicates_applies_each_policy() {
let allowed = resolve_duplicates(vec![
policy("a", DuplicatePolicy::Allow),
policy("a", DuplicatePolicy::Allow),
])
.unwrap();
assert_eq!(allowed.len(), 2);
let replaced = resolve_duplicates(vec![
policy("b", DuplicatePolicy::Replace),
policy("b", DuplicatePolicy::Replace),
])
.unwrap();
assert_eq!(replaced.len(), 1);
assert!(resolve_duplicates(vec![
policy("c", DuplicatePolicy::Reject),
policy("c", DuplicatePolicy::Reject)
])
.is_err());
let distinct = resolve_duplicates(vec![
policy("x", DuplicatePolicy::Reject),
policy("y", DuplicatePolicy::Reject),
])
.unwrap();
assert_eq!(distinct.len(), 2);
}
}