1use serde::{Deserialize, Serialize};
2use std::net::SocketAddr;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct Config {
7 pub server: ServerConfig,
9 pub log: LogConfig,
11 pub cors: CorsConfig,
13 #[serde(default)]
15 #[cfg(feature = "database")]
16 pub database: Option<DatabaseConfig>,
17 #[serde(default)]
19 #[cfg(feature = "redis")]
20 pub redis: Option<RedisConfig>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ServerConfig {
26 pub addr: String,
28 pub port: u16,
30 pub workers: Option<usize>,
32 #[serde(default)]
34 pub context_path: Option<String>,
35}
36
37impl ServerConfig {
38 pub fn socket_addr(&self) -> Result<SocketAddr, std::net::AddrParseError> {
40 format!("{}:{}", self.addr, self.port).parse()
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LogConfig {
47 pub level: String,
49 pub json: bool,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct CorsConfig {
56 pub allowed_origins: Vec<String>,
58 pub allowed_methods: Vec<String>,
60 pub allowed_headers: Vec<String>,
62 pub allow_credentials: bool,
64}
65
66#[cfg(feature = "database")]
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DatabaseConfig {
70 pub url: String,
72 #[serde(default = "default_max_connections")]
74 pub max_connections: u32,
75 #[serde(default = "default_min_connections")]
77 pub min_connections: u32,
78}
79
80#[cfg(feature = "database")]
81fn default_max_connections() -> u32 {
82 100
83}
84
85#[cfg(feature = "database")]
86fn default_min_connections() -> u32 {
87 10
88}
89
90#[cfg(feature = "redis")]
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RedisConfig {
94 pub url: String,
96 #[serde(default)]
98 pub password: Option<String>,
99 #[serde(default = "default_pool_size")]
101 pub pool_size: usize,
102}
103
104#[cfg(feature = "redis")]
105fn default_pool_size() -> usize {
106 10
107}
108
109impl Default for Config {
110 fn default() -> Self {
111 Self {
112 server: ServerConfig {
113 addr: "127.0.0.1".to_string(),
114 port: 3000,
115 workers: None,
116 context_path: None,
117 },
118 log: LogConfig {
119 level: "info".to_string(),
120 json: false,
121 },
122 cors: CorsConfig {
123 allowed_origins: vec!["*".to_string()],
124 allowed_methods: vec![
125 "GET".to_string(),
126 "POST".to_string(),
127 "PUT".to_string(),
128 "DELETE".to_string(),
129 "PATCH".to_string(),
130 "OPTIONS".to_string(),
131 ],
132 allowed_headers: vec!["*".to_string()],
133 allow_credentials: false,
135 },
136 #[cfg(feature = "database")]
137 database: None,
138 #[cfg(feature = "redis")]
139 redis: None,
140 }
141 }
142}
143
144impl Config {
145 fn find_project_root() -> Option<std::path::PathBuf> {
170 if let Ok(exe_path) = std::env::current_exe() {
173 if let Some(exe_name) = exe_path.file_stem().and_then(|s| s.to_str()) {
175 if let Some(exe_dir) = exe_path.parent() {
177 let mut path = exe_dir.to_path_buf();
178 loop {
179 if let Some(parent) = path.parent() {
181 let project_dir = parent.join(exe_name);
182 if project_dir.join(".env").exists() {
184 return Some(project_dir);
185 }
186 let cargo_toml = project_dir.join("Cargo.toml");
188 if cargo_toml.exists() {
189 if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
190 if !content.contains("[workspace]") {
191 return Some(project_dir);
192 }
193 }
194 }
195 }
196 if path.join(".env").exists() {
198 return Some(path);
199 }
200 match path.parent() {
202 Some(parent) => path = parent.to_path_buf(),
203 None => break,
204 }
205 }
206 }
207 }
208 }
209
210 if let Ok(mut current_dir) = std::env::current_dir() {
212 loop {
213 if current_dir.join(".env").exists() {
214 let cargo_toml = current_dir.join("Cargo.toml");
216 if cargo_toml.exists() {
217 if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
218 if content.contains("[workspace]") {
219 return Some(current_dir);
221 }
222 }
223 }
224 return Some(current_dir);
225 }
226 match current_dir.parent() {
227 Some(parent) => current_dir = parent.to_path_buf(),
228 None => break,
229 }
230 }
231 }
232
233 None
234 }
235
236 pub fn from_env() -> Result<Self, config::ConfigError> {
237 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
242 let env_path = std::path::Path::new(&manifest_dir).join(".env");
243 if env_path.exists() {
244 if dotenvy::from_path(&env_path).is_ok() {
245 eprintln!(
246 "✓ 从 CARGO_MANIFEST_DIR 加载 .env 文件: {}",
247 env_path.display()
248 );
249 return Self::load_config_from_env();
250 }
251 }
252 }
253
254 if let Some(project_root) = Self::find_project_root() {
256 let env_path = project_root.join(".env");
257 if env_path.exists() {
258 if dotenvy::from_path(&env_path).is_ok() {
259 eprintln!("✓ 从项目根目录加载 .env 文件: {}", env_path.display());
260 return Self::load_config_from_env();
261 }
262 }
263 }
264
265 match dotenvy::dotenv() {
267 Ok(path) => {
268 eprintln!("✓ 从当前工作目录向上查找加载 .env 文件: {}", path.display());
269 }
270 Err(_) => {
271 eprintln!("⚠ 未找到 .env 文件,将使用环境变量和默认配置");
272 }
273 }
274
275 Self::load_config_from_env()
276 }
277
278 fn load_config_from_env() -> Result<Self, config::ConfigError> {
280 let mut default_origins = vec!["*".to_string()];
282 let mut default_methods = vec![
283 "GET".to_string(),
284 "POST".to_string(),
285 "PUT".to_string(),
286 "DELETE".to_string(),
287 "PATCH".to_string(),
288 "OPTIONS".to_string(),
289 ];
290 let mut default_headers = vec!["*".to_string()];
291
292 if let Ok(origins_str) = std::env::var("APP__CORS__ALLOWED_ORIGINS") {
294 default_origins = origins_str
295 .split(',')
296 .map(|s| s.trim().to_string())
297 .collect();
298 }
299
300 if let Ok(methods_str) = std::env::var("APP__CORS__ALLOWED_METHODS") {
301 default_methods = methods_str
302 .split(',')
303 .map(|s| s.trim().to_string())
304 .collect();
305 }
306
307 if let Ok(headers_str) = std::env::var("APP__CORS__ALLOWED_HEADERS") {
308 default_headers = headers_str
309 .split(',')
310 .map(|s| s.trim().to_string())
311 .collect();
312 }
313
314 let origins_backup = std::env::var("APP__CORS__ALLOWED_ORIGINS").ok();
316 let methods_backup = std::env::var("APP__CORS__ALLOWED_METHODS").ok();
317 let headers_backup = std::env::var("APP__CORS__ALLOWED_HEADERS").ok();
318
319 if origins_backup.is_some() {
320 std::env::remove_var("APP__CORS__ALLOWED_ORIGINS");
321 }
322 if methods_backup.is_some() {
323 std::env::remove_var("APP__CORS__ALLOWED_METHODS");
324 }
325 if headers_backup.is_some() {
326 std::env::remove_var("APP__CORS__ALLOWED_HEADERS");
327 }
328
329 let builder = config::Config::builder()
330 .set_default("server.addr", "127.0.0.1")?
331 .set_default("server.port", 3000)?
332 .set_default("log.level", "info")?
333 .set_default("log.json", false)?
334 .set_default("cors.allowed_origins", default_origins.clone())?
335 .set_default("cors.allowed_methods", default_methods.clone())?
336 .set_default("cors.allowed_headers", default_headers.clone())?
337 .set_default("cors.allow_credentials", false)?
338 .add_source(config::Environment::with_prefix("APP").separator("__"));
339
340 let config = builder.build()?;
341 let result: Config = config.try_deserialize()?;
342
343 Ok(result)
344 }
345}