1#[cfg(feature = "tls")]
2use std::path::Path;
3use std::sync::Arc;
4
5use hyper::http;
6use hyper::server::conn::Http;
7use hyper::service::service_fn;
8use lazy_static::lazy_static;
9use tokio::net::{ToSocketAddrs};
10use tokio::net::TcpListener;
11
12use crate::errors::errors::Error;
13use crate::middleware::{Middleware, WithState};
14use crate::request::request::Request;
15use crate::response::response::Response;
16use crate::router::router::Router;
17use crate::{
18 endpoint::{Endpoint, RouterEndpoint},
19};
20
21lazy_static! {
22 pub static ref SERVER_ID: String = format!("rok {}", env!("CARGO_PKG_VERSION"));
23}
24
25pub struct App {
26 router: Router,
27}
28
29impl App {
30 pub fn new() -> App {
31 App {
32 router: Router::new(),
33 }
34 }
35
36 pub fn with_state<T>(state: T) -> App
37 where
38 T: Send + Sync + 'static + Clone,
39 {
40 let mut app = App::new();
41
42 app.middleware(WithState::new(state));
43 app
44 }
45
46 pub fn merge(
47 &mut self,
48 prefix: impl AsRef<str>,
49 router: Router,
50 ) -> Result<(), crate::errors::errors::Error> {
51 self.router.merge(prefix, router)
52 }
53
54 pub fn register(&mut self, method: http::Method, path: impl AsRef<str>, ep: impl Endpoint) {
55 self.router.register(method, path, ep)
56 }
57
58 pub fn options(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
59 self.register(http::Method::OPTIONS, path, ep)
60 }
61
62 pub fn get(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
63 self.register(http::Method::GET, path, ep)
64 }
65
66 pub fn head(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
67 self.register(http::Method::HEAD, path, ep)
68 }
69
70 pub fn post(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
71 self.register(http::Method::POST, path, ep)
72 }
73
74 pub fn put(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
75 self.register(http::Method::PUT, path, ep)
76 }
77
78 pub fn delete(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
79 self.register(http::Method::DELETE, path, ep)
80 }
81
82 pub fn trace(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
83 self.register(http::Method::TRACE, path, ep)
84 }
85
86 pub fn connect(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
87 self.register(http::Method::CONNECT, path, ep)
88 }
89
90 pub fn patch(&mut self, path: impl AsRef<str>, ep: impl Endpoint) {
91 self.register(http::Method::PATCH, path, ep)
92 }
93
94 pub fn middleware(&mut self, m: impl Middleware) -> &mut Self {
95 self.router.middleware(m);
96 self
97 }
98
99 pub fn handle_not_found(&mut self, ep: impl Endpoint) -> &mut Self {
100 self.router.set_not_found_handler(ep);
101 self
102 }
103
104 pub async fn respond(self, req: impl Into<Request>) -> Response {
105 let req = req.into();
106 let App { router } = self;
107
108 let router = Arc::new(router.finalize());
109
110 let endpoint = RouterEndpoint::new(router);
111 endpoint.call(req).await
112 }
113
114 pub async fn run(self, addr: impl ToSocketAddrs) -> Result<(), Error> {
115 let App { router } = self;
116
117 let router = Arc::new(router.finalize());
118
119 let server = Http::new();
120
121 let listener = TcpListener::bind(addr).await.unwrap();
122 while let Ok((socket, remote_addr)) = listener.accept().await {
123 let server = server.clone();
124 let router = router.clone();
125
126 tokio::spawn(async move {
127 let router = router.clone();
128
129 let ret = server.serve_connection(
130 socket,
131 service_fn(|req| {
132 let router = router.clone();
133 let req = Request::new(req, Some(remote_addr));
134
135 async move {
136 let endpoint = RouterEndpoint::new(router);
137 let resp = endpoint.call(req).await;
138 Ok::<_, Error>(resp.into())
139 }
140 }),
141 );
142
143 if let Err(e) = ret.await {
144 tracing::error!("serve_connection error: {:?}", e);
145 }
146 });
147 }
148
149 Ok(())
150 }
151
152 #[cfg(feature = "tls")]
153 pub async fn run_with_tls(
154 self,
155 addr: impl ToSocketAddrs,
156 cert: impl AsRef<Path>,
157 key: impl AsRef<Path>,
158 ) -> Result<(), Error> {
159 let App { router } = self;
160
161 let router = Arc::new(router.finalize());
162
163 let server = Http::new();
164
165 let tls_acceptor = crate::tls::new_tls_acceptor(cert, key)?;
166
167 let listener = TcpListener::bind(addr).await.unwrap();
168 while let Ok((socket, remote_addr)) = listener.accept().await {
169 let tls_acceptor = tls_acceptor.clone();
170 let server = server.clone();
171 let router = router.clone();
172
173 tokio::spawn(async move {
174 let tls_acceptor = tls_acceptor.clone();
175 let router = router.clone();
176
177 match tls_acceptor.accept(socket).await {
178 Ok(stream) => {
179 let ret = server.serve_connection(
180 stream,
181 service_fn(|req| {
182 let router = router.clone();
183 let req = Request::new(req, Some(remote_addr));
184
185 async move {
186 let endpoint = RouterEndpoint::new(router);
187 let resp = endpoint.call(req).await;
188 Ok::<_, Error>(resp.into())
189 }
190 }),
191 );
192
193 if let Err(e) = ret.await {
194 tracing::error!("serve_connection error: {:?}", e);
195 }
196 }
197 Err(err) => {
198 tracing::error!("tls accept failed, {:?}", err);
199 }
200 }
201 });
202 }
203
204 Ok(())
205 }
206}
207
208impl Default for App {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214pub fn server_id() -> &'static str {
215 &SERVER_ID
216}
217