1use bytes::{Buf, BufMut as _, BytesMut};
2use snafu::{ensure, OptionExt as _, ResultExt as _, Snafu};
3use std::{
4 io::Write as _,
5 net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
6 str::{FromStr as _, Utf8Error},
7};
8
9const CR: u8 = 0x0D;
10const LF: u8 = 0x0A;
11
12#[derive(Debug, Snafu)]
13#[cfg_attr(test, derive(PartialEq, Eq))]
14pub enum ParseError {
15 #[snafu(display("an unexpected eof was hit"))]
16 UnexpectedEof,
17
18 #[snafu(display("an illegal address family was presented"))]
19 IllegalAddressFamily,
20
21 #[snafu(display("the given input is not valid ascii text"))]
22 NonAscii { source: Utf8Error },
23
24 #[snafu(display("the given input misses an address"))]
25 MissingAddress,
26
27 #[snafu(display("invalid ip address"))]
28 InvalidAddress { source: AddrParseError },
29
30 #[snafu(display("invalid port"))]
31 InvalidPort,
32
33 #[snafu(display("illegal header ending"))]
34 IllegalHeaderEnding,
35}
36
37#[derive(Debug, Snafu)]
38pub enum EncodeError {
39 #[snafu(display("could not write to the buffer"))]
40 StdIo { source: std::io::Error },
41}
42
43#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
44pub enum ProxyAddresses {
45 Unknown,
46 Ipv4 {
47 source: SocketAddrV4,
48 destination: SocketAddrV4,
49 },
50 Ipv6 {
51 source: SocketAddrV6,
52 destination: SocketAddrV6,
53 },
54}
55
56fn count_till_first(haystack: &[u8], needle: u8) -> Option<usize> {
57 for (idx, &b) in haystack.iter().enumerate() {
58 if b == needle {
59 return Some(idx);
60 }
61 }
62
63 None
64}
65
66pub(crate) fn parse(buf: &mut impl Buf) -> Result<super::ProxyHeader, ParseError> {
67 ensure!(buf.remaining() >= 4, UnexpectedEof);
68
69 let step = buf.get_u8();
70
71 #[derive(PartialEq, Eq)]
72 enum ProxyAddressFamily {
73 Tcp4,
74 Tcp6,
75 Unknown,
76 }
77
78 let address_family = match step {
79 b'T' => {
80 buf.advance(2);
82 let version = buf.get_u8();
83 match version {
84 b'4' => ProxyAddressFamily::Tcp4,
85 b'6' => ProxyAddressFamily::Tcp6,
86 _ => return IllegalAddressFamily.fail(),
87 }
88 }
89 b'U' => {
90 ensure!(buf.remaining() >= 6, UnexpectedEof); buf.advance(6);
93 ProxyAddressFamily::Unknown
94 }
95 _ => return IllegalAddressFamily.fail(),
96 };
97
98 if address_family == ProxyAddressFamily::Unknown {
99 let mut cr = false;
101 loop {
102 ensure!(buf.has_remaining(), UnexpectedEof);
103 let b = buf.get_u8();
104 if cr && b == LF {
105 break;
106 }
107 cr = b == CR;
108 }
109 return Ok(super::ProxyHeader::Version1 {
110 addresses: ProxyAddresses::Unknown,
111 });
112 }
113
114 ensure!(buf.remaining() >= 8, UnexpectedEof);
116 buf.advance(1); let count = count_till_first(buf.chunk(), b' ').context(MissingAddress)?;
119 let source = &buf.chunk()[..count];
120 let source = std::str::from_utf8(source).context(NonAscii)?;
121 let source = match address_family {
122 ProxyAddressFamily::Tcp4 => IpAddr::V4(Ipv4Addr::from_str(source).context(InvalidAddress)?),
123 ProxyAddressFamily::Tcp6 => IpAddr::V6(Ipv6Addr::from_str(source).context(InvalidAddress)?),
124 ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"),
125 };
126 buf.advance(count);
127
128 ensure!(buf.remaining() >= 8, UnexpectedEof);
130 buf.advance(1); let count = count_till_first(buf.chunk(), b' ').context(MissingAddress)?;
133 let destination = &buf.chunk()[..count];
134 let destination = std::str::from_utf8(destination).context(NonAscii)?;
135 let destination = match address_family {
136 ProxyAddressFamily::Tcp4 => {
137 IpAddr::V4(Ipv4Addr::from_str(destination).context(InvalidAddress)?)
138 }
139 ProxyAddressFamily::Tcp6 => {
140 IpAddr::V6(Ipv6Addr::from_str(destination).context(InvalidAddress)?)
141 }
142 ProxyAddressFamily::Unknown => unreachable!("unknown should have its own branch"),
143 };
144 buf.advance(count);
145
146 ensure!(buf.remaining() >= 2, UnexpectedEof);
148 buf.advance(1);
149
150 let count = count_till_first(buf.chunk(), b' ').context(InvalidPort)?;
151 let source_port = &buf.chunk()[..count];
152 let source_port = std::str::from_utf8(source_port).context(NonAscii)?;
153 ensure!(
154 !source_port.starts_with('0') || source_port == "0",
156 InvalidPort,
157 );
158 let source_port: u16 = source_port.parse().ok().context(InvalidPort)?;
159 buf.advance(count);
160
161 ensure!(buf.remaining() >= 4, UnexpectedEof);
163 buf.advance(1);
164
165 let count = count_till_first(buf.chunk(), CR).context(InvalidPort)?;
167 let destination_port = &buf.chunk()[..count];
168 let destination_port = std::str::from_utf8(destination_port).context(NonAscii)?;
169 ensure!(
170 !destination_port.starts_with('0') || destination_port == "0",
172 InvalidPort,
173 );
174 let destination_port: u16 = destination_port.parse().ok().context(InvalidPort)?;
175 buf.advance(count);
176
177 ensure!(buf.get_u8() == CR, IllegalHeaderEnding);
178 ensure!(buf.get_u8() == LF, IllegalHeaderEnding);
179
180 let addresses = match (source, destination) {
181 (IpAddr::V4(source), IpAddr::V4(destination)) => ProxyAddresses::Ipv4 {
182 source: SocketAddrV4::new(source, source_port),
183 destination: SocketAddrV4::new(destination, destination_port),
184 },
185 (IpAddr::V6(source), IpAddr::V6(destination)) => ProxyAddresses::Ipv6 {
186 source: SocketAddrV6::new(source, source_port, 0, 0),
187 destination: SocketAddrV6::new(destination, destination_port, 0, 0),
188 },
189 _ => unreachable!(),
191 };
192
193 Ok(super::ProxyHeader::Version1 {
194 addresses,
195 })
196}
197
198pub(crate) fn encode(addresses: ProxyAddresses) -> Result<BytesMut, EncodeError> {
199 if let ProxyAddresses::Unknown = addresses {
200 return Ok(BytesMut::from(&b"PROXY UNKNOWN\r\n"[..]));
201 }
202
203 let mut buf = BytesMut::with_capacity(107).writer();
205 buf.write_all(&b"PROXY TCP"[..]).context(StdIo)?;
206
207 match addresses {
208 ProxyAddresses::Ipv4 {
209 source,
210 destination,
211 } => {
212 buf.write(&b"4 "[..]).context(StdIo)?;
213 write!(
214 buf,
215 "{} {} {} {}\r\n",
216 source.ip(),
217 destination.ip(),
218 source.port(),
219 destination.port(),
220 )
221 .context(StdIo)?;
222 }
223 ProxyAddresses::Ipv6 {
224 source,
225 destination,
226 } => {
227 buf.write(&b"6 "[..]).context(StdIo)?;
228 write!(
229 buf,
230 "{} {} {} {}\r\n",
231 source.ip(),
232 destination.ip(),
233 source.port(),
234 destination.port(),
235 )
236 .context(StdIo)?;
237 }
238 ProxyAddresses::Unknown => unreachable!(),
239 }
240
241 Ok(buf.into_inner())
242}
243
244#[cfg(test)]
245mod parse_tests {
246 use super::*;
247 use crate::ProxyHeader;
248 use bytes::Bytes;
249 use pretty_assertions::assert_eq;
250 use rand::prelude::*;
251 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
252
253 #[test]
254 fn test_valid_unknown_cases() {
255 let unknown = Ok(ProxyHeader::Version1 {
256 addresses: ProxyAddresses::Unknown,
257 });
258 assert_eq!(parse(&mut &b"UNKNOWN\r\n"[..]), unknown);
259 assert_eq!(
260 parse(&mut &b"UNKNOWN this is bogus data!\r\r\r\n"[..]),
261 unknown,
262 );
263 assert_eq!(
264 parse(&mut &b"UNKNOWN 192.168.0.1 192.168.1.1 123 321\r\n"[..]),
265 unknown,
266 );
267
268 let mut random = [0u8; 128];
269 rand::thread_rng().fill_bytes(&mut random);
270 let mut header = b"UNKNOWN ".to_vec();
271 header.extend(&random[..]);
272 header.extend(b"\r\n");
273 let mut buf = Bytes::from(header);
274 assert_eq!(parse(&mut buf), unknown);
275 assert!(!buf.has_remaining()); }
277
278 #[test]
279 fn test_valid_ipv4_cases() {
280 fn valid(
281 (a, b, c, d): (u8, u8, u8, u8),
282 e: u16,
283 (f, g, h, i): (u8, u8, u8, u8),
284 j: u16,
285 ) -> ProxyHeader {
286 ProxyHeader::Version1 {
287 addresses: ProxyAddresses::Ipv4 {
288 source: SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), e),
289 destination: SocketAddrV4::new(Ipv4Addr::new(f, g, h, i), j),
290 },
291 }
292 }
293 assert_eq!(
294 parse(&mut &b"TCP4 192.168.201.102 1.2.3.4 0 65535\r\n"[..]),
295 Ok(valid((192, 168, 201, 102), 0, (1, 2, 3, 4), 65535)),
296 );
297 assert_eq!(
298 parse(&mut &b"TCP4 0.0.0.0 0.0.0.0 0 0\r\n"[..]),
299 Ok(valid((0, 0, 0, 0), 0, (0, 0, 0, 0), 0)),
300 );
301 assert_eq!(
302 parse(&mut &b"TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n"[..]),
303 Ok(valid(
304 (255, 255, 255, 255),
305 65535,
306 (255, 255, 255, 255),
307 65535,
308 )),
309 );
310 }
311
312 #[test]
313 fn test_valid_ipv6_cases() {
314 fn valid(
315 (a, b, c, d, e, f, g, h): (u16, u16, u16, u16, u16, u16, u16, u16),
316 i: u16,
317 (j, k, l, m, n, o, p, q): (u16, u16, u16, u16, u16, u16, u16, u16),
318 r: u16,
319 ) -> ProxyHeader {
320 ProxyHeader::Version1 {
321 addresses: ProxyAddresses::Ipv6 {
322 source: SocketAddrV6::new(Ipv6Addr::new(a, b, c, d, e, f, g, h), i, 0, 0),
323 destination: SocketAddrV6::new(Ipv6Addr::new(j, k, l, m, n, o, p, q), r, 0, 0),
324 },
325 }
326 }
327 assert_eq!(
328 parse(&mut &b"TCP6 ab:ce:ef:01:23:45:67:89 ::1 0 65535\r\n"[..]),
329 Ok(valid(
330 (0xAB, 0xCE, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89),
331 0,
332 (0, 0, 0, 0, 0, 0, 0, 1),
333 65535,
334 )),
335 );
336 assert_eq!(
337 parse(&mut &b"TCP6 :: :: 0 0\r\n"[..]),
338 Ok(valid(
339 (0, 0, 0, 0, 0, 0, 0, 0),
340 0,
341 (0, 0, 0, 0, 0, 0, 0, 0),
342 0,
343 )),
344 );
345 assert_eq!(
346 parse(
347 &mut &b"TCP6 ff:ff:ff:ff:ff:ff:ff:ff ff:ff:ff:ff:ff:ff:ff:ff 65535 65535\r\n"[..],
348 ),
349 Ok(valid(
350 (0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF),
351 65535,
352 (0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF),
353 65535,
354 )),
355 );
356 }
357
358 #[test]
359 fn test_invalid_cases() {
360 assert_eq!(
361 parse(&mut &b"UNKNOWN \r"[..]),
362 Err(ParseError::UnexpectedEof)
363 );
364 assert_eq!(
365 parse(&mut &b"UNKNOWN \r\t\t\r"[..]),
366 Err(ParseError::UnexpectedEof),
367 );
368 assert_eq!(
369 parse(&mut &b"UNKNOWN\r\r\r\r\rHello, world!"[..]),
370 Err(ParseError::UnexpectedEof),
371 );
372 assert_eq!(
373 parse(&mut &b"UNKNOWN\nGET /index.html HTTP/1.0"[..]),
374 Err(ParseError::UnexpectedEof),
375 );
376 assert_eq!(
377 parse(&mut &b"UNKNOWN\n"[..]),
378 Err(ParseError::UnexpectedEof)
379 );
380 }
381
382 #[test]
383 fn test_crlf() {
384 assert_eq!(CR, b'\r');
385 assert_eq!(LF, b'\n');
386 }
387}
388
389#[cfg(test)]
390mod encode_tests {
391 use super::*;
392 use bytes::Bytes;
393 use pretty_assertions::assert_eq;
394 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
395
396 #[test]
397 fn test_unknown() {
398 let encoded = encode(ProxyAddresses::Unknown);
399 assert!(matches!(encoded, Ok(_)));
400 assert_eq!(encoded.unwrap(), &b"PROXY UNKNOWN\r\n"[..]);
401 }
402
403 #[test]
404 fn test_tcp4() {
405 let encoded = encode(ProxyAddresses::Ipv4 {
406 source: SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 987),
407 destination: SocketAddrV4::new(Ipv4Addr::new(255, 254, 253, 252), 12345),
408 });
409 assert!(matches!(encoded, Ok(_)));
410 assert_eq!(
411 encoded.unwrap(),
412 Bytes::from_static(&b"PROXY TCP4 1.2.3.4 255.254.253.252 987 12345\r\n"[..]),
413 );
414 }
415
416 #[test]
417 fn test_tcp6() {
418 let encoded = encode(ProxyAddresses::Ipv6 {
419 source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 987, 0, 0),
420 destination: SocketAddrV6::new(
421 Ipv6Addr::new(65535, 65534, 65533, 65532, 0, 1, 2, 3),
422 12345,
423 0,
424 0,
425 ),
426 });
427 assert!(matches!(encoded, Ok(_)));
428 assert_eq!(
429 encoded.unwrap(),
430 Bytes::from_static(
431 &b"PROXY TCP6 1:2:3:4:5:6:7:8 ffff:fffe:fffd:fffc:0:1:2:3 987 12345\r\n"[..],
432 ),
433 );
434
435 let encoded = encode(ProxyAddresses::Ipv6 {
436 source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 987, 0, 0),
437 destination: SocketAddrV6::new(
438 Ipv6Addr::new(65535, 65534, 0, 0, 0, 1, 2, 3),
439 12345,
440 0,
441 0,
442 ),
443 });
444 assert!(matches!(encoded, Ok(_)));
445 assert_eq!(
446 encoded.unwrap(),
447 Bytes::from_static(&b"PROXY TCP6 1:2:3:4:5:6:7:8 ffff:fffe::1:2:3 987 12345\r\n"[..]),
448 );
449
450 let encoded = encode(ProxyAddresses::Ipv6 {
451 source: SocketAddrV6::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 987, 0, 0),
452 destination: SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 1, 2, 3), 12345, 0, 0),
453 });
454 assert!(matches!(encoded, Ok(_)));
455 assert_eq!(
456 encoded.unwrap(),
457 Bytes::from_static(&b"PROXY TCP6 1:2:3:4:5:6:7:8 ::1:2:3 987 12345\r\n"[..]),
458 );
459 }
460}