bolt_web/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4    Request,
5    body::Incoming,
6    server::conn::{http1, http2},
7    service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10
11use tokio::{
12    io::{AsyncRead, AsyncWrite},
13    net::TcpListener,
14};
15
16use crate::{
17    client::Client,
18    group::Group,
19    middleware::error::DefaultErrorHandler,
20    request::RequestBody,
21    response::ResponseWriter,
22    router::Router,
23    types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
24};
25
26pub mod client;
27mod group;
28pub mod macros;
29pub mod middleware;
30pub mod request;
31pub mod response;
32mod router;
33pub mod types;
34pub use bolt_web_macro::main;
35pub use num_cpus;
36pub use paste;
37pub use tokio;
38
39trait Io: AsyncRead + AsyncWrite + Unpin {}
40impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
41
42#[allow(dead_code)]
43pub struct Bolt {
44    router: Router,
45    error_handler: Arc<dyn ErrorHandler>,
46    client: Client,
47}
48
49#[allow(unused_variables)]
50#[allow(dead_code)]
51impl Bolt {
52    pub fn new() -> Self {
53        Self {
54            router: Router::new(),
55            error_handler: Arc::new(DefaultErrorHandler),
56            client: Client::new(),
57        }
58    }
59
60    pub async fn run(
61        &self,
62        addr: &str,
63        mode: Mode,
64        tls: Option<(&str, &str)>,
65    ) -> Result<(), BoltError> {
66        println!("⚡ A high performance & minimalist web framework in rust.");
67        println!(
68            r#"
69    __          ____
70   / /_  ____  / / /_
71  / __ \/ __ \/ / __/
72 / /_/ / /_/ / / /_  
73/_.___/\____/_/\__/  v0.2
74"#
75        );
76
77        println!(
78            ">> Server running on {}://{}",
79            if tls.is_some() { "https" } else { "http" },
80            addr
81        );
82
83        let addr: SocketAddr = addr.parse().unwrap();
84
85        let listener = TcpListener::bind(addr).await?;
86        let router = Arc::new(self.router.clone());
87
88        let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
89            let pkcs12 = std::fs::read(pkcs12_path)?;
90            let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
91            Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
92                tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
93            )))
94        } else {
95            None
96        };
97
98        loop {
99            let (stream, _) = listener.accept().await?;
100
101            let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
102                Box::new(acceptor.accept(stream).await?)
103            } else {
104                Box::new(stream)
105            };
106
107            let io = TokioIo::new(io);
108
109            let router = router.clone();
110            let error_handler = self.error_handler.clone();
111
112            let service = service_fn(move |req: Request<Incoming>| {
113                let router = router.clone();
114
115                let error_handler = error_handler.clone();
116
117                async move {
118                    let mut req_body = RequestBody::new(req);
119                    let mut res_body = ResponseWriter::new();
120
121                    let method = match *req_body.method() {
122                        hyper::Method::GET => Method::GET,
123                        hyper::Method::POST => Method::POST,
124                        hyper::Method::PUT => Method::PUT,
125                        hyper::Method::PATCH => Method::PATCH,
126                        hyper::Method::DELETE => Method::DELETE,
127                        hyper::Method::OPTIONS => Method::OPTIONS,
128                        hyper::Method::HEAD => Method::HEAD,
129                        hyper::Method::TRACE => Method::TRACE,
130                        _ => {
131                            res_body.status(405);
132                            res_body.send("Method Not Allowed");
133                            return Ok::<_, Infallible>(res_body.into_response());
134                        }
135                    };
136
137                    let path = req_body.path().to_string();
138
139                    let mws = router.collect_middleware(&path, method);
140
141                    for mw in mws {
142                        mw.run(&mut req_body, &mut res_body).await;
143
144                        if res_body.has_error() {
145                            break;
146                        }
147                    }
148
149                    if !res_body.has_error() {
150                        if let Some((handler, params)) = router.find(&path, method) {
151                            req_body.set_params(params);
152
153                            handler.handle(&mut req_body, &mut res_body).await;
154                        } else {
155                            let method_str = match *req_body.method() {
156                                hyper::Method::GET => "GET",
157                                hyper::Method::POST => "POST",
158                                hyper::Method::PUT => "PUT",
159                                hyper::Method::PATCH => "PATCH",
160                                hyper::Method::DELETE => "DELETE",
161                                hyper::Method::OPTIONS => "OPTIONS",
162                                hyper::Method::HEAD => "HEAD",
163                                hyper::Method::TRACE => "TRACE",
164                                _ => "UNKNOWN",
165                            };
166
167                            res_body.error(404, &format!("Not Found {} {}", method_str, path));
168                        }
169                    }
170
171                    if res_body.has_error() {
172                        let msg = res_body.body.clone();
173                        error_handler.run(msg, &mut res_body).await;
174                    }
175
176                    req_body.cleanup().await;
177
178                    if req_body.log {
179                        println!(
180                            "[LOG] method={} path={} status={}",
181                            req_body.method(),
182                            path,
183                            res_body.status,
184                        );
185                    }
186
187                    res_body.strip_header("X-Internal-Request-Start");
188
189                    Ok::<_, Infallible>(res_body.into_response())
190                }
191            });
192
193            match mode {
194                Mode::Http1 => {
195                    tokio::task::spawn(async move {
196                        if let Err(err) = http1::Builder::new().serve_connection(io, service).await
197                        {
198                            eprintln!("Error serving connection: {}", err);
199                        }
200                    });
201                }
202
203                Mode::Http2 => {
204                    tokio::task::spawn(async move {
205                        if let Err(err) = http2::Builder::new(TokioExecutor::new())
206                            .serve_connection(io, service)
207                            .await
208                        {
209                            eprintln!("Error serving connection: {}", err);
210                        }
211                    });
212                }
213            }
214        }
215    }
216
217    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
218    where
219        H: Handler + 'static,
220    {
221        self.router.insert(path, method, handler);
222    }
223
224    pub fn get<H>(&mut self, path: &str, handler: H)
225    where
226        H: Handler + 'static,
227    {
228        self.add_route(Method::GET, path, handler);
229    }
230
231    pub fn post<H>(&mut self, path: &str, handler: H)
232    where
233        H: Handler + 'static,
234    {
235        self.add_route(Method::POST, path, handler);
236    }
237
238    pub fn put<H>(&mut self, path: &str, handler: H)
239    where
240        H: Handler + 'static,
241    {
242        self.add_route(Method::PUT, path, handler);
243    }
244
245    pub fn patch<H>(&mut self, path: &str, handler: H)
246    where
247        H: Handler + 'static,
248    {
249        self.add_route(Method::PATCH, path, handler);
250    }
251
252    pub fn delete<H>(&mut self, path: &str, handler: H)
253    where
254        H: Handler + 'static,
255    {
256        self.add_route(Method::DELETE, path, handler);
257    }
258
259    pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
260        Group {
261            prefix: path.to_string(),
262            app: self,
263        }
264    }
265
266    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
267    where
268        M: Middleware + 'static,
269    {
270        let mw = Arc::new(middleware_fn);
271        let full_path = path.to_string();
272
273        match method {
274            Some(m) => self.router.insert_middleware(&full_path, m, mw),
275            None => {
276                for m in [
277                    Method::GET,
278                    Method::POST,
279                    Method::PUT,
280                    Method::PATCH,
281                    Method::DELETE,
282                    Method::OPTIONS,
283                    Method::HEAD,
284                    Method::TRACE,
285                ] {
286                    self.router.insert_middleware(&full_path, m, mw.clone());
287                }
288            }
289        }
290    }
291
292    pub fn set_error_handler(&mut self, handler: Arc<dyn ErrorHandler>) {
293        self.error_handler = handler;
294    }
295}