use clap::{Args, Parser, Subcommand, ValueHint};
use mtop::bench::{Bencher, Percent, Summary};
use mtop::check::{Bundle, Checker};
use mtop::profile;
use mtop_client::{Discovery, MemcachedClient, Meta, MtopError, Timeout, TlsConfig, Value};
use rustls_pki_types::{InvalidDnsNameError, ServerName};
use std::num::{NonZeroU64, NonZeroUsize};
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use std::{env, io, slice};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::runtime::Handle;
use tracing::{Instrument, Level};
#[derive(Debug, Parser)]
#[command(name = "mc", version = clap::crate_version!())]
struct McConfig {
#[arg(long, env = "MC_LOG_LEVEL", default_value_t = Level::INFO)]
log_level: Level,
#[arg(long, env = "MC_RESOLV_CONF", default_value = "/etc/resolv.conf", value_hint = ValueHint::FilePath)]
resolv_conf: PathBuf,
#[arg(long, env = "MC_HOST", default_value = "localhost:11211", value_hint = ValueHint::Hostname)]
host: String,
#[arg(long, env = "MC_TIMEOUT_SECS", default_value_t = NonZeroU64::new(30).unwrap())]
timeout_secs: NonZeroU64,
#[arg(long, env = "MC_CONNECTIONS", default_value_t = NonZeroU64::new(4).unwrap())]
connections: NonZeroU64,
#[arg(long, env = "MC_PROFILE_OUTPUT", value_hint = ValueHint::FilePath)]
profile_output: Option<PathBuf>,
#[arg(long, env = "MC_TLS_ENABLED")]
tls_enabled: bool,
#[arg(long, env = "MC_TLS_CA", value_hint = ValueHint::FilePath)]
tls_ca: Option<PathBuf>,
#[arg(long, env = "MC_TLS_SERVER_NAME", value_parser = parse_server_name)]
tls_server_name: Option<ServerName<'static>>,
#[arg(long, env = "MC_TLS_CERT", requires = "tls_key", value_hint = ValueHint::FilePath)]
tls_cert: Option<PathBuf>,
#[arg(long, env = "MC_TLS_KEY", requires = "tls_cert", value_hint = ValueHint::FilePath)]
tls_key: Option<PathBuf>,
#[command(subcommand)]
mode: Action,
}
impl TryInto<TlsConfig> for &McConfig {
type Error = ();
fn try_into(self) -> Result<TlsConfig, Self::Error> {
if self.tls_enabled {
Ok(TlsConfig {
ca_path: self.tls_ca.clone(),
cert_path: self.tls_cert.clone(),
key_path: self.tls_key.clone(),
server_name: self.tls_server_name.clone(),
})
} else {
Err(())
}
}
}
fn parse_server_name(s: &str) -> Result<ServerName<'static>, InvalidDnsNameError> {
ServerName::try_from(s).map(|n| n.to_owned())
}
#[derive(Debug, Subcommand)]
enum Action {
Add(AddCommand),
Bench(BenchCommand),
Check(CheckCommand),
Decr(DecrCommand),
Delete(DeleteCommand),
FlushAll(FlushAllCommand),
Get(GetCommand),
Incr(IncrCommand),
Keys(KeysCommand),
Replace(ReplaceCommand),
Set(SetCommand),
Touch(TouchCommand),
}
#[derive(Debug, Args)]
struct AddCommand {
#[arg(required = true)]
key: String,
#[arg(required = true)]
ttl_secs: u32,
}
#[derive(Debug, Args)]
struct BenchCommand {
#[arg(long, env = "MC_BENCH_TIME_SECS", default_value_t = NonZeroU64::new(60).unwrap())]
time_secs: NonZeroU64,
#[arg(long, env = "MC_BENCH_WRITE_PERCENT", default_value_t = Percent::unchecked(0.05))]
write_percent: Percent,
#[arg(long, env = "MC_BENCH_CONCURRENCY", default_value_t = NonZeroUsize::new(1).unwrap())]
concurrency: NonZeroUsize,
#[arg(long, env = "MC_BENCH_DELAY_MILLIS", default_value_t = NonZeroU64::new(100).unwrap())]
delay_millis: NonZeroU64,
#[arg(long, env = "MC_BENCH_TTL_SECS", default_value_t = 300)]
ttl_secs: u32,
}
#[derive(Debug, Args)]
struct CheckCommand {
#[arg(long, env = "MC_CHECK_TIME_SECS", default_value_t = NonZeroU64::new(60).unwrap())]
time_secs: NonZeroU64,
#[arg(long, env = "MC_CHECK_DELAY_MILLIS", default_value_t = NonZeroU64::new(100).unwrap())]
delay_millis: NonZeroU64,
}
#[derive(Debug, Args)]
struct DecrCommand {
#[arg(required = true)]
key: String,
#[arg(required = true)]
delta: u64,
}
#[derive(Debug, Args)]
struct DeleteCommand {
#[arg(required = true)]
key: String,
}
#[derive(Debug, Args)]
struct FlushAllCommand {
#[arg(long, env = "MC_FLUSH_ALL_WAIT_SECS")]
wait_secs: Option<NonZeroU64>,
}
#[derive(Debug, Args)]
struct GetCommand {
#[arg(required = true)]
key: String,
}
#[derive(Debug, Args)]
struct IncrCommand {
#[arg(required = true)]
key: String,
#[arg(required = true)]
delta: u64,
}
#[derive(Debug, Args)]
struct KeysCommand {
#[arg(long, env = "MC_KEYS_DETAILS")]
details: bool,
}
#[derive(Debug, Args)]
struct ReplaceCommand {
#[arg(required = true)]
key: String,
#[arg(required = true)]
ttl_secs: u32,
}
#[derive(Debug, Args)]
struct SetCommand {
#[arg(required = true)]
key: String,
#[arg(required = true)]
ttl_secs: u32,
}
#[derive(Debug, Args)]
struct TouchCommand {
#[arg(required = true)]
key: String,
#[arg(required = true)]
ttl_secs: u32,
}
#[tokio::main]
async fn main() -> ExitCode {
let opts = McConfig::parse();
let console_subscriber =
mtop::tracing::console_subscriber(opts.log_level).expect("failed to setup console logging");
tracing::subscriber::set_global_default(console_subscriber).expect("failed to initialize console logging");
let dns_client = mtop::dns::new_client(&opts.resolv_conf, None, None).await;
let discovery = Discovery::new(dns_client);
let timeout = Duration::from_secs(opts.timeout_secs.get());
let servers = match mtop::discovery::resolve_single(&opts.host, &discovery, timeout).await {
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to resolve host names", hosts = ?opts.host, err = %e);
return ExitCode::FAILURE;
}
};
let client = match mtop::discovery::new_client(&servers, opts.connections.get(), (&opts).try_into().ok()).await {
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to initialize memcached client", host = opts.host, err = %e);
return ExitCode::FAILURE;
}
};
if let Err(e) = connect(&client, timeout).await {
tracing::error!(message = "unable to connect", host = opts.host, err = %e);
return ExitCode::FAILURE;
};
let profiling = profile::Writer::default();
let code = match &opts.mode {
Action::Add(cmd) => run_add(&opts, cmd, &client).await,
Action::Bench(cmd) => run_bench(&opts, cmd, client).await,
Action::Check(cmd) => run_check(&opts, cmd, client, discovery).await,
Action::Decr(cmd) => run_decr(&opts, cmd, &client).await,
Action::Delete(cmd) => run_delete(&opts, cmd, &client).await,
Action::FlushAll(cmd) => run_flush_all(&opts, cmd, &client).await,
Action::Get(cmd) => run_get(&opts, cmd, &client).await,
Action::Incr(cmd) => run_incr(&opts, cmd, &client).await,
Action::Keys(cmd) => run_keys(&opts, cmd, &client).await,
Action::Replace(cmd) => run_replace(&opts, cmd, &client).await,
Action::Set(cmd) => run_set(&opts, cmd, &client).await,
Action::Touch(cmd) => run_touch(&opts, cmd, &client).await,
};
if let Some(p) = opts.profile_output {
profiling.finish(p);
}
code
}
async fn connect(client: &MemcachedClient, timeout: Duration) -> Result<(), MtopError> {
let pings = client
.ping()
.timeout(timeout, "client.ping")
.instrument(tracing::span!(Level::INFO, "client.ping"))
.await?;
if let Some((_server, err)) = pings.errors.into_iter().next() {
return Err(err);
}
Ok(())
}
async fn run_add(opts: &McConfig, cmd: &AddCommand, client: &MemcachedClient) -> ExitCode {
let buf = match read_input().await {
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to read item data from stdin", err = %e);
return ExitCode::FAILURE;
}
};
if let Err(e) = client
.add(&cmd.key, 0, cmd.ttl_secs, &buf)
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.add")
.instrument(tracing::span!(Level::INFO, "client.add"))
.await
{
tracing::error!(message = "unable to add item", key = cmd.key, host = opts.host, err = %e);
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_bench(opts: &McConfig, cmd: &BenchCommand, client: MemcachedClient) -> ExitCode {
let stop = Arc::new(AtomicBool::new(false));
mtop::sig::wait_for_interrupt(Handle::current(), stop.clone()).await;
let bencher = Bencher::new(
client,
Handle::current(),
Duration::from_millis(cmd.delay_millis.get()),
Duration::from_secs(opts.timeout_secs.get()),
Duration::from_secs(cmd.ttl_secs as u64),
cmd.write_percent,
cmd.concurrency.get(),
stop.clone(),
);
let measurements = bencher.run(Duration::from_secs(cmd.time_secs.into())).await;
print_bench_results(&measurements);
ExitCode::SUCCESS
}
async fn run_check(opts: &McConfig, cmd: &CheckCommand, client: MemcachedClient, resolver: Discovery) -> ExitCode {
let stop = Arc::new(AtomicBool::new(false));
mtop::sig::wait_for_interrupt(Handle::current(), stop.clone()).await;
let checker = Checker::new(
client,
resolver,
Duration::from_millis(cmd.delay_millis.get()),
Duration::from_secs(opts.timeout_secs.get()),
stop.clone(),
);
let results = checker.run(&opts.host, Duration::from_secs(cmd.time_secs.get())).await;
print_check_results(&results);
if results.failures.total > 0 {
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_decr(opts: &McConfig, cmd: &DecrCommand, client: &MemcachedClient) -> ExitCode {
if let Err(e) = client
.decr(&cmd.key, cmd.delta)
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.decr")
.instrument(tracing::span!(Level::INFO, "client.decr"))
.await
{
tracing::error!(message = "unable to decrement value", key = cmd.key, host = opts.host, err = %e);
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_delete(opts: &McConfig, cmd: &DeleteCommand, client: &MemcachedClient) -> ExitCode {
if let Err(e) = client
.delete(&cmd.key)
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.delete")
.instrument(tracing::span!(Level::INFO, "client.delete"))
.await
{
tracing::error!(message = "unable to delete item", key = cmd.key, host = opts.host, err = %e);
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_flush_all(opts: &McConfig, cmd: &FlushAllCommand, client: &MemcachedClient) -> ExitCode {
let wait = cmd.wait_secs.map(|d| Duration::from_secs(d.get()));
let response = match client
.flush_all(wait)
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.flush_all")
.instrument(tracing::span!(Level::INFO, "client.flush_all"))
.await
{
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to flush caches", host = opts.host, err = %e);
return ExitCode::FAILURE;
}
};
let mut success = response.values.into_iter().collect::<Vec<_>>();
success.sort_by(|v1, v2| v1.0.cmp(&v2.0));
let mut errors = response.errors.into_iter().collect::<Vec<_>>();
errors.sort_by(|v1, v2| v1.0.cmp(&v2.0));
let has_errors = !errors.is_empty();
for (id, _) in success {
tracing::info!(message = "scheduled cache flush", host = %id);
}
for (id, e) in errors {
tracing::error!(message = "unable to flush cache for server", host = %id, err = %e);
}
if has_errors {
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_get(opts: &McConfig, cmd: &GetCommand, client: &MemcachedClient) -> ExitCode {
let response = match client
.get(slice::from_ref(&cmd.key))
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.get")
.instrument(tracing::span!(Level::INFO, "client.get"))
.await
{
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to get item", key = cmd.key, host = opts.host, err = %e);
return ExitCode::FAILURE;
}
};
if let Some(v) = response.values.get(&cmd.key)
&& let Err(e) = print_data(v).await
{
tracing::warn!(message = "error writing output", err = %e);
}
for (id, e) in response.errors.iter() {
tracing::error!(message = "error fetching value", server = %id, err = %e);
}
if !response.errors.is_empty() {
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_incr(opts: &McConfig, cmd: &IncrCommand, client: &MemcachedClient) -> ExitCode {
if let Err(e) = client
.incr(&cmd.key, cmd.delta)
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.incr")
.instrument(tracing::span!(Level::INFO, "client.incr"))
.await
{
tracing::error!(message = "unable to increment value", key = cmd.key, host = opts.host, err = %e);
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_keys(opts: &McConfig, cmd: &KeysCommand, client: &MemcachedClient) -> ExitCode {
let response = match client
.metas()
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.metas")
.instrument(tracing::span!(Level::INFO, "client.metas"))
.await
{
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to list keys", host = opts.host, err = %e);
return ExitCode::FAILURE;
}
};
let has_errors = !response.errors.is_empty();
let mut metas: Vec<Meta> = response.values.into_values().flatten().collect();
metas.sort();
if let Err(e) = print_keys(&metas, cmd.details).await {
tracing::warn!(message = "error writing output", err = %e);
}
for (id, e) in response.errors.iter() {
tracing::error!(message = "error fetching metas", server = %id, err = %e);
}
if has_errors {
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_replace(opts: &McConfig, cmd: &ReplaceCommand, client: &MemcachedClient) -> ExitCode {
let buf = match read_input().await {
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to read item data from stdin", err = %e);
return ExitCode::FAILURE;
}
};
if let Err(e) = client
.replace(&cmd.key, 0, cmd.ttl_secs, &buf)
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.replace")
.instrument(tracing::span!(Level::INFO, "client.replace"))
.await
{
tracing::error!(message = "unable to replace item", key = cmd.key, host = opts.host, err = %e);
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_set(opts: &McConfig, cmd: &SetCommand, client: &MemcachedClient) -> ExitCode {
let buf = match read_input().await {
Ok(v) => v,
Err(e) => {
tracing::error!(message = "unable to read item data from stdin", err = %e);
return ExitCode::FAILURE;
}
};
if let Err(e) = client
.set(&cmd.key, 0, cmd.ttl_secs, &buf)
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.set")
.instrument(tracing::span!(Level::INFO, "client.set"))
.await
{
tracing::error!(message = "unable to set item", key = cmd.key, host = opts.host, err = %e);
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn run_touch(opts: &McConfig, cmd: &TouchCommand, client: &MemcachedClient) -> ExitCode {
if let Err(e) = client
.touch(&cmd.key, cmd.ttl_secs)
.timeout(Duration::from_secs(opts.timeout_secs.get()), "client.touch")
.instrument(tracing::span!(Level::INFO, "client.touch"))
.await
{
tracing::error!(message = "unable to touch item", key = cmd.key, host = opts.host, err = %e);
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
async fn read_input() -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
let mut input = BufReader::new(tokio::io::stdin());
input.read_to_end(&mut buf).await?;
Ok(buf)
}
async fn print_data(val: &Value) -> io::Result<()> {
let mut output = BufWriter::new(tokio::io::stdout());
output.write_all(&val.data).await?;
output.flush().await
}
async fn print_keys(metas: &[Meta], show_details: bool) -> io::Result<()> {
let mut output = BufWriter::new(tokio::io::stdout());
if show_details {
for meta in metas {
output
.write_all(format!("{}\t{}\t{}\n", meta.key, meta.expires, meta.size).as_bytes())
.await?;
}
} else {
for meta in metas {
output.write_all(format!("{}\n", meta.key).as_bytes()).await?;
}
}
output.flush().await
}
fn print_bench_results(results: &[Summary]) {
for m in results {
println!(
"worker={} gets={} gets_time={:.3} gets_per_second={:.0} sets={} sets_time={:.3} sets_per_second={:.0}",
m.worker,
m.gets,
m.gets_time.as_secs_f64(),
m.gets_per_sec(),
m.sets,
m.sets_time.as_secs_f64(),
m.sets_per_sec(),
);
}
}
fn print_check_results(results: &Bundle) {
println!(
"type=overall min={:.6}s max={:.6}s avg={:.6}s stddev={:.6}s failures={}",
results.overall.min.as_secs_f64(),
results.overall.max.as_secs_f64(),
results.overall.avg.as_secs_f64(),
results.overall.std_dev.as_secs_f64(),
results.failures.total,
);
println!(
"type=dns min={:.6}s max={:.6}s avg={:.6}s stddev={:.6}s failures={}",
results.dns.min.as_secs_f64(),
results.dns.max.as_secs_f64(),
results.dns.avg.as_secs_f64(),
results.dns.std_dev.as_secs_f64(),
results.failures.dns,
);
println!(
"type=connection min={:.6}s max={:.6}s avg={:.6}s stddev={:.6}s failures={}",
results.connections.min.as_secs_f64(),
results.connections.max.as_secs_f64(),
results.connections.avg.as_secs_f64(),
results.connections.std_dev.as_secs_f64(),
results.failures.connections,
);
println!(
"type=set min={:.6}s max={:.6}s avg={:.6}s stddev={:.6}s failures={}",
results.sets.min.as_secs_f64(),
results.sets.max.as_secs_f64(),
results.sets.avg.as_secs_f64(),
results.sets.std_dev.as_secs_f64(),
results.failures.sets,
);
println!(
"type=get min={:.6}s max={:.6}s avg={:.6}s stddev={:.6}s failures={}",
results.gets.min.as_secs_f64(),
results.gets.max.as_secs_f64(),
results.gets.avg.as_secs_f64(),
results.gets.std_dev.as_secs_f64(),
results.failures.gets,
);
}