use serde::{Deserialize, Serialize};
use std::path::PathBuf;
fn default_heartbeat_ms() -> u64 {
150
}
fn default_election_timeout_ms() -> u64 {
300
}
fn default_compaction_threshold() -> usize {
10_000
}
fn default_metrics_addr() -> String {
"0.0.0.0:9091".to_string()
}
fn default_key_retention_count() -> usize {
3
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeConfig {
pub bind_addr: String,
pub node_id: u64,
#[serde(default)]
pub peers: Vec<String>,
#[serde(default = "default_heartbeat_ms")]
pub heartbeat_interval_ms: u64,
#[serde(default = "default_election_timeout_ms")]
pub election_timeout_ms: u64,
#[serde(default = "default_compaction_threshold")]
pub compaction_threshold: usize,
#[serde(default)]
pub data_dir: Option<PathBuf>,
#[serde(default = "default_metrics_addr")]
pub metrics_addr: String,
#[serde(default)]
pub key_rotation_interval_secs: Option<u64>,
#[serde(default = "default_key_retention_count")]
pub key_retention_count: usize,
}
impl NodeConfig {
pub fn load(path: &std::path::Path) -> Result<Self, ConfigError> {
let raw = std::fs::read_to_string(path)?;
let mut cfg = Self::from_toml(&raw)?;
cfg.apply_env_overrides();
Ok(cfg)
}
pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
let cfg: Self = toml::from_str(toml_str)?;
Ok(cfg)
}
pub fn apply_env_overrides(&mut self) {
if let Ok(v) = std::env::var("AMATERS_BIND_ADDR") {
self.bind_addr = v;
}
if let Ok(v) = std::env::var("AMATERS_NODE_ID") {
if let Ok(n) = v.parse::<u64>() {
self.node_id = n;
}
}
if let Ok(v) = std::env::var("AMATERS_PEERS") {
self.peers = v
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
}
if let Ok(v) = std::env::var("AMATERS_HEARTBEAT_INTERVAL_MS") {
if let Ok(n) = v.parse::<u64>() {
self.heartbeat_interval_ms = n;
}
}
if let Ok(v) = std::env::var("AMATERS_ELECTION_TIMEOUT_MS") {
if let Ok(n) = v.parse::<u64>() {
self.election_timeout_ms = n;
}
}
if let Ok(v) = std::env::var("AMATERS_COMPACTION_THRESHOLD") {
if let Ok(n) = v.parse::<usize>() {
self.compaction_threshold = n;
}
}
if let Ok(v) = std::env::var("AMATERS_DATA_DIR") {
self.data_dir = Some(PathBuf::from(v));
}
if let Ok(v) = std::env::var("AMATERS_METRICS_ADDR") {
self.metrics_addr = v;
}
if let Ok(v) = std::env::var("AMATERS_KEY_ROTATION_INTERVAL_SECS") {
if let Ok(n) = v.parse::<u64>() {
self.key_rotation_interval_secs = Some(n);
}
}
if let Ok(v) = std::env::var("AMATERS_KEY_RETENTION_COUNT") {
if let Ok(n) = v.parse::<usize>() {
self.key_retention_count = n;
}
}
}
pub fn dynamic(&self) -> DynamicConfig {
DynamicConfig {
heartbeat_interval_ms: self.heartbeat_interval_ms,
compaction_threshold: self.compaction_threshold,
}
}
pub fn validate(&self) -> Vec<ConfigError> {
let mut errors = Vec::new();
if self.bind_addr.is_empty() {
errors.push(ConfigError::Validation {
field: "bind_addr".to_string(),
reason: "must not be empty".to_string(),
});
} else if !self.bind_addr.contains(':') {
errors.push(ConfigError::Validation {
field: "bind_addr".to_string(),
reason: "must contain a ':' separator (e.g. \"0.0.0.0:7001\")".to_string(),
});
}
if self.node_id == 0 {
errors.push(ConfigError::Validation {
field: "node_id".to_string(),
reason: "must be > 0 (0 is reserved as a sentinel)".to_string(),
});
}
if self.heartbeat_interval_ms == 0 {
errors.push(ConfigError::Validation {
field: "heartbeat_interval_ms".to_string(),
reason: "must be > 0".to_string(),
});
}
if self.heartbeat_interval_ms > 0
&& self.election_timeout_ms < 2 * self.heartbeat_interval_ms
{
errors.push(ConfigError::Validation {
field: "election_timeout_ms".to_string(),
reason: format!(
"must be >= 2 * heartbeat_interval_ms ({} >= {})",
self.election_timeout_ms,
2 * self.heartbeat_interval_ms,
),
});
}
errors
}
}
#[derive(Debug, Clone)]
pub struct DynamicConfig {
pub heartbeat_interval_ms: u64,
pub compaction_threshold: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("TOML parse error: {0}")]
TomlParse(#[from] toml::de::Error),
#[error("Validation error: field '{field}' — {reason}")]
Validation { field: String, reason: String },
}
#[cfg(test)]
mod tests {
use super::*;
const MINIMAL_TOML: &str = r#"
bind_addr = "0.0.0.0:7001"
node_id = 1
"#;
#[test]
fn test_config_from_toml() {
let cfg = NodeConfig::from_toml(MINIMAL_TOML).expect("valid TOML");
assert_eq!(cfg.bind_addr, "0.0.0.0:7001");
assert_eq!(cfg.node_id, 1);
assert!(cfg.peers.is_empty());
assert_eq!(cfg.heartbeat_interval_ms, 150);
assert_eq!(cfg.election_timeout_ms, 300);
assert_eq!(cfg.compaction_threshold, 10_000);
assert!(cfg.data_dir.is_none());
assert_eq!(cfg.metrics_addr, "0.0.0.0:9091");
assert!(
cfg.key_rotation_interval_secs.is_none(),
"rotation interval defaults to None (manual rotation only)"
);
assert_eq!(cfg.key_retention_count, 3);
}
#[test]
fn test_config_encryption_fields_from_toml() {
let toml = r#"
bind_addr = "0.0.0.0:7001"
node_id = 1
key_rotation_interval_secs = 86400
key_retention_count = 5
"#;
let cfg = NodeConfig::from_toml(toml).expect("valid TOML");
assert_eq!(cfg.key_rotation_interval_secs, Some(86_400));
assert_eq!(cfg.key_retention_count, 5);
}
#[test]
fn test_config_from_toml_full() {
let toml = r#"
bind_addr = "127.0.0.1:8001"
node_id = 5
peers = ["2=10.0.0.2:7001", "3=10.0.0.3:7001"]
heartbeat_interval_ms = 50
election_timeout_ms = 200
compaction_threshold = 5000
data_dir = "/var/data/raft"
metrics_addr = "0.0.0.0:9999"
"#;
let cfg = NodeConfig::from_toml(toml).expect("valid TOML");
assert_eq!(cfg.bind_addr, "127.0.0.1:8001");
assert_eq!(cfg.node_id, 5);
assert_eq!(cfg.peers, vec!["2=10.0.0.2:7001", "3=10.0.0.3:7001"]);
assert_eq!(cfg.heartbeat_interval_ms, 50);
assert_eq!(cfg.election_timeout_ms, 200);
assert_eq!(cfg.compaction_threshold, 5000);
assert_eq!(cfg.data_dir, Some(PathBuf::from("/var/data/raft")));
assert_eq!(cfg.metrics_addr, "0.0.0.0:9999");
}
#[test]
fn test_config_env_override() {
let mut cfg = NodeConfig::from_toml(MINIMAL_TOML).expect("valid TOML");
unsafe {
std::env::set_var("AMATERS_BIND_ADDR", "10.0.0.1:9000");
std::env::set_var("AMATERS_NODE_ID", "42");
std::env::set_var("AMATERS_PEERS", "2=10.0.0.2:7001,3=10.0.0.3:7001");
std::env::set_var("AMATERS_HEARTBEAT_INTERVAL_MS", "75");
std::env::set_var("AMATERS_ELECTION_TIMEOUT_MS", "400");
std::env::set_var("AMATERS_COMPACTION_THRESHOLD", "2000");
std::env::set_var("AMATERS_METRICS_ADDR", "127.0.0.1:8080");
}
cfg.apply_env_overrides();
assert_eq!(cfg.bind_addr, "10.0.0.1:9000");
assert_eq!(cfg.node_id, 42);
assert_eq!(cfg.peers, vec!["2=10.0.0.2:7001", "3=10.0.0.3:7001"]);
assert_eq!(cfg.heartbeat_interval_ms, 75);
assert_eq!(cfg.election_timeout_ms, 400);
assert_eq!(cfg.compaction_threshold, 2000);
assert_eq!(cfg.metrics_addr, "127.0.0.1:8080");
unsafe {
std::env::remove_var("AMATERS_BIND_ADDR");
std::env::remove_var("AMATERS_NODE_ID");
std::env::remove_var("AMATERS_PEERS");
std::env::remove_var("AMATERS_HEARTBEAT_INTERVAL_MS");
std::env::remove_var("AMATERS_ELECTION_TIMEOUT_MS");
std::env::remove_var("AMATERS_COMPACTION_THRESHOLD");
std::env::remove_var("AMATERS_METRICS_ADDR");
}
}
#[test]
fn test_config_validation_missing_field() {
let toml = r#"
bind_addr = "0.0.0.0:7001"
node_id = 0
"#;
let cfg = NodeConfig::from_toml(toml).expect("parse should succeed");
let errors = cfg.validate();
assert!(
!errors.is_empty(),
"expected validation errors for node_id = 0"
);
let has_node_id_error = errors
.iter()
.any(|e| matches!(e, ConfigError::Validation { field, .. } if field == "node_id"));
assert!(
has_node_id_error,
"expected a Validation error for 'node_id'"
);
}
#[test]
fn test_config_validation_out_of_range() {
let toml = r#"
bind_addr = "0.0.0.0:7001"
node_id = 1
heartbeat_interval_ms = 200
election_timeout_ms = 300
"#;
let cfg = NodeConfig::from_toml(toml).expect("parse should succeed");
let errors = cfg.validate();
assert!(
!errors.is_empty(),
"expected validation error: election_timeout_ms 300 < 2*200 = 400"
);
let has_timeout_error = errors.iter().any(|e| {
matches!(e, ConfigError::Validation { field, .. } if field == "election_timeout_ms")
});
assert!(
has_timeout_error,
"expected a Validation error for 'election_timeout_ms'"
);
}
#[test]
fn test_config_validation_passes_for_valid_config() {
let cfg = NodeConfig::from_toml(MINIMAL_TOML).expect("valid TOML");
let errors = cfg.validate();
assert!(
errors.is_empty(),
"expected no validation errors, got: {:?}",
errors
);
}
#[test]
fn test_config_dynamic_extraction() {
let toml = r#"
bind_addr = "0.0.0.0:7001"
node_id = 1
heartbeat_interval_ms = 100
compaction_threshold = 5000
"#;
let cfg = NodeConfig::from_toml(toml).expect("valid TOML");
let dyn_cfg = cfg.dynamic();
assert_eq!(dyn_cfg.heartbeat_interval_ms, 100);
assert_eq!(dyn_cfg.compaction_threshold, 5000);
}
#[test]
fn test_config_load_from_file() {
let dir = std::env::temp_dir();
let path = dir.join("amaters_cluster_test_config_load.toml");
std::fs::write(&path, MINIMAL_TOML).expect("write temp config");
let raw = std::fs::read_to_string(&path).expect("read temp config");
let cfg = NodeConfig::from_toml(&raw).expect("parse TOML from file");
assert_eq!(cfg.bind_addr, "0.0.0.0:7001");
assert_eq!(cfg.node_id, 1);
NodeConfig::load(&path).expect("load() must succeed for a valid file");
let _ = std::fs::remove_file(&path);
}
}