use core::fmt;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "__dnssec")]
use crate::dnssec::{Algorithm, SupportedAlgorithms};
use crate::{
error::*,
rr::{
DNSClass, Name, RData, Record, RecordType,
rdata::{
OPT,
opt::{EdnsCode, EdnsOption},
},
},
serialize::binary::{BinEncodable, BinEncoder},
};
#[derive(Debug, PartialEq, Eq, Clone)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Edns {
rcode_high: u8,
version: u8,
flags: EdnsFlags,
max_payload: u16,
options: OPT,
}
impl Default for Edns {
fn default() -> Self {
Self {
rcode_high: 0,
version: 0,
flags: EdnsFlags::default(),
max_payload: 512,
options: OPT::default(),
}
}
}
impl Edns {
pub fn new() -> Self {
Self::default()
}
pub fn rcode_high(&self) -> u8 {
self.rcode_high
}
pub fn version(&self) -> u8 {
self.version
}
pub fn flags(&self) -> &EdnsFlags {
&self.flags
}
pub fn flags_mut(&mut self) -> &mut EdnsFlags {
&mut self.flags
}
pub fn max_payload(&self) -> u16 {
self.max_payload
}
pub fn option(&self, code: EdnsCode) -> Option<&EdnsOption> {
self.options.get(code)
}
pub fn options(&self) -> &OPT {
&self.options
}
pub fn options_mut(&mut self) -> &mut OPT {
&mut self.options
}
pub fn set_rcode_high(&mut self, rcode_high: u8) -> &mut Self {
self.rcode_high = rcode_high;
self
}
pub fn set_version(&mut self, version: u8) -> &mut Self {
self.version = version;
self
}
#[cfg(feature = "__dnssec")]
pub fn enable_dnssec(&mut self) {
self.set_dnssec_ok(true);
self.set_default_algorithms();
}
#[cfg(feature = "__dnssec")]
pub fn set_default_algorithms(&mut self) -> &mut Self {
let mut algorithms = SupportedAlgorithms::new();
for algorithm in [
Algorithm::RSASHA256,
Algorithm::RSASHA512,
Algorithm::ECDSAP256SHA256,
Algorithm::ECDSAP384SHA384,
Algorithm::ED25519,
] {
if algorithm.is_supported() {
algorithms.set(algorithm);
}
}
let dau = EdnsOption::DAU(algorithms);
self.options_mut().insert(dau);
self
}
pub fn set_dnssec_ok(&mut self, dnssec_ok: bool) -> &mut Self {
self.flags.dnssec_ok = dnssec_ok;
self
}
pub fn set_max_payload(&mut self, max_payload: u16) -> &mut Self {
self.max_payload = max_payload.max(512);
self
}
}
impl<'a> From<&'a Record> for Edns {
fn from(value: &'a Record) -> Self {
assert!(value.record_type() == RecordType::OPT);
let rcode_high = ((value.ttl & 0xFF00_0000u32) >> 24) as u8;
let version = ((value.ttl & 0x00FF_0000u32) >> 16) as u8;
let flags = EdnsFlags::from((value.ttl & 0x0000_FFFFu32) as u16);
let max_payload = u16::from(value.dns_class);
let options = match &value.data {
RData::Update0(..) | RData::NULL(..) => {
OPT::default()
}
RData::OPT(option_data) => {
option_data.clone() }
_ => {
panic!("rr_type doesn't match the RData: {:?}", value.data) }
};
Self {
rcode_high,
version,
flags,
max_payload,
options,
}
}
}
impl<'a> From<&'a Edns> for Record {
fn from(value: &'a Edns) -> Self {
let mut ttl: u32 = u32::from(value.rcode_high()) << 24;
ttl |= u32::from(value.version()) << 16;
ttl |= u32::from(u16::from(value.flags));
let mut record = Self::from_rdata(Name::root(), ttl, RData::OPT(value.options().clone()));
record.dns_class = DNSClass::for_opt(value.max_payload());
record
}
}
impl BinEncodable for Edns {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
encoder.emit(0)?; RecordType::OPT.emit(encoder)?; DNSClass::for_opt(self.max_payload()).emit(encoder)?;
let mut ttl = u32::from(self.rcode_high()) << 24;
ttl |= u32::from(self.version()) << 16;
ttl |= u32::from(u16::from(self.flags));
encoder.emit_u32(ttl)?;
let place = encoder.place::<u16>()?;
self.options.emit(encoder)?;
let len = encoder.len_since_place(&place);
assert!(len <= u16::MAX as usize);
place.replace(encoder, len as u16)?;
Ok(())
}
}
impl fmt::Display for Edns {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let version = self.version;
let dnssec_ok = self.flags.dnssec_ok;
let z_flags = self.flags.z;
let max_payload = self.max_payload;
write!(
f,
"version: {version} dnssec_ok: {dnssec_ok} z_flags: {z_flags} max_payload: {max_payload} opts: {opts_len}",
opts_len = self.options().as_ref().len()
)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct EdnsFlags {
pub dnssec_ok: bool,
pub z: u16,
}
impl From<u16> for EdnsFlags {
fn from(flags: u16) -> Self {
Self {
dnssec_ok: flags & 0x8000 == 0x8000,
z: flags & 0x7FFF,
}
}
}
impl From<EdnsFlags> for u16 {
fn from(flags: EdnsFlags) -> Self {
match flags.dnssec_ok {
true => 0x8000 | flags.z,
false => 0x7FFF & flags.z,
}
}
}
pub const DEFAULT_MAX_PAYLOAD_LEN: u16 = 1232;
#[cfg(all(test, feature = "__dnssec"))]
mod tests {
use super::*;
#[test]
fn test_encode_decode() {
let mut edns = Edns::new();
let flags = edns.flags_mut();
flags.dnssec_ok = true;
flags.z = 1;
edns.set_max_payload(0x8008);
edns.set_version(0x40);
edns.set_rcode_high(0x01);
edns.options_mut()
.insert(EdnsOption::DAU(SupportedAlgorithms::all()));
let record = Record::from(&edns);
let edns_decode = Edns::from(&record);
assert_eq!(edns.flags().dnssec_ok, edns_decode.flags().dnssec_ok);
assert_eq!(edns.flags().z, edns_decode.flags().z);
assert_eq!(edns.max_payload(), edns_decode.max_payload());
assert_eq!(edns.version(), edns_decode.version());
assert_eq!(edns.rcode_high(), edns_decode.rcode_high());
assert_eq!(edns.options(), edns_decode.options());
edns.options_mut()
.insert(EdnsOption::DAU(SupportedAlgorithms::all()));
edns.options_mut().remove(EdnsCode::DAU);
assert!(edns.option(EdnsCode::DAU).is_none());
}
}