1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4 Request,
5 body::Incoming,
6 server::conn::{http1, http2},
7 service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10
11use tokio::{
12 io::{AsyncRead, AsyncWrite},
13 net::TcpListener,
14};
15
16use crate::{
17 client::Client,
18 group::Group,
19 middleware::error::DefaultErrorHandler,
20 request::RequestBody,
21 response::ResponseWriter,
22 router::Router,
23 types::{BoltResult, ErrorHandler, Handler, Method, Middleware, Mode},
24};
25
26pub mod client;
27mod group;
28pub mod macros;
29pub mod middleware;
30pub mod request;
31pub mod response;
32mod router;
33pub mod types;
34pub use bolt_web_macro::main;
35pub use paste;
36pub use tokio;
37
38trait Io: AsyncRead + AsyncWrite + Unpin {}
39impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
40
41#[allow(dead_code)]
42pub struct Bolt {
43 router: Router,
44 error_handler: Arc<dyn ErrorHandler>,
45 client: Client,
46}
47
48#[allow(unused_variables)]
49#[allow(dead_code)]
50impl Bolt {
51 pub fn new() -> Self {
52 Self {
53 router: Router::new(),
54 error_handler: Arc::new(DefaultErrorHandler),
55 client: Client::new(),
56 }
57 }
58
59 pub async fn run(&self, addr: &str, mode: Mode, tls: Option<(&str, &str)>) -> BoltResult<()> {
60 println!("⚡ A high performance & minimalist web framework in rust.");
61 println!(
62 r#"
63 __ ____
64 / /_ ____ / / /_
65 / __ \/ __ \/ / __/
66 / /_/ / /_/ / / /_
67/_.___/\____/_/\__/ v0.2
68"#
69 );
70
71 println!(
72 ">> Server running on {}://{}",
73 if tls.is_some() { "https" } else { "http" },
74 addr
75 );
76
77 let addr: SocketAddr = addr.parse().unwrap();
78
79 let listener = TcpListener::bind(addr).await?;
80 let router = Arc::new(self.router.clone());
81
82 let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
83 let pkcs12 = std::fs::read(pkcs12_path)?;
84 let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
85 Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
86 tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
87 )))
88 } else {
89 None
90 };
91
92 loop {
93 let (stream, _) = listener.accept().await?;
94
95 let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
96 Box::new(acceptor.accept(stream).await?)
97 } else {
98 Box::new(stream)
99 };
100
101 let io = TokioIo::new(io);
102
103 let router = router.clone();
104 let error_handler = self.error_handler.clone();
105
106 let service = service_fn(move |req: Request<Incoming>| {
107 let router = router.clone();
108
109 let error_handler = error_handler.clone();
110
111 async move {
112 let mut req_body = RequestBody::new(req);
113 let mut res_body = ResponseWriter::new();
114
115 let method = match *req_body.method() {
116 hyper::Method::GET => Method::GET,
117 hyper::Method::POST => Method::POST,
118 hyper::Method::PUT => Method::PUT,
119 hyper::Method::PATCH => Method::PATCH,
120 hyper::Method::DELETE => Method::DELETE,
121 hyper::Method::OPTIONS => Method::OPTIONS,
122 hyper::Method::HEAD => Method::HEAD,
123 hyper::Method::TRACE => Method::TRACE,
124 _ => {
125 res_body.status(405);
126 res_body.send("Method Not Allowed");
127 return Ok::<_, Infallible>(res_body.into_response());
128 }
129 };
130
131 let path = req_body.path().to_string();
132
133 let mws = router.collect_middleware(&path, method);
134
135 for mw in mws {
136 mw.run(&mut req_body, &mut res_body).await;
137
138 if res_body.has_error() {
139 break;
140 }
141 }
142
143 if !res_body.has_error() {
144 if let Some((handler, params)) = router.find(&path, method) {
145 req_body.set_params(params);
146
147 handler.handle(&mut req_body, &mut res_body).await;
148 } else {
149 let method_str = match *req_body.method() {
150 hyper::Method::GET => "GET",
151 hyper::Method::POST => "POST",
152 hyper::Method::PUT => "PUT",
153 hyper::Method::PATCH => "PATCH",
154 hyper::Method::DELETE => "DELETE",
155 hyper::Method::OPTIONS => "OPTIONS",
156 hyper::Method::HEAD => "HEAD",
157 hyper::Method::TRACE => "TRACE",
158 _ => "UNKNOWN",
159 };
160
161 res_body.error(404, &format!("Not Found {} {}", method_str, path));
162 }
163 }
164
165 if res_body.has_error() {
166 let msg = res_body.body.clone();
167 error_handler.run(msg, &mut res_body).await;
168 }
169
170 req_body.cleanup().await;
171
172 if req_body.log {
173 println!(
174 "[LOG] method={} path={} status={}",
175 req_body.method(),
176 path,
177 res_body.status,
178 );
179 }
180
181 res_body.strip_header("X-Internal-Request-Start");
182
183 Ok::<_, Infallible>(res_body.into_response())
184 }
185 });
186
187 match mode {
188 Mode::Http1 => {
189 tokio::task::spawn(async move {
190 if let Err(err) = http1::Builder::new().serve_connection(io, service).await
191 {
192 eprintln!("Error serving connection: {}", err);
193 }
194 });
195 }
196
197 Mode::Http2 => {
198 tokio::task::spawn(async move {
199 if let Err(err) = http2::Builder::new(TokioExecutor::new())
200 .serve_connection(io, service)
201 .await
202 {
203 eprintln!("Error serving connection: {}", err);
204 }
205 });
206 }
207 }
208 }
209 }
210
211 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
212 where
213 H: Handler + 'static,
214 {
215 self.router.insert(path, method, handler);
216 }
217
218 pub fn get<H>(&mut self, path: &str, handler: H)
219 where
220 H: Handler + 'static,
221 {
222 self.add_route(Method::GET, path, handler);
223 }
224
225 pub fn post<H>(&mut self, path: &str, handler: H)
226 where
227 H: Handler + 'static,
228 {
229 self.add_route(Method::POST, path, handler);
230 }
231
232 pub fn put<H>(&mut self, path: &str, handler: H)
233 where
234 H: Handler + 'static,
235 {
236 self.add_route(Method::PUT, path, handler);
237 }
238
239 pub fn patch<H>(&mut self, path: &str, handler: H)
240 where
241 H: Handler + 'static,
242 {
243 self.add_route(Method::PATCH, path, handler);
244 }
245
246 pub fn delete<H>(&mut self, path: &str, handler: H)
247 where
248 H: Handler + 'static,
249 {
250 self.add_route(Method::DELETE, path, handler);
251 }
252
253 pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
254 Group {
255 prefix: path.to_string(),
256 app: self,
257 }
258 }
259
260 pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
261 where
262 M: Middleware + 'static,
263 {
264 let mw = Arc::new(middleware_fn);
265 let full_path = path.to_string();
266
267 match method {
268 Some(m) => self.router.insert_middleware(&full_path, m, mw),
269 None => {
270 for m in [
271 Method::GET,
272 Method::POST,
273 Method::PUT,
274 Method::PATCH,
275 Method::DELETE,
276 Method::OPTIONS,
277 Method::HEAD,
278 Method::TRACE,
279 ] {
280 self.router.insert_middleware(&full_path, m, mw.clone());
281 }
282 }
283 }
284 }
285
286 pub fn set_error_handler(&mut self, handler: Arc<dyn ErrorHandler>) {
287 self.error_handler = handler;
288 }
289}