bolt_web/
lib.rs

1use futures_util::FutureExt;
2
3use std::{
4    convert::Infallible, net::SocketAddr, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
5    time::Duration,
6};
7
8use hyper::{
9    Request,
10    body::Incoming,
11    server::conn::{http1, http2},
12    service::service_fn,
13};
14use hyper_util::rt::{TokioExecutor, TokioIo};
15
16use tokio::{
17    io::{AsyncRead, AsyncWrite},
18    net::TcpListener,
19    sync::Semaphore,
20};
21use tokio_rustls::TlsAcceptor;
22
23use crate::{
24    client::Client,
25    error::DefaultErrorHandler,
26    group::Group,
27    headers::LimitReader,
28    http::StatusCode,
29    request::RequestBody,
30    response::ResponseWriter,
31    router::Router,
32    tls::tls_config,
33    types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
34};
35
36pub mod client;
37mod error;
38mod group;
39mod headers;
40pub mod http;
41pub mod macros;
42pub mod request;
43pub mod response;
44mod router;
45mod tls;
46pub mod types;
47pub use bolt_web_macro::main;
48
49trait Io: AsyncRead + AsyncWrite + Unpin {}
50impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
51
52#[allow(dead_code)]
53pub struct App {
54    router: Router,
55    error_handler: Arc<dyn ErrorHandler>,
56    client: Client,
57    timeout: u64,
58    connection_limit: u64,
59    read_timeout: u64,
60    header_limit: usize,
61}
62
63#[allow(unused_variables)]
64#[allow(dead_code)]
65impl App {
66    pub fn new() -> Self {
67        Self {
68            router: Router::new(),
69            error_handler: Arc::new(DefaultErrorHandler),
70            client: Client::new(),
71            timeout: 30,
72            connection_limit: 100,
73            read_timeout: 10,
74            header_limit: 32 * 1024,
75        }
76    }
77
78    pub fn set_timeout(&mut self, seconds: u64) {
79        self.timeout = seconds;
80    }
81
82    pub fn set_connection_limit(&mut self, limit: u64) {
83        self.connection_limit = limit;
84    }
85
86    pub fn set_read_timeout(&mut self, seconds: u64) {
87        self.read_timeout = seconds;
88    }
89
90    pub fn set_header_limit(&mut self, bytes: usize) {
91        self.header_limit = bytes;
92    }
93
94    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
95    where
96        H: Handler + 'static,
97    {
98        self.router.insert(path, method, handler);
99    }
100
101    pub fn get<H>(&mut self, path: &str, handler: H)
102    where
103        H: Handler + 'static,
104    {
105        self.add_route(Method::GET, path, handler);
106    }
107
108    pub fn post<H>(&mut self, path: &str, handler: H)
109    where
110        H: Handler + 'static,
111    {
112        self.add_route(Method::POST, path, handler);
113    }
114
115    pub fn put<H>(&mut self, path: &str, handler: H)
116    where
117        H: Handler + 'static,
118    {
119        self.add_route(Method::PUT, path, handler);
120    }
121
122    pub fn patch<H>(&mut self, path: &str, handler: H)
123    where
124        H: Handler + 'static,
125    {
126        self.add_route(Method::PATCH, path, handler);
127    }
128
129    pub fn delete<H>(&mut self, path: &str, handler: H)
130    where
131        H: Handler + 'static,
132    {
133        self.add_route(Method::DELETE, path, handler);
134    }
135
136    pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
137        Group {
138            prefix: path.to_string(),
139            app: self,
140        }
141    }
142
143    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
144    where
145        M: Middleware + 'static,
146    {
147        let mw: Arc<M> = Arc::new(middleware_fn);
148        let full_path = path.to_string();
149
150        match method {
151            Some(m) => self.router.insert_middleware(&full_path, m, mw),
152            None => {
153                for m in [
154                    Method::GET,
155                    Method::POST,
156                    Method::PUT,
157                    Method::PATCH,
158                    Method::DELETE,
159                    Method::OPTIONS,
160                    Method::HEAD,
161                    Method::TRACE,
162                ] {
163                    self.router.insert_middleware(&full_path, m, mw.clone());
164                }
165            }
166        }
167    }
168
169    pub fn set_error_handler<E>(&mut self, handler: E)
170    where
171        E: ErrorHandler + 'static,
172    {
173        self.error_handler = Arc::new(handler);
174    }
175
176    pub async fn run(&self, addr: &str, mode: Mode) -> Result<(), BoltError> {
177        println!("⚡ A high performance & minimalist web framework in rust.");
178        println!(
179            r#"
180    __          ____
181   / /_  ____  / / /_
182  / __ \/ __ \/ / __/
183 / /_/ / /_/ / / /_  
184/_.___/\____/_/\__/  v0.2.0
185"#
186        );
187
188        println!(">> Server running on http://{}", addr);
189
190        let addr: SocketAddr = addr.parse().unwrap();
191
192        let listener = TcpListener::bind(addr).await?;
193        let router = Arc::new(self.router.clone());
194        let error_handler = self.error_handler.clone();
195        let active = Arc::new(Semaphore::new(self.connection_limit as usize));
196
197        self.server_loop(
198            router,
199            error_handler,
200            listener,
201            mode,
202            None,
203            self.timeout,
204            self.read_timeout,
205            Box::pin(tokio::signal::ctrl_c().map(|_| ())),
206            active,
207        )
208        .await
209    }
210
211    pub async fn run_tls(
212        &self,
213        addr: &str,
214        mode: Mode,
215        tls: Option<(&str, &str)>,
216    ) -> Result<(), BoltError> {
217        println!("⚡ A high performance & minimalist web framework in rust.");
218        println!(
219            "{}",
220            r#"
221    __          ____
222   / /_  ____  / / /_
223  / __ \/ __ \/ / __/
224 / /_/ / /_/ / / /_  
225/_.___/\____/_/\__/  v0.2.0
226"#
227        );
228
229        let addr: SocketAddr = addr.parse().unwrap();
230        let listener = TcpListener::bind(addr).await?;
231
232        let tls_acceptor: Option<Arc<TlsAcceptor>> = if let Some((cert, key)) = tls {
233            let cfg = tls_config(cert, key)?;
234            Some(Arc::new(TlsAcceptor::from(cfg)))
235        } else {
236            None
237        };
238
239        println!(
240            ">> Server running on {}://{}",
241            if tls_acceptor.is_some() {
242                "https"
243            } else {
244                "http"
245            },
246            addr
247        );
248
249        let router: Arc<Router> = Arc::new(self.router.clone());
250        let error_handler = self.error_handler.clone();
251        let active = Arc::new(Semaphore::new(self.connection_limit as usize));
252
253        self.server_loop(
254            router,
255            error_handler,
256            listener,
257            mode,
258            tls_acceptor,
259            self.timeout,
260            self.read_timeout,
261            Box::pin(tokio::signal::ctrl_c().map(|_| ())),
262            active,
263        )
264        .await
265    }
266
267    async fn server_loop(
268        &self,
269        router: Arc<Router>,
270        error_handler: Arc<dyn ErrorHandler>,
271        listener: TcpListener,
272        mode: Mode,
273        tls_acceptor: Option<Arc<TlsAcceptor>>,
274        timeout: u64,
275        read_timeout: u64,
276        mut shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
277        active: Arc<Semaphore>,
278    ) -> Result<(), BoltError> {
279        loop {
280            tokio::select! {
281                _ = &mut shutdown => {
282                    println!(">> Shutdown signal received. Stopping server...");
283                    break;
284                }
285
286                accept_res = listener.accept() => {
287                    let (stream, remote_addr) = match accept_res {
288                        Ok(v) => v,
289                        Err(e) => {
290                            eprintln!("Accept error: {}", e);
291                            continue;
292                        }
293                    };
294
295                    let permit = match active.clone().try_acquire_owned() {
296                        Ok(p) => p,
297                        Err(_) => {
298                            eprintln!("Connection limit reached — dropping client");
299                            continue;
300                        }
301                    };
302
303                    let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
304                        match acceptor.accept(stream).await {
305                            Ok(c) => Box::new(c),
306                            Err(e) => {
307                                eprintln!("TLS error: {}", e);
308                                continue;
309                            }
310                        }
311                    } else {
312                        Box::new(stream)
313                    };
314
315                    let limited = LimitReader::new(io, self.header_limit);
316                    let io = TokioIo::new(limited);
317
318                    let router = router.clone();
319                    let error_handler = error_handler.clone();
320
321                    let service = service_fn(move |req: Request<Incoming>| {
322                        let router = router.clone();
323                        let error_handler = error_handler.clone();
324                        let remote_addr = remote_addr.clone();
325                        let timeout = timeout;
326
327                        async move {
328                            let handler_future = tokio::time::timeout(
329                                Duration::from_secs(timeout),
330                                async {
331                                    let inner = AssertUnwindSafe(async move {
332                                        let mut req_body = RequestBody::new(req, remote_addr);
333                                        let mut res_body = ResponseWriter::new();
334
335                                        let method = match *req_body.method() {
336                                            hyper::Method::GET => Method::GET,
337                                            hyper::Method::POST => Method::POST,
338                                            hyper::Method::PUT => Method::PUT,
339                                            hyper::Method::PATCH => Method::PATCH,
340                                            hyper::Method::DELETE => Method::DELETE,
341                                            hyper::Method::OPTIONS => Method::OPTIONS,
342                                            hyper::Method::HEAD => Method::HEAD,
343                                            hyper::Method::TRACE => Method::TRACE,
344                                            _ => {
345                                                res_body.status(StatusCode::MethodNotAllowed)
346                                                        .send("Method Not Allowed");
347                                                return res_body;
348                                            }
349                                        };
350
351                                        let path = req_body.path().to_string();
352                                        for mw in router.collect_middleware(&path, method) {
353                                            mw.run(&mut req_body, &mut res_body).await;
354                                            if res_body.has_error() { break; }
355                                        }
356
357                                        if !res_body.has_error() {
358                                            if let Some((handler, params)) = router.find(&path, method) {
359                                                req_body.set_params(params);
360                                                handler.run(&mut req_body, &mut res_body).await;
361                                            } else {
362                                                res_body.error(
363                                                    StatusCode::NotFound,
364                                                    &format!("Not Found {} {}", req_body.method(), path),
365                                                );
366                                            }
367                                        }
368
369                                        if res_body.has_error() {
370                                            let msg = res_body.body.clone();
371                                            error_handler.run(msg, &mut res_body).await;
372                                        }
373
374                                        req_body.cleanup().await;
375                                        res_body
376                                    })
377                                    .catch_unwind()
378                                    .await;
379
380                                    match inner {
381                                        Ok(r) => r,
382                                        Err(_) => {
383                                            let mut res = ResponseWriter::new();
384                                            res.error(StatusCode::InternalServerError, "Internal Server Error");
385                                            res
386                                        }
387                                    }
388                                }
389                            ).await;
390
391                            let res_body = match handler_future {
392                                Ok(res) => res,
393                                Err(_) => {
394                                    let mut res = ResponseWriter::new();
395                                    res.error(StatusCode::RequestTimeout, "Request Timeout");
396                                    res
397                                }
398                            };
399
400                            Ok::<_, Infallible>(res_body.into_response())
401                        }
402                    });
403
404                    let permit = permit;
405
406                    match mode {
407                        Mode::Http1 => {
408                            tokio::spawn(async move {
409                                let _permit = permit;
410
411                                let result = tokio::time::timeout(
412                                    Duration::from_secs(read_timeout),
413                                    async {
414                                        http1::Builder::new()
415                                            .serve_connection(io, service)
416                                            .await
417                                    }
418                                ).await;
419
420                                match result {
421                                    Ok(Ok(_)) => {}
422                                    Ok(Err(e)) => eprintln!("Connection error: {}", e),
423                                    Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
424                                }
425                            });
426                        }
427
428                        Mode::Http2 => {
429                            tokio::spawn(async move {
430                                let _permit = permit;
431
432                                let result = tokio::time::timeout(
433                                    Duration::from_secs(read_timeout),
434                                    async {
435                                        http2::Builder::new(TokioExecutor::new())
436                                            .serve_connection(io, service)
437                                            .await
438                                    }
439                                ).await;
440
441                                match result {
442                                    Ok(Ok(_)) => {}
443                                    Ok(Err(e)) => eprintln!("Connection error: {}", e),
444                                    Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
445                                }
446                            });
447                        }
448                    }
449                }
450            }
451        }
452
453        Ok(())
454    }
455}