Skip to main content

trojan_server/
server.rs

1//! Main server loop and connection handling.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_metrics::{
26    ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
27    record_connection_rejected, record_error, record_tls_handshake_duration,
28    set_connection_queue_depth,
29};
30
31/// Default graceful shutdown timeout.
32pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
33
34/// Global connection ID counter.
35static CONN_ID: AtomicU64 = AtomicU64::new(1);
36
37/// Generate a unique connection ID.
38#[inline]
39fn next_conn_id() -> u64 {
40    CONN_ID.fetch_add(1, Ordering::Relaxed)
41}
42
43/// Run the server with a cancellation token for graceful shutdown.
44pub async fn run_with_shutdown(
45    config: Config,
46    auth: impl AuthBackend + 'static,
47    shutdown: CancellationToken,
48) -> Result<(), ServerError> {
49    let tls_config = load_tls_config(&config.tls)?;
50    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
51
52    let listen: SocketAddr = config
53        .server
54        .listen
55        .parse()
56        .map_err(|_| ServerError::Config("invalid listen address".into()))?;
57
58    let fallback_addr = resolve_sockaddr(&config.server.fallback).await?;
59
60    // Initialize fallback connection pool if configured
61    let fallback_pool: Option<Arc<ConnectionPool>> =
62        config.server.fallback_pool.as_ref().map(|pool_cfg| {
63            info!(
64                max_idle = pool_cfg.max_idle,
65                max_age_secs = pool_cfg.max_age_secs,
66                fill_batch = pool_cfg.fill_batch,
67                fill_delay_ms = pool_cfg.fill_delay_ms,
68                "fallback connection pool enabled"
69            );
70            let pool = Arc::new(ConnectionPool::new(
71                fallback_addr,
72                pool_cfg.max_idle,
73                pool_cfg.max_age_secs,
74                pool_cfg.fill_batch,
75                pool_cfg.fill_delay_ms,
76            ));
77            // Use max_age_secs as cleanup interval
78            pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
79            pool
80        });
81
82    // Extract resource limits with defaults
83    let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
84        match &config.server.resource_limits {
85            Some(rl) => {
86                info!(
87                    relay_buffer = rl.relay_buffer_size,
88                    tcp_send_buffer = rl.tcp_send_buffer,
89                    tcp_recv_buffer = rl.tcp_recv_buffer,
90                    connection_backlog = rl.connection_backlog,
91                    "resource limits configured"
92                );
93                (
94                    rl.relay_buffer_size,
95                    rl.tcp_send_buffer,
96                    rl.tcp_recv_buffer,
97                    rl.connection_backlog,
98                )
99            }
100            None => (
101                defaults::DEFAULT_RELAY_BUFFER_SIZE,
102                defaults::DEFAULT_TCP_SEND_BUFFER,
103                defaults::DEFAULT_TCP_RECV_BUFFER,
104                defaults::DEFAULT_CONNECTION_BACKLOG,
105            ),
106        };
107
108    // Initialize analytics if feature enabled and configured
109    #[cfg(feature = "analytics")]
110    let analytics = if config.analytics.enabled {
111        match trojan_analytics::init(config.analytics.clone()).await {
112            Ok(collector) => {
113                info!("analytics enabled, sending to ClickHouse");
114                Some(collector)
115            }
116            Err(e) => {
117                warn!("failed to init analytics: {}, disabled", e);
118                None
119            }
120        }
121    } else {
122        debug!("analytics disabled in config");
123        None
124    };
125
126    let state = Arc::new(ServerState {
127        fallback_addr,
128        max_udp_payload: config.server.max_udp_payload,
129        max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
130        max_header_bytes: config.server.max_header_bytes,
131        tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
132        udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
133        fallback_pool,
134        relay_buffer_size,
135        tcp_send_buffer,
136        tcp_recv_buffer,
137        websocket: config.websocket.clone(),
138        #[cfg(feature = "analytics")]
139        analytics,
140    });
141    let auth = Arc::new(auth);
142    let tracker = ConnectionTracker::new();
143
144    // Connection limiter (None = unlimited)
145    let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
146        info!("max_connections set to {}", n);
147        Arc::new(Semaphore::new(n))
148    });
149
150    // Rate limiter (None = disabled)
151    let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
152        info!(
153            max_per_ip = rl.max_connections_per_ip,
154            window_secs = rl.window_secs,
155            "rate limiting enabled"
156        );
157        let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
158        limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
159        limiter
160    });
161
162    // Create listener with custom backlog using socket2
163    let listener = create_listener(listen, connection_backlog)?;
164    info!(address = %listen, backlog = connection_backlog, "listening");
165
166    #[cfg(feature = "ws")]
167    if config.websocket.enabled && config.websocket.mode == "split" {
168        let ws_listen = config.websocket.listen.clone().unwrap_or_default();
169        let ws_addr: SocketAddr = ws_listen
170            .parse()
171            .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
172        let ws_listener = create_listener(ws_addr, connection_backlog)?;
173        let ws_acceptor = acceptor.clone();
174        let ws_state = state.clone();
175        let ws_auth = auth.clone();
176        let ws_tracker = tracker.clone();
177        let ws_conn_limit = conn_limit.clone();
178        let ws_rate_limiter = rate_limiter.clone();
179        let ws_shutdown = shutdown.clone();
180
181        info!(address = %ws_addr, "websocket split listener started");
182        tokio::spawn(async move {
183            loop {
184                tokio::select! {
185                    biased;
186                    _ = ws_shutdown.cancelled() => break,
187                    result = ws_listener.accept() => {
188                        let (tcp, peer) = match result {
189                            Ok(v) => v,
190                            Err(_) => continue,
191                        };
192
193                        if let Some(ref limiter) = ws_rate_limiter {
194                            let ip = peer.ip();
195                            if !limiter.check_and_increment(ip) {
196                                record_connection_rejected("rate_limit");
197                                drop(tcp);
198                                continue;
199                            }
200                        }
201
202                        let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
203                            Some(sem) => match sem.clone().try_acquire_owned() {
204                                Ok(p) => Some(p),
205                                Err(_) => {
206                                    record_connection_rejected("max_connections");
207                                    drop(tcp);
208                                    continue;
209                                }
210                            },
211                            None => None,
212                        };
213
214                        let conn_id = next_conn_id();
215                        let acceptor = ws_acceptor.clone();
216                        let state = ws_state.clone();
217                        let auth = ws_auth.clone();
218                        ws_tracker.increment();
219                        let guard = ConnectionGuard::new(ws_tracker.clone());
220
221                        let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
222                        tokio::spawn(
223                            async move {
224                                let _guard = guard;
225                                let _permit = permit;
226                                record_connection_accepted();
227                                let start = Instant::now();
228
229                                let result = async {
230                                    let tls_start = Instant::now();
231                                    let tls_timeout =
232                                        Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
233                                    match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
234                                    {
235                                        Ok(Ok(tls)) => {
236                                            let tls_duration = tls_start.elapsed().as_secs_f64();
237                                            record_tls_handshake_duration(tls_duration);
238                                            crate::handler::handle_ws_only(tls, state, auth, peer).await
239                                        }
240                                        Ok(Err(err)) => {
241                                            record_error(ERROR_TLS_HANDSHAKE);
242                                            warn!(error = %err, "TLS handshake failed");
243                                            Ok(())
244                                        }
245                                        Err(_) => {
246                                            record_error(ERROR_TLS_HANDSHAKE);
247                                            warn!(
248                                                timeout_secs = tls_timeout.as_secs(),
249                                                "TLS handshake timed out"
250                                            );
251                                            Ok(())
252                                        }
253                                    }
254                                }
255                                .await;
256
257                                let duration_secs = start.elapsed().as_secs_f64();
258                                record_connection_closed(duration_secs);
259
260                                if let Err(ref err) = result {
261                                    warn!(error = %err, "connection error");
262                                }
263                            }
264                            .instrument(span),
265                        );
266                    }
267                }
268            }
269        });
270    }
271
272    #[cfg(not(feature = "ws"))]
273    if config.websocket.enabled {
274        warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
275    }
276
277    loop {
278        tokio::select! {
279            biased;
280
281            _ = shutdown.cancelled() => {
282                info!("shutdown signal received, stopping accept loop");
283                break;
284            }
285
286            result = listener.accept() => {
287                let (tcp, peer) = result?;
288
289                // Update connection queue depth metric (based on semaphore usage)
290                if let Some(ref sem) = conn_limit {
291                    let available = sem.available_permits();
292                    set_connection_queue_depth(available as f64);
293                }
294
295                // Check rate limit first
296                if let Some(ref limiter) = rate_limiter {
297                    let ip = peer.ip();
298                    if !limiter.check_and_increment(ip) {
299                        debug!(peer = %peer, reason = "rate_limit", "connection rejected");
300                        record_connection_rejected("rate_limit");
301                        drop(tcp);
302                        continue;
303                    }
304                }
305
306                // Try to acquire connection permit
307                let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
308                    Some(sem) => match sem.clone().try_acquire_owned() {
309                        Ok(p) => Some(p),
310                        Err(_) => {
311                            debug!(peer = %peer, reason = "max_connections", "connection rejected");
312                            record_connection_rejected("max_connections");
313                            drop(tcp); // close immediately
314                            continue;
315                        }
316                    },
317                    None => None,
318                };
319
320                let conn_id = next_conn_id();
321                debug!(conn_id, peer = %peer, "new connection");
322
323                let acceptor = acceptor.clone();
324                let state = state.clone();
325                let auth = auth.clone();
326                tracker.increment();
327                let guard = ConnectionGuard::new(tracker.clone());
328
329                let span = info_span!("conn", id = conn_id, peer = %peer);
330                tokio::spawn(
331                    async move {
332                        let _guard = guard; // ensure decrement on drop
333                        let _permit = permit; // hold permit until connection closes
334                        record_connection_accepted();
335                        let start = Instant::now();
336
337                        let result = async {
338                            // Measure TLS handshake duration with timeout
339                            let tls_start = Instant::now();
340                            let tls_timeout =
341                                Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
342                            match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
343                                Ok(Ok(tls)) => {
344                                    let tls_duration = tls_start.elapsed().as_secs_f64();
345                                    record_tls_handshake_duration(tls_duration);
346                                    debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
347                                    handle_conn(tls, state, auth, peer).await
348                                }
349                                Ok(Err(err)) => {
350                                    record_error(ERROR_TLS_HANDSHAKE);
351                                    warn!(error = %err, "TLS handshake failed");
352                                    Ok(())
353                                }
354                                Err(_) => {
355                                    record_error(ERROR_TLS_HANDSHAKE);
356                                    warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
357                                    Ok(())
358                                }
359                            }
360                        }
361                        .await;
362
363                        let duration_secs = start.elapsed().as_secs_f64();
364                        record_connection_closed(duration_secs);
365
366                        if let Err(ref err) = result {
367                            record_error(err.error_type());
368                            warn!(duration_secs, error = %err, "connection closed with error");
369                        } else {
370                            debug!(duration_secs, "connection closed");
371                        }
372                    }
373                    .instrument(span),
374                );
375            }
376        }
377    }
378
379    // Shutdown rate limiter cleanup task
380    if let Some(ref limiter) = rate_limiter {
381        limiter.shutdown();
382    }
383
384    // Graceful drain: wait for active connections
385    let active = tracker.count();
386    if active > 0 {
387        info!("waiting for {} active connections to drain", active);
388        if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
389            info!("all connections drained");
390        } else {
391            warn!(
392                "shutdown timeout, {} connections still active",
393                tracker.count()
394            );
395        }
396    }
397
398    info!("server stopped");
399    Ok(())
400}
401
402/// Run the server (blocking until error, no graceful shutdown).
403/// For backward compatibility with existing code.
404pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
405    run_with_shutdown(config, auth, CancellationToken::new()).await
406}