use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub log: LogConfig,
pub cors: CorsConfig,
#[serde(default)]
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
pub database: Option<DatabaseConfig>,
#[serde(default)]
#[cfg(feature = "redis")]
pub redis: Option<RedisConfig>,
#[serde(default)]
#[cfg(feature = "nacos")]
pub nacos: Option<NacosConfig>,
#[serde(default)]
#[cfg(feature = "kafka")]
pub kafka: Option<KafkaConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub addr: String,
pub port: u16,
pub workers: Option<usize>,
#[serde(default)]
pub context_path: Option<String>,
}
impl ServerConfig {
pub fn socket_addr(&self) -> Result<SocketAddr, std::net::AddrParseError> {
format!("{}:{}", self.addr, self.port).parse()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogConfig {
pub level: String,
pub json: bool,
#[serde(default = "default_log_timezone")]
pub timezone: i32,
#[serde(default)]
pub file: Option<FileLogConfig>,
}
fn default_log_timezone() -> i32 {
8 }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileLogConfig {
#[serde(default = "default_log_directory")]
pub directory: String,
#[serde(default = "default_log_filename")]
pub filename: String,
#[serde(default = "default_log_format")]
pub format: String,
#[serde(default = "default_log_size_limit_mb")]
pub size_limit_mb: u64,
#[serde(default = "default_log_count_limit")]
pub count_limit: u32,
#[serde(default = "default_log_rotation")]
pub rotation: String,
}
fn default_log_directory() -> String {
"./logs".to_string()
}
fn default_log_filename() -> String {
"app".to_string()
}
fn default_log_format() -> String {
"plain".to_string()
}
fn default_log_size_limit_mb() -> u64 {
100
}
fn default_log_count_limit() -> u32 {
10
}
fn default_log_rotation() -> String {
"daily".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
pub allow_credentials: bool,
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub url: String,
#[serde(default = "default_max_connections")]
pub max_connections: u32,
#[serde(default = "default_min_connections")]
pub min_connections: u32,
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
fn default_max_connections() -> u32 {
100
}
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
fn default_min_connections() -> u32 {
10
}
#[cfg(feature = "redis")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedisConfig {
pub url: String,
#[serde(default)]
pub password: Option<String>,
#[serde(default = "default_pool_size")]
pub pool_size: usize,
}
#[cfg(feature = "redis")]
fn default_pool_size() -> usize {
10
}
#[cfg(feature = "nacos")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NacosConfig {
#[serde(default = "default_nacos_server_addrs")]
pub server_addrs: Vec<String>,
pub namespace: Option<String>,
#[serde(default)]
pub naming_namespace: Option<String>,
#[serde(default)]
pub config_namespace: Option<String>,
#[serde(default = "default_nacos_username")]
pub username: Option<String>,
#[serde(default = "default_nacos_password")]
pub password: Option<String>,
#[serde(default)]
pub service_name: String,
#[serde(default = "default_nacos_group")]
pub group_name: String,
#[serde(default)]
pub naming_group: Option<String>,
#[serde(default)]
pub config_group: Option<String>,
#[serde(default)]
pub service_ip: Option<String>,
#[serde(default)]
pub service_port: Option<u32>,
#[serde(default = "default_nacos_health_check_path")]
pub health_check_path: Option<String>,
#[serde(default)]
pub metadata: Option<std::collections::HashMap<String, String>>,
#[serde(
default,
deserialize_with = "crate::utils::serde_helpers::deserialize_string_or_vec"
)]
pub subscribe_services: Vec<String>,
#[serde(default)]
pub subscribe_configs: Vec<NacosConfigItem>,
}
#[cfg(feature = "nacos")]
impl NacosConfig {
pub fn effective_naming_namespace(&self) -> Option<&String> {
self.naming_namespace.as_ref().or(self.namespace.as_ref())
}
pub fn effective_config_namespace(&self) -> Option<&String> {
self.config_namespace.as_ref().or(self.namespace.as_ref())
}
pub fn effective_naming_group(&self) -> &str {
self.naming_group.as_deref().unwrap_or(&self.group_name)
}
pub fn effective_config_group(&self) -> &str {
self.config_group.as_deref().unwrap_or(&self.group_name)
}
pub fn is_namespace_separated(&self) -> bool {
let naming_ns = self.effective_naming_namespace();
let config_ns = self.effective_config_namespace();
naming_ns != config_ns
}
}
#[cfg(feature = "nacos")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NacosConfigItem {
pub data_id: String,
#[serde(default = "default_nacos_group")]
pub group: String,
#[serde(default = "default_nacos_namespace")]
pub namespace: String,
}
#[cfg(feature = "nacos")]
fn default_nacos_group() -> String {
"DEFAULT_GROUP".to_string()
}
#[cfg(feature = "nacos")]
fn default_nacos_server_addrs() -> Vec<String> {
vec!["127.0.0.1:8848".to_string()]
}
#[cfg(feature = "nacos")]
fn default_nacos_health_check_path() -> Option<String> {
Some("/health".to_string())
}
#[cfg(feature = "nacos")]
fn default_nacos_namespace() -> String {
"public".to_string()
}
#[cfg(feature = "nacos")]
fn default_nacos_username() -> Option<String> {
Some("nacos".to_string())
}
#[cfg(feature = "nacos")]
fn default_nacos_password() -> Option<String> {
Some("nacos".to_string())
}
#[cfg(feature = "kafka")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KafkaConfig {
pub brokers: String,
#[serde(default)]
pub producer: Option<KafkaProducerConfig>,
#[serde(default)]
pub consumer: Option<KafkaConsumerConfig>,
}
#[cfg(feature = "kafka")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KafkaProducerConfig {
#[serde(default = "default_producer_retries")]
pub retries: i32,
#[serde(default = "default_producer_idempotence")]
pub enable_idempotence: bool,
#[serde(default = "default_producer_acks")]
pub acks: String,
}
#[cfg(feature = "kafka")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KafkaConsumerConfig {
#[serde(default = "default_consumer_auto_commit")]
pub enable_auto_commit: bool,
}
#[cfg(feature = "kafka")]
fn default_producer_retries() -> i32 {
3
}
#[cfg(feature = "kafka")]
fn default_producer_idempotence() -> bool {
true
}
#[cfg(feature = "kafka")]
fn default_producer_acks() -> String {
"all".to_string()
}
#[cfg(feature = "kafka")]
fn default_consumer_auto_commit() -> bool {
false
}
impl Default for Config {
fn default() -> Self {
Self {
server: ServerConfig {
addr: "127.0.0.1".to_string(),
port: 3000,
workers: None,
context_path: None,
},
log: LogConfig {
level: "info".to_string(),
json: false,
timezone: 8,
file: None,
},
cors: CorsConfig {
allowed_origins: vec!["*".to_string()],
allowed_methods: vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"DELETE".to_string(),
"PATCH".to_string(),
"OPTIONS".to_string(),
],
allowed_headers: vec!["*".to_string()],
allow_credentials: false,
},
#[cfg(any(feature = "mysql", feature = "postgres", feature = "sqlite"))]
database: None,
#[cfg(feature = "redis")]
redis: None,
#[cfg(feature = "nacos")]
nacos: None,
#[cfg(feature = "kafka")]
kafka: None,
}
}
}
impl Config {
fn get_local_ip() -> Option<String> {
match local_ip_address::local_ip() {
Ok(ip) => {
if ip.is_ipv4() && !ip.is_loopback() {
Some(ip.to_string())
} else {
None
}
}
Err(_) => None,
}
}
fn find_project_root() -> Option<std::path::PathBuf> {
if let Ok(exe_path) = std::env::current_exe() {
if let Some(exe_name) = exe_path.file_stem().and_then(|s| s.to_str()) {
if let Some(exe_dir) = exe_path.parent() {
let mut path = exe_dir.to_path_buf();
loop {
if let Some(parent) = path.parent() {
let project_dir = parent.join(exe_name);
if project_dir.join(".env").exists() {
return Some(project_dir);
}
let cargo_toml = project_dir.join("Cargo.toml");
if cargo_toml.exists() {
if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
if !content.contains("[workspace]") {
return Some(project_dir);
}
}
}
}
if path.join(".env").exists() {
return Some(path);
}
match path.parent() {
Some(parent) => path = parent.to_path_buf(),
None => break,
}
}
}
}
}
if let Ok(mut current_dir) = std::env::current_dir() {
loop {
if current_dir.join(".env").exists() {
let cargo_toml = current_dir.join("Cargo.toml");
if cargo_toml.exists() {
if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
if content.contains("[workspace]") {
return Some(current_dir);
}
}
}
return Some(current_dir);
}
match current_dir.parent() {
Some(parent) => current_dir = parent.to_path_buf(),
None => break,
}
}
}
None
}
pub fn from_env() -> Result<Self, config::ConfigError> {
if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
let env_path = std::path::Path::new(&manifest_dir).join(".env");
if env_path.exists() {
if dotenvy::from_path(&env_path).is_ok() {
eprintln!(
"✓ 从 CARGO_MANIFEST_DIR 加载 .env 文件: {}",
env_path.display()
);
return Self::load_config_from_env();
}
}
}
if let Some(project_root) = Self::find_project_root() {
let env_path = project_root.join(".env");
if env_path.exists() {
if dotenvy::from_path(&env_path).is_ok() {
eprintln!("✓ 从项目根目录加载 .env 文件: {}", env_path.display());
return Self::load_config_from_env();
}
}
}
match dotenvy::dotenv() {
Ok(path) => {
eprintln!("✓ 从当前工作目录向上查找加载 .env 文件: {}", path.display());
}
Err(_) => {
eprintln!("⚠ 未找到 .env 文件,将使用环境变量和默认配置");
}
}
Self::load_config_from_env()
}
fn load_config_from_env() -> Result<Self, config::ConfigError> {
let mut default_origins = vec!["*".to_string()];
let mut default_methods = vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"DELETE".to_string(),
"PATCH".to_string(),
"OPTIONS".to_string(),
];
let mut default_headers = vec!["*".to_string()];
if let Ok(origins_str) = std::env::var("APP__CORS__ALLOWED_ORIGINS") {
default_origins = origins_str
.split(',')
.map(|s| s.trim().to_string())
.collect();
}
if let Ok(methods_str) = std::env::var("APP__CORS__ALLOWED_METHODS") {
default_methods = methods_str
.split(',')
.map(|s| s.trim().to_string())
.collect();
}
if let Ok(headers_str) = std::env::var("APP__CORS__ALLOWED_HEADERS") {
default_headers = headers_str
.split(',')
.map(|s| s.trim().to_string())
.collect();
}
#[cfg(feature = "nacos")]
let nacos_server_addrs_override: Option<Vec<String>> = {
if let Ok(addrs_str) = std::env::var("APP__NACOS__SERVER_ADDRS") {
Some(
addrs_str
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
)
} else {
None
}
};
let origins_backup = std::env::var("APP__CORS__ALLOWED_ORIGINS").ok();
let methods_backup = std::env::var("APP__CORS__ALLOWED_METHODS").ok();
let headers_backup = std::env::var("APP__CORS__ALLOWED_HEADERS").ok();
#[cfg(feature = "nacos")]
let nacos_addrs_backup = std::env::var("APP__NACOS__SERVER_ADDRS").ok();
if origins_backup.is_some() {
std::env::remove_var("APP__CORS__ALLOWED_ORIGINS");
}
if methods_backup.is_some() {
std::env::remove_var("APP__CORS__ALLOWED_METHODS");
}
if headers_backup.is_some() {
std::env::remove_var("APP__CORS__ALLOWED_HEADERS");
}
#[cfg(feature = "nacos")]
if nacos_addrs_backup.is_some() {
std::env::remove_var("APP__NACOS__SERVER_ADDRS");
}
let default_server_addr = if std::env::var("APP__SERVER__ADDR").is_ok() {
"127.0.0.1".to_string()
} else {
match Self::get_local_ip() {
Some(ip) => {
eprintln!("✓ 自动获取本机 IP 地址: {}", ip);
ip
}
None => {
eprintln!("⚠ 无法获取本机 IP 地址,将使用 127.0.0.1");
"127.0.0.1".to_string()
}
}
};
let builder = config::Config::builder()
.set_default("server.addr", default_server_addr.as_str())?
.set_default("server.port", 3000)?
.set_default("log.level", "info")?
.set_default("log.json", false)?
.set_default("log.timezone", 8)?
.set_default("log.file.directory", "./logs")?
.set_default("log.file.filename", "app")?
.set_default("log.file.format", "plain")?
.set_default("log.file.size_limit_mb", 100u64)?
.set_default("log.file.count_limit", 10u32)?
.set_default("log.file.rotation", "daily")?
.set_default("cors.allowed_origins", default_origins.clone())?
.set_default("cors.allowed_methods", default_methods.clone())?
.set_default("cors.allowed_headers", default_headers.clone())?
.set_default("cors.allow_credentials", false)?;
#[cfg(feature = "nacos")]
let builder = {
let nacos_addrs = nacos_server_addrs_override
.clone()
.unwrap_or_else(default_nacos_server_addrs);
builder
.set_default("nacos.server_addrs", nacos_addrs)?
.set_default("nacos.service_name", String::new())?
.set_default("nacos.group_name", default_nacos_group())?
.set_default("nacos.namespace", default_nacos_namespace())?
.set_default("nacos.username", default_nacos_username())?
.set_default("nacos.password", default_nacos_password())?
.set_default("nacos.health_check_path", default_nacos_health_check_path())?
};
#[cfg(feature = "kafka")]
let builder = builder.set_default("kafka.brokers", "localhost:9092")?;
#[cfg(feature = "producer")]
let builder = builder
.set_default("kafka.producer.retries", default_producer_retries())?
.set_default(
"kafka.producer.enable_idempotence",
default_producer_idempotence(),
)?
.set_default("kafka.producer.acks", default_producer_acks())?;
#[cfg(feature = "consumer")]
let builder = builder.set_default(
"kafka.consumer.enable_auto_commit",
default_consumer_auto_commit(),
)?;
let builder = builder.add_source(config::Environment::with_prefix("APP").separator("__"));
let config = builder.build()?;
let result: Config = config.try_deserialize()?;
if let Some(v) = origins_backup {
std::env::set_var("APP__CORS__ALLOWED_ORIGINS", v);
}
if let Some(v) = methods_backup {
std::env::set_var("APP__CORS__ALLOWED_METHODS", v);
}
if let Some(v) = headers_backup {
std::env::set_var("APP__CORS__ALLOWED_HEADERS", v);
}
#[cfg(feature = "nacos")]
if let Some(v) = nacos_addrs_backup {
std::env::set_var("APP__NACOS__SERVER_ADDRS", v);
}
Ok(result)
}
}