use std::{collections::BTreeSet, ffi::OsString, io::Cursor, sync::LazyLock};
use clap::{ArgAction, ArgMatches, CommandFactory, Parser, parser::ValueSource};
use config::{ConfigError, Map, Source, Value, ValueKind};
use getset::{CopyGetters, Getters};
use libmoshpit::PathDefaults;
use vergen_pretty::{Pretty, vergen_pretty_env};
static LONG_VERSION: LazyLock<String> = LazyLock::new(|| {
let pretty = Pretty::builder().env(vergen_pretty_env!()).build();
let mut cursor = Cursor::new(vec![]);
let mut output = env!("CARGO_PKG_VERSION").to_string();
output.push_str("\n\n");
pretty
.display(&mut cursor)
.expect("writing to Vec never fails");
output += &String::from_utf8_lossy(cursor.get_ref());
output
});
#[derive(Clone, CopyGetters, Debug, Getters, Parser)]
#[command(author, version, about, long_version = LONG_VERSION.as_str(), long_about = None)]
pub(crate) struct Cli {
#[clap(
short,
long,
action = ArgAction::Count,
help = "Turn up logging verbosity (multiple will turn it up more)",
conflicts_with = "quiet",
)]
#[getset(get_copy = "pub(crate)")]
verbose: u8,
#[clap(
short,
long,
action = ArgAction::Count,
help = "Turn down logging verbosity (multiple will turn it down more)",
conflicts_with = "verbose",
)]
#[getset(get_copy = "pub(crate)")]
quiet: u8,
#[clap(short, long, help = "Enable logging to stdout/stderr")]
enable_std_output: bool,
#[clap(short, long, help = "Specify the absolute path to the config file")]
#[getset(get = "pub(crate)")]
config_absolute_path: Option<String>,
#[clap(
short,
long,
help = "Specify the absolute path to the tracing output file"
)]
#[getset(get = "pub(crate)")]
tracing_absolute_path: Option<String>,
#[clap(
short,
long,
help = "Specify the absolute path to the private key file"
)]
#[getset(get = "pub(crate)")]
private_key_path: Option<String>,
#[clap(
short = 'k',
long,
help = "Specify the absolute path to the public key file"
)]
#[getset(get = "pub(crate)")]
public_key_path: Option<String>,
#[clap(
long,
value_name = "MILLIS",
help = "Extra delay (ms) after peer discovery before sending terminal data"
)]
#[getset(get_copy = "pub(crate)")]
warmup_delay_ms: Option<u64>,
#[clap(
long,
value_name = "MICROS",
help = "Min inter-packet delay (µs) between diff chunks [default: 1000]"
)]
#[getset(get_copy = "pub(crate)")]
pacing_delay_us: Option<u64>,
#[clap(
long,
value_name = "VERSION",
help = "Minimum wire protocol version to accept (older clients are rejected)"
)]
#[getset(get_copy = "pub(crate)")]
min_protocol_version: Option<u16>,
#[clap(
long,
value_name = "TERM",
default_value = "xterm-256color",
help = "TERM environment variable for spawned shells"
)]
#[getset(get = "pub(crate)")]
term_type: String,
#[clap(
long,
value_name = "ALGOS",
help = "Ordered KEX algorithms to prefer, comma-separated [supported: x25519-sha256 (default), ml-kem-768-sha256, ml-kem-512-sha256, ml-kem-1024-sha256, p384-sha384, p256-sha256]"
)]
#[getset(get = "pub(crate)")]
kex_algos: Option<String>,
#[clap(
long,
value_name = "ALGOS",
help = "Ordered AEAD algorithms to prefer, comma-separated [supported: aes256-gcm-siv (default), aes256-gcm, chacha20-poly1305, aes128-gcm-siv]"
)]
#[getset(get = "pub(crate)")]
aead_algos: Option<String>,
#[clap(
long,
value_name = "ALGOS",
help = "Ordered MAC algorithms to prefer, comma-separated [supported: hmac-sha512 (default), hmac-sha256]"
)]
#[getset(get = "pub(crate)")]
mac_algos: Option<String>,
#[clap(
long,
value_name = "ALGOS",
help = "Ordered KDF algorithms to prefer, comma-separated [supported: hkdf-sha256 (default), hkdf-sha384, hkdf-sha512]"
)]
#[getset(get = "pub(crate)")]
kdf_algos: Option<String>,
#[clap(skip)]
#[getset(get = "pub(crate)")]
explicit_args: BTreeSet<String>,
}
impl Cli {
pub(crate) fn parse_argv<I, T>(argv: I) -> clap::error::Result<Self>
where
I: IntoIterator<Item = T> + Clone,
T: Into<OsString> + Clone,
{
let mut cli = Cli::try_parse_from(argv.clone())?;
let matches = <Cli as CommandFactory>::command().try_get_matches_from(argv)?;
cli.explicit_args = explicit_command_line_ids(&matches);
Ok(cli)
}
}
fn explicit_command_line_ids(matches: &ArgMatches) -> BTreeSet<String> {
matches
.ids()
.filter(|id| matches.value_source(id.as_str()) == Some(ValueSource::CommandLine))
.map(|id| id.as_str().to_string())
.collect()
}
fn build_algo_table(
kex: Option<&str>,
aead: Option<&str>,
mac: Option<&str>,
kdf: Option<&str>,
) -> Option<Map<String, Value>> {
let mut table = Map::new();
let parse = |s: &str| -> Vec<Value> {
s.split(',')
.map(|a| Value::new(None, ValueKind::String(a.trim().to_string())))
.collect()
};
for (key, opt) in [("kex", kex), ("aead", aead), ("mac", mac), ("kdf", kdf)] {
if let Some(s) = opt {
let _old = table.insert(
key.to_string(),
Value::new(None, ValueKind::Array(parse(s))),
);
}
}
(!table.is_empty()).then_some(table)
}
impl Source for Cli {
fn clone_into_box(&self) -> Box<dyn Source + Send + Sync> {
Box::new((*self).clone())
}
fn collect(&self) -> Result<Map<String, Value>, ConfigError> {
let mut map = Map::new();
let origin = String::from("command line");
let on = |id: &str| self.explicit_args.contains(id);
if on("verbose") {
let _old = map.insert(
"verbose".to_string(),
Value::new(Some(&origin), ValueKind::U64(u8::into(self.verbose))),
);
}
if on("quiet") {
let _old = map.insert(
"quiet".to_string(),
Value::new(Some(&origin), ValueKind::U64(u8::into(self.quiet))),
);
}
if on("enable_std_output") {
let _old = map.insert(
"enable_std_output".to_string(),
Value::new(Some(&origin), ValueKind::Boolean(self.enable_std_output)),
);
}
if on("config_absolute_path")
&& let Some(config_path) = &self.config_absolute_path
{
let _old = map.insert(
"config_path".to_string(),
Value::new(Some(&origin), ValueKind::String(config_path.clone())),
);
}
if on("tracing_absolute_path")
&& let Some(tracing_path) = &self.tracing_absolute_path
{
let _old = map.insert(
"tracing_path".to_string(),
Value::new(Some(&origin), ValueKind::String(tracing_path.clone())),
);
}
if on("private_key_path")
&& let Some(private_key_path) = &self.private_key_path
{
let _old = map.insert(
"private_key_path".to_string(),
Value::new(Some(&origin), ValueKind::String(private_key_path.clone())),
);
}
if on("public_key_path")
&& let Some(public_key_path) = &self.public_key_path
{
let _old = map.insert(
"public_key_path".to_string(),
Value::new(Some(&origin), ValueKind::String(public_key_path.clone())),
);
}
if on("warmup_delay_ms")
&& let Some(warmup_delay_ms) = self.warmup_delay_ms
{
let _old = map.insert(
"warmup_delay_ms".to_string(),
Value::new(Some(&origin), ValueKind::U64(warmup_delay_ms)),
);
}
if on("pacing_delay_us")
&& let Some(pacing_delay_us) = self.pacing_delay_us
{
let _old = map.insert(
"pacing_delay_us".to_string(),
Value::new(Some(&origin), ValueKind::U64(pacing_delay_us)),
);
}
if on("min_protocol_version")
&& let Some(min_protocol_version) = self.min_protocol_version
{
let _old = map.insert(
"min_protocol_version".to_string(),
Value::new(
Some(&origin),
ValueKind::U64(u64::from(min_protocol_version)),
),
);
}
if on("term_type") {
let _old = map.insert(
"term_type".to_string(),
Value::new(Some(&origin), ValueKind::String(self.term_type.clone())),
);
}
if let Some(table) = build_algo_table(
self.kex_algos.as_deref().filter(|_| on("kex_algos")),
self.aead_algos.as_deref().filter(|_| on("aead_algos")),
self.mac_algos.as_deref().filter(|_| on("mac_algos")),
self.kdf_algos.as_deref().filter(|_| on("kdf_algos")),
) {
let _old = map.insert(
"preferred_algorithms".to_string(),
Value::new(Some(&origin), ValueKind::Table(table)),
);
}
Ok(map)
}
}
impl PathDefaults for Cli {
fn env_prefix(&self) -> String {
env!("CARGO_PKG_NAME").to_ascii_uppercase()
}
fn config_absolute_path(&self) -> Option<String> {
self.config_absolute_path.clone()
}
fn default_file_path(&self) -> String {
env!("CARGO_PKG_NAME").to_string()
}
fn default_file_name(&self) -> String {
env!("CARGO_PKG_NAME").to_string()
}
fn tracing_absolute_path(&self) -> Option<String> {
self.tracing_absolute_path.clone()
}
fn default_tracing_path(&self) -> String {
format!("{}/logs", env!("CARGO_PKG_NAME"))
}
fn default_tracing_file_name(&self) -> String {
env!("CARGO_PKG_NAME").to_string()
}
}
#[cfg(test)]
mod test {
use config::Source as _;
use super::Cli;
fn parse(args: &[&str]) -> Cli {
Cli::parse_argv(args).expect("args parse")
}
#[test]
fn cli_defaults() {
let cli = parse(&["mps"]);
assert_eq!(cli.verbose(), 0);
assert_eq!(cli.quiet(), 0);
assert!(!cli.enable_std_output);
assert!(cli.config_absolute_path().is_none());
assert!(cli.tracing_absolute_path().is_none());
assert!(cli.private_key_path().is_none());
assert!(cli.public_key_path().is_none());
assert_eq!(cli.term_type(), "xterm-256color");
}
#[test]
fn cli_verbose() {
let cli = parse(&["mps", "-vv"]);
assert_eq!(cli.verbose(), 2);
}
#[test]
fn cli_quiet() {
let cli = parse(&["mps", "-qq"]);
assert_eq!(cli.quiet(), 2);
}
#[test]
fn cli_private_key_path() {
let cli = parse(&["mps", "-p", "/tmp/key"]);
assert_eq!(cli.private_key_path().as_deref(), Some("/tmp/key"));
}
#[test]
fn cli_public_key_path() {
let cli = parse(&["mps", "-k", "/tmp/key.pub"]);
assert_eq!(cli.public_key_path().as_deref(), Some("/tmp/key.pub"));
}
#[test]
fn cli_term_type_default() {
let cli = parse(&["mps"]);
assert_eq!(cli.term_type(), "xterm-256color");
}
#[test]
fn cli_term_type_custom() {
let cli = parse(&["mps", "--term-type", "screen-256color"]);
assert_eq!(cli.term_type(), "screen-256color");
}
#[test]
fn cli_term_type_various_values() {
let test_cases = vec!["xterm", "screen", "tmux-256color", "linux", "vt100"];
for term in test_cases {
let cli = parse(&["mps", "--term-type", term]);
assert_eq!(cli.term_type(), term);
}
}
#[test]
fn cli_source_collect() {
let cli = parse(&["mps"]);
let map = cli.collect().expect("collect should succeed");
assert!(!map.contains_key("verbose"));
assert!(!map.contains_key("quiet"));
assert!(!map.contains_key("enable_std_output"));
assert!(!map.contains_key("term_type"));
assert!(!map.contains_key("private_key_path"));
assert!(!map.contains_key("public_key_path"));
assert!(!map.contains_key("config_path"));
assert!(!map.contains_key("tracing_path"));
}
#[test]
fn cli_source_collect_emits_explicit_flags() {
let cli = parse(&["mps", "-vv", "--enable-std-output", "--term-type", "screen"]);
let map = cli.collect().expect("collect should succeed");
assert!(map.contains_key("verbose"));
assert!(map.contains_key("enable_std_output"));
assert!(map.contains_key("term_type"));
assert!(!map.contains_key("quiet"));
}
#[test]
fn cli_source_collect_with_paths() {
let cli = parse(&[
"mps",
"-p",
"/tmp/priv",
"-k",
"/tmp/pub",
"-c",
"/tmp/config.toml",
"-t",
"/tmp/trace.log",
"--term-type",
"tmux-256color",
]);
let map = cli.collect().expect("collect should succeed");
assert!(map.contains_key("private_key_path"));
assert!(map.contains_key("public_key_path"));
assert!(map.contains_key("config_path"));
assert!(map.contains_key("tracing_path"));
assert!(map.contains_key("term_type"));
let term_type = map.get("term_type").expect("term_type should be in map");
assert_eq!(
term_type.clone().into_string().ok(),
Some("tmux-256color".to_string())
);
}
#[test]
fn cli_min_protocol_version_absent_by_default() {
let cli = parse(&["mps"]);
assert!(cli.min_protocol_version().is_none());
let map = cli.collect().expect("collect should succeed");
assert!(!map.contains_key("min_protocol_version"));
}
#[test]
fn cli_min_protocol_version_collected() {
let cli = parse(&["mps", "--min-protocol-version", "3"]);
assert_eq!(cli.min_protocol_version(), Some(3));
let map = cli.collect().expect("collect should succeed");
let value = map
.get("min_protocol_version")
.expect("min_protocol_version should be in map");
assert_eq!(value.clone().into_uint().ok(), Some(3));
}
#[test]
fn cli_source_collect_optional_delays() {
let cli = parse(&["mps", "--warmup-delay-ms", "50", "--pacing-delay-us", "250"]);
let map = cli.collect().expect("collect should succeed");
let warmup = map
.get("warmup_delay_ms")
.expect("warmup_delay_ms should be in map");
assert_eq!(warmup.clone().into_uint().ok(), Some(50));
let pacing = map
.get("pacing_delay_us")
.expect("pacing_delay_us should be in map");
assert_eq!(pacing.clone().into_uint().ok(), Some(250));
}
#[test]
fn cli_source_collect_algo_table() {
let cli = parse(&[
"mps",
"--kex-algos",
"x25519-sha256, ml-kem-768-sha256",
"--aead-algos",
"aes256-gcm-siv",
"--mac-algos",
"hmac-sha512",
"--kdf-algos",
"hkdf-sha256",
]);
let map = cli.collect().expect("collect should succeed");
let value = map
.get("preferred_algorithms")
.expect("preferred_algorithms should be in map");
let table = value
.clone()
.into_table()
.expect("preferred_algorithms should be a table");
assert!(table.contains_key("kex"));
assert!(table.contains_key("aead"));
assert!(table.contains_key("mac"));
assert!(table.contains_key("kdf"));
}
#[test]
fn cli_path_defaults() {
use libmoshpit::PathDefaults as _;
let cli = parse(&["mps"]);
assert_eq!(cli.env_prefix(), "MOSHPITS");
assert_eq!(cli.default_file_path(), "moshpits");
assert_eq!(cli.default_file_name(), "moshpits");
assert_eq!(cli.default_tracing_path(), "moshpits/logs");
assert_eq!(cli.default_tracing_file_name(), "moshpits");
assert!(cli.config_absolute_path().is_none());
assert!(cli.tracing_absolute_path().is_none());
}
}