1use futures_util::FutureExt;
2
3use std::{
4 convert::Infallible, net::SocketAddr, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
5 time::Duration,
6};
7
8use hyper::{
9 Request,
10 body::Incoming,
11 server::conn::{http1, http2},
12 service::service_fn,
13};
14use hyper_util::rt::{TokioExecutor, TokioIo};
15
16use tokio::{
17 io::{AsyncRead, AsyncWrite},
18 net::TcpListener,
19 sync::Semaphore,
20};
21use tokio_rustls::TlsAcceptor;
22
23use crate::{
24 client::Client,
25 error::DefaultErrorHandler,
26 group::Group,
27 headers::LimitReader,
28 http::StatusCode,
29 request::RequestBody,
30 response::ResponseWriter,
31 router::Router,
32 tls::tls_config,
33 types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
34};
35
36pub mod client;
37mod error;
38mod group;
39mod headers;
40pub mod http;
41pub mod macros;
42pub mod request;
43pub mod response;
44mod router;
45mod tls;
46pub mod types;
47pub use bolt_web_macro::main;
48
49trait Io: AsyncRead + AsyncWrite + Unpin {}
50impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
51
52#[allow(dead_code)]
53pub struct App {
54 router: Router,
55 error_handler: Arc<dyn ErrorHandler>,
56 client: Client,
57 timeout: u64,
58 connection_limit: u64,
59 read_timeout: u64,
60 header_limit: usize,
61}
62
63#[allow(unused_variables)]
64#[allow(dead_code)]
65impl App {
66 pub fn new() -> Self {
67 Self {
68 router: Router::new(),
69 error_handler: Arc::new(DefaultErrorHandler),
70 client: Client::new(),
71 timeout: 30,
72 connection_limit: 100,
73 read_timeout: 10,
74 header_limit: 32 * 1024,
75 }
76 }
77
78 pub fn set_timeout(&mut self, seconds: u64) {
79 self.timeout = seconds;
80 }
81
82 pub fn set_connection_limit(&mut self, limit: u64) {
83 self.connection_limit = limit;
84 }
85
86 pub fn set_read_timeout(&mut self, seconds: u64) {
87 self.read_timeout = seconds;
88 }
89
90 pub fn set_header_limit(&mut self, bytes: usize) {
91 self.header_limit = bytes;
92 }
93
94 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
95 where
96 H: Handler + 'static,
97 {
98 self.router.insert(path, method, handler);
99 }
100
101 pub fn get<H>(&mut self, path: &str, handler: H)
102 where
103 H: Handler + 'static,
104 {
105 self.add_route(Method::GET, path, handler);
106 }
107
108 pub fn post<H>(&mut self, path: &str, handler: H)
109 where
110 H: Handler + 'static,
111 {
112 self.add_route(Method::POST, path, handler);
113 }
114
115 pub fn put<H>(&mut self, path: &str, handler: H)
116 where
117 H: Handler + 'static,
118 {
119 self.add_route(Method::PUT, path, handler);
120 }
121
122 pub fn patch<H>(&mut self, path: &str, handler: H)
123 where
124 H: Handler + 'static,
125 {
126 self.add_route(Method::PATCH, path, handler);
127 }
128
129 pub fn delete<H>(&mut self, path: &str, handler: H)
130 where
131 H: Handler + 'static,
132 {
133 self.add_route(Method::DELETE, path, handler);
134 }
135
136 pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
137 Group {
138 prefix: path.to_string(),
139 app: self,
140 }
141 }
142
143 pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
144 where
145 M: Middleware + 'static,
146 {
147 let mw: Arc<M> = Arc::new(middleware_fn);
148 let full_path = path.to_string();
149
150 match method {
151 Some(m) => self.router.insert_middleware(&full_path, m, mw),
152 None => {
153 for m in [
154 Method::GET,
155 Method::POST,
156 Method::PUT,
157 Method::PATCH,
158 Method::DELETE,
159 Method::OPTIONS,
160 Method::HEAD,
161 Method::TRACE,
162 ] {
163 self.router.insert_middleware(&full_path, m, mw.clone());
164 }
165 }
166 }
167 }
168
169 pub fn set_error_handler<E>(&mut self, handler: E)
170 where
171 E: ErrorHandler + 'static,
172 {
173 self.error_handler = Arc::new(handler);
174 }
175
176 pub async fn run(&self, addr: &str, mode: Mode) -> Result<(), BoltError> {
177 println!("⚡ A high performance & minimalist web framework in rust.");
178 println!(
179 r#"
180 __ ____
181 / /_ ____ / / /_
182 / __ \/ __ \/ / __/
183 / /_/ / /_/ / / /_
184/_.___/\____/_/\__/ v0.2.0
185"#
186 );
187
188 println!(">> Server running on http://{}", addr);
189
190 let addr: SocketAddr = addr.parse().unwrap();
191
192 let listener = TcpListener::bind(addr).await?;
193 let router = Arc::new(self.router.clone());
194 let error_handler = self.error_handler.clone();
195 let active = Arc::new(Semaphore::new(self.connection_limit as usize));
196
197 self.server_loop(
198 router,
199 error_handler,
200 listener,
201 mode,
202 None,
203 self.timeout,
204 self.read_timeout,
205 Box::pin(tokio::signal::ctrl_c().map(|_| ())),
206 active,
207 )
208 .await
209 }
210
211 pub async fn run_tls(
212 &self,
213 addr: &str,
214 mode: Mode,
215 tls: Option<(&str, &str)>,
216 ) -> Result<(), BoltError> {
217 println!("⚡ A high performance & minimalist web framework in rust.");
218 println!(
219 "{}",
220 r#"
221 __ ____
222 / /_ ____ / / /_
223 / __ \/ __ \/ / __/
224 / /_/ / /_/ / / /_
225/_.___/\____/_/\__/ v0.2.0
226"#
227 );
228
229 let addr: SocketAddr = addr.parse().unwrap();
230 let listener = TcpListener::bind(addr).await?;
231
232 let tls_acceptor: Option<Arc<TlsAcceptor>> = if let Some((cert, key)) = tls {
233 let cfg = tls_config(cert, key)?;
234 Some(Arc::new(TlsAcceptor::from(cfg)))
235 } else {
236 None
237 };
238
239 println!(
240 ">> Server running on {}://{}",
241 if tls_acceptor.is_some() {
242 "https"
243 } else {
244 "http"
245 },
246 addr
247 );
248
249 let router: Arc<Router> = Arc::new(self.router.clone());
250 let error_handler = self.error_handler.clone();
251 let active = Arc::new(Semaphore::new(self.connection_limit as usize));
252
253 self.server_loop(
254 router,
255 error_handler,
256 listener,
257 mode,
258 tls_acceptor,
259 self.timeout,
260 self.read_timeout,
261 Box::pin(tokio::signal::ctrl_c().map(|_| ())),
262 active,
263 )
264 .await
265 }
266
267 async fn server_loop(
268 &self,
269 router: Arc<Router>,
270 error_handler: Arc<dyn ErrorHandler>,
271 listener: TcpListener,
272 mode: Mode,
273 tls_acceptor: Option<Arc<TlsAcceptor>>,
274 timeout: u64,
275 read_timeout: u64,
276 mut shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
277 active: Arc<Semaphore>,
278 ) -> Result<(), BoltError> {
279 loop {
280 tokio::select! {
281 _ = &mut shutdown => {
282 println!(">> Shutdown signal received. Stopping server...");
283 break;
284 }
285
286 accept_res = listener.accept() => {
287 let (stream, remote_addr) = match accept_res {
288 Ok(v) => v,
289 Err(e) => {
290 eprintln!("Accept error: {}", e);
291 continue;
292 }
293 };
294
295 let permit = match active.clone().try_acquire_owned() {
296 Ok(p) => p,
297 Err(_) => {
298 eprintln!("Connection limit reached — dropping client");
299 continue;
300 }
301 };
302
303 let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
304 match acceptor.accept(stream).await {
305 Ok(c) => Box::new(c),
306 Err(e) => {
307 eprintln!("TLS error: {}", e);
308 continue;
309 }
310 }
311 } else {
312 Box::new(stream)
313 };
314
315 let limited = LimitReader::new(io, self.header_limit);
316 let io = TokioIo::new(limited);
317
318 let router = router.clone();
319 let error_handler = error_handler.clone();
320
321 let service = service_fn(move |req: Request<Incoming>| {
322 let router = router.clone();
323 let error_handler = error_handler.clone();
324 let remote_addr = remote_addr.clone();
325 let timeout = timeout;
326
327 async move {
328 let handler_future = tokio::time::timeout(
329 Duration::from_secs(timeout),
330 async {
331 let inner = AssertUnwindSafe(async move {
332 let mut req_body = RequestBody::new(req, remote_addr);
333 let mut res_body = ResponseWriter::new();
334
335 let method = match *req_body.method() {
336 hyper::Method::GET => Method::GET,
337 hyper::Method::POST => Method::POST,
338 hyper::Method::PUT => Method::PUT,
339 hyper::Method::PATCH => Method::PATCH,
340 hyper::Method::DELETE => Method::DELETE,
341 hyper::Method::OPTIONS => Method::OPTIONS,
342 hyper::Method::HEAD => Method::HEAD,
343 hyper::Method::TRACE => Method::TRACE,
344 _ => {
345 res_body.status(StatusCode::MethodNotAllowed)
346 .send("Method Not Allowed");
347 return res_body;
348 }
349 };
350
351 let path = req_body.path().to_string();
352 for mw in router.collect_middleware(&path, method) {
353 mw.run(&mut req_body, &mut res_body).await;
354 if res_body.has_error() { break; }
355 }
356
357 if !res_body.has_error() {
358 if let Some((handler, params)) = router.find(&path, method) {
359 req_body.set_params(params);
360 handler.run(&mut req_body, &mut res_body).await;
361 } else {
362 res_body.error(
363 StatusCode::NotFound,
364 &format!("Not Found {} {}", req_body.method(), path),
365 );
366 }
367 }
368
369 if res_body.has_error() {
370 let msg = res_body.body.clone();
371 error_handler.run(msg, &mut res_body).await;
372 }
373
374 req_body.cleanup().await;
375 res_body
376 })
377 .catch_unwind()
378 .await;
379
380 match inner {
381 Ok(r) => r,
382 Err(_) => {
383 let mut res = ResponseWriter::new();
384 res.error(StatusCode::InternalServerError, "Internal Server Error");
385 res
386 }
387 }
388 }
389 ).await;
390
391 let res_body = match handler_future {
392 Ok(res) => res,
393 Err(_) => {
394 let mut res = ResponseWriter::new();
395 res.error(StatusCode::RequestTimeout, "Request Timeout");
396 res
397 }
398 };
399
400 Ok::<_, Infallible>(res_body.into_response())
401 }
402 });
403
404 let permit = permit;
405
406 match mode {
407 Mode::Http1 => {
408 tokio::spawn(async move {
409 let _permit = permit;
410
411 let result = tokio::time::timeout(
412 Duration::from_secs(read_timeout),
413 async {
414 http1::Builder::new()
415 .serve_connection(io, service)
416 .await
417 }
418 ).await;
419
420 match result {
421 Ok(Ok(_)) => {}
422 Ok(Err(e)) => eprintln!("Connection error: {}", e),
423 Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
424 }
425 });
426 }
427
428 Mode::Http2 => {
429 tokio::spawn(async move {
430 let _permit = permit;
431
432 let result = tokio::time::timeout(
433 Duration::from_secs(read_timeout),
434 async {
435 http2::Builder::new(TokioExecutor::new())
436 .serve_connection(io, service)
437 .await
438 }
439 ).await;
440
441 match result {
442 Ok(Ok(_)) => {}
443 Ok(Err(e)) => eprintln!("Connection error: {}", e),
444 Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
445 }
446 });
447 }
448 }
449 }
450 }
451 }
452
453 Ok(())
454 }
455}