network_protocol/
config.rs

1//! # Configuration Management
2//!
3//! Centralized configuration for the network protocol library.
4//!
5//! This module provides structured configuration for servers and clients,
6//! including connection parameters, timeouts, compression settings, and security options.
7//!
8//! ## Configuration Sources
9//! - TOML files via `from_toml_file()`
10//! - Direct instantiation with defaults
11//! - Environment-specific overrides
12//!
13//! ## Security Considerations
14//! - Default compression threshold (512 bytes) balances performance and CPU
15//! - Recommended timeout values prevent slowloris attacks
16//! - TLS settings enforce modern cryptography (TLS 1.2+)
17
18use crate::error::{ProtocolError, Result};
19use crate::utils::timeout;
20use serde::{Deserialize, Serialize};
21use std::fs::File;
22use std::io::Read;
23use std::path::Path;
24use std::time::Duration;
25use tracing::Level;
26
27/// Current supported protocol version
28pub const PROTOCOL_VERSION: u8 = 1;
29
30/// Magic bytes to identify protocol packets (e.g., 0x4E50524F → "NPRO")
31pub const MAGIC_BYTES: [u8; 4] = [0x4E, 0x50, 0x52, 0x4F];
32
33/// Max allowed payload size (e.g. 16 MB)
34pub const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
35
36/// Whether to enable compression by default
37pub const ENABLE_COMPRESSION: bool = false;
38
39/// Whether to enable encryption by default
40pub const ENABLE_ENCRYPTION: bool = true;
41
42/// Main network configuration structure that contains all configurable settings
43#[derive(Debug, Clone, Deserialize, Serialize, Default)]
44pub struct NetworkConfig {
45    /// Server-specific configuration
46    #[serde(default)]
47    pub server: ServerConfig,
48
49    /// Client-specific configuration
50    #[serde(default)]
51    pub client: ClientConfig,
52
53    /// Transport configuration
54    #[serde(default)]
55    pub transport: TransportConfig,
56
57    /// Logging configuration
58    #[serde(default)]
59    pub logging: LoggingConfig,
60}
61
62// Default implementation is now derived
63
64impl NetworkConfig {
65    /// Load configuration from a TOML file
66    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
67        let mut file = File::open(path)
68            .map_err(|e| ProtocolError::ConfigError(format!("Failed to open config file: {e}")))?;
69
70        let mut contents = String::new();
71        file.read_to_string(&mut contents)
72            .map_err(|e| ProtocolError::ConfigError(format!("Failed to read config file: {e}")))?;
73
74        Self::from_toml(&contents)
75    }
76
77    /// Load configuration from TOML string
78    pub fn from_toml(content: &str) -> Result<Self> {
79        toml::from_str::<Self>(content)
80            .map_err(|e| ProtocolError::ConfigError(format!("Failed to parse TOML: {e}")))
81    }
82
83    /// Load configuration from environment variables
84    pub fn from_env() -> Result<Self> {
85        // Start with defaults
86        let mut config = Self::default();
87
88        // Override with environment variables
89        if let Ok(addr) = std::env::var("NETWORK_PROTOCOL_SERVER_ADDRESS") {
90            config.server.address = addr;
91        }
92
93        if let Ok(capacity) = std::env::var("NETWORK_PROTOCOL_BACKPRESSURE_LIMIT") {
94            if let Ok(val) = capacity.parse::<usize>() {
95                config.server.backpressure_limit = val;
96            }
97        }
98
99        if let Ok(timeout) = std::env::var("NETWORK_PROTOCOL_CONNECTION_TIMEOUT_MS") {
100            if let Ok(val) = timeout.parse::<u64>() {
101                config.server.connection_timeout = Duration::from_millis(val);
102                config.client.connection_timeout = Duration::from_millis(val);
103            }
104        }
105
106        if let Ok(heartbeat) = std::env::var("NETWORK_PROTOCOL_HEARTBEAT_INTERVAL_MS") {
107            if let Ok(val) = heartbeat.parse::<u64>() {
108                config.server.heartbeat_interval = Duration::from_millis(val);
109            }
110        }
111
112        // Add more environment variables as needed
113
114        Ok(config)
115    }
116
117    /// Apply overrides to the default configuration
118    pub fn default_with_overrides<F>(mutator: F) -> Self
119    where
120        F: FnOnce(&mut Self),
121    {
122        let mut config = Self::default();
123        mutator(&mut config);
124        config
125    }
126
127    /// Generate example configuration file content
128    pub fn example_config() -> String {
129        toml::to_string_pretty(&Self::default())
130            .unwrap_or_else(|_| String::from("# Failed to generate example config"))
131    }
132
133    /// Save configuration to a file
134    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
135        let content = toml::to_string_pretty(self)
136            .map_err(|e| ProtocolError::ConfigError(format!("Failed to serialize config: {e}")))?;
137
138        std::fs::write(path, content)
139            .map_err(|e| ProtocolError::ConfigError(format!("Failed to write config file: {e}")))?;
140
141        Ok(())
142    }
143}
144
145/// Server-specific configuration
146#[derive(Debug, Clone, Deserialize, Serialize)]
147pub struct ServerConfig {
148    /// Server listen address (e.g., "127.0.0.1:9000")
149    pub address: String,
150
151    /// Maximum number of messages in the backpressure queue
152    pub backpressure_limit: usize,
153
154    /// Timeout for client connections
155    #[serde(with = "duration_serde")]
156    pub connection_timeout: Duration,
157
158    /// Interval for sending heartbeat messages
159    #[serde(with = "duration_serde")]
160    pub heartbeat_interval: Duration,
161
162    /// Timeout for graceful server shutdown
163    #[serde(with = "duration_serde")]
164    pub shutdown_timeout: Duration,
165
166    /// Maximum number of concurrent connections
167    pub max_connections: usize,
168}
169
170impl Default for ServerConfig {
171    fn default() -> Self {
172        Self {
173            address: String::from("127.0.0.1:9000"),
174            backpressure_limit: 32,
175            connection_timeout: timeout::DEFAULT_TIMEOUT,
176            heartbeat_interval: timeout::KEEPALIVE_INTERVAL,
177            shutdown_timeout: timeout::SHUTDOWN_TIMEOUT,
178            max_connections: 1000,
179        }
180    }
181}
182
183/// Client-specific configuration
184#[derive(Debug, Clone, Deserialize, Serialize)]
185pub struct ClientConfig {
186    /// Target server address
187    pub address: String,
188
189    /// Timeout for connection attempts
190    #[serde(with = "duration_serde")]
191    pub connection_timeout: Duration,
192
193    /// Timeout for individual operations
194    #[serde(with = "duration_serde")]
195    pub operation_timeout: Duration,
196
197    /// Timeout for waiting for response messages
198    #[serde(with = "duration_serde")]
199    pub response_timeout: Duration,
200
201    /// Interval for sending heartbeat messages
202    #[serde(with = "duration_serde")]
203    pub heartbeat_interval: Duration,
204
205    /// Whether to automatically reconnect on connection loss
206    pub auto_reconnect: bool,
207
208    /// Maximum number of reconnect attempts before giving up
209    pub max_reconnect_attempts: u32,
210
211    /// Delay between reconnect attempts
212    #[serde(with = "duration_serde")]
213    pub reconnect_delay: Duration,
214}
215
216impl Default for ClientConfig {
217    fn default() -> Self {
218        Self {
219            address: String::from("127.0.0.1:9000"),
220            connection_timeout: timeout::DEFAULT_TIMEOUT,
221            operation_timeout: Duration::from_secs(3),
222            response_timeout: Duration::from_secs(30),
223            heartbeat_interval: timeout::KEEPALIVE_INTERVAL,
224            auto_reconnect: true,
225            max_reconnect_attempts: 3,
226            reconnect_delay: Duration::from_secs(1),
227        }
228    }
229}
230
231/// Transport configuration
232#[derive(Debug, Clone, Deserialize, Serialize)]
233pub struct TransportConfig {
234    /// Whether to enable compression
235    pub compression_enabled: bool,
236
237    /// Whether to enable encryption
238    pub encryption_enabled: bool,
239
240    /// Maximum allowed payload size in bytes
241    pub max_payload_size: usize,
242
243    /// Compression level (when compression is enabled)
244    pub compression_level: i32,
245
246    /// Minimum payload size (bytes) before compression is applied
247    /// Payloads smaller than this threshold should bypass compression to reduce overhead
248    #[serde(default)]
249    pub compression_threshold_bytes: usize,
250}
251
252impl Default for TransportConfig {
253    fn default() -> Self {
254        Self {
255            compression_enabled: ENABLE_COMPRESSION,
256            encryption_enabled: ENABLE_ENCRYPTION,
257            max_payload_size: MAX_PAYLOAD_SIZE,
258            compression_level: 6, // Default compression level (medium)
259            compression_threshold_bytes: 512,
260        }
261    }
262}
263
264/// Logging configuration
265#[derive(Debug, Clone, Deserialize, Serialize)]
266pub struct LoggingConfig {
267    /// Application name for logs
268    pub app_name: String,
269
270    /// Log level
271    #[serde(with = "log_level_serde")]
272    pub log_level: Level,
273
274    /// Whether to log to console
275    pub log_to_console: bool,
276
277    /// Whether to log to file
278    pub log_to_file: bool,
279
280    /// Path to log file (if log_to_file is true)
281    pub log_file_path: Option<String>,
282
283    /// Whether to use JSON formatting for logs
284    pub json_format: bool,
285}
286
287impl Default for LoggingConfig {
288    fn default() -> Self {
289        Self {
290            app_name: String::from("network-protocol"),
291            log_level: Level::INFO,
292            log_to_console: true,
293            log_to_file: false,
294            log_file_path: None,
295            json_format: false,
296        }
297    }
298}
299
300/// Helper module for Duration serialization/deserialization
301mod duration_serde {
302    use serde::{Deserialize, Deserializer, Serialize, Serializer};
303    use std::time::Duration;
304
305    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
306    where
307        S: Serializer,
308    {
309        let millis = duration.as_millis() as u64;
310        millis.serialize(serializer)
311    }
312
313    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
314    where
315        D: Deserializer<'de>,
316    {
317        let millis = u64::deserialize(deserializer)?;
318        Ok(Duration::from_millis(millis))
319    }
320}
321
322/// Helper module for tracing::Level serialization/deserialization
323mod log_level_serde {
324    use serde::{Deserialize, Deserializer, Serialize, Serializer};
325    use std::str::FromStr;
326    use tracing::Level;
327
328    pub fn serialize<S>(level: &Level, serializer: S) -> Result<S::Ok, S::Error>
329    where
330        S: Serializer,
331    {
332        let level_str = match *level {
333            Level::TRACE => "trace",
334            Level::DEBUG => "debug",
335            Level::INFO => "info",
336            Level::WARN => "warn",
337            Level::ERROR => "error",
338        };
339        level_str.serialize(serializer)
340    }
341
342    pub fn deserialize<'de, D>(deserializer: D) -> Result<Level, D::Error>
343    where
344        D: Deserializer<'de>,
345    {
346        let level_str = String::deserialize(deserializer)?;
347        Level::from_str(&level_str)
348            .map_err(|_| serde::de::Error::custom(format!("Invalid log level: {level_str}")))
349    }
350}