proxy_protocol_codec/v2/codec/
decode.rs1#[cfg(feature = "feat-alloc")]
4use alloc::vec::Vec;
5use core::cmp::min;
6use core::iter::FusedIterator;
7use core::net::{Ipv4Addr, Ipv6Addr};
8use core::num::NonZeroUsize;
9
10use slicur::Reader;
11
12use crate::v2::model::{
13 AddressPair, Command, ExtensionRef, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE, ADDR_UNIX_SIZE,
14 BYTE_VERSION, HEADER_SIZE,
15};
16use crate::v2::Header;
17
18#[derive(Debug)]
19pub struct HeaderDecoder;
23
24const MASK_HI: u8 = 0xF0;
27
28const MASK_LO: u8 = 0x0F;
31
32const COMMAND_LOCAL: u8 = Command::Local as u8;
34
35const COMMAND_PROXY: u8 = Command::Proxy as u8;
37
38const FAMILY_UNSPECIFIED: u8 = Family::Unspecified as u8;
40
41const FAMILY_INET: u8 = Family::Inet as u8;
43
44const FAMILY_INET6: u8 = Family::Inet6 as u8;
46
47const FAMILY_UNIX: u8 = Family::Unix as u8;
49
50const PROTOCOL_UNSPECIFIED: u8 = Protocol::Unspecified as u8;
52
53const PROTOCOL_STREAM: u8 = Protocol::Stream as u8;
55
56const PROTOCOL_DGRAM: u8 = Protocol::Dgram as u8;
58
59impl HeaderDecoder {
60 #[allow(clippy::missing_panics_doc)]
61 pub fn decode<'a>(buf: &'a [u8]) -> Result<Decoded<'a>, DecodeError> {
87 {
89 let magic_length = min(Header::MAGIC.len(), buf.len());
90
91 if buf[..magic_length] != Header::MAGIC[..magic_length] {
92 return Ok(Decoded::None);
93 }
94 }
95
96 match HEADER_SIZE.checked_sub(buf.len()).and_then(NonZeroUsize::new) {
98 None => {}
99 Some(remaining_bytes) => {
100 #[cfg(feature = "feat-nightly")]
102 core::hint::cold_path();
103
104 return Ok(Decoded::Partial(remaining_bytes));
105 }
106 }
107
108 match buf[12] & MASK_HI {
110 BYTE_VERSION => {}
111 v => {
112 #[cfg(feature = "feat-nightly")]
113 core::hint::cold_path();
114
115 return Err(DecodeError::InvalidVersion(v));
116 }
117 };
118
119 let command = match buf[12] & MASK_LO {
121 COMMAND_LOCAL => Command::Local,
122 COMMAND_PROXY => Command::Proxy,
123 c => {
124 #[cfg(feature = "feat-nightly")]
125 core::hint::cold_path();
126
127 return Err(DecodeError::InvalidCommand(c));
128 }
129 };
130
131 let addr_family = match buf[13] & MASK_HI {
133 FAMILY_UNSPECIFIED => Family::Unspecified,
134 FAMILY_INET => Family::Inet,
135 FAMILY_INET6 => Family::Inet6,
136 FAMILY_UNIX => Family::Unix,
137 f => {
138 #[cfg(feature = "feat-nightly")]
139 core::hint::cold_path();
140
141 return Err(DecodeError::InvalidFamily(f));
142 }
143 };
144
145 let protocol = match buf[13] & MASK_LO {
147 PROTOCOL_UNSPECIFIED => Protocol::Unspecified,
148 PROTOCOL_STREAM => Protocol::Stream,
149 PROTOCOL_DGRAM => Protocol::Dgram,
150 p => {
151 #[cfg(feature = "feat-nightly")]
152 core::hint::cold_path();
153
154 return Err(DecodeError::InvalidProtocol(p));
155 }
156 };
157
158 let remaining_len = u16::from_be_bytes([buf[14], buf[15]]);
160
161 let payload = match HEADER_SIZE
163 .checked_add(remaining_len as usize)
164 .ok_or(DecodeError::MalformedData)?
165 .checked_sub(buf.len())
166 .map(NonZeroUsize::new)
167 {
168 Some(None) => &buf[HEADER_SIZE..],
169 Some(Some(remaining_bytes)) => return Ok(Decoded::Partial(remaining_bytes)),
170 None => {
171 #[cfg(feature = "feat-nightly")]
172 core::hint::cold_path();
173
174 return Err(DecodeError::TrailingData);
176 }
177 };
178
179 let (address_pair, extensions) = match addr_family {
180 Family::Unspecified => (AddressPair::Unspecified, payload),
181 Family::Inet => {
182 if payload.len() < ADDR_INET_SIZE {
183 #[cfg(feature = "feat-nightly")]
184 core::hint::cold_path();
185
186 return Err(DecodeError::MalformedData);
187 }
188
189 (
190 AddressPair::Inet {
191 src_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[0..4]).unwrap()),
192 dst_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[4..8]).unwrap()),
193 src_port: u16::from_be_bytes([payload[8], payload[9]]),
194 dst_port: u16::from_be_bytes([payload[10], payload[11]]),
195 },
196 &payload[ADDR_INET_SIZE..],
197 )
198 }
199 Family::Inet6 => {
200 if payload.len() < ADDR_INET6_SIZE {
201 #[cfg(feature = "feat-nightly")]
202 core::hint::cold_path();
203
204 return Err(DecodeError::MalformedData);
205 }
206
207 (
208 AddressPair::Inet6 {
209 src_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[0..16]).unwrap()),
210 dst_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[16..32]).unwrap()),
211 src_port: u16::from_be_bytes([payload[32], payload[33]]),
212 dst_port: u16::from_be_bytes([payload[34], payload[35]]),
213 },
214 &payload[ADDR_INET6_SIZE..],
215 )
216 }
217 Family::Unix => {
218 if payload.len() < ADDR_UNIX_SIZE {
219 #[cfg(feature = "feat-nightly")]
220 core::hint::cold_path();
221
222 return Err(DecodeError::MalformedData);
223 }
224
225 (
226 AddressPair::Unix {
227 src_addr: payload[0..108].try_into().unwrap(),
228 dst_addr: payload[108..216].try_into().unwrap(),
229 },
230 &payload[ADDR_UNIX_SIZE..],
231 )
232 }
233 };
234
235 match command {
236 Command::Local => Ok(Decoded::Some(DecodedHeader {
237 header: Header::new_local(),
238 extensions: DecodedExtensions::const_from(extensions),
239 })),
240 Command::Proxy => Ok(Decoded::Some(DecodedHeader {
241 header: Header::new_proxy(protocol, address_pair),
242 extensions: DecodedExtensions::const_from(extensions),
243 })),
244 }
245 }
246}
247
248#[allow(clippy::large_enum_variant)]
249#[derive(Debug)]
250pub enum Decoded<'a> {
252 Some(DecodedHeader<'a>),
254
255 Partial(NonZeroUsize),
257
258 None,
260}
261
262#[derive(Debug)]
263pub struct DecodedHeader<'a> {
265 pub header: Header,
267
268 pub extensions: DecodedExtensions<'a>,
270}
271
272wrapper_lite::wrapper! {
273 #[wrapper_impl(Deref)]
274 #[derive(Debug)]
275 pub DecodedExtensions<'a>(&'a [u8])
281}
282
283impl<'a> DecodedExtensions<'a> {
284 #[cfg(feature = "feat-alloc")]
285 pub fn collect(self) -> Result<Vec<ExtensionRef<'a>>, DecodeError> {
288 self.into_iter().collect()
289 }
290}
291
292impl<'a> IntoIterator for DecodedExtensions<'a> {
293 type IntoIter = DecodedExtensionsIter<'a>;
294 type Item = Result<ExtensionRef<'a>, DecodeError>;
295
296 fn into_iter(self) -> Self::IntoIter {
297 DecodedExtensionsIter {
298 inner: Some(Reader::init(self.inner)),
299 }
300 }
301}
302
303#[derive(Debug)]
304pub struct DecodedExtensionsIter<'a> {
313 inner: Option<Reader<'a>>,
314}
315
316impl<'a> Iterator for DecodedExtensionsIter<'a> {
317 type Item = Result<ExtensionRef<'a>, DecodeError>;
318
319 fn next(&mut self) -> Option<Self::Item> {
320 match self.inner.as_mut() {
321 Some(reader) => match ExtensionRef::decode(reader) {
322 Ok(Some(extension)) => Some(Ok(extension)),
323 Ok(None) => {
324 self.inner = None;
326
327 None
328 }
329 Err(err) => {
330 self.inner = None;
332
333 Some(Err(err))
334 }
335 },
336 None => None,
337 }
338 }
339}
340
341impl<'a> FusedIterator for DecodedExtensionsIter<'a> {}
342
343#[derive(Debug)]
344#[derive(thiserror::Error)]
345pub enum DecodeError {
347 #[error("Invalid PROXY Protocol version: {0}")]
348 InvalidVersion(u8),
350
351 #[error("Invalid PROXY Protocol command: {0}")]
352 InvalidCommand(u8),
354
355 #[error("Invalid proxy address family: {0}")]
356 InvalidFamily(u8),
358
359 #[error("Invalid proxy transport protocol: {0}")]
360 InvalidProtocol(u8),
362
363 #[error("Trailing data after the header")]
364 TrailingData,
366
367 #[error("Malformed data")]
368 MalformedData,
371}