mod cli;
mod display;
mod dnssec;
mod error;
mod resolver;
mod transport;
#[cfg(feature = "tui")]
mod tui;
use std::time::Duration;
use clap::Parser;
use hickory_proto::rr::RecordType;
use indicatif::{ProgressBar, ProgressStyle};
use cli::args::{Args, OutputFormat};
use cli::output::{
json::JsonRenderer, plain::PlainRenderer, short::ShortRenderer, table::ColoredRenderer, Render,
};
use futures_util::future::join_all;
#[tokio::main]
async fn main() {
let args = Args::parse();
let renderer: Box<dyn Render> = if args.short {
Box::new(ShortRenderer)
} else {
match args.output {
OutputFormat::Json => Box::new(JsonRenderer),
OutputFormat::Plain => Box::new(PlainRenderer),
OutputFormat::Colored => Box::new(ColoredRenderer),
}
};
let stdin_mode = args.domain.as_deref() == Some("-")
|| (args.domain.is_none()
&& args.reverse.is_none()
&& args.file.is_none()
&& !std::io::IsTerminal::is_terminal(&std::io::stdin()));
if args.domain.is_none() && args.reverse.is_none() && args.file.is_none() && !stdin_mode {
eprintln!("Error: missing domain. Provide a domain name, use -x <IP> for reverse lookup, -f <file> for batch, or pipe domains via stdin.");
std::process::exit(1);
}
#[cfg(feature = "tui")]
if args.tui {
let spinner = make_spinner();
let (domain, rtypes) = match resolve_effective_args(&args, None) {
Ok(v) => v,
Err(e) => {
eprintln!("Error: {e}");
std::process::exit(1);
}
};
let mut rtypes_iter = rtypes.into_iter();
let record_type = rtypes_iter.next().unwrap_or(RecordType::A);
if rtypes_iter.next().is_some() {
eprintln!("Note: TUI mode uses only the first record type");
}
spinner.set_message(format!("Loading data for {}...", domain));
let opts = match build_query_opts(&args, &domain, record_type).await {
Ok(o) => o,
Err(e) => {
spinner.finish_and_clear();
eprintln!("Error: {e}");
std::process::exit(1);
}
};
let resolver_ip = server_ip_from_args(&args);
let (records_res, dnssec_res, trace_res) = tokio::join!(
resolver::standard::query(&opts),
dnssec::build_chain(&domain, record_type, resolver_ip, args.verbose),
resolver::iterative::trace(&domain, record_type, resolver_ip),
);
spinner.finish_and_clear();
let records = records_res.unwrap_or_else(|e| {
eprintln!("Warning: DNS query failed: {e}");
std::process::exit(1);
});
let dnssec_chain = dnssec_res.unwrap_or_else(|e| {
eprintln!("Warning: DNSSEC chain failed: {e}");
std::process::exit(1);
});
let trace = trace_res.unwrap_or_else(|e| {
eprintln!("Warning: trace failed: {e}");
std::process::exit(1);
});
if let Err(e) = tui::run(domain, records, dnssec_chain, trace).await {
eprintln!("TUI error: {e}");
std::process::exit(1);
}
return;
}
if let Some(ref path) = args.file {
use std::io::BufRead;
let file = match std::fs::File::open(path) {
Ok(f) => f,
Err(e) => {
eprintln!("Error: cannot open file '{}': {e}", path.display());
std::process::exit(1);
}
};
let domains: Vec<String> = std::io::BufReader::new(file)
.lines()
.map_while(Result::ok)
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.collect();
if domains.is_empty() {
eprintln!("Error: no domains found in '{}'", path.display());
std::process::exit(1);
}
let mut any_failed = false;
for domain in &domains {
if let Err(e) = cli::args::validate_domain(domain) {
eprintln!("Error: invalid domain '{domain}': {e}");
any_failed = true;
continue;
}
if !run_once(&args, &*renderer, Some(domain.as_str())).await {
any_failed = true;
}
}
if any_failed {
std::process::exit(1);
}
return;
}
if stdin_mode {
use std::io::BufRead;
let domains: Vec<String> = std::io::stdin()
.lock()
.lines()
.map_while(Result::ok)
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.collect();
if domains.is_empty() {
eprintln!("Error: no domains read from stdin");
std::process::exit(1);
}
let mut any_failed = false;
for domain in &domains {
if let Err(e) = cli::args::validate_domain(domain) {
eprintln!("Error: invalid domain '{domain}': {e}");
any_failed = true;
continue;
}
if !run_once(&args, &*renderer, Some(domain.as_str())).await {
any_failed = true;
}
}
if any_failed {
std::process::exit(1);
}
return;
}
let mut iteration = 0u32;
loop {
iteration = iteration.saturating_add(1);
if iteration > 1 && std::io::IsTerminal::is_terminal(&std::io::stdout()) {
print!("\x1b[2J\x1b[H");
}
run_once(&args, &*renderer, None).await;
match args.watch {
Some(secs) => {
eprintln!("\n Refreshing in {secs}s — Ctrl+C to stop");
tokio::time::sleep(Duration::from_secs(secs)).await;
}
None => break,
}
}
}
fn bail(spinner: &ProgressBar, msg: &dyn std::fmt::Display) -> bool {
spinner.finish_and_clear();
eprintln!("Error: {msg}");
false
}
async fn run_once(args: &Args, renderer: &dyn Render, domain_override: Option<&str>) -> bool {
let spinner = make_spinner();
let resolver_ip = server_ip_from_args(args);
let (domain, record_types) = match resolve_effective_args(args, domain_override) {
Ok(v) => v,
Err(e) => return bail(&spinner, &e),
};
let primary_type = record_types[0];
if args.axfr {
let server = match &args.server {
None => return bail(&spinner, &"--axfr requires -s <server>"),
Some(s) => {
let addr_str = parse_server_addr(s);
match addr_str.parse::<std::net::SocketAddr>() {
Ok(a) => a,
Err(e) => return bail(&spinner, &format!("invalid --server address '{s}': {e}")),
}
}
};
spinner.set_message(format!("Fetching zone {} via AXFR from {}...", domain, server));
match resolver::zone_transfer::axfr(&domain, server, args.timeout).await {
Ok(result) => {
spinner.finish_and_clear();
print!("{}", renderer.render_records(&result));
true
}
Err(e) => bail(&spinner, &e),
}
} else if !args.compare.is_empty() {
if args.compare.len() == 1 {
let compare_addr = &args.compare[0];
spinner.set_message(format!("Comparing {} against {}...", domain, compare_addr));
let opts_left = match build_query_opts(args, &domain, primary_type).await {
Ok(o) => o,
Err(e) => return bail(&spinner, &e),
};
let compare_addr_str = parse_server_addr(compare_addr);
let compare_sock = match compare_addr_str.parse::<std::net::SocketAddr>() {
Ok(a) => a,
Err(e) => {
return bail(
&spinner,
&format!("invalid --compare address '{compare_addr}': {e}"),
)
}
};
let opts_right = resolver::QueryOptions {
domain: domain.clone(),
record_type: primary_type,
server: Some(compare_sock),
transport: None,
validate_dnssec: args.dnssec,
force_tcp: false,
no_recurse: args.no_recurse,
timeout_secs: args.timeout,
ipv4_only: args.ipv4_only,
ipv6_only: args.ipv6_only,
};
let (left_res, right_res) = tokio::join!(
resolver::standard::query(&opts_left),
resolver::standard::query(&opts_right),
);
spinner.finish_and_clear();
match (left_res, right_res) {
(Ok(left), Ok(right)) => {
let cmp = resolver::DnsComparison {
domain,
record_type: primary_type.to_string(),
left,
right,
};
print!("{}", renderer.render_compare(&cmp));
true
}
(Err(e1), Err(e2)) => {
eprintln!("Error (left): {e1}");
eprintln!("Error (right): {e2}");
false
}
(Err(e), _) | (_, Err(e)) => {
eprintln!("Error: {e}");
false
}
}
} else {
let compare_addrs = args.compare.iter().map(|a| a.as_str()).collect::<Vec<_>>();
spinner.set_message(format!(
"Querying {} across {} servers...",
domain,
compare_addrs.len() + 1
));
let opts_primary = match build_query_opts(args, &domain, primary_type).await {
Ok(o) => o,
Err(e) => return bail(&spinner, &e),
};
let mut all_opts = vec![opts_primary];
for addr in &args.compare {
let addr_str = parse_server_addr(addr);
let sock = match addr_str.parse::<std::net::SocketAddr>() {
Ok(a) => a,
Err(e) => return bail(&spinner, &format!("invalid --compare address '{addr}': {e}")),
};
all_opts.push(resolver::QueryOptions {
domain: domain.clone(),
record_type: primary_type,
server: Some(sock),
transport: None,
validate_dnssec: args.dnssec,
force_tcp: false,
no_recurse: args.no_recurse,
timeout_secs: args.timeout,
ipv4_only: args.ipv4_only,
ipv6_only: args.ipv6_only,
});
}
let results = join_all(all_opts.iter().map(|o| resolver::standard::query(o))).await;
spinner.finish_and_clear();
let mut query_results = Vec::new();
for (i, result) in results.into_iter().enumerate() {
match result {
Ok(r) => query_results.push(r),
Err(e) => eprintln!("Warning: server {i} failed: {e}"),
}
}
if query_results.is_empty() {
spinner.finish_and_clear();
eprintln!("Error: all servers failed");
return false;
}
let multi = resolver::DnsMultiQuery {
domain,
record_type: primary_type.to_string(),
results: query_results,
};
print!("{}", renderer.render_multi(&multi));
true
}
} else if args.trace {
spinner.set_message(format!("Tracing resolution path for {}...", domain));
match resolver::iterative::trace(&domain, primary_type, resolver_ip).await {
Ok(trace) => {
spinner.finish_and_clear();
print!("{}", renderer.render_trace(&trace));
true
}
Err(e) => bail(&spinner, &e),
}
} else if args.dnssec {
spinner.set_message(format!("Validating DNSSEC chain for {}...", domain));
match dnssec::build_chain(&domain, primary_type, resolver_ip, args.verbose).await {
Ok(chain) => {
spinner.finish_and_clear();
print!("{}", renderer.render_dnssec(&chain));
true
}
Err(e) => bail(&spinner, &e),
}
} else {
spinner.set_message(format!("Querying {}...", domain));
let mut cleared = false;
let mut success = true;
for &rtype in &record_types {
let opts = match build_query_opts(args, &domain, rtype).await {
Ok(o) => o,
Err(e) => return bail(&spinner, &e),
};
match resolver::standard::query(&opts).await {
Ok(result) => {
if !cleared {
spinner.finish_and_clear();
cleared = true;
}
print!("{}", renderer.render_records(&result));
}
Err(e) => {
if !cleared {
spinner.finish_and_clear();
cleared = true;
}
eprintln!("Error: {e}");
success = false;
}
}
}
if !cleared {
spinner.finish_and_clear();
}
success
}
}
fn make_spinner() -> ProgressBar {
let pb = ProgressBar::new_spinner();
pb.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.cyan} {msg}")
.expect("spinner template is a valid literal"),
);
pb.enable_steady_tick(Duration::from_millis(80));
pb
}
async fn build_query_opts(
args: &Args,
domain: &str,
record_type: RecordType,
) -> error::Result<resolver::QueryOptions> {
use std::net::SocketAddr;
let transport = if let Some(url) = &args.doh {
let (config, label) = transport::doh::build_doh_config(url).await?;
Some((config, label))
} else if let Some(addr) = &args.dot {
let (config, label) = transport::dot::build_dot_config(addr).await?;
Some((config, label))
} else if let Some(addr) = &args.doq {
let (config, label) = transport::doq::build_doq_config(addr).await?;
Some((config, label))
} else {
None
};
let server: Option<SocketAddr> = if transport.is_none() {
match &args.server {
None => None,
Some(s) => {
let addr_str = parse_server_addr(s);
match addr_str.parse::<SocketAddr>() {
Ok(addr) => Some(addr),
Err(e) => {
return Err(crate::error::ShoheError::Parse(format!(
"Invalid --server address '{s}': {e}. \
Use IP:PORT (e.g. 8.8.8.8:53) or bare IP."
)));
}
}
}
}
} else {
None
};
Ok(resolver::QueryOptions {
domain: domain.to_string(),
record_type,
server,
transport,
validate_dnssec: args.dnssec,
force_tcp: args.tcp,
no_recurse: args.no_recurse,
timeout_secs: args.timeout,
ipv4_only: args.ipv4_only,
ipv6_only: args.ipv6_only,
})
}
fn resolve_effective_args(
args: &Args,
domain_override: Option<&str>,
) -> error::Result<(String, Vec<RecordType>)> {
if let Some(ip_str) = &args.reverse {
let ptr_domain = ip_to_ptr_domain(ip_str)?;
Ok((ptr_domain, vec![RecordType::PTR]))
} else {
let domain = domain_override
.map(str::to_string)
.or_else(|| args.domain.clone())
.expect("domain is required when -x is not set and no override given");
let types = args.record_types.iter().map(|r| r.to_record_type()).collect();
Ok((domain, types))
}
}
fn ip_to_ptr_domain(ip_str: &str) -> error::Result<String> {
match ip_str.trim().parse::<std::net::IpAddr>() {
Ok(std::net::IpAddr::V4(v4)) => {
let o = v4.octets();
Ok(format!("{}.{}.{}.{}.in-addr.arpa", o[3], o[2], o[1], o[0]))
}
Ok(std::net::IpAddr::V6(v6)) => {
let nibbles: String = v6
.octets()
.iter()
.rev()
.flat_map(|b| {
let lo = b & 0xf;
let hi = b >> 4;
[format!("{lo:x}."), format!("{hi:x}.")]
})
.collect();
Ok(format!("{nibbles}ip6.arpa"))
}
Err(_) => Err(crate::error::ShoheError::Parse(format!(
"'-x' expects an IP address (e.g. 1.1.1.1 or 2606:4700::1), got: '{ip_str}'"
))),
}
}
fn server_ip_from_args(args: &Args) -> Option<std::net::IpAddr> {
args.server.as_ref().and_then(|s| {
parse_server_addr(s)
.parse::<std::net::SocketAddr>()
.ok()
.map(|sa| sa.ip())
})
}
fn parse_server_addr(s: &str) -> String {
if s.starts_with('[') {
return s.to_string();
}
let colon_count = s.chars().filter(|&c| c == ':').count();
if colon_count > 1 {
return format!("[{s}]:53");
}
if colon_count == 0 {
return format!("{s}:53");
}
s.to_string()
}