1use futures_util::FutureExt;
2
3use std::{
4 convert::Infallible, net::SocketAddr, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
5 time::Duration,
6};
7
8use hyper::{
9 Request,
10 body::Incoming,
11 server::conn::{http1, http2},
12 service::service_fn,
13};
14use hyper_util::rt::{TokioExecutor, TokioIo};
15
16use tokio::{
17 io::{AsyncRead, AsyncWrite},
18 net::TcpListener,
19 sync::Semaphore,
20};
21use tokio_rustls::TlsAcceptor;
22
23use crate::{
24 client::Client,
25 error::DefaultErrorHandler,
26 group::Group,
27 headers::LimitReader,
28 http::StatusCode,
29 request::RequestBody,
30 response::ResponseWriter,
31 router::Router,
32 tls::tls_config,
33 types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
34};
35
36pub mod client;
37mod error;
38mod group;
39mod headers;
40pub mod http;
41pub mod macros;
42pub mod request;
43pub mod response;
44mod router;
45mod tls;
46pub mod types;
47pub use bolt_web_macro::main;
48pub use paste;
49pub use tokio;
50
51trait Io: AsyncRead + AsyncWrite + Unpin {}
52impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
53
54#[allow(dead_code)]
55pub struct App {
56 router: Router,
57 error_handler: Arc<dyn ErrorHandler>,
58 client: Client,
59 timeout: u64,
60 connection_limit: u64,
61 read_timeout: u64,
62 header_limit: usize,
63}
64
65#[allow(unused_variables)]
66#[allow(dead_code)]
67impl App {
68 pub fn new() -> Self {
69 Self {
70 router: Router::new(),
71 error_handler: Arc::new(DefaultErrorHandler),
72 client: Client::new(),
73 timeout: 30,
74 connection_limit: 100,
75 read_timeout: 10,
76 header_limit: 32 * 1024,
77 }
78 }
79
80 pub fn set_timeout(&mut self, seconds: u64) {
81 self.timeout = seconds;
82 }
83
84 pub fn set_connection_limit(&mut self, limit: u64) {
85 self.connection_limit = limit;
86 }
87
88 pub fn set_read_timeout(&mut self, seconds: u64) {
89 self.read_timeout = seconds;
90 }
91
92 pub fn set_header_limit(&mut self, bytes: usize) {
93 self.header_limit = bytes;
94 }
95
96 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
97 where
98 H: Handler + 'static,
99 {
100 self.router.insert(path, method, handler);
101 }
102
103 pub fn get<H>(&mut self, path: &str, handler: H)
104 where
105 H: Handler + 'static,
106 {
107 self.add_route(Method::GET, path, handler);
108 }
109
110 pub fn post<H>(&mut self, path: &str, handler: H)
111 where
112 H: Handler + 'static,
113 {
114 self.add_route(Method::POST, path, handler);
115 }
116
117 pub fn put<H>(&mut self, path: &str, handler: H)
118 where
119 H: Handler + 'static,
120 {
121 self.add_route(Method::PUT, path, handler);
122 }
123
124 pub fn patch<H>(&mut self, path: &str, handler: H)
125 where
126 H: Handler + 'static,
127 {
128 self.add_route(Method::PATCH, path, handler);
129 }
130
131 pub fn delete<H>(&mut self, path: &str, handler: H)
132 where
133 H: Handler + 'static,
134 {
135 self.add_route(Method::DELETE, path, handler);
136 }
137
138 pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
139 Group {
140 prefix: path.to_string(),
141 app: self,
142 }
143 }
144
145 pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
146 where
147 M: Middleware + 'static,
148 {
149 let mw: Arc<M> = Arc::new(middleware_fn);
150 let full_path = path.to_string();
151
152 match method {
153 Some(m) => self.router.insert_middleware(&full_path, m, mw),
154 None => {
155 for m in [
156 Method::GET,
157 Method::POST,
158 Method::PUT,
159 Method::PATCH,
160 Method::DELETE,
161 Method::OPTIONS,
162 Method::HEAD,
163 Method::TRACE,
164 ] {
165 self.router.insert_middleware(&full_path, m, mw.clone());
166 }
167 }
168 }
169 }
170
171 pub fn set_error_handler<E>(&mut self, handler: E)
172 where
173 E: ErrorHandler + 'static,
174 {
175 self.error_handler = Arc::new(handler);
176 }
177
178 pub async fn run(&self, addr: &str, mode: Mode) -> Result<(), BoltError> {
179 println!("⚡ A high performance & minimalist web framework in rust.");
180 println!(
181 r#"
182 __ ____
183 / /_ ____ / / /_
184 / __ \/ __ \/ / __/
185 / /_/ / /_/ / / /_
186/_.___/\____/_/\__/ v0.2.0
187"#
188 );
189
190 println!(">> Server running on http://{}", addr);
191
192 let addr: SocketAddr = addr.parse().unwrap();
193
194 let listener = TcpListener::bind(addr).await?;
195 let router = Arc::new(self.router.clone());
196 let error_handler = self.error_handler.clone();
197 let active = Arc::new(Semaphore::new(self.connection_limit as usize));
198
199 self.server_loop(
200 router,
201 error_handler,
202 listener,
203 mode,
204 None,
205 self.timeout,
206 self.read_timeout,
207 Box::pin(tokio::signal::ctrl_c().map(|_| ())),
208 active,
209 )
210 .await
211 }
212
213 pub async fn run_tls(
214 &self,
215 addr: &str,
216 mode: Mode,
217 tls: Option<(&str, &str)>,
218 ) -> Result<(), BoltError> {
219 println!("⚡ A high performance & minimalist web framework in rust.");
220 println!(
221 "{}",
222 r#"
223 __ ____
224 / /_ ____ / / /_
225 / __ \/ __ \/ / __/
226 / /_/ / /_/ / / /_
227/_.___/\____/_/\__/ v0.2.0
228"#
229 );
230
231 let addr: SocketAddr = addr.parse().unwrap();
232 let listener = TcpListener::bind(addr).await?;
233
234 let tls_acceptor: Option<Arc<TlsAcceptor>> = if let Some((cert, key)) = tls {
235 let cfg = tls_config(cert, key)?;
236 Some(Arc::new(TlsAcceptor::from(cfg)))
237 } else {
238 None
239 };
240
241 println!(
242 ">> Server running on {}://{}",
243 if tls_acceptor.is_some() {
244 "https"
245 } else {
246 "http"
247 },
248 addr
249 );
250
251 let router: Arc<Router> = Arc::new(self.router.clone());
252 let error_handler = self.error_handler.clone();
253 let active = Arc::new(Semaphore::new(self.connection_limit as usize));
254
255 self.server_loop(
256 router,
257 error_handler,
258 listener,
259 mode,
260 tls_acceptor,
261 self.timeout,
262 self.read_timeout,
263 Box::pin(tokio::signal::ctrl_c().map(|_| ())),
264 active,
265 )
266 .await
267 }
268
269 async fn server_loop(
270 &self,
271 router: Arc<Router>,
272 error_handler: Arc<dyn ErrorHandler>,
273 listener: TcpListener,
274 mode: Mode,
275 tls_acceptor: Option<Arc<TlsAcceptor>>,
276 timeout: u64,
277 read_timeout: u64,
278 mut shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
279 active: Arc<Semaphore>,
280 ) -> Result<(), BoltError> {
281 loop {
282 tokio::select! {
283 _ = &mut shutdown => {
284 println!(">> Shutdown signal received. Stopping server...");
285 break;
286 }
287
288 accept_res = listener.accept() => {
289 let (stream, remote_addr) = match accept_res {
290 Ok(v) => v,
291 Err(e) => {
292 eprintln!("Accept error: {}", e);
293 continue;
294 }
295 };
296
297 let permit = match active.clone().try_acquire_owned() {
298 Ok(p) => p,
299 Err(_) => {
300 eprintln!("Connection limit reached — dropping client");
301 continue;
302 }
303 };
304
305 let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
306 match acceptor.accept(stream).await {
307 Ok(c) => Box::new(c),
308 Err(e) => {
309 eprintln!("TLS error: {}", e);
310 continue;
311 }
312 }
313 } else {
314 Box::new(stream)
315 };
316
317 let limited = LimitReader::new(io, self.header_limit);
318 let io = TokioIo::new(limited);
319
320 let router = router.clone();
321 let error_handler = error_handler.clone();
322
323 let service = service_fn(move |req: Request<Incoming>| {
324 let router = router.clone();
325 let error_handler = error_handler.clone();
326 let remote_addr = remote_addr.clone();
327 let timeout = timeout;
328
329 async move {
330 let handler_future = tokio::time::timeout(
331 Duration::from_secs(timeout),
332 async {
333 let inner = AssertUnwindSafe(async move {
334 let mut req_body = RequestBody::new(req, remote_addr);
335 let mut res_body = ResponseWriter::new();
336
337 let method = match *req_body.method() {
338 hyper::Method::GET => Method::GET,
339 hyper::Method::POST => Method::POST,
340 hyper::Method::PUT => Method::PUT,
341 hyper::Method::PATCH => Method::PATCH,
342 hyper::Method::DELETE => Method::DELETE,
343 hyper::Method::OPTIONS => Method::OPTIONS,
344 hyper::Method::HEAD => Method::HEAD,
345 hyper::Method::TRACE => Method::TRACE,
346 _ => {
347 res_body.status(StatusCode::MethodNotAllowed)
348 .send("Method Not Allowed");
349 return res_body;
350 }
351 };
352
353 let path = req_body.path().to_string();
354 for mw in router.collect_middleware(&path, method) {
355 mw.run(&mut req_body, &mut res_body).await;
356 if res_body.has_error() { break; }
357 }
358
359 if !res_body.has_error() {
360 if let Some((handler, params)) = router.find(&path, method) {
361 req_body.set_params(params);
362 handler.run(&mut req_body, &mut res_body).await;
363 } else {
364 res_body.error(
365 StatusCode::NotFound,
366 &format!("Not Found {} {}", req_body.method(), path),
367 );
368 }
369 }
370
371 if res_body.has_error() {
372 let msg = res_body.body.clone();
373 error_handler.run(msg, &mut res_body).await;
374 }
375
376 req_body.cleanup().await;
377 res_body
378 })
379 .catch_unwind()
380 .await;
381
382 match inner {
383 Ok(r) => r,
384 Err(_) => {
385 let mut res = ResponseWriter::new();
386 res.error(StatusCode::InternalServerError, "Internal Server Error");
387 res
388 }
389 }
390 }
391 ).await;
392
393 let res_body = match handler_future {
394 Ok(res) => res,
395 Err(_) => {
396 let mut res = ResponseWriter::new();
397 res.error(StatusCode::RequestTimeout, "Request Timeout");
398 res
399 }
400 };
401
402 Ok::<_, Infallible>(res_body.into_response())
403 }
404 });
405
406 let permit = permit;
407
408 match mode {
409 Mode::Http1 => {
410 tokio::spawn(async move {
411 let _permit = permit;
412
413 let result = tokio::time::timeout(
414 Duration::from_secs(read_timeout),
415 async {
416 http1::Builder::new()
417 .serve_connection(io, service)
418 .await
419 }
420 ).await;
421
422 match result {
423 Ok(Ok(_)) => {}
424 Ok(Err(e)) => eprintln!("Connection error: {}", e),
425 Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
426 }
427 });
428 }
429
430 Mode::Http2 => {
431 tokio::spawn(async move {
432 let _permit = permit;
433
434 let result = tokio::time::timeout(
435 Duration::from_secs(read_timeout),
436 async {
437 http2::Builder::new(TokioExecutor::new())
438 .serve_connection(io, service)
439 .await
440 }
441 ).await;
442
443 match result {
444 Ok(Ok(_)) => {}
445 Ok(Err(e)) => eprintln!("Connection error: {}", e),
446 Err(_) => eprintln!("Slowloris: read timeout — closing connection"),
447 }
448 });
449 }
450 }
451 }
452 }
453 }
454
455 Ok(())
456 }
457}