1use std::io;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use clap::Parser;
11use tracing::{info, warn};
12use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
13use trojan_auth::{
14 AuthBackend, MemoryAuth, ReloadableAuth,
15 http::{Codec, HttpAuth},
16};
17use trojan_config::{CliOverrides, LoggingConfig, apply_overrides, load_config, validate_config};
18
19use crate::{CancellationToken, run_with_shutdown};
20
21#[derive(Parser, Debug, Clone)]
23#[command(name = "trojan-server", version, about = "Trojan server in Rust")]
24pub struct ServerArgs {
25 #[arg(short, long, default_value = "config.toml")]
27 pub config: PathBuf,
28
29 #[command(flatten)]
30 pub overrides: CliOverrides,
31}
32
33pub async fn run(args: ServerArgs) -> Result<(), Box<dyn std::error::Error>> {
38 let mut config = load_config(&args.config)?;
39 apply_overrides(&mut config, &args.overrides);
40 validate_config(&config)?;
41
42 init_tracing(&config.logging);
43
44 let shutdown = CancellationToken::new();
46 let shutdown_signal = shutdown.clone();
47
48 tokio::spawn(async move {
49 shutdown_signal_handler().await;
50 info!("shutdown signal received");
51 shutdown_signal.cancel();
52 });
53
54 let auth = Arc::new(ReloadableAuth::new(build_auth(&config.auth)));
56
57 #[cfg(unix)]
59 {
60 let config_path = args.config.clone();
61 let overrides = args.overrides.clone();
62 let auth_reload = auth.clone();
63 tokio::spawn(async move {
64 reload_signal_handler(config_path, overrides, auth_reload).await;
65 });
66 }
67
68 run_with_shutdown(config, auth, shutdown).await?;
69 Ok(())
70}
71
72async fn shutdown_signal_handler() {
74 let ctrl_c = async {
75 if let Err(e) = tokio::signal::ctrl_c().await {
76 warn!("failed to listen for Ctrl+C: {}", e);
77 std::future::pending::<()>().await;
79 }
80 };
81
82 #[cfg(unix)]
83 let terminate = async {
84 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
85 Ok(mut sig) => {
86 sig.recv().await;
87 }
88 Err(e) => {
89 warn!("failed to listen for SIGTERM: {}", e);
90 std::future::pending::<()>().await;
92 }
93 }
94 };
95
96 #[cfg(not(unix))]
97 let terminate = std::future::pending::<()>();
98
99 tokio::select! {
100 _ = ctrl_c => {}
101 _ = terminate => {}
102 }
103}
104
105#[cfg(unix)]
107async fn reload_signal_handler(
108 config_path: PathBuf,
109 overrides: CliOverrides,
110 auth: Arc<ReloadableAuth>,
111) {
112 use tokio::signal::unix::{SignalKind, signal};
113
114 let mut sighup = match signal(SignalKind::hangup()) {
115 Ok(sig) => sig,
116 Err(e) => {
117 warn!(
118 "failed to install SIGHUP handler: {}, config reload disabled",
119 e
120 );
121 return;
122 }
123 };
124
125 loop {
126 sighup.recv().await;
127 info!("SIGHUP received, reloading configuration");
128
129 match reload_config(&config_path, &overrides, &auth) {
130 Ok(()) => info!("configuration reloaded successfully"),
131 Err(e) => warn!("failed to reload configuration: {}", e),
132 }
133 }
134}
135
136#[cfg(unix)]
138fn reload_config(
139 config_path: &PathBuf,
140 overrides: &CliOverrides,
141 auth: &Arc<ReloadableAuth>,
142) -> Result<(), Box<dyn std::error::Error>> {
143 let mut config = load_config(config_path)?;
144 apply_overrides(&mut config, overrides);
145 validate_config(&config)?;
146
147 let new_auth = build_auth(&config.auth);
149 auth.reload(new_auth);
150 info!(
151 password_count = config.auth.passwords.len(),
152 user_count = config.auth.users.len(),
153 http = config.auth.http_url.is_some(),
154 "auth reloaded"
155 );
156
157 Ok(())
161}
162
163fn build_auth(auth: &trojan_config::AuthConfig) -> Box<dyn AuthBackend> {
168 if let Some(ref url) = auth.http_url {
169 let codec = match auth.http_codec.as_deref() {
170 Some("json") => Codec::Json,
171 _ => Codec::Bincode,
172 };
173 info!(url = %url, codec = ?codec, "using HTTP auth backend");
174 Box::new(HttpAuth::new(
175 url.clone(),
176 codec,
177 auth.http_node_token.clone(),
178 ))
179 } else {
180 let mut mem = MemoryAuth::new();
181 for pw in &auth.passwords {
182 mem.add_password(pw, None);
183 }
184 for u in &auth.users {
185 mem.add_password(&u.password, Some(u.id.clone()));
186 }
187 Box::new(mem)
188 }
189}
190
191fn init_tracing(config: &LoggingConfig) {
199 let base_level = config.level.as_deref().unwrap_or("info");
201 let mut filter_str = base_level.to_string();
202
203 for (module, level) in &config.filters {
204 filter_str.push(',');
205 filter_str.push_str(module);
206 filter_str.push('=');
207 filter_str.push_str(level);
208 }
209
210 let filter = EnvFilter::try_new(&filter_str).unwrap_or_else(|_| EnvFilter::new("info"));
211
212 let format = config.format.as_deref().unwrap_or("pretty");
213 let output = config.output.as_deref().unwrap_or("stderr");
214
215 match (format, output) {
217 ("json", "stdout") => {
218 tracing_subscriber::registry()
219 .with(filter)
220 .with(fmt::layer().json().with_writer(io::stdout))
221 .init();
222 }
223 ("json", _) => {
224 tracing_subscriber::registry()
225 .with(filter)
226 .with(fmt::layer().json().with_writer(io::stderr))
227 .init();
228 }
229 ("compact", "stdout") => {
230 tracing_subscriber::registry()
231 .with(filter)
232 .with(fmt::layer().compact().with_writer(io::stdout))
233 .init();
234 }
235 ("compact", _) => {
236 tracing_subscriber::registry()
237 .with(filter)
238 .with(fmt::layer().compact().with_writer(io::stderr))
239 .init();
240 }
241 (_, "stdout") => {
242 tracing_subscriber::registry()
244 .with(filter)
245 .with(fmt::layer().with_writer(io::stdout))
246 .init();
247 }
248 _ => {
249 tracing_subscriber::registry()
251 .with(filter)
252 .with(fmt::layer().with_writer(io::stderr))
253 .init();
254 }
255 }
256}