1use std::{
2 io,
3 net::{IpAddr, SocketAddr},
4 pin::{Pin, pin},
5 task::{Context, Poll},
6};
7
8use anyhow::{Context as _, anyhow};
9use ppp::v2;
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
11
12use super::Error;
13use crate::http::AsyncReadWrite;
14
15const PREFIX_LEN: usize = 12;
17const MINIMUM_LEN: usize = PREFIX_LEN + 4;
19const LENGTH_INDEX: usize = PREFIX_LEN + 2;
21const BUFFER_LEN: usize = 512;
23
24#[derive(Clone, Debug, PartialEq, Eq)]
26pub struct ProxyHeader {
27 pub src: SocketAddr,
28 pub dst: SocketAddr,
29}
30
31impl TryFrom<v2::Addresses> for ProxyHeader {
32 type Error = Error;
33
34 fn try_from(value: v2::Addresses) -> Result<Self, Self::Error> {
35 let (src, dst) = match value {
36 v2::Addresses::IPv4(v) => (
37 SocketAddr::new(IpAddr::V4(v.source_address), v.source_port),
38 SocketAddr::new(IpAddr::V4(v.destination_address), v.destination_port),
39 ),
40 v2::Addresses::IPv6(v) => (
41 SocketAddr::new(IpAddr::V6(v.source_address), v.source_port),
42 SocketAddr::new(IpAddr::V6(v.destination_address), v.destination_port),
43 ),
44 _ => return Err(Error::Generic(anyhow!("unsupported address type"))),
45 };
46
47 Ok(Self { src, dst })
48 }
49}
50
51#[derive(Debug)]
53pub(super) struct ProxyProtocolStream<T: AsyncReadWrite> {
54 inner: T,
55 data: Option<Vec<u8>>,
56}
57
58impl<T: AsyncReadWrite> ProxyProtocolStream<T> {
59 pub const fn new(inner: T, data: Option<Vec<u8>>) -> Self {
60 Self { inner, data }
61 }
62
63 pub async fn accept(mut stream: T) -> Result<(Self, Option<ProxyHeader>), Error> {
64 let mut buf = [0; BUFFER_LEN];
65
66 stream
71 .read_exact(&mut buf[..MINIMUM_LEN])
72 .await
73 .context("unable to read prefix")?;
74
75 if &buf[..PREFIX_LEN] != v2::PROTOCOL_PREFIX {
78 return Ok((Self::new(stream, Some(buf[..MINIMUM_LEN].to_vec())), None));
79 }
80
81 let len = u16::from_be_bytes([buf[LENGTH_INDEX], buf[LENGTH_INDEX + 1]]) as usize;
83 let full_len = MINIMUM_LEN + len;
84
85 #[allow(unused_assignments)]
89 let mut dyn_buf = Vec::new();
90 let hdr = if full_len > BUFFER_LEN {
91 dyn_buf = vec![0; full_len];
92 dyn_buf[..MINIMUM_LEN].copy_from_slice(&buf[..MINIMUM_LEN]);
93 stream
94 .read_exact(&mut dyn_buf[MINIMUM_LEN..full_len])
95 .await
96 .context("unable to read proxy header")?;
97
98 dyn_buf.as_slice()
99 } else {
100 stream
102 .read_exact(&mut buf[MINIMUM_LEN..full_len])
103 .await
104 .context("unable to read proxy header")?;
105
106 &buf
107 };
108
109 let hdr = v2::Header::try_from(hdr).context("unable to parse header")?;
111 let hdr = ProxyHeader::try_from(hdr.addresses)?;
112
113 Ok((Self::new(stream, None), Some(hdr)))
114 }
115}
116
117impl<T: AsyncReadWrite> AsyncRead for ProxyProtocolStream<T> {
118 fn poll_read(
119 mut self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 buf: &mut ReadBuf<'_>,
122 ) -> Poll<io::Result<()>> {
123 if let Some(mut v) = self.data.take() {
124 let buf_avail = buf.remaining();
125
126 if v.len() <= buf_avail {
128 buf.put_slice(&v);
129 return Poll::Ready(Ok(()));
130 }
131
132 buf.put_slice(&v[..buf_avail]);
134 v.rotate_left(buf_avail);
136 v.truncate(v.len() - buf_avail);
138 self.data.replace(v);
141
142 return Poll::Ready(Ok(()));
143 }
144
145 pin!(&mut self.inner).poll_read(cx, buf)
146 }
147}
148
149impl<T: AsyncReadWrite> AsyncWrite for ProxyProtocolStream<T> {
150 fn poll_write(
151 mut self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 buf: &[u8],
154 ) -> Poll<io::Result<usize>> {
155 pin!(&mut self.inner).poll_write(cx, buf)
156 }
157
158 fn poll_shutdown(
159 mut self: Pin<&mut Self>,
160 cx: &mut Context<'_>,
161 ) -> Poll<Result<(), io::Error>> {
162 pin!(&mut self.inner).poll_shutdown(cx)
163 }
164
165 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166 pin!(&mut self.inner).poll_flush(cx)
167 }
168}
169
170#[cfg(test)]
171mod test {
172 use std::net::{Ipv4Addr, SocketAddrV4};
173
174 use super::*;
175 use anyhow::Error;
176 use mock_io::tokio::MockStream;
177 use tokio::io::AsyncWriteExt;
178
179 #[tokio::test]
180 async fn test_proxy_protocol_stream() -> Result<(), Error> {
181 let (recv, mut send) = MockStream::pair();
183 tokio::task::spawn(async move {
184 let _ = send.write(b"foobar").await.unwrap();
185 });
186 let mut s = ProxyProtocolStream::new(recv, None);
187 let mut buf = vec![0; 6];
188 s.read_exact(&mut buf).await.unwrap();
189 assert_eq!(buf, b"foobar");
190
191 let (recv, mut send) = MockStream::pair();
193 tokio::task::spawn(async move {
194 let _ = send.write(b"foobar").await.unwrap();
195 });
196 let mut s = ProxyProtocolStream::new(recv, Some(b"deadbeef".to_vec()));
197 let mut buf = vec![0; 14];
198 s.read_exact(&mut buf).await.unwrap();
199 assert_eq!(buf, b"deadbeeffoobar");
200
201 let (recv, mut send) = MockStream::pair();
203 tokio::task::spawn(async move {
204 let _ = send.write(b"foobar").await.unwrap();
205 });
206 let mut s = ProxyProtocolStream::new(recv, Some(b"deadbeef".to_vec()));
207 let mut buf = vec![0; 6];
208 s.read_exact(&mut buf).await.unwrap();
209 assert_eq!(buf, b"deadbe");
210 let mut buf = vec![0; 3];
211 s.read_exact(&mut buf).await.unwrap();
212 assert_eq!(buf, b"eff");
213 let mut buf = vec![0; 3];
214 s.read_exact(&mut buf).await.unwrap();
215 assert_eq!(buf, b"oob");
216 let mut buf = vec![0; 2];
217 s.read_exact(&mut buf).await.unwrap();
218 assert_eq!(buf, b"ar");
219 assert!(s.read(&mut buf).await.is_err());
220
221 Ok(())
222 }
223
224 #[tokio::test]
225 async fn test_proxy_protocol_accept_with_proxy_header() -> Result<(), Error> {
226 let addrs = v2::IPv4::new([1, 1, 1, 1], [2, 2, 2, 2], 31337, 443);
227 let mut hdr = v2::Builder::with_addresses(
228 v2::Version::Two | v2::Command::Proxy,
229 v2::Protocol::Stream,
230 addrs,
231 )
232 .build()?;
233 hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
234
235 let (recv, mut send) = MockStream::pair();
236 tokio::task::spawn(async move {
237 let n = send.write(&hdr).await.unwrap();
238 assert_eq!(n, hdr.len());
239 });
240
241 let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
242 let addr = addr.unwrap();
243 assert_eq!(
244 addr,
245 ProxyHeader {
246 src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 31337)),
247 dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(2, 2, 2, 2), 443)),
248 }
249 );
250
251 let mut buf = vec![0; 20];
252 stream.read_exact(&mut buf).await?;
253 assert_eq!(buf, &b"foobar foobaz foobar"[..]);
254
255 Ok(())
256 }
257
258 #[tokio::test]
259 async fn test_proxy_protocol_accept_with_long_proxy_header() -> Result<(), Error> {
260 let addrs = v2::IPv4::new([1, 1, 1, 1], [2, 2, 2, 2], 31337, 443);
261 let mut hdr = v2::Builder::with_addresses(
262 v2::Version::Two | v2::Command::Proxy,
263 v2::Protocol::Stream,
264 addrs,
265 );
266 for _ in 0..7000 {
267 hdr = hdr.write_tlv(v2::Type::NoOp, &b"foobar"[..]).unwrap();
268 }
269 let mut hdr = hdr.build()?;
270 hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
271
272 let (recv, mut send) = MockStream::pair();
273 tokio::task::spawn(async move {
274 let n = send.write(&hdr).await.unwrap();
275 assert_eq!(n, hdr.len());
276 });
277
278 let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
279 let addr = addr.unwrap();
280 assert_eq!(
281 addr,
282 ProxyHeader {
283 src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 31337)),
284 dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(2, 2, 2, 2), 443)),
285 }
286 );
287
288 let mut buf = vec![0; 20];
289 stream.read_exact(&mut buf).await?;
290 assert_eq!(buf, &b"foobar foobaz foobar"[..]);
291
292 Ok(())
293 }
294
295 #[tokio::test]
296 async fn test_proxy_protocol_accept_without_proxy_header() -> Result<(), Error> {
297 let (recv, mut send) = MockStream::pair();
298 tokio::task::spawn(async move {
299 let _ = send.write(&b"foobar foobaz foobar"[..]).await.unwrap();
300 });
301
302 let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
303 assert!(addr.is_none());
304
305 let mut buf = vec![0; 10];
306 stream.read_exact(&mut buf).await?;
307 assert_eq!(buf, &b"foobar foo"[..]);
308
309 let mut buf = vec![0; 10];
310 stream.read_exact(&mut buf).await?;
311 assert_eq!(buf, &b"baz foobar"[..]);
312
313 Ok(())
314 }
315
316 #[tokio::test]
317 async fn test_proxy_protocol_accept_with_invalid_header() {
318 let mut hdr = v2::PROTOCOL_PREFIX.to_vec();
320 hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
321
322 let (recv, mut send) = MockStream::pair();
323 tokio::task::spawn(async move {
324 let n = send.write(&hdr).await.unwrap();
325 assert_eq!(n, hdr.len());
326 });
327
328 let res = ProxyProtocolStream::accept(recv).await;
329 assert!(res.is_err());
330 }
331}