bolt_web/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4    Request,
5    body::Incoming,
6    server::conn::{http1, http2},
7    service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10
11use tokio::{
12    io::{AsyncRead, AsyncWrite},
13    net::TcpListener,
14};
15
16use crate::{
17    client::Client,
18    group::Group,
19    middleware::error::DefaultErrorHandler,
20    request::RequestBody,
21    response::ResponseWriter,
22    router::Router,
23    types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
24};
25
26pub mod client;
27mod group;
28pub mod macros;
29pub mod middleware;
30pub mod request;
31pub mod response;
32mod router;
33pub mod types;
34pub use bolt_web_macro::main;
35pub use paste;
36pub use tokio;
37
38trait Io: AsyncRead + AsyncWrite + Unpin {}
39impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
40
41#[allow(dead_code)]
42pub struct Bolt {
43    router: Router,
44    error_handler: Arc<dyn ErrorHandler>,
45    client: Client,
46}
47
48#[allow(unused_variables)]
49#[allow(dead_code)]
50impl Bolt {
51    pub fn new() -> Self {
52        Self {
53            router: Router::new(),
54            error_handler: Arc::new(DefaultErrorHandler),
55            client: Client::new(),
56        }
57    }
58
59    pub async fn run(
60        &self,
61        addr: &str,
62        mode: Mode,
63        tls: Option<(&str, &str)>,
64    ) -> Result<(), BoltError> {
65        println!("⚡ A high performance & minimalist web framework in rust.");
66        println!(
67            r#"
68    __          ____
69   / /_  ____  / / /_
70  / __ \/ __ \/ / __/
71 / /_/ / /_/ / / /_  
72/_.___/\____/_/\__/  v0.2.0
73"#
74        );
75
76        println!(
77            ">> Server running on {}://{}",
78            if tls.is_some() { "https" } else { "http" },
79            addr
80        );
81
82        let addr: SocketAddr = addr.parse().unwrap();
83
84        let listener = TcpListener::bind(addr).await?;
85        let router = Arc::new(self.router.clone());
86
87        let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
88            let pkcs12 = std::fs::read(pkcs12_path)?;
89            let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
90            Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
91                tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
92            )))
93        } else {
94            None
95        };
96
97        loop {
98            let (stream, _) = listener.accept().await?;
99
100            let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
101                Box::new(acceptor.accept(stream).await?)
102            } else {
103                Box::new(stream)
104            };
105
106            let io = TokioIo::new(io);
107
108            let router = router.clone();
109            let error_handler = self.error_handler.clone();
110
111            let service = service_fn(move |req: Request<Incoming>| {
112                let router = router.clone();
113
114                let error_handler = error_handler.clone();
115
116                async move {
117                    let mut req_body = RequestBody::new(req);
118                    let mut res_body = ResponseWriter::new();
119
120                    let method = match *req_body.method() {
121                        hyper::Method::GET => Method::GET,
122                        hyper::Method::POST => Method::POST,
123                        hyper::Method::PUT => Method::PUT,
124                        hyper::Method::PATCH => Method::PATCH,
125                        hyper::Method::DELETE => Method::DELETE,
126                        hyper::Method::OPTIONS => Method::OPTIONS,
127                        hyper::Method::HEAD => Method::HEAD,
128                        hyper::Method::TRACE => Method::TRACE,
129                        _ => {
130                            res_body.status(405);
131                            res_body.send("Method Not Allowed");
132                            return Ok::<_, Infallible>(res_body.into_response());
133                        }
134                    };
135
136                    let path = req_body.path().to_string();
137
138                    let mws = router.collect_middleware(&path, method);
139
140                    for mw in mws {
141                        mw.run(&mut req_body, &mut res_body).await;
142
143                        if res_body.has_error() {
144                            break;
145                        }
146                    }
147
148                    if !res_body.has_error() {
149                        if let Some((handler, params)) = router.find(&path, method) {
150                            req_body.set_params(params);
151
152                            handler.handle(&mut req_body, &mut res_body).await;
153                        } else {
154                            let method_str = match *req_body.method() {
155                                hyper::Method::GET => "GET",
156                                hyper::Method::POST => "POST",
157                                hyper::Method::PUT => "PUT",
158                                hyper::Method::PATCH => "PATCH",
159                                hyper::Method::DELETE => "DELETE",
160                                hyper::Method::OPTIONS => "OPTIONS",
161                                hyper::Method::HEAD => "HEAD",
162                                hyper::Method::TRACE => "TRACE",
163                                _ => "UNKNOWN",
164                            };
165
166                            res_body.error(404, &format!("Not Found {} {}", method_str, path));
167                        }
168                    }
169
170                    if res_body.has_error() {
171                        let msg = res_body.body.clone();
172                        error_handler.run(msg, &mut res_body).await;
173                    }
174
175                    req_body.cleanup().await;
176
177                    if req_body.log {
178                        println!(
179                            "[LOG] method={} path={} status={}",
180                            req_body.method(),
181                            path,
182                            res_body.status,
183                        );
184                    }
185
186                    res_body.strip_header("X-Internal-Request-Start");
187
188                    Ok::<_, Infallible>(res_body.into_response())
189                }
190            });
191
192            match mode {
193                Mode::Http1 => {
194                    tokio::task::spawn(async move {
195                        if let Err(err) = http1::Builder::new().serve_connection(io, service).await
196                        {
197                            eprintln!("Error serving connection: {}", err);
198                        }
199                    });
200                }
201
202                Mode::Http2 => {
203                    tokio::task::spawn(async move {
204                        if let Err(err) = http2::Builder::new(TokioExecutor::new())
205                            .serve_connection(io, service)
206                            .await
207                        {
208                            eprintln!("Error serving connection: {}", err);
209                        }
210                    });
211                }
212            }
213        }
214    }
215
216    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
217    where
218        H: Handler + 'static,
219    {
220        self.router.insert(path, method, handler);
221    }
222
223    pub fn get<H>(&mut self, path: &str, handler: H)
224    where
225        H: Handler + 'static,
226    {
227        self.add_route(Method::GET, path, handler);
228    }
229
230    pub fn post<H>(&mut self, path: &str, handler: H)
231    where
232        H: Handler + 'static,
233    {
234        self.add_route(Method::POST, path, handler);
235    }
236
237    pub fn put<H>(&mut self, path: &str, handler: H)
238    where
239        H: Handler + 'static,
240    {
241        self.add_route(Method::PUT, path, handler);
242    }
243
244    pub fn patch<H>(&mut self, path: &str, handler: H)
245    where
246        H: Handler + 'static,
247    {
248        self.add_route(Method::PATCH, path, handler);
249    }
250
251    pub fn delete<H>(&mut self, path: &str, handler: H)
252    where
253        H: Handler + 'static,
254    {
255        self.add_route(Method::DELETE, path, handler);
256    }
257
258    pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
259        Group {
260            prefix: path.to_string(),
261            app: self,
262        }
263    }
264
265    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
266    where
267        M: Middleware + 'static,
268    {
269        let mw = Arc::new(middleware_fn);
270        let full_path = path.to_string();
271
272        match method {
273            Some(m) => self.router.insert_middleware(&full_path, m, mw),
274            None => {
275                for m in [
276                    Method::GET,
277                    Method::POST,
278                    Method::PUT,
279                    Method::PATCH,
280                    Method::DELETE,
281                    Method::OPTIONS,
282                    Method::HEAD,
283                    Method::TRACE,
284                ] {
285                    self.router.insert_middleware(&full_path, m, mw.clone());
286                }
287            }
288        }
289    }
290
291    pub fn set_error_handler(&mut self, handler: Arc<dyn ErrorHandler>) {
292        self.error_handler = handler;
293    }
294}