roa_core/
group.rs

1use std::sync::Arc;
2
3use crate::{async_trait, Context, Endpoint, Middleware, Next, Result};
4
5/// A set of method to chain middleware/endpoint to middleware
6/// or make middleware shared.
7pub trait MiddlewareExt<S>: Sized + for<'a> Middleware<'a, S> {
8    /// Chain two middlewares.
9    fn chain<M>(self, next: M) -> Chain<Self, M>
10    where
11        M: for<'a> Middleware<'a, S>,
12    {
13        Chain(self, next)
14    }
15
16    /// Chain an endpoint to a middleware.
17    fn end<E>(self, next: E) -> Chain<Self, E>
18    where
19        E: for<'a> Endpoint<'a, S>,
20    {
21        Chain(self, next)
22    }
23
24    /// Make middleware shared.
25    fn shared(self) -> Shared<S>
26    where
27        S: 'static,
28    {
29        Shared(Arc::new(self))
30    }
31}
32
33/// Extra methods of endpoint.
34pub trait EndpointExt<S>: Sized + for<'a> Endpoint<'a, S> {
35    /// Box an endpoint.
36    fn boxed(self) -> Boxed<S>
37    where
38        S: 'static,
39    {
40        Boxed(Box::new(self))
41    }
42}
43
44impl<S, T> MiddlewareExt<S> for T where T: for<'a> Middleware<'a, S> {}
45impl<S, T> EndpointExt<S> for T where T: for<'a> Endpoint<'a, S> {}
46
47/// A middleware composing and executing other middlewares in a stack-like manner.
48pub struct Chain<T, U>(T, U);
49
50/// Shared middleware.
51pub struct Shared<S>(Arc<dyn for<'a> Middleware<'a, S>>);
52
53/// Boxed endpoint.
54pub struct Boxed<S>(Box<dyn for<'a> Endpoint<'a, S>>);
55
56#[async_trait(?Send)]
57impl<'a, S, T, U> Middleware<'a, S> for Chain<T, U>
58where
59    U: Middleware<'a, S>,
60    T: for<'b> Middleware<'b, S>,
61{
62    #[inline]
63    async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
64        let ptr = ctx as *mut Context<S>;
65        let mut next = self.1.handle(unsafe { &mut *ptr }, next);
66        self.0.handle(ctx, &mut next).await
67    }
68}
69
70#[async_trait(?Send)]
71impl<'a, S> Middleware<'a, S> for Shared<S>
72where
73    S: 'static,
74{
75    #[inline]
76    async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
77        self.0.handle(ctx, next).await
78    }
79}
80
81impl<S> Clone for Shared<S> {
82    #[inline]
83    fn clone(&self) -> Self {
84        Self(self.0.clone())
85    }
86}
87
88#[async_trait(?Send)]
89impl<'a, S> Endpoint<'a, S> for Boxed<S>
90where
91    S: 'static,
92{
93    #[inline]
94    async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
95        self.0.call(ctx).await
96    }
97}
98
99#[async_trait(?Send)]
100impl<'a, S, T, U> Endpoint<'a, S> for Chain<T, U>
101where
102    U: Endpoint<'a, S>,
103    T: for<'b> Middleware<'b, S>,
104{
105    #[inline]
106    async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
107        let ptr = ctx as *mut Context<S>;
108        let mut next = self.1.call(unsafe { &mut *ptr });
109        self.0.handle(ctx, &mut next).await
110    }
111}
112
113#[cfg(all(test, feature = "runtime"))]
114mod tests {
115    use std::sync::Arc;
116
117    use futures::lock::Mutex;
118    use http::StatusCode;
119
120    use crate::{async_trait, App, Context, Middleware, Next, Request, Status};
121
122    struct Pusher {
123        data: usize,
124        vector: Arc<Mutex<Vec<usize>>>,
125    }
126
127    impl Pusher {
128        fn new(data: usize, vector: Arc<Mutex<Vec<usize>>>) -> Self {
129            Self { data, vector }
130        }
131    }
132
133    #[async_trait(?Send)]
134    impl<'a> Middleware<'a, ()> for Pusher {
135        async fn handle(&'a self, _ctx: &'a mut Context, next: Next<'a>) -> Result<(), Status> {
136            self.vector.lock().await.push(self.data);
137            next.await?;
138            self.vector.lock().await.push(self.data);
139            Ok(())
140        }
141    }
142
143    #[tokio::test]
144    async fn middleware_order() -> Result<(), Box<dyn std::error::Error>> {
145        let vector = Arc::new(Mutex::new(Vec::new()));
146        let service = App::new()
147            .gate(Pusher::new(0, vector.clone()))
148            .gate(Pusher::new(1, vector.clone()))
149            .gate(Pusher::new(2, vector.clone()))
150            .gate(Pusher::new(3, vector.clone()))
151            .gate(Pusher::new(4, vector.clone()))
152            .gate(Pusher::new(5, vector.clone()))
153            .gate(Pusher::new(6, vector.clone()))
154            .gate(Pusher::new(7, vector.clone()))
155            .gate(Pusher::new(8, vector.clone()))
156            .gate(Pusher::new(9, vector.clone()))
157            .end(())
158            .http_service();
159        let resp = service.serve(Request::default()).await;
160        assert_eq!(StatusCode::OK, resp.status);
161        for i in 0..10 {
162            assert_eq!(i, vector.lock().await[i]);
163            assert_eq!(i, vector.lock().await[19 - i]);
164        }
165        Ok(())
166    }
167}