Skip to main content

trojan_server/
cli.rs

1//! CLI module for trojan-server.
2//!
3//! This module provides the command-line interface that can be used either
4//! as a standalone binary or as a subcommand of the main trojan-rs CLI.
5
6use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use clap::Parser;
11use tracing::{info, warn};
12use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13use trojan_auth::{MemoryAuth, ReloadableAuth};
14use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
15
16use crate::{CancellationToken, run_with_shutdown};
17
18/// Trojan server CLI arguments.
19#[derive(Parser, Debug, Clone)]
20#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
21pub struct ServerArgs {
22    /// Config file path (json/jsonc/yaml/toml)
23    #[arg(short, long, default_value = "config.toml")]
24    pub config: PathBuf,
25
26    #[command(flatten)]
27    pub overrides: CliOverrides,
28}
29
30/// Run the trojan server with the given arguments.
31///
32/// This is the main entry point for the server CLI, used by both the
33/// standalone binary and the unified trojan-rs CLI.
34pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
35    let mut config = load_config(&args.config)?;
36    apply_overrides(&mut config, &args.overrides);
37    validate_config(&config)?;
38
39    init_tracing(&config.logging);
40
41    if let Some(listen) = &config.metrics.listen {
42        match trojan_metrics::init_metrics_server(listen) {
43            Ok(_handle) => info!(
44                "metrics server listening on {} (/metrics, /health, /ready)",
45                listen
46            ),
47            Err(e) => warn!("failed to start metrics server: {}", e),
48        }
49    }
50
51    // Set up graceful shutdown on SIGTERM/SIGINT
52    let shutdown = CancellationToken::new();
53    let shutdown_signal = shutdown.clone();
54
55    tokio::spawn(async move {
56        shutdown_signal_handler().await;
57        info!("shutdown signal received");
58        shutdown_signal.cancel();
59    });
60
61    // Create reloadable auth backend
62    let auth = Arc::new(ReloadableAuth::new(build_memory_auth(&config.auth)));
63
64    // Set up SIGHUP handler for config reload
65    #[cfg(unix)]
66    {
67        let config_path = args.config.clone();
68        let overrides = args.overrides.clone();
69        let auth_reload = auth.clone();
70        tokio::spawn(async move {
71            reload_signal_handler(config_path, overrides, auth_reload).await;
72        });
73    }
74
75    run_with_shutdown(config, auth, shutdown).await?;
76    Ok(())
77}
78
79/// Wait for shutdown signals (SIGTERM, SIGINT).
80async fn shutdown_signal_handler() {
81    let ctrl_c = async {
82        if let Err(e) = tokio::signal::ctrl_c().await {
83            warn!("failed to listen for Ctrl+C: {}", e);
84            // Fall back to waiting forever
85            std::future::pending::<()>().await;
86        }
87    };
88
89    #[cfg(unix)]
90    let terminate = async {
91        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
92            Ok(mut sig) => {
93                sig.recv().await;
94            }
95            Err(e) => {
96                warn!("failed to listen for SIGTERM: {}", e);
97                // Fall back to waiting forever
98                std::future::pending::<()>().await;
99            }
100        }
101    };
102
103    #[cfg(not(unix))]
104    let terminate = std::future::pending::<()>();
105
106    tokio::select! {
107        _ = ctrl_c => {}
108        _ = terminate => {}
109    }
110}
111
112/// Handle SIGHUP for config reload (Unix only).
113#[cfg(unix)]
114async fn reload_signal_handler(
115    config_path: PathBuf,
116    overrides: CliOverrides,
117    auth: Arc<ReloadableAuth>,
118) {
119    use tokio::signal::unix::{SignalKind, signal};
120
121    let mut sighup = match signal(SignalKind::hangup()) {
122        Ok(sig) => sig,
123        Err(e) => {
124            warn!(
125                "failed to install SIGHUP handler: {}, config reload disabled",
126                e
127            );
128            return;
129        }
130    };
131
132    loop {
133        sighup.recv().await;
134        info!("SIGHUP received, reloading configuration");
135
136        match reload_config(&config_path, &overrides, &auth) {
137            Ok(()) => info!("configuration reloaded successfully"),
138            Err(e) => warn!("failed to reload configuration: {}", e),
139        }
140    }
141}
142
143/// Reload configuration from file.
144#[cfg(unix)]
145fn reload_config(
146    config_path: &PathBuf,
147    overrides: &CliOverrides,
148    auth: &Arc<ReloadableAuth>,
149) -> Result<(), Box<dyn std::error::Error>> {
150    let mut config = load_config(config_path)?;
151    apply_overrides(&mut config, overrides);
152    validate_config(&config)?;
153
154    // Reload auth passwords + users
155    let new_auth = build_memory_auth(&config.auth);
156    auth.reload(new_auth);
157    info!(
158        password_count = config.auth.passwords.len(),
159        user_count = config.auth.users.len(),
160        "auth reloaded"
161    );
162
163    // Note: TLS certificates and other settings require server restart
164    // Future enhancement: implement TLS cert hot-reload via rustls ResolvesServerCert
165
166    Ok(())
167}
168
169/// Build a `MemoryAuth` from both `passwords` and `users` in the config.
170fn build_memory_auth(auth: &trojan_config::AuthConfig) -> MemoryAuth {
171    let mut mem = MemoryAuth::new();
172    for pw in &auth.passwords {
173        mem.add_password(pw, None);
174    }
175    for u in &auth.users {
176        mem.add_password(&u.password, Some(u.id.clone()));
177    }
178    mem
179}
180
181/// Initialize tracing subscriber with the given logging configuration.
182///
183/// Supports:
184/// - `level`: Base log level (trace, debug, info, warn, error)
185/// - `format`: Output format (json, pretty, compact). Default: pretty
186/// - `output`: Output target (stdout, stderr). Default: stderr
187/// - `filters`: Per-module log level overrides
188fn init_tracing(config: &LoggingConfig) {
189    // Build the env filter from base level and per-module filters
190    let base_level = config.level.as_deref().unwrap_or("info");
191    let mut filter_str = base_level.to_string();
192
193    for (module, level) in &config.filters {
194        filter_str.push(',');
195        filter_str.push_str(module);
196        filter_str.push('=');
197        filter_str.push_str(level);
198    }
199
200    let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
201
202    let format = config.format.as_deref().unwrap_or("pretty");
203    let output = config.output.as_deref().unwrap_or("stderr");
204
205    // Create the subscriber based on format and output
206    match (format, output) {
207        ("json", "stdout") => {
208            tracing_subscriber::registry()
209                .with(filter)
210                .with(fmt::layer().json().with_writer(io::stdout))
211                .init();
212        }
213        ("json", _) => {
214            tracing_subscriber::registry()
215                .with(filter)
216                .with(fmt::layer().json().with_writer(io::stderr))
217                .init();
218        }
219        ("compact", "stdout") => {
220            tracing_subscriber::registry()
221                .with(filter)
222                .with(fmt::layer().compact().with_writer(io::stdout))
223                .init();
224        }
225        ("compact", _) => {
226            tracing_subscriber::registry()
227                .with(filter)
228                .with(fmt::layer().compact().with_writer(io::stderr))
229                .init();
230        }
231        (_, "stdout") => {
232            // pretty is default
233            tracing_subscriber::registry()
234                .with(filter)
235                .with(fmt::layer().with_writer(io::stdout))
236                .init();
237        }
238        _ => {
239            // pretty to stderr is default
240            tracing_subscriber::registry()
241                .with(filter)
242                .with(fmt::layer().with_writer(io::stderr))
243                .init();
244        }
245    }
246}