bws_web_server/config/
site.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4
5#[derive(Debug, Deserialize, Serialize, Clone)]
6pub struct SiteConfig {
7    pub name: String,
8    pub hostname: String,
9    pub port: u16,
10    pub static_dir: String,
11    #[serde(default)]
12    pub default: bool,
13    #[serde(default)]
14    pub api_only: bool,
15    #[serde(default)]
16    pub headers: HashMap<String, String>,
17    #[serde(default)]
18    pub redirect_to_https: bool,
19    #[serde(default)]
20    pub index_files: Vec<String>,
21    #[serde(default)]
22    pub error_pages: HashMap<u16, String>,
23    #[serde(default)]
24    pub compression: CompressionConfig,
25    #[serde(default)]
26    pub cache: CacheConfig,
27    #[serde(default)]
28    pub access_control: AccessControlConfig,
29    #[serde(default)]
30    pub ssl: SiteSslConfig,
31    #[serde(default)]
32    pub proxy: ProxyConfig,
33}
34
35#[derive(Debug, Deserialize, Serialize, Clone, Default)]
36pub struct SiteSslConfig {
37    #[serde(default)]
38    pub enabled: bool,
39    #[serde(default)]
40    pub auto_cert: bool,
41    #[serde(default)]
42    pub domains: Vec<String>, // Additional domains beyond hostname
43    #[serde(default)]
44    pub cert_file: Option<String>, // Manual certificate file
45    #[serde(default)]
46    pub key_file: Option<String>, // Manual key file
47    #[serde(default)]
48    pub acme: Option<SiteAcmeConfig>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, Default)]
52pub struct SiteAcmeConfig {
53    #[serde(default)]
54    pub enabled: bool,
55    #[serde(default)]
56    pub email: String,
57    #[serde(default)]
58    pub staging: bool,
59    #[serde(default)]
60    pub challenge_dir: Option<String>, // Make optional for automatic management
61}
62
63#[derive(Debug, Deserialize, Serialize, Clone, Default)]
64pub struct ProxyConfig {
65    #[serde(default)]
66    pub enabled: bool,
67    #[serde(default)]
68    pub upstreams: Vec<UpstreamConfig>,
69    #[serde(default)]
70    pub routes: Vec<ProxyRoute>,
71    #[serde(default)]
72    pub health_check: HealthCheckConfig,
73    #[serde(default)]
74    pub load_balancing: LoadBalancingConfig,
75    #[serde(default)]
76    pub timeout: TimeoutConfig,
77    #[serde(default)]
78    pub headers: ProxyHeadersConfig,
79}
80
81#[derive(Debug, Deserialize, Serialize, Clone)]
82pub struct UpstreamConfig {
83    pub name: String,
84    pub url: String,
85    #[serde(default = "default_weight")]
86    pub weight: u32,
87    #[serde(default)]
88    pub max_conns: Option<u32>,
89}
90
91#[derive(Debug, Deserialize, Serialize, Clone)]
92pub struct ProxyRoute {
93    pub path: String,
94    pub upstream: String, // References upstream name
95    #[serde(default)]
96    pub strip_prefix: bool,
97    #[serde(default)]
98    pub rewrite_target: Option<String>,
99    #[serde(default)]
100    pub websocket: bool, // Enable WebSocket proxying for this route
101}
102
103#[derive(Debug, Deserialize, Serialize, Clone)]
104pub struct HealthCheckConfig {
105    #[serde(default)]
106    pub enabled: bool,
107    #[serde(default = "default_health_path")]
108    pub path: String,
109    #[serde(default = "default_health_interval")]
110    pub interval: u64, // seconds
111    #[serde(default = "default_health_timeout")]
112    pub timeout: u64, // seconds
113    #[serde(default = "default_health_retries")]
114    pub retries: u32,
115}
116
117#[derive(Debug, Deserialize, Serialize, Clone)]
118pub struct LoadBalancingConfig {
119    #[serde(default = "default_lb_method")]
120    pub method: String, // "round_robin", "least_conn", "weighted"
121    #[serde(default)]
122    pub sticky_sessions: bool,
123}
124
125#[derive(Debug, Deserialize, Serialize, Clone)]
126pub struct TimeoutConfig {
127    #[serde(default = "default_connect_timeout")]
128    pub connect: u64, // seconds
129    #[serde(default = "default_read_timeout")]
130    pub read: u64, // seconds
131    #[serde(default = "default_write_timeout")]
132    pub write: u64, // seconds
133}
134
135#[derive(Debug, Deserialize, Serialize, Clone)]
136pub struct ProxyHeadersConfig {
137    #[serde(default)]
138    pub preserve_host: bool,
139    #[serde(default)]
140    pub add_forwarded: bool,
141    #[serde(default)]
142    pub add_x_forwarded: bool,
143    #[serde(default)]
144    pub remove: Vec<String>,
145    #[serde(default)]
146    pub add: HashMap<String, String>,
147}
148
149fn default_weight() -> u32 {
150    1
151}
152fn default_health_path() -> String {
153    "/health".to_string()
154}
155fn default_health_interval() -> u64 {
156    30
157}
158fn default_health_timeout() -> u64 {
159    5
160}
161fn default_health_retries() -> u32 {
162    3
163}
164fn default_lb_method() -> String {
165    "round_robin".to_string()
166}
167fn default_connect_timeout() -> u64 {
168    10
169}
170fn default_read_timeout() -> u64 {
171    30
172}
173fn default_write_timeout() -> u64 {
174    30
175}
176
177impl Default for HealthCheckConfig {
178    fn default() -> Self {
179        Self {
180            enabled: false,
181            path: default_health_path(),
182            interval: default_health_interval(),
183            timeout: default_health_timeout(),
184            retries: default_health_retries(),
185        }
186    }
187}
188
189impl Default for LoadBalancingConfig {
190    fn default() -> Self {
191        Self {
192            method: default_lb_method(),
193            sticky_sessions: false,
194        }
195    }
196}
197
198impl Default for TimeoutConfig {
199    fn default() -> Self {
200        Self {
201            connect: default_connect_timeout(),
202            read: default_read_timeout(),
203            write: default_write_timeout(),
204        }
205    }
206}
207
208impl Default for ProxyHeadersConfig {
209    fn default() -> Self {
210        Self {
211            preserve_host: true,
212            add_forwarded: true,
213            add_x_forwarded: true,
214            remove: Vec::new(),
215            add: HashMap::new(),
216        }
217    }
218}
219
220#[derive(Debug, Deserialize, Serialize, Clone)]
221pub struct CompressionConfig {
222    #[serde(default)]
223    pub enabled: bool,
224    #[serde(default = "default_compression_types")]
225    pub types: Vec<String>,
226    #[serde(default = "default_compression_level")]
227    pub level: u32,
228    #[serde(default = "default_min_size")]
229    pub min_size: usize,
230}
231
232#[derive(Debug, Deserialize, Serialize, Clone)]
233pub struct CacheConfig {
234    #[serde(default)]
235    pub enabled: bool,
236    #[serde(default = "default_cache_control")]
237    pub cache_control: String,
238    #[serde(default)]
239    pub etag_enabled: bool,
240    #[serde(default)]
241    pub last_modified_enabled: bool,
242    #[serde(default)]
243    pub max_age_static: u32,
244    #[serde(default)]
245    pub max_age_dynamic: u32,
246}
247
248#[derive(Debug, Deserialize, Serialize, Clone)]
249pub struct AccessControlConfig {
250    #[serde(default)]
251    pub allow_methods: Vec<String>,
252    #[serde(default)]
253    pub allow_headers: Vec<String>,
254    #[serde(default)]
255    pub allow_origins: Vec<String>,
256    #[serde(default)]
257    pub allow_credentials: bool,
258    #[serde(default)]
259    pub max_age: u32,
260}
261
262// Default value functions
263fn default_compression_types() -> Vec<String> {
264    vec![
265        "text/html".to_string(),
266        "text/css".to_string(),
267        "text/javascript".to_string(),
268        "application/javascript".to_string(),
269        "application/json".to_string(),
270        "text/xml".to_string(),
271        "application/xml".to_string(),
272        "text/plain".to_string(),
273    ]
274}
275
276fn default_compression_level() -> u32 {
277    6
278}
279
280fn default_min_size() -> usize {
281    1024 // 1KB
282}
283
284fn default_cache_control() -> String {
285    "public, max-age=3600".to_string()
286}
287
288impl Default for CompressionConfig {
289    fn default() -> Self {
290        Self {
291            enabled: true,
292            types: default_compression_types(),
293            level: default_compression_level(),
294            min_size: default_min_size(),
295        }
296    }
297}
298
299impl Default for CacheConfig {
300    fn default() -> Self {
301        Self {
302            enabled: true,
303            cache_control: default_cache_control(),
304            etag_enabled: true,
305            last_modified_enabled: true,
306            max_age_static: 3600, // 1 hour for static files
307            max_age_dynamic: 300, // 5 minutes for dynamic content
308        }
309    }
310}
311
312impl Default for AccessControlConfig {
313    fn default() -> Self {
314        Self {
315            allow_methods: vec!["GET".to_string(), "HEAD".to_string(), "OPTIONS".to_string()],
316            allow_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
317            allow_origins: vec!["*".to_string()],
318            allow_credentials: false,
319            max_age: 86400, // 24 hours
320        }
321    }
322}
323
324impl SiteConfig {
325    pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
326        // Validate required fields
327        if self.name.is_empty() {
328            return Err("Site name cannot be empty".into());
329        }
330
331        if self.hostname.is_empty() {
332            return Err("Site hostname cannot be empty".into());
333        }
334
335        if self.port == 0 {
336            return Err("Site port must be greater than 0".into());
337        }
338
339        if self.static_dir.is_empty() {
340            return Err("Site static_dir cannot be empty".into());
341        }
342
343        // Validate static directory exists (or can be created)
344        let static_path = Path::new(&self.static_dir);
345        if !static_path.exists() {
346            log::warn!(
347                "Static directory does not exist for site '{}': {}",
348                self.name,
349                self.static_dir
350            );
351        }
352
353        // Validate hostname format (basic check)
354        if !self.is_valid_hostname() {
355            return Err(format!("Invalid hostname format: {}", self.hostname).into());
356        }
357
358        // Validate port range
359        if self.port < 1 {
360            return Err(format!("Invalid port number: {}", self.port).into());
361        }
362
363        // Validate SSL configuration
364        if self.ssl.enabled {
365            if self.ssl.auto_cert {
366                if let Some(acme) = &self.ssl.acme {
367                    if acme.email.is_empty() {
368                        return Err("ACME email is required when auto_cert is enabled".into());
369                    }
370                } else {
371                    return Err("ACME configuration is required when auto_cert is enabled".into());
372                }
373            } else if self.ssl.cert_file.is_none() || self.ssl.key_file.is_none() {
374                return Err("Manual SSL requires both cert_file and key_file".into());
375            }
376        }
377
378        // Validate index files
379        for index_file in &self.index_files {
380            if index_file.is_empty() {
381                return Err("Index file name cannot be empty".into());
382            }
383        }
384
385        // Validate error pages
386        for (status_code, error_page) in &self.error_pages {
387            if *status_code < 100 || *status_code > 999 {
388                return Err(format!("Invalid HTTP status code: {}", status_code).into());
389            }
390            if error_page.is_empty() {
391                return Err(
392                    format!("Error page path cannot be empty for status {}", status_code).into(),
393                );
394            }
395        }
396
397        // Validate compression configuration
398        self.compression.validate()?;
399
400        // Validate cache configuration
401        self.cache.validate()?;
402
403        // Validate access control configuration
404        self.access_control.validate()?;
405
406        Ok(())
407    }
408
409    fn is_valid_hostname(&self) -> bool {
410        // Basic hostname validation
411        if self.hostname.is_empty() || self.hostname.len() > 253 {
412            return false;
413        }
414
415        // Allow localhost and IP addresses for development
416        if self.hostname == "localhost"
417            || self.hostname.starts_with("127.")
418            || self.hostname.starts_with("0.0.0.0")
419        {
420            return true;
421        }
422
423        // Basic domain name validation
424        self.hostname.split('.').all(|label| {
425            !label.is_empty()
426                && label.len() <= 63
427                && label.chars().all(|c| c.is_alphanumeric() || c == '-')
428                && !label.starts_with('-')
429                && !label.ends_with('-')
430        })
431    }
432
433    pub fn get_ssl_domain(&self) -> Option<&str> {
434        if self.ssl.enabled {
435            // Return primary domain (hostname) plus any additional domains
436            Some(&self.hostname)
437        } else {
438            None
439        }
440    }
441
442    pub fn get_all_ssl_domains(&self) -> Vec<&str> {
443        if self.ssl.enabled {
444            let mut domains = vec![self.hostname.as_str()];
445            for domain in &self.ssl.domains {
446                domains.push(domain.as_str());
447            }
448            domains
449        } else {
450            Vec::new()
451        }
452    }
453
454    pub fn is_ssl_enabled(&self) -> bool {
455        self.ssl.enabled
456    }
457
458    pub fn get_index_files(&self) -> Vec<&str> {
459        if self.index_files.is_empty() {
460            vec!["index.html", "index.htm"]
461        } else {
462            self.index_files.iter().map(|s| s.as_str()).collect()
463        }
464    }
465
466    pub fn get_error_page(&self, status_code: u16) -> Option<&str> {
467        self.error_pages.get(&status_code).map(|s| s.as_str())
468    }
469
470    pub fn should_compress(&self, content_type: &str, content_length: usize) -> bool {
471        self.compression.enabled
472            && content_length >= self.compression.min_size
473            && self
474                .compression
475                .types
476                .iter()
477                .any(|t| content_type.starts_with(t))
478    }
479
480    pub fn get_cache_headers(&self, is_static: bool) -> Vec<(String, String)> {
481        let mut headers = Vec::new();
482
483        if self.cache.enabled {
484            let max_age = if is_static {
485                self.cache.max_age_static
486            } else {
487                self.cache.max_age_dynamic
488            };
489
490            headers.push((
491                "Cache-Control".to_string(),
492                format!("public, max-age={}", max_age),
493            ));
494
495            if self.cache.etag_enabled {
496                // ETag would be calculated based on file content
497                // This is a placeholder - actual implementation would calculate based on file
498                headers.push(("ETag".to_string(), "\"placeholder\"".to_string()));
499            }
500
501            if self.cache.last_modified_enabled {
502                // Last-Modified would be based on file modification time
503                // This is a placeholder - actual implementation would use file mtime
504                headers.push(("Last-Modified".to_string(), "placeholder".to_string()));
505            }
506        }
507
508        headers
509    }
510
511    pub fn get_cors_headers(&self) -> Vec<(String, String)> {
512        let mut headers = Vec::new();
513
514        if !self.access_control.allow_origins.is_empty() {
515            headers.push((
516                "Access-Control-Allow-Origin".to_string(),
517                self.access_control.allow_origins.join(", "),
518            ));
519        }
520
521        if !self.access_control.allow_methods.is_empty() {
522            headers.push((
523                "Access-Control-Allow-Methods".to_string(),
524                self.access_control.allow_methods.join(", "),
525            ));
526        }
527
528        if !self.access_control.allow_headers.is_empty() {
529            headers.push((
530                "Access-Control-Allow-Headers".to_string(),
531                self.access_control.allow_headers.join(", "),
532            ));
533        }
534
535        if self.access_control.allow_credentials {
536            headers.push((
537                "Access-Control-Allow-Credentials".to_string(),
538                "true".to_string(),
539            ));
540        }
541
542        if self.access_control.max_age > 0 {
543            headers.push((
544                "Access-Control-Max-Age".to_string(),
545                self.access_control.max_age.to_string(),
546            ));
547        }
548
549        headers
550    }
551
552    pub fn url(&self) -> String {
553        let protocol = if self.is_ssl_enabled() {
554            "https"
555        } else {
556            "http"
557        };
558        let port_suffix = match (self.is_ssl_enabled(), self.port) {
559            (true, 443) | (false, 80) => String::new(),
560            _ => format!(":{}", self.port),
561        };
562        format!("{}://{}{}", protocol, self.hostname, port_suffix)
563    }
564}
565
566impl CompressionConfig {
567    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
568        if self.level > 9 {
569            return Err("Compression level must be between 0 and 9".into());
570        }
571
572        if self.types.is_empty() && self.enabled {
573            return Err("Compression types cannot be empty when compression is enabled".into());
574        }
575
576        Ok(())
577    }
578}
579
580impl CacheConfig {
581    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
582        // Cache configuration is generally permissive
583        // Just ensure max_age values are reasonable
584        if self.max_age_static > 365 * 24 * 3600 {
585            log::warn!("Static cache max_age is very large (> 1 year)");
586        }
587
588        if self.max_age_dynamic > 24 * 3600 {
589            log::warn!("Dynamic cache max_age is very large (> 1 day)");
590        }
591
592        Ok(())
593    }
594}
595
596impl AccessControlConfig {
597    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
598        // Validate HTTP methods
599        let valid_methods = [
600            "GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH", "TRACE", "CONNECT",
601        ];
602
603        for method in &self.allow_methods {
604            if !valid_methods.contains(&method.as_str()) {
605                return Err(format!("Invalid HTTP method: {}", method).into());
606            }
607        }
608
609        // Max age should be reasonable
610        if self.max_age > 7 * 24 * 3600 {
611            log::warn!("CORS max_age is very large (> 1 week)");
612        }
613
614        Ok(())
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[test]
623    fn test_site_config_validation() {
624        let mut site = SiteConfig {
625            name: "test".to_string(),
626            hostname: "example.com".to_string(),
627            port: 8080,
628            static_dir: "/tmp".to_string(),
629            default: false,
630            api_only: false,
631            headers: HashMap::new(),
632            redirect_to_https: false,
633            index_files: vec![],
634            error_pages: HashMap::new(),
635            compression: CompressionConfig::default(),
636            cache: CacheConfig::default(),
637            access_control: AccessControlConfig::default(),
638            ssl: SiteSslConfig::default(),
639            proxy: ProxyConfig::default(),
640        };
641
642        assert!(site.validate().is_ok());
643
644        // Test invalid hostname
645        site.hostname = "".to_string();
646        assert!(site.validate().is_err());
647
648        // Test invalid port
649        site.hostname = "example.com".to_string();
650        site.port = 0;
651        assert!(site.validate().is_err());
652
653        // Test invalid port range (port 0 is invalid)
654        site.port = 0;
655        assert!(site.validate().is_err());
656    }
657
658    #[test]
659    fn test_hostname_validation() {
660        let site = SiteConfig {
661            name: "test".to_string(),
662            hostname: "localhost".to_string(),
663            port: 8080,
664            static_dir: "/tmp".to_string(),
665            default: false,
666            api_only: false,
667            headers: HashMap::new(),
668            redirect_to_https: false,
669            index_files: vec![],
670            error_pages: HashMap::new(),
671            compression: CompressionConfig::default(),
672            cache: CacheConfig::default(),
673            access_control: AccessControlConfig::default(),
674            ssl: SiteSslConfig::default(),
675            proxy: ProxyConfig::default(),
676        };
677
678        assert!(site.is_valid_hostname());
679
680        let mut invalid_site = site.clone();
681        invalid_site.hostname = "invalid..hostname".to_string();
682        assert!(!invalid_site.is_valid_hostname());
683
684        invalid_site.hostname = "-invalid.hostname".to_string();
685        assert!(!invalid_site.is_valid_hostname());
686    }
687
688    #[test]
689    fn test_compression_config() {
690        let site = SiteConfig {
691            name: "test".to_string(),
692            hostname: "example.com".to_string(),
693            port: 8080,
694            static_dir: "/tmp".to_string(),
695            default: false,
696            api_only: false,
697            headers: HashMap::new(),
698            redirect_to_https: false,
699            index_files: vec![],
700            error_pages: HashMap::new(),
701            compression: CompressionConfig::default(),
702            cache: CacheConfig::default(),
703            access_control: AccessControlConfig::default(),
704            ssl: SiteSslConfig::default(),
705            proxy: ProxyConfig::default(),
706        };
707
708        assert!(site.should_compress("text/html", 2048));
709        assert!(!site.should_compress("text/html", 512)); // Below min_size
710        assert!(!site.should_compress("image/png", 2048)); // Not in types list
711    }
712
713    #[test]
714    fn test_site_url_generation() {
715        let mut site = SiteConfig {
716            name: "test".to_string(),
717            hostname: "example.com".to_string(),
718            port: 8080,
719            static_dir: "/tmp".to_string(),
720            default: false,
721            api_only: false,
722            headers: HashMap::new(),
723            redirect_to_https: false,
724            index_files: vec![],
725            error_pages: HashMap::new(),
726            compression: CompressionConfig::default(),
727            cache: CacheConfig::default(),
728            access_control: AccessControlConfig::default(),
729            ssl: SiteSslConfig::default(),
730            proxy: ProxyConfig::default(),
731        };
732
733        assert_eq!(site.url(), "http://example.com:8080");
734
735        site.ssl.enabled = true;
736        site.port = 443;
737        assert_eq!(site.url(), "https://example.com");
738
739        site.port = 8443;
740        assert_eq!(site.url(), "https://example.com:8443");
741    }
742}