1use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use clap::Parser;
11use tracing::{info, warn};
12use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13use trojan_auth::{MemoryAuth, ReloadableAuth};
14use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
15
16use crate::{CancellationToken, run_with_shutdown};
17
18#[derive(Parser, Debug, Clone)]
20#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
21pub struct ServerArgs {
22 #[arg(short, long, default_value = "config.toml")]
24 pub config: PathBuf,
25
26 #[command(flatten)]
27 pub overrides: CliOverrides,
28}
29
30pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
35 let mut config = load_config(&args.config)?;
36 apply_overrides(&mut config, &args.overrides);
37 validate_config(&config)?;
38
39 init_tracing(&config.logging);
40
41 if let Some(listen) = &config.metrics.listen {
42 match trojan_metrics::init_metrics_server(listen) {
43 Ok(_handle) => info!(
44 "metrics server listening on {} (/metrics, /health, /ready)",
45 listen
46 ),
47 Err(e) => warn!("failed to start metrics server: {}", e),
48 }
49 }
50
51 let shutdown = CancellationToken::new();
53 let shutdown_signal = shutdown.clone();
54
55 tokio::spawn(async move {
56 shutdown_signal_handler().await;
57 info!("shutdown signal received");
58 shutdown_signal.cancel();
59 });
60
61 let auth = Arc::new(ReloadableAuth::new(MemoryAuth::from_passwords(
63 &config.auth.passwords,
64 )));
65
66 #[cfg(unix)]
68 {
69 let config_path = args.config.clone();
70 let overrides = args.overrides.clone();
71 let auth_reload = auth.clone();
72 tokio::spawn(async move {
73 reload_signal_handler(config_path, overrides, auth_reload).await;
74 });
75 }
76
77 run_with_shutdown(config, auth, shutdown).await?;
78 Ok(())
79}
80
81async fn shutdown_signal_handler() {
83 let ctrl_c = async {
84 if let Err(e) = tokio::signal::ctrl_c().await {
85 warn!("failed to listen for Ctrl+C: {}", e);
86 std::future::pending::<()>().await;
88 }
89 };
90
91 #[cfg(unix)]
92 let terminate = async {
93 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
94 Ok(mut sig) => {
95 sig.recv().await;
96 }
97 Err(e) => {
98 warn!("failed to listen for SIGTERM: {}", e);
99 std::future::pending::<()>().await;
101 }
102 }
103 };
104
105 #[cfg(not(unix))]
106 let terminate = std::future::pending::<()>();
107
108 tokio::select! {
109 _ = ctrl_c => {}
110 _ = terminate => {}
111 }
112}
113
114#[cfg(unix)]
116async fn reload_signal_handler(
117 config_path: PathBuf,
118 overrides: CliOverrides,
119 auth: Arc<ReloadableAuth>,
120) {
121 use tokio::signal::unix::{SignalKind, signal};
122
123 let mut sighup = match signal(SignalKind::hangup()) {
124 Ok(sig) => sig,
125 Err(e) => {
126 warn!(
127 "failed to install SIGHUP handler: {}, config reload disabled",
128 e
129 );
130 return;
131 }
132 };
133
134 loop {
135 sighup.recv().await;
136 info!("SIGHUP received, reloading configuration");
137
138 match reload_config(&config_path, &overrides, &auth) {
139 Ok(()) => info!("configuration reloaded successfully"),
140 Err(e) => warn!("failed to reload configuration: {}", e),
141 }
142 }
143}
144
145#[cfg(unix)]
147fn reload_config(
148 config_path: &PathBuf,
149 overrides: &CliOverrides,
150 auth: &Arc<ReloadableAuth>,
151) -> Result<(), Box<dyn std::error::Error>> {
152 let mut config = load_config(config_path)?;
153 apply_overrides(&mut config, overrides);
154 validate_config(&config)?;
155
156 let new_auth = MemoryAuth::from_passwords(&config.auth.passwords);
158 auth.reload(new_auth);
159 info!(
160 password_count = config.auth.passwords.len(),
161 "auth passwords reloaded"
162 );
163
164 Ok(())
168}
169
170fn init_tracing(config: &LoggingConfig) {
178 let base_level = config.level.as_deref().unwrap_or("info");
180 let mut filter_str = base_level.to_string();
181
182 for (module, level) in &config.filters {
183 filter_str.push(',');
184 filter_str.push_str(module);
185 filter_str.push('=');
186 filter_str.push_str(level);
187 }
188
189 let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
190
191 let format = config.format.as_deref().unwrap_or("pretty");
192 let output = config.output.as_deref().unwrap_or("stderr");
193
194 match (format, output) {
196 ("json", "stdout") => {
197 tracing_subscriber::registry()
198 .with(filter)
199 .with(fmt::layer().json().with_writer(io::stdout))
200 .init();
201 }
202 ("json", _) => {
203 tracing_subscriber::registry()
204 .with(filter)
205 .with(fmt::layer().json().with_writer(io::stderr))
206 .init();
207 }
208 ("compact", "stdout") => {
209 tracing_subscriber::registry()
210 .with(filter)
211 .with(fmt::layer().compact().with_writer(io::stdout))
212 .init();
213 }
214 ("compact", _) => {
215 tracing_subscriber::registry()
216 .with(filter)
217 .with(fmt::layer().compact().with_writer(io::stderr))
218 .init();
219 }
220 (_, "stdout") => {
221 tracing_subscriber::registry()
223 .with(filter)
224 .with(fmt::layer().with_writer(io::stdout))
225 .init();
226 }
227 _ => {
228 tracing_subscriber::registry()
230 .with(filter)
231 .with(fmt::layer().with_writer(io::stderr))
232 .init();
233 }
234 }
235}