hyperproxy/
wrapped_incoming.rs

1use std::net::SocketAddr;
2use std::pin::Pin;
3#[cfg(feature = "track_conn_count")]
4use std::sync::atomic::{AtomicU64, Ordering};
5#[cfg(feature = "track_conn_count")]
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use futures::Stream;
11use hyper::server::accept::Accept;
12use hyper::server::conn::AddrIncoming;
13
14use crate::{ProxyMode, WrappedStream};
15
16pub struct WrappedIncoming {
17    inner: AddrIncoming,
18    #[cfg(feature = "track_conn_count")]
19    conn_count: Arc<AtomicU64>,
20    proxy_mode: ProxyMode,
21}
22
23impl WrappedIncoming {
24    pub fn new(
25        addr: SocketAddr,
26        nodelay: bool,
27        keepalive: Option<Duration>,
28        proxy_mode: ProxyMode,
29    ) -> hyper::Result<Self> {
30        let mut inner = AddrIncoming::bind(&addr)?;
31        inner.set_nodelay(nodelay);
32        inner.set_keepalive(keepalive);
33        Ok(WrappedIncoming {
34            inner,
35            #[cfg(feature = "track_conn_count")]
36            conn_count: Arc::new(AtomicU64::new(0)),
37            proxy_mode,
38        })
39    }
40
41    #[cfg(feature = "track_conn_count")]
42    pub fn get_conn_count(&self) -> Arc<AtomicU64> {
43        self.conn_count.clone()
44    }
45}
46
47impl Accept for WrappedIncoming {
48    type Conn = WrappedStream;
49
50    type Error = std::io::Error;
51
52    fn poll_accept(
53        self: Pin<&mut Self>,
54        cx: &mut Context<'_>,
55    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
56        self.poll_next(cx)
57    }
58}
59
60impl Stream for WrappedIncoming {
61    type Item = Result<WrappedStream, std::io::Error>;
62
63    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
64        match Pin::new(&mut self.inner).poll_accept(cx) {
65            Poll::Ready(Some(Ok(stream))) => {
66                #[cfg(feature = "track_conn_count")]
67                self.conn_count.fetch_add(1, Ordering::SeqCst);
68                let remote_addr = stream.remote_addr();
69                let (read, write) = stream.into_inner().into_split();
70                Poll::Ready(Some(Ok(WrappedStream {
71                    remote_addr,
72                    inner_read: Some(Box::pin(read)),
73                    inner_write: Box::pin(write),
74                    #[cfg(feature = "track_conn_count")]
75                    conn_count: self.conn_count.clone(),
76                    #[cfg(feature = "tonic")]
77                    connect_info: Default::default(),
78                    pending_read_proxy: None,
79                    fused_error: false,
80                    info: None,
81                    proxy_mode: self.proxy_mode,
82                })))
83            }
84            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
85            Poll::Ready(None) => Poll::Ready(None),
86            Poll::Pending => Poll::Pending,
87        }
88    }
89}