use clap::{Parser, ValueEnum};
#[derive(Parser, Debug)]
#[command(
name = "shohei",
version,
about = "Next-generation DNS diagnostic CLI with DNSSEC chain-of-trust visualization",
long_about = "shohei queries DNS records and can visualize the full DNSSEC chain of trust,\niterative resolution path, and supports modern transports like DoH and DoT."
)]
pub struct Args {
#[arg(value_parser = validate_domain_or_stdin)]
pub domain: Option<String>,
#[arg(short = 'x', long = "reverse", value_name = "IP", conflicts_with = "domain")]
pub reverse: Option<String>,
#[arg(short = 'v', long = "verbose")]
pub verbose: bool,
#[arg(
long = "type",
short = 't',
value_enum,
num_args = 1..,
default_values = ["a"],
ignore_case = true
)]
pub record_types: Vec<RType>,
#[arg(long, short = 'd')]
pub dnssec: bool,
#[arg(long)]
pub trace: bool,
#[arg(long, value_name = "URL")]
pub doh: Option<String>,
#[arg(long, value_name = "IP:PORT")]
pub dot: Option<String>,
#[arg(long, value_name = "IP:PORT", conflicts_with_all = ["doh", "dot"])]
pub doq: Option<String>,
#[arg(long, short = 's', value_name = "ADDR")]
pub server: Option<String>,
#[arg(long, short = 'o', value_enum, default_value = "colored")]
pub output: OutputFormat,
#[arg(long)]
pub short: bool,
#[arg(long, value_name = "SECS", value_parser = clap::value_parser!(u64).range(1..))]
pub watch: Option<u64>,
#[arg(long, value_name = "ADDR", action = clap::ArgAction::Append)]
pub compare: Vec<String>,
#[arg(long, conflicts_with_all = ["dnssec", "trace", "watch", "doq", "doh", "dot"], requires = "server")]
pub axfr: bool,
#[arg(long, conflicts_with_all = ["doh", "dot", "doq"], requires = "server")]
pub tcp: bool,
#[arg(long)]
pub no_recurse: bool,
#[arg(long, value_name = "SECS", default_value = "5", value_parser = clap::value_parser!(u64).range(1..=60))]
pub timeout: u64,
#[arg(short = 'f', long = "file", value_name = "FILE", conflicts_with_all = ["domain", "reverse"])]
pub file: Option<std::path::PathBuf>,
#[arg(short = '4', conflicts_with = "ipv6_only")]
pub ipv4_only: bool,
#[arg(short = '6', conflicts_with = "ipv4_only")]
pub ipv6_only: bool,
#[cfg(feature = "tui")]
#[arg(long)]
pub tui: bool,
}
#[derive(Debug, Clone, ValueEnum)]
pub enum RType {
A,
Aaaa,
Mx,
Ns,
Txt,
Cname,
Soa,
Ptr,
Srv,
Https,
Svcb,
Naptr,
Dnskey,
Ds,
Rrsig,
Caa,
Tlsa,
Sshfp,
Nsec,
Nsec3,
Any,
}
impl RType {
pub fn to_record_type(&self) -> hickory_proto::rr::RecordType {
use hickory_proto::rr::RecordType;
match self {
RType::A => RecordType::A,
RType::Aaaa => RecordType::AAAA,
RType::Mx => RecordType::MX,
RType::Ns => RecordType::NS,
RType::Txt => RecordType::TXT,
RType::Cname => RecordType::CNAME,
RType::Soa => RecordType::SOA,
RType::Ptr => RecordType::PTR,
RType::Srv => RecordType::SRV,
RType::Https => RecordType::HTTPS,
RType::Svcb => RecordType::SVCB,
RType::Naptr => RecordType::NAPTR,
RType::Dnskey => RecordType::DNSKEY,
RType::Ds => RecordType::DS,
RType::Rrsig => RecordType::RRSIG,
RType::Caa => RecordType::CAA,
RType::Tlsa => RecordType::TLSA,
RType::Sshfp => RecordType::SSHFP,
RType::Nsec => RecordType::NSEC,
RType::Nsec3 => RecordType::NSEC3,
RType::Any => RecordType::ANY,
}
}
}
fn validate_domain_or_stdin(s: &str) -> std::result::Result<String, String> {
if s == "-" {
return Ok(s.to_string());
}
validate_domain(s)
}
pub fn validate_domain(s: &str) -> std::result::Result<String, String> {
let trimmed = s.trim_end_matches('.');
if trimmed.is_empty() {
return Err("domain name cannot be empty".to_string());
}
if trimmed.len() > 253 {
return Err(format!(
"domain name too long ({} chars, RFC 1035 max 253)",
trimmed.len()
));
}
for label in trimmed.split('.') {
if label.is_empty() {
return Err("domain name contains an empty label".to_string());
}
if label.len() > 63 {
return Err(format!(
"label '{label}' too long ({} chars, RFC 1035 max 63)",
label.len()
));
}
}
Ok(s.to_string())
}
#[derive(Debug, Clone, ValueEnum, PartialEq)]
pub enum OutputFormat {
Colored,
Plain,
Json,
}