use std::{net::SocketAddr, time::Duration};
use axum_server::{service::MakeService, Handle as AxumHandle};
use futures::{stream::FuturesUnordered, StreamExt, TryFutureExt};
use opentelemetry::{metrics::MeterProvider as _, trace::TracerProvider as _};
use opentelemetry_sdk::{
metrics::SdkMeterProvider,
trace::{SdkTracerProvider, Tracer},
};
use thiserror::Error;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use crate::{
builder::server::ServerBuilder, config::AppConfig, crypto::ensure_default_crypto_provider,
errors::IoError, metrics::gather_runtime_metrics, notify::ServiceNotifier,
};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum HandleError {
#[error(transparent)]
Logging(#[from] crate::logging::LoggingError),
#[error(transparent)]
Tracing(#[from] crate::tracing::TracingError),
#[error(transparent)]
Metrics(#[from] crate::metrics::MetricsError),
#[error(transparent)]
ServerBuilder(#[from] crate::builder::server::ServerBuilderError),
#[error("HTTP server error: {0}")]
Server(IoError),
#[error("Error initializing crypto provider")]
InitTls,
#[error("HTTPS server error: {0}")]
TlsServer(IoError),
#[error("Server task error: {0}")]
ServerTask(#[from] tokio::task::JoinError),
#[error("No server is currently running")]
NotRunning,
#[error("Custom error: {0}")]
Custom(Box<dyn std::error::Error + Send + Sync>),
}
impl HandleError {
#[must_use]
pub fn custom<T>(err: T) -> Self
where
T: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self::Custom(err.into())
}
}
#[allow(dead_code)]
#[non_exhaustive]
pub struct Handle {
token: CancellationToken,
buf_guards: Vec<WorkerGuard>,
tracer: Option<Tracer>,
tracer_provider: Option<SdkTracerProvider>,
metrics_provider: Option<SdkMeterProvider>,
handle: AxumHandle,
notify: ServiceNotifier,
service_watchdog: Option<JoinHandle<()>>,
signal_handler: Option<JoinHandle<()>>,
http_task: Option<JoinHandle<Result<(), HandleError>>>,
https_task: Option<JoinHandle<Result<(), HandleError>>>,
rt_metrics_task: Option<JoinHandle<()>>,
}
impl Drop for Handle {
fn drop(&mut self) {
self.token.cancel();
if let Some(provider) = self.metrics_provider.take() {
if let Err(err) = provider.force_flush() {
eprintln!("Error flushing metrics: {err}");
}
if let Err(err) = provider.shutdown() {
eprintln!("Error shutting down OTel metrics provider: {err}")
}
}
if let Some(provider) = self.tracer_provider.take() {
if let Err(err) = provider.force_flush() {
eprintln!("Error flushing spans: {err}");
}
if let Err(err) = provider.shutdown() {
eprintln!("Error shutting down OTel tracing provider: {err}")
}
}
}
}
impl Handle {
fn prepare(&mut self, server: &ServerBuilder) -> Result<(), HandleError> {
if self.signal_handler.is_none() {
self.signal_handler = Some(server.spawn_signal_handler(self.handle.clone())?);
}
if self.service_watchdog.is_none() {
self.service_watchdog = Some(tokio::spawn(self.notify.watchdog_task()));
}
Ok(())
}
async fn start_servers<A>(&mut self, server: ServerBuilder, app: A) -> Result<(), HandleError>
where
A: MakeService<SocketAddr, http::Request<hyper::body::Incoming>>
+ tower::Service<SocketAddr>
+ Clone
+ Send
+ 'static,
A::Response: tower::Service<http::Request<hyper::body::Incoming>>,
A::MakeFuture: Send,
{
if server.has_tls_config() {
ensure_default_crypto_provider();
self.https_task = Some(tokio::spawn(
server
.clone()
.build_tls()
.await?
.handle(self.handle.clone())
.serve(app.clone())
.map_err(|err| HandleError::TlsServer(err.into())),
));
}
self.http_task = Some(tokio::spawn(
server
.build()
.await?
.handle(self.handle.clone())
.serve(app)
.map_err(|err| HandleError::Server(err.into())),
));
Ok(())
}
pub async fn start<A>(&mut self, server: ServerBuilder, app: A) -> Result<(), HandleError>
where
A: MakeService<SocketAddr, http::Request<hyper::body::Incoming>>
+ tower::Service<SocketAddr>
+ Clone
+ Send
+ 'static,
A::Response: tower::Service<http::Request<hyper::body::Incoming>>,
A::MakeFuture: Send,
{
self.prepare(&server)?;
self.start_servers(server, app).await?;
self.notify.on_ready();
Ok(())
}
pub async fn shutdown(&mut self) -> Result<(), HandleError> {
self.notify.on_shutdown();
self.handle.shutdown();
if let Some(task) = self.http_task.take() {
task.await??;
}
if let Some(task) = self.https_task.take() {
task.await??;
}
Ok(())
}
pub async fn graceful_shutdown(
&mut self,
graceful: Option<Duration>,
) -> Result<(), HandleError> {
self.notify.on_shutdown();
self.handle.graceful_shutdown(graceful);
if let Some(task) = self.http_task.take() {
task.await??;
}
if let Some(task) = self.https_task.take() {
task.await??;
}
Ok(())
}
pub fn abort(&mut self) {
self.notify.on_shutdown();
if let Some(task) = self.http_task.take() {
task.abort();
}
if let Some(task) = self.https_task.take() {
task.abort();
}
}
pub async fn run<A>(
&mut self,
server: ServerBuilder,
app: A,
graceful: Option<Duration>,
) -> Result<(), HandleError>
where
A: MakeService<SocketAddr, http::Request<hyper::body::Incoming>>
+ tower::Service<SocketAddr>
+ Clone
+ Send
+ 'static,
A::Response: tower::Service<http::Request<hyper::body::Incoming>>,
A::MakeFuture: Send,
{
self.start(server, app).await?;
self.wait(graceful).await
}
pub async fn wait(&mut self, graceful: Option<Duration>) -> Result<(), HandleError> {
let http_fut = self.http_task.take();
let https_fut = self.https_task.take();
match (http_fut, https_fut) {
(None, None) => Err(HandleError::NotRunning),
(Some(task), None) => task.await?,
(None, Some(task)) => task.await?,
(Some(http), Some(https)) => {
let mut tasks = FuturesUnordered::new();
tasks.push(http);
tasks.push(https);
match tasks.next().await {
Some(ret) => {
self.handle.graceful_shutdown(graceful);
while let Some(other_ret) = tasks.next().await {
let _ = other_ret?;
}
ret?
}
None => Ok(()),
}
}
}
}
}
impl AppConfig {
pub async fn handle(&mut self) -> Result<Handle, HandleError> {
let token = CancellationToken::new();
let (registry, buf_guards) = self.logging.make_registry()?;
let otel_res = self.otel_resource();
let (tracer, tracer_provider) = if let Some(tcfg) = self.tracing.as_mut() {
let tracer_provider = tcfg.build_provider(otel_res.clone()).await?;
let tracer = tracer_provider.tracer("uxum");
let layer = tcfg.build_layer(&tracer);
registry.with(layer).init();
opentelemetry::global::set_text_map_propagator(
opentelemetry_sdk::propagation::TraceContextPropagator::default(),
);
(Some(tracer), Some(tracer_provider))
} else {
registry.init();
(None, None)
};
let (metrics_provider, rt_metrics_task) = if let Some(mcfg) = self.metrics.as_ref() {
let (metrics_provider, prom_exporter) = mcfg.build_provider(otel_res).await?;
let meter = metrics_provider.meter("uxum");
let metrics_state = mcfg.build_state(&meter, prom_exporter);
let rt_task = tokio::spawn(gather_runtime_metrics(
metrics_state.clone(),
mcfg.runtime_metrics_interval,
token.clone(),
));
self.metrics_state = Some(metrics_state);
opentelemetry::global::set_meter_provider(metrics_provider.clone());
(Some(metrics_provider), Some(rt_task))
} else {
(None, None)
};
let handle = AxumHandle::new();
let notify = ServiceNotifier::new();
Ok(Handle {
token,
buf_guards,
tracer,
tracer_provider,
metrics_provider,
handle,
notify,
service_watchdog: None,
signal_handler: None,
http_task: None,
https_task: None,
rt_metrics_task,
})
}
}