Skip to main content

proxy_protocol_rs/
stream.rs

1// Copyright (C) 2025-2026 Michael S. Klishin and Contributors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::io::Cursor;
17use std::net::{IpAddr, SocketAddr};
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
22use tokio::net::TcpStream;
23
24use crate::types::ProxyInfo;
25
26/// Connection metadata extracted from a Proxy Protocol header
27///
28/// Available without the `axum` feature; use
29/// [`ProxiedStream::connect_info`] to build one from an accepted stream
30#[derive(Debug, Clone)]
31pub struct ProxyConnectInfo {
32    /// The original client address (from a PP header or a TCP peer)
33    pub client_addr: SocketAddr,
34    /// The TCP peer address (the proxy's address)
35    pub peer_addr: SocketAddr,
36    /// Full Proxy Protocol info, if available
37    pub proxy_info: Option<ProxyInfo>,
38}
39
40/// A TCP stream with Proxy Protocol metadata attached
41///
42/// Implements `AsyncRead + AsyncWrite`, so it can be wrapped by a TLS
43/// acceptor for deployments that terminate TLS at the application
44#[derive(Debug)]
45pub struct ProxiedStream {
46    inner: TcpStream,
47    leftover: Cursor<Vec<u8>>,
48    proxy_info: Option<ProxyInfo>,
49    peer_addr: SocketAddr,
50}
51
52impl ProxiedStream {
53    pub(crate) fn new(
54        inner: TcpStream,
55        leftover: Vec<u8>,
56        proxy_info: Option<ProxyInfo>,
57        peer_addr: SocketAddr,
58    ) -> Self {
59        Self {
60            inner,
61            leftover: Cursor::new(leftover),
62            proxy_info,
63            peer_addr,
64        }
65    }
66
67    /// Parsed Proxy Protocol information
68    pub fn proxy_info(&self) -> Option<&ProxyInfo> {
69        self.proxy_info.as_ref()
70    }
71
72    /// Client address from a PP header, falling back to the TCP peer
73    pub fn client_addr(&self) -> SocketAddr {
74        self.proxy_info
75            .as_ref()
76            .and_then(|info| info.source_inet())
77            .unwrap_or(self.peer_addr)
78    }
79
80    /// The raw TCP peer address (the load balancer's IP)
81    pub fn peer_addr(&self) -> SocketAddr {
82        self.peer_addr
83    }
84
85    /// Access the inner `TcpStream`
86    pub fn inner(&self) -> &TcpStream {
87        &self.inner
88    }
89
90    /// Build a [`ProxyConnectInfo`] snapshot from this stream's metadata
91    ///
92    /// Useful for extracting connection info before wrapping with TLS
93    pub fn connect_info(&self) -> ProxyConnectInfo {
94        ProxyConnectInfo {
95            client_addr: self.client_addr(),
96            peer_addr: self.peer_addr,
97            proxy_info: self.proxy_info.clone(),
98        }
99    }
100}
101
102impl ProxyConnectInfo {
103    /// Client IP address without the port, for rate limiting and access control
104    pub fn client_ip(&self) -> IpAddr {
105        self.client_addr.ip()
106    }
107
108    /// Whether a Proxy Protocol header was present on this connection
109    pub fn is_proxied(&self) -> bool {
110        self.proxy_info.is_some()
111    }
112}
113
114impl fmt::Display for ProxyConnectInfo {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        if self.is_proxied() {
117            write!(f, "{} via {}", self.client_addr, self.peer_addr)
118        } else {
119            write!(f, "{} (direct)", self.client_addr)
120        }
121    }
122}
123
124impl From<&ProxiedStream> for ProxyConnectInfo {
125    fn from(stream: &ProxiedStream) -> Self {
126        stream.connect_info()
127    }
128}
129
130impl From<SocketAddr> for ProxyConnectInfo {
131    fn from(addr: SocketAddr) -> Self {
132        Self {
133            client_addr: addr,
134            peer_addr: addr,
135            proxy_info: None,
136        }
137    }
138}
139
140impl From<(ProxyInfo, SocketAddr)> for ProxyConnectInfo {
141    fn from((info, peer_addr): (ProxyInfo, SocketAddr)) -> Self {
142        let client_addr = info.source_inet().unwrap_or(peer_addr);
143        Self {
144            client_addr,
145            peer_addr,
146            proxy_info: Some(info),
147        }
148    }
149}
150
151impl AsyncRead for ProxiedStream {
152    fn poll_read(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &mut ReadBuf<'_>,
156    ) -> Poll<io::Result<()>> {
157        let this = self.get_mut();
158
159        // Serve leftover bytes first
160        let leftover_data = this.leftover.get_ref();
161        let leftover_pos = this.leftover.position() as usize;
162        let leftover_remaining = leftover_data.len() - leftover_pos;
163
164        if leftover_remaining > 0 {
165            let to_copy = leftover_remaining.min(buf.remaining());
166            buf.put_slice(&leftover_data[leftover_pos..leftover_pos + to_copy]);
167            this.leftover.set_position((leftover_pos + to_copy) as u64);
168            return Poll::Ready(Ok(()));
169        }
170
171        // Delegate to inner TcpStream
172        Pin::new(&mut this.inner).poll_read(cx, buf)
173    }
174}
175
176impl AsyncWrite for ProxiedStream {
177    fn poll_write(
178        self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &[u8],
181    ) -> Poll<io::Result<usize>> {
182        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
183    }
184
185    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
186        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
187    }
188
189    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
191    }
192}