1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
mod add_data;
mod cors;
mod set_header;
#[cfg(feature = "tracing")]
mod tracing;
pub use add_data::AddData;
pub use cors::Cors;
pub use set_header::SetHeader;
#[cfg(feature = "tracing")]
pub use self::tracing::Tracing;
use crate::endpoint::Endpoint;
pub trait Middleware<E: Endpoint> {
type Output: Endpoint;
fn transform(self, ep: E) -> Self::Output;
}
pub struct FnMiddleware<T>(T);
impl<T, E, E2> Middleware<E> for FnMiddleware<T>
where
T: Fn(E) -> E2,
E: Endpoint,
E2: Endpoint,
{
type Output = E2;
fn transform(self, ep: E) -> Self::Output {
(self.0)(ep)
}
}
pub fn make<T>(f: T) -> FnMiddleware<T> {
FnMiddleware(f)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
handler,
http::{header::HeaderName, HeaderValue},
EndpointExt, IntoResponse, Request, Response,
};
#[tokio::test]
async fn test_make() {
#[handler(internal)]
fn index() -> &'static str {
"abc"
}
struct AddHeader<E> {
ep: E,
header: HeaderName,
value: HeaderValue,
}
#[async_trait::async_trait]
impl<E: Endpoint> Endpoint for AddHeader<E> {
type Output = Response;
async fn call(&self, req: Request) -> Self::Output {
let mut resp = self.ep.call(req).await.into_response();
resp.headers_mut()
.insert(self.header.clone(), self.value.clone());
resp
}
}
let ep = index.with(make(|ep| AddHeader {
ep,
header: HeaderName::from_static("hello"),
value: HeaderValue::from_static("world"),
}));
let mut resp = ep.call(Request::default()).await;
assert_eq!(
resp.headers()
.get(HeaderName::from_static("hello"))
.cloned(),
Some(HeaderValue::from_static("world"))
);
assert_eq!(resp.take_body().into_string().await.unwrap(), "abc");
}
}