bws_web_server/config/
site.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4
5#[derive(Debug, Deserialize, Serialize, Clone)]
6pub struct SiteConfig {
7    pub name: String,
8    pub hostname: String,
9    pub port: u16,
10    pub static_dir: String,
11    #[serde(default)]
12    pub default: bool,
13    #[serde(default)]
14    pub api_only: bool,
15    #[serde(default)]
16    pub headers: HashMap<String, String>,
17    #[serde(default)]
18    pub redirect_to_https: bool,
19    #[serde(default)]
20    pub index_files: Vec<String>,
21    #[serde(default)]
22    pub error_pages: HashMap<u16, String>,
23    #[serde(default)]
24    pub compression: CompressionConfig,
25    #[serde(default)]
26    pub cache: CacheConfig,
27    #[serde(default)]
28    pub access_control: AccessControlConfig,
29    #[serde(default)]
30    pub ssl: SiteSslConfig,
31    #[serde(default)]
32    pub proxy: ProxyConfig,
33}
34
35#[derive(Debug, Deserialize, Serialize, Clone, Default)]
36pub struct SiteSslConfig {
37    #[serde(default)]
38    pub enabled: bool,
39    #[serde(default)]
40    pub auto_cert: bool,
41    #[serde(default)]
42    pub domains: Vec<String>, // Additional domains beyond hostname
43    #[serde(default)]
44    pub cert_file: Option<String>, // Manual certificate file
45    #[serde(default)]
46    pub key_file: Option<String>, // Manual key file
47    #[serde(default)]
48    pub acme: Option<SiteAcmeConfig>,
49}
50
51#[derive(Debug, Deserialize, Serialize, Clone)]
52pub struct SiteAcmeConfig {
53    #[serde(default)]
54    pub enabled: bool,
55    #[serde(default)]
56    pub email: String,
57    #[serde(default)]
58    pub staging: bool,
59    #[serde(default)]
60    pub challenge_dir: String,
61}
62
63#[derive(Debug, Deserialize, Serialize, Clone, Default)]
64pub struct ProxyConfig {
65    #[serde(default)]
66    pub enabled: bool,
67    #[serde(default)]
68    pub upstreams: Vec<UpstreamConfig>,
69    #[serde(default)]
70    pub routes: Vec<ProxyRoute>,
71    #[serde(default)]
72    pub health_check: HealthCheckConfig,
73    #[serde(default)]
74    pub load_balancing: LoadBalancingConfig,
75    #[serde(default)]
76    pub timeout: TimeoutConfig,
77    #[serde(default)]
78    pub headers: ProxyHeadersConfig,
79}
80
81#[derive(Debug, Deserialize, Serialize, Clone)]
82pub struct UpstreamConfig {
83    pub name: String,
84    pub url: String,
85    #[serde(default = "default_weight")]
86    pub weight: u32,
87    #[serde(default)]
88    pub max_conns: Option<u32>,
89}
90
91#[derive(Debug, Deserialize, Serialize, Clone)]
92pub struct ProxyRoute {
93    pub path: String,
94    pub upstream: String, // References upstream name
95    #[serde(default)]
96    pub strip_prefix: bool,
97    #[serde(default)]
98    pub rewrite_target: Option<String>,
99}
100
101#[derive(Debug, Deserialize, Serialize, Clone)]
102pub struct HealthCheckConfig {
103    #[serde(default)]
104    pub enabled: bool,
105    #[serde(default = "default_health_path")]
106    pub path: String,
107    #[serde(default = "default_health_interval")]
108    pub interval: u64, // seconds
109    #[serde(default = "default_health_timeout")]
110    pub timeout: u64, // seconds
111    #[serde(default = "default_health_retries")]
112    pub retries: u32,
113}
114
115#[derive(Debug, Deserialize, Serialize, Clone)]
116pub struct LoadBalancingConfig {
117    #[serde(default = "default_lb_method")]
118    pub method: String, // "round_robin", "least_conn", "weighted"
119    #[serde(default)]
120    pub sticky_sessions: bool,
121}
122
123#[derive(Debug, Deserialize, Serialize, Clone)]
124pub struct TimeoutConfig {
125    #[serde(default = "default_connect_timeout")]
126    pub connect: u64, // seconds
127    #[serde(default = "default_read_timeout")]
128    pub read: u64, // seconds
129    #[serde(default = "default_write_timeout")]
130    pub write: u64, // seconds
131}
132
133#[derive(Debug, Deserialize, Serialize, Clone)]
134pub struct ProxyHeadersConfig {
135    #[serde(default)]
136    pub preserve_host: bool,
137    #[serde(default)]
138    pub add_forwarded: bool,
139    #[serde(default)]
140    pub add_x_forwarded: bool,
141    #[serde(default)]
142    pub remove: Vec<String>,
143    #[serde(default)]
144    pub add: HashMap<String, String>,
145}
146
147fn default_weight() -> u32 {
148    1
149}
150fn default_health_path() -> String {
151    "/health".to_string()
152}
153fn default_health_interval() -> u64 {
154    30
155}
156fn default_health_timeout() -> u64 {
157    5
158}
159fn default_health_retries() -> u32 {
160    3
161}
162fn default_lb_method() -> String {
163    "round_robin".to_string()
164}
165fn default_connect_timeout() -> u64 {
166    10
167}
168fn default_read_timeout() -> u64 {
169    30
170}
171fn default_write_timeout() -> u64 {
172    30
173}
174
175impl Default for SiteAcmeConfig {
176    fn default() -> Self {
177        Self {
178            enabled: false,
179            email: String::new(),
180            staging: false,
181            challenge_dir: "./acme-challenges".to_string(),
182        }
183    }
184}
185
186impl Default for HealthCheckConfig {
187    fn default() -> Self {
188        Self {
189            enabled: false,
190            path: default_health_path(),
191            interval: default_health_interval(),
192            timeout: default_health_timeout(),
193            retries: default_health_retries(),
194        }
195    }
196}
197
198impl Default for LoadBalancingConfig {
199    fn default() -> Self {
200        Self {
201            method: default_lb_method(),
202            sticky_sessions: false,
203        }
204    }
205}
206
207impl Default for TimeoutConfig {
208    fn default() -> Self {
209        Self {
210            connect: default_connect_timeout(),
211            read: default_read_timeout(),
212            write: default_write_timeout(),
213        }
214    }
215}
216
217impl Default for ProxyHeadersConfig {
218    fn default() -> Self {
219        Self {
220            preserve_host: true,
221            add_forwarded: true,
222            add_x_forwarded: true,
223            remove: Vec::new(),
224            add: HashMap::new(),
225        }
226    }
227}
228
229#[derive(Debug, Deserialize, Serialize, Clone)]
230pub struct CompressionConfig {
231    #[serde(default)]
232    pub enabled: bool,
233    #[serde(default = "default_compression_types")]
234    pub types: Vec<String>,
235    #[serde(default = "default_compression_level")]
236    pub level: u32,
237    #[serde(default = "default_min_size")]
238    pub min_size: usize,
239}
240
241#[derive(Debug, Deserialize, Serialize, Clone)]
242pub struct CacheConfig {
243    #[serde(default)]
244    pub enabled: bool,
245    #[serde(default = "default_cache_control")]
246    pub cache_control: String,
247    #[serde(default)]
248    pub etag_enabled: bool,
249    #[serde(default)]
250    pub last_modified_enabled: bool,
251    #[serde(default)]
252    pub max_age_static: u32,
253    #[serde(default)]
254    pub max_age_dynamic: u32,
255}
256
257#[derive(Debug, Deserialize, Serialize, Clone)]
258pub struct AccessControlConfig {
259    #[serde(default)]
260    pub allow_methods: Vec<String>,
261    #[serde(default)]
262    pub allow_headers: Vec<String>,
263    #[serde(default)]
264    pub allow_origins: Vec<String>,
265    #[serde(default)]
266    pub allow_credentials: bool,
267    #[serde(default)]
268    pub max_age: u32,
269}
270
271// Default value functions
272fn default_compression_types() -> Vec<String> {
273    vec![
274        "text/html".to_string(),
275        "text/css".to_string(),
276        "text/javascript".to_string(),
277        "application/javascript".to_string(),
278        "application/json".to_string(),
279        "text/xml".to_string(),
280        "application/xml".to_string(),
281        "text/plain".to_string(),
282    ]
283}
284
285fn default_compression_level() -> u32 {
286    6
287}
288
289fn default_min_size() -> usize {
290    1024 // 1KB
291}
292
293fn default_cache_control() -> String {
294    "public, max-age=3600".to_string()
295}
296
297impl Default for CompressionConfig {
298    fn default() -> Self {
299        Self {
300            enabled: true,
301            types: default_compression_types(),
302            level: default_compression_level(),
303            min_size: default_min_size(),
304        }
305    }
306}
307
308impl Default for CacheConfig {
309    fn default() -> Self {
310        Self {
311            enabled: true,
312            cache_control: default_cache_control(),
313            etag_enabled: true,
314            last_modified_enabled: true,
315            max_age_static: 3600, // 1 hour for static files
316            max_age_dynamic: 300, // 5 minutes for dynamic content
317        }
318    }
319}
320
321impl Default for AccessControlConfig {
322    fn default() -> Self {
323        Self {
324            allow_methods: vec!["GET".to_string(), "HEAD".to_string(), "OPTIONS".to_string()],
325            allow_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
326            allow_origins: vec!["*".to_string()],
327            allow_credentials: false,
328            max_age: 86400, // 24 hours
329        }
330    }
331}
332
333impl SiteConfig {
334    pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
335        // Validate required fields
336        if self.name.is_empty() {
337            return Err("Site name cannot be empty".into());
338        }
339
340        if self.hostname.is_empty() {
341            return Err("Site hostname cannot be empty".into());
342        }
343
344        if self.port == 0 {
345            return Err("Site port must be greater than 0".into());
346        }
347
348        if self.static_dir.is_empty() {
349            return Err("Site static_dir cannot be empty".into());
350        }
351
352        // Validate static directory exists (or can be created)
353        let static_path = Path::new(&self.static_dir);
354        if !static_path.exists() {
355            log::warn!(
356                "Static directory does not exist for site '{}': {}",
357                self.name,
358                self.static_dir
359            );
360        }
361
362        // Validate hostname format (basic check)
363        if !self.is_valid_hostname() {
364            return Err(format!("Invalid hostname format: {}", self.hostname).into());
365        }
366
367        // Validate port range
368        if self.port < 1 {
369            return Err(format!("Invalid port number: {}", self.port).into());
370        }
371
372        // Validate SSL configuration
373        if self.ssl.enabled {
374            if self.ssl.auto_cert {
375                if let Some(acme) = &self.ssl.acme {
376                    if acme.email.is_empty() {
377                        return Err("ACME email is required when auto_cert is enabled".into());
378                    }
379                } else {
380                    return Err("ACME configuration is required when auto_cert is enabled".into());
381                }
382            } else if self.ssl.cert_file.is_none() || self.ssl.key_file.is_none() {
383                return Err("Manual SSL requires both cert_file and key_file".into());
384            }
385        }
386
387        // Validate index files
388        for index_file in &self.index_files {
389            if index_file.is_empty() {
390                return Err("Index file name cannot be empty".into());
391            }
392        }
393
394        // Validate error pages
395        for (status_code, error_page) in &self.error_pages {
396            if *status_code < 100 || *status_code > 999 {
397                return Err(format!("Invalid HTTP status code: {}", status_code).into());
398            }
399            if error_page.is_empty() {
400                return Err(
401                    format!("Error page path cannot be empty for status {}", status_code).into(),
402                );
403            }
404        }
405
406        // Validate compression configuration
407        self.compression.validate()?;
408
409        // Validate cache configuration
410        self.cache.validate()?;
411
412        // Validate access control configuration
413        self.access_control.validate()?;
414
415        Ok(())
416    }
417
418    fn is_valid_hostname(&self) -> bool {
419        // Basic hostname validation
420        if self.hostname.is_empty() || self.hostname.len() > 253 {
421            return false;
422        }
423
424        // Allow localhost and IP addresses for development
425        if self.hostname == "localhost"
426            || self.hostname.starts_with("127.")
427            || self.hostname.starts_with("0.0.0.0")
428        {
429            return true;
430        }
431
432        // Basic domain name validation
433        self.hostname.split('.').all(|label| {
434            !label.is_empty()
435                && label.len() <= 63
436                && label.chars().all(|c| c.is_alphanumeric() || c == '-')
437                && !label.starts_with('-')
438                && !label.ends_with('-')
439        })
440    }
441
442    pub fn get_ssl_domain(&self) -> Option<&str> {
443        if self.ssl.enabled {
444            // Return primary domain (hostname) plus any additional domains
445            Some(&self.hostname)
446        } else {
447            None
448        }
449    }
450
451    pub fn get_all_ssl_domains(&self) -> Vec<&str> {
452        if self.ssl.enabled {
453            let mut domains = vec![self.hostname.as_str()];
454            for domain in &self.ssl.domains {
455                domains.push(domain.as_str());
456            }
457            domains
458        } else {
459            Vec::new()
460        }
461    }
462
463    pub fn is_ssl_enabled(&self) -> bool {
464        self.ssl.enabled
465    }
466
467    pub fn get_index_files(&self) -> Vec<&str> {
468        if self.index_files.is_empty() {
469            vec!["index.html", "index.htm"]
470        } else {
471            self.index_files.iter().map(|s| s.as_str()).collect()
472        }
473    }
474
475    pub fn get_error_page(&self, status_code: u16) -> Option<&str> {
476        self.error_pages.get(&status_code).map(|s| s.as_str())
477    }
478
479    pub fn should_compress(&self, content_type: &str, content_length: usize) -> bool {
480        self.compression.enabled
481            && content_length >= self.compression.min_size
482            && self
483                .compression
484                .types
485                .iter()
486                .any(|t| content_type.starts_with(t))
487    }
488
489    pub fn get_cache_headers(&self, is_static: bool) -> Vec<(String, String)> {
490        let mut headers = Vec::new();
491
492        if self.cache.enabled {
493            let max_age = if is_static {
494                self.cache.max_age_static
495            } else {
496                self.cache.max_age_dynamic
497            };
498
499            headers.push((
500                "Cache-Control".to_string(),
501                format!("public, max-age={}", max_age),
502            ));
503
504            if self.cache.etag_enabled {
505                // ETag would be calculated based on file content
506                // This is a placeholder - actual implementation would calculate based on file
507                headers.push(("ETag".to_string(), "\"placeholder\"".to_string()));
508            }
509
510            if self.cache.last_modified_enabled {
511                // Last-Modified would be based on file modification time
512                // This is a placeholder - actual implementation would use file mtime
513                headers.push(("Last-Modified".to_string(), "placeholder".to_string()));
514            }
515        }
516
517        headers
518    }
519
520    pub fn get_cors_headers(&self) -> Vec<(String, String)> {
521        let mut headers = Vec::new();
522
523        if !self.access_control.allow_origins.is_empty() {
524            headers.push((
525                "Access-Control-Allow-Origin".to_string(),
526                self.access_control.allow_origins.join(", "),
527            ));
528        }
529
530        if !self.access_control.allow_methods.is_empty() {
531            headers.push((
532                "Access-Control-Allow-Methods".to_string(),
533                self.access_control.allow_methods.join(", "),
534            ));
535        }
536
537        if !self.access_control.allow_headers.is_empty() {
538            headers.push((
539                "Access-Control-Allow-Headers".to_string(),
540                self.access_control.allow_headers.join(", "),
541            ));
542        }
543
544        if self.access_control.allow_credentials {
545            headers.push((
546                "Access-Control-Allow-Credentials".to_string(),
547                "true".to_string(),
548            ));
549        }
550
551        if self.access_control.max_age > 0 {
552            headers.push((
553                "Access-Control-Max-Age".to_string(),
554                self.access_control.max_age.to_string(),
555            ));
556        }
557
558        headers
559    }
560
561    pub fn url(&self) -> String {
562        let protocol = if self.is_ssl_enabled() {
563            "https"
564        } else {
565            "http"
566        };
567        let port_suffix = match (self.is_ssl_enabled(), self.port) {
568            (true, 443) | (false, 80) => String::new(),
569            _ => format!(":{}", self.port),
570        };
571        format!("{}://{}{}", protocol, self.hostname, port_suffix)
572    }
573}
574
575impl CompressionConfig {
576    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
577        if self.level > 9 {
578            return Err("Compression level must be between 0 and 9".into());
579        }
580
581        if self.types.is_empty() && self.enabled {
582            return Err("Compression types cannot be empty when compression is enabled".into());
583        }
584
585        Ok(())
586    }
587}
588
589impl CacheConfig {
590    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
591        // Cache configuration is generally permissive
592        // Just ensure max_age values are reasonable
593        if self.max_age_static > 365 * 24 * 3600 {
594            log::warn!("Static cache max_age is very large (> 1 year)");
595        }
596
597        if self.max_age_dynamic > 24 * 3600 {
598            log::warn!("Dynamic cache max_age is very large (> 1 day)");
599        }
600
601        Ok(())
602    }
603}
604
605impl AccessControlConfig {
606    fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
607        // Validate HTTP methods
608        let valid_methods = [
609            "GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH", "TRACE", "CONNECT",
610        ];
611
612        for method in &self.allow_methods {
613            if !valid_methods.contains(&method.as_str()) {
614                return Err(format!("Invalid HTTP method: {}", method).into());
615            }
616        }
617
618        // Max age should be reasonable
619        if self.max_age > 7 * 24 * 3600 {
620            log::warn!("CORS max_age is very large (> 1 week)");
621        }
622
623        Ok(())
624    }
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    #[test]
632    fn test_site_config_validation() {
633        let mut site = SiteConfig {
634            name: "test".to_string(),
635            hostname: "example.com".to_string(),
636            port: 8080,
637            static_dir: "/tmp".to_string(),
638            default: false,
639            api_only: false,
640            headers: HashMap::new(),
641            redirect_to_https: false,
642            index_files: vec![],
643            error_pages: HashMap::new(),
644            compression: CompressionConfig::default(),
645            cache: CacheConfig::default(),
646            access_control: AccessControlConfig::default(),
647            ssl: SiteSslConfig::default(),
648            proxy: ProxyConfig::default(),
649        };
650
651        assert!(site.validate().is_ok());
652
653        // Test invalid hostname
654        site.hostname = "".to_string();
655        assert!(site.validate().is_err());
656
657        // Test invalid port
658        site.hostname = "example.com".to_string();
659        site.port = 0;
660        assert!(site.validate().is_err());
661
662        // Test invalid port range (port 0 is invalid)
663        site.port = 0;
664        assert!(site.validate().is_err());
665    }
666
667    #[test]
668    fn test_hostname_validation() {
669        let site = SiteConfig {
670            name: "test".to_string(),
671            hostname: "localhost".to_string(),
672            port: 8080,
673            static_dir: "/tmp".to_string(),
674            default: false,
675            api_only: false,
676            headers: HashMap::new(),
677            redirect_to_https: false,
678            index_files: vec![],
679            error_pages: HashMap::new(),
680            compression: CompressionConfig::default(),
681            cache: CacheConfig::default(),
682            access_control: AccessControlConfig::default(),
683            ssl: SiteSslConfig::default(),
684            proxy: ProxyConfig::default(),
685        };
686
687        assert!(site.is_valid_hostname());
688
689        let mut invalid_site = site.clone();
690        invalid_site.hostname = "invalid..hostname".to_string();
691        assert!(!invalid_site.is_valid_hostname());
692
693        invalid_site.hostname = "-invalid.hostname".to_string();
694        assert!(!invalid_site.is_valid_hostname());
695    }
696
697    #[test]
698    fn test_compression_config() {
699        let site = SiteConfig {
700            name: "test".to_string(),
701            hostname: "example.com".to_string(),
702            port: 8080,
703            static_dir: "/tmp".to_string(),
704            default: false,
705            api_only: false,
706            headers: HashMap::new(),
707            redirect_to_https: false,
708            index_files: vec![],
709            error_pages: HashMap::new(),
710            compression: CompressionConfig::default(),
711            cache: CacheConfig::default(),
712            access_control: AccessControlConfig::default(),
713            ssl: SiteSslConfig::default(),
714            proxy: ProxyConfig::default(),
715        };
716
717        assert!(site.should_compress("text/html", 2048));
718        assert!(!site.should_compress("text/html", 512)); // Below min_size
719        assert!(!site.should_compress("image/png", 2048)); // Not in types list
720    }
721
722    #[test]
723    fn test_site_url_generation() {
724        let mut site = SiteConfig {
725            name: "test".to_string(),
726            hostname: "example.com".to_string(),
727            port: 8080,
728            static_dir: "/tmp".to_string(),
729            default: false,
730            api_only: false,
731            headers: HashMap::new(),
732            redirect_to_https: false,
733            index_files: vec![],
734            error_pages: HashMap::new(),
735            compression: CompressionConfig::default(),
736            cache: CacheConfig::default(),
737            access_control: AccessControlConfig::default(),
738            ssl: SiteSslConfig::default(),
739            proxy: ProxyConfig::default(),
740        };
741
742        assert_eq!(site.url(), "http://example.com:8080");
743
744        site.ssl.enabled = true;
745        site.port = 443;
746        assert_eq!(site.url(), "https://example.com");
747
748        site.port = 8443;
749        assert_eq!(site.url(), "https://example.com:8443");
750    }
751}