1use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, apply_tcp_options, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_dns::DnsResolver;
26use trojan_metrics::{
27 ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
28 record_connection_rejected, record_error, record_tls_handshake_duration,
29 set_connection_queue_depth,
30};
31
32pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
34
35static CONN_ID: AtomicU64 = AtomicU64::new(1);
37
38#[inline]
40fn next_conn_id() -> u64 {
41 CONN_ID.fetch_add(1, Ordering::Relaxed)
42}
43
44pub async fn run_with_shutdown(
46 config: Config,
47 auth: impl AuthBackend + 'static,
48 shutdown: CancellationToken,
49) -> Result<(), ServerError> {
50 let tls_config = load_tls_config(&config.tls)?;
51 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
52
53 let listen: SocketAddr = config
54 .server
55 .listen
56 .parse()
57 .map_err(|_| ServerError::Config("invalid listen address".into()))?;
58
59 let mut dns_config = config.dns.clone();
62 if config.server.tcp.prefer_ipv4 && !dns_config.prefer_ipv4 {
63 dns_config.prefer_ipv4 = true;
64 info!(
65 "server.tcp.prefer_ipv4 is deprecated; mapped to dns.prefer_ipv4 for backward compatibility"
66 );
67 }
68 let dns_resolver = DnsResolver::new(&dns_config)
69 .map_err(|e| ServerError::Config(format!("dns resolver: {e}")))?;
70 info!(
71 dns = ?dns_config.strategy,
72 prefer_ipv4 = dns_config.prefer_ipv4,
73 "dns resolver initialized"
74 );
75
76 let fallback_addr = resolve_sockaddr(&config.server.fallback, &dns_resolver).await?;
77
78 let fallback_pool: Option<Arc<ConnectionPool>> =
80 config.server.fallback_pool.as_ref().map(|pool_cfg| {
81 info!(
82 max_idle = pool_cfg.max_idle,
83 max_age_secs = pool_cfg.max_age_secs,
84 fill_batch = pool_cfg.fill_batch,
85 fill_delay_ms = pool_cfg.fill_delay_ms,
86 "fallback connection pool enabled"
87 );
88 let pool = Arc::new(ConnectionPool::new(
89 fallback_addr,
90 pool_cfg.max_idle,
91 pool_cfg.max_age_secs,
92 pool_cfg.fill_batch,
93 pool_cfg.fill_delay_ms,
94 ));
95 pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
97 pool
98 });
99
100 let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
102 match &config.server.resource_limits {
103 Some(rl) => {
104 info!(
105 relay_buffer = rl.relay_buffer_size,
106 tcp_send_buffer = rl.tcp_send_buffer,
107 tcp_recv_buffer = rl.tcp_recv_buffer,
108 connection_backlog = rl.connection_backlog,
109 "resource limits configured"
110 );
111 (
112 rl.relay_buffer_size,
113 rl.tcp_send_buffer,
114 rl.tcp_recv_buffer,
115 rl.connection_backlog,
116 )
117 }
118 None => (
119 defaults::DEFAULT_RELAY_BUFFER_SIZE,
120 defaults::DEFAULT_TCP_SEND_BUFFER,
121 defaults::DEFAULT_TCP_RECV_BUFFER,
122 defaults::DEFAULT_CONNECTION_BACKLOG,
123 ),
124 };
125
126 #[cfg(feature = "analytics")]
128 let analytics = if config.analytics.enabled {
129 match trojan_analytics::init(config.analytics.clone()).await {
130 Ok(collector) => {
131 info!("analytics enabled, sending to ClickHouse");
132 Some(collector)
133 }
134 Err(e) => {
135 warn!("failed to init analytics: {}, disabled", e);
136 None
137 }
138 }
139 } else {
140 debug!("analytics disabled in config");
141 None
142 };
143
144 #[cfg(feature = "rules")]
146 let rule_engine = if !config.server.rules.is_empty() {
147 match crate::rules::build_rule_engine(&config.server) {
148 Ok(engine) => {
149 info!(
150 rule_sets = engine.rule_set_count(),
151 rules = engine.rule_count(),
152 "rule engine initialized"
153 );
154 Some(Arc::new(trojan_rules::HotRuleEngine::new(engine)))
155 }
156 Err(e) => {
157 return Err(ServerError::Rules(format!("failed to init rules: {e}")));
158 }
159 }
160 } else {
161 debug!("no routing rules configured");
162 None
163 };
164
165 #[cfg(feature = "rules")]
167 if let Some(ref hot_engine) = rule_engine
168 && crate::rules::has_http_providers(&config.server)
169 {
170 let interval_secs = crate::rules::http_update_interval(&config.server).unwrap_or(3600); let engine_ref = hot_engine.clone();
172 let server_cfg = config.server.clone();
173 let update_shutdown = shutdown.clone();
174 info!(interval_secs, "starting background rule update task");
175 tokio::spawn(async move {
176 rule_update_loop(engine_ref, server_cfg, interval_secs, update_shutdown).await;
177 });
178 }
179
180 #[cfg(feature = "rules")]
182 let outbounds = {
183 let mut map = std::collections::HashMap::new();
184 for (name, outbound_cfg) in &config.server.outbounds {
185 match crate::outbound::Outbound::from_config(name, outbound_cfg) {
186 Ok(outbound) => {
187 info!(name = %name, "outbound connector configured");
188 map.insert(name.clone(), Arc::new(outbound));
189 }
190 Err(e) => {
191 return Err(ServerError::Config(format!("outbound '{name}': {e}")));
192 }
193 }
194 }
195 map
196 };
197
198 #[cfg(feature = "geoip")]
201 #[allow(unused_variables)]
202 let (geoip_server, geoip_metrics, geoip_analytics) =
203 load_geoip_databases(&config, &shutdown).await;
204
205 if let Some(ref listen) = config.metrics.listen {
207 #[cfg(feature = "rules")]
208 let extra_routes = rule_engine
209 .as_ref()
210 .map(|engine| crate::debug_api::debug_routes(engine.clone()));
211 #[cfg(not(feature = "rules"))]
212 let extra_routes: Option<axum::Router> = None;
213
214 match trojan_metrics::init_metrics_server(listen, extra_routes) {
215 Ok(_handle) => {
216 #[cfg(feature = "rules")]
217 let endpoints = if rule_engine.is_some() {
218 "/metrics, /health, /ready, /debug/rules/match"
219 } else {
220 "/metrics, /health, /ready"
221 };
222 #[cfg(not(feature = "rules"))]
223 let endpoints = "/metrics, /health, /ready";
224 info!("metrics server listening on {} ({})", listen, endpoints);
225 }
226 Err(e) => warn!("failed to start metrics server: {}", e),
227 }
228 }
229
230 #[cfg(feature = "ddns")]
232 if config.ddns.enabled {
233 let ddns_config = config.ddns.clone();
234 let ddns_shutdown = shutdown.clone();
235 info!("starting DDNS update task");
236 tokio::spawn(async move {
237 trojan_ddns::ddns_loop(ddns_config, ddns_shutdown).await;
238 });
239 }
240
241 let tcp_cfg = &config.server.tcp;
243 info!(
244 no_delay = tcp_cfg.no_delay,
245 keepalive_secs = tcp_cfg.keepalive_secs,
246 reuse_port = tcp_cfg.reuse_port,
247 fast_open = tcp_cfg.fast_open,
248 "TCP options configured"
249 );
250
251 let state = Arc::new(ServerState {
252 fallback_addr,
253 max_udp_payload: config.server.max_udp_payload,
254 max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
255 max_header_bytes: config.server.max_header_bytes,
256 tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
257 udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
258 fallback_pool,
259 relay_buffer_size,
260 tcp_send_buffer,
261 tcp_recv_buffer,
262 tcp_config: config.server.tcp.clone(),
263 websocket: config.websocket.clone(),
264 dns_resolver,
265 #[cfg(feature = "analytics")]
266 analytics,
267 #[cfg(feature = "rules")]
268 rule_engine,
269 #[cfg(feature = "rules")]
270 outbounds,
271 #[cfg(feature = "geoip")]
272 geoip_metrics,
273 #[cfg(all(feature = "geoip", feature = "analytics"))]
274 geoip_analytics,
275 });
276 let auth = Arc::new(auth);
277 let tracker = ConnectionTracker::new();
278
279 let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
281 info!("max_connections set to {}", n);
282 Arc::new(Semaphore::new(n))
283 });
284
285 let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
287 info!(
288 max_per_ip = rl.max_connections_per_ip,
289 window_secs = rl.window_secs,
290 "rate limiting enabled"
291 );
292 let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
293 limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
294 limiter
295 });
296
297 let listener = create_listener(listen, connection_backlog, &config.server.tcp)?;
299 info!(address = %listen, backlog = connection_backlog, "listening");
300
301 #[cfg(feature = "ws")]
302 if config.websocket.enabled && config.websocket.mode == "split" {
303 let ws_listen = config.websocket.listen.clone().unwrap_or_default();
304 let ws_addr: SocketAddr = ws_listen
305 .parse()
306 .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
307 let ws_listener = create_listener(ws_addr, connection_backlog, &config.server.tcp)?;
308 let ws_acceptor = acceptor.clone();
309 let ws_state = state.clone();
310 let ws_auth = auth.clone();
311 let ws_tracker = tracker.clone();
312 let ws_conn_limit = conn_limit.clone();
313 let ws_rate_limiter = rate_limiter.clone();
314 let ws_shutdown = shutdown.clone();
315
316 info!(address = %ws_addr, "websocket split listener started");
317 tokio::spawn(async move {
318 loop {
319 tokio::select! {
320 biased;
321 _ = ws_shutdown.cancelled() => break,
322 result = ws_listener.accept() => {
323 let (tcp, peer) = match result {
324 Ok(v) => v,
325 Err(_) => continue,
326 };
327
328 if let Err(e) = apply_tcp_options(&tcp, &ws_state.tcp_config) {
330 tracing::debug!(error = %e, "failed to apply TCP options");
331 }
332
333 if let Some(ref limiter) = ws_rate_limiter {
334 let ip = peer.ip();
335 if !limiter.check_and_increment(ip) {
336 record_connection_rejected("rate_limit");
337 drop(tcp);
338 continue;
339 }
340 }
341
342 let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
343 Some(sem) => match sem.clone().try_acquire_owned() {
344 Ok(p) => Some(p),
345 Err(_) => {
346 record_connection_rejected("max_connections");
347 drop(tcp);
348 continue;
349 }
350 },
351 None => None,
352 };
353
354 let conn_id = next_conn_id();
355 let acceptor = ws_acceptor.clone();
356 let state = ws_state.clone();
357 let auth = ws_auth.clone();
358 ws_tracker.increment();
359 let guard = ConnectionGuard::new(ws_tracker.clone());
360
361 let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
362 tokio::spawn(
363 async move {
364 let _guard = guard;
365 let _permit = permit;
366 record_connection_accepted();
367 let start = Instant::now();
368
369 let result = async {
370 let tls_start = Instant::now();
371 let tls_timeout =
372 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
373 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
374 {
375 Ok(Ok(tls)) => {
376 let tls_duration = tls_start.elapsed().as_secs_f64();
377 record_tls_handshake_duration(tls_duration);
378 crate::handler::handle_ws_only(tls, state, auth, peer).await
379 }
380 Ok(Err(err)) => {
381 record_error(ERROR_TLS_HANDSHAKE);
382 warn!(error = %err, "TLS handshake failed");
383 Ok(())
384 }
385 Err(_) => {
386 record_error(ERROR_TLS_HANDSHAKE);
387 warn!(
388 timeout_secs = tls_timeout.as_secs(),
389 "TLS handshake timed out"
390 );
391 Ok(())
392 }
393 }
394 }
395 .await;
396
397 let duration_secs = start.elapsed().as_secs_f64();
398 record_connection_closed(duration_secs);
399
400 if let Err(ref err) = result {
401 warn!(error = %err, "connection error");
402 }
403 }
404 .instrument(span),
405 );
406 }
407 }
408 }
409 });
410 }
411
412 #[cfg(not(feature = "ws"))]
413 if config.websocket.enabled {
414 warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
415 }
416
417 loop {
418 tokio::select! {
419 biased;
420
421 _ = shutdown.cancelled() => {
422 info!("shutdown signal received, stopping accept loop");
423 break;
424 }
425
426 result = listener.accept() => {
427 let (tcp, peer) = result?;
428
429 if let Err(e) = apply_tcp_options(&tcp, &state.tcp_config) {
431 debug!(error = %e, "failed to apply TCP options");
432 }
433
434 if let Some(ref sem) = conn_limit {
436 let available = sem.available_permits();
437 set_connection_queue_depth(available as f64);
438 }
439
440 if let Some(ref limiter) = rate_limiter {
442 let ip = peer.ip();
443 if !limiter.check_and_increment(ip) {
444 debug!(peer = %peer, reason = "rate_limit", "connection rejected");
445 record_connection_rejected("rate_limit");
446 drop(tcp);
447 continue;
448 }
449 }
450
451 let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
453 Some(sem) => match sem.clone().try_acquire_owned() {
454 Ok(p) => Some(p),
455 Err(_) => {
456 debug!(peer = %peer, reason = "max_connections", "connection rejected");
457 record_connection_rejected("max_connections");
458 drop(tcp); continue;
460 }
461 },
462 None => None,
463 };
464
465 let conn_id = next_conn_id();
466 debug!(conn_id, peer = %peer, "new connection");
467
468 let acceptor = acceptor.clone();
469 let state = state.clone();
470 let auth = auth.clone();
471 tracker.increment();
472 let guard = ConnectionGuard::new(tracker.clone());
473
474 let span = info_span!("conn", id = conn_id, peer = %peer);
475 tokio::spawn(
476 async move {
477 let _guard = guard; let _permit = permit; record_connection_accepted();
480 let start = Instant::now();
481
482 let result = async {
483 let tls_start = Instant::now();
485 let tls_timeout =
486 Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
487 match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
488 Ok(Ok(tls)) => {
489 let tls_duration = tls_start.elapsed().as_secs_f64();
490 record_tls_handshake_duration(tls_duration);
491 debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
492 handle_conn(tls, state, auth, peer).await
493 }
494 Ok(Err(err)) => {
495 record_error(ERROR_TLS_HANDSHAKE);
496 warn!(error = %err, "TLS handshake failed");
497 Ok(())
498 }
499 Err(_) => {
500 record_error(ERROR_TLS_HANDSHAKE);
501 warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
502 Ok(())
503 }
504 }
505 }
506 .await;
507
508 let duration_secs = start.elapsed().as_secs_f64();
509 record_connection_closed(duration_secs);
510
511 if let Err(ref err) = result {
512 record_error(err.error_type());
513 warn!(duration_secs, error = %err, "connection closed with error");
514 } else {
515 debug!(duration_secs, "connection closed");
516 }
517 }
518 .instrument(span),
519 );
520 }
521 }
522 }
523
524 if let Some(ref limiter) = rate_limiter {
526 limiter.shutdown();
527 }
528
529 let active = tracker.count();
531 if active > 0 {
532 info!("waiting for {} active connections to drain", active);
533 if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
534 info!("all connections drained");
535 } else {
536 warn!(
537 "shutdown timeout, {} connections still active",
538 tracker.count()
539 );
540 }
541 }
542
543 info!("server stopped");
544 Ok(())
545}
546
547pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
550 run_with_shutdown(config, auth, CancellationToken::new()).await
551}
552
553#[cfg(feature = "geoip")]
561#[allow(unused_variables)]
562async fn load_geoip_databases(
563 config: &Config,
564 shutdown: &CancellationToken,
565) -> (
566 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
567 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
568 Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
569) {
570 use std::collections::HashMap;
571 use trojan_rules::geoip_db::GeoipDb;
572
573 type Key = (Option<String>, Option<String>, String);
575 let mut loaded: HashMap<Key, Arc<GeoipDb>> = HashMap::new();
576
577 let mut auto_update_configs: Vec<(trojan_config::GeoipConfig, Arc<GeoipDb>)> = Vec::new();
579
580 async fn load_or_share(
582 cfg: &trojan_config::GeoipConfig,
583 loaded: &mut HashMap<Key, Arc<GeoipDb>>,
584 ) -> Option<Arc<GeoipDb>> {
585 let key: Key = (cfg.path.clone(), cfg.url.clone(), cfg.source.clone());
586 if let Some(existing) = loaded.get(&key) {
587 return Some(existing.clone());
588 }
589 match trojan_rules::geoip_db::load_geoip(cfg).await {
590 Ok(db) => {
591 let arc = Arc::new(db);
592 loaded.insert(key, arc.clone());
593 Some(arc)
594 }
595 Err(e) => {
596 warn!(source = %cfg.source, error = %e, "failed to load GeoIP database");
597 None
598 }
599 }
600 }
601
602 let server_geoip = if let Some(cfg) = config.server.geoip.as_ref() {
604 load_or_share(cfg, &mut loaded).await
605 } else {
606 None
607 };
608
609 let metrics_geoip = if let Some(cfg) = config.metrics.geoip.as_ref() {
611 let result = load_or_share(cfg, &mut loaded).await;
612 if let Some(ref db) = result
613 && cfg.auto_update
614 && cfg.path.is_none()
615 {
616 auto_update_configs.push((cfg.clone(), db.clone()));
617 }
618 result
619 } else {
620 server_geoip.clone() };
622
623 #[cfg(feature = "analytics")]
625 let analytics_geoip = if let Some(cfg) = config.analytics.geoip.as_ref() {
626 let result = load_or_share(cfg, &mut loaded).await;
627 if let Some(ref db) = result
628 && cfg.auto_update
629 && cfg.path.is_none()
630 {
631 auto_update_configs.push((cfg.clone(), db.clone()));
632 }
633 result
634 } else {
635 None
636 };
637 #[cfg(not(feature = "analytics"))]
638 let analytics_geoip: Option<Arc<GeoipDb>> = None;
639
640 if !loaded.is_empty() {
641 info!(
642 databases = loaded.len(),
643 "GeoIP databases loaded (deduplicated)"
644 );
645 }
646
647 {
649 let mut seen_ptrs = std::collections::HashSet::new();
651 for (cfg, db) in auto_update_configs {
652 let ptr = Arc::as_ptr(&db) as usize;
653 if !seen_ptrs.insert(ptr) {
654 continue; }
656 let cancel = shutdown.clone();
657 let source = cfg.source.clone();
658 info!(source = %source, "spawning GeoIP auto-update task");
659 let swappable = Arc::new(arc_swap::ArcSwap::from(db));
660 tokio::spawn(trojan_rules::geoip_db::geoip_auto_update_task(
661 cfg,
662 swappable,
663 cancel,
664 move |success| {
665 if success {
666 trojan_metrics::record_rule_update();
667 } else {
668 trojan_metrics::record_rule_update_error();
669 }
670 },
671 ));
672 }
673 }
674
675 (server_geoip, metrics_geoip, analytics_geoip)
676}
677
678#[cfg(feature = "rules")]
680async fn rule_update_loop(
681 engine: Arc<trojan_rules::HotRuleEngine>,
682 server_config: trojan_config::ServerConfig,
683 interval_secs: u64,
684 shutdown: CancellationToken,
685) {
686 use std::time::Duration;
687 use trojan_metrics::{record_rule_update, record_rule_update_error};
688
689 match crate::rules::build_rule_engine_async(&server_config).await {
691 Ok(new_engine) => {
692 info!(
693 rule_sets = new_engine.rule_set_count(),
694 rules = new_engine.rule_count(),
695 "initial rule fetch completed, engine updated"
696 );
697 engine.update(new_engine);
698 record_rule_update();
699 }
700 Err(e) => {
701 warn!(error = %e, "initial rule fetch failed, keeping startup rules");
702 record_rule_update_error();
703 }
704 }
705
706 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
707 interval.tick().await; loop {
710 tokio::select! {
711 biased;
712 _ = shutdown.cancelled() => {
713 debug!("rule update task shutting down");
714 return;
715 }
716 _ = interval.tick() => {
717 debug!("starting scheduled rule update");
718 match crate::rules::build_rule_engine_async(&server_config).await {
719 Ok(new_engine) => {
720 info!(
721 rule_sets = new_engine.rule_set_count(),
722 rules = new_engine.rule_count(),
723 "rule update completed, engine swapped"
724 );
725 engine.update(new_engine);
726 record_rule_update();
727 }
728 Err(e) => {
729 warn!(error = %e, "rule update failed, keeping current rules");
730 record_rule_update_error();
731 }
732 }
733 }
734 }
735 }
736}