1use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::sync::{OwnedSemaphorePermit, Semaphore};
8use tokio::time::Instant;
9use tokio_rustls::TlsAcceptor;
10use tokio_util::sync::CancellationToken;
11use tracing::{debug, info, warn};
12
13use crate::error::ServerError;
14use crate::handler::handle_conn;
15use crate::pool::ConnectionPool;
16use crate::rate_limit::RateLimiter;
17use crate::resolve::resolve_sockaddr;
18use crate::state::ServerState;
19use crate::tls::load_tls_config;
20use crate::util::{create_listener, ConnectionGuard, ConnectionTracker};
21use trojan_auth::AuthBackend;
22use trojan_config::Config;
23use trojan_core::defaults;
24use trojan_metrics::{
25 record_connection_accepted, record_connection_closed, record_connection_rejected,
26 record_error, record_tls_handshake_duration, set_connection_queue_depth, ERROR_TLS_HANDSHAKE,
27};
28
29pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
31
32pub async fn run_with_shutdown(
34 config: Config,
35 auth: impl AuthBackend + 'static,
36 shutdown: CancellationToken,
37) -> Result<(), ServerError> {
38 let tls_config = load_tls_config(&config.tls)?;
39 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
40
41 let listen: SocketAddr = config
42 .server
43 .listen
44 .parse()
45 .map_err(|_| ServerError::Config("invalid listen address".into()))?;
46
47 let fallback_addr = resolve_sockaddr(&config.server.fallback).await?;
48
49 let fallback_pool: Option<Arc<ConnectionPool>> =
51 config.server.fallback_pool.as_ref().map(|pool_cfg| {
52 info!(
53 max_idle = pool_cfg.max_idle,
54 max_age_secs = pool_cfg.max_age_secs,
55 fill_batch = pool_cfg.fill_batch,
56 fill_delay_ms = pool_cfg.fill_delay_ms,
57 "fallback connection pool enabled"
58 );
59 let pool = Arc::new(ConnectionPool::new(
60 fallback_addr,
61 pool_cfg.max_idle,
62 pool_cfg.max_age_secs,
63 pool_cfg.fill_batch,
64 pool_cfg.fill_delay_ms,
65 ));
66 pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
68 pool
69 });
70
71 let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
73 match &config.server.resource_limits {
74 Some(rl) => {
75 info!(
76 relay_buffer = rl.relay_buffer_size,
77 tcp_send_buffer = rl.tcp_send_buffer,
78 tcp_recv_buffer = rl.tcp_recv_buffer,
79 connection_backlog = rl.connection_backlog,
80 "resource limits configured"
81 );
82 (
83 rl.relay_buffer_size,
84 rl.tcp_send_buffer,
85 rl.tcp_recv_buffer,
86 rl.connection_backlog,
87 )
88 }
89 None => (
90 defaults::DEFAULT_RELAY_BUFFER_SIZE,
91 defaults::DEFAULT_TCP_SEND_BUFFER,
92 defaults::DEFAULT_TCP_RECV_BUFFER,
93 defaults::DEFAULT_CONNECTION_BACKLOG,
94 ),
95 };
96
97 let state = Arc::new(ServerState {
98 fallback_addr,
99 max_udp_payload: config.server.max_udp_payload,
100 max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
101 max_header_bytes: config.server.max_header_bytes,
102 tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
103 udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
104 fallback_pool,
105 relay_buffer_size,
106 tcp_send_buffer,
107 tcp_recv_buffer,
108 });
109 let auth = Arc::new(auth);
110 let tracker = ConnectionTracker::new();
111
112 let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
114 info!("max_connections set to {}", n);
115 Arc::new(Semaphore::new(n))
116 });
117
118 let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
120 info!(
121 max_per_ip = rl.max_connections_per_ip,
122 window_secs = rl.window_secs,
123 "rate limiting enabled"
124 );
125 let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
126 limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
127 limiter
128 });
129
130 let listener = create_listener(listen, connection_backlog)?;
132 info!(address = %listen, backlog = connection_backlog, "listening");
133
134 loop {
135 tokio::select! {
136 biased;
137
138 _ = shutdown.cancelled() => {
139 info!("shutdown signal received, stopping accept loop");
140 break;
141 }
142
143 result = listener.accept() => {
144 let (tcp, peer) = result?;
145
146 if let Some(ref sem) = conn_limit {
148 let available = sem.available_permits();
149 set_connection_queue_depth(available as f64);
150 }
151
152 if let Some(ref limiter) = rate_limiter {
154 let ip = peer.ip();
155 if !limiter.check_and_increment(ip) {
156 debug!(peer = %peer, reason = "rate_limit", "connection rejected");
157 record_connection_rejected("rate_limit");
158 drop(tcp);
159 continue;
160 }
161 }
162
163 let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
165 Some(sem) => match sem.clone().try_acquire_owned() {
166 Ok(p) => Some(p),
167 Err(_) => {
168 debug!(peer = %peer, reason = "max_connections", "connection rejected");
169 record_connection_rejected("max_connections");
170 drop(tcp); continue;
172 }
173 },
174 None => None,
175 };
176
177 debug!(peer = %peer, "new connection");
178
179 let acceptor = acceptor.clone();
180 let state = state.clone();
181 let auth = auth.clone();
182 tracker.increment();
183 let guard = ConnectionGuard::new(tracker.clone());
184
185 tokio::spawn(async move {
186 let _guard = guard; let _permit = permit; record_connection_accepted();
189 let start = Instant::now();
190
191 let result = async {
192 let tls_start = Instant::now();
194 let tls_timeout = Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
195 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
196 Ok(Ok(tls)) => {
197 let tls_duration = tls_start.elapsed().as_secs_f64();
198 record_tls_handshake_duration(tls_duration);
199 debug!(peer = %peer, duration_ms = tls_duration * 1000.0, "TLS handshake completed");
200 handle_conn(tls, state, auth, peer).await
201 }
202 Ok(Err(err)) => {
203 record_error(ERROR_TLS_HANDSHAKE);
204 warn!(peer = %peer, error = %err, "TLS handshake failed");
205 Ok(())
206 }
207 Err(_) => {
208 record_error(ERROR_TLS_HANDSHAKE);
209 warn!(peer = %peer, timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
210 Ok(())
211 }
212 }
213 }
214 .await;
215
216 let duration_secs = start.elapsed().as_secs_f64();
217 record_connection_closed(duration_secs);
218
219 if let Err(ref err) = result {
220 record_error(err.error_type());
221 warn!(peer = %peer, duration_secs, error = %err, "connection closed with error");
222 } else {
223 debug!(peer = %peer, duration_secs, "connection closed");
224 }
225 });
226 }
227 }
228 }
229
230 if let Some(ref limiter) = rate_limiter {
232 limiter.shutdown();
233 }
234
235 let active = tracker.count();
237 if active > 0 {
238 info!("waiting for {} active connections to drain", active);
239 if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
240 info!("all connections drained");
241 } else {
242 warn!(
243 "shutdown timeout, {} connections still active",
244 tracker.count()
245 );
246 }
247 }
248
249 info!("server stopped");
250 Ok(())
251}
252
253pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
256 run_with_shutdown(config, auth, CancellationToken::new()).await
257}