use std::error;
use std::fmt;
use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, UdpSocket};
use std::time::Duration;
pub const DEFAULT_DNS_SERVERS: &[&str] = &[
"1.1.1.1", "1.0.0.1", "8.8.8.8", "8.8.4.4", ];
pub const DNS_TYPE_A: u16 = 1; pub const DNS_TYPE_MX: u16 = 15; pub const DNS_TYPE_AAAA: u16 = 28;
#[derive(Debug)]
pub enum Error {
Io(io::Error),
Timeout,
ServerError(u16),
NoRecordsFound,
MalformedPacket,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Io(err) => write!(f, "IO error: {}", err),
Error::Timeout => write!(f, "DNS query timed out"),
Error::ServerError(code) => write!(f, "DNS server returned error code: {}", code),
Error::NoRecordsFound => write!(f, "No DNS records found"),
Error::MalformedPacket => write!(f, "Malformed DNS packet"),
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Error::Io(err) => Some(err),
_ => None,
}
}
}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Self {
Error::Io(err)
}
}
#[derive(Debug)]
pub struct DnsHeader {
pub id: u16,
pub flags: u16,
pub questions: u16,
pub answers: u16,
pub authorities: u16,
pub additionals: u16,
}
pub fn parse_dns_header(buffer: &[u8]) -> Result<DnsHeader, Error> {
if buffer.len() < 12 {
return Err(Error::MalformedPacket);
}
let header = DnsHeader {
id: u16::from_be_bytes([buffer[0], buffer[1]]),
flags: u16::from_be_bytes([buffer[2], buffer[3]]),
questions: u16::from_be_bytes([buffer[4], buffer[5]]),
answers: u16::from_be_bytes([buffer[6], buffer[7]]),
authorities: u16::from_be_bytes([buffer[8], buffer[9]]),
additionals: u16::from_be_bytes([buffer[10], buffer[11]]),
};
let rcode = header.flags & 0x0F;
if rcode != 0 {
return Err(Error::ServerError(rcode));
}
Ok(header)
}
pub fn skip_question(buffer: &[u8], mut pos: usize) -> Result<usize, Error> {
if pos >= buffer.len() {
return Err(Error::MalformedPacket);
}
while pos < buffer.len() {
let len = buffer[pos] as usize;
if len == 0 {
pos += 1;
break;
}
if len & 0xC0 == 0xC0 {
pos += 2;
break;
}
pos += len + 1;
if pos >= buffer.len() {
return Err(Error::MalformedPacket);
}
}
if pos + 4 > buffer.len() {
return Err(Error::MalformedPacket);
}
Ok(pos + 4)
}
#[derive(Debug, Clone)]
pub enum RecordData {
MX {
priority: u16,
server: String,
},
A(Ipv4Addr),
AAAA(Ipv6Addr),
Unknown,
}
#[derive(Debug)]
pub struct DnsRecord {
pub record_type: u16,
pub data: RecordData,
}
pub fn parse_answer(buffer: &[u8], mut pos: usize) -> Result<(DnsRecord, usize), Error> {
if pos >= buffer.len() {
return Err(Error::MalformedPacket);
}
while pos < buffer.len() {
let len = buffer[pos] as usize;
if len == 0 {
pos += 1;
break;
}
if len & 0xC0 == 0xC0 {
pos += 2;
break;
}
pos += len + 1;
if pos >= buffer.len() {
return Err(Error::MalformedPacket);
}
}
if pos + 10 > buffer.len() {
return Err(Error::MalformedPacket);
}
let record_type = u16::from_be_bytes([buffer[pos], buffer[pos + 1]]);
pos += 4;
pos += 4;
let data_len = u16::from_be_bytes([buffer[pos], buffer[pos + 1]]) as usize;
pos += 2;
if pos + data_len > buffer.len() {
return Err(Error::MalformedPacket);
}
let data = match record_type {
DNS_TYPE_MX => {
if data_len < 2 {
return Err(Error::MalformedPacket);
}
let priority = u16::from_be_bytes([buffer[pos], buffer[pos + 1]]);
pos += 2;
let server = parse_dns_name(buffer, pos)?;
pos += data_len - 2;
RecordData::MX { priority, server }
}
DNS_TYPE_A => {
if data_len != 4 {
return Err(Error::MalformedPacket);
}
let ipv4 = Ipv4Addr::new(
buffer[pos],
buffer[pos + 1],
buffer[pos + 2],
buffer[pos + 3],
);
pos += 4;
RecordData::A(ipv4)
}
DNS_TYPE_AAAA => {
if data_len != 16 {
return Err(Error::MalformedPacket);
}
let mut ipv6_bytes = [0u8; 16];
ipv6_bytes.copy_from_slice(&buffer[pos..pos + 16]);
let ipv6 = Ipv6Addr::from(ipv6_bytes);
pos += 16;
RecordData::AAAA(ipv6)
}
_ => {
pos += data_len;
RecordData::Unknown
}
};
Ok((DnsRecord { record_type, data }, pos))
}
pub fn parse_dns_name(buffer: &[u8], mut pos: usize) -> Result<String, Error> {
if pos >= buffer.len() {
return Err(Error::MalformedPacket);
}
let mut name = String::new();
let mut first = true;
let mut jumps = 0;
const MAX_JUMPS: usize = 10;
loop {
if pos >= buffer.len() {
return Err(Error::MalformedPacket);
}
let len = buffer[pos] as usize;
pos += 1;
if len == 0 {
break;
}
if len & 0xC0 == 0xC0 {
if pos >= buffer.len() {
return Err(Error::MalformedPacket);
}
jumps += 1;
if jumps > MAX_JUMPS {
return Err(Error::MalformedPacket);
}
let offset = ((len & 0x3F) << 8) | buffer[pos] as usize;
if !first {
name.push('.');
}
let remainder = parse_dns_name(buffer, offset)?;
name.push_str(&remainder);
break;
}
if pos + len > buffer.len() {
return Err(Error::MalformedPacket);
}
if !first {
name.push('.');
}
first = false;
name.push_str(&String::from_utf8_lossy(&buffer[pos..pos + len]));
pos += len;
}
Ok(name)
}
#[derive(Debug, Clone, PartialEq)]
pub struct MxRecord {
pub priority: u16,
pub server: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ServerIpRecord {
pub server: String,
pub ip_addresses: Vec<IpAddr>,
}
#[derive(Debug, Clone)]
pub struct DnsConfig {
pub servers: Vec<String>,
pub timeout: u64,
}
impl Default for DnsConfig {
fn default() -> Self {
Self {
servers: DEFAULT_DNS_SERVERS.iter().map(|s| s.to_string()).collect(),
timeout: 5,
}
}
}
pub fn build_dns_query(domain: &str, record_type: u16) -> Result<Vec<u8>, Error> {
let mut packet = Vec::new();
let id = 1_i16;
packet.extend_from_slice(&id.to_be_bytes());
packet.extend_from_slice(&[0x01, 0x00]); packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]);
for part in domain.split('.') {
if part.is_empty() {
continue;
}
if part.len() > 63 {
return Err(Error::MalformedPacket);
}
packet.push(part.len() as u8);
packet.extend_from_slice(part.as_bytes());
}
packet.push(0);
packet.extend_from_slice(&record_type.to_be_bytes());
packet.extend_from_slice(&[0x00, 0x01]);
Ok(packet)
}
pub fn parse_mx_records(buffer: &[u8]) -> Result<Vec<MxRecord>, Error> {
let mut records = Vec::new();
let header = parse_dns_header(buffer)?;
if header.answers == 0 {
return Err(Error::NoRecordsFound);
}
let mut pos = 12; for _ in 0..header.questions {
pos = skip_question(buffer, pos)?;
}
for _ in 0..header.answers {
let (record, new_pos) = parse_answer(buffer, pos)?;
pos = new_pos;
if record.record_type == DNS_TYPE_MX {
if let RecordData::MX { priority, server } = record.data {
records.push(MxRecord { priority, server });
}
}
}
if records.is_empty() {
return Err(Error::NoRecordsFound);
}
records.sort_by_key(|r| r.priority);
Ok(records)
}
pub fn parse_ip_records(buffer: &[u8]) -> Result<Vec<IpAddr>, Error> {
let mut ips = Vec::new();
let header = parse_dns_header(buffer)?;
let mut pos = 12; for _ in 0..header.questions {
pos = skip_question(buffer, pos)?;
}
for _ in 0..header.answers {
let (record, new_pos) = parse_answer(buffer, pos)?;
pos = new_pos;
match record.data {
RecordData::A(ipv4) => ips.push(IpAddr::V4(ipv4)),
RecordData::AAAA(ipv6) => ips.push(IpAddr::V6(ipv6)),
_ => {}
}
}
if ips.is_empty() {
return Err(Error::NoRecordsFound);
}
Ok(ips)
}
pub fn lookup_dns_records(
domain: &str,
record_type: u16,
config: Option<DnsConfig>,
) -> Result<Vec<u8>, Error> {
let config = config.unwrap_or_default();
let query = build_dns_query(domain, record_type)?;
for server in &config.servers {
let server_addr = format!("{}:53", server);
match try_dns_query(&server_addr, &query, config.timeout) {
Ok(response) => return Ok(response),
Err(Error::Timeout) | Err(Error::Io(_)) => continue, Err(e) => return Err(e),
}
}
Err(Error::Timeout)
}
fn try_dns_query(server: &str, query: &[u8], timeout: u64) -> Result<Vec<u8>, Error> {
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.set_read_timeout(Some(Duration::from_secs(timeout)))?;
socket.connect(server)?;
socket.send(query)?;
let mut buffer = [0; 512];
let size = socket.recv(&mut buffer)?;
Ok(buffer[..size].to_vec())
}
pub fn lookup_mx_records(domain: &str) -> Result<Vec<MxRecord>, Error> {
lookup_mx_records_with_config(domain, None)
}
pub fn lookup_mx_records_with_config(
domain: &str,
config: Option<DnsConfig>,
) -> Result<Vec<MxRecord>, Error> {
let response = lookup_dns_records(domain, DNS_TYPE_MX, config)?;
parse_mx_records(&response)
}
pub fn lookup_ip_addresses(hostname: &str) -> Result<Vec<IpAddr>, Error> {
lookup_ip_addresses_with_config(hostname, None)
}
pub fn lookup_ip_addresses_with_config(
hostname: &str,
config: Option<DnsConfig>,
) -> Result<Vec<IpAddr>, Error> {
let mut ips = Vec::new();
match lookup_dns_records(hostname, DNS_TYPE_A, config.clone()) {
Ok(response) => match parse_ip_records(&response) {
Ok(v4_ips) => ips.extend(v4_ips),
Err(Error::NoRecordsFound) => {} Err(e) => return Err(e),
},
Err(Error::NoRecordsFound) => {} Err(e) => return Err(e),
}
match lookup_dns_records(hostname, DNS_TYPE_AAAA, config) {
Ok(response) => match parse_ip_records(&response) {
Ok(v6_ips) => ips.extend(v6_ips),
Err(Error::NoRecordsFound) => {} Err(e) => return Err(e),
},
Err(Error::NoRecordsFound) => {} Err(e) => return Err(e),
}
if ips.is_empty() {
return Err(Error::NoRecordsFound);
}
Ok(ips)
}
pub fn resolve_mx_server_ips(domain: &str) -> Result<Vec<ServerIpRecord>, Error> {
resolve_mx_server_ips_with_config(domain, None)
}
pub fn resolve_mx_server_ips_with_config(
domain: &str,
config: Option<DnsConfig>,
) -> Result<Vec<ServerIpRecord>, Error> {
let mx_records = lookup_mx_records_with_config(domain, config.clone())?;
let mut server_ips = Vec::new();
for mx in mx_records {
match lookup_ip_addresses_with_config(&mx.server, config.clone()) {
Ok(ips) => {
server_ips.push(ServerIpRecord {
server: mx.server,
ip_addresses: ips,
});
}
Err(Error::NoRecordsFound) => {} Err(e) => return Err(e),
}
}
if server_ips.is_empty() {
return Err(Error::NoRecordsFound);
}
Ok(server_ips)
}