1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4 Request,
5 body::Incoming,
6 server::conn::{http1, http2},
7 service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10
11use tokio::{
12 io::{AsyncRead, AsyncWrite},
13 net::TcpListener,
14};
15
16use crate::{
17 client::Client,
18 group::Group,
19 middleware::error::DefaultErrorHandler,
20 request::RequestBody,
21 response::ResponseWriter,
22 router::Router,
23 types::{BoltResult, ErrorHandler, Handler, Method, Middleware, Mode},
24};
25
26pub mod client;
27mod group;
28pub mod macros;
29pub mod middleware;
30pub mod request;
31pub mod response;
32mod router;
33pub mod types;
34pub use async_trait;
35pub use bolt_web_macro::main;
36pub use paste;
37pub use tokio;
38
39trait Io: AsyncRead + AsyncWrite + Unpin {}
40impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
41
42#[allow(dead_code)]
43pub struct Bolt {
44 router: Router,
45 error_handler: Arc<dyn ErrorHandler>,
46 client: Client,
47}
48
49#[allow(unused_variables)]
50#[allow(dead_code)]
51impl Bolt {
52 pub fn new() -> Self {
53 Self {
54 router: Router::new(),
55 error_handler: Arc::new(DefaultErrorHandler),
56 client: Client::new(),
57 }
58 }
59
60 pub async fn run(&self, addr: &str, mode: Mode, tls: Option<(&str, &str)>) -> BoltResult<()> {
61 println!("⚡ A high performance & minimalist web framework in rust.");
62 println!(
63 r#"
64 __ ____
65 / /_ ____ / / /_
66 / __ \/ __ \/ / __/
67 / /_/ / /_/ / / /_
68/_.___/\____/_/\__/ v0.2
69"#
70 );
71
72 println!(
73 ">> Server running on {}://{}",
74 if tls.is_some() { "https" } else { "http" },
75 addr
76 );
77
78 let addr: SocketAddr = addr.parse().unwrap();
79
80 let listener = TcpListener::bind(addr).await?;
81 let router = Arc::new(self.router.clone());
82
83 let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
84 let pkcs12 = std::fs::read(pkcs12_path)?;
85 let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
86 Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
87 tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
88 )))
89 } else {
90 None
91 };
92
93 loop {
94 let (stream, _) = listener.accept().await?;
95
96 let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
97 Box::new(acceptor.accept(stream).await?)
98 } else {
99 Box::new(stream)
100 };
101
102 let io = TokioIo::new(io);
103
104 let router = router.clone();
105 let error_handler = self.error_handler.clone();
106
107 let service = service_fn(move |req: Request<Incoming>| {
108 let router = router.clone();
109
110 let error_handler = error_handler.clone();
111
112 async move {
113 let mut req_body = RequestBody::new(req);
114 let mut res_body = ResponseWriter::new();
115
116 let method = match *req_body.method() {
117 hyper::Method::GET => Method::GET,
118 hyper::Method::POST => Method::POST,
119 hyper::Method::PUT => Method::PUT,
120 hyper::Method::PATCH => Method::PATCH,
121 hyper::Method::DELETE => Method::DELETE,
122 hyper::Method::OPTIONS => Method::OPTIONS,
123 hyper::Method::HEAD => Method::HEAD,
124 hyper::Method::TRACE => Method::TRACE,
125 _ => {
126 res_body.status(405);
127 res_body.send("Method Not Allowed");
128 return Ok::<_, Infallible>(res_body.into_response());
129 }
130 };
131
132 let path = req_body.path().to_string();
133
134 let mws = router.collect_middleware(&path, method);
135
136 for mw in mws {
137 mw.run(&mut req_body, &mut res_body).await;
138
139 if res_body.has_error() {
140 break;
141 }
142 }
143
144 if !res_body.has_error() {
145 if let Some((handler, params)) = router.find(&path, method) {
146 req_body.set_params(params);
147
148 handler.handle(&mut req_body, &mut res_body).await;
149 } else {
150 let method_str = match *req_body.method() {
151 hyper::Method::GET => "GET",
152 hyper::Method::POST => "POST",
153 hyper::Method::PUT => "PUT",
154 hyper::Method::PATCH => "PATCH",
155 hyper::Method::DELETE => "DELETE",
156 hyper::Method::OPTIONS => "OPTIONS",
157 hyper::Method::HEAD => "HEAD",
158 hyper::Method::TRACE => "TRACE",
159 _ => "UNKNOWN",
160 };
161
162 res_body.error(404, &format!("Not Found {} {}", method_str, path));
163 }
164 }
165
166 if res_body.has_error() {
167 let msg = res_body.body.clone();
168 error_handler.run(msg, &mut res_body).await;
169 }
170
171 req_body.cleanup().await;
172
173 if req_body.log {
174 println!(
175 "[LOG] method={} path={} status={}",
176 req_body.method(),
177 path,
178 res_body.status,
179 );
180 }
181
182 res_body.strip_header("X-Internal-Request-Start");
183
184 Ok::<_, Infallible>(res_body.into_response())
185 }
186 });
187
188 match mode {
189 Mode::Http1 => {
190 tokio::task::spawn(async move {
191 if let Err(err) = http1::Builder::new().serve_connection(io, service).await
192 {
193 eprintln!("Error serving connection: {}", err);
194 }
195 });
196 }
197
198 Mode::Http2 => {
199 tokio::task::spawn(async move {
200 if let Err(err) = http2::Builder::new(TokioExecutor::new())
201 .serve_connection(io, service)
202 .await
203 {
204 eprintln!("Error serving connection: {}", err);
205 }
206 });
207 }
208 }
209 }
210 }
211
212 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
213 where
214 H: Handler + 'static,
215 {
216 self.router.insert(path, method, handler);
217 }
218
219 pub fn get<H>(&mut self, path: &str, handler: H)
220 where
221 H: Handler + 'static,
222 {
223 self.add_route(Method::GET, path, handler);
224 }
225
226 pub fn post<H>(&mut self, path: &str, handler: H)
227 where
228 H: Handler + 'static,
229 {
230 self.add_route(Method::POST, path, handler);
231 }
232
233 pub fn put<H>(&mut self, path: &str, handler: H)
234 where
235 H: Handler + 'static,
236 {
237 self.add_route(Method::PUT, path, handler);
238 }
239
240 pub fn patch<H>(&mut self, path: &str, handler: H)
241 where
242 H: Handler + 'static,
243 {
244 self.add_route(Method::PATCH, path, handler);
245 }
246
247 pub fn delete<H>(&mut self, path: &str, handler: H)
248 where
249 H: Handler + 'static,
250 {
251 self.add_route(Method::DELETE, path, handler);
252 }
253
254 pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
255 Group {
256 prefix: path.to_string(),
257 app: self,
258 }
259 }
260
261 pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
262 where
263 M: Middleware + 'static,
264 {
265 let mw = Arc::new(middleware_fn);
266 let full_path = path.to_string();
267
268 match method {
269 Some(m) => self.router.insert_middleware(&full_path, m, mw),
270 None => {
271 for m in [
272 Method::GET,
273 Method::POST,
274 Method::PUT,
275 Method::PATCH,
276 Method::DELETE,
277 Method::OPTIONS,
278 Method::HEAD,
279 Method::TRACE,
280 ] {
281 self.router.insert_middleware(&full_path, m, mw.clone());
282 }
283 }
284 }
285 }
286
287 pub fn set_error_handler<E>(&mut self, handler: E)
288 where
289 E: ErrorHandler + 'static,
290 {
291 self.error_handler = Arc::new(handler);
292 }
293}