1pub mod health;
2pub(crate) mod management;
3
4#[cfg(feature = "tls")]
5pub(crate) mod tls;
6pub mod version;
7
8use crate::application::health::{AlwaysReadyAndAlive, HealthExt};
9use crate::application::version::{DefaultVersion, VersionExt};
10use crate::configuration::{AppConfig, Empty};
11use crate::error::Result;
12use crate::management::build_management_router;
13use crate::middleware::trace_request;
14use axum::middleware::from_fn;
15use axum::Router;
16use hyper::Server;
17use std::fmt::{Debug, Formatter};
18use std::net::SocketAddr;
19use std::sync::Arc;
20use tokio::signal;
21use tracing::info;
22
23pub struct Application<H = AlwaysReadyAndAlive, T = Empty, V = DefaultVersion> {
25    config: Arc<AppConfig<T>>,
26    health_indicator: H,
27    version: V,
28    router: Option<Router>,
29    metrics_callback: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
30    use_default_trace_layer: bool,
31}
32
33impl<H: Debug, T: Debug, V: Debug> Debug for Application<H, T, V> {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        let Self {
36            config,
37            health_indicator,
38            router,
39            metrics_callback,
40            use_default_trace_layer,
41            version,
42        } = self;
43        f.debug_struct("Application")
44            .field("config", config)
45            .field("health_indicator", health_indicator)
46            .field("version", version)
47            .field("router", router)
48            .field("use_default_trace_layer", use_default_trace_layer)
49            .field(
50                "metrics_callback",
51                if metrics_callback.is_some() {
52                    &"Some"
53                } else {
54                    &"None"
55                },
56            )
57            .finish()
58    }
59}
60
61impl<T> Application<T> {
62    pub fn new(config: AppConfig<T>) -> Application<AlwaysReadyAndAlive, T, DefaultVersion> {
64        Application::<AlwaysReadyAndAlive, T, DefaultVersion> {
65            config: Arc::new(config),
66            health_indicator: AlwaysReadyAndAlive,
67            version: DefaultVersion,
68            router: None,
69            metrics_callback: None,
70            use_default_trace_layer: true,
71        }
72    }
73
74    pub fn new_from_arced(
76        config: Arc<AppConfig<T>>,
77    ) -> Application<AlwaysReadyAndAlive, T, DefaultVersion> {
78        Application::<AlwaysReadyAndAlive, T, DefaultVersion> {
79            config,
80            health_indicator: AlwaysReadyAndAlive,
81            version: DefaultVersion,
82            router: None,
83            metrics_callback: None,
84            use_default_trace_layer: true,
85        }
86    }
87}
88
89impl<H, T, V> Application<H, T, V> {
90    pub fn health_indicator<Hh: HealthExt>(self, health: Hh) -> Application<Hh, T, V> {
92        let Self {
93            config,
94            health_indicator: _,
95            router,
96            metrics_callback,
97            use_default_trace_layer,
98            version,
99        } = self;
100
101        Application::<Hh, T, V> {
102            config,
103            health_indicator: health,
104            router,
105            metrics_callback,
106            use_default_trace_layer,
107            version,
108        }
109    }
110
111    pub fn version<Vv: VersionExt<T>>(self, version: Vv) -> Application<H, T, Vv> {
113        let Self {
114            config,
115            health_indicator,
116            router,
117            metrics_callback,
118            use_default_trace_layer,
119            version: _,
120        } = self;
121
122        Application::<H, T, Vv> {
123            config,
124            health_indicator,
125            router,
126            metrics_callback,
127            use_default_trace_layer,
128            version,
129        }
130    }
131
132    #[must_use]
134    pub fn router(self, router: Router) -> Self {
135        Self {
136            router: Some(router),
137            ..self
138        }
139    }
140
141    #[must_use]
143    pub fn metrics_callback(self, metrics_callback: impl Fn() + Send + Sync + 'static) -> Self {
144        Self {
145            metrics_callback: Some(Arc::new(metrics_callback)),
146            ..self
147        }
148    }
149
150    #[must_use]
165    pub fn use_default_tracing_layer(self, use_default: bool) -> Self {
166        Self {
167            use_default_trace_layer: use_default,
168            ..self
169        }
170    }
171
172    pub async fn serve(self) -> Result<()>
174    where
175        H: HealthExt,
176        V: VersionExt<T>,
177        T: Send + Sync + 'static,
178    {
179        let (router, application_socket) = self.prepare_router();
180        run_service(&application_socket, router).await
181    }
182
183    #[cfg(feature = "tls")]
185    pub async fn serve_tls(self) -> Result<()>
186    where
187        H: HealthExt,
188        V: VersionExt<T>,
189        T: Send + Sync + 'static,
190    {
191        use crate::error::Error;
192        use futures_util::TryFutureExt;
193        use std::fmt;
194        use tokio::{fs, try_join};
195
196        fn cant_load<Arg: fmt::Display>(r#type: &str) -> impl FnOnce(Arg) -> Error + '_ {
197            move |error| Error::CustomError(format!("Cant load TLS {type}: `{error}`."))
198        }
199
200        let tls_handshake_timeout = self.config.tls.handshake_timeout;
201
202        let tls_cert_path = self
203            .config
204            .tls
205            .cert_path
206            .as_deref()
207            .ok_or_else(|| cant_load("certificate")("No path present."))?;
208
209        let tls_key_path = self
210            .config
211            .tls
212            .key_path
213            .as_deref()
214            .ok_or_else(|| cant_load("key")("No path present."))?;
215
216        let (tls_cert, tls_key) = try_join!(
217            fs::read(tls_cert_path).map_err(cant_load("certificate")),
218            fs::read(tls_key_path).map_err(cant_load("key"))
219        )?;
220
221        let (router, application_socket) = self.prepare_router();
222
223        tls::run_service(
224            &application_socket,
225            router,
226            tls_handshake_timeout,
227            tls_cert,
228            tls_key,
229        )
230        .await
231    }
232
233    fn prepare_router(self) -> (Router, SocketAddr)
234    where
235        H: HealthExt,
236        V: VersionExt<T>,
237        T: Send + Sync + 'static,
238    {
239        let app_router = self
240            .router
241            .map(|router| {
242                let service_name = self.config.observability_cfg.service_name.clone();
243                let component_name = self.config.observability_cfg.component_name.clone();
244
245                if self.use_default_trace_layer {
247                    router.layer(from_fn(move |req, next| {
248                        trace_request(req, next, service_name.clone(), component_name.clone())
249                    }))
250                } else {
251                    router
252                }
253            })
254            .unwrap_or_default();
255
256        let router = build_management_router(
257            &self.config,
258            self.health_indicator,
259            self.version,
260            self.metrics_callback,
261        )
262        .merge(app_router);
263
264        let application_socket = SocketAddr::new(self.config.host, self.config.port);
265        (router, application_socket)
266    }
267}
268
269#[allow(clippy::expect_used)]
270async fn shutdown_signal() {
271    let ctrl_c = async {
272        signal::ctrl_c()
273            .await
274            .expect("failed to install Ctrl+C handler");
275    };
276
277    #[cfg(unix)]
278    let terminate = async {
279        signal::unix::signal(signal::unix::SignalKind::terminate())
280            .expect("failed to install SIGTERM signal handler")
281            .recv()
282            .await;
283    };
284
285    #[cfg(not(unix))]
286    let terminate = std::future::pending::<()>();
287
288    tokio::select! {
289        _ = ctrl_c => {},
290        _ = terminate => {},
291    }
292
293    info!("Termination signal, starting shutdown...");
294}
295
296async fn run_service(socket: &SocketAddr, router: Router) -> Result<()> {
297    let app = router.into_make_service_with_connect_info::<SocketAddr>();
298    let server = Server::bind(socket).serve(app);
299
300    info!(target: "server", "Started: http://{socket}");
301
302    Ok(server.with_graceful_shutdown(shutdown_signal()).await?)
303}