jerrycan_core/
middleware.rs1use crate::extract::RequestCtx;
5use crate::handler::BoxHandlerFn;
6use crate::response::Response;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11pub type MiddlewareFuture<'a> = Pin<Box<dyn Future<Output = Response> + Send + 'a>>;
13
14pub trait Middleware: Send + Sync + 'static {
17 fn handle<'a>(&'a self, ctx: &'a mut RequestCtx, next: Next<'a>) -> MiddlewareFuture<'a>;
18}
19
20pub struct Next<'a> {
22 pub(crate) chain: &'a [Arc<dyn Middleware>],
23 pub(crate) endpoint: &'a BoxHandlerFn,
24}
25
26impl<'a> Next<'a> {
27 pub fn run<'b>(self, ctx: &'b mut RequestCtx) -> MiddlewareFuture<'b>
30 where
31 'a: 'b,
32 {
33 let Next { chain, endpoint } = self;
34 match chain.split_first() {
35 Some((head, rest)) => head.handle(
36 ctx,
37 Next {
38 chain: rest,
39 endpoint,
40 },
41 ),
42 None => (**endpoint)(ctx), }
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50 use crate::dep::{DepEnv, DepResolver};
51 use crate::response::IntoResponse;
52 use std::collections::HashMap;
53 use std::sync::Mutex;
54
55 type Log = Arc<Mutex<Vec<&'static str>>>;
56
57 struct Tag {
58 name_in: &'static str,
59 name_out: &'static str,
60 log: Log,
61 }
62
63 impl Middleware for Tag {
64 fn handle<'a>(&'a self, ctx: &'a mut RequestCtx, next: Next<'a>) -> MiddlewareFuture<'a> {
65 Box::pin(async move {
66 self.log.lock().unwrap().push(self.name_in);
67 let res = next.run(&mut *ctx).await;
68 self.log.lock().unwrap().push(self.name_out);
69 res
70 })
71 }
72 }
73
74 #[tokio::test]
75 async fn chain_runs_outside_in_then_inside_out() {
76 let log: Log = Arc::new(Mutex::new(Vec::new()));
77 let l = log.clone();
78 let endpoint: BoxHandlerFn = Arc::new(move |_ctx: &mut RequestCtx| {
79 let l = l.clone();
80 Box::pin(async move {
81 l.lock().unwrap().push("handler");
82 "ok".into_response()
83 })
84 });
85
86 let chain: Vec<Arc<dyn Middleware>> = vec![
87 Arc::new(Tag {
88 name_in: "outer-in",
89 name_out: "outer-out",
90 log: log.clone(),
91 }),
92 Arc::new(Tag {
93 name_in: "inner-in",
94 name_out: "inner-out",
95 log: log.clone(),
96 }),
97 ];
98
99 let req = http::Request::builder().uri("/").body(()).unwrap();
100 let (parts, ()) = req.into_parts();
101 let mut ctx = RequestCtx::new(
102 parts,
103 bytes::Bytes::new(),
104 DepResolver::new(Arc::new(DepEnv::default()), Arc::new(HashMap::new())),
105 );
106
107 let res = Next {
108 chain: &chain,
109 endpoint: &endpoint,
110 }
111 .run(&mut ctx)
112 .await;
113 assert_eq!(res.status(), http::StatusCode::OK);
114 assert_eq!(
115 *log.lock().unwrap(),
116 vec!["outer-in", "inner-in", "handler", "inner-out", "outer-out"]
117 );
118 }
119}