1use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use clap::Parser;
11use tracing::{info, warn};
12use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13use trojan_auth::{MemoryAuth, ReloadableAuth};
14use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
15
16use crate::{CancellationToken, run_with_shutdown};
17
18#[derive(Parser, Debug, Clone)]
20#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
21pub struct ServerArgs {
22 #[arg(short, long, default_value = "config.toml")]
24 pub config: PathBuf,
25
26 #[command(flatten)]
27 pub overrides: CliOverrides,
28}
29
30pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
35 let mut config = load_config(&args.config)?;
36 apply_overrides(&mut config, &args.overrides);
37 validate_config(&config)?;
38
39 init_tracing(&config.logging);
40
41 if let Some(listen) = &config.metrics.listen {
42 match trojan_metrics::init_metrics_server(listen) {
43 Ok(_handle) => info!(
44 "metrics server listening on {} (/metrics, /health, /ready)",
45 listen
46 ),
47 Err(e) => warn!("failed to start metrics server: {}", e),
48 }
49 }
50
51 let shutdown = CancellationToken::new();
53 let shutdown_signal = shutdown.clone();
54
55 tokio::spawn(async move {
56 shutdown_signal_handler().await;
57 info!("shutdown signal received");
58 shutdown_signal.cancel();
59 });
60
61 let auth = Arc::new(ReloadableAuth::new(build_memory_auth(&config.auth)));
63
64 #[cfg(unix)]
66 {
67 let config_path = args.config.clone();
68 let overrides = args.overrides.clone();
69 let auth_reload = auth.clone();
70 tokio::spawn(async move {
71 reload_signal_handler(config_path, overrides, auth_reload).await;
72 });
73 }
74
75 run_with_shutdown(config, auth, shutdown).await?;
76 Ok(())
77}
78
79async fn shutdown_signal_handler() {
81 let ctrl_c = async {
82 if let Err(e) = tokio::signal::ctrl_c().await {
83 warn!("failed to listen for Ctrl+C: {}", e);
84 std::future::pending::<()>().await;
86 }
87 };
88
89 #[cfg(unix)]
90 let terminate = async {
91 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
92 Ok(mut sig) => {
93 sig.recv().await;
94 }
95 Err(e) => {
96 warn!("failed to listen for SIGTERM: {}", e);
97 std::future::pending::<()>().await;
99 }
100 }
101 };
102
103 #[cfg(not(unix))]
104 let terminate = std::future::pending::<()>();
105
106 tokio::select! {
107 _ = ctrl_c => {}
108 _ = terminate => {}
109 }
110}
111
112#[cfg(unix)]
114async fn reload_signal_handler(
115 config_path: PathBuf,
116 overrides: CliOverrides,
117 auth: Arc<ReloadableAuth>,
118) {
119 use tokio::signal::unix::{SignalKind, signal};
120
121 let mut sighup = match signal(SignalKind::hangup()) {
122 Ok(sig) => sig,
123 Err(e) => {
124 warn!(
125 "failed to install SIGHUP handler: {}, config reload disabled",
126 e
127 );
128 return;
129 }
130 };
131
132 loop {
133 sighup.recv().await;
134 info!("SIGHUP received, reloading configuration");
135
136 match reload_config(&config_path, &overrides, &auth) {
137 Ok(()) => info!("configuration reloaded successfully"),
138 Err(e) => warn!("failed to reload configuration: {}", e),
139 }
140 }
141}
142
143#[cfg(unix)]
145fn reload_config(
146 config_path: &PathBuf,
147 overrides: &CliOverrides,
148 auth: &Arc<ReloadableAuth>,
149) -> Result<(), Box<dyn std::error::Error>> {
150 let mut config = load_config(config_path)?;
151 apply_overrides(&mut config, overrides);
152 validate_config(&config)?;
153
154 let new_auth = build_memory_auth(&config.auth);
156 auth.reload(new_auth);
157 info!(
158 password_count = config.auth.passwords.len(),
159 user_count = config.auth.users.len(),
160 "auth reloaded"
161 );
162
163 Ok(())
167}
168
169fn build_memory_auth(auth: &trojan_config::AuthConfig) -> MemoryAuth {
171 let mut mem = MemoryAuth::new();
172 for pw in &auth.passwords {
173 mem.add_password(pw, None);
174 }
175 for u in &auth.users {
176 mem.add_password(&u.password, Some(u.id.clone()));
177 }
178 mem
179}
180
181fn init_tracing(config: &LoggingConfig) {
189 let base_level = config.level.as_deref().unwrap_or("info");
191 let mut filter_str = base_level.to_string();
192
193 for (module, level) in &config.filters {
194 filter_str.push(',');
195 filter_str.push_str(module);
196 filter_str.push('=');
197 filter_str.push_str(level);
198 }
199
200 let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
201
202 let format = config.format.as_deref().unwrap_or("pretty");
203 let output = config.output.as_deref().unwrap_or("stderr");
204
205 match (format, output) {
207 ("json", "stdout") => {
208 tracing_subscriber::registry()
209 .with(filter)
210 .with(fmt::layer().json().with_writer(io::stdout))
211 .init();
212 }
213 ("json", _) => {
214 tracing_subscriber::registry()
215 .with(filter)
216 .with(fmt::layer().json().with_writer(io::stderr))
217 .init();
218 }
219 ("compact", "stdout") => {
220 tracing_subscriber::registry()
221 .with(filter)
222 .with(fmt::layer().compact().with_writer(io::stdout))
223 .init();
224 }
225 ("compact", _) => {
226 tracing_subscriber::registry()
227 .with(filter)
228 .with(fmt::layer().compact().with_writer(io::stderr))
229 .init();
230 }
231 (_, "stdout") => {
232 tracing_subscriber::registry()
234 .with(filter)
235 .with(fmt::layer().with_writer(io::stdout))
236 .init();
237 }
238 _ => {
239 tracing_subscriber::registry()
241 .with(filter)
242 .with(fmt::layer().with_writer(io::stderr))
243 .init();
244 }
245 }
246}