1use crate::error::{ProtocolError, Result};
19use crate::utils::timeout;
20use serde::{Deserialize, Serialize};
21use std::fs::File;
22use std::io::Read;
23use std::path::Path;
24use std::time::Duration;
25use tracing::Level;
26
27pub const PROTOCOL_VERSION: u8 = 1;
29
30pub const MAGIC_BYTES: [u8; 4] = [0x4E, 0x50, 0x52, 0x4F];
32
33pub const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
35
36pub const ENABLE_COMPRESSION: bool = false;
38
39pub const ENABLE_ENCRYPTION: bool = true;
41
42#[derive(Debug, Clone, Deserialize, Serialize, Default)]
44pub struct NetworkConfig {
45 #[serde(default)]
47 pub server: ServerConfig,
48
49 #[serde(default)]
51 pub client: ClientConfig,
52
53 #[serde(default)]
55 pub transport: TransportConfig,
56
57 #[serde(default)]
59 pub logging: LoggingConfig,
60}
61
62impl NetworkConfig {
65 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
67 let mut file = File::open(path)
68 .map_err(|e| ProtocolError::ConfigError(format!("Failed to open config file: {e}")))?;
69
70 let mut contents = String::new();
71 file.read_to_string(&mut contents)
72 .map_err(|e| ProtocolError::ConfigError(format!("Failed to read config file: {e}")))?;
73
74 Self::from_toml(&contents)
75 }
76
77 pub fn from_toml(content: &str) -> Result<Self> {
79 toml::from_str::<Self>(content)
80 .map_err(|e| ProtocolError::ConfigError(format!("Failed to parse TOML: {e}")))
81 }
82
83 pub fn from_env() -> Result<Self> {
85 let mut config = Self::default();
87
88 if let Ok(addr) = std::env::var("NETWORK_PROTOCOL_SERVER_ADDRESS") {
90 config.server.address = addr;
91 }
92
93 if let Ok(capacity) = std::env::var("NETWORK_PROTOCOL_BACKPRESSURE_LIMIT") {
94 if let Ok(val) = capacity.parse::<usize>() {
95 config.server.backpressure_limit = val;
96 }
97 }
98
99 if let Ok(timeout) = std::env::var("NETWORK_PROTOCOL_CONNECTION_TIMEOUT_MS") {
100 if let Ok(val) = timeout.parse::<u64>() {
101 config.server.connection_timeout = Duration::from_millis(val);
102 config.client.connection_timeout = Duration::from_millis(val);
103 }
104 }
105
106 if let Ok(heartbeat) = std::env::var("NETWORK_PROTOCOL_HEARTBEAT_INTERVAL_MS") {
107 if let Ok(val) = heartbeat.parse::<u64>() {
108 config.server.heartbeat_interval = Duration::from_millis(val);
109 }
110 }
111
112 Ok(config)
115 }
116
117 pub fn default_with_overrides<F>(mutator: F) -> Self
119 where
120 F: FnOnce(&mut Self),
121 {
122 let mut config = Self::default();
123 mutator(&mut config);
124 config
125 }
126
127 pub fn example_config() -> String {
129 toml::to_string_pretty(&Self::default())
130 .unwrap_or_else(|_| String::from("# Failed to generate example config"))
131 }
132
133 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
135 let content = toml::to_string_pretty(self)
136 .map_err(|e| ProtocolError::ConfigError(format!("Failed to serialize config: {e}")))?;
137
138 std::fs::write(path, content)
139 .map_err(|e| ProtocolError::ConfigError(format!("Failed to write config file: {e}")))?;
140
141 Ok(())
142 }
143}
144
145#[derive(Debug, Clone, Deserialize, Serialize)]
147pub struct ServerConfig {
148 pub address: String,
150
151 pub backpressure_limit: usize,
153
154 #[serde(with = "duration_serde")]
156 pub connection_timeout: Duration,
157
158 #[serde(with = "duration_serde")]
160 pub heartbeat_interval: Duration,
161
162 #[serde(with = "duration_serde")]
164 pub shutdown_timeout: Duration,
165
166 pub max_connections: usize,
168}
169
170impl Default for ServerConfig {
171 fn default() -> Self {
172 Self {
173 address: String::from("127.0.0.1:9000"),
174 backpressure_limit: 32,
175 connection_timeout: timeout::DEFAULT_TIMEOUT,
176 heartbeat_interval: timeout::KEEPALIVE_INTERVAL,
177 shutdown_timeout: timeout::SHUTDOWN_TIMEOUT,
178 max_connections: 1000,
179 }
180 }
181}
182
183#[derive(Debug, Clone, Deserialize, Serialize)]
185pub struct ClientConfig {
186 pub address: String,
188
189 #[serde(with = "duration_serde")]
191 pub connection_timeout: Duration,
192
193 #[serde(with = "duration_serde")]
195 pub operation_timeout: Duration,
196
197 #[serde(with = "duration_serde")]
199 pub response_timeout: Duration,
200
201 #[serde(with = "duration_serde")]
203 pub heartbeat_interval: Duration,
204
205 pub auto_reconnect: bool,
207
208 pub max_reconnect_attempts: u32,
210
211 #[serde(with = "duration_serde")]
213 pub reconnect_delay: Duration,
214}
215
216impl Default for ClientConfig {
217 fn default() -> Self {
218 Self {
219 address: String::from("127.0.0.1:9000"),
220 connection_timeout: timeout::DEFAULT_TIMEOUT,
221 operation_timeout: Duration::from_secs(3),
222 response_timeout: Duration::from_secs(30),
223 heartbeat_interval: timeout::KEEPALIVE_INTERVAL,
224 auto_reconnect: true,
225 max_reconnect_attempts: 3,
226 reconnect_delay: Duration::from_secs(1),
227 }
228 }
229}
230
231#[derive(Debug, Clone, Deserialize, Serialize)]
233pub struct TransportConfig {
234 pub compression_enabled: bool,
236
237 pub encryption_enabled: bool,
239
240 pub max_payload_size: usize,
242
243 pub compression_level: i32,
245
246 #[serde(default)]
249 pub compression_threshold_bytes: usize,
250}
251
252impl Default for TransportConfig {
253 fn default() -> Self {
254 Self {
255 compression_enabled: ENABLE_COMPRESSION,
256 encryption_enabled: ENABLE_ENCRYPTION,
257 max_payload_size: MAX_PAYLOAD_SIZE,
258 compression_level: 6, compression_threshold_bytes: 512,
260 }
261 }
262}
263
264#[derive(Debug, Clone, Deserialize, Serialize)]
266pub struct LoggingConfig {
267 pub app_name: String,
269
270 #[serde(with = "log_level_serde")]
272 pub log_level: Level,
273
274 pub log_to_console: bool,
276
277 pub log_to_file: bool,
279
280 pub log_file_path: Option<String>,
282
283 pub json_format: bool,
285}
286
287impl Default for LoggingConfig {
288 fn default() -> Self {
289 Self {
290 app_name: String::from("network-protocol"),
291 log_level: Level::INFO,
292 log_to_console: true,
293 log_to_file: false,
294 log_file_path: None,
295 json_format: false,
296 }
297 }
298}
299
300mod duration_serde {
302 use serde::{Deserialize, Deserializer, Serialize, Serializer};
303 use std::time::Duration;
304
305 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
306 where
307 S: Serializer,
308 {
309 let millis = duration.as_millis() as u64;
310 millis.serialize(serializer)
311 }
312
313 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
314 where
315 D: Deserializer<'de>,
316 {
317 let millis = u64::deserialize(deserializer)?;
318 Ok(Duration::from_millis(millis))
319 }
320}
321
322mod log_level_serde {
324 use serde::{Deserialize, Deserializer, Serialize, Serializer};
325 use std::str::FromStr;
326 use tracing::Level;
327
328 pub fn serialize<S>(level: &Level, serializer: S) -> Result<S::Ok, S::Error>
329 where
330 S: Serializer,
331 {
332 let level_str = match *level {
333 Level::TRACE => "trace",
334 Level::DEBUG => "debug",
335 Level::INFO => "info",
336 Level::WARN => "warn",
337 Level::ERROR => "error",
338 };
339 level_str.serialize(serializer)
340 }
341
342 pub fn deserialize<'de, D>(deserializer: D) -> Result<Level, D::Error>
343 where
344 D: Deserializer<'de>,
345 {
346 let level_str = String::deserialize(deserializer)?;
347 Level::from_str(&level_str)
348 .map_err(|_| serde::de::Error::custom(format!("Invalid log level: {level_str}")))
349 }
350}