bolt_web/
lib.rs

1use futures_util::FutureExt;
2
3use std::{
4    convert::Infallible, net::SocketAddr, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
5    time::Duration,
6};
7
8use hyper::{
9    Request,
10    body::Incoming,
11    server::conn::{http1, http2},
12    service::service_fn,
13};
14use hyper_util::rt::{TokioExecutor, TokioIo};
15
16use tokio::{
17    io::{AsyncRead, AsyncWrite},
18    net::TcpListener,
19    sync::Semaphore,
20};
21use tokio_rustls::TlsAcceptor;
22
23use crate::{
24    client::Client,
25    error::DefaultErrorHandler,
26    group::Group,
27    headers::LimitReader,
28    http::StatusCode,
29    request::RequestBody,
30    response::ResponseWriter,
31    router::Router,
32    tls::tls_config,
33    types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
34};
35
36pub mod client;
37mod error;
38mod group;
39mod headers;
40pub mod http;
41pub mod macros;
42pub mod request;
43pub mod response;
44mod router;
45mod tls;
46pub mod types;
47pub use async_trait;
48pub use bolt_web_macro::main;
49pub use paste;
50pub use tokio;
51
52trait Io: AsyncRead + AsyncWrite + Unpin {}
53impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
54
55#[allow(dead_code)]
56pub struct App {
57    router: Router,
58    error_handler: Arc<dyn ErrorHandler>,
59    client: Client,
60    timeout: u64,
61    connection_limit: u64,
62    header_limit: usize,
63}
64
65#[allow(unused_variables)]
66#[allow(dead_code)]
67impl App {
68    pub fn new() -> Self {
69        Self {
70            router: Router::new(),
71            error_handler: Arc::new(DefaultErrorHandler),
72            client: Client::new(),
73            timeout: 30,
74            connection_limit: 100,
75            header_limit: 32 * 1024,
76        }
77    }
78
79    pub fn set_timeout(&mut self, seconds: u64) {
80        self.timeout = seconds;
81    }
82
83    pub fn set_connection_limit(&mut self, limit: u64) {
84        self.connection_limit = limit;
85    }
86
87    pub fn set_header_limit(&mut self, bytes: usize) {
88        self.header_limit = bytes;
89    }
90
91    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
92    where
93        H: Handler + 'static,
94    {
95        self.router.insert(path, method, handler);
96    }
97
98    pub fn get<H>(&mut self, path: &str, handler: H)
99    where
100        H: Handler + 'static,
101    {
102        self.add_route(Method::GET, path, handler);
103    }
104
105    pub fn post<H>(&mut self, path: &str, handler: H)
106    where
107        H: Handler + 'static,
108    {
109        self.add_route(Method::POST, path, handler);
110    }
111
112    pub fn put<H>(&mut self, path: &str, handler: H)
113    where
114        H: Handler + 'static,
115    {
116        self.add_route(Method::PUT, path, handler);
117    }
118
119    pub fn patch<H>(&mut self, path: &str, handler: H)
120    where
121        H: Handler + 'static,
122    {
123        self.add_route(Method::PATCH, path, handler);
124    }
125
126    pub fn delete<H>(&mut self, path: &str, handler: H)
127    where
128        H: Handler + 'static,
129    {
130        self.add_route(Method::DELETE, path, handler);
131    }
132
133    pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
134        Group {
135            prefix: path.to_string(),
136            app: self,
137        }
138    }
139
140    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
141    where
142        M: Middleware + 'static,
143    {
144        let mw: Arc<M> = Arc::new(middleware_fn);
145        let full_path = path.to_string();
146
147        match method {
148            Some(m) => self.router.insert_middleware(&full_path, m, mw),
149            None => {
150                for m in [
151                    Method::GET,
152                    Method::POST,
153                    Method::PUT,
154                    Method::PATCH,
155                    Method::DELETE,
156                    Method::OPTIONS,
157                    Method::HEAD,
158                    Method::TRACE,
159                ] {
160                    self.router.insert_middleware(&full_path, m, mw.clone());
161                }
162            }
163        }
164    }
165
166    pub fn set_error_handler<E>(&mut self, handler: E)
167    where
168        E: ErrorHandler + 'static,
169    {
170        self.error_handler = Arc::new(handler);
171    }
172
173    pub async fn run(&self, addr: &str, mode: Mode) -> Result<(), BoltError> {
174        println!("⚡ A high performance & minimalist web framework in rust.");
175        println!(
176            r#"
177    __          ____
178   / /_  ____  / / /_
179  / __ \/ __ \/ / __/
180 / /_/ / /_/ / / /_  
181/_.___/\____/_/\__/  v0.3
182"#
183        );
184
185        println!(">> Server running on http://{}", addr);
186
187        let addr: SocketAddr = addr.parse().unwrap();
188
189        let listener = TcpListener::bind(addr).await?;
190        let router = Arc::new(self.router.clone());
191        let error_handler = self.error_handler.clone();
192        let active = Arc::new(Semaphore::new(self.connection_limit as usize));
193
194        self.server_loop(
195            router,
196            error_handler,
197            listener,
198            mode,
199            None,
200            self.timeout,
201            Box::pin(tokio::signal::ctrl_c().map(|_| ())),
202            active,
203        )
204        .await
205    }
206
207    pub async fn run_tls(
208        &self,
209        addr: &str,
210        mode: Mode,
211        tls: Option<(&str, &str)>,
212    ) -> Result<(), BoltError> {
213        println!("⚡ A high performance & minimalist web framework in rust.");
214        println!(
215            "{}",
216            r#"
217    __          ____
218   / /_  ____  / / /_
219  / __ \/ __ \/ / __/
220 / /_/ / /_/ / / /_  
221/_.___/\____/_/\__/  v0.3
222"#
223        );
224
225        let addr: SocketAddr = addr.parse().unwrap();
226        let listener = TcpListener::bind(addr).await?;
227
228        let tls_acceptor: Option<Arc<TlsAcceptor>> = if let Some((cert, key)) = tls {
229            let cfg = tls_config(cert, key)?;
230            Some(Arc::new(TlsAcceptor::from(cfg)))
231        } else {
232            None
233        };
234
235        println!(
236            ">> Server running on {}://{}",
237            if tls_acceptor.is_some() {
238                "https"
239            } else {
240                "http"
241            },
242            addr
243        );
244
245        let router: Arc<Router> = Arc::new(self.router.clone());
246        let error_handler = self.error_handler.clone();
247        let active = Arc::new(Semaphore::new(self.connection_limit as usize));
248
249        self.server_loop(
250            router,
251            error_handler,
252            listener,
253            mode,
254            tls_acceptor,
255            self.timeout,
256            Box::pin(tokio::signal::ctrl_c().map(|_| ())),
257            active,
258        )
259        .await
260    }
261
262    async fn server_loop(
263        &self,
264        router: Arc<Router>,
265        error_handler: Arc<dyn ErrorHandler>,
266        listener: TcpListener,
267        mode: Mode,
268        tls_acceptor: Option<Arc<TlsAcceptor>>,
269        timeout: u64,
270        mut shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
271        active: Arc<Semaphore>,
272    ) -> Result<(), BoltError> {
273        loop {
274            tokio::select! {
275                _ = &mut shutdown => {
276                    println!(">> Shutdown signal received. Stopping server...");
277                    break;
278                }
279
280                accept_res = listener.accept() => {
281                    let (stream, remote_addr) = match accept_res {
282                        Ok(v) => v,
283                        Err(e) => {
284                            eprintln!("Accept error: {}", e);
285                            continue;
286                        }
287                    };
288
289                    let permit = match active.clone().try_acquire_owned() {
290                        Ok(p) => p,
291                        Err(_) => {
292                            eprintln!("Connection limit reached — dropping client");
293                            continue;
294                        }
295                    };
296
297                    let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
298                        match acceptor.accept(stream).await {
299                            Ok(c) => Box::new(c),
300                            Err(e) => {
301                                eprintln!("TLS error: {}", e);
302                                continue;
303                            }
304                        }
305                    } else {
306                        Box::new(stream)
307                    };
308
309                    let limited = LimitReader::new(io, self.header_limit);
310                    let io = TokioIo::new(limited);
311
312                    let router = router.clone();
313                    let error_handler = error_handler.clone();
314
315                    let service = service_fn(move |req: Request<Incoming>| {
316                        let router = router.clone();
317                        let error_handler = error_handler.clone();
318                        let remote_addr = remote_addr.clone();
319                        let timeout = timeout;
320
321                        async move {
322                            let handler_future = tokio::time::timeout(
323                                Duration::from_secs(timeout),
324                                async {
325                                    let inner = AssertUnwindSafe(async move {
326                                        let mut req_body = RequestBody::new(req, remote_addr);
327                                        let mut res_body = ResponseWriter::new();
328
329                                        let method = match *req_body.method() {
330                                            hyper::Method::GET => Method::GET,
331                                            hyper::Method::POST => Method::POST,
332                                            hyper::Method::PUT => Method::PUT,
333                                            hyper::Method::PATCH => Method::PATCH,
334                                            hyper::Method::DELETE => Method::DELETE,
335                                            hyper::Method::OPTIONS => Method::OPTIONS,
336                                            hyper::Method::HEAD => Method::HEAD,
337                                            hyper::Method::TRACE => Method::TRACE,
338                                            _ => {
339                                                res_body.status(StatusCode::MethodNotAllowed)
340                                                        .send("Method Not Allowed");
341                                                return res_body;
342                                            }
343                                        };
344
345                                        let path = req_body.path().to_string();
346                                        for mw in router.collect_middleware(&path, method) {
347                                            mw.run(&mut req_body, &mut res_body).await;
348                                            if res_body.has_error() { break; }
349                                        }
350
351                                        if !res_body.has_error() {
352                                            if let Some((handler, params)) = router.find(&path, method) {
353                                                req_body.set_params(params);
354                                                handler.run(&mut req_body, &mut res_body).await;
355                                            } else {
356                                                res_body.error(
357                                                    StatusCode::NotFound,
358                                                    &format!("Not Found {} {}", req_body.method(), path),
359                                                );
360                                            }
361                                        }
362
363                                        if res_body.has_error() {
364                                            let msg = res_body.body.clone();
365                                            error_handler.run(msg, &mut res_body).await;
366                                        }
367
368                                        req_body.cleanup().await;
369                                        res_body
370                                    })
371                                    .catch_unwind()
372                                    .await;
373
374                                    match inner {
375                                        Ok(r) => r,
376                                        Err(_) => {
377                                            let mut res = ResponseWriter::new();
378                                            res.error(StatusCode::InternalServerError, "Internal Server Error");
379                                            res
380                                        }
381                                    }
382                                }
383                            ).await;
384
385                            let res_body = match handler_future {
386                                Ok(res) => res,
387                                Err(_) => {
388                                    let mut res = ResponseWriter::new();
389                                    res.error(StatusCode::RequestTimeout, "Request Timeout");
390                                    res
391                                }
392                            };
393
394                            Ok::<_, Infallible>(res_body.into_response())
395                        }
396                    });
397
398                    let permit = permit;
399
400                    match mode {
401                        Mode::Http1 => {
402                            tokio::spawn(async move {
403                                let _permit = permit;
404
405                                if let Err(e) = http1::Builder::new()
406                                    .serve_connection(io, service)
407                                    .await
408                                {
409                                    eprintln!("Connection error: {}", e);
410                                }
411                            });
412                        }
413
414                        Mode::Http2 => {
415                            tokio::spawn(async move {
416                                let _permit = permit;
417
418                            if let Err(e) = http2::Builder::new(TokioExecutor::new())
419                                .serve_connection(io, service)
420                                .await
421                            {
422                                eprintln!("Connection error: {}", e);
423                            }
424                            });
425                        }
426                    }
427                }
428            }
429        }
430
431        Ok(())
432    }
433}