proxy_protocol_codec/v2/codec/
decode.rs1use core::cmp::min;
4use core::net::{Ipv4Addr, Ipv6Addr};
5use core::num::NonZeroUsize;
6
7use slicur::Reader;
8
9use crate::v2::model::{
10 AddressPair, Command, ExtensionRef, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE, ADDR_UNIX_SIZE,
11 BYTE_VERSION, HEADER_SIZE,
12};
13use crate::v2::Header;
14
15#[derive(Debug)]
16pub struct HeaderDecoder;
20
21const MASK_HI: u8 = 0xF0;
24
25const MASK_LO: u8 = 0x0F;
28
29const COMMAND_LOCAL: u8 = Command::Local as u8;
31
32const COMMAND_PROXY: u8 = Command::Proxy as u8;
34
35const FAMILY_UNSPECIFIED: u8 = Family::Unspecified as u8;
37
38const FAMILY_INET: u8 = Family::Inet as u8;
40
41const FAMILY_INET6: u8 = Family::Inet6 as u8;
43
44const FAMILY_UNIX: u8 = Family::Unix as u8;
46
47const PROTOCOL_UNSPECIFIED: u8 = Protocol::Unspecified as u8;
49
50const PROTOCOL_STREAM: u8 = Protocol::Stream as u8;
52
53const PROTOCOL_DGRAM: u8 = Protocol::Dgram as u8;
55
56impl HeaderDecoder {
57 #[allow(clippy::missing_panics_doc)]
58 pub fn decode<'a>(buf: &'a [u8]) -> Result<Decoded<'a>, DecodeError> {
81 {
83 let magic_length = min(Header::MAGIC.len(), buf.len());
84
85 if buf[..magic_length] != Header::MAGIC[..magic_length] {
86 return Ok(Decoded::None);
87 }
88 }
89
90 match HEADER_SIZE.checked_sub(buf.len()).and_then(NonZeroUsize::new) {
92 None => {}
93 Some(remaining_bytes) => {
94 #[cfg(feature = "feat-nightly")]
96 core::hint::cold_path();
97
98 return Ok(Decoded::Partial(remaining_bytes));
99 }
100 }
101
102 match buf[12] & MASK_HI {
104 BYTE_VERSION => {}
105 v => {
106 #[cfg(feature = "feat-nightly")]
107 core::hint::cold_path();
108
109 return Err(DecodeError::InvalidVersion(v));
110 }
111 };
112
113 let command = match buf[12] & MASK_LO {
115 COMMAND_LOCAL => Command::Local,
116 COMMAND_PROXY => Command::Proxy,
117 c => {
118 #[cfg(feature = "feat-nightly")]
119 core::hint::cold_path();
120
121 return Err(DecodeError::InvalidCommand(c));
122 }
123 };
124
125 let addr_family = match buf[13] & MASK_HI {
127 FAMILY_UNSPECIFIED => Family::Unspecified,
128 FAMILY_INET => Family::Inet,
129 FAMILY_INET6 => Family::Inet6,
130 FAMILY_UNIX => Family::Unix,
131 f => {
132 #[cfg(feature = "feat-nightly")]
133 core::hint::cold_path();
134
135 return Err(DecodeError::InvalidFamily(f));
136 }
137 };
138
139 let protocol = match buf[13] & MASK_LO {
141 PROTOCOL_UNSPECIFIED => Protocol::Unspecified,
142 PROTOCOL_STREAM => Protocol::Stream,
143 PROTOCOL_DGRAM => Protocol::Dgram,
144 p => {
145 #[cfg(feature = "feat-nightly")]
146 core::hint::cold_path();
147
148 return Err(DecodeError::InvalidProtocol(p));
149 }
150 };
151
152 let remaining_len = u16::from_be_bytes([buf[14], buf[15]]);
154
155 let payload = match HEADER_SIZE
157 .checked_add(remaining_len as usize)
158 .ok_or(DecodeError::MalformedData)?
159 .checked_sub(buf.len())
160 .map(NonZeroUsize::new)
161 {
162 Some(None) => &buf[HEADER_SIZE..],
163 Some(Some(remaining_bytes)) => return Ok(Decoded::Partial(remaining_bytes)),
164 None => {
165 #[cfg(feature = "feat-nightly")]
166 core::hint::cold_path();
167
168 return Err(DecodeError::TrailingData);
170 }
171 };
172
173 let (address_pair, extensions) = match addr_family {
174 Family::Unspecified => (AddressPair::Unspecified, payload),
175 Family::Inet => {
176 if payload.len() < ADDR_INET_SIZE {
177 #[cfg(feature = "feat-nightly")]
178 core::hint::cold_path();
179
180 return Err(DecodeError::MalformedData);
181 }
182
183 (
184 AddressPair::Inet {
185 src_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[0..4]).unwrap()),
186 dst_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[4..8]).unwrap()),
187 src_port: u16::from_be_bytes([payload[8], payload[9]]),
188 dst_port: u16::from_be_bytes([payload[10], payload[11]]),
189 },
190 &payload[ADDR_INET_SIZE..],
191 )
192 }
193 Family::Inet6 => {
194 if payload.len() < ADDR_INET6_SIZE {
195 #[cfg(feature = "feat-nightly")]
196 core::hint::cold_path();
197
198 return Err(DecodeError::MalformedData);
199 }
200
201 (
202 AddressPair::Inet6 {
203 src_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[0..16]).unwrap()),
204 dst_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[16..32]).unwrap()),
205 src_port: u16::from_be_bytes([payload[32], payload[33]]),
206 dst_port: u16::from_be_bytes([payload[34], payload[35]]),
207 },
208 &payload[ADDR_INET6_SIZE..],
209 )
210 }
211 Family::Unix => {
212 if payload.len() < ADDR_UNIX_SIZE {
213 #[cfg(feature = "feat-nightly")]
214 core::hint::cold_path();
215
216 return Err(DecodeError::MalformedData);
217 }
218
219 (
220 AddressPair::Unix {
221 src_addr: payload[0..108].try_into().unwrap(),
222 dst_addr: payload[108..216].try_into().unwrap(),
223 },
224 &payload[ADDR_UNIX_SIZE..],
225 )
226 }
227 };
228
229 match command {
230 Command::Local => Ok(Decoded::Some(DecodedHeader {
231 header: Header::new_local(),
232 extensions: DecodedExtensions::const_from(extensions),
233 })),
234 Command::Proxy => Ok(Decoded::Some(DecodedHeader {
235 header: Header::new_proxy(protocol, address_pair),
236 extensions: DecodedExtensions::const_from(extensions),
237 })),
238 }
239 }
240}
241
242#[allow(clippy::large_enum_variant)]
243#[derive(Debug)]
244pub enum Decoded<'a> {
246 Some(DecodedHeader<'a>),
248
249 Partial(NonZeroUsize),
251
252 None,
254}
255
256#[derive(Debug)]
257pub struct DecodedHeader<'a> {
259 pub header: Header,
261
262 pub extensions: DecodedExtensions<'a>,
264}
265
266wrapper_lite::wrapper! {
267 #[wrapper_impl(Deref)]
268 #[derive(Debug)]
269 pub DecodedExtensions<'a>(&'a [u8])
275}
276
277impl<'a> IntoIterator for DecodedExtensions<'a> {
278 type IntoIter = DecodedExtensionsIter<'a>;
279 type Item = Result<ExtensionRef<'a>, DecodeError>;
280
281 fn into_iter(self) -> Self::IntoIter {
282 DecodedExtensionsIter {
283 inner: Reader::init(self.inner),
284 }
285 }
286}
287
288#[derive(Debug)]
289pub struct DecodedExtensionsIter<'a> {
291 inner: Reader<'a>,
292}
293
294impl<'a> Iterator for DecodedExtensionsIter<'a> {
295 type Item = Result<ExtensionRef<'a>, DecodeError>;
296
297 fn next(&mut self) -> Option<Self::Item> {
298 ExtensionRef::decode(&mut self.inner).transpose()
299 }
300}
301
302#[derive(Debug)]
303#[derive(thiserror::Error)]
304pub enum DecodeError {
306 #[error("Invalid PROXY Protocol version: {0}")]
307 InvalidVersion(u8),
309
310 #[error("Invalid PROXY Protocol command: {0}")]
311 InvalidCommand(u8),
313
314 #[error("Invalid proxy address family: {0}")]
315 InvalidFamily(u8),
317
318 #[error("Invalid proxy transport protocol: {0}")]
319 InvalidProtocol(u8),
321
322 #[error("Trailing data after the header")]
323 TrailingData,
325
326 #[error("Malformed data")]
327 MalformedData,
330}