1use self::pipeline::{Pipeline, PipelineBuilder};
4use hyper_util::{rt::TokioIo, server::graceful::GracefulShutdown};
5use std::net::IpAddr;
6
7use crate::{
8 http::request::request_body_limit::RequestBodyLimit,
9 server::Server
10};
11
12use std::{
13 future::Future,
14 io::Error,
15 net::SocketAddr,
16 sync::{Arc, Weak}
17};
18
19use tokio::{
20 io::self,
21 net::{TcpListener, TcpStream},
22 signal,
23 sync::watch
24};
25
26
27#[cfg(feature = "di")]
28use crate::di::{Container, ContainerBuilder};
29
30#[cfg(feature = "tls")]
31use tokio_rustls::TlsAcceptor;
32
33#[cfg(feature = "tls")]
34use crate::tls::TlsConfig;
35
36#[cfg(feature = "tracing")]
37use crate::tracing::TracingConfig;
38
39#[cfg(feature = "middleware")]
40use crate::http::CorsConfig;
41
42#[cfg(feature = "jwt-auth")]
43use crate::auth::bearer::{BearerAuthConfig, BearerTokenService};
44
45#[cfg(feature = "static-files")]
46pub use self::env::HostEnv;
47
48#[cfg(feature = "static-files")]
49pub mod env;
50pub mod router;
51pub(crate) mod pipeline;
52pub(crate) mod scope;
53
54pub(super) const GRACEFUL_SHUTDOWN_TIMEOUT: u64 = 10;
55const DEFAULT_PORT: u16 = 7878;
56
57#[derive(Debug)]
87pub struct App {
88 #[cfg(feature = "di")]
90 pub(super) container: ContainerBuilder,
91
92 #[cfg(feature = "tls")]
94 pub(super) tls_config: Option<TlsConfig>,
95
96 #[cfg(feature = "tracing")]
98 pub(super) tracing_config: Option<TracingConfig>,
99
100 #[cfg(feature = "middleware")]
102 pub(super) cors_config: Option<CorsConfig>,
103
104 #[cfg(feature = "static-files")]
106 pub(super) host_env: HostEnv,
107
108 #[cfg(feature = "jwt-auth")]
110 pub(super) auth_config: Option<BearerAuthConfig>,
111
112 pub(super) pipeline: PipelineBuilder,
114
115 connection: Connection,
117
118 body_limit: RequestBodyLimit,
122
123 no_delay: bool,
127
128 show_greeter: bool,
132}
133
134#[derive(Debug)]
136pub struct Connection {
137 socket: SocketAddr
138}
139
140impl Default for Connection {
141 fn default() -> Self {
142 #[cfg(target_os = "windows")]
143 let ip = [127, 0, 0, 1];
144 #[cfg(not(target_os = "windows"))]
145 let ip = [0, 0, 0, 0];
146 let socket = (ip, DEFAULT_PORT).into();
147 Self { socket }
148 }
149}
150
151impl From<&str> for Connection {
152 fn from(s: &str) -> Self {
153 if let Ok(socket) = s.parse::<SocketAddr>() {
154 Self { socket }
155 } else {
156 Self::default()
157 }
158 }
159}
160
161impl<I: Into<IpAddr>> From<(I, u16)> for Connection {
162 fn from(value: (I, u16)) -> Self {
163 Self { socket: SocketAddr::from(value) }
164 }
165}
166
167pub(crate) struct AppInstance {
169 #[cfg(feature = "tls")]
171 pub(super) acceptor: Option<TlsAcceptor>,
172
173 #[cfg(feature = "di")]
175 container: Container,
176
177 #[cfg(feature = "static-files")]
179 pub(super) host_env: HostEnv,
180
181 #[cfg(feature = "jwt-auth")]
183 pub(super) bearer_token_service: Option<BearerTokenService>,
184
185 pub(super) graceful_shutdown: GracefulShutdown,
187
188 pub(super) body_limit: RequestBodyLimit,
190
191 pipeline: Pipeline,
193}
194
195impl TryFrom<App> for AppInstance {
196 type Error = Error;
197
198 fn try_from(app: App) -> Result<Self, Self::Error> {
199 #[cfg(feature = "tls")]
200 let acceptor = {
201 let tls_config = app.tls_config
202 .map(|config| config.build())
203 .transpose()?;
204 tls_config
205 .map(|config| TlsAcceptor::from(Arc::new(config)))
206 };
207 #[cfg(feature = "jwt-auth")]
208 let bearer_token_service = app.auth_config.map(Into::into);
209
210 let app_instance = Self {
211 body_limit: app.body_limit,
212 pipeline: app.pipeline.build(),
213 graceful_shutdown: GracefulShutdown::new(),
214 #[cfg(feature = "static-files")]
215 host_env: app.host_env,
216 #[cfg(feature = "di")]
217 container: app.container.build(),
218 #[cfg(feature = "jwt-auth")]
219 bearer_token_service,
220 #[cfg(feature = "tls")]
221 acceptor
222 };
223 Ok(app_instance)
224 }
225}
226
227impl AppInstance {
228 #[inline]
230 async fn shutdown(self) {
231 tokio::select! {
232 _ = self.graceful_shutdown.shutdown() => {
233 #[cfg(feature = "tracing")]
234 tracing::info!("shutting down the server...");
235 },
236 _ = tokio::time::sleep(std::time::Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT)) => {
237 #[cfg(feature = "tracing")]
238 tracing::warn!("timed out wait for all connections to close");
239 }
240 }
241 }
242}
243
244impl Default for App {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250impl App {
252 pub fn new() -> Self {
261 Self {
262 #[cfg(feature = "di")]
263 container: ContainerBuilder::new(),
264 #[cfg(feature = "tls")]
265 tls_config: None,
266 #[cfg(feature = "tracing")]
267 tracing_config: None,
268 #[cfg(feature = "middleware")]
269 cors_config: None,
270 #[cfg(feature = "static-files")]
271 host_env: HostEnv::default(),
272 #[cfg(feature = "jwt-auth")]
273 auth_config: None,
274 pipeline: PipelineBuilder::new(),
275 connection: Default::default(),
276 body_limit: Default::default(),
277 no_delay: false,
278 #[cfg(debug_assertions)]
279 show_greeter: true,
280 #[cfg(not(debug_assertions))]
281 show_greeter: false,
282 }
283 }
284
285 pub fn bind<S: Into<Connection>>(mut self, socket: S) -> Self {
295 self.connection = socket.into();
296 self
297 }
298
299 pub fn with_body_limit(mut self, limit: usize) -> Self {
303 self.body_limit = RequestBodyLimit::Enabled(limit);
304 self
305 }
306
307 pub fn without_body_limit(mut self) -> Self {
309 self.body_limit = RequestBodyLimit::Disabled;
310 self
311 }
312
313 pub fn with_no_delay(mut self) -> Self {
321 self.no_delay = true;
322 self
323 }
324
325 pub fn without_greeter(mut self) -> Self {
329 self.show_greeter = false;
330 self
331 }
332
333 pub fn run_blocking(self) {
350 if tokio::runtime::Handle::try_current().is_ok() {
351 panic!("`App::run_blocking()` cannot be called inside an existing Tokio runtime. Use `run().await` instead.");
352 }
353
354 let runtime = match tokio::runtime::Builder::new_multi_thread()
355 .enable_all()
356 .build()
357 {
358 Ok(rt) => rt,
359 Err(err) => {
360 #[cfg(feature = "tracing")]
361 tracing::error!("failed to start the runtime: {err:#}");
362 #[cfg(not(feature = "tracing"))]
363 eprintln!("failed to start the runtime: {err:#}");
364 return;
365 }
366 };
367
368 runtime.block_on(async {
369 if let Err(err) = self.run().await {
370 #[cfg(feature = "tracing")]
371 tracing::error!("failed to run the server: {err:#}");
372 #[cfg(not(feature = "tracing"))]
373 eprintln!("failed to run the server: {err:#}");
374 }
375 });
376 }
377
378 #[cfg(feature = "middleware")]
401 pub fn run(mut self) -> impl Future<Output = io::Result<()>> {
402 self.use_endpoints();
403 self.run_internal()
404 }
405
406 #[cfg(not(feature = "middleware"))]
429 pub fn run(self) -> impl Future<Output = io::Result<()>> {
430 self.run_internal()
431 }
432
433 #[inline]
434 async fn run_internal(self) -> io::Result<()> {
435 let socket = self.connection.socket;
436 let no_delay = self.no_delay;
437 let tcp_listener = TcpListener::bind(socket).await?;
438
439 #[cfg(debug_assertions)]
440 self.print_welcome();
441
442 #[cfg(feature = "tracing")]
443 {
444 #[cfg(feature = "tls")]
445 if self.tls_config.is_some() {
446 tracing::info!("listening on: https://{socket}")
447 } else {
448 tracing::info!("listening on: http://{socket}")
449 };
450 #[cfg(not(feature = "tls"))]
451 tracing::info!("listening on: http://{socket}");
452 }
453
454 let (shutdown_tx, shutdown_rx) = watch::channel::<()>(());
455 let shutdown_tx = Arc::new(shutdown_tx);
456 Self::shutdown_signal(shutdown_rx);
457
458 #[cfg(feature = "tls")]
459 let redirection_config = self.tls_config
460 .as_ref()
461 .map(|config| config.https_redirection_config);
462
463 let app_instance: Arc<AppInstance> = Arc::new(self.try_into()?);
464
465 #[cfg(feature = "tls")]
466 if let Some(redirection_config) = redirection_config
467 && redirection_config.enabled {
468 Self::run_https_redirection_middleware(
469 socket,
470 redirection_config.http_port,
471 shutdown_tx.clone());
472 }
473
474 loop {
475 let (stream, _) = tokio::select! {
476 Ok(connection) = tcp_listener.accept() => connection,
477 _ = shutdown_tx.closed() => break,
478 };
479 if let Err(_err) = stream.set_nodelay(no_delay) {
480 #[cfg(feature = "tracing")]
481 tracing::warn!("failed to set TCP_NODELAY on incoming connection: {_err:#}");
482 }
483 let instance = Arc::downgrade(&app_instance);
484 tokio::spawn(Self::handle_connection(stream, instance));
485 }
486
487 drop(tcp_listener);
488
489 if let Some(app_instance) = Arc::into_inner(app_instance) {
490 app_instance.shutdown().await;
491 }
492 Ok(())
493 }
494
495 #[inline]
496 fn shutdown_signal(shutdown_rx: watch::Receiver<()>) {
497 tokio::spawn(async move {
498 match signal::ctrl_c().await {
499 Ok(_) => (),
500 #[cfg(feature = "tracing")]
501 Err(err) => tracing::error!("unable to listen for shutdown signal: {err:#}"),
502 #[cfg(not(feature = "tracing"))]
503 Err(_) => ()
504 }
505 #[cfg(feature = "tracing")]
506 tracing::trace!("shutdown signal received, not accepting new requests");
507 drop(shutdown_rx);
508 });
509 }
510
511 #[inline]
512 async fn handle_connection(stream: TcpStream, app_instance: Weak<AppInstance>) {
513 #[cfg(not(feature = "tls"))]
514 Server::new(TokioIo::new(stream)).serve(app_instance).await;
515
516 #[cfg(feature = "tls")]
517 if let Some(acceptor) = app_instance.upgrade().and_then(|app| app.acceptor()) {
518 let stream = match acceptor.accept(stream).await {
519 Ok(tls_stream) => tls_stream,
520 Err(_err) => {
521 #[cfg(feature = "tracing")]
522 tracing::error!("failed to perform tls handshake: {_err:#}");
523 return;
524 }
525 };
526 let io = TokioIo::new(stream);
527 Server::new(io).serve(app_instance).await;
528 } else {
529 let io = TokioIo::new(stream);
530 Server::new(io).serve(app_instance).await;
531 };
532 }
533
534 #[cfg(debug_assertions)]
535 fn print_welcome(&self) {
536 if !self.show_greeter {
537 return;
538 }
539
540 let version = env!("CARGO_PKG_VERSION");
541 let addr = self.connection.socket;
542
543 #[cfg(not(feature = "tls"))]
544 let url = format!("http://{addr}");
545 #[cfg(feature = "tls")]
546 let url = if self.tls_config.is_some() {
547 format!("https://{addr}")
548 } else {
549 format!("http://{addr}")
550 };
551
552 println!();
553 println!("\x1b[1;34m╭───────────────────────────────────────────────╮");
554 println!("│ 🚀 Welcome to Volga v{version:<5} │");
555 println!("│ Listening on: {url:<28}│");
556 println!("╰───────────────────────────────────────────────╯\x1b[0m");
557
558 let routes = self.pipeline
559 .endpoints()
560 .collect();
561 println!("{routes}");
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use std::net::SocketAddr;
568 use crate::http::request::request_body_limit::RequestBodyLimit;
569 use crate::App;
570 use crate::app::{AppInstance, Connection};
571
572 #[test]
573 fn it_creates_connection_with_default_socket() {
574 let connection = Connection::default();
575
576 #[cfg(target_os = "windows")]
577 assert_eq!(connection.socket, SocketAddr::from(([127, 0, 0, 1], 7878)));
578 #[cfg(not(target_os = "windows"))]
579 assert_eq!(connection.socket, SocketAddr::from(([0, 0, 0, 0], 7878)));
580 }
581
582 #[test]
583 fn it_creates_connection_with_specified_socket() {
584 let connection: Connection = "127.0.0.1:5000".into();
585
586 assert_eq!(connection.socket, SocketAddr::from(([127, 0, 0, 1], 5000)));
587 }
588
589 #[test]
590 fn it_creates_default_connection_from_empty_str() {
591 let connection: Connection = "".into();
592
593 #[cfg(target_os = "windows")]
594 assert_eq!(connection.socket, SocketAddr::from(([127, 0, 0, 1], 7878)));
595 #[cfg(not(target_os = "windows"))]
596 assert_eq!(connection.socket, SocketAddr::from(([0, 0, 0, 0], 7878)));
597 }
598
599 #[test]
600 fn it_creates_connection_with_specified_socket_from_tuple() {
601 let connection: Connection = ([127, 0, 0, 1], 5000).into();
602
603 assert_eq!(connection.socket, SocketAddr::from(([127, 0, 0, 1], 5000)));
604 }
605
606 #[test]
607 fn it_creates_app_with_default_socket() {
608 let app = App::new();
609
610 #[cfg(target_os = "windows")]
611 assert_eq!(app.connection.socket, SocketAddr::from(([127, 0, 0, 1], 7878)));
612 #[cfg(not(target_os = "windows"))]
613 assert_eq!(app.connection.socket, SocketAddr::from(([0, 0, 0, 0], 7878)));
614 }
615
616 #[test]
617 fn it_binds_app_to_socket() {
618 let app = App::new().bind("127.0.0.1:5001");
619
620 assert_eq!(app.connection.socket, SocketAddr::from(([127, 0, 0, 1], 5001)));
621 }
622
623 #[test]
624 fn it_sets_default_body_limit() {
625 let app = App::new();
626 let RequestBodyLimit::Enabled(limit) = app.body_limit else { unreachable!() };
627
628 assert_eq!(limit, 5242880)
629 }
630
631 #[test]
632 fn it_sets_body_limit() {
633 let app = App::new().with_body_limit(10);
634 let RequestBodyLimit::Enabled(limit) = app.body_limit else { unreachable!() };
635
636 assert_eq!(limit, 10)
637 }
638
639 #[test]
640 fn it_disables_body_limit() {
641 let app = App::new().without_body_limit();
642
643 let RequestBodyLimit::Disabled = app.body_limit else { panic!() };
644 }
645
646 #[test]
647 fn it_converts_into_app_instance() {
648 let app = App::default();
649
650 let app_instance: AppInstance = app.try_into().unwrap();
651 let RequestBodyLimit::Enabled(limit) = app_instance.body_limit else { unreachable!() };
652
653 assert_eq!(limit, 5242880);
654 }
655
656 #[test]
657 fn it_debugs_connection() {
658 let connection: Connection = ([127, 0, 0, 1], 5000).into();
659
660 assert_eq!(format!("{connection:?}"), "Connection { socket: 127.0.0.1:5000 }");
661 }
662}