pulseengine_mcp_security_middleware/
lib.rs1pub mod auth;
86pub mod config;
87pub mod error;
88pub mod middleware;
89pub mod profiles;
90pub mod utils;
91
92pub use auth::{ApiKeyValidator, AuthContext, TokenValidator};
94pub use config::SecurityConfig;
95pub use error::{SecurityError, SecurityResult};
96pub use middleware::{SecurityMiddleware, mcp_auth_middleware, mcp_rate_limit_middleware};
97pub use profiles::SecurityProfile;
98pub use profiles::{DevelopmentProfile, ProductionProfile, StagingProfile};
99pub use utils::{SecureRandom, generate_api_key, generate_jwt_secret};
100
101pub const VERSION: &str = env!("CARGO_PKG_VERSION");
103
104pub fn dev_security() -> SecurityConfig {
116 SecurityConfig::development()
117}
118
119pub fn prod_security() -> SecurityConfig {
131 SecurityConfig::production()
132}
133
134pub fn env_security() -> SecurityResult<SecurityConfig> {
147 SecurityConfig::from_env()
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_dev_security_creation() {
156 let config = dev_security();
157 assert_eq!(config.profile, SecurityProfile::Development);
158 }
159
160 #[test]
161 fn test_prod_security_creation() {
162 let config = prod_security();
163 assert_eq!(config.profile, SecurityProfile::Production);
164 }
165
166 #[test]
167 fn test_version_format() {
168 assert!(
170 VERSION.chars().any(|c| c.is_ascii_digit()),
171 "VERSION should contain digits: {VERSION}"
172 );
173 }
174
175 #[test]
176 #[serial_test::serial]
177 fn test_env_security_with_invalid_profile() {
178 use std::env;
179
180 unsafe {
182 env::set_var("MCP_SECURITY_PROFILE", "invalid");
183 }
184
185 let result = env_security();
186 assert!(result.is_err(), "Should fail with invalid profile");
187
188 unsafe {
190 env::remove_var("MCP_SECURITY_PROFILE");
191 }
192 }
193
194 #[test]
195 #[serial_test::serial]
196 fn test_env_security_with_valid_profiles() {
197 use std::env;
198
199 for profile in &["development", "staging", "production"] {
200 unsafe {
201 env::set_var("MCP_SECURITY_PROFILE", profile);
202 }
203 let result = env_security();
204 assert!(result.is_ok(), "Should succeed with profile {profile}");
205 unsafe {
206 env::remove_var("MCP_SECURITY_PROFILE");
207 }
208 }
209 }
210
211 #[test]
212 fn test_version_constant() {
213 assert!(VERSION.contains('.'), "Version should contain dots");
215 assert!(
216 VERSION.chars().any(char::is_numeric),
217 "Version should contain numbers"
218 );
219 }
220
221 #[test]
222 fn test_module_exports() {
223 let _config = dev_security();
225 let _prod_config = prod_security();
226
227 use crate::profiles::SecurityProfile;
229 let _dev_profile = SecurityProfile::Development;
230 let _staging_profile = SecurityProfile::Staging;
231 let _prod_profile = SecurityProfile::Production;
232 }
233
234 #[test]
235 fn test_error_constructors() {
236 use crate::error::SecurityError;
237
238 let config_err = SecurityError::config("test config error");
239 assert!(config_err.to_string().contains("test config error"));
240
241 let auth_err = SecurityError::auth("test auth error");
242 assert!(auth_err.to_string().contains("test auth error"));
243
244 let authz_err = SecurityError::authz("test authz error");
245 assert!(authz_err.to_string().contains("test authz error"));
246
247 let token_err = SecurityError::invalid_token("test token error");
248 assert!(token_err.to_string().contains("test token error"));
249
250 let jwt_err = SecurityError::jwt("test jwt error");
251 assert!(jwt_err.to_string().contains("test jwt error"));
252
253 let random_err = SecurityError::random("test random error");
254 assert!(random_err.to_string().contains("test random error"));
255
256 let crypto_err = SecurityError::crypto("test crypto error");
257 assert!(crypto_err.to_string().contains("test crypto error"));
258
259 let http_err = SecurityError::http("test http error");
260 assert!(http_err.to_string().contains("test http error"));
261
262 let internal_err = SecurityError::internal("test internal error");
263 assert!(internal_err.to_string().contains("test internal error"));
264 }
265
266 #[test]
267 fn test_security_config_methods() {
268 use crate::config::SecurityConfig;
269 use crate::profiles::SecurityProfile;
270
271 let dev_config = SecurityConfig::development();
273 assert_eq!(dev_config.profile, SecurityProfile::Development);
274
275 let staging_config = SecurityConfig::staging();
276 assert_eq!(staging_config.profile, SecurityProfile::Staging);
277
278 let prod_config = SecurityConfig::production();
279 assert_eq!(prod_config.profile, SecurityProfile::Production);
280
281 let config = SecurityConfig::development()
283 .with_api_key("test_key")
284 .with_jwt_secret("test_secret")
285 .with_jwt_issuer("test_issuer")
286 .with_jwt_audience("test_audience");
287
288 assert_eq!(config.api_key.as_ref().unwrap(), "test_key");
289 assert_eq!(config.jwt_secret.as_ref().unwrap(), "test_secret");
290 assert_eq!(config.jwt_issuer, "test_issuer");
291 assert_eq!(config.jwt_audience, "test_audience");
292 }
293
294 #[test]
295 fn test_additional_utility_functions() {
296 use crate::utils::{SecureRandom, generate_request_id, generate_session_id};
297
298 let session1 = generate_session_id();
300 let session2 = generate_session_id();
301 assert_ne!(session1, session2);
302 assert!(session1.len() > 10);
303
304 let request1 = generate_request_id();
306 let request2 = generate_request_id();
307 assert_ne!(request1, request2);
308 assert!(request1.len() > 10);
309
310 let b64_str1 = SecureRandom::base64_string(32);
312 let b64_str2 = SecureRandom::base64_string(32);
313 assert_ne!(b64_str1, b64_str2);
314 assert!(b64_str1.len() > 40); let b64_url_str1 = SecureRandom::base64_url_string(32);
318 let b64_url_str2 = SecureRandom::base64_url_string(32);
319 assert_ne!(b64_url_str1, b64_url_str2);
320 assert!(b64_url_str1.len() > 40);
321 assert!(!b64_url_str1.contains('+'));
322 assert!(!b64_url_str1.contains('/'));
323 }
324}