1use crate::config::SiteConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
8pub struct ServerConfig {
9 pub server: ServerInfo,
10 pub sites: Vec<SiteConfig>,
11 #[serde(default)]
12 pub logging: LoggingConfig,
13 #[serde(default)]
14 pub performance: PerformanceConfig,
15 #[serde(default)]
16 pub security: SecurityConfig,
17 #[serde(default)]
18 pub management: ManagementConfig,
19}
20
21#[derive(Debug, Deserialize, Serialize, Clone)]
22pub struct ServerInfo {
23 pub name: String,
24 #[serde(default = "default_version")]
25 pub version: String,
26 #[serde(default)]
27 pub description: String,
28}
29
30#[derive(Debug, Deserialize, Serialize, Clone)]
31pub struct LoggingConfig {
32 #[serde(default = "default_log_level")]
33 pub level: String,
34 #[serde(default)]
35 pub access_log: Option<String>,
36 #[serde(default)]
37 pub error_log: Option<String>,
38 #[serde(default = "default_log_format")]
39 pub format: String,
40 #[serde(default)]
41 pub log_requests: bool,
42}
43
44#[derive(Debug, Deserialize, Serialize, Clone)]
45pub struct PerformanceConfig {
46 #[serde(default = "default_worker_threads")]
47 pub worker_threads: usize,
48 #[serde(default = "default_max_connections")]
49 pub max_connections: usize,
50 #[serde(default = "default_keep_alive_timeout")]
51 pub keep_alive_timeout: u64,
52 #[serde(default = "default_request_timeout")]
53 pub request_timeout: u64,
54 #[serde(default = "default_buffer_size")]
55 pub read_buffer_size: String,
56 #[serde(default = "default_buffer_size")]
57 pub write_buffer_size: String,
58}
59
60#[derive(Debug, Deserialize, Serialize, Clone)]
61pub struct SecurityConfig {
62 #[serde(default)]
63 pub hide_server_header: bool,
64 #[serde(default = "default_max_request_size")]
65 pub max_request_size: String,
66 #[serde(default)]
67 pub allowed_origins: Vec<String>,
68 #[serde(default)]
69 pub security_headers: HashMap<String, String>,
70 #[serde(default)]
71 pub rate_limiting: Option<RateLimitConfig>,
72}
73
74#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
75pub struct RateLimitConfig {
76 pub requests_per_minute: u32,
77 pub burst_size: u32,
78 #[serde(default)]
79 pub whitelist: Vec<String>,
80}
81
82#[derive(Debug, Deserialize, Serialize, Clone)]
83pub struct ManagementConfig {
84 #[serde(default = "default_management_enabled")]
85 pub enabled: bool,
86 #[serde(default = "default_management_host")]
87 pub host: String,
88 #[serde(default = "default_management_port")]
89 pub port: u16,
90 #[serde(default)]
91 pub api_key: Option<String>,
92}
93
94fn default_version() -> String {
96 env!("CARGO_PKG_VERSION").to_string()
97}
98
99fn default_log_level() -> String {
100 "info".to_string()
101}
102
103fn default_log_format() -> String {
104 "combined".to_string()
105}
106
107fn default_worker_threads() -> usize {
108 num_cpus::get().max(1)
109}
110
111fn default_max_connections() -> usize {
112 1000
113}
114
115fn default_keep_alive_timeout() -> u64 {
116 60
117}
118
119fn default_request_timeout() -> u64 {
120 30
121}
122
123fn default_buffer_size() -> String {
124 "32KB".to_string()
125}
126
127fn default_max_request_size() -> String {
128 "10MB".to_string()
129}
130
131fn default_management_enabled() -> bool {
132 true
133}
134
135fn default_management_host() -> String {
136 "127.0.0.1".to_string()
137}
138
139fn default_management_port() -> u16 {
140 7654
141}
142
143impl Default for LoggingConfig {
144 fn default() -> Self {
145 Self {
146 level: default_log_level(),
147 access_log: None,
148 error_log: None,
149 format: default_log_format(),
150 log_requests: true,
151 }
152 }
153}
154
155impl Default for PerformanceConfig {
156 fn default() -> Self {
157 Self {
158 worker_threads: default_worker_threads(),
159 max_connections: default_max_connections(),
160 keep_alive_timeout: default_keep_alive_timeout(),
161 request_timeout: default_request_timeout(),
162 read_buffer_size: default_buffer_size(),
163 write_buffer_size: default_buffer_size(),
164 }
165 }
166}
167
168impl Default for SecurityConfig {
169 fn default() -> Self {
170 let mut security_headers = HashMap::new();
171 security_headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
172 security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
173 security_headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
174 security_headers.insert(
175 "Referrer-Policy".to_string(),
176 "strict-origin-when-cross-origin".to_string(),
177 );
178
179 Self {
180 hide_server_header: false,
181 max_request_size: default_max_request_size(),
182 allowed_origins: vec![],
183 security_headers,
184 rate_limiting: None,
185 }
186 }
187}
188
189impl Default for ManagementConfig {
190 fn default() -> Self {
191 Self {
192 enabled: default_management_enabled(),
193 host: default_management_host(),
194 port: default_management_port(),
195 api_key: None,
196 }
197 }
198}
199
200impl ServerConfig {
201 pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
202 let content = fs::read_to_string(path)?;
203 let mut config: ServerConfig = toml::from_str(&content)?;
204
205 config.post_process()?;
207
208 config.validate()?;
210
211 Ok(config)
212 }
213
214 pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
215 let content = toml::to_string_pretty(self)?;
216
217 if let Some(parent) = Path::new(path).parent() {
219 fs::create_dir_all(parent)?;
220 }
221
222 fs::write(path, content)?;
223 log::info!("Configuration saved to {}", path);
224 Ok(())
225 }
226
227 pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
228 if self.server.name.is_empty() {
230 return Err("Server name cannot be empty".into());
231 }
232
233 if self.sites.is_empty() {
235 return Err("At least one site must be configured".into());
236 }
237
238 let mut default_sites = 0;
239 let mut used_hostname_ports = std::collections::HashSet::new();
240
241 for (i, site) in self.sites.iter().enumerate() {
242 site.validate().map_err(|e| format!("Site {}: {}", i, e))?;
243
244 if site.default {
245 default_sites += 1;
246 }
247
248 for hostname in site.get_all_hostnames() {
251 let hostname_port_key = (hostname, site.port);
252 if used_hostname_ports.contains(&hostname_port_key) {
253 return Err(format!(
254 "Duplicate hostname:port combination: {}:{}. Each hostname must be unique per port.",
255 hostname, site.port
256 )
257 .into());
258 }
259 used_hostname_ports.insert(hostname_port_key);
260 }
261 }
262
263 if default_sites == 0 {
264 if self.sites.len() == 1 {
266 return Err(
268 "Internal error: single site was not marked as default during post-processing"
269 .into(),
270 );
271 } else {
272 return Err("At least one site must be marked as default when multiple sites are configured".into());
273 }
274 }
275 if default_sites > 1 {
276 return Err("Only one site can be marked as default".into());
277 }
278
279 for site in &self.sites {
281 if site.ssl.enabled {
282 if site.ssl.auto_cert {
283 if let Some(acme) = &site.ssl.acme {
284 if acme.email.is_empty() {
285 return Err(format!(
286 "Site '{}': ACME email is required when auto_cert is enabled",
287 site.name
288 )
289 .into());
290 }
291 } else {
292 return Err(format!(
293 "Site '{}': ACME configuration is required when auto_cert is enabled",
294 site.name
295 )
296 .into());
297 }
298 } else if site.ssl.cert_file.is_none() || site.ssl.key_file.is_none() {
299 return Err(format!(
300 "Site '{}': Manual SSL requires both cert_file and key_file",
301 site.name
302 )
303 .into());
304 }
305 }
306 }
307
308 self.performance.validate()?;
310
311 self.security.validate()?;
313
314 Ok(())
315 }
316
317 fn post_process(&mut self) -> Result<(), Box<dyn std::error::Error>> {
318 if self.sites.len() == 1 && !self.sites[0].default {
321 log::info!(
322 "Automatically setting single site '{}' as default",
323 self.sites[0].name
324 );
325 self.sites[0].default = true;
326 }
327
328 self.sites.sort_by(|a, b| match (a.default, b.default) {
330 (true, false) => std::cmp::Ordering::Less,
331 (false, true) => std::cmp::Ordering::Greater,
332 _ => a.name.cmp(&b.name),
333 });
334
335 Ok(())
339 }
340
341 pub fn find_site_by_host_port(&self, host: &str, port: u16) -> Option<&SiteConfig> {
342 for site in &self.sites {
344 if site.handles_hostname_port(host, port) {
345 return Some(site);
346 }
347 }
348
349 for site in &self.sites {
351 if site.port == port {
352 return Some(site);
353 }
354 }
355
356 self.sites.iter().find(|site| site.default)
358 }
359
360 pub fn get_ssl_domains(&self) -> Vec<String> {
361 self.sites
362 .iter()
363 .filter_map(|site| {
364 if site.ssl.enabled {
365 Some(
366 site.get_all_ssl_domains()
367 .into_iter()
368 .map(|s| s.to_string())
369 .collect::<Vec<_>>(),
370 )
371 } else {
372 None
373 }
374 })
375 .flatten()
376 .collect()
377 }
378
379 pub fn get_site_by_domain(&self, domain: &str) -> Option<&SiteConfig> {
380 self.sites.iter().find(|site| {
381 site.handles_hostname(domain) || site.ssl.domains.contains(&domain.to_string())
382 })
383 }
384
385 pub fn reload_from_file(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
386 let new_config = Self::load_from_file(path)?;
387 *self = new_config;
388 log::info!("Configuration reloaded from {}", path);
389 Ok(())
390 }
391}
392
393impl PerformanceConfig {
394 fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
395 if self.worker_threads == 0 {
396 return Err("Worker threads must be greater than 0".into());
397 }
398
399 if self.max_connections == 0 {
400 return Err("Max connections must be greater than 0".into());
401 }
402
403 if self.keep_alive_timeout == 0 {
404 return Err("Keep alive timeout must be greater than 0".into());
405 }
406
407 if self.request_timeout == 0 {
408 return Err("Request timeout must be greater than 0".into());
409 }
410
411 self.parse_buffer_size(&self.read_buffer_size)
413 .map_err(|_| "Invalid read buffer size format")?;
414 self.parse_buffer_size(&self.write_buffer_size)
415 .map_err(|_| "Invalid write buffer size format")?;
416
417 Ok(())
418 }
419
420 pub fn parse_buffer_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
421 let size_str = size_str.trim().to_uppercase();
422
423 if let Some(value) = size_str.strip_suffix("KB") {
424 Ok(value.parse::<usize>()? * 1024)
425 } else if let Some(value) = size_str.strip_suffix("MB") {
426 Ok(value.parse::<usize>()? * 1024 * 1024)
427 } else if let Some(value) = size_str.strip_suffix("GB") {
428 Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
429 } else if let Some(value) = size_str.strip_suffix("B") {
430 Ok(value.parse::<usize>()?)
431 } else {
432 Ok(size_str.parse::<usize>()?)
434 }
435 }
436}
437
438impl SecurityConfig {
439 fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
440 self.parse_size(&self.max_request_size)
442 .map_err(|_| "Invalid max request size format")?;
443
444 if let Some(rate_limit) = &self.rate_limiting {
446 if rate_limit.requests_per_minute == 0 {
447 return Err("Rate limit requests per minute must be greater than 0".into());
448 }
449 if rate_limit.burst_size == 0 {
450 return Err("Rate limit burst size must be greater than 0".into());
451 }
452 }
453
454 Ok(())
455 }
456
457 pub fn parse_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
458 let size_str = size_str.trim().to_uppercase();
459
460 if let Some(value) = size_str.strip_suffix("KB") {
461 Ok(value.parse::<usize>()? * 1024)
462 } else if let Some(value) = size_str.strip_suffix("MB") {
463 Ok(value.parse::<usize>()? * 1024 * 1024)
464 } else if let Some(value) = size_str.strip_suffix("GB") {
465 Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
466 } else if let Some(value) = size_str.strip_suffix("B") {
467 Ok(value.parse::<usize>()?)
468 } else {
469 Ok(size_str.parse::<usize>()?)
471 }
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use tempfile::NamedTempFile;
479
480 #[test]
481 fn test_buffer_size_parsing() {
482 let config = PerformanceConfig::default();
483
484 assert_eq!(config.parse_buffer_size("1024").unwrap(), 1024);
485 assert_eq!(config.parse_buffer_size("1KB").unwrap(), 1024);
486 assert_eq!(config.parse_buffer_size("1MB").unwrap(), 1024 * 1024);
487 assert_eq!(config.parse_buffer_size("1GB").unwrap(), 1024 * 1024 * 1024);
488 assert_eq!(config.parse_buffer_size("500B").unwrap(), 500);
489
490 assert!(config.parse_buffer_size("invalid").is_err());
491 assert!(config.parse_buffer_size("").is_err());
492 }
493
494 #[test]
495 fn test_automatic_default_site() {
496 use crate::config::SiteConfig;
497
498 let mut config = ServerConfig {
500 server: ServerInfo {
501 name: "test-server".to_string(),
502 version: "1.0.0".to_string(),
503 description: "Test server".to_string(),
504 },
505 sites: vec![SiteConfig {
506 name: "single-site".to_string(),
507 hostname: "localhost".to_string(),
508 hostnames: vec![],
509 port: 8080,
510 static_dir: "/tmp/static".to_string(),
511 default: false, api_only: false,
513 headers: HashMap::new(),
514 redirect_to_https: false,
515 index_files: vec!["index.html".to_string()],
516 error_pages: HashMap::new(),
517 compression: Default::default(),
518 cache: Default::default(),
519 access_control: Default::default(),
520 ssl: Default::default(),
521 proxy: Default::default(),
522 }],
523 logging: LoggingConfig::default(),
524 performance: PerformanceConfig::default(),
525 security: SecurityConfig::default(),
526 management: ManagementConfig::default(),
527 };
528
529 assert!(!config.sites[0].default);
531
532 config.post_process().unwrap();
534 assert!(config.sites[0].default);
535
536 assert!(config.validate().is_ok());
538
539 config.sites.push(SiteConfig {
541 name: "second-site".to_string(),
542 hostname: "example.com".to_string(),
543 hostnames: vec![],
544 port: 8081,
545 static_dir: "/tmp/static2".to_string(),
546 default: false,
547 api_only: false,
548 headers: HashMap::new(),
549 redirect_to_https: false,
550 index_files: vec!["index.html".to_string()],
551 error_pages: HashMap::new(),
552 compression: Default::default(),
553 cache: Default::default(),
554 access_control: Default::default(),
555 ssl: Default::default(),
556 proxy: Default::default(),
557 });
558
559 config.sites[0].default = false;
561
562 config.post_process().unwrap();
564 assert!(!config.sites[0].default);
565 assert!(!config.sites[1].default);
566
567 assert!(config.validate().is_err());
569 }
570
571 #[test]
572 fn test_security_config_validation() {
573 let mut config = SecurityConfig::default();
574 assert!(config.validate().is_ok());
575
576 config.max_request_size = "invalid".to_string();
577 assert!(config.validate().is_err());
578
579 config.max_request_size = "10MB".to_string();
580 config.rate_limiting = Some(RateLimitConfig {
581 requests_per_minute: 0,
582 burst_size: 10,
583 whitelist: vec![],
584 });
585 assert!(config.validate().is_err());
586 }
587
588 #[tokio::test]
589 async fn test_config_save_load() {
590 use crate::config::SiteConfig;
591
592 let config = ServerConfig {
593 server: ServerInfo {
594 name: "test-server".to_string(),
595 version: "1.0.0".to_string(),
596 description: "Test server".to_string(),
597 },
598 sites: vec![SiteConfig {
599 name: "test-site".to_string(),
600 hostname: "localhost".to_string(),
601 hostnames: vec![],
602 port: 8080,
603 static_dir: "/tmp/static".to_string(),
604 default: true,
605 api_only: false,
606 headers: HashMap::new(),
607 redirect_to_https: false,
608 index_files: vec!["index.html".to_string()],
609 error_pages: HashMap::new(),
610 compression: Default::default(),
611 cache: Default::default(),
612 access_control: Default::default(),
613 ssl: Default::default(),
614 proxy: Default::default(),
615 }],
616 logging: LoggingConfig::default(),
617 performance: PerformanceConfig::default(),
618 security: SecurityConfig::default(),
619 management: ManagementConfig::default(),
620 };
621
622 let temp_file = NamedTempFile::new().unwrap();
623 let path = temp_file.path().to_str().unwrap();
624
625 config.save_to_file(path).unwrap();
627
628 let loaded_config = ServerConfig::load_from_file(path).unwrap();
630 assert_eq!(config.server.name, loaded_config.server.name);
631 assert_eq!(config.sites.len(), loaded_config.sites.len());
632 }
633}