Skip to main content

trojan_server/
cli.rs

1//! CLI module for trojan-server.
2//!
3//! This module provides the command-line interface that can be used either
4//! as a standalone binary or as a subcommand of the main trojan-rs CLI.
5
6use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use clap::Parser;
11use tracing::{info, warn};
12use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13use trojan_auth::{
14    AuthBackend, MemoryAuth, ReloadableAuth,
15    http::{Codec, HttpAuth},
16};
17use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
18
19use crate::{CancellationToken, run_with_shutdown};
20
21/// Trojan server CLI arguments.
22#[derive(Parser, Debug, Clone)]
23#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
24pub struct ServerArgs {
25    /// Config file path (json/jsonc/yaml/toml)
26    #[arg(short, long, default_value = "config.toml")]
27    pub config: PathBuf,
28
29    #[command(flatten)]
30    pub overrides: CliOverrides,
31}
32
33/// Run the trojan server with the given arguments.
34///
35/// This is the main entry point for the server CLI, used by both the
36/// standalone binary and the unified trojan-rs CLI.
37pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
38    let mut config = load_config(&args.config)?;
39    apply_overrides(&mut config, &args.overrides);
40    validate_config(&config)?;
41
42    init_tracing(&config.logging);
43
44    // Set up graceful shutdown on SIGTERM/SIGINT
45    let shutdown = CancellationToken::new();
46    let shutdown_signal = shutdown.clone();
47
48    tokio::spawn(async move {
49        shutdown_signal_handler().await;
50        info!("shutdown signal received");
51        shutdown_signal.cancel();
52    });
53
54    // Create reloadable auth backend
55    let auth = Arc::new(ReloadableAuth::new(build_auth(&config.auth)));
56
57    // Set up SIGHUP handler for config reload
58    #[cfg(unix)]
59    {
60        let config_path = args.config.clone();
61        let overrides = args.overrides.clone();
62        let auth_reload = auth.clone();
63        tokio::spawn(async move {
64            reload_signal_handler(config_path, overrides, auth_reload).await;
65        });
66    }
67
68    run_with_shutdown(config, auth, shutdown).await?;
69    Ok(())
70}
71
72/// Wait for shutdown signals (SIGTERM, SIGINT).
73async fn shutdown_signal_handler() {
74    let ctrl_c = async {
75        if let Err(e) = tokio::signal::ctrl_c().await {
76            warn!("failed to listen for Ctrl+C: {}", e);
77            // Fall back to waiting forever
78            std::future::pending::<()>().await;
79        }
80    };
81
82    #[cfg(unix)]
83    let terminate = async {
84        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
85            Ok(mut sig) => {
86                sig.recv().await;
87            }
88            Err(e) => {
89                warn!("failed to listen for SIGTERM: {}", e);
90                // Fall back to waiting forever
91                std::future::pending::<()>().await;
92            }
93        }
94    };
95
96    #[cfg(not(unix))]
97    let terminate = std::future::pending::<()>();
98
99    tokio::select! {
100        _ = ctrl_c => {}
101        _ = terminate => {}
102    }
103}
104
105/// Handle SIGHUP for config reload (Unix only).
106#[cfg(unix)]
107async fn reload_signal_handler(
108    config_path: PathBuf,
109    overrides: CliOverrides,
110    auth: Arc<ReloadableAuth>,
111) {
112    use tokio::signal::unix::{SignalKind, signal};
113
114    let mut sighup = match signal(SignalKind::hangup()) {
115        Ok(sig) => sig,
116        Err(e) => {
117            warn!(
118                "failed to install SIGHUP handler: {}, config reload disabled",
119                e
120            );
121            return;
122        }
123    };
124
125    loop {
126        sighup.recv().await;
127        info!("SIGHUP received, reloading configuration");
128
129        match reload_config(&config_path, &overrides, &auth) {
130            Ok(()) => info!("configuration reloaded successfully"),
131            Err(e) => warn!("failed to reload configuration: {}", e),
132        }
133    }
134}
135
136/// Reload configuration from file.
137#[cfg(unix)]
138fn reload_config(
139    config_path: &PathBuf,
140    overrides: &CliOverrides,
141    auth: &Arc<ReloadableAuth>,
142) -> Result<(), Box<dyn std::error::Error>> {
143    let mut config = load_config(config_path)?;
144    apply_overrides(&mut config, overrides);
145    validate_config(&config)?;
146
147    // Reload auth backend
148    let new_auth = build_auth(&config.auth);
149    auth.reload(new_auth);
150    info!(
151        password_count = config.auth.passwords.len(),
152        user_count = config.auth.users.len(),
153        http = config.auth.http_url.is_some(),
154        "auth reloaded"
155    );
156
157    // Note: TLS certificates and other settings require server restart
158    // Future enhancement: implement TLS cert hot-reload via rustls ResolvesServerCert
159
160    Ok(())
161}
162
163/// Build an auth backend from config.
164///
165/// If `http_url` is set, creates an [`HttpAuth`] backend that delegates to a
166/// remote auth-worker. Otherwise falls back to in-memory password auth.
167fn build_auth(auth: &trojan_config::AuthConfig) -> Box<dyn AuthBackend> {
168    if let Some(ref url) = auth.http_url {
169        let codec = match auth.http_codec.as_deref() {
170            Some("json") => Codec::Json,
171            _ => Codec::Bincode,
172        };
173        info!(url = %url, codec = ?codec, "using HTTP auth backend");
174        Box::new(HttpAuth::new(
175            url.clone(),
176            codec,
177            auth.http_node_token.clone(),
178        ))
179    } else {
180        let mut mem = MemoryAuth::new();
181        for pw in &auth.passwords {
182            mem.add_password(pw, None);
183        }
184        for u in &auth.users {
185            mem.add_password(&u.password, Some(u.id.clone()));
186        }
187        Box::new(mem)
188    }
189}
190
191/// Initialize tracing subscriber with the given logging configuration.
192///
193/// Supports:
194/// - `level`: Base log level (trace, debug, info, warn, error)
195/// - `format`: Output format (json, pretty, compact). Default: pretty
196/// - `output`: Output target (stdout, stderr). Default: stderr
197/// - `filters`: Per-module log level overrides
198fn init_tracing(config: &LoggingConfig) {
199    // Build the env filter from base level and per-module filters
200    let base_level = config.level.as_deref().unwrap_or("info");
201    let mut filter_str = base_level.to_string();
202
203    for (module, level) in &config.filters {
204        filter_str.push(',');
205        filter_str.push_str(module);
206        filter_str.push('=');
207        filter_str.push_str(level);
208    }
209
210    let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
211
212    let format = config.format.as_deref().unwrap_or("pretty");
213    let output = config.output.as_deref().unwrap_or("stderr");
214
215    // Create the subscriber based on format and output
216    match (format, output) {
217        ("json", "stdout") => {
218            tracing_subscriber::registry()
219                .with(filter)
220                .with(fmt::layer().json().with_writer(io::stdout))
221                .init();
222        }
223        ("json", _) => {
224            tracing_subscriber::registry()
225                .with(filter)
226                .with(fmt::layer().json().with_writer(io::stderr))
227                .init();
228        }
229        ("compact", "stdout") => {
230            tracing_subscriber::registry()
231                .with(filter)
232                .with(fmt::layer().compact().with_writer(io::stdout))
233                .init();
234        }
235        ("compact", _) => {
236            tracing_subscriber::registry()
237                .with(filter)
238                .with(fmt::layer().compact().with_writer(io::stderr))
239                .init();
240        }
241        (_, "stdout") => {
242            // pretty is default
243            tracing_subscriber::registry()
244                .with(filter)
245                .with(fmt::layer().with_writer(io::stdout))
246                .init();
247        }
248        _ => {
249            // pretty to stderr is default
250            tracing_subscriber::registry()
251                .with(filter)
252                .with(fmt::layer().with_writer(io::stderr))
253                .init();
254        }
255    }
256}