vantus 0.3.0

Macro-first async Rust backend framework with explicit composition, typed extraction, and hardened HTTP defaults.
Documentation
use std::sync::Arc;

use crate::routing::{Handler, HandlerFuture, HandlerResult, RequestContext};

pub type MiddlewareFuture = HandlerFuture;

#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub enum MiddlewareStage {
    Logging,
    Recovery,
    Auth,
    Validation,
    Response,
}

#[derive(Clone)]
pub struct Next {
    middlewares: Arc<[PipelineMiddleware]>,
    index: usize,
    handler: Handler,
}

impl Next {
    pub(crate) fn new(
        middlewares: Arc<[PipelineMiddleware]>,
        index: usize,
        handler: Handler,
    ) -> Self {
        Self {
            middlewares,
            index,
            handler,
        }
    }

    pub fn run(&self, ctx: RequestContext) -> MiddlewareFuture {
        if let Some(middleware) = self.middlewares.get(self.index).cloned() {
            let next = Self {
                middlewares: Arc::clone(&self.middlewares),
                index: self.index + 1,
                handler: self.handler.clone(),
            };
            middleware.inner.handle(ctx, next)
        } else {
            let handler = self.handler.clone();
            Box::pin(async move { handler.call(ctx).await })
        }
    }
}

pub trait Middleware: Send + Sync {
    fn stage(&self) -> MiddlewareStage {
        MiddlewareStage::Validation
    }

    fn handle(&self, ctx: RequestContext, next: Next) -> MiddlewareFuture;
}

#[derive(Clone, Default)]
pub struct MiddlewareStack {
    stack: Vec<Arc<dyn Middleware>>,
}

impl MiddlewareStack {
    pub fn new() -> Self {
        Self { stack: Vec::new() }
    }

    pub fn len(&self) -> usize {
        self.stack.len()
    }

    pub fn is_empty(&self) -> bool {
        self.stack.is_empty()
    }

    pub fn add<M>(&mut self, middleware: M)
    where
        M: Middleware + 'static,
    {
        self.stack.push(Arc::new(middleware));
    }

    pub async fn execute(
        &self,
        route_stack: &[Arc<dyn Middleware>],
        ctx: RequestContext,
        handler: Handler,
    ) -> HandlerResult {
        let mut combined = self
            .stack
            .iter()
            .cloned()
            .enumerate()
            .map(|(index, middleware)| {
                PipelineMiddleware::new(middleware, MiddlewareSource::Global, index)
            })
            .chain(
                route_stack
                    .iter()
                    .cloned()
                    .enumerate()
                    .map(|(index, middleware)| {
                        PipelineMiddleware::new(middleware, MiddlewareSource::Route, index)
                    }),
            )
            .collect::<Vec<_>>();

        combined.sort_by_key(|middleware| {
            (
                middleware.stage,
                middleware.source,
                middleware.registration_index,
            )
        });

        let combined: Arc<[PipelineMiddleware]> = combined.into();

        Next::new(combined, 0, handler).run(ctx).await
    }
}

#[derive(Clone)]
pub struct PipelineMiddleware {
    inner: Arc<dyn Middleware>,
    stage: MiddlewareStage,
    source: MiddlewareSource,
    registration_index: usize,
}

impl PipelineMiddleware {
    fn new(
        inner: Arc<dyn Middleware>,
        source: MiddlewareSource,
        registration_index: usize,
    ) -> Self {
        Self {
            stage: inner.stage(),
            inner,
            source,
            registration_index,
        }
    }
}

#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
enum MiddlewareSource {
    Global,
    Route,
}