bolt_web/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4    Request,
5    body::Incoming,
6    server::conn::{http1, http2},
7    service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10
11use tokio::{
12    io::{AsyncRead, AsyncWrite},
13    net::TcpListener,
14};
15
16use crate::{
17    client::Client,
18    group::Group,
19    middleware::error::DefaultErrorHandler,
20    request::RequestBody,
21    response::ResponseWriter,
22    router::Router,
23    types::{BoltResult, ErrorHandler, Handler, Method, Middleware, Mode},
24};
25
26pub mod client;
27mod group;
28pub mod macros;
29pub mod middleware;
30pub mod request;
31pub mod response;
32mod router;
33pub mod types;
34pub use bolt_web_macro::main;
35pub use paste;
36pub use tokio;
37
38trait Io: AsyncRead + AsyncWrite + Unpin {}
39impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
40
41#[allow(dead_code)]
42pub struct Bolt {
43    router: Router,
44    error_handler: Arc<dyn ErrorHandler>,
45    client: Client,
46}
47
48#[allow(unused_variables)]
49#[allow(dead_code)]
50impl Bolt {
51    pub fn new() -> Self {
52        Self {
53            router: Router::new(),
54            error_handler: Arc::new(DefaultErrorHandler),
55            client: Client::new(),
56        }
57    }
58
59    pub async fn run(&self, addr: &str, mode: Mode, tls: Option<(&str, &str)>) -> BoltResult<()> {
60        println!("⚡ A high performance & minimalist web framework in rust.");
61        println!(
62            r#"
63    __          ____
64   / /_  ____  / / /_
65  / __ \/ __ \/ / __/
66 / /_/ / /_/ / / /_  
67/_.___/\____/_/\__/  v0.2
68"#
69        );
70
71        println!(
72            ">> Server running on {}://{}",
73            if tls.is_some() { "https" } else { "http" },
74            addr
75        );
76
77        let addr: SocketAddr = addr.parse().unwrap();
78
79        let listener = TcpListener::bind(addr).await?;
80        let router = Arc::new(self.router.clone());
81
82        let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
83            let pkcs12 = std::fs::read(pkcs12_path)?;
84            let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
85            Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
86                tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
87            )))
88        } else {
89            None
90        };
91
92        loop {
93            let (stream, _) = listener.accept().await?;
94
95            let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
96                Box::new(acceptor.accept(stream).await?)
97            } else {
98                Box::new(stream)
99            };
100
101            let io = TokioIo::new(io);
102
103            let router = router.clone();
104            let error_handler = self.error_handler.clone();
105
106            let service = service_fn(move |req: Request<Incoming>| {
107                let router = router.clone();
108
109                let error_handler = error_handler.clone();
110
111                async move {
112                    let mut req_body = RequestBody::new(req);
113                    let mut res_body = ResponseWriter::new();
114
115                    let method = match *req_body.method() {
116                        hyper::Method::GET => Method::GET,
117                        hyper::Method::POST => Method::POST,
118                        hyper::Method::PUT => Method::PUT,
119                        hyper::Method::PATCH => Method::PATCH,
120                        hyper::Method::DELETE => Method::DELETE,
121                        hyper::Method::OPTIONS => Method::OPTIONS,
122                        hyper::Method::HEAD => Method::HEAD,
123                        hyper::Method::TRACE => Method::TRACE,
124                        _ => {
125                            res_body.status(405);
126                            res_body.send("Method Not Allowed");
127                            return Ok::<_, Infallible>(res_body.into_response());
128                        }
129                    };
130
131                    let path = req_body.path().to_string();
132
133                    let mws = router.collect_middleware(&path, method);
134
135                    for mw in mws {
136                        mw.run(&mut req_body, &mut res_body).await;
137
138                        if res_body.has_error() {
139                            break;
140                        }
141                    }
142
143                    if !res_body.has_error() {
144                        if let Some((handler, params)) = router.find(&path, method) {
145                            req_body.set_params(params);
146
147                            handler.handle(&mut req_body, &mut res_body).await;
148                        } else {
149                            let method_str = match *req_body.method() {
150                                hyper::Method::GET => "GET",
151                                hyper::Method::POST => "POST",
152                                hyper::Method::PUT => "PUT",
153                                hyper::Method::PATCH => "PATCH",
154                                hyper::Method::DELETE => "DELETE",
155                                hyper::Method::OPTIONS => "OPTIONS",
156                                hyper::Method::HEAD => "HEAD",
157                                hyper::Method::TRACE => "TRACE",
158                                _ => "UNKNOWN",
159                            };
160
161                            res_body.error(404, &format!("Not Found {} {}", method_str, path));
162                        }
163                    }
164
165                    if res_body.has_error() {
166                        let msg = res_body.body.clone();
167                        error_handler.run(msg, &mut res_body).await;
168                    }
169
170                    req_body.cleanup().await;
171
172                    if req_body.log {
173                        println!(
174                            "[LOG] method={} path={} status={}",
175                            req_body.method(),
176                            path,
177                            res_body.status,
178                        );
179                    }
180
181                    res_body.strip_header("X-Internal-Request-Start");
182
183                    Ok::<_, Infallible>(res_body.into_response())
184                }
185            });
186
187            match mode {
188                Mode::Http1 => {
189                    tokio::task::spawn(async move {
190                        if let Err(err) = http1::Builder::new().serve_connection(io, service).await
191                        {
192                            eprintln!("Error serving connection: {}", err);
193                        }
194                    });
195                }
196
197                Mode::Http2 => {
198                    tokio::task::spawn(async move {
199                        if let Err(err) = http2::Builder::new(TokioExecutor::new())
200                            .serve_connection(io, service)
201                            .await
202                        {
203                            eprintln!("Error serving connection: {}", err);
204                        }
205                    });
206                }
207            }
208        }
209    }
210
211    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
212    where
213        H: Handler + 'static,
214    {
215        self.router.insert(path, method, handler);
216    }
217
218    pub fn get<H>(&mut self, path: &str, handler: H)
219    where
220        H: Handler + 'static,
221    {
222        self.add_route(Method::GET, path, handler);
223    }
224
225    pub fn post<H>(&mut self, path: &str, handler: H)
226    where
227        H: Handler + 'static,
228    {
229        self.add_route(Method::POST, path, handler);
230    }
231
232    pub fn put<H>(&mut self, path: &str, handler: H)
233    where
234        H: Handler + 'static,
235    {
236        self.add_route(Method::PUT, path, handler);
237    }
238
239    pub fn patch<H>(&mut self, path: &str, handler: H)
240    where
241        H: Handler + 'static,
242    {
243        self.add_route(Method::PATCH, path, handler);
244    }
245
246    pub fn delete<H>(&mut self, path: &str, handler: H)
247    where
248        H: Handler + 'static,
249    {
250        self.add_route(Method::DELETE, path, handler);
251    }
252
253    pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
254        Group {
255            prefix: path.to_string(),
256            app: self,
257        }
258    }
259
260    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
261    where
262        M: Middleware + 'static,
263    {
264        let mw = Arc::new(middleware_fn);
265        let full_path = path.to_string();
266
267        match method {
268            Some(m) => self.router.insert_middleware(&full_path, m, mw),
269            None => {
270                for m in [
271                    Method::GET,
272                    Method::POST,
273                    Method::PUT,
274                    Method::PATCH,
275                    Method::DELETE,
276                    Method::OPTIONS,
277                    Method::HEAD,
278                    Method::TRACE,
279                ] {
280                    self.router.insert_middleware(&full_path, m, mw.clone());
281                }
282            }
283        }
284    }
285
286    pub fn set_error_handler(&mut self, handler: Arc<dyn ErrorHandler>) {
287        self.error_handler = handler;
288    }
289}