use bytes::{Buf, BufMut as _, BytesMut};
use snafu::{ensure, Snafu};
use std::net::{Ipv4Addr, Ipv6Addr};
#[derive(Debug, Snafu)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub enum ParseError {
#[snafu(display("an unexpected eof was hit"))]
UnexpectedEof,
#[snafu(display("invalid command: {}", cmd))]
UnknownCommand { cmd: u8 },
#[snafu(display("invalid address family: {}", family))]
UnknownAddressFamily { family: u8 },
#[snafu(display("invalid transport protocol: {}", protocol))]
UnknownTransportProtocol { protocol: u8 },
#[snafu(display("insufficient length specified: {}, requires minimum {}", given, needs))]
InsufficientLengthSpecified { given: usize, needs: usize },
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum ProxyCommand {
Local,
Proxy,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum ProxyTransportProtocol {
Unspec,
Stream,
Datagram,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum ProxyAddresses {
Unspec,
Ipv4 {
source: (Ipv4Addr, Option<u16>),
destination: (Ipv4Addr, Option<u16>),
},
Ipv6 {
source: (Ipv6Addr, Option<u16>),
destination: (Ipv6Addr, Option<u16>),
},
Unix {
source: [u8; 108],
destination: [u8; 108],
},
}
#[derive(PartialEq, Eq)]
enum ProxyAddressFamily {
Unspec,
Inet,
Inet6,
Unix,
}
pub(crate) fn parse(buf: &mut impl Buf) -> Result<super::ProxyHeader, ParseError> {
let command = buf.get_u8() << 4 >> 4;
let command = match command {
0 => ProxyCommand::Local,
1 => ProxyCommand::Proxy,
cmd => return UnknownCommand { cmd }.fail(),
};
ensure!(buf.remaining() >= 3, UnexpectedEof);
let byte = buf.get_u8();
let address_family = match byte >> 4 {
0 => ProxyAddressFamily::Unspec,
1 => ProxyAddressFamily::Inet,
2 => ProxyAddressFamily::Inet6,
3 => ProxyAddressFamily::Unix,
family => return UnknownAddressFamily { family }.fail(),
};
let transport_protocol = match byte << 4 >> 4 {
0 => ProxyTransportProtocol::Unspec,
1 => ProxyTransportProtocol::Stream,
2 => ProxyTransportProtocol::Datagram,
protocol => return UnknownTransportProtocol { protocol }.fail(),
};
let length = buf.get_u16() as usize;
if address_family == ProxyAddressFamily::Unspec {
ensure!(buf.remaining() >= length, UnexpectedEof);
buf.advance(length);
return Ok(super::ProxyHeader::Version2 {
command,
transport_protocol,
addresses: ProxyAddresses::Unspec,
});
}
if address_family == ProxyAddressFamily::Unix {
ensure!(
length >= 108 * 2,
InsufficientLengthSpecified {
given: length,
needs: 108usize * 2,
},
);
ensure!(buf.remaining() >= 108 * 2, UnexpectedEof);
let mut source = [0u8; 108];
let mut destination = [0u8; 108];
buf.copy_to_slice(&mut source[..]);
buf.copy_to_slice(&mut destination[..]);
if length > 108 * 2 {
buf.advance(length - (108 * 2));
}
return Ok(super::ProxyHeader::Version2 {
command,
transport_protocol,
addresses: ProxyAddresses::Unix {
source,
destination,
},
});
}
let read_port = transport_protocol != ProxyTransportProtocol::Unspec;
let port_length = if read_port { 4 } else { 0 };
let address_length = if address_family == ProxyAddressFamily::Inet {
8
} else {
32
};
ensure!(
length >= port_length + address_length,
InsufficientLengthSpecified {
given: length,
needs: port_length + address_length,
},
);
ensure!(
buf.remaining() >= port_length + address_length,
UnexpectedEof,
);
let addresses = if address_family == ProxyAddressFamily::Inet {
let mut data = [0u8; 4];
buf.copy_to_slice(&mut data[..]);
let source = Ipv4Addr::from(data);
let source_port = if read_port { Some(buf.get_u16()) } else { None };
buf.copy_to_slice(&mut data);
let destination = Ipv4Addr::from(data);
let destination_port = if read_port { Some(buf.get_u16()) } else { None };
ProxyAddresses::Ipv4 {
source: (source, source_port),
destination: (destination, destination_port),
}
} else {
let mut data = [0u8; 16];
buf.copy_to_slice(&mut data);
let source = Ipv6Addr::from(data);
let source_port = if read_port { Some(buf.get_u16()) } else { None };
buf.copy_to_slice(&mut data);
let destination = Ipv6Addr::from(data);
let destination_port = if read_port { Some(buf.get_u16()) } else { None };
ProxyAddresses::Ipv6 {
source: (source, source_port),
destination: (destination, destination_port),
}
};
if length > port_length + address_length {
buf.advance(length - (port_length + address_length));
}
Ok(super::ProxyHeader::Version2 {
command,
transport_protocol,
addresses,
})
}
pub(crate) fn encode(
command: ProxyCommand,
transport_protocol: ProxyTransportProtocol,
addresses: ProxyAddresses,
) -> BytesMut {
const SIG: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
let ver_cmd = (2 << 4)
| match command {
ProxyCommand::Local => 0,
ProxyCommand::Proxy => 1,
};
let fam = (match addresses {
ProxyAddresses::Unspec => 0,
ProxyAddresses::Ipv4 { .. } => 1,
ProxyAddresses::Ipv6 { .. } => 2,
ProxyAddresses::Unix { .. } => 3,
} << 4)
| match transport_protocol {
ProxyTransportProtocol::Unspec => 0,
ProxyTransportProtocol::Stream => 1,
ProxyTransportProtocol::Datagram => 2,
};
let len = match addresses {
ProxyAddresses::Unspec => 0,
ProxyAddresses::Unix {
source,
destination,
} => source.len() + destination.len(),
ProxyAddresses::Ipv4 { source, .. } => {
4 + 4 + if let Some(_) = source.1 { 2 + 2 } else { 0 }
}
ProxyAddresses::Ipv6 { source, .. } => {
16 + 16 + if let Some(_) = source.1 { 2 + 2 } else { 0 }
}
};
let mut buf = BytesMut::with_capacity(16 + len);
buf.put_slice(&SIG[..]);
buf.put_slice(&[ver_cmd, fam][..]);
buf.put_u16(len as u16);
match addresses {
ProxyAddresses::Unspec => (),
ProxyAddresses::Unix {
source,
destination,
} => {
buf.put_slice(&source[..]);
buf.put_slice(&destination[..]);
}
ProxyAddresses::Ipv4 {
source,
destination,
} => {
buf.put_slice(&source.0.octets()[..]);
if let Some(port) = source.1 {
buf.put_u16(port);
}
buf.put_slice(&destination.0.octets()[..]);
if let Some(port) = destination.1 {
buf.put_u16(port);
}
}
ProxyAddresses::Ipv6 {
source,
destination,
} => {
buf.put_slice(&source.0.octets()[..]);
if let Some(port) = source.1 {
buf.put_u16(port);
}
buf.put_slice(&destination.0.octets()[..]);
if let Some(port) = destination.1 {
buf.put_u16(port);
}
}
}
buf
}
#[cfg(test)]
mod parse_tests {
use super::*;
use crate::ProxyHeader;
use bytes::{Bytes, BytesMut};
use pretty_assertions::assert_eq;
use rand::prelude::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn test_unspec() {
assert_eq!(
parse(&mut &[0u8; 16][..]),
Ok(ProxyHeader::Version2 {
command: ProxyCommand::Local,
addresses: ProxyAddresses::Unspec,
transport_protocol: ProxyTransportProtocol::Unspec,
}),
);
let mut prefix = BytesMut::from(&[1u8][..]);
prefix.reserve(16);
prefix.extend_from_slice(&[0u8; 16][..]);
assert_eq!(
parse(&mut prefix),
Ok(ProxyHeader::Version2 {
command: ProxyCommand::Proxy,
addresses: ProxyAddresses::Unspec,
transport_protocol: ProxyTransportProtocol::Unspec,
}),
);
}
#[test]
fn test_ipv4() {
assert_eq!(
parse(
&mut &[
1u8,
(1 << 4) | 1,
0,
15,
127,
0,
0,
1,
255,
255,
192,
168,
0,
1,
1,
1,
69,
0,
0,
][..]
),
Ok(ProxyHeader::Version2 {
command: ProxyCommand::Proxy,
transport_protocol: ProxyTransportProtocol::Stream,
addresses: ProxyAddresses::Ipv4 {
source: (Ipv4Addr::new(127, 0, 0, 1), Some(65535)),
destination: (Ipv4Addr::new(192, 168, 0, 1), Some(257)),
},
})
);
let mut data = Bytes::from_static(
&[
0u8,
(1 << 4) | 2,
0,
12,
0,
0,
0,
0,
0,
0,
255,
255,
255,
255,
255,
0,
1,
2,
3,
4,
][..],
);
assert_eq!(
parse(&mut data),
Ok(ProxyHeader::Version2 {
command: ProxyCommand::Local,
transport_protocol: ProxyTransportProtocol::Datagram,
addresses: ProxyAddresses::Ipv4 {
source: (Ipv4Addr::new(0, 0, 0, 0), Some(0)),
destination: (Ipv4Addr::new(255, 255, 255, 255), Some(255 << 8)),
},
})
);
assert!(data.remaining() == 4);
}
#[test]
fn test_ipv6() {
assert_eq!(
parse(
&mut &[
1u8,
(2 << 4) | 2,
0,
39,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
1,
69,
0,
0,
][..]
),
Ok(ProxyHeader::Version2 {
command: ProxyCommand::Proxy,
transport_protocol: ProxyTransportProtocol::Datagram,
addresses: ProxyAddresses::Ipv6 {
source: (
Ipv6Addr::new(65535, 65535, 65535, 65535, 65535, 65535, 65535, 65535),
Some(65535),
),
destination: (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), Some(257)),
},
})
);
let mut data = Bytes::from_static(
&[
0u8,
(2 << 4) | 1,
0,
36,
81,
92,
0,
52,
83,
12,
255,
68,
19,
5,
111,
200,
54,
90,
55,
66,
123,
0,
255,
255,
255,
255,
0,
0,
0,
0,
123,
123,
69,
69,
21,
21,
42,
42,
255,
255,
1,
2,
3,
4,
][..],
);
assert_eq!(
parse(&mut data),
Ok(ProxyHeader::Version2 {
command: ProxyCommand::Local,
transport_protocol: ProxyTransportProtocol::Stream,
addresses: ProxyAddresses::Ipv6 {
source: (
Ipv6Addr::new(20828, 52, 21260, 65348, 4869, 28616, 13914, 14146),
Some(31488),
),
destination: (
Ipv6Addr::new(65535, 65535, 0, 0, 31611, 17733, 5397, 10794),
Some(65535),
),
},
})
);
assert!(data.remaining() == 4);
}
#[test]
fn test_invalid_data() {
let mut data = [0u8; 200];
rand::thread_rng().fill_bytes(&mut data);
data[0] = 99;
assert!(parse(&mut &data[..]).is_err());
assert_eq!(parse(&mut &[0][..]), Err(ParseError::UnexpectedEof));
assert_eq!(
parse(
&mut &[
1u8,
(1 << 4) | 1,
0,
3,
][..]
),
Err(ParseError::InsufficientLengthSpecified {
given: 3,
needs: 4 * 2 + 2 * 2,
}),
);
}
}
#[cfg(test)]
mod encode_tests {
use super::*;
use bytes::{Bytes, BytesMut};
use pretty_assertions::assert_eq;
use std::net::{Ipv4Addr, Ipv6Addr};
const SIG: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
fn signed(buf: &[u8]) -> Bytes {
let mut bytes = BytesMut::from(&SIG[..]);
bytes.extend_from_slice(buf);
bytes.freeze()
}
#[test]
fn test_unspec() {
assert_eq!(
encode(
ProxyCommand::Local,
ProxyTransportProtocol::Unspec,
ProxyAddresses::Unspec,
),
signed(&[(2 << 4) | 0, 0, 0, 0][..]),
);
assert_eq!(
encode(
ProxyCommand::Proxy,
ProxyTransportProtocol::Unspec,
ProxyAddresses::Unspec,
),
signed(&[(2 << 4) | 1, 0, 0, 0][..]),
);
assert_eq!(
encode(
ProxyCommand::Proxy,
ProxyTransportProtocol::Unspec,
ProxyAddresses::Ipv4 {
source: (Ipv4Addr::new(1, 2, 3, 4), Some(65535)),
destination: (Ipv4Addr::new(192, 168, 0, 1), Some(9012)),
},
),
signed(
&[
(2 << 4) | 1,
(1 << 4) | 0,
0,
12,
1,
2,
3,
4,
255,
255,
192,
168,
0,
1,
(9012u16 >> 8) as u8,
9012u16 as u8,
][..]
),
);
}
#[test]
fn test_ipv4() {
assert_eq!(
encode(
ProxyCommand::Proxy,
ProxyTransportProtocol::Stream,
ProxyAddresses::Ipv4 {
source: (Ipv4Addr::new(1, 2, 3, 4), Some(65535)),
destination: (Ipv4Addr::new(192, 168, 0, 1), Some(9012)),
},
),
signed(
&[
(2 << 4) | 1,
(1 << 4) | 1,
0,
12,
1,
2,
3,
4,
255,
255,
192,
168,
0,
1,
(9012u16 >> 8) as u8,
9012u16 as u8,
][..]
),
);
assert_eq!(
encode(
ProxyCommand::Local,
ProxyTransportProtocol::Datagram,
ProxyAddresses::Ipv4 {
source: (Ipv4Addr::new(255, 255, 255, 255), None),
destination: (Ipv4Addr::new(192, 168, 0, 1), None),
},
),
signed(
&[
(2 << 4) | 0,
(1 << 4) | 2,
0,
8,
255,
255,
255,
255,
192,
168,
0,
1,
][..]
),
);
}
#[test]
fn test_ipv6() {
assert_eq!(
encode(
ProxyCommand::Local,
ProxyTransportProtocol::Datagram,
ProxyAddresses::Ipv6 {
source: (Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), Some(8192)),
destination: (
Ipv6Addr::new(65535, 65535, 32767, 32766, 111, 222, 333, 444),
Some(0),
),
}
),
signed(
&[
(2 << 4) | 0,
(2 << 4) | 2,
0,
36,
0,
1,
0,
2,
0,
3,
0,
4,
0,
5,
0,
6,
0,
7,
0,
8,
(8192u16 >> 8) as u8,
8192u16 as u8,
255,
255,
255,
255,
(32767u16 >> 8) as u8,
32767u16 as u8,
(32766u16 >> 8) as u8,
32766u16 as u8,
0,
111,
0,
222,
(333u16 >> 8) as u8,
333u16 as u8,
(444u16 >> 8) as u8,
444u16 as u8,
0,
0,
][..]
),
);
}
}