roa_core/
middleware.rs

1use std::future::Future;
2
3use http::header::LOCATION;
4use http::{StatusCode, Uri};
5
6use crate::{async_trait, throw, Context, Result, Status};
7
8/// ### Middleware
9///
10/// #### Build-in middlewares
11///
12/// - Functional middleware
13///
14/// A functional middleware is an async function with signature:
15/// `async fn(&mut Context, Next<'_>) -> Result`.
16///
17/// ```rust
18/// use roa_core::{App, Context, Next, Result};
19///
20/// async fn middleware(ctx: &mut Context, next: Next<'_>) -> Result {
21///     next.await
22/// }
23///
24/// let app = App::new().gate(middleware);
25/// ```
26///
27/// - Blank middleware
28///
29/// `()` is a blank middleware, it just calls the next middleware or endpoint.
30///
31/// ```rust
32/// let app = roa_core::App::new().gate(());
33/// ```
34///
35/// #### Custom middleware
36///
37/// You can implement custom `Middleware` for other types.
38///
39/// ```rust
40/// use roa_core::{App, Middleware, Context, Next, Result, async_trait};
41/// use std::sync::Arc;
42/// use std::time::Instant;
43///
44///
45/// struct Logger;
46///
47/// #[async_trait(?Send)]
48/// impl <'a> Middleware<'a> for Logger {
49///     async fn handle(&'a self, ctx: &'a mut Context, next: Next<'a>) -> Result {
50///         let start = Instant::now();
51///         let result = next.await;
52///         println!("time elapsed: {}ms", start.elapsed().as_millis());
53///         result
54///     }
55/// }
56///
57/// let app = App::new().gate(Logger);
58/// ```
59#[async_trait(?Send)]
60pub trait Middleware<'a, S = ()>: 'static + Sync + Send {
61    /// Handle context and next, return status.
62    async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result;
63}
64
65#[async_trait(?Send)]
66impl<'a, S, T, F> Middleware<'a, S> for T
67where
68    S: 'a,
69    T: 'static + Send + Sync + Fn(&'a mut Context<S>, Next<'a>) -> F,
70    F: 'a + Future<Output = Result>,
71{
72    #[inline]
73    async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
74        (self)(ctx, next).await
75    }
76}
77
78/// ### Endpoint
79///
80/// An endpoint is a request handler.
81///
82/// #### Build-in endpoint
83///
84/// There are some build-in endpoints.
85///
86/// - Functional endpoint
87///
88/// A normal functional endpoint is an async function with signature:
89/// `async fn(&mut Context) -> Result`.
90///
91/// ```rust
92/// use roa_core::{App, Context, Result};
93///
94/// async fn endpoint(ctx: &mut Context) -> Result {
95///     Ok(())
96/// }
97///
98/// let app = App::new().end(endpoint);
99/// ```
100/// - Ok endpoint
101///
102/// `()` is an endpoint always return `Ok(())`
103///
104/// ```rust
105/// let app = roa_core::App::new().end(());
106/// ```
107///
108/// - Status endpoint
109///
110/// `Status` is an endpoint always return `Err(Status)`
111///
112/// ```rust
113/// use roa_core::{App, status};
114/// use roa_core::http::StatusCode;
115/// let app = App::new().end(status!(StatusCode::BAD_REQUEST));
116/// ```
117///
118/// - String endpoint
119///
120/// Write string to body.
121///
122/// ```rust
123/// use roa_core::App;
124///
125/// let app = App::new().end("Hello, world"); // static slice
126/// let app = App::new().end("Hello, world".to_owned()); // string
127/// ```
128///
129/// - Redirect endpoint
130///
131/// Redirect to an uri.
132///
133/// ```rust
134/// use roa_core::App;
135/// use roa_core::http::Uri;
136///
137/// let app = App::new().end("/target".parse::<Uri>().unwrap());
138/// ```
139///
140/// #### Custom endpoint
141///
142/// You can implement custom `Endpoint` for your types.
143///
144/// ```rust
145/// use roa_core::{App, Endpoint, Context, Next, Result, async_trait};
146///
147/// fn is_endpoint(endpoint: impl for<'a> Endpoint<'a>) {
148/// }
149///
150/// struct Service;
151///
152/// #[async_trait(?Send)]
153/// impl <'a> Endpoint<'a> for Service {
154///     async fn call(&'a self, ctx: &'a mut Context) -> Result {
155///         Ok(())
156///     }
157/// }
158///
159/// let app = App::new().end(Service);
160/// ```
161#[async_trait(?Send)]
162pub trait Endpoint<'a, S = ()>: 'static + Sync + Send {
163    /// Call this endpoint.
164    async fn call(&'a self, ctx: &'a mut Context<S>) -> Result;
165}
166
167#[async_trait(?Send)]
168impl<'a, S, T, F> Endpoint<'a, S> for T
169where
170    S: 'a,
171    T: 'static + Send + Sync + Fn(&'a mut Context<S>) -> F,
172    F: 'a + Future<Output = Result>,
173{
174    #[inline]
175    async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
176        (self)(ctx).await
177    }
178}
179
180/// blank middleware.
181#[async_trait(?Send)]
182impl<'a, S> Middleware<'a, S> for () {
183    #[allow(clippy::trivially_copy_pass_by_ref)]
184    #[inline]
185    async fn handle(&'a self, _ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
186        next.await
187    }
188}
189
190/// ok endpoint, always return Ok(())
191#[async_trait(?Send)]
192impl<'a, S> Endpoint<'a, S> for () {
193    #[allow(clippy::trivially_copy_pass_by_ref)]
194    #[inline]
195    async fn call(&'a self, _ctx: &'a mut Context<S>) -> Result {
196        Ok(())
197    }
198}
199
200/// status endpoint.
201#[async_trait(?Send)]
202impl<'a, S> Endpoint<'a, S> for Status {
203    #[inline]
204    async fn call(&'a self, _ctx: &'a mut Context<S>) -> Result {
205        Err(self.clone())
206    }
207}
208
209/// String endpoint.
210#[async_trait(?Send)]
211impl<'a, S> Endpoint<'a, S> for String {
212    #[inline]
213    #[allow(clippy::ptr_arg)]
214    async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
215        ctx.resp.write(self.clone());
216        Ok(())
217    }
218}
219
220/// Static slice endpoint.
221#[async_trait(?Send)]
222impl<'a, S> Endpoint<'a, S> for &'static str {
223    #[inline]
224    async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
225        ctx.resp.write(*self);
226        Ok(())
227    }
228}
229
230/// Redirect endpoint.
231#[async_trait(?Send)]
232impl<'a, S> Endpoint<'a, S> for Uri {
233    #[inline]
234    async fn call(&'a self, ctx: &'a mut Context<S>) -> Result {
235        ctx.resp.headers.insert(LOCATION, self.to_string().parse()?);
236        throw!(StatusCode::PERMANENT_REDIRECT)
237    }
238}
239
240/// Type of the second parameter in a middleware,
241/// an alias for `&mut (dyn Unpin + Future<Output = Result>)`
242///
243/// Developer of middleware can jump to next middleware by calling `next.await`.
244///
245/// ### Example
246///
247/// ```rust
248/// use roa_core::{App, Context, Result, Status, MiddlewareExt, Next};
249/// use roa_core::http::StatusCode;
250///
251/// let app = App::new()
252///     .gate(first)
253///     .gate(second)
254///     .gate(third)
255///     .end(end);
256/// async fn first(ctx: &mut Context, next: Next<'_>) -> Result {
257///     assert!(ctx.store("id", "1").is_none());
258///     next.await?;
259///     assert_eq!("5", *ctx.load::<&'static str>("id").unwrap());
260///     Ok(())
261/// }
262/// async fn second(ctx: &mut Context, next: Next<'_>) -> Result {
263///     assert_eq!("1", *ctx.load::<&'static str>("id").unwrap());
264///     assert_eq!("1", *ctx.store("id", "2").unwrap());
265///     next.await?;
266///     assert_eq!("4", *ctx.store("id", "5").unwrap());
267///     Ok(())
268/// }
269/// async fn third(ctx: &mut Context, next: Next<'_>) -> Result {
270///     assert_eq!("2", *ctx.store("id", "3").unwrap());
271///     next.await?; // next is none; do nothing
272///     assert_eq!("3", *ctx.store("id", "4").unwrap());
273///     Ok(())
274/// }
275///
276/// async fn end(ctx: &mut Context) -> Result {
277///     assert_eq!("3", *ctx.load::<&'static str>("id").unwrap());
278///     Ok(())
279/// }
280/// ```
281///
282/// ### Error Handling
283///
284/// You can catch or straightly throw a Error returned by next.
285///
286/// ```rust
287/// use roa_core::{App, Context, Result, Status, MiddlewareExt, Next, status};
288/// use roa_core::http::StatusCode;
289///         
290/// let app = App::new()
291///     .gate(catch)
292///     .gate(gate)
293///     .end(status!(StatusCode::IM_A_TEAPOT, "I'm a teapot!"));
294///
295/// async fn catch(ctx: &mut Context, next: Next<'_>) -> Result {
296///     // catch
297///     if let Err(err) = next.await {
298///         // teapot is ok
299///         if err.status_code != StatusCode::IM_A_TEAPOT {
300///             return Err(err);
301///         }
302///     }
303///     Ok(())
304/// }
305/// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
306///     next.await?; // just throw
307///     unreachable!()
308/// }
309/// ```
310///
311pub type Next<'a> = &'a mut (dyn Unpin + Future<Output = Result>);
312
313#[cfg(test)]
314mod tests {
315    use futures::{AsyncReadExt, TryStreamExt};
316    use http::header::LOCATION;
317    use http::{StatusCode, Uri};
318
319    use crate::{status, App, Request};
320
321    const HELLO: &str = "Hello, world";
322
323    #[tokio::test]
324    async fn status_endpoint() {
325        let app = App::new().end(status!(StatusCode::BAD_REQUEST));
326        let service = app.http_service();
327        let resp = service.serve(Request::default()).await;
328        assert_eq!(StatusCode::BAD_REQUEST, resp.status);
329    }
330
331    #[tokio::test]
332    async fn string_endpoint() {
333        let app = App::new().end(HELLO.to_owned());
334        let service = app.http_service();
335        let mut data = String::new();
336        service
337            .serve(Request::default())
338            .await
339            .body
340            .into_async_read()
341            .read_to_string(&mut data)
342            .await
343            .unwrap();
344        assert_eq!(HELLO, data);
345    }
346    #[tokio::test]
347    async fn static_slice_endpoint() {
348        let app = App::new().end(HELLO);
349        let service = app.http_service();
350        let mut data = String::new();
351        service
352            .serve(Request::default())
353            .await
354            .body
355            .into_async_read()
356            .read_to_string(&mut data)
357            .await
358            .unwrap();
359        assert_eq!(HELLO, data);
360    }
361    #[tokio::test]
362    async fn redirect_endpoint() {
363        let app = App::new().end("/target".parse::<Uri>().unwrap());
364        let service = app.http_service();
365        let resp = service.serve(Request::default()).await;
366        assert_eq!(StatusCode::PERMANENT_REDIRECT, resp.status);
367        assert_eq!("/target", resp.headers[LOCATION].to_str().unwrap())
368    }
369}