bolt_web/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc};
2
3use hyper::{
4    Request,
5    body::Incoming,
6    server::conn::{http1, http2},
7    service::service_fn,
8};
9use hyper_util::rt::{TokioExecutor, TokioIo};
10use tokio::{
11    io::{AsyncRead, AsyncWrite},
12    net::TcpListener,
13};
14
15use crate::{
16    client::Client,
17    group::Group,
18    middleware::error::DefaultErrorHandler,
19    request::RequestBody,
20    response::ResponseWriter,
21    router::Router,
22    types::{ErrorHandler, Handler, Method, Middleware, Mode},
23};
24
25pub mod client;
26mod group;
27pub mod macros;
28pub mod middleware;
29pub mod request;
30pub mod response;
31mod router;
32pub mod types;
33pub use async_trait;
34pub use paste;
35
36trait Io: AsyncRead + AsyncWrite + Unpin {}
37impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
38
39#[allow(dead_code)]
40pub struct Bolt {
41    router: Router,
42    error_handler: Arc<dyn ErrorHandler>,
43    client: Client,
44}
45
46#[allow(dead_code)]
47impl Bolt {
48    pub fn new() -> Self {
49        Self {
50            router: Router::new(),
51            error_handler: Arc::new(DefaultErrorHandler),
52            client: Client::new(),
53        }
54    }
55
56    pub async fn run(
57        &self,
58        addr: &str,
59        mode: Mode,
60        tls: Option<(&str, &str)>,
61    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
62        println!("⚡ A high performance & minimalist web framework in rust.");
63        println!(
64            r#"
65    __          ____
66   / /_  ____  / / /_
67  / __ \/ __ \/ / __/
68 / /_/ / /_/ / / /_  
69/_.___/\____/_/\__/  v0.1.9
70"#
71        );
72
73        println!(
74            "=> Server running on {}://{}",
75            if tls.is_some() { "https" } else { "http" },
76            addr
77        );
78
79        let addr: SocketAddr = addr.parse().unwrap();
80
81        let listener = TcpListener::bind(addr).await?;
82        let router = Arc::new(self.router.clone());
83
84        let tls_acceptor = if let Some((pkcs12_path, password)) = tls {
85            let pkcs12 = std::fs::read(pkcs12_path)?;
86            let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&pkcs12, password)?;
87            Some(Arc::new(tokio_native_tls::TlsAcceptor::from(
88                tokio_native_tls::native_tls::TlsAcceptor::builder(identity).build()?,
89            )))
90        } else {
91            None
92        };
93
94        loop {
95            let (stream, _) = listener.accept().await?;
96
97            let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
98                Box::new(acceptor.accept(stream).await?)
99            } else {
100                Box::new(stream)
101            };
102
103            let io = TokioIo::new(io);
104
105            let router = router.clone();
106            let error_handler = self.error_handler.clone();
107
108            let service = service_fn(move |req: Request<Incoming>| {
109                let router = router.clone();
110
111                let error_handler = error_handler.clone();
112
113                async move {
114                    let mut req_body = RequestBody::new(req);
115                    let mut res_body = ResponseWriter::new();
116
117                    let method = match *req_body.method() {
118                        hyper::Method::GET => Method::GET,
119                        hyper::Method::POST => Method::POST,
120                        hyper::Method::PUT => Method::PUT,
121                        hyper::Method::PATCH => Method::PATCH,
122                        hyper::Method::DELETE => Method::DELETE,
123                        hyper::Method::OPTIONS => Method::OPTIONS,
124                        hyper::Method::HEAD => Method::HEAD,
125                        hyper::Method::TRACE => Method::TRACE,
126                        _ => {
127                            res_body.status(405);
128                            res_body.send("Method Not Allowed");
129                            return Ok::<_, Infallible>(res_body.into_response());
130                        }
131                    };
132
133                    let path = req_body.path().to_string();
134
135                    let mws = router.collect_middleware(&path, method);
136
137                    for mw in mws {
138                        mw.run(&mut req_body, &mut res_body).await;
139
140                        if res_body.has_error() {
141                            break;
142                        }
143                    }
144
145                    if !res_body.has_error() {
146                        if let Some((handler, params)) = router.find(&path, method) {
147                            req_body.set_params(params);
148
149                            handler.handle(&mut req_body, &mut res_body).await;
150                        } else {
151                            let method_str = match *req_body.method() {
152                                hyper::Method::GET => "GET",
153                                hyper::Method::POST => "POST",
154                                hyper::Method::PUT => "PUT",
155                                hyper::Method::PATCH => "PATCH",
156                                hyper::Method::DELETE => "DELETE",
157                                hyper::Method::OPTIONS => "OPTIONS",
158                                hyper::Method::HEAD => "HEAD",
159                                hyper::Method::TRACE => "TRACE",
160                                _ => "UNKNOWN",
161                            };
162
163                            res_body.error(404, &format!("Not Found {} {}", method_str, path));
164                        }
165                    }
166
167                    if res_body.has_error() {
168                        let msg = res_body.body.clone();
169                        error_handler.run(msg, &mut res_body).await;
170                    }
171
172                    Ok::<_, Infallible>(res_body.into_response())
173                }
174            });
175
176            match mode {
177                Mode::Http1 => {
178                    tokio::task::spawn(async move {
179                        if let Err(err) = http1::Builder::new().serve_connection(io, service).await
180                        {
181                            eprintln!("Error serving connection: {}", err);
182                        }
183                    });
184                }
185
186                Mode::Http2 => {
187                    tokio::task::spawn(async move {
188                        if let Err(err) = http2::Builder::new(TokioExecutor::new())
189                            .serve_connection(io, service)
190                            .await
191                        {
192                            eprintln!("Error serving connection: {}", err);
193                        }
194                    });
195                }
196            }
197        }
198    }
199
200    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
201    where
202        H: Handler + 'static,
203    {
204        self.router.insert(path, method, handler);
205    }
206
207    pub fn get<H>(&mut self, path: &str, handler: H)
208    where
209        H: Handler + 'static,
210    {
211        self.add_route(Method::GET, path, handler);
212    }
213
214    pub fn post<H>(&mut self, path: &str, handler: H)
215    where
216        H: Handler + 'static,
217    {
218        self.add_route(Method::POST, path, handler);
219    }
220
221    pub fn put<H>(&mut self, path: &str, handler: H)
222    where
223        H: Handler + 'static,
224    {
225        self.add_route(Method::PUT, path, handler);
226    }
227
228    pub fn patch<H>(&mut self, path: &str, handler: H)
229    where
230        H: Handler + 'static,
231    {
232        self.add_route(Method::PATCH, path, handler);
233    }
234
235    pub fn delete<H>(&mut self, path: &str, handler: H)
236    where
237        H: Handler + 'static,
238    {
239        self.add_route(Method::DELETE, path, handler);
240    }
241
242    pub fn group(&mut self, path: &str) -> Group {
243        Group {
244            prefix: path.to_string(),
245            app: self,
246        }
247    }
248
249    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
250    where
251        M: Middleware + 'static,
252    {
253        let mw = Arc::new(middleware_fn);
254        let full_path = path.to_string();
255
256        match method {
257            Some(m) => self.router.insert_middleware(&full_path, m, mw),
258            None => {
259                for m in [
260                    Method::GET,
261                    Method::POST,
262                    Method::PUT,
263                    Method::PATCH,
264                    Method::DELETE,
265                    Method::OPTIONS,
266                    Method::HEAD,
267                    Method::TRACE,
268                ] {
269                    self.router.insert_middleware(&full_path, m, mw.clone());
270                }
271            }
272        }
273    }
274
275    pub fn set_error_handler(&mut self, handler: Arc<dyn ErrorHandler>) {
276        self.error_handler = handler;
277    }
278}