use crate::dns_string::DnsString;
use crate::{
read_exact, read_u16_be, read_u32_be, write_bytes, write_u16_be, write_u32_be, DnsClass,
DnsError, DnsName, DnsType,
};
use core::fmt::{Debug, Formatter};
use fixed_buffer::FixedBuf;
use std::convert::TryFrom;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn check_empty<const SIZE: usize>(f: &FixedBuf<SIZE>) -> Result<(), DnsError> {
if f.is_empty() {
Ok(())
} else {
Err(DnsError::RecordHasAdditionalBytes)
}
}
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum DnsRecord {
A(DnsName, std::net::Ipv4Addr),
AAAA(DnsName, std::net::Ipv6Addr),
CNAME(DnsName, DnsName),
NS(DnsName, DnsName),
TXT(DnsName, Vec<DnsString>),
Unknown(DnsName, DnsType),
}
impl DnsRecord {
pub fn read_rdata<const N: usize>(buf: &mut FixedBuf<N>) -> Result<FixedBuf<65535>, DnsError> {
let len = read_u16_be(buf)?;
if buf.len() < (len as usize) {
return Err(DnsError::Truncated);
}
let borrowed_rdata = buf.read_bytes(len as usize);
let mut rdata: FixedBuf<65535> = FixedBuf::new();
rdata
.write_bytes(borrowed_rdata)
.map_err(|_| DnsError::Unreachable(file!(), line!()))?;
Ok(rdata)
}
pub fn write_rdata<const N: usize>(
bytes: &[u8],
out: &mut FixedBuf<N>,
) -> Result<(), DnsError> {
let len =
u16::try_from(bytes.len()).map_err(|_| DnsError::Unreachable(file!(), line!()))?;
write_u16_be(out, len)?;
write_bytes(out, bytes)?;
Ok(())
}
pub fn new_a(name: &str, ipv4_addr: &str) -> Result<Self, String> {
let dns_name = DnsName::new(name)?;
let ip_addr: IpAddr = ipv4_addr
.parse()
.map_err(|e| format!("failed parsing {ipv4_addr:?} as an IP address: {e}"))?;
match ip_addr {
IpAddr::V4(addr) => Ok(Self::A(dns_name, addr)),
IpAddr::V6(addr) => Err(format!(
"cannot create an A record with ipv6 address {addr:?}"
)),
}
}
pub fn new_aaaa(name: &str, ipv6_addr: &str) -> Result<Self, String> {
let dns_name = DnsName::new(name)?;
let ip_addr: IpAddr = ipv6_addr
.parse()
.map_err(|e| format!("failed parsing {ipv6_addr:?} as an IP address: {e}"))?;
match ip_addr {
IpAddr::V4(addr) => Err(format!(
"cannot create an AAAA record with ipv4 address {addr:?}"
)),
IpAddr::V6(addr) => Ok(Self::AAAA(dns_name, addr)),
}
}
pub fn new_cname(name: &str, target: &str) -> Result<Self, String> {
let dns_name = DnsName::new(name)?;
let dns_name_target = DnsName::new(target)?;
Ok(Self::CNAME(dns_name, dns_name_target))
}
pub fn new_ns(name: &str, target: &str) -> Result<Self, String> {
let dns_name = DnsName::new(name)?;
let dns_name_target = DnsName::new(target)?;
Ok(Self::NS(dns_name, dns_name_target))
}
pub fn new_txt(name: &str, content: &str) -> Result<Self, String> {
let dns_name = DnsName::new(name)?;
let dns_string = DnsString::new(content)?;
Ok(Self::TXT(dns_name, vec![dns_string]))
}
pub fn new_txt_multi(name: &str, lines: &[&str]) -> Result<Self, String> {
let dns_name = DnsName::new(name)?;
let mut dns_strings = Vec::new();
for line in lines {
dns_strings.push(DnsString::new(line)?);
}
Ok(Self::TXT(dns_name, dns_strings))
}
#[must_use]
pub fn name(&self) -> &DnsName {
match self {
DnsRecord::A(dns_name, _)
| DnsRecord::AAAA(dns_name, _)
| DnsRecord::CNAME(dns_name, _)
| DnsRecord::NS(dns_name, _)
| DnsRecord::TXT(dns_name, _)
| DnsRecord::Unknown(dns_name, _) => dns_name,
}
}
#[must_use]
pub fn typ(&self) -> DnsType {
match self {
DnsRecord::A(..) => DnsType::A,
DnsRecord::AAAA(..) => DnsType::AAAA,
DnsRecord::CNAME(..) => DnsType::CNAME,
DnsRecord::NS(..) => DnsType::NS,
DnsRecord::TXT(..) => DnsType::TXT,
DnsRecord::Unknown(_, typ) => DnsType::Unknown(typ.num()),
}
}
pub fn read<const N: usize>(buf: &mut FixedBuf<N>) -> Result<Self, DnsError> {
let name = DnsName::read(buf)?;
let typ = DnsType::read(buf)?;
let class = DnsClass::read(buf)?;
if class != DnsClass::Internet && class != DnsClass::Any {
return Err(DnsError::InvalidClass);
}
let _ttl_seconds = read_u32_be(buf)?;
let mut rdata = Self::read_rdata(buf)?;
let record = match typ {
DnsType::A => {
let octets: [u8; 4] = read_exact(&mut rdata)?;
check_empty(&rdata)?;
DnsRecord::A(name, Ipv4Addr::from(octets))
}
DnsType::AAAA => {
let octets: [u8; 16] = read_exact(&mut rdata)?;
check_empty(&rdata)?;
DnsRecord::AAAA(name, Ipv6Addr::from(octets))
}
DnsType::CNAME => {
let target = DnsName::read(&mut rdata)?;
check_empty(&rdata)?;
DnsRecord::CNAME(name, target)
}
DnsType::NS => {
let target = DnsName::read(&mut rdata)?;
check_empty(&rdata)?;
DnsRecord::NS(name, target)
}
DnsType::TXT => {
let strings = DnsString::read_multiple(&mut rdata)?;
DnsRecord::TXT(name, strings)
}
DnsType::MX | DnsType::PTR | DnsType::SOA | DnsType::ANY | DnsType::Unknown(_) => {
DnsRecord::Unknown(name, typ)
}
};
Ok(record)
}
pub fn write<const N: usize>(&self, out: &mut FixedBuf<N>) -> Result<(), DnsError> {
self.name().write(out)?;
self.typ().write(out)?;
DnsClass::Internet.write(out)?;
write_u32_be(out, 300)?; #[allow(clippy::match_same_arms)]
match self {
DnsRecord::A(_, ipv4_addr) => Self::write_rdata(&ipv4_addr.octets(), out)?,
DnsRecord::AAAA(_, ipv6_addr) => Self::write_rdata(&ipv6_addr.octets(), out)?,
DnsRecord::CNAME(_, target_name) => {
Self::write_rdata(target_name.as_bytes()?.readable(), out)?;
}
DnsRecord::NS(_, target_name) => {
Self::write_rdata(target_name.as_bytes()?.readable(), out)?;
}
DnsRecord::TXT(_, strings) => {
let mut buf = Vec::new();
for string in strings {
let bytes = string.as_bytes()?;
buf.extend(bytes.readable());
}
Self::write_rdata(&buf, out)?;
}
DnsRecord::Unknown(_, _) => {
return Err(DnsError::Internal(format!("cannot write record {self:?}")))
}
}
Ok(())
}
}
impl Debug for DnsRecord {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
match self {
DnsRecord::A(name, addr) => write!(f, "DnsRecord::A({name},{addr})"),
DnsRecord::AAAA(name, addr) => write!(f, "DnsRecord::AAAA({name},{addr})"),
DnsRecord::CNAME(name, target) => write!(f, "DnsRecord::CNAME({name},{target})"),
DnsRecord::NS(name, target) => write!(f, "DnsRecord::NS({name},{target})"),
DnsRecord::TXT(name, strings) => {
write!(f, "DnsRecord::TXT({name},['")?;
let mut first = true;
for string in strings {
if first {
first = false;
write!(f, "{string}")?;
} else {
write!(f, "', '{string}")?;
}
}
write!(f, "'])")
}
DnsRecord::Unknown(name, typ) => write!(f, "DnsRecord::Unknown({name},{typ})"),
}
}
}
#[cfg(test)]
#[test]
fn test_dns_record() {
use std::net::{Ipv4Addr, Ipv6Addr};
assert_eq!(
DnsRecord::A(DnsName::new("a.b").unwrap(), Ipv4Addr::new(1, 2, 3, 4)),
DnsRecord::new_a("a.b", "1.2.3.4").unwrap()
);
assert_eq!(
DnsRecord::AAAA(
DnsName::new("a.b").unwrap(),
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)
),
DnsRecord::new_aaaa("a.b", "2001:db8::").unwrap()
);
assert_eq!(
DnsRecord::CNAME(DnsName::new("a.b").unwrap(), DnsName::new("c.d").unwrap()),
DnsRecord::new_cname("a.b", "c.d").unwrap()
);
assert_eq!(
DnsRecord::NS(DnsName::new("a.b").unwrap(), DnsName::new("c.d").unwrap()),
DnsRecord::new_ns("a.b", "c.d").unwrap()
);
assert_eq!(
DnsRecord::TXT(
DnsName::new("a.b").unwrap(),
vec![DnsString::new("s1").unwrap()]
),
DnsRecord::new_txt("a.b", "s1").unwrap()
);
assert_eq!(
"DnsRecord::A(a.b,1.2.3.4)",
format!(
"{:?}",
DnsRecord::A(DnsName::new("a.b").unwrap(), Ipv4Addr::new(1, 2, 3, 4))
)
);
assert_eq!(
"DnsRecord::AAAA(a.b,2001:db8::)",
format!(
"{:?}",
DnsRecord::AAAA(
DnsName::new("a.b").unwrap(),
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)
)
)
);
assert_eq!(
"DnsRecord::CNAME(a.b,c.d)",
format!(
"{:?}",
DnsRecord::CNAME(DnsName::new("a.b").unwrap(), DnsName::new("c.d").unwrap())
)
);
assert_eq!(
"DnsRecord::NS(a.b,c.d)",
format!(
"{:?}",
DnsRecord::NS(DnsName::new("a.b").unwrap(), DnsName::new("c.d").unwrap())
)
);
assert_eq!(
"DnsRecord::TXT(a.b,['s1'])",
format!("{:?}", DnsRecord::new_txt("a.b", "s1").unwrap())
);
assert_eq!(
"DnsRecord::TXT(a.b,['s1', 's2'])",
format!(
"{:?}",
DnsRecord::new_txt_multi("a.b", &["s1", "s2"]).unwrap()
)
);
}