1#[macro_use]
4extern crate log;
5
6#[cfg(feature = "track_conn_count")]
7use std::sync::{
8 atomic::{AtomicU64, Ordering},
9 Arc,
10};
11use std::{
12 convert::TryInto,
13 io::{self, ErrorKind},
14 mem::MaybeUninit,
15 net::{IpAddr, Ipv4Addr, SocketAddr},
16 pin::Pin,
17 task::{Context, Poll},
18};
19
20use futures::Future;
21use tokio::{
22 io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf},
23 net::tcp::{OwnedReadHalf, OwnedWriteHalf},
24};
25
26mod wrapped_incoming;
27pub use wrapped_incoming::WrappedIncoming;
28
29#[derive(Clone, Copy, Debug)]
30pub enum ProxyMode {
32 None,
34 Accept,
36 Require,
38}
39
40const PROXY_PACKET_HEADER_LEN: usize = 16;
41const PROXY_PACKET_MAX_PROXY_ADDR_SIZE: usize = 216;
42const PROXY_SIGNATURE: [u8; 12] = [
43 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
44];
45const PROXY_PROTOCOL_VERSION: u8 = 2;
46
47#[repr(u8)]
48#[derive(Debug, Clone, Copy, PartialEq)]
49pub enum Command {
51 Local,
52 Proxy,
53}
54
55impl Command {
56 fn from_u8(from: u8) -> Option<Self> {
57 match from {
58 0 => Some(Command::Local),
59 1 => Some(Command::Proxy),
60 _ => None,
61 }
62 }
63}
64
65#[repr(u8)]
66#[derive(Debug, Clone, Copy, PartialEq)]
67pub enum Family {
69 Unspecified,
70 Ipv4,
71 Ipv6,
72 Unix,
73}
74
75impl Family {
76 fn from_u8(from: u8) -> Option<Self> {
77 match from {
78 0 => Some(Family::Unspecified),
79 1 => Some(Family::Ipv4),
80 2 => Some(Family::Ipv6),
81 3 => Some(Family::Unix),
82 _ => None,
83 }
84 }
85
86 fn len(&self) -> Option<usize> {
87 match self {
88 Family::Unspecified => None,
89 Family::Ipv4 => Some(12),
90 Family::Ipv6 => Some(36),
91 Family::Unix => Some(216),
92 }
93 }
94}
95
96#[repr(u8)]
97#[derive(Debug, Clone, Copy, PartialEq)]
98pub enum Protocol {
100 Unspecified,
101 Stream,
102 Datagram,
103}
104
105impl Protocol {
106 fn from_u8(from: u8) -> Option<Self> {
107 match from {
108 0 => Some(Protocol::Unspecified),
109 1 => Some(Protocol::Stream),
110 2 => Some(Protocol::Datagram),
111 _ => None,
112 }
113 }
114}
115
116#[derive(PartialEq, Debug)]
117struct ProxyInfo {
118 command: Command,
119 family: Family,
120 protocol: Protocol,
121 discovered_dest: Option<SocketAddr>,
122 discovered_src: Option<SocketAddr>,
123}
124
125#[derive(PartialEq, Debug)]
126enum ProxyResult {
127 Proxy(ProxyInfo),
128 SignatureBytes([u8; PROXY_SIGNATURE.len()]),
129}
130
131pub struct WrappedStream {
133 remote_addr: SocketAddr,
134 inner_write: Pin<Box<OwnedWriteHalf>>,
135 inner_read: Option<Pin<Box<OwnedReadHalf>>>,
136 #[cfg(feature = "track_conn_count")]
137 conn_count: Arc<AtomicU64>,
138 pending_read_proxy: Option<
139 Pin<
140 Box<
141 dyn Future<Output = io::Result<(ProxyResult, Pin<Box<OwnedReadHalf>>)>>
142 + Send
143 + Sync
144 + 'static,
145 >,
146 >,
147 >,
148 info: Option<ProxyInfo>,
149 #[cfg(feature = "tonic")]
150 connect_info: std::sync::Arc<std::sync::RwLock<Option<SocketAddr>>>,
151 fused_error: bool,
152 proxy_mode: ProxyMode,
153}
154
155#[cfg(feature = "tonic")]
156#[derive(Clone)]
157pub struct TcpConnectInfo {
158 inner: std::sync::Arc<std::sync::RwLock<Option<SocketAddr>>>,
159}
160
161#[cfg(feature = "tonic")]
162impl TcpConnectInfo {
163 pub fn remote_addr(&self) -> Option<SocketAddr> {
164 *self.inner.read().unwrap()
165 }
166}
167
168#[cfg(feature = "tonic")]
169impl tonic::transport::server::Connected for WrappedStream {
170 type ConnectInfo = TcpConnectInfo;
171 fn connect_info(&self) -> Self::ConnectInfo {
172 TcpConnectInfo {
173 inner: self.connect_info.clone(),
174 }
175 }
176}
177
178#[cfg(feature = "tonic")]
179pub fn tonic_remote_addr<T>(request: &tonic::Request<T>) -> Option<SocketAddr> {
180 request
181 .extensions()
182 .get::<TcpConnectInfo>()
183 .expect("missing TCP connect info (was hyperproxy inline with tonic?)")
184 .remote_addr()
185}
186
187#[cfg(feature = "axum")]
188impl<'a> axum::extract::connect_info::Connected<&'a WrappedStream> for SocketAddr {
189 fn connect_info(target: &'a WrappedStream) -> Self {
190 target.source()
191 }
192}
193
194fn to_array<const SIZE: usize>(from: &[u8]) -> [u8; SIZE] {
195 from.try_into().unwrap()
196}
197
198async fn read_proxy<R: AsyncRead + Unpin>(mut read: R) -> io::Result<(ProxyResult, R)> {
199 let mut signature = [0u8; PROXY_SIGNATURE.len()];
200 read.read_exact(&mut signature[..]).await?;
201 if signature != PROXY_SIGNATURE {
202 return Ok((ProxyResult::SignatureBytes(signature), read));
203 }
204
205 let mut header = [0u8; PROXY_PACKET_HEADER_LEN - PROXY_SIGNATURE.len()];
207 read.read_exact(&mut header[..]).await?;
208
209 let version = (header[0] & 0xf0) >> 4;
210 if version != PROXY_PROTOCOL_VERSION {
211 debug!("invalid proxy protocol version: {}", version);
212 return Err(io::Error::new(
213 ErrorKind::InvalidData,
214 "invalid proxy protocol version",
215 ));
216 }
217 let command = header[0] & 0x0f;
218 let command = match Command::from_u8(command) {
219 Some(c) => c,
220 None => {
221 debug!("invalid proxy protocol command: {}", command);
222 return Err(io::Error::new(
223 ErrorKind::InvalidData,
224 "invalid proxy protocol command",
225 ));
226 }
227 };
228
229 let family = (header[1] & 0xf0) >> 4;
230 let family = match Family::from_u8(family) {
231 None => {
232 debug!("invalid proxy family: {}", family);
233 return Err(io::Error::new(
234 ErrorKind::InvalidData,
235 "invalid proxy family",
236 ));
237 }
238 Some(family) => {
239 trace!("PROXY family: {:?}", family);
240 family
241 }
242 };
243
244 let protocol = header[1] & 0x0f;
245 let protocol = match Protocol::from_u8(protocol) {
246 None => {
247 debug!("invalid proxy protocol: {}", protocol);
248 return Err(io::Error::new(
249 ErrorKind::InvalidData,
250 "invalid proxy protocol",
251 ));
252 }
253 Some(protocol) => {
254 trace!("PROXY protocol: {:?}", protocol);
255 protocol
256 }
257 };
258
259 let len = u16::from_be_bytes([header[2], header[3]]) as usize;
260 let target_len = if matches!(command, Command::Local) {
261 None
262 } else {
263 family.len()
264 };
265
266 if let Some(target_len) = target_len {
267 if len < target_len {
268 debug!("invalid proxy address length: {}", target_len);
269 return Err(io::Error::new(
270 ErrorKind::InvalidData,
271 "invalid proxy address length",
272 ));
273 }
274 }
275
276 let mut raw =
277 unsafe { MaybeUninit::<[u8; PROXY_PACKET_MAX_PROXY_ADDR_SIZE]>::uninit().assume_init() };
278 read.read_exact(&mut raw[..len]).await?;
279 let raw = &raw[..len];
280
281 let mut discovered_src = None;
282 let mut discovered_dest = None;
283
284 match family {
285 Family::Unspecified => {
286 debug!("unspecified PROXY family data: {:?}", raw);
287 }
288 Family::Ipv4 => {
289 let src_addr = IpAddr::V4(Ipv4Addr::from(to_array(&raw[..4])));
290 let dest_addr = IpAddr::V4(Ipv4Addr::from(to_array(&raw[4..8])));
291 let src_port = u16::from_be_bytes((&raw[8..10]).try_into().unwrap());
292 let dest_port = u16::from_be_bytes((&raw[10..12]).try_into().unwrap());
293 discovered_src = Some(SocketAddr::new(src_addr, src_port));
294 discovered_dest = Some(SocketAddr::new(dest_addr, dest_port));
295 }
296 Family::Ipv6 => {
297 let src_addr = IpAddr::V6(to_array(&raw[..16]).into());
298 let dest_addr = IpAddr::V6(to_array(&raw[16..32]).into());
299 let src_port = u16::from_be_bytes((&raw[32..34]).try_into().unwrap());
300 let dest_port = u16::from_be_bytes((&raw[34..36]).try_into().unwrap());
301 discovered_src = Some(SocketAddr::new(src_addr, src_port));
302 discovered_dest = Some(SocketAddr::new(dest_addr, dest_port));
303 }
304 Family::Unix => {
305 warn!("unsupported UNIX PROXY family, ignored.");
306 }
307 }
308
309 Ok((
310 ProxyResult::Proxy(ProxyInfo {
311 command,
312 family,
313 protocol,
314 discovered_dest,
315 discovered_src,
316 }),
317 read,
318 ))
319}
320
321impl AsyncRead for WrappedStream {
322 #[inline]
323 fn poll_read(
324 mut self: Pin<&mut Self>,
325 cx: &mut Context<'_>,
326 buf: &mut ReadBuf<'_>,
327 ) -> Poll<io::Result<()>> {
328 if self.fused_error {
329 return Poll::Ready(Err(io::Error::new(
330 ErrorKind::Unsupported,
331 "called read after error",
332 )));
333 }
334 if matches!(self.proxy_mode, ProxyMode::None) {
335 return self
336 .inner_read
337 .as_mut()
338 .unwrap()
339 .as_mut()
340 .poll_read(cx, buf);
341 }
342 assert!(buf.remaining() >= PROXY_SIGNATURE.len());
343
344 if self.pending_read_proxy.is_none() {
345 self.pending_read_proxy = Some(Box::pin(read_proxy(self.inner_read.take().unwrap())));
346 }
347 let output = self.pending_read_proxy.as_mut().unwrap().as_mut().poll(cx);
348 match output {
349 Poll::Ready(Err(e)) => {
350 self.fused_error = true;
351 self.pending_read_proxy = None;
352 Poll::Ready(Err(e))
353 }
354 Poll::Ready(Ok((ProxyResult::SignatureBytes(bytes), stream))) => {
355 if matches!(self.proxy_mode, ProxyMode::Require) {
356 return Poll::Ready(Err(io::Error::new(
357 ErrorKind::InvalidData,
358 "required a PROXYv2 header, none found",
359 )));
360 }
361 self.proxy_mode = ProxyMode::None;
362 buf.put_slice(&bytes[..]);
363 self.pending_read_proxy = None;
364 self.inner_read = Some(stream);
365 #[cfg(feature = "tonic")]
366 {
367 *self.connect_info.write().unwrap() = Some(self.source());
368 }
369 self.inner_read
370 .as_mut()
371 .unwrap()
372 .as_mut()
373 .poll_read(cx, buf)
374 }
375 Poll::Ready(Ok((ProxyResult::Proxy(info), stream))) => {
376 self.proxy_mode = ProxyMode::None;
377 self.info = Some(info);
378 self.pending_read_proxy = None;
379 self.inner_read = Some(stream);
380 #[cfg(feature = "tonic")]
381 {
382 *self.connect_info.write().unwrap() = Some(self.source());
383 }
384 self.inner_read
385 .as_mut()
386 .unwrap()
387 .as_mut()
388 .poll_read(cx, buf)
389 }
390 Poll::Pending => Poll::Pending,
391 }
392 }
393}
394
395impl WrappedStream {
396 pub fn was_proxied(&self) -> bool {
398 self.info.is_some()
399 }
400
401 pub fn command(&self) -> Option<Command> {
403 self.info.as_ref().map(|x| x.command)
404 }
405
406 pub fn family(&self) -> Option<Family> {
408 self.info.as_ref().map(|x| x.family)
409 }
410
411 pub fn protocol(&self) -> Option<Protocol> {
413 self.info.as_ref().map(|x| x.protocol)
414 }
415
416 pub fn destination(&self) -> Option<SocketAddr> {
418 self.info.as_ref().map(|x| x.discovered_dest).flatten()
419 }
420
421 pub fn source(&self) -> SocketAddr {
423 self.info
424 .as_ref()
425 .map(|x| x.discovered_src)
426 .flatten()
427 .unwrap_or_else(|| self.remote_addr)
428 }
429
430 pub fn original_source(&self) -> SocketAddr {
432 self.remote_addr
433 }
434}
435
436impl AsyncWrite for WrappedStream {
437 #[inline]
438 fn poll_write(
439 mut self: Pin<&mut Self>,
440 cx: &mut Context<'_>,
441 buf: &[u8],
442 ) -> Poll<io::Result<usize>> {
443 self.inner_write.as_mut().poll_write(cx, buf)
444 }
445
446 #[inline]
447 fn poll_write_vectored(
448 mut self: Pin<&mut Self>,
449 cx: &mut Context<'_>,
450 bufs: &[io::IoSlice<'_>],
451 ) -> Poll<io::Result<usize>> {
452 self.inner_write.as_mut().poll_write_vectored(cx, bufs)
453 }
454
455 #[inline]
456 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
457 self.inner_write.as_mut().poll_flush(cx)
458 }
459
460 #[inline]
461 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
462 self.inner_write.as_mut().poll_shutdown(cx)
463 }
464
465 #[inline]
466 fn is_write_vectored(&self) -> bool {
467 self.inner_write.is_write_vectored()
468 }
469}
470
471#[cfg(feature = "track_conn_count")]
472impl Drop for WrappedStream {
473 fn drop(&mut self) {
474 self.conn_count.fetch_sub(1, Ordering::SeqCst);
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[tokio::test]
483 async fn test_parse() {
484 let raw = hex::decode("0d0a0d0a000d0a515549540a21110054ffffffffac1f1cd1898801bb030004508978bb04003e0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap();
485 assert_eq!(
486 read_proxy(&raw[..]).await.unwrap().0,
487 ProxyResult::Proxy(ProxyInfo {
488 command: Command::Proxy,
489 family: Family::Ipv4,
490 protocol: Protocol::Stream,
491 discovered_dest: Some("172.31.28.209:443".parse().unwrap()),
492 discovered_src: Some("255.255.255.255:35208".parse().unwrap()),
493 })
494 );
495 }
496}