1use crate::config::SiteConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
8pub struct ServerConfig {
9 pub server: ServerInfo,
10 pub sites: Vec<SiteConfig>,
11 #[serde(default)]
12 pub logging: LoggingConfig,
13 #[serde(default)]
14 pub performance: PerformanceConfig,
15 #[serde(default)]
16 pub security: SecurityConfig,
17}
18
19#[derive(Debug, Deserialize, Serialize, Clone)]
20pub struct ServerInfo {
21 pub name: String,
22 #[serde(default = "default_version")]
23 pub version: String,
24 #[serde(default)]
25 pub description: String,
26}
27
28#[derive(Debug, Deserialize, Serialize, Clone)]
29pub struct LoggingConfig {
30 #[serde(default = "default_log_level")]
31 pub level: String,
32 #[serde(default)]
33 pub access_log: Option<String>,
34 #[serde(default)]
35 pub error_log: Option<String>,
36 #[serde(default = "default_log_format")]
37 pub format: String,
38 #[serde(default)]
39 pub log_requests: bool,
40}
41
42#[derive(Debug, Deserialize, Serialize, Clone)]
43pub struct PerformanceConfig {
44 #[serde(default = "default_worker_threads")]
45 pub worker_threads: usize,
46 #[serde(default = "default_max_connections")]
47 pub max_connections: usize,
48 #[serde(default = "default_keep_alive_timeout")]
49 pub keep_alive_timeout: u64,
50 #[serde(default = "default_request_timeout")]
51 pub request_timeout: u64,
52 #[serde(default = "default_buffer_size")]
53 pub read_buffer_size: String,
54 #[serde(default = "default_buffer_size")]
55 pub write_buffer_size: String,
56}
57
58#[derive(Debug, Deserialize, Serialize, Clone)]
59pub struct SecurityConfig {
60 #[serde(default)]
61 pub hide_server_header: bool,
62 #[serde(default = "default_max_request_size")]
63 pub max_request_size: String,
64 #[serde(default)]
65 pub allowed_origins: Vec<String>,
66 #[serde(default)]
67 pub security_headers: HashMap<String, String>,
68 #[serde(default)]
69 pub rate_limiting: Option<RateLimitConfig>,
70}
71
72#[derive(Debug, Deserialize, Serialize, Clone)]
73pub struct RateLimitConfig {
74 pub requests_per_minute: u32,
75 pub burst_size: u32,
76 #[serde(default)]
77 pub whitelist: Vec<String>,
78}
79
80fn default_version() -> String {
82 env!("CARGO_PKG_VERSION").to_string()
83}
84
85fn default_log_level() -> String {
86 "info".to_string()
87}
88
89fn default_log_format() -> String {
90 "combined".to_string()
91}
92
93fn default_worker_threads() -> usize {
94 num_cpus::get().max(1)
95}
96
97fn default_max_connections() -> usize {
98 1000
99}
100
101fn default_keep_alive_timeout() -> u64 {
102 60
103}
104
105fn default_request_timeout() -> u64 {
106 30
107}
108
109fn default_buffer_size() -> String {
110 "32KB".to_string()
111}
112
113fn default_max_request_size() -> String {
114 "10MB".to_string()
115}
116
117impl Default for LoggingConfig {
118 fn default() -> Self {
119 Self {
120 level: default_log_level(),
121 access_log: None,
122 error_log: None,
123 format: default_log_format(),
124 log_requests: true,
125 }
126 }
127}
128
129impl Default for PerformanceConfig {
130 fn default() -> Self {
131 Self {
132 worker_threads: default_worker_threads(),
133 max_connections: default_max_connections(),
134 keep_alive_timeout: default_keep_alive_timeout(),
135 request_timeout: default_request_timeout(),
136 read_buffer_size: default_buffer_size(),
137 write_buffer_size: default_buffer_size(),
138 }
139 }
140}
141
142impl Default for SecurityConfig {
143 fn default() -> Self {
144 let mut security_headers = HashMap::new();
145 security_headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
146 security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
147 security_headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
148 security_headers.insert(
149 "Referrer-Policy".to_string(),
150 "strict-origin-when-cross-origin".to_string(),
151 );
152
153 Self {
154 hide_server_header: false,
155 max_request_size: default_max_request_size(),
156 allowed_origins: vec![],
157 security_headers,
158 rate_limiting: None,
159 }
160 }
161}
162
163impl ServerConfig {
164 pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
165 let content = fs::read_to_string(path)?;
166 let mut config: ServerConfig = toml::from_str(&content)?;
167
168 config.validate()?;
170
171 config.post_process()?;
173
174 Ok(config)
175 }
176
177 pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
178 let content = toml::to_string_pretty(self)?;
179
180 if let Some(parent) = Path::new(path).parent() {
182 fs::create_dir_all(parent)?;
183 }
184
185 fs::write(path, content)?;
186 log::info!("Configuration saved to {}", path);
187 Ok(())
188 }
189
190 pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
191 if self.server.name.is_empty() {
193 return Err("Server name cannot be empty".into());
194 }
195
196 if self.sites.is_empty() {
198 return Err("At least one site must be configured".into());
199 }
200
201 let mut default_sites = 0;
202 let mut used_hostname_ports = std::collections::HashSet::new();
203
204 for (i, site) in self.sites.iter().enumerate() {
205 site.validate().map_err(|e| format!("Site {}: {}", i, e))?;
206
207 if site.default {
208 default_sites += 1;
209 }
210
211 for hostname in site.get_all_hostnames() {
214 let hostname_port_key = (hostname, site.port);
215 if used_hostname_ports.contains(&hostname_port_key) {
216 return Err(format!(
217 "Duplicate hostname:port combination: {}:{}. Each hostname must be unique per port.",
218 hostname, site.port
219 )
220 .into());
221 }
222 used_hostname_ports.insert(hostname_port_key);
223 }
224 }
225
226 if default_sites == 0 {
227 return Err("At least one site must be marked as default".into());
228 }
229 if default_sites > 1 {
230 return Err("Only one site can be marked as default".into());
231 }
232
233 for site in &self.sites {
235 if site.ssl.enabled {
236 if site.ssl.auto_cert {
237 if let Some(acme) = &site.ssl.acme {
238 if acme.email.is_empty() {
239 return Err(format!(
240 "Site '{}': ACME email is required when auto_cert is enabled",
241 site.name
242 )
243 .into());
244 }
245 } else {
246 return Err(format!(
247 "Site '{}': ACME configuration is required when auto_cert is enabled",
248 site.name
249 )
250 .into());
251 }
252 } else if site.ssl.cert_file.is_none() || site.ssl.key_file.is_none() {
253 return Err(format!(
254 "Site '{}': Manual SSL requires both cert_file and key_file",
255 site.name
256 )
257 .into());
258 }
259 }
260 }
261
262 self.performance.validate()?;
264
265 self.security.validate()?;
267
268 Ok(())
269 }
270
271 fn post_process(&mut self) -> Result<(), Box<dyn std::error::Error>> {
272 self.sites.sort_by(|a, b| match (a.default, b.default) {
274 (true, false) => std::cmp::Ordering::Less,
275 (false, true) => std::cmp::Ordering::Greater,
276 _ => a.name.cmp(&b.name),
277 });
278
279 Ok(())
283 }
284
285 pub fn find_site_by_host_port(&self, host: &str, port: u16) -> Option<&SiteConfig> {
286 for site in &self.sites {
288 if site.handles_hostname_port(host, port) {
289 return Some(site);
290 }
291 }
292
293 for site in &self.sites {
295 if site.port == port {
296 return Some(site);
297 }
298 }
299
300 self.sites.iter().find(|site| site.default)
302 }
303
304 pub fn get_ssl_domains(&self) -> Vec<String> {
305 self.sites
306 .iter()
307 .filter_map(|site| {
308 if site.ssl.enabled {
309 Some(
310 site.get_all_ssl_domains()
311 .into_iter()
312 .map(|s| s.to_string())
313 .collect::<Vec<_>>(),
314 )
315 } else {
316 None
317 }
318 })
319 .flatten()
320 .collect()
321 }
322
323 pub fn get_site_by_domain(&self, domain: &str) -> Option<&SiteConfig> {
324 self.sites.iter().find(|site| {
325 site.handles_hostname(domain) || site.ssl.domains.contains(&domain.to_string())
326 })
327 }
328
329 pub fn reload_from_file(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
330 let new_config = Self::load_from_file(path)?;
331 *self = new_config;
332 log::info!("Configuration reloaded from {}", path);
333 Ok(())
334 }
335}
336
337impl PerformanceConfig {
338 fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
339 if self.worker_threads == 0 {
340 return Err("Worker threads must be greater than 0".into());
341 }
342
343 if self.max_connections == 0 {
344 return Err("Max connections must be greater than 0".into());
345 }
346
347 if self.keep_alive_timeout == 0 {
348 return Err("Keep alive timeout must be greater than 0".into());
349 }
350
351 if self.request_timeout == 0 {
352 return Err("Request timeout must be greater than 0".into());
353 }
354
355 self.parse_buffer_size(&self.read_buffer_size)
357 .map_err(|_| "Invalid read buffer size format")?;
358 self.parse_buffer_size(&self.write_buffer_size)
359 .map_err(|_| "Invalid write buffer size format")?;
360
361 Ok(())
362 }
363
364 pub fn parse_buffer_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
365 let size_str = size_str.trim().to_uppercase();
366
367 if let Some(value) = size_str.strip_suffix("KB") {
368 Ok(value.parse::<usize>()? * 1024)
369 } else if let Some(value) = size_str.strip_suffix("MB") {
370 Ok(value.parse::<usize>()? * 1024 * 1024)
371 } else if let Some(value) = size_str.strip_suffix("GB") {
372 Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
373 } else if let Some(value) = size_str.strip_suffix("B") {
374 Ok(value.parse::<usize>()?)
375 } else {
376 Ok(size_str.parse::<usize>()?)
378 }
379 }
380}
381
382impl SecurityConfig {
383 fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
384 self.parse_size(&self.max_request_size)
386 .map_err(|_| "Invalid max request size format")?;
387
388 if let Some(rate_limit) = &self.rate_limiting {
390 if rate_limit.requests_per_minute == 0 {
391 return Err("Rate limit requests per minute must be greater than 0".into());
392 }
393 if rate_limit.burst_size == 0 {
394 return Err("Rate limit burst size must be greater than 0".into());
395 }
396 }
397
398 Ok(())
399 }
400
401 pub fn parse_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
402 let size_str = size_str.trim().to_uppercase();
403
404 if let Some(value) = size_str.strip_suffix("KB") {
405 Ok(value.parse::<usize>()? * 1024)
406 } else if let Some(value) = size_str.strip_suffix("MB") {
407 Ok(value.parse::<usize>()? * 1024 * 1024)
408 } else if let Some(value) = size_str.strip_suffix("GB") {
409 Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
410 } else if let Some(value) = size_str.strip_suffix("B") {
411 Ok(value.parse::<usize>()?)
412 } else {
413 Ok(size_str.parse::<usize>()?)
415 }
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use tempfile::NamedTempFile;
423
424 #[test]
425 fn test_buffer_size_parsing() {
426 let config = PerformanceConfig::default();
427
428 assert_eq!(config.parse_buffer_size("1024").unwrap(), 1024);
429 assert_eq!(config.parse_buffer_size("1KB").unwrap(), 1024);
430 assert_eq!(config.parse_buffer_size("1MB").unwrap(), 1024 * 1024);
431 assert_eq!(config.parse_buffer_size("1GB").unwrap(), 1024 * 1024 * 1024);
432 assert_eq!(config.parse_buffer_size("500B").unwrap(), 500);
433
434 assert!(config.parse_buffer_size("invalid").is_err());
435 assert!(config.parse_buffer_size("").is_err());
436 }
437
438 #[test]
439 fn test_security_config_validation() {
440 let mut config = SecurityConfig::default();
441 assert!(config.validate().is_ok());
442
443 config.max_request_size = "invalid".to_string();
444 assert!(config.validate().is_err());
445
446 config.max_request_size = "10MB".to_string();
447 config.rate_limiting = Some(RateLimitConfig {
448 requests_per_minute: 0,
449 burst_size: 10,
450 whitelist: vec![],
451 });
452 assert!(config.validate().is_err());
453 }
454
455 #[tokio::test]
456 async fn test_config_save_load() {
457 use crate::config::SiteConfig;
458
459 let config = ServerConfig {
460 server: ServerInfo {
461 name: "test-server".to_string(),
462 version: "1.0.0".to_string(),
463 description: "Test server".to_string(),
464 },
465 sites: vec![SiteConfig {
466 name: "test-site".to_string(),
467 hostname: "localhost".to_string(),
468 hostnames: vec![],
469 port: 8080,
470 static_dir: "/tmp/static".to_string(),
471 default: true,
472 api_only: false,
473 headers: HashMap::new(),
474 redirect_to_https: false,
475 index_files: vec!["index.html".to_string()],
476 error_pages: HashMap::new(),
477 compression: Default::default(),
478 cache: Default::default(),
479 access_control: Default::default(),
480 ssl: Default::default(),
481 proxy: Default::default(),
482 }],
483 logging: LoggingConfig::default(),
484 performance: PerformanceConfig::default(),
485 security: SecurityConfig::default(),
486 };
487
488 let temp_file = NamedTempFile::new().unwrap();
489 let path = temp_file.path().to_str().unwrap();
490
491 config.save_to_file(path).unwrap();
493
494 let loaded_config = ServerConfig::load_from_file(path).unwrap();
496 assert_eq!(config.server.name, loaded_config.server.name);
497 assert_eq!(config.sites.len(), loaded_config.sites.len());
498 }
499}