pulseengine_mcp_security_middleware/
profiles.rs

1//! Security profiles for different deployment environments
2
3use crate::error::{SecurityError, SecurityResult};
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6
7/// Security profile defining the security level and features
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum SecurityProfile {
11    /// Development profile - permissive settings for local development
12    Development,
13
14    /// Staging profile - balanced security for testing environments
15    Staging,
16
17    /// Production profile - strict security for production deployments
18    Production,
19
20    /// Custom profile - user-defined security settings
21    Custom,
22}
23
24impl SecurityProfile {
25    /// Parse security profile from string
26    #[allow(clippy::should_implement_trait)]
27    pub fn from_str(s: &str) -> SecurityResult<Self> {
28        match s.to_lowercase().as_str() {
29            "development" | "dev" => Ok(Self::Development),
30            "staging" | "stage" => Ok(Self::Staging),
31            "production" | "prod" => Ok(Self::Production),
32            "custom" => Ok(Self::Custom),
33            _ => Err(SecurityError::config(format!(
34                "Invalid security profile: {s}"
35            ))),
36        }
37    }
38
39    /// Get the string representation
40    pub fn as_str(&self) -> &'static str {
41        match self {
42            Self::Development => "development",
43            Self::Staging => "staging",
44            Self::Production => "production",
45            Self::Custom => "custom",
46        }
47    }
48}
49
50impl std::str::FromStr for SecurityProfile {
51    type Err = SecurityError;
52
53    fn from_str(s: &str) -> Result<Self, Self::Err> {
54        Self::from_str(s)
55    }
56}
57
58impl std::fmt::Display for SecurityProfile {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "{}", self.as_str())
61    }
62}
63
64/// Rate limiting configuration
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct RateLimitConfig {
67    /// Maximum requests per window
68    pub max_requests: u32,
69
70    /// Time window duration
71    pub window_duration: Duration,
72
73    /// Whether rate limiting is enabled
74    pub enabled: bool,
75}
76
77impl RateLimitConfig {
78    /// Create permissive rate limiting (high limits)
79    pub fn permissive() -> Self {
80        Self {
81            max_requests: 10000,
82            window_duration: Duration::from_secs(60),
83            enabled: false,
84        }
85    }
86
87    /// Create moderate rate limiting
88    pub fn moderate() -> Self {
89        Self {
90            max_requests: 1000,
91            window_duration: Duration::from_secs(60),
92            enabled: true,
93        }
94    }
95
96    /// Create strict rate limiting (low limits)
97    pub fn strict() -> Self {
98        Self {
99            max_requests: 100,
100            window_duration: Duration::from_secs(60),
101            enabled: true,
102        }
103    }
104}
105
106/// CORS (Cross-Origin Resource Sharing) configuration
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct CorsConfig {
109    /// Allowed origins
110    pub allowed_origins: Vec<String>,
111
112    /// Allow credentials
113    pub allow_credentials: bool,
114
115    /// Allowed methods
116    pub allowed_methods: Vec<String>,
117
118    /// Allowed headers
119    pub allowed_headers: Vec<String>,
120
121    /// Whether CORS is enabled
122    pub enabled: bool,
123}
124
125impl CorsConfig {
126    /// Create permissive CORS configuration
127    pub fn permissive() -> Self {
128        Self {
129            allowed_origins: vec!["*".to_string()],
130            allow_credentials: false,
131            allowed_methods: vec![
132                "GET".to_string(),
133                "POST".to_string(),
134                "PUT".to_string(),
135                "DELETE".to_string(),
136                "OPTIONS".to_string(),
137            ],
138            allowed_headers: vec!["*".to_string()],
139            enabled: true,
140        }
141    }
142
143    /// Create localhost-only CORS configuration
144    pub fn localhost_only() -> Self {
145        Self {
146            allowed_origins: vec![
147                "http://localhost:3000".to_string(),
148                "http://localhost:3001".to_string(),
149                "http://localhost:8080".to_string(),
150                "http://127.0.0.1:3000".to_string(),
151                "http://127.0.0.1:3001".to_string(),
152                "http://127.0.0.1:8080".to_string(),
153            ],
154            allow_credentials: true,
155            allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
156            allowed_headers: vec![
157                "authorization".to_string(),
158                "content-type".to_string(),
159                "x-request-id".to_string(),
160            ],
161            enabled: true,
162        }
163    }
164
165    /// Create strict CORS configuration
166    pub fn strict() -> Self {
167        Self {
168            allowed_origins: Vec::new(), // Must be explicitly configured
169            allow_credentials: true,
170            allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
171            allowed_headers: vec!["authorization".to_string(), "content-type".to_string()],
172            enabled: true,
173        }
174    }
175}
176
177/// Development security profile implementation
178pub struct DevelopmentProfile;
179
180impl DevelopmentProfile {
181    /// Create development security settings
182    pub fn security_settings() -> SecuritySettings {
183        SecuritySettings {
184            require_authentication: false, // Optional for dev
185            require_https: false,
186            enable_audit_logging: true, // Good for debugging
187            jwt_expiry_seconds: 86400,  // 24 hours - long for dev convenience
188            rate_limit: RateLimitConfig::permissive(),
189            cors: CorsConfig::permissive(),
190            auto_generate_keys: true,
191            validate_token_audience: false, // Relaxed for dev
192        }
193    }
194
195    /// Get recommended environment variables for development
196    pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
197        vec![
198            ("MCP_SECURITY_PROFILE", "development"),
199            ("MCP_API_KEY", "auto-generate"),
200            ("MCP_JWT_SECRET", "auto-generate"),
201            ("MCP_REQUIRE_HTTPS", "false"),
202            ("MCP_ENABLE_AUDIT_LOG", "true"),
203            ("MCP_CORS_ORIGIN", "*"),
204        ]
205    }
206}
207
208/// Staging security profile implementation
209pub struct StagingProfile;
210
211impl StagingProfile {
212    /// Create staging security settings
213    pub fn security_settings() -> SecuritySettings {
214        SecuritySettings {
215            require_authentication: true,
216            require_https: true, // HTTPS required for staging
217            enable_audit_logging: true,
218            jwt_expiry_seconds: 3600, // 1 hour
219            rate_limit: RateLimitConfig::moderate(),
220            cors: CorsConfig::localhost_only(),
221            auto_generate_keys: true,
222            validate_token_audience: true,
223        }
224    }
225
226    /// Get recommended environment variables for staging
227    pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
228        vec![
229            ("MCP_SECURITY_PROFILE", "staging"),
230            ("MCP_API_KEY", "auto-generate"),
231            ("MCP_JWT_SECRET", "auto-generate"),
232            ("MCP_REQUIRE_HTTPS", "true"),
233            ("MCP_ENABLE_AUDIT_LOG", "true"),
234            ("MCP_RATE_LIMIT", "1000/min"),
235            ("MCP_CORS_ORIGIN", "localhost"),
236        ]
237    }
238}
239
240/// Production security profile implementation
241pub struct ProductionProfile;
242
243impl ProductionProfile {
244    /// Create production security settings
245    pub fn security_settings() -> SecuritySettings {
246        SecuritySettings {
247            require_authentication: true,
248            require_https: true, // Mandatory HTTPS
249            enable_audit_logging: true,
250            jwt_expiry_seconds: 900, // 15 minutes - short for security
251            rate_limit: RateLimitConfig::strict(),
252            cors: CorsConfig::strict(),
253            auto_generate_keys: false, // Manual key management in production
254            validate_token_audience: true,
255        }
256    }
257
258    /// Get required environment variables for production
259    pub fn required_env_vars() -> Vec<&'static str> {
260        vec![
261            "MCP_API_KEY",
262            "MCP_JWT_SECRET",
263            "MCP_CORS_ORIGIN",
264            "MCP_ALLOWED_ORIGINS",
265        ]
266    }
267
268    /// Get recommended environment variables for production
269    pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
270        vec![
271            ("MCP_SECURITY_PROFILE", "production"),
272            ("MCP_REQUIRE_HTTPS", "true"),
273            ("MCP_ENABLE_AUDIT_LOG", "true"),
274            ("MCP_RATE_LIMIT", "100/min"),
275            ("MCP_JWT_EXPIRY", "900"), // 15 minutes
276        ]
277    }
278}
279
280/// Consolidated security settings
281#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct SecuritySettings {
283    /// Whether authentication is required
284    pub require_authentication: bool,
285
286    /// Whether HTTPS is required
287    pub require_https: bool,
288
289    /// Whether audit logging is enabled
290    pub enable_audit_logging: bool,
291
292    /// JWT token expiry in seconds
293    pub jwt_expiry_seconds: u64,
294
295    /// Rate limiting configuration
296    pub rate_limit: RateLimitConfig,
297
298    /// CORS configuration
299    pub cors: CorsConfig,
300
301    /// Whether to auto-generate keys if not provided
302    pub auto_generate_keys: bool,
303
304    /// Whether to validate JWT token audience
305    pub validate_token_audience: bool,
306}
307
308impl SecuritySettings {
309    /// Create settings for a specific profile
310    pub fn for_profile(profile: &SecurityProfile) -> Self {
311        match profile {
312            SecurityProfile::Development => DevelopmentProfile::security_settings(),
313            SecurityProfile::Staging => StagingProfile::security_settings(),
314            SecurityProfile::Production => ProductionProfile::security_settings(),
315            SecurityProfile::Custom => Self::default(), // User must customize
316        }
317    }
318
319    /// Validate the security settings
320    pub fn validate(&self) -> SecurityResult<()> {
321        // Check for inconsistencies
322        if self.require_authentication && !self.auto_generate_keys {
323            // In production, we need explicit keys
324            return Ok(()); // This will be validated elsewhere
325        }
326
327        if self.require_https
328            && self.cors.allowed_origins.contains(&"*".to_string())
329            && self.cors.allow_credentials
330        {
331            return Err(SecurityError::config(
332                "Cannot use wildcard origins with credentials over HTTPS",
333            ));
334        }
335
336        if self.jwt_expiry_seconds > 86400 * 7 {
337            // More than 1 week
338            tracing::warn!(
339                "JWT expiry is longer than 1 week, consider shorter expiry for security"
340            );
341        }
342
343        if self.jwt_expiry_seconds < 60 {
344            // Less than 1 minute
345            return Err(SecurityError::config(
346                "JWT expiry cannot be less than 1 minute",
347            ));
348        }
349
350        Ok(())
351    }
352
353    /// Get the security level description
354    pub fn security_level_description(&self) -> &'static str {
355        if !self.require_authentication {
356            "Minimal - Authentication disabled"
357        } else if !self.require_https {
358            "Low - HTTP allowed"
359        } else if self.auto_generate_keys {
360            "Medium - Auto-generated keys"
361        } else {
362            "High - Manual key management"
363        }
364    }
365}
366
367impl Default for SecuritySettings {
368    fn default() -> Self {
369        // Safe defaults that work but encourage explicit configuration
370        Self {
371            require_authentication: true,
372            require_https: true,
373            enable_audit_logging: true,
374            jwt_expiry_seconds: 3600, // 1 hour
375            rate_limit: RateLimitConfig::moderate(),
376            cors: CorsConfig::strict(),
377            auto_generate_keys: false,
378            validate_token_audience: true,
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_security_profile_parsing() {
389        assert_eq!(
390            SecurityProfile::from_str("development").unwrap(),
391            SecurityProfile::Development
392        );
393        assert_eq!(
394            SecurityProfile::from_str("dev").unwrap(),
395            SecurityProfile::Development
396        );
397        assert_eq!(
398            SecurityProfile::from_str("staging").unwrap(),
399            SecurityProfile::Staging
400        );
401        assert_eq!(
402            SecurityProfile::from_str("production").unwrap(),
403            SecurityProfile::Production
404        );
405
406        assert!(SecurityProfile::from_str("invalid").is_err());
407    }
408
409    #[test]
410    fn test_security_profile_display() {
411        assert_eq!(SecurityProfile::Development.to_string(), "development");
412        assert_eq!(SecurityProfile::Staging.to_string(), "staging");
413        assert_eq!(SecurityProfile::Production.to_string(), "production");
414    }
415
416    #[test]
417    fn test_development_profile() {
418        let settings = DevelopmentProfile::security_settings();
419        assert!(!settings.require_authentication);
420        assert!(!settings.require_https);
421        assert!(settings.enable_audit_logging);
422        assert!(settings.auto_generate_keys);
423        assert!(!settings.rate_limit.enabled);
424    }
425
426    #[test]
427    fn test_staging_profile() {
428        let settings = StagingProfile::security_settings();
429        assert!(settings.require_authentication);
430        assert!(settings.require_https);
431        assert!(settings.enable_audit_logging);
432        assert!(settings.auto_generate_keys);
433        assert!(settings.rate_limit.enabled);
434        assert_eq!(settings.jwt_expiry_seconds, 3600);
435    }
436
437    #[test]
438    fn test_production_profile() {
439        let settings = ProductionProfile::security_settings();
440        assert!(settings.require_authentication);
441        assert!(settings.require_https);
442        assert!(settings.enable_audit_logging);
443        assert!(!settings.auto_generate_keys); // Manual key management
444        assert!(settings.rate_limit.enabled);
445        assert_eq!(settings.jwt_expiry_seconds, 900); // 15 minutes
446    }
447
448    #[test]
449    fn test_rate_limit_configs() {
450        let permissive = RateLimitConfig::permissive();
451        assert!(!permissive.enabled);
452        assert_eq!(permissive.max_requests, 10000);
453
454        let strict = RateLimitConfig::strict();
455        assert!(strict.enabled);
456        assert_eq!(strict.max_requests, 100);
457    }
458
459    #[test]
460    fn test_cors_configs() {
461        let permissive = CorsConfig::permissive();
462        assert_eq!(permissive.allowed_origins, vec!["*"]);
463        assert!(!permissive.allow_credentials);
464
465        let localhost = CorsConfig::localhost_only();
466        assert!(localhost.allow_credentials);
467        assert!(
468            localhost
469                .allowed_origins
470                .contains(&"http://localhost:3000".to_string())
471        );
472
473        let strict = CorsConfig::strict();
474        assert!(strict.allowed_origins.is_empty());
475        assert!(strict.allow_credentials);
476    }
477
478    #[test]
479    fn test_security_settings_validation() {
480        let mut settings = SecuritySettings::default();
481        assert!(settings.validate().is_ok());
482
483        // Test invalid JWT expiry - need to avoid the early return by setting auto_generate_keys
484        settings.auto_generate_keys = true; // This prevents early return
485        settings.jwt_expiry_seconds = 30; // Too short (less than 60 seconds)
486        assert!(settings.validate().is_err());
487
488        settings.jwt_expiry_seconds = 3600; // Valid again
489        assert!(settings.validate().is_ok());
490
491        // Test CORS validation with credentials + wildcard
492        settings.cors.allowed_origins = vec!["*".to_string()];
493        settings.cors.allow_credentials = true;
494        settings.require_https = true;
495        assert!(settings.validate().is_err());
496    }
497
498    #[test]
499    fn test_security_settings_for_profiles() {
500        let dev_settings = SecuritySettings::for_profile(&SecurityProfile::Development);
501        assert!(!dev_settings.require_authentication);
502
503        let prod_settings = SecuritySettings::for_profile(&SecurityProfile::Production);
504        assert!(prod_settings.require_authentication);
505        assert!(prod_settings.require_https);
506    }
507}