use std::path::Path;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use parking_lot::RwLock;
use tracing::info;
pub mod env;
pub use env::{detect_profile, parse_args, merge_env_overrides};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
#[serde(default = "default_app_name")]
pub app_name: String,
#[serde(default = "default_profile")]
pub profile: String,
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub log: LogConfig,
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
pub redis: RedisConfig,
#[serde(default)]
pub cache: CacheConfig,
#[serde(default)]
pub middleware: MiddlewareConfig,
#[serde(default)]
pub router: RouterConfig,
#[serde(default)]
pub plugins: PluginsConfig,
#[serde(default)]
pub upload: UploadConfig,
#[serde(default)]
pub download: DownloadConfig,
#[serde(default)]
pub template: TemplateConfig,
#[serde(default)]
pub static_files: StaticConfig,
#[serde(default)]
pub custom: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_listen")]
pub listen: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogConfig {
#[serde(default = "default_log_level")]
pub level: String,
#[serde(default = "default_log_format")]
pub format: String,
#[serde(default)]
pub dir: Option<String>,
#[serde(default = "default_log_prefix")]
pub file_prefix: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_db_type")]
pub r#type: String,
#[serde(default = "default_host")]
pub host: String,
pub port: Option<u16>,
#[serde(default)]
pub name: String,
#[serde(default)]
pub user: String,
#[serde(default)]
pub password: String,
#[serde(default)]
pub password_encrypted: bool,
#[serde(default = "default_pool_size")]
pub max_connections: u32,
#[serde(default = "default_min_idle")]
pub min_connections: u32,
#[serde(default = "default_timeout")]
pub connect_timeout: u64,
#[serde(default)]
pub sql_logging: bool,
#[serde(default)]
pub slow_query_ms: u64,
#[serde(default)]
pub migration: MigrationConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedisConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_redis_url")]
pub url: String,
#[serde(default = "default_pool_size")]
pub max_connections: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
#[serde(default = "default_cache_type")]
pub r#type: String,
#[serde(default = "default_cache_capacity")]
pub max_capacity: u64,
#[serde(default = "default_ttl")]
pub default_ttl: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MiddlewareConfig {
#[serde(default)]
pub request_id: bool,
#[serde(default)]
pub request_log: bool,
#[serde(default)]
pub request_log_config: RequestLogConfig,
#[serde(default)]
pub auth: AuthMiddlewareConfig,
#[serde(default)]
pub cors: CorsConfig,
#[serde(default)]
pub compression: CompressConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
#[serde(default)]
pub security_headers: SecurityHeadersConfig,
#[serde(default)]
pub permission: PermissionConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestLogConfig {
#[serde(default)]
pub exclude_paths: Vec<String>,
#[serde(default = "default_true")]
pub log_duration: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthMiddlewareConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub ignore_paths: Vec<String>,
#[serde(default)]
pub jwt_secret: String,
#[serde(default = "default_access_token_expire")]
pub access_token_expire_secs: u64,
#[serde(default = "default_refresh_token_expire")]
pub refresh_token_expire_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorsConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub allow_origins: Vec<String>,
#[serde(default)]
pub allow_methods: Vec<String>,
#[serde(default)]
pub allow_headers: Vec<String>,
#[serde(default = "default_true")]
pub allow_credentials: bool,
#[serde(default = "default_cors_max_age")]
pub max_age_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_compress_level")]
pub level: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_rate_limit_requests")]
pub requests_per_window: u64,
#[serde(default = "default_rate_limit_window")]
pub window_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityHeadersConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_true")]
pub nosniff: bool,
#[serde(default = "default_true")]
pub frame_options: bool,
#[serde(default = "default_true")]
pub hsts: bool,
#[serde(default = "default_hsts_max_age")]
pub hsts_max_age_secs: u64,
#[serde(default = "default_true")]
pub hsts_include_subdomains: bool,
#[serde(default = "default_true")]
pub csp: bool,
#[serde(default = "default_csp_value")]
pub csp_value: String,
#[serde(default = "default_true")]
pub referrer_policy: bool,
#[serde(default = "default_referrer_policy_value")]
pub referrer_policy_value: String,
#[serde(default)]
pub permissions_policy: bool,
#[serde(default = "default_permissions_policy_value")]
pub permissions_policy_value: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PermissionConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub rules: Vec<PermissionRule>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PermissionRule {
pub path: String,
#[serde(default)]
pub methods: Vec<String>,
pub permission: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig {
#[serde(default)]
pub prefix: String,
#[serde(default)]
pub not_found: NotFoundConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NotFoundConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_not_found_msg")]
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_migration_path")]
pub path: String,
#[serde(default)]
pub auto_migrate: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UploadConfig {
#[serde(default = "default_upload_path")]
pub path: String,
#[serde(default = "default_max_size")]
pub max_size_mb: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DownloadConfig {
#[serde(default = "default_download_path")]
pub path: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemplateConfig {
#[serde(default = "default_template_path")]
pub path: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StaticConfig {
#[serde(default = "default_static_path")]
pub path: String,
#[serde(default)]
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginsConfig {
#[serde(default)]
pub enabled: Vec<String>,
#[serde(default)]
pub notification: NotificationConfig,
#[serde(default)]
pub async_task: AsyncTaskConfig,
#[serde(default)]
pub scheduler: SchedulerConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct NotificationConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub smtp_host: String,
#[serde(default = "default_smtp_port")]
pub smtp_port: u16,
#[serde(default)]
pub smtp_user: String,
#[serde(default)]
pub smtp_pass: String,
#[serde(default)]
pub from_email: String,
#[serde(default)]
pub from_name: String,
}
fn default_smtp_port() -> u16 { 587 }
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AsyncTaskConfig {
#[serde(default = "default_workers")]
pub workers: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SchedulerConfig {
#[serde(default = "default_workers")]
pub workers: usize,
}
fn default_app_name() -> String { "Alun".into() }
fn default_profile() -> String { "dev".into() }
fn default_listen() -> String { "8023".into() }
fn default_log_level() -> String { "info".into() }
fn default_log_format() -> String { "text".into() }
fn default_log_prefix() -> String { "alun".into() }
fn default_db_type() -> String { "postgres".into() }
fn default_host() -> String { "localhost".into() }
fn default_true() -> bool { true }
fn default_pool_size() -> u32 { 10 }
fn default_min_idle() -> u32 { 2 }
fn default_timeout() -> u64 { 10 }
fn default_workers() -> usize { 4 }
fn default_redis_url() -> String { "redis://127.0.0.1:6379".into() }
fn default_cache_type() -> String { "local".into() }
fn default_cache_capacity() -> u64 { 10000 }
fn default_ttl() -> u64 { 3600 }
fn default_access_token_expire() -> u64 { 7200 }
fn default_refresh_token_expire() -> u64 { 604800 }
fn default_cors_max_age() -> u64 { 86400 }
fn default_compress_level() -> u32 { 6 }
fn default_rate_limit_requests() -> u64 { 100 }
fn default_rate_limit_window() -> u64 { 60 }
fn default_hsts_max_age() -> u64 { 31536000 }
fn default_csp_value() -> String { "default-src 'self'".into() }
fn default_referrer_policy_value() -> String { "strict-origin-when-cross-origin".into() }
fn default_permissions_policy_value() -> String {
"camera=(), microphone=(), geolocation=()".into()
}
fn default_migration_path() -> String { "migrations".into() }
fn default_upload_path() -> String { "uploads".into() }
fn default_download_path() -> String { "downloads".into() }
fn default_template_path() -> String { "templates".into() }
fn default_static_path() -> String { "static".into() }
fn default_not_found_msg() -> String { "请求的资源不存在".into() }
fn default_max_size() -> u64 { 10 }
impl Default for AppConfig {
fn default() -> Self {
Self {
app_name: default_app_name(),
profile: default_profile(),
server: ServerConfig::default(),
log: LogConfig::default(),
database: DatabaseConfig::default(),
redis: RedisConfig::default(),
cache: CacheConfig::default(),
middleware: MiddlewareConfig::default(),
router: RouterConfig::default(),
plugins: PluginsConfig::default(),
upload: UploadConfig::default(),
download: DownloadConfig::default(),
template: TemplateConfig::default(),
static_files: StaticConfig::default(),
custom: HashMap::new(),
}
}
}
impl Default for ServerConfig { fn default() -> Self { Self { listen: default_listen() } } }
impl Default for LogConfig {
fn default() -> Self {
Self {
level: default_log_level(),
format: default_log_format(),
dir: None,
file_prefix: default_log_prefix(),
}
}
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
enabled: false, r#type: default_db_type(), host: default_host(),
port: None, name: String::new(), user: String::new(), password: String::new(),
password_encrypted: false,
max_connections: default_pool_size(), min_connections: default_min_idle(),
connect_timeout: default_timeout(), sql_logging: false, slow_query_ms: 0,
migration: MigrationConfig::default(),
}
}
}
impl Default for RedisConfig { fn default() -> Self { Self { enabled: false, url: default_redis_url(), max_connections: default_pool_size() } } }
impl Default for CacheConfig { fn default() -> Self { Self { r#type: default_cache_type(), max_capacity: default_cache_capacity(), default_ttl: default_ttl() } } }
impl Default for MiddlewareConfig {
fn default() -> Self {
Self {
request_id: false, request_log: false,
request_log_config: RequestLogConfig::default(),
auth: AuthMiddlewareConfig::default(),
cors: CorsConfig::default(),
compression: CompressConfig::default(),
rate_limit: RateLimitConfig::default(),
security_headers: SecurityHeadersConfig::default(),
permission: PermissionConfig::default(),
}
}
}
impl Default for RequestLogConfig {
fn default() -> Self { Self { exclude_paths: vec![], log_duration: true } }
}
impl Default for AuthMiddlewareConfig { fn default() -> Self { Self { enabled: false, ignore_paths: vec![], jwt_secret: String::new(), access_token_expire_secs: default_access_token_expire(), refresh_token_expire_secs: default_refresh_token_expire() } } }
impl Default for CorsConfig { fn default() -> Self { Self { enabled: false, allow_origins: vec![], allow_methods: vec![], allow_headers: vec![], allow_credentials: true, max_age_secs: default_cors_max_age() } } }
impl Default for CompressConfig { fn default() -> Self { Self { enabled: false, level: default_compress_level() } } }
impl Default for RateLimitConfig { fn default() -> Self { Self { enabled: false, requests_per_window: default_rate_limit_requests(), window_secs: default_rate_limit_window() } } }
impl Default for SecurityHeadersConfig {
fn default() -> Self {
Self {
enabled: true,
nosniff: true, frame_options: true,
hsts: true, hsts_max_age_secs: default_hsts_max_age(),
hsts_include_subdomains: true,
csp: true, csp_value: default_csp_value(),
referrer_policy: true, referrer_policy_value: default_referrer_policy_value(),
permissions_policy: false, permissions_policy_value: default_permissions_policy_value(),
}
}
}
impl Default for PermissionConfig { fn default() -> Self { Self { enabled: false, rules: vec![] } } }
impl Default for PermissionRule { fn default() -> Self { Self { path: String::new(), methods: vec![], permission: String::new() } } }
impl Default for RouterConfig { fn default() -> Self { Self { prefix: String::new(), not_found: NotFoundConfig::default() } } }
impl Default for NotFoundConfig { fn default() -> Self { Self { enabled: true, message: default_not_found_msg() } } }
impl Default for MigrationConfig { fn default() -> Self { Self { enabled: false, path: default_migration_path(), auto_migrate: false } } }
impl Default for UploadConfig { fn default() -> Self { Self { path: default_upload_path(), max_size_mb: default_max_size() } } }
impl Default for DownloadConfig { fn default() -> Self { Self { path: default_download_path() } } }
impl Default for TemplateConfig { fn default() -> Self { Self { path: default_template_path() } } }
impl Default for StaticConfig { fn default() -> Self { Self { path: default_static_path(), enabled: false } } }
impl Default for PluginsConfig { fn default() -> Self { Self { enabled: vec![], notification: NotificationConfig::default(), async_task: AsyncTaskConfig::default(), scheduler: SchedulerConfig::default() } } }
pub struct ConfigManager {
pub static_config: AppConfig,
pub dynamic: RwLock<HashMap<String, serde_json::Value>>,
}
impl ConfigManager {
pub fn load(config_dir: Option<&str>) -> Self {
let dir = config_dir.unwrap_or("config");
let profile = detect_profile();
let mut cfg = Self::load_file(dir, &profile);
merge_env_overrides(&mut cfg);
info!("配置加载完成 profile={}, listen={}", cfg.profile, cfg.server.listen);
Self {
static_config: cfg,
dynamic: RwLock::new(HashMap::new()),
}
}
fn load_file(dir: &str, profile: &str) -> AppConfig {
let base_path = Path::new(dir).join("config.toml");
let mut cfg = if base_path.exists() {
let content = fs::read_to_string(&base_path)
.unwrap_or_else(|_| String::new());
toml::from_str(&content).unwrap_or_default()
} else {
AppConfig::default()
};
let profile_path = Path::new(dir).join(format!("config-{}.toml", profile));
if profile_path.exists() {
if let Ok(content) = fs::read_to_string(&profile_path) {
if let Ok(profile_cfg) = toml::from_str::<AppConfig>(&content) {
merge_configs(&mut cfg, &profile_cfg);
}
}
}
cfg.profile = profile.to_string();
cfg
}
pub fn get(&self) -> &AppConfig {
&self.static_config
}
pub fn get_dynamic(&self, key: &str) -> Option<serde_json::Value> {
self.dynamic.read().get(key).cloned()
}
pub fn set_dynamic(&self, key: &str, value: serde_json::Value) {
self.dynamic.write().insert(key.to_string(), value);
}
pub fn remove_dynamic(&self, key: &str) {
self.dynamic.write().remove(key);
}
pub fn generate_default(dir: &str) -> std::io::Result<()> {
let config_dir = Path::new(dir);
fs::create_dir_all(config_dir)?;
let cfg = AppConfig::default();
let toml_str = toml::to_string_pretty(&cfg)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
let header = r#"# Alun 默认配置文件
# 修改后保存即可生效(需重启服务)
#
# 使用 --gen-config 参数可重新生成此文件到 config/config.toml
# 多环境:创建 config/config-dev.toml, config/config-prod.toml
# 通过环境变量或命令行 --profile=prod 指定
"#;
fs::write(config_dir.join("config.toml"), format!("{}{}", header, toml_str))?;
info!("默认配置文件已生成到 {}/config.toml", dir);
Ok(())
}
}
fn merge_configs(base: &mut AppConfig, overlay: &AppConfig) {
if overlay.server.listen != default_listen() { base.server.listen = overlay.server.listen.clone(); }
if overlay.log.level != default_log_level() { base.log.level = overlay.log.level.clone(); }
if overlay.log.format != default_log_format() { base.log.format = overlay.log.format.clone(); }
if overlay.log.dir.is_some() { base.log.dir = overlay.log.dir.clone(); }
if overlay.log.file_prefix != default_log_prefix() { base.log.file_prefix = overlay.log.file_prefix.clone(); }
if overlay.database.host != default_host() || !overlay.database.name.is_empty() {
base.database = overlay.database.clone();
}
if overlay.redis.url != default_redis_url() { base.redis = overlay.redis.clone(); }
if overlay.cache.r#type != default_cache_type() { base.cache = overlay.cache.clone(); }
if overlay.router.prefix != String::new() { base.router.prefix = overlay.router.prefix.clone(); }
if overlay.router.not_found.message != default_not_found_msg() {
base.router.not_found.message = overlay.router.not_found.message.clone();
}
if !overlay.router.not_found.enabled {
base.router.not_found.enabled = false;
}
if overlay.upload.path != default_upload_path() { base.upload = overlay.upload.clone(); }
if overlay.download.path != default_download_path() { base.download = overlay.download.clone(); }
if overlay.template.path != default_template_path() { base.template = overlay.template.clone(); }
if overlay.static_files.path != default_static_path() { base.static_files = overlay.static_files.clone(); }
merge_middleware(&mut base.middleware, &overlay.middleware);
if !overlay.plugins.enabled.is_empty() {
base.plugins = overlay.plugins.clone();
}
for (k, v) in &overlay.custom { base.custom.insert(k.clone(), v.clone()); }
}
fn merge_middleware(base: &mut MiddlewareConfig, overlay: &MiddlewareConfig) {
let default_mw = MiddlewareConfig::default();
if overlay.request_id != default_mw.request_id { base.request_id = overlay.request_id; }
if overlay.request_log != default_mw.request_log { base.request_log = overlay.request_log; }
if overlay.request_log_config.log_duration != default_mw.request_log_config.log_duration {
base.request_log_config.log_duration = overlay.request_log_config.log_duration;
}
if !overlay.request_log_config.exclude_paths.is_empty() {
base.request_log_config.exclude_paths = overlay.request_log_config.exclude_paths.clone();
}
if overlay.auth.enabled != default_mw.auth.enabled { base.auth.enabled = overlay.auth.enabled; }
if overlay.auth.jwt_secret != default_mw.auth.jwt_secret { base.auth.jwt_secret = overlay.auth.jwt_secret.clone(); }
if overlay.auth.access_token_expire_secs != 0 { base.auth.access_token_expire_secs = overlay.auth.access_token_expire_secs; }
if overlay.auth.refresh_token_expire_secs != 0 { base.auth.refresh_token_expire_secs = overlay.auth.refresh_token_expire_secs; }
if !overlay.auth.ignore_paths.is_empty() { base.auth.ignore_paths = overlay.auth.ignore_paths.clone(); }
if overlay.cors.enabled != default_mw.cors.enabled { base.cors.enabled = overlay.cors.enabled; }
if !overlay.cors.allow_origins.is_empty() { base.cors.allow_origins = overlay.cors.allow_origins.clone(); }
if overlay.compression.enabled != default_mw.compression.enabled { base.compression.enabled = overlay.compression.enabled; }
if overlay.rate_limit.enabled != default_mw.rate_limit.enabled { base.rate_limit.enabled = overlay.rate_limit.enabled; }
if overlay.rate_limit.requests_per_window != 0 { base.rate_limit.requests_per_window = overlay.rate_limit.requests_per_window; }
if overlay.rate_limit.window_secs != 0 { base.rate_limit.window_secs = overlay.rate_limit.window_secs; }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config_serialization() {
let cfg = AppConfig::default();
let toml_str = toml::to_string_pretty(&cfg).unwrap();
assert!(toml_str.contains("listen = \"8023\""));
assert!(toml_str.contains("level = \"info\""));
}
}