use std::collections::HashMap;
use std::error::Error;
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::str::FromStr;
use anyhow::anyhow;
use chrono::Utc;
use clap::{AppSettings, ArgEnum, Parser};
use serde::{Deserialize, Serialize};
fn parse_key_val<T, U>(s: &str) -> Result<(T, U), anyhow::Error>
where
T: std::str::FromStr,
T::Err: Error + Send + Sync + 'static,
U: std::str::FromStr,
U::Err: Error + Send + Sync + 'static,
{
let pos = s
.find('=')
.ok_or_else(|| anyhow!("invalid KEY=value: no `=` found in `{}`", s))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum Interval {
Count(u64),
Time(tokio::time::Duration),
Unbounded,
}
impl Interval {
pub fn is_not_zero(&self) -> bool {
match self {
Interval::Count(cnt) => *cnt > 0,
Interval::Time(d) => !d.is_zero(),
Interval::Unbounded => false,
}
}
pub fn is_bounded(&self) -> bool {
!matches!(self, Interval::Unbounded)
}
pub fn count(&self) -> Option<u64> {
if let Interval::Count(c) = self {
Some(*c)
} else {
None
}
}
pub fn seconds(&self) -> Option<f32> {
if let Interval::Time(d) = self {
Some(d.as_secs_f32())
} else {
None
}
}
}
impl FromStr for Interval {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(i) = s.parse() {
Ok(Interval::Count(i))
} else if let Ok(d) = parse_duration::parse(s) {
Ok(Interval::Time(d))
} else {
Err("Required integer number of cycles or time duration".to_string())
}
}
}
#[derive(Parser, Debug, Serialize, Deserialize)]
pub struct ConnectionConf {
#[clap(
short('c'),
long("connections"),
default_value = "1",
value_name = "COUNT"
)]
pub count: NonZeroUsize,
#[clap(name = "addresses", default_value = "localhost")]
pub addresses: Vec<String>,
#[clap(long, env("CASSANDRA_USER"), default_value = "")]
pub user: String,
#[clap(long, env("CASSANDRA_PASSWORD"), default_value = "")]
pub password: String,
#[clap(long("ssl"))]
pub ssl: bool,
#[clap(long("ssl-ca"), value_name = "PATH")]
pub ssl_ca_cert_file: Option<PathBuf>,
#[clap(long("ssl-cert"), value_name = "PATH")]
pub ssl_cert_file: Option<PathBuf>,
#[clap(long("ssl-key"), value_name = "PATH")]
pub ssl_key_file: Option<PathBuf>,
#[clap(long("consistency"), required = false, default_value = "LOCAL_QUORUM")]
pub consistency: Consistency,
}
#[derive(ArgEnum, Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum Consistency {
Any,
One,
Two,
Three,
Quorum,
All,
LocalOne,
#[default]
LocalQuorum,
EachQuorum,
}
impl Consistency {
pub fn scylla_consistency(&self) -> scylla::frame::types::Consistency {
match self {
Self::Any => scylla::frame::types::Consistency::Any,
Self::One => scylla::frame::types::Consistency::One,
Self::Two => scylla::frame::types::Consistency::Two,
Self::Three => scylla::frame::types::Consistency::Three,
Self::Quorum => scylla::frame::types::Consistency::Quorum,
Self::All => scylla::frame::types::Consistency::All,
Self::LocalOne => scylla::frame::types::Consistency::LocalOne,
Self::LocalQuorum => scylla::frame::types::Consistency::LocalQuorum,
Self::EachQuorum => scylla::frame::types::Consistency::EachQuorum,
}
}
}
impl FromStr for Consistency {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"any" => Ok(Self::Any),
"one" | "1" => Ok(Self::One),
"two" | "2" => Ok(Self::Two),
"three" | "3" => Ok(Self::Three),
"quorum" | "q" => Ok(Self::Quorum),
"all" => Ok(Self::All),
"local_one" | "localone" | "l1" => Ok(Self::LocalOne),
"local_quorum" | "localquorum" | "lq" => Ok(Self::LocalQuorum),
"each_quorum" | "eachquorum" | "eq" => Ok(Self::EachQuorum),
s => Err(format!("Unknown consistency level {s}")),
}
}
}
#[derive(Parser, Debug, Serialize, Deserialize)]
#[clap(
setting(AppSettings::NextLineHelp),
setting(AppSettings::DeriveDisplayOrder)
)]
pub struct SchemaCommand {
#[clap(short('P'), parse(try_from_str = parse_key_val),
number_of_values = 1, multiple_occurrences = true)]
pub params: Vec<(String, String)>,
#[clap(name = "workload", required = true, value_name = "PATH")]
pub workload: PathBuf,
#[clap(flatten)]
pub connection: ConnectionConf,
}
#[derive(Parser, Debug, Serialize, Deserialize)]
#[clap(
setting(AppSettings::NextLineHelp),
setting(AppSettings::DeriveDisplayOrder)
)]
pub struct LoadCommand {
#[clap(short('r'), long, value_name = "COUNT")]
pub rate: Option<f64>,
#[clap(short('t'), long, default_value = "1", value_name = "COUNT")]
pub threads: NonZeroUsize,
#[clap(long, default_value = "512", value_name = "COUNT")]
pub concurrency: NonZeroUsize,
#[clap(short('P'), parse(try_from_str = parse_key_val),
number_of_values = 1, multiple_occurrences = true)]
pub params: Vec<(String, String)>,
#[clap(short, long)]
pub quiet: bool,
#[clap(name = "workload", required = true, value_name = "PATH")]
pub workload: PathBuf,
#[clap(flatten)]
pub connection: ConnectionConf,
}
#[derive(Parser, Debug, Serialize, Deserialize)]
#[clap(
setting(AppSettings::NextLineHelp),
setting(AppSettings::DeriveDisplayOrder)
)]
pub struct RunCommand {
#[clap(short('r'), long, value_name = "COUNT")]
pub rate: Option<f64>,
#[clap(
short('w'),
long("warmup"),
default_value = "1",
value_name = "TIME | COUNT"
)]
pub warmup_duration: Interval,
#[clap(
short('d'),
long("duration"),
default_value = "60s",
value_name = "TIME | COUNT"
)]
pub run_duration: Interval,
#[clap(short('t'), long, default_value = "1", value_name = "COUNT")]
pub threads: NonZeroUsize,
#[clap(short('p'), long, default_value = "128", value_name = "COUNT")]
pub concurrency: NonZeroUsize,
#[clap(
short('s'),
long("sampling"),
default_value = "1s",
value_name = "TIME | COUNT"
)]
pub sampling_interval: Interval,
#[clap(long("tag"), number_of_values = 1, multiple_occurrences = true)]
pub tags: Vec<String>,
#[clap(short('o'), long)]
#[serde(skip)]
pub output: Option<PathBuf>,
#[clap(short('b'), long, value_name = "PATH")]
pub baseline: Option<PathBuf>,
#[clap(name = "workload", required = true, value_name = "PATH")]
pub workload: PathBuf,
#[clap(short('P'), parse(try_from_str = parse_key_val),
number_of_values = 1, multiple_occurrences = true)]
pub params: Vec<(String, String)>,
#[clap(short, long)]
pub quiet: bool,
#[clap(flatten)]
pub connection: ConnectionConf,
#[clap(hidden(true), long)]
pub timestamp: Option<i64>,
#[clap(skip)]
pub cluster_name: Option<String>,
#[clap(skip)]
pub cass_version: Option<String>,
}
impl RunCommand {
pub fn set_timestamp_if_empty(mut self) -> Self {
if self.timestamp.is_none() {
self.timestamp = Some(Utc::now().timestamp())
}
self
}
pub fn get_param(&self, key: &str) -> Option<i64> {
self.params
.iter()
.find(|(k, _)| k == key)
.and_then(|v| v.1.parse().ok())
}
pub fn name(&self) -> String {
self.workload
.file_stem()
.unwrap()
.to_string_lossy()
.to_string()
}
}
#[derive(Parser, Debug)]
pub struct ShowCommand {
#[clap(value_name = "PATH")]
pub report: PathBuf,
#[clap(short('b'), long, value_name = "PATH")]
pub baseline: Option<PathBuf>,
}
#[derive(Parser, Debug)]
pub struct HdrCommand {
#[clap(value_name = "PATH")]
pub report: PathBuf,
#[clap(short('o'), long, value_name = "PATH")]
pub output: Option<PathBuf>,
#[clap(long, value_name = "STRING")]
pub tag: Option<String>,
}
#[derive(Parser, Debug)]
#[allow(clippy::large_enum_variant)]
pub enum Command {
Schema(SchemaCommand),
Load(LoadCommand),
Run(RunCommand),
Show(ShowCommand),
Hdr(HdrCommand),
}
#[derive(Parser, Debug)]
#[clap(
name = "Cassandra Latency and Throughput Tester",
author = "Piotr Kołaczkowski <pkolaczk@datastax.com>",
version = clap::crate_version ! (),
)]
pub struct AppConfig {
#[clap(subcommand)]
pub command: Command,
}
#[derive(Debug, Deserialize, Default)]
pub struct SchemaConfig {
#[serde(default)]
pub script: Vec<String>,
#[serde(default)]
pub cql: String,
}
#[derive(Debug, Deserialize)]
pub struct LoadConfig {
pub count: u64,
#[serde(default)]
pub script: Vec<String>,
#[serde(default)]
pub cql: String,
}
mod defaults {
pub fn ratio() -> f64 {
1.0
}
}
#[derive(Debug, Deserialize)]
pub struct RunConfig {
#[serde(default = "defaults::ratio")]
pub ratio: f64,
#[serde(default)]
pub script: Vec<String>,
#[serde(default)]
pub cql: String,
}
#[derive(Debug, Deserialize)]
pub struct WorkloadConfig {
#[serde(default)]
pub schema: SchemaConfig,
#[serde(default)]
pub load: HashMap<String, LoadConfig>,
pub run: HashMap<String, RunConfig>,
#[serde(default)]
pub bindings: HashMap<String, String>,
}