use alloc::{borrow::ToOwned, boxed::Box};
use core::{cmp::Ordering, convert::TryFrom, fmt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "__dnssec")]
use crate::dnssec::{Proof, Proven};
#[cfg(test)]
use crate::rr::rdata::A;
use crate::{
error::ProtoResult,
rr::{Name, RData, RecordData, RecordType, dns_class::DNSClass},
serialize::binary::{
BinDecodable, BinDecoder, BinEncodable, BinEncoder, DecodeError, Restrict,
},
};
#[cfg(feature = "mdns")]
const MDNS_ENABLE_CACHE_FLUSH: u16 = 1 << 15;
#[non_exhaustive]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Eq, Debug, Clone)]
pub struct Record<R: RecordData = RData> {
pub name: Name,
pub dns_class: DNSClass,
pub ttl: u32,
pub data: R,
#[cfg(feature = "mdns")]
pub mdns_cache_flush: bool,
#[cfg(feature = "__dnssec")]
pub proof: Proof,
}
impl Record {
#[cfg(test)]
pub(crate) fn stub() -> Self {
Self {
name: Name::from_ascii(".").unwrap(),
dns_class: DNSClass::IN,
ttl: 0,
data: RData::A(A::new(0, 0, 0, 0)),
#[cfg(feature = "mdns")]
mdns_cache_flush: false,
#[cfg(feature = "__dnssec")]
proof: Proof::default(),
}
}
}
impl Record {
pub fn update0(name: Name, ttl: u32, rr_type: RecordType) -> Self {
Self {
name,
dns_class: DNSClass::IN,
ttl,
data: RData::Update0(rr_type),
#[cfg(feature = "mdns")]
mdns_cache_flush: false,
#[cfg(feature = "__dnssec")]
proof: Proof::default(),
}
}
pub fn try_borrow<T>(&self) -> Option<RecordRef<'_, T>>
where
T: RecordData,
{
RecordRef::try_from(self).ok()
}
}
impl<R: RecordData> Record<R> {
pub fn from_rdata(name: Name, ttl: u32, rdata: R) -> Self {
Self {
name,
dns_class: DNSClass::IN,
ttl,
data: rdata,
#[cfg(feature = "mdns")]
mdns_cache_flush: false,
#[cfg(feature = "__dnssec")]
proof: Proof::default(),
}
}
pub fn map<N: RecordData>(self, f: impl FnOnce(R) -> Option<N>) -> Option<Record<N>> {
let Self {
name,
dns_class,
ttl,
data: rdata,
#[cfg(feature = "mdns")]
mdns_cache_flush,
#[cfg(feature = "__dnssec")]
proof,
} = self;
Some(Record {
name,
dns_class,
ttl,
data: f(rdata)?,
#[cfg(feature = "mdns")]
mdns_cache_flush,
#[cfg(feature = "__dnssec")]
proof,
})
}
pub fn into_record_of_rdata(self) -> Record<RData> {
let Self {
name,
dns_class,
ttl,
data: rdata,
#[cfg(feature = "mdns")]
mdns_cache_flush,
#[cfg(feature = "__dnssec")]
proof,
} = self;
let rdata = RecordData::into_rdata(rdata);
Record {
name,
dns_class,
ttl,
data: rdata,
#[cfg(feature = "mdns")]
mdns_cache_flush,
#[cfg(feature = "__dnssec")]
proof,
}
}
pub fn decrement_ttl(&mut self, offset: u32) -> &mut Self {
self.ttl = self.ttl.saturating_sub(offset);
self
}
#[inline]
pub fn record_type(&self) -> RecordType {
self.data.record_type()
}
}
impl<R: RecordData> BinEncodable for Record<R> {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
self.name.emit(encoder)?;
self.record_type().emit(encoder)?;
#[cfg(not(feature = "mdns"))]
self.dns_class.emit(encoder)?;
#[cfg(feature = "mdns")]
{
if self.mdns_cache_flush {
encoder.emit_u16(u16::from(self.dns_class) | MDNS_ENABLE_CACHE_FLUSH)?;
} else {
self.dns_class.emit(encoder)?;
}
}
encoder.emit_u32(self.ttl)?;
let place = encoder.place::<u16>()?;
if !self.data.is_update() {
self.data.emit(encoder)?;
}
let len = encoder.len_since_place(&place);
assert!(len <= u16::MAX as usize);
place.replace(encoder, len as u16)?;
Ok(())
}
}
impl<'r> BinDecodable<'r> for Record<RData> {
fn read(decoder: &mut BinDecoder<'r>) -> Result<Self, DecodeError> {
let name_labels: Name = Name::read(decoder)?;
let record_type: RecordType = RecordType::read(decoder)?;
#[cfg(feature = "mdns")]
let mut mdns_cache_flush = false;
let class: DNSClass = if record_type == RecordType::OPT {
if !name_labels.is_root() {
return Err(DecodeError::EdnsNameNotRoot(Box::new(name_labels)));
}
DNSClass::for_opt(
decoder.read_u16()?.unverified(),
)
} else {
#[cfg(not(feature = "mdns"))]
{
DNSClass::read(decoder)?
}
#[cfg(feature = "mdns")]
{
let dns_class_value =
decoder.read_u16()?.unverified();
if dns_class_value & MDNS_ENABLE_CACHE_FLUSH > 0 {
mdns_cache_flush = true;
DNSClass::from(dns_class_value & !MDNS_ENABLE_CACHE_FLUSH)
} else {
DNSClass::from(dns_class_value)
}
}
};
let ttl: u32 = decoder.read_u32()?.unverified();
let rd_length = decoder
.read_u16()?
.verify_unwrap(|u| (*u as usize) <= decoder.len())
.map_err(|u| DecodeError::IncorrectRDataLengthRead {
read: decoder.len(),
len: u as usize,
})?;
let rdata = if rd_length == 0 {
RData::Update0(record_type)
} else {
RData::read(decoder, record_type, Restrict::new(rd_length))?
};
Ok(Self {
name: name_labels,
dns_class: class,
ttl,
data: rdata,
#[cfg(feature = "mdns")]
mdns_cache_flush,
#[cfg(feature = "__dnssec")]
proof: Proof::default(),
})
}
}
impl<R: RecordData> fmt::Display for Record<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
f,
"{name} {ttl} {class} {ty} {rdata}",
name = self.name,
ttl = self.ttl,
class = self.dns_class,
ty = self.record_type(),
rdata = self.data,
)?;
Ok(())
}
}
impl<R: RecordData> PartialEq for Record<R> {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.dns_class == other.dns_class && self.data == other.data
}
}
macro_rules! compare_or_equal {
($x:ident, $y:ident, $z:ident) => {
match ($x).$z.cmp(&($y).$z) {
o @ Ordering::Less | o @ Ordering::Greater => return o,
Ordering::Equal => (),
}
};
}
impl Ord for Record {
fn cmp(&self, other: &Self) -> Ordering {
compare_or_equal!(self, other, name);
match self.record_type().cmp(&other.record_type()) {
o @ Ordering::Less | o @ Ordering::Greater => return o,
Ordering::Equal => {}
}
compare_or_equal!(self, other, dns_class);
compare_or_equal!(self, other, ttl);
compare_or_equal!(self, other, data);
Ordering::Equal
}
}
impl PartialOrd<Self> for Record {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg(feature = "__dnssec")]
impl From<Record> for Proven<Record> {
fn from(record: Record) -> Self {
let proof = record.proof;
Self::new(proof, record)
}
}
#[cfg(feature = "__dnssec")]
impl<'a> From<&'a Record> for Proven<&'a Record> {
fn from(record: &'a Record) -> Self {
let proof = record.proof;
Self::new(proof, record)
}
}
pub struct RecordRef<'a, R: RecordData> {
name: &'a Name,
dns_class: DNSClass,
ttl: u32,
data: &'a R,
#[cfg(feature = "mdns")]
mdns_cache_flush: bool,
#[cfg(feature = "__dnssec")]
proof: Proof,
}
impl<R: RecordData> Clone for RecordRef<'_, R> {
fn clone(&self) -> Self {
*self
}
}
impl<R: RecordData> Copy for RecordRef<'_, R> {}
impl<R: RecordData> RecordRef<'_, R> {
pub fn to_owned(&self) -> Record<R> {
Record {
name: self.name.to_owned(),
dns_class: self.dns_class,
ttl: self.ttl,
data: self.data.clone(),
#[cfg(feature = "mdns")]
mdns_cache_flush: self.mdns_cache_flush,
#[cfg(feature = "__dnssec")]
proof: self.proof,
}
}
#[inline]
pub fn name(&self) -> &Name {
self.name
}
#[inline]
pub fn record_type(&self) -> RecordType {
self.data.record_type()
}
#[inline]
pub fn dns_class(&self) -> DNSClass {
self.dns_class
}
#[inline]
pub fn ttl(&self) -> u32 {
self.ttl
}
#[inline]
pub fn data(&self) -> &R {
self.data
}
#[cfg(feature = "mdns")]
#[inline]
pub fn mdns_cache_flush(&self) -> bool {
self.mdns_cache_flush
}
#[cfg(feature = "__dnssec")]
#[inline]
pub fn proof(&self) -> Proof {
self.proof
}
}
impl<'a, R: RecordData> TryFrom<&'a Record> for RecordRef<'a, R> {
type Error = &'a Record;
fn try_from(record: &'a Record) -> Result<Self, Self::Error> {
let Record {
name,
dns_class,
ttl,
data: rdata,
#[cfg(feature = "mdns")]
mdns_cache_flush,
#[cfg(feature = "__dnssec")]
proof,
} = record;
match R::try_borrow(rdata) {
None => Err(record),
Some(rdata) => Ok(Self {
name,
dns_class: *dns_class,
ttl: *ttl,
data: rdata,
#[cfg(feature = "mdns")]
mdns_cache_flush: *mdns_cache_flush,
#[cfg(feature = "__dnssec")]
proof: *proof,
}),
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use alloc::vec::Vec;
use core::cmp::Ordering;
use core::str::FromStr;
#[cfg(feature = "std")]
use std::println;
use super::*;
use crate::rr::Name;
use crate::rr::dns_class::DNSClass;
use crate::rr::rdata::{A, AAAA};
use crate::rr::record_data::RData;
#[test]
fn test_emit_and_read() {
let record = Record::from_rdata(
Name::from_str("www.example.com.").unwrap(),
5,
RData::A(A::new(192, 168, 0, 1)),
);
let mut vec_bytes: Vec<u8> = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut vec_bytes);
record.emit(&mut encoder).unwrap();
}
let mut decoder = BinDecoder::new(&vec_bytes);
let got = Record::read(&mut decoder).unwrap();
assert_eq!(got, record);
}
#[test]
fn test_order() {
let record = Record::from_rdata(
Name::from_str("www.example.com").unwrap(),
5,
RData::A(A::new(192, 168, 0, 1)),
);
let mut greater_name = record.clone();
greater_name.name = Name::from_str("zzz.example.com").unwrap();
let mut greater_type = record.clone().into_record_of_rdata();
greater_type.data = RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 0));
let mut greater_class = record.clone();
greater_class.dns_class = DNSClass::NONE;
let mut greater_rdata = record.clone();
greater_rdata.data = RData::A(A::new(192, 168, 0, 255));
let compares = vec![
(&record, &greater_name),
(&record, &greater_type),
(&record, &greater_class),
(&record, &greater_rdata),
];
assert_eq!(record.clone(), record.clone());
for (r, g) in compares {
#[cfg(feature = "std")]
println!("r, g: {r:?}, {g:?}");
assert_eq!(r.cmp(g), Ordering::Less);
}
}
#[cfg(feature = "mdns")]
#[test]
fn test_mdns_cache_flush_bit_handling() {
const RR_CLASS_OFFSET: usize = 1 +
size_of::<u16>() ;
let mut record = Record::<RData>::stub();
record.mdns_cache_flush = true;
let mut vec_bytes: Vec<u8> = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut vec_bytes);
record.emit(&mut encoder).unwrap();
let rr_class_slice = encoder.slice_of(RR_CLASS_OFFSET, RR_CLASS_OFFSET + 2);
assert_eq!(rr_class_slice, &[0x80, 0x01]);
}
let mut decoder = BinDecoder::new(&vec_bytes);
let got = Record::<RData>::read(&mut decoder).unwrap();
assert_eq!(got.dns_class, DNSClass::IN);
assert!(got.mdns_cache_flush);
}
}