bws_web_server/config/
server.rs

1use crate::config::SiteConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::Path;
6
7/// Top-level server configuration for BWS.
8/// Contains global server info, site definitions, and all major subsystems.
9#[derive(Debug, Deserialize, Serialize, Clone)]
10pub struct ServerConfig {
11    /// General server info (name, version, description)
12    pub server: ServerInfo,
13    /// List of all configured sites (virtual hosts)
14    pub sites: Vec<SiteConfig>,
15    /// Logging configuration
16    #[serde(default)]
17    pub logging: LoggingConfig,
18    /// Performance tuning configuration
19    #[serde(default)]
20    pub performance: PerformanceConfig,
21    /// Security-related configuration
22    #[serde(default)]
23    pub security: SecurityConfig,
24    /// Management API configuration
25    #[serde(default)]
26    pub management: ManagementConfig,
27}
28
29/// Information about the server (name, version, description)
30#[derive(Debug, Deserialize, Serialize, Clone)]
31pub struct ServerInfo {
32    /// Server name
33    pub name: String,
34    /// Server version
35    #[serde(default = "default_version")]
36    pub version: String,
37    /// Server description
38    #[serde(default)]
39    pub description: String,
40}
41
42/// Logging configuration for the server
43#[derive(Debug, Deserialize, Serialize, Clone)]
44pub struct LoggingConfig {
45    /// Log level (e.g., info, debug, warn)
46    #[serde(default = "default_log_level")]
47    pub level: String,
48    /// Path to access log file (optional)
49    #[serde(default)]
50    pub access_log: Option<String>,
51    /// Path to error log file (optional)
52    #[serde(default)]
53    pub error_log: Option<String>,
54    /// Log format (e.g., combined, json)
55    #[serde(default = "default_log_format")]
56    pub format: String,
57    /// Whether to log all requests
58    #[serde(default)]
59    pub log_requests: bool,
60}
61
62/// Performance tuning configuration for the server
63#[derive(Debug, Deserialize, Serialize, Clone)]
64pub struct PerformanceConfig {
65    /// Number of worker threads
66    #[serde(default = "default_worker_threads")]
67    pub worker_threads: usize,
68    /// Maximum concurrent connections
69    #[serde(default = "default_max_connections")]
70    pub max_connections: usize,
71    /// Keep-alive timeout in seconds
72    #[serde(default = "default_keep_alive_timeout")]
73    pub keep_alive_timeout: u64,
74    /// Request timeout in seconds
75    #[serde(default = "default_request_timeout")]
76    pub request_timeout: u64,
77    /// Read buffer size (e.g., "32KB")
78    #[serde(default = "default_buffer_size")]
79    pub read_buffer_size: String,
80    /// Write buffer size (e.g., "32KB")
81    #[serde(default = "default_buffer_size")]
82    pub write_buffer_size: String,
83}
84
85/// Security-related configuration for the server
86#[derive(Debug, Deserialize, Serialize, Clone)]
87pub struct SecurityConfig {
88    /// Whether to hide the server header in responses
89    #[serde(default)]
90    pub hide_server_header: bool,
91    /// Maximum allowed request size (e.g., "10MB")
92    #[serde(default = "default_max_request_size")]
93    pub max_request_size: String,
94    /// List of allowed CORS origins
95    #[serde(default)]
96    pub allowed_origins: Vec<String>,
97    /// Custom security headers to add to responses
98    #[serde(default)]
99    pub security_headers: HashMap<String, String>,
100    /// Optional rate limiting configuration
101    #[serde(default)]
102    pub rate_limiting: Option<RateLimitConfig>,
103}
104
105/// Rate limiting configuration
106#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
107pub struct RateLimitConfig {
108    /// Allowed requests per minute
109    pub requests_per_minute: u32,
110    /// Allowed burst size
111    pub burst_size: u32,
112    /// Whitelisted IPs or CIDRs
113    #[serde(default)]
114    pub whitelist: Vec<String>,
115}
116
117/// Management API configuration
118#[derive(Debug, Deserialize, Serialize, Clone)]
119pub struct ManagementConfig {
120    /// Whether the management API is enabled
121    #[serde(default = "default_management_enabled")]
122    pub enabled: bool,
123    /// Host for the management API
124    #[serde(default = "default_management_host")]
125    pub host: String,
126    /// Port for the management API
127    #[serde(default = "default_management_port")]
128    pub port: u16,
129    /// Optional API key for authentication
130    #[serde(default)]
131    pub api_key: Option<String>,
132}
133
134// Default value functions
135fn default_version() -> String {
136    env!("CARGO_PKG_VERSION").to_string()
137}
138
139fn default_log_level() -> String {
140    "info".to_string()
141}
142
143fn default_log_format() -> String {
144    "combined".to_string()
145}
146
147fn default_worker_threads() -> usize {
148    num_cpus::get().max(1)
149}
150
151fn default_max_connections() -> usize {
152    1000
153}
154
155fn default_keep_alive_timeout() -> u64 {
156    60
157}
158
159fn default_request_timeout() -> u64 {
160    30
161}
162
163fn default_buffer_size() -> String {
164    "32KB".to_string()
165}
166
167fn default_max_request_size() -> String {
168    "10MB".to_string()
169}
170
171fn default_management_enabled() -> bool {
172    false
173}
174
175fn default_management_host() -> String {
176    "127.0.0.1".to_string()
177}
178
179fn default_management_port() -> u16 {
180    7654
181}
182
183impl Default for LoggingConfig {
184    fn default() -> Self {
185        Self {
186            level: default_log_level(),
187            access_log: None,
188            error_log: None,
189            format: default_log_format(),
190            log_requests: true,
191        }
192    }
193}
194
195impl Default for PerformanceConfig {
196    fn default() -> Self {
197        Self {
198            worker_threads: default_worker_threads(),
199            max_connections: default_max_connections(),
200            keep_alive_timeout: default_keep_alive_timeout(),
201            request_timeout: default_request_timeout(),
202            read_buffer_size: default_buffer_size(),
203            write_buffer_size: default_buffer_size(),
204        }
205    }
206}
207
208impl Default for SecurityConfig {
209    fn default() -> Self {
210        let mut security_headers = HashMap::new();
211        security_headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
212        security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
213        security_headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
214        security_headers.insert(
215            "Referrer-Policy".to_string(),
216            "strict-origin-when-cross-origin".to_string(),
217        );
218
219        Self {
220            hide_server_header: false,
221            max_request_size: default_max_request_size(),
222            allowed_origins: vec![],
223            security_headers,
224            rate_limiting: None,
225        }
226    }
227}
228
229impl Default for ManagementConfig {
230    fn default() -> Self {
231        Self {
232            enabled: default_management_enabled(),
233            host: default_management_host(),
234            port: default_management_port(),
235            api_key: None,
236        }
237    }
238}
239
240impl ServerConfig {
241    pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
242        let content = fs::read_to_string(path)?;
243        let mut config: ServerConfig = toml::from_str(&content)?;
244
245        // Post-process configuration first (to set automatic defaults)
246        config.post_process()?;
247
248        // Then validate the processed configuration
249        config.validate()?;
250
251        Ok(config)
252    }
253
254    pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
255        let content = toml::to_string_pretty(self)?;
256
257        // Ensure parent directory exists
258        if let Some(parent) = Path::new(path).parent() {
259            fs::create_dir_all(parent)?;
260        }
261
262        fs::write(path, content)?;
263        log::info!("Configuration saved to {}", path);
264        Ok(())
265    }
266
267    pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
268        // Validate server info
269        if self.server.name.is_empty() {
270            return Err("Server name cannot be empty".into());
271        }
272
273        // Validate sites
274        if self.sites.is_empty() {
275            return Err("At least one site must be configured".into());
276        }
277
278        let mut default_sites = 0;
279        let mut used_hostname_ports = std::collections::HashSet::new();
280
281        for (i, site) in self.sites.iter().enumerate() {
282            site.validate().map_err(|e| format!("Site {}: {}", i, e))?;
283
284            if site.default {
285                default_sites += 1;
286            }
287
288            // Check for hostname:port conflicts across all hostnames for this site
289            // Allow multiple sites on same port with different hostnames (virtual hosting)
290            for hostname in site.get_all_hostnames() {
291                let hostname_port_key = (hostname, site.port);
292                if used_hostname_ports.contains(&hostname_port_key) {
293                    return Err(format!(
294                        "Duplicate hostname:port combination: {}:{}. Each hostname must be unique per port.",
295                        hostname, site.port
296                    )
297                    .into());
298                }
299                used_hostname_ports.insert(hostname_port_key);
300            }
301        }
302
303        if default_sites == 0 {
304            // This should not happen after post_process, but let's be defensive
305            if self.sites.len() == 1 {
306                // Single site should have been auto-marked as default in post_process
307                return Err(
308                    "Internal error: single site was not marked as default during post-processing"
309                        .into(),
310                );
311            } else {
312                return Err("At least one site must be marked as default when multiple sites are configured".into());
313            }
314        }
315        if default_sites > 1 {
316            return Err("Only one site can be marked as default".into());
317        }
318
319        // Validate that each site has proper SSL configuration if enabled
320        for site in &self.sites {
321            if site.ssl.enabled {
322                if site.ssl.auto_cert {
323                    if let Some(acme) = &site.ssl.acme {
324                        if acme.email.is_empty() {
325                            return Err(format!(
326                                "Site '{}': ACME email is required when auto_cert is enabled",
327                                site.name
328                            )
329                            .into());
330                        }
331                    } else {
332                        return Err(format!(
333                            "Site '{}': ACME configuration is required when auto_cert is enabled",
334                            site.name
335                        )
336                        .into());
337                    }
338                } else if site.ssl.cert_file.is_none() || site.ssl.key_file.is_none() {
339                    return Err(format!(
340                        "Site '{}': Manual SSL requires both cert_file and key_file",
341                        site.name
342                    )
343                    .into());
344                }
345            }
346        }
347
348        // Validate performance configuration
349        self.performance.validate()?;
350
351        // Validate security configuration
352        self.security.validate()?;
353
354        Ok(())
355    }
356
357    fn post_process(&mut self) -> Result<(), Box<dyn std::error::Error>> {
358        // If there's only one site and no site is explicitly marked as default,
359        // automatically make the single site the default
360        if self.sites.len() == 1 && !self.sites[0].default {
361            log::info!(
362                "Automatically setting single site '{}' as default",
363                self.sites[0].name
364            );
365            self.sites[0].default = true;
366        }
367
368        // Sort sites by priority (default site first, then by name)
369        self.sites.sort_by(|a, b| match (a.default, b.default) {
370            (true, false) => std::cmp::Ordering::Less,
371            (false, true) => std::cmp::Ordering::Greater,
372            _ => a.name.cmp(&b.name),
373        });
374
375        // No additional post-processing needed for per-site SSL
376        // Each site manages its own SSL configuration
377
378        Ok(())
379    }
380
381    pub fn find_site_by_host_port(&self, host: &str, port: u16) -> Option<&SiteConfig> {
382        // First try to match both hostname and port exactly using the new method
383        for site in &self.sites {
384            if site.handles_hostname_port(host, port) {
385                return Some(site);
386            }
387        }
388
389        // Then try to match just the port (for cases where hostname might not match exactly)
390        for site in &self.sites {
391            if site.port == port {
392                return Some(site);
393            }
394        }
395
396        // Finally, return the default site if no match
397        self.sites.iter().find(|site| site.default)
398    }
399
400    pub fn get_ssl_domains(&self) -> Vec<String> {
401        self.sites
402            .iter()
403            .filter_map(|site| {
404                if site.ssl.enabled {
405                    Some(
406                        site.get_all_ssl_domains()
407                            .into_iter()
408                            .map(|s| s.to_string())
409                            .collect::<Vec<_>>(),
410                    )
411                } else {
412                    None
413                }
414            })
415            .flatten()
416            .collect()
417    }
418
419    pub fn get_site_by_domain(&self, domain: &str) -> Option<&SiteConfig> {
420        self.sites.iter().find(|site| {
421            site.handles_hostname(domain) || site.ssl.domains.contains(&domain.to_string())
422        })
423    }
424
425    pub fn reload_from_file(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
426        let new_config = Self::load_from_file(path)?;
427        *self = new_config;
428        log::info!("Configuration reloaded from {}", path);
429        Ok(())
430    }
431}
432
433impl PerformanceConfig {
434    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
435        if self.worker_threads == 0 {
436            return Err("Worker threads must be greater than 0".into());
437        }
438
439        if self.max_connections == 0 {
440            return Err("Max connections must be greater than 0".into());
441        }
442
443        if self.keep_alive_timeout == 0 {
444            return Err("Keep alive timeout must be greater than 0".into());
445        }
446
447        if self.request_timeout == 0 {
448            return Err("Request timeout must be greater than 0".into());
449        }
450
451        // Validate buffer sizes
452        self.parse_buffer_size(&self.read_buffer_size)
453            .map_err(|_| "Invalid read buffer size format")?;
454        self.parse_buffer_size(&self.write_buffer_size)
455            .map_err(|_| "Invalid write buffer size format")?;
456
457        Ok(())
458    }
459
460    pub fn parse_buffer_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
461        let size_str = size_str.trim().to_uppercase();
462
463        if let Some(value) = size_str.strip_suffix("KB") {
464            Ok(value.parse::<usize>()? * 1024)
465        } else if let Some(value) = size_str.strip_suffix("MB") {
466            Ok(value.parse::<usize>()? * 1024 * 1024)
467        } else if let Some(value) = size_str.strip_suffix("GB") {
468            Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
469        } else if let Some(value) = size_str.strip_suffix("B") {
470            Ok(value.parse::<usize>()?)
471        } else {
472            // Assume bytes if no suffix
473            Ok(size_str.parse::<usize>()?)
474        }
475    }
476}
477
478impl SecurityConfig {
479    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
480        // Validate max request size format
481        self.parse_size(&self.max_request_size)
482            .map_err(|_| "Invalid max request size format")?;
483
484        // Validate rate limiting configuration
485        if let Some(rate_limit) = &self.rate_limiting {
486            if rate_limit.requests_per_minute == 0 {
487                return Err("Rate limit requests per minute must be greater than 0".into());
488            }
489            if rate_limit.burst_size == 0 {
490                return Err("Rate limit burst size must be greater than 0".into());
491            }
492        }
493
494        Ok(())
495    }
496
497    pub fn parse_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
498        let size_str = size_str.trim().to_uppercase();
499
500        if let Some(value) = size_str.strip_suffix("KB") {
501            Ok(value.parse::<usize>()? * 1024)
502        } else if let Some(value) = size_str.strip_suffix("MB") {
503            Ok(value.parse::<usize>()? * 1024 * 1024)
504        } else if let Some(value) = size_str.strip_suffix("GB") {
505            Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
506        } else if let Some(value) = size_str.strip_suffix("B") {
507            Ok(value.parse::<usize>()?)
508        } else {
509            // Assume bytes if no suffix
510            Ok(size_str.parse::<usize>()?)
511        }
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use tempfile::NamedTempFile;
519
520    #[test]
521    fn test_buffer_size_parsing() {
522        let config = PerformanceConfig::default();
523
524        assert_eq!(config.parse_buffer_size("1024").unwrap(), 1024);
525        assert_eq!(config.parse_buffer_size("1KB").unwrap(), 1024);
526        assert_eq!(config.parse_buffer_size("1MB").unwrap(), 1024 * 1024);
527        assert_eq!(config.parse_buffer_size("1GB").unwrap(), 1024 * 1024 * 1024);
528        assert_eq!(config.parse_buffer_size("500B").unwrap(), 500);
529
530        assert!(config.parse_buffer_size("invalid").is_err());
531        assert!(config.parse_buffer_size("").is_err());
532    }
533
534    #[test]
535    fn test_automatic_default_site() {
536        use crate::config::SiteConfig;
537
538        // Test that a single site without explicit default=true gets auto-marked as default
539        let mut config = ServerConfig {
540            server: ServerInfo {
541                name: "test-server".to_string(),
542                version: "1.0.0".to_string(),
543                description: "Test server".to_string(),
544            },
545            sites: vec![SiteConfig {
546                name: "single-site".to_string(),
547                hostname: "localhost".to_string(),
548                hostnames: vec![],
549                port: 8080,
550                static_dir: "/tmp/static".to_string(),
551                default: false, // Explicitly NOT marked as default
552                api_only: false,
553                headers: HashMap::new(),
554                redirect_to_https: false,
555                index_files: vec!["index.html".to_string()],
556                error_pages: HashMap::new(),
557                compression: Default::default(),
558                cache: Default::default(),
559                access_control: Default::default(),
560                ssl: Default::default(),
561                proxy: Default::default(),
562            }],
563            logging: LoggingConfig::default(),
564            performance: PerformanceConfig::default(),
565            security: SecurityConfig::default(),
566            management: ManagementConfig::default(),
567        };
568
569        // Before post_process, the site should not be marked as default
570        assert!(!config.sites[0].default);
571
572        // After post_process, the single site should be automatically marked as default
573        config.post_process().unwrap();
574        assert!(config.sites[0].default);
575
576        // Validation should pass
577        assert!(config.validate().is_ok());
578
579        // Test with multiple sites - auto-default should not apply
580        config.sites.push(SiteConfig {
581            name: "second-site".to_string(),
582            hostname: "example.com".to_string(),
583            hostnames: vec![],
584            port: 8081,
585            static_dir: "/tmp/static2".to_string(),
586            default: false,
587            api_only: false,
588            headers: HashMap::new(),
589            redirect_to_https: false,
590            index_files: vec!["index.html".to_string()],
591            error_pages: HashMap::new(),
592            compression: Default::default(),
593            cache: Default::default(),
594            access_control: Default::default(),
595            ssl: Default::default(),
596            proxy: Default::default(),
597        });
598
599        // Reset first site's default flag
600        config.sites[0].default = false;
601
602        // Post-process should not auto-mark any site as default when there are multiple
603        config.post_process().unwrap();
604        assert!(!config.sites[0].default);
605        assert!(!config.sites[1].default);
606
607        // Validation should fail because no site is marked as default
608        assert!(config.validate().is_err());
609    }
610
611    #[test]
612    fn test_security_config_validation() {
613        let mut config = SecurityConfig::default();
614        assert!(config.validate().is_ok());
615
616        config.max_request_size = "invalid".to_string();
617        assert!(config.validate().is_err());
618
619        config.max_request_size = "10MB".to_string();
620        config.rate_limiting = Some(RateLimitConfig {
621            requests_per_minute: 0,
622            burst_size: 10,
623            whitelist: vec![],
624        });
625        assert!(config.validate().is_err());
626    }
627
628    #[tokio::test]
629    async fn test_config_save_load() {
630        use crate::config::SiteConfig;
631
632        let config = ServerConfig {
633            server: ServerInfo {
634                name: "test-server".to_string(),
635                version: "1.0.0".to_string(),
636                description: "Test server".to_string(),
637            },
638            sites: vec![SiteConfig {
639                name: "test-site".to_string(),
640                hostname: "localhost".to_string(),
641                hostnames: vec![],
642                port: 8080,
643                static_dir: "/tmp/static".to_string(),
644                default: true,
645                api_only: false,
646                headers: HashMap::new(),
647                redirect_to_https: false,
648                index_files: vec!["index.html".to_string()],
649                error_pages: HashMap::new(),
650                compression: Default::default(),
651                cache: Default::default(),
652                access_control: Default::default(),
653                ssl: Default::default(),
654                proxy: Default::default(),
655            }],
656            logging: LoggingConfig::default(),
657            performance: PerformanceConfig::default(),
658            security: SecurityConfig::default(),
659            management: ManagementConfig::default(),
660        };
661
662        let temp_file = NamedTempFile::new().unwrap();
663        let path = temp_file.path().to_str().unwrap();
664
665        // Test save
666        config.save_to_file(path).unwrap();
667
668        // Test load
669        let loaded_config = ServerConfig::load_from_file(path).unwrap();
670        assert_eq!(config.server.name, loaded_config.server.name);
671        assert_eq!(config.sites.len(), loaded_config.sites.len());
672    }
673}