1#![deny(warnings)]
2#![warn(unused_extern_crates)]
3#![deny(clippy::todo)]
4#![deny(clippy::unimplemented)]
5#![deny(clippy::unwrap_used)]
6#![deny(clippy::expect_used)]
7#![deny(clippy::panic)]
8#![deny(clippy::unreachable)]
9#![deny(clippy::await_holding_lock)]
10#![deny(clippy::needless_pass_by_value)]
11#![deny(clippy::trivially_copy_pass_by_ref)]
12
13use crate::parse::{parse_proxy_hdr_v1, parse_proxy_hdr_v2};
14use std::num::NonZeroUsize;
15
16#[cfg(any(test, feature = "tokio"))]
17use crate::parse::{V1_MAX_LEN, V1_MIN_LEN};
18
19const NZ_ONE: NonZeroUsize = NonZeroUsize::new(1).expect("Invalid compile time constant");
20
21mod parse;
22
23#[derive(Debug, PartialEq, Eq, Clone, Copy)]
24#[repr(u8)]
25enum Protocol {
26 Unspec = 0x00,
27 TcpV4 = 0x11,
28 UdpV4 = 0x12,
29 TcpV6 = 0x21,
30 UdpV6 = 0x22,
31 }
34
35#[derive(Debug, PartialEq, Eq, Clone, Copy)]
36#[repr(u8)]
37enum Command {
38 Local = 0x00,
39 Proxy = 0x01,
40}
41
42#[derive(Debug, PartialEq, Eq, Clone)]
43enum Address {
44 None,
45 V4 {
46 src: std::net::SocketAddrV4,
47 dst: std::net::SocketAddrV4,
48 },
49 V6 {
50 src: std::net::SocketAddrV6,
51 dst: std::net::SocketAddrV6,
52 },
53 }
58
59#[derive(Debug, Clone)]
60pub enum RemoteAddress {
61 Local,
62 Invalid,
63 TcpV4 {
64 src: std::net::SocketAddrV4,
65 dst: std::net::SocketAddrV4,
66 },
67 UdpV4 {
68 src: std::net::SocketAddrV4,
69 dst: std::net::SocketAddrV4,
70 },
71 TcpV6 {
72 src: std::net::SocketAddrV6,
73 dst: std::net::SocketAddrV6,
74 },
75 UdpV6 {
76 src: std::net::SocketAddrV6,
77 dst: std::net::SocketAddrV6,
78 },
79}
80
81#[derive(Debug)]
82pub enum Error {
83 Incomplete { need: NonZeroUsize },
84 Invalid,
85 UnableToComplete,
86}
87
88#[derive(Debug, Clone)]
89pub struct ProxyHdrV2 {
90 command: Command,
91 protocol: Protocol,
92 address: Address,
95}
96
97impl ProxyHdrV2 {
98 pub fn parse(input_data: &[u8]) -> Result<(usize, Self), Error> {
99 match parse_proxy_hdr_v2(input_data) {
100 Ok((remainder, hdr)) => {
101 let took = input_data.len() - remainder.len();
102 Ok((took, hdr))
103 }
104 Err(nom::Err::Incomplete(nom::Needed::Size(need))) => Err(Error::Incomplete { need }),
105 Err(nom::Err::Incomplete(nom::Needed::Unknown)) => Err(Error::UnableToComplete),
107
108 Err(nom::Err::Error(err)) => {
109 tracing::error!(?err);
110 Err(Error::Invalid)
111 }
112 Err(nom::Err::Failure(err)) => {
113 tracing::error!(?err, "parser failure handling proxy v2 header");
114 Err(Error::Invalid)
115 }
116 }
117 }
118
119 pub fn to_remote_addr(self) -> RemoteAddress {
120 match (self.command, self.protocol, self.address) {
121 (Command::Local, _, _) => RemoteAddress::Local,
122 (Command::Proxy, Protocol::TcpV4, Address::V4 { src, dst }) => {
123 RemoteAddress::TcpV4 { src, dst }
124 }
125 (Command::Proxy, Protocol::UdpV4, Address::V4 { src, dst }) => {
126 RemoteAddress::UdpV4 { src, dst }
127 }
128 (Command::Proxy, Protocol::TcpV6, Address::V6 { src, dst }) => {
129 RemoteAddress::TcpV6 { src, dst }
130 }
131 (Command::Proxy, Protocol::UdpV6, Address::V6 { src, dst }) => {
132 RemoteAddress::UdpV6 { src, dst }
133 }
134 _ => RemoteAddress::Invalid,
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
140pub struct ProxyHdrV1 {
141 protocol: Protocol,
142 address: Address,
143}
144
145impl ProxyHdrV1 {
146 pub fn parse(input_data: &[u8]) -> Result<(usize, Self), Error> {
147 match parse_proxy_hdr_v1(input_data) {
148 Ok((remainder, hdr)) => {
149 let took = input_data.len() - remainder.len();
150 Ok((took, hdr))
151 }
152 Err(nom::Err::Incomplete(nom::Needed::Size(need))) => Err(Error::Incomplete { need }),
153 Err(nom::Err::Incomplete(nom::Needed::Unknown)) => {
155 Err(Error::Incomplete { need: NZ_ONE })
156 }
157
158 Err(nom::Err::Error(err)) => {
159 tracing::error!(?err);
160 Err(Error::Invalid)
161 }
162 Err(nom::Err::Failure(err)) => {
163 tracing::error!(?err, "parser failure handling proxy v1 header");
164 Err(Error::Invalid)
165 }
166 }
167 }
168
169 pub fn to_remote_addr(self) -> RemoteAddress {
170 match (self.protocol, self.address) {
171 (Protocol::TcpV4, Address::V4 { src, dst }) => RemoteAddress::TcpV4 { src, dst },
172 (Protocol::UdpV4, Address::V4 { src, dst }) => RemoteAddress::UdpV4 { src, dst },
173 (Protocol::TcpV6, Address::V6 { src, dst }) => RemoteAddress::TcpV6 { src, dst },
174 (Protocol::UdpV6, Address::V6 { src, dst }) => RemoteAddress::UdpV6 { src, dst },
175 _ => RemoteAddress::Invalid,
176 }
177 }
178}
179
180#[cfg(any(feature = "tokio", test))]
181#[derive(Debug)]
182pub enum AsyncReadError {
183 Io(std::io::Error),
184 Invalid,
185 UnableToComplete,
186 RequestTooLarge,
187 InconsistentRead,
188}
189
190#[cfg(any(feature = "tokio", test))]
191impl ProxyHdrV2 {
192 pub async fn parse_from_read<S>(mut stream: S) -> Result<(S, Self), AsyncReadError>
193 where
194 S: tokio::io::AsyncReadExt + std::marker::Unpin,
195 {
196 use tracing::{debug, error};
197
198 const HDR_SIZE_LIMIT: usize = 512;
199
200 let mut buf = vec![0; 16];
201
202 let mut took = stream
206 .read_exact(&mut buf)
207 .await
208 .map_err(AsyncReadError::Io)?;
209
210 match ProxyHdrV2::parse(&buf) {
211 Ok((_, hdr)) => return Ok((stream, hdr)),
213 Err(Error::Incomplete { need }) => {
215 let resize_to = buf.len() + usize::from(need);
216 if resize_to > HDR_SIZE_LIMIT {
219 error!(
220 "proxy v2 header request was larger than {} bytes, refusing to proceed.",
221 HDR_SIZE_LIMIT
222 );
223 return Err(AsyncReadError::RequestTooLarge);
224 }
225 buf.resize(resize_to, 0);
226 }
227 Err(Error::Invalid) => {
228 debug!(proxy_binary_dump = %hex::encode(&buf));
229 error!("proxy v2 header was invalid");
230 return Err(AsyncReadError::Invalid);
231 }
232 Err(Error::UnableToComplete) => {
233 debug!(proxy_binary_dump = %hex::encode(&buf));
234 error!("proxy v2 header was incomplete");
235 return Err(AsyncReadError::UnableToComplete);
236 }
237 };
238
239 took += stream
241 .read_exact(&mut buf[16..])
242 .await
243 .map_err(AsyncReadError::Io)?;
244
245 match ProxyHdrV2::parse(&buf) {
246 Ok((hdr_took, _)) if hdr_took != took => {
247 error!("proxy v2 header read an inconsistent amount from stream.");
249 Err(AsyncReadError::InconsistentRead)
250 }
251 Ok((_, hdr)) =>
252 {
254 Ok((stream, hdr))
255 }
256 Err(Error::Incomplete { need: _ }) => {
257 error!("proxy v2 header could not be read to the end.");
258 Err(AsyncReadError::UnableToComplete)
259 }
260 Err(Error::Invalid) => {
261 debug!(proxy_binary_dump = %hex::encode(&buf));
262 error!("proxy v2 header was invalid");
263 Err(AsyncReadError::Invalid)
264 }
265 Err(Error::UnableToComplete) => {
266 debug!(proxy_binary_dump = %hex::encode(&buf));
267 error!("proxy v2 header was incomplete");
268 Err(AsyncReadError::UnableToComplete)
269 }
270 }
271 }
272}
273
274#[cfg(any(feature = "tokio", test))]
275impl ProxyHdrV1 {
276 pub async fn parse_from_read<S>(mut stream: S) -> Result<(S, Self), AsyncReadError>
277 where
278 S: tokio::io::AsyncReadExt + std::marker::Unpin,
279 {
280 use tracing::{debug, error};
281
282 let mut buf = [0; V1_MAX_LEN + 1];
284
285 let mut took = stream
289 .read_exact(&mut buf[..V1_MIN_LEN])
290 .await
291 .map_err(AsyncReadError::Io)?;
292
293 loop {
296 if took > buf.len() {
297 error!("proxy v1 header read over ran the buffer allocation.");
298 return Err(AsyncReadError::Invalid);
299 }
300 match ProxyHdrV1::parse(&buf[..took]) {
301 Ok((hdr_took, _)) if hdr_took != took => {
302 error!("proxy v1 header read an inconsistent amount from stream.");
304 return Err(AsyncReadError::InconsistentRead);
305 }
306 Ok((_, hdr)) =>
307 {
309 return Ok((stream, hdr));
310 }
311 Err(Error::Incomplete { need }) => {
312 took += stream
315 .read_exact(&mut buf[took..took + need.get()])
316 .await
317 .map_err(AsyncReadError::Io)?;
318
319 continue;
320 }
321 Err(Error::Invalid) => {
322 debug!(proxy_binary_dump = %hex::encode(buf));
323 error!("proxy v1 header was invalid");
324 return Err(AsyncReadError::Invalid);
325 }
326 Err(Error::UnableToComplete) => {
327 debug!(proxy_binary_dump = %hex::encode(buf));
328 error!("proxy v1 header was incomplete");
329 return Err(AsyncReadError::UnableToComplete);
330 }
331 }
332 } }
334}
335
336#[cfg(test)]
337mod tests {
338 use crate::{Address, Command, Protocol, ProxyHdrV1, ProxyHdrV2};
339 use std::net::SocketAddrV4;
340 use std::str::FromStr;
341
342 #[tokio::test]
343 async fn proxyv1_stream_parse() {
344 let _ = tracing_subscriber::fmt::try_init();
345
346 let data = "PROXY TCP4 91.221.138.33 91.221.138.106 47780 636\r\n";
347
348 let (_, hdr) = ProxyHdrV1::parse_from_read(data.as_bytes()).await.unwrap();
349
350 tracing::debug!(?hdr);
351
352 assert_eq!(hdr.protocol, Protocol::TcpV4);
353 assert_eq!(
354 hdr.address,
355 Address::V4 {
356 src: SocketAddrV4::from_str("91.221.138.33:47780").unwrap(),
357 dst: SocketAddrV4::from_str("91.221.138.106:636").unwrap(),
358 }
359 );
360 }
361
362 #[tokio::test]
363 async fn proxyv2_stream_parse() {
364 let _ = tracing_subscriber::fmt::try_init();
365
366 let sample = hex::decode("0d0a0d0a000d0a515549540a2111000cac180c76ac180b8fcdcb027d")
367 .expect("valid hex");
368
369 let (_, hdr) = ProxyHdrV2::parse_from_read(sample.as_slice())
370 .await
371 .expect("should parse v4 addr");
372
373 tracing::debug!(?hdr);
374
375 assert_eq!(hdr.command, Command::Proxy);
376 assert_eq!(hdr.protocol, Protocol::TcpV4);
377 assert_eq!(
378 hdr.address,
379 Address::V4 {
380 src: SocketAddrV4::from_str("172.24.12.118:52683").expect("valid addr"),
381 dst: SocketAddrV4::from_str("172.24.11.143:637").expect("valid addr"),
382 }
383 );
384 }
385
386 #[cfg(all(test, feature = "tokio"))]
387 mod async_stream_tests {
388 use super::*;
389 use std::net::{SocketAddrV4, SocketAddrV6};
390 use std::str::FromStr;
391 use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
392
393 async fn write_in_chunks<W>(mut writer: W, data: &[u8], chunk_sizes: &[usize])
394 where
395 W: AsyncWrite + Unpin,
396 {
397 let mut offset = 0;
398 for &size in chunk_sizes {
399 if offset >= data.len() {
400 break;
401 }
402 let end = (offset + size).min(data.len());
403 #[allow(clippy::expect_used)] writer
405 .write_all(&data[offset..end])
406 .await
407 .expect("chunk write should succeed");
408 tokio::task::yield_now().await;
409 offset = end;
410 }
411
412 if offset < data.len() {
413 #[allow(clippy::expect_used)] writer
415 .write_all(&data[offset..])
416 .await
417 .expect("final write should succeed");
418 }
419 }
420
421 #[tokio::test]
422 async fn tokio_stream_parse_v2_chunks() {
423 let _ = tracing_subscriber::fmt::try_init();
424
425 let sample = hex::decode("0d0a0d0a000d0a515549540a2111000cac180c76ac180b8fcdcb027d")
426 .expect("valid hex");
427 let payload = b"hello";
428 let mut full = sample.clone();
429 full.extend_from_slice(payload);
430
431 let (client, server) = tokio::io::duplex(32);
432
433 let writer = tokio::spawn(async move {
434 write_in_chunks(server, &full, &[5, 3, 1, 7, 2]).await;
435 });
436
437 let (mut stream, hdr) = ProxyHdrV2::parse_from_read(client)
438 .await
439 .expect("should parse v2 from stream");
440
441 let mut extra = vec![0; payload.len()];
442 stream
443 .read_exact(&mut extra)
444 .await
445 .expect("should read extra payload");
446
447 writer.await.expect("writer task should finish");
448
449 assert_eq!(extra.as_slice(), payload);
450 assert_eq!(hdr.command, Command::Proxy);
451 assert_eq!(hdr.protocol, Protocol::TcpV4);
452 assert_eq!(
453 hdr.address,
454 Address::V4 {
455 src: SocketAddrV4::from_str("172.24.12.118:52683").expect("valid addr"),
456 dst: SocketAddrV4::from_str("172.24.11.143:637").expect("valid addr"),
457 }
458 );
459 }
460
461 #[tokio::test]
462 async fn tokio_stream_parse_v1_chunks() {
463 let _ = tracing_subscriber::fmt::try_init();
464
465 let header = b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n";
466 let payload = b"more_data";
467 let mut full = header.to_vec();
468 full.extend_from_slice(payload);
469
470 let (client, server) = tokio::io::duplex(64);
471
472 let writer = tokio::spawn(async move {
473 write_in_chunks(server, &full, &[4, 1, 8, 2, 3, 5, 1]).await;
474 });
475
476 let (mut stream, hdr) = ProxyHdrV1::parse_from_read(client)
477 .await
478 .expect("should parse v1 from stream");
479
480 let mut extra = vec![0; payload.len()];
481 stream
482 .read_exact(&mut extra)
483 .await
484 .expect("should read extra payload");
485
486 writer.await.expect("writer task should finish");
487
488 assert_eq!(extra.as_slice(), payload);
489 assert_eq!(hdr.protocol, Protocol::TcpV6);
490 assert_eq!(
491 hdr.address,
492 Address::V6 {
493 src: SocketAddrV6::from_str("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535")
494 .expect("valid addr"),
495 dst: SocketAddrV6::from_str("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535")
496 .expect("valid addr"),
497 }
498 );
499 }
500 }
501}