1use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::net::SocketAddr;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::Duration;
15use thiserror::Error;
16use tracing::{info, warn};
17
18#[derive(Error, Debug)]
20pub enum ConfigError {
21 #[error("Failed to read configuration file: {0}")]
22 ReadFile(#[from] std::io::Error),
23
24 #[error("Failed to parse TOML: {0}")]
25 ParseToml(#[from] toml::de::Error),
26
27 #[error("Validation error: {0}")]
28 Validation(String),
29
30 #[error("Invalid socket address: {0}")]
31 InvalidAddress(#[from] std::net::AddrParseError),
32}
33
34pub type ConfigResult<T> = Result<T, ConfigError>;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ServerConfig {
39 pub server: ServerSettings,
41
42 pub storage: StorageSettings,
44
45 pub network: NetworkSettings,
47
48 #[serde(default)]
50 pub cluster: Option<ClusterSettings>,
51
52 pub logging: LoggingSettings,
54
55 pub metrics: MetricsSettings,
57
58 #[serde(default)]
60 pub auth: AuthSettings,
61
62 #[serde(default)]
64 pub authz: AuthorizationSettings,
65
66 #[serde(default)]
68 pub resource_limits: ResourceLimits,
69
70 #[serde(default)]
72 pub circuit_cache: CircuitCacheSettings,
73
74 #[serde(default)]
76 pub timeouts: TimeoutConfig,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ServerSettings {
82 pub bind_address: String,
84
85 pub data_dir: PathBuf,
87
88 #[serde(default = "default_pid_file")]
90 pub pid_file: PathBuf,
91
92 #[serde(default = "default_max_connections")]
94 pub max_connections: usize,
95
96 #[serde(default = "default_shutdown_timeout")]
98 pub shutdown_timeout_secs: u64,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct StorageSettings {
104 #[serde(default = "default_storage_engine")]
106 pub engine: String,
107
108 #[serde(default)]
110 pub wal: WalSettings,
111
112 #[serde(default = "default_memtable_size")]
114 pub memtable_size_mb: usize,
115
116 #[serde(default = "default_block_cache_size")]
118 pub block_cache_size_mb: usize,
119
120 #[serde(default)]
122 pub compaction: CompactionSettings,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct WalSettings {
128 #[serde(default = "default_true")]
130 pub enabled: bool,
131
132 #[serde(default = "default_wal_dir")]
134 pub dir: PathBuf,
135
136 #[serde(default = "default_wal_segment_size")]
138 pub segment_size_mb: usize,
139
140 #[serde(default = "default_sync_mode")]
142 pub sync_mode: String,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct CompactionSettings {
148 #[serde(default = "default_compaction_strategy")]
150 pub strategy: String,
151
152 #[serde(default = "default_num_levels")]
154 pub num_levels: usize,
155
156 #[serde(default = "default_level_multiplier")]
158 pub level_multiplier: usize,
159
160 #[serde(default = "default_max_compactions")]
162 pub max_concurrent: usize,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct NetworkSettings {
168 #[serde(default = "default_false")]
170 pub tls_enabled: bool,
171
172 pub tls_cert: Option<PathBuf>,
174
175 pub tls_key: Option<PathBuf>,
177
178 pub tls_ca: Option<PathBuf>,
180
181 #[serde(default = "default_false")]
183 pub require_client_cert: bool,
184
185 #[serde(default = "default_connection_timeout")]
187 pub connection_timeout_secs: u64,
188
189 #[serde(default = "default_keepalive_interval")]
191 pub keepalive_interval_secs: u64,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ClusterSettings {
197 #[serde(default = "default_true")]
199 pub enabled: bool,
200
201 pub node_id: u64,
203
204 pub peers: Vec<String>,
206
207 #[serde(default = "default_heartbeat_interval")]
209 pub heartbeat_interval_ms: u64,
210
211 #[serde(default = "default_election_timeout")]
213 pub election_timeout_ms: u64,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct LoggingSettings {
219 #[serde(default = "default_log_level")]
221 pub level: String,
222
223 #[serde(default = "default_log_format")]
225 pub format: String,
226
227 #[serde(default = "default_false")]
229 pub file_enabled: bool,
230
231 pub file_path: Option<PathBuf>,
233
234 #[serde(default)]
236 pub rotation: LogRotationSettings,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct LogRotationSettings {
242 #[serde(default = "default_true")]
244 pub enabled: bool,
245
246 #[serde(default = "default_log_max_size")]
248 pub max_size_mb: usize,
249
250 #[serde(default = "default_log_max_backups")]
252 pub max_backups: usize,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct MetricsSettings {
258 #[serde(default = "default_true")]
260 pub enabled: bool,
261
262 #[serde(default = "default_metrics_address")]
264 pub bind_address: String,
265
266 #[serde(default = "default_metrics_interval")]
268 pub export_interval_secs: u64,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct AuthSettings {
274 #[serde(default = "default_false")]
276 pub enabled: bool,
277
278 #[serde(default = "default_auth_methods")]
280 pub methods: Vec<String>,
281
282 #[serde(default)]
284 pub mtls: MtlsSettings,
285
286 #[serde(default)]
288 pub jwt: JwtSettings,
289
290 #[serde(default)]
292 pub api_key: ApiKeySettings,
293
294 #[serde(default = "default_true")]
296 pub reject_unauthenticated: bool,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct MtlsSettings {
302 #[serde(default = "default_false")]
304 pub enabled: bool,
305
306 pub ca_certs_dir: Option<PathBuf>,
308
309 pub crl_path: Option<PathBuf>,
311
312 #[serde(default = "default_true")]
314 pub verify_cn: bool,
315
316 #[serde(default)]
318 pub allowed_organizations: Vec<String>,
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct JwtSettings {
324 #[serde(default = "default_false")]
326 pub enabled: bool,
327
328 pub secret: Option<String>,
330
331 pub public_key_path: Option<PathBuf>,
333
334 pub ec_public_key_path: Option<PathBuf>,
336
337 pub ed_public_key_path: Option<PathBuf>,
339
340 #[serde(default = "default_jwt_algorithm")]
342 pub algorithm: String,
343
344 #[serde(default = "default_jwt_expiration")]
346 pub expiration_secs: u64,
347
348 pub issuer: Option<String>,
350
351 pub audience: Option<String>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
357pub struct ApiKeySettings {
358 #[serde(default = "default_false")]
360 pub enabled: bool,
361
362 pub keys_file: Option<PathBuf>,
364
365 #[serde(default = "default_api_key_header")]
367 pub header_name: String,
368
369 #[serde(default = "default_true")]
371 pub hash_keys: bool,
372}
373
374#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct AuthorizationSettings {
377 #[serde(default = "default_true")]
379 pub enabled: bool,
380
381 #[serde(default = "default_user_role")]
383 pub default_role: String,
384
385 pub roles_file: Option<PathBuf>,
387
388 pub policies_file: Option<PathBuf>,
390
391 #[serde(default = "default_true")]
393 pub collection_permissions: bool,
394
395 #[serde(default = "default_permission_mode")]
397 pub default_mode: String,
398
399 #[serde(default = "default_true")]
401 pub audit_enabled: bool,
402
403 pub audit_log_path: Option<PathBuf>,
405}
406
407fn default_pid_file() -> PathBuf {
409 PathBuf::from("/var/run/amaters-server.pid")
410}
411
412fn default_max_connections() -> usize {
413 1000
414}
415
416fn default_shutdown_timeout() -> u64 {
417 30
418}
419
420fn default_storage_engine() -> String {
421 "lsm".to_string()
422}
423
424fn default_memtable_size() -> usize {
425 64
426}
427
428fn default_block_cache_size() -> usize {
429 256
430}
431
432fn default_wal_dir() -> PathBuf {
433 PathBuf::from("wal")
434}
435
436fn default_wal_segment_size() -> usize {
437 64
438}
439
440fn default_sync_mode() -> String {
441 "interval".to_string()
442}
443
444fn default_compaction_strategy() -> String {
445 "leveled".to_string()
446}
447
448fn default_num_levels() -> usize {
449 7
450}
451
452fn default_level_multiplier() -> usize {
453 10
454}
455
456fn default_max_compactions() -> usize {
457 4
458}
459
460fn default_connection_timeout() -> u64 {
461 30
462}
463
464fn default_keepalive_interval() -> u64 {
465 60
466}
467
468fn default_heartbeat_interval() -> u64 {
469 100
470}
471
472fn default_election_timeout() -> u64 {
473 300
474}
475
476fn default_log_level() -> String {
477 "info".to_string()
478}
479
480fn default_log_format() -> String {
481 "pretty".to_string()
482}
483
484fn default_log_max_size() -> usize {
485 100
486}
487
488fn default_log_max_backups() -> usize {
489 10
490}
491
492fn default_metrics_address() -> String {
493 "127.0.0.1:9090".to_string()
494}
495
496fn default_metrics_interval() -> u64 {
497 60
498}
499
500fn default_true() -> bool {
501 true
502}
503
504fn default_false() -> bool {
505 false
506}
507
508fn default_auth_methods() -> Vec<String> {
509 vec!["mtls".to_string()]
510}
511
512fn default_jwt_algorithm() -> String {
513 "HS256".to_string()
514}
515
516fn default_jwt_expiration() -> u64 {
517 3600 }
519
520fn default_api_key_header() -> String {
521 "X-API-Key".to_string()
522}
523
524fn default_user_role() -> String {
525 "user".to_string()
526}
527
528fn default_permission_mode() -> String {
529 "deny-by-default".to_string()
530}
531
532fn default_max_connections_per_client() -> usize {
533 10
534}
535
536fn default_max_rps_global() -> u64 {
537 10_000
538}
539
540fn default_max_active_queries() -> usize {
541 1000
542}
543
544fn default_circuit_cache_max_entries() -> usize {
545 1000
546}
547
548fn default_circuit_cache_ttl_secs() -> u64 {
549 300
550}
551
552fn default_request_timeout_ms() -> u64 {
553 30_000
554}
555
556fn default_idle_connection_timeout_ms() -> u64 {
557 60_000
558}
559
560fn default_graceful_shutdown_timeout_ms() -> u64 {
561 5_000
562}
563
564fn default_keep_alive_interval_ms() -> u64 {
565 15_000
566}
567
568#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct ResourceLimits {
571 #[serde(default = "default_max_connections_per_client")]
573 pub max_connections_per_client: usize,
574 #[serde(default = "default_max_rps_global")]
576 pub max_requests_per_second_global: u64,
577 #[serde(default)]
579 pub max_memory_bytes: Option<u64>,
580 #[serde(default = "default_max_active_queries")]
582 pub max_active_queries: usize,
583}
584
585impl Default for ResourceLimits {
586 fn default() -> Self {
587 Self {
588 max_connections_per_client: default_max_connections_per_client(),
589 max_requests_per_second_global: default_max_rps_global(),
590 max_memory_bytes: None,
591 max_active_queries: default_max_active_queries(),
592 }
593 }
594}
595
596#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct CircuitCacheSettings {
599 #[serde(default = "default_circuit_cache_max_entries")]
601 pub max_entries: usize,
602 #[serde(default = "default_circuit_cache_ttl_secs")]
604 pub ttl_secs: u64,
605}
606
607impl Default for CircuitCacheSettings {
608 fn default() -> Self {
609 Self {
610 max_entries: default_circuit_cache_max_entries(),
611 ttl_secs: default_circuit_cache_ttl_secs(),
612 }
613 }
614}
615
616#[derive(Debug, Clone, Serialize, Deserialize)]
618pub struct TimeoutConfig {
619 #[serde(default = "default_request_timeout_ms")]
621 pub request_timeout_ms: u64,
622 #[serde(default = "default_idle_connection_timeout_ms")]
624 pub idle_connection_timeout_ms: u64,
625 #[serde(default = "default_graceful_shutdown_timeout_ms")]
627 pub graceful_shutdown_timeout_ms: u64,
628 #[serde(default = "default_keep_alive_interval_ms")]
630 pub keep_alive_interval_ms: u64,
631}
632
633impl Default for TimeoutConfig {
634 fn default() -> Self {
635 Self {
636 request_timeout_ms: default_request_timeout_ms(),
637 idle_connection_timeout_ms: default_idle_connection_timeout_ms(),
638 graceful_shutdown_timeout_ms: default_graceful_shutdown_timeout_ms(),
639 keep_alive_interval_ms: default_keep_alive_interval_ms(),
640 }
641 }
642}
643
644impl Default for ServerConfig {
645 fn default() -> Self {
646 Self {
647 server: ServerSettings {
648 bind_address: "0.0.0.0:7878".to_string(),
649 data_dir: PathBuf::from("./data"),
650 pid_file: default_pid_file(),
651 max_connections: default_max_connections(),
652 shutdown_timeout_secs: default_shutdown_timeout(),
653 },
654 storage: StorageSettings {
655 engine: default_storage_engine(),
656 wal: WalSettings::default(),
657 memtable_size_mb: default_memtable_size(),
658 block_cache_size_mb: default_block_cache_size(),
659 compaction: CompactionSettings::default(),
660 },
661 network: NetworkSettings {
662 tls_enabled: false,
663 tls_cert: None,
664 tls_key: None,
665 tls_ca: None,
666 require_client_cert: false,
667 connection_timeout_secs: default_connection_timeout(),
668 keepalive_interval_secs: default_keepalive_interval(),
669 },
670 cluster: None,
671 logging: LoggingSettings {
672 level: default_log_level(),
673 format: default_log_format(),
674 file_enabled: false,
675 file_path: None,
676 rotation: LogRotationSettings::default(),
677 },
678 metrics: MetricsSettings {
679 enabled: true,
680 bind_address: default_metrics_address(),
681 export_interval_secs: default_metrics_interval(),
682 },
683 auth: AuthSettings::default(),
684 authz: AuthorizationSettings::default(),
685 resource_limits: ResourceLimits::default(),
686 circuit_cache: CircuitCacheSettings::default(),
687 timeouts: TimeoutConfig::default(),
688 }
689 }
690}
691
692impl Default for WalSettings {
693 fn default() -> Self {
694 Self {
695 enabled: true,
696 dir: default_wal_dir(),
697 segment_size_mb: default_wal_segment_size(),
698 sync_mode: default_sync_mode(),
699 }
700 }
701}
702
703impl Default for CompactionSettings {
704 fn default() -> Self {
705 Self {
706 strategy: default_compaction_strategy(),
707 num_levels: default_num_levels(),
708 level_multiplier: default_level_multiplier(),
709 max_concurrent: default_max_compactions(),
710 }
711 }
712}
713
714impl Default for LogRotationSettings {
715 fn default() -> Self {
716 Self {
717 enabled: true,
718 max_size_mb: default_log_max_size(),
719 max_backups: default_log_max_backups(),
720 }
721 }
722}
723
724impl Default for AuthSettings {
725 fn default() -> Self {
726 Self {
727 enabled: false,
728 methods: default_auth_methods(),
729 mtls: MtlsSettings::default(),
730 jwt: JwtSettings::default(),
731 api_key: ApiKeySettings::default(),
732 reject_unauthenticated: true,
733 }
734 }
735}
736
737impl Default for MtlsSettings {
738 fn default() -> Self {
739 Self {
740 enabled: false,
741 ca_certs_dir: None,
742 crl_path: None,
743 verify_cn: true,
744 allowed_organizations: Vec::new(),
745 }
746 }
747}
748
749impl Default for JwtSettings {
750 fn default() -> Self {
751 Self {
752 enabled: false,
753 secret: None,
754 public_key_path: None,
755 ec_public_key_path: None,
756 ed_public_key_path: None,
757 algorithm: default_jwt_algorithm(),
758 expiration_secs: default_jwt_expiration(),
759 issuer: None,
760 audience: None,
761 }
762 }
763}
764
765impl Default for ApiKeySettings {
766 fn default() -> Self {
767 Self {
768 enabled: false,
769 keys_file: None,
770 header_name: default_api_key_header(),
771 hash_keys: true,
772 }
773 }
774}
775
776impl Default for AuthorizationSettings {
777 fn default() -> Self {
778 Self {
779 enabled: true,
780 default_role: default_user_role(),
781 roles_file: None,
782 policies_file: None,
783 collection_permissions: true,
784 default_mode: default_permission_mode(),
785 audit_enabled: true,
786 audit_log_path: None,
787 }
788 }
789}
790
791impl ServerConfig {
792 pub fn from_file(path: impl AsRef<Path>) -> ConfigResult<Self> {
794 let contents = std::fs::read_to_string(path)?;
795 let config: ServerConfig = toml::from_str(&contents)?;
796 config.validate()?;
797 Ok(config)
798 }
799
800 pub fn from_file_with_env(path: impl AsRef<Path>) -> ConfigResult<Self> {
802 let mut config = Self::from_file(path)?;
803 config.apply_env_overrides();
804 config.validate()?;
805 Ok(config)
806 }
807
808 pub fn apply_env_overrides(&mut self) {
810 if let Ok(bind) = std::env::var("AMATERS_BIND_ADDRESS") {
811 self.server.bind_address = bind;
812 }
813 if let Ok(data_dir) = std::env::var("AMATERS_DATA_DIR") {
814 self.server.data_dir = PathBuf::from(data_dir);
815 }
816 if let Ok(log_level) = std::env::var("AMATERS_LOG_LEVEL") {
817 self.logging.level = log_level;
818 }
819 if let Ok(tls_enabled) = std::env::var("AMATERS_TLS_ENABLED") {
820 self.network.tls_enabled = tls_enabled.parse().unwrap_or(false);
821 }
822 }
823
824 pub fn validate(&self) -> ConfigResult<()> {
826 let _: SocketAddr = self
828 .server
829 .bind_address
830 .parse()
831 .map_err(|e| ConfigError::Validation(format!("Invalid bind address: {}", e)))?;
832
833 if self.server.data_dir.as_os_str().is_empty() {
835 return Err(ConfigError::Validation(
836 "Data directory cannot be empty".to_string(),
837 ));
838 }
839
840 match self.storage.engine.as_str() {
842 "memory" | "lsm" => {}
843 other => {
844 return Err(ConfigError::Validation(format!(
845 "Invalid storage engine: {}. Must be 'memory' or 'lsm'",
846 other
847 )));
848 }
849 }
850
851 if self.network.tls_enabled {
853 if self.network.tls_cert.is_none() {
854 return Err(ConfigError::Validation(
855 "TLS enabled but no certificate file specified".to_string(),
856 ));
857 }
858 if self.network.tls_key.is_none() {
859 return Err(ConfigError::Validation(
860 "TLS enabled but no key file specified".to_string(),
861 ));
862 }
863 if self.network.require_client_cert && self.network.tls_ca.is_none() {
864 return Err(ConfigError::Validation(
865 "Client certificate required but no CA file specified".to_string(),
866 ));
867 }
868 }
869
870 if let Some(ref cluster) = self.cluster {
872 if cluster.enabled && cluster.peers.is_empty() {
873 return Err(ConfigError::Validation(
874 "Cluster enabled but no peers specified".to_string(),
875 ));
876 }
877 }
878
879 match self.logging.level.to_lowercase().as_str() {
881 "trace" | "debug" | "info" | "warn" | "error" => {}
882 other => {
883 return Err(ConfigError::Validation(format!(
884 "Invalid log level: {}. Must be one of: trace, debug, info, warn, error",
885 other
886 )));
887 }
888 }
889
890 let _: SocketAddr = self
892 .metrics
893 .bind_address
894 .parse()
895 .map_err(|e| ConfigError::Validation(format!("Invalid metrics address: {}", e)))?;
896
897 if self.auth.enabled {
899 let has_enabled_method = (self.auth.mtls.enabled
901 && self.auth.methods.contains(&"mtls".to_string()))
902 || (self.auth.jwt.enabled && self.auth.methods.contains(&"jwt".to_string()))
903 || (self.auth.api_key.enabled
904 && self.auth.methods.contains(&"api_key".to_string()));
905
906 if !has_enabled_method {
907 return Err(ConfigError::Validation(
908 "Authentication enabled but no valid auth methods configured".to_string(),
909 ));
910 }
911
912 if self.auth.jwt.enabled {
914 match self.auth.jwt.algorithm.as_str() {
915 "HS256" => {
916 if self.auth.jwt.secret.is_none() {
917 return Err(ConfigError::Validation(
918 "JWT HS256 enabled but no secret key provided".to_string(),
919 ));
920 }
921 }
922 "RS256" => {
923 if self.auth.jwt.public_key_path.is_none() {
924 return Err(ConfigError::Validation(
925 "JWT RS256 enabled but no public key path provided".to_string(),
926 ));
927 }
928 }
929 other => {
930 return Err(ConfigError::Validation(format!(
931 "Invalid JWT algorithm: {}. Supported: HS256, RS256",
932 other
933 )));
934 }
935 }
936 }
937
938 if self.auth.api_key.enabled && self.auth.api_key.keys_file.is_none() {
940 return Err(ConfigError::Validation(
941 "API key auth enabled but no keys file specified".to_string(),
942 ));
943 }
944
945 if self.auth.mtls.enabled && self.auth.mtls.ca_certs_dir.is_none() {
947 return Err(ConfigError::Validation(
948 "mTLS enabled but no CA certificates directory specified".to_string(),
949 ));
950 }
951 }
952
953 if self.authz.enabled {
955 match self.authz.default_mode.as_str() {
956 "deny-by-default" | "allow-by-default" => {}
957 other => {
958 return Err(ConfigError::Validation(format!(
959 "Invalid authorization default mode: {}. Must be 'deny-by-default' or 'allow-by-default'",
960 other
961 )));
962 }
963 }
964 }
965
966 if self.timeouts.request_timeout_ms >= self.timeouts.idle_connection_timeout_ms {
968 return Err(ConfigError::Validation(
969 "request_timeout_ms must be less than idle_connection_timeout_ms".to_string(),
970 ));
971 }
972
973 Ok(())
974 }
975
976 pub fn shutdown_timeout(&self) -> Duration {
978 Duration::from_secs(self.server.shutdown_timeout_secs)
979 }
980
981 pub fn connection_timeout(&self) -> Duration {
983 Duration::from_secs(self.network.connection_timeout_secs)
984 }
985
986 pub fn keepalive_interval(&self) -> Duration {
988 Duration::from_secs(self.network.keepalive_interval_secs)
989 }
990
991 pub fn save_to_file(&self, path: impl AsRef<Path>) -> ConfigResult<()> {
993 let contents = toml::to_string_pretty(self)
994 .map_err(|e| ConfigError::Validation(format!("Failed to serialize config: {}", e)))?;
995 std::fs::write(path, contents)?;
996 Ok(())
997 }
998
999 pub fn example() -> Self {
1001 Self::default()
1002 }
1003}
1004
1005#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1007pub enum ReloadableSection {
1008 Logging,
1010 Metrics,
1012 Compaction,
1014 RateLimit,
1016}
1017
1018impl std::fmt::Display for ReloadableSection {
1019 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1020 match self {
1021 ReloadableSection::Logging => write!(f, "logging"),
1022 ReloadableSection::Metrics => write!(f, "metrics"),
1023 ReloadableSection::Compaction => write!(f, "compaction"),
1024 ReloadableSection::RateLimit => write!(f, "rate_limit"),
1025 }
1026 }
1027}
1028
1029#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1031pub enum NonReloadableSection {
1032 BindAddress,
1034 Port,
1036 TlsCertPath,
1038 TlsKeyPath,
1040 StorageEngine,
1042 DataDir,
1044 ClusterNodeId,
1046}
1047
1048impl std::fmt::Display for NonReloadableSection {
1049 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1050 match self {
1051 NonReloadableSection::BindAddress => write!(f, "bind_address"),
1052 NonReloadableSection::Port => write!(f, "port"),
1053 NonReloadableSection::TlsCertPath => write!(f, "tls_cert_path"),
1054 NonReloadableSection::TlsKeyPath => write!(f, "tls_key_path"),
1055 NonReloadableSection::StorageEngine => write!(f, "storage_engine"),
1056 NonReloadableSection::DataDir => write!(f, "data_dir"),
1057 NonReloadableSection::ClusterNodeId => write!(f, "cluster_node_id"),
1058 }
1059 }
1060}
1061
1062#[derive(Debug, Clone, Default)]
1064pub struct ConfigDiff {
1065 pub reloadable_changes: Vec<ReloadableSection>,
1067 pub non_reloadable_changes: Vec<NonReloadableSection>,
1069}
1070
1071impl ConfigDiff {
1072 pub fn is_empty(&self) -> bool {
1074 self.reloadable_changes.is_empty() && self.non_reloadable_changes.is_empty()
1075 }
1076
1077 pub fn has_non_reloadable_changes(&self) -> bool {
1079 !self.non_reloadable_changes.is_empty()
1080 }
1081}
1082
1083pub fn diff(old: &ServerConfig, new: &ServerConfig) -> ConfigDiff {
1085 let mut result = ConfigDiff::default();
1086
1087 if old.logging.level != new.logging.level
1089 || old.logging.format != new.logging.format
1090 || old.logging.file_enabled != new.logging.file_enabled
1091 || old.logging.file_path != new.logging.file_path
1092 || old.logging.rotation.enabled != new.logging.rotation.enabled
1093 || old.logging.rotation.max_size_mb != new.logging.rotation.max_size_mb
1094 || old.logging.rotation.max_backups != new.logging.rotation.max_backups
1095 {
1096 result.reloadable_changes.push(ReloadableSection::Logging);
1097 }
1098
1099 if old.metrics.export_interval_secs != new.metrics.export_interval_secs
1100 || old.metrics.enabled != new.metrics.enabled
1101 {
1102 result.reloadable_changes.push(ReloadableSection::Metrics);
1103 }
1104
1105 if old.storage.compaction.strategy != new.storage.compaction.strategy
1106 || old.storage.compaction.num_levels != new.storage.compaction.num_levels
1107 || old.storage.compaction.level_multiplier != new.storage.compaction.level_multiplier
1108 || old.storage.compaction.max_concurrent != new.storage.compaction.max_concurrent
1109 {
1110 result
1111 .reloadable_changes
1112 .push(ReloadableSection::Compaction);
1113 }
1114
1115 if old.server.max_connections != new.server.max_connections {
1116 result.reloadable_changes.push(ReloadableSection::RateLimit);
1117 }
1118
1119 if old.server.bind_address != new.server.bind_address {
1121 result
1122 .non_reloadable_changes
1123 .push(NonReloadableSection::BindAddress);
1124 }
1125
1126 if old.server.data_dir != new.server.data_dir {
1127 result
1128 .non_reloadable_changes
1129 .push(NonReloadableSection::DataDir);
1130 }
1131
1132 if old.storage.engine != new.storage.engine {
1133 result
1134 .non_reloadable_changes
1135 .push(NonReloadableSection::StorageEngine);
1136 }
1137
1138 if old.network.tls_cert != new.network.tls_cert {
1139 result
1140 .non_reloadable_changes
1141 .push(NonReloadableSection::TlsCertPath);
1142 }
1143
1144 if old.network.tls_key != new.network.tls_key {
1145 result
1146 .non_reloadable_changes
1147 .push(NonReloadableSection::TlsKeyPath);
1148 }
1149
1150 if let (Some(old_cluster), Some(new_cluster)) = (&old.cluster, &new.cluster) {
1151 if old_cluster.node_id != new_cluster.node_id {
1152 result
1153 .non_reloadable_changes
1154 .push(NonReloadableSection::ClusterNodeId);
1155 }
1156 }
1157
1158 result
1159}
1160
1161#[derive(Debug, Clone)]
1163pub struct ReloadReport {
1164 pub sections_updated: Vec<ReloadableSection>,
1166 pub sections_skipped: Vec<NonReloadableSection>,
1168 pub errors: Vec<String>,
1170 pub success: bool,
1172}
1173
1174impl std::fmt::Display for ReloadReport {
1175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1176 if self.success {
1177 write!(f, "Config reload successful. ")?;
1178 } else {
1179 write!(f, "Config reload failed. ")?;
1180 }
1181 if !self.sections_updated.is_empty() {
1182 write!(f, "Updated: ")?;
1183 for (i, s) in self.sections_updated.iter().enumerate() {
1184 if i > 0 {
1185 write!(f, ", ")?;
1186 }
1187 write!(f, "{}", s)?;
1188 }
1189 write!(f, ". ")?;
1190 }
1191 if !self.sections_skipped.is_empty() {
1192 write!(f, "Skipped (restart required): ")?;
1193 for (i, s) in self.sections_skipped.iter().enumerate() {
1194 if i > 0 {
1195 write!(f, ", ")?;
1196 }
1197 write!(f, "{}", s)?;
1198 }
1199 write!(f, ". ")?;
1200 }
1201 for err in &self.errors {
1202 write!(f, "Error: {}. ", err)?;
1203 }
1204 Ok(())
1205 }
1206}
1207
1208#[derive(Clone)]
1213pub struct ReloadableConfig {
1214 inner: Arc<RwLock<ServerConfig>>,
1215 config_path: Arc<RwLock<Option<PathBuf>>>,
1217}
1218
1219impl ReloadableConfig {
1220 pub fn new(config: ServerConfig) -> Self {
1222 Self {
1223 inner: Arc::new(RwLock::new(config)),
1224 config_path: Arc::new(RwLock::new(None)),
1225 }
1226 }
1227
1228 pub fn from_file(path: &str) -> ConfigResult<Self> {
1230 let config = ServerConfig::from_file(path)?;
1231 let rc = Self::new(config);
1232 *rc.config_path.write() = Some(PathBuf::from(path));
1233 Ok(rc)
1234 }
1235
1236 pub fn set_config_path(&self, path: PathBuf) {
1238 *self.config_path.write() = Some(path);
1239 }
1240
1241 pub fn read(&self) -> parking_lot::RwLockReadGuard<'_, ServerConfig> {
1243 self.inner.read()
1244 }
1245
1246 pub fn snapshot(&self) -> ServerConfig {
1248 self.inner.read().clone()
1249 }
1250
1251 pub fn reload_from_file(&self, path: &str) -> ConfigResult<ReloadReport> {
1257 let contents = std::fs::read_to_string(path)?;
1259 let new_config: ServerConfig = toml::from_str(&contents)?;
1260
1261 if let Err(e) = new_config.validate() {
1263 return Ok(ReloadReport {
1264 sections_updated: Vec::new(),
1265 sections_skipped: Vec::new(),
1266 errors: vec![format!("Validation failed: {}", e)],
1267 success: false,
1268 });
1269 }
1270
1271 self.apply_reload(new_config)
1272 }
1273
1274 pub fn reload_from_stored_path(&self) -> ConfigResult<ReloadReport> {
1276 let path = self.config_path.read().clone();
1277 match path {
1278 Some(p) => {
1279 let path_str = p.to_string_lossy().to_string();
1280 self.reload_from_file(&path_str)
1281 }
1282 None => Ok(ReloadReport {
1283 sections_updated: Vec::new(),
1284 sections_skipped: Vec::new(),
1285 errors: vec!["No config file path set for reload".to_string()],
1286 success: false,
1287 }),
1288 }
1289 }
1290
1291 fn apply_reload(&self, new_config: ServerConfig) -> ConfigResult<ReloadReport> {
1293 let mut report = ReloadReport {
1294 sections_updated: Vec::new(),
1295 sections_skipped: Vec::new(),
1296 errors: Vec::new(),
1297 success: true,
1298 };
1299
1300 let config_diff = {
1301 let current = self.inner.read();
1302 diff(¤t, &new_config)
1303 };
1304
1305 if config_diff.is_empty() {
1306 info!("Config reload: no changes detected");
1307 return Ok(report);
1308 }
1309
1310 for section in &config_diff.non_reloadable_changes {
1312 warn!(
1313 "Config reload: section '{}' changed but requires restart - skipping",
1314 section
1315 );
1316 report.sections_skipped.push(*section);
1317 }
1318
1319 if !config_diff.reloadable_changes.is_empty() {
1321 let mut current = self.inner.write();
1322
1323 for section in &config_diff.reloadable_changes {
1324 match section {
1325 ReloadableSection::Logging => {
1326 current.logging = new_config.logging.clone();
1327 info!("Config reload: updated logging settings");
1328 }
1329 ReloadableSection::Metrics => {
1330 current.metrics.export_interval_secs =
1332 new_config.metrics.export_interval_secs;
1333 current.metrics.enabled = new_config.metrics.enabled;
1334 info!("Config reload: updated metrics settings");
1335 }
1336 ReloadableSection::Compaction => {
1337 current.storage.compaction = new_config.storage.compaction.clone();
1338 info!("Config reload: updated compaction settings");
1339 }
1340 ReloadableSection::RateLimit => {
1341 current.server.max_connections = new_config.server.max_connections;
1342 info!("Config reload: updated rate limit settings");
1343 }
1344 }
1345 report.sections_updated.push(*section);
1346 }
1347 }
1348
1349 Ok(report)
1350 }
1351
1352 pub fn manual_reload(&self) -> ConfigResult<ReloadReport> {
1354 self.reload_from_stored_path()
1355 }
1356}
1357
1358impl std::fmt::Debug for ReloadableConfig {
1359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1360 f.debug_struct("ReloadableConfig")
1361 .field("config", &*self.inner.read())
1362 .field("config_path", &*self.config_path.read())
1363 .finish()
1364 }
1365}
1366
1367#[cfg(test)]
1368mod tests {
1369 use super::*;
1370 use std::env;
1371
1372 #[test]
1373 fn test_default_config() {
1374 let config = ServerConfig::default();
1375 assert_eq!(config.server.bind_address, "0.0.0.0:7878");
1376 assert_eq!(config.storage.engine, "lsm");
1377 assert_eq!(config.logging.level, "info");
1378 }
1379
1380 #[test]
1381 fn test_config_validation() {
1382 let config = ServerConfig::default();
1383 assert!(config.validate().is_ok());
1384 }
1385
1386 #[test]
1387 fn test_invalid_bind_address() {
1388 let mut config = ServerConfig::default();
1389 config.server.bind_address = "invalid".to_string();
1390 assert!(config.validate().is_err());
1391 }
1392
1393 #[test]
1394 fn test_invalid_storage_engine() {
1395 let mut config = ServerConfig::default();
1396 config.storage.engine = "invalid".to_string();
1397 assert!(config.validate().is_err());
1398 }
1399
1400 #[test]
1401 fn test_tls_validation() {
1402 let mut config = ServerConfig::default();
1403 config.network.tls_enabled = true;
1404 assert!(config.validate().is_err()); }
1406
1407 #[test]
1408 fn test_env_overrides() {
1409 unsafe {
1410 env::set_var("AMATERS_BIND_ADDRESS", "127.0.0.1:9999");
1411 env::set_var("AMATERS_LOG_LEVEL", "debug");
1412 }
1413
1414 let mut config = ServerConfig::default();
1415 config.apply_env_overrides();
1416
1417 assert_eq!(config.server.bind_address, "127.0.0.1:9999");
1418 assert_eq!(config.logging.level, "debug");
1419
1420 unsafe {
1421 env::remove_var("AMATERS_BIND_ADDRESS");
1422 env::remove_var("AMATERS_LOG_LEVEL");
1423 }
1424 }
1425
1426 #[test]
1427 fn test_save_and_load() {
1428 let temp_dir = env::temp_dir();
1429 let config_path = temp_dir.join("test_config.toml");
1430
1431 let config = ServerConfig::default();
1432 config
1433 .save_to_file(&config_path)
1434 .expect("Failed to save config");
1435
1436 let loaded = ServerConfig::from_file(&config_path).expect("Failed to load config");
1437 assert_eq!(config.server.bind_address, loaded.server.bind_address);
1438
1439 std::fs::remove_file(&config_path).ok();
1440 }
1441
1442 fn save_temp_config(config: &ServerConfig, name: &str) -> PathBuf {
1446 let path = env::temp_dir().join(format!("amaters_reload_test_{}.toml", name));
1447 config
1448 .save_to_file(&path)
1449 .expect("Failed to save temp config");
1450 path
1451 }
1452
1453 #[test]
1454 fn test_reload_logging_section() {
1455 let config = ServerConfig::default();
1456 let path = save_temp_config(&config, "reload_logging");
1457
1458 let reloadable = ReloadableConfig::new(config);
1459 reloadable.set_config_path(path.clone());
1460
1461 let mut new_config = reloadable.snapshot();
1463 new_config.logging.level = "debug".to_string();
1464 new_config.logging.format = "json".to_string();
1465 new_config
1466 .save_to_file(&path)
1467 .expect("Failed to save modified config");
1468
1469 let report = reloadable
1470 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1471 .expect("Reload should succeed");
1472
1473 assert!(report.success);
1474 assert!(
1475 report
1476 .sections_updated
1477 .contains(&ReloadableSection::Logging)
1478 );
1479 assert_eq!(reloadable.read().logging.level, "debug");
1480 assert_eq!(reloadable.read().logging.format, "json");
1481
1482 std::fs::remove_file(&path).ok();
1483 }
1484
1485 #[test]
1486 fn test_reload_metrics_section() {
1487 let config = ServerConfig::default();
1488 let path = save_temp_config(&config, "reload_metrics");
1489
1490 let reloadable = ReloadableConfig::new(config);
1491
1492 let mut new_config = reloadable.snapshot();
1493 new_config.metrics.export_interval_secs = 120;
1494 new_config
1495 .save_to_file(&path)
1496 .expect("Failed to save modified config");
1497
1498 let report = reloadable
1499 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1500 .expect("Reload should succeed");
1501
1502 assert!(report.success);
1503 assert!(
1504 report
1505 .sections_updated
1506 .contains(&ReloadableSection::Metrics)
1507 );
1508 assert_eq!(reloadable.read().metrics.export_interval_secs, 120);
1509
1510 std::fs::remove_file(&path).ok();
1511 }
1512
1513 #[test]
1514 fn test_non_reloadable_section_skipped() {
1515 let config = ServerConfig::default();
1516 let path = save_temp_config(&config, "reload_non_reloadable");
1517
1518 let reloadable = ReloadableConfig::new(config);
1519
1520 let mut new_config = reloadable.snapshot();
1521 new_config.server.bind_address = "127.0.0.1:9999".to_string();
1523 new_config.logging.level = "warn".to_string();
1525 new_config
1526 .save_to_file(&path)
1527 .expect("Failed to save modified config");
1528
1529 let report = reloadable
1530 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1531 .expect("Reload should succeed");
1532
1533 assert!(report.success);
1534 assert!(
1536 report
1537 .sections_updated
1538 .contains(&ReloadableSection::Logging)
1539 );
1540 assert_eq!(reloadable.read().logging.level, "warn");
1541 assert!(
1543 report
1544 .sections_skipped
1545 .contains(&NonReloadableSection::BindAddress)
1546 );
1547 assert_eq!(reloadable.read().server.bind_address, "0.0.0.0:7878");
1548
1549 std::fs::remove_file(&path).ok();
1550 }
1551
1552 #[test]
1553 fn test_invalid_config_rejected() {
1554 let config = ServerConfig::default();
1555 let path = save_temp_config(&config, "reload_invalid");
1556
1557 let reloadable = ReloadableConfig::new(config);
1558
1559 let mut new_config = reloadable.snapshot();
1561 new_config.server.bind_address = "not-an-address".to_string();
1562 let contents = toml::to_string_pretty(&new_config).expect("Failed to serialize config");
1564 std::fs::write(&path, contents).expect("Failed to write config");
1565
1566 let report = reloadable
1567 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1568 .expect("Reload should return report");
1569
1570 assert!(!report.success);
1571 assert!(!report.errors.is_empty());
1572 assert_eq!(reloadable.read().server.bind_address, "0.0.0.0:7878");
1574
1575 std::fs::remove_file(&path).ok();
1576 }
1577
1578 #[test]
1579 fn test_config_diff_detection() {
1580 let old = ServerConfig::default();
1581 let mut new = old.clone();
1582
1583 let d = diff(&old, &new);
1585 assert!(d.is_empty());
1586
1587 new.logging.level = "error".to_string();
1589 let d = diff(&old, &new);
1590 assert!(d.reloadable_changes.contains(&ReloadableSection::Logging));
1591 assert!(!d.has_non_reloadable_changes());
1592
1593 new.server.bind_address = "127.0.0.1:1234".to_string();
1595 let d = diff(&old, &new);
1596 assert!(d.has_non_reloadable_changes());
1597 assert!(
1598 d.non_reloadable_changes
1599 .contains(&NonReloadableSection::BindAddress)
1600 );
1601
1602 new.storage.compaction.strategy = "tiered".to_string();
1604 let d = diff(&old, &new);
1605 assert!(
1606 d.reloadable_changes
1607 .contains(&ReloadableSection::Compaction)
1608 );
1609
1610 new.server.max_connections = 5000;
1612 let d = diff(&old, &new);
1613 assert!(d.reloadable_changes.contains(&ReloadableSection::RateLimit));
1614 }
1615
1616 #[test]
1617 fn test_reload_report_contents() {
1618 let config = ServerConfig::default();
1619 let path = save_temp_config(&config, "reload_report");
1620
1621 let reloadable = ReloadableConfig::new(config);
1622
1623 let mut new_config = reloadable.snapshot();
1625 new_config.logging.level = "trace".to_string();
1626 new_config.metrics.export_interval_secs = 30;
1627 new_config.server.bind_address = "127.0.0.1:5555".to_string();
1628 new_config
1629 .save_to_file(&path)
1630 .expect("Failed to save modified config");
1631
1632 let report = reloadable
1633 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1634 .expect("Reload should succeed");
1635
1636 assert!(report.success);
1637 assert_eq!(report.sections_updated.len(), 2); assert_eq!(report.sections_skipped.len(), 1); assert!(report.errors.is_empty());
1640
1641 let display = format!("{}", report);
1643 assert!(display.contains("Updated"));
1644 assert!(display.contains("Skipped"));
1645
1646 std::fs::remove_file(&path).ok();
1647 }
1648
1649 #[test]
1650 fn test_concurrent_reads_during_reload() {
1651 let config = ServerConfig::default();
1652 let path = save_temp_config(&config, "reload_concurrent");
1653
1654 let reloadable = ReloadableConfig::new(config);
1655
1656 let handles: Vec<_> = (0..10)
1658 .map(|_| {
1659 let rc = reloadable.clone();
1660 std::thread::spawn(move || {
1661 for _ in 0..100 {
1662 let _level = rc.read().logging.level.clone();
1663 }
1664 })
1665 })
1666 .collect();
1667
1668 let mut new_config = reloadable.snapshot();
1670 new_config.logging.level = "debug".to_string();
1671 new_config
1672 .save_to_file(&path)
1673 .expect("Failed to save modified config");
1674
1675 let _report = reloadable
1676 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1677 .expect("Reload should succeed");
1678
1679 for h in handles {
1680 h.join().expect("Reader thread should not panic");
1681 }
1682
1683 assert_eq!(reloadable.read().logging.level, "debug");
1684
1685 std::fs::remove_file(&path).ok();
1686 }
1687
1688 #[test]
1689 fn test_multiple_sequential_reloads() {
1690 let config = ServerConfig::default();
1691 let path = save_temp_config(&config, "reload_sequential");
1692
1693 let reloadable = ReloadableConfig::new(config);
1694
1695 let levels = ["debug", "warn", "error", "trace", "info"];
1696 for level in &levels {
1697 let mut new_config = reloadable.snapshot();
1698 new_config.logging.level = level.to_string();
1699 new_config
1700 .save_to_file(&path)
1701 .expect("Failed to save modified config");
1702
1703 let report = reloadable
1704 .reload_from_file(path.to_str().expect("path should be valid utf-8"))
1705 .expect("Reload should succeed");
1706
1707 assert!(report.success);
1708 assert_eq!(reloadable.read().logging.level, *level);
1709 }
1710
1711 std::fs::remove_file(&path).ok();
1712 }
1713
1714 #[test]
1715 fn test_reload_no_stored_path() {
1716 let config = ServerConfig::default();
1717 let reloadable = ReloadableConfig::new(config);
1718
1719 let report = reloadable
1720 .reload_from_stored_path()
1721 .expect("Should return report");
1722
1723 assert!(!report.success);
1724 assert!(!report.errors.is_empty());
1725 }
1726
1727 #[test]
1728 fn test_reloadable_config_from_file() {
1729 let config = ServerConfig::default();
1730 let path = save_temp_config(&config, "reload_from_file");
1731
1732 let reloadable =
1733 ReloadableConfig::from_file(path.to_str().expect("path should be valid utf-8"))
1734 .expect("Should load from file");
1735
1736 assert_eq!(reloadable.read().server.bind_address, "0.0.0.0:7878");
1737
1738 std::fs::remove_file(&path).ok();
1739 }
1740
1741 #[test]
1742 fn test_manual_reload() {
1743 let config = ServerConfig::default();
1744 let path = save_temp_config(&config, "reload_manual");
1745
1746 let reloadable = ReloadableConfig::new(config);
1747 reloadable.set_config_path(path.clone());
1748
1749 let mut new_config = reloadable.snapshot();
1750 new_config.logging.level = "error".to_string();
1751 new_config
1752 .save_to_file(&path)
1753 .expect("Failed to save modified config");
1754
1755 let report = reloadable
1756 .manual_reload()
1757 .expect("Manual reload should succeed");
1758 assert!(report.success);
1759 assert_eq!(reloadable.read().logging.level, "error");
1760
1761 std::fs::remove_file(&path).ok();
1762 }
1763
1764 #[test]
1765 fn test_resource_limits_defaults() {
1766 let config = ServerConfig::default();
1767 assert_eq!(config.resource_limits.max_connections_per_client, 10);
1768 assert_eq!(
1769 config.resource_limits.max_requests_per_second_global,
1770 10_000
1771 );
1772 assert!(config.resource_limits.max_memory_bytes.is_none());
1773 assert_eq!(config.resource_limits.max_active_queries, 1000);
1774 }
1775
1776 #[test]
1777 fn test_circuit_cache_defaults() {
1778 let config = ServerConfig::default();
1779 assert_eq!(config.circuit_cache.max_entries, 1000);
1780 assert_eq!(config.circuit_cache.ttl_secs, 300);
1781 }
1782
1783 #[test]
1784 fn test_timeout_config_defaults() {
1785 let config = ServerConfig::default();
1786 assert_eq!(config.timeouts.request_timeout_ms, 30_000);
1787 assert_eq!(config.timeouts.idle_connection_timeout_ms, 60_000);
1788 assert_eq!(config.timeouts.keep_alive_interval_ms, 15_000);
1789 }
1790
1791 #[test]
1792 fn test_timeout_validation_ordering() {
1793 let mut config = ServerConfig::default();
1794 config.timeouts.request_timeout_ms = 60_000;
1796 config.timeouts.idle_connection_timeout_ms = 30_000;
1797 assert!(config.validate().is_err());
1798 config.timeouts.request_timeout_ms = 30_000;
1800 config.timeouts.idle_connection_timeout_ms = 60_000;
1801 assert!(config.validate().is_ok());
1802 }
1803}