bws_web_server/config/
server.rs

1use crate::config::SiteConfig;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::Path;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
8pub struct ServerConfig {
9    pub server: ServerInfo,
10    pub sites: Vec<SiteConfig>,
11    #[serde(default)]
12    pub logging: LoggingConfig,
13    #[serde(default)]
14    pub performance: PerformanceConfig,
15    #[serde(default)]
16    pub security: SecurityConfig,
17}
18
19#[derive(Debug, Deserialize, Serialize, Clone)]
20pub struct ServerInfo {
21    pub name: String,
22    #[serde(default = "default_version")]
23    pub version: String,
24    #[serde(default)]
25    pub description: String,
26}
27
28#[derive(Debug, Deserialize, Serialize, Clone)]
29pub struct LoggingConfig {
30    #[serde(default = "default_log_level")]
31    pub level: String,
32    #[serde(default)]
33    pub access_log: Option<String>,
34    #[serde(default)]
35    pub error_log: Option<String>,
36    #[serde(default = "default_log_format")]
37    pub format: String,
38    #[serde(default)]
39    pub log_requests: bool,
40}
41
42#[derive(Debug, Deserialize, Serialize, Clone)]
43pub struct PerformanceConfig {
44    #[serde(default = "default_worker_threads")]
45    pub worker_threads: usize,
46    #[serde(default = "default_max_connections")]
47    pub max_connections: usize,
48    #[serde(default = "default_keep_alive_timeout")]
49    pub keep_alive_timeout: u64,
50    #[serde(default = "default_request_timeout")]
51    pub request_timeout: u64,
52    #[serde(default = "default_buffer_size")]
53    pub read_buffer_size: String,
54    #[serde(default = "default_buffer_size")]
55    pub write_buffer_size: String,
56}
57
58#[derive(Debug, Deserialize, Serialize, Clone)]
59pub struct SecurityConfig {
60    #[serde(default)]
61    pub hide_server_header: bool,
62    #[serde(default = "default_max_request_size")]
63    pub max_request_size: String,
64    #[serde(default)]
65    pub allowed_origins: Vec<String>,
66    #[serde(default)]
67    pub security_headers: HashMap<String, String>,
68    #[serde(default)]
69    pub rate_limiting: Option<RateLimitConfig>,
70}
71
72#[derive(Debug, Deserialize, Serialize, Clone)]
73pub struct RateLimitConfig {
74    pub requests_per_minute: u32,
75    pub burst_size: u32,
76    #[serde(default)]
77    pub whitelist: Vec<String>,
78}
79
80// Default value functions
81fn default_version() -> String {
82    env!("CARGO_PKG_VERSION").to_string()
83}
84
85fn default_log_level() -> String {
86    "info".to_string()
87}
88
89fn default_log_format() -> String {
90    "combined".to_string()
91}
92
93fn default_worker_threads() -> usize {
94    num_cpus::get().max(1)
95}
96
97fn default_max_connections() -> usize {
98    1000
99}
100
101fn default_keep_alive_timeout() -> u64 {
102    60
103}
104
105fn default_request_timeout() -> u64 {
106    30
107}
108
109fn default_buffer_size() -> String {
110    "32KB".to_string()
111}
112
113fn default_max_request_size() -> String {
114    "10MB".to_string()
115}
116
117impl Default for LoggingConfig {
118    fn default() -> Self {
119        Self {
120            level: default_log_level(),
121            access_log: None,
122            error_log: None,
123            format: default_log_format(),
124            log_requests: true,
125        }
126    }
127}
128
129impl Default for PerformanceConfig {
130    fn default() -> Self {
131        Self {
132            worker_threads: default_worker_threads(),
133            max_connections: default_max_connections(),
134            keep_alive_timeout: default_keep_alive_timeout(),
135            request_timeout: default_request_timeout(),
136            read_buffer_size: default_buffer_size(),
137            write_buffer_size: default_buffer_size(),
138        }
139    }
140}
141
142impl Default for SecurityConfig {
143    fn default() -> Self {
144        let mut security_headers = HashMap::new();
145        security_headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
146        security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
147        security_headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
148        security_headers.insert(
149            "Referrer-Policy".to_string(),
150            "strict-origin-when-cross-origin".to_string(),
151        );
152
153        Self {
154            hide_server_header: false,
155            max_request_size: default_max_request_size(),
156            allowed_origins: vec![],
157            security_headers,
158            rate_limiting: None,
159        }
160    }
161}
162
163impl ServerConfig {
164    pub fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
165        let content = fs::read_to_string(path)?;
166        let mut config: ServerConfig = toml::from_str(&content)?;
167
168        // Validate configuration
169        config.validate()?;
170
171        // Post-process configuration
172        config.post_process()?;
173
174        Ok(config)
175    }
176
177    pub fn save_to_file(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
178        let content = toml::to_string_pretty(self)?;
179
180        // Ensure parent directory exists
181        if let Some(parent) = Path::new(path).parent() {
182            fs::create_dir_all(parent)?;
183        }
184
185        fs::write(path, content)?;
186        log::info!("Configuration saved to {}", path);
187        Ok(())
188    }
189
190    pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
191        // Validate server info
192        if self.server.name.is_empty() {
193            return Err("Server name cannot be empty".into());
194        }
195
196        // Validate sites
197        if self.sites.is_empty() {
198            return Err("At least one site must be configured".into());
199        }
200
201        let mut default_sites = 0;
202        let mut used_ports = std::collections::HashSet::new();
203
204        for (i, site) in self.sites.iter().enumerate() {
205            site.validate().map_err(|e| format!("Site {}: {}", i, e))?;
206
207            if site.default {
208                default_sites += 1;
209            }
210
211            // Check for port conflicts (same hostname and port)
212            let port_key = (&site.hostname, site.port);
213            if used_ports.contains(&port_key) {
214                return Err(format!(
215                    "Duplicate hostname:port combination: {}:{}",
216                    site.hostname, site.port
217                )
218                .into());
219            }
220            used_ports.insert(port_key);
221        }
222
223        if default_sites == 0 {
224            return Err("At least one site must be marked as default".into());
225        }
226        if default_sites > 1 {
227            return Err("Only one site can be marked as default".into());
228        }
229
230        // Validate that each site has proper SSL configuration if enabled
231        for site in &self.sites {
232            if site.ssl.enabled {
233                if site.ssl.auto_cert {
234                    if let Some(acme) = &site.ssl.acme {
235                        if acme.email.is_empty() {
236                            return Err(format!(
237                                "Site '{}': ACME email is required when auto_cert is enabled",
238                                site.name
239                            )
240                            .into());
241                        }
242                    } else {
243                        return Err(format!(
244                            "Site '{}': ACME configuration is required when auto_cert is enabled",
245                            site.name
246                        )
247                        .into());
248                    }
249                } else if site.ssl.cert_file.is_none() || site.ssl.key_file.is_none() {
250                    return Err(format!(
251                        "Site '{}': Manual SSL requires both cert_file and key_file",
252                        site.name
253                    )
254                    .into());
255                }
256            }
257        }
258
259        // Validate performance configuration
260        self.performance.validate()?;
261
262        // Validate security configuration
263        self.security.validate()?;
264
265        Ok(())
266    }
267
268    fn post_process(&mut self) -> Result<(), Box<dyn std::error::Error>> {
269        // Sort sites by priority (default site first, then by name)
270        self.sites.sort_by(|a, b| match (a.default, b.default) {
271            (true, false) => std::cmp::Ordering::Less,
272            (false, true) => std::cmp::Ordering::Greater,
273            _ => a.name.cmp(&b.name),
274        });
275
276        // No additional post-processing needed for per-site SSL
277        // Each site manages its own SSL configuration
278
279        Ok(())
280    }
281
282    pub fn find_site_by_host_port(&self, host: &str, port: u16) -> Option<&SiteConfig> {
283        // First try to match both hostname and port exactly
284        for site in &self.sites {
285            if site.hostname == host && site.port == port {
286                return Some(site);
287            }
288        }
289
290        // Then try to match just the port (for cases where hostname might not match exactly)
291        for site in &self.sites {
292            if site.port == port {
293                return Some(site);
294            }
295        }
296
297        // Finally, return the default site if no match
298        self.sites.iter().find(|site| site.default)
299    }
300
301    pub fn get_ssl_domains(&self) -> Vec<String> {
302        self.sites
303            .iter()
304            .filter_map(|site| {
305                if site.ssl.enabled {
306                    let mut domains = vec![site.hostname.clone()];
307                    domains.extend(site.ssl.domains.clone());
308                    Some(domains)
309                } else {
310                    None
311                }
312            })
313            .flatten()
314            .collect()
315    }
316
317    pub fn get_site_by_domain(&self, domain: &str) -> Option<&SiteConfig> {
318        self.sites
319            .iter()
320            .find(|site| site.hostname == domain || site.ssl.domains.contains(&domain.to_string()))
321    }
322
323    pub fn reload_from_file(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
324        let new_config = Self::load_from_file(path)?;
325        *self = new_config;
326        log::info!("Configuration reloaded from {}", path);
327        Ok(())
328    }
329}
330
331impl PerformanceConfig {
332    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
333        if self.worker_threads == 0 {
334            return Err("Worker threads must be greater than 0".into());
335        }
336
337        if self.max_connections == 0 {
338            return Err("Max connections must be greater than 0".into());
339        }
340
341        if self.keep_alive_timeout == 0 {
342            return Err("Keep alive timeout must be greater than 0".into());
343        }
344
345        if self.request_timeout == 0 {
346            return Err("Request timeout must be greater than 0".into());
347        }
348
349        // Validate buffer sizes
350        self.parse_buffer_size(&self.read_buffer_size)
351            .map_err(|_| "Invalid read buffer size format")?;
352        self.parse_buffer_size(&self.write_buffer_size)
353            .map_err(|_| "Invalid write buffer size format")?;
354
355        Ok(())
356    }
357
358    pub fn parse_buffer_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
359        let size_str = size_str.trim().to_uppercase();
360
361        if let Some(value) = size_str.strip_suffix("KB") {
362            Ok(value.parse::<usize>()? * 1024)
363        } else if let Some(value) = size_str.strip_suffix("MB") {
364            Ok(value.parse::<usize>()? * 1024 * 1024)
365        } else if let Some(value) = size_str.strip_suffix("GB") {
366            Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
367        } else if let Some(value) = size_str.strip_suffix("B") {
368            Ok(value.parse::<usize>()?)
369        } else {
370            // Assume bytes if no suffix
371            Ok(size_str.parse::<usize>()?)
372        }
373    }
374}
375
376impl SecurityConfig {
377    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
378        // Validate max request size format
379        self.parse_size(&self.max_request_size)
380            .map_err(|_| "Invalid max request size format")?;
381
382        // Validate rate limiting configuration
383        if let Some(rate_limit) = &self.rate_limiting {
384            if rate_limit.requests_per_minute == 0 {
385                return Err("Rate limit requests per minute must be greater than 0".into());
386            }
387            if rate_limit.burst_size == 0 {
388                return Err("Rate limit burst size must be greater than 0".into());
389            }
390        }
391
392        Ok(())
393    }
394
395    pub fn parse_size(&self, size_str: &str) -> Result<usize, Box<dyn std::error::Error>> {
396        let size_str = size_str.trim().to_uppercase();
397
398        if let Some(value) = size_str.strip_suffix("KB") {
399            Ok(value.parse::<usize>()? * 1024)
400        } else if let Some(value) = size_str.strip_suffix("MB") {
401            Ok(value.parse::<usize>()? * 1024 * 1024)
402        } else if let Some(value) = size_str.strip_suffix("GB") {
403            Ok(value.parse::<usize>()? * 1024 * 1024 * 1024)
404        } else if let Some(value) = size_str.strip_suffix("B") {
405            Ok(value.parse::<usize>()?)
406        } else {
407            // Assume bytes if no suffix
408            Ok(size_str.parse::<usize>()?)
409        }
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use tempfile::NamedTempFile;
417
418    #[test]
419    fn test_buffer_size_parsing() {
420        let config = PerformanceConfig::default();
421
422        assert_eq!(config.parse_buffer_size("1024").unwrap(), 1024);
423        assert_eq!(config.parse_buffer_size("1KB").unwrap(), 1024);
424        assert_eq!(config.parse_buffer_size("1MB").unwrap(), 1024 * 1024);
425        assert_eq!(config.parse_buffer_size("1GB").unwrap(), 1024 * 1024 * 1024);
426        assert_eq!(config.parse_buffer_size("500B").unwrap(), 500);
427
428        assert!(config.parse_buffer_size("invalid").is_err());
429        assert!(config.parse_buffer_size("").is_err());
430    }
431
432    #[test]
433    fn test_security_config_validation() {
434        let mut config = SecurityConfig::default();
435        assert!(config.validate().is_ok());
436
437        config.max_request_size = "invalid".to_string();
438        assert!(config.validate().is_err());
439
440        config.max_request_size = "10MB".to_string();
441        config.rate_limiting = Some(RateLimitConfig {
442            requests_per_minute: 0,
443            burst_size: 10,
444            whitelist: vec![],
445        });
446        assert!(config.validate().is_err());
447    }
448
449    #[tokio::test]
450    async fn test_config_save_load() {
451        use crate::config::SiteConfig;
452
453        let config = ServerConfig {
454            server: ServerInfo {
455                name: "test-server".to_string(),
456                version: "1.0.0".to_string(),
457                description: "Test server".to_string(),
458            },
459            sites: vec![SiteConfig {
460                name: "test-site".to_string(),
461                hostname: "localhost".to_string(),
462                port: 8080,
463                static_dir: "/tmp/static".to_string(),
464                default: true,
465                api_only: false,
466                headers: HashMap::new(),
467                redirect_to_https: false,
468                index_files: vec!["index.html".to_string()],
469                error_pages: HashMap::new(),
470                compression: Default::default(),
471                cache: Default::default(),
472                access_control: Default::default(),
473                ssl: Default::default(),
474                proxy: Default::default(),
475            }],
476            logging: LoggingConfig::default(),
477            performance: PerformanceConfig::default(),
478            security: SecurityConfig::default(),
479        };
480
481        let temp_file = NamedTempFile::new().unwrap();
482        let path = temp_file.path().to_str().unwrap();
483
484        // Test save
485        config.save_to_file(path).unwrap();
486
487        // Test load
488        let loaded_config = ServerConfig::load_from_file(path).unwrap();
489        assert_eq!(config.server.name, loaded_config.server.name);
490        assert_eq!(config.sites.len(), loaded_config.sites.len());
491    }
492}