1use crate::config::SiteConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
10pub struct ServerConfig {
11 pub server: ServerInfo,
13 pub sites: Vec<SiteConfig>,
15 #[serde(default)]
17 pub logging: LoggingConfig,
18 #[serde(default)]
20 pub performance: PerformanceConfig,
21 #[serde(default)]
23 pub security: SecurityConfig,
24 #[serde(default)]
26 pub management: ManagementConfig,
27}
28
29#[derive(Debug, Deserialize, Serialize, Clone)]
31pub struct ServerInfo {
32 pub name: String,
34 #[serde(default = "default_version")]
36 pub version: String,
37 #[serde(default)]
39 pub description: String,
40}
41
42#[derive(Debug, Deserialize, Serialize, Clone)]
44pub struct LoggingConfig {
45 #[serde(default = "default_log_level")]
47 pub level: String,
48 #[serde(default)]
50 pub access_log: Option<String>,
51 #[serde(default)]
53 pub error_log: Option<String>,
54 #[serde(default = "default_log_format")]
56 pub format: String,
57 #[serde(default)]
59 pub log_requests: bool,
60}
61
62#[derive(Debug, Deserialize, Serialize, Clone)]
64pub struct PerformanceConfig {
65 #[serde(default = "default_worker_threads")]
67 pub worker_threads: usize,
68 #[serde(default = "default_max_connections")]
70 pub max_connections: usize,
71 #[serde(default = "default_keep_alive_timeout")]
73 pub keep_alive_timeout: u64,
74 #[serde(default = "default_request_timeout")]
76 pub request_timeout: u64,
77 #[serde(default = "default_buffer_size")]
79 pub read_buffer_size: String,
80 #[serde(default = "default_buffer_size")]
82 pub write_buffer_size: String,
83}
84
85#[derive(Debug, Deserialize, Serialize, Clone)]
87pub struct SecurityConfig {
88 #[serde(default)]
90 pub hide_server_header: bool,
91 #[serde(default = "default_max_request_size")]
93 pub max_request_size: String,
94 #[serde(default)]
96 pub allowed_origins: Vec<String>,
97 #[serde(default)]
99 pub security_headers: HashMap<String, String>,
100 #[serde(default)]
102 pub rate_limiting: Option<RateLimitConfig>,
103}
104
105#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
107pub struct RateLimitConfig {
108 pub requests_per_minute: u32,
110 pub burst_size: u32,
112 #[serde(default)]
114 pub whitelist: Vec<String>,
115}
116
117#[derive(Debug, Deserialize, Serialize, Clone)]
119pub struct ManagementConfig {
120 #[serde(default = "default_management_enabled")]
122 pub enabled: bool,
123 #[serde(default = "default_management_host")]
125 pub host: String,
126 #[serde(default = "default_management_port")]
128 pub port: u16,
129 #[serde(default)]
131 pub api_key: Option<String>,
132}
133
134fn default_version() -> String {
136 env!("CARGO_PKG_VERSION").to_string()
137}
138
139fn default_log_level() -> String {
140 "info".to_string()
141}
142
143fn default_log_format() -> String {
144 "combined".to_string()
145}
146
147fn default_worker_threads() -> usize {
148 num_cpus::get().max(1)
149}
150
151fn default_max_connections() -> usize {
152 1000
153}
154
155fn default_keep_alive_timeout() -> u64 {
156 60
157}
158
159fn default_request_timeout() -> u64 {
160 30
161}
162
163fn default_buffer_size() -> String {
164 "32KB".to_string()
165}
166
167fn default_max_request_size() -> String {
168 "10MB".to_string()
169}
170
171fn default_management_enabled() -> bool {
172 false
173}
174
175fn default_management_host() -> String {
176 "127.0.0.1".to_string()
177}
178
179fn default_management_port() -> u16 {
180 7654
181}
182
183impl Default for LoggingConfig {
184 fn default() -> Self {
185 Self {
186 level: default_log_level(),
187 access_log: None,
188 error_log: None,
189 format: default_log_format(),
190 log_requests: true,
191 }
192 }
193}
194
195impl Default for PerformanceConfig {
196 fn default() -> Self {
197 Self {
198 worker_threads: default_worker_threads(),
199 max_connections: default_max_connections(),
200 keep_alive_timeout: default_keep_alive_timeout(),
201 request_timeout: default_request_timeout(),
202 read_buffer_size: default_buffer_size(),
203 write_buffer_size: default_buffer_size(),
204 }
205 }
206}
207
208impl Default for SecurityConfig {
209 fn default() -> Self {
210 let mut security_headers = HashMap::new();
211 security_headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
212 security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
213 security_headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
214 security_headers.insert(
215 "Referrer-Policy".to_string(),
216 "strict-origin-when-cross-origin".to_string(),
217 );
218
219 Self {
220 hide_server_header: false,
221 max_request_size: default_max_request_size(),
222 allowed_origins: vec![],
223 security_headers,
224 rate_limiting: None,
225 }
226 }
227}
228
229impl Default for ManagementConfig {
230 fn default() -> Self {
231 Self {
232 enabled: default_management_enabled(),
233 host: default_management_host(),
234 port: default_management_port(),
235 api_key: None,
236 }
237 }
238}
239
240impl ServerConfig {
241 pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
242 let content = fs::read_to_string(path)?;
243 let mut config: ServerConfig = toml::from_str(&content)?;
244
245 config.post_process()?;
247
248 config.validate()?;
250
251 Ok(config)
252 }
253
254 pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
255 let content = toml::to_string_pretty(self)?;
256
257 if let Some(parent) = Path::new(path).parent() {
259 fs::create_dir_all(parent)?;
260 }
261
262 fs::write(path, content)?;
263 log::info!("Configuration saved to {}", path);
264 Ok(())
265 }
266
267 pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
268 if self.server.name.is_empty() {
270 return Err("Server name cannot be empty".into());
271 }
272
273 if self.sites.is_empty() {
275 return Err("At least one site must be configured".into());
276 }
277
278 let mut default_sites = 0;
279 let mut used_hostname_ports = std::collections::HashSet::new();
280
281 for (i, site) in self.sites.iter().enumerate() {
282 site.validate().map_err(|e| format!("Site {}: {}", i, e))?;
283
284 if site.default {
285 default_sites += 1;
286 }
287
288 for hostname in site.get_all_hostnames() {
291 let hostname_port_key = (hostname, site.port);
292 if used_hostname_ports.contains(&hostname_port_key) {
293 return Err(format!(
294 "Duplicate hostname:port combination: {}:{}. Each hostname must be unique per port.",
295 hostname, site.port
296 )
297 .into());
298 }
299 used_hostname_ports.insert(hostname_port_key);
300 }
301 }
302
303 if default_sites == 0 {
304 if self.sites.len() == 1 {
306 return Err(
308 "Internal error: single site was not marked as default during post-processing"
309 .into(),
310 );
311 } else {
312 return Err("At least one site must be marked as default when multiple sites are configured".into());
313 }
314 }
315 if default_sites > 1 {
316 return Err("Only one site can be marked as default".into());
317 }
318
319 for site in &self.sites {
321 if site.ssl.enabled {
322 if site.ssl.auto_cert {
323 if let Some(acme) = &site.ssl.acme {
324 if acme.email.is_empty() {
325 return Err(format!(
326 "Site '{}': ACME email is required when auto_cert is enabled",
327 site.name
328 )
329 .into());
330 }
331 } else {
332 return Err(format!(
333 "Site '{}': ACME configuration is required when auto_cert is enabled",
334 site.name
335 )
336 .into());
337 }
338 } else if site.ssl.cert_file.is_none() || site.ssl.key_file.is_none() {
339 return Err(format!(
340 "Site '{}': Manual SSL requires both cert_file and key_file",
341 site.name
342 )
343 .into());
344 }
345 }
346 }
347
348 self.performance.validate()?;
350
351 self.security.validate()?;
353
354 Ok(())
355 }
356
357 fn post_process(&mut self) -> Result<(), Box<dyn std::error::Error>> {
358 if self.sites.len() == 1 && !self.sites[0].default {
361 log::info!(
362 "Automatically setting single site '{}' as default",
363 self.sites[0].name
364 );
365 self.sites[0].default = true;
366 }
367
368 self.sites.sort_by(|a, b| match (a.default, b.default) {
370 (true, false) => std::cmp::Ordering::Less,
371 (false, true) => std::cmp::Ordering::Greater,
372 _ => a.name.cmp(&b.name),
373 });
374
375 Ok(())
379 }
380
381 pub fn find_site_by_host_port(&self, host: &str, port: u16) -> Option<&SiteConfig> {
382 for site in &self.sites {
384 if site.handles_hostname_port(host, port) {
385 return Some(site);
386 }
387 }
388
389 for site in &self.sites {
391 if site.port == port {
392 return Some(site);
393 }
394 }
395
396 self.sites.iter().find(|site| site.default)
398 }
399
400 pub fn get_ssl_domains(&self) -> Vec<String> {
401 self.sites
402 .iter()
403 .filter_map(|site| {
404 if site.ssl.enabled {
405 Some(
406 site.get_all_ssl_domains()
407 .into_iter()
408 .map(|s| s.to_string())
409 .collect::<Vec<_>>(),
410 )
411 } else {
412 None
413 }
414 })
415 .flatten()
416 .collect()
417 }
418
419 pub fn get_site_by_domain(&self, domain: &str) -> Option<&SiteConfig> {
420 self.sites.iter().find(|site| {
421 site.handles_hostname(domain) || site.ssl.domains.contains(&domain.to_string())
422 })
423 }
424
425 pub fn reload_from_file(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
426 let new_config = Self::load_from_file(path)?;
427 *self = new_config;
428 log::info!("Configuration reloaded from {}", path);
429 Ok(())
430 }
431}
432
433impl PerformanceConfig {
434 fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
435 if self.worker_threads == 0 {
436 return Err("Worker threads must be greater than 0".into());
437 }
438
439 if self.max_connections == 0 {
440 return Err("Max connections must be greater than 0".into());
441 }
442
443 if self.keep_alive_timeout == 0 {
444 return Err("Keep alive timeout must be greater than 0".into());
445 }
446
447 if self.request_timeout == 0 {
448 return Err("Request timeout must be greater than 0".into());
449 }
450
451 self.parse_buffer_size(&self.read_buffer_size)
453 .map_err(|_| "Invalid read buffer size format")?;
454 self.parse_buffer_size(&self.write_buffer_size)
455 .map_err(|_| "Invalid write buffer size format")?;
456
457 Ok(())
458 }
459
460 pub fn parse_buffer_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
461 let size_str = size_str.trim().to_uppercase();
462
463 if let Some(value) = size_str.strip_suffix("KB") {
464 Ok(value.parse::<usize>()? * 1024)
465 } else if let Some(value) = size_str.strip_suffix("MB") {
466 Ok(value.parse::<usize>()? * 1024 * 1024)
467 } else if let Some(value) = size_str.strip_suffix("GB") {
468 Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
469 } else if let Some(value) = size_str.strip_suffix("B") {
470 Ok(value.parse::<usize>()?)
471 } else {
472 Ok(size_str.parse::<usize>()?)
474 }
475 }
476}
477
478impl SecurityConfig {
479 fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
480 self.parse_size(&self.max_request_size)
482 .map_err(|_| "Invalid max request size format")?;
483
484 if let Some(rate_limit) = &self.rate_limiting {
486 if rate_limit.requests_per_minute == 0 {
487 return Err("Rate limit requests per minute must be greater than 0".into());
488 }
489 if rate_limit.burst_size == 0 {
490 return Err("Rate limit burst size must be greater than 0".into());
491 }
492 }
493
494 Ok(())
495 }
496
497 pub fn parse_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
498 let size_str = size_str.trim().to_uppercase();
499
500 if let Some(value) = size_str.strip_suffix("KB") {
501 Ok(value.parse::<usize>()? * 1024)
502 } else if let Some(value) = size_str.strip_suffix("MB") {
503 Ok(value.parse::<usize>()? * 1024 * 1024)
504 } else if let Some(value) = size_str.strip_suffix("GB") {
505 Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
506 } else if let Some(value) = size_str.strip_suffix("B") {
507 Ok(value.parse::<usize>()?)
508 } else {
509 Ok(size_str.parse::<usize>()?)
511 }
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use tempfile::NamedTempFile;
519
520 #[test]
521 fn test_buffer_size_parsing() {
522 let config = PerformanceConfig::default();
523
524 assert_eq!(config.parse_buffer_size("1024").unwrap(), 1024);
525 assert_eq!(config.parse_buffer_size("1KB").unwrap(), 1024);
526 assert_eq!(config.parse_buffer_size("1MB").unwrap(), 1024 * 1024);
527 assert_eq!(config.parse_buffer_size("1GB").unwrap(), 1024 * 1024 * 1024);
528 assert_eq!(config.parse_buffer_size("500B").unwrap(), 500);
529
530 assert!(config.parse_buffer_size("invalid").is_err());
531 assert!(config.parse_buffer_size("").is_err());
532 }
533
534 #[test]
535 fn test_automatic_default_site() {
536 use crate::config::SiteConfig;
537
538 let mut config = ServerConfig {
540 server: ServerInfo {
541 name: "test-server".to_string(),
542 version: "1.0.0".to_string(),
543 description: "Test server".to_string(),
544 },
545 sites: vec![SiteConfig {
546 name: "single-site".to_string(),
547 hostname: "localhost".to_string(),
548 hostnames: vec![],
549 port: 8080,
550 static_dir: "/tmp/static".to_string(),
551 default: false, api_only: false,
553 headers: HashMap::new(),
554 redirect_to_https: false,
555 index_files: vec!["index.html".to_string()],
556 error_pages: HashMap::new(),
557 compression: Default::default(),
558 cache: Default::default(),
559 access_control: Default::default(),
560 ssl: Default::default(),
561 proxy: Default::default(),
562 }],
563 logging: LoggingConfig::default(),
564 performance: PerformanceConfig::default(),
565 security: SecurityConfig::default(),
566 management: ManagementConfig::default(),
567 };
568
569 assert!(!config.sites[0].default);
571
572 config.post_process().unwrap();
574 assert!(config.sites[0].default);
575
576 assert!(config.validate().is_ok());
578
579 config.sites.push(SiteConfig {
581 name: "second-site".to_string(),
582 hostname: "example.com".to_string(),
583 hostnames: vec![],
584 port: 8081,
585 static_dir: "/tmp/static2".to_string(),
586 default: false,
587 api_only: false,
588 headers: HashMap::new(),
589 redirect_to_https: false,
590 index_files: vec!["index.html".to_string()],
591 error_pages: HashMap::new(),
592 compression: Default::default(),
593 cache: Default::default(),
594 access_control: Default::default(),
595 ssl: Default::default(),
596 proxy: Default::default(),
597 });
598
599 config.sites[0].default = false;
601
602 config.post_process().unwrap();
604 assert!(!config.sites[0].default);
605 assert!(!config.sites[1].default);
606
607 assert!(config.validate().is_err());
609 }
610
611 #[test]
612 fn test_security_config_validation() {
613 let mut config = SecurityConfig::default();
614 assert!(config.validate().is_ok());
615
616 config.max_request_size = "invalid".to_string();
617 assert!(config.validate().is_err());
618
619 config.max_request_size = "10MB".to_string();
620 config.rate_limiting = Some(RateLimitConfig {
621 requests_per_minute: 0,
622 burst_size: 10,
623 whitelist: vec![],
624 });
625 assert!(config.validate().is_err());
626 }
627
628 #[tokio::test]
629 async fn test_config_save_load() {
630 use crate::config::SiteConfig;
631
632 let config = ServerConfig {
633 server: ServerInfo {
634 name: "test-server".to_string(),
635 version: "1.0.0".to_string(),
636 description: "Test server".to_string(),
637 },
638 sites: vec![SiteConfig {
639 name: "test-site".to_string(),
640 hostname: "localhost".to_string(),
641 hostnames: vec![],
642 port: 8080,
643 static_dir: "/tmp/static".to_string(),
644 default: true,
645 api_only: false,
646 headers: HashMap::new(),
647 redirect_to_https: false,
648 index_files: vec!["index.html".to_string()],
649 error_pages: HashMap::new(),
650 compression: Default::default(),
651 cache: Default::default(),
652 access_control: Default::default(),
653 ssl: Default::default(),
654 proxy: Default::default(),
655 }],
656 logging: LoggingConfig::default(),
657 performance: PerformanceConfig::default(),
658 security: SecurityConfig::default(),
659 management: ManagementConfig::default(),
660 };
661
662 let temp_file = NamedTempFile::new().unwrap();
663 let path = temp_file.path().to_str().unwrap();
664
665 config.save_to_file(path).unwrap();
667
668 let loaded_config = ServerConfig::load_from_file(path).unwrap();
670 assert_eq!(config.server.name, loaded_config.server.name);
671 assert_eq!(config.sites.len(), loaded_config.sites.len());
672 }
673}