1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4 Request,
5 body::Incoming,
6 server::conn::{http1, http2},
7 service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10
11use tokio::{
12 io::{AsyncRead, AsyncWrite},
13 net::TcpListener,
14};
15
16use crate::{
17 client::Client,
18 group::Group,
19 middleware::error::DefaultErrorHandler,
20 request::RequestBody,
21 response::ResponseWriter,
22 router::Router,
23 types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
24};
25
26pub mod client;
27mod group;
28pub mod macros;
29pub mod middleware;
30pub mod request;
31pub mod response;
32mod router;
33pub mod types;
34pub use bolt_web_macro::main;
35pub use paste;
36pub use tokio;
37
38trait Io: AsyncRead + AsyncWrite + Unpin {}
39impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
40
41#[allow(dead_code)]
42pub struct Bolt {
43 router: Router,
44 error_handler: Arc<dyn ErrorHandler>,
45 client: Client,
46}
47
48#[allow(unused_variables)]
49#[allow(dead_code)]
50impl Bolt {
51 pub fn new() -> Self {
52 Self {
53 router: Router::new(),
54 error_handler: Arc::new(DefaultErrorHandler),
55 client: Client::new(),
56 }
57 }
58
59 pub async fn run(
60 &self,
61 addr: &str,
62 mode: Mode,
63 tls: Option<(&str, &str)>,
64 ) -> Result<(), BoltError> {
65 println!("⚡ A high performance & minimalist web framework in rust.");
66 println!(
67 r#"
68 __ ____
69 / /_ ____ / / /_
70 / __ \/ __ \/ / __/
71 / /_/ / /_/ / / /_
72/_.___/\____/_/\__/ v0.2.0
73"#
74 );
75
76 println!(
77 ">> Server running on {}://{}",
78 if tls.is_some() { "https" } else { "http" },
79 addr
80 );
81
82 let addr: SocketAddr = addr.parse().unwrap();
83
84 let listener = TcpListener::bind(addr).await?;
85 let router = Arc::new(self.router.clone());
86
87 let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
88 let pkcs12 = std::fs::read(pkcs12_path)?;
89 let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
90 Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
91 tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
92 )))
93 } else {
94 None
95 };
96
97 loop {
98 let (stream, _) = listener.accept().await?;
99
100 let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
101 Box::new(acceptor.accept(stream).await?)
102 } else {
103 Box::new(stream)
104 };
105
106 let io = TokioIo::new(io);
107
108 let router = router.clone();
109 let error_handler = self.error_handler.clone();
110
111 let service = service_fn(move |req: Request<Incoming>| {
112 let router = router.clone();
113
114 let error_handler = error_handler.clone();
115
116 async move {
117 let mut req_body = RequestBody::new(req);
118 let mut res_body = ResponseWriter::new();
119
120 let method = match *req_body.method() {
121 hyper::Method::GET => Method::GET,
122 hyper::Method::POST => Method::POST,
123 hyper::Method::PUT => Method::PUT,
124 hyper::Method::PATCH => Method::PATCH,
125 hyper::Method::DELETE => Method::DELETE,
126 hyper::Method::OPTIONS => Method::OPTIONS,
127 hyper::Method::HEAD => Method::HEAD,
128 hyper::Method::TRACE => Method::TRACE,
129 _ => {
130 res_body.status(405);
131 res_body.send("Method Not Allowed");
132 return Ok::<_, Infallible>(res_body.into_response());
133 }
134 };
135
136 let path = req_body.path().to_string();
137
138 let mws = router.collect_middleware(&path, method);
139
140 for mw in mws {
141 mw.run(&mut req_body, &mut res_body).await;
142
143 if res_body.has_error() {
144 break;
145 }
146 }
147
148 if !res_body.has_error() {
149 if let Some((handler, params)) = router.find(&path, method) {
150 req_body.set_params(params);
151
152 handler.handle(&mut req_body, &mut res_body).await;
153 } else {
154 let method_str = match *req_body.method() {
155 hyper::Method::GET => "GET",
156 hyper::Method::POST => "POST",
157 hyper::Method::PUT => "PUT",
158 hyper::Method::PATCH => "PATCH",
159 hyper::Method::DELETE => "DELETE",
160 hyper::Method::OPTIONS => "OPTIONS",
161 hyper::Method::HEAD => "HEAD",
162 hyper::Method::TRACE => "TRACE",
163 _ => "UNKNOWN",
164 };
165
166 res_body.error(404, &format!("Not Found {} {}", method_str, path));
167 }
168 }
169
170 if res_body.has_error() {
171 let msg = res_body.body.clone();
172 error_handler.run(msg, &mut res_body).await;
173 }
174
175 req_body.cleanup().await;
176
177 if req_body.log {
178 println!(
179 "[LOG] method={} path={} status={}",
180 req_body.method(),
181 path,
182 res_body.status,
183 );
184 }
185
186 res_body.strip_header("X-Internal-Request-Start");
187
188 Ok::<_, Infallible>(res_body.into_response())
189 }
190 });
191
192 match mode {
193 Mode::Http1 => {
194 tokio::task::spawn(async move {
195 if let Err(err) = http1::Builder::new().serve_connection(io, service).await
196 {
197 eprintln!("Error serving connection: {}", err);
198 }
199 });
200 }
201
202 Mode::Http2 => {
203 tokio::task::spawn(async move {
204 if let Err(err) = http2::Builder::new(TokioExecutor::new())
205 .serve_connection(io, service)
206 .await
207 {
208 eprintln!("Error serving connection: {}", err);
209 }
210 });
211 }
212 }
213 }
214 }
215
216 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
217 where
218 H: Handler + 'static,
219 {
220 self.router.insert(path, method, handler);
221 }
222
223 pub fn get<H>(&mut self, path: &str, handler: H)
224 where
225 H: Handler + 'static,
226 {
227 self.add_route(Method::GET, path, handler);
228 }
229
230 pub fn post<H>(&mut self, path: &str, handler: H)
231 where
232 H: Handler + 'static,
233 {
234 self.add_route(Method::POST, path, handler);
235 }
236
237 pub fn put<H>(&mut self, path: &str, handler: H)
238 where
239 H: Handler + 'static,
240 {
241 self.add_route(Method::PUT, path, handler);
242 }
243
244 pub fn patch<H>(&mut self, path: &str, handler: H)
245 where
246 H: Handler + 'static,
247 {
248 self.add_route(Method::PATCH, path, handler);
249 }
250
251 pub fn delete<H>(&mut self, path: &str, handler: H)
252 where
253 H: Handler + 'static,
254 {
255 self.add_route(Method::DELETE, path, handler);
256 }
257
258 pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
259 Group {
260 prefix: path.to_string(),
261 app: self,
262 }
263 }
264
265 pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
266 where
267 M: Middleware + 'static,
268 {
269 let mw = Arc::new(middleware_fn);
270 let full_path = path.to_string();
271
272 match method {
273 Some(m) => self.router.insert_middleware(&full_path, m, mw),
274 None => {
275 for m in [
276 Method::GET,
277 Method::POST,
278 Method::PUT,
279 Method::PATCH,
280 Method::DELETE,
281 Method::OPTIONS,
282 Method::HEAD,
283 Method::TRACE,
284 ] {
285 self.router.insert_middleware(&full_path, m, mw.clone());
286 }
287 }
288 }
289 }
290
291 pub fn set_error_handler(&mut self, handler: Arc<dyn ErrorHandler>) {
292 self.error_handler = handler;
293 }
294}