1use futures_util::FutureExt;
2
3use std::{
4 convert::Infallible, net::SocketAddr, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
5 time::Duration,
6};
7
8use hyper::{
9 Request,
10 body::Incoming,
11 server::conn::{http1, http2},
12 service::service_fn,
13};
14use hyper_util::rt::{TokioExecutor, TokioIo};
15
16use tokio::{
17 io::{AsyncRead, AsyncWrite},
18 net::TcpListener,
19 sync::Semaphore,
20};
21use tokio_rustls::TlsAcceptor;
22
23use crate::{
24 client::Client,
25 error::DefaultErrorHandler,
26 group::Group,
27 headers::LimitReader,
28 http::StatusCode,
29 request::RequestBody,
30 response::ResponseWriter,
31 router::Router,
32 tls::tls_config,
33 types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
34};
35
36pub mod client;
37mod error;
38mod group;
39mod headers;
40pub mod http;
41pub mod macros;
42pub mod request;
43pub mod response;
44mod router;
45mod tls;
46pub mod types;
47pub use async_trait;
48pub use bolt_web_macro::main;
49pub use paste;
50pub use tokio;
51
52trait Io: AsyncRead + AsyncWrite + Unpin {}
53impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
54
55#[allow(dead_code)]
56pub struct App {
57 router: Router,
58 error_handler: Arc<dyn ErrorHandler>,
59 client: Client,
60 timeout: u64,
61 connection_limit: u64,
62 read_timeout: u64,
63 header_limit: usize,
64}
65
66#[allow(unused_variables)]
67#[allow(dead_code)]
68impl App {
69 pub fn new() -> Self {
70 Self {
71 router: Router::new(),
72 error_handler: Arc::new(DefaultErrorHandler),
73 client: Client::new(),
74 timeout: 30,
75 connection_limit: 100,
76 read_timeout: 10,
77 header_limit: 32 * 1024,
78 }
79 }
80
81 pub fn set_timeout(&mut self, seconds: u64) {
82 self.timeout = seconds;
83 }
84
85 pub fn set_connection_limit(&mut self, limit: u64) {
86 self.connection_limit = limit;
87 }
88
89 pub fn set_read_timeout(&mut self, seconds: u64) {
90 self.read_timeout = seconds;
91 }
92
93 pub fn set_header_limit(&mut self, bytes: usize) {
94 self.header_limit = bytes;
95 }
96
97 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
98 where
99 H: Handler + 'static,
100 {
101 self.router.insert(path, method, handler);
102 }
103
104 pub fn get<H>(&mut self, path: &str, handler: H)
105 where
106 H: Handler + 'static,
107 {
108 self.add_route(Method::GET, path, handler);
109 }
110
111 pub fn post<H>(&mut self, path: &str, handler: H)
112 where
113 H: Handler + 'static,
114 {
115 self.add_route(Method::POST, path, handler);
116 }
117
118 pub fn put<H>(&mut self, path: &str, handler: H)
119 where
120 H: Handler + 'static,
121 {
122 self.add_route(Method::PUT, path, handler);
123 }
124
125 pub fn patch<H>(&mut self, path: &str, handler: H)
126 where
127 H: Handler + 'static,
128 {
129 self.add_route(Method::PATCH, path, handler);
130 }
131
132 pub fn delete<H>(&mut self, path: &str, handler: H)
133 where
134 H: Handler + 'static,
135 {
136 self.add_route(Method::DELETE, path, handler);
137 }
138
139 pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
140 Group {
141 prefix: path.to_string(),
142 app: self,
143 }
144 }
145
146 pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
147 where
148 M: Middleware + 'static,
149 {
150 let mw: Arc<M> = Arc::new(middleware_fn);
151 let full_path = path.to_string();
152
153 match method {
154 Some(m) => self.router.insert_middleware(&full_path, m, mw),
155 None => {
156 for m in [
157 Method::GET,
158 Method::POST,
159 Method::PUT,
160 Method::PATCH,
161 Method::DELETE,
162 Method::OPTIONS,
163 Method::HEAD,
164 Method::TRACE,
165 ] {
166 self.router.insert_middleware(&full_path, m, mw.clone());
167 }
168 }
169 }
170 }
171
172 pub fn set_error_handler<E>(&mut self, handler: E)
173 where
174 E: ErrorHandler + 'static,
175 {
176 self.error_handler = Arc::new(handler);
177 }
178
179 pub async fn run(&self, addr: &str, mode: Mode) -> Result<(), BoltError> {
180 println!("⚡ A high performance & minimalist web framework in rust.");
181 println!(
182 r#"
183 __ ____
184 / /_ ____ / / /_
185 / __ \/ __ \/ / __/
186 / /_/ / /_/ / / /_
187/_.___/\____/_/\__/ v0.2.0
188"#
189 );
190
191 println!(">> Server running on http://{}", addr);
192
193 let addr: SocketAddr = addr.parse().unwrap();
194
195 let listener = TcpListener::bind(addr).await?;
196 let router = Arc::new(self.router.clone());
197 let error_handler = self.error_handler.clone();
198 let active = Arc::new(Semaphore::new(self.connection_limit as usize));
199
200 self.server_loop(
201 router,
202 error_handler,
203 listener,
204 mode,
205 None,
206 self.timeout,
207 self.read_timeout,
208 Box::pin(tokio::signal::ctrl_c().map(|_| ())),
209 active,
210 )
211 .await
212 }
213
214 pub async fn run_tls(
215 &self,
216 addr: &str,
217 mode: Mode,
218 tls: Option<(&str, &str)>,
219 ) -> Result<(), BoltError> {
220 println!("⚡ A high performance & minimalist web framework in rust.");
221 println!(
222 "{}",
223 r#"
224 __ ____
225 / /_ ____ / / /_
226 / __ \/ __ \/ / __/
227 / /_/ / /_/ / / /_
228/_.___/\____/_/\__/ v0.2.0
229"#
230 );
231
232 let addr: SocketAddr = addr.parse().unwrap();
233 let listener = TcpListener::bind(addr).await?;
234
235 let tls_acceptor: Option<Arc<TlsAcceptor>> = if let Some((cert, key)) = tls {
236 let cfg = tls_config(cert, key)?;
237 Some(Arc::new(TlsAcceptor::from(cfg)))
238 } else {
239 None
240 };
241
242 println!(
243 ">> Server running on {}://{}",
244 if tls_acceptor.is_some() {
245 "https"
246 } else {
247 "http"
248 },
249 addr
250 );
251
252 let router: Arc<Router> = Arc::new(self.router.clone());
253 let error_handler = self.error_handler.clone();
254 let active = Arc::new(Semaphore::new(self.connection_limit as usize));
255
256 self.server_loop(
257 router,
258 error_handler,
259 listener,
260 mode,
261 tls_acceptor,
262 self.timeout,
263 self.read_timeout,
264 Box::pin(tokio::signal::ctrl_c().map(|_| ())),
265 active,
266 )
267 .await
268 }
269
270 async fn server_loop(
271 &self,
272 router: Arc<Router>,
273 error_handler: Arc<dyn ErrorHandler>,
274 listener: TcpListener,
275 mode: Mode,
276 tls_acceptor: Option<Arc<TlsAcceptor>>,
277 timeout: u64,
278 read_timeout: u64,
279 mut shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
280 active: Arc<Semaphore>,
281 ) -> Result<(), BoltError> {
282 loop {
283 tokio::select! {
284 _ = &mut shutdown => {
285 println!(">> Shutdown signal received. Stopping server...");
286 break;
287 }
288
289 accept_res = listener.accept() => {
290 let (stream, remote_addr) = match accept_res {
291 Ok(v) => v,
292 Err(e) => {
293 eprintln!("Accept error: {}", e);
294 continue;
295 }
296 };
297
298 let permit = match active.clone().try_acquire_owned() {
299 Ok(p) => p,
300 Err(_) => {
301 eprintln!("Connection limit reached — dropping client");
302 continue;
303 }
304 };
305
306 let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
307 match acceptor.accept(stream).await {
308 Ok(c) => Box::new(c),
309 Err(e) => {
310 eprintln!("TLS error: {}", e);
311 continue;
312 }
313 }
314 } else {
315 Box::new(stream)
316 };
317
318 let limited = LimitReader::new(io, self.header_limit);
319 let io = TokioIo::new(limited);
320
321 let router = router.clone();
322 let error_handler = error_handler.clone();
323
324 let service = service_fn(move |req: Request<Incoming>| {
325 let router = router.clone();
326 let error_handler = error_handler.clone();
327 let remote_addr = remote_addr.clone();
328 let timeout = timeout;
329
330 async move {
331 let handler_future = tokio::time::timeout(
332 Duration::from_secs(timeout),
333 async {
334 let inner = AssertUnwindSafe(async move {
335 let mut req_body = RequestBody::new(req, remote_addr);
336 let mut res_body = ResponseWriter::new();
337
338 let method = match *req_body.method() {
339 hyper::Method::GET => Method::GET,
340 hyper::Method::POST => Method::POST,
341 hyper::Method::PUT => Method::PUT,
342 hyper::Method::PATCH => Method::PATCH,
343 hyper::Method::DELETE => Method::DELETE,
344 hyper::Method::OPTIONS => Method::OPTIONS,
345 hyper::Method::HEAD => Method::HEAD,
346 hyper::Method::TRACE => Method::TRACE,
347 _ => {
348 res_body.status(StatusCode::MethodNotAllowed)
349 .send("Method Not Allowed");
350 return res_body;
351 }
352 };
353
354 let path = req_body.path().to_string();
355 for mw in router.collect_middleware(&path, method) {
356 mw.run(&mut req_body, &mut res_body).await;
357 if res_body.has_error() { break; }
358 }
359
360 if !res_body.has_error() {
361 if let Some((handler, params)) = router.find(&path, method) {
362 req_body.set_params(params);
363 handler.run(&mut req_body, &mut res_body).await;
364 } else {
365 res_body.error(
366 StatusCode::NotFound,
367 &format!("Not Found {} {}", req_body.method(), path),
368 );
369 }
370 }
371
372 if res_body.has_error() {
373 let msg = res_body.body.clone();
374 error_handler.run(msg, &mut res_body).await;
375 }
376
377 req_body.cleanup().await;
378 res_body
379 })
380 .catch_unwind()
381 .await;
382
383 match inner {
384 Ok(r) => r,
385 Err(_) => {
386 let mut res = ResponseWriter::new();
387 res.error(StatusCode::InternalServerError, "Internal Server Error");
388 res
389 }
390 }
391 }
392 ).await;
393
394 let res_body = match handler_future {
395 Ok(res) => res,
396 Err(_) => {
397 let mut res = ResponseWriter::new();
398 res.error(StatusCode::RequestTimeout, "Request Timeout");
399 res
400 }
401 };
402
403 Ok::<_, Infallible>(res_body.into_response())
404 }
405 });
406
407 let permit = permit;
408
409 match mode {
410 Mode::Http1 => {
411 tokio::spawn(async move {
412 let _permit = permit;
413
414 let result = tokio::time::timeout(
415 Duration::from_secs(read_timeout),
416 async {
417 http1::Builder::new()
418 .serve_connection(io, service)
419 .await
420 }
421 ).await;
422
423 match result {
424 Ok(Ok(_)) => {}
425 Ok(Err(e)) => eprintln!("Connection error: {}", e),
426 Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
427 }
428 });
429 }
430
431 Mode::Http2 => {
432 tokio::spawn(async move {
433 let _permit = permit;
434
435 let result = tokio::time::timeout(
436 Duration::from_secs(read_timeout),
437 async {
438 http2::Builder::new(TokioExecutor::new())
439 .serve_connection(io, service)
440 .await
441 }
442 ).await;
443
444 match result {
445 Ok(Ok(_)) => {}
446 Ok(Err(e)) => eprintln!("Connection error: {}", e),
447 Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
448 }
449 });
450 }
451 }
452 }
453 }
454 }
455
456 Ok(())
457 }
458}