Skip to main content

network_protocol/
config.rs

1//! # Configuration Management
2//!
3//! Centralized configuration for the network protocol library.
4//!
5//! This module provides structured configuration for servers and clients,
6//! including connection parameters, timeouts, compression settings, and security options.
7//!
8//! ## Configuration Sources
9//! - TOML files via `from_toml_file()`
10//! - Direct instantiation with defaults
11//! - Environment-specific overrides
12//!
13//! ## Security Considerations
14//! - Default compression threshold (512 bytes) balances performance and CPU
15//! - Recommended timeout values prevent slowloris attacks
16//! - TLS settings enforce modern cryptography (TLS 1.2+)
17
18use crate::error::{ProtocolError, Result};
19use crate::utils::timeout;
20use serde::{Deserialize, Serialize};
21use std::fs::File;
22use std::io::Read;
23use std::path::Path;
24use std::time::Duration;
25use tracing::Level;
26
27/// Current supported protocol version
28pub const PROTOCOL_VERSION: u8 = 1;
29
30/// Magic bytes to identify protocol packets (e.g., 0x4E50524F → "NPRO")
31pub const MAGIC_BYTES: [u8; 4] = [0x4E, 0x50, 0x52, 0x4F];
32
33/// Max allowed payload size (e.g. 16 MB)
34pub const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
35
36/// Whether to enable compression by default
37pub const ENABLE_COMPRESSION: bool = false;
38
39/// Whether to enable encryption by default
40pub const ENABLE_ENCRYPTION: bool = true;
41
42/// Main network configuration structure that contains all configurable settings
43#[derive(Debug, Clone, Deserialize, Serialize, Default)]
44pub struct NetworkConfig {
45    /// Server-specific configuration
46    #[serde(default)]
47    pub server: ServerConfig,
48
49    /// Client-specific configuration
50    #[serde(default)]
51    pub client: ClientConfig,
52
53    /// Transport configuration
54    #[serde(default)]
55    pub transport: TransportConfig,
56
57    /// Logging configuration
58    #[serde(default)]
59    pub logging: LoggingConfig,
60}
61
62// Default implementation is now derived
63
64impl NetworkConfig {
65    /// Load configuration from a TOML file
66    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
67        let mut file = File::open(path)
68            .map_err(|e| ProtocolError::ConfigError(format!("Failed to open config file: {e}")))?;
69
70        let mut contents = String::new();
71        file.read_to_string(&mut contents)
72            .map_err(|e| ProtocolError::ConfigError(format!("Failed to read config file: {e}")))?;
73
74        Self::from_toml(&contents)
75    }
76
77    /// Load configuration from TOML string
78    pub fn from_toml(content: &str) -> Result<Self> {
79        toml::from_str::<Self>(content)
80            .map_err(|e| ProtocolError::ConfigError(format!("Failed to parse TOML: {e}")))
81    }
82
83    /// Load configuration from environment variables
84    pub fn from_env() -> Result<Self> {
85        // Start with defaults
86        let mut config = Self::default();
87
88        // Override with environment variables
89        if let Ok(addr) = std::env::var("NETWORK_PROTOCOL_SERVER_ADDRESS") {
90            config.server.address = addr;
91        }
92
93        if let Ok(capacity) = std::env::var("NETWORK_PROTOCOL_BACKPRESSURE_LIMIT") {
94            if let Ok(val) = capacity.parse::<usize>() {
95                config.server.backpressure_limit = val;
96            }
97        }
98
99        if let Ok(timeout) = std::env::var("NETWORK_PROTOCOL_CONNECTION_TIMEOUT_MS") {
100            if let Ok(val) = timeout.parse::<u64>() {
101                config.server.connection_timeout = Duration::from_millis(val);
102                config.client.connection_timeout = Duration::from_millis(val);
103            }
104        }
105
106        if let Ok(heartbeat) = std::env::var("NETWORK_PROTOCOL_HEARTBEAT_INTERVAL_MS") {
107            if let Ok(val) = heartbeat.parse::<u64>() {
108                config.server.heartbeat_interval = Duration::from_millis(val);
109            }
110        }
111
112        // Add more environment variables as needed
113
114        Ok(config)
115    }
116
117    /// Apply overrides to the default configuration
118    pub fn default_with_overrides<F>(mutator: F) -> Self
119    where
120        F: FnOnce(&mut Self),
121    {
122        let mut config = Self::default();
123        mutator(&mut config);
124        config
125    }
126
127    /// Generate example configuration file content
128    pub fn example_config() -> String {
129        toml::to_string_pretty(&Self::default())
130            .unwrap_or_else(|_| String::from("# Failed to generate example config"))
131    }
132
133    /// Save configuration to a file
134    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
135        let content = toml::to_string_pretty(self)
136            .map_err(|e| ProtocolError::ConfigError(format!("Failed to serialize config: {e}")))?;
137
138        std::fs::write(path, content)
139            .map_err(|e| ProtocolError::ConfigError(format!("Failed to write config file: {e}")))?;
140
141        Ok(())
142    }
143
144    /// Validate the configuration for common issues and misconfigurations
145    ///
146    /// Returns a list of validation errors. Empty list means configuration is valid.
147    pub fn validate(&self) -> Vec<String> {
148        let mut errors = Vec::new();
149
150        // Validate server configuration
151        errors.extend(self.server.validate());
152
153        // Validate client configuration
154        errors.extend(self.client.validate());
155
156        // Validate transport configuration
157        errors.extend(self.transport.validate());
158
159        // Validate logging configuration
160        errors.extend(self.logging.validate());
161
162        errors
163    }
164
165    /// Validate and return Result - convenience method
166    pub fn validate_strict(&self) -> Result<()> {
167        let errors = self.validate();
168        if errors.is_empty() {
169            Ok(())
170        } else {
171            Err(ProtocolError::ConfigError(format!(
172                "Configuration validation failed:\n  - {}",
173                errors.join("\n  - ")
174            )))
175        }
176    }
177}
178
179/// Server-specific configuration
180#[derive(Debug, Clone, Deserialize, Serialize)]
181pub struct ServerConfig {
182    /// Server listen address (e.g., "127.0.0.1:9000")
183    pub address: String,
184
185    /// Maximum number of messages in the backpressure queue
186    pub backpressure_limit: usize,
187
188    /// Timeout for client connections
189    #[serde(with = "duration_serde")]
190    pub connection_timeout: Duration,
191
192    /// Interval for sending heartbeat messages
193    #[serde(with = "duration_serde")]
194    pub heartbeat_interval: Duration,
195
196    /// Timeout for graceful server shutdown
197    #[serde(with = "duration_serde")]
198    pub shutdown_timeout: Duration,
199
200    /// Maximum number of concurrent connections
201    pub max_connections: usize,
202}
203
204impl Default for ServerConfig {
205    fn default() -> Self {
206        Self {
207            address: String::from("127.0.0.1:9000"),
208            backpressure_limit: 32,
209            connection_timeout: timeout::DEFAULT_TIMEOUT,
210            heartbeat_interval: timeout::KEEPALIVE_INTERVAL,
211            shutdown_timeout: timeout::SHUTDOWN_TIMEOUT,
212            max_connections: 1000,
213        }
214    }
215}
216
217impl ServerConfig {
218    /// Validate server configuration
219    pub fn validate(&self) -> Vec<String> {
220        let mut errors = Vec::new();
221
222        // Validate address format
223        if self.address.is_empty() {
224            errors.push("Server address cannot be empty".to_string());
225        } else if self.address.parse::<std::net::SocketAddr>().is_err() {
226            errors.push(format!(
227                "Invalid server address format: '{}' (expected format: '0.0.0.0:8080')",
228                self.address
229            ));
230        }
231
232        // Validate backpressure limit
233        if self.backpressure_limit == 0 {
234            errors.push("Backpressure limit must be greater than 0".to_string());
235        } else if self.backpressure_limit > 1_000_000 {
236            errors.push(format!(
237                "Backpressure limit too large: {} (max recommended: 1,000,000)",
238                self.backpressure_limit
239            ));
240        }
241
242        // Validate connection timeout
243        if self.connection_timeout.as_millis() < 100 {
244            errors.push("Connection timeout too short (minimum: 100ms)".to_string());
245        } else if self.connection_timeout.as_secs() > 300 {
246            errors.push("Connection timeout too long (maximum: 300s)".to_string());
247        }
248
249        // Validate heartbeat interval
250        if self.heartbeat_interval.as_millis() < 100 {
251            errors.push("Heartbeat interval too short (minimum: 100ms)".to_string());
252        } else if self.heartbeat_interval.as_secs() > 3600 {
253            errors.push("Heartbeat interval too long (maximum: 1 hour)".to_string());
254        }
255
256        // Validate shutdown timeout
257        if self.shutdown_timeout.as_secs() < 1 {
258            errors.push("Shutdown timeout too short (minimum: 1s)".to_string());
259        } else if self.shutdown_timeout.as_secs() > 60 {
260            errors.push("Shutdown timeout too long (maximum: 60s)".to_string());
261        }
262
263        // Validate max connections
264        if self.max_connections == 0 {
265            errors.push("Max connections must be greater than 0".to_string());
266        } else if self.max_connections > 100_000 {
267            errors.push(format!(
268                "Max connections very high: {} (ensure system resources can support this)",
269                self.max_connections
270            ));
271        }
272
273        errors
274    }
275}
276
277/// Client-specific configuration
278#[derive(Debug, Clone, Deserialize, Serialize)]
279pub struct ClientConfig {
280    /// Target server address
281    pub address: String,
282
283    /// Timeout for connection attempts
284    #[serde(with = "duration_serde")]
285    pub connection_timeout: Duration,
286
287    /// Timeout for individual operations
288    #[serde(with = "duration_serde")]
289    pub operation_timeout: Duration,
290
291    /// Timeout for waiting for response messages
292    #[serde(with = "duration_serde")]
293    pub response_timeout: Duration,
294
295    /// Interval for sending heartbeat messages
296    #[serde(with = "duration_serde")]
297    pub heartbeat_interval: Duration,
298
299    /// Whether to automatically reconnect on connection loss
300    pub auto_reconnect: bool,
301
302    /// Maximum number of reconnect attempts before giving up
303    pub max_reconnect_attempts: u32,
304
305    /// Delay between reconnect attempts
306    #[serde(with = "duration_serde")]
307    pub reconnect_delay: Duration,
308}
309
310impl Default for ClientConfig {
311    fn default() -> Self {
312        Self {
313            address: String::from("127.0.0.1:9000"),
314            connection_timeout: timeout::DEFAULT_TIMEOUT,
315            operation_timeout: Duration::from_secs(3),
316            response_timeout: Duration::from_secs(30),
317            heartbeat_interval: timeout::KEEPALIVE_INTERVAL,
318            auto_reconnect: true,
319            max_reconnect_attempts: 3,
320            reconnect_delay: Duration::from_secs(1),
321        }
322    }
323}
324
325impl ClientConfig {
326    /// Validate client configuration
327    pub fn validate(&self) -> Vec<String> {
328        let mut errors = Vec::new();
329
330        // Validate address format
331        if self.address.is_empty() {
332            errors.push("Client address cannot be empty".to_string());
333        } else if self.address.parse::<std::net::SocketAddr>().is_err() {
334            errors.push(format!(
335                "Invalid client address format: '{}' (expected format: 'example.com:8080')",
336                self.address
337            ));
338        }
339
340        // Validate timeouts
341        if self.connection_timeout.as_millis() < 100 {
342            errors.push("Connection timeout too short (minimum: 100ms)".to_string());
343        }
344
345        if self.operation_timeout.as_millis() < 10 {
346            errors.push("Operation timeout too short (minimum: 10ms)".to_string());
347        }
348
349        if self.response_timeout.as_millis() < 100 {
350            errors.push("Response timeout too short (minimum: 100ms)".to_string());
351        }
352
353        // Validate reconnect settings
354        if self.auto_reconnect && self.max_reconnect_attempts == 0 {
355            errors.push(
356                "Max reconnect attempts must be greater than 0 when auto_reconnect is enabled"
357                    .to_string(),
358            );
359        }
360
361        if self.reconnect_delay.as_millis() < 10 {
362            errors.push("Reconnect delay too short (minimum: 10ms)".to_string());
363        } else if self.reconnect_delay.as_secs() > 60 {
364            errors.push("Reconnect delay too long (maximum: 60s)".to_string());
365        }
366
367        errors
368    }
369}
370
371/// Transport configuration
372#[derive(Debug, Clone, Deserialize, Serialize)]
373pub struct TransportConfig {
374    /// Whether to enable compression
375    pub compression_enabled: bool,
376
377    /// Whether to enable encryption
378    pub encryption_enabled: bool,
379
380    /// Maximum allowed payload size in bytes
381    pub max_payload_size: usize,
382
383    /// Compression level (when compression is enabled)
384    pub compression_level: i32,
385
386    /// Minimum payload size (bytes) before compression is applied
387    /// Payloads smaller than this threshold should bypass compression to reduce overhead
388    #[serde(default)]
389    pub compression_threshold_bytes: usize,
390}
391
392impl Default for TransportConfig {
393    fn default() -> Self {
394        Self {
395            compression_enabled: ENABLE_COMPRESSION,
396            encryption_enabled: ENABLE_ENCRYPTION,
397            max_payload_size: MAX_PAYLOAD_SIZE,
398            compression_level: 6, // Default compression level (medium)
399            compression_threshold_bytes: 512,
400        }
401    }
402}
403
404impl TransportConfig {
405    /// Validate transport configuration
406    pub fn validate(&self) -> Vec<String> {
407        let mut errors = Vec::new();
408
409        // Validate max payload size
410        if self.max_payload_size == 0 {
411            errors.push("Max payload size cannot be 0".to_string());
412        } else if self.max_payload_size < 1024 {
413            errors.push("Max payload size too small (minimum: 1 KB)".to_string());
414        } else if self.max_payload_size > 100 * 1024 * 1024 {
415            errors.push(format!(
416                "Max payload size too large: {} bytes (maximum recommended: 100 MB)",
417                self.max_payload_size
418            ));
419        }
420
421        // Validate compression settings
422        if self.compression_enabled {
423            if self.compression_level < 1 || self.compression_level > 22 {
424                errors.push(format!(
425                    "Invalid compression level: {} (valid range: 1-22)",
426                    self.compression_level
427                ));
428            }
429
430            if self.compression_threshold_bytes > self.max_payload_size {
431                errors.push(
432                    "Compression threshold cannot be larger than max payload size".to_string(),
433                );
434            }
435        }
436
437        // Warn if encryption is disabled
438        if !self.encryption_enabled {
439            errors.push(
440                "WARNING: Encryption is disabled - not recommended for production".to_string(),
441            );
442        }
443
444        errors
445    }
446}
447
448/// Logging configuration
449#[derive(Debug, Clone, Deserialize, Serialize)]
450pub struct LoggingConfig {
451    /// Application name for logs
452    pub app_name: String,
453
454    /// Log level
455    #[serde(with = "log_level_serde")]
456    pub log_level: Level,
457
458    /// Whether to log to console
459    pub log_to_console: bool,
460
461    /// Whether to log to file
462    pub log_to_file: bool,
463
464    /// Path to log file (if log_to_file is true)
465    pub log_file_path: Option<String>,
466
467    /// Whether to use JSON formatting for logs
468    pub json_format: bool,
469}
470
471impl Default for LoggingConfig {
472    fn default() -> Self {
473        Self {
474            app_name: String::from("network-protocol"),
475            log_level: Level::INFO,
476            log_to_console: true,
477            log_to_file: false,
478            log_file_path: None,
479            json_format: false,
480        }
481    }
482}
483
484impl LoggingConfig {
485    /// Validate logging configuration
486    pub fn validate(&self) -> Vec<String> {
487        let mut errors = Vec::new();
488
489        // Validate app name
490        if self.app_name.is_empty() {
491            errors.push("Application name cannot be empty".to_string());
492        } else if self.app_name.len() > 64 {
493            errors.push(format!(
494                "Application name too long: {} characters (maximum: 64)",
495                self.app_name.len()
496            ));
497        }
498
499        // Validate file logging configuration
500        if self.log_to_file {
501            if let Some(ref path) = self.log_file_path {
502                // Check if parent directory exists (if path is absolute)
503                if let Some(parent) = std::path::Path::new(path).parent() {
504                    if !parent.as_os_str().is_empty() && !parent.exists() {
505                        errors.push(format!(
506                            "Log file directory does not exist: {}",
507                            parent.display()
508                        ));
509                    }
510                }
511            } else {
512                errors.push("log_file_path must be specified when log_to_file is true".to_string());
513            }
514        }
515
516        // Validate at least one output is enabled
517        if !self.log_to_console && !self.log_to_file {
518            errors
519                .push("At least one logging output (console or file) must be enabled".to_string());
520        }
521
522        errors
523    }
524}
525
526/// Helper module for Duration serialization/deserialization
527mod duration_serde {
528    use serde::{Deserialize, Deserializer, Serialize, Serializer};
529    use std::time::Duration;
530
531    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
532    where
533        S: Serializer,
534    {
535        let millis = duration.as_millis() as u64;
536        millis.serialize(serializer)
537    }
538
539    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
540    where
541        D: Deserializer<'de>,
542    {
543        let millis = u64::deserialize(deserializer)?;
544        Ok(Duration::from_millis(millis))
545    }
546}
547
548/// Helper module for tracing::Level serialization/deserialization
549mod log_level_serde {
550    use serde::{Deserialize, Deserializer, Serialize, Serializer};
551    use std::str::FromStr;
552    use tracing::Level;
553
554    pub fn serialize<S>(level: &Level, serializer: S) -> Result<S::Ok, S::Error>
555    where
556        S: Serializer,
557    {
558        let level_str = match *level {
559            Level::TRACE => "trace",
560            Level::DEBUG => "debug",
561            Level::INFO => "info",
562            Level::WARN => "warn",
563            Level::ERROR => "error",
564        };
565        level_str.serialize(serializer)
566    }
567
568    pub fn deserialize<'de, D>(deserializer: D) -> Result<Level, D::Error>
569    where
570        D: Deserializer<'de>,
571    {
572        let level_str = String::deserialize(deserializer)?;
573        Level::from_str(&level_str)
574            .map_err(|_| serde::de::Error::custom(format!("Invalid log level: {level_str}")))
575    }
576}