1use std::io;
15use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use pin_project_lite::pin_project;
20use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
21use tokio::net::TcpStream;
22use tracing::{debug, warn};
23
24use crate::ProxyError;
25
26const V2_SIGNATURE: &[u8; 12] = b"\r\n\r\n\0\r\nQUIT\n";
28
29const V1_MAX_LINE_LEN: usize = 108;
31
32const MAX_HEADER_SIZE: usize = 536;
34
35#[derive(Debug, Clone)]
37pub struct ProxyProtocolHeader {
38 pub src_addr: SocketAddr,
40 pub dst_addr: Option<SocketAddr>,
42}
43
44pub async fn parse_proxy_protocol(
54 stream: &mut TcpStream,
55) -> Result<(Option<ProxyProtocolHeader>, Vec<u8>), ProxyError> {
56 let mut buf = vec![0u8; MAX_HEADER_SIZE];
60 let mut total_read = 0;
61
62 while total_read < 16 {
64 let n = stream.read(&mut buf[total_read..]).await?;
65 if n == 0 {
66 return Err(ProxyError::Internal(
67 "connection closed before PROXY protocol header".into(),
68 ));
69 }
70 total_read += n;
71 }
72
73 if buf[..12] == *V2_SIGNATURE {
75 return parse_v2(&mut buf, total_read, stream).await;
76 }
77
78 if buf.starts_with(b"PROXY ") {
80 return parse_v1(&mut buf, total_read, stream).await;
81 }
82
83 debug!("no PROXY protocol header detected, passing through");
85 let prefix = buf[..total_read].to_vec();
86 Ok((None, prefix))
87}
88
89#[allow(clippy::ptr_arg)]
91async fn parse_v1(
92 buf: &mut Vec<u8>,
93 mut total_read: usize,
94 stream: &mut TcpStream,
95) -> Result<(Option<ProxyProtocolHeader>, Vec<u8>), ProxyError> {
96 loop {
98 if let Some(pos) = buf[..total_read].windows(2).position(|w| w == b"\r\n") {
99 let line = std::str::from_utf8(&buf[..pos])
100 .map_err(|_| ProxyError::Internal("PROXY v1 header is not valid UTF-8".into()))?;
101
102 let header = parse_v1_line(line)?;
103 let remaining = buf[pos + 2..total_read].to_vec();
104
105 debug!(
106 src = %header.src_addr,
107 dst = ?header.dst_addr,
108 "parsed PROXY protocol v1 header"
109 );
110
111 return Ok((Some(header), remaining));
112 }
113
114 if total_read >= V1_MAX_LINE_LEN {
115 return Err(ProxyError::Internal(
116 "PROXY v1 header too long (no CRLF found)".into(),
117 ));
118 }
119
120 let n = stream.read(&mut buf[total_read..]).await?;
122 if n == 0 {
123 return Err(ProxyError::Internal(
124 "connection closed while reading PROXY v1 header".into(),
125 ));
126 }
127 total_read += n;
128 }
129}
130
131fn parse_v1_line(line: &str) -> Result<ProxyProtocolHeader, ProxyError> {
133 let parts: Vec<&str> = line.split_whitespace().collect();
134
135 if parts.is_empty() || parts[0] != "PROXY" {
137 return Err(ProxyError::Internal("invalid PROXY v1 header".into()));
138 }
139
140 if parts.len() >= 2 && parts[1] == "UNKNOWN" {
142 return Err(ProxyError::Internal(
143 "PROXY v1 UNKNOWN protocol — no address info".into(),
144 ));
145 }
146
147 if parts.len() < 6 {
148 return Err(ProxyError::Internal(format!(
149 "PROXY v1 header has too few fields: {line}"
150 )));
151 }
152
153 let proto = parts[1]; let src_ip_str = parts[2];
155 let dst_ip_str = parts[3];
156 let src_port: u16 = parts[4].parse().map_err(|_| {
157 ProxyError::Internal(format!("invalid source port in PROXY v1: {}", parts[4]))
158 })?;
159 let dst_port: u16 = parts[5].parse().map_err(|_| {
160 ProxyError::Internal(format!(
161 "invalid destination port in PROXY v1: {}",
162 parts[5]
163 ))
164 })?;
165
166 let src_ip: IpAddr = match proto {
167 "TCP4" => src_ip_str.parse::<Ipv4Addr>().map(IpAddr::V4),
168 "TCP6" => src_ip_str.parse::<Ipv6Addr>().map(IpAddr::V6),
169 _ => {
170 return Err(ProxyError::Internal(format!(
171 "unknown protocol in PROXY v1: {proto}"
172 )));
173 }
174 }
175 .map_err(|_| ProxyError::Internal(format!("invalid source IP in PROXY v1: {src_ip_str}")))?;
176
177 let dst_ip: IpAddr = match proto {
178 "TCP4" => dst_ip_str.parse::<Ipv4Addr>().map(IpAddr::V4),
179 "TCP6" => dst_ip_str.parse::<Ipv6Addr>().map(IpAddr::V6),
180 _ => unreachable!(), }
182 .map_err(|_| {
183 ProxyError::Internal(format!("invalid destination IP in PROXY v1: {dst_ip_str}"))
184 })?;
185
186 Ok(ProxyProtocolHeader {
187 src_addr: SocketAddr::new(src_ip, src_port),
188 dst_addr: Some(SocketAddr::new(dst_ip, dst_port)),
189 })
190}
191
192async fn parse_v2(
194 buf: &mut Vec<u8>,
195 mut total_read: usize,
196 stream: &mut TcpStream,
197) -> Result<(Option<ProxyProtocolHeader>, Vec<u8>), ProxyError> {
198 let ver_cmd = buf[12];
204 let version = (ver_cmd >> 4) & 0x0F;
205 let command = ver_cmd & 0x0F;
206
207 if version != 2 {
208 return Err(ProxyError::Internal(format!(
209 "unsupported PROXY v2 version: {version}"
210 )));
211 }
212
213 let fam_proto = buf[13];
214 let family = (fam_proto >> 4) & 0x0F;
215 let _transport = fam_proto & 0x0F;
216
217 let addr_len = u16::from_be_bytes([buf[14], buf[15]]) as usize;
218 let total_header_len = 16 + addr_len;
219
220 if total_header_len > MAX_HEADER_SIZE {
221 return Err(ProxyError::Internal(format!(
222 "PROXY v2 header too large: {total_header_len} bytes"
223 )));
224 }
225
226 if buf.len() < total_header_len {
228 buf.resize(total_header_len, 0);
229 }
230 while total_read < total_header_len {
231 let n = stream.read(&mut buf[total_read..total_header_len]).await?;
232 if n == 0 {
233 return Err(ProxyError::Internal(
234 "connection closed while reading PROXY v2 header".into(),
235 ));
236 }
237 total_read += n;
238 }
239
240 if command == 0x00 {
242 debug!("PROXY v2 LOCAL command (no address info)");
243 let remaining = buf[total_header_len..total_read].to_vec();
244 return Ok((None, remaining));
245 }
246
247 if command != 0x01 {
248 warn!(command, "unknown PROXY v2 command");
249 let remaining = buf[total_header_len..total_read].to_vec();
250 return Ok((None, remaining));
251 }
252
253 let addr_data = &buf[16..16 + addr_len];
254
255 let header = match family {
256 0x01 => {
258 if addr_len < 12 {
259 return Err(ProxyError::Internal("PROXY v2 IPv4 addr too short".into()));
260 }
261 let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]);
262 let dst_ip = Ipv4Addr::new(addr_data[4], addr_data[5], addr_data[6], addr_data[7]);
263 let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
264 let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
265
266 ProxyProtocolHeader {
267 src_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
268 dst_addr: Some(SocketAddr::new(IpAddr::V4(dst_ip), dst_port)),
269 }
270 }
271 0x02 => {
273 if addr_len < 36 {
274 return Err(ProxyError::Internal("PROXY v2 IPv6 addr too short".into()));
275 }
276 let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap());
277 let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[16..32]).unwrap());
278 let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
279 let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
280
281 ProxyProtocolHeader {
282 src_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
283 dst_addr: Some(SocketAddr::new(IpAddr::V6(dst_ip), dst_port)),
284 }
285 }
286 0x00 => {
288 debug!("PROXY v2 AF_UNSPEC — no address info");
289 let remaining = buf[total_header_len..total_read].to_vec();
290 return Ok((None, remaining));
291 }
292 _ => {
293 warn!(family, "unknown PROXY v2 address family");
294 let remaining = buf[total_header_len..total_read].to_vec();
295 return Ok((None, remaining));
296 }
297 };
298
299 debug!(
300 src = %header.src_addr,
301 dst = ?header.dst_addr,
302 "parsed PROXY protocol v2 header"
303 );
304
305 let remaining = buf[total_header_len..total_read].to_vec();
306 Ok((Some(header), remaining))
307}
308
309pin_project! {
314 pub struct PrefixedStream {
320 prefix: Vec<u8>,
321 offset: usize,
322 #[pin]
323 inner: TcpStream,
324 }
325}
326
327impl PrefixedStream {
328 pub fn new(prefix: Vec<u8>, inner: TcpStream) -> Self {
330 Self {
331 prefix,
332 offset: 0,
333 inner,
334 }
335 }
336}
337
338impl AsyncRead for PrefixedStream {
339 fn poll_read(
340 self: Pin<&mut Self>,
341 cx: &mut Context<'_>,
342 buf: &mut ReadBuf<'_>,
343 ) -> Poll<io::Result<()>> {
344 let this = self.project();
345
346 if *this.offset < this.prefix.len() {
348 let remaining = &this.prefix[*this.offset..];
349 let to_copy = remaining.len().min(buf.remaining());
350 buf.put_slice(&remaining[..to_copy]);
351 *this.offset += to_copy;
352 return Poll::Ready(Ok(()));
353 }
354
355 this.inner.poll_read(cx, buf)
357 }
358}
359
360impl AsyncWrite for PrefixedStream {
361 fn poll_write(
362 self: Pin<&mut Self>,
363 cx: &mut Context<'_>,
364 buf: &[u8],
365 ) -> Poll<io::Result<usize>> {
366 self.project().inner.poll_write(cx, buf)
367 }
368
369 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
370 self.project().inner.poll_flush(cx)
371 }
372
373 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
374 self.project().inner.poll_shutdown(cx)
375 }
376}
377
378#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn parse_v1_tcp4_line() {
388 let header = parse_v1_line("PROXY TCP4 192.168.1.100 10.0.0.1 56324 443").unwrap();
389 assert_eq!(
390 header.src_addr,
391 "192.168.1.100:56324".parse::<SocketAddr>().unwrap()
392 );
393 assert_eq!(
394 header.dst_addr,
395 Some("10.0.0.1:443".parse::<SocketAddr>().unwrap())
396 );
397 }
398
399 #[test]
400 fn parse_v1_tcp6_line() {
401 let header = parse_v1_line("PROXY TCP6 2001:db8::1 2001:db8::2 56324 443").unwrap();
402 assert_eq!(
403 header.src_addr,
404 "[2001:db8::1]:56324".parse::<SocketAddr>().unwrap()
405 );
406 assert_eq!(
407 header.dst_addr,
408 Some("[2001:db8::2]:443".parse::<SocketAddr>().unwrap())
409 );
410 }
411
412 #[test]
413 fn parse_v1_unknown_is_error() {
414 let result = parse_v1_line("PROXY UNKNOWN");
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn parse_v1_too_few_fields() {
420 let result = parse_v1_line("PROXY TCP4 1.2.3.4 5.6.7.8 1234");
421 assert!(result.is_err());
422 }
423
424 #[test]
425 fn v2_signature_constant() {
426 assert_eq!(V2_SIGNATURE.len(), 12);
427 assert_eq!(V2_SIGNATURE[0], b'\r');
428 assert_eq!(V2_SIGNATURE[11], b'\n');
429 }
430}