use std::path::Path;
use std::time::Duration;
use serde::{Deserialize, Serialize};
fn default_true() -> bool {
true
}
use crate::auth::JwtConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
pub auth: AuthConfig,
#[serde(default)]
pub cors: CorsConfig,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
auth: AuthConfig::default(),
cors: CorsConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
#[serde(default = "default_true")]
pub api_key_enabled: bool,
#[serde(default)]
pub jwt_enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwt: Option<JwtConfigSerde>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
api_key_enabled: true,
jwt_enabled: false,
jwt: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtConfigSerde {
pub issuer: String,
pub audience: String,
pub jwks_url: String,
#[serde(default = "default_tenant_id")]
pub default_tenant_id: String,
#[serde(default = "default_tenant_claim")]
pub tenant_claim: String,
#[serde(default = "default_roles_claim")]
pub roles_claim: String,
#[serde(default = "default_jwks_cache_ttl")]
pub jwks_cache_ttl_seconds: u64,
}
impl From<JwtConfigSerde> for JwtConfig {
fn from(config: JwtConfigSerde) -> Self {
JwtConfig {
issuer: config.issuer,
audience: config.audience,
jwks_url: config.jwks_url,
default_tenant_id: config.default_tenant_id,
tenant_claim: config.tenant_claim,
roles_claim: config.roles_claim,
jwks_cache_ttl: Duration::from_secs(config.jwks_cache_ttl_seconds),
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorsConfig {
#[serde(default = "default_allowed_origins")]
pub allowed_origins: Vec<String>,
#[serde(default = "default_allowed_methods")]
pub allowed_methods: Vec<String>,
#[serde(default = "default_allowed_headers")]
pub allowed_headers: Vec<String>,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allowed_origins: default_allowed_origins(),
allowed_methods: default_allowed_methods(),
allowed_headers: default_allowed_headers(),
}
}
}
fn default_host() -> String {
"0.0.0.0".to_string()
}
fn default_port() -> u16 {
8080
}
fn default_tenant_id() -> String {
"default".to_string()
}
fn default_tenant_claim() -> String {
"tenant_id".to_string()
}
fn default_roles_claim() -> String {
"roles".to_string()
}
fn default_jwks_cache_ttl() -> u64 {
3600
}
fn default_allowed_origins() -> Vec<String> {
vec!["*".to_string()]
}
fn default_allowed_methods() -> Vec<String> {
vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"DELETE".to_string(),
"OPTIONS".to_string(),
]
}
fn default_allowed_headers() -> Vec<String> {
vec!["*".to_string()]
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RustbergConfig {
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub tls: TlsConfigFile,
#[serde(default)]
pub storage: StorageConfig,
#[serde(default)]
pub kms: KmsConfigFile,
#[serde(default)]
pub rate_limit: RateLimitConfigFile,
#[serde(default)]
pub logging: LoggingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfigFile {
#[serde(default)]
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub cert_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub key_path: Option<String>,
#[serde(default)]
pub insecure_http: bool,
}
impl Default for TlsConfigFile {
fn default() -> Self {
Self {
enabled: false, cert_path: None,
key_path: None,
insecure_http: true, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageConfig {
#[serde(default = "default_storage_type")]
pub backend: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub warehouse_location: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aws_region: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_dir: Option<String>,
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
backend: default_storage_type(),
warehouse_location: None,
aws_region: None,
cache_dir: None,
}
}
}
fn default_storage_type() -> String {
"file:///var/lib/rustberg/data".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KmsConfigFile {
#[serde(default = "default_kms_type")]
pub provider: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub aws_key_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aws_region: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vault_address: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vault_key_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gcp_project_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gcp_location: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gcp_key_ring: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gcp_key_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub azure_vault_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub azure_key_name: Option<String>,
#[serde(default = "default_kms_cache_ttl")]
pub cache_ttl_seconds: u64,
#[serde(default = "default_true")]
pub circuit_breaker_enabled: bool,
}
impl Default for KmsConfigFile {
fn default() -> Self {
Self {
provider: default_kms_type(),
aws_key_id: None,
aws_region: None,
vault_address: None,
vault_key_name: None,
gcp_project_id: None,
gcp_location: None,
gcp_key_ring: None,
gcp_key_name: None,
azure_vault_url: None,
azure_key_name: None,
cache_ttl_seconds: default_kms_cache_ttl(),
circuit_breaker_enabled: true,
}
}
}
fn default_kms_type() -> String {
"env".to_string()
}
fn default_kms_cache_ttl() -> u64 {
300
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfigFile {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_requests_per_second")]
pub requests_per_second: u32,
#[serde(default = "default_burst_size")]
pub burst_size: u32,
#[serde(default = "default_true")]
pub track_auth_failures: bool,
#[serde(default = "default_max_auth_failures")]
pub max_auth_failures: u32,
#[serde(default = "default_lockout_duration")]
pub lockout_duration_seconds: u64,
#[serde(default)]
pub trust_proxy_headers: bool,
}
impl Default for RateLimitConfigFile {
fn default() -> Self {
Self {
enabled: true,
requests_per_second: default_requests_per_second(),
burst_size: default_burst_size(),
track_auth_failures: true,
max_auth_failures: default_max_auth_failures(),
lockout_duration_seconds: default_lockout_duration(),
trust_proxy_headers: false, }
}
}
fn default_requests_per_second() -> u32 {
100
}
fn default_burst_size() -> u32 {
200
}
fn default_max_auth_failures() -> u32 {
5
}
fn default_lockout_duration() -> u64 {
300 }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
#[serde(default = "default_log_level")]
pub level: String,
#[serde(default)]
pub json_format: bool,
#[serde(default = "default_true")]
pub with_span_events: bool,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: default_log_level(),
json_format: false,
with_span_events: true,
}
}
}
fn default_log_level() -> String {
"info".to_string()
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("Failed to read config file: {0}")]
ReadError(#[from] std::io::Error),
#[error("Failed to parse config file: {0}")]
ParseError(#[from] toml::de::Error),
#[error("Invalid configuration: {0}")]
ValidationError(String),
}
impl RustbergConfig {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)?;
let config: RustbergConfig = toml::from_str(&content)?;
config.validate()?;
Ok(config)
}
pub fn parse_str(content: &str) -> Result<Self, ConfigError> {
let config: RustbergConfig = toml::from_str(content)?;
config.validate()?;
Ok(config)
}
pub fn load_or_default<P: AsRef<Path>>(path: P) -> Self {
match Self::from_file(path.as_ref()) {
Ok(config) => {
tracing::info!(path = %path.as_ref().display(), "Loaded configuration from file");
config
}
Err(e) => {
tracing::warn!(
path = %path.as_ref().display(),
error = %e,
"Failed to load config, using defaults"
);
Self::default()
}
}
}
pub fn discover() -> Self {
let search_paths = [
"rustberg.toml",
"/etc/rustberg/config.toml",
"config/rustberg.toml",
];
for path in search_paths {
if Path::new(path).exists() {
if let Ok(config) = Self::from_file(path) {
tracing::info!(path = %path, "Discovered configuration file");
return config;
}
}
}
tracing::debug!("No config file found, using defaults");
Self::default()
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.tls.enabled {
if self.tls.cert_path.is_none() && !self.tls.insecure_http {
return Err(ConfigError::ValidationError(
"TLS enabled but no cert_path provided and insecure_http is false".to_string(),
));
}
if self.tls.key_path.is_none() && self.tls.cert_path.is_some() {
return Err(ConfigError::ValidationError(
"TLS cert_path provided but no key_path".to_string(),
));
}
}
match self.kms.provider.as_str() {
"env" => { }
"aws-kms" => {
if self.kms.aws_key_id.is_none() {
return Err(ConfigError::ValidationError(
"AWS KMS provider requires aws_key_id".to_string(),
));
}
}
"vault" => {
if self.kms.vault_address.is_none() || self.kms.vault_key_name.is_none() {
return Err(ConfigError::ValidationError(
"Vault provider requires vault_address and vault_key_name".to_string(),
));
}
}
provider => {
return Err(ConfigError::ValidationError(format!(
"Unknown KMS provider: {}",
provider
)));
}
}
if self.rate_limit.enabled && self.rate_limit.requests_per_second == 0 {
return Err(ConfigError::ValidationError(
"requests_per_second must be > 0".to_string(),
));
}
Ok(())
}
pub fn to_toml(&self) -> Result<String, toml::ser::Error> {
toml::to_string_pretty(self)
}
pub fn sample() -> String {
let sample = Self {
server: ServerConfig {
host: "0.0.0.0".to_string(),
port: 8000,
auth: AuthConfig {
api_key_enabled: true,
jwt_enabled: false,
jwt: None,
},
cors: CorsConfig::default(),
},
tls: TlsConfigFile {
enabled: true,
cert_path: Some("/path/to/cert.pem".to_string()),
key_path: Some("/path/to/key.pem".to_string()),
insecure_http: false,
},
storage: StorageConfig {
backend: "file:///var/lib/rustberg/data".to_string(),
warehouse_location: Some("s3://my-bucket/warehouse".to_string()),
aws_region: None,
cache_dir: None,
},
kms: KmsConfigFile {
provider: "env".to_string(),
aws_key_id: None,
aws_region: None,
vault_address: None,
vault_key_name: None,
gcp_project_id: None,
gcp_location: None,
gcp_key_ring: None,
gcp_key_name: None,
azure_vault_url: None,
azure_key_name: None,
cache_ttl_seconds: 300,
circuit_breaker_enabled: true,
},
rate_limit: RateLimitConfigFile::default(),
logging: LoggingConfig::default(),
};
sample.to_toml().unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_server_config() {
let config = ServerConfig::default();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 8080);
assert!(config.auth.api_key_enabled);
assert!(!config.auth.jwt_enabled);
assert!(config.auth.jwt.is_none());
}
#[test]
fn test_jwt_config_conversion() {
let jwt_config_serde = JwtConfigSerde {
issuer: "https://issuer.example.com".to_string(),
audience: "rustberg-api".to_string(),
jwks_url: "https://issuer.example.com/.well-known/jwks.json".to_string(),
default_tenant_id: "test-tenant".to_string(),
tenant_claim: "custom_tenant".to_string(),
roles_claim: "custom_roles".to_string(),
jwks_cache_ttl_seconds: 7200,
};
let jwt_config: JwtConfig = jwt_config_serde.into();
assert_eq!(jwt_config.issuer, "https://issuer.example.com");
assert_eq!(jwt_config.audience, "rustberg-api");
assert_eq!(jwt_config.default_tenant_id, "test-tenant");
assert_eq!(jwt_config.tenant_claim, "custom_tenant");
assert_eq!(jwt_config.roles_claim, "custom_roles");
assert_eq!(jwt_config.jwks_cache_ttl, Duration::from_secs(7200));
}
#[test]
fn test_server_config_serialization() {
let config = ServerConfig {
host: "127.0.0.1".to_string(),
port: 9000,
auth: AuthConfig {
api_key_enabled: true,
jwt_enabled: true,
jwt: Some(JwtConfigSerde {
issuer: "https://issuer.example.com".to_string(),
audience: "rustberg-api".to_string(),
jwks_url: "https://issuer.example.com/.well-known/jwks.json".to_string(),
default_tenant_id: "default".to_string(),
tenant_claim: "tenant_id".to_string(),
roles_claim: "roles".to_string(),
jwks_cache_ttl_seconds: 3600,
}),
},
cors: CorsConfig::default(),
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: ServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.host, "127.0.0.1");
assert_eq!(deserialized.port, 9000);
assert!(deserialized.auth.api_key_enabled);
assert!(deserialized.auth.jwt_enabled);
assert!(deserialized.auth.jwt.is_some());
}
#[test]
fn test_rustberg_config_from_toml() {
let toml_content = r#"
[server]
host = "127.0.0.1"
port = 9000
[server.auth]
api_key_enabled = true
jwt_enabled = false
[tls]
enabled = false
insecure_http = true
[storage]
backend = "file:///tmp/rustberg"
[kms]
provider = "env"
[rate_limit]
enabled = true
requests_per_second = 50
[logging]
level = "debug"
"#;
let config = RustbergConfig::parse_str(toml_content).unwrap();
assert_eq!(config.server.host, "127.0.0.1");
assert_eq!(config.server.port, 9000);
assert!(!config.tls.enabled);
assert!(config.storage.backend.starts_with("file://"));
assert_eq!(config.kms.provider, "env");
assert_eq!(config.rate_limit.requests_per_second, 50);
assert_eq!(config.logging.level, "debug");
}
#[test]
fn test_rustberg_config_validation_tls() {
let config = RustbergConfig {
tls: TlsConfigFile {
enabled: true,
cert_path: Some("/path/to/cert".to_string()),
key_path: None,
insecure_http: false,
},
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
}
#[test]
fn test_rustberg_config_validation_kms() {
let config = RustbergConfig {
kms: KmsConfigFile {
provider: "aws-kms".to_string(),
aws_key_id: None, ..Default::default()
},
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
}
#[test]
fn test_rustberg_config_sample() {
let sample = RustbergConfig::sample();
assert!(sample.contains("[server]"));
assert!(sample.contains("[tls]"));
assert!(sample.contains("[storage]"));
assert!(sample.contains("[kms]"));
}
#[test]
fn test_rustberg_config_roundtrip() {
let config = RustbergConfig::default();
let toml_str = config.to_toml().unwrap();
let parsed = RustbergConfig::parse_str(&toml_str).unwrap();
assert_eq!(config.server.host, parsed.server.host);
assert_eq!(config.server.port, parsed.server.port);
assert_eq!(config.kms.provider, parsed.kms.provider);
}
}