use clap::Parser;
use hickory_client::rr::RecordType;
use std::{net::IpAddr, str::FromStr, time::Duration};
#[expect(clippy::struct_excessive_bools, reason = "Those are flags, not states")]
#[derive(Parser, Debug, Clone, PartialEq, Eq)]
#[command(version, about)]
pub struct Args {
pub domain: String,
#[arg(short = 'c', long)]
pub no_positive_cache: bool,
#[arg(short = 'C', long)]
pub negative_cache: bool,
#[arg(short = 'e', long)]
pub no_edns0: bool,
#[arg(short = 'o', long)]
pub overview: bool,
#[arg(short = 'q', long, default_value = "A", value_parser = RecordType::from_str)]
pub query_type: RecordType,
#[arg(short = 'r', long, default_value = "3")]
pub retries: usize,
#[arg(short, long, default_value = "a.root-servers.net")]
pub server: String,
#[arg(short = 't', long, default_value = "5", value_parser = parse_duration)]
pub timeout: Duration,
#[arg(short = 'S', long)]
pub source_address: Option<IpAddr>,
#[arg(short = '6', long)]
pub ipv6: bool,
#[arg(short = '4', long)]
pub ipv4: bool,
#[arg(short = 'T', long)]
pub tcp: bool,
}
impl Args {
pub fn validate(&mut self) -> Result<(), String> {
match self.source_address {
Some(IpAddr::V4(ip)) => {
if self.ipv6 {
return Err(format!(
"Cannot use IPv6 only queries with an ipv4 source address ({ip})"
));
}
self.ipv4 = true;
}
Some(IpAddr::V6(ip)) => {
if self.ipv4 {
return Err(format!(
"Cannot use IPv4 only queries with an ipv6 source address ({ip})"
));
}
self.ipv6 = true;
}
None => (),
}
Ok(())
}
}
fn parse_duration(src: &str) -> Result<Duration, String> {
src.parse::<u64>()
.map(Duration::from_secs)
.map_err(|_| format!("Invalid duration: {src}"))
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
use std::net::IpAddr;
#[test]
fn test_default_values() {
let args = Args::try_parse_from(["test", "example.com"]).unwrap();
assert_eq!(args.domain, "example.com");
assert!(!args.no_positive_cache);
assert!(!args.negative_cache);
assert!(!args.no_edns0);
assert!(!args.overview);
assert_eq!(args.query_type, RecordType::A);
assert_eq!(args.retries, 3);
assert_eq!(args.server, "a.root-servers.net");
assert_eq!(args.timeout, Duration::from_secs(5));
assert!(args.source_address.is_none());
assert!(!args.ipv6);
assert!(!args.ipv4);
assert!(!args.tcp);
}
#[test]
fn test_all_flags() {
let args = Args::try_parse_from([
"test",
"-c", "-C", "-e", "-o", "-q",
"NS", "-r",
"5", "-s",
"8.8.8.8", "-t",
"10", "-S",
"192.168.0.1", "-6", "-T", "example.com",
])
.unwrap();
assert!(args.no_positive_cache);
assert!(args.negative_cache);
assert!(args.no_edns0);
assert!(args.overview);
assert_eq!(args.query_type, RecordType::NS);
assert_eq!(args.retries, 5);
assert_eq!(args.server, "8.8.8.8");
assert_eq!(args.timeout, Duration::from_secs(10));
assert_eq!(
args.source_address,
Some(IpAddr::from_str("192.168.0.1").unwrap())
);
assert!(args.ipv6);
assert!(!args.ipv4);
assert!(args.tcp);
}
#[test]
fn test_ipv4_flag() {
let args = Args::try_parse_from(["test", "example.com", "-4"]).unwrap();
assert!(args.ipv4);
assert!(!args.ipv6);
}
#[test]
fn test_with_server_override() {
let args = Args::try_parse_from(["test", "-s", "1.1.1.1", "example.com"]).unwrap();
assert_eq!(args.server, "1.1.1.1");
}
#[test]
fn test_with_query_type() {
let args = Args::try_parse_from(["test", "example.com", "-q", "AAAA"]).unwrap();
assert_eq!(args.query_type, RecordType::AAAA);
}
#[test]
fn test_invalid_query_type() {
let result = Args::try_parse_from(["test", "example.com", "-q", "INVALID"]);
assert!(result.is_err()); }
#[test]
fn test_with_source_address_v4() {
let mut args = Args::try_parse_from(["test", "example.com", "-S", "1.1.1.1"]).unwrap();
let validated = args.validate();
assert!(validated.is_ok());
assert_eq!(args.source_address, Some("1.1.1.1".parse().unwrap()));
assert!(args.ipv4);
assert!(!args.ipv6);
}
#[test]
fn test_with_source_address_v6() {
let mut args = Args::try_parse_from(["test", "example.com", "-S", "2001:db8::1"]).unwrap();
let validated = args.validate();
assert!(validated.is_ok());
assert_eq!(args.source_address, Some("2001:db8::1".parse().unwrap()));
assert!(!args.ipv4);
assert!(args.ipv6);
}
#[test]
fn test_with_source_address_v4_and_ipv6() {
let mut args =
Args::try_parse_from(["test", "example.com", "-6", "-S", "1.1.1.1"]).unwrap();
let validated = args.validate();
assert!(validated.is_err());
assert_eq!(
validated.unwrap_err(),
"Cannot use IPv6 only queries with an ipv4 source address (1.1.1.1)"
);
}
#[test]
fn test_with_source_address_v6_and_ipv4() {
let mut args =
Args::try_parse_from(["test", "example.com", "-4", "-S", "2001:db8::1"]).unwrap();
let validated = args.validate();
assert!(validated.is_err());
assert_eq!(
validated.unwrap_err(),
"Cannot use IPv4 only queries with an ipv6 source address (2001:db8::1)"
);
}
}