fbc_starter/
config.rs

1use serde::{Deserialize, Serialize};
2use std::net::SocketAddr;
3
4/// 应用配置
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct Config {
7    /// 服务器配置
8    pub server: ServerConfig,
9    /// 日志配置
10    pub log: LogConfig,
11    /// CORS 配置
12    pub cors: CorsConfig,
13    /// 数据库配置(可选,需要启用 database 特性)
14    #[serde(default)]
15    #[cfg(feature = "database")]
16    pub database: Option<DatabaseConfig>,
17    /// Redis 配置(可选,需要启用 redis 特性)
18    #[serde(default)]
19    #[cfg(feature = "redis")]
20    pub redis: Option<RedisConfig>,
21}
22
23/// 服务器配置
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ServerConfig {
26    /// 监听地址
27    pub addr: String,
28    /// 端口
29    pub port: u16,
30    /// 工作线程数(0 表示使用默认值)
31    pub workers: Option<usize>,
32    /// 上下文路径(可选),例如 "/api",如果不配置则为空
33    #[serde(default)]
34    pub context_path: Option<String>,
35}
36
37impl ServerConfig {
38    /// 获取完整的 SocketAddr
39    pub fn socket_addr(&self) -> Result<SocketAddr, std::net::AddrParseError> {
40        format!("{}:{}", self.addr, self.port).parse()
41    }
42}
43
44/// 日志配置
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LogConfig {
47    /// 日志级别 (trace, debug, info, warn, error)
48    pub level: String,
49    /// 是否使用 JSON 格式
50    pub json: bool,
51}
52
53/// CORS 配置
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct CorsConfig {
56    /// 允许的源(* 表示允许所有)
57    pub allowed_origins: Vec<String>,
58    /// 允许的方法
59    pub allowed_methods: Vec<String>,
60    /// 允许的请求头
61    pub allowed_headers: Vec<String>,
62    /// 是否允许凭证
63    pub allow_credentials: bool,
64}
65
66/// 数据库配置(需要启用 database 特性)
67#[cfg(feature = "database")]
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DatabaseConfig {
70    /// 数据库 URL(例如:postgres://user:password@localhost/dbname)
71    pub url: String,
72    /// 最大连接数
73    #[serde(default = "default_max_connections")]
74    pub max_connections: u32,
75    /// 最小连接数
76    #[serde(default = "default_min_connections")]
77    pub min_connections: u32,
78}
79
80#[cfg(feature = "database")]
81fn default_max_connections() -> u32 {
82    100
83}
84
85#[cfg(feature = "database")]
86fn default_min_connections() -> u32 {
87    10
88}
89
90/// Redis 配置(需要启用 redis 特性)
91#[cfg(feature = "redis")]
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RedisConfig {
94    /// Redis URL(例如:redis://127.0.0.1:6379 或 redis://:password@127.0.0.1:6379)
95    pub url: String,
96    /// Redis 密码(可选,如果 URL 中已包含密码则不需要)
97    #[serde(default)]
98    pub password: Option<String>,
99    /// 连接池大小
100    #[serde(default = "default_pool_size")]
101    pub pool_size: usize,
102}
103
104#[cfg(feature = "redis")]
105fn default_pool_size() -> usize {
106    10
107}
108
109impl Default for Config {
110    fn default() -> Self {
111        Self {
112            server: ServerConfig {
113                addr: "127.0.0.1".to_string(),
114                port: 3000,
115                workers: None,
116                context_path: None,
117            },
118            log: LogConfig {
119                level: "info".to_string(),
120                json: false,
121            },
122            cors: CorsConfig {
123                allowed_origins: vec!["*".to_string()],
124                allowed_methods: vec![
125                    "GET".to_string(),
126                    "POST".to_string(),
127                    "PUT".to_string(),
128                    "DELETE".to_string(),
129                    "PATCH".to_string(),
130                    "OPTIONS".to_string(),
131                ],
132                allowed_headers: vec!["*".to_string()],
133                // 注意:当 allowed_origins 或 allowed_headers 为 * 时,allow_credentials 会自动设置为 false
134                allow_credentials: false,
135            },
136            #[cfg(feature = "database")]
137            database: None,
138            #[cfg(feature = "redis")]
139            redis: None,
140        }
141    }
142}
143
144impl Config {
145    /// 从 .env 文件和环境变量加载配置
146    ///
147    /// 配置项命名规则:
148    /// - APP__SERVER__ADDR -> server.addr
149    /// - APP__SERVER__PORT -> server.port
150    /// - APP__SERVER__CONTEXT_PATH -> server.context_path (可选,例如 "/api")
151    /// - APP__LOG__LEVEL -> log.level
152    /// - APP__LOG__JSON -> log.json
153    /// - APP__CORS__ALLOWED_ORIGINS -> cors.allowed_origins (逗号分隔)
154    /// - APP__CORS__ALLOWED_METHODS -> cors.allowed_methods (逗号分隔)
155    /// - APP__CORS__ALLOWED_HEADERS -> cors.allowed_headers (逗号分隔)
156    /// - APP__CORS__ALLOW_CREDENTIALS -> cors.allow_credentials
157    /// - APP__DATABASE__URL -> database.url (可选,需要启用 database 特性)
158    /// - APP__DATABASE__MAX_CONNECTIONS -> database.max_connections (可选,默认 100)
159    /// - APP__DATABASE__MIN_CONNECTIONS -> database.min_connections (可选,默认 10)
160    /// - APP__REDIS__URL -> redis.url (可选,需要启用 redis 特性)
161    /// - APP__REDIS__PASSWORD -> redis.password (可选,如果 URL 中已包含密码则不需要)
162    /// - APP__REDIS__POOL_SIZE -> redis.pool_size (可选,默认 10)
163    /// 查找项目根目录(通过查找 Cargo.toml 或 .env 文件)
164    ///
165    /// 查找策略(按优先级):
166    /// 1. 从可执行文件路径推断项目目录(例如 target/debug/im-server -> im-server/)
167    /// 2. 从可执行文件所在目录向上查找 .env 文件
168    /// 3. 从当前工作目录向上查找 .env 文件
169    fn find_project_root() -> Option<std::path::PathBuf> {
170        // 策略 1: 从可执行文件路径推断项目目录
171        // 例如:/path/to/hula-server/target/debug/im-server -> /path/to/hula-server/im-server/
172        if let Ok(exe_path) = std::env::current_exe() {
173            // 获取可执行文件名(例如 "im-server")
174            if let Some(exe_name) = exe_path.file_stem().and_then(|s| s.to_str()) {
175                // 从可执行文件路径向上查找,直到找到 workspace 根目录或项目根目录
176                if let Some(exe_dir) = exe_path.parent() {
177                    let mut path = exe_dir.to_path_buf();
178                    loop {
179                        // 检查当前目录的父目录是否包含与可执行文件同名的目录
180                        if let Some(parent) = path.parent() {
181                            let project_dir = parent.join(exe_name);
182                            // 如果找到同名目录且包含 .env 文件,这就是项目根目录
183                            if project_dir.join(".env").exists() {
184                                return Some(project_dir);
185                            }
186                            // 如果找到同名目录且包含 Cargo.toml(非 workspace),这也是项目根目录
187                            let cargo_toml = project_dir.join("Cargo.toml");
188                            if cargo_toml.exists() {
189                                if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
190                                    if !content.contains("[workspace]") {
191                                        return Some(project_dir);
192                                    }
193                                }
194                            }
195                        }
196                        // 检查当前目录是否有 .env 文件
197                        if path.join(".env").exists() {
198                            return Some(path);
199                        }
200                        // 向上查找
201                        match path.parent() {
202                            Some(parent) => path = parent.to_path_buf(),
203                            None => break,
204                        }
205                    }
206                }
207            }
208        }
209
210        // 策略 2: 从当前工作目录向上查找 .env 文件
211        if let Ok(mut current_dir) = std::env::current_dir() {
212            loop {
213                if current_dir.join(".env").exists() {
214                    // 检查是否是 workspace 根目录
215                    let cargo_toml = current_dir.join("Cargo.toml");
216                    if cargo_toml.exists() {
217                        if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
218                            if content.contains("[workspace]") {
219                                // 这是 workspace 根目录,但找到了 .env,返回当前目录
220                                return Some(current_dir);
221                            }
222                        }
223                    }
224                    return Some(current_dir);
225                }
226                match current_dir.parent() {
227                    Some(parent) => current_dir = parent.to_path_buf(),
228                    None => break,
229                }
230            }
231        }
232
233        None
234    }
235
236    pub fn from_env() -> Result<Self, config::ConfigError> {
237        // 加载 .env 文件(如果存在)
238        // 优先从项目根目录加载,确保当库被其他项目使用时能正确找到 .env 文件
239
240        // 方法 1: 使用 CARGO_MANIFEST_DIR 环境变量(编译时设置,最可靠)
241        if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
242            let env_path = std::path::Path::new(&manifest_dir).join(".env");
243            if env_path.exists() {
244                if dotenvy::from_path(&env_path).is_ok() {
245                    eprintln!(
246                        "✓ 从 CARGO_MANIFEST_DIR 加载 .env 文件: {}",
247                        env_path.display()
248                    );
249                    return Self::load_config_from_env();
250                }
251            }
252        }
253
254        // 方法 2: 从项目根目录加载(通过查找逻辑)
255        if let Some(project_root) = Self::find_project_root() {
256            let env_path = project_root.join(".env");
257            if env_path.exists() {
258                if dotenvy::from_path(&env_path).is_ok() {
259                    eprintln!("✓ 从项目根目录加载 .env 文件: {}", env_path.display());
260                    return Self::load_config_from_env();
261                }
262            }
263        }
264
265        // 方法 3: 使用 dotenvy::dotenv() 从当前工作目录向上查找(备用方案)
266        match dotenvy::dotenv() {
267            Ok(path) => {
268                eprintln!("✓ 从当前工作目录向上查找加载 .env 文件: {}", path.display());
269            }
270            Err(_) => {
271                eprintln!("⚠ 未找到 .env 文件,将使用环境变量和默认配置");
272            }
273        }
274
275        Self::load_config_from_env()
276    }
277
278    /// 从环境变量加载配置(内部方法)
279    fn load_config_from_env() -> Result<Self, config::ConfigError> {
280        // 先手动处理数组类型的配置项,设置默认值
281        let mut default_origins = vec!["*".to_string()];
282        let mut default_methods = vec![
283            "GET".to_string(),
284            "POST".to_string(),
285            "PUT".to_string(),
286            "DELETE".to_string(),
287            "PATCH".to_string(),
288            "OPTIONS".to_string(),
289        ];
290        let mut default_headers = vec!["*".to_string()];
291
292        // 从环境变量读取数组配置
293        if let Ok(origins_str) = std::env::var("APP__CORS__ALLOWED_ORIGINS") {
294            default_origins = origins_str
295                .split(',')
296                .map(|s| s.trim().to_string())
297                .collect();
298        }
299
300        if let Ok(methods_str) = std::env::var("APP__CORS__ALLOWED_METHODS") {
301            default_methods = methods_str
302                .split(',')
303                .map(|s| s.trim().to_string())
304                .collect();
305        }
306
307        if let Ok(headers_str) = std::env::var("APP__CORS__ALLOWED_HEADERS") {
308            default_headers = headers_str
309                .split(',')
310                .map(|s| s.trim().to_string())
311                .collect();
312        }
313
314        // 临时移除这些环境变量,避免 config crate 尝试解析它们
315        let origins_backup = std::env::var("APP__CORS__ALLOWED_ORIGINS").ok();
316        let methods_backup = std::env::var("APP__CORS__ALLOWED_METHODS").ok();
317        let headers_backup = std::env::var("APP__CORS__ALLOWED_HEADERS").ok();
318
319        if origins_backup.is_some() {
320            std::env::remove_var("APP__CORS__ALLOWED_ORIGINS");
321        }
322        if methods_backup.is_some() {
323            std::env::remove_var("APP__CORS__ALLOWED_METHODS");
324        }
325        if headers_backup.is_some() {
326            std::env::remove_var("APP__CORS__ALLOWED_HEADERS");
327        }
328
329        let builder = config::Config::builder()
330            .set_default("server.addr", "127.0.0.1")?
331            .set_default("server.port", 3000)?
332            .set_default("log.level", "info")?
333            .set_default("log.json", false)?
334            .set_default("cors.allowed_origins", default_origins.clone())?
335            .set_default("cors.allowed_methods", default_methods.clone())?
336            .set_default("cors.allowed_headers", default_headers.clone())?
337            .set_default("cors.allow_credentials", false)?
338            .add_source(config::Environment::with_prefix("APP").separator("__"));
339
340        let config = builder.build()?;
341        let result: Config = config.try_deserialize()?;
342
343        Ok(result)
344    }
345}