1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4 Request,
5 body::Incoming,
6 server::conn::{http1, http2},
7 service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10
11use tokio::{
12 io::{AsyncRead, AsyncWrite},
13 net::TcpListener,
14};
15
16use crate::{
17 client::Client,
18 group::Group,
19 middleware::error::DefaultErrorHandler,
20 request::RequestBody,
21 response::ResponseWriter,
22 router::Router,
23 types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
24};
25
26pub mod client;
27mod group;
28pub mod macros;
29pub mod middleware;
30pub mod request;
31pub mod response;
32mod router;
33pub mod types;
34pub use bolt_web_macro::main;
35pub use num_cpus;
36pub use paste;
37pub use tokio;
38
39trait Io: AsyncRead + AsyncWrite + Unpin {}
40impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
41
42#[allow(dead_code)]
43pub struct Bolt {
44 router: Router,
45 error_handler: Arc<dyn ErrorHandler>,
46 client: Client,
47}
48
49#[allow(unused_variables)]
50#[allow(dead_code)]
51impl Bolt {
52 pub fn new() -> Self {
53 Self {
54 router: Router::new(),
55 error_handler: Arc::new(DefaultErrorHandler),
56 client: Client::new(),
57 }
58 }
59
60 pub async fn run(
61 &self,
62 addr: &str,
63 mode: Mode,
64 tls: Option<(&str, &str)>,
65 ) -> Result<(), BoltError> {
66 println!("⚡ A high performance & minimalist web framework in rust.");
67 println!(
68 r#"
69 __ ____
70 / /_ ____ / / /_
71 / __ \/ __ \/ / __/
72 / /_/ / /_/ / / /_
73/_.___/\____/_/\__/ v0.2
74"#
75 );
76
77 println!(
78 ">> Server running on {}://{}",
79 if tls.is_some() { "https" } else { "http" },
80 addr
81 );
82
83 let addr: SocketAddr = addr.parse().unwrap();
84
85 let listener = TcpListener::bind(addr).await?;
86 let router = Arc::new(self.router.clone());
87
88 let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
89 let pkcs12 = std::fs::read(pkcs12_path)?;
90 let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
91 Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
92 tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
93 )))
94 } else {
95 None
96 };
97
98 loop {
99 let (stream, _) = listener.accept().await?;
100
101 let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
102 Box::new(acceptor.accept(stream).await?)
103 } else {
104 Box::new(stream)
105 };
106
107 let io = TokioIo::new(io);
108
109 let router = router.clone();
110 let error_handler = self.error_handler.clone();
111
112 let service = service_fn(move |req: Request<Incoming>| {
113 let router = router.clone();
114
115 let error_handler = error_handler.clone();
116
117 async move {
118 let mut req_body = RequestBody::new(req);
119 let mut res_body = ResponseWriter::new();
120
121 let method = match *req_body.method() {
122 hyper::Method::GET => Method::GET,
123 hyper::Method::POST => Method::POST,
124 hyper::Method::PUT => Method::PUT,
125 hyper::Method::PATCH => Method::PATCH,
126 hyper::Method::DELETE => Method::DELETE,
127 hyper::Method::OPTIONS => Method::OPTIONS,
128 hyper::Method::HEAD => Method::HEAD,
129 hyper::Method::TRACE => Method::TRACE,
130 _ => {
131 res_body.status(405);
132 res_body.send("Method Not Allowed");
133 return Ok::<_, Infallible>(res_body.into_response());
134 }
135 };
136
137 let path = req_body.path().to_string();
138
139 let mws = router.collect_middleware(&path, method);
140
141 for mw in mws {
142 mw.run(&mut req_body, &mut res_body).await;
143
144 if res_body.has_error() {
145 break;
146 }
147 }
148
149 if !res_body.has_error() {
150 if let Some((handler, params)) = router.find(&path, method) {
151 req_body.set_params(params);
152
153 handler.handle(&mut req_body, &mut res_body).await;
154 } else {
155 let method_str = match *req_body.method() {
156 hyper::Method::GET => "GET",
157 hyper::Method::POST => "POST",
158 hyper::Method::PUT => "PUT",
159 hyper::Method::PATCH => "PATCH",
160 hyper::Method::DELETE => "DELETE",
161 hyper::Method::OPTIONS => "OPTIONS",
162 hyper::Method::HEAD => "HEAD",
163 hyper::Method::TRACE => "TRACE",
164 _ => "UNKNOWN",
165 };
166
167 res_body.error(404, &format!("Not Found {} {}", method_str, path));
168 }
169 }
170
171 if res_body.has_error() {
172 let msg = res_body.body.clone();
173 error_handler.run(msg, &mut res_body).await;
174 }
175
176 req_body.cleanup().await;
177
178 if req_body.log {
179 println!(
180 "[LOG] method={} path={} status={}",
181 req_body.method(),
182 path,
183 res_body.status,
184 );
185 }
186
187 res_body.strip_header("X-Internal-Request-Start");
188
189 Ok::<_, Infallible>(res_body.into_response())
190 }
191 });
192
193 match mode {
194 Mode::Http1 => {
195 tokio::task::spawn(async move {
196 if let Err(err) = http1::Builder::new().serve_connection(io, service).await
197 {
198 eprintln!("Error serving connection: {}", err);
199 }
200 });
201 }
202
203 Mode::Http2 => {
204 tokio::task::spawn(async move {
205 if let Err(err) = http2::Builder::new(TokioExecutor::new())
206 .serve_connection(io, service)
207 .await
208 {
209 eprintln!("Error serving connection: {}", err);
210 }
211 });
212 }
213 }
214 }
215 }
216
217 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
218 where
219 H: Handler + 'static,
220 {
221 self.router.insert(path, method, handler);
222 }
223
224 pub fn get<H>(&mut self, path: &str, handler: H)
225 where
226 H: Handler + 'static,
227 {
228 self.add_route(Method::GET, path, handler);
229 }
230
231 pub fn post<H>(&mut self, path: &str, handler: H)
232 where
233 H: Handler + 'static,
234 {
235 self.add_route(Method::POST, path, handler);
236 }
237
238 pub fn put<H>(&mut self, path: &str, handler: H)
239 where
240 H: Handler + 'static,
241 {
242 self.add_route(Method::PUT, path, handler);
243 }
244
245 pub fn patch<H>(&mut self, path: &str, handler: H)
246 where
247 H: Handler + 'static,
248 {
249 self.add_route(Method::PATCH, path, handler);
250 }
251
252 pub fn delete<H>(&mut self, path: &str, handler: H)
253 where
254 H: Handler + 'static,
255 {
256 self.add_route(Method::DELETE, path, handler);
257 }
258
259 pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
260 Group {
261 prefix: path.to_string(),
262 app: self,
263 }
264 }
265
266 pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
267 where
268 M: Middleware + 'static,
269 {
270 let mw = Arc::new(middleware_fn);
271 let full_path = path.to_string();
272
273 match method {
274 Some(m) => self.router.insert_middleware(&full_path, m, mw),
275 None => {
276 for m in [
277 Method::GET,
278 Method::POST,
279 Method::PUT,
280 Method::PATCH,
281 Method::DELETE,
282 Method::OPTIONS,
283 Method::HEAD,
284 Method::TRACE,
285 ] {
286 self.router.insert_middleware(&full_path, m, mw.clone());
287 }
288 }
289 }
290 }
291
292 pub fn set_error_handler(&mut self, handler: Arc<dyn ErrorHandler>) {
293 self.error_handler = handler;
294 }
295}