hyperproxy/
lib.rs

1//! A PROXYv2 wrapper for hyper and tonic.
2
3#[macro_use]
4extern crate log;
5
6#[cfg(feature = "track_conn_count")]
7use std::sync::{
8    atomic::{AtomicU64, Ordering},
9    Arc,
10};
11use std::{
12    convert::TryInto,
13    io::{self, ErrorKind},
14    mem::MaybeUninit,
15    net::{IpAddr, Ipv4Addr, SocketAddr},
16    pin::Pin,
17    task::{Context, Poll},
18};
19
20use futures::Future;
21use tokio::{
22    io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf},
23    net::tcp::{OwnedReadHalf, OwnedWriteHalf},
24};
25
26mod wrapped_incoming;
27pub use wrapped_incoming::WrappedIncoming;
28
29#[derive(Clone, Copy, Debug)]
30/// Accept/Reject mode for accepting connections
31pub enum ProxyMode {
32    /// Disable PROXYv2 (if sent, PROXYv2 data will be passed through)
33    None,
34    /// PROXYv2 data is parsed if present, otherwise the original address is used
35    Accept,
36    /// PROXYv2 data is required or the connection will be rejected
37    Require,
38}
39
40const PROXY_PACKET_HEADER_LEN: usize = 16;
41const PROXY_PACKET_MAX_PROXY_ADDR_SIZE: usize = 216;
42const PROXY_SIGNATURE: [u8; 12] = [
43    0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
44];
45const PROXY_PROTOCOL_VERSION: u8 = 2;
46
47#[repr(u8)]
48#[derive(Debug, Clone, Copy, PartialEq)]
49/// A PROXYv2 Command
50pub enum Command {
51    Local,
52    Proxy,
53}
54
55impl Command {
56    fn from_u8(from: u8) -> Option<Self> {
57        match from {
58            0 => Some(Command::Local),
59            1 => Some(Command::Proxy),
60            _ => None,
61        }
62    }
63}
64
65#[repr(u8)]
66#[derive(Debug, Clone, Copy, PartialEq)]
67/// A PROXYv2 family
68pub enum Family {
69    Unspecified,
70    Ipv4,
71    Ipv6,
72    Unix,
73}
74
75impl Family {
76    fn from_u8(from: u8) -> Option<Self> {
77        match from {
78            0 => Some(Family::Unspecified),
79            1 => Some(Family::Ipv4),
80            2 => Some(Family::Ipv6),
81            3 => Some(Family::Unix),
82            _ => None,
83        }
84    }
85
86    fn len(&self) -> Option<usize> {
87        match self {
88            Family::Unspecified => None,
89            Family::Ipv4 => Some(12),
90            Family::Ipv6 => Some(36),
91            Family::Unix => Some(216),
92        }
93    }
94}
95
96#[repr(u8)]
97#[derive(Debug, Clone, Copy, PartialEq)]
98/// A PROXYv2 protocol
99pub enum Protocol {
100    Unspecified,
101    Stream,
102    Datagram,
103}
104
105impl Protocol {
106    fn from_u8(from: u8) -> Option<Self> {
107        match from {
108            0 => Some(Protocol::Unspecified),
109            1 => Some(Protocol::Stream),
110            2 => Some(Protocol::Datagram),
111            _ => None,
112        }
113    }
114}
115
116#[derive(PartialEq, Debug)]
117struct ProxyInfo {
118    command: Command,
119    family: Family,
120    protocol: Protocol,
121    discovered_dest: Option<SocketAddr>,
122    discovered_src: Option<SocketAddr>,
123}
124
125#[derive(PartialEq, Debug)]
126enum ProxyResult {
127    Proxy(ProxyInfo),
128    SignatureBytes([u8; PROXY_SIGNATURE.len()]),
129}
130
131/// A wrapper over [`hyper::server::conn::AddrStream`] that grabs PROXYv2 information
132pub struct WrappedStream {
133    remote_addr: SocketAddr,
134    inner_write: Pin<Box<OwnedWriteHalf>>,
135    inner_read: Option<Pin<Box<OwnedReadHalf>>>,
136    #[cfg(feature = "track_conn_count")]
137    conn_count: Arc<AtomicU64>,
138    pending_read_proxy: Option<
139        Pin<
140            Box<
141                dyn Future<Output = io::Result<(ProxyResult, Pin<Box<OwnedReadHalf>>)>>
142                    + Send
143                    + Sync
144                    + 'static,
145            >,
146        >,
147    >,
148    info: Option<ProxyInfo>,
149    #[cfg(feature = "tonic")]
150    connect_info: std::sync::Arc<std::sync::RwLock<Option<SocketAddr>>>,
151    fused_error: bool,
152    proxy_mode: ProxyMode,
153}
154
155#[cfg(feature = "tonic")]
156#[derive(Clone)]
157pub struct TcpConnectInfo {
158    inner: std::sync::Arc<std::sync::RwLock<Option<SocketAddr>>>,
159}
160
161#[cfg(feature = "tonic")]
162impl TcpConnectInfo {
163    pub fn remote_addr(&self) -> Option<SocketAddr> {
164        *self.inner.read().unwrap()
165    }
166}
167
168#[cfg(feature = "tonic")]
169impl tonic::transport::server::Connected for WrappedStream {
170    type ConnectInfo = TcpConnectInfo;
171    fn connect_info(&self) -> Self::ConnectInfo {
172        TcpConnectInfo {
173            inner: self.connect_info.clone(),
174        }
175    }
176}
177
178#[cfg(feature = "tonic")]
179pub fn tonic_remote_addr<T>(request: &tonic::Request<T>) -> Option<SocketAddr> {
180    request
181        .extensions()
182        .get::<TcpConnectInfo>()
183        .expect("missing TCP connect info (was hyperproxy inline with tonic?)")
184        .remote_addr()
185}
186
187#[cfg(feature = "axum")]
188impl<'a> axum::extract::connect_info::Connected<&'a WrappedStream> for SocketAddr {
189    fn connect_info(target: &'a WrappedStream) -> Self {
190        target.source()
191    }
192}
193
194fn to_array<const SIZE: usize>(from: &[u8]) -> [u8; SIZE] {
195    from.try_into().unwrap()
196}
197
198async fn read_proxy<R: AsyncRead + Unpin>(mut read: R) -> io::Result<(ProxyResult, R)> {
199    let mut signature = [0u8; PROXY_SIGNATURE.len()];
200    read.read_exact(&mut signature[..]).await?;
201    if signature != PROXY_SIGNATURE {
202        return Ok((ProxyResult::SignatureBytes(signature), read));
203    }
204
205    // 4 bytes
206    let mut header = [0u8; PROXY_PACKET_HEADER_LEN - PROXY_SIGNATURE.len()];
207    read.read_exact(&mut header[..]).await?;
208
209    let version = (header[0] & 0xf0) >> 4;
210    if version != PROXY_PROTOCOL_VERSION {
211        debug!("invalid proxy protocol version: {}", version);
212        return Err(io::Error::new(
213            ErrorKind::InvalidData,
214            "invalid proxy protocol version",
215        ));
216    }
217    let command = header[0] & 0x0f;
218    let command = match Command::from_u8(command) {
219        Some(c) => c,
220        None => {
221            debug!("invalid proxy protocol command: {}", command);
222            return Err(io::Error::new(
223                ErrorKind::InvalidData,
224                "invalid proxy protocol command",
225            ));
226        }
227    };
228
229    let family = (header[1] & 0xf0) >> 4;
230    let family = match Family::from_u8(family) {
231        None => {
232            debug!("invalid proxy family: {}", family);
233            return Err(io::Error::new(
234                ErrorKind::InvalidData,
235                "invalid proxy family",
236            ));
237        }
238        Some(family) => {
239            trace!("PROXY family: {:?}", family);
240            family
241        }
242    };
243
244    let protocol = header[1] & 0x0f;
245    let protocol = match Protocol::from_u8(protocol) {
246        None => {
247            debug!("invalid proxy protocol: {}", protocol);
248            return Err(io::Error::new(
249                ErrorKind::InvalidData,
250                "invalid proxy protocol",
251            ));
252        }
253        Some(protocol) => {
254            trace!("PROXY protocol: {:?}", protocol);
255            protocol
256        }
257    };
258
259    let len = u16::from_be_bytes([header[2], header[3]]) as usize;
260    let target_len = if matches!(command, Command::Local) {
261        None
262    } else {
263        family.len()
264    };
265
266    if let Some(target_len) = target_len {
267        if len < target_len {
268            debug!("invalid proxy address length: {}", target_len);
269            return Err(io::Error::new(
270                ErrorKind::InvalidData,
271                "invalid proxy address length",
272            ));
273        }
274    }
275
276    let mut raw =
277        unsafe { MaybeUninit::<[u8; PROXY_PACKET_MAX_PROXY_ADDR_SIZE]>::uninit().assume_init() };
278    read.read_exact(&mut raw[..len]).await?;
279    let raw = &raw[..len];
280
281    let mut discovered_src = None;
282    let mut discovered_dest = None;
283
284    match family {
285        Family::Unspecified => {
286            debug!("unspecified PROXY family data: {:?}", raw);
287        }
288        Family::Ipv4 => {
289            let src_addr = IpAddr::V4(Ipv4Addr::from(to_array(&raw[..4])));
290            let dest_addr = IpAddr::V4(Ipv4Addr::from(to_array(&raw[4..8])));
291            let src_port = u16::from_be_bytes((&raw[8..10]).try_into().unwrap());
292            let dest_port = u16::from_be_bytes((&raw[10..12]).try_into().unwrap());
293            discovered_src = Some(SocketAddr::new(src_addr, src_port));
294            discovered_dest = Some(SocketAddr::new(dest_addr, dest_port));
295        }
296        Family::Ipv6 => {
297            let src_addr = IpAddr::V6(to_array(&raw[..16]).into());
298            let dest_addr = IpAddr::V6(to_array(&raw[16..32]).into());
299            let src_port = u16::from_be_bytes((&raw[32..34]).try_into().unwrap());
300            let dest_port = u16::from_be_bytes((&raw[34..36]).try_into().unwrap());
301            discovered_src = Some(SocketAddr::new(src_addr, src_port));
302            discovered_dest = Some(SocketAddr::new(dest_addr, dest_port));
303        }
304        Family::Unix => {
305            warn!("unsupported UNIX PROXY family, ignored.");
306        }
307    }
308
309    Ok((
310        ProxyResult::Proxy(ProxyInfo {
311            command,
312            family,
313            protocol,
314            discovered_dest,
315            discovered_src,
316        }),
317        read,
318    ))
319}
320
321impl AsyncRead for WrappedStream {
322    #[inline]
323    fn poll_read(
324        mut self: Pin<&mut Self>,
325        cx: &mut Context<'_>,
326        buf: &mut ReadBuf<'_>,
327    ) -> Poll<io::Result<()>> {
328        if self.fused_error {
329            return Poll::Ready(Err(io::Error::new(
330                ErrorKind::Unsupported,
331                "called read after error",
332            )));
333        }
334        if matches!(self.proxy_mode, ProxyMode::None) {
335            return self
336                .inner_read
337                .as_mut()
338                .unwrap()
339                .as_mut()
340                .poll_read(cx, buf);
341        }
342        assert!(buf.remaining() >= PROXY_SIGNATURE.len());
343
344        if self.pending_read_proxy.is_none() {
345            self.pending_read_proxy = Some(Box::pin(read_proxy(self.inner_read.take().unwrap())));
346        }
347        let output = self.pending_read_proxy.as_mut().unwrap().as_mut().poll(cx);
348        match output {
349            Poll::Ready(Err(e)) => {
350                self.fused_error = true;
351                self.pending_read_proxy = None;
352                Poll::Ready(Err(e))
353            }
354            Poll::Ready(Ok((ProxyResult::SignatureBytes(bytes), stream))) => {
355                if matches!(self.proxy_mode, ProxyMode::Require) {
356                    return Poll::Ready(Err(io::Error::new(
357                        ErrorKind::InvalidData,
358                        "required a PROXYv2 header, none found",
359                    )));
360                }
361                self.proxy_mode = ProxyMode::None;
362                buf.put_slice(&bytes[..]);
363                self.pending_read_proxy = None;
364                self.inner_read = Some(stream);
365                #[cfg(feature = "tonic")]
366                {
367                    *self.connect_info.write().unwrap() = Some(self.source());
368                }
369                self.inner_read
370                    .as_mut()
371                    .unwrap()
372                    .as_mut()
373                    .poll_read(cx, buf)
374            }
375            Poll::Ready(Ok((ProxyResult::Proxy(info), stream))) => {
376                self.proxy_mode = ProxyMode::None;
377                self.info = Some(info);
378                self.pending_read_proxy = None;
379                self.inner_read = Some(stream);
380                #[cfg(feature = "tonic")]
381                {
382                    *self.connect_info.write().unwrap() = Some(self.source());
383                }
384                self.inner_read
385                    .as_mut()
386                    .unwrap()
387                    .as_mut()
388                    .poll_read(cx, buf)
389            }
390            Poll::Pending => Poll::Pending,
391        }
392    }
393}
394
395impl WrappedStream {
396    /// Returns `true` if PROXYv2 information was sent
397    pub fn was_proxied(&self) -> bool {
398        self.info.is_some()
399    }
400
401    /// PROXYv2 reported command or None
402    pub fn command(&self) -> Option<Command> {
403        self.info.as_ref().map(|x| x.command)
404    }
405
406    /// PROXYv2 reported family or None
407    pub fn family(&self) -> Option<Family> {
408        self.info.as_ref().map(|x| x.family)
409    }
410
411    /// PROXYv2 reported protocol or None
412    pub fn protocol(&self) -> Option<Protocol> {
413        self.info.as_ref().map(|x| x.protocol)
414    }
415
416    /// PROXYv2 reported destination or None
417    pub fn destination(&self) -> Option<SocketAddr> {
418        self.info.as_ref().map(|x| x.discovered_dest).flatten()
419    }
420
421    /// PROXYv2 reported source or original address if none
422    pub fn source(&self) -> SocketAddr {
423        self.info
424            .as_ref()
425            .map(|x| x.discovered_src)
426            .flatten()
427            .unwrap_or_else(|| self.remote_addr)
428    }
429
430    /// The actual source that connected to us
431    pub fn original_source(&self) -> SocketAddr {
432        self.remote_addr
433    }
434}
435
436impl AsyncWrite for WrappedStream {
437    #[inline]
438    fn poll_write(
439        mut self: Pin<&mut Self>,
440        cx: &mut Context<'_>,
441        buf: &[u8],
442    ) -> Poll<io::Result<usize>> {
443        self.inner_write.as_mut().poll_write(cx, buf)
444    }
445
446    #[inline]
447    fn poll_write_vectored(
448        mut self: Pin<&mut Self>,
449        cx: &mut Context<'_>,
450        bufs: &[io::IoSlice<'_>],
451    ) -> Poll<io::Result<usize>> {
452        self.inner_write.as_mut().poll_write_vectored(cx, bufs)
453    }
454
455    #[inline]
456    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
457        self.inner_write.as_mut().poll_flush(cx)
458    }
459
460    #[inline]
461    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
462        self.inner_write.as_mut().poll_shutdown(cx)
463    }
464
465    #[inline]
466    fn is_write_vectored(&self) -> bool {
467        self.inner_write.is_write_vectored()
468    }
469}
470
471#[cfg(feature = "track_conn_count")]
472impl Drop for WrappedStream {
473    fn drop(&mut self) {
474        self.conn_count.fetch_sub(1, Ordering::SeqCst);
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[tokio::test]
483    async fn test_parse() {
484        let raw = hex::decode("0d0a0d0a000d0a515549540a21110054ffffffffac1f1cd1898801bb030004508978bb04003e0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap();
485        assert_eq!(
486            read_proxy(&raw[..]).await.unwrap().0,
487            ProxyResult::Proxy(ProxyInfo {
488                command: Command::Proxy,
489                family: Family::Ipv4,
490                protocol: Protocol::Stream,
491                discovered_dest: Some("172.31.28.209:443".parse().unwrap()),
492                discovered_src: Some("255.255.255.255:35208".parse().unwrap()),
493            })
494        );
495    }
496}