Skip to main content

trojan_server/
cli.rs

1//! CLI module for trojan-server.
2//!
3//! This module provides the command-line interface that can be used either
4//! as a standalone binary or as a subcommand of the main trojan-rs CLI.
5
6use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use clap::Parser;
11use tracing::{info, warn};
12use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13use trojan_auth::{MemoryAuth, ReloadableAuth};
14use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
15
16use crate::{CancellationToken, run_with_shutdown};
17
18/// Trojan server CLI arguments.
19#[derive(Parser, Debug, Clone)]
20#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
21pub struct ServerArgs {
22    /// Config file path (json/jsonc/yaml/toml)
23    #[arg(short, long, default_value = "config.toml")]
24    pub config: PathBuf,
25
26    #[command(flatten)]
27    pub overrides: CliOverrides,
28}
29
30/// Run the trojan server with the given arguments.
31///
32/// This is the main entry point for the server CLI, used by both the
33/// standalone binary and the unified trojan-rs CLI.
34pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
35    let mut config = load_config(&args.config)?;
36    apply_overrides(&mut config, &args.overrides);
37    validate_config(&config)?;
38
39    init_tracing(&config.logging);
40
41    if let Some(listen) = &config.metrics.listen {
42        match trojan_metrics::init_metrics_server(listen) {
43            Ok(_handle) => info!(
44                "metrics server listening on {} (/metrics, /health, /ready)",
45                listen
46            ),
47            Err(e) => warn!("failed to start metrics server: {}", e),
48        }
49    }
50
51    // Set up graceful shutdown on SIGTERM/SIGINT
52    let shutdown = CancellationToken::new();
53    let shutdown_signal = shutdown.clone();
54
55    tokio::spawn(async move {
56        shutdown_signal_handler().await;
57        info!("shutdown signal received");
58        shutdown_signal.cancel();
59    });
60
61    // Create reloadable auth backend
62    let auth = Arc::new(ReloadableAuth::new(MemoryAuth::from_passwords(
63        &config.auth.passwords,
64    )));
65
66    // Set up SIGHUP handler for config reload
67    #[cfg(unix)]
68    {
69        let config_path = args.config.clone();
70        let overrides = args.overrides.clone();
71        let auth_reload = auth.clone();
72        tokio::spawn(async move {
73            reload_signal_handler(config_path, overrides, auth_reload).await;
74        });
75    }
76
77    run_with_shutdown(config, auth, shutdown).await?;
78    Ok(())
79}
80
81/// Wait for shutdown signals (SIGTERM, SIGINT).
82async fn shutdown_signal_handler() {
83    let ctrl_c = async {
84        if let Err(e) = tokio::signal::ctrl_c().await {
85            warn!("failed to listen for Ctrl+C: {}", e);
86            // Fall back to waiting forever
87            std::future::pending::<()>().await;
88        }
89    };
90
91    #[cfg(unix)]
92    let terminate = async {
93        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
94            Ok(mut sig) => {
95                sig.recv().await;
96            }
97            Err(e) => {
98                warn!("failed to listen for SIGTERM: {}", e);
99                // Fall back to waiting forever
100                std::future::pending::<()>().await;
101            }
102        }
103    };
104
105    #[cfg(not(unix))]
106    let terminate = std::future::pending::<()>();
107
108    tokio::select! {
109        _ = ctrl_c => {}
110        _ = terminate => {}
111    }
112}
113
114/// Handle SIGHUP for config reload (Unix only).
115#[cfg(unix)]
116async fn reload_signal_handler(
117    config_path: PathBuf,
118    overrides: CliOverrides,
119    auth: Arc<ReloadableAuth>,
120) {
121    use tokio::signal::unix::{SignalKind, signal};
122
123    let mut sighup = match signal(SignalKind::hangup()) {
124        Ok(sig) => sig,
125        Err(e) => {
126            warn!(
127                "failed to install SIGHUP handler: {}, config reload disabled",
128                e
129            );
130            return;
131        }
132    };
133
134    loop {
135        sighup.recv().await;
136        info!("SIGHUP received, reloading configuration");
137
138        match reload_config(&config_path, &overrides, &auth) {
139            Ok(()) => info!("configuration reloaded successfully"),
140            Err(e) => warn!("failed to reload configuration: {}", e),
141        }
142    }
143}
144
145/// Reload configuration from file.
146#[cfg(unix)]
147fn reload_config(
148    config_path: &PathBuf,
149    overrides: &CliOverrides,
150    auth: &Arc<ReloadableAuth>,
151) -> Result<(), Box<dyn std::error::Error>> {
152    let mut config = load_config(config_path)?;
153    apply_overrides(&mut config, overrides);
154    validate_config(&config)?;
155
156    // Reload auth passwords
157    let new_auth = MemoryAuth::from_passwords(&config.auth.passwords);
158    auth.reload(new_auth);
159    info!(
160        password_count = config.auth.passwords.len(),
161        "auth passwords reloaded"
162    );
163
164    // Note: TLS certificates and other settings require server restart
165    // Future enhancement: implement TLS cert hot-reload via rustls ResolvesServerCert
166
167    Ok(())
168}
169
170/// Initialize tracing subscriber with the given logging configuration.
171///
172/// Supports:
173/// - `level`: Base log level (trace, debug, info, warn, error)
174/// - `format`: Output format (json, pretty, compact). Default: pretty
175/// - `output`: Output target (stdout, stderr). Default: stderr
176/// - `filters`: Per-module log level overrides
177fn init_tracing(config: &LoggingConfig) {
178    // Build the env filter from base level and per-module filters
179    let base_level = config.level.as_deref().unwrap_or("info");
180    let mut filter_str = base_level.to_string();
181
182    for (module, level) in &config.filters {
183        filter_str.push(',');
184        filter_str.push_str(module);
185        filter_str.push('=');
186        filter_str.push_str(level);
187    }
188
189    let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
190
191    let format = config.format.as_deref().unwrap_or("pretty");
192    let output = config.output.as_deref().unwrap_or("stderr");
193
194    // Create the subscriber based on format and output
195    match (format, output) {
196        ("json", "stdout") => {
197            tracing_subscriber::registry()
198                .with(filter)
199                .with(fmt::layer().json().with_writer(io::stdout))
200                .init();
201        }
202        ("json", _) => {
203            tracing_subscriber::registry()
204                .with(filter)
205                .with(fmt::layer().json().with_writer(io::stderr))
206                .init();
207        }
208        ("compact", "stdout") => {
209            tracing_subscriber::registry()
210                .with(filter)
211                .with(fmt::layer().compact().with_writer(io::stdout))
212                .init();
213        }
214        ("compact", _) => {
215            tracing_subscriber::registry()
216                .with(filter)
217                .with(fmt::layer().compact().with_writer(io::stderr))
218                .init();
219        }
220        (_, "stdout") => {
221            // pretty is default
222            tracing_subscriber::registry()
223                .with(filter)
224                .with(fmt::layer().with_writer(io::stdout))
225                .init();
226        }
227        _ => {
228            // pretty to stderr is default
229            tracing_subscriber::registry()
230                .with(filter)
231                .with(fmt::layer().with_writer(io::stderr))
232                .init();
233        }
234    }
235}