1use crate::error::{ProtocolError, Result};
19use crate::utils::timeout;
20use serde::{Deserialize, Serialize};
21use std::fs::File;
22use std::io::Read;
23use std::path::Path;
24use std::time::Duration;
25use tracing::Level;
26
27pub const PROTOCOL_VERSION: u8 = 1;
29
30pub const MAGIC_BYTES: [u8; 4] = [0x4E, 0x50, 0x52, 0x4F];
32
33pub const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
35
36pub const ENABLE_COMPRESSION: bool = false;
38
39pub const ENABLE_ENCRYPTION: bool = true;
41
42#[derive(Debug, Clone, Deserialize, Serialize, Default)]
44pub struct NetworkConfig {
45 #[serde(default)]
47 pub server: ServerConfig,
48
49 #[serde(default)]
51 pub client: ClientConfig,
52
53 #[serde(default)]
55 pub transport: TransportConfig,
56
57 #[serde(default)]
59 pub logging: LoggingConfig,
60}
61
62impl NetworkConfig {
65 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
67 let mut file = File::open(path)
68 .map_err(|e| ProtocolError::ConfigError(format!("Failed to open config file: {e}")))?;
69
70 let mut contents = String::new();
71 file.read_to_string(&mut contents)
72 .map_err(|e| ProtocolError::ConfigError(format!("Failed to read config file: {e}")))?;
73
74 Self::from_toml(&contents)
75 }
76
77 pub fn from_toml(content: &str) -> Result<Self> {
79 toml::from_str::<Self>(content)
80 .map_err(|e| ProtocolError::ConfigError(format!("Failed to parse TOML: {e}")))
81 }
82
83 pub fn from_env() -> Result<Self> {
85 let mut config = Self::default();
87
88 if let Ok(addr) = std::env::var("NETWORK_PROTOCOL_SERVER_ADDRESS") {
90 config.server.address = addr;
91 }
92
93 if let Ok(capacity) = std::env::var("NETWORK_PROTOCOL_BACKPRESSURE_LIMIT") {
94 if let Ok(val) = capacity.parse::<usize>() {
95 config.server.backpressure_limit = val;
96 }
97 }
98
99 if let Ok(timeout) = std::env::var("NETWORK_PROTOCOL_CONNECTION_TIMEOUT_MS") {
100 if let Ok(val) = timeout.parse::<u64>() {
101 config.server.connection_timeout = Duration::from_millis(val);
102 config.client.connection_timeout = Duration::from_millis(val);
103 }
104 }
105
106 if let Ok(heartbeat) = std::env::var("NETWORK_PROTOCOL_HEARTBEAT_INTERVAL_MS") {
107 if let Ok(val) = heartbeat.parse::<u64>() {
108 config.server.heartbeat_interval = Duration::from_millis(val);
109 }
110 }
111
112 Ok(config)
115 }
116
117 pub fn default_with_overrides<F>(mutator: F) -> Self
119 where
120 F: FnOnce(&mut Self),
121 {
122 let mut config = Self::default();
123 mutator(&mut config);
124 config
125 }
126
127 pub fn example_config() -> String {
129 toml::to_string_pretty(&Self::default())
130 .unwrap_or_else(|_| String::from("# Failed to generate example config"))
131 }
132
133 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
135 let content = toml::to_string_pretty(self)
136 .map_err(|e| ProtocolError::ConfigError(format!("Failed to serialize config: {e}")))?;
137
138 std::fs::write(path, content)
139 .map_err(|e| ProtocolError::ConfigError(format!("Failed to write config file: {e}")))?;
140
141 Ok(())
142 }
143
144 pub fn validate(&self) -> Vec<String> {
148 let mut errors = Vec::new();
149
150 errors.extend(self.server.validate());
152
153 errors.extend(self.client.validate());
155
156 errors.extend(self.transport.validate());
158
159 errors.extend(self.logging.validate());
161
162 errors
163 }
164
165 pub fn validate_strict(&self) -> Result<()> {
167 let errors = self.validate();
168 if errors.is_empty() {
169 Ok(())
170 } else {
171 Err(ProtocolError::ConfigError(format!(
172 "Configuration validation failed:\n - {}",
173 errors.join("\n - ")
174 )))
175 }
176 }
177}
178
179#[derive(Debug, Clone, Deserialize, Serialize)]
181pub struct ServerConfig {
182 pub address: String,
184
185 pub backpressure_limit: usize,
187
188 #[serde(with = "duration_serde")]
190 pub connection_timeout: Duration,
191
192 #[serde(with = "duration_serde")]
194 pub heartbeat_interval: Duration,
195
196 #[serde(with = "duration_serde")]
198 pub shutdown_timeout: Duration,
199
200 pub max_connections: usize,
202}
203
204impl Default for ServerConfig {
205 fn default() -> Self {
206 Self {
207 address: String::from("127.0.0.1:9000"),
208 backpressure_limit: 32,
209 connection_timeout: timeout::DEFAULT_TIMEOUT,
210 heartbeat_interval: timeout::KEEPALIVE_INTERVAL,
211 shutdown_timeout: timeout::SHUTDOWN_TIMEOUT,
212 max_connections: 1000,
213 }
214 }
215}
216
217impl ServerConfig {
218 pub fn validate(&self) -> Vec<String> {
220 let mut errors = Vec::new();
221
222 if self.address.is_empty() {
224 errors.push("Server address cannot be empty".to_string());
225 } else if self.address.parse::<std::net::SocketAddr>().is_err() {
226 errors.push(format!(
227 "Invalid server address format: '{}' (expected format: '0.0.0.0:8080')",
228 self.address
229 ));
230 }
231
232 if self.backpressure_limit == 0 {
234 errors.push("Backpressure limit must be greater than 0".to_string());
235 } else if self.backpressure_limit > 1_000_000 {
236 errors.push(format!(
237 "Backpressure limit too large: {} (max recommended: 1,000,000)",
238 self.backpressure_limit
239 ));
240 }
241
242 if self.connection_timeout.as_millis() < 100 {
244 errors.push("Connection timeout too short (minimum: 100ms)".to_string());
245 } else if self.connection_timeout.as_secs() > 300 {
246 errors.push("Connection timeout too long (maximum: 300s)".to_string());
247 }
248
249 if self.heartbeat_interval.as_millis() < 100 {
251 errors.push("Heartbeat interval too short (minimum: 100ms)".to_string());
252 } else if self.heartbeat_interval.as_secs() > 3600 {
253 errors.push("Heartbeat interval too long (maximum: 1 hour)".to_string());
254 }
255
256 if self.shutdown_timeout.as_secs() < 1 {
258 errors.push("Shutdown timeout too short (minimum: 1s)".to_string());
259 } else if self.shutdown_timeout.as_secs() > 60 {
260 errors.push("Shutdown timeout too long (maximum: 60s)".to_string());
261 }
262
263 if self.max_connections == 0 {
265 errors.push("Max connections must be greater than 0".to_string());
266 } else if self.max_connections > 100_000 {
267 errors.push(format!(
268 "Max connections very high: {} (ensure system resources can support this)",
269 self.max_connections
270 ));
271 }
272
273 errors
274 }
275}
276
277#[derive(Debug, Clone, Deserialize, Serialize)]
279pub struct ClientConfig {
280 pub address: String,
282
283 #[serde(with = "duration_serde")]
285 pub connection_timeout: Duration,
286
287 #[serde(with = "duration_serde")]
289 pub operation_timeout: Duration,
290
291 #[serde(with = "duration_serde")]
293 pub response_timeout: Duration,
294
295 #[serde(with = "duration_serde")]
297 pub heartbeat_interval: Duration,
298
299 pub auto_reconnect: bool,
301
302 pub max_reconnect_attempts: u32,
304
305 #[serde(with = "duration_serde")]
307 pub reconnect_delay: Duration,
308}
309
310impl Default for ClientConfig {
311 fn default() -> Self {
312 Self {
313 address: String::from("127.0.0.1:9000"),
314 connection_timeout: timeout::DEFAULT_TIMEOUT,
315 operation_timeout: Duration::from_secs(3),
316 response_timeout: Duration::from_secs(30),
317 heartbeat_interval: timeout::KEEPALIVE_INTERVAL,
318 auto_reconnect: true,
319 max_reconnect_attempts: 3,
320 reconnect_delay: Duration::from_secs(1),
321 }
322 }
323}
324
325impl ClientConfig {
326 pub fn validate(&self) -> Vec<String> {
328 let mut errors = Vec::new();
329
330 if self.address.is_empty() {
332 errors.push("Client address cannot be empty".to_string());
333 } else if self.address.parse::<std::net::SocketAddr>().is_err() {
334 errors.push(format!(
335 "Invalid client address format: '{}' (expected format: 'example.com:8080')",
336 self.address
337 ));
338 }
339
340 if self.connection_timeout.as_millis() < 100 {
342 errors.push("Connection timeout too short (minimum: 100ms)".to_string());
343 }
344
345 if self.operation_timeout.as_millis() < 10 {
346 errors.push("Operation timeout too short (minimum: 10ms)".to_string());
347 }
348
349 if self.response_timeout.as_millis() < 100 {
350 errors.push("Response timeout too short (minimum: 100ms)".to_string());
351 }
352
353 if self.auto_reconnect && self.max_reconnect_attempts == 0 {
355 errors.push(
356 "Max reconnect attempts must be greater than 0 when auto_reconnect is enabled"
357 .to_string(),
358 );
359 }
360
361 if self.reconnect_delay.as_millis() < 10 {
362 errors.push("Reconnect delay too short (minimum: 10ms)".to_string());
363 } else if self.reconnect_delay.as_secs() > 60 {
364 errors.push("Reconnect delay too long (maximum: 60s)".to_string());
365 }
366
367 errors
368 }
369}
370
371#[derive(Debug, Clone, Deserialize, Serialize)]
373pub struct TransportConfig {
374 pub compression_enabled: bool,
376
377 pub encryption_enabled: bool,
379
380 pub max_payload_size: usize,
382
383 pub compression_level: i32,
385
386 #[serde(default)]
389 pub compression_threshold_bytes: usize,
390}
391
392impl Default for TransportConfig {
393 fn default() -> Self {
394 Self {
395 compression_enabled: ENABLE_COMPRESSION,
396 encryption_enabled: ENABLE_ENCRYPTION,
397 max_payload_size: MAX_PAYLOAD_SIZE,
398 compression_level: 6, compression_threshold_bytes: 512,
400 }
401 }
402}
403
404impl TransportConfig {
405 pub fn validate(&self) -> Vec<String> {
407 let mut errors = Vec::new();
408
409 if self.max_payload_size == 0 {
411 errors.push("Max payload size cannot be 0".to_string());
412 } else if self.max_payload_size < 1024 {
413 errors.push("Max payload size too small (minimum: 1 KB)".to_string());
414 } else if self.max_payload_size > 100 * 1024 * 1024 {
415 errors.push(format!(
416 "Max payload size too large: {} bytes (maximum recommended: 100 MB)",
417 self.max_payload_size
418 ));
419 }
420
421 if self.compression_enabled {
423 if self.compression_level < 1 || self.compression_level > 22 {
424 errors.push(format!(
425 "Invalid compression level: {} (valid range: 1-22)",
426 self.compression_level
427 ));
428 }
429
430 if self.compression_threshold_bytes > self.max_payload_size {
431 errors.push(
432 "Compression threshold cannot be larger than max payload size".to_string(),
433 );
434 }
435 }
436
437 if !self.encryption_enabled {
439 errors.push(
440 "WARNING: Encryption is disabled - not recommended for production".to_string(),
441 );
442 }
443
444 errors
445 }
446}
447
448#[derive(Debug, Clone, Deserialize, Serialize)]
450pub struct LoggingConfig {
451 pub app_name: String,
453
454 #[serde(with = "log_level_serde")]
456 pub log_level: Level,
457
458 pub log_to_console: bool,
460
461 pub log_to_file: bool,
463
464 pub log_file_path: Option<String>,
466
467 pub json_format: bool,
469}
470
471impl Default for LoggingConfig {
472 fn default() -> Self {
473 Self {
474 app_name: String::from("network-protocol"),
475 log_level: Level::INFO,
476 log_to_console: true,
477 log_to_file: false,
478 log_file_path: None,
479 json_format: false,
480 }
481 }
482}
483
484impl LoggingConfig {
485 pub fn validate(&self) -> Vec<String> {
487 let mut errors = Vec::new();
488
489 if self.app_name.is_empty() {
491 errors.push("Application name cannot be empty".to_string());
492 } else if self.app_name.len() > 64 {
493 errors.push(format!(
494 "Application name too long: {} characters (maximum: 64)",
495 self.app_name.len()
496 ));
497 }
498
499 if self.log_to_file {
501 if let Some(ref path) = self.log_file_path {
502 if let Some(parent) = std::path::Path::new(path).parent() {
504 if !parent.as_os_str().is_empty() && !parent.exists() {
505 errors.push(format!(
506 "Log file directory does not exist: {}",
507 parent.display()
508 ));
509 }
510 }
511 } else {
512 errors.push("log_file_path must be specified when log_to_file is true".to_string());
513 }
514 }
515
516 if !self.log_to_console && !self.log_to_file {
518 errors
519 .push("At least one logging output (console or file) must be enabled".to_string());
520 }
521
522 errors
523 }
524}
525
526mod duration_serde {
528 use serde::{Deserialize, Deserializer, Serialize, Serializer};
529 use std::time::Duration;
530
531 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
532 where
533 S: Serializer,
534 {
535 let millis = duration.as_millis() as u64;
536 millis.serialize(serializer)
537 }
538
539 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
540 where
541 D: Deserializer<'de>,
542 {
543 let millis = u64::deserialize(deserializer)?;
544 Ok(Duration::from_millis(millis))
545 }
546}
547
548mod log_level_serde {
550 use serde::{Deserialize, Deserializer, Serialize, Serializer};
551 use std::str::FromStr;
552 use tracing::Level;
553
554 pub fn serialize<S>(level: &Level, serializer: S) -> Result<S::Ok, S::Error>
555 where
556 S: Serializer,
557 {
558 let level_str = match *level {
559 Level::TRACE => "trace",
560 Level::DEBUG => "debug",
561 Level::INFO => "info",
562 Level::WARN => "warn",
563 Level::ERROR => "error",
564 };
565 level_str.serialize(serializer)
566 }
567
568 pub fn deserialize<'de, D>(deserializer: D) -> Result<Level, D::Error>
569 where
570 D: Deserializer<'de>,
571 {
572 let level_str = String::deserialize(deserializer)?;
573 Level::from_str(&level_str)
574 .map_err(|_| serde::de::Error::custom(format!("Invalid log level: {level_str}")))
575 }
576}