bolt_web/
lib.rs

1use futures_util::FutureExt;
2
3use std::{
4    convert::Infallible, net::SocketAddr, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
5    time::Duration,
6};
7
8use hyper::{
9    Request,
10    body::Incoming,
11    server::conn::{http1, http2},
12    service::service_fn,
13};
14use hyper_util::rt::{TokioExecutor, TokioIo};
15
16use tokio::{
17    io::{AsyncRead, AsyncWrite},
18    net::TcpListener,
19    sync::Semaphore,
20};
21use tokio_rustls::TlsAcceptor;
22
23use crate::{
24    client::Client,
25    error::DefaultErrorHandler,
26    group::Group,
27    headers::LimitReader,
28    http::StatusCode,
29    request::RequestBody,
30    response::ResponseWriter,
31    router::Router,
32    tls::tls_config,
33    types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
34};
35
36pub mod client;
37mod error;
38mod group;
39mod headers;
40pub mod http;
41pub mod macros;
42pub mod request;
43pub mod response;
44mod router;
45mod tls;
46pub mod types;
47pub use bolt_web_macro::main;
48pub use paste;
49pub use tokio;
50
51trait Io: AsyncRead + AsyncWrite + Unpin {}
52impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
53
54#[allow(dead_code)]
55pub struct App {
56    router: Router,
57    error_handler: Arc<dyn ErrorHandler>,
58    client: Client,
59    timeout: u64,
60    connection_limit: u64,
61    read_timeout: u64,
62    header_limit: usize,
63}
64
65#[allow(unused_variables)]
66#[allow(dead_code)]
67impl App {
68    pub fn new() -> Self {
69        Self {
70            router: Router::new(),
71            error_handler: Arc::new(DefaultErrorHandler),
72            client: Client::new(),
73            timeout: 30,
74            connection_limit: 100,
75            read_timeout: 10,
76            header_limit: 32 * 1024,
77        }
78    }
79
80    pub fn set_timeout(&mut self, seconds: u64) {
81        self.timeout = seconds;
82    }
83
84    pub fn set_connection_limit(&mut self, limit: u64) {
85        self.connection_limit = limit;
86    }
87
88    pub fn set_read_timeout(&mut self, seconds: u64) {
89        self.read_timeout = seconds;
90    }
91
92    pub fn set_header_limit(&mut self, bytes: usize) {
93        self.header_limit = bytes;
94    }
95
96    fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
97    where
98        H: Handler + 'static,
99    {
100        self.router.insert(path, method, handler);
101    }
102
103    pub fn get<H>(&mut self, path: &str, handler: H)
104    where
105        H: Handler + 'static,
106    {
107        self.add_route(Method::GET, path, handler);
108    }
109
110    pub fn post<H>(&mut self, path: &str, handler: H)
111    where
112        H: Handler + 'static,
113    {
114        self.add_route(Method::POST, path, handler);
115    }
116
117    pub fn put<H>(&mut self, path: &str, handler: H)
118    where
119        H: Handler + 'static,
120    {
121        self.add_route(Method::PUT, path, handler);
122    }
123
124    pub fn patch<H>(&mut self, path: &str, handler: H)
125    where
126        H: Handler + 'static,
127    {
128        self.add_route(Method::PATCH, path, handler);
129    }
130
131    pub fn delete<H>(&mut self, path: &str, handler: H)
132    where
133        H: Handler + 'static,
134    {
135        self.add_route(Method::DELETE, path, handler);
136    }
137
138    pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
139        Group {
140            prefix: path.to_string(),
141            app: self,
142        }
143    }
144
145    pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
146    where
147        M: Middleware + 'static,
148    {
149        let mw: Arc<M> = Arc::new(middleware_fn);
150        let full_path = path.to_string();
151
152        match method {
153            Some(m) => self.router.insert_middleware(&full_path, m, mw),
154            None => {
155                for m in [
156                    Method::GET,
157                    Method::POST,
158                    Method::PUT,
159                    Method::PATCH,
160                    Method::DELETE,
161                    Method::OPTIONS,
162                    Method::HEAD,
163                    Method::TRACE,
164                ] {
165                    self.router.insert_middleware(&full_path, m, mw.clone());
166                }
167            }
168        }
169    }
170
171    pub fn set_error_handler<E>(&mut self, handler: E)
172    where
173        E: ErrorHandler + 'static,
174    {
175        self.error_handler = Arc::new(handler);
176    }
177
178    pub async fn run(&self, addr: &str, mode: Mode) -> Result<(), BoltError> {
179        println!("⚡ A high performance & minimalist web framework in rust.");
180        println!(
181            r#"
182    __          ____
183   / /_  ____  / / /_
184  / __ \/ __ \/ / __/
185 / /_/ / /_/ / / /_  
186/_.___/\____/_/\__/  v0.2.0
187"#
188        );
189
190        println!(">> Server running on http://{}", addr);
191
192        let addr: SocketAddr = addr.parse().unwrap();
193
194        let listener = TcpListener::bind(addr).await?;
195        let router = Arc::new(self.router.clone());
196        let error_handler = self.error_handler.clone();
197        let active = Arc::new(Semaphore::new(self.connection_limit as usize));
198
199        self.server_loop(
200            router,
201            error_handler,
202            listener,
203            mode,
204            None,
205            self.timeout,
206            self.read_timeout,
207            Box::pin(tokio::signal::ctrl_c().map(|_| ())),
208            active,
209        )
210        .await
211    }
212
213    pub async fn run_tls(
214        &self,
215        addr: &str,
216        mode: Mode,
217        tls: Option<(&str, &str)>,
218    ) -> Result<(), BoltError> {
219        println!("⚡ A high performance & minimalist web framework in rust.");
220        println!(
221            "{}",
222            r#"
223    __          ____
224   / /_  ____  / / /_
225  / __ \/ __ \/ / __/
226 / /_/ / /_/ / / /_  
227/_.___/\____/_/\__/  v0.2.0
228"#
229        );
230
231        let addr: SocketAddr = addr.parse().unwrap();
232        let listener = TcpListener::bind(addr).await?;
233
234        let tls_acceptor: Option<Arc<TlsAcceptor>> = if let Some((cert, key)) = tls {
235            let cfg = tls_config(cert, key)?;
236            Some(Arc::new(TlsAcceptor::from(cfg)))
237        } else {
238            None
239        };
240
241        println!(
242            ">> Server running on {}://{}",
243            if tls_acceptor.is_some() {
244                "https"
245            } else {
246                "http"
247            },
248            addr
249        );
250
251        let router: Arc<Router> = Arc::new(self.router.clone());
252        let error_handler = self.error_handler.clone();
253        let active = Arc::new(Semaphore::new(self.connection_limit as usize));
254
255        self.server_loop(
256            router,
257            error_handler,
258            listener,
259            mode,
260            tls_acceptor,
261            self.timeout,
262            self.read_timeout,
263            Box::pin(tokio::signal::ctrl_c().map(|_| ())),
264            active,
265        )
266        .await
267    }
268
269    async fn server_loop(
270        &self,
271        router: Arc<Router>,
272        error_handler: Arc<dyn ErrorHandler>,
273        listener: TcpListener,
274        mode: Mode,
275        tls_acceptor: Option<Arc<TlsAcceptor>>,
276        timeout: u64,
277        read_timeout: u64,
278        mut shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
279        active: Arc<Semaphore>,
280    ) -> Result<(), BoltError> {
281        loop {
282            tokio::select! {
283                _ = &mut shutdown => {
284                    println!(">> Shutdown signal received. Stopping server...");
285                    break;
286                }
287
288                accept_res = listener.accept() => {
289                    let (stream, remote_addr) = match accept_res {
290                        Ok(v) => v,
291                        Err(e) => {
292                            eprintln!("Accept error: {}", e);
293                            continue;
294                        }
295                    };
296
297                    let permit = match active.clone().try_acquire_owned() {
298                        Ok(p) => p,
299                        Err(_) => {
300                            eprintln!("Connection limit reached — dropping client");
301                            continue;
302                        }
303                    };
304
305                    let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
306                        match acceptor.accept(stream).await {
307                            Ok(c) => Box::new(c),
308                            Err(e) => {
309                                eprintln!("TLS error: {}", e);
310                                continue;
311                            }
312                        }
313                    } else {
314                        Box::new(stream)
315                    };
316
317                    let limited = LimitReader::new(io, self.header_limit);
318                    let io = TokioIo::new(limited);
319
320                    let router = router.clone();
321                    let error_handler = error_handler.clone();
322
323                    let service = service_fn(move |req: Request<Incoming>| {
324                        let router = router.clone();
325                        let error_handler = error_handler.clone();
326                        let remote_addr = remote_addr.clone();
327                        let timeout = timeout;
328
329                        async move {
330                            let handler_future = tokio::time::timeout(
331                                Duration::from_secs(timeout),
332                                async {
333                                    let inner = AssertUnwindSafe(async move {
334                                        let mut req_body = RequestBody::new(req, remote_addr);
335                                        let mut res_body = ResponseWriter::new();
336
337                                        let method = match *req_body.method() {
338                                            hyper::Method::GET => Method::GET,
339                                            hyper::Method::POST => Method::POST,
340                                            hyper::Method::PUT => Method::PUT,
341                                            hyper::Method::PATCH => Method::PATCH,
342                                            hyper::Method::DELETE => Method::DELETE,
343                                            hyper::Method::OPTIONS => Method::OPTIONS,
344                                            hyper::Method::HEAD => Method::HEAD,
345                                            hyper::Method::TRACE => Method::TRACE,
346                                            _ => {
347                                                res_body.status(StatusCode::MethodNotAllowed)
348                                                        .send("Method Not Allowed");
349                                                return res_body;
350                                            }
351                                        };
352
353                                        let path = req_body.path().to_string();
354                                        for mw in router.collect_middleware(&path, method) {
355                                            mw.run(&mut req_body, &mut res_body).await;
356                                            if res_body.has_error() { break; }
357                                        }
358
359                                        if !res_body.has_error() {
360                                            if let Some((handler, params)) = router.find(&path, method) {
361                                                req_body.set_params(params);
362                                                handler.run(&mut req_body, &mut res_body).await;
363                                            } else {
364                                                res_body.error(
365                                                    StatusCode::NotFound,
366                                                    &format!("Not Found {} {}", req_body.method(), path),
367                                                );
368                                            }
369                                        }
370
371                                        if res_body.has_error() {
372                                            let msg = res_body.body.clone();
373                                            error_handler.run(msg, &mut res_body).await;
374                                        }
375
376                                        req_body.cleanup().await;
377                                        res_body
378                                    })
379                                    .catch_unwind()
380                                    .await;
381
382                                    match inner {
383                                        Ok(r) => r,
384                                        Err(_) => {
385                                            let mut res = ResponseWriter::new();
386                                            res.error(StatusCode::InternalServerError, "Internal Server Error");
387                                            res
388                                        }
389                                    }
390                                }
391                            ).await;
392
393                            let res_body = match handler_future {
394                                Ok(res) => res,
395                                Err(_) => {
396                                    let mut res = ResponseWriter::new();
397                                    res.error(StatusCode::RequestTimeout, "Request Timeout");
398                                    res
399                                }
400                            };
401
402                            Ok::<_, Infallible>(res_body.into_response())
403                        }
404                    });
405
406                    let permit = permit;
407
408                    match mode {
409                        Mode::Http1 => {
410                            tokio::spawn(async move {
411                                let _permit = permit;
412
413                                let result = tokio::time::timeout(
414                                    Duration::from_secs(read_timeout),
415                                    async {
416                                        http1::Builder::new()
417                                            .serve_connection(io, service)
418                                            .await
419                                    }
420                                ).await;
421
422                                match result {
423                                    Ok(Ok(_)) => {}
424                                    Ok(Err(e)) => eprintln!("Connection error: {}", e),
425                                    Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
426                                }
427                            });
428                        }
429
430                        Mode::Http2 => {
431                            tokio::spawn(async move {
432                                let _permit = permit;
433
434                                let result = tokio::time::timeout(
435                                    Duration::from_secs(read_timeout),
436                                    async {
437                                        http2::Builder::new(TokioExecutor::new())
438                                            .serve_connection(io, service)
439                                            .await
440                                    }
441                                ).await;
442
443                                match result {
444                                    Ok(Ok(_)) => {}
445                                    Ok(Err(e)) => eprintln!("Connection error: {}", e),
446                                    Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
447                                }
448                            });
449                        }
450                    }
451                }
452            }
453        }
454
455        Ok(())
456    }
457}