bws_web_server/config/
server.rs

1use crate::config::SiteConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
8pub struct ServerConfig {
9    pub server: ServerInfo,
10    pub sites: Vec<SiteConfig>,
11    #[serde(default)]
12    pub logging: LoggingConfig,
13    #[serde(default)]
14    pub performance: PerformanceConfig,
15    #[serde(default)]
16    pub security: SecurityConfig,
17}
18
19#[derive(Debug, Deserialize, Serialize, Clone)]
20pub struct ServerInfo {
21    pub name: String,
22    #[serde(default = "default_version")]
23    pub version: String,
24    #[serde(default)]
25    pub description: String,
26}
27
28#[derive(Debug, Deserialize, Serialize, Clone)]
29pub struct LoggingConfig {
30    #[serde(default = "default_log_level")]
31    pub level: String,
32    #[serde(default)]
33    pub access_log: Option<String>,
34    #[serde(default)]
35    pub error_log: Option<String>,
36    #[serde(default = "default_log_format")]
37    pub format: String,
38    #[serde(default)]
39    pub log_requests: bool,
40}
41
42#[derive(Debug, Deserialize, Serialize, Clone)]
43pub struct PerformanceConfig {
44    #[serde(default = "default_worker_threads")]
45    pub worker_threads: usize,
46    #[serde(default = "default_max_connections")]
47    pub max_connections: usize,
48    #[serde(default = "default_keep_alive_timeout")]
49    pub keep_alive_timeout: u64,
50    #[serde(default = "default_request_timeout")]
51    pub request_timeout: u64,
52    #[serde(default = "default_buffer_size")]
53    pub read_buffer_size: String,
54    #[serde(default = "default_buffer_size")]
55    pub write_buffer_size: String,
56}
57
58#[derive(Debug, Deserialize, Serialize, Clone)]
59pub struct SecurityConfig {
60    #[serde(default)]
61    pub hide_server_header: bool,
62    #[serde(default = "default_max_request_size")]
63    pub max_request_size: String,
64    #[serde(default)]
65    pub allowed_origins: Vec<String>,
66    #[serde(default)]
67    pub security_headers: HashMap<String, String>,
68    #[serde(default)]
69    pub rate_limiting: Option<RateLimitConfig>,
70}
71
72#[derive(Debug, Deserialize, Serialize, Clone)]
73pub struct RateLimitConfig {
74    pub requests_per_minute: u32,
75    pub burst_size: u32,
76    #[serde(default)]
77    pub whitelist: Vec<String>,
78}
79
80// Default value functions
81fn default_version() -> String {
82    env!("CARGO_PKG_VERSION").to_string()
83}
84
85fn default_log_level() -> String {
86    "info".to_string()
87}
88
89fn default_log_format() -> String {
90    "combined".to_string()
91}
92
93fn default_worker_threads() -> usize {
94    num_cpus::get().max(1)
95}
96
97fn default_max_connections() -> usize {
98    1000
99}
100
101fn default_keep_alive_timeout() -> u64 {
102    60
103}
104
105fn default_request_timeout() -> u64 {
106    30
107}
108
109fn default_buffer_size() -> String {
110    "32KB".to_string()
111}
112
113fn default_max_request_size() -> String {
114    "10MB".to_string()
115}
116
117impl Default for LoggingConfig {
118    fn default() -> Self {
119        Self {
120            level: default_log_level(),
121            access_log: None,
122            error_log: None,
123            format: default_log_format(),
124            log_requests: true,
125        }
126    }
127}
128
129impl Default for PerformanceConfig {
130    fn default() -> Self {
131        Self {
132            worker_threads: default_worker_threads(),
133            max_connections: default_max_connections(),
134            keep_alive_timeout: default_keep_alive_timeout(),
135            request_timeout: default_request_timeout(),
136            read_buffer_size: default_buffer_size(),
137            write_buffer_size: default_buffer_size(),
138        }
139    }
140}
141
142impl Default for SecurityConfig {
143    fn default() -> Self {
144        let mut security_headers = HashMap::new();
145        security_headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
146        security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
147        security_headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
148        security_headers.insert(
149            "Referrer-Policy".to_string(),
150            "strict-origin-when-cross-origin".to_string(),
151        );
152
153        Self {
154            hide_server_header: false,
155            max_request_size: default_max_request_size(),
156            allowed_origins: vec![],
157            security_headers,
158            rate_limiting: None,
159        }
160    }
161}
162
163impl ServerConfig {
164    pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
165        let content = fs::read_to_string(path)?;
166        let mut config: ServerConfig = toml::from_str(&content)?;
167
168        // Validate configuration
169        config.validate()?;
170
171        // Post-process configuration
172        config.post_process()?;
173
174        Ok(config)
175    }
176
177    pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
178        let content = toml::to_string_pretty(self)?;
179
180        // Ensure parent directory exists
181        if let Some(parent) = Path::new(path).parent() {
182            fs::create_dir_all(parent)?;
183        }
184
185        fs::write(path, content)?;
186        log::info!("Configuration saved to {}", path);
187        Ok(())
188    }
189
190    pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
191        // Validate server info
192        if self.server.name.is_empty() {
193            return Err("Server name cannot be empty".into());
194        }
195
196        // Validate sites
197        if self.sites.is_empty() {
198            return Err("At least one site must be configured".into());
199        }
200
201        let mut default_sites = 0;
202        let mut used_hostname_ports = std::collections::HashSet::new();
203
204        for (i, site) in self.sites.iter().enumerate() {
205            site.validate().map_err(|e| format!("Site {}: {}", i, e))?;
206
207            if site.default {
208                default_sites += 1;
209            }
210
211            // Check for hostname:port conflicts across all hostnames for this site
212            // Allow multiple sites on same port with different hostnames (virtual hosting)
213            for hostname in site.get_all_hostnames() {
214                let hostname_port_key = (hostname, site.port);
215                if used_hostname_ports.contains(&hostname_port_key) {
216                    return Err(format!(
217                        "Duplicate hostname:port combination: {}:{}. Each hostname must be unique per port.",
218                        hostname, site.port
219                    )
220                    .into());
221                }
222                used_hostname_ports.insert(hostname_port_key);
223            }
224        }
225
226        if default_sites == 0 {
227            return Err("At least one site must be marked as default".into());
228        }
229        if default_sites > 1 {
230            return Err("Only one site can be marked as default".into());
231        }
232
233        // Validate that each site has proper SSL configuration if enabled
234        for site in &self.sites {
235            if site.ssl.enabled {
236                if site.ssl.auto_cert {
237                    if let Some(acme) = &site.ssl.acme {
238                        if acme.email.is_empty() {
239                            return Err(format!(
240                                "Site '{}': ACME email is required when auto_cert is enabled",
241                                site.name
242                            )
243                            .into());
244                        }
245                    } else {
246                        return Err(format!(
247                            "Site '{}': ACME configuration is required when auto_cert is enabled",
248                            site.name
249                        )
250                        .into());
251                    }
252                } else if site.ssl.cert_file.is_none() || site.ssl.key_file.is_none() {
253                    return Err(format!(
254                        "Site '{}': Manual SSL requires both cert_file and key_file",
255                        site.name
256                    )
257                    .into());
258                }
259            }
260        }
261
262        // Validate performance configuration
263        self.performance.validate()?;
264
265        // Validate security configuration
266        self.security.validate()?;
267
268        Ok(())
269    }
270
271    fn post_process(&mut self) -> Result<(), Box<dyn std::error::Error>> {
272        // Sort sites by priority (default site first, then by name)
273        self.sites.sort_by(|a, b| match (a.default, b.default) {
274            (true, false) => std::cmp::Ordering::Less,
275            (false, true) => std::cmp::Ordering::Greater,
276            _ => a.name.cmp(&b.name),
277        });
278
279        // No additional post-processing needed for per-site SSL
280        // Each site manages its own SSL configuration
281
282        Ok(())
283    }
284
285    pub fn find_site_by_host_port(&self, host: &str, port: u16) -> Option<&SiteConfig> {
286        // First try to match both hostname and port exactly using the new method
287        for site in &self.sites {
288            if site.handles_hostname_port(host, port) {
289                return Some(site);
290            }
291        }
292
293        // Then try to match just the port (for cases where hostname might not match exactly)
294        for site in &self.sites {
295            if site.port == port {
296                return Some(site);
297            }
298        }
299
300        // Finally, return the default site if no match
301        self.sites.iter().find(|site| site.default)
302    }
303
304    pub fn get_ssl_domains(&self) -> Vec<String> {
305        self.sites
306            .iter()
307            .filter_map(|site| {
308                if site.ssl.enabled {
309                    Some(
310                        site.get_all_ssl_domains()
311                            .into_iter()
312                            .map(|s| s.to_string())
313                            .collect::<Vec<_>>(),
314                    )
315                } else {
316                    None
317                }
318            })
319            .flatten()
320            .collect()
321    }
322
323    pub fn get_site_by_domain(&self, domain: &str) -> Option<&SiteConfig> {
324        self.sites.iter().find(|site| {
325            site.handles_hostname(domain) || site.ssl.domains.contains(&domain.to_string())
326        })
327    }
328
329    pub fn reload_from_file(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
330        let new_config = Self::load_from_file(path)?;
331        *self = new_config;
332        log::info!("Configuration reloaded from {}", path);
333        Ok(())
334    }
335}
336
337impl PerformanceConfig {
338    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
339        if self.worker_threads == 0 {
340            return Err("Worker threads must be greater than 0".into());
341        }
342
343        if self.max_connections == 0 {
344            return Err("Max connections must be greater than 0".into());
345        }
346
347        if self.keep_alive_timeout == 0 {
348            return Err("Keep alive timeout must be greater than 0".into());
349        }
350
351        if self.request_timeout == 0 {
352            return Err("Request timeout must be greater than 0".into());
353        }
354
355        // Validate buffer sizes
356        self.parse_buffer_size(&self.read_buffer_size)
357            .map_err(|_| "Invalid read buffer size format")?;
358        self.parse_buffer_size(&self.write_buffer_size)
359            .map_err(|_| "Invalid write buffer size format")?;
360
361        Ok(())
362    }
363
364    pub fn parse_buffer_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
365        let size_str = size_str.trim().to_uppercase();
366
367        if let Some(value) = size_str.strip_suffix("KB") {
368            Ok(value.parse::<usize>()? * 1024)
369        } else if let Some(value) = size_str.strip_suffix("MB") {
370            Ok(value.parse::<usize>()? * 1024 * 1024)
371        } else if let Some(value) = size_str.strip_suffix("GB") {
372            Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
373        } else if let Some(value) = size_str.strip_suffix("B") {
374            Ok(value.parse::<usize>()?)
375        } else {
376            // Assume bytes if no suffix
377            Ok(size_str.parse::<usize>()?)
378        }
379    }
380}
381
382impl SecurityConfig {
383    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
384        // Validate max request size format
385        self.parse_size(&self.max_request_size)
386            .map_err(|_| "Invalid max request size format")?;
387
388        // Validate rate limiting configuration
389        if let Some(rate_limit) = &self.rate_limiting {
390            if rate_limit.requests_per_minute == 0 {
391                return Err("Rate limit requests per minute must be greater than 0".into());
392            }
393            if rate_limit.burst_size == 0 {
394                return Err("Rate limit burst size must be greater than 0".into());
395            }
396        }
397
398        Ok(())
399    }
400
401    pub fn parse_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
402        let size_str = size_str.trim().to_uppercase();
403
404        if let Some(value) = size_str.strip_suffix("KB") {
405            Ok(value.parse::<usize>()? * 1024)
406        } else if let Some(value) = size_str.strip_suffix("MB") {
407            Ok(value.parse::<usize>()? * 1024 * 1024)
408        } else if let Some(value) = size_str.strip_suffix("GB") {
409            Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
410        } else if let Some(value) = size_str.strip_suffix("B") {
411            Ok(value.parse::<usize>()?)
412        } else {
413            // Assume bytes if no suffix
414            Ok(size_str.parse::<usize>()?)
415        }
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use tempfile::NamedTempFile;
423
424    #[test]
425    fn test_buffer_size_parsing() {
426        let config = PerformanceConfig::default();
427
428        assert_eq!(config.parse_buffer_size("1024").unwrap(), 1024);
429        assert_eq!(config.parse_buffer_size("1KB").unwrap(), 1024);
430        assert_eq!(config.parse_buffer_size("1MB").unwrap(), 1024 * 1024);
431        assert_eq!(config.parse_buffer_size("1GB").unwrap(), 1024 * 1024 * 1024);
432        assert_eq!(config.parse_buffer_size("500B").unwrap(), 500);
433
434        assert!(config.parse_buffer_size("invalid").is_err());
435        assert!(config.parse_buffer_size("").is_err());
436    }
437
438    #[test]
439    fn test_security_config_validation() {
440        let mut config = SecurityConfig::default();
441        assert!(config.validate().is_ok());
442
443        config.max_request_size = "invalid".to_string();
444        assert!(config.validate().is_err());
445
446        config.max_request_size = "10MB".to_string();
447        config.rate_limiting = Some(RateLimitConfig {
448            requests_per_minute: 0,
449            burst_size: 10,
450            whitelist: vec![],
451        });
452        assert!(config.validate().is_err());
453    }
454
455    #[tokio::test]
456    async fn test_config_save_load() {
457        use crate::config::SiteConfig;
458
459        let config = ServerConfig {
460            server: ServerInfo {
461                name: "test-server".to_string(),
462                version: "1.0.0".to_string(),
463                description: "Test server".to_string(),
464            },
465            sites: vec![SiteConfig {
466                name: "test-site".to_string(),
467                hostname: "localhost".to_string(),
468                hostnames: vec![],
469                port: 8080,
470                static_dir: "/tmp/static".to_string(),
471                default: true,
472                api_only: false,
473                headers: HashMap::new(),
474                redirect_to_https: false,
475                index_files: vec!["index.html".to_string()],
476                error_pages: HashMap::new(),
477                compression: Default::default(),
478                cache: Default::default(),
479                access_control: Default::default(),
480                ssl: Default::default(),
481                proxy: Default::default(),
482            }],
483            logging: LoggingConfig::default(),
484            performance: PerformanceConfig::default(),
485            security: SecurityConfig::default(),
486        };
487
488        let temp_file = NamedTempFile::new().unwrap();
489        let path = temp_file.path().to_str().unwrap();
490
491        // Test save
492        config.save_to_file(path).unwrap();
493
494        // Test load
495        let loaded_config = ServerConfig::load_from_file(path).unwrap();
496        assert_eq!(config.server.name, loaded_config.server.name);
497        assert_eq!(config.sites.len(), loaded_config.sites.len());
498    }
499}