use crate::ssh::SshClient;
use crate::tools::ToolsError;
#[derive(Debug, Clone)]
pub struct DnsQuery {
pub name: String,
pub record_type: DnsRecordType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DnsRecordType {
A,
AAAA,
CNAME,
MX,
TXT,
NS,
}
impl DnsRecordType {
fn dig_arg(self) -> &'static str {
match self {
DnsRecordType::A => "A",
DnsRecordType::AAAA => "AAAA",
DnsRecordType::CNAME => "CNAME",
DnsRecordType::MX => "MX",
DnsRecordType::TXT => "TXT",
DnsRecordType::NS => "NS",
}
}
}
#[derive(Debug, Clone)]
pub struct DnsAnswer {
pub perspective: String,
pub query: String,
pub record_type: DnsRecordType,
pub answers: Vec<String>,
pub error: Option<String>,
pub elapsed_ms: u64,
}
pub async fn dns_resolve_local(query: &DnsQuery) -> DnsAnswer {
let started = std::time::Instant::now();
let perspective = "local".to_string();
match query.record_type {
DnsRecordType::A | DnsRecordType::AAAA => {
let target = format!("{}:0", query.name);
match tokio::net::lookup_host(target).await {
Ok(iter) => {
let want_v6 = query.record_type == DnsRecordType::AAAA;
let answers: Vec<String> = iter
.filter(|sa| sa.is_ipv6() == want_v6)
.map(|sa| sa.ip().to_string())
.collect();
DnsAnswer {
perspective,
query: query.name.clone(),
record_type: query.record_type,
answers,
error: None,
elapsed_ms: started.elapsed().as_millis() as u64,
}
}
Err(e) => DnsAnswer {
perspective,
query: query.name.clone(),
record_type: query.record_type,
answers: vec![],
error: Some(e.to_string()),
elapsed_ms: started.elapsed().as_millis() as u64,
},
}
}
_ => {
let dig_cmd = format!(
"dig +short {} {}",
shell_safe(&query.name),
query.record_type.dig_arg()
);
match tokio::process::Command::new("sh")
.arg("-c")
.arg(&dig_cmd)
.output()
.await
{
Ok(out) if out.status.success() => DnsAnswer {
perspective,
query: query.name.clone(),
record_type: query.record_type,
answers: parse_dig_lines(&String::from_utf8_lossy(&out.stdout)),
error: None,
elapsed_ms: started.elapsed().as_millis() as u64,
},
Ok(out) => DnsAnswer {
perspective,
query: query.name.clone(),
record_type: query.record_type,
answers: vec![],
error: Some(String::from_utf8_lossy(&out.stderr).to_string()),
elapsed_ms: started.elapsed().as_millis() as u64,
},
Err(e) => DnsAnswer {
perspective,
query: query.name.clone(),
record_type: query.record_type,
answers: vec![],
error: Some(e.to_string()),
elapsed_ms: started.elapsed().as_millis() as u64,
},
}
}
}
}
pub async fn dns_resolve_remote(
client: &SshClient,
perspective_label: &str,
query: &DnsQuery,
) -> DnsAnswer {
let started = std::time::Instant::now();
let cmd = format!(
"dig +short {} {}",
shell_safe(&query.name),
query.record_type.dig_arg()
);
match client.execute_command_full(&cmd).await {
Ok(out) if out.is_success() => DnsAnswer {
perspective: perspective_label.to_string(),
query: query.name.clone(),
record_type: query.record_type,
answers: parse_dig_lines(&out.stdout),
error: None,
elapsed_ms: started.elapsed().as_millis() as u64,
},
Ok(out) => {
let answers = parse_dig_lines(&out.stdout);
let err = if answers.is_empty() && !out.stderr.trim().is_empty() {
Some(out.stderr.trim().to_string())
} else if answers.is_empty() && out.exit_code.unwrap_or(0) != 0 {
Some(format!(
"dig exited {}",
out.exit_code
.map(|c| c.to_string())
.unwrap_or_else(|| "?".into())
))
} else {
None
};
DnsAnswer {
perspective: perspective_label.to_string(),
query: query.name.clone(),
record_type: query.record_type,
answers,
error: err,
elapsed_ms: started.elapsed().as_millis() as u64,
}
}
Err(e) => DnsAnswer {
perspective: perspective_label.to_string(),
query: query.name.clone(),
record_type: query.record_type,
answers: vec![],
error: Some(e.to_string()),
elapsed_ms: started.elapsed().as_millis() as u64,
},
}
}
fn parse_dig_lines(out: &str) -> Vec<String> {
out.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty() && !l.starts_with(';'))
.map(|l| l.to_string())
.collect()
}
fn shell_safe(name: &str) -> String {
name.chars()
.filter(|c| c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_'))
.collect()
}
#[allow(dead_code)]
fn _ensure_error_unused(e: ToolsError) -> ToolsError {
e
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_dig_short_output() {
let raw = "\
1.2.3.4
5.6.7.8
;; Truncated, retrying in TCP mode.
";
let parsed = parse_dig_lines(raw);
assert_eq!(parsed, vec!["1.2.3.4", "5.6.7.8"]);
}
#[test]
fn shell_safe_strips_metachars() {
assert_eq!(shell_safe("example.com"), "example.com");
assert_eq!(shell_safe("ex; rm -rf /"), "exrm-rf");
assert_eq!(shell_safe("foo$(bad)"), "foobad");
}
}