#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(docsrs, allow(unused_attributes))]
mod util;
mod v1;
mod v2;
pub mod io;
use crate::util::{tlv, tlv_borrowed};
use std::borrow::Cow;
use std::fmt;
use std::net::SocketAddr;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum Protocol {
Stream,
Datagram,
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct ProxiedAddress {
pub protocol: Protocol,
pub source: SocketAddr,
pub destination: SocketAddr,
}
impl ProxiedAddress {
pub fn stream(source: SocketAddr, destination: SocketAddr) -> Self {
Self {
protocol: Protocol::Stream,
source,
destination,
}
}
pub fn datagram(source: SocketAddr, destination: SocketAddr) -> Self {
Self {
protocol: Protocol::Datagram,
source,
destination,
}
}
}
pub struct Tlvs<'a> {
buf: &'a [u8],
}
impl<'a> Iterator for Tlvs<'a> {
type Item = Result<Tlv<'a>, Error>;
fn next(&mut self) -> Option<Self::Item> {
if self.buf.is_empty() {
return None;
}
let kind = self.buf[0];
match self
.buf
.get(1..3)
.map(|s| u16::from_be_bytes(s.try_into().unwrap()) as usize)
{
Some(u) if u + 3 <= self.buf.len() => {
let (ret, new) = self.buf.split_at(3 + u);
self.buf = new;
Some(Tlv::decode(kind, &ret[3..]))
}
_ => {
self.buf = &[];
Some(Err(Error::Invalid))
}
}
}
}
#[derive(PartialEq, Eq, Clone)]
pub struct SslInfo<'a>(u8, u32, Cow<'a, [u8]>);
impl<'a> SslInfo<'a> {
pub fn new(
client_ssl: bool,
client_cert_conn: bool,
client_cert_sess: bool,
verify: u32,
) -> Self {
Self(
(client_ssl as u8) | (client_cert_conn as u8) << 1 | (client_cert_sess as u8) << 2,
verify,
Default::default(),
)
}
pub fn client_ssl(&self) -> bool {
self.0 & 0x01 != 0
}
pub fn client_cert_conn(&self) -> bool {
self.0 & 0x02 != 0
}
pub fn client_cert_sess(&self) -> bool {
self.0 & 0x04 != 0
}
pub fn verify(&self) -> u32 {
self.1
}
pub fn tlvs(&self) -> Tlvs<'_> {
Tlvs { buf: &self.2 }
}
pub fn version(&self) -> Option<&str> {
tlv_borrowed!(self, SslVersion)
}
pub fn cn(&self) -> Option<&str> {
tlv_borrowed!(self, SslCn)
}
pub fn cipher(&self) -> Option<&str> {
tlv_borrowed!(self, SslCipher)
}
pub fn sig_alg(&self) -> Option<&str> {
tlv_borrowed!(self, SslSigAlg)
}
pub fn key_alg(&self) -> Option<&str> {
tlv_borrowed!(self, SslKeyAlg)
}
pub fn into_owned(self) -> SslInfo<'static> {
SslInfo(self.0, self.1, Cow::Owned(self.2.into_owned()))
}
pub fn append_tlv(&mut self, tlv: Tlv<'_>) {
tlv.encode(self.2.to_mut());
}
}
impl fmt::Debug for SslInfo<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Ssl")
.field("verify", &self.verify())
.field("client_ssl", &self.client_ssl())
.field("client_cert_conn", &self.client_cert_conn())
.field("client_cert_sess", &self.client_cert_sess())
.field("fields", &self.tlvs().collect::<Vec<_>>())
.finish()
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Tlv<'a> {
Alpn(Cow<'a, [u8]>),
Authority(Cow<'a, str>),
Crc32c(u32),
Noop(usize),
UniqueId(Cow<'a, [u8]>),
Ssl(SslInfo<'a>),
Netns(Cow<'a, str>),
SslVersion(Cow<'a, str>),
SslCn(Cow<'a, str>),
SslCipher(Cow<'a, str>),
SslSigAlg(Cow<'a, str>),
SslKeyAlg(Cow<'a, str>),
Custom(u8, Cow<'a, [u8]>),
}
impl<'a> Tlv<'a> {
pub fn decode(kind: u8, data: &'a [u8]) -> Result<Tlv<'a>, Error> {
use std::str::from_utf8;
use Tlv::*;
match kind {
0x01 => Ok(Alpn(data.into())),
0x02 => Ok(Authority(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
)),
0x03 => Ok(Crc32c(u32::from_be_bytes(
data.try_into().map_err(|_| Error::Invalid)?,
))),
0x04 => Ok(Noop(data.len())),
0x05 => Ok(UniqueId(data.into())),
0x20 => Ok(Ssl(SslInfo(
*data.first().ok_or(Error::Invalid)?,
u32::from_be_bytes(
data.get(1..5)
.ok_or(Error::Invalid)?
.try_into()
.map_err(|_| Error::Invalid)?,
),
data.get(5..).ok_or(Error::Invalid)?.into(),
))),
0x21 => Ok(SslVersion(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
)),
0x22 => Ok(SslCn(from_utf8(data).map_err(|_| Error::Invalid)?.into())),
0x23 => Ok(SslCipher(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
)),
0x24 => Ok(SslSigAlg(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
)),
0x25 => Ok(SslKeyAlg(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
)),
0x30 => Ok(Netns(from_utf8(data).map_err(|_| Error::Invalid)?.into())),
a => Ok(Custom(a, data.into())),
}
}
pub fn kind(&self) -> u8 {
match self {
Tlv::Alpn(_) => 0x01,
Tlv::Authority(_) => 0x02,
Tlv::Crc32c(_) => 0x03,
Tlv::Noop(_) => 0x04,
Tlv::UniqueId(_) => 0x05,
Tlv::Ssl(_) => 0x20,
Tlv::Netns(_) => 0x30,
Tlv::SslVersion(_) => 0x21,
Tlv::SslCn(_) => 0x22,
Tlv::SslCipher(_) => 0x23,
Tlv::SslSigAlg(_) => 0x24,
Tlv::SslKeyAlg(_) => 0x25,
Tlv::Custom(a, _) => *a,
}
}
pub fn encode(&self, buf: &mut Vec<u8>) {
let initial = buf.len();
buf.extend_from_slice(&[self.kind(), 0, 0]);
match self {
Tlv::Alpn(v) => buf.extend_from_slice(v),
Tlv::Authority(v) => buf.extend_from_slice(v.as_bytes()),
Tlv::Crc32c(v) => buf.extend_from_slice(&v.to_be_bytes()),
Tlv::Noop(len) => {
buf.resize(buf.len() + len, 0);
}
Tlv::UniqueId(v) => buf.extend_from_slice(v),
Tlv::Ssl(v) => {
buf.push(v.0);
buf.extend_from_slice(&v.1.to_be_bytes());
buf.extend_from_slice(&v.2);
}
Tlv::Netns(v) => buf.extend_from_slice(v.as_bytes()),
Tlv::SslVersion(v) => buf.extend_from_slice(v.as_bytes()),
Tlv::SslCn(v) => buf.extend_from_slice(v.as_bytes()),
Tlv::SslCipher(v) => buf.extend_from_slice(v.as_bytes()),
Tlv::SslSigAlg(v) => buf.extend_from_slice(v.as_bytes()),
Tlv::SslKeyAlg(v) => buf.extend_from_slice(v.as_bytes()),
Tlv::Custom(_, v) => buf.extend_from_slice(v),
}
let len = buf.len() - initial - 3;
if len > u16::MAX as usize {
panic!("TLV field too long");
}
buf[initial + 1] = ((len >> 8) & 0xff) as u8;
buf[initial + 2] = (len & 0xff) as u8;
}
pub fn into_owned(self) -> Tlv<'static> {
match self {
Tlv::Alpn(v) => Tlv::Alpn(Cow::Owned(v.into_owned())),
Tlv::Authority(v) => Tlv::Authority(Cow::Owned(v.into_owned())),
Tlv::Crc32c(v) => Tlv::Crc32c(v),
Tlv::Noop(v) => Tlv::Noop(v),
Tlv::UniqueId(v) => Tlv::UniqueId(Cow::Owned(v.into_owned())),
Tlv::Ssl(v) => Tlv::Ssl(v.into_owned()),
Tlv::Netns(v) => Tlv::Netns(Cow::Owned(v.into_owned())),
Tlv::SslVersion(v) => Tlv::SslVersion(Cow::Owned(v.into_owned())),
Tlv::SslCn(v) => Tlv::SslCn(Cow::Owned(v.into_owned())),
Tlv::SslCipher(v) => Tlv::SslCipher(Cow::Owned(v.into_owned())),
Tlv::SslSigAlg(v) => Tlv::SslSigAlg(Cow::Owned(v.into_owned())),
Tlv::SslKeyAlg(v) => Tlv::SslKeyAlg(Cow::Owned(v.into_owned())),
Tlv::Custom(a, v) => Tlv::Custom(a, Cow::Owned(v.into_owned())),
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct ParseConfig {
pub include_tlvs: bool,
pub allow_v1: bool,
pub allow_v2: bool,
}
impl Default for ParseConfig {
fn default() -> Self {
Self {
include_tlvs: true,
allow_v1: true,
allow_v2: true,
}
}
}
#[derive(Default, PartialEq, Eq, Clone)]
pub struct ProxyHeader<'a>(Option<ProxiedAddress>, Cow<'a, [u8]>);
impl<'a> ProxyHeader<'a> {
pub fn with_local() -> Self {
Default::default()
}
pub fn with_address(addr: ProxiedAddress) -> Self {
Self(Some(addr), Cow::Owned(Vec::new()))
}
pub fn with_tlvs<'b>(
addr: Option<ProxiedAddress>,
tlvs: impl IntoIterator<Item = Tlv<'b>>,
) -> Self {
let mut buf = Vec::with_capacity(64);
for tlv in tlvs {
tlv.encode(&mut buf);
}
Self(addr, Cow::Owned(buf))
}
pub fn parse(buf: &'a [u8], config: ParseConfig) -> Result<(Self, usize), Error> {
match buf.first() {
Some(b'P') if config.allow_v1 => v1::decode(buf),
Some(b'\r') if config.allow_v2 => v2::decode(buf, config),
None => Err(Error::BufferTooShort),
_ => Err(Error::Invalid),
}
}
pub fn proxied_address(&self) -> Option<&ProxiedAddress> {
self.0.as_ref()
}
pub fn tlvs(&self) -> Tlvs<'_> {
Tlvs { buf: &self.1 }
}
pub fn alpn(&self) -> Option<&[u8]> {
tlv_borrowed!(self, Alpn)
}
pub fn authority(&self) -> Option<&str> {
tlv_borrowed!(self, Authority)
}
pub fn crc32c(&self) -> Option<u32> {
tlv!(self, Crc32c)
}
pub fn unique_id(&self) -> Option<&[u8]> {
tlv_borrowed!(self, UniqueId)
}
pub fn ssl(&self) -> Option<SslInfo<'_>> {
tlv!(self, Ssl)
}
pub fn netns(&self) -> Option<&str> {
tlv_borrowed!(self, Netns)
}
pub fn into_owned(self) -> ProxyHeader<'static> {
ProxyHeader(self.0, Cow::Owned(self.1.into_owned()))
}
pub fn append_tlv(&mut self, tlv: Tlv<'_>) {
tlv.encode(self.1.to_mut());
}
pub fn encode_v1(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
v1::encode(self, buf)
}
pub fn encode_v2(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
v2::encode(self, buf)
}
pub fn encode_to_slice_v1(&self, buf: &mut [u8]) -> Result<usize, Error> {
let mut cursor = std::io::Cursor::new(buf);
v1::encode(self, &mut cursor)?;
Ok(cursor.position() as usize)
}
pub fn encode_to_slice_v2(&self, buf: &mut [u8]) -> Result<usize, Error> {
let mut cursor = std::io::Cursor::new(buf);
v2::encode(self, &mut cursor)?;
Ok(cursor.position() as usize)
}
}
impl fmt::Debug for ProxyHeader<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProxyHeader")
.field("address_info", &self.proxied_address())
.field("fields", &self.tlvs().collect::<Vec<_>>())
.finish()
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum Error {
BufferTooShort,
Invalid,
AddressFamilyMismatch,
HeaderTooBig,
V1UnsupportedTlv,
V1UnsupportedProtocol,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Error::*;
match self {
BufferTooShort => write!(f, "buffer too short"),
Invalid => write!(f, "invalid PROXY header"),
AddressFamilyMismatch => {
write!(f, "source and destination address families do not match")
}
HeaderTooBig => write!(f, "PROXY header too big"),
V1UnsupportedTlv => write!(f, "TLV fields are not supported in v1 header"),
V1UnsupportedProtocol => {
write!(f, "protocols other than TCP are not supported in v1 header")
}
}
}
}
impl std::error::Error for Error {}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use super::*;
const V1_UNKNOWN: &[u8] = b"PROXY UNKNOWN\r\n";
const V1_TCPV4: &[u8] = b"PROXY TCP4 127.0.0.1 192.168.0.1 12345 443\r\n";
const V1_TCPV6: &[u8] = b"PROXY TCP6 2001:db8::1 ::1 12345 443\r\n";
const V2_LOCAL: &[u8] =
b"\r\n\r\n\0\r\nQUIT\n \0\0\x0f\x03\0\x04\x88\x9d\xa1\xdf \0\x05\0\0\0\0\0";
const V2_TCPV4: &[u8] = &[
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 12, 127, 0, 0, 1, 192, 168, 0, 1,
48, 57, 1, 187,
];
const V2_TCPV6: &[u8] = &[
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 33, 0, 36, 32, 1, 13, 184, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 48, 57, 1, 187,
];
const V2_TCPV4_TLV: &[u8] = &[
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 104, 127, 0, 0, 1, 192, 168, 0,
1, 48, 57, 1, 187, 3, 0, 4, 211, 153, 216, 216, 5, 0, 4, 49, 50, 51, 52, 32, 0, 75, 7, 0,
0, 0, 0, 33, 0, 7, 84, 76, 83, 118, 49, 46, 51, 34, 0, 9, 108, 111, 99, 97, 108, 104, 111,
115, 116, 37, 0, 7, 82, 83, 65, 52, 48, 57, 54, 36, 0, 10, 82, 83, 65, 45, 83, 72, 65, 50,
53, 54, 35, 0, 22, 84, 76, 83, 95, 65, 69, 83, 95, 50, 53, 54, 95, 71, 67, 77, 95, 83, 72,
65, 51, 56, 52,
];
#[test]
fn test_parse_proxy_header_too_short() {
for case in [
V1_TCPV4,
V1_TCPV6,
V1_UNKNOWN,
V2_TCPV4,
V2_TCPV6,
V2_TCPV4_TLV,
V2_LOCAL,
]
.iter()
{
for i in 0..case.len() {
assert!(matches!(
ProxyHeader::parse(&case[..i], Default::default()),
Err(Error::BufferTooShort)
));
}
assert!(matches!(
ProxyHeader::parse(case, Default::default()),
Ok(_)
));
}
}
#[test]
fn test_parse_proxy_header_v1_unterminated() {
let line = b"PROXY TCP4 THISISSTORYALLABOUTHOWMYLIFEGOTFLIPPEDTURNEDUPSIDEDOWNANDIDLIKETOTAKEAMINUTEJUSTSITRIGHTTHEREANDILLTELLYOUHOWIGOTTHEPRINCEOFAIR\r\n";
assert!(matches!(
ProxyHeader::parse(line, Default::default()),
Err(Error::Invalid)
));
}
#[test]
fn test_parse_proxy_header_v1() {
let (res, consumed) = ProxyHeader::parse(V1_TCPV4, Default::default()).unwrap();
assert_eq!(consumed, V1_TCPV4.len());
assert_eq!(
res.0,
Some(ProxiedAddress {
protocol: Protocol::Stream,
source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
})
);
assert_eq!(res.1, vec![0; 0]);
let (res, consumed) = ProxyHeader::parse(V1_TCPV6, Default::default()).unwrap();
assert_eq!(consumed, V1_TCPV6.len());
assert_eq!(
res.0,
Some(ProxiedAddress {
protocol: Protocol::Stream,
source: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
12345
),
destination: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
443
),
})
);
assert_eq!(res.1, vec![0; 0]);
}
#[test]
fn test_parse_proxy_header_v2() {
let (res, consumed) = ProxyHeader::parse(V2_LOCAL, Default::default()).unwrap();
assert_eq!(consumed, V2_LOCAL.len());
assert_eq!(res.0, None);
let (res, consumed) = ProxyHeader::parse(V2_TCPV4, Default::default()).unwrap();
assert_eq!(consumed, V2_TCPV4.len());
assert_eq!(
res.0,
Some(ProxiedAddress {
protocol: Protocol::Stream,
source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
})
);
let (res, consumed) = ProxyHeader::parse(V2_TCPV6, Default::default()).unwrap();
assert_eq!(consumed, V2_TCPV6.len());
assert_eq!(
res.0,
Some(ProxiedAddress {
protocol: Protocol::Stream,
source: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
12345
),
destination: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
443
),
})
);
}
#[test]
fn test_parse_proxy_header_with_tlvs() {
let (res, _) = ProxyHeader::parse(
V2_TCPV4_TLV,
ParseConfig {
include_tlvs: true,
..Default::default()
},
)
.unwrap();
use Tlv::*;
let mut fields = res.tlvs();
assert_eq!(fields.next(), Some(Ok(Crc32c(0xd399d8d8))));
assert_eq!(fields.next(), Some(Ok(UniqueId(b"1234"[..].into()))));
let ssl = fields.next().unwrap().unwrap();
let ssl = match ssl {
Tlv::Ssl(ssl) => ssl,
_ => panic!("expected SSL TLV"),
};
assert!(ssl.verify() == 0);
assert!(ssl.client_ssl());
assert!(ssl.client_cert_conn());
assert!(ssl.client_cert_sess());
let mut f = ssl.tlvs();
assert_eq!(f.next(), Some(Ok(SslVersion("TLSv1.3".into()))));
assert_eq!(f.next(), Some(Ok(SslCn("localhost".into()))));
assert_eq!(f.next(), Some(Ok(SslKeyAlg("RSA4096".into()))));
assert_eq!(f.next(), Some(Ok(SslSigAlg("RSA-SHA256".into()))));
assert_eq!(
f.next(),
Some(Ok(SslCipher("TLS_AES_256_GCM_SHA384".into())))
);
assert!(f.next().is_none());
assert!(fields.next().is_none());
}
}