proxy_protocol_rs/
stream.rs1use std::fmt;
16use std::io::Cursor;
17use std::net::{IpAddr, SocketAddr};
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
22use tokio::net::TcpStream;
23
24use crate::types::ProxyInfo;
25
26#[derive(Debug, Clone)]
31pub struct ProxyConnectInfo {
32 pub client_addr: SocketAddr,
34 pub peer_addr: SocketAddr,
36 pub proxy_info: Option<ProxyInfo>,
38}
39
40#[derive(Debug)]
45pub struct ProxiedStream {
46 inner: TcpStream,
47 leftover: Cursor<Vec<u8>>,
48 proxy_info: Option<ProxyInfo>,
49 peer_addr: SocketAddr,
50}
51
52impl ProxiedStream {
53 pub(crate) fn new(
54 inner: TcpStream,
55 leftover: Vec<u8>,
56 proxy_info: Option<ProxyInfo>,
57 peer_addr: SocketAddr,
58 ) -> Self {
59 Self {
60 inner,
61 leftover: Cursor::new(leftover),
62 proxy_info,
63 peer_addr,
64 }
65 }
66
67 pub fn proxy_info(&self) -> Option<&ProxyInfo> {
69 self.proxy_info.as_ref()
70 }
71
72 pub fn client_addr(&self) -> SocketAddr {
74 self.proxy_info
75 .as_ref()
76 .and_then(|info| info.source_inet())
77 .unwrap_or(self.peer_addr)
78 }
79
80 pub fn peer_addr(&self) -> SocketAddr {
82 self.peer_addr
83 }
84
85 pub fn inner(&self) -> &TcpStream {
87 &self.inner
88 }
89
90 pub fn connect_info(&self) -> ProxyConnectInfo {
94 ProxyConnectInfo {
95 client_addr: self.client_addr(),
96 peer_addr: self.peer_addr,
97 proxy_info: self.proxy_info.clone(),
98 }
99 }
100}
101
102impl ProxyConnectInfo {
103 pub fn client_ip(&self) -> IpAddr {
105 self.client_addr.ip()
106 }
107
108 pub fn is_proxied(&self) -> bool {
110 self.proxy_info.is_some()
111 }
112}
113
114impl fmt::Display for ProxyConnectInfo {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 if self.is_proxied() {
117 write!(f, "{} via {}", self.client_addr, self.peer_addr)
118 } else {
119 write!(f, "{} (direct)", self.client_addr)
120 }
121 }
122}
123
124impl From<&ProxiedStream> for ProxyConnectInfo {
125 fn from(stream: &ProxiedStream) -> Self {
126 stream.connect_info()
127 }
128}
129
130impl From<SocketAddr> for ProxyConnectInfo {
131 fn from(addr: SocketAddr) -> Self {
132 Self {
133 client_addr: addr,
134 peer_addr: addr,
135 proxy_info: None,
136 }
137 }
138}
139
140impl From<(ProxyInfo, SocketAddr)> for ProxyConnectInfo {
141 fn from((info, peer_addr): (ProxyInfo, SocketAddr)) -> Self {
142 let client_addr = info.source_inet().unwrap_or(peer_addr);
143 Self {
144 client_addr,
145 peer_addr,
146 proxy_info: Some(info),
147 }
148 }
149}
150
151impl AsyncRead for ProxiedStream {
152 fn poll_read(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &mut ReadBuf<'_>,
156 ) -> Poll<io::Result<()>> {
157 let this = self.get_mut();
158
159 let leftover_data = this.leftover.get_ref();
161 let leftover_pos = this.leftover.position() as usize;
162 let leftover_remaining = leftover_data.len() - leftover_pos;
163
164 if leftover_remaining > 0 {
165 let to_copy = leftover_remaining.min(buf.remaining());
166 buf.put_slice(&leftover_data[leftover_pos..leftover_pos + to_copy]);
167 this.leftover.set_position((leftover_pos + to_copy) as u64);
168 return Poll::Ready(Ok(()));
169 }
170
171 Pin::new(&mut this.inner).poll_read(cx, buf)
173 }
174}
175
176impl AsyncWrite for ProxiedStream {
177 fn poll_write(
178 self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &[u8],
181 ) -> Poll<io::Result<usize>> {
182 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
183 }
184
185 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
186 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
187 }
188
189 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
191 }
192}