1use crate::{
2 command::AgentServerCommand, config::AgentServerConfig, error::AgentServerError,
3 event::AgentServerEvent, publish_server_event,
4};
5use crate::{
6 crypto::AgentServerRsaCryptoFetcher,
7 proxy::ProxyConnectionFactory,
8 tunnel::dispatcher::{ClientDispatcher, Tunnel},
9};
10
11use std::{net::SocketAddr, sync::atomic::AtomicBool};
12use std::{
13 sync::{
14 atomic::{AtomicU64, Ordering::Relaxed},
15 Arc,
16 },
17 time::Duration,
18};
19
20use anyhow::Result;
21use ppaass_protocol::message::values::address::PpaassUnifiedAddress;
22use tokio::sync::mpsc::{channel, Receiver, Sender};
23
24use tokio::{net::TcpListener, time::interval};
25
26use tokio_io_timeout::TimeoutStream;
27use tokio_tfo::{TfoListener, TfoStream};
28use tracing::{debug, error, info};
29
30const AGENT_SEVER_EVENT_CHANNEL_BUF: usize = 1024;
31const AGENT_SEVER_COMMAND_CHANNEL_BUF: usize = 1024;
32const ONE_MB: u64 = 1024 * 1024;
33
34pub struct AgentServer {
35 config: Arc<AgentServerConfig>,
36 client_dispatcher: ClientDispatcher<AgentServerRsaCryptoFetcher>,
37}
38
39impl AgentServer {
40 pub fn new(config: Arc<AgentServerConfig>) -> Result<Self> {
41 let rsa_crypto_fetcher = AgentServerRsaCryptoFetcher::new(&config)?;
42 let proxy_connection_factory =
43 ProxyConnectionFactory::new(config.clone(), rsa_crypto_fetcher)?;
44 let client_dispatcher = ClientDispatcher::new(config.clone(), proxy_connection_factory);
45 Ok(Self {
46 config,
47 client_dispatcher,
48 })
49 }
50
51 pub fn start(self) -> (Sender<AgentServerCommand>, Receiver<AgentServerEvent>) {
52 let (server_event_tx, server_event_rx) = channel(AGENT_SEVER_EVENT_CHANNEL_BUF);
53 let (server_command_tx, mut server_command_rx) = channel(AGENT_SEVER_COMMAND_CHANNEL_BUF);
54 let config = self.config;
55 let client_dispatcher = self.client_dispatcher;
56 let upload_bytes_amount: Arc<AtomicU64> = Default::default();
57 let download_bytes_amount: Arc<AtomicU64> = Default::default();
58 let stopped_status: Arc<AtomicBool> = Default::default();
59 tokio::spawn(async move {
60 let agent_server_bind_addr = if config.ipv6() {
61 format!("::1:{}", config.port())
62 } else {
63 format!("0.0.0.0:{}", config.port())
64 };
65 info!("Agent server start to serve request on address: {agent_server_bind_addr}.");
66 let tcp_listener = match TcpListener::bind(&agent_server_bind_addr).await {
67 Ok(tcp_listener) => tcp_listener,
68 Err(e) => {
69 error!(
70 "Fail to listen tcp port {} because of error: {e:?}",
71 config.port()
72 );
73 publish_server_event(
74 &server_event_tx,
75 AgentServerEvent::ServerStartFail {
76 listening_port: config.port(),
77 reason: format!("Fail to listen tcp port: {}", config.port()),
78 },
79 )
80 .await;
81 return;
82 }
83 };
84 let tcp_listener = match TfoListener::from_tokio(tcp_listener) {
85 Ok(tcp_listener) => tcp_listener,
86 Err(e) => {
87 error!("Fail to use fast open on tcp listener because of error: {e:?}");
88 return;
89 }
90 };
91
92 publish_server_event(
93 &server_event_tx,
94 AgentServerEvent::ServerStartSuccess(config.port()),
95 )
96 .await;
97
98 let upload_bytes_amount = upload_bytes_amount.clone();
100 let download_bytes_amount = download_bytes_amount.clone();
101 let server_event_tx = server_event_tx.clone();
102 let server_event_tick_interval_val = config.server_signal_tick_interval();
103 let mb_per_second_div_base = (server_event_tick_interval_val * ONE_MB) as f64;
104 let mut server_event_tick_interval =
105 interval(Duration::from_secs(server_event_tick_interval_val));
106
107 loop {
108 let upload_bytes_amount_pre_val = upload_bytes_amount.fetch_add(0, Relaxed);
109 let download_bytes_amount_pre_val = download_bytes_amount.fetch_add(0, Relaxed);
110 tokio::select! {
111 server_command = server_command_rx.recv() => {
113 match server_command {
114 Some(server_command) => {
115 match server_command {
116 AgentServerCommand::Stop => {
117 info!("Agent server stopped because of receive stop command.");
118 stopped_status.swap(true, Relaxed);
119 publish_server_event(&server_event_tx, AgentServerEvent::ServerStopSuccess).await;
120 return;
121 },
122 }
123 },
124 None => {
125 info!("Agent server stopped because of no command tx.");
126 publish_server_event(&server_event_tx, AgentServerEvent::ServerStopSuccess).await;
127 return;
128 },
129 }
130 }
131 _ = server_event_tick_interval.tick() => {
133 let upload_bytes_amount_current_val =
134 upload_bytes_amount.fetch_add(0, Relaxed);
135 let download_bytes_amount_current_val =
136 download_bytes_amount.fetch_add(0, Relaxed);
137
138 let upload_mb_per_second =
139 (upload_bytes_amount_current_val - upload_bytes_amount_pre_val) as f64
140 / mb_per_second_div_base;
141
142 let download_mb_per_second = (download_bytes_amount_current_val
143 - download_bytes_amount_pre_val)
144 as f64
145 / mb_per_second_div_base;
146
147 publish_server_event(
148 &server_event_tx,
149 AgentServerEvent::NetworkState {
150 upload_mb_amount: upload_bytes_amount_current_val as f64
151 / ONE_MB as f64,
152 upload_mb_per_second,
153 download_mb_amount: download_bytes_amount_current_val as f64
154 / ONE_MB as f64,
155 download_mb_per_second,
156 },
157 )
158 .await;
159 }
160 client_accept_result = Self::accept_client_connection(&config, &tcp_listener) => {
162 match client_accept_result{
163 Ok((client_tcp_stream, client_socket_address)) => {
164 debug!("Accept client tcp connection on address: {client_socket_address}");
165 Self::handle_client_connection(
166 client_tcp_stream,
167 client_socket_address.into(),
168 client_dispatcher.clone(),
169 server_event_tx.clone(),
170 upload_bytes_amount.clone(),
171 download_bytes_amount.clone(),
172 stopped_status.clone()
173 );
174 }
175 Err(e) => {
176 error!("Agent server fail to accept client connection because of error: {e:?}");
177 continue;
178 }
179 }
180 }
181 }
182 }
183 });
184 (server_command_tx, server_event_rx)
185 }
186
187 async fn accept_client_connection(
188 config: &AgentServerConfig,
189 tcp_listener: &TfoListener,
190 ) -> Result<(TimeoutStream<TfoStream>, SocketAddr), AgentServerError> {
191 let (client_tcp_stream, client_socket_address) = tcp_listener.accept().await?;
192 client_tcp_stream.set_nodelay(true)?;
193 let mut client_tcp_stream = TimeoutStream::new(client_tcp_stream);
194 client_tcp_stream.set_read_timeout(Some(Duration::from_secs(
195 config.client_connection_read_timeout(),
196 )));
197 client_tcp_stream.set_write_timeout(Some(Duration::from_secs(
198 config.client_connection_write_timeout(),
199 )));
200 Ok((client_tcp_stream, client_socket_address))
201 }
202
203 fn handle_client_connection(
204 client_tcp_stream: TimeoutStream<TfoStream>,
205 client_socket_address: PpaassUnifiedAddress,
206 client_dispatcher: ClientDispatcher<AgentServerRsaCryptoFetcher>,
207 server_event_tx: Sender<AgentServerEvent>,
208 upload_bytes_amount: Arc<AtomicU64>,
209 download_bytes_amount: Arc<AtomicU64>,
210 stopped_status: Arc<AtomicBool>,
211 ) {
212 tokio::spawn(async move {
213 let tunnel = match client_dispatcher
214 .dispatch(
215 client_tcp_stream,
216 &client_socket_address,
217 &server_event_tx,
218 upload_bytes_amount,
219 download_bytes_amount,
220 )
221 .await
222 {
223 Ok(tunnel) => tunnel,
224 Err(e) => {
225 error!("Fail to dispatch client connection [{client_socket_address}] to tunnel because of error: {e:?}");
226 return;
227 }
228 };
229
230 match tunnel {
231 Tunnel::Socks5(tunnel) => {
232 if let Err(e) = tunnel.process(&server_event_tx, stopped_status).await {
233 error!("Fail to process socks5 tunnel for client connection [{client_socket_address}] because of error: {e:?}");
234 }
235 }
236 Tunnel::Http(tunnel) => {
237 if let Err(e) = tunnel.process(&server_event_tx, stopped_status).await {
238 error!("Fail to process http tunnel for client connection [{client_socket_address}] because of error: {e:?}");
239 }
240 }
241 }
242 });
243 }
244}