#![deny(warnings)]
#![warn(unused_extern_crates)]
#![deny(clippy::todo)]
#![deny(clippy::unimplemented)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::expect_used)]
#![deny(clippy::panic)]
#![deny(clippy::unreachable)]
#![deny(clippy::await_holding_lock)]
#![deny(clippy::needless_pass_by_value)]
#![deny(clippy::trivially_copy_pass_by_ref)]
use crate::parse::{parse_proxy_hdr_v1, parse_proxy_hdr_v2};
use std::num::NonZeroUsize;
#[cfg(any(test, feature = "tokio"))]
use crate::parse::{V1_MAX_LEN, V1_MIN_LEN};
const NZ_ONE: NonZeroUsize = NonZeroUsize::new(1).expect("Invalid compile time constant");
mod parse;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[repr(u8)]
enum Protocol {
Unspec = 0x00,
TcpV4 = 0x11,
UdpV4 = 0x12,
TcpV6 = 0x21,
UdpV6 = 0x22,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[repr(u8)]
enum Command {
Local = 0x00,
Proxy = 0x01,
}
#[derive(Debug, PartialEq, Eq, Clone)]
enum Address {
None,
V4 {
src: std::net::SocketAddrV4,
dst: std::net::SocketAddrV4,
},
V6 {
src: std::net::SocketAddrV6,
dst: std::net::SocketAddrV6,
},
}
#[derive(Debug, Clone)]
pub enum RemoteAddress {
Local,
Invalid,
TcpV4 {
src: std::net::SocketAddrV4,
dst: std::net::SocketAddrV4,
},
UdpV4 {
src: std::net::SocketAddrV4,
dst: std::net::SocketAddrV4,
},
TcpV6 {
src: std::net::SocketAddrV6,
dst: std::net::SocketAddrV6,
},
UdpV6 {
src: std::net::SocketAddrV6,
dst: std::net::SocketAddrV6,
},
}
#[derive(Debug)]
pub enum Error {
Incomplete { need: NonZeroUsize },
Invalid,
UnableToComplete,
}
#[derive(Debug, Clone)]
pub struct ProxyHdrV2 {
command: Command,
protocol: Protocol,
address: Address,
}
impl ProxyHdrV2 {
pub fn parse(input_data: &[u8]) -> Result<(usize, Self), Error> {
match parse_proxy_hdr_v2(input_data) {
Ok((remainder, hdr)) => {
let took = input_data.len() - remainder.len();
Ok((took, hdr))
}
Err(nom::Err::Incomplete(nom::Needed::Size(need))) => Err(Error::Incomplete { need }),
Err(nom::Err::Incomplete(nom::Needed::Unknown)) => Err(Error::UnableToComplete),
Err(nom::Err::Error(err)) => {
tracing::error!(?err);
Err(Error::Invalid)
}
Err(nom::Err::Failure(err)) => {
tracing::error!(?err, "parser failure handling proxy v2 header");
Err(Error::Invalid)
}
}
}
pub fn to_remote_addr(self) -> RemoteAddress {
match (self.command, self.protocol, self.address) {
(Command::Local, _, _) => RemoteAddress::Local,
(Command::Proxy, Protocol::TcpV4, Address::V4 { src, dst }) => {
RemoteAddress::TcpV4 { src, dst }
}
(Command::Proxy, Protocol::UdpV4, Address::V4 { src, dst }) => {
RemoteAddress::UdpV4 { src, dst }
}
(Command::Proxy, Protocol::TcpV6, Address::V6 { src, dst }) => {
RemoteAddress::TcpV6 { src, dst }
}
(Command::Proxy, Protocol::UdpV6, Address::V6 { src, dst }) => {
RemoteAddress::UdpV6 { src, dst }
}
_ => RemoteAddress::Invalid,
}
}
}
#[derive(Debug, Clone)]
pub struct ProxyHdrV1 {
protocol: Protocol,
address: Address,
}
impl ProxyHdrV1 {
pub fn parse(input_data: &[u8]) -> Result<(usize, Self), Error> {
match parse_proxy_hdr_v1(input_data) {
Ok((remainder, hdr)) => {
let took = input_data.len() - remainder.len();
Ok((took, hdr))
}
Err(nom::Err::Incomplete(nom::Needed::Size(need))) => Err(Error::Incomplete { need }),
Err(nom::Err::Incomplete(nom::Needed::Unknown)) => {
Err(Error::Incomplete { need: NZ_ONE })
}
Err(nom::Err::Error(err)) => {
tracing::error!(?err);
Err(Error::Invalid)
}
Err(nom::Err::Failure(err)) => {
tracing::error!(?err, "parser failure handling proxy v1 header");
Err(Error::Invalid)
}
}
}
pub fn to_remote_addr(self) -> RemoteAddress {
match (self.protocol, self.address) {
(Protocol::TcpV4, Address::V4 { src, dst }) => RemoteAddress::TcpV4 { src, dst },
(Protocol::UdpV4, Address::V4 { src, dst }) => RemoteAddress::UdpV4 { src, dst },
(Protocol::TcpV6, Address::V6 { src, dst }) => RemoteAddress::TcpV6 { src, dst },
(Protocol::UdpV6, Address::V6 { src, dst }) => RemoteAddress::UdpV6 { src, dst },
_ => RemoteAddress::Invalid,
}
}
}
#[cfg(any(feature = "tokio", test))]
#[derive(Debug)]
pub enum AsyncReadError {
Io(std::io::Error),
Invalid,
UnableToComplete,
RequestTooLarge,
InconsistentRead,
}
#[cfg(any(feature = "tokio", test))]
impl ProxyHdrV2 {
pub async fn parse_from_read<S>(mut stream: S) -> Result<(S, Self), AsyncReadError>
where
S: tokio::io::AsyncReadExt + std::marker::Unpin,
{
use tracing::{debug, error};
const HDR_SIZE_LIMIT: usize = 512;
let mut buf = vec![0; 16];
let mut took = stream
.read_exact(&mut buf)
.await
.map_err(AsyncReadError::Io)?;
match ProxyHdrV2::parse(&buf) {
Ok((_, hdr)) => return Ok((stream, hdr)),
Err(Error::Incomplete { need }) => {
let resize_to = buf.len() + usize::from(need);
if resize_to > HDR_SIZE_LIMIT {
error!(
"proxy v2 header request was larger than {} bytes, refusing to proceed.",
HDR_SIZE_LIMIT
);
return Err(AsyncReadError::RequestTooLarge);
}
buf.resize(resize_to, 0);
}
Err(Error::Invalid) => {
debug!(proxy_binary_dump = %hex::encode(&buf));
error!("proxy v2 header was invalid");
return Err(AsyncReadError::Invalid);
}
Err(Error::UnableToComplete) => {
debug!(proxy_binary_dump = %hex::encode(&buf));
error!("proxy v2 header was incomplete");
return Err(AsyncReadError::UnableToComplete);
}
};
took += stream
.read_exact(&mut buf[16..])
.await
.map_err(AsyncReadError::Io)?;
match ProxyHdrV2::parse(&buf) {
Ok((hdr_took, _)) if hdr_took != took => {
error!("proxy v2 header read an inconsistent amount from stream.");
Err(AsyncReadError::InconsistentRead)
}
Ok((_, hdr)) =>
{
Ok((stream, hdr))
}
Err(Error::Incomplete { need: _ }) => {
error!("proxy v2 header could not be read to the end.");
Err(AsyncReadError::UnableToComplete)
}
Err(Error::Invalid) => {
debug!(proxy_binary_dump = %hex::encode(&buf));
error!("proxy v2 header was invalid");
Err(AsyncReadError::Invalid)
}
Err(Error::UnableToComplete) => {
debug!(proxy_binary_dump = %hex::encode(&buf));
error!("proxy v2 header was incomplete");
Err(AsyncReadError::UnableToComplete)
}
}
}
}
#[cfg(any(feature = "tokio", test))]
impl ProxyHdrV1 {
pub async fn parse_from_read<S>(mut stream: S) -> Result<(S, Self), AsyncReadError>
where
S: tokio::io::AsyncReadExt + std::marker::Unpin,
{
use tracing::{debug, error};
let mut buf = [0; V1_MAX_LEN + 1];
let mut took = stream
.read_exact(&mut buf[..V1_MIN_LEN])
.await
.map_err(AsyncReadError::Io)?;
loop {
if took > buf.len() {
error!("proxy v1 header read over ran the buffer allocation.");
return Err(AsyncReadError::Invalid);
}
match ProxyHdrV1::parse(&buf[..took]) {
Ok((hdr_took, _)) if hdr_took != took => {
error!("proxy v1 header read an inconsistent amount from stream.");
return Err(AsyncReadError::InconsistentRead);
}
Ok((_, hdr)) =>
{
return Ok((stream, hdr));
}
Err(Error::Incomplete { need }) => {
took += stream
.read_exact(&mut buf[took..took + need.get()])
.await
.map_err(AsyncReadError::Io)?;
continue;
}
Err(Error::Invalid) => {
debug!(proxy_binary_dump = %hex::encode(buf));
error!("proxy v1 header was invalid");
return Err(AsyncReadError::Invalid);
}
Err(Error::UnableToComplete) => {
debug!(proxy_binary_dump = %hex::encode(buf));
error!("proxy v1 header was incomplete");
return Err(AsyncReadError::UnableToComplete);
}
}
} }
}
#[cfg(test)]
mod tests {
use crate::{Address, Command, Protocol, ProxyHdrV1, ProxyHdrV2};
use std::net::SocketAddrV4;
use std::str::FromStr;
#[tokio::test]
async fn proxyv1_stream_parse() {
let _ = tracing_subscriber::fmt::try_init();
let data = "PROXY TCP4 91.221.138.33 91.221.138.106 47780 636\r\n";
let (_, hdr) = ProxyHdrV1::parse_from_read(data.as_bytes()).await.unwrap();
tracing::debug!(?hdr);
assert_eq!(hdr.protocol, Protocol::TcpV4);
assert_eq!(
hdr.address,
Address::V4 {
src: SocketAddrV4::from_str("91.221.138.33:47780").unwrap(),
dst: SocketAddrV4::from_str("91.221.138.106:636").unwrap(),
}
);
}
#[tokio::test]
async fn proxyv2_stream_parse() {
let _ = tracing_subscriber::fmt::try_init();
let sample = hex::decode("0d0a0d0a000d0a515549540a2111000cac180c76ac180b8fcdcb027d")
.expect("valid hex");
let (_, hdr) = ProxyHdrV2::parse_from_read(sample.as_slice())
.await
.expect("should parse v4 addr");
tracing::debug!(?hdr);
assert_eq!(hdr.command, Command::Proxy);
assert_eq!(hdr.protocol, Protocol::TcpV4);
assert_eq!(
hdr.address,
Address::V4 {
src: SocketAddrV4::from_str("172.24.12.118:52683").expect("valid addr"),
dst: SocketAddrV4::from_str("172.24.11.143:637").expect("valid addr"),
}
);
}
#[cfg(all(test, feature = "tokio"))]
mod async_stream_tests {
use super::*;
use std::net::{SocketAddrV4, SocketAddrV6};
use std::str::FromStr;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
async fn write_in_chunks<W>(mut writer: W, data: &[u8], chunk_sizes: &[usize])
where
W: AsyncWrite + Unpin,
{
let mut offset = 0;
for &size in chunk_sizes {
if offset >= data.len() {
break;
}
let end = (offset + size).min(data.len());
#[allow(clippy::expect_used)] writer
.write_all(&data[offset..end])
.await
.expect("chunk write should succeed");
tokio::task::yield_now().await;
offset = end;
}
if offset < data.len() {
#[allow(clippy::expect_used)] writer
.write_all(&data[offset..])
.await
.expect("final write should succeed");
}
}
#[tokio::test]
async fn tokio_stream_parse_v2_chunks() {
let _ = tracing_subscriber::fmt::try_init();
let sample = hex::decode("0d0a0d0a000d0a515549540a2111000cac180c76ac180b8fcdcb027d")
.expect("valid hex");
let payload = b"hello";
let mut full = sample.clone();
full.extend_from_slice(payload);
let (client, server) = tokio::io::duplex(32);
let writer = tokio::spawn(async move {
write_in_chunks(server, &full, &[5, 3, 1, 7, 2]).await;
});
let (mut stream, hdr) = ProxyHdrV2::parse_from_read(client)
.await
.expect("should parse v2 from stream");
let mut extra = vec![0; payload.len()];
stream
.read_exact(&mut extra)
.await
.expect("should read extra payload");
writer.await.expect("writer task should finish");
assert_eq!(extra.as_slice(), payload);
assert_eq!(hdr.command, Command::Proxy);
assert_eq!(hdr.protocol, Protocol::TcpV4);
assert_eq!(
hdr.address,
Address::V4 {
src: SocketAddrV4::from_str("172.24.12.118:52683").expect("valid addr"),
dst: SocketAddrV4::from_str("172.24.11.143:637").expect("valid addr"),
}
);
}
#[tokio::test]
async fn tokio_stream_parse_v1_chunks() {
let _ = tracing_subscriber::fmt::try_init();
let header = b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n";
let payload = b"more_data";
let mut full = header.to_vec();
full.extend_from_slice(payload);
let (client, server) = tokio::io::duplex(64);
let writer = tokio::spawn(async move {
write_in_chunks(server, &full, &[4, 1, 8, 2, 3, 5, 1]).await;
});
let (mut stream, hdr) = ProxyHdrV1::parse_from_read(client)
.await
.expect("should parse v1 from stream");
let mut extra = vec![0; payload.len()];
stream
.read_exact(&mut extra)
.await
.expect("should read extra payload");
writer.await.expect("writer task should finish");
assert_eq!(extra.as_slice(), payload);
assert_eq!(hdr.protocol, Protocol::TcpV6);
assert_eq!(
hdr.address,
Address::V6 {
src: SocketAddrV6::from_str("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535")
.expect("valid addr"),
dst: SocketAddrV6::from_str("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535")
.expect("valid addr"),
}
);
}
}
}