1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4
5use config::{Config as ConfigBuilder, ConfigError, Environment, File, FileFormat};
6
7use super::{ConnectionConfig, HttpConfig, LoggingConfig, RtuConfig, TcpConfig};
8
9#[derive(Default, Debug, Clone, Serialize, Deserialize)]
11#[serde(deny_unknown_fields)]
12pub struct Config {
13 pub tcp: TcpConfig,
15
16 pub rtu: RtuConfig,
18
19 pub http: HttpConfig,
21
22 pub logging: LoggingConfig,
24
25 pub connection: ConnectionConfig,
27}
28
29impl Config {
30 pub const CONFIG_DIR: &'static str = "config";
32
33 const ENV_PREFIX: &'static str = "MODBUS_RELAY";
35
36 pub fn new() -> Result<Self, ConfigError> {
43 let environment = std::env::var("RUN_MODE").unwrap_or_else(|_| "development".into());
44
45 let defaults = Config::default();
47
48 let mut builder = ConfigBuilder::builder();
49
50 builder = builder
52 .set_default("tcp.bind_addr", defaults.tcp.bind_addr)?
54 .set_default("tcp.bind_port", defaults.tcp.bind_port)?
55 .set_default("rtu.device", defaults.rtu.device)?
57 .set_default("rtu.baud_rate", defaults.rtu.baud_rate)?
58 .set_default("rtu.data_bits", defaults.rtu.data_bits.to_string())?
59 .set_default("rtu.parity", defaults.rtu.parity.to_string())?
60 .set_default("rtu.stop_bits", defaults.rtu.stop_bits.to_string())?
61 .set_default("rtu.flush_after_write", defaults.rtu.flush_after_write)?
62 .set_default("rtu.rts_type", defaults.rtu.rts_type.to_string())?
63 .set_default("rtu.rts_delay_us", defaults.rtu.rts_delay_us)?
64 .set_default(
65 "rtu.transaction_timeout",
66 format!("{}s", defaults.rtu.transaction_timeout.as_secs()),
67 )?
68 .set_default(
69 "rtu.serial_timeout",
70 format!("{}s", defaults.rtu.serial_timeout.as_secs()),
71 )?
72 .set_default("rtu.max_frame_size", defaults.rtu.max_frame_size)?
73 .set_default("http.enabled", defaults.http.enabled)?
75 .set_default("http.bind_addr", defaults.http.bind_addr)?
76 .set_default("http.bind_port", defaults.http.bind_port)?
77 .set_default("http.metrics_enabled", defaults.http.metrics_enabled)?
78 .set_default("logging.log_dir", defaults.logging.log_dir)?
80 .set_default("logging.trace_frames", defaults.logging.trace_frames)?
81 .set_default("logging.level", defaults.logging.level)?
82 .set_default("logging.format", defaults.logging.format)?
83 .set_default(
84 "logging.include_location",
85 defaults.logging.include_location,
86 )?
87 .set_default("logging.thread_ids", defaults.logging.thread_ids)?
88 .set_default("logging.thread_names", defaults.logging.thread_names)?
89 .set_default(
91 "connection.max_connections",
92 defaults.connection.max_connections,
93 )?
94 .set_default(
95 "connection.idle_timeout",
96 format!("{}s", defaults.connection.idle_timeout.as_secs()),
97 )?
98 .set_default(
99 "connection.connect_timeout",
100 format!("{}s", defaults.connection.connect_timeout.as_secs()),
101 )?
102 .set_default(
103 "connection.per_ip_limits",
104 defaults.connection.per_ip_limits,
105 )?
106 .set_default(
108 "connection.backoff.initial_interval",
109 format!(
110 "{}s",
111 defaults.connection.backoff.initial_interval.as_secs()
112 ),
113 )?
114 .set_default(
115 "connection.backoff.max_interval",
116 format!("{}s", defaults.connection.backoff.max_interval.as_secs()),
117 )?
118 .set_default(
119 "connection.backoff.multiplier",
120 defaults.connection.backoff.multiplier,
121 )?
122 .set_default(
123 "connection.backoff.max_retries",
124 defaults.connection.backoff.max_retries,
125 )?;
126
127 let config = builder
128 .add_source(File::new(
130 &format!("{}/default", Self::CONFIG_DIR),
131 FileFormat::Yaml,
132 ))
133 .add_source(
135 File::new(
136 &format!("{}/{}", Self::CONFIG_DIR, environment),
137 FileFormat::Yaml,
138 )
139 .required(false),
140 )
141 .add_source(
143 File::new(&format!("{}/local", Self::CONFIG_DIR), FileFormat::Yaml).required(false),
144 )
145 .add_source(
147 Environment::with_prefix(Self::ENV_PREFIX)
148 .prefix_separator("_")
149 .separator("__")
150 .try_parsing(true),
151 )
152 .build()?;
153
154 let config = config.try_deserialize()?;
156 Self::validate(&config)?;
157
158 Ok(config)
159 }
160
161 pub fn from_file(path: PathBuf) -> Result<Self, ConfigError> {
163 let config = ConfigBuilder::builder()
164 .add_source(File::from(path))
166 .add_source(
168 Environment::with_prefix(Self::ENV_PREFIX)
169 .separator("_")
170 .try_parsing(true),
171 )
172 .build()?;
173
174 let config = config.try_deserialize()?;
175 Self::validate(&config)?;
176
177 Ok(config)
178 }
179
180 pub fn validate(config: &Self) -> Result<(), ConfigError> {
182 fn validation_error(msg: &str) -> ConfigError {
184 ConfigError::Message(msg.to_string())
185 }
186
187 if config.tcp.bind_addr.is_empty() {
189 return Err(validation_error("TCP bind address must not be empty"));
190 }
191 if config.tcp.bind_port == 0 {
192 return Err(validation_error("TCP port must be non-zero"));
193 }
194
195 if config.rtu.device.is_empty() {
197 return Err(validation_error("RTU device must not be empty"));
198 }
199 if config.rtu.baud_rate == 0 {
200 return Err(validation_error("RTU baud rate must be non-zero"));
201 }
202
203 if config.rtu.transaction_timeout.is_zero() {
205 return Err(validation_error("Transaction timeout must be non-zero"));
206 }
207 if config.rtu.serial_timeout.is_zero() {
208 return Err(validation_error("Serial timeout must be non-zero"));
209 }
210 if config.rtu.max_frame_size == 0 {
211 return Err(validation_error("Max frame size must be non-zero"));
212 }
213
214 match config.logging.level.to_lowercase().as_str() {
216 "error" | "warn" | "info" | "debug" | "trace" => {}
217 _ => return Err(validation_error("Invalid log level")),
218 }
219
220 match config.logging.format.to_lowercase().as_str() {
222 "pretty" | "json" => {}
223 _ => return Err(validation_error("Invalid log format")),
224 }
225
226 if config.connection.max_connections == 0 {
228 return Err(validation_error("Maximum connections must be non-zero"));
229 }
230 if config.connection.idle_timeout.is_zero() {
231 return Err(validation_error("Idle timeout must be non-zero"));
232 }
233 if config.connection.connect_timeout.is_zero() {
234 return Err(validation_error("Connect timeout must be non-zero"));
235 }
236 if let Some(limit) = config.connection.per_ip_limits {
237 if limit == 0 {
238 return Err(validation_error("Per IP connection limit must be non-zero"));
239 }
240 if limit > config.connection.max_connections {
241 return Err(validation_error(
242 "Per IP connection limit cannot exceed maximum connections",
243 ));
244 }
245 }
246 if config.connection.backoff.initial_interval.is_zero() {
248 return Err(validation_error(
249 "Backoff initial interval must be non-zero",
250 ));
251 }
252 if config.connection.backoff.max_interval.is_zero() {
253 return Err(validation_error("Backoff max interval must be non-zero"));
254 }
255 if config.connection.backoff.multiplier <= 0.0 {
256 return Err(validation_error("Backoff multiplier must be positive"));
257 }
258 if config.connection.backoff.max_retries == 0 {
259 return Err(validation_error("Backoff max retries must be non-zero"));
260 }
261
262 Ok(())
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use crate::{DataBits, Parity, RtsType, StopBits};
269
270 use super::*;
271 use std::{fs, time::Duration};
272 use tempfile::tempdir;
273
274 #[test]
275 #[serial_test::serial]
276 fn test_default_config() {
277 let config = Config::new().unwrap();
278 assert_eq!(config.tcp.bind_port, 502);
279 assert_eq!(config.tcp.bind_addr, "127.0.0.1");
280 }
281
282 #[test]
283 #[serial_test::serial]
284 fn test_env_override() {
285 unsafe { std::env::set_var("MODBUS_RELAY_TCP__BIND_PORT", "5000") };
286 let config = Config::new().unwrap();
287 assert_eq!(config.tcp.bind_port, 5000);
288 unsafe { std::env::remove_var("MODBUS_RELAY_TCP__BIND_PORT") };
289 }
290
291 #[test]
292 #[serial_test::serial]
293 fn test_file_config() {
294 let dir = tempdir().unwrap();
295 let config_path = dir.path().join("config.yaml");
296
297 fs::write(
298 &config_path,
299 r#"
300 tcp:
301 bind_port: 9000
302 bind_addr: "192.168.1.100"
303 keep_alive: "60s"
304 rtu:
305 device: "/dev/ttyAMA0"
306 baud_rate: 9600
307 data_bits: 8
308 parity: "none"
309 stop_bits: "one"
310 flush_after_write: true
311 rts_type: "down"
312 rts_delay_us: 3500
313 transaction_timeout: "5s"
314 serial_timeout: "1s"
315 max_frame_size: 256
316 http:
317 enabled: false
318 bind_addr: "192.168.1.100"
319 bind_port: 9080
320 metrics_enabled: false
321 logging:
322 log_dir: "logs"
323 trace_frames: false
324 level: "trace"
325 format: "pretty"
326 include_location: false
327 thread_ids: false
328 thread_names: true
329 connection:
330 max_connections: 100
331 idle_timeout: "60s"
332 error_timeout: "300s"
333 connect_timeout: "5s"
334 per_ip_limits: 10
335 backoff:
336 # Initial wait time
337 initial_interval: "100ms"
338 # Maximum wait time
339 max_interval: "30s"
340 # Multiplier for each subsequent attempt
341 multiplier: 2.0
342 # Maximum number of attempts
343 max_retries: 5
344 "#,
345 )
346 .unwrap();
347
348 let config = Config::from_file(config_path).unwrap();
349 assert_eq!(config.tcp.bind_port, 9000);
350 assert_eq!(config.tcp.bind_addr, "192.168.1.100");
351 assert_eq!(config.tcp.keep_alive, Duration::from_secs(60));
352 assert_eq!(config.rtu.device, "/dev/ttyAMA0");
353 assert_eq!(config.rtu.baud_rate, 9600);
354 assert_eq!(config.rtu.data_bits, DataBits::new(8).unwrap());
355 assert_eq!(config.rtu.parity, Parity::None);
356 assert_eq!(config.rtu.stop_bits, StopBits::One);
357 assert!(config.rtu.flush_after_write);
358 assert_eq!(config.rtu.rts_type, RtsType::Down);
359 assert_eq!(config.rtu.rts_delay_us, 3500);
360 assert_eq!(config.rtu.transaction_timeout, Duration::from_secs(5));
361 assert_eq!(config.rtu.serial_timeout, Duration::from_secs(1));
362 assert_eq!(config.rtu.max_frame_size, 256);
363 assert!(!config.http.enabled);
364 assert_eq!(config.http.bind_addr, "192.168.1.100");
365 assert_eq!(config.http.bind_port, 9080);
366 assert!(!config.http.metrics_enabled);
367 assert_eq!(config.logging.log_dir, "logs");
368 assert!(!config.logging.trace_frames);
369 assert_eq!(config.logging.level, "trace");
370 assert_eq!(config.logging.format, "pretty");
371 assert!(!config.logging.include_location);
372 assert!(!config.logging.thread_ids);
373 assert!(config.logging.thread_names);
374 assert_eq!(config.connection.max_connections, 100);
375 assert_eq!(config.connection.idle_timeout, Duration::from_secs(60));
376 assert_eq!(config.connection.error_timeout, Duration::from_secs(300));
377 assert_eq!(config.connection.connect_timeout, Duration::from_secs(5));
378 assert_eq!(config.connection.per_ip_limits, Some(10));
379 assert_eq!(
380 config.connection.backoff.initial_interval,
381 Duration::from_millis(100)
382 );
383 assert_eq!(
384 config.connection.backoff.max_interval,
385 Duration::from_secs(30)
386 );
387 assert_eq!(config.connection.backoff.multiplier, 2.0);
388 assert_eq!(config.connection.backoff.max_retries, 5);
389 }
390
391 #[test]
392 #[serial_test::serial]
393 fn test_validation() {
394 unsafe { std::env::set_var("MODBUS_RELAY_TCP__BIND_PORT", "0") };
395 assert!(Config::new().is_err());
396 unsafe { std::env::remove_var("MODBUS_RELAY_TCP__BIND_PORT") };
397 }
398}