impulse_server_kit/
startup.rs

1//! Startup module.
2//!
3//! In most cases, you just need to use `start` function:
4//!
5//! ```rust,ignore
6//! let (server, _) = start(app_state, app_config, router).await.unwrap();
7//! server.await
8//! ```
9
10use impulse_utils::errors::ServerError;
11use impulse_utils::prelude::MResult;
12use salvo::prelude::*;
13
14use salvo::conn::rustls::{Keycert, RustlsConfig};
15use salvo::server::ServerHandle;
16use std::future::Future;
17use std::pin::Pin;
18use std::process::Command;
19
20#[cfg(feature = "http3")]
21use salvo::http::HeaderValue;
22#[cfg(feature = "http3")]
23use salvo::http::header::ALT_SVC;
24
25use crate::setup::{GenericServerState, GenericSetup, StartupVariant};
26
27static TLS13: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13];
28
29#[cfg(feature = "http3")]
30#[handler]
31/// HTTP2-to-HTTP3 switching header.
32///
33/// Usage is `router.hoop(h3_header)`.
34pub async fn h3_header(depot: &mut Depot, res: &mut Response) {
35  use crate::setup::GenericValues;
36
37  let server_port = match depot.obtain::<GenericValues>() {
38    Ok(app_config) => app_config.server_port.unwrap(),
39    Err(_) => 443,
40  };
41
42  res
43    .headers_mut()
44    .insert(
45      ALT_SVC,
46      HeaderValue::from_str(&format!(r##"h3=":{server_port}"; ma=2592000"##)).unwrap(),
47    )
48    .unwrap();
49}
50
51fn tlsv13(certpath: impl AsRef<str>, keypath: impl AsRef<str>) -> MResult<RustlsConfig> {
52  Ok(
53    RustlsConfig::new(
54      Keycert::new()
55        .cert_from_path(certpath.as_ref())
56        .map_err(|e| ServerError::from_private(e).with_500())?
57        .key_from_path(keypath.as_ref())
58        .map_err(|e| ServerError::from_private(e).with_500())?,
59    )
60    .tls_versions(TLS13),
61  )
62}
63
64#[cfg(feature = "otel")]
65#[handler]
66/// Default Server Kit OpenTelemetry metrics.
67///
68/// Installed by default with `get_root_router_autoinject` method.
69pub async fn sk_default_metrics(req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
70  let meter = crate::otel::api::global::meter("sk_metrics");
71
72  let request_counter = meter
73    .u64_counter("sk_requests")
74    .with_unit("1")
75    .with_description("Total number of requests")
76    .build();
77
78  let request_duration = meter
79    .f64_histogram("sk_request_duration")
80    .with_unit("s")
81    .with_description("HTTP request duration in seconds")
82    .build();
83
84  let active_connections = meter
85    .i64_up_down_counter("sk_active_connections")
86    .with_unit("1")
87    .with_description("Number of active HTTP connections")
88    .build();
89
90  let host = req.uri().host().map(String::from);
91  let path = req.uri().path().to_string();
92  let method = req.method().as_str().to_string();
93
94  let attributes = vec![
95    opentelemetry::KeyValue::new("host", host.unwrap_or(String::from("unknown"))),
96    opentelemetry::KeyValue::new("path", path),
97    opentelemetry::KeyValue::new("method", method),
98    opentelemetry::KeyValue::new("user_agent", req.header("user-agent").unwrap_or("unknown").to_string()),
99  ];
100
101  active_connections.add(1, &[]);
102  active_connections.add(1, &attributes);
103  let start = tokio::time::Instant::now();
104
105  ctrl.call_next(req, depot, res).await;
106
107  let duration = start.elapsed().as_secs_f64();
108
109  let mut result_attributes = attributes.clone();
110  let status = res.status_code.unwrap_or(StatusCode::OK).as_u16().to_string();
111  result_attributes.push(opentelemetry::KeyValue::new("status", status));
112
113  request_counter.add(1, &[]);
114  request_counter.add(1, &result_attributes);
115  request_duration.record(duration, &result_attributes);
116
117  active_connections.add(-1, &attributes);
118}
119
120/// Returns preconfigured router with app state and OpenTelemetry metrics injected.
121///
122/// To get your `app_config` inside handler/endpoint, call
123/// `depot.obtain::<YourAppConfigType>().unwrap()`.
124pub fn get_root_router_autoinject<T: GenericSetup + Send + Sync + Clone + 'static>(
125  app_state: &GenericServerState,
126  app_config: T,
127) -> Router {
128  #[allow(unused_mut)]
129  let mut router = Router::new().hoop(affix_state::inject(app_state.clone()).inject(app_config.clone()));
130
131  #[cfg(all(feature = "http3", feature = "acme"))]
132  if app_state.startup_variant == StartupVariant::QuinnAcme {
133    router = router.hoop(h3_header);
134  }
135
136  #[cfg(feature = "http3")]
137  if app_state.startup_variant == StartupVariant::Quinn || app_state.startup_variant == StartupVariant::QuinnOnly {
138    router = router.hoop(h3_header);
139  }
140
141  #[cfg(feature = "otel")]
142  if app_config.generic_values().tracing_options.otel_http_endpoint.is_some() {
143    router = router.hoop(sk_default_metrics);
144  }
145
146  router
147}
148
149/// Returns preconfigured root router to use.
150///
151/// Usually it installs application config and state in `affix_state` and installs `h3_header` for switching protocol to QUIC, if used.
152#[allow(unused_variables)]
153pub fn get_root_router(app_state: &GenericServerState) -> Router {
154  #[allow(unused_mut)]
155  let mut router = Router::new();
156
157  #[cfg(all(feature = "http3", feature = "acme"))]
158  if app_state.startup_variant == StartupVariant::QuinnAcme {
159    router = router.hoop(h3_header);
160  }
161
162  #[cfg(feature = "http3")]
163  if app_state.startup_variant == StartupVariant::Quinn || app_state.startup_variant == StartupVariant::QuinnOnly {
164    router = router.hoop(h3_header);
165  }
166
167  router
168}
169
170#[cfg(any(feature = "oapi", feature = "acme"))]
171#[allow(clippy::mut_from_ref, invalid_reference_casting)]
172unsafe fn make_mut<T>(reference: &T) -> &mut T {
173  let const_ptr = reference as *const T;
174  let mut_ptr = const_ptr as *mut T;
175  unsafe { &mut *mut_ptr }
176}
177
178/// Starts up HTTPS redirect server.
179///
180/// Example:
181///
182/// ```rust,ignore
183/// let (server, _) = start(app_state, app_config, router).await.unwrap();
184/// let (redirect, _) = start_force_https_redirect(80, 443).await.unwrap();
185///
186/// tracing::info!("Server is booted.");
187///
188/// tokio::select! {
189///   _ = server   => tracing::info!("Server is shutdowned."),
190///   _ = redirect => tracing::info!("Redirect is shutdowned."),
191/// }
192/// ```
193#[cfg(feature = "force-https")]
194pub async fn start_force_https_redirect(
195  listen_port: u16,
196  redirect_port: u16,
197) -> MResult<(Pin<Box<dyn Future<Output = ()> + Send>>, ServerHandle)> {
198  let service = Service::new(Router::new()).hoop(ForceHttps::new().https_port(redirect_port));
199  let acceptor = TcpListener::new(format!("0.0.0.0:{listen_port}")).bind().await;
200  let server = Server::new(acceptor);
201  let handle = server.handle();
202  let server = Box::pin(server.serve(service));
203  Ok((server, handle))
204}
205
206/// Starts your application with provided service, if you predefined one by yourself.
207///
208/// For example, you can setup service with error catcher or any other middleware that
209/// `salvo` provides.
210pub async fn start_with_service(
211  app_state: GenericServerState,
212  app_config: &impl GenericSetup,
213  #[allow(unused_mut)] mut service: Service,
214) -> MResult<(Pin<Box<dyn Future<Output = ()> + Send>>, ServerHandle)> {
215  tracing::info!("Server is starting...");
216
217  rustls::crypto::aws_lc_rs::default_provider()
218    .install_default()
219    .map_err(|_| ServerError::from_private_str("Can't install default crypto provider!").with_500())?;
220  let app_config = app_config.generic_values();
221
222  if let Some(bin) = app_config.auto_migrate_bin.as_ref() {
223    Command::new(bin)
224      .spawn()
225      .map_err(|e| ServerError::from_private(e).with_500())?;
226  }
227
228  #[cfg(feature = "oapi")]
229  if app_config.allow_oapi_access.is_some_and(|v| v) {
230    let doc = OpenApi::new(
231      app_config.oapi_name.as_ref().unwrap(),
232      app_config.oapi_ver.as_ref().unwrap(),
233    )
234    .merge_router(&service.router);
235
236    let oapi_endpoint = if let Some(ftype) = app_config.oapi_frontend_type.as_ref() {
237      match ftype.as_str() {
238        "Scalar" => Some(
239          Scalar::new(format!("{}/openapi.json", app_config.oapi_api_addr.as_ref().unwrap()))
240            .title(format!(
241              "{} - API @ Scalar",
242              app_config.oapi_name.as_ref().unwrap_or(&app_config.app_name)
243            ))
244            .description(format!(
245              "{} - API",
246              app_config.oapi_name.as_ref().unwrap_or(&app_config.app_name)
247            ))
248            .into_router(app_config.oapi_api_addr.as_ref().unwrap()),
249        ),
250        "SwaggerUI" => Some(
251          SwaggerUi::new(format!("{}/openapi.json", app_config.oapi_api_addr.as_ref().unwrap()))
252            .title(format!(
253              "{} - API @ SwaggerUI",
254              app_config.oapi_name.as_ref().unwrap_or(&app_config.app_name)
255            ))
256            .description(format!(
257              "{} - API",
258              app_config.oapi_name.as_ref().unwrap_or(&app_config.app_name)
259            ))
260            .into_router(app_config.oapi_api_addr.as_ref().unwrap()),
261        ),
262        _ => None,
263      }
264    } else {
265      None
266    };
267
268    let mut router = Router::new();
269    router = router.push(doc.into_router(format!("{}/openapi.json", app_config.oapi_api_addr.as_ref().unwrap())));
270    if let Some(oapi) = oapi_endpoint {
271      router = router.push(oapi);
272    }
273
274    unsafe {
275      let service_router = make_mut(service.router.as_ref());
276      service_router.routers_mut().insert(0, router);
277    }
278
279    tracing::info!("API is available on {}", app_config.oapi_api_addr.as_ref().unwrap());
280  }
281
282  #[cfg(feature = "cors")]
283  if let Some(domain) = &app_config.allow_cors_domain {
284    let cors = salvo::cors::Cors::new()
285      .allow_origin(domain)
286      .allow_credentials(domain.as_str() != "*")
287      .allow_headers(vec![
288        "Authorization",
289        "Accept",
290        "Access-Control-Allow-Headers",
291        "Content-Type",
292        "Origin",
293        "X-Requested-With",
294        "Cookie",
295      ])
296      .expose_headers(vec!["Set-Cookie"])
297      .allow_methods(vec![
298        salvo::http::Method::GET,
299        salvo::http::Method::POST,
300        salvo::http::Method::PUT,
301        salvo::http::Method::PATCH,
302        salvo::http::Method::DELETE,
303        salvo::http::Method::OPTIONS,
304      ])
305      .into_handler();
306
307    service = service.hoop(cors);
308  }
309
310  let handle;
311
312  let server = match app_state.startup_variant {
313    StartupVariant::HttpLocalhost => {
314      let acceptor = TcpListener::new(format!("127.0.0.1:{}", app_config.server_port.unwrap()))
315        .bind()
316        .await;
317      let server = Server::new(acceptor);
318      handle = server.handle();
319      Box::pin(server.serve(service)) as Pin<Box<dyn Future<Output = ()> + Send>>
320    }
321    StartupVariant::UnsafeHttp => {
322      let acceptor = TcpListener::new(format!(
323        "{}:{}",
324        app_config.server_host.as_ref().unwrap(),
325        app_config.server_port.unwrap()
326      ))
327      .bind()
328      .await;
329      let server = Server::new(acceptor);
330      handle = server.handle();
331      Box::pin(server.serve(service))
332    }
333    #[cfg(feature = "acme")]
334    StartupVariant::HttpsAcme => {
335      let acceptor = TcpListener::new(format!(
336        "{}:{}",
337        app_config.server_host.as_ref().unwrap(),
338        app_config.server_port.unwrap()
339      ))
340      .acme()
341      .cache_path("tmp/letsencrypt")
342      .add_domain(app_config.acme_domain.as_ref().unwrap())
343      .bind()
344      .await;
345      let server = Server::new(acceptor);
346      handle = server.handle();
347      Box::pin(server.serve(service))
348    }
349    StartupVariant::HttpsOnly => {
350      let rustls_config = tlsv13(
351        app_config.ssl_crt_path.as_ref().unwrap(),
352        app_config.ssl_key_path.as_ref().unwrap(),
353      )?;
354      let listener = TcpListener::new(format!(
355        "{}:{}",
356        app_config.server_host.as_ref().unwrap(),
357        app_config.server_port.unwrap()
358      ))
359      .rustls(rustls_config)
360      .bind()
361      .await;
362
363      let server = Server::new(listener);
364      handle = server.handle();
365      Box::pin(server.serve(service))
366    }
367    #[cfg(all(feature = "http3", feature = "acme"))]
368    StartupVariant::QuinnAcme => {
369      let acceptor = TcpListener::new(format!(
370        "{}:{}",
371        app_config.server_host.as_ref().unwrap(),
372        app_config.server_port.unwrap()
373      ))
374      .acme()
375      .cache_path("tmp/letsencrypt")
376      .add_domain(app_config.acme_domain.as_ref().unwrap())
377      .quinn(format!(
378        "{}:{}",
379        app_config.server_host.as_ref().unwrap(),
380        app_config.server_port.unwrap()
381      ))
382      .bind()
383      .await;
384      let server = Server::new(acceptor);
385      handle = server.handle();
386      Box::pin(server.serve(service))
387    }
388    #[cfg(feature = "http3")]
389    StartupVariant::Quinn => {
390      let rustls_config = tlsv13(
391        app_config.ssl_crt_path.as_ref().unwrap(),
392        app_config.ssl_key_path.as_ref().unwrap(),
393      )?;
394      let listener = TcpListener::new(format!(
395        "{}:{}",
396        app_config.server_host.as_ref().unwrap(),
397        app_config.server_port.unwrap()
398      ))
399      .rustls(rustls_config.clone());
400
401      let quinn_config = rustls_config
402        .build_quinn_config()
403        .map_err(|e| ServerError::from_private(e).with_500())?;
404      let acceptor = QuinnListener::new(
405        quinn_config,
406        format!(
407          "{}:{}",
408          app_config.server_host.as_ref().unwrap(),
409          app_config.server_port.unwrap()
410        ),
411      )
412      .join(listener)
413      .bind()
414      .await;
415
416      let server = Server::new(acceptor);
417      handle = server.handle();
418      Box::pin(server.serve(service))
419    }
420    #[cfg(feature = "http3")]
421    StartupVariant::QuinnOnly => {
422      let quinn_config = tlsv13(
423        app_config.ssl_crt_path.as_ref().unwrap(),
424        app_config.ssl_key_path.as_ref().unwrap(),
425      )?
426      .build_quinn_config()
427      .map_err(|e| ServerError::from_private(e).with_500())?;
428      let acceptor = QuinnListener::new(
429        quinn_config,
430        format!(
431          "{}:{}",
432          app_config.server_host.as_ref().unwrap(),
433          app_config.server_port.unwrap()
434        ),
435      )
436      .bind()
437      .await;
438
439      let server = Server::new(acceptor);
440      handle = server.handle();
441      Box::pin(server.serve(service))
442    }
443  };
444
445  Ok((server, handle))
446}
447
448/// Starts the server according to the startup variant provided with the custom shutdown.
449pub async fn start_clean(
450  app_state: GenericServerState,
451  app_config: &impl GenericSetup,
452  router: Router,
453) -> MResult<(Pin<Box<dyn Future<Output = ()> + Send>>, ServerHandle)> {
454  start_with_service(app_state, app_config, Service::new(router)).await
455}
456
457/// Starts the server according to the startup variant provided.
458pub async fn start(
459  app_state: GenericServerState,
460  app_config: &impl GenericSetup,
461  router: Router,
462) -> MResult<(Pin<Box<dyn Future<Output = ()> + Send>>, ServerHandle)> {
463  let (fut, handle) = start_clean(app_state, app_config, router).await?;
464  let ctrl_c_handle = handle.clone();
465  tokio::spawn(async move { shutdown_signal(ctrl_c_handle).await });
466  Ok((fut, handle))
467}
468
469/// Signal to graceful shutdown.
470///
471/// Required to be manually awaited, if you start server with `start_clean`/`start_with_service` functions. Example:
472///
473/// ```rust,ignore
474/// let (server, handle) = start_clean(app_state, app_config, router).await.unwrap();
475/// let default_handle = tokio::spawn(async move { shutdown_signal(handle).await });
476///
477/// tracing::info!("Server is booted.");
478///
479/// tokio::select! {
480///   _ = server         => tracing::info!("Server is shutdowned."),
481///   _ = default_handle => std::process::exit(0),
482/// }
483/// ```
484///
485/// Graceful coroutine starts automatically with `start` function.
486pub async fn shutdown_signal(handle: ServerHandle) {
487  tokio::signal::ctrl_c().await.unwrap();
488  tracing::info!("Shutdown with Ctrl+C requested.");
489  handle.stop_graceful(None);
490}