1use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, apply_tcp_options, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_metrics::{
26 ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
27 record_connection_rejected, record_error, record_tls_handshake_duration,
28 set_connection_queue_depth,
29};
30
31pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
33
34static CONN_ID: AtomicU64 = AtomicU64::new(1);
36
37#[inline]
39fn next_conn_id() -> u64 {
40 CONN_ID.fetch_add(1, Ordering::Relaxed)
41}
42
43pub async fn run_with_shutdown(
45 config: Config,
46 auth: impl AuthBackend + 'static,
47 shutdown: CancellationToken,
48) -> Result<(), ServerError> {
49 let tls_config = load_tls_config(&config.tls)?;
50 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
51
52 let listen: SocketAddr = config
53 .server
54 .listen
55 .parse()
56 .map_err(|_| ServerError::Config("invalid listen address".into()))?;
57
58 let fallback_addr = resolve_sockaddr(&config.server.fallback).await?;
59
60 let fallback_pool: Option<Arc<ConnectionPool>> =
62 config.server.fallback_pool.as_ref().map(|pool_cfg| {
63 info!(
64 max_idle = pool_cfg.max_idle,
65 max_age_secs = pool_cfg.max_age_secs,
66 fill_batch = pool_cfg.fill_batch,
67 fill_delay_ms = pool_cfg.fill_delay_ms,
68 "fallback connection pool enabled"
69 );
70 let pool = Arc::new(ConnectionPool::new(
71 fallback_addr,
72 pool_cfg.max_idle,
73 pool_cfg.max_age_secs,
74 pool_cfg.fill_batch,
75 pool_cfg.fill_delay_ms,
76 ));
77 pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
79 pool
80 });
81
82 let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
84 match &config.server.resource_limits {
85 Some(rl) => {
86 info!(
87 relay_buffer = rl.relay_buffer_size,
88 tcp_send_buffer = rl.tcp_send_buffer,
89 tcp_recv_buffer = rl.tcp_recv_buffer,
90 connection_backlog = rl.connection_backlog,
91 "resource limits configured"
92 );
93 (
94 rl.relay_buffer_size,
95 rl.tcp_send_buffer,
96 rl.tcp_recv_buffer,
97 rl.connection_backlog,
98 )
99 }
100 None => (
101 defaults::DEFAULT_RELAY_BUFFER_SIZE,
102 defaults::DEFAULT_TCP_SEND_BUFFER,
103 defaults::DEFAULT_TCP_RECV_BUFFER,
104 defaults::DEFAULT_CONNECTION_BACKLOG,
105 ),
106 };
107
108 #[cfg(feature = "analytics")]
110 let analytics = if config.analytics.enabled {
111 match trojan_analytics::init(config.analytics.clone()).await {
112 Ok(collector) => {
113 info!("analytics enabled, sending to ClickHouse");
114 Some(collector)
115 }
116 Err(e) => {
117 warn!("failed to init analytics: {}, disabled", e);
118 None
119 }
120 }
121 } else {
122 debug!("analytics disabled in config");
123 None
124 };
125
126 let tcp_cfg = &config.server.tcp;
128 info!(
129 no_delay = tcp_cfg.no_delay,
130 keepalive_secs = tcp_cfg.keepalive_secs,
131 reuse_port = tcp_cfg.reuse_port,
132 fast_open = tcp_cfg.fast_open,
133 "TCP options configured"
134 );
135
136 let state = Arc::new(ServerState {
137 fallback_addr,
138 max_udp_payload: config.server.max_udp_payload,
139 max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
140 max_header_bytes: config.server.max_header_bytes,
141 tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
142 udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
143 fallback_pool,
144 relay_buffer_size,
145 tcp_send_buffer,
146 tcp_recv_buffer,
147 tcp_config: config.server.tcp.clone(),
148 websocket: config.websocket.clone(),
149 #[cfg(feature = "analytics")]
150 analytics,
151 });
152 let auth = Arc::new(auth);
153 let tracker = ConnectionTracker::new();
154
155 let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
157 info!("max_connections set to {}", n);
158 Arc::new(Semaphore::new(n))
159 });
160
161 let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
163 info!(
164 max_per_ip = rl.max_connections_per_ip,
165 window_secs = rl.window_secs,
166 "rate limiting enabled"
167 );
168 let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
169 limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
170 limiter
171 });
172
173 let listener = create_listener(listen, connection_backlog, &config.server.tcp)?;
175 info!(address = %listen, backlog = connection_backlog, "listening");
176
177 #[cfg(feature = "ws")]
178 if config.websocket.enabled && config.websocket.mode == "split" {
179 let ws_listen = config.websocket.listen.clone().unwrap_or_default();
180 let ws_addr: SocketAddr = ws_listen
181 .parse()
182 .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
183 let ws_listener = create_listener(ws_addr, connection_backlog, &config.server.tcp)?;
184 let ws_acceptor = acceptor.clone();
185 let ws_state = state.clone();
186 let ws_auth = auth.clone();
187 let ws_tracker = tracker.clone();
188 let ws_conn_limit = conn_limit.clone();
189 let ws_rate_limiter = rate_limiter.clone();
190 let ws_shutdown = shutdown.clone();
191
192 info!(address = %ws_addr, "websocket split listener started");
193 tokio::spawn(async move {
194 loop {
195 tokio::select! {
196 biased;
197 _ = ws_shutdown.cancelled() => break,
198 result = ws_listener.accept() => {
199 let (tcp, peer) = match result {
200 Ok(v) => v,
201 Err(_) => continue,
202 };
203
204 if let Err(e) = apply_tcp_options(&tcp, &ws_state.tcp_config) {
206 tracing::debug!(error = %e, "failed to apply TCP options");
207 }
208
209 if let Some(ref limiter) = ws_rate_limiter {
210 let ip = peer.ip();
211 if !limiter.check_and_increment(ip) {
212 record_connection_rejected("rate_limit");
213 drop(tcp);
214 continue;
215 }
216 }
217
218 let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
219 Some(sem) => match sem.clone().try_acquire_owned() {
220 Ok(p) => Some(p),
221 Err(_) => {
222 record_connection_rejected("max_connections");
223 drop(tcp);
224 continue;
225 }
226 },
227 None => None,
228 };
229
230 let conn_id = next_conn_id();
231 let acceptor = ws_acceptor.clone();
232 let state = ws_state.clone();
233 let auth = ws_auth.clone();
234 ws_tracker.increment();
235 let guard = ConnectionGuard::new(ws_tracker.clone());
236
237 let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
238 tokio::spawn(
239 async move {
240 let _guard = guard;
241 let _permit = permit;
242 record_connection_accepted();
243 let start = Instant::now();
244
245 let result = async {
246 let tls_start = Instant::now();
247 let tls_timeout =
248 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
249 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
250 {
251 Ok(Ok(tls)) => {
252 let tls_duration = tls_start.elapsed().as_secs_f64();
253 record_tls_handshake_duration(tls_duration);
254 crate::handler::handle_ws_only(tls, state, auth, peer).await
255 }
256 Ok(Err(err)) => {
257 record_error(ERROR_TLS_HANDSHAKE);
258 warn!(error = %err, "TLS handshake failed");
259 Ok(())
260 }
261 Err(_) => {
262 record_error(ERROR_TLS_HANDSHAKE);
263 warn!(
264 timeout_secs = tls_timeout.as_secs(),
265 "TLS handshake timed out"
266 );
267 Ok(())
268 }
269 }
270 }
271 .await;
272
273 let duration_secs = start.elapsed().as_secs_f64();
274 record_connection_closed(duration_secs);
275
276 if let Err(ref err) = result {
277 warn!(error = %err, "connection error");
278 }
279 }
280 .instrument(span),
281 );
282 }
283 }
284 }
285 });
286 }
287
288 #[cfg(not(feature = "ws"))]
289 if config.websocket.enabled {
290 warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
291 }
292
293 loop {
294 tokio::select! {
295 biased;
296
297 _ = shutdown.cancelled() => {
298 info!("shutdown signal received, stopping accept loop");
299 break;
300 }
301
302 result = listener.accept() => {
303 let (tcp, peer) = result?;
304
305 if let Err(e) = apply_tcp_options(&tcp, &state.tcp_config) {
307 debug!(error = %e, "failed to apply TCP options");
308 }
309
310 if let Some(ref sem) = conn_limit {
312 let available = sem.available_permits();
313 set_connection_queue_depth(available as f64);
314 }
315
316 if let Some(ref limiter) = rate_limiter {
318 let ip = peer.ip();
319 if !limiter.check_and_increment(ip) {
320 debug!(peer = %peer, reason = "rate_limit", "connection rejected");
321 record_connection_rejected("rate_limit");
322 drop(tcp);
323 continue;
324 }
325 }
326
327 let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
329 Some(sem) => match sem.clone().try_acquire_owned() {
330 Ok(p) => Some(p),
331 Err(_) => {
332 debug!(peer = %peer, reason = "max_connections", "connection rejected");
333 record_connection_rejected("max_connections");
334 drop(tcp); continue;
336 }
337 },
338 None => None,
339 };
340
341 let conn_id = next_conn_id();
342 debug!(conn_id, peer = %peer, "new connection");
343
344 let acceptor = acceptor.clone();
345 let state = state.clone();
346 let auth = auth.clone();
347 tracker.increment();
348 let guard = ConnectionGuard::new(tracker.clone());
349
350 let span = info_span!("conn", id = conn_id, peer = %peer);
351 tokio::spawn(
352 async move {
353 let _guard = guard; let _permit = permit; record_connection_accepted();
356 let start = Instant::now();
357
358 let result = async {
359 let tls_start = Instant::now();
361 let tls_timeout =
362 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
363 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
364 Ok(Ok(tls)) => {
365 let tls_duration = tls_start.elapsed().as_secs_f64();
366 record_tls_handshake_duration(tls_duration);
367 debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
368 handle_conn(tls, state, auth, peer).await
369 }
370 Ok(Err(err)) => {
371 record_error(ERROR_TLS_HANDSHAKE);
372 warn!(error = %err, "TLS handshake failed");
373 Ok(())
374 }
375 Err(_) => {
376 record_error(ERROR_TLS_HANDSHAKE);
377 warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
378 Ok(())
379 }
380 }
381 }
382 .await;
383
384 let duration_secs = start.elapsed().as_secs_f64();
385 record_connection_closed(duration_secs);
386
387 if let Err(ref err) = result {
388 record_error(err.error_type());
389 warn!(duration_secs, error = %err, "connection closed with error");
390 } else {
391 debug!(duration_secs, "connection closed");
392 }
393 }
394 .instrument(span),
395 );
396 }
397 }
398 }
399
400 if let Some(ref limiter) = rate_limiter {
402 limiter.shutdown();
403 }
404
405 let active = tracker.count();
407 if active > 0 {
408 info!("waiting for {} active connections to drain", active);
409 if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
410 info!("all connections drained");
411 } else {
412 warn!(
413 "shutdown timeout, {} connections still active",
414 tracker.count()
415 );
416 }
417 }
418
419 info!("server stopped");
420 Ok(())
421}
422
423pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
426 run_with_shutdown(config, auth, CancellationToken::new()).await
427}