Skip to main content

trojan_server/
server.rs

1//! Main server loop and connection handling.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::sync::{OwnedSemaphorePermit, Semaphore};
8use tokio::time::Instant;
9use tokio_rustls::TlsAcceptor;
10use tokio_util::sync::CancellationToken;
11use tracing::{debug, info, warn};
12
13use crate::error::ServerError;
14use crate::handler::handle_conn;
15use crate::pool::ConnectionPool;
16use crate::rate_limit::RateLimiter;
17use crate::resolve::resolve_sockaddr;
18use crate::state::ServerState;
19use crate::tls::load_tls_config;
20use crate::util::{create_listener, ConnectionGuard, ConnectionTracker};
21use trojan_auth::AuthBackend;
22use trojan_config::Config;
23use trojan_core::defaults;
24use trojan_metrics::{
25    record_connection_accepted, record_connection_closed, record_connection_rejected,
26    record_error, record_tls_handshake_duration, set_connection_queue_depth, ERROR_TLS_HANDSHAKE,
27};
28
29/// Default graceful shutdown timeout.
30pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
31
32/// Run the server with a cancellation token for graceful shutdown.
33pub async fn run_with_shutdown(
34    config: Config,
35    auth: impl AuthBackend + 'static,
36    shutdown: CancellationToken,
37) -> Result<(), ServerError> {
38    let tls_config = load_tls_config(&config.tls)?;
39    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
40
41    let listen: SocketAddr = config
42        .server
43        .listen
44        .parse()
45        .map_err(|_| ServerError::Config("invalid listen address".into()))?;
46
47    let fallback_addr = resolve_sockaddr(&config.server.fallback).await?;
48
49    // Initialize fallback connection pool if configured
50    let fallback_pool: Option<Arc<ConnectionPool>> =
51        config.server.fallback_pool.as_ref().map(|pool_cfg| {
52            info!(
53                max_idle = pool_cfg.max_idle,
54                max_age_secs = pool_cfg.max_age_secs,
55                fill_batch = pool_cfg.fill_batch,
56                fill_delay_ms = pool_cfg.fill_delay_ms,
57                "fallback connection pool enabled"
58            );
59            let pool = Arc::new(ConnectionPool::new(
60                fallback_addr,
61                pool_cfg.max_idle,
62                pool_cfg.max_age_secs,
63                pool_cfg.fill_batch,
64                pool_cfg.fill_delay_ms,
65            ));
66            // Use max_age_secs as cleanup interval
67            pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
68            pool
69        });
70
71    // Extract resource limits with defaults
72    let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
73        match &config.server.resource_limits {
74            Some(rl) => {
75                info!(
76                    relay_buffer = rl.relay_buffer_size,
77                    tcp_send_buffer = rl.tcp_send_buffer,
78                    tcp_recv_buffer = rl.tcp_recv_buffer,
79                    connection_backlog = rl.connection_backlog,
80                    "resource limits configured"
81                );
82                (
83                    rl.relay_buffer_size,
84                    rl.tcp_send_buffer,
85                    rl.tcp_recv_buffer,
86                    rl.connection_backlog,
87                )
88            }
89            None => (
90                defaults::DEFAULT_RELAY_BUFFER_SIZE,
91                defaults::DEFAULT_TCP_SEND_BUFFER,
92                defaults::DEFAULT_TCP_RECV_BUFFER,
93                defaults::DEFAULT_CONNECTION_BACKLOG,
94            ),
95        };
96
97    let state = Arc::new(ServerState {
98        fallback_addr,
99        max_udp_payload: config.server.max_udp_payload,
100        max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
101        max_header_bytes: config.server.max_header_bytes,
102        tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
103        udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
104        fallback_pool,
105        relay_buffer_size,
106        tcp_send_buffer,
107        tcp_recv_buffer,
108    });
109    let auth = Arc::new(auth);
110    let tracker = ConnectionTracker::new();
111
112    // Connection limiter (None = unlimited)
113    let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
114        info!("max_connections set to {}", n);
115        Arc::new(Semaphore::new(n))
116    });
117
118    // Rate limiter (None = disabled)
119    let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
120        info!(
121            max_per_ip = rl.max_connections_per_ip,
122            window_secs = rl.window_secs,
123            "rate limiting enabled"
124        );
125        let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
126        limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
127        limiter
128    });
129
130    // Create listener with custom backlog using socket2
131    let listener = create_listener(listen, connection_backlog)?;
132    info!(address = %listen, backlog = connection_backlog, "listening");
133
134    loop {
135        tokio::select! {
136            biased;
137
138            _ = shutdown.cancelled() => {
139                info!("shutdown signal received, stopping accept loop");
140                break;
141            }
142
143            result = listener.accept() => {
144                let (tcp, peer) = result?;
145
146                // Update connection queue depth metric (based on semaphore usage)
147                if let Some(ref sem) = conn_limit {
148                    let available = sem.available_permits();
149                    set_connection_queue_depth(available as f64);
150                }
151
152                // Check rate limit first
153                if let Some(ref limiter) = rate_limiter {
154                    let ip = peer.ip();
155                    if !limiter.check_and_increment(ip) {
156                        debug!(peer = %peer, reason = "rate_limit", "connection rejected");
157                        record_connection_rejected("rate_limit");
158                        drop(tcp);
159                        continue;
160                    }
161                }
162
163                // Try to acquire connection permit
164                let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
165                    Some(sem) => match sem.clone().try_acquire_owned() {
166                        Ok(p) => Some(p),
167                        Err(_) => {
168                            debug!(peer = %peer, reason = "max_connections", "connection rejected");
169                            record_connection_rejected("max_connections");
170                            drop(tcp); // close immediately
171                            continue;
172                        }
173                    },
174                    None => None,
175                };
176
177                debug!(peer = %peer, "new connection");
178
179                let acceptor = acceptor.clone();
180                let state = state.clone();
181                let auth = auth.clone();
182                tracker.increment();
183                let guard = ConnectionGuard::new(tracker.clone());
184
185                tokio::spawn(async move {
186                    let _guard = guard; // ensure decrement on drop
187                    let _permit = permit; // hold permit until connection closes
188                    record_connection_accepted();
189                    let start = Instant::now();
190
191                    let result = async {
192                        // Measure TLS handshake duration with timeout
193                        let tls_start = Instant::now();
194                        let tls_timeout = Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
195                        match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
196                            Ok(Ok(tls)) => {
197                                let tls_duration = tls_start.elapsed().as_secs_f64();
198                                record_tls_handshake_duration(tls_duration);
199                                debug!(peer = %peer, duration_ms = tls_duration * 1000.0, "TLS handshake completed");
200                                handle_conn(tls, state, auth, peer).await
201                            }
202                            Ok(Err(err)) => {
203                                record_error(ERROR_TLS_HANDSHAKE);
204                                warn!(peer = %peer, error = %err, "TLS handshake failed");
205                                Ok(())
206                            }
207                            Err(_) => {
208                                record_error(ERROR_TLS_HANDSHAKE);
209                                warn!(peer = %peer, timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
210                                Ok(())
211                            }
212                        }
213                    }
214                    .await;
215
216                    let duration_secs = start.elapsed().as_secs_f64();
217                    record_connection_closed(duration_secs);
218
219                    if let Err(ref err) = result {
220                        record_error(err.error_type());
221                        warn!(peer = %peer, duration_secs, error = %err, "connection closed with error");
222                    } else {
223                        debug!(peer = %peer, duration_secs, "connection closed");
224                    }
225                });
226            }
227        }
228    }
229
230    // Shutdown rate limiter cleanup task
231    if let Some(ref limiter) = rate_limiter {
232        limiter.shutdown();
233    }
234
235    // Graceful drain: wait for active connections
236    let active = tracker.count();
237    if active > 0 {
238        info!("waiting for {} active connections to drain", active);
239        if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
240            info!("all connections drained");
241        } else {
242            warn!(
243                "shutdown timeout, {} connections still active",
244                tracker.count()
245            );
246        }
247    }
248
249    info!("server stopped");
250    Ok(())
251}
252
253/// Run the server (blocking until error, no graceful shutdown).
254/// For backward compatibility with existing code.
255pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
256    run_with_shutdown(config, auth, CancellationToken::new()).await
257}