roa_core/
app.rs

1mod future;
2#[cfg(feature = "runtime")]
3mod runtime;
4mod stream;
5
6use std::convert::Infallible;
7use std::error::Error;
8use std::future::Future;
9use std::net::SocketAddr;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::Poll;
13
14use future::SendFuture;
15use http::{Request as HttpRequest, Response as HttpResponse};
16use hyper::service::Service;
17use hyper::{Body as HyperBody, Server};
18use tokio::io::{AsyncRead, AsyncWrite};
19
20pub use self::stream::AddrStream;
21use crate::{
22    Accept, Chain, Context, Endpoint, Executor, Middleware, MiddlewareExt, Request, Response,
23    Spawn, State,
24};
25
26/// The Application of roa.
27/// ### Example
28/// ```rust,no_run
29/// use roa_core::{App, Context, Next, Result, MiddlewareExt};
30/// use tracing::info;
31/// use tokio::fs::File;
32///
33/// let app = App::new().gate(gate).end(end);
34/// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
35///     info!("{} {}", ctx.method(), ctx.uri());
36///     next.await
37/// }
38///
39/// async fn end(ctx: &mut Context) -> Result {
40///     ctx.resp.write_reader(File::open("assets/welcome.html").await?);
41///     Ok(())
42/// }
43/// ```
44///
45/// ### State
46/// The `State` is designed to share data or handler between middlewares.
47/// The only one type implemented `State` by this crate is `()`, you can implement your custom state if neccassary.
48///
49/// ```rust
50/// use roa_core::{App, Context, Next, Result};
51/// use tracing::info;
52/// use futures::lock::Mutex;
53///
54/// use std::sync::Arc;
55/// use std::collections::HashMap;
56///
57/// #[derive(Clone)]
58/// struct State {
59///     id: u64,
60///     database: Arc<Mutex<HashMap<u64, String>>>,
61/// }
62///
63/// impl State {
64///     fn new() -> Self {
65///         Self {
66///             id: 0,
67///             database: Arc::new(Mutex::new(HashMap::new()))
68///         }
69///     }
70/// }
71///
72/// let app = App::state(State::new()).gate(gate).end(end);
73/// async fn gate(ctx: &mut Context<State>, next: Next<'_>) -> Result {
74///     ctx.id = 1;
75///     next.await
76/// }
77///
78/// async fn end(ctx: &mut Context<State>) -> Result {
79///     let id = ctx.id;
80///     ctx.database.lock().await.get(&id);
81///     Ok(())
82/// }
83/// ```
84///
85pub struct App<S, T> {
86    service: T,
87    exec: Executor,
88    state: S,
89}
90
91/// An implementation of hyper HttpService.
92pub struct HttpService<S, E> {
93    endpoint: Arc<E>,
94    remote_addr: SocketAddr,
95    exec: Executor,
96    pub(crate) state: S,
97}
98
99impl<S, T> App<S, T> {
100    /// Map app::service
101    fn map_service<U>(self, mapper: impl FnOnce(T) -> U) -> App<S, U> {
102        let Self {
103            exec,
104            state,
105            service,
106        } = self;
107        App {
108            service: mapper(service),
109            exec,
110            state,
111        }
112    }
113}
114
115impl<S> App<S, ()> {
116    /// Construct an application with custom runtime.
117    pub fn with_exec(state: S, exec: impl 'static + Send + Sync + Spawn) -> Self {
118        Self {
119            service: (),
120            exec: Executor(Arc::new(exec)),
121            state,
122        }
123    }
124}
125
126impl<S, T> App<S, T>
127where
128    T: for<'a> Middleware<'a, S>,
129{
130    /// Use a middleware.
131    pub fn gate<M>(self, middleware: M) -> App<S, Chain<T, M>>
132    where
133        M: for<'a> Middleware<'a, S>,
134    {
135        self.map_service(move |service| service.chain(middleware))
136    }
137
138    /// Set endpoint, then app can only be used to serve http request.
139    pub fn end<E>(self, endpoint: E) -> App<S, Arc<Chain<T, E>>>
140    where
141        E: for<'a> Endpoint<'a, S>,
142    {
143        self.map_service(move |service| Arc::new(service.end(endpoint)))
144    }
145}
146
147impl<S, E> App<S, Arc<E>>
148where
149    E: for<'a> Endpoint<'a, S>,
150{
151    /// Construct a hyper server by an incoming.
152    pub fn accept<I, IO>(self, incoming: I) -> Server<I, Self, Executor>
153    where
154        S: State,
155        IO: 'static + Send + Sync + Unpin + AsyncRead + AsyncWrite,
156        I: Accept<Conn = AddrStream<IO>>,
157        I::Error: Into<Box<dyn Error + Send + Sync>>,
158    {
159        Server::builder(incoming)
160            .executor(self.exec.clone())
161            .serve(self)
162    }
163
164    /// Make a fake http service for test.
165    #[cfg(test)]
166    pub fn http_service(&self) -> HttpService<S, E>
167    where
168        S: Clone,
169    {
170        let endpoint = self.service.clone();
171        let addr = ([127, 0, 0, 1], 0);
172        let state = self.state.clone();
173        let exec = self.exec.clone();
174        HttpService::new(endpoint, addr.into(), exec, state)
175    }
176}
177
178macro_rules! impl_poll_ready {
179    () => {
180        #[inline]
181        fn poll_ready(
182            &mut self,
183            _cx: &mut std::task::Context<'_>,
184        ) -> Poll<Result<(), Self::Error>> {
185            Poll::Ready(Ok(()))
186        }
187    };
188}
189
190type AppFuture<S, E> =
191    Pin<Box<dyn 'static + Future<Output = std::io::Result<HttpService<S, E>>> + Send>>;
192
193impl<S, E, IO> Service<&AddrStream<IO>> for App<S, Arc<E>>
194where
195    S: State,
196    E: for<'a> Endpoint<'a, S>,
197    IO: 'static + Send + Sync + Unpin + AsyncRead + AsyncWrite,
198{
199    type Response = HttpService<S, E>;
200    type Error = std::io::Error;
201    type Future = AppFuture<S, E>;
202    impl_poll_ready!();
203
204    #[inline]
205    fn call(&mut self, stream: &AddrStream<IO>) -> Self::Future {
206        let endpoint = self.service.clone();
207        let addr = stream.remote_addr;
208        let state = self.state.clone();
209        let exec = self.exec.clone();
210        Box::pin(async move { Ok(HttpService::new(endpoint, addr, exec, state)) })
211    }
212}
213
214type HttpFuture =
215    Pin<Box<dyn 'static + Future<Output = Result<HttpResponse<HyperBody>, Infallible>> + Send>>;
216
217impl<S, E> Service<HttpRequest<HyperBody>> for HttpService<S, E>
218where
219    S: State,
220    E: for<'a> Endpoint<'a, S>,
221{
222    type Response = HttpResponse<HyperBody>;
223    type Error = Infallible;
224    type Future = HttpFuture;
225    impl_poll_ready!();
226
227    #[inline]
228    fn call(&mut self, req: HttpRequest<HyperBody>) -> Self::Future {
229        let service = self.clone();
230        Box::pin(async move {
231            let serve_future = SendFuture(Box::pin(service.serve(req.into())));
232            Ok(serve_future.await.into())
233        })
234    }
235}
236
237impl<S, E> HttpService<S, E> {
238    pub fn new(endpoint: Arc<E>, remote_addr: SocketAddr, exec: Executor, state: S) -> Self {
239        Self {
240            endpoint,
241            remote_addr,
242            exec,
243            state,
244        }
245    }
246
247    /// Receive a request then return a response.
248    /// The entry point of http service.
249    pub async fn serve(self, req: Request) -> Response
250    where
251        S: 'static,
252        E: for<'a> Endpoint<'a, S>,
253    {
254        let Self {
255            endpoint,
256            remote_addr,
257            exec,
258            state,
259        } = self;
260        let mut ctx = Context::new(req, state, exec, remote_addr);
261        if let Err(status) = endpoint.call(&mut ctx).await {
262            ctx.resp.status = status.status_code;
263            if status.expose {
264                ctx.resp.write(status.message);
265            } else {
266                ctx.exec
267                    .spawn_blocking(move || tracing::error!("Uncaught status: {}", status))
268                    .await;
269            }
270        }
271        ctx.resp
272    }
273}
274
275impl<S: Clone, E> Clone for HttpService<S, E> {
276    fn clone(&self) -> Self {
277        Self {
278            endpoint: self.endpoint.clone(),
279            state: self.state.clone(),
280            exec: self.exec.clone(),
281            remote_addr: self.remote_addr,
282        }
283    }
284}
285
286#[cfg(all(test, feature = "runtime"))]
287mod tests {
288    use http::StatusCode;
289
290    use crate::{App, Request};
291
292    #[tokio::test]
293    async fn gate_simple() -> Result<(), Box<dyn std::error::Error>> {
294        let service = App::new().end(()).http_service();
295        let resp = service.serve(Request::default()).await;
296        assert_eq!(StatusCode::OK, resp.status);
297        Ok(())
298    }
299}