1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
use crate::{Context, DynTargetHandler, Model, Next, Status, TargetHandler};
use std::future::Future;
use std::sync::Arc;

pub struct Middleware<M: Model>(Arc<DynTargetHandler<M, Next>>);

impl<M: Model> Middleware<M> {
    pub fn new() -> Self {
        Self(Arc::new(|_ctx, next| next()))
    }

    pub fn join<F>(
        &mut self,
        middleware: impl 'static + Sync + Send + Fn(Context<M>, Next) -> F,
    ) -> &mut Self
    where
        F: 'static + Future<Output = Result<(), Status>> + Send,
    {
        let current = self.0.clone();
        let next_middleware: Arc<DynTargetHandler<M, Next>> =
            Arc::from(Box::new(middleware).dynamic());
        self.0 = Arc::new(move |ctx, next| {
            let next_middleware = next_middleware.clone();
            let ctx_cloned = ctx.clone();
            let next = Box::new(move || next_middleware(ctx_cloned, next));
            current(ctx, next)
        });
        self
    }

    pub fn handler(&self) -> Box<DynTargetHandler<M, Next>> {
        let handler = self.0.clone();
        Box::new(move |ctx, next| handler(ctx, next))
    }
}

impl<M: Model> Clone for Middleware<M> {
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}

#[cfg(test)]
mod tests {
    use super::Middleware;
    use crate::App;
    use async_std::task::spawn;
    use futures::lock::Mutex;
    use http::StatusCode;
    use std::sync::Arc;

    #[tokio::test]
    async fn middleware_order() -> Result<(), Box<dyn std::error::Error>> {
        let vector = Arc::new(Mutex::new(Vec::new()));
        let mut middleware = Middleware::<()>::new();
        for i in 0..100 {
            let vec = vector.clone();
            middleware.join(move |_ctx, next| {
                let vec = vec.clone();
                async move {
                    vec.lock().await.push(i);
                    next().await?;
                    vec.lock().await.push(i);
                    Ok(())
                }
            });
        }
        let (addr, server) = App::new(()).gate(middleware.handler()).run_local()?;
        spawn(server);
        let resp = reqwest::get(&format!("http://{}", addr)).await?;
        assert_eq!(StatusCode::OK, resp.status());
        for i in 0..100 {
            assert_eq!(i, vector.lock().await[i]);
            assert_eq!(i, vector.lock().await[199 - i]);
        }
        Ok(())
    }
}