use crate::error::{AllSourceError, Result};
use serde::{Deserialize, Serialize};
use std::{
fs,
path::{Component, Path, PathBuf},
};
fn validate_config_path(path: &Path) -> Result<()> {
let os = path.as_os_str();
if os.is_empty() {
return Err(AllSourceError::ValidationError(
"config path must not be empty".to_string(),
));
}
let bytes = os.as_encoded_bytes();
if bytes.contains(&0) {
return Err(AllSourceError::ValidationError(
"config path contains a null byte".to_string(),
));
}
if path.components().any(|c| matches!(c, Component::ParentDir)) {
return Err(AllSourceError::ValidationError(
"config path must not contain '..' components".to_string(),
));
}
Ok(())
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub storage: StorageConfig,
pub auth: AuthConfig,
pub rate_limit: RateLimitConfigFile,
pub backup: BackupConfigFile,
pub metrics: MetricsConfig,
pub logging: LoggingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub workers: Option<usize>,
pub max_connections: usize,
pub request_timeout_secs: u64,
pub cors_enabled: bool,
pub cors_origins: Vec<String>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: "0.0.0.0".to_string(),
port: 3900,
workers: None, max_connections: 10_000,
request_timeout_secs: 30,
cors_enabled: true,
cors_origins: vec!["*".to_string()],
}
}
}
impl ServerConfig {
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(host) = std::env::var("ALLSOURCE_HOST").or_else(|_| std::env::var("HOST")) {
config.host = host;
}
if let Ok(port) = std::env::var("ALLSOURCE_PORT").or_else(|_| std::env::var("PORT"))
&& let Ok(p) = port.parse::<u16>()
{
config.port = p;
}
config
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageConfig {
pub data_dir: PathBuf,
pub wal_dir: PathBuf,
pub batch_size: usize,
pub compression: CompressionType,
pub retention_days: Option<u32>,
pub max_storage_gb: Option<u32>,
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
data_dir: PathBuf::from("./data"),
wal_dir: PathBuf::from("./wal"),
batch_size: 1000,
compression: CompressionType::Lz4,
retention_days: None,
max_storage_gb: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum CompressionType {
None,
Lz4,
Gzip,
Snappy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub jwt_secret: String,
pub jwt_expiry_hours: i64,
pub api_key_expiry_days: Option<i64>,
pub password_min_length: usize,
pub require_email_verification: bool,
pub session_timeout_minutes: u64,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
jwt_secret: "CHANGE_ME_IN_PRODUCTION".to_string(),
jwt_expiry_hours: 24,
api_key_expiry_days: Some(90),
password_min_length: 8,
require_email_verification: false,
session_timeout_minutes: 60,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfigFile {
pub enabled: bool,
pub default_tier: RateLimitTier,
pub requests_per_minute: Option<u32>,
pub burst_size: Option<u32>,
}
impl Default for RateLimitConfigFile {
fn default() -> Self {
Self {
enabled: true,
default_tier: RateLimitTier::Professional,
requests_per_minute: None,
burst_size: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum RateLimitTier {
Free,
Professional,
Unlimited,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackupConfigFile {
pub enabled: bool,
pub backup_dir: PathBuf,
pub schedule_cron: Option<String>,
pub retention_count: usize,
pub compression_level: u8,
pub verify_after_backup: bool,
}
impl Default for BackupConfigFile {
fn default() -> Self {
Self {
enabled: false,
backup_dir: PathBuf::from("./backups"),
schedule_cron: None, retention_count: 7,
compression_level: 6,
verify_after_backup: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig {
pub enabled: bool,
pub endpoint: String,
pub push_interval_secs: Option<u64>,
pub push_gateway_url: Option<String>,
}
impl Default for MetricsConfig {
fn default() -> Self {
Self {
enabled: true,
endpoint: "/metrics".to_string(),
push_interval_secs: None,
push_gateway_url: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub level: LogLevel,
pub format: LogFormat,
pub output: LogOutput,
pub file_path: Option<PathBuf>,
pub rotate_size_mb: Option<u64>,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: LogLevel::Info,
format: LogFormat::Pretty,
output: LogOutput::Stdout,
file_path: None,
rotate_size_mb: Some(100),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
Trace,
Debug,
Info,
Warn,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum LogFormat {
Json,
Pretty,
Compact,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum LogOutput {
Stdout,
Stderr,
File,
Both,
}
impl Config {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_ref = path.as_ref();
validate_config_path(path_ref)?;
let content = fs::read_to_string(path_ref).map_err(|e| {
AllSourceError::StorageError(format!("Failed to read config file: {e}"))
})?;
toml::from_str(&content)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid config format: {e}")))
}
pub fn from_env() -> Result<Self> {
let mut config = Config::default();
if let Ok(host) = std::env::var("ALLSOURCE_HOST").or_else(|_| std::env::var("HOST")) {
config.server.host = host;
}
let port_str = std::env::var("ALLSOURCE_PORT").or_else(|_| std::env::var("PORT"));
if let Ok(port) = port_str {
config.server.port = port
.parse()
.map_err(|_| AllSourceError::ValidationError("Invalid port number".to_string()))?;
}
if let Ok(data_dir) = std::env::var("ALLSOURCE_DATA_DIR") {
config.storage.data_dir = PathBuf::from(data_dir);
}
if let Ok(jwt_secret) = std::env::var("ALLSOURCE_JWT_SECRET") {
config.auth.jwt_secret = jwt_secret;
}
Ok(config)
}
pub fn load(config_path: Option<PathBuf>) -> Result<Self> {
let mut config = if let Some(path) = config_path {
if path.exists() {
tracing::info!("Loading config from: {}", path.display());
Self::from_file(path)?
} else {
tracing::warn!("Config file not found: {}, using defaults", path.display());
Config::default()
}
} else {
Config::default()
};
if let Ok(env_config) = Self::from_env() {
config.merge_env(env_config);
}
config.validate()?;
Ok(config)
}
fn merge_env(&mut self, env_config: Config) {
if env_config.server.host != ServerConfig::default().host {
self.server.host = env_config.server.host;
}
if env_config.server.port != ServerConfig::default().port {
self.server.port = env_config.server.port;
}
if env_config.storage.data_dir != StorageConfig::default().data_dir {
self.storage.data_dir = env_config.storage.data_dir;
}
if env_config.auth.jwt_secret != AuthConfig::default().jwt_secret {
self.auth.jwt_secret = env_config.auth.jwt_secret;
}
}
pub fn validate(&self) -> Result<()> {
if self.server.port == 0 {
return Err(AllSourceError::ValidationError(
"Server port cannot be 0".to_string(),
));
}
if self.auth.jwt_secret == "CHANGE_ME_IN_PRODUCTION" {
tracing::warn!("⚠️ Using default JWT secret - INSECURE for production!");
}
if self.storage.data_dir.as_os_str().is_empty() {
return Err(AllSourceError::ValidationError(
"Data directory path cannot be empty".to_string(),
));
}
Ok(())
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let toml = toml::to_string_pretty(self).map_err(|e| {
AllSourceError::ValidationError(format!("Failed to serialize config: {e}"))
})?;
fs::write(path.as_ref(), toml).map_err(|e| {
AllSourceError::StorageError(format!("Failed to write config file: {e}"))
})?;
Ok(())
}
pub fn example() -> String {
toml::to_string_pretty(&Config::default())
.unwrap_or_else(|_| String::from("# Failed to generate example config"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.server.port, 3900);
assert!(config.rate_limit.enabled);
}
#[test]
fn test_config_validation() {
let config = Config::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_invalid_port() {
let mut config = Config::default();
config.server.port = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_validate_config_path_accepts_normal_paths() {
assert!(validate_config_path(Path::new("config.toml")).is_ok());
assert!(validate_config_path(Path::new("/etc/allsource/config.toml")).is_ok());
assert!(validate_config_path(Path::new("./config/allsource.toml")).is_ok());
}
#[test]
fn test_validate_config_path_rejects_traversal_and_nulls() {
assert!(validate_config_path(Path::new("")).is_err());
assert!(validate_config_path(Path::new("../secret.toml")).is_err());
assert!(validate_config_path(Path::new("config/../../secret.toml")).is_err());
}
#[test]
fn test_config_serialization() {
let config = Config::default();
let toml = toml::to_string(&config).unwrap();
let deserialized: Config = toml::from_str(&toml).unwrap();
assert_eq!(config.server.port, deserialized.server.port);
}
}