bws_web_server/config/
server.rs

1use crate::config::SiteConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
8pub struct ServerConfig {
9    pub server: ServerInfo,
10    pub sites: Vec<SiteConfig>,
11    #[serde(default)]
12    pub logging: LoggingConfig,
13    #[serde(default)]
14    pub performance: PerformanceConfig,
15    #[serde(default)]
16    pub security: SecurityConfig,
17    #[serde(default)]
18    pub management: ManagementConfig,
19}
20
21#[derive(Debug, Deserialize, Serialize, Clone)]
22pub struct ServerInfo {
23    pub name: String,
24    #[serde(default = "default_version")]
25    pub version: String,
26    #[serde(default)]
27    pub description: String,
28}
29
30#[derive(Debug, Deserialize, Serialize, Clone)]
31pub struct LoggingConfig {
32    #[serde(default = "default_log_level")]
33    pub level: String,
34    #[serde(default)]
35    pub access_log: Option<String>,
36    #[serde(default)]
37    pub error_log: Option<String>,
38    #[serde(default = "default_log_format")]
39    pub format: String,
40    #[serde(default)]
41    pub log_requests: bool,
42}
43
44#[derive(Debug, Deserialize, Serialize, Clone)]
45pub struct PerformanceConfig {
46    #[serde(default = "default_worker_threads")]
47    pub worker_threads: usize,
48    #[serde(default = "default_max_connections")]
49    pub max_connections: usize,
50    #[serde(default = "default_keep_alive_timeout")]
51    pub keep_alive_timeout: u64,
52    #[serde(default = "default_request_timeout")]
53    pub request_timeout: u64,
54    #[serde(default = "default_buffer_size")]
55    pub read_buffer_size: String,
56    #[serde(default = "default_buffer_size")]
57    pub write_buffer_size: String,
58}
59
60#[derive(Debug, Deserialize, Serialize, Clone)]
61pub struct SecurityConfig {
62    #[serde(default)]
63    pub hide_server_header: bool,
64    #[serde(default = "default_max_request_size")]
65    pub max_request_size: String,
66    #[serde(default)]
67    pub allowed_origins: Vec<String>,
68    #[serde(default)]
69    pub security_headers: HashMap<String, String>,
70    #[serde(default)]
71    pub rate_limiting: Option<RateLimitConfig>,
72}
73
74#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
75pub struct RateLimitConfig {
76    pub requests_per_minute: u32,
77    pub burst_size: u32,
78    #[serde(default)]
79    pub whitelist: Vec<String>,
80}
81
82#[derive(Debug, Deserialize, Serialize, Clone)]
83pub struct ManagementConfig {
84    #[serde(default = "default_management_enabled")]
85    pub enabled: bool,
86    #[serde(default = "default_management_host")]
87    pub host: String,
88    #[serde(default = "default_management_port")]
89    pub port: u16,
90    #[serde(default)]
91    pub api_key: Option<String>,
92}
93
94// Default value functions
95fn default_version() -> String {
96    env!("CARGO_PKG_VERSION").to_string()
97}
98
99fn default_log_level() -> String {
100    "info".to_string()
101}
102
103fn default_log_format() -> String {
104    "combined".to_string()
105}
106
107fn default_worker_threads() -> usize {
108    num_cpus::get().max(1)
109}
110
111fn default_max_connections() -> usize {
112    1000
113}
114
115fn default_keep_alive_timeout() -> u64 {
116    60
117}
118
119fn default_request_timeout() -> u64 {
120    30
121}
122
123fn default_buffer_size() -> String {
124    "32KB".to_string()
125}
126
127fn default_max_request_size() -> String {
128    "10MB".to_string()
129}
130
131fn default_management_enabled() -> bool {
132    true
133}
134
135fn default_management_host() -> String {
136    "127.0.0.1".to_string()
137}
138
139fn default_management_port() -> u16 {
140    7654
141}
142
143impl Default for LoggingConfig {
144    fn default() -> Self {
145        Self {
146            level: default_log_level(),
147            access_log: None,
148            error_log: None,
149            format: default_log_format(),
150            log_requests: true,
151        }
152    }
153}
154
155impl Default for PerformanceConfig {
156    fn default() -> Self {
157        Self {
158            worker_threads: default_worker_threads(),
159            max_connections: default_max_connections(),
160            keep_alive_timeout: default_keep_alive_timeout(),
161            request_timeout: default_request_timeout(),
162            read_buffer_size: default_buffer_size(),
163            write_buffer_size: default_buffer_size(),
164        }
165    }
166}
167
168impl Default for SecurityConfig {
169    fn default() -> Self {
170        let mut security_headers = HashMap::new();
171        security_headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
172        security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
173        security_headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
174        security_headers.insert(
175            "Referrer-Policy".to_string(),
176            "strict-origin-when-cross-origin".to_string(),
177        );
178
179        Self {
180            hide_server_header: false,
181            max_request_size: default_max_request_size(),
182            allowed_origins: vec![],
183            security_headers,
184            rate_limiting: None,
185        }
186    }
187}
188
189impl Default for ManagementConfig {
190    fn default() -> Self {
191        Self {
192            enabled: default_management_enabled(),
193            host: default_management_host(),
194            port: default_management_port(),
195            api_key: None,
196        }
197    }
198}
199
200impl ServerConfig {
201    pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
202        let content = fs::read_to_string(path)?;
203        let mut config: ServerConfig = toml::from_str(&content)?;
204
205        // Post-process configuration first (to set automatic defaults)
206        config.post_process()?;
207
208        // Then validate the processed configuration
209        config.validate()?;
210
211        Ok(config)
212    }
213
214    pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
215        let content = toml::to_string_pretty(self)?;
216
217        // Ensure parent directory exists
218        if let Some(parent) = Path::new(path).parent() {
219            fs::create_dir_all(parent)?;
220        }
221
222        fs::write(path, content)?;
223        log::info!("Configuration saved to {}", path);
224        Ok(())
225    }
226
227    pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
228        // Validate server info
229        if self.server.name.is_empty() {
230            return Err("Server name cannot be empty".into());
231        }
232
233        // Validate sites
234        if self.sites.is_empty() {
235            return Err("At least one site must be configured".into());
236        }
237
238        let mut default_sites = 0;
239        let mut used_hostname_ports = std::collections::HashSet::new();
240
241        for (i, site) in self.sites.iter().enumerate() {
242            site.validate().map_err(|e| format!("Site {}: {}", i, e))?;
243
244            if site.default {
245                default_sites += 1;
246            }
247
248            // Check for hostname:port conflicts across all hostnames for this site
249            // Allow multiple sites on same port with different hostnames (virtual hosting)
250            for hostname in site.get_all_hostnames() {
251                let hostname_port_key = (hostname, site.port);
252                if used_hostname_ports.contains(&hostname_port_key) {
253                    return Err(format!(
254                        "Duplicate hostname:port combination: {}:{}. Each hostname must be unique per port.",
255                        hostname, site.port
256                    )
257                    .into());
258                }
259                used_hostname_ports.insert(hostname_port_key);
260            }
261        }
262
263        if default_sites == 0 {
264            // This should not happen after post_process, but let's be defensive
265            if self.sites.len() == 1 {
266                // Single site should have been auto-marked as default in post_process
267                return Err(
268                    "Internal error: single site was not marked as default during post-processing"
269                        .into(),
270                );
271            } else {
272                return Err("At least one site must be marked as default when multiple sites are configured".into());
273            }
274        }
275        if default_sites > 1 {
276            return Err("Only one site can be marked as default".into());
277        }
278
279        // Validate that each site has proper SSL configuration if enabled
280        for site in &self.sites {
281            if site.ssl.enabled {
282                if site.ssl.auto_cert {
283                    if let Some(acme) = &site.ssl.acme {
284                        if acme.email.is_empty() {
285                            return Err(format!(
286                                "Site '{}': ACME email is required when auto_cert is enabled",
287                                site.name
288                            )
289                            .into());
290                        }
291                    } else {
292                        return Err(format!(
293                            "Site '{}': ACME configuration is required when auto_cert is enabled",
294                            site.name
295                        )
296                        .into());
297                    }
298                } else if site.ssl.cert_file.is_none() || site.ssl.key_file.is_none() {
299                    return Err(format!(
300                        "Site '{}': Manual SSL requires both cert_file and key_file",
301                        site.name
302                    )
303                    .into());
304                }
305            }
306        }
307
308        // Validate performance configuration
309        self.performance.validate()?;
310
311        // Validate security configuration
312        self.security.validate()?;
313
314        Ok(())
315    }
316
317    fn post_process(&mut self) -> Result<(), Box<dyn std::error::Error>> {
318        // If there's only one site and no site is explicitly marked as default,
319        // automatically make the single site the default
320        if self.sites.len() == 1 && !self.sites[0].default {
321            log::info!(
322                "Automatically setting single site '{}' as default",
323                self.sites[0].name
324            );
325            self.sites[0].default = true;
326        }
327
328        // Sort sites by priority (default site first, then by name)
329        self.sites.sort_by(|a, b| match (a.default, b.default) {
330            (true, false) => std::cmp::Ordering::Less,
331            (false, true) => std::cmp::Ordering::Greater,
332            _ => a.name.cmp(&b.name),
333        });
334
335        // No additional post-processing needed for per-site SSL
336        // Each site manages its own SSL configuration
337
338        Ok(())
339    }
340
341    pub fn find_site_by_host_port(&self, host: &str, port: u16) -> Option<&SiteConfig> {
342        // First try to match both hostname and port exactly using the new method
343        for site in &self.sites {
344            if site.handles_hostname_port(host, port) {
345                return Some(site);
346            }
347        }
348
349        // Then try to match just the port (for cases where hostname might not match exactly)
350        for site in &self.sites {
351            if site.port == port {
352                return Some(site);
353            }
354        }
355
356        // Finally, return the default site if no match
357        self.sites.iter().find(|site| site.default)
358    }
359
360    pub fn get_ssl_domains(&self) -> Vec<String> {
361        self.sites
362            .iter()
363            .filter_map(|site| {
364                if site.ssl.enabled {
365                    Some(
366                        site.get_all_ssl_domains()
367                            .into_iter()
368                            .map(|s| s.to_string())
369                            .collect::<Vec<_>>(),
370                    )
371                } else {
372                    None
373                }
374            })
375            .flatten()
376            .collect()
377    }
378
379    pub fn get_site_by_domain(&self, domain: &str) -> Option<&SiteConfig> {
380        self.sites.iter().find(|site| {
381            site.handles_hostname(domain) || site.ssl.domains.contains(&domain.to_string())
382        })
383    }
384
385    pub fn reload_from_file(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
386        let new_config = Self::load_from_file(path)?;
387        *self = new_config;
388        log::info!("Configuration reloaded from {}", path);
389        Ok(())
390    }
391}
392
393impl PerformanceConfig {
394    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
395        if self.worker_threads == 0 {
396            return Err("Worker threads must be greater than 0".into());
397        }
398
399        if self.max_connections == 0 {
400            return Err("Max connections must be greater than 0".into());
401        }
402
403        if self.keep_alive_timeout == 0 {
404            return Err("Keep alive timeout must be greater than 0".into());
405        }
406
407        if self.request_timeout == 0 {
408            return Err("Request timeout must be greater than 0".into());
409        }
410
411        // Validate buffer sizes
412        self.parse_buffer_size(&self.read_buffer_size)
413            .map_err(|_| "Invalid read buffer size format")?;
414        self.parse_buffer_size(&self.write_buffer_size)
415            .map_err(|_| "Invalid write buffer size format")?;
416
417        Ok(())
418    }
419
420    pub fn parse_buffer_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
421        let size_str = size_str.trim().to_uppercase();
422
423        if let Some(value) = size_str.strip_suffix("KB") {
424            Ok(value.parse::<usize>()? * 1024)
425        } else if let Some(value) = size_str.strip_suffix("MB") {
426            Ok(value.parse::<usize>()? * 1024 * 1024)
427        } else if let Some(value) = size_str.strip_suffix("GB") {
428            Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
429        } else if let Some(value) = size_str.strip_suffix("B") {
430            Ok(value.parse::<usize>()?)
431        } else {
432            // Assume bytes if no suffix
433            Ok(size_str.parse::<usize>()?)
434        }
435    }
436}
437
438impl SecurityConfig {
439    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
440        // Validate max request size format
441        self.parse_size(&self.max_request_size)
442            .map_err(|_| "Invalid max request size format")?;
443
444        // Validate rate limiting configuration
445        if let Some(rate_limit) = &self.rate_limiting {
446            if rate_limit.requests_per_minute == 0 {
447                return Err("Rate limit requests per minute must be greater than 0".into());
448            }
449            if rate_limit.burst_size == 0 {
450                return Err("Rate limit burst size must be greater than 0".into());
451            }
452        }
453
454        Ok(())
455    }
456
457    pub fn parse_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
458        let size_str = size_str.trim().to_uppercase();
459
460        if let Some(value) = size_str.strip_suffix("KB") {
461            Ok(value.parse::<usize>()? * 1024)
462        } else if let Some(value) = size_str.strip_suffix("MB") {
463            Ok(value.parse::<usize>()? * 1024 * 1024)
464        } else if let Some(value) = size_str.strip_suffix("GB") {
465            Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
466        } else if let Some(value) = size_str.strip_suffix("B") {
467            Ok(value.parse::<usize>()?)
468        } else {
469            // Assume bytes if no suffix
470            Ok(size_str.parse::<usize>()?)
471        }
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use tempfile::NamedTempFile;
479
480    #[test]
481    fn test_buffer_size_parsing() {
482        let config = PerformanceConfig::default();
483
484        assert_eq!(config.parse_buffer_size("1024").unwrap(), 1024);
485        assert_eq!(config.parse_buffer_size("1KB").unwrap(), 1024);
486        assert_eq!(config.parse_buffer_size("1MB").unwrap(), 1024 * 1024);
487        assert_eq!(config.parse_buffer_size("1GB").unwrap(), 1024 * 1024 * 1024);
488        assert_eq!(config.parse_buffer_size("500B").unwrap(), 500);
489
490        assert!(config.parse_buffer_size("invalid").is_err());
491        assert!(config.parse_buffer_size("").is_err());
492    }
493
494    #[test]
495    fn test_automatic_default_site() {
496        use crate::config::SiteConfig;
497
498        // Test that a single site without explicit default=true gets auto-marked as default
499        let mut config = ServerConfig {
500            server: ServerInfo {
501                name: "test-server".to_string(),
502                version: "1.0.0".to_string(),
503                description: "Test server".to_string(),
504            },
505            sites: vec![SiteConfig {
506                name: "single-site".to_string(),
507                hostname: "localhost".to_string(),
508                hostnames: vec![],
509                port: 8080,
510                static_dir: "/tmp/static".to_string(),
511                default: false, // Explicitly NOT marked as default
512                api_only: false,
513                headers: HashMap::new(),
514                redirect_to_https: false,
515                index_files: vec!["index.html".to_string()],
516                error_pages: HashMap::new(),
517                compression: Default::default(),
518                cache: Default::default(),
519                access_control: Default::default(),
520                ssl: Default::default(),
521                proxy: Default::default(),
522            }],
523            logging: LoggingConfig::default(),
524            performance: PerformanceConfig::default(),
525            security: SecurityConfig::default(),
526            management: ManagementConfig::default(),
527        };
528
529        // Before post_process, the site should not be marked as default
530        assert!(!config.sites[0].default);
531
532        // After post_process, the single site should be automatically marked as default
533        config.post_process().unwrap();
534        assert!(config.sites[0].default);
535
536        // Validation should pass
537        assert!(config.validate().is_ok());
538
539        // Test with multiple sites - auto-default should not apply
540        config.sites.push(SiteConfig {
541            name: "second-site".to_string(),
542            hostname: "example.com".to_string(),
543            hostnames: vec![],
544            port: 8081,
545            static_dir: "/tmp/static2".to_string(),
546            default: false,
547            api_only: false,
548            headers: HashMap::new(),
549            redirect_to_https: false,
550            index_files: vec!["index.html".to_string()],
551            error_pages: HashMap::new(),
552            compression: Default::default(),
553            cache: Default::default(),
554            access_control: Default::default(),
555            ssl: Default::default(),
556            proxy: Default::default(),
557        });
558
559        // Reset first site's default flag
560        config.sites[0].default = false;
561
562        // Post-process should not auto-mark any site as default when there are multiple
563        config.post_process().unwrap();
564        assert!(!config.sites[0].default);
565        assert!(!config.sites[1].default);
566
567        // Validation should fail because no site is marked as default
568        assert!(config.validate().is_err());
569    }
570
571    #[test]
572    fn test_security_config_validation() {
573        let mut config = SecurityConfig::default();
574        assert!(config.validate().is_ok());
575
576        config.max_request_size = "invalid".to_string();
577        assert!(config.validate().is_err());
578
579        config.max_request_size = "10MB".to_string();
580        config.rate_limiting = Some(RateLimitConfig {
581            requests_per_minute: 0,
582            burst_size: 10,
583            whitelist: vec![],
584        });
585        assert!(config.validate().is_err());
586    }
587
588    #[tokio::test]
589    async fn test_config_save_load() {
590        use crate::config::SiteConfig;
591
592        let config = ServerConfig {
593            server: ServerInfo {
594                name: "test-server".to_string(),
595                version: "1.0.0".to_string(),
596                description: "Test server".to_string(),
597            },
598            sites: vec![SiteConfig {
599                name: "test-site".to_string(),
600                hostname: "localhost".to_string(),
601                hostnames: vec![],
602                port: 8080,
603                static_dir: "/tmp/static".to_string(),
604                default: true,
605                api_only: false,
606                headers: HashMap::new(),
607                redirect_to_https: false,
608                index_files: vec!["index.html".to_string()],
609                error_pages: HashMap::new(),
610                compression: Default::default(),
611                cache: Default::default(),
612                access_control: Default::default(),
613                ssl: Default::default(),
614                proxy: Default::default(),
615            }],
616            logging: LoggingConfig::default(),
617            performance: PerformanceConfig::default(),
618            security: SecurityConfig::default(),
619            management: ManagementConfig::default(),
620        };
621
622        let temp_file = NamedTempFile::new().unwrap();
623        let path = temp_file.path().to_str().unwrap();
624
625        // Test save
626        config.save_to_file(path).unwrap();
627
628        // Test load
629        let loaded_config = ServerConfig::load_from_file(path).unwrap();
630        assert_eq!(config.server.name, loaded_config.server.name);
631        assert_eq!(config.sites.len(), loaded_config.sites.len());
632    }
633}