bolt_web/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4    Request,
5    body::Incoming,
6    server::conn::{http1, http2},
7    service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10
11use tokio::{
12    io::{AsyncRead, AsyncWrite},
13    net::TcpListener,
14};
15
16use crate::{
17    client::Client,
18    group::Group,
19    middleware::error::DefaultErrorHandler,
20    request::RequestBody,
21    response::ResponseWriter,
22    router::Router,
23    types::{BoltResult, ErrorHandler, Handler, Method, Middleware, Mode},
24};
25
26pub mod client;
27mod group;
28pub mod macros;
29pub mod middleware;
30pub mod request;
31pub mod response;
32mod router;
33pub mod types;
34pub use async_trait;
35pub use bolt_web_macro::main;
36pub use paste;
37pub use tokio;
38
39trait Io: AsyncRead + AsyncWrite + Unpin {}
40impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
41
42#[allow(dead_code)]
43pub struct Bolt {
44    router: Router,
45    error_handler: Arc<dyn ErrorHandler>,
46    client: Client,
47}
48
49#[allow(unused_variables)]
50#[allow(dead_code)]
51impl Bolt {
52    pub fn new() -> Self {
53        Self {
54            router: Router::new(),
55            error_handler: Arc::new(DefaultErrorHandler),
56            client: Client::new(),
57        }
58    }
59
60    pub async fn run(&self, addr: &str, mode: Mode, tls: Option<(&str, &str)>) -> BoltResult<()> {
61        println!("⚡ A high performance & minimalist web framework in rust.");
62        println!(
63            r#"
64    __          ____
65   / /_  ____  / / /_
66  / __ \/ __ \/ / __/
67 / /_/ / /_/ / / /_  
68/_.___/\____/_/\__/  v0.2
69"#
70        );
71
72        println!(
73            ">> Server running on {}://{}",
74            if tls.is_some() { "https" } else { "http" },
75            addr
76        );
77
78        let addr: SocketAddr = addr.parse().unwrap();
79
80        let listener = TcpListener::bind(addr).await?;
81        let router = Arc::new(self.router.clone());
82
83        let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
84            let pkcs12 = std::fs::read(pkcs12_path)?;
85            let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
86            Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
87                tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
88            )))
89        } else {
90            None
91        };
92
93        loop {
94            let (stream, _) = listener.accept().await?;
95
96            let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
97                Box::new(acceptor.accept(stream).await?)
98            } else {
99                Box::new(stream)
100            };
101
102            let io = TokioIo::new(io);
103
104            let router = router.clone();
105            let error_handler = self.error_handler.clone();
106
107            let service = service_fn(move |req: Request<Incoming>| {
108                let router = router.clone();
109
110                let error_handler = error_handler.clone();
111
112                async move {
113                    let mut req_body = RequestBody::new(req);
114                    let mut res_body = ResponseWriter::new();
115
116                    let method = match *req_body.method() {
117                        hyper::Method::GET => Method::GET,
118                        hyper::Method::POST => Method::POST,
119                        hyper::Method::PUT => Method::PUT,
120                        hyper::Method::PATCH => Method::PATCH,
121                        hyper::Method::DELETE => Method::DELETE,
122                        hyper::Method::OPTIONS => Method::OPTIONS,
123                        hyper::Method::HEAD => Method::HEAD,
124                        hyper::Method::TRACE => Method::TRACE,
125                        _ => {
126                            res_body.status(405);
127                            res_body.send("Method Not Allowed");
128                            return Ok::<_, Infallible>(res_body.into_response());
129                        }
130                    };
131
132                    let path = req_body.path().to_string();
133
134                    let mws = router.collect_middleware(&path, method);
135
136                    for mw in mws {
137                        mw.run(&mut req_body, &mut res_body).await;
138
139                        if res_body.has_error() {
140                            break;
141                        }
142                    }
143
144                    if !res_body.has_error() {
145                        if let Some((handler, params)) = router.find(&path, method) {
146                            req_body.set_params(params);
147
148                            handler.handle(&mut req_body, &mut res_body).await;
149                        } else {
150                            let method_str = match *req_body.method() {
151                                hyper::Method::GET => "GET",
152                                hyper::Method::POST => "POST",
153                                hyper::Method::PUT => "PUT",
154                                hyper::Method::PATCH => "PATCH",
155                                hyper::Method::DELETE => "DELETE",
156                                hyper::Method::OPTIONS => "OPTIONS",
157                                hyper::Method::HEAD => "HEAD",
158                                hyper::Method::TRACE => "TRACE",
159                                _ => "UNKNOWN",
160                            };
161
162                            res_body.error(404, &format!("Not Found {} {}", method_str, path));
163                        }
164                    }
165
166                    if res_body.has_error() {
167                        let msg = res_body.body.clone();
168                        error_handler.run(msg, &mut res_body).await;
169                    }
170
171                    req_body.cleanup().await;
172
173                    if req_body.log {
174                        println!(
175                            "[LOG] method={} path={} status={}",
176                            req_body.method(),
177                            path,
178                            res_body.status,
179                        );
180                    }
181
182                    res_body.strip_header("X-Internal-Request-Start");
183
184                    Ok::<_, Infallible>(res_body.into_response())
185                }
186            });
187
188            match mode {
189                Mode::Http1 => {
190                    tokio::task::spawn(async move {
191                        if let Err(err) = http1::Builder::new().serve_connection(io, service).await
192                        {
193                            eprintln!("Error serving connection: {}", err);
194                        }
195                    });
196                }
197
198                Mode::Http2 => {
199                    tokio::task::spawn(async move {
200                        if let Err(err) = http2::Builder::new(TokioExecutor::new())
201                            .serve_connection(io, service)
202                            .await
203                        {
204                            eprintln!("Error serving connection: {}", err);
205                        }
206                    });
207                }
208            }
209        }
210    }
211
212    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
213    where
214        H: Handler + 'static,
215    {
216        self.router.insert(path, method, handler);
217    }
218
219    pub fn get<H>(&mut self, path: &str, handler: H)
220    where
221        H: Handler + 'static,
222    {
223        self.add_route(Method::GET, path, handler);
224    }
225
226    pub fn post<H>(&mut self, path: &str, handler: H)
227    where
228        H: Handler + 'static,
229    {
230        self.add_route(Method::POST, path, handler);
231    }
232
233    pub fn put<H>(&mut self, path: &str, handler: H)
234    where
235        H: Handler + 'static,
236    {
237        self.add_route(Method::PUT, path, handler);
238    }
239
240    pub fn patch<H>(&mut self, path: &str, handler: H)
241    where
242        H: Handler + 'static,
243    {
244        self.add_route(Method::PATCH, path, handler);
245    }
246
247    pub fn delete<H>(&mut self, path: &str, handler: H)
248    where
249        H: Handler + 'static,
250    {
251        self.add_route(Method::DELETE, path, handler);
252    }
253
254    pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
255        Group {
256            prefix: path.to_string(),
257            app: self,
258        }
259    }
260
261    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
262    where
263        M: Middleware + 'static,
264    {
265        let mw = Arc::new(middleware_fn);
266        let full_path = path.to_string();
267
268        match method {
269            Some(m) => self.router.insert_middleware(&full_path, m, mw),
270            None => {
271                for m in [
272                    Method::GET,
273                    Method::POST,
274                    Method::PUT,
275                    Method::PATCH,
276                    Method::DELETE,
277                    Method::OPTIONS,
278                    Method::HEAD,
279                    Method::TRACE,
280                ] {
281                    self.router.insert_middleware(&full_path, m, mw.clone());
282                }
283            }
284        }
285    }
286
287    pub fn set_error_handler<E>(&mut self, handler: E)
288    where
289        E: ErrorHandler + 'static,
290    {
291        self.error_handler = Arc::new(handler);
292    }
293}