use clap::Parser;
use eyre::{Result, WrapErr as _, bail};
use hickory_proto::rr::RecordType;
use std::{net::IpAddr, str::FromStr as _, time::Duration};
#[expect(clippy::struct_excessive_bools, reason = "Those are flags, not states")]
#[derive(Parser, Debug, Clone, PartialEq, Eq)]
#[command(version, about)]
pub struct Args {
pub domain: String,
#[arg(short = 'c', long)]
pub no_positive_cache: bool,
#[arg(short = 'C', long)]
pub negative_cache: bool,
#[arg(short = 'e', long)]
pub no_edns0: bool,
#[arg(short = 'o', long)]
pub overview: bool,
#[arg(short = 'q', long, default_value = "A", value_parser = parse_record_type)]
pub query_type: RecordType,
#[arg(short = 'r', long, default_value = "3")]
pub retries: usize,
#[arg(short, long, default_value = "a.root-servers.net")]
pub server: String,
#[arg(short = 't', long, default_value = "5", value_parser = parse_duration)]
pub timeout: Duration,
#[arg(short = 'S', long)]
pub source_address: Option<IpAddr>,
#[arg(short = '6', long)]
pub ipv6: bool,
#[arg(short = '4', long)]
pub ipv4: bool,
#[arg(short = 'T', long)]
pub tcp: bool,
}
fn parse_record_type(s: &str) -> Result<RecordType> {
Ok(RecordType::from_str(&s.to_ascii_uppercase())?)
}
impl Args {
pub fn validate(&mut self) -> Result<()> {
match self.source_address {
Some(IpAddr::V4(ip)) => {
if self.ipv6 {
bail!("Cannot use IPv6 only queries with an ipv4 source address ({ip})");
}
self.ipv4 = true;
}
Some(IpAddr::V6(ip)) => {
if self.ipv4 {
bail!("Cannot use IPv4 only queries with an ipv6 source address ({ip})");
}
self.ipv6 = true;
}
None => (),
}
Ok(())
}
}
fn parse_duration(src: &str) -> Result<Duration> {
src.parse::<u64>()
.map(Duration::from_secs)
.wrap_err_with(|| format!("Invalid duration: {src}"))
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used, reason = "test")]
use super::*;
use insta::assert_debug_snapshot;
use rstest::rstest;
#[rstest]
#[case("default_values", vec!["test", "example.com"])]
#[case("all_flags", vec![
"test",
"-c", // no_positive_cache
"-C", // negative_cache
"-e", // edns0 disabled
"-o", // overview enabled
"-q",
"NS", // query_type: NS
"-r",
"5", // retries: 5
"-s",
"8.8.8.8", // server: 8.8.8.8
"-t",
"10", // timeout: 10 seconds
"-S",
"192.168.0.1", // source_address: 192.168.0.1
"-6", // force IPv6
"-T", // use TCP
"example.com",
])]
#[case("ipv4_flag", vec!["test", "example.com", "-4"])]
#[case("server_override", vec!["test", "-s", "1.1.1.1", "example.com"])]
#[case("query_type", vec!["test", "example.com", "-q", "AAAA"])]
#[case("source_v4", vec!["test", "example.com", "-S", "1.1.1.1"])]
#[case("source_v6", vec!["test", "example.com", "-S", "2001:db8::1"])]
#[trace]
fn args(#[case] name: &str, #[case] input: Vec<&str>) {
let args = Args::parse_from(input);
assert_debug_snapshot!(format!("args_{name}"), args);
}
#[rstest]
#[case("soa")]
#[case("SoA")]
#[case("a")]
#[case("A")]
#[case("aAaA")]
#[case("Mx")]
#[case("dS")]
#[trace]
fn record_type(#[case] record: &str) {
let args = Args::parse_from(["test", "-q", record, "example.com"]);
assert_debug_snapshot!(format!("record_{record}"), args);
}
#[rstest]
#[case("invalid_query_type", vec!["test", "example.com", "-q", "INVALID"])]
#[case("invalid_retries", vec!["test", "example.com", "-r", "INVALID"])]
#[case("invalid_ipv4", vec!["test", "example.com", "-S", "5432.5432.234.12"])]
#[case("invalid_ipv6", vec!["test", "example.com", "-S", "2a0x::1"])]
#[trace]
fn bad_args(#[case] name: &str, #[case] input: Vec<&str>) {
let args = Args::try_parse_from(input).unwrap_err();
assert_debug_snapshot!(format!("bad_{name}"), args);
}
#[rstest]
#[case("source_v4_plus_ipv6_flag", vec!["test", "example.com", "-6", "-S", "1.1.1.1"])]
#[case("source_v6_plus_ipv4_flag", vec!["test", "example.com", "-4", "-S", "2001:db8::1"])]
#[trace]
fn not_valid(#[case] name: &str, #[case] input: Vec<&str>) {
let mut args = Args::parse_from(input);
let validated = args.validate().unwrap_err();
assert_debug_snapshot!(format!("not_valid_{name}"), validated);
}
}