1use crate::error::{SecurityError, SecurityResult};
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum SecurityProfile {
11    Development,
13
14    Staging,
16
17    Production,
19
20    Custom,
22}
23
24impl SecurityProfile {
25    #[allow(clippy::should_implement_trait)]
27    pub fn from_str(s: &str) -> SecurityResult<Self> {
28        match s.to_lowercase().as_str() {
29            "development" | "dev" => Ok(Self::Development),
30            "staging" | "stage" => Ok(Self::Staging),
31            "production" | "prod" => Ok(Self::Production),
32            "custom" => Ok(Self::Custom),
33            _ => Err(SecurityError::config(format!(
34                "Invalid security profile: {s}"
35            ))),
36        }
37    }
38
39    pub fn as_str(&self) -> &'static str {
41        match self {
42            Self::Development => "development",
43            Self::Staging => "staging",
44            Self::Production => "production",
45            Self::Custom => "custom",
46        }
47    }
48}
49
50impl std::str::FromStr for SecurityProfile {
51    type Err = SecurityError;
52
53    fn from_str(s: &str) -> Result<Self, Self::Err> {
54        Self::from_str(s)
55    }
56}
57
58impl std::fmt::Display for SecurityProfile {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "{}", self.as_str())
61    }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct RateLimitConfig {
67    pub max_requests: u32,
69
70    pub window_duration: Duration,
72
73    pub enabled: bool,
75}
76
77impl RateLimitConfig {
78    pub fn permissive() -> Self {
80        Self {
81            max_requests: 10000,
82            window_duration: Duration::from_secs(60),
83            enabled: false,
84        }
85    }
86
87    pub fn moderate() -> Self {
89        Self {
90            max_requests: 1000,
91            window_duration: Duration::from_secs(60),
92            enabled: true,
93        }
94    }
95
96    pub fn strict() -> Self {
98        Self {
99            max_requests: 100,
100            window_duration: Duration::from_secs(60),
101            enabled: true,
102        }
103    }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct CorsConfig {
109    pub allowed_origins: Vec<String>,
111
112    pub allow_credentials: bool,
114
115    pub allowed_methods: Vec<String>,
117
118    pub allowed_headers: Vec<String>,
120
121    pub enabled: bool,
123}
124
125impl CorsConfig {
126    pub fn permissive() -> Self {
128        Self {
129            allowed_origins: vec!["*".to_string()],
130            allow_credentials: false,
131            allowed_methods: vec![
132                "GET".to_string(),
133                "POST".to_string(),
134                "PUT".to_string(),
135                "DELETE".to_string(),
136                "OPTIONS".to_string(),
137            ],
138            allowed_headers: vec!["*".to_string()],
139            enabled: true,
140        }
141    }
142
143    pub fn localhost_only() -> Self {
145        Self {
146            allowed_origins: vec![
147                "http://localhost:3000".to_string(),
148                "http://localhost:3001".to_string(),
149                "http://localhost:8080".to_string(),
150                "http://127.0.0.1:3000".to_string(),
151                "http://127.0.0.1:3001".to_string(),
152                "http://127.0.0.1:8080".to_string(),
153            ],
154            allow_credentials: true,
155            allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
156            allowed_headers: vec![
157                "authorization".to_string(),
158                "content-type".to_string(),
159                "x-request-id".to_string(),
160            ],
161            enabled: true,
162        }
163    }
164
165    pub fn strict() -> Self {
167        Self {
168            allowed_origins: Vec::new(), allow_credentials: true,
170            allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
171            allowed_headers: vec!["authorization".to_string(), "content-type".to_string()],
172            enabled: true,
173        }
174    }
175}
176
177pub struct DevelopmentProfile;
179
180impl DevelopmentProfile {
181    pub fn security_settings() -> SecuritySettings {
183        SecuritySettings {
184            require_authentication: false, require_https: false,
186            enable_audit_logging: true, jwt_expiry_seconds: 86400,  rate_limit: RateLimitConfig::permissive(),
189            cors: CorsConfig::permissive(),
190            auto_generate_keys: true,
191            validate_token_audience: false, }
193    }
194
195    pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
197        vec![
198            ("MCP_SECURITY_PROFILE", "development"),
199            ("MCP_API_KEY", "auto-generate"),
200            ("MCP_JWT_SECRET", "auto-generate"),
201            ("MCP_REQUIRE_HTTPS", "false"),
202            ("MCP_ENABLE_AUDIT_LOG", "true"),
203            ("MCP_CORS_ORIGIN", "*"),
204        ]
205    }
206}
207
208pub struct StagingProfile;
210
211impl StagingProfile {
212    pub fn security_settings() -> SecuritySettings {
214        SecuritySettings {
215            require_authentication: true,
216            require_https: true, enable_audit_logging: true,
218            jwt_expiry_seconds: 3600, rate_limit: RateLimitConfig::moderate(),
220            cors: CorsConfig::localhost_only(),
221            auto_generate_keys: true,
222            validate_token_audience: true,
223        }
224    }
225
226    pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
228        vec![
229            ("MCP_SECURITY_PROFILE", "staging"),
230            ("MCP_API_KEY", "auto-generate"),
231            ("MCP_JWT_SECRET", "auto-generate"),
232            ("MCP_REQUIRE_HTTPS", "true"),
233            ("MCP_ENABLE_AUDIT_LOG", "true"),
234            ("MCP_RATE_LIMIT", "1000/min"),
235            ("MCP_CORS_ORIGIN", "localhost"),
236        ]
237    }
238}
239
240pub struct ProductionProfile;
242
243impl ProductionProfile {
244    pub fn security_settings() -> SecuritySettings {
246        SecuritySettings {
247            require_authentication: true,
248            require_https: true, enable_audit_logging: true,
250            jwt_expiry_seconds: 900, rate_limit: RateLimitConfig::strict(),
252            cors: CorsConfig::strict(),
253            auto_generate_keys: false, validate_token_audience: true,
255        }
256    }
257
258    pub fn required_env_vars() -> Vec<&'static str> {
260        vec![
261            "MCP_API_KEY",
262            "MCP_JWT_SECRET",
263            "MCP_CORS_ORIGIN",
264            "MCP_ALLOWED_ORIGINS",
265        ]
266    }
267
268    pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
270        vec![
271            ("MCP_SECURITY_PROFILE", "production"),
272            ("MCP_REQUIRE_HTTPS", "true"),
273            ("MCP_ENABLE_AUDIT_LOG", "true"),
274            ("MCP_RATE_LIMIT", "100/min"),
275            ("MCP_JWT_EXPIRY", "900"), ]
277    }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct SecuritySettings {
283    pub require_authentication: bool,
285
286    pub require_https: bool,
288
289    pub enable_audit_logging: bool,
291
292    pub jwt_expiry_seconds: u64,
294
295    pub rate_limit: RateLimitConfig,
297
298    pub cors: CorsConfig,
300
301    pub auto_generate_keys: bool,
303
304    pub validate_token_audience: bool,
306}
307
308impl SecuritySettings {
309    pub fn for_profile(profile: &SecurityProfile) -> Self {
311        match profile {
312            SecurityProfile::Development => DevelopmentProfile::security_settings(),
313            SecurityProfile::Staging => StagingProfile::security_settings(),
314            SecurityProfile::Production => ProductionProfile::security_settings(),
315            SecurityProfile::Custom => Self::default(), }
317    }
318
319    pub fn validate(&self) -> SecurityResult<()> {
321        if self.require_authentication && !self.auto_generate_keys {
323            return Ok(()); }
326
327        if self.require_https
328            && self.cors.allowed_origins.contains(&"*".to_string())
329            && self.cors.allow_credentials
330        {
331            return Err(SecurityError::config(
332                "Cannot use wildcard origins with credentials over HTTPS",
333            ));
334        }
335
336        if self.jwt_expiry_seconds > 86400 * 7 {
337            tracing::warn!(
339                "JWT expiry is longer than 1 week, consider shorter expiry for security"
340            );
341        }
342
343        if self.jwt_expiry_seconds < 60 {
344            return Err(SecurityError::config(
346                "JWT expiry cannot be less than 1 minute",
347            ));
348        }
349
350        Ok(())
351    }
352
353    pub fn security_level_description(&self) -> &'static str {
355        if !self.require_authentication {
356            "Minimal - Authentication disabled"
357        } else if !self.require_https {
358            "Low - HTTP allowed"
359        } else if self.auto_generate_keys {
360            "Medium - Auto-generated keys"
361        } else {
362            "High - Manual key management"
363        }
364    }
365}
366
367impl Default for SecuritySettings {
368    fn default() -> Self {
369        Self {
371            require_authentication: true,
372            require_https: true,
373            enable_audit_logging: true,
374            jwt_expiry_seconds: 3600, rate_limit: RateLimitConfig::moderate(),
376            cors: CorsConfig::strict(),
377            auto_generate_keys: false,
378            validate_token_audience: true,
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_security_profile_parsing() {
389        assert_eq!(
390            SecurityProfile::from_str("development").unwrap(),
391            SecurityProfile::Development
392        );
393        assert_eq!(
394            SecurityProfile::from_str("dev").unwrap(),
395            SecurityProfile::Development
396        );
397        assert_eq!(
398            SecurityProfile::from_str("staging").unwrap(),
399            SecurityProfile::Staging
400        );
401        assert_eq!(
402            SecurityProfile::from_str("production").unwrap(),
403            SecurityProfile::Production
404        );
405
406        assert!(SecurityProfile::from_str("invalid").is_err());
407    }
408
409    #[test]
410    fn test_security_profile_display() {
411        assert_eq!(SecurityProfile::Development.to_string(), "development");
412        assert_eq!(SecurityProfile::Staging.to_string(), "staging");
413        assert_eq!(SecurityProfile::Production.to_string(), "production");
414    }
415
416    #[test]
417    fn test_development_profile() {
418        let settings = DevelopmentProfile::security_settings();
419        assert!(!settings.require_authentication);
420        assert!(!settings.require_https);
421        assert!(settings.enable_audit_logging);
422        assert!(settings.auto_generate_keys);
423        assert!(!settings.rate_limit.enabled);
424    }
425
426    #[test]
427    fn test_staging_profile() {
428        let settings = StagingProfile::security_settings();
429        assert!(settings.require_authentication);
430        assert!(settings.require_https);
431        assert!(settings.enable_audit_logging);
432        assert!(settings.auto_generate_keys);
433        assert!(settings.rate_limit.enabled);
434        assert_eq!(settings.jwt_expiry_seconds, 3600);
435    }
436
437    #[test]
438    fn test_production_profile() {
439        let settings = ProductionProfile::security_settings();
440        assert!(settings.require_authentication);
441        assert!(settings.require_https);
442        assert!(settings.enable_audit_logging);
443        assert!(!settings.auto_generate_keys); assert!(settings.rate_limit.enabled);
445        assert_eq!(settings.jwt_expiry_seconds, 900); }
447
448    #[test]
449    fn test_rate_limit_configs() {
450        let permissive = RateLimitConfig::permissive();
451        assert!(!permissive.enabled);
452        assert_eq!(permissive.max_requests, 10000);
453
454        let strict = RateLimitConfig::strict();
455        assert!(strict.enabled);
456        assert_eq!(strict.max_requests, 100);
457    }
458
459    #[test]
460    fn test_cors_configs() {
461        let permissive = CorsConfig::permissive();
462        assert_eq!(permissive.allowed_origins, vec!["*"]);
463        assert!(!permissive.allow_credentials);
464
465        let localhost = CorsConfig::localhost_only();
466        assert!(localhost.allow_credentials);
467        assert!(
468            localhost
469                .allowed_origins
470                .contains(&"http://localhost:3000".to_string())
471        );
472
473        let strict = CorsConfig::strict();
474        assert!(strict.allowed_origins.is_empty());
475        assert!(strict.allow_credentials);
476    }
477
478    #[test]
479    fn test_security_settings_validation() {
480        let mut settings = SecuritySettings::default();
481        assert!(settings.validate().is_ok());
482
483        settings.auto_generate_keys = true; settings.jwt_expiry_seconds = 30; assert!(settings.validate().is_err());
487
488        settings.jwt_expiry_seconds = 3600; assert!(settings.validate().is_ok());
490
491        settings.cors.allowed_origins = vec!["*".to_string()];
493        settings.cors.allow_credentials = true;
494        settings.require_https = true;
495        assert!(settings.validate().is_err());
496    }
497
498    #[test]
499    fn test_security_settings_for_profiles() {
500        let dev_settings = SecuritySettings::for_profile(&SecurityProfile::Development);
501        assert!(!dev_settings.require_authentication);
502
503        let prod_settings = SecuritySettings::for_profile(&SecurityProfile::Production);
504        assert!(prod_settings.require_authentication);
505        assert!(prod_settings.require_https);
506    }
507}