roa_core/middleware.rs
1use std::future::Future;
2
3use http::header::LOCATION;
4use http::{StatusCode, Uri};
5
6use crate::{async_trait, throw, Context, Result, Status};
7
8/// ### Middleware
9///
10/// #### Build-in middlewares
11///
12/// - Functional middleware
13///
14/// A functional middleware is an async function with signature:
15/// `async fn(&mut Context, Next<'_>) -> Result`.
16///
17/// ```rust
18/// use roa_core::{App, Context, Next, Result};
19///
20/// async fn middleware(ctx: &mut Context, next: Next<'_>) -> Result {
21/// next.await
22/// }
23///
24/// let app = App::new().gate(middleware);
25/// ```
26///
27/// - Blank middleware
28///
29/// `()` is a blank middleware, it just calls the next middleware or endpoint.
30///
31/// ```rust
32/// let app = roa_core::App::new().gate(());
33/// ```
34///
35/// #### Custom middleware
36///
37/// You can implement custom `Middleware` for other types.
38///
39/// ```rust
40/// use roa_core::{App, Middleware, Context, Next, Result, async_trait};
41/// use std::sync::Arc;
42/// use std::time::Instant;
43///
44///
45/// struct Logger;
46///
47/// #[async_trait(?Send)]
48/// impl <'a> Middleware<'a> for Logger {
49/// async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result {
50/// let start = Instant::now();
51/// let result = next.await;
52/// println!("time elapsed: {}ms", start.elapsed().as_millis());
53/// result
54/// }
55/// }
56///
57/// let app = App::new().gate(Logger);
58/// ```
59#[async_trait(?Send)]
60pub trait Middleware<'a, S = ()>: 'static + Sync + Send {
61 /// Handle context and next, return status.
62 async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result;
63}
64
65#[async_trait(?Send)]
66impl<'a, S, T, F> Middleware<'a, S> for T
67where
68 S: 'a,
69 T: 'static + Send + Sync + Fn(&'a mut Context<S>, Next<'a>) -> F,
70 F: 'a + Future<Output = Result>,
71{
72 #[inline]
73 async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
74 (self)(ctx, next).await
75 }
76}
77
78/// ### Endpoint
79///
80/// An endpoint is a request handler.
81///
82/// #### Build-in endpoint
83///
84/// There are some build-in endpoints.
85///
86/// - Functional endpoint
87///
88/// A normal functional endpoint is an async function with signature:
89/// `async fn(&mut Context) -> Result`.
90///
91/// ```rust
92/// use roa_core::{App, Context, Result};
93///
94/// async fn endpoint(ctx: &mut Context) -> Result {
95/// Ok(())
96/// }
97///
98/// let app = App::new().end(endpoint);
99/// ```
100/// - Ok endpoint
101///
102/// `()` is an endpoint always return `Ok(())`
103///
104/// ```rust
105/// let app = roa_core::App::new().end(());
106/// ```
107///
108/// - Status endpoint
109///
110/// `Status` is an endpoint always return `Err(Status)`
111///
112/// ```rust
113/// use roa_core::{App, status};
114/// use roa_core::http::StatusCode;
115/// let app = App::new().end(status!(StatusCode::BAD_REQUEST));
116/// ```
117///
118/// - String endpoint
119///
120/// Write string to body.
121///
122/// ```rust
123/// use roa_core::App;
124///
125/// let app = App::new().end("Hello, world"); // static slice
126/// let app = App::new().end("Hello, world".to_owned()); // string
127/// ```
128///
129/// - Redirect endpoint
130///
131/// Redirect to an uri.
132///
133/// ```rust
134/// use roa_core::App;
135/// use roa_core::http::Uri;
136///
137/// let app = App::new().end("/target".parse::<Uri>().unwrap());
138/// ```
139///
140/// #### Custom endpoint
141///
142/// You can implement custom `Endpoint` for your types.
143///
144/// ```rust
145/// use roa_core::{App, Endpoint, Context, Next, Result, async_trait};
146///
147/// fn is_endpoint(endpoint: impl for<'a> Endpoint<'a>) {
148/// }
149///
150/// struct Service;
151///
152/// #[async_trait(?Send)]
153/// impl <'a> Endpoint<'a> for Service {
154/// async fn call(&'a self, ctx: &'a mut Context) -> Result {
155/// Ok(())
156/// }
157/// }
158///
159/// let app = App::new().end(Service);
160/// ```
161#[async_trait(?Send)]
162pub trait Endpoint<'a, S = ()>: 'static + Sync + Send {
163 /// Call this endpoint.
164 async fn call(&'a self, ctx: &'a mut Context<S>) -> Result;
165}
166
167#[async_trait(?Send)]
168impl<'a, S, T, F> Endpoint<'a, S> for T
169where
170 S: 'a,
171 T: 'static + Send + Sync + Fn(&'a mut Context<S>) -> F,
172 F: 'a + Future<Output = Result>,
173{
174 #[inline]
175 async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
176 (self)(ctx).await
177 }
178}
179
180/// blank middleware.
181#[async_trait(?Send)]
182impl<'a, S> Middleware<'a, S> for () {
183 #[allow(clippy::trivially_copy_pass_by_ref)]
184 #[inline]
185 async fn handle(&'a self, _ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
186 next.await
187 }
188}
189
190/// ok endpoint, always return Ok(())
191#[async_trait(?Send)]
192impl<'a, S> Endpoint<'a, S> for () {
193 #[allow(clippy::trivially_copy_pass_by_ref)]
194 #[inline]
195 async fn call(&'a self, _ctx: &'a mut Context<S>) -> Result {
196 Ok(())
197 }
198}
199
200/// status endpoint.
201#[async_trait(?Send)]
202impl<'a, S> Endpoint<'a, S> for Status {
203 #[inline]
204 async fn call(&'a self, _ctx: &'a mut Context<S>) -> Result {
205 Err(self.clone())
206 }
207}
208
209/// String endpoint.
210#[async_trait(?Send)]
211impl<'a, S> Endpoint<'a, S> for String {
212 #[inline]
213 #[allow(clippy::ptr_arg)]
214 async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
215 ctx.resp.write(self.clone());
216 Ok(())
217 }
218}
219
220/// Static slice endpoint.
221#[async_trait(?Send)]
222impl<'a, S> Endpoint<'a, S> for &'static str {
223 #[inline]
224 async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
225 ctx.resp.write(*self);
226 Ok(())
227 }
228}
229
230/// Redirect endpoint.
231#[async_trait(?Send)]
232impl<'a, S> Endpoint<'a, S> for Uri {
233 #[inline]
234 async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
235 ctx.resp.headers.insert(LOCATION, self.to_string().parse()?);
236 throw!(StatusCode::PERMANENT_REDIRECT)
237 }
238}
239
240/// Type of the second parameter in a middleware,
241/// an alias for `&mut (dyn Unpin + Future<Output = Result>)`
242///
243/// Developer of middleware can jump to next middleware by calling `next.await`.
244///
245/// ### Example
246///
247/// ```rust
248/// use roa_core::{App, Context, Result, Status, MiddlewareExt, Next};
249/// use roa_core::http::StatusCode;
250///
251/// let app = App::new()
252/// .gate(first)
253/// .gate(second)
254/// .gate(third)
255/// .end(end);
256/// async fn first(ctx: &mut Context, next: Next<'_>) -> Result {
257/// assert!(ctx.store("id", "1").is_none());
258/// next.await?;
259/// assert_eq!("5", *ctx.load::<&'static str>("id").unwrap());
260/// Ok(())
261/// }
262/// async fn second(ctx: &mut Context, next: Next<'_>) -> Result {
263/// assert_eq!("1", *ctx.load::<&'static str>("id").unwrap());
264/// assert_eq!("1", *ctx.store("id", "2").unwrap());
265/// next.await?;
266/// assert_eq!("4", *ctx.store("id", "5").unwrap());
267/// Ok(())
268/// }
269/// async fn third(ctx: &mut Context, next: Next<'_>) -> Result {
270/// assert_eq!("2", *ctx.store("id", "3").unwrap());
271/// next.await?; // next is none; do nothing
272/// assert_eq!("3", *ctx.store("id", "4").unwrap());
273/// Ok(())
274/// }
275///
276/// async fn end(ctx: &mut Context) -> Result {
277/// assert_eq!("3", *ctx.load::<&'static str>("id").unwrap());
278/// Ok(())
279/// }
280/// ```
281///
282/// ### Error Handling
283///
284/// You can catch or straightly throw a Error returned by next.
285///
286/// ```rust
287/// use roa_core::{App, Context, Result, Status, MiddlewareExt, Next, status};
288/// use roa_core::http::StatusCode;
289///
290/// let app = App::new()
291/// .gate(catch)
292/// .gate(gate)
293/// .end(status!(StatusCode::IM_A_TEAPOT, "I'm a teapot!"));
294///
295/// async fn catch(ctx: &mut Context, next: Next<'_>) -> Result {
296/// // catch
297/// if let Err(err) = next.await {
298/// // teapot is ok
299/// if err.status_code != StatusCode::IM_A_TEAPOT {
300/// return Err(err);
301/// }
302/// }
303/// Ok(())
304/// }
305/// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
306/// next.await?; // just throw
307/// unreachable!()
308/// }
309/// ```
310///
311pub type Next<'a> = &'a mut (dyn Unpin + Future<Output = Result>);
312
313#[cfg(test)]
314mod tests {
315 use futures::{AsyncReadExt, TryStreamExt};
316 use http::header::LOCATION;
317 use http::{StatusCode, Uri};
318
319 use crate::{status, App, Request};
320
321 const HELLO: &str = "Hello, world";
322
323 #[tokio::test]
324 async fn status_endpoint() {
325 let app = App::new().end(status!(StatusCode::BAD_REQUEST));
326 let service = app.http_service();
327 let resp = service.serve(Request::default()).await;
328 assert_eq!(StatusCode::BAD_REQUEST, resp.status);
329 }
330
331 #[tokio::test]
332 async fn string_endpoint() {
333 let app = App::new().end(HELLO.to_owned());
334 let service = app.http_service();
335 let mut data = String::new();
336 service
337 .serve(Request::default())
338 .await
339 .body
340 .into_async_read()
341 .read_to_string(&mut data)
342 .await
343 .unwrap();
344 assert_eq!(HELLO, data);
345 }
346 #[tokio::test]
347 async fn static_slice_endpoint() {
348 let app = App::new().end(HELLO);
349 let service = app.http_service();
350 let mut data = String::new();
351 service
352 .serve(Request::default())
353 .await
354 .body
355 .into_async_read()
356 .read_to_string(&mut data)
357 .await
358 .unwrap();
359 assert_eq!(HELLO, data);
360 }
361 #[tokio::test]
362 async fn redirect_endpoint() {
363 let app = App::new().end("/target".parse::<Uri>().unwrap());
364 let service = app.http_service();
365 let resp = service.serve(Request::default()).await;
366 assert_eq!(StatusCode::PERMANENT_REDIRECT, resp.status);
367 assert_eq!("/target", resp.headers[LOCATION].to_str().unwrap())
368 }
369}