use clap::{Parser, ValueHint};
use mtop::queue::{BlockingStatsQueue, Host, StatsQueue};
use mtop::ui::{TAILWIND, Theme};
use mtop_client::{Discovery, MemcachedClient, MtopError, Server, Timeout, TlsConfig};
use rustls_pki_types::{InvalidDnsNameError, ServerName};
use std::env;
use std::num::NonZeroU64;
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Handle;
use tokio::task;
use tracing::instrument::WithSubscriber;
use tracing::{Instrument, Level};
const STATS_INTERVAL: Duration = Duration::from_millis(1073);
const NUM_MEASUREMENTS: usize = 10;
#[derive(Debug, Parser)]
#[command(name = "mtop", version = clap::crate_version!())]
struct MtopConfig {
#[arg(long, env = "MTOP_LOG_LEVEL", default_value_t = Level::INFO)]
log_level: Level,
#[arg(long, env = "MTOP_RESOLV_CONF", default_value = "/etc/resolv.conf", value_hint = ValueHint::FilePath)]
resolv_conf: PathBuf,
#[arg(long, env = "MTOP_TIMEOUT_SECS", default_value_t = NonZeroU64::new(5).unwrap())]
timeout_secs: NonZeroU64,
#[arg(long, env = "MTOP_CONNECTIONS", default_value_t = NonZeroU64::new(2).unwrap())]
connections: NonZeroU64,
#[arg(long, env = "MTOP_LOG_FILE", default_value = default_log_file().into_os_string(), value_hint = ValueHint::FilePath)]
log_file: PathBuf,
#[arg(long, env = "MTOP_THEME", default_value_t = TAILWIND)]
theme: Theme,
#[arg(long, env = "MTOP_TLS_ENABLED")]
tls_enabled: bool,
#[arg(long, env = "MTOP_TLS_CA", value_hint = ValueHint::FilePath)]
tls_ca: Option<PathBuf>,
#[arg(long, env = "MTOP_TLS_SERVER_NAME", value_parser = parse_server_name)]
tls_server_name: Option<ServerName<'static>>,
#[arg(long, env = "MTOP_TLS_CERT", requires = "tls_key", value_hint = ValueHint::FilePath)]
tls_cert: Option<PathBuf>,
#[arg(long, env = "MTOP_TLS_KEY", requires = "tls_cert", value_hint = ValueHint::FilePath)]
tls_key: Option<PathBuf>,
#[arg(required = true, value_hint = ValueHint::Hostname)]
hosts: Vec<String>,
}
impl TryInto<TlsConfig> for &MtopConfig {
type Error = ();
fn try_into(self) -> Result<TlsConfig, Self::Error> {
if self.tls_enabled {
Ok(TlsConfig {
ca_path: self.tls_ca.clone(),
cert_path: self.tls_cert.clone(),
key_path: self.tls_key.clone(),
server_name: self.tls_server_name.clone(),
})
} else {
Err(())
}
}
}
fn parse_server_name(s: &str) -> Result<ServerName<'static>, InvalidDnsNameError> {
ServerName::try_from(s).map(|n| n.to_owned())
}
fn default_log_file() -> PathBuf {
env::temp_dir().join("mtop").join("mtop.log")
}
#[tokio::main]
async fn main() -> ExitCode {
let opts = MtopConfig::parse();
let console_subscriber =
mtop::tracing::console_subscriber(opts.log_level).expect("failed to setup console logging");
tracing::subscriber::set_global_default(console_subscriber).expect("failed to initialize console logging");
let file_subscriber = match mtop::tracing::file_subscriber(opts.log_level, &opts.log_file).map(Arc::new) {
Ok(v) => v,
Err(e) => {
tracing::error!(message = "failed to initialize file logging", error = %e);
return ExitCode::FAILURE;
}
};
let timeout = Duration::from_secs(opts.timeout_secs.get());
let measurements = Arc::new(StatsQueue::new(NUM_MEASUREMENTS));
let dns_client = mtop::dns::new_client(&opts.resolv_conf, None, None).await;
let discovery = Discovery::new(dns_client);
let servers = match mtop::discovery::resolve(&opts.hosts, &discovery, timeout).await {
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to resolve host names", hosts = ?opts.hosts, error = %e);
return ExitCode::FAILURE;
}
};
if servers.is_empty() {
tracing::error!(message = "resolving host names did not return any results", hosts = ?opts.hosts);
return ExitCode::FAILURE;
}
let client = match mtop::discovery::new_client(&servers, opts.connections.get(), (&opts).try_into().ok()).await {
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to initialize memcached client", hosts = fmt_servers(&servers), err = %e);
return ExitCode::FAILURE;
}
};
let update_task = UpdateTask::new(client, measurements.clone(), timeout);
if let Err(e) = update_task.connect().await {
tracing::error!(message = "unable to connect to memcached servers", hosts = fmt_servers(&servers), err = %e);
return ExitCode::FAILURE;
}
task::spawn(
async move {
let mut interval = tokio::time::interval(STATS_INTERVAL);
loop {
let _ = interval.tick().await;
if let Err(e) = update_task
.update()
.instrument(tracing::span!(Level::INFO, "periodic.update"))
.await
{
tracing::error!(message = "unable to update server metrics", err = %e);
}
}
}
.with_subscriber(file_subscriber.clone()),
);
let ui_res = task::spawn_blocking(move || {
let mut term = mtop::ui::initialize_terminal()?;
mtop::ui::install_panic_handler();
let blocking_measurements = BlockingStatsQueue::new(measurements.clone(), Handle::current());
let hosts: Vec<Host> = servers.iter().map(|s| Host::from(s.id())).collect();
let app = mtop::ui::Application::new(&hosts, blocking_measurements, opts.theme);
mtop::ui::run(&mut term, app).and(mtop::ui::reset_terminal())
})
.await;
match ui_res {
Err(e) => {
tracing::error!(message = "unable to run UI in dedicated thread", err = %e);
ExitCode::FAILURE
}
Ok(Err(e)) => {
tracing::error!(message = "error setting up terminal or running UI", err = %e);
ExitCode::FAILURE
}
_ => ExitCode::SUCCESS,
}
}
fn fmt_servers(servers: &[Server]) -> String {
let ids: Vec<String> = servers.iter().map(|s| s.id().to_string()).collect();
format!("[{}]", ids.join(", "))
}
#[derive(Debug)]
struct UpdateTask {
client: MemcachedClient,
queue: Arc<StatsQueue>,
timeout: Duration,
}
impl UpdateTask {
fn new(client: MemcachedClient, queue: Arc<StatsQueue>, timeout: Duration) -> Self {
UpdateTask { client, queue, timeout }
}
async fn connect(&self) -> Result<(), MtopError> {
let pings = self
.client
.ping()
.timeout(self.timeout, "client.ping")
.instrument(tracing::span!(Level::INFO, "client.ping"))
.await?;
if let Some((_server, err)) = pings.errors.into_iter().next() {
return Err(err);
}
Ok(())
}
async fn update(&self) -> Result<(), MtopError> {
let stats = self
.client
.stats()
.timeout(self.timeout, "client.stats")
.instrument(tracing::span!(Level::INFO, "client.stats"))
.await?;
let mut slabs = self
.client
.slabs()
.timeout(self.timeout, "client.slabs")
.instrument(tracing::span!(Level::INFO, "client.slabs"))
.await?;
let mut items = self
.client
.items()
.timeout(self.timeout, "client.items")
.instrument(tracing::span!(Level::INFO, "client.items"))
.await?;
for (id, stats) in stats.values {
let slabs = match slabs.values.remove(&id) {
Some(v) => v,
None => continue,
};
let items = match items.values.remove(&id) {
Some(v) => v,
None => continue,
};
self.queue
.insert(Host::from(&id), stats, slabs, items)
.instrument(tracing::span!(Level::INFO, "queue.insert"))
.await;
}
for (id, e) in stats.errors {
tracing::warn!(message = "error fetching stats", server = %id, err = %e);
}
for (id, e) in slabs.errors {
tracing::warn!(message = "error fetching slabs", server = %id, err = %e);
}
for (id, e) in items.errors {
tracing::warn!(message = "error fetching items", server = %id, err = %e);
}
Ok(())
}
}