use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::time::Duration;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("Failed to read configuration file: {0}")]
ReadFile(#[from] std::io::Error),
#[error("Failed to parse TOML: {0}")]
ParseToml(#[from] toml::de::Error),
#[error("Validation error: {0}")]
Validation(String),
#[error("Invalid socket address: {0}")]
InvalidAddress(#[from] std::net::AddrParseError),
}
pub type ConfigResult<T> = Result<T, ConfigError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub server: ServerSettings,
pub storage: StorageSettings,
pub network: NetworkSettings,
#[serde(default)]
pub cluster: Option<ClusterSettings>,
pub logging: LoggingSettings,
pub metrics: MetricsSettings,
#[serde(default)]
pub auth: AuthSettings,
#[serde(default)]
pub authz: AuthorizationSettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerSettings {
pub bind_address: String,
pub data_dir: PathBuf,
#[serde(default = "default_pid_file")]
pub pid_file: PathBuf,
#[serde(default = "default_max_connections")]
pub max_connections: usize,
#[serde(default = "default_shutdown_timeout")]
pub shutdown_timeout_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageSettings {
#[serde(default = "default_storage_engine")]
pub engine: String,
#[serde(default)]
pub wal: WalSettings,
#[serde(default = "default_memtable_size")]
pub memtable_size_mb: usize,
#[serde(default = "default_block_cache_size")]
pub block_cache_size_mb: usize,
#[serde(default)]
pub compaction: CompactionSettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalSettings {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_wal_dir")]
pub dir: PathBuf,
#[serde(default = "default_wal_segment_size")]
pub segment_size_mb: usize,
#[serde(default = "default_sync_mode")]
pub sync_mode: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionSettings {
#[serde(default = "default_compaction_strategy")]
pub strategy: String,
#[serde(default = "default_num_levels")]
pub num_levels: usize,
#[serde(default = "default_level_multiplier")]
pub level_multiplier: usize,
#[serde(default = "default_max_compactions")]
pub max_concurrent: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkSettings {
#[serde(default = "default_false")]
pub tls_enabled: bool,
pub tls_cert: Option<PathBuf>,
pub tls_key: Option<PathBuf>,
pub tls_ca: Option<PathBuf>,
#[serde(default = "default_false")]
pub require_client_cert: bool,
#[serde(default = "default_connection_timeout")]
pub connection_timeout_secs: u64,
#[serde(default = "default_keepalive_interval")]
pub keepalive_interval_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterSettings {
#[serde(default = "default_true")]
pub enabled: bool,
pub node_id: u64,
pub peers: Vec<String>,
#[serde(default = "default_heartbeat_interval")]
pub heartbeat_interval_ms: u64,
#[serde(default = "default_election_timeout")]
pub election_timeout_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingSettings {
#[serde(default = "default_log_level")]
pub level: String,
#[serde(default = "default_log_format")]
pub format: String,
#[serde(default = "default_false")]
pub file_enabled: bool,
pub file_path: Option<PathBuf>,
#[serde(default)]
pub rotation: LogRotationSettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogRotationSettings {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_log_max_size")]
pub max_size_mb: usize,
#[serde(default = "default_log_max_backups")]
pub max_backups: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsSettings {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_metrics_address")]
pub bind_address: String,
#[serde(default = "default_metrics_interval")]
pub export_interval_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthSettings {
#[serde(default = "default_false")]
pub enabled: bool,
#[serde(default = "default_auth_methods")]
pub methods: Vec<String>,
#[serde(default)]
pub mtls: MtlsSettings,
#[serde(default)]
pub jwt: JwtSettings,
#[serde(default)]
pub api_key: ApiKeySettings,
#[serde(default = "default_true")]
pub reject_unauthenticated: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MtlsSettings {
#[serde(default = "default_false")]
pub enabled: bool,
pub ca_certs_dir: Option<PathBuf>,
pub crl_path: Option<PathBuf>,
#[serde(default = "default_true")]
pub verify_cn: bool,
#[serde(default)]
pub allowed_organizations: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtSettings {
#[serde(default = "default_false")]
pub enabled: bool,
pub secret: Option<String>,
pub public_key_path: Option<PathBuf>,
#[serde(default = "default_jwt_algorithm")]
pub algorithm: String,
#[serde(default = "default_jwt_expiration")]
pub expiration_secs: u64,
pub issuer: Option<String>,
pub audience: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeySettings {
#[serde(default = "default_false")]
pub enabled: bool,
pub keys_file: Option<PathBuf>,
#[serde(default = "default_api_key_header")]
pub header_name: String,
#[serde(default = "default_true")]
pub hash_keys: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationSettings {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_user_role")]
pub default_role: String,
pub roles_file: Option<PathBuf>,
pub policies_file: Option<PathBuf>,
#[serde(default = "default_true")]
pub collection_permissions: bool,
#[serde(default = "default_permission_mode")]
pub default_mode: String,
#[serde(default = "default_true")]
pub audit_enabled: bool,
pub audit_log_path: Option<PathBuf>,
}
fn default_pid_file() -> PathBuf {
PathBuf::from("/var/run/amaters-server.pid")
}
fn default_max_connections() -> usize {
1000
}
fn default_shutdown_timeout() -> u64 {
30
}
fn default_storage_engine() -> String {
"lsm".to_string()
}
fn default_memtable_size() -> usize {
64
}
fn default_block_cache_size() -> usize {
256
}
fn default_wal_dir() -> PathBuf {
PathBuf::from("wal")
}
fn default_wal_segment_size() -> usize {
64
}
fn default_sync_mode() -> String {
"interval".to_string()
}
fn default_compaction_strategy() -> String {
"leveled".to_string()
}
fn default_num_levels() -> usize {
7
}
fn default_level_multiplier() -> usize {
10
}
fn default_max_compactions() -> usize {
4
}
fn default_connection_timeout() -> u64 {
30
}
fn default_keepalive_interval() -> u64 {
60
}
fn default_heartbeat_interval() -> u64 {
100
}
fn default_election_timeout() -> u64 {
300
}
fn default_log_level() -> String {
"info".to_string()
}
fn default_log_format() -> String {
"pretty".to_string()
}
fn default_log_max_size() -> usize {
100
}
fn default_log_max_backups() -> usize {
10
}
fn default_metrics_address() -> String {
"127.0.0.1:9090".to_string()
}
fn default_metrics_interval() -> u64 {
60
}
fn default_true() -> bool {
true
}
fn default_false() -> bool {
false
}
fn default_auth_methods() -> Vec<String> {
vec!["mtls".to_string()]
}
fn default_jwt_algorithm() -> String {
"HS256".to_string()
}
fn default_jwt_expiration() -> u64 {
3600 }
fn default_api_key_header() -> String {
"X-API-Key".to_string()
}
fn default_user_role() -> String {
"user".to_string()
}
fn default_permission_mode() -> String {
"deny-by-default".to_string()
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
server: ServerSettings {
bind_address: "0.0.0.0:7878".to_string(),
data_dir: PathBuf::from("./data"),
pid_file: default_pid_file(),
max_connections: default_max_connections(),
shutdown_timeout_secs: default_shutdown_timeout(),
},
storage: StorageSettings {
engine: default_storage_engine(),
wal: WalSettings::default(),
memtable_size_mb: default_memtable_size(),
block_cache_size_mb: default_block_cache_size(),
compaction: CompactionSettings::default(),
},
network: NetworkSettings {
tls_enabled: false,
tls_cert: None,
tls_key: None,
tls_ca: None,
require_client_cert: false,
connection_timeout_secs: default_connection_timeout(),
keepalive_interval_secs: default_keepalive_interval(),
},
cluster: None,
logging: LoggingSettings {
level: default_log_level(),
format: default_log_format(),
file_enabled: false,
file_path: None,
rotation: LogRotationSettings::default(),
},
metrics: MetricsSettings {
enabled: true,
bind_address: default_metrics_address(),
export_interval_secs: default_metrics_interval(),
},
auth: AuthSettings::default(),
authz: AuthorizationSettings::default(),
}
}
}
impl Default for WalSettings {
fn default() -> Self {
Self {
enabled: true,
dir: default_wal_dir(),
segment_size_mb: default_wal_segment_size(),
sync_mode: default_sync_mode(),
}
}
}
impl Default for CompactionSettings {
fn default() -> Self {
Self {
strategy: default_compaction_strategy(),
num_levels: default_num_levels(),
level_multiplier: default_level_multiplier(),
max_concurrent: default_max_compactions(),
}
}
}
impl Default for LogRotationSettings {
fn default() -> Self {
Self {
enabled: true,
max_size_mb: default_log_max_size(),
max_backups: default_log_max_backups(),
}
}
}
impl Default for AuthSettings {
fn default() -> Self {
Self {
enabled: false,
methods: default_auth_methods(),
mtls: MtlsSettings::default(),
jwt: JwtSettings::default(),
api_key: ApiKeySettings::default(),
reject_unauthenticated: true,
}
}
}
impl Default for MtlsSettings {
fn default() -> Self {
Self {
enabled: false,
ca_certs_dir: None,
crl_path: None,
verify_cn: true,
allowed_organizations: Vec::new(),
}
}
}
impl Default for JwtSettings {
fn default() -> Self {
Self {
enabled: false,
secret: None,
public_key_path: None,
algorithm: default_jwt_algorithm(),
expiration_secs: default_jwt_expiration(),
issuer: None,
audience: None,
}
}
}
impl Default for ApiKeySettings {
fn default() -> Self {
Self {
enabled: false,
keys_file: None,
header_name: default_api_key_header(),
hash_keys: true,
}
}
}
impl Default for AuthorizationSettings {
fn default() -> Self {
Self {
enabled: true,
default_role: default_user_role(),
roles_file: None,
policies_file: None,
collection_permissions: true,
default_mode: default_permission_mode(),
audit_enabled: true,
audit_log_path: None,
}
}
}
impl ServerConfig {
pub fn from_file(path: impl AsRef<Path>) -> ConfigResult<Self> {
let contents = std::fs::read_to_string(path)?;
let config: ServerConfig = toml::from_str(&contents)?;
config.validate()?;
Ok(config)
}
pub fn from_file_with_env(path: impl AsRef<Path>) -> ConfigResult<Self> {
let mut config = Self::from_file(path)?;
config.apply_env_overrides();
config.validate()?;
Ok(config)
}
pub fn apply_env_overrides(&mut self) {
if let Ok(bind) = std::env::var("AMATERS_BIND_ADDRESS") {
self.server.bind_address = bind;
}
if let Ok(data_dir) = std::env::var("AMATERS_DATA_DIR") {
self.server.data_dir = PathBuf::from(data_dir);
}
if let Ok(log_level) = std::env::var("AMATERS_LOG_LEVEL") {
self.logging.level = log_level;
}
if let Ok(tls_enabled) = std::env::var("AMATERS_TLS_ENABLED") {
self.network.tls_enabled = tls_enabled.parse().unwrap_or(false);
}
}
pub fn validate(&self) -> ConfigResult<()> {
let _: SocketAddr = self
.server
.bind_address
.parse()
.map_err(|e| ConfigError::Validation(format!("Invalid bind address: {}", e)))?;
if self.server.data_dir.as_os_str().is_empty() {
return Err(ConfigError::Validation(
"Data directory cannot be empty".to_string(),
));
}
match self.storage.engine.as_str() {
"memory" | "lsm" => {}
other => {
return Err(ConfigError::Validation(format!(
"Invalid storage engine: {}. Must be 'memory' or 'lsm'",
other
)));
}
}
if self.network.tls_enabled {
if self.network.tls_cert.is_none() {
return Err(ConfigError::Validation(
"TLS enabled but no certificate file specified".to_string(),
));
}
if self.network.tls_key.is_none() {
return Err(ConfigError::Validation(
"TLS enabled but no key file specified".to_string(),
));
}
if self.network.require_client_cert && self.network.tls_ca.is_none() {
return Err(ConfigError::Validation(
"Client certificate required but no CA file specified".to_string(),
));
}
}
if let Some(ref cluster) = self.cluster {
if cluster.enabled && cluster.peers.is_empty() {
return Err(ConfigError::Validation(
"Cluster enabled but no peers specified".to_string(),
));
}
}
match self.logging.level.to_lowercase().as_str() {
"trace" | "debug" | "info" | "warn" | "error" => {}
other => {
return Err(ConfigError::Validation(format!(
"Invalid log level: {}. Must be one of: trace, debug, info, warn, error",
other
)));
}
}
let _: SocketAddr = self
.metrics
.bind_address
.parse()
.map_err(|e| ConfigError::Validation(format!("Invalid metrics address: {}", e)))?;
if self.auth.enabled {
let has_enabled_method = (self.auth.mtls.enabled
&& self.auth.methods.contains(&"mtls".to_string()))
|| (self.auth.jwt.enabled && self.auth.methods.contains(&"jwt".to_string()))
|| (self.auth.api_key.enabled
&& self.auth.methods.contains(&"api_key".to_string()));
if !has_enabled_method {
return Err(ConfigError::Validation(
"Authentication enabled but no valid auth methods configured".to_string(),
));
}
if self.auth.jwt.enabled {
match self.auth.jwt.algorithm.as_str() {
"HS256" => {
if self.auth.jwt.secret.is_none() {
return Err(ConfigError::Validation(
"JWT HS256 enabled but no secret key provided".to_string(),
));
}
}
"RS256" => {
if self.auth.jwt.public_key_path.is_none() {
return Err(ConfigError::Validation(
"JWT RS256 enabled but no public key path provided".to_string(),
));
}
}
other => {
return Err(ConfigError::Validation(format!(
"Invalid JWT algorithm: {}. Supported: HS256, RS256",
other
)));
}
}
}
if self.auth.api_key.enabled && self.auth.api_key.keys_file.is_none() {
return Err(ConfigError::Validation(
"API key auth enabled but no keys file specified".to_string(),
));
}
if self.auth.mtls.enabled && self.auth.mtls.ca_certs_dir.is_none() {
return Err(ConfigError::Validation(
"mTLS enabled but no CA certificates directory specified".to_string(),
));
}
}
if self.authz.enabled {
match self.authz.default_mode.as_str() {
"deny-by-default" | "allow-by-default" => {}
other => {
return Err(ConfigError::Validation(format!(
"Invalid authorization default mode: {}. Must be 'deny-by-default' or 'allow-by-default'",
other
)));
}
}
}
Ok(())
}
pub fn shutdown_timeout(&self) -> Duration {
Duration::from_secs(self.server.shutdown_timeout_secs)
}
pub fn connection_timeout(&self) -> Duration {
Duration::from_secs(self.network.connection_timeout_secs)
}
pub fn keepalive_interval(&self) -> Duration {
Duration::from_secs(self.network.keepalive_interval_secs)
}
pub fn save_to_file(&self, path: impl AsRef<Path>) -> ConfigResult<()> {
let contents = toml::to_string_pretty(self)
.map_err(|e| ConfigError::Validation(format!("Failed to serialize config: {}", e)))?;
std::fs::write(path, contents)?;
Ok(())
}
pub fn example() -> Self {
Self::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_default_config() {
let config = ServerConfig::default();
assert_eq!(config.server.bind_address, "0.0.0.0:7878");
assert_eq!(config.storage.engine, "lsm");
assert_eq!(config.logging.level, "info");
}
#[test]
fn test_config_validation() {
let config = ServerConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_invalid_bind_address() {
let mut config = ServerConfig::default();
config.server.bind_address = "invalid".to_string();
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_storage_engine() {
let mut config = ServerConfig::default();
config.storage.engine = "invalid".to_string();
assert!(config.validate().is_err());
}
#[test]
fn test_tls_validation() {
let mut config = ServerConfig::default();
config.network.tls_enabled = true;
assert!(config.validate().is_err()); }
#[test]
fn test_env_overrides() {
unsafe {
env::set_var("AMATERS_BIND_ADDRESS", "127.0.0.1:9999");
env::set_var("AMATERS_LOG_LEVEL", "debug");
}
let mut config = ServerConfig::default();
config.apply_env_overrides();
assert_eq!(config.server.bind_address, "127.0.0.1:9999");
assert_eq!(config.logging.level, "debug");
unsafe {
env::remove_var("AMATERS_BIND_ADDRESS");
env::remove_var("AMATERS_LOG_LEVEL");
}
}
#[test]
fn test_save_and_load() {
let temp_dir = env::temp_dir();
let config_path = temp_dir.join("test_config.toml");
let config = ServerConfig::default();
config
.save_to_file(&config_path)
.expect("Failed to save config");
let loaded = ServerConfig::from_file(&config_path).expect("Failed to load config");
assert_eq!(config.server.bind_address, loaded.server.bind_address);
std::fs::remove_file(&config_path).ok();
}
}