#![allow(clippy::uninlined_format_args)]
use clap::{Args, Parser, Subcommand, ValueHint};
use mtop::ping::{Bundle, DnsPinger};
use mtop::{profile, sig};
use mtop_client::dns::{DnsClient, Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordType};
use std::fmt::Write;
use std::io::Cursor;
use std::net::SocketAddr;
use std::num::NonZeroU64;
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::runtime::Handle;
use tracing::{Instrument, Level};
const DEFAULT_RECORD_TYPE: RecordType = RecordType::A;
const DEFAULT_RECORD_CLASS: RecordClass = RecordClass::INET;
const DEFAULT_RESOLV_CONF: &str = "/etc/resolv.conf";
const MIN_PING_INTERVAL_SECS: f64 = 0.1;
#[derive(Debug, Parser)]
#[command(name = "dns", version = clap::crate_version!())]
struct DnsConfig {
#[arg(long, default_value_t = Level::INFO)]
log_level: Level,
#[arg(long, value_hint = ValueHint::FilePath)]
profile_output: Option<PathBuf>,
#[command(subcommand)]
mode: Action,
}
#[derive(Debug, Subcommand)]
enum Action {
Ping(PingCommand),
Query(QueryCommand),
Read(ReadCommand),
Write(WriteCommand),
}
#[derive(Debug, Args)]
struct PingCommand {
#[arg(long, value_parser = parse_interval, default_value_t = 1.0)]
interval_secs: f64,
#[arg(long, default_value_t = 0)]
count: u64,
#[arg(long, default_value = DEFAULT_RESOLV_CONF, value_hint = ValueHint::FilePath)]
resolv_conf: PathBuf,
#[arg(long, value_hint = ValueHint::Hostname)]
nameserver: Option<SocketAddr>,
#[arg(long)]
nameserver_timeout_secs: Option<NonZeroU64>,
#[arg(long, default_value_t = DEFAULT_RECORD_TYPE)]
rtype: RecordType,
#[arg(long, default_value_t = DEFAULT_RECORD_CLASS)]
rclass: RecordClass,
#[arg(required = true)]
name: Name,
}
fn parse_interval(s: &str) -> Result<f64, String> {
match s.parse() {
Ok(v) if v >= MIN_PING_INTERVAL_SECS => Ok(v),
Ok(_) => Err(format!("must be at least {}", MIN_PING_INTERVAL_SECS)),
Err(e) => Err(e.to_string()),
}
}
#[derive(Debug, Args)]
struct QueryCommand {
#[arg(long, default_value = DEFAULT_RESOLV_CONF, value_hint = ValueHint::FilePath)]
resolv_conf: PathBuf,
#[arg(long, value_hint = ValueHint::Hostname)]
nameserver: Option<SocketAddr>,
#[arg(long)]
nameserver_timeout_secs: Option<NonZeroU64>,
#[arg(long, default_value_t = false)]
raw: bool,
#[arg(long, default_value_t = DEFAULT_RECORD_TYPE)]
rtype: RecordType,
#[arg(long, default_value_t = DEFAULT_RECORD_CLASS)]
rclass: RecordClass,
#[arg(required = true)]
name: Name,
}
#[derive(Debug, Args)]
struct ReadCommand {}
#[derive(Debug, Args)]
struct WriteCommand {
#[arg(long, default_value_t = DEFAULT_RECORD_TYPE)]
rtype: RecordType,
#[arg(long, default_value_t = DEFAULT_RECORD_CLASS)]
rclass: RecordClass,
#[arg(required = true)]
name: Name,
}
#[tokio::main]
async fn main() -> ExitCode {
let opts = DnsConfig::parse();
let console_subscriber =
mtop::tracing::console_subscriber(opts.log_level).expect("failed to setup console logging");
tracing::subscriber::set_global_default(console_subscriber).expect("failed to initialize console logging");
let profiling = profile::Writer::default();
let code = match &opts.mode {
Action::Ping(cmd) => run_ping(cmd).await,
Action::Query(cmd) => run_query(cmd).await,
Action::Read(cmd) => run_read(cmd).await,
Action::Write(cmd) => run_write(cmd).await,
};
if let Some(p) = opts.profile_output {
profiling.finish(p);
}
code
}
async fn run_ping(cmd: &PingCommand) -> ExitCode {
let client = mtop::dns::new_client(
&cmd.resolv_conf,
cmd.nameserver,
cmd.nameserver_timeout_secs.map(|v| Duration::from_secs(v.get())),
)
.instrument(tracing::span!(Level::INFO, "dns::new_client"))
.await;
let stop = Arc::new(AtomicBool::new(false));
sig::wait_for_interrupt(Handle::current(), stop.clone()).await;
let ping = DnsPinger::new(client, Duration::from_secs_f64(cmd.interval_secs), stop.clone());
let result = ping.run(cmd.name.clone(), cmd.rtype, cmd.rclass, cmd.count).await;
print_ping_results(&result);
ExitCode::SUCCESS
}
async fn run_query(cmd: &QueryCommand) -> ExitCode {
let client = mtop::dns::new_client(
&cmd.resolv_conf,
cmd.nameserver,
cmd.nameserver_timeout_secs.map(|v| Duration::from_secs(v.get())),
)
.instrument(tracing::span!(Level::INFO, "dns::new_client"))
.await;
let response = match client
.resolve(MessageId::random(), cmd.name.clone(), cmd.rtype, cmd.rclass)
.instrument(tracing::span!(Level::INFO, "dns_client.resolve"))
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!(message = "unable to perform DNS query", name = %cmd.name, err = %e);
return ExitCode::FAILURE;
}
};
if cmd.raw {
write_binary_message(&response).await
} else {
write_text_message(&response).await
}
}
async fn run_read(_: &ReadCommand) -> ExitCode {
let mut buf = Vec::new();
let mut input = tokio::io::stdin();
let n = match input.read_to_end(&mut buf).await {
Ok(n) => n,
Err(e) => {
tracing::error!(message = "unable to read message from stdin", err = %e);
return ExitCode::FAILURE;
}
};
let cur = Cursor::new(&buf[0..n]);
let msg = match Message::read_network_bytes(cur) {
Ok(m) => m,
Err(e) => {
tracing::error!(message = "malformed message", err = %e);
return ExitCode::FAILURE;
}
};
write_text_message(&msg).await;
ExitCode::SUCCESS
}
async fn run_write(cmd: &WriteCommand) -> ExitCode {
let id = MessageId::random();
let name = cmd.name.clone().to_fqdn();
let msg = Message::new(id, Flags::default().set_query().set_recursion_desired())
.add_question(Question::new(name, cmd.rtype).set_qclass(cmd.rclass));
write_binary_message(&msg).await
}
async fn write_binary_message(msg: &Message) -> ExitCode {
let mut buf = Vec::new();
if let Err(e) = msg.write_network_bytes(&mut buf) {
tracing::error!(message = "unable to encode message to wire format", err = %e);
return ExitCode::FAILURE;
}
let mut out = tokio::io::stdout();
if let Err(e) = out.write_all(&buf).await {
tracing::error!(message = "unable to write wire encoded message to stdout", err = %e);
return ExitCode::FAILURE;
}
ExitCode::SUCCESS
}
async fn write_text_message(msg: &Message) -> ExitCode {
let mut buf = String::new();
format_header(&mut buf, msg);
format_question(&mut buf, msg.questions());
format_answer(&mut buf, msg.answers());
format_authority(&mut buf, msg.authority());
format_extra(&mut buf, msg.extra());
let mut out = tokio::io::stdout();
if let Err(e) = out.write_all(buf.as_bytes()).await {
tracing::error!(message = "unable to write human readable message to stdout", err = %e);
return ExitCode::FAILURE;
}
ExitCode::SUCCESS
}
fn format_header(buf: &mut String, msg: &Message) {
let _ = writeln!(
buf,
";; >>HEADER<< opcode: {:?}, status: {:?}, id: {}",
msg.flags().get_op_code(),
msg.flags().get_response_code(),
msg.id()
);
let _ = writeln!(buf, ";; flags: {:?}", msg.flags());
}
fn format_question(buf: &mut String, questions: &[Question]) {
let _ = writeln!(buf, ";; QUESTION SECTION:");
for q in questions {
let _ = writeln!(buf, "; {}\t\t\t{}\t{}", q.name(), q.qclass(), q.qtype());
}
}
fn format_authority(buf: &mut String, records: &[Record]) {
let _ = writeln!(buf, ";; AUTHORITY SECTION:");
format_records(buf, records);
}
fn format_answer(buf: &mut String, records: &[Record]) {
let _ = writeln!(buf, ";; ANSWER SECTION:");
format_records(buf, records);
}
fn format_extra(buf: &mut String, records: &[Record]) {
let _ = writeln!(buf, ";; ADDITIONAL SECTION:");
format_records(buf, records);
}
fn format_records(buf: &mut String, records: &[Record]) {
for r in records {
let _ = writeln!(
buf,
"{}\t\t{}\t{}\t{}\t{}",
r.name(),
r.ttl(),
r.rclass(),
r.rtype(),
r.rdata()
);
}
}
fn print_ping_results(results: &Bundle) {
println!(
"type=timing min={:.6}s max={:.6}s avg={:.6}s stddev={:.6}s",
results.timing.min.as_secs_f64(),
results.timing.max.as_secs_f64(),
results.timing.avg.as_secs_f64(),
results.timing.std_dev.as_secs_f64(),
);
println!(
"type=queries total={} protocol_errors={} fatal_errors={}",
results.total, results.protocol_errors, results.fatal_errors,
);
}