Skip to main content

trojan_server/
server.rs

1//! Main server loop and connection handling.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use tokio::sync::{OwnedSemaphorePermit, Semaphore};
9use tokio::time::Instant;
10use tokio_rustls::TlsAcceptor;
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug, info, info_span, warn};
13
14use crate::error::ServerError;
15use crate::handler::handle_conn;
16use crate::pool::ConnectionPool;
17use crate::rate_limit::RateLimiter;
18use crate::resolve::resolve_sockaddr;
19use crate::state::ServerState;
20use crate::tls::load_tls_config;
21use crate::util::{ConnectionGuard, ConnectionTracker, apply_tcp_options, create_listener};
22use trojan_auth::AuthBackend;
23use trojan_config::Config;
24use trojan_core::defaults;
25use trojan_dns::DnsResolver;
26use trojan_metrics::{
27    ERROR_TLS_HANDSHAKE, record_connection_accepted, record_connection_closed,
28    record_connection_rejected, record_error, record_tls_handshake_duration,
29    set_connection_queue_depth,
30};
31
32/// Default graceful shutdown timeout.
33pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
34
35/// Global connection ID counter.
36static CONN_ID: AtomicU64 = AtomicU64::new(1);
37
38/// Generate a unique connection ID.
39#[inline]
40fn next_conn_id() -> u64 {
41    CONN_ID.fetch_add(1, Ordering::Relaxed)
42}
43
44/// Run the server with a cancellation token for graceful shutdown.
45pub async fn run_with_shutdown(
46    config: Config,
47    auth: impl AuthBackend + 'static,
48    shutdown: CancellationToken,
49) -> Result<(), ServerError> {
50    let tls_config = load_tls_config(&config.tls)?;
51    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
52
53    let listen: SocketAddr = config
54        .server
55        .listen
56        .parse()
57        .map_err(|_| ServerError::Config("invalid listen address".into()))?;
58
59    // Build DNS resolver from config.
60    // Backward compatibility: preserve legacy `server.tcp.prefer_ipv4` behavior.
61    let mut dns_config = config.dns.clone();
62    if config.server.tcp.prefer_ipv4 && !dns_config.prefer_ipv4 {
63        dns_config.prefer_ipv4 = true;
64        info!(
65            "server.tcp.prefer_ipv4 is deprecated; mapped to dns.prefer_ipv4 for backward compatibility"
66        );
67    }
68    let dns_resolver = DnsResolver::new(&dns_config)
69        .map_err(|e| ServerError::Config(format!("dns resolver: {e}")))?;
70    info!(
71        dns = ?dns_config.strategy,
72        prefer_ipv4 = dns_config.prefer_ipv4,
73        "dns resolver initialized"
74    );
75
76    let fallback_addr = resolve_sockaddr(&config.server.fallback, &dns_resolver).await?;
77
78    // Initialize fallback connection pool if configured
79    let fallback_pool: Option<Arc<ConnectionPool>> =
80        config.server.fallback_pool.as_ref().map(|pool_cfg| {
81            info!(
82                max_idle = pool_cfg.max_idle,
83                max_age_secs = pool_cfg.max_age_secs,
84                fill_batch = pool_cfg.fill_batch,
85                fill_delay_ms = pool_cfg.fill_delay_ms,
86                "fallback connection pool enabled"
87            );
88            let pool = Arc::new(ConnectionPool::new(
89                fallback_addr,
90                pool_cfg.max_idle,
91                pool_cfg.max_age_secs,
92                pool_cfg.fill_batch,
93                pool_cfg.fill_delay_ms,
94            ));
95            // Use max_age_secs as cleanup interval
96            pool.start_cleanup_task(Duration::from_secs(pool_cfg.max_age_secs));
97            pool
98        });
99
100    // Extract resource limits with defaults
101    let (relay_buffer_size, tcp_send_buffer, tcp_recv_buffer, connection_backlog) =
102        match &config.server.resource_limits {
103            Some(rl) => {
104                info!(
105                    relay_buffer = rl.relay_buffer_size,
106                    tcp_send_buffer = rl.tcp_send_buffer,
107                    tcp_recv_buffer = rl.tcp_recv_buffer,
108                    connection_backlog = rl.connection_backlog,
109                    "resource limits configured"
110                );
111                (
112                    rl.relay_buffer_size,
113                    rl.tcp_send_buffer,
114                    rl.tcp_recv_buffer,
115                    rl.connection_backlog,
116                )
117            }
118            None => (
119                defaults::DEFAULT_RELAY_BUFFER_SIZE,
120                defaults::DEFAULT_TCP_SEND_BUFFER,
121                defaults::DEFAULT_TCP_RECV_BUFFER,
122                defaults::DEFAULT_CONNECTION_BACKLOG,
123            ),
124        };
125
126    // Initialize analytics if feature enabled and configured
127    #[cfg(feature = "analytics")]
128    let analytics = if config.analytics.enabled {
129        match trojan_analytics::init(config.analytics.clone()).await {
130            Ok(collector) => {
131                info!("analytics enabled, sending to ClickHouse");
132                Some(collector)
133            }
134            Err(e) => {
135                warn!("failed to init analytics: {}, disabled", e);
136                None
137            }
138        }
139    } else {
140        debug!("analytics disabled in config");
141        None
142    };
143
144    // Initialize rule engine if feature enabled and rules configured
145    #[cfg(feature = "rules")]
146    let rule_engine = if !config.server.rules.is_empty() {
147        match crate::rules::build_rule_engine(&config.server) {
148            Ok(engine) => {
149                info!(
150                    rule_sets = engine.rule_set_count(),
151                    rules = engine.rule_count(),
152                    "rule engine initialized"
153                );
154                Some(Arc::new(trojan_rules::HotRuleEngine::new(engine)))
155            }
156            Err(e) => {
157                return Err(ServerError::Rules(format!("failed to init rules: {e}")));
158            }
159        }
160    } else {
161        debug!("no routing rules configured");
162        None
163    };
164
165    // Spawn background rule update task for HTTP providers
166    #[cfg(feature = "rules")]
167    if let Some(ref hot_engine) = rule_engine
168        && crate::rules::has_http_providers(&config.server)
169    {
170        let interval_secs = crate::rules::http_update_interval(&config.server).unwrap_or(3600); // default: 1 hour
171        let engine_ref = hot_engine.clone();
172        let server_cfg = config.server.clone();
173        let update_shutdown = shutdown.clone();
174        info!(interval_secs, "starting background rule update task");
175        tokio::spawn(async move {
176            rule_update_loop(engine_ref, server_cfg, interval_secs, update_shutdown).await;
177        });
178    }
179
180    // Build outbound connectors from config
181    #[cfg(feature = "rules")]
182    let outbounds = {
183        let mut map = std::collections::HashMap::new();
184        for (name, outbound_cfg) in &config.server.outbounds {
185            match crate::outbound::Outbound::from_config(name, outbound_cfg) {
186                Ok(outbound) => {
187                    info!(name = %name, "outbound connector configured");
188                    map.insert(name.clone(), Arc::new(outbound));
189                }
190                Err(e) => {
191                    return Err(ServerError::Config(format!("outbound '{name}': {e}")));
192                }
193            }
194        }
195        map
196    };
197
198    // Load GeoIP databases with deduplication.
199    // geoip_server is used indirectly (metrics fallback shares it).
200    #[cfg(feature = "geoip")]
201    #[allow(unused_variables)]
202    let (geoip_server, geoip_metrics, geoip_analytics) =
203        load_geoip_databases(&config, &shutdown).await;
204
205    // Start metrics server (with debug routes if rules feature is enabled)
206    if let Some(ref listen) = config.metrics.listen {
207        #[cfg(feature = "rules")]
208        let extra_routes = rule_engine
209            .as_ref()
210            .map(|engine| crate::debug_api::debug_routes(engine.clone()));
211        #[cfg(not(feature = "rules"))]
212        let extra_routes: Option<axum::Router> = None;
213
214        match trojan_metrics::init_metrics_server(listen, extra_routes) {
215            Ok(_handle) => {
216                #[cfg(feature = "rules")]
217                let endpoints = if rule_engine.is_some() {
218                    "/metrics, /health, /ready, /debug/rules/match"
219                } else {
220                    "/metrics, /health, /ready"
221                };
222                #[cfg(not(feature = "rules"))]
223                let endpoints = "/metrics, /health, /ready";
224                info!("metrics server listening on {} ({})", listen, endpoints);
225            }
226            Err(e) => warn!("failed to start metrics server: {}", e),
227        }
228    }
229
230    // Log TCP options
231    let tcp_cfg = &config.server.tcp;
232    info!(
233        no_delay = tcp_cfg.no_delay,
234        keepalive_secs = tcp_cfg.keepalive_secs,
235        reuse_port = tcp_cfg.reuse_port,
236        fast_open = tcp_cfg.fast_open,
237        "TCP options configured"
238    );
239
240    let state = Arc::new(ServerState {
241        fallback_addr,
242        max_udp_payload: config.server.max_udp_payload,
243        max_udp_buffer_bytes: config.server.max_udp_buffer_bytes,
244        max_header_bytes: config.server.max_header_bytes,
245        tcp_idle_timeout: Duration::from_secs(config.server.tcp_idle_timeout_secs),
246        udp_idle_timeout: Duration::from_secs(config.server.udp_timeout_secs),
247        fallback_pool,
248        relay_buffer_size,
249        tcp_send_buffer,
250        tcp_recv_buffer,
251        tcp_config: config.server.tcp.clone(),
252        websocket: config.websocket.clone(),
253        dns_resolver,
254        #[cfg(feature = "analytics")]
255        analytics,
256        #[cfg(feature = "rules")]
257        rule_engine,
258        #[cfg(feature = "rules")]
259        outbounds,
260        #[cfg(feature = "geoip")]
261        geoip_metrics,
262        #[cfg(all(feature = "geoip", feature = "analytics"))]
263        geoip_analytics,
264    });
265    let auth = Arc::new(auth);
266    let tracker = ConnectionTracker::new();
267
268    // Connection limiter (None = unlimited)
269    let conn_limit: Option<Arc<Semaphore>> = config.server.max_connections.map(|n| {
270        info!("max_connections set to {}", n);
271        Arc::new(Semaphore::new(n))
272    });
273
274    // Rate limiter (None = disabled)
275    let rate_limiter: Option<Arc<RateLimiter>> = config.server.rate_limit.as_ref().map(|rl| {
276        info!(
277            max_per_ip = rl.max_connections_per_ip,
278            window_secs = rl.window_secs,
279            "rate limiting enabled"
280        );
281        let limiter = Arc::new(RateLimiter::new(rl.max_connections_per_ip, rl.window_secs));
282        limiter.start_cleanup_task(Duration::from_secs(rl.cleanup_interval_secs));
283        limiter
284    });
285
286    // Create listener with custom backlog and TCP options using socket2
287    let listener = create_listener(listen, connection_backlog, &config.server.tcp)?;
288    info!(address = %listen, backlog = connection_backlog, "listening");
289
290    #[cfg(feature = "ws")]
291    if config.websocket.enabled && config.websocket.mode == "split" {
292        let ws_listen = config.websocket.listen.clone().unwrap_or_default();
293        let ws_addr: SocketAddr = ws_listen
294            .parse()
295            .map_err(|_| ServerError::Config("invalid websocket.listen address".into()))?;
296        let ws_listener = create_listener(ws_addr, connection_backlog, &config.server.tcp)?;
297        let ws_acceptor = acceptor.clone();
298        let ws_state = state.clone();
299        let ws_auth = auth.clone();
300        let ws_tracker = tracker.clone();
301        let ws_conn_limit = conn_limit.clone();
302        let ws_rate_limiter = rate_limiter.clone();
303        let ws_shutdown = shutdown.clone();
304
305        info!(address = %ws_addr, "websocket split listener started");
306        tokio::spawn(async move {
307            loop {
308                tokio::select! {
309                    biased;
310                    _ = ws_shutdown.cancelled() => break,
311                    result = ws_listener.accept() => {
312                        let (tcp, peer) = match result {
313                            Ok(v) => v,
314                            Err(_) => continue,
315                        };
316
317                        // Apply TCP socket options
318                        if let Err(e) = apply_tcp_options(&tcp, &ws_state.tcp_config) {
319                            tracing::debug!(error = %e, "failed to apply TCP options");
320                        }
321
322                        if let Some(ref limiter) = ws_rate_limiter {
323                            let ip = peer.ip();
324                            if !limiter.check_and_increment(ip) {
325                                record_connection_rejected("rate_limit");
326                                drop(tcp);
327                                continue;
328                            }
329                        }
330
331                        let permit: Option<OwnedSemaphorePermit> = match &ws_conn_limit {
332                            Some(sem) => match sem.clone().try_acquire_owned() {
333                                Ok(p) => Some(p),
334                                Err(_) => {
335                                    record_connection_rejected("max_connections");
336                                    drop(tcp);
337                                    continue;
338                                }
339                            },
340                            None => None,
341                        };
342
343                        let conn_id = next_conn_id();
344                        let acceptor = ws_acceptor.clone();
345                        let state = ws_state.clone();
346                        let auth = ws_auth.clone();
347                        ws_tracker.increment();
348                        let guard = ConnectionGuard::new(ws_tracker.clone());
349
350                        let span = info_span!("conn", id = conn_id, peer = %peer, transport = "ws");
351                        tokio::spawn(
352                            async move {
353                                let _guard = guard;
354                                let _permit = permit;
355                                record_connection_accepted();
356                                let start = Instant::now();
357
358                                let result = async {
359                                    let tls_start = Instant::now();
360                                    let tls_timeout =
361                                        Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
362                                    match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await
363                                    {
364                                        Ok(Ok(tls)) => {
365                                            let tls_duration = tls_start.elapsed().as_secs_f64();
366                                            record_tls_handshake_duration(tls_duration);
367                                            crate::handler::handle_ws_only(tls, state, auth, peer).await
368                                        }
369                                        Ok(Err(err)) => {
370                                            record_error(ERROR_TLS_HANDSHAKE);
371                                            warn!(error = %err, "TLS handshake failed");
372                                            Ok(())
373                                        }
374                                        Err(_) => {
375                                            record_error(ERROR_TLS_HANDSHAKE);
376                                            warn!(
377                                                timeout_secs = tls_timeout.as_secs(),
378                                                "TLS handshake timed out"
379                                            );
380                                            Ok(())
381                                        }
382                                    }
383                                }
384                                .await;
385
386                                let duration_secs = start.elapsed().as_secs_f64();
387                                record_connection_closed(duration_secs);
388
389                                if let Err(ref err) = result {
390                                    warn!(error = %err, "connection error");
391                                }
392                            }
393                            .instrument(span),
394                        );
395                    }
396                }
397            }
398        });
399    }
400
401    #[cfg(not(feature = "ws"))]
402    if config.websocket.enabled {
403        warn!("websocket.enabled=true but ws feature is disabled; ignoring websocket");
404    }
405
406    loop {
407        tokio::select! {
408            biased;
409
410            _ = shutdown.cancelled() => {
411                info!("shutdown signal received, stopping accept loop");
412                break;
413            }
414
415            result = listener.accept() => {
416                let (tcp, peer) = result?;
417
418                // Apply TCP socket options (no_delay, keepalive)
419                if let Err(e) = apply_tcp_options(&tcp, &state.tcp_config) {
420                    debug!(error = %e, "failed to apply TCP options");
421                }
422
423                // Update connection queue depth metric (based on semaphore usage)
424                if let Some(ref sem) = conn_limit {
425                    let available = sem.available_permits();
426                    set_connection_queue_depth(available as f64);
427                }
428
429                // Check rate limit first
430                if let Some(ref limiter) = rate_limiter {
431                    let ip = peer.ip();
432                    if !limiter.check_and_increment(ip) {
433                        debug!(peer = %peer, reason = "rate_limit", "connection rejected");
434                        record_connection_rejected("rate_limit");
435                        drop(tcp);
436                        continue;
437                    }
438                }
439
440                // Try to acquire connection permit
441                let permit: Option<OwnedSemaphorePermit> = match &conn_limit {
442                    Some(sem) => match sem.clone().try_acquire_owned() {
443                        Ok(p) => Some(p),
444                        Err(_) => {
445                            debug!(peer = %peer, reason = "max_connections", "connection rejected");
446                            record_connection_rejected("max_connections");
447                            drop(tcp); // close immediately
448                            continue;
449                        }
450                    },
451                    None => None,
452                };
453
454                let conn_id = next_conn_id();
455                debug!(conn_id, peer = %peer, "new connection");
456
457                let acceptor = acceptor.clone();
458                let state = state.clone();
459                let auth = auth.clone();
460                tracker.increment();
461                let guard = ConnectionGuard::new(tracker.clone());
462
463                let span = info_span!("conn", id = conn_id, peer = %peer);
464                tokio::spawn(
465                    async move {
466                        let _guard = guard; // ensure decrement on drop
467                        let _permit = permit; // hold permit until connection closes
468                        record_connection_accepted();
469                        let start = Instant::now();
470
471                        let result = async {
472                            // Measure TLS handshake duration with timeout
473                            let tls_start = Instant::now();
474                            let tls_timeout =
475                                Duration::from_secs(defaults::DEFAULT_TLS_HANDSHAKE_TIMEOUT_SECS);
476                            match tokio::time::timeout(tls_timeout, acceptor.accept(tcp)).await {
477                                Ok(Ok(tls)) => {
478                                    let tls_duration = tls_start.elapsed().as_secs_f64();
479                                    record_tls_handshake_duration(tls_duration);
480                                    debug!(duration_ms = tls_duration * 1000.0, "TLS handshake completed");
481                                    handle_conn(tls, state, auth, peer).await
482                                }
483                                Ok(Err(err)) => {
484                                    record_error(ERROR_TLS_HANDSHAKE);
485                                    warn!(error = %err, "TLS handshake failed");
486                                    Ok(())
487                                }
488                                Err(_) => {
489                                    record_error(ERROR_TLS_HANDSHAKE);
490                                    warn!(timeout_secs = tls_timeout.as_secs(), "TLS handshake timed out");
491                                    Ok(())
492                                }
493                            }
494                        }
495                        .await;
496
497                        let duration_secs = start.elapsed().as_secs_f64();
498                        record_connection_closed(duration_secs);
499
500                        if let Err(ref err) = result {
501                            record_error(err.error_type());
502                            warn!(duration_secs, error = %err, "connection closed with error");
503                        } else {
504                            debug!(duration_secs, "connection closed");
505                        }
506                    }
507                    .instrument(span),
508                );
509            }
510        }
511    }
512
513    // Shutdown rate limiter cleanup task
514    if let Some(ref limiter) = rate_limiter {
515        limiter.shutdown();
516    }
517
518    // Graceful drain: wait for active connections
519    let active = tracker.count();
520    if active > 0 {
521        info!("waiting for {} active connections to drain", active);
522        if tracker.wait_for_zero(DEFAULT_SHUTDOWN_TIMEOUT).await {
523            info!("all connections drained");
524        } else {
525            warn!(
526                "shutdown timeout, {} connections still active",
527                tracker.count()
528            );
529        }
530    }
531
532    info!("server stopped");
533    Ok(())
534}
535
536/// Run the server (blocking until error, no graceful shutdown).
537/// For backward compatibility with existing code.
538pub async fn run(config: Config, auth: impl AuthBackend + 'static) -> Result<(), ServerError> {
539    run_with_shutdown(config, auth, CancellationToken::new()).await
540}
541
542/// Load GeoIP databases from config with deduplication.
543///
544/// Returns `(server_geoip, metrics_geoip, analytics_geoip)`.
545/// If multiple configs point to the same source, the same `Arc` is shared.
546///
547/// Databases can be downloaded from CDN or custom URLs. Auto-update tasks
548/// are spawned for configs with `auto_update = true` and no local `path` set.
549#[cfg(feature = "geoip")]
550#[allow(unused_variables)]
551async fn load_geoip_databases(
552    config: &Config,
553    shutdown: &CancellationToken,
554) -> (
555    Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
556    Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
557    Option<Arc<trojan_rules::geoip_db::GeoipDb>>,
558) {
559    use std::collections::HashMap;
560    use trojan_rules::geoip_db::GeoipDb;
561
562    // Deduplication key: (path, url, source) tuple identifies a unique database
563    type Key = (Option<String>, Option<String>, String);
564    let mut loaded: HashMap<Key, Arc<GeoipDb>> = HashMap::new();
565
566    // Track configs that need auto-update tasks
567    let mut auto_update_configs: Vec<(trojan_config::GeoipConfig, Arc<GeoipDb>)> = Vec::new();
568
569    // Load a single GeoIP config, deduplicating by key
570    async fn load_or_share(
571        cfg: &trojan_config::GeoipConfig,
572        loaded: &mut HashMap<Key, Arc<GeoipDb>>,
573    ) -> Option<Arc<GeoipDb>> {
574        let key: Key = (cfg.path.clone(), cfg.url.clone(), cfg.source.clone());
575        if let Some(existing) = loaded.get(&key) {
576            return Some(existing.clone());
577        }
578        match trojan_rules::geoip_db::load_geoip(cfg).await {
579            Ok(db) => {
580                let arc = Arc::new(db);
581                loaded.insert(key, arc.clone());
582                Some(arc)
583            }
584            Err(e) => {
585                warn!(source = %cfg.source, error = %e, "failed to load GeoIP database");
586                None
587            }
588        }
589    }
590
591    // Server GeoIP (for rule matching — also shared by metrics/analytics)
592    let server_geoip = if let Some(cfg) = config.server.geoip.as_ref() {
593        load_or_share(cfg, &mut loaded).await
594    } else {
595        None
596    };
597
598    // Metrics GeoIP
599    let metrics_geoip = if let Some(cfg) = config.metrics.geoip.as_ref() {
600        let result = load_or_share(cfg, &mut loaded).await;
601        if let Some(ref db) = result
602            && cfg.auto_update
603            && cfg.path.is_none()
604        {
605            auto_update_configs.push((cfg.clone(), db.clone()));
606        }
607        result
608    } else {
609        server_geoip.clone() // fallback to server's GeoIP
610    };
611
612    // Analytics GeoIP
613    #[cfg(feature = "analytics")]
614    let analytics_geoip = if let Some(cfg) = config.analytics.geoip.as_ref() {
615        let result = load_or_share(cfg, &mut loaded).await;
616        if let Some(ref db) = result
617            && cfg.auto_update
618            && cfg.path.is_none()
619        {
620            auto_update_configs.push((cfg.clone(), db.clone()));
621        }
622        result
623    } else {
624        None
625    };
626    #[cfg(not(feature = "analytics"))]
627    let analytics_geoip: Option<Arc<GeoipDb>> = None;
628
629    if !loaded.is_empty() {
630        info!(
631            databases = loaded.len(),
632            "GeoIP databases loaded (deduplicated)"
633        );
634    }
635
636    // Spawn auto-update tasks for configs that need them
637    {
638        // Deduplicate auto-update tasks by Arc pointer identity
639        let mut seen_ptrs = std::collections::HashSet::new();
640        for (cfg, db) in auto_update_configs {
641            let ptr = Arc::as_ptr(&db) as usize;
642            if !seen_ptrs.insert(ptr) {
643                continue; // already spawned for this database
644            }
645            let cancel = shutdown.clone();
646            let source = cfg.source.clone();
647            info!(source = %source, "spawning GeoIP auto-update task");
648            let swappable = Arc::new(arc_swap::ArcSwap::from(db));
649            tokio::spawn(trojan_rules::geoip_db::geoip_auto_update_task(
650                cfg,
651                swappable,
652                cancel,
653                move |success| {
654                    if success {
655                        trojan_metrics::record_rule_update();
656                    } else {
657                        trojan_metrics::record_rule_update_error();
658                    }
659                },
660            ));
661        }
662    }
663
664    (server_geoip, metrics_geoip, analytics_geoip)
665}
666
667/// Background task that periodically re-fetches HTTP rule-sets and hot-swaps the engine.
668#[cfg(feature = "rules")]
669async fn rule_update_loop(
670    engine: Arc<trojan_rules::HotRuleEngine>,
671    server_config: trojan_config::ServerConfig,
672    interval_secs: u64,
673    shutdown: CancellationToken,
674) {
675    use std::time::Duration;
676    use trojan_metrics::{record_rule_update, record_rule_update_error};
677
678    // Initial fetch (immediate) to replace any cache-only startup data
679    match crate::rules::build_rule_engine_async(&server_config).await {
680        Ok(new_engine) => {
681            info!(
682                rule_sets = new_engine.rule_set_count(),
683                rules = new_engine.rule_count(),
684                "initial rule fetch completed, engine updated"
685            );
686            engine.update(new_engine);
687            record_rule_update();
688        }
689        Err(e) => {
690            warn!(error = %e, "initial rule fetch failed, keeping startup rules");
691            record_rule_update_error();
692        }
693    }
694
695    let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
696    interval.tick().await; // consume the immediate tick
697
698    loop {
699        tokio::select! {
700            biased;
701            _ = shutdown.cancelled() => {
702                debug!("rule update task shutting down");
703                return;
704            }
705            _ = interval.tick() => {
706                debug!("starting scheduled rule update");
707                match crate::rules::build_rule_engine_async(&server_config).await {
708                    Ok(new_engine) => {
709                        info!(
710                            rule_sets = new_engine.rule_set_count(),
711                            rules = new_engine.rule_count(),
712                            "rule update completed, engine swapped"
713                        );
714                        engine.update(new_engine);
715                        record_rule_update();
716                    }
717                    Err(e) => {
718                        warn!(error = %e, "rule update failed, keeping current rules");
719                        record_rule_update_error();
720                    }
721                }
722            }
723        }
724    }
725}