proxy_protocol_codec/v2/codec/
decode.rs1#[cfg(feature = "feat-alloc")]
4use alloc::vec::Vec;
5use core::cmp::min;
6use core::iter::FusedIterator;
7use core::net::{Ipv4Addr, Ipv6Addr};
8use core::num::NonZeroUsize;
9
10use slicur::Reader;
11
12use crate::v2::model::{
13 AddressPair, Command, ExtensionRef, Family, Protocol, ADDR_INET6_SIZE, ADDR_INET_SIZE, ADDR_UNIX_SIZE,
14 BYTE_VERSION, HEADER_SIZE,
15};
16use crate::v2::Header;
17
18#[derive(Debug)]
19pub struct HeaderDecoder;
23
24const MASK_HI: u8 = 0xF0;
27
28const MASK_LO: u8 = 0x0F;
31
32const COMMAND_LOCAL: u8 = Command::Local as u8;
34
35const COMMAND_PROXY: u8 = Command::Proxy as u8;
37
38const FAMILY_UNSPECIFIED: u8 = Family::Unspecified as u8;
40
41const FAMILY_INET: u8 = Family::Inet as u8;
43
44const FAMILY_INET6: u8 = Family::Inet6 as u8;
46
47const FAMILY_UNIX: u8 = Family::Unix as u8;
49
50const PROTOCOL_UNSPECIFIED: u8 = Protocol::Unspecified as u8;
52
53const PROTOCOL_STREAM: u8 = Protocol::Stream as u8;
55
56const PROTOCOL_DGRAM: u8 = Protocol::Dgram as u8;
58
59impl HeaderDecoder {
60 #[allow(clippy::missing_panics_doc, reason = "XXX")]
61 #[allow(clippy::too_many_lines, reason = "XXX")]
62 pub fn decode(buf: &[u8]) -> Result<Decoded<'_>, DecodeError> {
92 {
94 let magic_length = min(Header::MAGIC.len(), buf.len());
95
96 if buf[..magic_length] != Header::MAGIC[..magic_length] {
97 return Ok(Decoded::None);
98 }
99 }
100
101 match HEADER_SIZE.checked_sub(buf.len()).and_then(NonZeroUsize::new) {
103 None => {}
104 Some(remaining_bytes) => {
105 return Ok(Decoded::Partial(remaining_bytes));
108 }
109 }
110
111 match buf[12] & MASK_HI {
113 BYTE_VERSION => {}
114 v => {
115 return Err(DecodeError::InvalidVersion(v));
116 }
117 };
118
119 let command = match buf[12] & MASK_LO {
121 COMMAND_LOCAL => Command::Local,
122 COMMAND_PROXY => Command::Proxy,
123 c => {
124 return Err(DecodeError::InvalidCommand(c));
125 }
126 };
127
128 let addr_family = match buf[13] & MASK_HI {
130 FAMILY_UNSPECIFIED => Family::Unspecified,
131 FAMILY_INET => Family::Inet,
132 FAMILY_INET6 => Family::Inet6,
133 FAMILY_UNIX => Family::Unix,
134 f => {
135 return Err(DecodeError::InvalidFamily(f));
136 }
137 };
138
139 let protocol = match buf[13] & MASK_LO {
141 PROTOCOL_UNSPECIFIED => Protocol::Unspecified,
142 PROTOCOL_STREAM => Protocol::Stream,
143 PROTOCOL_DGRAM => Protocol::Dgram,
144 p => {
145 return Err(DecodeError::InvalidProtocol(p));
146 }
147 };
148
149 let remaining_len = u16::from_be_bytes([buf[14], buf[15]]);
151
152 let payload = match HEADER_SIZE
154 .checked_add(remaining_len as usize)
155 .ok_or(DecodeError::MalformedData)?
156 .checked_sub(buf.len())
157 .map(NonZeroUsize::new)
158 {
159 Some(None) => &buf[HEADER_SIZE..],
160 Some(Some(remaining_bytes)) => return Ok(Decoded::Partial(remaining_bytes)),
161 None => {
162 return Err(DecodeError::TrailingData);
164 }
165 };
166
167 let (address_pair, extensions) = match addr_family {
168 Family::Unspecified => (AddressPair::Unspecified, payload),
169 Family::Inet => {
170 if payload.len() < ADDR_INET_SIZE {
171 return Err(DecodeError::MalformedData);
172 }
173
174 (
175 AddressPair::Inet {
176 src_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[0..4]).unwrap()),
177 dst_ip: Ipv4Addr::from(TryInto::<[u8; 4]>::try_into(&payload[4..8]).unwrap()),
178 src_port: u16::from_be_bytes([payload[8], payload[9]]),
179 dst_port: u16::from_be_bytes([payload[10], payload[11]]),
180 },
181 &payload[ADDR_INET_SIZE..],
182 )
183 }
184 Family::Inet6 => {
185 if payload.len() < ADDR_INET6_SIZE {
186 return Err(DecodeError::MalformedData);
187 }
188
189 (
190 AddressPair::Inet6 {
191 src_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[0..16]).unwrap()),
192 dst_ip: Ipv6Addr::from(TryInto::<[u8; 16]>::try_into(&payload[16..32]).unwrap()),
193 src_port: u16::from_be_bytes([payload[32], payload[33]]),
194 dst_port: u16::from_be_bytes([payload[34], payload[35]]),
195 },
196 &payload[ADDR_INET6_SIZE..],
197 )
198 }
199 Family::Unix => {
200 if payload.len() < ADDR_UNIX_SIZE {
201 return Err(DecodeError::MalformedData);
202 }
203
204 (
205 AddressPair::Unix {
206 src_addr: payload[0..108].try_into().unwrap(),
207 dst_addr: payload[108..216].try_into().unwrap(),
208 },
209 &payload[ADDR_UNIX_SIZE..],
210 )
211 }
212 };
213
214 match command {
215 Command::Local => Ok(Decoded::Some(DecodedHeader {
216 header: Header::new_local(),
217 extensions: DecodedExtensions::const_from(extensions),
218 })),
219 Command::Proxy => Ok(Decoded::Some(DecodedHeader {
220 header: Header::new_proxy(protocol, address_pair),
221 extensions: DecodedExtensions::const_from(extensions),
222 })),
223 }
224 }
225}
226
227#[allow(clippy::large_enum_variant, reason = "XXX")]
228#[derive(Debug)]
229pub enum Decoded<'a> {
231 Some(DecodedHeader<'a>),
233
234 Partial(NonZeroUsize),
236
237 None,
239}
240
241#[derive(Debug)]
242pub struct DecodedHeader<'a> {
244 pub header: Header,
246
247 pub extensions: DecodedExtensions<'a>,
249}
250
251wrapper_lite::wrapper! {
252 #[wrapper_impl(AsRef<[u8]>)]
253 #[derive(Debug)]
254 pub struct DecodedExtensions<'a>(&'a [u8]);
260}
261
262impl<'a> DecodedExtensions<'a> {
263 #[cfg(feature = "feat-alloc")]
264 pub fn collect(self) -> Result<Vec<ExtensionRef<'a>>, DecodeError> {
271 self.into_iter().collect()
272 }
273}
274
275impl<'a> IntoIterator for DecodedExtensions<'a> {
276 type IntoIter = DecodedExtensionsIter<'a>;
277 type Item = Result<ExtensionRef<'a>, DecodeError>;
278
279 fn into_iter(self) -> Self::IntoIter {
280 DecodedExtensionsIter {
281 inner: Some(Reader::init(self.inner)),
282 }
283 }
284}
285
286#[derive(Debug)]
287pub struct DecodedExtensionsIter<'a> {
296 inner: Option<Reader<'a>>,
297}
298
299impl<'a> Iterator for DecodedExtensionsIter<'a> {
300 type Item = Result<ExtensionRef<'a>, DecodeError>;
301
302 fn next(&mut self) -> Option<Self::Item> {
303 match self.inner.as_mut() {
304 Some(reader) => match ExtensionRef::decode(reader) {
305 Ok(Some(extension)) => Some(Ok(extension)),
306 Ok(None) => {
307 self.inner = None;
309
310 None
311 }
312 Err(err) => {
313 self.inner = None;
315
316 Some(Err(err))
317 }
318 },
319 None => None,
320 }
321 }
322}
323
324impl FusedIterator for DecodedExtensionsIter<'_> {}
325
326#[allow(clippy::module_name_repetitions, reason = "XXX")]
327#[derive(Debug)]
328#[derive(thiserror::Error)]
329pub enum DecodeError {
331 #[error("Invalid PROXY Protocol version: {0}")]
332 InvalidVersion(u8),
334
335 #[error("Invalid PROXY Protocol command: {0}")]
336 InvalidCommand(u8),
338
339 #[error("Invalid proxy address family: {0}")]
340 InvalidFamily(u8),
342
343 #[error("Invalid proxy transport protocol: {0}")]
344 InvalidProtocol(u8),
346
347 #[error("Trailing data after the header")]
348 TrailingData,
350
351 #[error("Malformed data")]
352 MalformedData,
355}