fbc_starter/
config.rs

1use serde::{Deserialize, Serialize};
2use std::net::SocketAddr;
3
4/// 应用配置
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct Config {
7    /// 服务器配置
8    pub server: ServerConfig,
9    /// 日志配置
10    pub log: LogConfig,
11    /// CORS 配置
12    pub cors: CorsConfig,
13    /// 数据库配置(可选,需要启用 database 特性)
14    #[serde(default)]
15    #[cfg(feature = "database")]
16    pub database: Option<DatabaseConfig>,
17    /// Redis 配置(可选,需要启用 redis 特性)
18    #[serde(default)]
19    #[cfg(feature = "redis")]
20    pub redis: Option<RedisConfig>,
21}
22
23/// 服务器配置
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ServerConfig {
26    /// 监听地址
27    pub addr: String,
28    /// 端口
29    pub port: u16,
30    /// 工作线程数(0 表示使用默认值)
31    pub workers: Option<usize>,
32    /// 上下文路径(可选),例如 "/api",如果不配置则为空
33    #[serde(default)]
34    pub context_path: Option<String>,
35}
36
37impl ServerConfig {
38    /// 获取完整的 SocketAddr
39    pub fn socket_addr(&self) -> Result<SocketAddr, std::net::AddrParseError> {
40        format!("{}:{}", self.addr, self.port).parse()
41    }
42}
43
44/// 日志配置
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LogConfig {
47    /// 日志级别 (trace, debug, info, warn, error)
48    pub level: String,
49    /// 是否使用 JSON 格式
50    pub json: bool,
51}
52
53/// CORS 配置
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct CorsConfig {
56    /// 允许的源(* 表示允许所有)
57    pub allowed_origins: Vec<String>,
58    /// 允许的方法
59    pub allowed_methods: Vec<String>,
60    /// 允许的请求头
61    pub allowed_headers: Vec<String>,
62    /// 是否允许凭证
63    pub allow_credentials: bool,
64}
65
66/// 数据库配置(需要启用 database 特性)
67#[cfg(feature = "database")]
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DatabaseConfig {
70    /// 数据库 URL(例如:postgres://user:password@localhost/dbname)
71    pub url: String,
72    /// 最大连接数
73    #[serde(default = "default_max_connections")]
74    pub max_connections: u32,
75    /// 最小连接数
76    #[serde(default = "default_min_connections")]
77    pub min_connections: u32,
78}
79
80#[cfg(feature = "database")]
81fn default_max_connections() -> u32 {
82    100
83}
84
85#[cfg(feature = "database")]
86fn default_min_connections() -> u32 {
87    10
88}
89
90/// Redis 配置(需要启用 redis 特性)
91#[cfg(feature = "redis")]
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RedisConfig {
94    /// Redis URL(例如:redis://127.0.0.1:6379 或 redis://:password@127.0.0.1:6379)
95    pub url: String,
96    /// Redis 密码(可选,如果 URL 中已包含密码则不需要)
97    #[serde(default)]
98    pub password: Option<String>,
99    /// 连接池大小
100    #[serde(default = "default_pool_size")]
101    pub pool_size: usize,
102}
103
104#[cfg(feature = "redis")]
105fn default_pool_size() -> usize {
106    10
107}
108
109impl Default for Config {
110    fn default() -> Self {
111        Self {
112            server: ServerConfig {
113                addr: "0.0.0.0".to_string(),
114                port: 3000,
115                workers: None,
116                context_path: None,
117            },
118            log: LogConfig {
119                level: "info".to_string(),
120                json: false,
121            },
122            cors: CorsConfig {
123                allowed_origins: vec!["*".to_string()],
124                allowed_methods: vec![
125                    "GET".to_string(),
126                    "POST".to_string(),
127                    "PUT".to_string(),
128                    "DELETE".to_string(),
129                    "PATCH".to_string(),
130                    "OPTIONS".to_string(),
131                ],
132                allowed_headers: vec!["*".to_string()],
133                // 注意:当 allowed_origins 或 allowed_headers 为 * 时,allow_credentials 会自动设置为 false
134                allow_credentials: false,
135            },
136            #[cfg(feature = "database")]
137            database: None,
138            #[cfg(feature = "redis")]
139            redis: None,
140        }
141    }
142}
143
144impl Config {
145    /// 从 .env 文件和环境变量加载配置
146    ///
147    /// 配置项命名规则:
148    /// - APP__SERVER__ADDR -> server.addr
149    /// - APP__SERVER__PORT -> server.port
150    /// - APP__SERVER__CONTEXT_PATH -> server.context_path (可选,例如 "/api")
151    /// - APP__LOG__LEVEL -> log.level
152    /// - APP__LOG__JSON -> log.json
153    /// - APP__CORS__ALLOWED_ORIGINS -> cors.allowed_origins (逗号分隔)
154    /// - APP__CORS__ALLOWED_METHODS -> cors.allowed_methods (逗号分隔)
155    /// - APP__CORS__ALLOWED_HEADERS -> cors.allowed_headers (逗号分隔)
156    /// - APP__CORS__ALLOW_CREDENTIALS -> cors.allow_credentials
157    /// - APP__DATABASE__URL -> database.url (可选,需要启用 database 特性)
158    /// - APP__DATABASE__MAX_CONNECTIONS -> database.max_connections (可选,默认 100)
159    /// - APP__DATABASE__MIN_CONNECTIONS -> database.min_connections (可选,默认 10)
160    /// - APP__REDIS__URL -> redis.url (可选,需要启用 redis 特性)
161    /// - APP__REDIS__PASSWORD -> redis.password (可选,如果 URL 中已包含密码则不需要)
162    /// - APP__REDIS__POOL_SIZE -> redis.pool_size (可选,默认 10)
163    /// 查找项目根目录(通过查找 Cargo.toml 或 .env 文件)
164    fn find_project_root() -> Option<std::path::PathBuf> {
165        // 首先尝试从当前工作目录向上查找
166        if let Ok(mut current_dir) = std::env::current_dir() {
167            loop {
168                // 检查是否存在 Cargo.toml(项目根目录标识)
169                if current_dir.join("Cargo.toml").exists() {
170                    return Some(current_dir);
171                }
172                // 检查是否存在 .env 文件
173                if current_dir.join(".env").exists() {
174                    return Some(current_dir);
175                }
176                // 向上查找父目录
177                match current_dir.parent() {
178                    Some(parent) => current_dir = parent.to_path_buf(),
179                    None => break,
180                }
181            }
182        }
183
184        // 如果从当前工作目录找不到,尝试从可执行文件所在目录向上查找
185        if let Ok(exe_path) = std::env::current_exe() {
186            if let Some(exe_dir) = exe_path.parent() {
187                let mut current_dir = exe_dir.to_path_buf();
188                loop {
189                    // 检查是否存在 Cargo.toml
190                    if current_dir.join("Cargo.toml").exists() {
191                        return Some(current_dir);
192                    }
193                    // 检查是否存在 .env 文件
194                    if current_dir.join(".env").exists() {
195                        return Some(current_dir);
196                    }
197                    // 向上查找父目录
198                    match current_dir.parent() {
199                        Some(parent) => current_dir = parent.to_path_buf(),
200                        None => break,
201                    }
202                }
203            }
204        }
205
206        None
207    }
208
209    pub fn from_env() -> Result<Self, config::ConfigError> {
210        // 加载 .env 文件(如果存在)
211        // 优先从项目根目录加载,确保当库被其他项目使用时能正确找到 .env 文件
212
213        // 方法 1: 从项目根目录加载(最可靠)
214        if let Some(project_root) = Self::find_project_root() {
215            let env_path = project_root.join(".env");
216            if env_path.exists() {
217                if let Err(e) = dotenvy::from_path(&env_path) {
218                    tracing::debug!("从项目根目录加载 .env 文件失败: {},尝试其他方法", e);
219                } else {
220                    tracing::debug!("成功从项目根目录加载 .env 文件: {}", env_path.display());
221                    // 成功加载,继续后续配置
222                    return Self::load_config_from_env();
223                }
224            }
225        }
226
227        // 方法 2: 使用 dotenvy::dotenv() 从当前工作目录向上查找(备用方案)
228        if let Err(e) = dotenvy::dotenv() {
229            tracing::debug!("未找到 .env 文件: {},将使用环境变量和默认配置", e);
230        } else {
231            tracing::debug!("成功加载 .env 文件(从当前工作目录向上查找)");
232        }
233
234        Self::load_config_from_env()
235    }
236
237    /// 从环境变量加载配置(内部方法)
238    fn load_config_from_env() -> Result<Self, config::ConfigError> {
239        let builder = config::Config::builder()
240            .set_default("server.addr", "127.0.0.1")?
241            .set_default("server.port", 3000)?
242            .set_default("log.level", "info")?
243            .set_default("log.json", false)?
244            .set_default("cors.allowed_origins", vec!["*"])?
245            .set_default(
246                "cors.allowed_methods",
247                vec!["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
248            )?
249            .set_default("cors.allowed_headers", vec!["*"])?
250            // 注意:当 allowed_origins 或 allowed_headers 为 * 时,allow_credentials 会自动设置为 false
251            .set_default("cors.allow_credentials", false)?
252            // 从环境变量加载配置,使用 APP__ 前缀,__ 作为嵌套分隔符
253            .add_source(config::Environment::with_prefix("APP").separator("__"));
254
255        builder.build()?.try_deserialize()
256    }
257}