1use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_metrics::{
26 ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
27 record_connection_rejected, record_error, record_tls_handshake_duration,
28 set_connection_queue_depth,
29};
30
31pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
33
34static CONN_ID: AtomicU64 = AtomicU64::new(1);
36
37#[inline]
39fn next_conn_id() -> u64 {
40 CONN_ID.fetch_add(1, Ordering::Relaxed)
41}
42
43pub async fn run_with_shutdown(
45 config: Config,
46 auth: impl AuthBackend + 'static,
47 shutdown: CancellationToken,
48) -> Result<(), ServerError> {
49 let tls_config = load_tls_config(&config.tls)?;
50 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
51
52 let listen: SocketAddr = config
53 .server
54 .listen
55 .parse()
56 .map_err(|_| ServerError::Config("invalid listen address".into()))?;
57
58 let fallback_addr = resolve_sockaddr(&config.server.fallback).await?;
59
60 let fallback_pool: Option<Arc<ConnectionPool>> =
62 config.server.fallback_pool.as_ref().map(|pool_cfg| {
63 info!(
64 max_idle = pool_cfg.max_idle,
65 max_age_secs = pool_cfg.max_age_secs,
66 fill_batch = pool_cfg.fill_batch,
67 fill_delay_ms = pool_cfg.fill_delay_ms,
68 "fallback connection pool enabled"
69 );
70 let pool = Arc::new(ConnectionPool::new(
71 fallback_addr,
72 pool_cfg.max_idle,
73 pool_cfg.max_age_secs,
74 pool_cfg.fill_batch,
75 pool_cfg.fill_delay_ms,
76 ));
77 pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
79 pool
80 });
81
82 let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
84 match &config.server.resource_limits {
85 Some(rl) => {
86 info!(
87 relay_buffer = rl.relay_buffer_size,
88 tcp_send_buffer = rl.tcp_send_buffer,
89 tcp_recv_buffer = rl.tcp_recv_buffer,
90 connection_backlog = rl.connection_backlog,
91 "resource limits configured"
92 );
93 (
94 rl.relay_buffer_size,
95 rl.tcp_send_buffer,
96 rl.tcp_recv_buffer,
97 rl.connection_backlog,
98 )
99 }
100 None => (
101 defaults::DEFAULT_RELAY_BUFFER_SIZE,
102 defaults::DEFAULT_TCP_SEND_BUFFER,
103 defaults::DEFAULT_TCP_RECV_BUFFER,
104 defaults::DEFAULT_CONNECTION_BACKLOG,
105 ),
106 };
107
108 #[cfg(feature = "analytics")]
110 let analytics = if config.analytics.enabled {
111 match trojan_analytics::init(config.analytics.clone()).await {
112 Ok(collector) => {
113 info!("analytics enabled, sending to ClickHouse");
114 Some(collector)
115 }
116 Err(e) => {
117 warn!("failed to init analytics: {}, disabled", e);
118 None
119 }
120 }
121 } else {
122 debug!("analytics disabled in config");
123 None
124 };
125
126 let state = Arc::new(ServerState {
127 fallback_addr,
128 max_udp_payload: config.server.max_udp_payload,
129 max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
130 max_header_bytes: config.server.max_header_bytes,
131 tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
132 udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
133 fallback_pool,
134 relay_buffer_size,
135 tcp_send_buffer,
136 tcp_recv_buffer,
137 websocket: config.websocket.clone(),
138 #[cfg(feature = "analytics")]
139 analytics,
140 });
141 let auth = Arc::new(auth);
142 let tracker = ConnectionTracker::new();
143
144 let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
146 info!("max_connections set to {}", n);
147 Arc::new(Semaphore::new(n))
148 });
149
150 let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
152 info!(
153 max_per_ip = rl.max_connections_per_ip,
154 window_secs = rl.window_secs,
155 "rate limiting enabled"
156 );
157 let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
158 limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
159 limiter
160 });
161
162 let listener = create_listener(listen, connection_backlog)?;
164 info!(address = %listen, backlog = connection_backlog, "listening");
165
166 #[cfg(feature = "ws")]
167 if config.websocket.enabled && config.websocket.mode == "split" {
168 let ws_listen = config.websocket.listen.clone().unwrap_or_default();
169 let ws_addr: SocketAddr = ws_listen
170 .parse()
171 .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
172 let ws_listener = create_listener(ws_addr, connection_backlog)?;
173 let ws_acceptor = acceptor.clone();
174 let ws_state = state.clone();
175 let ws_auth = auth.clone();
176 let ws_tracker = tracker.clone();
177 let ws_conn_limit = conn_limit.clone();
178 let ws_rate_limiter = rate_limiter.clone();
179 let ws_shutdown = shutdown.clone();
180
181 info!(address = %ws_addr, "websocket split listener started");
182 tokio::spawn(async move {
183 loop {
184 tokio::select! {
185 biased;
186 _ = ws_shutdown.cancelled() => break,
187 result = ws_listener.accept() => {
188 let (tcp, peer) = match result {
189 Ok(v) => v,
190 Err(_) => continue,
191 };
192
193 if let Some(ref limiter) = ws_rate_limiter {
194 let ip = peer.ip();
195 if !limiter.check_and_increment(ip) {
196 record_connection_rejected("rate_limit");
197 drop(tcp);
198 continue;
199 }
200 }
201
202 let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
203 Some(sem) => match sem.clone().try_acquire_owned() {
204 Ok(p) => Some(p),
205 Err(_) => {
206 record_connection_rejected("max_connections");
207 drop(tcp);
208 continue;
209 }
210 },
211 None => None,
212 };
213
214 let conn_id = next_conn_id();
215 let acceptor = ws_acceptor.clone();
216 let state = ws_state.clone();
217 let auth = ws_auth.clone();
218 ws_tracker.increment();
219 let guard = ConnectionGuard::new(ws_tracker.clone());
220
221 let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
222 tokio::spawn(
223 async move {
224 let _guard = guard;
225 let _permit = permit;
226 record_connection_accepted();
227 let start = Instant::now();
228
229 let result = async {
230 let tls_start = Instant::now();
231 let tls_timeout =
232 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
233 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
234 {
235 Ok(Ok(tls)) => {
236 let tls_duration = tls_start.elapsed().as_secs_f64();
237 record_tls_handshake_duration(tls_duration);
238 crate::handler::handle_ws_only(tls, state, auth, peer).await
239 }
240 Ok(Err(err)) => {
241 record_error(ERROR_TLS_HANDSHAKE);
242 warn!(error = %err, "TLS handshake failed");
243 Ok(())
244 }
245 Err(_) => {
246 record_error(ERROR_TLS_HANDSHAKE);
247 warn!(
248 timeout_secs = tls_timeout.as_secs(),
249 "TLS handshake timed out"
250 );
251 Ok(())
252 }
253 }
254 }
255 .await;
256
257 let duration_secs = start.elapsed().as_secs_f64();
258 record_connection_closed(duration_secs);
259
260 if let Err(ref err) = result {
261 warn!(error = %err, "connection error");
262 }
263 }
264 .instrument(span),
265 );
266 }
267 }
268 }
269 });
270 }
271
272 #[cfg(not(feature = "ws"))]
273 if config.websocket.enabled {
274 warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
275 }
276
277 loop {
278 tokio::select! {
279 biased;
280
281 _ = shutdown.cancelled() => {
282 info!("shutdown signal received, stopping accept loop");
283 break;
284 }
285
286 result = listener.accept() => {
287 let (tcp, peer) = result?;
288
289 if let Some(ref sem) = conn_limit {
291 let available = sem.available_permits();
292 set_connection_queue_depth(available as f64);
293 }
294
295 if let Some(ref limiter) = rate_limiter {
297 let ip = peer.ip();
298 if !limiter.check_and_increment(ip) {
299 debug!(peer = %peer, reason = "rate_limit", "connection rejected");
300 record_connection_rejected("rate_limit");
301 drop(tcp);
302 continue;
303 }
304 }
305
306 let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
308 Some(sem) => match sem.clone().try_acquire_owned() {
309 Ok(p) => Some(p),
310 Err(_) => {
311 debug!(peer = %peer, reason = "max_connections", "connection rejected");
312 record_connection_rejected("max_connections");
313 drop(tcp); continue;
315 }
316 },
317 None => None,
318 };
319
320 let conn_id = next_conn_id();
321 debug!(conn_id, peer = %peer, "new connection");
322
323 let acceptor = acceptor.clone();
324 let state = state.clone();
325 let auth = auth.clone();
326 tracker.increment();
327 let guard = ConnectionGuard::new(tracker.clone());
328
329 let span = info_span!("conn", id = conn_id, peer = %peer);
330 tokio::spawn(
331 async move {
332 let _guard = guard; let _permit = permit; record_connection_accepted();
335 let start = Instant::now();
336
337 let result = async {
338 let tls_start = Instant::now();
340 let tls_timeout =
341 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
342 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
343 Ok(Ok(tls)) => {
344 let tls_duration = tls_start.elapsed().as_secs_f64();
345 record_tls_handshake_duration(tls_duration);
346 debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
347 handle_conn(tls, state, auth, peer).await
348 }
349 Ok(Err(err)) => {
350 record_error(ERROR_TLS_HANDSHAKE);
351 warn!(error = %err, "TLS handshake failed");
352 Ok(())
353 }
354 Err(_) => {
355 record_error(ERROR_TLS_HANDSHAKE);
356 warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
357 Ok(())
358 }
359 }
360 }
361 .await;
362
363 let duration_secs = start.elapsed().as_secs_f64();
364 record_connection_closed(duration_secs);
365
366 if let Err(ref err) = result {
367 record_error(err.error_type());
368 warn!(duration_secs, error = %err, "connection closed with error");
369 } else {
370 debug!(duration_secs, "connection closed");
371 }
372 }
373 .instrument(span),
374 );
375 }
376 }
377 }
378
379 if let Some(ref limiter) = rate_limiter {
381 limiter.shutdown();
382 }
383
384 let active = tracker.count();
386 if active > 0 {
387 info!("waiting for {} active connections to drain", active);
388 if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
389 info!("all connections drained");
390 } else {
391 warn!(
392 "shutdown timeout, {} connections still active",
393 tracker.count()
394 );
395 }
396 }
397
398 info!("server stopped");
399 Ok(())
400}
401
402pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
405 run_with_shutdown(config, auth, CancellationToken::new()).await
406}