1use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::net::SocketAddr;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::Duration;
15use thiserror::Error;
16use tracing::{info, warn};
17
18#[derive(Error, Debug)]
20pub enum ConfigError {
21 #[error("Failed to read configuration file: {0}")]
22 ReadFile(#[from] std::io::Error),
23
24 #[error("Failed to parse TOML: {0}")]
25 ParseToml(#[from] toml::de::Error),
26
27 #[error("Validation error: {0}")]
28 Validation(String),
29
30 #[error("Invalid socket address: {0}")]
31 InvalidAddress(#[from] std::net::AddrParseError),
32}
33
34pub type ConfigResult<T> = Result<T, ConfigError>;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ServerConfig {
39 pub server: ServerSettings,
41
42 pub storage: StorageSettings,
44
45 pub network: NetworkSettings,
47
48 #[serde(default)]
50 pub cluster: Option<ClusterSettings>,
51
52 pub logging: LoggingSettings,
54
55 pub metrics: MetricsSettings,
57
58 #[serde(default)]
60 pub auth: AuthSettings,
61
62 #[serde(default)]
64 pub authz: AuthorizationSettings,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ServerSettings {
70 pub bind_address: String,
72
73 pub data_dir: PathBuf,
75
76 #[serde(default = "default_pid_file")]
78 pub pid_file: PathBuf,
79
80 #[serde(default = "default_max_connections")]
82 pub max_connections: usize,
83
84 #[serde(default = "default_shutdown_timeout")]
86 pub shutdown_timeout_secs: u64,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct StorageSettings {
92 #[serde(default = "default_storage_engine")]
94 pub engine: String,
95
96 #[serde(default)]
98 pub wal: WalSettings,
99
100 #[serde(default = "default_memtable_size")]
102 pub memtable_size_mb: usize,
103
104 #[serde(default = "default_block_cache_size")]
106 pub block_cache_size_mb: usize,
107
108 #[serde(default)]
110 pub compaction: CompactionSettings,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct WalSettings {
116 #[serde(default = "default_true")]
118 pub enabled: bool,
119
120 #[serde(default = "default_wal_dir")]
122 pub dir: PathBuf,
123
124 #[serde(default = "default_wal_segment_size")]
126 pub segment_size_mb: usize,
127
128 #[serde(default = "default_sync_mode")]
130 pub sync_mode: String,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct CompactionSettings {
136 #[serde(default = "default_compaction_strategy")]
138 pub strategy: String,
139
140 #[serde(default = "default_num_levels")]
142 pub num_levels: usize,
143
144 #[serde(default = "default_level_multiplier")]
146 pub level_multiplier: usize,
147
148 #[serde(default = "default_max_compactions")]
150 pub max_concurrent: usize,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct NetworkSettings {
156 #[serde(default = "default_false")]
158 pub tls_enabled: bool,
159
160 pub tls_cert: Option<PathBuf>,
162
163 pub tls_key: Option<PathBuf>,
165
166 pub tls_ca: Option<PathBuf>,
168
169 #[serde(default = "default_false")]
171 pub require_client_cert: bool,
172
173 #[serde(default = "default_connection_timeout")]
175 pub connection_timeout_secs: u64,
176
177 #[serde(default = "default_keepalive_interval")]
179 pub keepalive_interval_secs: u64,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct ClusterSettings {
185 #[serde(default = "default_true")]
187 pub enabled: bool,
188
189 pub node_id: u64,
191
192 pub peers: Vec<String>,
194
195 #[serde(default = "default_heartbeat_interval")]
197 pub heartbeat_interval_ms: u64,
198
199 #[serde(default = "default_election_timeout")]
201 pub election_timeout_ms: u64,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct LoggingSettings {
207 #[serde(default = "default_log_level")]
209 pub level: String,
210
211 #[serde(default = "default_log_format")]
213 pub format: String,
214
215 #[serde(default = "default_false")]
217 pub file_enabled: bool,
218
219 pub file_path: Option<PathBuf>,
221
222 #[serde(default)]
224 pub rotation: LogRotationSettings,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct LogRotationSettings {
230 #[serde(default = "default_true")]
232 pub enabled: bool,
233
234 #[serde(default = "default_log_max_size")]
236 pub max_size_mb: usize,
237
238 #[serde(default = "default_log_max_backups")]
240 pub max_backups: usize,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct MetricsSettings {
246 #[serde(default = "default_true")]
248 pub enabled: bool,
249
250 #[serde(default = "default_metrics_address")]
252 pub bind_address: String,
253
254 #[serde(default = "default_metrics_interval")]
256 pub export_interval_secs: u64,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct AuthSettings {
262 #[serde(default = "default_false")]
264 pub enabled: bool,
265
266 #[serde(default = "default_auth_methods")]
268 pub methods: Vec<String>,
269
270 #[serde(default)]
272 pub mtls: MtlsSettings,
273
274 #[serde(default)]
276 pub jwt: JwtSettings,
277
278 #[serde(default)]
280 pub api_key: ApiKeySettings,
281
282 #[serde(default = "default_true")]
284 pub reject_unauthenticated: bool,
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct MtlsSettings {
290 #[serde(default = "default_false")]
292 pub enabled: bool,
293
294 pub ca_certs_dir: Option<PathBuf>,
296
297 pub crl_path: Option<PathBuf>,
299
300 #[serde(default = "default_true")]
302 pub verify_cn: bool,
303
304 #[serde(default)]
306 pub allowed_organizations: Vec<String>,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct JwtSettings {
312 #[serde(default = "default_false")]
314 pub enabled: bool,
315
316 pub secret: Option<String>,
318
319 pub public_key_path: Option<PathBuf>,
321
322 pub ec_public_key_path: Option<PathBuf>,
324
325 pub ed_public_key_path: Option<PathBuf>,
327
328 #[serde(default = "default_jwt_algorithm")]
330 pub algorithm: String,
331
332 #[serde(default = "default_jwt_expiration")]
334 pub expiration_secs: u64,
335
336 pub issuer: Option<String>,
338
339 pub audience: Option<String>,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
345pub struct ApiKeySettings {
346 #[serde(default = "default_false")]
348 pub enabled: bool,
349
350 pub keys_file: Option<PathBuf>,
352
353 #[serde(default = "default_api_key_header")]
355 pub header_name: String,
356
357 #[serde(default = "default_true")]
359 pub hash_keys: bool,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct AuthorizationSettings {
365 #[serde(default = "default_true")]
367 pub enabled: bool,
368
369 #[serde(default = "default_user_role")]
371 pub default_role: String,
372
373 pub roles_file: Option<PathBuf>,
375
376 pub policies_file: Option<PathBuf>,
378
379 #[serde(default = "default_true")]
381 pub collection_permissions: bool,
382
383 #[serde(default = "default_permission_mode")]
385 pub default_mode: String,
386
387 #[serde(default = "default_true")]
389 pub audit_enabled: bool,
390
391 pub audit_log_path: Option<PathBuf>,
393}
394
395fn default_pid_file() -> PathBuf {
397 PathBuf::from("/var/run/amaters-server.pid")
398}
399
400fn default_max_connections() -> usize {
401 1000
402}
403
404fn default_shutdown_timeout() -> u64 {
405 30
406}
407
408fn default_storage_engine() -> String {
409 "lsm".to_string()
410}
411
412fn default_memtable_size() -> usize {
413 64
414}
415
416fn default_block_cache_size() -> usize {
417 256
418}
419
420fn default_wal_dir() -> PathBuf {
421 PathBuf::from("wal")
422}
423
424fn default_wal_segment_size() -> usize {
425 64
426}
427
428fn default_sync_mode() -> String {
429 "interval".to_string()
430}
431
432fn default_compaction_strategy() -> String {
433 "leveled".to_string()
434}
435
436fn default_num_levels() -> usize {
437 7
438}
439
440fn default_level_multiplier() -> usize {
441 10
442}
443
444fn default_max_compactions() -> usize {
445 4
446}
447
448fn default_connection_timeout() -> u64 {
449 30
450}
451
452fn default_keepalive_interval() -> u64 {
453 60
454}
455
456fn default_heartbeat_interval() -> u64 {
457 100
458}
459
460fn default_election_timeout() -> u64 {
461 300
462}
463
464fn default_log_level() -> String {
465 "info".to_string()
466}
467
468fn default_log_format() -> String {
469 "pretty".to_string()
470}
471
472fn default_log_max_size() -> usize {
473 100
474}
475
476fn default_log_max_backups() -> usize {
477 10
478}
479
480fn default_metrics_address() -> String {
481 "127.0.0.1:9090".to_string()
482}
483
484fn default_metrics_interval() -> u64 {
485 60
486}
487
488fn default_true() -> bool {
489 true
490}
491
492fn default_false() -> bool {
493 false
494}
495
496fn default_auth_methods() -> Vec<String> {
497 vec!["mtls".to_string()]
498}
499
500fn default_jwt_algorithm() -> String {
501 "HS256".to_string()
502}
503
504fn default_jwt_expiration() -> u64 {
505 3600 }
507
508fn default_api_key_header() -> String {
509 "X-API-Key".to_string()
510}
511
512fn default_user_role() -> String {
513 "user".to_string()
514}
515
516fn default_permission_mode() -> String {
517 "deny-by-default".to_string()
518}
519
520impl Default for ServerConfig {
521 fn default() -> Self {
522 Self {
523 server: ServerSettings {
524 bind_address: "0.0.0.0:7878".to_string(),
525 data_dir: PathBuf::from("./data"),
526 pid_file: default_pid_file(),
527 max_connections: default_max_connections(),
528 shutdown_timeout_secs: default_shutdown_timeout(),
529 },
530 storage: StorageSettings {
531 engine: default_storage_engine(),
532 wal: WalSettings::default(),
533 memtable_size_mb: default_memtable_size(),
534 block_cache_size_mb: default_block_cache_size(),
535 compaction: CompactionSettings::default(),
536 },
537 network: NetworkSettings {
538 tls_enabled: false,
539 tls_cert: None,
540 tls_key: None,
541 tls_ca: None,
542 require_client_cert: false,
543 connection_timeout_secs: default_connection_timeout(),
544 keepalive_interval_secs: default_keepalive_interval(),
545 },
546 cluster: None,
547 logging: LoggingSettings {
548 level: default_log_level(),
549 format: default_log_format(),
550 file_enabled: false,
551 file_path: None,
552 rotation: LogRotationSettings::default(),
553 },
554 metrics: MetricsSettings {
555 enabled: true,
556 bind_address: default_metrics_address(),
557 export_interval_secs: default_metrics_interval(),
558 },
559 auth: AuthSettings::default(),
560 authz: AuthorizationSettings::default(),
561 }
562 }
563}
564
565impl Default for WalSettings {
566 fn default() -> Self {
567 Self {
568 enabled: true,
569 dir: default_wal_dir(),
570 segment_size_mb: default_wal_segment_size(),
571 sync_mode: default_sync_mode(),
572 }
573 }
574}
575
576impl Default for CompactionSettings {
577 fn default() -> Self {
578 Self {
579 strategy: default_compaction_strategy(),
580 num_levels: default_num_levels(),
581 level_multiplier: default_level_multiplier(),
582 max_concurrent: default_max_compactions(),
583 }
584 }
585}
586
587impl Default for LogRotationSettings {
588 fn default() -> Self {
589 Self {
590 enabled: true,
591 max_size_mb: default_log_max_size(),
592 max_backups: default_log_max_backups(),
593 }
594 }
595}
596
597impl Default for AuthSettings {
598 fn default() -> Self {
599 Self {
600 enabled: false,
601 methods: default_auth_methods(),
602 mtls: MtlsSettings::default(),
603 jwt: JwtSettings::default(),
604 api_key: ApiKeySettings::default(),
605 reject_unauthenticated: true,
606 }
607 }
608}
609
610impl Default for MtlsSettings {
611 fn default() -> Self {
612 Self {
613 enabled: false,
614 ca_certs_dir: None,
615 crl_path: None,
616 verify_cn: true,
617 allowed_organizations: Vec::new(),
618 }
619 }
620}
621
622impl Default for JwtSettings {
623 fn default() -> Self {
624 Self {
625 enabled: false,
626 secret: None,
627 public_key_path: None,
628 ec_public_key_path: None,
629 ed_public_key_path: None,
630 algorithm: default_jwt_algorithm(),
631 expiration_secs: default_jwt_expiration(),
632 issuer: None,
633 audience: None,
634 }
635 }
636}
637
638impl Default for ApiKeySettings {
639 fn default() -> Self {
640 Self {
641 enabled: false,
642 keys_file: None,
643 header_name: default_api_key_header(),
644 hash_keys: true,
645 }
646 }
647}
648
649impl Default for AuthorizationSettings {
650 fn default() -> Self {
651 Self {
652 enabled: true,
653 default_role: default_user_role(),
654 roles_file: None,
655 policies_file: None,
656 collection_permissions: true,
657 default_mode: default_permission_mode(),
658 audit_enabled: true,
659 audit_log_path: None,
660 }
661 }
662}
663
664impl ServerConfig {
665 pub fn from_file(path: impl AsRef<Path>) -> ConfigResult<Self> {
667 let contents = std::fs::read_to_string(path)?;
668 let config: ServerConfig = toml::from_str(&contents)?;
669 config.validate()?;
670 Ok(config)
671 }
672
673 pub fn from_file_with_env(path: impl AsRef<Path>) -> ConfigResult<Self> {
675 let mut config = Self::from_file(path)?;
676 config.apply_env_overrides();
677 config.validate()?;
678 Ok(config)
679 }
680
681 pub fn apply_env_overrides(&mut self) {
683 if let Ok(bind) = std::env::var("AMATERS_BIND_ADDRESS") {
684 self.server.bind_address = bind;
685 }
686 if let Ok(data_dir) = std::env::var("AMATERS_DATA_DIR") {
687 self.server.data_dir = PathBuf::from(data_dir);
688 }
689 if let Ok(log_level) = std::env::var("AMATERS_LOG_LEVEL") {
690 self.logging.level = log_level;
691 }
692 if let Ok(tls_enabled) = std::env::var("AMATERS_TLS_ENABLED") {
693 self.network.tls_enabled = tls_enabled.parse().unwrap_or(false);
694 }
695 }
696
697 pub fn validate(&self) -> ConfigResult<()> {
699 let _: SocketAddr = self
701 .server
702 .bind_address
703 .parse()
704 .map_err(|e| ConfigError::Validation(format!("Invalid bind address: {}", e)))?;
705
706 if self.server.data_dir.as_os_str().is_empty() {
708 return Err(ConfigError::Validation(
709 "Data directory cannot be empty".to_string(),
710 ));
711 }
712
713 match self.storage.engine.as_str() {
715 "memory" | "lsm" => {}
716 other => {
717 return Err(ConfigError::Validation(format!(
718 "Invalid storage engine: {}. Must be 'memory' or 'lsm'",
719 other
720 )));
721 }
722 }
723
724 if self.network.tls_enabled {
726 if self.network.tls_cert.is_none() {
727 return Err(ConfigError::Validation(
728 "TLS enabled but no certificate file specified".to_string(),
729 ));
730 }
731 if self.network.tls_key.is_none() {
732 return Err(ConfigError::Validation(
733 "TLS enabled but no key file specified".to_string(),
734 ));
735 }
736 if self.network.require_client_cert && self.network.tls_ca.is_none() {
737 return Err(ConfigError::Validation(
738 "Client certificate required but no CA file specified".to_string(),
739 ));
740 }
741 }
742
743 if let Some(ref cluster) = self.cluster {
745 if cluster.enabled && cluster.peers.is_empty() {
746 return Err(ConfigError::Validation(
747 "Cluster enabled but no peers specified".to_string(),
748 ));
749 }
750 }
751
752 match self.logging.level.to_lowercase().as_str() {
754 "trace" | "debug" | "info" | "warn" | "error" => {}
755 other => {
756 return Err(ConfigError::Validation(format!(
757 "Invalid log level: {}. Must be one of: trace, debug, info, warn, error",
758 other
759 )));
760 }
761 }
762
763 let _: SocketAddr = self
765 .metrics
766 .bind_address
767 .parse()
768 .map_err(|e| ConfigError::Validation(format!("Invalid metrics address: {}", e)))?;
769
770 if self.auth.enabled {
772 let has_enabled_method = (self.auth.mtls.enabled
774 && self.auth.methods.contains(&"mtls".to_string()))
775 || (self.auth.jwt.enabled && self.auth.methods.contains(&"jwt".to_string()))
776 || (self.auth.api_key.enabled
777 && self.auth.methods.contains(&"api_key".to_string()));
778
779 if !has_enabled_method {
780 return Err(ConfigError::Validation(
781 "Authentication enabled but no valid auth methods configured".to_string(),
782 ));
783 }
784
785 if self.auth.jwt.enabled {
787 match self.auth.jwt.algorithm.as_str() {
788 "HS256" => {
789 if self.auth.jwt.secret.is_none() {
790 return Err(ConfigError::Validation(
791 "JWT HS256 enabled but no secret key provided".to_string(),
792 ));
793 }
794 }
795 "RS256" => {
796 if self.auth.jwt.public_key_path.is_none() {
797 return Err(ConfigError::Validation(
798 "JWT RS256 enabled but no public key path provided".to_string(),
799 ));
800 }
801 }
802 other => {
803 return Err(ConfigError::Validation(format!(
804 "Invalid JWT algorithm: {}. Supported: HS256, RS256",
805 other
806 )));
807 }
808 }
809 }
810
811 if self.auth.api_key.enabled && self.auth.api_key.keys_file.is_none() {
813 return Err(ConfigError::Validation(
814 "API key auth enabled but no keys file specified".to_string(),
815 ));
816 }
817
818 if self.auth.mtls.enabled && self.auth.mtls.ca_certs_dir.is_none() {
820 return Err(ConfigError::Validation(
821 "mTLS enabled but no CA certificates directory specified".to_string(),
822 ));
823 }
824 }
825
826 if self.authz.enabled {
828 match self.authz.default_mode.as_str() {
829 "deny-by-default" | "allow-by-default" => {}
830 other => {
831 return Err(ConfigError::Validation(format!(
832 "Invalid authorization default mode: {}. Must be 'deny-by-default' or 'allow-by-default'",
833 other
834 )));
835 }
836 }
837 }
838
839 Ok(())
840 }
841
842 pub fn shutdown_timeout(&self) -> Duration {
844 Duration::from_secs(self.server.shutdown_timeout_secs)
845 }
846
847 pub fn connection_timeout(&self) -> Duration {
849 Duration::from_secs(self.network.connection_timeout_secs)
850 }
851
852 pub fn keepalive_interval(&self) -> Duration {
854 Duration::from_secs(self.network.keepalive_interval_secs)
855 }
856
857 pub fn save_to_file(&self, path: impl AsRef<Path>) -> ConfigResult<()> {
859 let contents = toml::to_string_pretty(self)
860 .map_err(|e| ConfigError::Validation(format!("Failed to serialize config: {}", e)))?;
861 std::fs::write(path, contents)?;
862 Ok(())
863 }
864
865 pub fn example() -> Self {
867 Self::default()
868 }
869}
870
871#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
873pub enum ReloadableSection {
874 Logging,
876 Metrics,
878 Compaction,
880 RateLimit,
882}
883
884impl std::fmt::Display for ReloadableSection {
885 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
886 match self {
887 ReloadableSection::Logging => write!(f, "logging"),
888 ReloadableSection::Metrics => write!(f, "metrics"),
889 ReloadableSection::Compaction => write!(f, "compaction"),
890 ReloadableSection::RateLimit => write!(f, "rate_limit"),
891 }
892 }
893}
894
895#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
897pub enum NonReloadableSection {
898 BindAddress,
900 Port,
902 TlsCertPath,
904 TlsKeyPath,
906 StorageEngine,
908 DataDir,
910 ClusterNodeId,
912}
913
914impl std::fmt::Display for NonReloadableSection {
915 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
916 match self {
917 NonReloadableSection::BindAddress => write!(f, "bind_address"),
918 NonReloadableSection::Port => write!(f, "port"),
919 NonReloadableSection::TlsCertPath => write!(f, "tls_cert_path"),
920 NonReloadableSection::TlsKeyPath => write!(f, "tls_key_path"),
921 NonReloadableSection::StorageEngine => write!(f, "storage_engine"),
922 NonReloadableSection::DataDir => write!(f, "data_dir"),
923 NonReloadableSection::ClusterNodeId => write!(f, "cluster_node_id"),
924 }
925 }
926}
927
928#[derive(Debug, Clone, Default)]
930pub struct ConfigDiff {
931 pub reloadable_changes: Vec<ReloadableSection>,
933 pub non_reloadable_changes: Vec<NonReloadableSection>,
935}
936
937impl ConfigDiff {
938 pub fn is_empty(&self) -> bool {
940 self.reloadable_changes.is_empty() && self.non_reloadable_changes.is_empty()
941 }
942
943 pub fn has_non_reloadable_changes(&self) -> bool {
945 !self.non_reloadable_changes.is_empty()
946 }
947}
948
949pub fn diff(old: &ServerConfig, new: &ServerConfig) -> ConfigDiff {
951 let mut result = ConfigDiff::default();
952
953 if old.logging.level != new.logging.level
955 || old.logging.format != new.logging.format
956 || old.logging.file_enabled != new.logging.file_enabled
957 || old.logging.file_path != new.logging.file_path
958 || old.logging.rotation.enabled != new.logging.rotation.enabled
959 || old.logging.rotation.max_size_mb != new.logging.rotation.max_size_mb
960 || old.logging.rotation.max_backups != new.logging.rotation.max_backups
961 {
962 result.reloadable_changes.push(ReloadableSection::Logging);
963 }
964
965 if old.metrics.export_interval_secs != new.metrics.export_interval_secs
966 || old.metrics.enabled != new.metrics.enabled
967 {
968 result.reloadable_changes.push(ReloadableSection::Metrics);
969 }
970
971 if old.storage.compaction.strategy != new.storage.compaction.strategy
972 || old.storage.compaction.num_levels != new.storage.compaction.num_levels
973 || old.storage.compaction.level_multiplier != new.storage.compaction.level_multiplier
974 || old.storage.compaction.max_concurrent != new.storage.compaction.max_concurrent
975 {
976 result
977 .reloadable_changes
978 .push(ReloadableSection::Compaction);
979 }
980
981 if old.server.max_connections != new.server.max_connections {
982 result.reloadable_changes.push(ReloadableSection::RateLimit);
983 }
984
985 if old.server.bind_address != new.server.bind_address {
987 result
988 .non_reloadable_changes
989 .push(NonReloadableSection::BindAddress);
990 }
991
992 if old.server.data_dir != new.server.data_dir {
993 result
994 .non_reloadable_changes
995 .push(NonReloadableSection::DataDir);
996 }
997
998 if old.storage.engine != new.storage.engine {
999 result
1000 .non_reloadable_changes
1001 .push(NonReloadableSection::StorageEngine);
1002 }
1003
1004 if old.network.tls_cert != new.network.tls_cert {
1005 result
1006 .non_reloadable_changes
1007 .push(NonReloadableSection::TlsCertPath);
1008 }
1009
1010 if old.network.tls_key != new.network.tls_key {
1011 result
1012 .non_reloadable_changes
1013 .push(NonReloadableSection::TlsKeyPath);
1014 }
1015
1016 if let (Some(old_cluster), Some(new_cluster)) = (&old.cluster, &new.cluster) {
1017 if old_cluster.node_id != new_cluster.node_id {
1018 result
1019 .non_reloadable_changes
1020 .push(NonReloadableSection::ClusterNodeId);
1021 }
1022 }
1023
1024 result
1025}
1026
1027#[derive(Debug, Clone)]
1029pub struct ReloadReport {
1030 pub sections_updated: Vec<ReloadableSection>,
1032 pub sections_skipped: Vec<NonReloadableSection>,
1034 pub errors: Vec<String>,
1036 pub success: bool,
1038}
1039
1040impl std::fmt::Display for ReloadReport {
1041 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1042 if self.success {
1043 write!(f, "Config reload successful. ")?;
1044 } else {
1045 write!(f, "Config reload failed. ")?;
1046 }
1047 if !self.sections_updated.is_empty() {
1048 write!(f, "Updated: ")?;
1049 for (i, s) in self.sections_updated.iter().enumerate() {
1050 if i > 0 {
1051 write!(f, ", ")?;
1052 }
1053 write!(f, "{}", s)?;
1054 }
1055 write!(f, ". ")?;
1056 }
1057 if !self.sections_skipped.is_empty() {
1058 write!(f, "Skipped (restart required): ")?;
1059 for (i, s) in self.sections_skipped.iter().enumerate() {
1060 if i > 0 {
1061 write!(f, ", ")?;
1062 }
1063 write!(f, "{}", s)?;
1064 }
1065 write!(f, ". ")?;
1066 }
1067 for err in &self.errors {
1068 write!(f, "Error: {}. ", err)?;
1069 }
1070 Ok(())
1071 }
1072}
1073
1074#[derive(Clone)]
1079pub struct ReloadableConfig {
1080 inner: Arc<RwLock<ServerConfig>>,
1081 config_path: Arc<RwLock<Option<PathBuf>>>,
1083}
1084
1085impl ReloadableConfig {
1086 pub fn new(config: ServerConfig) -> Self {
1088 Self {
1089 inner: Arc::new(RwLock::new(config)),
1090 config_path: Arc::new(RwLock::new(None)),
1091 }
1092 }
1093
1094 pub fn from_file(path: &str) -> ConfigResult<Self> {
1096 let config = ServerConfig::from_file(path)?;
1097 let rc = Self::new(config);
1098 *rc.config_path.write() = Some(PathBuf::from(path));
1099 Ok(rc)
1100 }
1101
1102 pub fn set_config_path(&self, path: PathBuf) {
1104 *self.config_path.write() = Some(path);
1105 }
1106
1107 pub fn read(&self) -> parking_lot::RwLockReadGuard<'_, ServerConfig> {
1109 self.inner.read()
1110 }
1111
1112 pub fn snapshot(&self) -> ServerConfig {
1114 self.inner.read().clone()
1115 }
1116
1117 pub fn reload_from_file(&self, path: &str) -> ConfigResult<ReloadReport> {
1123 let contents = std::fs::read_to_string(path)?;
1125 let new_config: ServerConfig = toml::from_str(&contents)?;
1126
1127 if let Err(e) = new_config.validate() {
1129 return Ok(ReloadReport {
1130 sections_updated: Vec::new(),
1131 sections_skipped: Vec::new(),
1132 errors: vec![format!("Validation failed: {}", e)],
1133 success: false,
1134 });
1135 }
1136
1137 self.apply_reload(new_config)
1138 }
1139
1140 pub fn reload_from_stored_path(&self) -> ConfigResult<ReloadReport> {
1142 let path = self.config_path.read().clone();
1143 match path {
1144 Some(p) => {
1145 let path_str = p.to_string_lossy().to_string();
1146 self.reload_from_file(&path_str)
1147 }
1148 None => Ok(ReloadReport {
1149 sections_updated: Vec::new(),
1150 sections_skipped: Vec::new(),
1151 errors: vec!["No config file path set for reload".to_string()],
1152 success: false,
1153 }),
1154 }
1155 }
1156
1157 fn apply_reload(&self, new_config: ServerConfig) -> ConfigResult<ReloadReport> {
1159 let mut report = ReloadReport {
1160 sections_updated: Vec::new(),
1161 sections_skipped: Vec::new(),
1162 errors: Vec::new(),
1163 success: true,
1164 };
1165
1166 let config_diff = {
1167 let current = self.inner.read();
1168 diff(¤t, &new_config)
1169 };
1170
1171 if config_diff.is_empty() {
1172 info!("Config reload: no changes detected");
1173 return Ok(report);
1174 }
1175
1176 for section in &config_diff.non_reloadable_changes {
1178 warn!(
1179 "Config reload: section '{}' changed but requires restart - skipping",
1180 section
1181 );
1182 report.sections_skipped.push(*section);
1183 }
1184
1185 if !config_diff.reloadable_changes.is_empty() {
1187 let mut current = self.inner.write();
1188
1189 for section in &config_diff.reloadable_changes {
1190 match section {
1191 ReloadableSection::Logging => {
1192 current.logging = new_config.logging.clone();
1193 info!("Config reload: updated logging settings");
1194 }
1195 ReloadableSection::Metrics => {
1196 current.metrics.export_interval_secs =
1198 new_config.metrics.export_interval_secs;
1199 current.metrics.enabled = new_config.metrics.enabled;
1200 info!("Config reload: updated metrics settings");
1201 }
1202 ReloadableSection::Compaction => {
1203 current.storage.compaction = new_config.storage.compaction.clone();
1204 info!("Config reload: updated compaction settings");
1205 }
1206 ReloadableSection::RateLimit => {
1207 current.server.max_connections = new_config.server.max_connections;
1208 info!("Config reload: updated rate limit settings");
1209 }
1210 }
1211 report.sections_updated.push(*section);
1212 }
1213 }
1214
1215 Ok(report)
1216 }
1217
1218 pub fn manual_reload(&self) -> ConfigResult<ReloadReport> {
1220 self.reload_from_stored_path()
1221 }
1222}
1223
1224impl std::fmt::Debug for ReloadableConfig {
1225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1226 f.debug_struct("ReloadableConfig")
1227 .field("config", &*self.inner.read())
1228 .field("config_path", &*self.config_path.read())
1229 .finish()
1230 }
1231}
1232
1233#[cfg(test)]
1234mod tests {
1235 use super::*;
1236 use std::env;
1237
1238 #[test]
1239 fn test_default_config() {
1240 let config = ServerConfig::default();
1241 assert_eq!(config.server.bind_address, "0.0.0.0:7878");
1242 assert_eq!(config.storage.engine, "lsm");
1243 assert_eq!(config.logging.level, "info");
1244 }
1245
1246 #[test]
1247 fn test_config_validation() {
1248 let config = ServerConfig::default();
1249 assert!(config.validate().is_ok());
1250 }
1251
1252 #[test]
1253 fn test_invalid_bind_address() {
1254 let mut config = ServerConfig::default();
1255 config.server.bind_address = "invalid".to_string();
1256 assert!(config.validate().is_err());
1257 }
1258
1259 #[test]
1260 fn test_invalid_storage_engine() {
1261 let mut config = ServerConfig::default();
1262 config.storage.engine = "invalid".to_string();
1263 assert!(config.validate().is_err());
1264 }
1265
1266 #[test]
1267 fn test_tls_validation() {
1268 let mut config = ServerConfig::default();
1269 config.network.tls_enabled = true;
1270 assert!(config.validate().is_err()); }
1272
1273 #[test]
1274 fn test_env_overrides() {
1275 unsafe {
1276 env::set_var("AMATERS_BIND_ADDRESS", "127.0.0.1:9999");
1277 env::set_var("AMATERS_LOG_LEVEL", "debug");
1278 }
1279
1280 let mut config = ServerConfig::default();
1281 config.apply_env_overrides();
1282
1283 assert_eq!(config.server.bind_address, "127.0.0.1:9999");
1284 assert_eq!(config.logging.level, "debug");
1285
1286 unsafe {
1287 env::remove_var("AMATERS_BIND_ADDRESS");
1288 env::remove_var("AMATERS_LOG_LEVEL");
1289 }
1290 }
1291
1292 #[test]
1293 fn test_save_and_load() {
1294 let temp_dir = env::temp_dir();
1295 let config_path = temp_dir.join("test_config.toml");
1296
1297 let config = ServerConfig::default();
1298 config
1299 .save_to_file(&config_path)
1300 .expect("Failed to save config");
1301
1302 let loaded = ServerConfig::from_file(&config_path).expect("Failed to load config");
1303 assert_eq!(config.server.bind_address, loaded.server.bind_address);
1304
1305 std::fs::remove_file(&config_path).ok();
1306 }
1307
1308 fn save_temp_config(config: &ServerConfig, name: &str) -> PathBuf {
1312 let path = env::temp_dir().join(format!("amaters_reload_test_{}.toml", name));
1313 config
1314 .save_to_file(&path)
1315 .expect("Failed to save temp config");
1316 path
1317 }
1318
1319 #[test]
1320 fn test_reload_logging_section() {
1321 let config = ServerConfig::default();
1322 let path = save_temp_config(&config, "reload_logging");
1323
1324 let reloadable = ReloadableConfig::new(config);
1325 reloadable.set_config_path(path.clone());
1326
1327 let mut new_config = reloadable.snapshot();
1329 new_config.logging.level = "debug".to_string();
1330 new_config.logging.format = "json".to_string();
1331 new_config
1332 .save_to_file(&path)
1333 .expect("Failed to save modified config");
1334
1335 let report = reloadable
1336 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1337 .expect("Reload should succeed");
1338
1339 assert!(report.success);
1340 assert!(
1341 report
1342 .sections_updated
1343 .contains(&ReloadableSection::Logging)
1344 );
1345 assert_eq!(reloadable.read().logging.level, "debug");
1346 assert_eq!(reloadable.read().logging.format, "json");
1347
1348 std::fs::remove_file(&path).ok();
1349 }
1350
1351 #[test]
1352 fn test_reload_metrics_section() {
1353 let config = ServerConfig::default();
1354 let path = save_temp_config(&config, "reload_metrics");
1355
1356 let reloadable = ReloadableConfig::new(config);
1357
1358 let mut new_config = reloadable.snapshot();
1359 new_config.metrics.export_interval_secs = 120;
1360 new_config
1361 .save_to_file(&path)
1362 .expect("Failed to save modified config");
1363
1364 let report = reloadable
1365 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1366 .expect("Reload should succeed");
1367
1368 assert!(report.success);
1369 assert!(
1370 report
1371 .sections_updated
1372 .contains(&ReloadableSection::Metrics)
1373 );
1374 assert_eq!(reloadable.read().metrics.export_interval_secs, 120);
1375
1376 std::fs::remove_file(&path).ok();
1377 }
1378
1379 #[test]
1380 fn test_non_reloadable_section_skipped() {
1381 let config = ServerConfig::default();
1382 let path = save_temp_config(&config, "reload_non_reloadable");
1383
1384 let reloadable = ReloadableConfig::new(config);
1385
1386 let mut new_config = reloadable.snapshot();
1387 new_config.server.bind_address = "127.0.0.1:9999".to_string();
1389 new_config.logging.level = "warn".to_string();
1391 new_config
1392 .save_to_file(&path)
1393 .expect("Failed to save modified config");
1394
1395 let report = reloadable
1396 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1397 .expect("Reload should succeed");
1398
1399 assert!(report.success);
1400 assert!(
1402 report
1403 .sections_updated
1404 .contains(&ReloadableSection::Logging)
1405 );
1406 assert_eq!(reloadable.read().logging.level, "warn");
1407 assert!(
1409 report
1410 .sections_skipped
1411 .contains(&NonReloadableSection::BindAddress)
1412 );
1413 assert_eq!(reloadable.read().server.bind_address, "0.0.0.0:7878");
1414
1415 std::fs::remove_file(&path).ok();
1416 }
1417
1418 #[test]
1419 fn test_invalid_config_rejected() {
1420 let config = ServerConfig::default();
1421 let path = save_temp_config(&config, "reload_invalid");
1422
1423 let reloadable = ReloadableConfig::new(config);
1424
1425 let mut new_config = reloadable.snapshot();
1427 new_config.server.bind_address = "not-an-address".to_string();
1428 let contents = toml::to_string_pretty(&new_config).expect("Failed to serialize config");
1430 std::fs::write(&path, contents).expect("Failed to write config");
1431
1432 let report = reloadable
1433 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1434 .expect("Reload should return report");
1435
1436 assert!(!report.success);
1437 assert!(!report.errors.is_empty());
1438 assert_eq!(reloadable.read().server.bind_address, "0.0.0.0:7878");
1440
1441 std::fs::remove_file(&path).ok();
1442 }
1443
1444 #[test]
1445 fn test_config_diff_detection() {
1446 let old = ServerConfig::default();
1447 let mut new = old.clone();
1448
1449 let d = diff(&old, &new);
1451 assert!(d.is_empty());
1452
1453 new.logging.level = "error".to_string();
1455 let d = diff(&old, &new);
1456 assert!(d.reloadable_changes.contains(&ReloadableSection::Logging));
1457 assert!(!d.has_non_reloadable_changes());
1458
1459 new.server.bind_address = "127.0.0.1:1234".to_string();
1461 let d = diff(&old, &new);
1462 assert!(d.has_non_reloadable_changes());
1463 assert!(
1464 d.non_reloadable_changes
1465 .contains(&NonReloadableSection::BindAddress)
1466 );
1467
1468 new.storage.compaction.strategy = "tiered".to_string();
1470 let d = diff(&old, &new);
1471 assert!(
1472 d.reloadable_changes
1473 .contains(&ReloadableSection::Compaction)
1474 );
1475
1476 new.server.max_connections = 5000;
1478 let d = diff(&old, &new);
1479 assert!(d.reloadable_changes.contains(&ReloadableSection::RateLimit));
1480 }
1481
1482 #[test]
1483 fn test_reload_report_contents() {
1484 let config = ServerConfig::default();
1485 let path = save_temp_config(&config, "reload_report");
1486
1487 let reloadable = ReloadableConfig::new(config);
1488
1489 let mut new_config = reloadable.snapshot();
1491 new_config.logging.level = "trace".to_string();
1492 new_config.metrics.export_interval_secs = 30;
1493 new_config.server.bind_address = "127.0.0.1:5555".to_string();
1494 new_config
1495 .save_to_file(&path)
1496 .expect("Failed to save modified config");
1497
1498 let report = reloadable
1499 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1500 .expect("Reload should succeed");
1501
1502 assert!(report.success);
1503 assert_eq!(report.sections_updated.len(), 2); assert_eq!(report.sections_skipped.len(), 1); assert!(report.errors.is_empty());
1506
1507 let display = format!("{}", report);
1509 assert!(display.contains("Updated"));
1510 assert!(display.contains("Skipped"));
1511
1512 std::fs::remove_file(&path).ok();
1513 }
1514
1515 #[test]
1516 fn test_concurrent_reads_during_reload() {
1517 let config = ServerConfig::default();
1518 let path = save_temp_config(&config, "reload_concurrent");
1519
1520 let reloadable = ReloadableConfig::new(config);
1521
1522 let handles: Vec<_> = (0..10)
1524 .map(|_| {
1525 let rc = reloadable.clone();
1526 std::thread::spawn(move || {
1527 for _ in 0..100 {
1528 let _level = rc.read().logging.level.clone();
1529 }
1530 })
1531 })
1532 .collect();
1533
1534 let mut new_config = reloadable.snapshot();
1536 new_config.logging.level = "debug".to_string();
1537 new_config
1538 .save_to_file(&path)
1539 .expect("Failed to save modified config");
1540
1541 let _report = reloadable
1542 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1543 .expect("Reload should succeed");
1544
1545 for h in handles {
1546 h.join().expect("Reader thread should not panic");
1547 }
1548
1549 assert_eq!(reloadable.read().logging.level, "debug");
1550
1551 std::fs::remove_file(&path).ok();
1552 }
1553
1554 #[test]
1555 fn test_multiple_sequential_reloads() {
1556 let config = ServerConfig::default();
1557 let path = save_temp_config(&config, "reload_sequential");
1558
1559 let reloadable = ReloadableConfig::new(config);
1560
1561 let levels = ["debug", "warn", "error", "trace", "info"];
1562 for level in &levels {
1563 let mut new_config = reloadable.snapshot();
1564 new_config.logging.level = level.to_string();
1565 new_config
1566 .save_to_file(&path)
1567 .expect("Failed to save modified config");
1568
1569 let report = reloadable
1570 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1571 .expect("Reload should succeed");
1572
1573 assert!(report.success);
1574 assert_eq!(reloadable.read().logging.level, *level);
1575 }
1576
1577 std::fs::remove_file(&path).ok();
1578 }
1579
1580 #[test]
1581 fn test_reload_no_stored_path() {
1582 let config = ServerConfig::default();
1583 let reloadable = ReloadableConfig::new(config);
1584
1585 let report = reloadable
1586 .reload_from_stored_path()
1587 .expect("Should return report");
1588
1589 assert!(!report.success);
1590 assert!(!report.errors.is_empty());
1591 }
1592
1593 #[test]
1594 fn test_reloadable_config_from_file() {
1595 let config = ServerConfig::default();
1596 let path = save_temp_config(&config, "reload_from_file");
1597
1598 let reloadable =
1599 ReloadableConfig::from_file(path.to_str().expect("path should be valid utf-8"))
1600 .expect("Should load from file");
1601
1602 assert_eq!(reloadable.read().server.bind_address, "0.0.0.0:7878");
1603
1604 std::fs::remove_file(&path).ok();
1605 }
1606
1607 #[test]
1608 fn test_manual_reload() {
1609 let config = ServerConfig::default();
1610 let path = save_temp_config(&config, "reload_manual");
1611
1612 let reloadable = ReloadableConfig::new(config);
1613 reloadable.set_config_path(path.clone());
1614
1615 let mut new_config = reloadable.snapshot();
1616 new_config.logging.level = "error".to_string();
1617 new_config
1618 .save_to_file(&path)
1619 .expect("Failed to save modified config");
1620
1621 let report = reloadable
1622 .manual_reload()
1623 .expect("Manual reload should succeed");
1624 assert!(report.success);
1625 assert_eq!(reloadable.read().logging.level, "error");
1626
1627 std::fs::remove_file(&path).ok();
1628 }
1629}