use std::{
io,
net::{IpAddr, SocketAddr},
pin::{Pin, pin},
task::{Context, Poll},
};
use anyhow::{Context as _, anyhow};
use ppp::v2;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use super::Error;
use crate::http::AsyncReadWrite;
const PREFIX_LEN: usize = 12;
const MINIMUM_LEN: usize = PREFIX_LEN + 4;
const LENGTH_INDEX: usize = PREFIX_LEN + 2;
const BUFFER_LEN: usize = 512;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ProxyHeader {
pub src: SocketAddr,
pub dst: SocketAddr,
}
impl TryFrom<v2::Addresses> for ProxyHeader {
type Error = Error;
fn try_from(value: v2::Addresses) -> Result<Self, Self::Error> {
let (src, dst) = match value {
v2::Addresses::IPv4(v) => (
SocketAddr::new(IpAddr::V4(v.source_address), v.source_port),
SocketAddr::new(IpAddr::V4(v.destination_address), v.destination_port),
),
v2::Addresses::IPv6(v) => (
SocketAddr::new(IpAddr::V6(v.source_address), v.source_port),
SocketAddr::new(IpAddr::V6(v.destination_address), v.destination_port),
),
_ => return Err(Error::Generic(anyhow!("unsupported address type"))),
};
Ok(Self { src, dst })
}
}
#[derive(Debug)]
pub(super) struct ProxyProtocolStream<T: AsyncReadWrite> {
inner: T,
data: Option<Vec<u8>>,
}
impl<T: AsyncReadWrite> ProxyProtocolStream<T> {
pub const fn new(inner: T, data: Option<Vec<u8>>) -> Self {
Self { inner, data }
}
pub async fn accept(mut stream: T) -> Result<(Self, Option<ProxyHeader>), Error> {
let mut buf = [0; BUFFER_LEN];
stream
.read_exact(&mut buf[..MINIMUM_LEN])
.await
.context("unable to read prefix")?;
if &buf[..PREFIX_LEN] != v2::PROTOCOL_PREFIX {
return Ok((Self::new(stream, Some(buf[..MINIMUM_LEN].to_vec())), None));
}
let len = u16::from_be_bytes([buf[LENGTH_INDEX], buf[LENGTH_INDEX + 1]]) as usize;
let full_len = MINIMUM_LEN + len;
#[allow(unused_assignments)]
let mut dyn_buf = Vec::new();
let hdr = if full_len > BUFFER_LEN {
dyn_buf = vec![0; full_len];
dyn_buf[..MINIMUM_LEN].copy_from_slice(&buf[..MINIMUM_LEN]);
stream
.read_exact(&mut dyn_buf[MINIMUM_LEN..full_len])
.await
.context("unable to read proxy header")?;
dyn_buf.as_slice()
} else {
stream
.read_exact(&mut buf[MINIMUM_LEN..full_len])
.await
.context("unable to read proxy header")?;
&buf
};
let hdr = v2::Header::try_from(hdr).context("unable to parse header")?;
let hdr = ProxyHeader::try_from(hdr.addresses)?;
Ok((Self::new(stream, None), Some(hdr)))
}
}
impl<T: AsyncReadWrite> AsyncRead for ProxyProtocolStream<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(mut v) = self.data.take() {
let buf_avail = buf.remaining();
if v.len() <= buf_avail {
buf.put_slice(&v);
return Poll::Ready(Ok(()));
}
buf.put_slice(&v[..buf_avail]);
v.rotate_left(buf_avail);
v.truncate(v.len() - buf_avail);
self.data.replace(v);
return Poll::Ready(Ok(()));
}
pin!(&mut self.inner).poll_read(cx, buf)
}
}
impl<T: AsyncReadWrite> AsyncWrite for ProxyProtocolStream<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
pin!(&mut self.inner).poll_write(cx, buf)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
pin!(&mut self.inner).poll_shutdown(cx)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
pin!(&mut self.inner).poll_flush(cx)
}
}
#[cfg(test)]
mod test {
use std::net::{Ipv4Addr, SocketAddrV4};
use super::*;
use anyhow::Error;
use mock_io::tokio::MockStream;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_proxy_protocol_stream() -> Result<(), Error> {
let (recv, mut send) = MockStream::pair();
tokio::task::spawn(async move {
let _ = send.write(b"foobar").await.unwrap();
});
let mut s = ProxyProtocolStream::new(recv, None);
let mut buf = vec![0; 6];
s.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"foobar");
let (recv, mut send) = MockStream::pair();
tokio::task::spawn(async move {
let _ = send.write(b"foobar").await.unwrap();
});
let mut s = ProxyProtocolStream::new(recv, Some(b"deadbeef".to_vec()));
let mut buf = vec![0; 14];
s.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"deadbeeffoobar");
let (recv, mut send) = MockStream::pair();
tokio::task::spawn(async move {
let _ = send.write(b"foobar").await.unwrap();
});
let mut s = ProxyProtocolStream::new(recv, Some(b"deadbeef".to_vec()));
let mut buf = vec![0; 6];
s.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"deadbe");
let mut buf = vec![0; 3];
s.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"eff");
let mut buf = vec![0; 3];
s.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"oob");
let mut buf = vec![0; 2];
s.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"ar");
assert!(s.read(&mut buf).await.is_err());
Ok(())
}
#[tokio::test]
async fn test_proxy_protocol_accept_with_proxy_header() -> Result<(), Error> {
let addrs = v2::IPv4::new([1, 1, 1, 1], [2, 2, 2, 2], 31337, 443);
let mut hdr = v2::Builder::with_addresses(
v2::Version::Two | v2::Command::Proxy,
v2::Protocol::Stream,
addrs,
)
.build()?;
hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
let (recv, mut send) = MockStream::pair();
tokio::task::spawn(async move {
let n = send.write(&hdr).await.unwrap();
assert_eq!(n, hdr.len());
});
let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
let addr = addr.unwrap();
assert_eq!(
addr,
ProxyHeader {
src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 31337)),
dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(2, 2, 2, 2), 443)),
}
);
let mut buf = vec![0; 20];
stream.read_exact(&mut buf).await?;
assert_eq!(buf, &b"foobar foobaz foobar"[..]);
Ok(())
}
#[tokio::test]
async fn test_proxy_protocol_accept_with_long_proxy_header() -> Result<(), Error> {
let addrs = v2::IPv4::new([1, 1, 1, 1], [2, 2, 2, 2], 31337, 443);
let mut hdr = v2::Builder::with_addresses(
v2::Version::Two | v2::Command::Proxy,
v2::Protocol::Stream,
addrs,
);
for _ in 0..7000 {
hdr = hdr.write_tlv(v2::Type::NoOp, &b"foobar"[..]).unwrap();
}
let mut hdr = hdr.build()?;
hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
let (recv, mut send) = MockStream::pair();
tokio::task::spawn(async move {
let n = send.write(&hdr).await.unwrap();
assert_eq!(n, hdr.len());
});
let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
let addr = addr.unwrap();
assert_eq!(
addr,
ProxyHeader {
src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 31337)),
dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(2, 2, 2, 2), 443)),
}
);
let mut buf = vec![0; 20];
stream.read_exact(&mut buf).await?;
assert_eq!(buf, &b"foobar foobaz foobar"[..]);
Ok(())
}
#[tokio::test]
async fn test_proxy_protocol_accept_without_proxy_header() -> Result<(), Error> {
let (recv, mut send) = MockStream::pair();
tokio::task::spawn(async move {
let _ = send.write(&b"foobar foobaz foobar"[..]).await.unwrap();
});
let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
assert!(addr.is_none());
let mut buf = vec![0; 10];
stream.read_exact(&mut buf).await?;
assert_eq!(buf, &b"foobar foo"[..]);
let mut buf = vec![0; 10];
stream.read_exact(&mut buf).await?;
assert_eq!(buf, &b"baz foobar"[..]);
Ok(())
}
#[tokio::test]
async fn test_proxy_protocol_accept_with_invalid_header() {
let mut hdr = v2::PROTOCOL_PREFIX.to_vec();
hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
let (recv, mut send) = MockStream::pair();
tokio::task::spawn(async move {
let n = send.write(&hdr).await.unwrap();
assert_eq!(n, hdr.len());
});
let res = ProxyProtocolStream::accept(recv).await;
assert!(res.is_err());
}
}