use serde::{Deserialize, Serialize};
use std::env;
use std::net::{SocketAddr, IpAddr};
use std::str::FromStr;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("Missing required environment variable: {0}")]
MissingEnvVar(String),
#[error("Invalid configuration value for {key}: {message}")]
InvalidValue { key: String, message: String },
#[error("Failed to parse {key}: {source}")]
ParseError {
key: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
}
pub type ConfigResult<T> = Result<T, ConfigError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub cors_origins: Vec<String>,
}
impl ServerConfig {
pub fn from_env() -> ConfigResult<Self> {
Self::from_env_with_prefix("SERVER_")
}
pub fn from_env_with_prefix(prefix: &str) -> ConfigResult<Self> {
let host = env::var(format!("{}HOST", prefix))
.unwrap_or_else(|_| "0.0.0.0".to_string());
let port = env::var(format!("{}PORT", prefix))
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
let cors_origins = env::var("CORS_ORIGINS")
.unwrap_or_else(|_| "http://localhost:3000".to_string())
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
Ok(Self {
host,
port,
cors_origins,
})
}
pub fn bind_addr(&self) -> ConfigResult<SocketAddr> {
let ip = IpAddr::from_str(&self.host).map_err(|e| ConfigError::ParseError {
key: "host".to_string(),
source: Box::new(e),
})?;
Ok(SocketAddr::new(ip, self.port))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub url: String,
pub max_connections: u32,
pub connect_timeout: u64,
}
impl DatabaseConfig {
pub fn from_env(prefix: Option<&str>) -> ConfigResult<Self> {
let prefix = prefix.unwrap_or("DATABASE_");
let url = env::var(format!("{}URL", prefix)).map_err(|_| {
ConfigError::MissingEnvVar(format!("{}URL", prefix))
})?;
let max_connections = env::var(format!("{}MAX_CONNECTIONS", prefix))
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10);
let connect_timeout = env::var(format!("{}CONNECT_TIMEOUT", prefix))
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
Ok(Self {
url,
max_connections,
connect_timeout,
})
}
}
pub trait ConfigLoader: Sized {
fn from_env() -> ConfigResult<Self>;
fn validate(&self) -> ConfigResult<()> {
Ok(())
}
}
impl ConfigLoader for ServerConfig {
fn from_env() -> ConfigResult<Self> {
Self::from_env()
}
fn validate(&self) -> ConfigResult<()> {
if self.port == 0 {
return Err(ConfigError::InvalidValue {
key: "port".to_string(),
message: "Port cannot be 0".to_string(),
});
}
if self.host.is_empty() {
return Err(ConfigError::InvalidValue {
key: "host".to_string(),
message: "Host cannot be empty".to_string(),
});
}
if IpAddr::from_str(&self.host).is_err() {
return Err(ConfigError::InvalidValue {
key: "host".to_string(),
message: format!("Invalid IP address: {}", self.host),
});
}
Ok(())
}
}
impl ConfigLoader for DatabaseConfig {
fn from_env() -> ConfigResult<Self> {
Self::from_env(None)
}
fn validate(&self) -> ConfigResult<()> {
if self.url.is_empty() {
return Err(ConfigError::InvalidValue {
key: "url".to_string(),
message: "Database URL cannot be empty".to_string(),
});
}
if self.max_connections == 0 {
return Err(ConfigError::InvalidValue {
key: "max_connections".to_string(),
message: "Max connections must be greater than 0".to_string(),
});
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_server_config_defaults() {
unsafe {
env::remove_var("SERVER_HOST");
env::remove_var("SERVER_PORT");
env::remove_var("CORS_ORIGINS");
}
let config = ServerConfig::from_env().unwrap();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 8080);
assert_eq!(config.cors_origins, vec!["http://localhost:3000"]);
}
#[test]
fn test_server_config_custom_values() {
unsafe {
env::remove_var("SERVER_HOST");
env::remove_var("SERVER_PORT");
env::remove_var("CORS_ORIGINS");
}
unsafe {
env::set_var("SERVER_HOST", "127.0.0.1");
env::set_var("SERVER_PORT", "3000");
env::set_var("CORS_ORIGINS", "http://example.com,http://test.com");
}
let config = ServerConfig::from_env().unwrap();
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3000);
assert_eq!(
config.cors_origins,
vec!["http://example.com", "http://test.com"]
);
unsafe {
env::remove_var("SERVER_HOST");
env::remove_var("SERVER_PORT");
env::remove_var("CORS_ORIGINS");
}
}
#[test]
fn test_server_config_bind_addr() {
let config = ServerConfig {
host: "127.0.0.1".to_string(),
port: 8080,
cors_origins: vec![],
};
let addr = config.bind_addr().unwrap();
assert_eq!(addr.to_string(), "127.0.0.1:8080");
}
#[test]
fn test_server_config_validation() {
let mut config = ServerConfig {
host: "127.0.0.1".to_string(),
port: 8080,
cors_origins: vec![],
};
assert!(config.validate().is_ok());
config.port = 0;
assert!(config.validate().is_err());
config.port = 8080;
config.host = "invalid".to_string();
assert!(config.validate().is_err());
}
#[test]
fn test_database_config_from_env() {
unsafe {
env::set_var("DATABASE_URL", "postgres://localhost/test");
env::set_var("DATABASE_MAX_CONNECTIONS", "20");
env::set_var("DATABASE_CONNECT_TIMEOUT", "60");
}
let config = DatabaseConfig::from_env(None).unwrap();
assert_eq!(config.url, "postgres://localhost/test");
assert_eq!(config.max_connections, 20);
assert_eq!(config.connect_timeout, 60);
unsafe {
env::remove_var("DATABASE_URL");
env::remove_var("DATABASE_MAX_CONNECTIONS");
env::remove_var("DATABASE_CONNECT_TIMEOUT");
}
}
#[test]
fn test_database_config_missing_url() {
unsafe {
env::remove_var("DATABASE_URL");
}
let result = DatabaseConfig::from_env(None);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ConfigError::MissingEnvVar(_)));
}
#[test]
fn test_database_config_validation() {
let mut config = DatabaseConfig {
url: "postgres://localhost/test".to_string(),
max_connections: 10,
connect_timeout: 30,
};
assert!(config.validate().is_ok());
config.url = String::new();
assert!(config.validate().is_err());
config.url = "postgres://localhost/test".to_string();
config.max_connections = 0;
assert!(config.validate().is_err());
}
}