Skip to main content

gatel_core/server/
proxy_protocol.rs

1//! PROXY protocol v1/v2 parser for extracting the real client address from
2//! load balancers, CDNs, or other proxies that prepend the PROXY protocol
3//! header to TCP connections.
4//!
5//! # Protocol overview
6//!
7//! - **v1** (text): `PROXY TCP4 192.168.1.1 192.168.1.2 12345 80\r\n`
8//! - **v2** (binary): 12-byte signature + version/command + family/transport + length + addresses
9//!
10//! This module also provides [`PrefixedStream`], a wrapper around `TcpStream` that
11//! prepends buffered bytes (the leftover data after the PROXY header) so the
12//! rest of the connection can be read normally.
13
14use std::io;
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use pin_project_lite::pin_project;
20use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
21use tokio::net::TcpStream;
22use tracing::{debug, warn};
23
24use crate::ProxyError;
25
26/// The 12-byte signature that identifies a PROXY protocol v2 header.
27const V2_SIGNATURE: &[u8; 12] = b"\r\n\r\n\0\r\nQUIT\n";
28
29/// The maximum length of a PROXY protocol v1 line (per spec: 107 bytes including CRLF).
30const V1_MAX_LINE_LEN: usize = 108;
31
32/// Maximum header size to read before giving up (v2 can have TLVs, but we cap it).
33const MAX_HEADER_SIZE: usize = 536;
34
35/// Parsed result of a PROXY protocol header.
36#[derive(Debug, Clone)]
37pub struct ProxyProtocolHeader {
38    /// The real source (client) address.
39    pub src_addr: SocketAddr,
40    /// The destination address (may be None for LOCAL commands in v2).
41    pub dst_addr: Option<SocketAddr>,
42}
43
44/// Parse the PROXY protocol header from the beginning of a TCP stream.
45///
46/// Returns the parsed header and a [`PrefixedStream`] that wraps the original
47/// stream with any leftover bytes prepended so subsequent reads see the actual
48/// application data.
49///
50/// If the stream does not start with a PROXY protocol header, returns `None`
51/// and the stream is returned as-is inside a `PrefixedStream` with the peeked
52/// bytes prepended.
53pub async fn parse_proxy_protocol(
54    stream: &mut TcpStream,
55) -> Result<(Option<ProxyProtocolHeader>, Vec<u8>), ProxyError> {
56    // Read enough bytes to determine the protocol version.
57    // We need at least 16 bytes to check for v2 (12 signature + 4 header).
58    // For v1, we need "PROXY " (6 bytes) as a prefix.
59    let mut buf = vec![0u8; MAX_HEADER_SIZE];
60    let mut total_read = 0;
61
62    // Read the initial bytes. We need at least 16 for v2 detection.
63    while total_read < 16 {
64        let n = stream.read(&mut buf[total_read..]).await?;
65        if n == 0 {
66            return Err(ProxyError::Internal(
67                "connection closed before PROXY protocol header".into(),
68            ));
69        }
70        total_read += n;
71    }
72
73    // Check for v2 signature.
74    if buf[..12] == *V2_SIGNATURE {
75        return parse_v2(&mut buf, total_read, stream).await;
76    }
77
78    // Check for v1 prefix.
79    if buf.starts_with(b"PROXY ") {
80        return parse_v1(&mut buf, total_read, stream).await;
81    }
82
83    // Not a PROXY protocol header — return all buffered bytes as prefix.
84    debug!("no PROXY protocol header detected, passing through");
85    let prefix = buf[..total_read].to_vec();
86    Ok((None, prefix))
87}
88
89/// Parse a PROXY protocol v1 (text) header.
90#[allow(clippy::ptr_arg)]
91async fn parse_v1(
92    buf: &mut Vec<u8>,
93    mut total_read: usize,
94    stream: &mut TcpStream,
95) -> Result<(Option<ProxyProtocolHeader>, Vec<u8>), ProxyError> {
96    // Read until we find \r\n or hit the max line length.
97    loop {
98        if let Some(pos) = buf[..total_read].windows(2).position(|w| w == b"\r\n") {
99            let line = std::str::from_utf8(&buf[..pos])
100                .map_err(|_| ProxyError::Internal("PROXY v1 header is not valid UTF-8".into()))?;
101
102            let header = parse_v1_line(line)?;
103            let remaining = buf[pos + 2..total_read].to_vec();
104
105            debug!(
106                src = %header.src_addr,
107                dst = ?header.dst_addr,
108                "parsed PROXY protocol v1 header"
109            );
110
111            return Ok((Some(header), remaining));
112        }
113
114        if total_read >= V1_MAX_LINE_LEN {
115            return Err(ProxyError::Internal(
116                "PROXY v1 header too long (no CRLF found)".into(),
117            ));
118        }
119
120        // Read more data.
121        let n = stream.read(&mut buf[total_read..]).await?;
122        if n == 0 {
123            return Err(ProxyError::Internal(
124                "connection closed while reading PROXY v1 header".into(),
125            ));
126        }
127        total_read += n;
128    }
129}
130
131/// Parse a v1 header line like `PROXY TCP4 192.168.1.1 192.168.1.2 12345 80`.
132fn parse_v1_line(line: &str) -> Result<ProxyProtocolHeader, ProxyError> {
133    let parts: Vec<&str> = line.split_whitespace().collect();
134
135    // Must start with "PROXY"
136    if parts.is_empty() || parts[0] != "PROXY" {
137        return Err(ProxyError::Internal("invalid PROXY v1 header".into()));
138    }
139
140    // PROXY UNKNOWN is valid — means we don't know the addresses.
141    if parts.len() >= 2 && parts[1] == "UNKNOWN" {
142        return Err(ProxyError::Internal(
143            "PROXY v1 UNKNOWN protocol — no address info".into(),
144        ));
145    }
146
147    if parts.len() < 6 {
148        return Err(ProxyError::Internal(format!(
149            "PROXY v1 header has too few fields: {line}"
150        )));
151    }
152
153    let proto = parts[1]; // TCP4 or TCP6
154    let src_ip_str = parts[2];
155    let dst_ip_str = parts[3];
156    let src_port: u16 = parts[4].parse().map_err(|_| {
157        ProxyError::Internal(format!("invalid source port in PROXY v1: {}", parts[4]))
158    })?;
159    let dst_port: u16 = parts[5].parse().map_err(|_| {
160        ProxyError::Internal(format!(
161            "invalid destination port in PROXY v1: {}",
162            parts[5]
163        ))
164    })?;
165
166    let src_ip: IpAddr = match proto {
167        "TCP4" => src_ip_str.parse::<Ipv4Addr>().map(IpAddr::V4),
168        "TCP6" => src_ip_str.parse::<Ipv6Addr>().map(IpAddr::V6),
169        _ => {
170            return Err(ProxyError::Internal(format!(
171                "unknown protocol in PROXY v1: {proto}"
172            )));
173        }
174    }
175    .map_err(|_| ProxyError::Internal(format!("invalid source IP in PROXY v1: {src_ip_str}")))?;
176
177    let dst_ip: IpAddr = match proto {
178        "TCP4" => dst_ip_str.parse::<Ipv4Addr>().map(IpAddr::V4),
179        "TCP6" => dst_ip_str.parse::<Ipv6Addr>().map(IpAddr::V6),
180        _ => unreachable!(), // handled above
181    }
182    .map_err(|_| {
183        ProxyError::Internal(format!("invalid destination IP in PROXY v1: {dst_ip_str}"))
184    })?;
185
186    Ok(ProxyProtocolHeader {
187        src_addr: SocketAddr::new(src_ip, src_port),
188        dst_addr: Some(SocketAddr::new(dst_ip, dst_port)),
189    })
190}
191
192/// Parse a PROXY protocol v2 (binary) header.
193async fn parse_v2(
194    buf: &mut Vec<u8>,
195    mut total_read: usize,
196    stream: &mut TcpStream,
197) -> Result<(Option<ProxyProtocolHeader>, Vec<u8>), ProxyError> {
198    // Bytes 12-15:
199    //   byte 12: version (upper nibble) | command (lower nibble)
200    //   byte 13: address family (upper nibble) | transport protocol (lower nibble)
201    //   bytes 14-15: length of the address/TLV block (big-endian u16)
202
203    let ver_cmd = buf[12];
204    let version = (ver_cmd >> 4) & 0x0F;
205    let command = ver_cmd & 0x0F;
206
207    if version != 2 {
208        return Err(ProxyError::Internal(format!(
209            "unsupported PROXY v2 version: {version}"
210        )));
211    }
212
213    let fam_proto = buf[13];
214    let family = (fam_proto >> 4) & 0x0F;
215    let _transport = fam_proto & 0x0F;
216
217    let addr_len = u16::from_be_bytes([buf[14], buf[15]]) as usize;
218    let total_header_len = 16 + addr_len;
219
220    if total_header_len > MAX_HEADER_SIZE {
221        return Err(ProxyError::Internal(format!(
222            "PROXY v2 header too large: {total_header_len} bytes"
223        )));
224    }
225
226    // Ensure we have enough data.
227    if buf.len() < total_header_len {
228        buf.resize(total_header_len, 0);
229    }
230    while total_read < total_header_len {
231        let n = stream.read(&mut buf[total_read..total_header_len]).await?;
232        if n == 0 {
233            return Err(ProxyError::Internal(
234                "connection closed while reading PROXY v2 header".into(),
235            ));
236        }
237        total_read += n;
238    }
239
240    // Command 0x00 = LOCAL (health check etc.), 0x01 = PROXY.
241    if command == 0x00 {
242        debug!("PROXY v2 LOCAL command (no address info)");
243        let remaining = buf[total_header_len..total_read].to_vec();
244        return Ok((None, remaining));
245    }
246
247    if command != 0x01 {
248        warn!(command, "unknown PROXY v2 command");
249        let remaining = buf[total_header_len..total_read].to_vec();
250        return Ok((None, remaining));
251    }
252
253    let addr_data = &buf[16..16 + addr_len];
254
255    let header = match family {
256        // AF_INET (IPv4)
257        0x01 => {
258            if addr_len < 12 {
259                return Err(ProxyError::Internal("PROXY v2 IPv4 addr too short".into()));
260            }
261            let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
262            let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]);
263            let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
264            let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
265
266            ProxyProtocolHeader {
267                src_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
268                dst_addr: Some(SocketAddr::new(IpAddr::V4(dst_ip), dst_port)),
269            }
270        }
271        // AF_INET6 (IPv6)
272        0x02 => {
273            if addr_len < 36 {
274                return Err(ProxyError::Internal("PROXY v2 IPv6 addr too short".into()));
275            }
276            let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
277            let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap());
278            let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
279            let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
280
281            ProxyProtocolHeader {
282                src_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
283                dst_addr: Some(SocketAddr::new(IpAddr::V6(dst_ip), dst_port)),
284            }
285        }
286        // AF_UNSPEC
287        0x00 => {
288            debug!("PROXY v2 AF_UNSPEC — no address info");
289            let remaining = buf[total_header_len..total_read].to_vec();
290            return Ok((None, remaining));
291        }
292        _ => {
293            warn!(family, "unknown PROXY v2 address family");
294            let remaining = buf[total_header_len..total_read].to_vec();
295            return Ok((None, remaining));
296        }
297    };
298
299    debug!(
300        src = %header.src_addr,
301        dst = ?header.dst_addr,
302        "parsed PROXY protocol v2 header"
303    );
304
305    let remaining = buf[total_header_len..total_read].to_vec();
306    Ok((Some(header), remaining))
307}
308
309// ---------------------------------------------------------------------------
310// PrefixedStream
311// ---------------------------------------------------------------------------
312
313pin_project! {
314    /// A wrapper around a `TcpStream` that prepends buffered bytes.
315    ///
316    /// After parsing the PROXY protocol header, there may be leftover bytes
317    /// in our read buffer that belong to the actual application data. This
318    /// stream serves those bytes first, then delegates to the inner stream.
319    pub struct PrefixedStream {
320        prefix: Vec<u8>,
321        offset: usize,
322        #[pin]
323        inner: TcpStream,
324    }
325}
326
327impl PrefixedStream {
328    /// Create a new `PrefixedStream` with the given prefix bytes and inner stream.
329    pub fn new(prefix: Vec<u8>, inner: TcpStream) -> Self {
330        Self {
331            prefix,
332            offset: 0,
333            inner,
334        }
335    }
336}
337
338impl AsyncRead for PrefixedStream {
339    fn poll_read(
340        self: Pin<&mut Self>,
341        cx: &mut Context<'_>,
342        buf: &mut ReadBuf<'_>,
343    ) -> Poll<io::Result<()>> {
344        let this = self.project();
345
346        // Serve from the prefix buffer first.
347        if *this.offset < this.prefix.len() {
348            let remaining = &this.prefix[*this.offset..];
349            let to_copy = remaining.len().min(buf.remaining());
350            buf.put_slice(&remaining[..to_copy]);
351            *this.offset += to_copy;
352            return Poll::Ready(Ok(()));
353        }
354
355        // Prefix exhausted — read from the inner stream.
356        this.inner.poll_read(cx, buf)
357    }
358}
359
360impl AsyncWrite for PrefixedStream {
361    fn poll_write(
362        self: Pin<&mut Self>,
363        cx: &mut Context<'_>,
364        buf: &[u8],
365    ) -> Poll<io::Result<usize>> {
366        self.project().inner.poll_write(cx, buf)
367    }
368
369    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
370        self.project().inner.poll_flush(cx)
371    }
372
373    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
374        self.project().inner.poll_shutdown(cx)
375    }
376}
377
378// ---------------------------------------------------------------------------
379// Tests
380// ---------------------------------------------------------------------------
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn parse_v1_tcp4_line() {
388        let header = parse_v1_line("PROXY TCP4 192.168.1.100 10.0.0.1 56324 443").unwrap();
389        assert_eq!(
390            header.src_addr,
391            "192.168.1.100:56324".parse::<SocketAddr>().unwrap()
392        );
393        assert_eq!(
394            header.dst_addr,
395            Some("10.0.0.1:443".parse::<SocketAddr>().unwrap())
396        );
397    }
398
399    #[test]
400    fn parse_v1_tcp6_line() {
401        let header = parse_v1_line("PROXY TCP6 2001:db8::1 2001:db8::2 56324 443").unwrap();
402        assert_eq!(
403            header.src_addr,
404            "[2001:db8::1]:56324".parse::<SocketAddr>().unwrap()
405        );
406        assert_eq!(
407            header.dst_addr,
408            Some("[2001:db8::2]:443".parse::<SocketAddr>().unwrap())
409        );
410    }
411
412    #[test]
413    fn parse_v1_unknown_is_error() {
414        let result = parse_v1_line("PROXY UNKNOWN");
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn parse_v1_too_few_fields() {
420        let result = parse_v1_line("PROXY TCP4 1.2.3.4 5.6.7.8 1234");
421        assert!(result.is_err());
422    }
423
424    #[test]
425    fn v2_signature_constant() {
426        assert_eq!(V2_SIGNATURE.len(), 12);
427        assert_eq!(V2_SIGNATURE[0], b'\r');
428        assert_eq!(V2_SIGNATURE[11], b'\n');
429    }
430}