#![allow(clippy::use_self)]
use alloc::vec::Vec;
use core::fmt;
use core::hash::{Hash, Hasher};
use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use core::str::FromStr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use tracing::warn;
use crate::{
error::{ProtoError, ProtoResult},
rr::{RData, RecordData, RecordDataDecodable, RecordType},
serialize::binary::{
BinDecodable, BinDecoder, BinEncodable, BinEncoder, DecodeError, RDataEncoding, Restrict,
},
};
#[cfg(feature = "__dnssec")]
use crate::dnssec::SupportedAlgorithms;
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Default, Debug, Clone, Ord, PartialOrd)]
#[non_exhaustive]
pub struct OPT {
pub options: Vec<(EdnsCode, EdnsOption)>,
}
impl OPT {
pub fn new(options: Vec<(EdnsCode, EdnsOption)>) -> Self {
Self { options }
}
pub fn get(&self, code: EdnsCode) -> Option<&EdnsOption> {
self.options
.iter()
.find_map(|(c, option)| if code == *c { Some(option) } else { None })
}
pub fn get_all(&self, code: EdnsCode) -> Vec<&EdnsOption> {
self.options
.iter()
.filter_map(|(c, option)| if code == *c { Some(option) } else { None })
.collect()
}
pub fn insert(&mut self, option: EdnsOption) {
self.options.push(((&option).into(), option));
}
pub fn remove(&mut self, option: EdnsCode) {
self.options.retain(|(c, _)| *c != option)
}
}
impl PartialEq for OPT {
fn eq(&self, other: &Self) -> bool {
let matching_elements_count = self
.options
.iter()
.filter(|entry| other.options.contains(entry))
.count();
matching_elements_count == self.options.len()
&& matching_elements_count == other.options.len()
}
}
impl Eq for OPT {}
impl Hash for OPT {
fn hash<H: Hasher>(&self, state: &mut H) {
let mut sorted = self.options.clone();
sorted.sort();
sorted.hash(state);
}
}
impl AsMut<Vec<(EdnsCode, EdnsOption)>> for OPT {
fn as_mut(&mut self) -> &mut Vec<(EdnsCode, EdnsOption)> {
&mut self.options
}
}
impl AsRef<[(EdnsCode, EdnsOption)]> for OPT {
fn as_ref(&self) -> &[(EdnsCode, EdnsOption)] {
&self.options
}
}
impl BinEncodable for OPT {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
let mut encoder = encoder.with_rdata_behavior(RDataEncoding::Other);
for (edns_code, edns_option) in self.as_ref().iter() {
encoder.emit_u16(u16::from(*edns_code))?;
encoder.emit_u16(edns_option.len())?;
edns_option.emit(&mut encoder)?
}
Ok(())
}
}
impl<'r> RecordDataDecodable<'r> for OPT {
fn read_data(decoder: &mut BinDecoder<'r>, length: Restrict<u16>) -> Result<Self, DecodeError> {
let mut state: OptReadState = OptReadState::ReadCode;
let mut options: Vec<(EdnsCode, EdnsOption)> = Vec::new();
let start_idx = decoder.index();
let rdata_length = length.map(|u| u as usize).unverified();
while rdata_length > decoder.index() - start_idx {
match state {
OptReadState::ReadCode => {
state = OptReadState::Code {
code: EdnsCode::from(
decoder.read_u16()?.unverified(),
),
};
}
OptReadState::Code { code } => {
let length = decoder
.read_u16()?
.map(|u| u as usize)
.verify_unwrap(|u| *u <= rdata_length)
.map_err(|opt_len| DecodeError::IncorrectRDataLengthRead {
read: rdata_length,
len: opt_len,
})?;
state = if length == 0 {
options.push((code, (code, &[] as &[u8]).try_into()?));
OptReadState::ReadCode
} else {
OptReadState::Data {
code,
length,
collected: Vec::<u8>::with_capacity(length),
}
};
}
OptReadState::Data {
code,
length,
mut collected,
} => {
collected.push(decoder.pop()?.unverified());
if length == collected.len() {
options.push((code, (code, &collected as &[u8]).try_into()?));
state = OptReadState::ReadCode;
} else {
state = OptReadState::Data {
code,
length,
collected,
};
}
}
}
}
if state != OptReadState::ReadCode {
warn!("incomplete or poorly formatted EDNS options: {:?}", state);
options.clear();
}
Ok(Self::new(options))
}
}
impl RecordData for OPT {
fn try_borrow(data: &RData) -> Option<&Self> {
match data {
RData::OPT(csync) => Some(csync),
_ => None,
}
}
fn record_type(&self) -> RecordType {
RecordType::OPT
}
fn into_rdata(self) -> RData {
RData::OPT(self)
}
}
impl fmt::Display for OPT {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
fmt::Debug::fmt(self, f)
}
}
#[derive(Debug, PartialEq, Eq)]
enum OptReadState {
ReadCode,
Code {
code: EdnsCode,
}, Data {
code: EdnsCode,
length: usize,
collected: Vec<u8>,
}, }
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Hash, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[non_exhaustive]
pub enum EdnsCode {
Zero,
LLQ,
UL,
NSID,
DAU,
DHU,
N3U,
Subnet,
Expire,
Cookie,
Keepalive,
Padding,
Chain,
Unknown(u16),
}
impl From<u16> for EdnsCode {
fn from(value: u16) -> Self {
match value {
0 => Self::Zero,
1 => Self::LLQ,
2 => Self::UL,
3 => Self::NSID,
5 => Self::DAU,
6 => Self::DHU,
7 => Self::N3U,
8 => Self::Subnet,
9 => Self::Expire,
10 => Self::Cookie,
11 => Self::Keepalive,
12 => Self::Padding,
13 => Self::Chain,
_ => Self::Unknown(value),
}
}
}
impl From<EdnsCode> for u16 {
fn from(value: EdnsCode) -> Self {
match value {
EdnsCode::Zero => 0,
EdnsCode::LLQ => 1,
EdnsCode::UL => 2,
EdnsCode::NSID => 3,
EdnsCode::DAU => 5,
EdnsCode::DHU => 6,
EdnsCode::N3U => 7,
EdnsCode::Subnet => 8,
EdnsCode::Expire => 9,
EdnsCode::Cookie => 10,
EdnsCode::Keepalive => 11,
EdnsCode::Padding => 12,
EdnsCode::Chain => 13,
EdnsCode::Unknown(value) => value,
}
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Debug, PartialOrd, PartialEq, Eq, Clone, Hash, Ord)]
#[non_exhaustive]
pub enum EdnsOption {
#[cfg(feature = "__dnssec")]
DAU(SupportedAlgorithms),
Subnet(ClientSubnet),
NSID(NSIDPayload),
Unknown(u16, Vec<u8>),
}
impl EdnsOption {
pub fn len(&self) -> u16 {
match self {
#[cfg(feature = "__dnssec")]
EdnsOption::DAU(algorithms) => algorithms.len(),
EdnsOption::Subnet(subnet) => subnet.len(),
EdnsOption::NSID(payload) => payload.as_ref().len() as u16, EdnsOption::Unknown(_, data) => data.len() as u16, }
}
pub fn is_empty(&self) -> bool {
match self {
#[cfg(feature = "__dnssec")]
EdnsOption::DAU(algorithms) => algorithms.is_empty(),
EdnsOption::Subnet(subnet) => subnet.is_empty(),
EdnsOption::NSID(payload) => payload.as_ref().is_empty(),
EdnsOption::Unknown(_, data) => data.is_empty(),
}
}
}
impl BinEncodable for EdnsOption {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
match self {
#[cfg(feature = "__dnssec")]
EdnsOption::DAU(algorithms) => algorithms.emit(encoder),
EdnsOption::Subnet(subnet) => subnet.emit(encoder),
EdnsOption::NSID(payload) => encoder.emit_vec(payload.as_ref()),
EdnsOption::Unknown(_, data) => encoder.emit_vec(data), }
}
}
impl<'a> TryFrom<(EdnsCode, &'a [u8])> for EdnsOption {
type Error = DecodeError;
fn try_from(value: (EdnsCode, &'a [u8])) -> Result<Self, Self::Error> {
Ok(match value.0 {
#[cfg(feature = "__dnssec")]
EdnsCode::DAU => Self::DAU(value.1.into()),
EdnsCode::Subnet => Self::Subnet(value.1.try_into()?),
EdnsCode::NSID => Self::NSID(value.1.try_into()?),
_ => Self::Unknown(value.0.into(), value.1.to_vec()),
})
}
}
impl<'a> TryFrom<&'a EdnsOption> for Vec<u8> {
type Error = ProtoError;
fn try_from(value: &'a EdnsOption) -> Result<Self, Self::Error> {
Ok(match value {
#[cfg(feature = "__dnssec")]
EdnsOption::DAU(algorithms) => algorithms.into(),
EdnsOption::Subnet(subnet) => subnet.try_into()?,
EdnsOption::NSID(payload) => payload.as_ref().to_vec(),
EdnsOption::Unknown(_, data) => data.clone(), })
}
}
impl<'a> From<&'a EdnsOption> for EdnsCode {
fn from(value: &'a EdnsOption) -> Self {
match value {
#[cfg(feature = "__dnssec")]
EdnsOption::DAU(..) => Self::DAU,
EdnsOption::Subnet(..) => Self::Subnet,
EdnsOption::NSID(..) => Self::NSID,
EdnsOption::Unknown(code, _) => (*code).into(),
}
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Debug, PartialOrd, PartialEq, Eq, Clone, Copy, Hash, Ord)]
pub struct ClientSubnet {
address: IpAddr,
source_prefix: u8,
scope_prefix: u8,
}
impl ClientSubnet {
pub fn new(address: IpAddr, source_prefix: u8, scope_prefix: u8) -> Self {
Self {
address,
source_prefix,
scope_prefix,
}
}
pub fn len(&self) -> u16 {
2 + 1 + 1 + self.addr_len()
}
#[inline]
pub fn is_empty(&self) -> bool {
false
}
pub fn addr(&self) -> IpAddr {
self.address
}
pub fn set_addr(&mut self, addr: IpAddr) {
self.address = addr;
}
pub fn source_prefix(&self) -> u8 {
self.source_prefix
}
pub fn set_source_prefix(&mut self, source_prefix: u8) {
self.source_prefix = source_prefix;
}
pub fn scope_prefix(&self) -> u8 {
self.scope_prefix
}
pub fn set_scope_prefix(&mut self, scope_prefix: u8) {
self.scope_prefix = scope_prefix;
}
fn addr_len(&self) -> u16 {
let source_prefix = self.source_prefix as u16;
source_prefix / 8
+ if !source_prefix.is_multiple_of(8) {
1
} else {
0
}
}
}
impl BinEncodable for ClientSubnet {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
let address = self.address;
let source_prefix = self.source_prefix;
let scope_prefix = self.scope_prefix;
let addr_len = self.addr_len();
match address {
IpAddr::V4(ip) => {
encoder.emit_u16(1)?; encoder.emit_u8(source_prefix)?;
encoder.emit_u8(scope_prefix)?;
let octets = ip.octets();
let addr_len = addr_len as usize;
if addr_len <= octets.len() {
encoder.emit_vec(&octets[0..addr_len])?
} else {
return Err(ProtoError::Message(
"Invalid addr length for encode EcsOption",
));
}
}
IpAddr::V6(ip) => {
encoder.emit_u16(2)?; encoder.emit_u8(source_prefix)?;
encoder.emit_u8(scope_prefix)?;
let octets = ip.octets();
let addr_len = addr_len as usize;
if addr_len <= octets.len() {
encoder.emit_vec(&octets[0..addr_len])?
} else {
return Err(ProtoError::Message(
"Invalid addr length for encode EcsOption",
));
}
}
}
Ok(())
}
}
impl<'a> BinDecodable<'a> for ClientSubnet {
fn read(decoder: &mut BinDecoder<'a>) -> Result<Self, DecodeError> {
let family = decoder.read_u16()?.unverified();
match family {
1 => {
let source_prefix = decoder.read_u8()?.unverified();
let scope_prefix = decoder.read_u8()?.unverified();
let addr_len =
(source_prefix / 8 + if source_prefix % 8 > 0 { 1 } else { 0 }) as usize;
let mut octets = Ipv4Addr::UNSPECIFIED.octets();
if addr_len > octets.len() {
return Err(DecodeError::IncorrectRDataLengthRead {
read: octets.len(),
len: addr_len,
});
}
for octet in octets.iter_mut().take(addr_len) {
*octet = decoder.read_u8()?.unverified();
}
Ok(Self {
address: IpAddr::from(octets),
source_prefix,
scope_prefix,
})
}
2 => {
let source_prefix = decoder.read_u8()?.unverified();
let scope_prefix = decoder.read_u8()?.unverified();
let addr_len =
(source_prefix / 8 + if source_prefix % 8 > 0 { 1 } else { 0 }) as usize;
let mut octets = Ipv6Addr::UNSPECIFIED.octets();
if addr_len > octets.len() {
return Err(DecodeError::IncorrectRDataLengthRead {
read: octets.len(),
len: addr_len,
});
}
for octet in octets.iter_mut().take(addr_len) {
*octet = decoder.read_u8()?.unverified();
}
Ok(Self {
address: IpAddr::from(octets),
source_prefix,
scope_prefix,
})
}
_ => Err(DecodeError::UnknownAddressFamily(family)),
}
}
}
impl<'a> TryFrom<&'a ClientSubnet> for Vec<u8> {
type Error = ProtoError;
fn try_from(value: &'a ClientSubnet) -> Result<Self, Self::Error> {
let mut bytes = Self::with_capacity(value.len() as usize); let mut encoder = BinEncoder::new(&mut bytes);
value.emit(&mut encoder)?;
bytes.shrink_to_fit();
Ok(bytes)
}
}
impl<'a> TryFrom<&'a [u8]> for ClientSubnet {
type Error = DecodeError;
fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
let mut decoder = BinDecoder::new(value);
Self::read(&mut decoder)
}
}
impl From<ipnet::IpNet> for ClientSubnet {
fn from(net: ipnet::IpNet) -> Self {
Self {
address: net.addr(),
source_prefix: net.prefix_len(),
scope_prefix: Default::default(),
}
}
}
impl FromStr for ClientSubnet {
type Err = ipnet::AddrParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
ipnet::IpNet::from_str(s).map(ClientSubnet::from)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct NSIDPayload(Vec<u8>);
impl NSIDPayload {
pub fn new(data: impl Into<Vec<u8>>) -> Result<Self, ProtoError> {
let data = data.into();
if data.len() > u16::MAX as usize {
return Err(ProtoError::from("NSID EDNS payload too large"));
}
Ok(Self(data))
}
}
impl<'a> TryFrom<&'a [u8]> for NSIDPayload {
type Error = DecodeError;
fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
if value.len() > u16::MAX as usize {
return Err(DecodeError::IncorrectRDataLengthRead {
read: value.len(),
len: u16::MAX as usize,
});
}
Ok(Self(value.to_vec()))
}
}
impl AsRef<[u8]> for NSIDPayload {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use core::hash::{Hash, Hasher};
#[cfg(feature = "std")]
use std::{hash::DefaultHasher, println};
use super::*;
#[test]
#[cfg(feature = "__dnssec")]
fn test() {
let mut rdata = OPT::default();
rdata.insert(EdnsOption::DAU(SupportedAlgorithms::all()));
let mut bytes = Vec::new();
let mut encoder: BinEncoder<'_> = BinEncoder::new(&mut bytes);
assert!(rdata.emit(&mut encoder).is_ok());
let bytes = encoder.into_bytes();
#[cfg(feature = "std")]
println!("bytes: {bytes:?}");
let mut decoder: BinDecoder<'_> = BinDecoder::new(bytes);
let restrict = Restrict::new(bytes.len() as u16);
let read_rdata = OPT::read_data(&mut decoder, restrict).expect("Decoding error");
assert_eq!(rdata, read_rdata);
}
#[test]
fn test_read_empty_option_at_end_of_opt() {
let bytes: Vec<u8> = vec![
0x00, 0x0a, 0x00, 0x08, 0x0b, 0x64, 0xb4, 0xdc, 0xd7, 0xb0, 0xcc, 0x8f, 0x00, 0x08,
0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00,
];
let mut decoder: BinDecoder<'_> = BinDecoder::new(&bytes);
let read_rdata = OPT::read_data(&mut decoder, Restrict::new(bytes.len() as u16));
assert!(
read_rdata.is_ok(),
"error decoding: {:?}",
read_rdata.unwrap_err()
);
let opt = read_rdata.unwrap();
let options = vec![
(
EdnsCode::Subnet,
EdnsOption::Subnet("0.0.0.0/0".parse().unwrap()),
),
(
EdnsCode::Cookie,
EdnsOption::Unknown(10, vec![0x0b, 0x64, 0xb4, 0xdc, 0xd7, 0xb0, 0xcc, 0x8f]),
),
(EdnsCode::Keepalive, EdnsOption::Unknown(11, vec![])),
];
let options = OPT::new(options);
assert_eq!(opt, options);
}
#[test]
fn test_multiple_options_with_same_code() {
let bytes: Vec<u8> = vec![
0x00, 0x0f, 0x00, 0x02, 0x00, 0x06, 0x00, 0x0f, 0x00, 0x0f, 0x00, 0x09, 0x55, 0x6E,
0x6B, 0x6E, 0x6F, 0x77, 0x6E, 0x20, 0x65, 0x72, 0x72, 0x6F, 0x72,
];
let mut decoder: BinDecoder<'_> = BinDecoder::new(&bytes);
let read_rdata = OPT::read_data(&mut decoder, Restrict::new(bytes.len() as u16));
assert!(
read_rdata.is_ok(),
"error decoding: {:?}",
read_rdata.unwrap_err()
);
let opt = read_rdata.unwrap();
let options = vec![
(
EdnsCode::Unknown(15u16),
EdnsOption::Unknown(15u16, vec![0x00, 0x06]),
),
(
EdnsCode::Unknown(15u16),
EdnsOption::Unknown(
15u16,
vec![
0x00, 0x09, 0x55, 0x6E, 0x6B, 0x6E, 0x6F, 0x77, 0x6E, 0x20, 0x65, 0x72,
0x72, 0x6F, 0x72,
],
),
),
];
let options = OPT::new(options);
assert_eq!(opt, options);
}
#[test]
fn test_write_client_subnet() {
let expected_bytes: Vec<u8> = vec![0x00, 0x01, 0x18, 0x00, 0xac, 0x01, 0x01];
let ecs: ClientSubnet = "172.1.1.1/24".parse().unwrap();
let bytes = Vec::<u8>::try_from(&ecs).unwrap();
#[cfg(feature = "std")]
println!("bytes: {bytes:?}");
assert_eq!(bytes, expected_bytes);
}
#[test]
fn test_read_client_subnet() {
let bytes: Vec<u8> = vec![0x00, 0x01, 0x18, 0x00, 0xac, 0x01, 0x01];
let ecs = ClientSubnet::try_from(bytes.as_slice()).unwrap();
assert_eq!(ecs, "172.1.1.0/24".parse().unwrap());
}
#[test]
fn test_nsid_payload_too_large() {
let err = NSIDPayload::try_from([0x00; (u16::MAX as usize) + 1].as_slice()).unwrap_err();
assert!(
matches!(err, DecodeError::IncorrectRDataLengthRead { .. }),
"expected IncorrectRDataLengthRead, got {err}"
);
}
#[test]
fn test_nsid_payload_roundtrip() {
let payload_in = EdnsOption::NSID([0xC0, 0xFF, 0xEE].as_slice().try_into().unwrap());
let mut buf = Vec::new();
let mut encoder = BinEncoder::new(&mut buf);
payload_in.emit(&mut encoder).unwrap();
let payload_out = EdnsOption::try_from((EdnsCode::NSID, buf.as_ref())).unwrap();
assert_eq!(payload_in, payload_out);
}
#[test]
fn test_eq_and_hash() {
let options_1 = OPT::new(vec![
(
EdnsCode::Unknown(15u16),
EdnsOption::Unknown(15u16, vec![0x00, 0x06]),
),
(
EdnsCode::Unknown(15u16),
EdnsOption::Unknown(
15u16,
vec![
0x00, 0x09, 0x55, 0x6E, 0x6B, 0x6E, 0x6F, 0x77, 0x6E, 0x20, 0x65, 0x72,
0x72, 0x6F, 0x72,
],
),
),
]);
let options_2 = OPT::new(vec![
(
EdnsCode::Unknown(15u16),
EdnsOption::Unknown(
15u16,
vec![
0x00, 0x09, 0x55, 0x6E, 0x6B, 0x6E, 0x6F, 0x77, 0x6E, 0x20, 0x65, 0x72,
0x72, 0x6F, 0x72,
],
),
),
(
EdnsCode::Unknown(15u16),
EdnsOption::Unknown(15u16, vec![0x00, 0x06]),
),
]);
let options_3 = OPT::new(vec![
(
EdnsCode::Unknown(15u16),
EdnsOption::Unknown(
15u16,
vec![
0xff, 0x09, 0x55, 0x6E, 0x6B, 0x6E, 0x6F, 0x77, 0x6E, 0x20, 0x65, 0x72,
0x72, 0x6F, 0x72,
],
),
),
(
EdnsCode::Unknown(15u16),
EdnsOption::Unknown(15u16, vec![0x00, 0x06]),
),
]);
let mut hasher_1 = DefaultHasher::new();
options_1.hash(&mut hasher_1);
let hash_1 = hasher_1.finish();
let mut hasher_2 = DefaultHasher::new();
options_2.hash(&mut hasher_2);
let hash_2 = hasher_2.finish();
let mut hasher_3 = DefaultHasher::new();
options_3.hash(&mut hasher_3);
let hash_3 = hasher_3.finish();
assert_eq!(options_1, options_2);
assert_eq!(hash_1, hash_2);
assert!(options_1 != options_3);
assert!(hash_1 != hash_3);
}
}