1use std::sync::Arc;
2
3use crate::{async_trait, Context, Endpoint, Middleware, Next, Result};
4
5pub trait MiddlewareExt<S>: Sized + for<'a> Middleware<'a, S> {
8 fn chain<M>(self, next: M) -> Chain<Self, M>
10 where
11 M: for<'a> Middleware<'a, S>,
12 {
13 Chain(self, next)
14 }
15
16 fn end<E>(self, next: E) -> Chain<Self, E>
18 where
19 E: for<'a> Endpoint<'a, S>,
20 {
21 Chain(self, next)
22 }
23
24 fn shared(self) -> Shared<S>
26 where
27 S: 'static,
28 {
29 Shared(Arc::new(self))
30 }
31}
32
33pub trait EndpointExt<S>: Sized + for<'a> Endpoint<'a, S> {
35 fn boxed(self) -> Boxed<S>
37 where
38 S: 'static,
39 {
40 Boxed(Box::new(self))
41 }
42}
43
44impl<S, T> MiddlewareExt<S> for T where T: for<'a> Middleware<'a, S> {}
45impl<S, T> EndpointExt<S> for T where T: for<'a> Endpoint<'a, S> {}
46
47pub struct Chain<T, U>(T, U);
49
50pub struct Shared<S>(Arc<dyn for<'a> Middleware<'a, S>>);
52
53pub struct Boxed<S>(Box<dyn for<'a> Endpoint<'a, S>>);
55
56#[async_trait(?Send)]
57impl<'a, S, T, U> Middleware<'a, S> for Chain<T, U>
58where
59 U: Middleware<'a, S>,
60 T: for<'b> Middleware<'b, S>,
61{
62 #[inline]
63 async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
64 let ptr = ctx as *mut Context<S>;
65 let mut next = self.1.handle(unsafe { &mut *ptr }, next);
66 self.0.handle(ctx, &mut next).await
67 }
68}
69
70#[async_trait(?Send)]
71impl<'a, S> Middleware<'a, S> for Shared<S>
72where
73 S: 'static,
74{
75 #[inline]
76 async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
77 self.0.handle(ctx, next).await
78 }
79}
80
81impl<S> Clone for Shared<S> {
82 #[inline]
83 fn clone(&self) -> Self {
84 Self(self.0.clone())
85 }
86}
87
88#[async_trait(?Send)]
89impl<'a, S> Endpoint<'a, S> for Boxed<S>
90where
91 S: 'static,
92{
93 #[inline]
94 async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
95 self.0.call(ctx).await
96 }
97}
98
99#[async_trait(?Send)]
100impl<'a, S, T, U> Endpoint<'a, S> for Chain<T, U>
101where
102 U: Endpoint<'a, S>,
103 T: for<'b> Middleware<'b, S>,
104{
105 #[inline]
106 async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
107 let ptr = ctx as *mut Context<S>;
108 let mut next = self.1.call(unsafe { &mut *ptr });
109 self.0.handle(ctx, &mut next).await
110 }
111}
112
113#[cfg(all(test, feature = "runtime"))]
114mod tests {
115 use std::sync::Arc;
116
117 use futures::lock::Mutex;
118 use http::StatusCode;
119
120 use crate::{async_trait, App, Context, Middleware, Next, Request, Status};
121
122 struct Pusher {
123 data: usize,
124 vector: Arc<Mutex<Vec<usize>>>,
125 }
126
127 impl Pusher {
128 fn new(data: usize, vector: Arc<Mutex<Vec<usize>>>) -> Self {
129 Self { data, vector }
130 }
131 }
132
133 #[async_trait(?Send)]
134 impl<'a> Middleware<'a, ()> for Pusher {
135 async fn handle(&'a self, _ctx: &'a mut Context, next: Next<'a>) -> Result<(), Status> {
136 self.vector.lock().await.push(self.data);
137 next.await?;
138 self.vector.lock().await.push(self.data);
139 Ok(())
140 }
141 }
142
143 #[tokio::test]
144 async fn middleware_order() -> Result<(), Box<dyn std::error::Error>> {
145 let vector = Arc::new(Mutex::new(Vec::new()));
146 let service = App::new()
147 .gate(Pusher::new(0, vector.clone()))
148 .gate(Pusher::new(1, vector.clone()))
149 .gate(Pusher::new(2, vector.clone()))
150 .gate(Pusher::new(3, vector.clone()))
151 .gate(Pusher::new(4, vector.clone()))
152 .gate(Pusher::new(5, vector.clone()))
153 .gate(Pusher::new(6, vector.clone()))
154 .gate(Pusher::new(7, vector.clone()))
155 .gate(Pusher::new(8, vector.clone()))
156 .gate(Pusher::new(9, vector.clone()))
157 .end(())
158 .http_service();
159 let resp = service.serve(Request::default()).await;
160 assert_eq!(StatusCode::OK, resp.status);
161 for i in 0..10 {
162 assert_eq!(i, vector.lock().await[i]);
163 assert_eq!(i, vector.lock().await[19 - i]);
164 }
165 Ok(())
166 }
167}