1use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, apply_tcp_options, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_metrics::{
26 ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
27 record_connection_rejected, record_error, record_tls_handshake_duration,
28 set_connection_queue_depth,
29};
30
31pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
33
34static CONN_ID: AtomicU64 = AtomicU64::new(1);
36
37#[inline]
39fn next_conn_id() -> u64 {
40 CONN_ID.fetch_add(1, Ordering::Relaxed)
41}
42
43pub async fn run_with_shutdown(
45 config: Config,
46 auth: impl AuthBackend + 'static,
47 shutdown: CancellationToken,
48) -> Result<(), ServerError> {
49 let tls_config = load_tls_config(&config.tls)?;
50 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
51
52 let listen: SocketAddr = config
53 .server
54 .listen
55 .parse()
56 .map_err(|_| ServerError::Config("invalid listen address".into()))?;
57
58 let fallback_addr =
59 resolve_sockaddr(&config.server.fallback, config.server.tcp.prefer_ipv4).await?;
60
61 let fallback_pool: Option<Arc<ConnectionPool>> =
63 config.server.fallback_pool.as_ref().map(|pool_cfg| {
64 info!(
65 max_idle = pool_cfg.max_idle,
66 max_age_secs = pool_cfg.max_age_secs,
67 fill_batch = pool_cfg.fill_batch,
68 fill_delay_ms = pool_cfg.fill_delay_ms,
69 "fallback connection pool enabled"
70 );
71 let pool = Arc::new(ConnectionPool::new(
72 fallback_addr,
73 pool_cfg.max_idle,
74 pool_cfg.max_age_secs,
75 pool_cfg.fill_batch,
76 pool_cfg.fill_delay_ms,
77 ));
78 pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
80 pool
81 });
82
83 let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
85 match &config.server.resource_limits {
86 Some(rl) => {
87 info!(
88 relay_buffer = rl.relay_buffer_size,
89 tcp_send_buffer = rl.tcp_send_buffer,
90 tcp_recv_buffer = rl.tcp_recv_buffer,
91 connection_backlog = rl.connection_backlog,
92 "resource limits configured"
93 );
94 (
95 rl.relay_buffer_size,
96 rl.tcp_send_buffer,
97 rl.tcp_recv_buffer,
98 rl.connection_backlog,
99 )
100 }
101 None => (
102 defaults::DEFAULT_RELAY_BUFFER_SIZE,
103 defaults::DEFAULT_TCP_SEND_BUFFER,
104 defaults::DEFAULT_TCP_RECV_BUFFER,
105 defaults::DEFAULT_CONNECTION_BACKLOG,
106 ),
107 };
108
109 #[cfg(feature = "analytics")]
111 let analytics = if config.analytics.enabled {
112 match trojan_analytics::init(config.analytics.clone()).await {
113 Ok(collector) => {
114 info!("analytics enabled, sending to ClickHouse");
115 Some(collector)
116 }
117 Err(e) => {
118 warn!("failed to init analytics: {}, disabled", e);
119 None
120 }
121 }
122 } else {
123 debug!("analytics disabled in config");
124 None
125 };
126
127 #[cfg(feature = "rules")]
129 let rule_engine = if !config.server.rules.is_empty() {
130 match crate::rules::build_rule_engine(&config.server) {
131 Ok(engine) => {
132 info!(
133 rule_sets = engine.rule_set_count(),
134 rules = engine.rule_count(),
135 "rule engine initialized"
136 );
137 Some(Arc::new(trojan_rules::HotRuleEngine::new(engine)))
138 }
139 Err(e) => {
140 return Err(ServerError::Rules(format!("failed to init rules: {e}")));
141 }
142 }
143 } else {
144 debug!("no routing rules configured");
145 None
146 };
147
148 #[cfg(feature = "rules")]
150 if let Some(ref hot_engine) = rule_engine
151 && crate::rules::has_http_providers(&config.server)
152 {
153 let interval_secs = crate::rules::http_update_interval(&config.server).unwrap_or(3600); let engine_ref = hot_engine.clone();
155 let server_cfg = config.server.clone();
156 let update_shutdown = shutdown.clone();
157 info!(interval_secs, "starting background rule update task");
158 tokio::spawn(async move {
159 rule_update_loop(engine_ref, server_cfg, interval_secs, update_shutdown).await;
160 });
161 }
162
163 #[cfg(feature = "rules")]
165 let outbounds = {
166 let mut map = std::collections::HashMap::new();
167 for (name, outbound_cfg) in &config.server.outbounds {
168 match crate::outbound::Outbound::from_config(name, outbound_cfg) {
169 Ok(outbound) => {
170 info!(name = %name, "outbound connector configured");
171 map.insert(name.clone(), Arc::new(outbound));
172 }
173 Err(e) => {
174 return Err(ServerError::Config(format!("outbound '{name}': {e}")));
175 }
176 }
177 }
178 map
179 };
180
181 #[cfg(feature = "geoip")]
184 #[allow(unused_variables)]
185 let (geoip_server, geoip_metrics, geoip_analytics) =
186 load_geoip_databases(&config, &shutdown).await;
187
188 if let Some(ref listen) = config.metrics.listen {
190 #[cfg(feature = "rules")]
191 let extra_routes = rule_engine
192 .as_ref()
193 .map(|engine| crate::debug_api::debug_routes(engine.clone()));
194 #[cfg(not(feature = "rules"))]
195 let extra_routes: Option<axum::Router> = None;
196
197 match trojan_metrics::init_metrics_server(listen, extra_routes) {
198 Ok(_handle) => {
199 #[cfg(feature = "rules")]
200 let endpoints = if rule_engine.is_some() {
201 "/metrics, /health, /ready, /debug/rules/match"
202 } else {
203 "/metrics, /health, /ready"
204 };
205 #[cfg(not(feature = "rules"))]
206 let endpoints = "/metrics, /health, /ready";
207 info!("metrics server listening on {} ({})", listen, endpoints);
208 }
209 Err(e) => warn!("failed to start metrics server: {}", e),
210 }
211 }
212
213 let tcp_cfg = &config.server.tcp;
215 info!(
216 no_delay = tcp_cfg.no_delay,
217 keepalive_secs = tcp_cfg.keepalive_secs,
218 reuse_port = tcp_cfg.reuse_port,
219 fast_open = tcp_cfg.fast_open,
220 "TCP options configured"
221 );
222
223 let state = Arc::new(ServerState {
224 fallback_addr,
225 max_udp_payload: config.server.max_udp_payload,
226 max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
227 max_header_bytes: config.server.max_header_bytes,
228 tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
229 udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
230 fallback_pool,
231 relay_buffer_size,
232 tcp_send_buffer,
233 tcp_recv_buffer,
234 tcp_config: config.server.tcp.clone(),
235 websocket: config.websocket.clone(),
236 #[cfg(feature = "analytics")]
237 analytics,
238 #[cfg(feature = "rules")]
239 rule_engine,
240 #[cfg(feature = "rules")]
241 outbounds,
242 #[cfg(feature = "geoip")]
243 geoip_metrics,
244 #[cfg(all(feature = "geoip", feature = "analytics"))]
245 geoip_analytics,
246 });
247 let auth = Arc::new(auth);
248 let tracker = ConnectionTracker::new();
249
250 let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
252 info!("max_connections set to {}", n);
253 Arc::new(Semaphore::new(n))
254 });
255
256 let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
258 info!(
259 max_per_ip = rl.max_connections_per_ip,
260 window_secs = rl.window_secs,
261 "rate limiting enabled"
262 );
263 let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
264 limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
265 limiter
266 });
267
268 let listener = create_listener(listen, connection_backlog, &config.server.tcp)?;
270 info!(address = %listen, backlog = connection_backlog, "listening");
271
272 #[cfg(feature = "ws")]
273 if config.websocket.enabled && config.websocket.mode == "split" {
274 let ws_listen = config.websocket.listen.clone().unwrap_or_default();
275 let ws_addr: SocketAddr = ws_listen
276 .parse()
277 .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
278 let ws_listener = create_listener(ws_addr, connection_backlog, &config.server.tcp)?;
279 let ws_acceptor = acceptor.clone();
280 let ws_state = state.clone();
281 let ws_auth = auth.clone();
282 let ws_tracker = tracker.clone();
283 let ws_conn_limit = conn_limit.clone();
284 let ws_rate_limiter = rate_limiter.clone();
285 let ws_shutdown = shutdown.clone();
286
287 info!(address = %ws_addr, "websocket split listener started");
288 tokio::spawn(async move {
289 loop {
290 tokio::select! {
291 biased;
292 _ = ws_shutdown.cancelled() => break,
293 result = ws_listener.accept() => {
294 let (tcp, peer) = match result {
295 Ok(v) => v,
296 Err(_) => continue,
297 };
298
299 if let Err(e) = apply_tcp_options(&tcp, &ws_state.tcp_config) {
301 tracing::debug!(error = %e, "failed to apply TCP options");
302 }
303
304 if let Some(ref limiter) = ws_rate_limiter {
305 let ip = peer.ip();
306 if !limiter.check_and_increment(ip) {
307 record_connection_rejected("rate_limit");
308 drop(tcp);
309 continue;
310 }
311 }
312
313 let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
314 Some(sem) => match sem.clone().try_acquire_owned() {
315 Ok(p) => Some(p),
316 Err(_) => {
317 record_connection_rejected("max_connections");
318 drop(tcp);
319 continue;
320 }
321 },
322 None => None,
323 };
324
325 let conn_id = next_conn_id();
326 let acceptor = ws_acceptor.clone();
327 let state = ws_state.clone();
328 let auth = ws_auth.clone();
329 ws_tracker.increment();
330 let guard = ConnectionGuard::new(ws_tracker.clone());
331
332 let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
333 tokio::spawn(
334 async move {
335 let _guard = guard;
336 let _permit = permit;
337 record_connection_accepted();
338 let start = Instant::now();
339
340 let result = async {
341 let tls_start = Instant::now();
342 let tls_timeout =
343 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
344 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
345 {
346 Ok(Ok(tls)) => {
347 let tls_duration = tls_start.elapsed().as_secs_f64();
348 record_tls_handshake_duration(tls_duration);
349 crate::handler::handle_ws_only(tls, state, auth, peer).await
350 }
351 Ok(Err(err)) => {
352 record_error(ERROR_TLS_HANDSHAKE);
353 warn!(error = %err, "TLS handshake failed");
354 Ok(())
355 }
356 Err(_) => {
357 record_error(ERROR_TLS_HANDSHAKE);
358 warn!(
359 timeout_secs = tls_timeout.as_secs(),
360 "TLS handshake timed out"
361 );
362 Ok(())
363 }
364 }
365 }
366 .await;
367
368 let duration_secs = start.elapsed().as_secs_f64();
369 record_connection_closed(duration_secs);
370
371 if let Err(ref err) = result {
372 warn!(error = %err, "connection error");
373 }
374 }
375 .instrument(span),
376 );
377 }
378 }
379 }
380 });
381 }
382
383 #[cfg(not(feature = "ws"))]
384 if config.websocket.enabled {
385 warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
386 }
387
388 loop {
389 tokio::select! {
390 biased;
391
392 _ = shutdown.cancelled() => {
393 info!("shutdown signal received, stopping accept loop");
394 break;
395 }
396
397 result = listener.accept() => {
398 let (tcp, peer) = result?;
399
400 if let Err(e) = apply_tcp_options(&tcp, &state.tcp_config) {
402 debug!(error = %e, "failed to apply TCP options");
403 }
404
405 if let Some(ref sem) = conn_limit {
407 let available = sem.available_permits();
408 set_connection_queue_depth(available as f64);
409 }
410
411 if let Some(ref limiter) = rate_limiter {
413 let ip = peer.ip();
414 if !limiter.check_and_increment(ip) {
415 debug!(peer = %peer, reason = "rate_limit", "connection rejected");
416 record_connection_rejected("rate_limit");
417 drop(tcp);
418 continue;
419 }
420 }
421
422 let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
424 Some(sem) => match sem.clone().try_acquire_owned() {
425 Ok(p) => Some(p),
426 Err(_) => {
427 debug!(peer = %peer, reason = "max_connections", "connection rejected");
428 record_connection_rejected("max_connections");
429 drop(tcp); continue;
431 }
432 },
433 None => None,
434 };
435
436 let conn_id = next_conn_id();
437 debug!(conn_id, peer = %peer, "new connection");
438
439 let acceptor = acceptor.clone();
440 let state = state.clone();
441 let auth = auth.clone();
442 tracker.increment();
443 let guard = ConnectionGuard::new(tracker.clone());
444
445 let span = info_span!("conn", id = conn_id, peer = %peer);
446 tokio::spawn(
447 async move {
448 let _guard = guard; let _permit = permit; record_connection_accepted();
451 let start = Instant::now();
452
453 let result = async {
454 let tls_start = Instant::now();
456 let tls_timeout =
457 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
458 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
459 Ok(Ok(tls)) => {
460 let tls_duration = tls_start.elapsed().as_secs_f64();
461 record_tls_handshake_duration(tls_duration);
462 debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
463 handle_conn(tls, state, auth, peer).await
464 }
465 Ok(Err(err)) => {
466 record_error(ERROR_TLS_HANDSHAKE);
467 warn!(error = %err, "TLS handshake failed");
468 Ok(())
469 }
470 Err(_) => {
471 record_error(ERROR_TLS_HANDSHAKE);
472 warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
473 Ok(())
474 }
475 }
476 }
477 .await;
478
479 let duration_secs = start.elapsed().as_secs_f64();
480 record_connection_closed(duration_secs);
481
482 if let Err(ref err) = result {
483 record_error(err.error_type());
484 warn!(duration_secs, error = %err, "connection closed with error");
485 } else {
486 debug!(duration_secs, "connection closed");
487 }
488 }
489 .instrument(span),
490 );
491 }
492 }
493 }
494
495 if let Some(ref limiter) = rate_limiter {
497 limiter.shutdown();
498 }
499
500 let active = tracker.count();
502 if active > 0 {
503 info!("waiting for {} active connections to drain", active);
504 if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
505 info!("all connections drained");
506 } else {
507 warn!(
508 "shutdown timeout, {} connections still active",
509 tracker.count()
510 );
511 }
512 }
513
514 info!("server stopped");
515 Ok(())
516}
517
518pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
521 run_with_shutdown(config, auth, CancellationToken::new()).await
522}
523
524#[cfg(feature = "geoip")]
532#[allow(unused_variables)]
533async fn load_geoip_databases(
534 config: &Config,
535 shutdown: &CancellationToken,
536) -> (
537 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
538 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
539 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
540) {
541 use std::collections::HashMap;
542 use trojan_rules::geoip_db::GeoipDb;
543
544 type Key = (Option<String>, Option<String>, String);
546 let mut loaded: HashMap<Key, Arc<GeoipDb>> = HashMap::new();
547
548 let mut auto_update_configs: Vec<(trojan_config::GeoipConfig, Arc<GeoipDb>)> = Vec::new();
550
551 async fn load_or_share(
553 cfg: &trojan_config::GeoipConfig,
554 loaded: &mut HashMap<Key, Arc<GeoipDb>>,
555 ) -> Option<Arc<GeoipDb>> {
556 let key: Key = (cfg.path.clone(), cfg.url.clone(), cfg.source.clone());
557 if let Some(existing) = loaded.get(&key) {
558 return Some(existing.clone());
559 }
560 match trojan_rules::geoip_db::load_geoip(cfg).await {
561 Ok(db) => {
562 let arc = Arc::new(db);
563 loaded.insert(key, arc.clone());
564 Some(arc)
565 }
566 Err(e) => {
567 warn!(source = %cfg.source, error = %e, "failed to load GeoIP database");
568 None
569 }
570 }
571 }
572
573 let server_geoip = if let Some(cfg) = config.server.geoip.as_ref() {
575 load_or_share(cfg, &mut loaded).await
576 } else {
577 None
578 };
579
580 let metrics_geoip = if let Some(cfg) = config.metrics.geoip.as_ref() {
582 let result = load_or_share(cfg, &mut loaded).await;
583 if let Some(ref db) = result
584 && cfg.auto_update
585 && cfg.path.is_none()
586 {
587 auto_update_configs.push((cfg.clone(), db.clone()));
588 }
589 result
590 } else {
591 server_geoip.clone() };
593
594 #[cfg(feature = "analytics")]
596 let analytics_geoip = if let Some(cfg) = config.analytics.geoip.as_ref() {
597 let result = load_or_share(cfg, &mut loaded).await;
598 if let Some(ref db) = result
599 && cfg.auto_update
600 && cfg.path.is_none()
601 {
602 auto_update_configs.push((cfg.clone(), db.clone()));
603 }
604 result
605 } else {
606 None
607 };
608 #[cfg(not(feature = "analytics"))]
609 let analytics_geoip: Option<Arc<GeoipDb>> = None;
610
611 if !loaded.is_empty() {
612 info!(
613 databases = loaded.len(),
614 "GeoIP databases loaded (deduplicated)"
615 );
616 }
617
618 {
620 let mut seen_ptrs = std::collections::HashSet::new();
622 for (cfg, db) in auto_update_configs {
623 let ptr = Arc::as_ptr(&db) as usize;
624 if !seen_ptrs.insert(ptr) {
625 continue; }
627 let cancel = shutdown.clone();
628 let source = cfg.source.clone();
629 info!(source = %source, "spawning GeoIP auto-update task");
630 let swappable = Arc::new(arc_swap::ArcSwap::from(db));
631 tokio::spawn(trojan_rules::geoip_db::geoip_auto_update_task(
632 cfg,
633 swappable,
634 cancel,
635 move |success| {
636 if success {
637 trojan_metrics::record_rule_update();
638 } else {
639 trojan_metrics::record_rule_update_error();
640 }
641 },
642 ));
643 }
644 }
645
646 (server_geoip, metrics_geoip, analytics_geoip)
647}
648
649#[cfg(feature = "rules")]
651async fn rule_update_loop(
652 engine: Arc<trojan_rules::HotRuleEngine>,
653 server_config: trojan_config::ServerConfig,
654 interval_secs: u64,
655 shutdown: CancellationToken,
656) {
657 use std::time::Duration;
658 use trojan_metrics::{record_rule_update, record_rule_update_error};
659
660 match crate::rules::build_rule_engine_async(&server_config).await {
662 Ok(new_engine) => {
663 info!(
664 rule_sets = new_engine.rule_set_count(),
665 rules = new_engine.rule_count(),
666 "initial rule fetch completed, engine updated"
667 );
668 engine.update(new_engine);
669 record_rule_update();
670 }
671 Err(e) => {
672 warn!(error = %e, "initial rule fetch failed, keeping startup rules");
673 record_rule_update_error();
674 }
675 }
676
677 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
678 interval.tick().await; loop {
681 tokio::select! {
682 biased;
683 _ = shutdown.cancelled() => {
684 debug!("rule update task shutting down");
685 return;
686 }
687 _ = interval.tick() => {
688 debug!("starting scheduled rule update");
689 match crate::rules::build_rule_engine_async(&server_config).await {
690 Ok(new_engine) => {
691 info!(
692 rule_sets = new_engine.rule_set_count(),
693 rules = new_engine.rule_count(),
694 "rule update completed, engine swapped"
695 );
696 engine.update(new_engine);
697 record_rule_update();
698 }
699 Err(e) => {
700 warn!(error = %e, "rule update failed, keeping current rules");
701 record_rule_update_error();
702 }
703 }
704 }
705 }
706 }
707}