rok/
server.rs

1#[cfg(feature = "tls")]
2use std::path::Path;
3use std::sync::Arc;
4
5use hyper::http;
6use hyper::server::conn::Http;
7use hyper::service::service_fn;
8use lazy_static::lazy_static;
9use tokio::net::{ToSocketAddrs};
10use tokio::net::TcpListener;
11
12use crate::errors::errors::Error;
13use crate::middleware::{Middleware, WithState};
14use crate::request::request::Request;
15use crate::response::response::Response;
16use crate::router::router::Router;
17use crate::{
18    endpoint::{Endpoint, RouterEndpoint},
19};
20
21lazy_static! {
22    pub static ref SERVER_ID: String = format!("rok {}", env!("CARGO_PKG_VERSION"));
23}
24
25pub struct App {
26    router: Router,
27}
28
29impl App {
30    pub fn new() -> App {
31        App {
32            router: Router::new(),
33        }
34    }
35
36    pub fn with_state<T>(state: T) -> App
37        where
38            T: Send + Sync + 'static + Clone,
39    {
40        let mut app = App::new();
41
42        app.middleware(WithState::new(state));
43        app
44    }
45
46    pub fn merge(
47        &mut self,
48        prefix: impl AsRef<str>,
49        router: Router,
50    ) -> Result<(), crate::errors::errors::Error> {
51        self.router.merge(prefix, router)
52    }
53
54    pub fn register(&mut self, method: http::Method, path: impl AsRef<str>, ep: impl Endpoint) {
55        self.router.register(method, path, ep)
56    }
57
58    pub fn options(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
59        self.register(http::Method::OPTIONS, path, ep)
60    }
61
62    pub fn get(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
63        self.register(http::Method::GET, path, ep)
64    }
65
66    pub fn head(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
67        self.register(http::Method::HEAD, path, ep)
68    }
69
70    pub fn post(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
71        self.register(http::Method::POST, path, ep)
72    }
73
74    pub fn put(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
75        self.register(http::Method::PUT, path, ep)
76    }
77
78    pub fn delete(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
79        self.register(http::Method::DELETE, path, ep)
80    }
81
82    pub fn trace(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
83        self.register(http::Method::TRACE, path, ep)
84    }
85
86    pub fn connect(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
87        self.register(http::Method::CONNECT, path, ep)
88    }
89
90    pub fn patch(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
91        self.register(http::Method::PATCH, path, ep)
92    }
93
94    pub fn middleware(&mut self, m: impl Middleware) -> &mut Self {
95        self.router.middleware(m);
96        self
97    }
98
99    pub fn handle_not_found(&mut self, ep: impl Endpoint) -> &mut Self {
100        self.router.set_not_found_handler(ep);
101        self
102    }
103
104    pub async fn respond(self, req: impl Into<Request>) -> Response {
105        let req = req.into();
106        let App { router } = self;
107
108        let router = Arc::new(router.finalize());
109
110        let endpoint = RouterEndpoint::new(router);
111        endpoint.call(req).await
112    }
113
114    pub async fn run(self, addr: impl ToSocketAddrs) -> Result<(), Error> {
115        let App { router } = self;
116
117        let router = Arc::new(router.finalize());
118
119        let server = Http::new();
120
121        let listener = TcpListener::bind(addr).await.unwrap();
122        while let Ok((socket, remote_addr)) = listener.accept().await {
123            let server = server.clone();
124            let router = router.clone();
125
126            tokio::spawn(async move {
127                let router = router.clone();
128
129                let ret = server.serve_connection(
130                    socket,
131                    service_fn(|req| {
132                        let router = router.clone();
133                        let req = Request::new(req, Some(remote_addr));
134
135                        async move {
136                            let endpoint = RouterEndpoint::new(router);
137                            let resp = endpoint.call(req).await;
138                            Ok::<_, Error>(resp.into())
139                        }
140                    }),
141                );
142
143                if let Err(e) = ret.await {
144                    tracing::error!("serve_connection error: {:?}", e);
145                }
146            });
147        }
148
149        Ok(())
150    }
151
152    #[cfg(feature = "tls")]
153    pub async fn run_with_tls(
154        self,
155        addr: impl ToSocketAddrs,
156        cert: impl AsRef<Path>,
157        key: impl AsRef<Path>,
158    ) -> Result<(), Error> {
159        let App { router } = self;
160
161        let router = Arc::new(router.finalize());
162
163        let server = Http::new();
164
165        let tls_acceptor = crate::tls::new_tls_acceptor(cert, key)?;
166
167        let listener = TcpListener::bind(addr).await.unwrap();
168        while let Ok((socket, remote_addr)) = listener.accept().await {
169            let tls_acceptor = tls_acceptor.clone();
170            let server = server.clone();
171            let router = router.clone();
172
173            tokio::spawn(async move {
174                let tls_acceptor = tls_acceptor.clone();
175                let router = router.clone();
176
177                match tls_acceptor.accept(socket).await {
178                    Ok(stream) => {
179                        let ret = server.serve_connection(
180                            stream,
181                            service_fn(|req| {
182                                let router = router.clone();
183                                let req = Request::new(req, Some(remote_addr));
184
185                                async move {
186                                    let endpoint = RouterEndpoint::new(router);
187                                    let resp = endpoint.call(req).await;
188                                    Ok::<_, Error>(resp.into())
189                                }
190                            }),
191                        );
192
193                        if let Err(e) = ret.await {
194                            tracing::error!("serve_connection error: {:?}", e);
195                        }
196                    }
197                    Err(err) => {
198                        tracing::error!("tls accept failed, {:?}", err);
199                    }
200                }
201            });
202        }
203
204        Ok(())
205    }
206}
207
208impl Default for App {
209    fn default() -> Self {
210        Self::new()
211    }
212}
213
214pub fn server_id() -> &'static str {
215    &SERVER_ID
216}
217