tcp_relay/
relay.rs

1use net_relay::{Builder, Error};
2use std::net::SocketAddr;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use std::time::Duration;
7use tcp_pool::net_pool::{Pool, debug, info, instrument_debug_span, tokio_spawn, warn2};
8
9/// tcp relay
10pub struct Relay<F, S, P = tcp_pool::Pool>
11where
12    F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S,
13    S: Future<Output = ()>,
14    P: tcp_pool::TcpPool + Pool,
15{
16    parts: net_relay::builder::Parts<P, F>,
17    pending: Option<Pin<Box<dyn Future<Output = Result<(), net_relay::Error>> + Send + 'static>>>,
18}
19
20impl<F, S, P> Relay<F, S, P>
21where
22    F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S,
23    S: Future<Output = ()>,
24    P: tcp_pool::TcpPool + Pool,
25{
26    pub fn build<B: FnOnce(Builder<P, F>) -> Builder<P, F>>(b: B) -> Result<Self, Error> {
27        let builder = Builder::new();
28        let parts = b(builder).build()?;
29        Ok(Relay {
30            parts,
31            pending: None,
32        })
33    }
34
35    pub fn bind_addrs(&self) -> &Vec<SocketAddr> {
36        &self.parts.bind_addrs
37    }
38
39    pub fn relay_fn(&self) -> Arc<F> {
40        self.parts.relay_fn.as_ref().unwrap().clone()
41    }
42
43    pub fn pool(&self) -> Arc<P> {
44        self.parts.pools[0].clone()
45    }
46
47    /// 设置最大连接数
48    pub fn set_max_conn(&self, max: Option<usize>) {
49        self.pool().set_max_conn(max)
50    }
51
52    /// 设置空闲连接保留时长
53    pub fn set_keepalive(&self, duration: Option<Duration>) {
54        self.pool().set_keepalive(duration)
55    }
56}
57
58impl<F, S, P> net_relay::Relay for Relay<F, S, P>
59where
60    F: Fn(Arc<P>, tokio::net::TcpStream, SocketAddr) -> S + Send + Sync + 'static,
61    S: Future<Output = ()> + Send + 'static,
62    P: tcp_pool::TcpPool + Pool + Send + 'static,
63{
64    fn poll_run(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
65        if self.pending.is_none() {
66            let tuple = (self.bind_addrs().clone(), self.pool(), self.relay_fn());
67
68            self.pending = Some(Box::pin(async move {
69                let listener = tokio::net::TcpListener::bind(tuple.0.as_slice()).await?;
70
71                info!(
72                    "[Tcp Relay] listen on: {:?}",
73                    listener.local_addr().unwrap()
74                );
75
76                loop {
77                    match listener.accept().await {
78                        Ok((client, addr)) => {
79                            let tuple = (tuple.1.clone(), tuple.2.clone());
80                            tokio_spawn! {
81                                instrument_debug_span! {
82                                    async move {
83                                        debug!("[Tcp Relay] connection accepted");
84                                        let res = tuple.1(tuple.0, client, addr).await;
85                                        debug!("[Tcp Relay] connection closed");
86                                        res
87                                    },
88                                    "new_tcp_stream",
89                                    address=addr.to_string()
90                                }
91                            };
92                        }
93                        Err(_e) => {
94                            warn2!("[Tcp Relay] accept from listen, error occurred: {:?}", _e);
95                        }
96                    }
97                }
98            }));
99        }
100
101        self.pending.as_mut().unwrap().as_mut().poll(cx)
102    }
103}
104
105pub async fn default_relay_fn<P: Pool + tcp_pool::TcpPool + Send>(
106    pool: Arc<P>,
107    mut client: tokio::net::TcpStream,
108    addr: SocketAddr,
109) {
110    let id = tcp_pool::net_pool::utils::socketaddr_to_hash_code(&addr);
111    let mut proxy = match pool.clone().get(&id.to_string()).await {
112        Err(_e) => {
113            warn2!(
114                "[Tcp Relay] get tcp stream from pool, error occurred: {:?}",
115                _e
116            );
117            return;
118        }
119        Ok(t) => t,
120    };
121
122    let proxy_mut: &mut tokio::net::TcpStream = &mut proxy;
123    let _cnt = tokio::io::copy_bidirectional(proxy_mut, &mut client).await;
124
125    debug!("[Tcp Relay] data exchange byte count: {:?}", _cnt);
126}