bolt_web/
lib.rs

1use futures_util::FutureExt;
2
3use std::{
4    convert::Infallible, net::SocketAddr, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
5    time::Duration,
6};
7
8use hyper::{
9    Request,
10    body::Incoming,
11    server::conn::{http1, http2},
12    service::service_fn,
13};
14use hyper_util::rt::{TokioExecutor, TokioIo};
15
16use tokio::{
17    io::{AsyncRead, AsyncWrite},
18    net::TcpListener,
19    sync::Semaphore,
20};
21use tokio_rustls::TlsAcceptor;
22
23use crate::{
24    client::Client,
25    error::DefaultErrorHandler,
26    group::Group,
27    headers::LimitReader,
28    http::StatusCode,
29    request::RequestBody,
30    response::ResponseWriter,
31    router::Router,
32    tls::tls_config,
33    types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
34};
35
36pub mod client;
37mod error;
38mod group;
39mod headers;
40pub mod http;
41pub mod macros;
42pub mod request;
43pub mod response;
44mod router;
45mod tls;
46pub mod types;
47pub use async_trait;
48pub use bolt_web_macro::main;
49pub use paste;
50pub use tokio;
51
52trait Io: AsyncRead + AsyncWrite + Unpin {}
53impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
54
55#[allow(dead_code)]
56pub struct App {
57    router: Router,
58    error_handler: Arc<dyn ErrorHandler>,
59    client: Client,
60    timeout: u64,
61    connection_limit: u64,
62    read_timeout: u64,
63    header_limit: usize,
64}
65
66#[allow(unused_variables)]
67#[allow(dead_code)]
68impl App {
69    pub fn new() -> Self {
70        Self {
71            router: Router::new(),
72            error_handler: Arc::new(DefaultErrorHandler),
73            client: Client::new(),
74            timeout: 30,
75            connection_limit: 100,
76            read_timeout: 10,
77            header_limit: 32 * 1024,
78        }
79    }
80
81    pub fn set_timeout(&mut self, seconds: u64) {
82        self.timeout = seconds;
83    }
84
85    pub fn set_connection_limit(&mut self, limit: u64) {
86        self.connection_limit = limit;
87    }
88
89    pub fn set_read_timeout(&mut self, seconds: u64) {
90        self.read_timeout = seconds;
91    }
92
93    pub fn set_header_limit(&mut self, bytes: usize) {
94        self.header_limit = bytes;
95    }
96
97    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
98    where
99        H: Handler + 'static,
100    {
101        self.router.insert(path, method, handler);
102    }
103
104    pub fn get<H>(&mut self, path: &str, handler: H)
105    where
106        H: Handler + 'static,
107    {
108        self.add_route(Method::GET, path, handler);
109    }
110
111    pub fn post<H>(&mut self, path: &str, handler: H)
112    where
113        H: Handler + 'static,
114    {
115        self.add_route(Method::POST, path, handler);
116    }
117
118    pub fn put<H>(&mut self, path: &str, handler: H)
119    where
120        H: Handler + 'static,
121    {
122        self.add_route(Method::PUT, path, handler);
123    }
124
125    pub fn patch<H>(&mut self, path: &str, handler: H)
126    where
127        H: Handler + 'static,
128    {
129        self.add_route(Method::PATCH, path, handler);
130    }
131
132    pub fn delete<H>(&mut self, path: &str, handler: H)
133    where
134        H: Handler + 'static,
135    {
136        self.add_route(Method::DELETE, path, handler);
137    }
138
139    pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
140        Group {
141            prefix: path.to_string(),
142            app: self,
143        }
144    }
145
146    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
147    where
148        M: Middleware + 'static,
149    {
150        let mw: Arc<M> = Arc::new(middleware_fn);
151        let full_path = path.to_string();
152
153        match method {
154            Some(m) => self.router.insert_middleware(&full_path, m, mw),
155            None => {
156                for m in [
157                    Method::GET,
158                    Method::POST,
159                    Method::PUT,
160                    Method::PATCH,
161                    Method::DELETE,
162                    Method::OPTIONS,
163                    Method::HEAD,
164                    Method::TRACE,
165                ] {
166                    self.router.insert_middleware(&full_path, m, mw.clone());
167                }
168            }
169        }
170    }
171
172    pub fn set_error_handler<E>(&mut self, handler: E)
173    where
174        E: ErrorHandler + 'static,
175    {
176        self.error_handler = Arc::new(handler);
177    }
178
179    pub async fn run(&self, addr: &str, mode: Mode) -> Result<(), BoltError> {
180        println!("⚡ A high performance & minimalist web framework in rust.");
181        println!(
182            r#"
183    __          ____
184   / /_  ____  / / /_
185  / __ \/ __ \/ / __/
186 / /_/ / /_/ / / /_  
187/_.___/\____/_/\__/  v0.2.0
188"#
189        );
190
191        println!(">> Server running on http://{}", addr);
192
193        let addr: SocketAddr = addr.parse().unwrap();
194
195        let listener = TcpListener::bind(addr).await?;
196        let router = Arc::new(self.router.clone());
197        let error_handler = self.error_handler.clone();
198        let active = Arc::new(Semaphore::new(self.connection_limit as usize));
199
200        self.server_loop(
201            router,
202            error_handler,
203            listener,
204            mode,
205            None,
206            self.timeout,
207            self.read_timeout,
208            Box::pin(tokio::signal::ctrl_c().map(|_| ())),
209            active,
210        )
211        .await
212    }
213
214    pub async fn run_tls(
215        &self,
216        addr: &str,
217        mode: Mode,
218        tls: Option<(&str, &str)>,
219    ) -> Result<(), BoltError> {
220        println!("⚡ A high performance & minimalist web framework in rust.");
221        println!(
222            "{}",
223            r#"
224    __          ____
225   / /_  ____  / / /_
226  / __ \/ __ \/ / __/
227 / /_/ / /_/ / / /_  
228/_.___/\____/_/\__/  v0.2.0
229"#
230        );
231
232        let addr: SocketAddr = addr.parse().unwrap();
233        let listener = TcpListener::bind(addr).await?;
234
235        let tls_acceptor: Option<Arc<TlsAcceptor>> = if let Some((cert, key)) = tls {
236            let cfg = tls_config(cert, key)?;
237            Some(Arc::new(TlsAcceptor::from(cfg)))
238        } else {
239            None
240        };
241
242        println!(
243            ">> Server running on {}://{}",
244            if tls_acceptor.is_some() {
245                "https"
246            } else {
247                "http"
248            },
249            addr
250        );
251
252        let router: Arc<Router> = Arc::new(self.router.clone());
253        let error_handler = self.error_handler.clone();
254        let active = Arc::new(Semaphore::new(self.connection_limit as usize));
255
256        self.server_loop(
257            router,
258            error_handler,
259            listener,
260            mode,
261            tls_acceptor,
262            self.timeout,
263            self.read_timeout,
264            Box::pin(tokio::signal::ctrl_c().map(|_| ())),
265            active,
266        )
267        .await
268    }
269
270    async fn server_loop(
271        &self,
272        router: Arc<Router>,
273        error_handler: Arc<dyn ErrorHandler>,
274        listener: TcpListener,
275        mode: Mode,
276        tls_acceptor: Option<Arc<TlsAcceptor>>,
277        timeout: u64,
278        read_timeout: u64,
279        mut shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
280        active: Arc<Semaphore>,
281    ) -> Result<(), BoltError> {
282        loop {
283            tokio::select! {
284                _ = &mut shutdown => {
285                    println!(">> Shutdown signal received. Stopping server...");
286                    break;
287                }
288
289                accept_res = listener.accept() => {
290                    let (stream, remote_addr) = match accept_res {
291                        Ok(v) => v,
292                        Err(e) => {
293                            eprintln!("Accept error: {}", e);
294                            continue;
295                        }
296                    };
297
298                    let permit = match active.clone().try_acquire_owned() {
299                        Ok(p) => p,
300                        Err(_) => {
301                            eprintln!("Connection limit reached — dropping client");
302                            continue;
303                        }
304                    };
305
306                    let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
307                        match acceptor.accept(stream).await {
308                            Ok(c) => Box::new(c),
309                            Err(e) => {
310                                eprintln!("TLS error: {}", e);
311                                continue;
312                            }
313                        }
314                    } else {
315                        Box::new(stream)
316                    };
317
318                    let limited = LimitReader::new(io, self.header_limit);
319                    let io = TokioIo::new(limited);
320
321                    let router = router.clone();
322                    let error_handler = error_handler.clone();
323
324                    let service = service_fn(move |req: Request<Incoming>| {
325                        let router = router.clone();
326                        let error_handler = error_handler.clone();
327                        let remote_addr = remote_addr.clone();
328                        let timeout = timeout;
329
330                        async move {
331                            let handler_future = tokio::time::timeout(
332                                Duration::from_secs(timeout),
333                                async {
334                                    let inner = AssertUnwindSafe(async move {
335                                        let mut req_body = RequestBody::new(req, remote_addr);
336                                        let mut res_body = ResponseWriter::new();
337
338                                        let method = match *req_body.method() {
339                                            hyper::Method::GET => Method::GET,
340                                            hyper::Method::POST => Method::POST,
341                                            hyper::Method::PUT => Method::PUT,
342                                            hyper::Method::PATCH => Method::PATCH,
343                                            hyper::Method::DELETE => Method::DELETE,
344                                            hyper::Method::OPTIONS => Method::OPTIONS,
345                                            hyper::Method::HEAD => Method::HEAD,
346                                            hyper::Method::TRACE => Method::TRACE,
347                                            _ => {
348                                                res_body.status(StatusCode::MethodNotAllowed)
349                                                        .send("Method Not Allowed");
350                                                return res_body;
351                                            }
352                                        };
353
354                                        let path = req_body.path().to_string();
355                                        for mw in router.collect_middleware(&path, method) {
356                                            mw.run(&mut req_body, &mut res_body).await;
357                                            if res_body.has_error() { break; }
358                                        }
359
360                                        if !res_body.has_error() {
361                                            if let Some((handler, params)) = router.find(&path, method) {
362                                                req_body.set_params(params);
363                                                handler.run(&mut req_body, &mut res_body).await;
364                                            } else {
365                                                res_body.error(
366                                                    StatusCode::NotFound,
367                                                    &format!("Not Found {} {}", req_body.method(), path),
368                                                );
369                                            }
370                                        }
371
372                                        if res_body.has_error() {
373                                            let msg = res_body.body.clone();
374                                            error_handler.run(msg, &mut res_body).await;
375                                        }
376
377                                        req_body.cleanup().await;
378                                        res_body
379                                    })
380                                    .catch_unwind()
381                                    .await;
382
383                                    match inner {
384                                        Ok(r) => r,
385                                        Err(_) => {
386                                            let mut res = ResponseWriter::new();
387                                            res.error(StatusCode::InternalServerError, "Internal Server Error");
388                                            res
389                                        }
390                                    }
391                                }
392                            ).await;
393
394                            let res_body = match handler_future {
395                                Ok(res) => res,
396                                Err(_) => {
397                                    let mut res = ResponseWriter::new();
398                                    res.error(StatusCode::RequestTimeout, "Request Timeout");
399                                    res
400                                }
401                            };
402
403                            Ok::<_, Infallible>(res_body.into_response())
404                        }
405                    });
406
407                    let permit = permit;
408
409                    match mode {
410                        Mode::Http1 => {
411                            tokio::spawn(async move {
412                                let _permit = permit;
413
414                                let result = tokio::time::timeout(
415                                    Duration::from_secs(read_timeout),
416                                    async {
417                                        http1::Builder::new()
418                                            .serve_connection(io, service)
419                                            .await
420                                    }
421                                ).await;
422
423                                match result {
424                                    Ok(Ok(_)) => {}
425                                    Ok(Err(e)) => eprintln!("Connection error: {}", e),
426                                    Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
427                                }
428                            });
429                        }
430
431                        Mode::Http2 => {
432                            tokio::spawn(async move {
433                                let _permit = permit;
434
435                                let result = tokio::time::timeout(
436                                    Duration::from_secs(read_timeout),
437                                    async {
438                                        http2::Builder::new(TokioExecutor::new())
439                                            .serve_connection(io, service)
440                                            .await
441                                    }
442                                ).await;
443
444                                match result {
445                                    Ok(Ok(_)) => {}
446                                    Ok(Err(e)) => eprintln!("Connection error: {}", e),
447                                    Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
448                                }
449                            });
450                        }
451                    }
452                }
453            }
454        }
455
456        Ok(())
457    }
458}