ic_bn_lib/http/server/
proxy_protocol.rs

1use std::{
2    io,
3    net::{IpAddr, SocketAddr},
4    pin::{Pin, pin},
5    task::{Context, Poll},
6};
7
8use anyhow::{Context as _, anyhow};
9use ppp::v2;
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
11
12use super::Error;
13use crate::http::AsyncReadWrite;
14
15/// The prefix length of a header in bytes.
16const PREFIX_LEN: usize = 12;
17/// The minimum length of a header in bytes.
18const MINIMUM_LEN: usize = PREFIX_LEN + 4;
19/// The index of the start of the big-endian u16 length in the header
20const LENGTH_INDEX: usize = PREFIX_LEN + 2;
21/// The length of the read buffer used to read the Proxy Protocol header
22const BUFFER_LEN: usize = 512;
23
24/// Data extracted from a Proxy Protocol header
25#[derive(Clone, Debug, PartialEq, Eq)]
26pub struct ProxyHeader {
27    pub src: SocketAddr,
28    pub dst: SocketAddr,
29}
30
31impl TryFrom<v2::Addresses> for ProxyHeader {
32    type Error = Error;
33
34    fn try_from(value: v2::Addresses) -> Result<Self, Self::Error> {
35        let (src, dst) = match value {
36            v2::Addresses::IPv4(v) => (
37                SocketAddr::new(IpAddr::V4(v.source_address), v.source_port),
38                SocketAddr::new(IpAddr::V4(v.destination_address), v.destination_port),
39            ),
40            v2::Addresses::IPv6(v) => (
41                SocketAddr::new(IpAddr::V6(v.source_address), v.source_port),
42                SocketAddr::new(IpAddr::V6(v.destination_address), v.destination_port),
43            ),
44            _ => return Err(Error::Generic(anyhow!("unsupported address type"))),
45        };
46
47        Ok(Self { src, dst })
48    }
49}
50
51/// Async Read+Write wrapper that appends some data before the wrapped stream
52#[derive(Debug)]
53pub(super) struct ProxyProtocolStream<T: AsyncReadWrite> {
54    inner: T,
55    data: Option<Vec<u8>>,
56}
57
58impl<T: AsyncReadWrite> ProxyProtocolStream<T> {
59    pub const fn new(inner: T, data: Option<Vec<u8>>) -> Self {
60        Self { inner, data }
61    }
62
63    pub async fn accept(mut stream: T) -> Result<(Self, Option<ProxyHeader>), Error> {
64        let mut buf = [0; BUFFER_LEN];
65
66        // Try to read the first part of proxy protocol header into a buffer.
67        // We assume that incoming requests are at least MINIMUM_LEN long,
68        // which is Ok since even the smallest HTTP request should be longer.
69        // That's not counting TLS handshake if we're running in TLS mode.
70        stream
71            .read_exact(&mut buf[..MINIMUM_LEN])
72            .await
73            .context("unable to read prefix")?;
74
75        // If the prefix doesn't match the proxy protocol signature - then we
76        // assume that we have no proxy protocol and just bypass the traffic.
77        if &buf[..PREFIX_LEN] != v2::PROTOCOL_PREFIX {
78            return Ok((Self::new(stream, Some(buf[..MINIMUM_LEN].to_vec())), None));
79        }
80
81        // Parse the header length
82        let len = u16::from_be_bytes([buf[LENGTH_INDEX], buf[LENGTH_INDEX + 1]]) as usize;
83        let full_len = MINIMUM_LEN + len;
84
85        // Switch to dynamic buffer if the header is too long.
86        // v2 has no maximum length (up to 2^16)
87        // TODO should we limit this even lower to avoid abuse?
88        #[allow(unused_assignments)]
89        let mut dyn_buf = Vec::new();
90        let hdr = if full_len > BUFFER_LEN {
91            dyn_buf = vec![0; full_len];
92            dyn_buf[..MINIMUM_LEN].copy_from_slice(&buf[..MINIMUM_LEN]);
93            stream
94                .read_exact(&mut dyn_buf[MINIMUM_LEN..full_len])
95                .await
96                .context("unable to read proxy header")?;
97
98            dyn_buf.as_slice()
99        } else {
100            // Otherwise just read into stack allocated buffer
101            stream
102                .read_exact(&mut buf[MINIMUM_LEN..full_len])
103                .await
104                .context("unable to read proxy header")?;
105
106            &buf
107        };
108
109        // Parse the header
110        let hdr = v2::Header::try_from(hdr).context("unable to parse header")?;
111        let hdr = ProxyHeader::try_from(hdr.addresses)?;
112
113        Ok((Self::new(stream, None), Some(hdr)))
114    }
115}
116
117impl<T: AsyncReadWrite> AsyncRead for ProxyProtocolStream<T> {
118    fn poll_read(
119        mut self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121        buf: &mut ReadBuf<'_>,
122    ) -> Poll<io::Result<()>> {
123        if let Some(mut v) = self.data.take() {
124            let buf_avail = buf.remaining();
125
126            // If there's enough space - just write there
127            if v.len() <= buf_avail {
128                buf.put_slice(&v);
129                return Poll::Ready(Ok(()));
130            }
131
132            // Otherwise write as much as we can
133            buf.put_slice(&v[..buf_avail]);
134            // Shift the buffer left
135            v.rotate_left(buf_avail);
136            // Truncate it
137            v.truncate(v.len() - buf_avail);
138            // Put it back.
139            // This helps avoid reallocating the Vec between read calls.
140            self.data.replace(v);
141
142            return Poll::Ready(Ok(()));
143        }
144
145        pin!(&mut self.inner).poll_read(cx, buf)
146    }
147}
148
149impl<T: AsyncReadWrite> AsyncWrite for ProxyProtocolStream<T> {
150    fn poll_write(
151        mut self: Pin<&mut Self>,
152        cx: &mut Context<'_>,
153        buf: &[u8],
154    ) -> Poll<io::Result<usize>> {
155        pin!(&mut self.inner).poll_write(cx, buf)
156    }
157
158    fn poll_shutdown(
159        mut self: Pin<&mut Self>,
160        cx: &mut Context<'_>,
161    ) -> Poll<Result<(), io::Error>> {
162        pin!(&mut self.inner).poll_shutdown(cx)
163    }
164
165    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166        pin!(&mut self.inner).poll_flush(cx)
167    }
168}
169
170#[cfg(test)]
171mod test {
172    use std::net::{Ipv4Addr, SocketAddrV4};
173
174    use super::*;
175    use anyhow::Error;
176    use mock_io::tokio::MockStream;
177    use tokio::io::AsyncWriteExt;
178
179    #[tokio::test]
180    async fn test_proxy_protocol_stream() -> Result<(), Error> {
181        // Try big enough buffer w/o data
182        let (recv, mut send) = MockStream::pair();
183        tokio::task::spawn(async move {
184            let _ = send.write(b"foobar").await.unwrap();
185        });
186        let mut s = ProxyProtocolStream::new(recv, None);
187        let mut buf = vec![0; 6];
188        s.read_exact(&mut buf).await.unwrap();
189        assert_eq!(buf, b"foobar");
190
191        // Try big enough buffer with data
192        let (recv, mut send) = MockStream::pair();
193        tokio::task::spawn(async move {
194            let _ = send.write(b"foobar").await.unwrap();
195        });
196        let mut s = ProxyProtocolStream::new(recv, Some(b"deadbeef".to_vec()));
197        let mut buf = vec![0; 14];
198        s.read_exact(&mut buf).await.unwrap();
199        assert_eq!(buf, b"deadbeeffoobar");
200
201        // Try smaller buffers
202        let (recv, mut send) = MockStream::pair();
203        tokio::task::spawn(async move {
204            let _ = send.write(b"foobar").await.unwrap();
205        });
206        let mut s = ProxyProtocolStream::new(recv, Some(b"deadbeef".to_vec()));
207        let mut buf = vec![0; 6];
208        s.read_exact(&mut buf).await.unwrap();
209        assert_eq!(buf, b"deadbe");
210        let mut buf = vec![0; 3];
211        s.read_exact(&mut buf).await.unwrap();
212        assert_eq!(buf, b"eff");
213        let mut buf = vec![0; 3];
214        s.read_exact(&mut buf).await.unwrap();
215        assert_eq!(buf, b"oob");
216        let mut buf = vec![0; 2];
217        s.read_exact(&mut buf).await.unwrap();
218        assert_eq!(buf, b"ar");
219        assert!(s.read(&mut buf).await.is_err());
220
221        Ok(())
222    }
223
224    #[tokio::test]
225    async fn test_proxy_protocol_accept_with_proxy_header() -> Result<(), Error> {
226        let addrs = v2::IPv4::new([1, 1, 1, 1], [2, 2, 2, 2], 31337, 443);
227        let mut hdr = v2::Builder::with_addresses(
228            v2::Version::Two | v2::Command::Proxy,
229            v2::Protocol::Stream,
230            addrs,
231        )
232        .build()?;
233        hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
234
235        let (recv, mut send) = MockStream::pair();
236        tokio::task::spawn(async move {
237            let n = send.write(&hdr).await.unwrap();
238            assert_eq!(n, hdr.len());
239        });
240
241        let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
242        let addr = addr.unwrap();
243        assert_eq!(
244            addr,
245            ProxyHeader {
246                src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 31337)),
247                dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(2, 2, 2, 2), 443)),
248            }
249        );
250
251        let mut buf = vec![0; 20];
252        stream.read_exact(&mut buf).await?;
253        assert_eq!(buf, &b"foobar foobaz foobar"[..]);
254
255        Ok(())
256    }
257
258    #[tokio::test]
259    async fn test_proxy_protocol_accept_with_long_proxy_header() -> Result<(), Error> {
260        let addrs = v2::IPv4::new([1, 1, 1, 1], [2, 2, 2, 2], 31337, 443);
261        let mut hdr = v2::Builder::with_addresses(
262            v2::Version::Two | v2::Command::Proxy,
263            v2::Protocol::Stream,
264            addrs,
265        );
266        for _ in 0..7000 {
267            hdr = hdr.write_tlv(v2::Type::NoOp, &b"foobar"[..]).unwrap();
268        }
269        let mut hdr = hdr.build()?;
270        hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
271
272        let (recv, mut send) = MockStream::pair();
273        tokio::task::spawn(async move {
274            let n = send.write(&hdr).await.unwrap();
275            assert_eq!(n, hdr.len());
276        });
277
278        let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
279        let addr = addr.unwrap();
280        assert_eq!(
281            addr,
282            ProxyHeader {
283                src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 31337)),
284                dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(2, 2, 2, 2), 443)),
285            }
286        );
287
288        let mut buf = vec![0; 20];
289        stream.read_exact(&mut buf).await?;
290        assert_eq!(buf, &b"foobar foobaz foobar"[..]);
291
292        Ok(())
293    }
294
295    #[tokio::test]
296    async fn test_proxy_protocol_accept_without_proxy_header() -> Result<(), Error> {
297        let (recv, mut send) = MockStream::pair();
298        tokio::task::spawn(async move {
299            let _ = send.write(&b"foobar foobaz foobar"[..]).await.unwrap();
300        });
301
302        let (mut stream, addr) = ProxyProtocolStream::accept(recv).await?;
303        assert!(addr.is_none());
304
305        let mut buf = vec![0; 10];
306        stream.read_exact(&mut buf).await?;
307        assert_eq!(buf, &b"foobar foo"[..]);
308
309        let mut buf = vec![0; 10];
310        stream.read_exact(&mut buf).await?;
311        assert_eq!(buf, &b"baz foobar"[..]);
312
313        Ok(())
314    }
315
316    #[tokio::test]
317    async fn test_proxy_protocol_accept_with_invalid_header() {
318        // Create a valid prefix, but invalid header data after it
319        let mut hdr = v2::PROTOCOL_PREFIX.to_vec();
320        hdr.extend_from_slice(&b"foobar foobaz foobar"[..]);
321
322        let (recv, mut send) = MockStream::pair();
323        tokio::task::spawn(async move {
324            let n = send.write(&hdr).await.unwrap();
325            assert_eq!(n, hdr.len());
326        });
327
328        let res = ProxyProtocolStream::accept(recv).await;
329        assert!(res.is_err());
330    }
331}