use crate::resource::{HealthState, MIN_HEALTH_INTERVAL, Resource};
use crate::resource_lifecycle::{
run_initial_health_checks, shutdown_resources, spawn_health_tasks,
};
use crate::runtime_state::{
RuntimeConfig, RuntimeInner, install_runtime, teardown_runtime, wait_for_tasks,
wait_for_tasks_timeout,
};
use crate::tls::CertStore;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
pub use tokio::runtime::Handle as TokioHandle;
pub(crate) use crate::runtime_state::{
cancel_channel, check_cancel, current_runtime, has_runtime, shutdown_notify, shutdown_signal,
};
pub fn on_cancel<F>(future: F)
where
F: std::future::Future<Output = ()> + Send + 'static,
{
crate::runtime_state::on_cancel(future);
}
pub fn block_on<F: std::future::Future>(f: F) -> F::Output {
crate::runtime_state::block_on(f)
}
pub fn request_shutdown() {
crate::runtime_state::request_shutdown();
}
pub fn tokio_handle() -> tokio::runtime::Handle {
crate::runtime_state::tokio_handle()
}
pub fn is_shutting_down() -> bool {
crate::runtime_state::is_shutting_down()
}
impl std::fmt::Debug for RuntimeBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RuntimeBuilder")
.field("worker_threads", &self.config.worker_threads)
.field("shutdown_timeout", &self.config.shutdown_timeout)
.field("keepalive_timeout", &self.config.keepalive_timeout)
.field("tracing_enabled", &self.config.tracing_enabled)
.field("metrics_enabled", &self.config.metrics_enabled)
.field("health_interval", &self.config.health_interval)
.field("connection_limit", &self.config.connection_limit)
.field("resource_count", &self.resources.len())
.field("has_tls", &self.tls_cert_store.is_some())
.finish()
}
}
pub struct RuntimeBuilder {
config: RuntimeConfig,
resources: Vec<Box<dyn Resource>>,
tls_cert_path: Option<std::path::PathBuf>,
tls_key_path: Option<std::path::PathBuf>,
tls_cert_store: Option<CertStore>,
#[cfg(feature = "acme")]
acme_config: Option<crate::acme::AcmeConfig>,
#[cfg(feature = "dns01")]
dns01_setup: Option<crate::dns01::Dns01Setup>,
#[cfg(feature = "otel")]
otel_endpoint: Option<Box<str>>,
}
impl RuntimeBuilder {
fn new() -> Self {
Self {
config: RuntimeConfig::default(),
resources: Vec::new(),
tls_cert_path: None,
tls_key_path: None,
tls_cert_store: None,
#[cfg(feature = "acme")]
acme_config: None,
#[cfg(feature = "dns01")]
dns01_setup: None,
#[cfg(feature = "otel")]
otel_endpoint: None,
}
}
pub fn worker_threads(mut self, n: usize) -> Self {
self.config.worker_threads = n;
self
}
pub fn shutdown_timeout(mut self, timeout: Duration) -> Self {
const MIN: Duration = Duration::from_millis(100);
self.config.shutdown_timeout =
crate::time::clamp_duration(timeout, MIN, "shutdown_timeout");
self
}
pub fn keepalive_timeout(mut self, timeout: Duration) -> Self {
const MIN: Duration = Duration::from_millis(100);
self.config.keepalive_timeout =
crate::time::clamp_duration(timeout, MIN, "keepalive_timeout");
self
}
pub fn health_interval(mut self, interval: Duration) -> Self {
self.config.health_interval = interval.max(MIN_HEALTH_INTERVAL);
self
}
pub fn connection_limit(mut self, n: usize) -> Self {
self.config.connection_limit = Some(n);
self
}
pub fn resource(mut self, r: impl Resource) -> Self {
self.resources.push(Box::new(r));
self
}
pub fn with_tracing(mut self) -> Self {
self.config.tracing_enabled = true;
self
}
pub fn with_metrics(mut self) -> Self {
self.config.metrics_enabled = true;
self
}
#[cfg(feature = "profiling")]
pub fn with_profiling(mut self) -> Self {
self.config.profiling_enabled = true;
self
}
#[cfg(feature = "otel")]
pub fn otel_endpoint(mut self, url: &str) -> Self {
self.otel_endpoint = Some(Box::from(url));
self
}
pub fn tls_cert(mut self, path: &std::path::Path) -> Self {
self.tls_cert_path = Some(path.to_path_buf());
self
}
pub fn tls_key(mut self, path: &std::path::Path) -> Self {
self.tls_key_path = Some(path.to_path_buf());
self
}
pub fn tls_resolver(mut self, store: CertStore) -> Self {
self.tls_cert_store = Some(store);
self
}
#[cfg(feature = "acme")]
pub fn tls_auto(mut self, config: crate::acme::AcmeConfig) -> Self {
self.acme_config = Some(config);
self
}
#[cfg(feature = "dns01")]
pub fn tls_auto_dns01(
mut self,
acme: crate::dns01::AcmeDns01,
api_token: Box<str>,
domain: Box<str>,
) -> Self {
self.dns01_setup = Some(crate::dns01::Dns01Setup {
acme,
api_token,
domain,
});
self
}
pub fn run<F, T>(self, f: F) -> Result<T, crate::RuntimeError>
where
F: FnOnce() -> T,
{
self.validate_tls_options()?;
if self.config.worker_threads == 0 {
return Err(crate::RuntimeError::InvalidArgument(
"worker_threads must be at least 1".into(),
));
}
if self.config.connection_limit == Some(0) {
return Err(crate::RuntimeError::InvalidArgument(
"connection_limit must be at least 1".into(),
));
}
let mut config = self.config;
let (tls_cfg, store) =
crate::tls::resolve_tls(self.tls_cert_store, self.tls_cert_path, self.tls_key_path)?;
config.tls_config = tls_cfg;
config.cert_store = store;
#[cfg(feature = "otel")]
if let Some(ref endpoint) = self.otel_endpoint {
crate::http::otel::init_exporter(endpoint)?;
}
#[cfg(feature = "acme")]
let acme_state = match self.acme_config {
Some(acme_cfg) => {
let (tls_cfg, state) = acme_cfg.build()?;
config.tls_config = Some(tls_cfg);
Some(state)
}
None => None,
};
run_inner_impl(
config,
self.resources.into_boxed_slice(),
f,
#[cfg(feature = "acme")]
acme_state,
#[cfg(feature = "dns01")]
self.dns01_setup,
)
}
fn validate_tls_options(&self) -> Result<(), crate::RuntimeError> {
let has_manual = self.tls_cert_path.is_some()
|| self.tls_key_path.is_some()
|| self.tls_cert_store.is_some();
#[cfg(feature = "acme")]
let has_acme = self.acme_config.is_some();
#[cfg(not(feature = "acme"))]
let has_acme = false;
#[cfg(feature = "dns01")]
let has_dns01 = self.dns01_setup.is_some();
#[cfg(not(feature = "dns01"))]
let has_dns01 = false;
match (has_acme, has_dns01, has_manual) {
(true, true, _) => Err(crate::RuntimeError::Tls(
"tls_auto and tls_auto_dns01 are mutually exclusive".into(),
)),
(true, _, true) => Err(crate::RuntimeError::Tls(
"tls_auto and tls_cert/tls_key are mutually exclusive".into(),
)),
(_, true, true) => Err(crate::RuntimeError::Tls(
"tls_auto_dns01 and tls_cert/tls_key are mutually exclusive".into(),
)),
_ => Ok(()),
}
}
}
pub fn builder() -> RuntimeBuilder {
RuntimeBuilder::new()
}
pub fn test<F, T>(f: F) -> Result<T, crate::RuntimeError>
where
F: FnOnce() -> T,
{
builder()
.keepalive_timeout(Duration::from_millis(100))
.shutdown_timeout(Duration::from_secs(1))
.run(f)
}
#[doc(hidden)]
pub fn __test_async<F, Fut, T>(f: F) -> Result<T, crate::RuntimeError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = T>,
{
try_test_async(f)
}
fn try_test_async<F, Fut, T>(f: F) -> Result<T, crate::RuntimeError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = T>,
{
let shutdown_timeout = Duration::from_secs(1);
let tokio_rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let mut inner = RuntimeInner::with_config(RuntimeConfig {
keepalive_timeout: Duration::from_millis(100),
shutdown_timeout,
..RuntimeConfig::default()
});
inner.tokio_handle = Some(tokio_rt.handle().clone());
let inner = Arc::new(inner);
install_runtime(Arc::clone(&inner));
let result = tokio_rt.block_on(f());
wait_for_tasks_timeout(&inner, shutdown_timeout);
teardown_runtime();
tokio_rt.shutdown_timeout(shutdown_timeout);
Ok(result)
}
pub fn run<F, T>(f: F) -> Result<T, crate::RuntimeError>
where
F: FnOnce() -> T,
{
run_inner_impl(
RuntimeConfig::default(),
Vec::new().into_boxed_slice(),
f,
#[cfg(feature = "acme")]
None,
#[cfg(feature = "dns01")]
None,
)
}
fn run_inner_impl<F, T>(
config: RuntimeConfig,
resources: Box<[Box<dyn Resource>]>,
f: F,
#[cfg(feature = "acme")] acme_state: Option<crate::acme::AcmeState<std::io::Error>>,
#[cfg(feature = "dns01")] dns01_setup: Option<crate::dns01::Dns01Setup>,
) -> Result<T, crate::RuntimeError>
where
F: FnOnce() -> T,
{
let shutdown_timeout = config.shutdown_timeout;
let metrics_enabled = config.metrics_enabled;
let tokio_rt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(config.worker_threads)
.enable_all()
.build()?;
#[cfg(feature = "dns01")]
let (config, dns01_state) = match dns01_setup {
Some(setup) => {
let state = tokio_rt.block_on(crate::dns01::init_dns01(setup))?;
let mut cfg = config;
cfg.tls_config = Some(state.tls_config.clone());
cfg.cert_store = Some(state.store.clone());
(cfg, Some(state))
}
None => (config, None),
};
let resources: Arc<[Box<dyn Resource>]> = resources.into();
let health_state: Option<HealthState> = match resources.is_empty() {
true => None,
false => Some(
resources
.iter()
.map(|r| (Box::from(r.name()), AtomicBool::new(true)))
.collect::<Vec<_>>()
.into_boxed_slice()
.into(),
),
};
let health_interval = config.health_interval;
let mut inner = RuntimeInner::with_config(config);
inner.metrics_handle = install_metrics(metrics_enabled);
inner.tokio_handle = Some(tokio_rt.handle().clone());
inner.health_state = health_state.clone();
let inner = Arc::new(inner);
install_runtime(Arc::clone(&inner));
let result = tokio_rt.block_on(async {
if let Some(ref hs) = health_state {
run_initial_health_checks(&resources, hs).await;
}
let signal_task = crate::signals::spawn_signal_watcher(
inner.shutdown.clone(),
inner.shutdown_notify.clone(),
);
#[cfg(feature = "acme")]
let acme_task = acme_state.map(spawn_acme_renewal);
#[cfg(feature = "dns01")]
let dns01_task = dns01_state.map(|s| s.acme.spawn_renewal(s.provider, s.store));
let health_tasks = spawn_health_tasks(
&resources,
&health_state,
health_interval,
&inner.shutdown_notify,
);
let value = f();
signal_task.abort();
for task in &health_tasks {
task.abort();
}
#[cfg(feature = "acme")]
if let Some(task) = acme_task {
task.abort();
}
#[cfg(feature = "dns01")]
if let Some(task) = dns01_task {
task.abort();
}
value
});
match inner.shutdown.load(Ordering::Acquire) {
true => wait_for_tasks_timeout(&inner, shutdown_timeout),
false => wait_for_tasks(&inner),
}
shutdown_resources(&resources);
#[cfg(feature = "otel")]
crate::http::otel::shutdown_exporter();
teardown_runtime();
tokio_rt.shutdown_timeout(shutdown_timeout);
Ok(result)
}
fn init_prometheus_recorder() -> metrics_exporter_prometheus::PrometheusHandle {
let recorder = metrics_exporter_prometheus::PrometheusBuilder::new().build_recorder();
let handle = recorder.handle();
match metrics::set_global_recorder(recorder) {
Ok(()) => {}
Err(e) => tracing::warn!("failed to install global metrics recorder: {e}"),
}
handle
}
fn install_metrics(enabled: bool) -> Option<metrics_exporter_prometheus::PrometheusHandle> {
static HANDLE: std::sync::OnceLock<metrics_exporter_prometheus::PrometheusHandle> =
std::sync::OnceLock::new();
match enabled {
false => None,
true => Some(HANDLE.get_or_init(init_prometheus_recorder).clone()),
}
}
#[cfg(feature = "acme")]
fn spawn_acme_renewal(
state: crate::acme::AcmeState<std::io::Error>,
) -> tokio::task::JoinHandle<()> {
use futures_util::StreamExt;
tokio::spawn(async move {
let mut state = std::pin::pin!(state);
while let Some(event) = state.next().await {
match event {
Ok(ok) => tracing::info!("acme: {ok:?}"),
Err(err) => tracing::warn!("acme: {err}"),
}
}
})
}