1use crate::error::{SecurityError, SecurityResult};
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum SecurityProfile {
11 Development,
13
14 Staging,
16
17 Production,
19
20 Custom,
22}
23
24impl SecurityProfile {
25 #[allow(clippy::should_implement_trait)]
27 pub fn from_str(s: &str) -> SecurityResult<Self> {
28 match s.to_lowercase().as_str() {
29 "development" | "dev" => Ok(Self::Development),
30 "staging" | "stage" => Ok(Self::Staging),
31 "production" | "prod" => Ok(Self::Production),
32 "custom" => Ok(Self::Custom),
33 _ => Err(SecurityError::config(format!(
34 "Invalid security profile: {s}"
35 ))),
36 }
37 }
38
39 pub fn as_str(&self) -> &'static str {
41 match self {
42 Self::Development => "development",
43 Self::Staging => "staging",
44 Self::Production => "production",
45 Self::Custom => "custom",
46 }
47 }
48}
49
50impl std::str::FromStr for SecurityProfile {
51 type Err = SecurityError;
52
53 fn from_str(s: &str) -> Result<Self, Self::Err> {
54 Self::from_str(s)
55 }
56}
57
58impl std::fmt::Display for SecurityProfile {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "{}", self.as_str())
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct RateLimitConfig {
67 pub max_requests: u32,
69
70 pub window_duration: Duration,
72
73 pub enabled: bool,
75}
76
77impl RateLimitConfig {
78 pub fn permissive() -> Self {
80 Self {
81 max_requests: 10000,
82 window_duration: Duration::from_secs(60),
83 enabled: false,
84 }
85 }
86
87 pub fn moderate() -> Self {
89 Self {
90 max_requests: 1000,
91 window_duration: Duration::from_secs(60),
92 enabled: true,
93 }
94 }
95
96 pub fn strict() -> Self {
98 Self {
99 max_requests: 100,
100 window_duration: Duration::from_secs(60),
101 enabled: true,
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct CorsConfig {
109 pub allowed_origins: Vec<String>,
111
112 pub allow_credentials: bool,
114
115 pub allowed_methods: Vec<String>,
117
118 pub allowed_headers: Vec<String>,
120
121 pub enabled: bool,
123}
124
125impl CorsConfig {
126 pub fn permissive() -> Self {
128 Self {
129 allowed_origins: vec!["*".to_string()],
130 allow_credentials: false,
131 allowed_methods: vec![
132 "GET".to_string(),
133 "POST".to_string(),
134 "PUT".to_string(),
135 "DELETE".to_string(),
136 "OPTIONS".to_string(),
137 ],
138 allowed_headers: vec!["*".to_string()],
139 enabled: true,
140 }
141 }
142
143 pub fn localhost_only() -> Self {
145 Self {
146 allowed_origins: vec![
147 "http://localhost:3000".to_string(),
148 "http://localhost:3001".to_string(),
149 "http://localhost:8080".to_string(),
150 "http://127.0.0.1:3000".to_string(),
151 "http://127.0.0.1:3001".to_string(),
152 "http://127.0.0.1:8080".to_string(),
153 ],
154 allow_credentials: true,
155 allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
156 allowed_headers: vec![
157 "authorization".to_string(),
158 "content-type".to_string(),
159 "x-request-id".to_string(),
160 ],
161 enabled: true,
162 }
163 }
164
165 pub fn strict() -> Self {
167 Self {
168 allowed_origins: Vec::new(), allow_credentials: true,
170 allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
171 allowed_headers: vec!["authorization".to_string(), "content-type".to_string()],
172 enabled: true,
173 }
174 }
175}
176
177pub struct DevelopmentProfile;
179
180impl DevelopmentProfile {
181 pub fn security_settings() -> SecuritySettings {
183 SecuritySettings {
184 require_authentication: false, require_https: false,
186 enable_audit_logging: true, jwt_expiry_seconds: 86400, rate_limit: RateLimitConfig::permissive(),
189 cors: CorsConfig::permissive(),
190 auto_generate_keys: true,
191 validate_token_audience: false, }
193 }
194
195 pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
197 vec![
198 ("MCP_SECURITY_PROFILE", "development"),
199 ("MCP_API_KEY", "auto-generate"),
200 ("MCP_JWT_SECRET", "auto-generate"),
201 ("MCP_REQUIRE_HTTPS", "false"),
202 ("MCP_ENABLE_AUDIT_LOG", "true"),
203 ("MCP_CORS_ORIGIN", "*"),
204 ]
205 }
206}
207
208pub struct StagingProfile;
210
211impl StagingProfile {
212 pub fn security_settings() -> SecuritySettings {
214 SecuritySettings {
215 require_authentication: true,
216 require_https: true, enable_audit_logging: true,
218 jwt_expiry_seconds: 3600, rate_limit: RateLimitConfig::moderate(),
220 cors: CorsConfig::localhost_only(),
221 auto_generate_keys: true,
222 validate_token_audience: true,
223 }
224 }
225
226 pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
228 vec![
229 ("MCP_SECURITY_PROFILE", "staging"),
230 ("MCP_API_KEY", "auto-generate"),
231 ("MCP_JWT_SECRET", "auto-generate"),
232 ("MCP_REQUIRE_HTTPS", "true"),
233 ("MCP_ENABLE_AUDIT_LOG", "true"),
234 ("MCP_RATE_LIMIT", "1000/min"),
235 ("MCP_CORS_ORIGIN", "localhost"),
236 ]
237 }
238}
239
240pub struct ProductionProfile;
242
243impl ProductionProfile {
244 pub fn security_settings() -> SecuritySettings {
246 SecuritySettings {
247 require_authentication: true,
248 require_https: true, enable_audit_logging: true,
250 jwt_expiry_seconds: 900, rate_limit: RateLimitConfig::strict(),
252 cors: CorsConfig::strict(),
253 auto_generate_keys: false, validate_token_audience: true,
255 }
256 }
257
258 pub fn required_env_vars() -> Vec<&'static str> {
260 vec![
261 "MCP_API_KEY",
262 "MCP_JWT_SECRET",
263 "MCP_CORS_ORIGIN",
264 "MCP_ALLOWED_ORIGINS",
265 ]
266 }
267
268 pub fn recommended_env_vars() -> Vec<(&'static str, &'static str)> {
270 vec![
271 ("MCP_SECURITY_PROFILE", "production"),
272 ("MCP_REQUIRE_HTTPS", "true"),
273 ("MCP_ENABLE_AUDIT_LOG", "true"),
274 ("MCP_RATE_LIMIT", "100/min"),
275 ("MCP_JWT_EXPIRY", "900"), ]
277 }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct SecuritySettings {
283 pub require_authentication: bool,
285
286 pub require_https: bool,
288
289 pub enable_audit_logging: bool,
291
292 pub jwt_expiry_seconds: u64,
294
295 pub rate_limit: RateLimitConfig,
297
298 pub cors: CorsConfig,
300
301 pub auto_generate_keys: bool,
303
304 pub validate_token_audience: bool,
306}
307
308impl SecuritySettings {
309 pub fn for_profile(profile: &SecurityProfile) -> Self {
311 match profile {
312 SecurityProfile::Development => DevelopmentProfile::security_settings(),
313 SecurityProfile::Staging => StagingProfile::security_settings(),
314 SecurityProfile::Production => ProductionProfile::security_settings(),
315 SecurityProfile::Custom => Self::default(), }
317 }
318
319 pub fn validate(&self) -> SecurityResult<()> {
321 if self.require_authentication && !self.auto_generate_keys {
323 return Ok(()); }
326
327 if self.require_https
328 && self.cors.allowed_origins.contains(&"*".to_string())
329 && self.cors.allow_credentials
330 {
331 return Err(SecurityError::config(
332 "Cannot use wildcard origins with credentials over HTTPS",
333 ));
334 }
335
336 if self.jwt_expiry_seconds > 86400 * 7 {
337 tracing::warn!(
339 "JWT expiry is longer than 1 week, consider shorter expiry for security"
340 );
341 }
342
343 if self.jwt_expiry_seconds < 60 {
344 return Err(SecurityError::config(
346 "JWT expiry cannot be less than 1 minute",
347 ));
348 }
349
350 Ok(())
351 }
352
353 pub fn security_level_description(&self) -> &'static str {
355 if !self.require_authentication {
356 "Minimal - Authentication disabled"
357 } else if !self.require_https {
358 "Low - HTTP allowed"
359 } else if self.auto_generate_keys {
360 "Medium - Auto-generated keys"
361 } else {
362 "High - Manual key management"
363 }
364 }
365}
366
367impl Default for SecuritySettings {
368 fn default() -> Self {
369 Self {
371 require_authentication: true,
372 require_https: true,
373 enable_audit_logging: true,
374 jwt_expiry_seconds: 3600, rate_limit: RateLimitConfig::moderate(),
376 cors: CorsConfig::strict(),
377 auto_generate_keys: false,
378 validate_token_audience: true,
379 }
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_security_profile_parsing() {
389 assert_eq!(
390 SecurityProfile::from_str("development").unwrap(),
391 SecurityProfile::Development
392 );
393 assert_eq!(
394 SecurityProfile::from_str("dev").unwrap(),
395 SecurityProfile::Development
396 );
397 assert_eq!(
398 SecurityProfile::from_str("staging").unwrap(),
399 SecurityProfile::Staging
400 );
401 assert_eq!(
402 SecurityProfile::from_str("production").unwrap(),
403 SecurityProfile::Production
404 );
405
406 assert!(SecurityProfile::from_str("invalid").is_err());
407 }
408
409 #[test]
410 fn test_security_profile_display() {
411 assert_eq!(SecurityProfile::Development.to_string(), "development");
412 assert_eq!(SecurityProfile::Staging.to_string(), "staging");
413 assert_eq!(SecurityProfile::Production.to_string(), "production");
414 }
415
416 #[test]
417 fn test_development_profile() {
418 let settings = DevelopmentProfile::security_settings();
419 assert!(!settings.require_authentication);
420 assert!(!settings.require_https);
421 assert!(settings.enable_audit_logging);
422 assert!(settings.auto_generate_keys);
423 assert!(!settings.rate_limit.enabled);
424 }
425
426 #[test]
427 fn test_staging_profile() {
428 let settings = StagingProfile::security_settings();
429 assert!(settings.require_authentication);
430 assert!(settings.require_https);
431 assert!(settings.enable_audit_logging);
432 assert!(settings.auto_generate_keys);
433 assert!(settings.rate_limit.enabled);
434 assert_eq!(settings.jwt_expiry_seconds, 3600);
435 }
436
437 #[test]
438 fn test_production_profile() {
439 let settings = ProductionProfile::security_settings();
440 assert!(settings.require_authentication);
441 assert!(settings.require_https);
442 assert!(settings.enable_audit_logging);
443 assert!(!settings.auto_generate_keys); assert!(settings.rate_limit.enabled);
445 assert_eq!(settings.jwt_expiry_seconds, 900); }
447
448 #[test]
449 fn test_rate_limit_configs() {
450 let permissive = RateLimitConfig::permissive();
451 assert!(!permissive.enabled);
452 assert_eq!(permissive.max_requests, 10000);
453
454 let strict = RateLimitConfig::strict();
455 assert!(strict.enabled);
456 assert_eq!(strict.max_requests, 100);
457 }
458
459 #[test]
460 fn test_cors_configs() {
461 let permissive = CorsConfig::permissive();
462 assert_eq!(permissive.allowed_origins, vec!["*"]);
463 assert!(!permissive.allow_credentials);
464
465 let localhost = CorsConfig::localhost_only();
466 assert!(localhost.allow_credentials);
467 assert!(
468 localhost
469 .allowed_origins
470 .contains(&"http://localhost:3000".to_string())
471 );
472
473 let strict = CorsConfig::strict();
474 assert!(strict.allowed_origins.is_empty());
475 assert!(strict.allow_credentials);
476 }
477
478 #[test]
479 fn test_security_settings_validation() {
480 let mut settings = SecuritySettings::default();
481 assert!(settings.validate().is_ok());
482
483 settings.auto_generate_keys = true; settings.jwt_expiry_seconds = 30; assert!(settings.validate().is_err());
487
488 settings.jwt_expiry_seconds = 3600; assert!(settings.validate().is_ok());
490
491 settings.cors.allowed_origins = vec!["*".to_string()];
493 settings.cors.allow_credentials = true;
494 settings.require_https = true;
495 assert!(settings.validate().is_err());
496 }
497
498 #[test]
499 fn test_security_settings_for_profiles() {
500 let dev_settings = SecuritySettings::for_profile(&SecurityProfile::Development);
501 assert!(!dev_settings.require_authentication);
502
503 let prod_settings = SecuritySettings::for_profile(&SecurityProfile::Production);
504 assert!(prod_settings.require_authentication);
505 assert!(prod_settings.require_https);
506 }
507}