ppaass_agent/
server.rs

1use crate::{
2    command::AgentServerCommand, config::AgentServerConfig, error::AgentServerError,
3    event::AgentServerEvent, publish_server_event,
4};
5use crate::{
6    crypto::AgentServerRsaCryptoFetcher,
7    proxy::ProxyConnectionFactory,
8    tunnel::dispatcher::{ClientDispatcher, Tunnel},
9};
10
11use std::{net::SocketAddr, sync::atomic::AtomicBool};
12use std::{
13    sync::{
14        atomic::{AtomicU64, Ordering::Relaxed},
15        Arc,
16    },
17    time::Duration,
18};
19
20use anyhow::Result;
21use ppaass_protocol::message::values::address::PpaassUnifiedAddress;
22use tokio::sync::mpsc::{channel, Receiver, Sender};
23
24use tokio::{net::TcpListener, time::interval};
25
26use tokio_io_timeout::TimeoutStream;
27use tokio_tfo::{TfoListener, TfoStream};
28use tracing::{debug, error, info};
29
30const AGENT_SEVER_EVENT_CHANNEL_BUF: usize = 1024;
31const AGENT_SEVER_COMMAND_CHANNEL_BUF: usize = 1024;
32const ONE_MB: u64 = 1024 * 1024;
33
34pub struct AgentServer {
35    config: Arc<AgentServerConfig>,
36    client_dispatcher: ClientDispatcher<AgentServerRsaCryptoFetcher>,
37}
38
39impl AgentServer {
40    pub fn new(config: Arc<AgentServerConfig>) -> Result<Self> {
41        let rsa_crypto_fetcher = AgentServerRsaCryptoFetcher::new(&config)?;
42        let proxy_connection_factory =
43            ProxyConnectionFactory::new(config.clone(), rsa_crypto_fetcher)?;
44        let client_dispatcher = ClientDispatcher::new(config.clone(), proxy_connection_factory);
45        Ok(Self {
46            config,
47            client_dispatcher,
48        })
49    }
50
51    pub fn start(self) -> (Sender<AgentServerCommand>, Receiver<AgentServerEvent>) {
52        let (server_event_tx, server_event_rx) = channel(AGENT_SEVER_EVENT_CHANNEL_BUF);
53        let (server_command_tx, mut server_command_rx) = channel(AGENT_SEVER_COMMAND_CHANNEL_BUF);
54        let config = self.config;
55        let client_dispatcher = self.client_dispatcher;
56        let upload_bytes_amount: Arc<AtomicU64> = Default::default();
57        let download_bytes_amount: Arc<AtomicU64> = Default::default();
58        let stopped_status: Arc<AtomicBool> = Default::default();
59        tokio::spawn(async move {
60            let agent_server_bind_addr = if config.ipv6() {
61                format!("::1:{}", config.port())
62            } else {
63                format!("0.0.0.0:{}", config.port())
64            };
65            info!("Agent server start to serve request on address: {agent_server_bind_addr}.");
66            let tcp_listener = match TcpListener::bind(&agent_server_bind_addr).await {
67                Ok(tcp_listener) => tcp_listener,
68                Err(e) => {
69                    error!(
70                        "Fail to listen tcp port {} because of error: {e:?}",
71                        config.port()
72                    );
73                    publish_server_event(
74                        &server_event_tx,
75                        AgentServerEvent::ServerStartFail {
76                            listening_port: config.port(),
77                            reason: format!("Fail to listen tcp port: {}", config.port()),
78                        },
79                    )
80                    .await;
81                    return;
82                }
83            };
84            let tcp_listener = match TfoListener::from_tokio(tcp_listener) {
85                Ok(tcp_listener) => tcp_listener,
86                Err(e) => {
87                    error!("Fail to use fast open on tcp listener because of error: {e:?}");
88                    return;
89                }
90            };
91
92            publish_server_event(
93                &server_event_tx,
94                AgentServerEvent::ServerStartSuccess(config.port()),
95            )
96            .await;
97
98            // Start the network state event task
99            let upload_bytes_amount = upload_bytes_amount.clone();
100            let download_bytes_amount = download_bytes_amount.clone();
101            let server_event_tx = server_event_tx.clone();
102            let server_event_tick_interval_val = config.server_signal_tick_interval();
103            let mb_per_second_div_base = (server_event_tick_interval_val * ONE_MB) as f64;
104            let mut server_event_tick_interval =
105                interval(Duration::from_secs(server_event_tick_interval_val));
106
107            loop {
108                let upload_bytes_amount_pre_val = upload_bytes_amount.fetch_add(0, Relaxed);
109                let download_bytes_amount_pre_val = download_bytes_amount.fetch_add(0, Relaxed);
110                tokio::select! {
111                    // Listening to server command
112                    server_command = server_command_rx.recv() => {
113                        match server_command {
114                            Some(server_command) => {
115                                match server_command {
116                                    AgentServerCommand::Stop => {
117                                        info!("Agent server stopped because of receive stop command.");
118                                        stopped_status.swap(true, Relaxed);
119                                        publish_server_event(&server_event_tx, AgentServerEvent::ServerStopSuccess).await;
120                                        return;
121                                    },
122                                }
123                            },
124                            None => {
125                                info!("Agent server stopped because of no command tx.");
126                                publish_server_event(&server_event_tx, AgentServerEvent::ServerStopSuccess).await;
127                                return;
128                            },
129                        }
130                    }
131                    // Send network
132                    _ = server_event_tick_interval.tick() => {
133                        let upload_bytes_amount_current_val =
134                            upload_bytes_amount.fetch_add(0, Relaxed);
135                        let download_bytes_amount_current_val =
136                            download_bytes_amount.fetch_add(0, Relaxed);
137
138                        let upload_mb_per_second =
139                            (upload_bytes_amount_current_val - upload_bytes_amount_pre_val) as f64
140                                / mb_per_second_div_base;
141
142                        let download_mb_per_second = (download_bytes_amount_current_val
143                            - download_bytes_amount_pre_val)
144                            as f64
145                            / mb_per_second_div_base;
146
147                        publish_server_event(
148                            &server_event_tx,
149                            AgentServerEvent::NetworkState {
150                                upload_mb_amount: upload_bytes_amount_current_val as f64
151                                    / ONE_MB as f64,
152                                upload_mb_per_second,
153                                download_mb_amount: download_bytes_amount_current_val as f64
154                                    / ONE_MB as f64,
155                                download_mb_per_second,
156                            },
157                        )
158                        .await;
159                    }
160                    // Accepting client connection
161                    client_accept_result = Self::accept_client_connection(&config, &tcp_listener) => {
162                        match client_accept_result{
163                            Ok((client_tcp_stream, client_socket_address)) => {
164                                debug!("Accept client tcp connection on address: {client_socket_address}");
165                                Self::handle_client_connection(
166                                    client_tcp_stream,
167                                    client_socket_address.into(),
168                                    client_dispatcher.clone(),
169                                    server_event_tx.clone(),
170                                    upload_bytes_amount.clone(),
171                                    download_bytes_amount.clone(),
172                                    stopped_status.clone()
173                                );
174                            }
175                            Err(e) => {
176                                error!("Agent server fail to accept client connection because of error: {e:?}");
177                                continue;
178                            }
179                        }
180                    }
181                }
182            }
183        });
184        (server_command_tx, server_event_rx)
185    }
186
187    async fn accept_client_connection(
188        config: &AgentServerConfig,
189        tcp_listener: &TfoListener,
190    ) -> Result<(TimeoutStream<TfoStream>, SocketAddr), AgentServerError> {
191        let (client_tcp_stream, client_socket_address) = tcp_listener.accept().await?;
192        client_tcp_stream.set_nodelay(true)?;
193        let mut client_tcp_stream = TimeoutStream::new(client_tcp_stream);
194        client_tcp_stream.set_read_timeout(Some(Duration::from_secs(
195            config.client_connection_read_timeout(),
196        )));
197        client_tcp_stream.set_write_timeout(Some(Duration::from_secs(
198            config.client_connection_write_timeout(),
199        )));
200        Ok((client_tcp_stream, client_socket_address))
201    }
202
203    fn handle_client_connection(
204        client_tcp_stream: TimeoutStream<TfoStream>,
205        client_socket_address: PpaassUnifiedAddress,
206        client_dispatcher: ClientDispatcher<AgentServerRsaCryptoFetcher>,
207        server_event_tx: Sender<AgentServerEvent>,
208        upload_bytes_amount: Arc<AtomicU64>,
209        download_bytes_amount: Arc<AtomicU64>,
210        stopped_status: Arc<AtomicBool>,
211    ) {
212        tokio::spawn(async move {
213            let tunnel = match client_dispatcher
214                .dispatch(
215                    client_tcp_stream,
216                    &client_socket_address,
217                    &server_event_tx,
218                    upload_bytes_amount,
219                    download_bytes_amount,
220                )
221                .await
222            {
223                Ok(tunnel) => tunnel,
224                Err(e) => {
225                    error!("Fail to dispatch client connection [{client_socket_address}] to tunnel because of error: {e:?}");
226                    return;
227                }
228            };
229
230            match tunnel {
231                Tunnel::Socks5(tunnel) => {
232                    if let Err(e) = tunnel.process(&server_event_tx, stopped_status).await {
233                        error!("Fail to process socks5 tunnel for client connection [{client_socket_address}] because of error: {e:?}");
234                    }
235                }
236                Tunnel::Http(tunnel) => {
237                    if let Err(e) = tunnel.process(&server_event_tx, stopped_status).await {
238                        error!("Fail to process http tunnel for client connection [{client_socket_address}] because of error: {e:?}");
239                    }
240                }
241            }
242        });
243    }
244}