modbus_relay/config/
relay.rs

1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4
5use config::{Config as ConfigBuilder, ConfigError, Environment, File, FileFormat};
6
7use super::{ConnectionConfig, HttpConfig, LoggingConfig, RtuConfig, TcpConfig};
8
9/// Main application configuration
10#[derive(Default, Debug, Clone, Serialize, Deserialize)]
11#[serde(deny_unknown_fields)]
12pub struct Config {
13    /// TCP server configuration
14    pub tcp: TcpConfig,
15
16    /// RTU client configuration
17    pub rtu: RtuConfig,
18
19    /// HTTP API configuration
20    pub http: HttpConfig,
21
22    /// Logging configuration
23    pub logging: LoggingConfig,
24
25    /// Connection management configuration
26    pub connection: ConnectionConfig,
27}
28
29impl Config {
30    /// Default configuration directory
31    pub const CONFIG_DIR: &'static str = "config";
32
33    /// Environment variable prefix
34    const ENV_PREFIX: &'static str = "MODBUS_RELAY";
35
36    /// Build configuration using the following priority (highest to lowest):
37    /// 1. Environment variables (MODBUS_RELAY_*)
38    /// 2. Local configuration file (config/local.yaml)
39    /// 3. Environment specific file (config/{env}.yaml)
40    /// 4. Default configuration (config/default.yaml)
41    /// 5. Built-in defaults
42    pub fn new() -> Result<Self, ConfigError> {
43        let environment = std::env::var("RUN_MODE").unwrap_or_else(|_| "development".into());
44
45        // Start with built-in defaults
46        let defaults = Config::default();
47
48        let mut builder = ConfigBuilder::builder();
49
50        // Set defaults for each field manually
51        builder = builder
52            // TCP configuration
53            .set_default("tcp.bind_addr", defaults.tcp.bind_addr)?
54            .set_default("tcp.bind_port", defaults.tcp.bind_port)?
55            // RTU configuration
56            .set_default("rtu.device", defaults.rtu.device)?
57            .set_default("rtu.baud_rate", defaults.rtu.baud_rate)?
58            .set_default("rtu.data_bits", defaults.rtu.data_bits.to_string())?
59            .set_default("rtu.parity", defaults.rtu.parity.to_string())?
60            .set_default("rtu.stop_bits", defaults.rtu.stop_bits.to_string())?
61            .set_default("rtu.flush_after_write", defaults.rtu.flush_after_write)?
62            .set_default("rtu.rts_type", defaults.rtu.rts_type.to_string())?
63            .set_default("rtu.rts_delay_us", defaults.rtu.rts_delay_us)?
64            .set_default(
65                "rtu.transaction_timeout",
66                format!("{}s", defaults.rtu.transaction_timeout.as_secs()),
67            )?
68            .set_default(
69                "rtu.serial_timeout",
70                format!("{}s", defaults.rtu.serial_timeout.as_secs()),
71            )?
72            .set_default("rtu.max_frame_size", defaults.rtu.max_frame_size)?
73            // HTTP configuration
74            .set_default("http.enabled", defaults.http.enabled)?
75            .set_default("http.bind_addr", defaults.http.bind_addr)?
76            .set_default("http.bind_port", defaults.http.bind_port)?
77            .set_default("http.metrics_enabled", defaults.http.metrics_enabled)?
78            // Logging configuration
79            .set_default("logging.log_dir", defaults.logging.log_dir)?
80            .set_default("logging.trace_frames", defaults.logging.trace_frames)?
81            .set_default("logging.level", defaults.logging.level)?
82            .set_default("logging.format", defaults.logging.format)?
83            .set_default(
84                "logging.include_location",
85                defaults.logging.include_location,
86            )?
87            .set_default("logging.thread_ids", defaults.logging.thread_ids)?
88            .set_default("logging.thread_names", defaults.logging.thread_names)?
89            // Connection configuration
90            .set_default(
91                "connection.max_connections",
92                defaults.connection.max_connections,
93            )?
94            .set_default(
95                "connection.idle_timeout",
96                format!("{}s", defaults.connection.idle_timeout.as_secs()),
97            )?
98            .set_default(
99                "connection.connect_timeout",
100                format!("{}s", defaults.connection.connect_timeout.as_secs()),
101            )?
102            .set_default(
103                "connection.per_ip_limits",
104                defaults.connection.per_ip_limits,
105            )?
106            // Connection backoff configuration
107            .set_default(
108                "connection.backoff.initial_interval",
109                format!(
110                    "{}s",
111                    defaults.connection.backoff.initial_interval.as_secs()
112                ),
113            )?
114            .set_default(
115                "connection.backoff.max_interval",
116                format!("{}s", defaults.connection.backoff.max_interval.as_secs()),
117            )?
118            .set_default(
119                "connection.backoff.multiplier",
120                defaults.connection.backoff.multiplier,
121            )?
122            .set_default(
123                "connection.backoff.max_retries",
124                defaults.connection.backoff.max_retries,
125            )?;
126
127        let config = builder
128            // Load default config file
129            .add_source(File::new(
130                &format!("{}/default", Self::CONFIG_DIR),
131                FileFormat::Yaml,
132            ))
133            // Load environment specific config
134            .add_source(
135                File::new(
136                    &format!("{}/{}", Self::CONFIG_DIR, environment),
137                    FileFormat::Yaml,
138                )
139                .required(false),
140            )
141            // Load local overrides
142            .add_source(
143                File::new(&format!("{}/local", Self::CONFIG_DIR), FileFormat::Yaml).required(false),
144            )
145            // Add environment variables
146            .add_source(
147                Environment::with_prefix(Self::ENV_PREFIX)
148                    .prefix_separator("_")
149                    .separator("__")
150                    .try_parsing(true),
151            )
152            .build()?;
153
154        // Deserialize and validate
155        let config = config.try_deserialize()?;
156        Self::validate(&config)?;
157
158        Ok(config)
159    }
160
161    /// Load configuration from a specific file
162    pub fn from_file(path: PathBuf) -> Result<Self, ConfigError> {
163        let config = ConfigBuilder::builder()
164            // Load the specified config file
165            .add_source(File::from(path))
166            // Add env vars as overrides
167            .add_source(
168                Environment::with_prefix(Self::ENV_PREFIX)
169                    .separator("_")
170                    .try_parsing(true),
171            )
172            .build()?;
173
174        let config = config.try_deserialize()?;
175        Self::validate(&config)?;
176
177        Ok(config)
178    }
179
180    /// Validate configuration
181    pub fn validate(config: &Self) -> Result<(), ConfigError> {
182        // Helper to convert validation errors
183        fn validation_error(msg: &str) -> ConfigError {
184            ConfigError::Message(msg.to_string())
185        }
186
187        // Validate TCP configuration
188        if config.tcp.bind_addr.is_empty() {
189            return Err(validation_error("TCP bind address must not be empty"));
190        }
191        if config.tcp.bind_port == 0 {
192            return Err(validation_error("TCP port must be non-zero"));
193        }
194
195        // Validate RTU configuration
196        if config.rtu.device.is_empty() {
197            return Err(validation_error("RTU device must not be empty"));
198        }
199        if config.rtu.baud_rate == 0 {
200            return Err(validation_error("RTU baud rate must be non-zero"));
201        }
202
203        // Validate connection configuration
204        if config.rtu.transaction_timeout.is_zero() {
205            return Err(validation_error("Transaction timeout must be non-zero"));
206        }
207        if config.rtu.serial_timeout.is_zero() {
208            return Err(validation_error("Serial timeout must be non-zero"));
209        }
210        if config.rtu.max_frame_size == 0 {
211            return Err(validation_error("Max frame size must be non-zero"));
212        }
213
214        // Validate log level
215        match config.logging.level.to_lowercase().as_str() {
216            "error" | "warn" | "info" | "debug" | "trace" => {}
217            _ => return Err(validation_error("Invalid log level")),
218        }
219
220        // Validate log format
221        match config.logging.format.to_lowercase().as_str() {
222            "pretty" | "json" => {}
223            _ => return Err(validation_error("Invalid log format")),
224        }
225
226        // Validate connection configuration
227        if config.connection.max_connections == 0 {
228            return Err(validation_error("Maximum connections must be non-zero"));
229        }
230        if config.connection.idle_timeout.is_zero() {
231            return Err(validation_error("Idle timeout must be non-zero"));
232        }
233        if config.connection.connect_timeout.is_zero() {
234            return Err(validation_error("Connect timeout must be non-zero"));
235        }
236        if let Some(limit) = config.connection.per_ip_limits {
237            if limit == 0 {
238                return Err(validation_error("Per IP connection limit must be non-zero"));
239            }
240            if limit > config.connection.max_connections {
241                return Err(validation_error(
242                    "Per IP connection limit cannot exceed maximum connections",
243                ));
244            }
245        }
246        // Validate backoff configuration
247        if config.connection.backoff.initial_interval.is_zero() {
248            return Err(validation_error(
249                "Backoff initial interval must be non-zero",
250            ));
251        }
252        if config.connection.backoff.max_interval.is_zero() {
253            return Err(validation_error("Backoff max interval must be non-zero"));
254        }
255        if config.connection.backoff.multiplier <= 0.0 {
256            return Err(validation_error("Backoff multiplier must be positive"));
257        }
258        if config.connection.backoff.max_retries == 0 {
259            return Err(validation_error("Backoff max retries must be non-zero"));
260        }
261
262        Ok(())
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use crate::{DataBits, Parity, RtsType, StopBits};
269
270    use super::*;
271    use std::{fs, time::Duration};
272    use tempfile::tempdir;
273
274    #[test]
275    #[serial_test::serial]
276    fn test_default_config() {
277        let config = Config::new().unwrap();
278        assert_eq!(config.tcp.bind_port, 502);
279        assert_eq!(config.tcp.bind_addr, "127.0.0.1");
280    }
281
282    #[test]
283    #[serial_test::serial]
284    fn test_env_override() {
285        unsafe { std::env::set_var("MODBUS_RELAY_TCP__BIND_PORT", "5000") };
286        let config = Config::new().unwrap();
287        assert_eq!(config.tcp.bind_port, 5000);
288        unsafe { std::env::remove_var("MODBUS_RELAY_TCP__BIND_PORT") };
289    }
290
291    #[test]
292    #[serial_test::serial]
293    fn test_file_config() {
294        let dir = tempdir().unwrap();
295        let config_path = dir.path().join("config.yaml");
296
297        fs::write(
298            &config_path,
299            r#"
300            tcp:
301              bind_port: 9000
302              bind_addr: "192.168.1.100"
303              keep_alive: "60s"
304            rtu:
305              device: "/dev/ttyAMA0"
306              baud_rate: 9600
307              data_bits: 8
308              parity: "none"
309              stop_bits: "one"
310              flush_after_write: true
311              rts_type: "down"
312              rts_delay_us: 3500
313              transaction_timeout: "5s"
314              serial_timeout: "1s"
315              max_frame_size: 256
316            http:
317              enabled: false
318              bind_addr: "192.168.1.100"
319              bind_port: 9080
320              metrics_enabled: false
321            logging:
322              log_dir: "logs"
323              trace_frames: false
324              level: "trace"
325              format: "pretty"
326              include_location: false
327              thread_ids: false
328              thread_names: true
329            connection:
330              max_connections: 100
331              idle_timeout: "60s"
332              error_timeout: "300s"
333              connect_timeout: "5s"
334              per_ip_limits: 10
335              backoff:
336                # Initial wait time
337                initial_interval: "100ms"
338                # Maximum wait time
339                max_interval: "30s"
340                # Multiplier for each subsequent attempt
341                multiplier: 2.0
342                # Maximum number of attempts
343                max_retries: 5
344            "#,
345        )
346        .unwrap();
347
348        let config = Config::from_file(config_path).unwrap();
349        assert_eq!(config.tcp.bind_port, 9000);
350        assert_eq!(config.tcp.bind_addr, "192.168.1.100");
351        assert_eq!(config.tcp.keep_alive, Duration::from_secs(60));
352        assert_eq!(config.rtu.device, "/dev/ttyAMA0");
353        assert_eq!(config.rtu.baud_rate, 9600);
354        assert_eq!(config.rtu.data_bits, DataBits::new(8).unwrap());
355        assert_eq!(config.rtu.parity, Parity::None);
356        assert_eq!(config.rtu.stop_bits, StopBits::One);
357        assert!(config.rtu.flush_after_write);
358        assert_eq!(config.rtu.rts_type, RtsType::Down);
359        assert_eq!(config.rtu.rts_delay_us, 3500);
360        assert_eq!(config.rtu.transaction_timeout, Duration::from_secs(5));
361        assert_eq!(config.rtu.serial_timeout, Duration::from_secs(1));
362        assert_eq!(config.rtu.max_frame_size, 256);
363        assert!(!config.http.enabled);
364        assert_eq!(config.http.bind_addr, "192.168.1.100");
365        assert_eq!(config.http.bind_port, 9080);
366        assert!(!config.http.metrics_enabled);
367        assert_eq!(config.logging.log_dir, "logs");
368        assert!(!config.logging.trace_frames);
369        assert_eq!(config.logging.level, "trace");
370        assert_eq!(config.logging.format, "pretty");
371        assert!(!config.logging.include_location);
372        assert!(!config.logging.thread_ids);
373        assert!(config.logging.thread_names);
374        assert_eq!(config.connection.max_connections, 100);
375        assert_eq!(config.connection.idle_timeout, Duration::from_secs(60));
376        assert_eq!(config.connection.error_timeout, Duration::from_secs(300));
377        assert_eq!(config.connection.connect_timeout, Duration::from_secs(5));
378        assert_eq!(config.connection.per_ip_limits, Some(10));
379        assert_eq!(
380            config.connection.backoff.initial_interval,
381            Duration::from_millis(100)
382        );
383        assert_eq!(
384            config.connection.backoff.max_interval,
385            Duration::from_secs(30)
386        );
387        assert_eq!(config.connection.backoff.multiplier, 2.0);
388        assert_eq!(config.connection.backoff.max_retries, 5);
389    }
390
391    #[test]
392    #[serial_test::serial]
393    fn test_validation() {
394        unsafe { std::env::set_var("MODBUS_RELAY_TCP__BIND_PORT", "0") };
395        assert!(Config::new().is_err());
396        unsafe { std::env::remove_var("MODBUS_RELAY_TCP__BIND_PORT") };
397    }
398}