1use crate::config::SiteConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
8pub struct ServerConfig {
9 pub server: ServerInfo,
10 pub sites: Vec<SiteConfig>,
11 #[serde(default)]
12 pub logging: LoggingConfig,
13 #[serde(default)]
14 pub performance: PerformanceConfig,
15 #[serde(default)]
16 pub security: SecurityConfig,
17}
18
19#[derive(Debug, Deserialize, Serialize, Clone)]
20pub struct ServerInfo {
21 pub name: String,
22 #[serde(default = "default_version")]
23 pub version: String,
24 #[serde(default)]
25 pub description: String,
26}
27
28#[derive(Debug, Deserialize, Serialize, Clone)]
29pub struct LoggingConfig {
30 #[serde(default = "default_log_level")]
31 pub level: String,
32 #[serde(default)]
33 pub access_log: Option<String>,
34 #[serde(default)]
35 pub error_log: Option<String>,
36 #[serde(default = "default_log_format")]
37 pub format: String,
38 #[serde(default)]
39 pub log_requests: bool,
40}
41
42#[derive(Debug, Deserialize, Serialize, Clone)]
43pub struct PerformanceConfig {
44 #[serde(default = "default_worker_threads")]
45 pub worker_threads: usize,
46 #[serde(default = "default_max_connections")]
47 pub max_connections: usize,
48 #[serde(default = "default_keep_alive_timeout")]
49 pub keep_alive_timeout: u64,
50 #[serde(default = "default_request_timeout")]
51 pub request_timeout: u64,
52 #[serde(default = "default_buffer_size")]
53 pub read_buffer_size: String,
54 #[serde(default = "default_buffer_size")]
55 pub write_buffer_size: String,
56}
57
58#[derive(Debug, Deserialize, Serialize, Clone)]
59pub struct SecurityConfig {
60 #[serde(default)]
61 pub hide_server_header: bool,
62 #[serde(default = "default_max_request_size")]
63 pub max_request_size: String,
64 #[serde(default)]
65 pub allowed_origins: Vec<String>,
66 #[serde(default)]
67 pub security_headers: HashMap<String, String>,
68 #[serde(default)]
69 pub rate_limiting: Option<RateLimitConfig>,
70}
71
72#[derive(Debug, Deserialize, Serialize, Clone)]
73pub struct RateLimitConfig {
74 pub requests_per_minute: u32,
75 pub burst_size: u32,
76 #[serde(default)]
77 pub whitelist: Vec<String>,
78}
79
80fn default_version() -> String {
82 env!("CARGO_PKG_VERSION").to_string()
83}
84
85fn default_log_level() -> String {
86 "info".to_string()
87}
88
89fn default_log_format() -> String {
90 "combined".to_string()
91}
92
93fn default_worker_threads() -> usize {
94 num_cpus::get().max(1)
95}
96
97fn default_max_connections() -> usize {
98 1000
99}
100
101fn default_keep_alive_timeout() -> u64 {
102 60
103}
104
105fn default_request_timeout() -> u64 {
106 30
107}
108
109fn default_buffer_size() -> String {
110 "32KB".to_string()
111}
112
113fn default_max_request_size() -> String {
114 "10MB".to_string()
115}
116
117impl Default for LoggingConfig {
118 fn default() -> Self {
119 Self {
120 level: default_log_level(),
121 access_log: None,
122 error_log: None,
123 format: default_log_format(),
124 log_requests: true,
125 }
126 }
127}
128
129impl Default for PerformanceConfig {
130 fn default() -> Self {
131 Self {
132 worker_threads: default_worker_threads(),
133 max_connections: default_max_connections(),
134 keep_alive_timeout: default_keep_alive_timeout(),
135 request_timeout: default_request_timeout(),
136 read_buffer_size: default_buffer_size(),
137 write_buffer_size: default_buffer_size(),
138 }
139 }
140}
141
142impl Default for SecurityConfig {
143 fn default() -> Self {
144 let mut security_headers = HashMap::new();
145 security_headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
146 security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
147 security_headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
148 security_headers.insert(
149 "Referrer-Policy".to_string(),
150 "strict-origin-when-cross-origin".to_string(),
151 );
152
153 Self {
154 hide_server_header: false,
155 max_request_size: default_max_request_size(),
156 allowed_origins: vec![],
157 security_headers,
158 rate_limiting: None,
159 }
160 }
161}
162
163impl ServerConfig {
164 pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
165 let content = fs::read_to_string(path)?;
166 let mut config: ServerConfig = toml::from_str(&content)?;
167
168 config.validate()?;
170
171 config.post_process()?;
173
174 Ok(config)
175 }
176
177 pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
178 let content = toml::to_string_pretty(self)?;
179
180 if let Some(parent) = Path::new(path).parent() {
182 fs::create_dir_all(parent)?;
183 }
184
185 fs::write(path, content)?;
186 log::info!("Configuration saved to {}", path);
187 Ok(())
188 }
189
190 pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
191 if self.server.name.is_empty() {
193 return Err("Server name cannot be empty".into());
194 }
195
196 if self.sites.is_empty() {
198 return Err("At least one site must be configured".into());
199 }
200
201 let mut default_sites = 0;
202 let mut used_ports = std::collections::HashSet::new();
203
204 for (i, site) in self.sites.iter().enumerate() {
205 site.validate().map_err(|e| format!("Site {}: {}", i, e))?;
206
207 if site.default {
208 default_sites += 1;
209 }
210
211 let port_key = (&site.hostname, site.port);
213 if used_ports.contains(&port_key) {
214 return Err(format!(
215 "Duplicate hostname:port combination: {}:{}",
216 site.hostname, site.port
217 )
218 .into());
219 }
220 used_ports.insert(port_key);
221 }
222
223 if default_sites == 0 {
224 return Err("At least one site must be marked as default".into());
225 }
226 if default_sites > 1 {
227 return Err("Only one site can be marked as default".into());
228 }
229
230 for site in &self.sites {
232 if site.ssl.enabled {
233 if site.ssl.auto_cert {
234 if let Some(acme) = &site.ssl.acme {
235 if acme.email.is_empty() {
236 return Err(format!(
237 "Site '{}': ACME email is required when auto_cert is enabled",
238 site.name
239 )
240 .into());
241 }
242 } else {
243 return Err(format!(
244 "Site '{}': ACME configuration is required when auto_cert is enabled",
245 site.name
246 )
247 .into());
248 }
249 } else if site.ssl.cert_file.is_none() || site.ssl.key_file.is_none() {
250 return Err(format!(
251 "Site '{}': Manual SSL requires both cert_file and key_file",
252 site.name
253 )
254 .into());
255 }
256 }
257 }
258
259 self.performance.validate()?;
261
262 self.security.validate()?;
264
265 Ok(())
266 }
267
268 fn post_process(&mut self) -> Result<(), Box<dyn std::error::Error>> {
269 self.sites.sort_by(|a, b| match (a.default, b.default) {
271 (true, false) => std::cmp::Ordering::Less,
272 (false, true) => std::cmp::Ordering::Greater,
273 _ => a.name.cmp(&b.name),
274 });
275
276 Ok(())
280 }
281
282 pub fn find_site_by_host_port(&self, host: &str, port: u16) -> Option<&SiteConfig> {
283 for site in &self.sites {
285 if site.hostname == host && site.port == port {
286 return Some(site);
287 }
288 }
289
290 for site in &self.sites {
292 if site.port == port {
293 return Some(site);
294 }
295 }
296
297 self.sites.iter().find(|site| site.default)
299 }
300
301 pub fn get_ssl_domains(&self) -> Vec<String> {
302 self.sites
303 .iter()
304 .filter_map(|site| {
305 if site.ssl.enabled {
306 let mut domains = vec![site.hostname.clone()];
307 domains.extend(site.ssl.domains.clone());
308 Some(domains)
309 } else {
310 None
311 }
312 })
313 .flatten()
314 .collect()
315 }
316
317 pub fn get_site_by_domain(&self, domain: &str) -> Option<&SiteConfig> {
318 self.sites
319 .iter()
320 .find(|site| site.hostname == domain || site.ssl.domains.contains(&domain.to_string()))
321 }
322
323 pub fn reload_from_file(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
324 let new_config = Self::load_from_file(path)?;
325 *self = new_config;
326 log::info!("Configuration reloaded from {}", path);
327 Ok(())
328 }
329}
330
331impl PerformanceConfig {
332 fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
333 if self.worker_threads == 0 {
334 return Err("Worker threads must be greater than 0".into());
335 }
336
337 if self.max_connections == 0 {
338 return Err("Max connections must be greater than 0".into());
339 }
340
341 if self.keep_alive_timeout == 0 {
342 return Err("Keep alive timeout must be greater than 0".into());
343 }
344
345 if self.request_timeout == 0 {
346 return Err("Request timeout must be greater than 0".into());
347 }
348
349 self.parse_buffer_size(&self.read_buffer_size)
351 .map_err(|_| "Invalid read buffer size format")?;
352 self.parse_buffer_size(&self.write_buffer_size)
353 .map_err(|_| "Invalid write buffer size format")?;
354
355 Ok(())
356 }
357
358 pub fn parse_buffer_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
359 let size_str = size_str.trim().to_uppercase();
360
361 if let Some(value) = size_str.strip_suffix("KB") {
362 Ok(value.parse::<usize>()? * 1024)
363 } else if let Some(value) = size_str.strip_suffix("MB") {
364 Ok(value.parse::<usize>()? * 1024 * 1024)
365 } else if let Some(value) = size_str.strip_suffix("GB") {
366 Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
367 } else if let Some(value) = size_str.strip_suffix("B") {
368 Ok(value.parse::<usize>()?)
369 } else {
370 Ok(size_str.parse::<usize>()?)
372 }
373 }
374}
375
376impl SecurityConfig {
377 fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
378 self.parse_size(&self.max_request_size)
380 .map_err(|_| "Invalid max request size format")?;
381
382 if let Some(rate_limit) = &self.rate_limiting {
384 if rate_limit.requests_per_minute == 0 {
385 return Err("Rate limit requests per minute must be greater than 0".into());
386 }
387 if rate_limit.burst_size == 0 {
388 return Err("Rate limit burst size must be greater than 0".into());
389 }
390 }
391
392 Ok(())
393 }
394
395 pub fn parse_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
396 let size_str = size_str.trim().to_uppercase();
397
398 if let Some(value) = size_str.strip_suffix("KB") {
399 Ok(value.parse::<usize>()? * 1024)
400 } else if let Some(value) = size_str.strip_suffix("MB") {
401 Ok(value.parse::<usize>()? * 1024 * 1024)
402 } else if let Some(value) = size_str.strip_suffix("GB") {
403 Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
404 } else if let Some(value) = size_str.strip_suffix("B") {
405 Ok(value.parse::<usize>()?)
406 } else {
407 Ok(size_str.parse::<usize>()?)
409 }
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use tempfile::NamedTempFile;
417
418 #[test]
419 fn test_buffer_size_parsing() {
420 let config = PerformanceConfig::default();
421
422 assert_eq!(config.parse_buffer_size("1024").unwrap(), 1024);
423 assert_eq!(config.parse_buffer_size("1KB").unwrap(), 1024);
424 assert_eq!(config.parse_buffer_size("1MB").unwrap(), 1024 * 1024);
425 assert_eq!(config.parse_buffer_size("1GB").unwrap(), 1024 * 1024 * 1024);
426 assert_eq!(config.parse_buffer_size("500B").unwrap(), 500);
427
428 assert!(config.parse_buffer_size("invalid").is_err());
429 assert!(config.parse_buffer_size("").is_err());
430 }
431
432 #[test]
433 fn test_security_config_validation() {
434 let mut config = SecurityConfig::default();
435 assert!(config.validate().is_ok());
436
437 config.max_request_size = "invalid".to_string();
438 assert!(config.validate().is_err());
439
440 config.max_request_size = "10MB".to_string();
441 config.rate_limiting = Some(RateLimitConfig {
442 requests_per_minute: 0,
443 burst_size: 10,
444 whitelist: vec![],
445 });
446 assert!(config.validate().is_err());
447 }
448
449 #[tokio::test]
450 async fn test_config_save_load() {
451 use crate::config::SiteConfig;
452
453 let config = ServerConfig {
454 server: ServerInfo {
455 name: "test-server".to_string(),
456 version: "1.0.0".to_string(),
457 description: "Test server".to_string(),
458 },
459 sites: vec![SiteConfig {
460 name: "test-site".to_string(),
461 hostname: "localhost".to_string(),
462 port: 8080,
463 static_dir: "/tmp/static".to_string(),
464 default: true,
465 api_only: false,
466 headers: HashMap::new(),
467 redirect_to_https: false,
468 index_files: vec!["index.html".to_string()],
469 error_pages: HashMap::new(),
470 compression: Default::default(),
471 cache: Default::default(),
472 access_control: Default::default(),
473 ssl: Default::default(),
474 proxy: Default::default(),
475 }],
476 logging: LoggingConfig::default(),
477 performance: PerformanceConfig::default(),
478 security: SecurityConfig::default(),
479 };
480
481 let temp_file = NamedTempFile::new().unwrap();
482 let path = temp_file.path().to_str().unwrap();
483
484 config.save_to_file(path).unwrap();
486
487 let loaded_config = ServerConfig::load_from_file(path).unwrap();
489 assert_eq!(config.server.name, loaded_config.server.name);
490 assert_eq!(config.sites.len(), loaded_config.sites.len());
491 }
492}