#![allow(clippy::use_self)]
use std::{
cmp::{Ord, Ordering, PartialOrd},
convert::TryFrom,
fmt,
net::Ipv4Addr,
net::Ipv6Addr,
};
#[cfg(feature = "serde-config")]
use serde::{Deserialize, Serialize};
use enum_as_inner::EnumAsInner;
use crate::error::*;
use crate::rr::Name;
use crate::serialize::binary::*;
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct SVCB {
svc_priority: u16,
target_name: Name,
svc_params: Vec<(SvcParamKey, SvcParamValue)>,
}
impl SVCB {
pub fn new(
svc_priority: u16,
target_name: Name,
svc_params: Vec<(SvcParamKey, SvcParamValue)>,
) -> Self {
Self {
svc_priority,
target_name,
svc_params,
}
}
pub fn svc_priority(&self) -> u16 {
self.svc_priority
}
pub fn target_name(&self) -> &Name {
&self.target_name
}
pub fn svc_params(&self) -> &[(SvcParamKey, SvcParamValue)] {
&self.svc_params
}
}
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum SvcParamKey {
Mandatory,
Alpn,
NoDefaultAlpn,
Port,
Ipv4Hint,
EchConfig,
Ipv6Hint,
Key(u16),
Key65535,
Unknown(u16),
}
impl From<u16> for SvcParamKey {
fn from(val: u16) -> Self {
match val {
0 => Self::Mandatory,
1 => Self::Alpn,
2 => Self::NoDefaultAlpn,
3 => Self::Port,
4 => Self::Ipv4Hint,
5 => Self::EchConfig,
6 => Self::Ipv6Hint,
65280..=65534 => Self::Key(val),
65535 => Self::Key65535,
_ => Self::Unknown(val),
}
}
}
impl From<SvcParamKey> for u16 {
fn from(val: SvcParamKey) -> Self {
match val {
SvcParamKey::Mandatory => 0,
SvcParamKey::Alpn => 1,
SvcParamKey::NoDefaultAlpn => 2,
SvcParamKey::Port => 3,
SvcParamKey::Ipv4Hint => 4,
SvcParamKey::EchConfig => 5,
SvcParamKey::Ipv6Hint => 6,
SvcParamKey::Key(val) => val,
SvcParamKey::Key65535 => 65535,
SvcParamKey::Unknown(val) => val,
}
}
}
impl<'r> BinDecodable<'r> for SvcParamKey {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
Ok(decoder.read_u16()?.unverified().into())
}
}
impl BinEncodable for SvcParamKey {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
encoder.emit_u16((*self).into())
}
}
impl fmt::Display for SvcParamKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match *self {
Self::Mandatory => f.write_str("mandatory")?,
Self::Alpn => f.write_str("alpn")?,
Self::NoDefaultAlpn => f.write_str("no-default-alpn")?,
Self::Port => f.write_str("port")?,
Self::Ipv4Hint => f.write_str("ipv4hint")?,
Self::EchConfig => f.write_str("echconfig")?,
Self::Ipv6Hint => f.write_str("ipv6hint")?,
Self::Key(val) => write!(f, "key{}", val)?,
Self::Key65535 => f.write_str("key65535")?,
Self::Unknown(val) => write!(f, "unknown{}", val)?,
}
Ok(())
}
}
impl std::str::FromStr for SvcParamKey {
type Err = ProtoError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
fn parse_unknown_key(key: &str) -> Result<SvcParamKey, ProtoError> {
let key_value = key.strip_prefix("key").ok_or_else(|| {
ProtoError::from(ProtoErrorKind::Msg(format!(
"bad formatted key ({}), expected key1234",
key
)))
})?;
let key_value = u16::from_str(key_value)?;
let key = SvcParamKey::from(key_value);
Ok(key)
}
let key = match s {
"mandatory" => Self::Mandatory,
"alpn" => Self::Alpn,
"no-default-alpn" => Self::NoDefaultAlpn,
"port" => Self::Port,
"ipv4hint" => Self::Ipv4Hint,
"echconfig" => Self::EchConfig,
"ipv6hint" => Self::Ipv6Hint,
"key65535" => Self::Key65535,
_ => parse_unknown_key(s)?,
};
Ok(key)
}
}
impl Ord for SvcParamKey {
fn cmp(&self, other: &Self) -> Ordering {
u16::from(*self).cmp(&u16::from(*other))
}
}
impl PartialOrd for SvcParamKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone, EnumAsInner)]
pub enum SvcParamValue {
Mandatory(Mandatory),
Alpn(Alpn),
NoDefaultAlpn,
Port(u16),
Ipv4Hint(IpHint<Ipv4Addr>),
EchConfig(EchConfig),
Ipv6Hint(IpHint<Ipv6Addr>),
Unknown(Unknown),
}
impl SvcParamValue {
fn read(key: SvcParamKey, decoder: &mut BinDecoder<'_>) -> ProtoResult<Self> {
let len: usize = decoder
.read_u16()?
.verify_unwrap(|len| *len as usize <= decoder.len())
.map(|len| len as usize)
.map_err(|u| {
ProtoError::from(format!(
"length of SvcParamValue ({}) exceeds remainder in RDATA ({})",
u,
decoder.len()
))
})?;
let param_data = decoder.read_slice(len)?.unverified();
let mut decoder = BinDecoder::new(param_data);
let value = match key {
SvcParamKey::Mandatory => Self::Mandatory(Mandatory::read(&mut decoder)?),
SvcParamKey::Alpn => Self::Alpn(Alpn::read(&mut decoder)?),
SvcParamKey::NoDefaultAlpn => {
if len > 0 {
return Err(ProtoError::from("Alpn expects at least one value"));
}
Self::NoDefaultAlpn
}
SvcParamKey::Port => {
let port = decoder.read_u16()?.unverified();
Self::Port(port)
}
SvcParamKey::Ipv4Hint => Self::Ipv4Hint(IpHint::<Ipv4Addr>::read(&mut decoder)?),
SvcParamKey::EchConfig => Self::EchConfig(EchConfig::read(&mut decoder)?),
SvcParamKey::Ipv6Hint => Self::Ipv6Hint(IpHint::<Ipv6Addr>::read(&mut decoder)?),
SvcParamKey::Key(_) | SvcParamKey::Key65535 | SvcParamKey::Unknown(_) => {
Self::Unknown(Unknown::read(&mut decoder)?)
}
};
Ok(value)
}
}
impl BinEncodable for SvcParamValue {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
let place = encoder.place::<u16>()?;
match self {
Self::Mandatory(mandatory) => mandatory.emit(encoder)?,
Self::Alpn(alpn) => alpn.emit(encoder)?,
Self::NoDefaultAlpn => (),
Self::Port(port) => encoder.emit_u16(*port)?,
Self::Ipv4Hint(ip_hint) => ip_hint.emit(encoder)?,
Self::EchConfig(ech_config) => ech_config.emit(encoder)?,
Self::Ipv6Hint(ip_hint) => ip_hint.emit(encoder)?,
Self::Unknown(unknown) => unknown.emit(encoder)?,
}
let len = u16::try_from(encoder.len_since_place(&place))
.map_err(|_| ProtoError::from("Total length of SvcParamValue exceeds u16::MAX"))?;
place.replace(encoder, len)?;
Ok(())
}
}
impl fmt::Display for SvcParamValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
Self::Mandatory(mandatory) => write!(f, "{}", mandatory)?,
Self::Alpn(alpn) => write!(f, "{}", alpn)?,
Self::NoDefaultAlpn => (),
Self::Port(port) => write!(f, "{}", port)?,
Self::Ipv4Hint(ip_hint) => write!(f, "{}", ip_hint)?,
Self::EchConfig(ech_config) => write!(f, "{}", ech_config)?,
Self::Ipv6Hint(ip_hint) => write!(f, "{}", ip_hint)?,
Self::Unknown(unknown) => write!(f, "{}", unknown)?,
}
Ok(())
}
}
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[repr(transparent)]
pub struct Mandatory(pub Vec<SvcParamKey>);
impl<'r> BinDecodable<'r> for Mandatory {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
let mut keys = Vec::with_capacity(1);
while decoder.peek().is_some() {
keys.push(SvcParamKey::read(decoder)?);
}
if keys.is_empty() {
return Err(ProtoError::from("Mandatory expects at least one value"));
}
Ok(Self(keys))
}
}
impl BinEncodable for Mandatory {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
if self.0.is_empty() {
return Err(ProtoError::from("Alpn expects at least one value"));
}
for key in self.0.iter() {
key.emit(encoder)?
}
Ok(())
}
}
impl fmt::Display for Mandatory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
for key in self.0.iter() {
write!(f, "{},", key)?;
}
Ok(())
}
}
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[repr(transparent)]
pub struct Alpn(pub Vec<String>);
impl<'r> BinDecodable<'r> for Alpn {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
let mut alpns = Vec::with_capacity(1);
while decoder.peek().is_some() {
let alpn = decoder.read_character_data()?.unverified();
let alpn = String::from_utf8(alpn.to_vec())?;
alpns.push(alpn);
}
if alpns.is_empty() {
return Err(ProtoError::from("Alpn expects at least one value"));
}
Ok(Self(alpns))
}
}
impl BinEncodable for Alpn {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
if self.0.is_empty() {
return Err(ProtoError::from("Alpn expects at least one value"));
}
for alpn in self.0.iter() {
encoder.emit_character_data(alpn)?
}
Ok(())
}
}
impl fmt::Display for Alpn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
for alpn in self.0.iter() {
write!(f, "{},", alpn)?;
}
Ok(())
}
}
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(PartialEq, Eq, Hash, Clone)]
#[repr(transparent)]
pub struct EchConfig(pub Vec<u8>);
impl<'r> BinDecodable<'r> for EchConfig {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
let redundant_len = decoder
.read_u16()?
.map(|len| len as usize)
.verify_unwrap(|len| *len <= decoder.len())
.map_err(|_| ProtoError::from("ECH value length exceeds max size of u16::MAX"))?;
let data =
decoder.read_vec(redundant_len)?.unverified();
Ok(Self(data))
}
}
impl BinEncodable for EchConfig {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
let len = u16::try_from(self.0.len())
.map_err(|_| ProtoError::from("ECH value length exceeds max size of u16::MAX"))?;
encoder.emit_u16(len)?;
encoder.emit_vec(&self.0)?;
Ok(())
}
}
impl fmt::Display for EchConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "\"{}\"", data_encoding::BASE64.encode(&self.0))
}
}
impl fmt::Debug for EchConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
f,
"\"EchConfig ({})\"",
data_encoding::BASE64.encode(&self.0)
)
}
}
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[repr(transparent)]
pub struct IpHint<T>(pub Vec<T>);
impl<'r, T> BinDecodable<'r> for IpHint<T>
where
T: BinDecodable<'r>,
{
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
let mut ips = Vec::new();
while decoder.peek().is_some() {
ips.push(T::read(decoder)?)
}
Ok(Self(ips))
}
}
impl<T> BinEncodable for IpHint<T>
where
T: BinEncodable,
{
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
for ip in self.0.iter() {
ip.emit(encoder)?;
}
Ok(())
}
}
impl<T> fmt::Display for IpHint<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
for ip in self.0.iter() {
write!(f, "{},", ip)?;
}
Ok(())
}
}
#[cfg_attr(feature = "serde-config", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[repr(transparent)]
pub struct Unknown(pub Vec<u8>);
impl<'r> BinDecodable<'r> for Unknown {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
let len = decoder.len();
let data = decoder.read_vec(len)?;
let unknowns = data.unverified().to_vec();
Ok(Self(unknowns))
}
}
impl BinEncodable for Unknown {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
encoder.emit_character_data(&self.0)?;
Ok(())
}
}
impl fmt::Display for Unknown {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "\"{}\",", String::from_utf8_lossy(&self.0))?;
Ok(())
}
}
pub fn read(decoder: &mut BinDecoder<'_>, rdata_length: Restrict<u16>) -> ProtoResult<SVCB> {
let start_index = decoder.index();
let svc_priority = decoder.read_u16()?.unverified();
let target_name = Name::read(decoder)?;
let mut remainder_len = rdata_length
.map(|len| len as usize)
.checked_sub(decoder.index() - start_index)
.map_err(|len| format!("Bad length for RDATA of SVCB: {}", len))?
.unverified(); let mut svc_params: Vec<(SvcParamKey, SvcParamValue)> = Vec::new();
while remainder_len >= 4 {
let key = SvcParamKey::read(decoder)?;
let value = SvcParamValue::read(key, decoder)?;
if let Some(last_key) = svc_params.last().map(|(key, _)| key) {
if last_key >= &key {
return Err(ProtoError::from("SvcParams out of order"));
}
}
svc_params.push((key, value));
remainder_len = rdata_length
.map(|len| len as usize)
.checked_sub(decoder.index() - start_index)
.map_err(|len| format!("Bad length for RDATA of SVCB: {}", len))?
.unverified(); }
Ok(SVCB {
svc_priority,
target_name,
svc_params,
})
}
pub fn emit(encoder: &mut BinEncoder<'_>, svcb: &SVCB) -> ProtoResult<()> {
svcb.svc_priority.emit(encoder)?;
svcb.target_name.emit(encoder)?;
let mut last_key: Option<SvcParamKey> = None;
for (key, param) in svcb.svc_params.iter() {
if let Some(last_key) = last_key {
if key <= &last_key {
return Err(ProtoError::from("SvcParams out of order"));
}
}
key.emit(encoder)?;
param.emit(encoder)?;
last_key = Some(*key);
}
Ok(())
}
impl fmt::Display for SVCB {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
f,
"{svc_priority} {target_name}",
svc_priority = self.svc_priority,
target_name = self.target_name,
)?;
for (key, param) in self.svc_params.iter() {
write!(f, " {key}={param}", key = key, param = param)?
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_svcb_key() {
assert_eq!(SvcParamKey::Mandatory, 0.into());
assert_eq!(SvcParamKey::Alpn, 1.into());
assert_eq!(SvcParamKey::NoDefaultAlpn, 2.into());
assert_eq!(SvcParamKey::Port, 3.into());
assert_eq!(SvcParamKey::Ipv4Hint, 4.into());
assert_eq!(SvcParamKey::EchConfig, 5.into());
assert_eq!(SvcParamKey::Ipv6Hint, 6.into());
assert_eq!(SvcParamKey::Key(65280), 65280.into());
assert_eq!(SvcParamKey::Key(65534), 65534.into());
assert_eq!(SvcParamKey::Key65535, 65535.into());
assert_eq!(SvcParamKey::Unknown(65279), 65279.into());
}
#[test]
fn read_svcb_key_to_u16() {
assert_eq!(u16::from(SvcParamKey::Mandatory), 0);
assert_eq!(u16::from(SvcParamKey::Alpn), 1);
assert_eq!(u16::from(SvcParamKey::NoDefaultAlpn), 2);
assert_eq!(u16::from(SvcParamKey::Port), 3);
assert_eq!(u16::from(SvcParamKey::Ipv4Hint), 4);
assert_eq!(u16::from(SvcParamKey::EchConfig), 5);
assert_eq!(u16::from(SvcParamKey::Ipv6Hint), 6);
assert_eq!(u16::from(SvcParamKey::Key(65280)), 65280);
assert_eq!(u16::from(SvcParamKey::Key(65534)), 65534);
assert_eq!(u16::from(SvcParamKey::Key65535), 65535);
assert_eq!(u16::from(SvcParamKey::Unknown(65279)), 65279);
}
#[track_caller]
fn test_encode_decode(rdata: SVCB) {
let mut bytes = Vec::new();
let mut encoder: BinEncoder<'_> = BinEncoder::new(&mut bytes);
emit(&mut encoder, &rdata).expect("failed to emit SVCB");
let bytes = encoder.into_bytes();
println!("svcb: {}", rdata);
println!("bytes: {:?}", bytes);
let mut decoder: BinDecoder<'_> = BinDecoder::new(bytes);
let read_rdata =
read(&mut decoder, Restrict::new(bytes.len() as u16)).expect("failed to read back");
assert_eq!(rdata, read_rdata);
}
#[test]
fn test_encode_decode_svcb() {
test_encode_decode(SVCB::new(
0,
Name::from_utf8("www.example.com.").unwrap(),
vec![],
));
test_encode_decode(SVCB::new(
0,
Name::from_utf8(".").unwrap(),
vec![(
SvcParamKey::Alpn,
SvcParamValue::Alpn(Alpn(vec!["h2".to_string()])),
)],
));
test_encode_decode(SVCB::new(
0,
Name::from_utf8("example.com.").unwrap(),
vec![
(
SvcParamKey::Mandatory,
SvcParamValue::Mandatory(Mandatory(vec![SvcParamKey::Alpn])),
),
(
SvcParamKey::Alpn,
SvcParamValue::Alpn(Alpn(vec!["h2".to_string()])),
),
],
));
}
#[test]
#[should_panic]
fn test_encode_decode_svcb_bad_order() {
test_encode_decode(SVCB::new(
0,
Name::from_utf8(".").unwrap(),
vec![
(
SvcParamKey::Alpn,
SvcParamValue::Alpn(Alpn(vec!["h2".to_string()])),
),
(
SvcParamKey::Mandatory,
SvcParamValue::Mandatory(Mandatory(vec![SvcParamKey::Alpn])),
),
],
));
}
#[test]
fn test_no_panic() {
const BUF: &[u8] = &[
255, 121, 0, 0, 0, 0, 40, 255, 255, 160, 160, 0, 0, 0, 64, 0, 1, 255, 158, 0, 0, 0, 8,
0, 0, 7, 7, 0, 0, 0, 0, 0, 0, 0,
];
assert!(crate::op::Message::from_vec(BUF).is_err());
}
}