use std::{
fs::File,
io::{BufRead, BufReader},
net::{IpAddr, SocketAddr},
ops::Deref,
path::PathBuf,
sync::Arc,
time::Duration,
};
use clap::{ArgGroup, Parser};
use console::style;
use tokio::task::JoinSet;
use tokio::time::MissedTickBehavior;
use hickory_net::{DnsError, NetError, runtime::TokioRuntimeProvider};
use hickory_proto::rr::{Record, RecordData, RecordType};
use hickory_resolver::{
TokioResolver,
config::{
CLOUDFLARE, GOOGLE, NameServerConfig, ProtocolConfig, QUAD9, ResolverConfig, ResolverOpts,
},
lookup::Lookup,
};
#[derive(Debug, Parser)]
#[clap(name = "resolve",
group(ArgGroup::new("qtype").args(&["happy", "reverse", "ty"])),
group(ArgGroup::new("input").required(true).args(&["domainname", "inputfile"]))
)]
struct Opts {
domainname: Option<String>,
#[clap(
short = 'f',
long = "file",
value_parser,
value_name = "FILE",
conflicts_with("domainname")
)]
inputfile: Option<PathBuf>,
#[clap(short = 't', long = "type", default_value = "A")]
ty: RecordType,
#[clap(short = 'e', long = "happy", conflicts_with_all(&["reverse", "ty"]))]
happy: bool,
#[clap(short = 'r', long = "reverse", conflicts_with_all(&["happy", "ty"]))]
reverse: bool,
#[clap(short = 's', long = "system")]
system: bool,
#[clap(long)]
google: bool,
#[clap(long)]
cloudflare: bool,
#[clap(long)]
quad9: bool,
#[clap(short = 'n', long, use_value_delimiter = true, value_delimiter(','))]
nameserver: Vec<SocketAddr>,
#[clap(long)]
bind: Option<IpAddr>,
#[clap(long)]
ipv4: bool,
#[clap(long)]
ipv6: bool,
#[clap(long)]
udp: bool,
#[clap(long)]
tcp: bool,
#[clap(flatten)]
log_config: hickory_util::LogConfig,
#[clap(long, default_value = "1.0")]
interval: f32,
}
fn print_record<D: RecordData, R: Deref<Target = Record<D>>>(r: &R) {
println!(
"\t{name} {ttl} {class} {ty} {rdata}",
name = style(&r.name).blue(),
ttl = style(r.ttl).blue(),
class = style(r.dns_class).blue(),
ty = style(r.record_type()).blue(),
rdata = r.data,
);
}
fn print_ok(lookup: Lookup) {
println!(
"{} for query {}",
style("Success").green(),
style(lookup.query()).blue()
);
let message = lookup.message();
let answers = &message.answers;
if !answers.is_empty() {
println!("\n;; {} SECTION:", style("ANSWER").yellow());
for r in answers {
print_record(&r);
}
}
let authority = &message.authorities;
if !authority.is_empty() {
println!("\n;; {} SECTION:", style("AUTHORITY").yellow());
for r in authority {
print_record(&r);
}
}
let additional = &message.additionals;
if !additional.is_empty() {
println!("\n;; {} SECTION:", style("ADDITIONAL").yellow());
for r in additional {
print_record(&r);
}
}
}
fn print_error(error: NetError) {
let NetError::Dns(DnsError::NoRecordsFound(no_records)) = error else {
println!("{error:?}");
return;
};
println!(
"{} for query {}",
style("NoRecordsFound").red(),
style(&no_records.query).blue()
);
if let Some(r) = &no_records.soa {
print_record(r);
}
}
fn print_result(result: Result<Lookup, NetError>) {
match result {
Ok(lookup) => print_ok(lookup),
Err(re) => print_error(re),
}
}
fn log_query(name: &str, ty: RecordType, name_servers: &str, opts: &Opts) {
if opts.happy {
println!(
"Querying for {name} {ty} from {ns}",
name = style(name).yellow(),
ty = style("A+AAAA").yellow(),
ns = style(name_servers).blue()
);
} else if opts.reverse {
println!(
"Querying {reverse} for {name} from {ns}",
reverse = style("reverse").yellow(),
name = style(name).yellow(),
ns = style(name_servers).blue()
);
} else {
println!(
"Querying for {name} {ty} from {ns}",
name = style(name).yellow(),
ty = style(ty).yellow(),
ns = style(name_servers).blue()
);
}
}
async fn execute_query(
resolver: Arc<TokioResolver>,
name: String,
happy: bool,
reverse: bool,
ty: RecordType,
) -> Result<Lookup, NetError> {
if happy {
Ok(resolver.lookup_ip(name.to_string()).await?.into())
} else if reverse {
let v4addr = name
.parse::<IpAddr>()
.unwrap_or_else(|_| panic!("Could not parse {name} into an IP address"));
Ok(resolver.reverse_lookup(v4addr).await?)
} else {
Ok(resolver.lookup(name.to_string(), ty).await?)
}
}
#[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
let opts: Opts = Opts::parse();
hickory_util::logger(env!("CARGO_BIN_NAME"), opts.log_config.level());
let (sys_config, sys_options): (Option<ResolverConfig>, Option<ResolverOpts>) = if opts.system {
let (config, options) = hickory_resolver::system_conf::read_system_conf()?;
(Some(config), Some(options))
} else {
(None, None)
};
let mut name_servers = Vec::new();
for socket_addr in &opts.nameserver {
let mut config = NameServerConfig::udp_and_tcp(socket_addr.ip());
config.trust_negative_responses = false;
for conn in config.connections.iter_mut() {
conn.port = socket_addr.port();
conn.bind_addr = opts.bind.map(|ip| SocketAddr::new(ip, 0));
}
name_servers.push(config);
}
if opts.google {
name_servers.extend(GOOGLE.udp_and_tcp());
}
if opts.cloudflare {
name_servers.extend(CLOUDFLARE.udp_and_tcp());
}
if opts.quad9 {
name_servers.extend(QUAD9.udp_and_tcp());
}
if name_servers.is_empty() && sys_config.is_none() {
name_servers.extend(GOOGLE.udp_and_tcp());
}
let ipv4 = opts.ipv4 || !opts.ipv6;
let ipv6 = opts.ipv6 || !opts.ipv4;
let udp = opts.udp || !opts.tcp;
let tcp = opts.tcp || !opts.udp;
name_servers.retain(|ns| (ipv4 && ns.ip.is_ipv4()) || (ipv6 && ns.ip.is_ipv6()));
for ns in name_servers.iter_mut() {
ns.connections.retain(|conn| {
(udp && conn.protocol == ProtocolConfig::Udp)
|| (tcp && conn.protocol == ProtocolConfig::Tcp)
});
}
let mut config = sys_config.unwrap_or_default();
for ns in name_servers.iter() {
config.add_name_server(ns.clone());
}
let name_servers = config
.name_servers()
.iter()
.map(|ns| format!("{ns:#?}"))
.collect::<Vec<String>>()
.join(", ");
let mut options = sys_options.unwrap_or_default();
if opts.happy {
options.ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6;
}
let mut resolver_builder =
TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
*resolver_builder.options_mut() = options;
let resolver_arc = Arc::new(resolver_builder.build()?);
if let Some(domainname) = &opts.domainname {
log_query(domainname, opts.ty, &name_servers, &opts);
let lookup = execute_query(
resolver_arc,
domainname.to_owned(),
opts.happy,
opts.reverse,
opts.ty,
)
.await;
print_result(lookup);
} else {
let duration = Duration::from_secs_f32(opts.interval);
let fd = File::open(opts.inputfile.as_ref().unwrap())?;
let reader = BufReader::new(fd);
let mut taskset = JoinSet::new();
let mut timer = tokio::time::interval(duration);
timer.set_missed_tick_behavior(MissedTickBehavior::Burst);
for name in reader.lines().map_while(Result::ok) {
let (happy, reverse, ty) = (opts.happy, opts.reverse, opts.ty);
log_query(&name, ty, &name_servers, &opts);
let resolver = resolver_arc.clone();
taskset.spawn(async move { execute_query(resolver, name, happy, reverse, ty).await });
loop {
tokio::select! {
_ = timer.tick() => break,
lookup_opt = taskset.join_next() => match lookup_opt {
Some(lookup_rr) => {
print_result(lookup_rr?);
},
None => { timer.tick().await; break; }
}
};
}
}
}
Ok(())
}