fbc_starter/
config.rs

1use serde::{Deserialize, Serialize};
2use std::net::SocketAddr;
3
4/// 应用配置
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct Config {
7    /// 服务器配置
8    pub server: ServerConfig,
9    /// 日志配置
10    pub log: LogConfig,
11    /// CORS 配置
12    pub cors: CorsConfig,
13    /// 数据库配置(可选,需要启用 mysql/postgres/sqlite 任一特性)
14    #[serde(default)]
15    #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
16    pub database: Option<DatabaseConfig>,
17    /// Redis 配置(可选,需要启用 redis 特性)
18    #[serde(default)]
19    #[cfg(feature = "redis")]
20    pub redis: Option<RedisConfig>,
21    /// Nacos 配置(可选,需要启用 nacos 特性)
22    #[serde(default)]
23    #[cfg(feature = "nacos")]
24    pub nacos: Option<NacosConfig>,
25    /// Kafka 配置(可选,需要启用 kafka 特性)
26    #[serde(default)]
27    #[cfg(feature = "kafka")]
28    pub kafka: Option<KafkaConfig>,
29}
30
31/// 服务器配置
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ServerConfig {
34    /// 监听地址
35    pub addr: String,
36    /// 端口
37    pub port: u16,
38    /// 工作线程数(0 表示使用默认值)
39    pub workers: Option<usize>,
40    /// 上下文路径(可选),例如 "/api",如果不配置则为空
41    #[serde(default)]
42    pub context_path: Option<String>,
43}
44
45impl ServerConfig {
46    /// 获取完整的 SocketAddr
47    pub fn socket_addr(&self) -> Result<SocketAddr, std::net::AddrParseError> {
48        format!("{}:{}", self.addr, self.port).parse()
49    }
50}
51
52/// 日志配置
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct LogConfig {
55    /// 日志级别 (trace, debug, info, warn, error)
56    pub level: String,
57    /// 是否使用 JSON 格式
58    pub json: bool,
59}
60
61/// CORS 配置
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct CorsConfig {
64    /// 允许的源(* 表示允许所有)
65    pub allowed_origins: Vec<String>,
66    /// 允许的方法
67    pub allowed_methods: Vec<String>,
68    /// 允许的请求头
69    pub allowed_headers: Vec<String>,
70    /// 是否允许凭证
71    pub allow_credentials: bool,
72}
73
74/// 数据库配置(需要启用 mysql/postgres/sqlite 任一特性)
75#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct DatabaseConfig {
78    /// 数据库 URL(例如:postgres://user:password@localhost/dbname)
79    pub url: String,
80    /// 最大连接数
81    #[serde(default = "default_max_connections")]
82    pub max_connections: u32,
83    /// 最小连接数
84    #[serde(default = "default_min_connections")]
85    pub min_connections: u32,
86}
87
88#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
89fn default_max_connections() -> u32 {
90    100
91}
92
93#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
94fn default_min_connections() -> u32 {
95    10
96}
97
98/// Redis 配置(需要启用 redis 特性)
99#[cfg(feature = "redis")]
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct RedisConfig {
102    /// Redis URL(例如:redis://127.0.0.1:6379 或 redis://:password@127.0.0.1:6379)
103    pub url: String,
104    /// Redis 密码(可选,如果 URL 中已包含密码则不需要)
105    #[serde(default)]
106    pub password: Option<String>,
107    /// 连接池大小
108    #[serde(default = "default_pool_size")]
109    pub pool_size: usize,
110}
111
112#[cfg(feature = "redis")]
113fn default_pool_size() -> usize {
114    10
115}
116
117/// Nacos 配置(需要启用 nacos 特性)
118#[cfg(feature = "nacos")]
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct NacosConfig {
121    /// Nacos 服务器地址列表(例如:["http://127.0.0.1:8848"])
122    #[serde(default = "default_nacos_server_addrs")]
123    pub server_addrs: Vec<String>,
124    /// 命名空间(可选)
125    pub namespace: Option<String>,
126    /// 用户名(可选,用于认证,默认为 "nacos")
127    #[serde(default = "default_nacos_username")]
128    pub username: Option<String>,
129    /// 密码(可选,用于认证,默认为 "nacos")
130    #[serde(default = "default_nacos_password")]
131    pub password: Option<String>,
132    /// 服务名称(用于服务注册,如果为空则使用环境变量 CARGO_PKG_NAME)
133    #[serde(default)]
134    pub service_name: String,
135    /// 服务组名(可选,默认为 DEFAULT_GROUP)
136    #[serde(default = "default_nacos_group")]
137    pub group_name: String,
138    /// 服务 IP(可选,默认使用服务器配置的地址)
139    #[serde(default)]
140    pub service_ip: Option<String>,
141    /// 服务端口(可选,默认使用服务器配置的端口)
142    #[serde(default)]
143    pub service_port: Option<u32>,
144    /// 健康检查路径(可选,默认为 "/health")
145    #[serde(default = "default_nacos_health_check_path")]
146    pub health_check_path: Option<String>,
147    /// 元数据(可选)
148    #[serde(default)]
149    pub metadata: Option<std::collections::HashMap<String, String>>,
150    /// 订阅的服务列表(可选,用于服务发现)
151    /// 环境变量支持逗号分隔:APP__NACOS__SUBSCRIBE_SERVICES=im-server,user-service
152    #[serde(
153        default,
154        deserialize_with = "crate::utils::serde_helpers::deserialize_string_or_vec"
155    )]
156    pub subscribe_services: Vec<String>,
157    /// 订阅的配置列表(可选,用于配置管理)
158    #[serde(default)]
159    pub subscribe_configs: Vec<NacosConfigItem>,
160}
161
162/// Nacos 配置项
163#[cfg(feature = "nacos")]
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct NacosConfigItem {
166    /// 配置的 Data ID
167    pub data_id: String,
168    /// 配置的 Group(可选,默认为 DEFAULT_GROUP)
169    #[serde(default = "default_nacos_group")]
170    pub group: String,
171    /// 命名空间(可选)
172    #[serde(default = "default_nacos_namespace")]
173    pub namespace: String,
174}
175
176#[cfg(feature = "nacos")]
177fn default_nacos_group() -> String {
178    "DEFAULT_GROUP".to_string()
179}
180
181#[cfg(feature = "nacos")]
182fn default_nacos_server_addrs() -> Vec<String> {
183    vec!["127.0.0.1:8848".to_string()]
184}
185
186#[cfg(feature = "nacos")]
187fn default_nacos_health_check_path() -> Option<String> {
188    Some("/health".to_string())
189}
190
191#[cfg(feature = "nacos")]
192fn default_nacos_namespace() -> String {
193    "public".to_string()
194}
195
196#[cfg(feature = "nacos")]
197fn default_nacos_username() -> Option<String> {
198    Some("nacos".to_string())
199}
200
201#[cfg(feature = "nacos")]
202fn default_nacos_password() -> Option<String> {
203    Some("nacos".to_string())
204}
205
206/// Kafka 配置(需要启用 kafka 特性)
207#[cfg(feature = "kafka")]
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct KafkaConfig {
210    /// Kafka 集群地址(例如:localhost:9092 或 10.0.0.1:9092,10.0.0.2:9092)
211    pub brokers: String,
212    /// 生产者配置(可选,需要启用 producer 特性)
213    #[serde(default)]
214    pub producer: Option<KafkaProducerConfig>,
215    /// 消费者配置(可选,需要启用 consumer 特性)
216    #[serde(default)]
217    pub consumer: Option<KafkaConsumerConfig>,
218}
219
220/// Kafka 生产者配置
221#[cfg(feature = "kafka")]
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct KafkaProducerConfig {
224    /// 生产者重试次数
225    #[serde(default = "default_producer_retries")]
226    pub retries: i32,
227    /// 是否启用幂等性
228    #[serde(default = "default_producer_idempotence")]
229    pub enable_idempotence: bool,
230    /// ACK 模式 (all, 1, 0)
231    #[serde(default = "default_producer_acks")]
232    pub acks: String,
233}
234
235/// Kafka 消费者配置
236#[cfg(feature = "kafka")]
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct KafkaConsumerConfig {
239    /// 是否自动提交偏移量
240    #[serde(default = "default_consumer_auto_commit")]
241    pub enable_auto_commit: bool,
242}
243
244// Kafka 生产者默认值
245#[cfg(feature = "kafka")]
246fn default_producer_retries() -> i32 {
247    3
248}
249
250#[cfg(feature = "kafka")]
251fn default_producer_idempotence() -> bool {
252    true
253}
254
255#[cfg(feature = "kafka")]
256fn default_producer_acks() -> String {
257    "all".to_string()
258}
259
260// Kafka 消费者默认值
261#[cfg(feature = "kafka")]
262fn default_consumer_auto_commit() -> bool {
263    false
264}
265
266impl Default for Config {
267    fn default() -> Self {
268        Self {
269            server: ServerConfig {
270                addr: "127.0.0.1".to_string(),
271                port: 3000,
272                workers: None,
273                context_path: None,
274            },
275            log: LogConfig {
276                level: "info".to_string(),
277                json: false,
278            },
279            cors: CorsConfig {
280                allowed_origins: vec!["*".to_string()],
281                allowed_methods: vec![
282                    "GET".to_string(),
283                    "POST".to_string(),
284                    "PUT".to_string(),
285                    "DELETE".to_string(),
286                    "PATCH".to_string(),
287                    "OPTIONS".to_string(),
288                ],
289                allowed_headers: vec!["*".to_string()],
290                // 注意:当 allowed_origins 或 allowed_headers 为 * 时,allow_credentials 会自动设置为 false
291                allow_credentials: false,
292            },
293            #[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
294            database: None,
295            #[cfg(feature = "redis")]
296            redis: None,
297            #[cfg(feature = "nacos")]
298            nacos: None,
299            #[cfg(feature = "kafka")]
300            kafka: None,
301        }
302    }
303}
304
305impl Config {
306    /// 获取本机 IP 地址
307    /// 返回第一个非回环的 IPv4 地址,如果获取失败则返回 None
308    fn get_local_ip() -> Option<String> {
309        match local_ip_address::local_ip() {
310            Ok(ip) => {
311                // 只返回 IPv4 地址,跳过回环地址
312                if ip.is_ipv4() && !ip.is_loopback() {
313                    Some(ip.to_string())
314                } else {
315                    None
316                }
317            }
318            Err(_) => None,
319        }
320    }
321
322    /// 从 .env 文件和环境变量加载配置
323    ///
324    /// 配置项命名规则:
325    /// - APP__SERVER__ADDR -> server.addr (如果不配置,自动获取本机 IP,获取不到则使用 127.0.0.1)
326    /// - APP__SERVER__PORT -> server.port
327    /// - APP__SERVER__CONTEXT_PATH -> server.context_path (可选,例如 "/api")
328    /// - APP__LOG__LEVEL -> log.level
329    /// - APP__LOG__JSON -> log.json
330    /// - APP__CORS__ALLOWED_ORIGINS -> cors.allowed_origins (逗号分隔)
331    /// - APP__CORS__ALLOWED_METHODS -> cors.allowed_methods (逗号分隔)
332    /// - APP__CORS__ALLOWED_HEADERS -> cors.allowed_headers (逗号分隔)
333    /// - APP__CORS__ALLOW_CREDENTIALS -> cors.allow_credentials
334    /// - APP__DATABASE__URL -> database.url (可选,需要启用 database 特性)
335    /// - APP__DATABASE__MAX_CONNECTIONS -> database.max_connections (可选,默认 100)
336    /// - APP__DATABASE__MIN_CONNECTIONS -> database.min_connections (可选,默认 10)
337    /// - APP__REDIS__URL -> redis.url (可选,需要启用 redis 特性)
338    /// - APP__REDIS__PASSWORD -> redis.password (可选,如果 URL 中已包含密码则不需要)
339    /// - APP__REDIS__POOL_SIZE -> redis.pool_size (可选,默认 10)
340    /// 查找项目根目录(通过查找 Cargo.toml 或 .env 文件)
341    ///
342    /// 查找策略(按优先级):
343    /// 1. 从可执行文件路径推断项目目录(例如 target/debug/im-server -> im-server/)
344    /// 2. 从可执行文件所在目录向上查找 .env 文件
345    /// 3. 从当前工作目录向上查找 .env 文件
346    fn find_project_root() -> Option<std::path::PathBuf> {
347        // 策略 1: 从可执行文件路径推断项目目录
348        // 例如:/path/to/hula-server/target/debug/im-server -> /path/to/hula-server/im-server/
349        if let Ok(exe_path) = std::env::current_exe() {
350            // 获取可执行文件名(例如 "im-server")
351            if let Some(exe_name) = exe_path.file_stem().and_then(|s| s.to_str()) {
352                // 从可执行文件路径向上查找,直到找到 workspace 根目录或项目根目录
353                if let Some(exe_dir) = exe_path.parent() {
354                    let mut path = exe_dir.to_path_buf();
355                    loop {
356                        // 检查当前目录的父目录是否包含与可执行文件同名的目录
357                        if let Some(parent) = path.parent() {
358                            let project_dir = parent.join(exe_name);
359                            // 如果找到同名目录且包含 .env 文件,这就是项目根目录
360                            if project_dir.join(".env").exists() {
361                                return Some(project_dir);
362                            }
363                            // 如果找到同名目录且包含 Cargo.toml(非 workspace),这也是项目根目录
364                            let cargo_toml = project_dir.join("Cargo.toml");
365                            if cargo_toml.exists() {
366                                if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
367                                    if !content.contains("[workspace]") {
368                                        return Some(project_dir);
369                                    }
370                                }
371                            }
372                        }
373                        // 检查当前目录是否有 .env 文件
374                        if path.join(".env").exists() {
375                            return Some(path);
376                        }
377                        // 向上查找
378                        match path.parent() {
379                            Some(parent) => path = parent.to_path_buf(),
380                            None => break,
381                        }
382                    }
383                }
384            }
385        }
386
387        // 策略 2: 从当前工作目录向上查找 .env 文件
388        if let Ok(mut current_dir) = std::env::current_dir() {
389            loop {
390                if current_dir.join(".env").exists() {
391                    // 检查是否是 workspace 根目录
392                    let cargo_toml = current_dir.join("Cargo.toml");
393                    if cargo_toml.exists() {
394                        if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
395                            if content.contains("[workspace]") {
396                                // 这是 workspace 根目录,但找到了 .env,返回当前目录
397                                return Some(current_dir);
398                            }
399                        }
400                    }
401                    return Some(current_dir);
402                }
403                match current_dir.parent() {
404                    Some(parent) => current_dir = parent.to_path_buf(),
405                    None => break,
406                }
407            }
408        }
409
410        None
411    }
412
413    pub fn from_env() -> Result<Self, config::ConfigError> {
414        // 加载 .env 文件(如果存在)
415        // 优先从项目根目录加载,确保当库被其他项目使用时能正确找到 .env 文件
416
417        // 方法 1: 使用 CARGO_MANIFEST_DIR 环境变量(编译时设置,最可靠)
418        if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
419            let env_path = std::path::Path::new(&manifest_dir).join(".env");
420            if env_path.exists() {
421                if dotenvy::from_path(&env_path).is_ok() {
422                    eprintln!(
423                        "✓ 从 CARGO_MANIFEST_DIR 加载 .env 文件: {}",
424                        env_path.display()
425                    );
426                    return Self::load_config_from_env();
427                }
428            }
429        }
430
431        // 方法 2: 从项目根目录加载(通过查找逻辑)
432        if let Some(project_root) = Self::find_project_root() {
433            let env_path = project_root.join(".env");
434            if env_path.exists() {
435                if dotenvy::from_path(&env_path).is_ok() {
436                    eprintln!("✓ 从项目根目录加载 .env 文件: {}", env_path.display());
437                    return Self::load_config_from_env();
438                }
439            }
440        }
441
442        // 方法 3: 使用 dotenvy::dotenv() 从当前工作目录向上查找(备用方案)
443        match dotenvy::dotenv() {
444            Ok(path) => {
445                eprintln!("✓ 从当前工作目录向上查找加载 .env 文件: {}", path.display());
446            }
447            Err(_) => {
448                eprintln!("⚠ 未找到 .env 文件,将使用环境变量和默认配置");
449            }
450        }
451
452        Self::load_config_from_env()
453    }
454
455    /// 从环境变量加载配置(内部方法)
456    fn load_config_from_env() -> Result<Self, config::ConfigError> {
457        // 先手动处理数组类型的配置项,设置默认值
458        let mut default_origins = vec!["*".to_string()];
459        let mut default_methods = vec![
460            "GET".to_string(),
461            "POST".to_string(),
462            "PUT".to_string(),
463            "DELETE".to_string(),
464            "PATCH".to_string(),
465            "OPTIONS".to_string(),
466        ];
467        let mut default_headers = vec!["*".to_string()];
468
469        // 从环境变量读取数组配置
470        if let Ok(origins_str) = std::env::var("APP__CORS__ALLOWED_ORIGINS") {
471            default_origins = origins_str
472                .split(',')
473                .map(|s| s.trim().to_string())
474                .collect();
475        }
476
477        if let Ok(methods_str) = std::env::var("APP__CORS__ALLOWED_METHODS") {
478            default_methods = methods_str
479                .split(',')
480                .map(|s| s.trim().to_string())
481                .collect();
482        }
483
484        if let Ok(headers_str) = std::env::var("APP__CORS__ALLOWED_HEADERS") {
485            default_headers = headers_str
486                .split(',')
487                .map(|s| s.trim().to_string())
488                .collect();
489        }
490
491        // 临时移除这些环境变量,避免 config crate 尝试解析它们
492        let origins_backup = std::env::var("APP__CORS__ALLOWED_ORIGINS").ok();
493        let methods_backup = std::env::var("APP__CORS__ALLOWED_METHODS").ok();
494        let headers_backup = std::env::var("APP__CORS__ALLOWED_HEADERS").ok();
495
496        if origins_backup.is_some() {
497            std::env::remove_var("APP__CORS__ALLOWED_ORIGINS");
498        }
499        if methods_backup.is_some() {
500            std::env::remove_var("APP__CORS__ALLOWED_METHODS");
501        }
502        if headers_backup.is_some() {
503            std::env::remove_var("APP__CORS__ALLOWED_HEADERS");
504        }
505
506        // 如果未配置 APP__SERVER__ADDR,则自动获取本机 IP
507        // 注意:如果环境变量 APP__SERVER__ADDR 存在,config crate 会优先使用环境变量的值,set_default 的值不会被使用
508        let default_server_addr = if std::env::var("APP__SERVER__ADDR").is_ok() {
509            // 环境变量已存在,set_default 的值不会被使用,但 API 要求提供一个值
510            // 这里返回任意值都可以,因为不会被使用
511            "127.0.0.1".to_string()
512        } else {
513            // 环境变量不存在,尝试获取本机 IP 作为默认值
514            match Self::get_local_ip() {
515                Some(ip) => {
516                    eprintln!("✓ 自动获取本机 IP 地址: {}", ip);
517                    ip
518                }
519                None => {
520                    eprintln!("⚠ 无法获取本机 IP 地址,将使用 127.0.0.1");
521                    "127.0.0.1".to_string()
522                }
523            }
524        };
525
526        let builder = config::Config::builder()
527            .set_default("server.addr", default_server_addr.as_str())?
528            .set_default("server.port", 3000)?
529            .set_default("log.level", "info")?
530            .set_default("log.json", false)?
531            .set_default("cors.allowed_origins", default_origins.clone())?
532            .set_default("cors.allowed_methods", default_methods.clone())?
533            .set_default("cors.allowed_headers", default_headers.clone())?
534            .set_default("cors.allow_credentials", false)?;
535
536        // Nacos 配置默认值
537        #[cfg(feature = "nacos")]
538        let builder = builder
539            .set_default("nacos.server_addrs", default_nacos_server_addrs())?
540            .set_default("nacos.service_name", String::new())?
541            .set_default("nacos.group_name", default_nacos_group())?
542            .set_default("nacos.namespace", default_nacos_namespace())?
543            .set_default("nacos.username", default_nacos_username())?
544            .set_default("nacos.password", default_nacos_password())?
545            .set_default("nacos.health_check_path", default_nacos_health_check_path())?;
546
547        // Kafka 配置默认值
548        #[cfg(feature = "kafka")]
549        let builder = builder.set_default("kafka.brokers", "localhost:9092")?;
550
551        #[cfg(feature = "producer")]
552        let builder = builder
553            .set_default("kafka.producer.retries", default_producer_retries())?
554            .set_default(
555                "kafka.producer.enable_idempotence",
556                default_producer_idempotence(),
557            )?
558            .set_default("kafka.producer.acks", default_producer_acks())?;
559
560        #[cfg(feature = "consumer")]
561        let builder = builder.set_default(
562            "kafka.consumer.enable_auto_commit",
563            default_consumer_auto_commit(),
564        )?;
565
566        let builder = builder.add_source(config::Environment::with_prefix("APP").separator("__"));
567
568        let config = builder.build()?;
569        let result: Config = config.try_deserialize()?;
570
571        Ok(result)
572    }
573}