1use futures_util::FutureExt;
2
3use std::{
4 convert::Infallible, net::SocketAddr, panic::AssertUnwindSafe, pin::Pin, sync::Arc,
5 time::Duration,
6};
7
8use hyper::{
9 Request,
10 body::Incoming,
11 server::conn::{http1, http2},
12 service::service_fn,
13};
14use hyper_util::rt::{TokioExecutor, TokioIo};
15
16use tokio::{
17 io::{AsyncRead, AsyncWrite},
18 net::TcpListener,
19 sync::Semaphore,
20};
21use tokio_rustls::TlsAcceptor;
22
23use crate::{
24 client::Client,
25 error::DefaultErrorHandler,
26 group::Group,
27 headers::LimitReader,
28 http::StatusCode,
29 request::RequestBody,
30 response::ResponseWriter,
31 router::Router,
32 tls::tls_config,
33 types::{BoltError, ErrorHandler, Handler, Method, Middleware, Mode},
34};
35
36pub mod client;
37mod error;
38mod group;
39mod headers;
40pub mod http;
41pub mod macros;
42pub mod request;
43pub mod response;
44mod router;
45mod tls;
46pub mod types;
47pub use async_trait;
48pub use bolt_web_macro::main;
49pub use paste;
50pub use tokio;
51
52trait Io: AsyncRead + AsyncWrite + Unpin {}
53impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
54
55#[allow(dead_code)]
56pub struct App {
57 router: Router,
58 error_handler: Arc<dyn ErrorHandler>,
59 client: Client,
60 timeout: u64,
61 connection_limit: u64,
62 header_limit: usize,
63}
64
65#[allow(unused_variables)]
66#[allow(dead_code)]
67impl App {
68 pub fn new() -> Self {
69 Self {
70 router: Router::new(),
71 error_handler: Arc::new(DefaultErrorHandler),
72 client: Client::new(),
73 timeout: 30,
74 connection_limit: 100,
75 header_limit: 32 * 1024,
76 }
77 }
78
79 pub fn set_timeout(&mut self, seconds: u64) {
80 self.timeout = seconds;
81 }
82
83 pub fn set_connection_limit(&mut self, limit: u64) {
84 self.connection_limit = limit;
85 }
86
87 pub fn set_header_limit(&mut self, bytes: usize) {
88 self.header_limit = bytes;
89 }
90
91 fn add_route<H>(&mut self, method: Method, path: &str, handler: H)
92 where
93 H: Handler + 'static,
94 {
95 self.router.insert(path, method, handler);
96 }
97
98 pub fn get<H>(&mut self, path: &str, handler: H)
99 where
100 H: Handler + 'static,
101 {
102 self.add_route(Method::GET, path, handler);
103 }
104
105 pub fn post<H>(&mut self, path: &str, handler: H)
106 where
107 H: Handler + 'static,
108 {
109 self.add_route(Method::POST, path, handler);
110 }
111
112 pub fn put<H>(&mut self, path: &str, handler: H)
113 where
114 H: Handler + 'static,
115 {
116 self.add_route(Method::PUT, path, handler);
117 }
118
119 pub fn patch<H>(&mut self, path: &str, handler: H)
120 where
121 H: Handler + 'static,
122 {
123 self.add_route(Method::PATCH, path, handler);
124 }
125
126 pub fn delete<H>(&mut self, path: &str, handler: H)
127 where
128 H: Handler + 'static,
129 {
130 self.add_route(Method::DELETE, path, handler);
131 }
132
133 pub fn group<'a>(&'a mut self, path: &str) -> Group<'a> {
134 Group {
135 prefix: path.to_string(),
136 app: self,
137 }
138 }
139
140 pub fn middleware<M>(&mut self, path: &str, method: Option<Method>, middleware_fn: M)
141 where
142 M: Middleware + 'static,
143 {
144 let mw: Arc<M> = Arc::new(middleware_fn);
145 let full_path = path.to_string();
146
147 match method {
148 Some(m) => self.router.insert_middleware(&full_path, m, mw),
149 None => {
150 for m in [
151 Method::GET,
152 Method::POST,
153 Method::PUT,
154 Method::PATCH,
155 Method::DELETE,
156 Method::OPTIONS,
157 Method::HEAD,
158 Method::TRACE,
159 ] {
160 self.router.insert_middleware(&full_path, m, mw.clone());
161 }
162 }
163 }
164 }
165
166 pub fn set_error_handler<E>(&mut self, handler: E)
167 where
168 E: ErrorHandler + 'static,
169 {
170 self.error_handler = Arc::new(handler);
171 }
172
173 pub async fn run(&self, addr: &str, mode: Mode) -> Result<(), BoltError> {
174 println!("⚡ A high performance & minimalist web framework in rust.");
175 println!(
176 r#"
177 __ ____
178 / /_ ____ / / /_
179 / __ \/ __ \/ / __/
180 / /_/ / /_/ / / /_
181/_.___/\____/_/\__/ v0.3
182"#
183 );
184
185 println!(">> Server running on http://{}", addr);
186
187 let addr: SocketAddr = addr.parse().unwrap();
188
189 let listener = TcpListener::bind(addr).await?;
190 let router = Arc::new(self.router.clone());
191 let error_handler = self.error_handler.clone();
192 let active = Arc::new(Semaphore::new(self.connection_limit as usize));
193
194 self.server_loop(
195 router,
196 error_handler,
197 listener,
198 mode,
199 None,
200 self.timeout,
201 Box::pin(tokio::signal::ctrl_c().map(|_| ())),
202 active,
203 )
204 .await
205 }
206
207 pub async fn run_tls(
208 &self,
209 addr: &str,
210 mode: Mode,
211 tls: Option<(&str, &str)>,
212 ) -> Result<(), BoltError> {
213 println!("⚡ A high performance & minimalist web framework in rust.");
214 println!(
215 "{}",
216 r#"
217 __ ____
218 / /_ ____ / / /_
219 / __ \/ __ \/ / __/
220 / /_/ / /_/ / / /_
221/_.___/\____/_/\__/ v0.3
222"#
223 );
224
225 let addr: SocketAddr = addr.parse().unwrap();
226 let listener = TcpListener::bind(addr).await?;
227
228 let tls_acceptor: Option<Arc<TlsAcceptor>> = if let Some((cert, key)) = tls {
229 let cfg = tls_config(cert, key)?;
230 Some(Arc::new(TlsAcceptor::from(cfg)))
231 } else {
232 None
233 };
234
235 println!(
236 ">> Server running on {}://{}",
237 if tls_acceptor.is_some() {
238 "https"
239 } else {
240 "http"
241 },
242 addr
243 );
244
245 let router: Arc<Router> = Arc::new(self.router.clone());
246 let error_handler = self.error_handler.clone();
247 let active = Arc::new(Semaphore::new(self.connection_limit as usize));
248
249 self.server_loop(
250 router,
251 error_handler,
252 listener,
253 mode,
254 tls_acceptor,
255 self.timeout,
256 Box::pin(tokio::signal::ctrl_c().map(|_| ())),
257 active,
258 )
259 .await
260 }
261
262 async fn server_loop(
263 &self,
264 router: Arc<Router>,
265 error_handler: Arc<dyn ErrorHandler>,
266 listener: TcpListener,
267 mode: Mode,
268 tls_acceptor: Option<Arc<TlsAcceptor>>,
269 timeout: u64,
270 mut shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
271 active: Arc<Semaphore>,
272 ) -> Result<(), BoltError> {
273 loop {
274 tokio::select! {
275 _ = &mut shutdown => {
276 println!(">> Shutdown signal received. Stopping server...");
277 break;
278 }
279
280 accept_res = listener.accept() => {
281 let (stream, remote_addr) = match accept_res {
282 Ok(v) => v,
283 Err(e) => {
284 eprintln!("Accept error: {}", e);
285 continue;
286 }
287 };
288
289 let permit = match active.clone().try_acquire_owned() {
290 Ok(p) => p,
291 Err(_) => {
292 eprintln!("Connection limit reached — dropping client");
293 continue;
294 }
295 };
296
297 let io: Box<dyn Io + Send> = if let Some(ref acceptor) = tls_acceptor {
298 match acceptor.accept(stream).await {
299 Ok(c) => Box::new(c),
300 Err(e) => {
301 eprintln!("TLS error: {}", e);
302 continue;
303 }
304 }
305 } else {
306 Box::new(stream)
307 };
308
309 let limited = LimitReader::new(io, self.header_limit);
310 let io = TokioIo::new(limited);
311
312 let router = router.clone();
313 let error_handler = error_handler.clone();
314
315 let service = service_fn(move |req: Request<Incoming>| {
316 let router = router.clone();
317 let error_handler = error_handler.clone();
318 let remote_addr = remote_addr.clone();
319 let timeout = timeout;
320
321 async move {
322 let handler_future = tokio::time::timeout(
323 Duration::from_secs(timeout),
324 async {
325 let inner = AssertUnwindSafe(async move {
326 let mut req_body = RequestBody::new(req, remote_addr);
327 let mut res_body = ResponseWriter::new();
328
329 let method = match *req_body.method() {
330 hyper::Method::GET => Method::GET,
331 hyper::Method::POST => Method::POST,
332 hyper::Method::PUT => Method::PUT,
333 hyper::Method::PATCH => Method::PATCH,
334 hyper::Method::DELETE => Method::DELETE,
335 hyper::Method::OPTIONS => Method::OPTIONS,
336 hyper::Method::HEAD => Method::HEAD,
337 hyper::Method::TRACE => Method::TRACE,
338 _ => {
339 res_body.status(StatusCode::MethodNotAllowed)
340 .send("Method Not Allowed");
341 return res_body;
342 }
343 };
344
345 let path = req_body.path().to_string();
346 for mw in router.collect_middleware(&path, method) {
347 mw.run(&mut req_body, &mut res_body).await;
348 if res_body.has_error() { break; }
349 }
350
351 if !res_body.has_error() {
352 if let Some((handler, params)) = router.find(&path, method) {
353 req_body.set_params(params);
354 handler.run(&mut req_body, &mut res_body).await;
355 } else {
356 res_body.error(
357 StatusCode::NotFound,
358 &format!("Not Found {} {}", req_body.method(), path),
359 );
360 }
361 }
362
363 if res_body.has_error() {
364 let msg = res_body.body.clone();
365 error_handler.run(msg, &mut res_body).await;
366 }
367
368 req_body.cleanup().await;
369 res_body
370 })
371 .catch_unwind()
372 .await;
373
374 match inner {
375 Ok(r) => r,
376 Err(_) => {
377 let mut res = ResponseWriter::new();
378 res.error(StatusCode::InternalServerError, "Internal Server Error");
379 res
380 }
381 }
382 }
383 ).await;
384
385 let res_body = match handler_future {
386 Ok(res) => res,
387 Err(_) => {
388 let mut res = ResponseWriter::new();
389 res.error(StatusCode::RequestTimeout, "Request Timeout");
390 res
391 }
392 };
393
394 Ok::<_, Infallible>(res_body.into_response())
395 }
396 });
397
398 let permit = permit;
399
400 match mode {
401 Mode::Http1 => {
402 tokio::spawn(async move {
403 let _permit = permit;
404
405 if let Err(e) = http1::Builder::new()
406 .serve_connection(io, service)
407 .await
408 {
409 eprintln!("Connection error: {}", e);
410 }
411 });
412 }
413
414 Mode::Http2 => {
415 tokio::spawn(async move {
416 let _permit = permit;
417
418 if let Err(e) = http2::Builder::new(TokioExecutor::new())
419 .serve_connection(io, service)
420 .await
421 {
422 eprintln!("Connection error: {}", e);
423 }
424 });
425 }
426 }
427 }
428 }
429 }
430
431 Ok(())
432 }
433}