use crate::config::Config;
use crate::config::topology::Topology;
use crate::hot_reload::client::HotReloadClient;
use crate::observability::LogFilterHttpExporter;
use anyhow::Context;
use anyhow::{Result, anyhow};
use clap::{Parser, crate_version};
use metrics_exporter_prometheus::PrometheusBuilder;
use rustls::crypto::aws_lc_rs::default_provider;
use std::env;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::runtime::{self, Runtime};
use tokio::signal::unix::{SignalKind, signal};
use tokio::sync::watch;
use tracing::{error, info, warn};
use tracing_appender::non_blocking::{NonBlocking, WorkerGuard};
use tracing_subscriber::filter::Directive;
use tracing_subscriber::fmt::Layer;
use tracing_subscriber::fmt::format::DefaultFields;
use tracing_subscriber::fmt::format::Format;
use tracing_subscriber::fmt::format::Full;
use tracing_subscriber::fmt::format::Json;
use tracing_subscriber::fmt::format::JsonFields;
use tracing_subscriber::layer::Layered;
use tracing_subscriber::reload::Handle;
use tracing_subscriber::{EnvFilter, Registry};
#[derive(Parser, Clone)]
#[clap(version = crate_version!(), author = "Instaclustr")]
struct ConfigOpts {
#[clap(short, long, default_value = "config/topology.yaml")]
pub topology_file: String,
#[clap(short, long, default_value = "config/config.yaml")]
pub config_file: String,
#[clap(long)]
pub core_threads: Option<usize>,
#[clap(long, default_value = "2097152")]
pub stack_size: usize,
#[arg(long, value_enum, default_value = "human")]
pub log_format: LogFormat,
#[clap(long)]
pub hotreload_socket: Option<String>,
#[clap(long, default_value = "60")]
pub hotreload_gradual_shutdown_seconds: u64,
}
#[derive(clap::ValueEnum, Clone, Copy)]
enum LogFormat {
Human,
Json,
}
pub struct Shotover {
runtime: Runtime,
topology: Topology,
config: Config,
tracing: TracingState,
hotreload_socket: Option<String>,
hotreload_gradual_shutdown_duration: Duration,
}
impl Shotover {
#[expect(clippy::new_without_default)]
pub fn new() -> Self {
if std::env::var("RUST_LIB_BACKTRACE").is_err() {
std::env::set_var("RUST_LIB_BACKTRACE", "0");
}
default_provider().install_default().unwrap();
let opts = ConfigOpts::parse();
let log_format = opts.log_format;
match Shotover::new_inner(opts) {
Ok(x) => x,
Err(err) => {
{
let rt = Runtime::new()
.context("Failed to create runtime while trying to report {err:?}")
.unwrap();
let _guard = rt.enter();
let _tracing_state = TracingState::new("error", log_format)
.context("Failed to create TracingState while trying to report {err:?}")
.unwrap();
tracing::error!("{:?}", err.context("Failed to start shotover"));
}
std::process::exit(1);
}
}
}
fn new_inner(params: ConfigOpts) -> Result<Self> {
let config = Config::from_file(params.config_file)?;
let topology = Topology::from_file(¶ms.topology_file)?;
let tracing = TracingState::new(config.main_log_level.as_str(), params.log_format)?;
let runtime = Shotover::create_runtime(params.stack_size, params.core_threads);
let hotreload_socket = params.hotreload_socket;
Shotover::start_observability_interface(&runtime, &config, &tracing)?;
Ok(Shotover {
runtime,
topology,
config,
tracing,
hotreload_socket,
hotreload_gradual_shutdown_duration: Duration::from_secs(
params.hotreload_gradual_shutdown_seconds,
),
})
}
fn start_observability_interface(
runtime: &Runtime,
config: &Config,
tracing: &TracingState,
) -> Result<()> {
if let Some(observability_interface) = &config.observability_interface {
let recorder = PrometheusBuilder::new()
.set_quantiles(&[0.0, 0.1, 0.5, 0.9, 0.95, 0.99, 0.999, 1.0])
.unwrap()
.build_recorder();
let handle = recorder.handle();
metrics::set_global_recorder(recorder)?;
let socket: SocketAddr = observability_interface.parse()?;
let exporter = LogFilterHttpExporter::new(handle, socket, tracing.handle.clone());
runtime.spawn(exporter.async_run());
}
Ok(())
}
async fn run_inner(
topology: Topology,
config: Config,
hotreload_socket: Option<String>,
hotreload_gradual_shutdown_duration: Duration,
trigger_shutdown_rx: watch::Receiver<bool>,
) -> Result<()> {
let hotreload_client = hotreload_socket.clone().and_then(HotReloadClient::new);
let hotreload_listeners = if let Some(client) = &hotreload_client {
info!("Hot reload CLIENT mode - requesting socket handoff from existing shotover");
client
.perform_hot_reloading()
.await
.context("Hot reload client failed")?
} else {
std::collections::HashMap::new()
};
info!("Starting Shotover {}", crate_version!());
info!(configuration = ?config);
info!(topology = ?topology);
match topology
.run_chains(trigger_shutdown_rx, hotreload_listeners)
.await
{
Ok(sources) => {
if let Some(client) = &hotreload_client {
if let Err(e) = client
.request_shutdown_old_instance(hotreload_gradual_shutdown_duration)
.await
{
warn!(
"Failed to send shutdown request to old shotover instance: {}",
e
);
}
}
if let Some(socket_path) = hotreload_socket {
info!("Starting hot reload server at: {}", socket_path);
crate::hot_reload::server::start_hot_reload_server(socket_path, &sources);
}
futures::future::join_all(sources.into_iter().map(|x| x.join())).await;
Ok(())
}
Err(err) => Err(err),
}
}
pub fn run_block(self) -> ! {
let Shotover {
runtime,
topology,
config,
tracing,
hotreload_socket,
hotreload_gradual_shutdown_duration,
} = self;
let (trigger_shutdown_tx, trigger_shutdown_rx) = tokio::sync::watch::channel(false);
let (mut interrupt, mut terminate) = runtime.block_on(async {
(
signal(SignalKind::interrupt()).unwrap(),
signal(SignalKind::terminate()).unwrap(),
)
});
runtime.spawn(async move {
tokio::select! {
_ = interrupt.recv() => {
info!("received SIGINT");
},
_ = terminate.recv() => {
info!("received SIGTERM");
},
};
trigger_shutdown_tx.send(true).unwrap();
});
let code = match runtime.block_on(Shotover::run_inner(
topology,
config,
hotreload_socket,
hotreload_gradual_shutdown_duration,
trigger_shutdown_rx,
)) {
Ok(()) => {
info!("Shotover was shutdown cleanly.");
0
}
Err(err) => {
error!("{:?}", err.context("Failed to start shotover"));
1
}
};
std::mem::drop(tracing);
std::mem::drop(runtime);
std::process::exit(code);
}
fn create_runtime(stack_size: usize, worker_threads: Option<usize>) -> Runtime {
let mut runtime_builder = runtime::Builder::new_multi_thread();
runtime_builder
.enable_all()
.thread_name("shotover-worker")
.thread_stack_size(stack_size);
if let Some(worker_threads) = worker_threads {
runtime_builder.worker_threads(worker_threads);
}
runtime_builder.build().unwrap()
}
}
struct TracingState {
_guard: WorkerGuard,
handle: ReloadHandle,
}
fn try_parse_log_directives(directives: &[Option<&str>]) -> Result<EnvFilter> {
let directives: Vec<Directive> = directives
.iter()
.flat_map(Option::as_deref)
.flat_map(|s| s.split(','))
.map(str::trim)
.filter(|s| !s.is_empty())
.map(|s| s.parse().map_err(|e| anyhow!("{}: {}", e, s)))
.collect::<Result<_>>()?;
let filter = directives
.into_iter()
.fold(EnvFilter::default(), |filter, directive| {
filter.add_directive(directive)
});
Ok(filter)
}
impl TracingState {
pub fn new(log_level: &str, format: LogFormat) -> Result<Self> {
let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout());
let overrides = env::var(EnvFilter::DEFAULT_ENV).ok();
let env_filter = try_parse_log_directives(&[Some(log_level), overrides.as_deref()])?;
let handle = match format {
LogFormat::Json => {
let builder = tracing_subscriber::fmt()
.json()
.with_writer(non_blocking)
.with_env_filter(env_filter)
.with_filter_reloading();
let handle = ReloadHandle::Json(builder.reload_handle());
builder.init();
handle
}
LogFormat::Human => {
let builder = tracing_subscriber::fmt()
.with_writer(non_blocking)
.with_env_filter(env_filter)
.with_filter_reloading();
let handle = ReloadHandle::Human(builder.reload_handle());
builder.init();
handle
}
};
if let LogFormat::Json = format {
crate::tracing_panic_handler::setup();
}
Ok(TracingState {
_guard: guard,
handle,
})
}
}
type Formatter<A, B> = Layered<Layer<Registry, A, Format<B>, NonBlocking>, Registry>;
#[derive(Clone)]
pub(crate) enum ReloadHandle {
Json(Handle<EnvFilter, Formatter<JsonFields, Json>>),
Human(Handle<EnvFilter, Formatter<DefaultFields, Full>>),
}
impl ReloadHandle {
pub fn reload(&self, filter: EnvFilter) -> Result<()> {
match self {
ReloadHandle::Json(handle) => handle.reload(filter).map_err(|e| anyhow!(e)),
ReloadHandle::Human(handle) => handle.reload(filter).map_err(|e| anyhow!(e)),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_try_parse_log_directives() {
assert_eq!(
try_parse_log_directives(&[
Some("info,short=warn,error"),
None,
Some("debug"),
Some("alongname=trace")
])
.unwrap()
.to_string(),
"alongname=trace,short=warn,debug"
);
match try_parse_log_directives(&[Some("good=info,bad=blah,warn")]) {
Ok(_) => panic!(),
Err(e) => assert_eq!(e.to_string(), "invalid filter directive: bad=blah"),
}
}
}