Skip to main content

trojan_server/
cli.rs

1//! CLI module for trojan-server.
2//!
3//! This module provides the command-line interface that can be used either
4//! as a standalone binary or as a subcommand of the main trojan-rs CLI.
5
6use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Duration;
10
11use clap::Parser;
12use tracing::{info, warn};
13use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
14
15use trojan_auth::{
16    AuthBackend, MemoryAuth, ReloadableAuth,
17    http::{Codec, HttpAuth, HttpAuthConfig},
18};
19use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
20
21use crate::{CancellationToken, run_with_shutdown};
22
23/// Trojan server CLI arguments.
24#[derive(Parser, Debug, Clone)]
25#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
26pub struct ServerArgs {
27    /// Config file path (json/jsonc/yaml/toml)
28    #[arg(short, long, default_value = "config.toml")]
29    pub config: PathBuf,
30
31    #[command(flatten)]
32    pub overrides: CliOverrides,
33}
34
35/// Run the trojan server with the given arguments.
36///
37/// This is the main entry point for the server CLI, used by both the
38/// standalone binary and the unified trojan-rs CLI.
39pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
40    let mut config = load_config(&args.config)?;
41    apply_overrides(&mut config, &args.overrides);
42    validate_config(&config)?;
43
44    init_tracing(&config.logging);
45
46    // Set up graceful shutdown on SIGTERM/SIGINT
47    let shutdown = CancellationToken::new();
48    let shutdown_signal = shutdown.clone();
49
50    tokio::spawn(async move {
51        shutdown_signal_handler().await;
52        info!("shutdown signal received");
53        shutdown_signal.cancel();
54    });
55
56    // Create reloadable auth backend
57    let auth = Arc::new(ReloadableAuth::new(build_auth(&config.auth)));
58
59    // Set up SIGHUP handler for config reload
60    #[cfg(unix)]
61    {
62        let config_path = args.config.clone();
63        let overrides = args.overrides.clone();
64        let auth_reload = auth.clone();
65        tokio::spawn(async move {
66            reload_signal_handler(config_path, overrides, auth_reload).await;
67        });
68    }
69
70    run_with_shutdown(config, auth, shutdown).await?;
71    Ok(())
72}
73
74/// Wait for shutdown signals (SIGTERM, SIGINT).
75async fn shutdown_signal_handler() {
76    let ctrl_c = async {
77        if let Err(e) = tokio::signal::ctrl_c().await {
78            warn!("failed to listen for Ctrl+C: {}", e);
79            // Fall back to waiting forever
80            std::future::pending::<()>().await;
81        }
82    };
83
84    #[cfg(unix)]
85    let terminate = async {
86        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
87            Ok(mut sig) => {
88                sig.recv().await;
89            }
90            Err(e) => {
91                warn!("failed to listen for SIGTERM: {}", e);
92                // Fall back to waiting forever
93                std::future::pending::<()>().await;
94            }
95        }
96    };
97
98    #[cfg(not(unix))]
99    let terminate = std::future::pending::<()>();
100
101    tokio::select! {
102        _ = ctrl_c => {}
103        _ = terminate => {}
104    }
105}
106
107/// Handle SIGHUP for config reload (Unix only).
108#[cfg(unix)]
109async fn reload_signal_handler(
110    config_path: PathBuf,
111    overrides: CliOverrides,
112    auth: Arc<ReloadableAuth>,
113) {
114    use tokio::signal::unix::{SignalKind, signal};
115
116    let mut sighup = match signal(SignalKind::hangup()) {
117        Ok(sig) => sig,
118        Err(e) => {
119            warn!(
120                "failed to install SIGHUP handler: {}, config reload disabled",
121                e
122            );
123            return;
124        }
125    };
126
127    loop {
128        sighup.recv().await;
129        info!("SIGHUP received, reloading configuration");
130
131        match reload_config(&config_path, &overrides, &auth) {
132            Ok(()) => info!("configuration reloaded successfully"),
133            Err(e) => warn!("failed to reload configuration: {}", e),
134        }
135    }
136}
137
138/// Reload configuration from file.
139#[cfg(unix)]
140fn reload_config(
141    config_path: &PathBuf,
142    overrides: &CliOverrides,
143    auth: &Arc<ReloadableAuth>,
144) -> Result<(), Box<dyn std::error::Error>> {
145    let mut config = load_config(config_path)?;
146    apply_overrides(&mut config, overrides);
147    validate_config(&config)?;
148
149    // Reload auth backend
150    let new_auth = build_auth(&config.auth);
151    auth.reload(new_auth);
152    info!(
153        password_count = config.auth.passwords.len(),
154        user_count = config.auth.users.len(),
155        http = config.auth.http_url.is_some(),
156        "auth reloaded"
157    );
158
159    // Note: TLS certificates and other settings require server restart
160    // Future enhancement: implement TLS cert hot-reload via rustls ResolvesServerCert
161
162    Ok(())
163}
164
165/// Build an auth backend from config.
166///
167/// If `http_url` is set, creates an [`HttpAuth`] backend that delegates to a
168/// remote auth-worker. Otherwise falls back to in-memory password auth.
169fn build_auth(auth: &trojan_config::AuthConfig) -> Box<dyn AuthBackend> {
170    if let Some(ref url) = auth.http_url {
171        let codec = match auth.http_codec.as_deref() {
172            Some("json") => Codec::Json,
173            _ => Codec::Bincode,
174        };
175        info!(
176            url = %url,
177            codec = ?codec,
178            cache_ttl = auth.http_cache_ttl_secs,
179            stale_ttl = auth.http_cache_stale_ttl_secs,
180            neg_cache_ttl = auth.http_cache_neg_ttl_secs,
181            "using HTTP auth backend"
182        );
183        let config = HttpAuthConfig {
184            base_url: url.clone(),
185            codec,
186            node_token: auth.http_node_token.clone(),
187            cache_ttl: Duration::from_secs(auth.http_cache_ttl_secs),
188            stale_ttl: Duration::from_secs(auth.http_cache_stale_ttl_secs),
189            neg_cache_ttl: Duration::from_secs(auth.http_cache_neg_ttl_secs),
190        };
191        Box::new(HttpAuth::new(config))
192    } else {
193        let mut mem = MemoryAuth::new();
194        for pw in &auth.passwords {
195            mem.add_password(pw, None);
196        }
197        for u in &auth.users {
198            mem.add_password(&u.password, Some(u.id.clone()));
199        }
200        Box::new(mem)
201    }
202}
203
204/// Initialize tracing subscriber with the given logging configuration.
205///
206/// Supports:
207/// - `level`: Base log level (trace, debug, info, warn, error)
208/// - `format`: Output format (json, pretty, compact). Default: pretty
209/// - `output`: Output target (stdout, stderr). Default: stderr
210/// - `filters`: Per-module log level overrides
211fn init_tracing(config: &LoggingConfig) {
212    // Build the env filter from base level and per-module filters
213    let base_level = config.level.as_deref().unwrap_or("info");
214    let mut filter_str = base_level.to_string();
215
216    for (module, level) in &config.filters {
217        filter_str.push(',');
218        filter_str.push_str(module);
219        filter_str.push('=');
220        filter_str.push_str(level);
221    }
222
223    let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
224
225    let format = config.format.as_deref().unwrap_or("pretty");
226    let output = config.output.as_deref().unwrap_or("stderr");
227
228    // Create the subscriber based on format and output
229    match (format, output) {
230        ("json", "stdout") => {
231            tracing_subscriber::registry()
232                .with(filter)
233                .with(fmt::layer().json().with_writer(io::stdout))
234                .init();
235        }
236        ("json", _) => {
237            tracing_subscriber::registry()
238                .with(filter)
239                .with(fmt::layer().json().with_writer(io::stderr))
240                .init();
241        }
242        ("compact", "stdout") => {
243            tracing_subscriber::registry()
244                .with(filter)
245                .with(fmt::layer().compact().with_writer(io::stdout))
246                .init();
247        }
248        ("compact", _) => {
249            tracing_subscriber::registry()
250                .with(filter)
251                .with(fmt::layer().compact().with_writer(io::stderr))
252                .init();
253        }
254        (_, "stdout") => {
255            // pretty is default
256            tracing_subscriber::registry()
257                .with(filter)
258                .with(fmt::layer().with_writer(io::stdout))
259                .init();
260        }
261        _ => {
262            // pretty to stderr is default
263            tracing_subscriber::registry()
264                .with(filter)
265                .with(fmt::layer().with_writer(io::stderr))
266                .init();
267        }
268    }
269}