bws_web_server/config/
site.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4
5#[derive(Debug, Deserialize, Serialize, Clone)]
6pub struct SiteConfig {
7    pub name: String,
8    pub hostname: String,
9    pub port: u16,
10    pub static_dir: String,
11    #[serde(default)]
12    pub default: bool,
13    #[serde(default)]
14    pub api_only: bool,
15    #[serde(default)]
16    pub headers: HashMap<String, String>,
17    #[serde(default)]
18    pub redirect_to_https: bool,
19    #[serde(default)]
20    pub index_files: Vec<String>,
21    #[serde(default)]
22    pub error_pages: HashMap<u16, String>,
23    #[serde(default)]
24    pub compression: CompressionConfig,
25    #[serde(default)]
26    pub cache: CacheConfig,
27    #[serde(default)]
28    pub access_control: AccessControlConfig,
29    #[serde(default)]
30    pub ssl: SiteSslConfig,
31    #[serde(default)]
32    pub proxy: ProxyConfig,
33}
34
35#[derive(Debug, Deserialize, Serialize, Clone, Default)]
36pub struct SiteSslConfig {
37    #[serde(default)]
38    pub enabled: bool,
39    #[serde(default)]
40    pub auto_cert: bool,
41    #[serde(default)]
42    pub domains: Vec<String>, // Additional domains beyond hostname
43    #[serde(default)]
44    pub cert_file: Option<String>, // Manual certificate file
45    #[serde(default)]
46    pub key_file: Option<String>, // Manual key file
47    #[serde(default)]
48    pub acme: Option<SiteAcmeConfig>,
49}
50
51#[derive(Debug, Deserialize, Serialize, Clone)]
52pub struct SiteAcmeConfig {
53    #[serde(default)]
54    pub enabled: bool,
55    #[serde(default)]
56    pub email: String,
57    #[serde(default)]
58    pub staging: bool,
59    #[serde(default)]
60    pub challenge_dir: String,
61}
62
63#[derive(Debug, Deserialize, Serialize, Clone, Default)]
64pub struct ProxyConfig {
65    #[serde(default)]
66    pub enabled: bool,
67    #[serde(default)]
68    pub upstreams: Vec<UpstreamConfig>,
69    #[serde(default)]
70    pub routes: Vec<ProxyRoute>,
71    #[serde(default)]
72    pub health_check: HealthCheckConfig,
73    #[serde(default)]
74    pub load_balancing: LoadBalancingConfig,
75    #[serde(default)]
76    pub timeout: TimeoutConfig,
77    #[serde(default)]
78    pub headers: ProxyHeadersConfig,
79}
80
81#[derive(Debug, Deserialize, Serialize, Clone)]
82pub struct UpstreamConfig {
83    pub name: String,
84    pub url: String,
85    #[serde(default = "default_weight")]
86    pub weight: u32,
87    #[serde(default)]
88    pub max_conns: Option<u32>,
89}
90
91#[derive(Debug, Deserialize, Serialize, Clone)]
92pub struct ProxyRoute {
93    pub path: String,
94    pub upstream: String, // References upstream name
95    #[serde(default)]
96    pub strip_prefix: bool,
97    #[serde(default)]
98    pub rewrite_target: Option<String>,
99    #[serde(default)]
100    pub websocket: bool, // Enable WebSocket proxying for this route
101}
102
103#[derive(Debug, Deserialize, Serialize, Clone)]
104pub struct HealthCheckConfig {
105    #[serde(default)]
106    pub enabled: bool,
107    #[serde(default = "default_health_path")]
108    pub path: String,
109    #[serde(default = "default_health_interval")]
110    pub interval: u64, // seconds
111    #[serde(default = "default_health_timeout")]
112    pub timeout: u64, // seconds
113    #[serde(default = "default_health_retries")]
114    pub retries: u32,
115}
116
117#[derive(Debug, Deserialize, Serialize, Clone)]
118pub struct LoadBalancingConfig {
119    #[serde(default = "default_lb_method")]
120    pub method: String, // "round_robin", "least_conn", "weighted"
121    #[serde(default)]
122    pub sticky_sessions: bool,
123}
124
125#[derive(Debug, Deserialize, Serialize, Clone)]
126pub struct TimeoutConfig {
127    #[serde(default = "default_connect_timeout")]
128    pub connect: u64, // seconds
129    #[serde(default = "default_read_timeout")]
130    pub read: u64, // seconds
131    #[serde(default = "default_write_timeout")]
132    pub write: u64, // seconds
133}
134
135#[derive(Debug, Deserialize, Serialize, Clone)]
136pub struct ProxyHeadersConfig {
137    #[serde(default)]
138    pub preserve_host: bool,
139    #[serde(default)]
140    pub add_forwarded: bool,
141    #[serde(default)]
142    pub add_x_forwarded: bool,
143    #[serde(default)]
144    pub remove: Vec<String>,
145    #[serde(default)]
146    pub add: HashMap<String, String>,
147}
148
149fn default_weight() -> u32 {
150    1
151}
152fn default_health_path() -> String {
153    "/health".to_string()
154}
155fn default_health_interval() -> u64 {
156    30
157}
158fn default_health_timeout() -> u64 {
159    5
160}
161fn default_health_retries() -> u32 {
162    3
163}
164fn default_lb_method() -> String {
165    "round_robin".to_string()
166}
167fn default_connect_timeout() -> u64 {
168    10
169}
170fn default_read_timeout() -> u64 {
171    30
172}
173fn default_write_timeout() -> u64 {
174    30
175}
176
177impl Default for SiteAcmeConfig {
178    fn default() -> Self {
179        Self {
180            enabled: false,
181            email: String::new(),
182            staging: false,
183            challenge_dir: "./acme-challenges".to_string(),
184        }
185    }
186}
187
188impl Default for HealthCheckConfig {
189    fn default() -> Self {
190        Self {
191            enabled: false,
192            path: default_health_path(),
193            interval: default_health_interval(),
194            timeout: default_health_timeout(),
195            retries: default_health_retries(),
196        }
197    }
198}
199
200impl Default for LoadBalancingConfig {
201    fn default() -> Self {
202        Self {
203            method: default_lb_method(),
204            sticky_sessions: false,
205        }
206    }
207}
208
209impl Default for TimeoutConfig {
210    fn default() -> Self {
211        Self {
212            connect: default_connect_timeout(),
213            read: default_read_timeout(),
214            write: default_write_timeout(),
215        }
216    }
217}
218
219impl Default for ProxyHeadersConfig {
220    fn default() -> Self {
221        Self {
222            preserve_host: true,
223            add_forwarded: true,
224            add_x_forwarded: true,
225            remove: Vec::new(),
226            add: HashMap::new(),
227        }
228    }
229}
230
231#[derive(Debug, Deserialize, Serialize, Clone)]
232pub struct CompressionConfig {
233    #[serde(default)]
234    pub enabled: bool,
235    #[serde(default = "default_compression_types")]
236    pub types: Vec<String>,
237    #[serde(default = "default_compression_level")]
238    pub level: u32,
239    #[serde(default = "default_min_size")]
240    pub min_size: usize,
241}
242
243#[derive(Debug, Deserialize, Serialize, Clone)]
244pub struct CacheConfig {
245    #[serde(default)]
246    pub enabled: bool,
247    #[serde(default = "default_cache_control")]
248    pub cache_control: String,
249    #[serde(default)]
250    pub etag_enabled: bool,
251    #[serde(default)]
252    pub last_modified_enabled: bool,
253    #[serde(default)]
254    pub max_age_static: u32,
255    #[serde(default)]
256    pub max_age_dynamic: u32,
257}
258
259#[derive(Debug, Deserialize, Serialize, Clone)]
260pub struct AccessControlConfig {
261    #[serde(default)]
262    pub allow_methods: Vec<String>,
263    #[serde(default)]
264    pub allow_headers: Vec<String>,
265    #[serde(default)]
266    pub allow_origins: Vec<String>,
267    #[serde(default)]
268    pub allow_credentials: bool,
269    #[serde(default)]
270    pub max_age: u32,
271}
272
273// Default value functions
274fn default_compression_types() -> Vec<String> {
275    vec![
276        "text/html".to_string(),
277        "text/css".to_string(),
278        "text/javascript".to_string(),
279        "application/javascript".to_string(),
280        "application/json".to_string(),
281        "text/xml".to_string(),
282        "application/xml".to_string(),
283        "text/plain".to_string(),
284    ]
285}
286
287fn default_compression_level() -> u32 {
288    6
289}
290
291fn default_min_size() -> usize {
292    1024 // 1KB
293}
294
295fn default_cache_control() -> String {
296    "public, max-age=3600".to_string()
297}
298
299impl Default for CompressionConfig {
300    fn default() -> Self {
301        Self {
302            enabled: true,
303            types: default_compression_types(),
304            level: default_compression_level(),
305            min_size: default_min_size(),
306        }
307    }
308}
309
310impl Default for CacheConfig {
311    fn default() -> Self {
312        Self {
313            enabled: true,
314            cache_control: default_cache_control(),
315            etag_enabled: true,
316            last_modified_enabled: true,
317            max_age_static: 3600, // 1 hour for static files
318            max_age_dynamic: 300, // 5 minutes for dynamic content
319        }
320    }
321}
322
323impl Default for AccessControlConfig {
324    fn default() -> Self {
325        Self {
326            allow_methods: vec!["GET".to_string(), "HEAD".to_string(), "OPTIONS".to_string()],
327            allow_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
328            allow_origins: vec!["*".to_string()],
329            allow_credentials: false,
330            max_age: 86400, // 24 hours
331        }
332    }
333}
334
335impl SiteConfig {
336    pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
337        // Validate required fields
338        if self.name.is_empty() {
339            return Err("Site name cannot be empty".into());
340        }
341
342        if self.hostname.is_empty() {
343            return Err("Site hostname cannot be empty".into());
344        }
345
346        if self.port == 0 {
347            return Err("Site port must be greater than 0".into());
348        }
349
350        if self.static_dir.is_empty() {
351            return Err("Site static_dir cannot be empty".into());
352        }
353
354        // Validate static directory exists (or can be created)
355        let static_path = Path::new(&self.static_dir);
356        if !static_path.exists() {
357            log::warn!(
358                "Static directory does not exist for site '{}': {}",
359                self.name,
360                self.static_dir
361            );
362        }
363
364        // Validate hostname format (basic check)
365        if !self.is_valid_hostname() {
366            return Err(format!("Invalid hostname format: {}", self.hostname).into());
367        }
368
369        // Validate port range
370        if self.port < 1 {
371            return Err(format!("Invalid port number: {}", self.port).into());
372        }
373
374        // Validate SSL configuration
375        if self.ssl.enabled {
376            if self.ssl.auto_cert {
377                if let Some(acme) = &self.ssl.acme {
378                    if acme.email.is_empty() {
379                        return Err("ACME email is required when auto_cert is enabled".into());
380                    }
381                } else {
382                    return Err("ACME configuration is required when auto_cert is enabled".into());
383                }
384            } else if self.ssl.cert_file.is_none() || self.ssl.key_file.is_none() {
385                return Err("Manual SSL requires both cert_file and key_file".into());
386            }
387        }
388
389        // Validate index files
390        for index_file in &self.index_files {
391            if index_file.is_empty() {
392                return Err("Index file name cannot be empty".into());
393            }
394        }
395
396        // Validate error pages
397        for (status_code, error_page) in &self.error_pages {
398            if *status_code < 100 || *status_code > 999 {
399                return Err(format!("Invalid HTTP status code: {}", status_code).into());
400            }
401            if error_page.is_empty() {
402                return Err(
403                    format!("Error page path cannot be empty for status {}", status_code).into(),
404                );
405            }
406        }
407
408        // Validate compression configuration
409        self.compression.validate()?;
410
411        // Validate cache configuration
412        self.cache.validate()?;
413
414        // Validate access control configuration
415        self.access_control.validate()?;
416
417        Ok(())
418    }
419
420    fn is_valid_hostname(&self) -> bool {
421        // Basic hostname validation
422        if self.hostname.is_empty() || self.hostname.len() > 253 {
423            return false;
424        }
425
426        // Allow localhost and IP addresses for development
427        if self.hostname == "localhost"
428            || self.hostname.starts_with("127.")
429            || self.hostname.starts_with("0.0.0.0")
430        {
431            return true;
432        }
433
434        // Basic domain name validation
435        self.hostname.split('.').all(|label| {
436            !label.is_empty()
437                && label.len() <= 63
438                && label.chars().all(|c| c.is_alphanumeric() || c == '-')
439                && !label.starts_with('-')
440                && !label.ends_with('-')
441        })
442    }
443
444    pub fn get_ssl_domain(&self) -> Option<&str> {
445        if self.ssl.enabled {
446            // Return primary domain (hostname) plus any additional domains
447            Some(&self.hostname)
448        } else {
449            None
450        }
451    }
452
453    pub fn get_all_ssl_domains(&self) -> Vec<&str> {
454        if self.ssl.enabled {
455            let mut domains = vec![self.hostname.as_str()];
456            for domain in &self.ssl.domains {
457                domains.push(domain.as_str());
458            }
459            domains
460        } else {
461            Vec::new()
462        }
463    }
464
465    pub fn is_ssl_enabled(&self) -> bool {
466        self.ssl.enabled
467    }
468
469    pub fn get_index_files(&self) -> Vec<&str> {
470        if self.index_files.is_empty() {
471            vec!["index.html", "index.htm"]
472        } else {
473            self.index_files.iter().map(|s| s.as_str()).collect()
474        }
475    }
476
477    pub fn get_error_page(&self, status_code: u16) -> Option<&str> {
478        self.error_pages.get(&status_code).map(|s| s.as_str())
479    }
480
481    pub fn should_compress(&self, content_type: &str, content_length: usize) -> bool {
482        self.compression.enabled
483            && content_length >= self.compression.min_size
484            && self
485                .compression
486                .types
487                .iter()
488                .any(|t| content_type.starts_with(t))
489    }
490
491    pub fn get_cache_headers(&self, is_static: bool) -> Vec<(String, String)> {
492        let mut headers = Vec::new();
493
494        if self.cache.enabled {
495            let max_age = if is_static {
496                self.cache.max_age_static
497            } else {
498                self.cache.max_age_dynamic
499            };
500
501            headers.push((
502                "Cache-Control".to_string(),
503                format!("public, max-age={}", max_age),
504            ));
505
506            if self.cache.etag_enabled {
507                // ETag would be calculated based on file content
508                // This is a placeholder - actual implementation would calculate based on file
509                headers.push(("ETag".to_string(), "\"placeholder\"".to_string()));
510            }
511
512            if self.cache.last_modified_enabled {
513                // Last-Modified would be based on file modification time
514                // This is a placeholder - actual implementation would use file mtime
515                headers.push(("Last-Modified".to_string(), "placeholder".to_string()));
516            }
517        }
518
519        headers
520    }
521
522    pub fn get_cors_headers(&self) -> Vec<(String, String)> {
523        let mut headers = Vec::new();
524
525        if !self.access_control.allow_origins.is_empty() {
526            headers.push((
527                "Access-Control-Allow-Origin".to_string(),
528                self.access_control.allow_origins.join(", "),
529            ));
530        }
531
532        if !self.access_control.allow_methods.is_empty() {
533            headers.push((
534                "Access-Control-Allow-Methods".to_string(),
535                self.access_control.allow_methods.join(", "),
536            ));
537        }
538
539        if !self.access_control.allow_headers.is_empty() {
540            headers.push((
541                "Access-Control-Allow-Headers".to_string(),
542                self.access_control.allow_headers.join(", "),
543            ));
544        }
545
546        if self.access_control.allow_credentials {
547            headers.push((
548                "Access-Control-Allow-Credentials".to_string(),
549                "true".to_string(),
550            ));
551        }
552
553        if self.access_control.max_age > 0 {
554            headers.push((
555                "Access-Control-Max-Age".to_string(),
556                self.access_control.max_age.to_string(),
557            ));
558        }
559
560        headers
561    }
562
563    pub fn url(&self) -> String {
564        let protocol = if self.is_ssl_enabled() {
565            "https"
566        } else {
567            "http"
568        };
569        let port_suffix = match (self.is_ssl_enabled(), self.port) {
570            (true, 443) | (false, 80) => String::new(),
571            _ => format!(":{}", self.port),
572        };
573        format!("{}://{}{}", protocol, self.hostname, port_suffix)
574    }
575}
576
577impl CompressionConfig {
578    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
579        if self.level > 9 {
580            return Err("Compression level must be between 0 and 9".into());
581        }
582
583        if self.types.is_empty() && self.enabled {
584            return Err("Compression types cannot be empty when compression is enabled".into());
585        }
586
587        Ok(())
588    }
589}
590
591impl CacheConfig {
592    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
593        // Cache configuration is generally permissive
594        // Just ensure max_age values are reasonable
595        if self.max_age_static > 365 * 24 * 3600 {
596            log::warn!("Static cache max_age is very large (> 1 year)");
597        }
598
599        if self.max_age_dynamic > 24 * 3600 {
600            log::warn!("Dynamic cache max_age is very large (> 1 day)");
601        }
602
603        Ok(())
604    }
605}
606
607impl AccessControlConfig {
608    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
609        // Validate HTTP methods
610        let valid_methods = [
611            "GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH", "TRACE", "CONNECT",
612        ];
613
614        for method in &self.allow_methods {
615            if !valid_methods.contains(&method.as_str()) {
616                return Err(format!("Invalid HTTP method: {}", method).into());
617            }
618        }
619
620        // Max age should be reasonable
621        if self.max_age > 7 * 24 * 3600 {
622            log::warn!("CORS max_age is very large (> 1 week)");
623        }
624
625        Ok(())
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn test_site_config_validation() {
635        let mut site = SiteConfig {
636            name: "test".to_string(),
637            hostname: "example.com".to_string(),
638            port: 8080,
639            static_dir: "/tmp".to_string(),
640            default: false,
641            api_only: false,
642            headers: HashMap::new(),
643            redirect_to_https: false,
644            index_files: vec![],
645            error_pages: HashMap::new(),
646            compression: CompressionConfig::default(),
647            cache: CacheConfig::default(),
648            access_control: AccessControlConfig::default(),
649            ssl: SiteSslConfig::default(),
650            proxy: ProxyConfig::default(),
651        };
652
653        assert!(site.validate().is_ok());
654
655        // Test invalid hostname
656        site.hostname = "".to_string();
657        assert!(site.validate().is_err());
658
659        // Test invalid port
660        site.hostname = "example.com".to_string();
661        site.port = 0;
662        assert!(site.validate().is_err());
663
664        // Test invalid port range (port 0 is invalid)
665        site.port = 0;
666        assert!(site.validate().is_err());
667    }
668
669    #[test]
670    fn test_hostname_validation() {
671        let site = SiteConfig {
672            name: "test".to_string(),
673            hostname: "localhost".to_string(),
674            port: 8080,
675            static_dir: "/tmp".to_string(),
676            default: false,
677            api_only: false,
678            headers: HashMap::new(),
679            redirect_to_https: false,
680            index_files: vec![],
681            error_pages: HashMap::new(),
682            compression: CompressionConfig::default(),
683            cache: CacheConfig::default(),
684            access_control: AccessControlConfig::default(),
685            ssl: SiteSslConfig::default(),
686            proxy: ProxyConfig::default(),
687        };
688
689        assert!(site.is_valid_hostname());
690
691        let mut invalid_site = site.clone();
692        invalid_site.hostname = "invalid..hostname".to_string();
693        assert!(!invalid_site.is_valid_hostname());
694
695        invalid_site.hostname = "-invalid.hostname".to_string();
696        assert!(!invalid_site.is_valid_hostname());
697    }
698
699    #[test]
700    fn test_compression_config() {
701        let site = SiteConfig {
702            name: "test".to_string(),
703            hostname: "example.com".to_string(),
704            port: 8080,
705            static_dir: "/tmp".to_string(),
706            default: false,
707            api_only: false,
708            headers: HashMap::new(),
709            redirect_to_https: false,
710            index_files: vec![],
711            error_pages: HashMap::new(),
712            compression: CompressionConfig::default(),
713            cache: CacheConfig::default(),
714            access_control: AccessControlConfig::default(),
715            ssl: SiteSslConfig::default(),
716            proxy: ProxyConfig::default(),
717        };
718
719        assert!(site.should_compress("text/html", 2048));
720        assert!(!site.should_compress("text/html", 512)); // Below min_size
721        assert!(!site.should_compress("image/png", 2048)); // Not in types list
722    }
723
724    #[test]
725    fn test_site_url_generation() {
726        let mut site = SiteConfig {
727            name: "test".to_string(),
728            hostname: "example.com".to_string(),
729            port: 8080,
730            static_dir: "/tmp".to_string(),
731            default: false,
732            api_only: false,
733            headers: HashMap::new(),
734            redirect_to_https: false,
735            index_files: vec![],
736            error_pages: HashMap::new(),
737            compression: CompressionConfig::default(),
738            cache: CacheConfig::default(),
739            access_control: AccessControlConfig::default(),
740            ssl: SiteSslConfig::default(),
741            proxy: ProxyConfig::default(),
742        };
743
744        assert_eq!(site.url(), "http://example.com:8080");
745
746        site.ssl.enabled = true;
747        site.port = 443;
748        assert_eq!(site.url(), "https://example.com");
749
750        site.port = 8443;
751        assert_eq!(site.url(), "https://example.com:8443");
752    }
753}