#[cfg(feature = "feat-alloc")]
use alloc::vec::Vec;
use core::cmp::min;
use core::iter::FusedIterator;
use core::net::{Ipv4Addr, Ipv6Addr};
use core::num::NonZeroUsize;
use slicur::Reader;
use crate::v2::model::{
AddressPair, Command, ExtensionRef, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE, ADDR_UNIX_SIZE,
BYTE_VERSION, HEADER_SIZE,
};
use crate::v2::Header;
#[derive(Debug)]
pub struct HeaderDecoder;
const MASK_HI: u8 = 0xF0;
const MASK_LO: u8 = 0x0F;
const COMMAND_LOCAL: u8 = Command::Local as u8;
const COMMAND_PROXY: u8 = Command::Proxy as u8;
const FAMILY_UNSPECIFIED: u8 = Family::Unspecified as u8;
const FAMILY_INET: u8 = Family::Inet as u8;
const FAMILY_INET6: u8 = Family::Inet6 as u8;
const FAMILY_UNIX: u8 = Family::Unix as u8;
const PROTOCOL_UNSPECIFIED: u8 = Protocol::Unspecified as u8;
const PROTOCOL_STREAM: u8 = Protocol::Stream as u8;
const PROTOCOL_DGRAM: u8 = Protocol::Dgram as u8;
impl HeaderDecoder {
#[allow(clippy::missing_panics_doc, reason = "XXX")]
#[allow(clippy::too_many_lines, reason = "XXX")]
pub fn decode(buf: &[u8]) -> Result<Decoded<'_>, DecodeError> {
{
let magic_length = min(Header::MAGIC.len(), buf.len());
if buf[..magic_length] != Header::MAGIC[..magic_length] {
return Ok(Decoded::None);
}
}
match HEADER_SIZE.checked_sub(buf.len()).and_then(NonZeroUsize::new) {
None => {}
Some(remaining_bytes) => {
return Ok(Decoded::Partial(remaining_bytes));
}
}
match buf[12] & MASK_HI {
BYTE_VERSION => {}
v => {
return Err(DecodeError::InvalidVersion(v));
}
};
let command = match buf[12] & MASK_LO {
COMMAND_LOCAL => Command::Local,
COMMAND_PROXY => Command::Proxy,
c => {
return Err(DecodeError::InvalidCommand(c));
}
};
let addr_family = match buf[13] & MASK_HI {
FAMILY_UNSPECIFIED => Family::Unspecified,
FAMILY_INET => Family::Inet,
FAMILY_INET6 => Family::Inet6,
FAMILY_UNIX => Family::Unix,
f => {
return Err(DecodeError::InvalidFamily(f));
}
};
let protocol = match buf[13] & MASK_LO {
PROTOCOL_UNSPECIFIED => Protocol::Unspecified,
PROTOCOL_STREAM => Protocol::Stream,
PROTOCOL_DGRAM => Protocol::Dgram,
p => {
return Err(DecodeError::InvalidProtocol(p));
}
};
let remaining_len = u16::from_be_bytes([buf[14], buf[15]]);
let payload = match HEADER_SIZE
.checked_add(remaining_len as usize)
.ok_or(DecodeError::MalformedData)?
.checked_sub(buf.len())
.map(NonZeroUsize::new)
{
Some(None) => &buf[HEADER_SIZE..],
Some(Some(remaining_bytes)) => return Ok(Decoded::Partial(remaining_bytes)),
None => {
return Err(DecodeError::TrailingData);
}
};
let (address_pair, extensions) = match addr_family {
Family::Unspecified => (AddressPair::Unspecified, payload),
Family::Inet => {
if payload.len() < ADDR_INET_SIZE {
return Err(DecodeError::MalformedData);
}
(
AddressPair::Inet {
src_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[0..4]).unwrap()),
dst_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[4..8]).unwrap()),
src_port: u16::from_be_bytes([payload[8], payload[9]]),
dst_port: u16::from_be_bytes([payload[10], payload[11]]),
},
&payload[ADDR_INET_SIZE..],
)
}
Family::Inet6 => {
if payload.len() < ADDR_INET6_SIZE {
return Err(DecodeError::MalformedData);
}
(
AddressPair::Inet6 {
src_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[0..16]).unwrap()),
dst_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[16..32]).unwrap()),
src_port: u16::from_be_bytes([payload[32], payload[33]]),
dst_port: u16::from_be_bytes([payload[34], payload[35]]),
},
&payload[ADDR_INET6_SIZE..],
)
}
Family::Unix => {
if payload.len() < ADDR_UNIX_SIZE {
return Err(DecodeError::MalformedData);
}
(
AddressPair::Unix {
src_addr: payload[0..108].try_into().unwrap(),
dst_addr: payload[108..216].try_into().unwrap(),
},
&payload[ADDR_UNIX_SIZE..],
)
}
};
match command {
Command::Local => Ok(Decoded::Some(DecodedHeader {
header: Header::new_local(),
extensions: DecodedExtensions::const_from(extensions),
})),
Command::Proxy => Ok(Decoded::Some(DecodedHeader {
header: Header::new_proxy(protocol, address_pair),
extensions: DecodedExtensions::const_from(extensions),
})),
}
}
}
#[allow(clippy::large_enum_variant, reason = "XXX")]
#[derive(Debug)]
pub enum Decoded<'a> {
Some(DecodedHeader<'a>),
Partial(NonZeroUsize),
None,
}
#[derive(Debug)]
pub struct DecodedHeader<'a> {
pub header: Header,
pub extensions: DecodedExtensions<'a>,
}
wrapper_lite::wrapper! {
#[wrapper_impl(AsRef<[u8]>)]
#[derive(Debug)]
pub struct DecodedExtensions<'a>(&'a [u8]);
}
impl<'a> DecodedExtensions<'a> {
#[cfg(feature = "feat-alloc")]
pub fn collect(self) -> Result<Vec<ExtensionRef<'a>>, DecodeError> {
self.into_iter().collect()
}
}
impl<'a> IntoIterator for DecodedExtensions<'a> {
type IntoIter = DecodedExtensionsIter<'a>;
type Item = Result<ExtensionRef<'a>, DecodeError>;
fn into_iter(self) -> Self::IntoIter {
DecodedExtensionsIter {
inner: Some(Reader::init(self.inner)),
}
}
}
#[derive(Debug)]
pub struct DecodedExtensionsIter<'a> {
inner: Option<Reader<'a>>,
}
impl<'a> Iterator for DecodedExtensionsIter<'a> {
type Item = Result<ExtensionRef<'a>, DecodeError>;
fn next(&mut self) -> Option<Self::Item> {
match self.inner.as_mut() {
Some(reader) => match ExtensionRef::decode(reader) {
Ok(Some(extension)) => Some(Ok(extension)),
Ok(None) => {
self.inner = None;
None
}
Err(err) => {
self.inner = None;
Some(Err(err))
}
},
None => None,
}
}
}
impl FusedIterator for DecodedExtensionsIter<'_> {}
#[allow(clippy::module_name_repetitions, reason = "XXX")]
#[derive(Debug)]
#[derive(thiserror::Error)]
pub enum DecodeError {
#[error("Invalid PROXY Protocol version: {0}")]
InvalidVersion(u8),
#[error("Invalid PROXY Protocol command: {0}")]
InvalidCommand(u8),
#[error("Invalid proxy address family: {0}")]
InvalidFamily(u8),
#[error("Invalid proxy transport protocol: {0}")]
InvalidProtocol(u8),
#[error("Trailing data after the header")]
TrailingData,
#[error("Malformed data")]
MalformedData,
}