pulseengine_mcp_security_middleware/
utils.rs

1//! Utility functions for security operations
2
3use crate::error::{SecurityError, SecurityResult};
4use base64::{Engine as _, engine::general_purpose};
5use rand::{Rng, distributions::Alphanumeric};
6use sha2::{Digest, Sha256};
7use std::time::{SystemTime, UNIX_EPOCH};
8
9/// Secure random generator for security operations
10pub struct SecureRandom;
11
12impl SecureRandom {
13    /// Generate cryptographically secure random bytes
14    pub fn bytes(length: usize) -> Vec<u8> {
15        let mut rng = rand::thread_rng();
16        (0..length).map(|_| rng.r#gen()).collect()
17    }
18
19    /// Generate cryptographically secure random string
20    pub fn string(length: usize) -> String {
21        rand::thread_rng()
22            .sample_iter(&Alphanumeric)
23            .take(length)
24            .map(char::from)
25            .collect()
26    }
27
28    /// Generate base64-encoded random string
29    pub fn base64_string(byte_length: usize) -> String {
30        let bytes = Self::bytes(byte_length);
31        general_purpose::STANDARD.encode(bytes)
32    }
33
34    /// Generate URL-safe base64-encoded random string
35    pub fn base64_url_string(byte_length: usize) -> String {
36        let bytes = Self::bytes(byte_length);
37        general_purpose::URL_SAFE_NO_PAD.encode(bytes)
38    }
39}
40
41/// Generate a secure API key
42///
43/// API keys are prefixed with "mcp_" and contain 32 bytes of random data
44/// encoded in base64-url format for URL safety.
45///
46/// # Example
47/// ```rust
48/// use pulseengine_mcp_security_middleware::generate_api_key;
49///
50/// let api_key = generate_api_key();
51/// assert!(api_key.starts_with("mcp_"));
52/// assert!(api_key.len() > 20); // At least 20 characters
53/// ```
54pub fn generate_api_key() -> String {
55    let random_part = SecureRandom::base64_url_string(32);
56    format!("mcp_{random_part}")
57}
58
59/// Generate a secure JWT secret
60///
61/// JWT secrets are 64 bytes of cryptographically secure random data
62/// encoded in base64 format.
63///
64/// # Example
65/// ```rust
66/// use pulseengine_mcp_security_middleware::generate_jwt_secret;
67///
68/// let secret = generate_jwt_secret();
69/// assert!(secret.len() >= 64); // At least 64 characters for security
70/// ```
71pub fn generate_jwt_secret() -> String {
72    SecureRandom::base64_string(64)
73}
74
75/// Hash an API key for storage
76///
77/// Uses SHA-256 to hash API keys for secure storage. The original key
78/// should never be stored, only the hash.
79pub fn hash_api_key(api_key: &str) -> String {
80    let mut hasher = Sha256::new();
81    hasher.update(api_key.as_bytes());
82    let result = hasher.finalize();
83    general_purpose::STANDARD.encode(result)
84}
85
86/// Verify an API key against its hash
87///
88/// Compares the hash of the provided API key with the stored hash.
89pub fn verify_api_key(api_key: &str, stored_hash: &str) -> bool {
90    let computed_hash = hash_api_key(api_key);
91    computed_hash == stored_hash
92}
93
94/// Get current Unix timestamp
95pub fn current_timestamp() -> u64 {
96    SystemTime::now()
97        .duration_since(UNIX_EPOCH)
98        .expect("Time went backwards")
99        .as_secs()
100}
101
102/// Validate that a string looks like a valid API key
103pub fn validate_api_key_format(key: &str) -> SecurityResult<()> {
104    if !key.starts_with("mcp_") {
105        return Err(SecurityError::invalid_token(
106            "API key must start with 'mcp_'",
107        ));
108    }
109
110    if key.len() < 20 {
111        return Err(SecurityError::invalid_token("API key too short"));
112    }
113
114    if key.len() > 200 {
115        return Err(SecurityError::invalid_token("API key too long"));
116    }
117
118    // Check that it contains only valid base64-url characters after prefix
119    let key_part = &key[4..]; // Skip "mcp_" prefix
120    for c in key_part.chars() {
121        if !c.is_alphanumeric() && c != '-' && c != '_' {
122            return Err(SecurityError::invalid_token(
123                "Invalid characters in API key",
124            ));
125        }
126    }
127
128    Ok(())
129}
130
131/// Generate a secure session ID
132pub fn generate_session_id() -> String {
133    format!("sess_{}", SecureRandom::base64_url_string(32))
134}
135
136/// Generate a secure request ID for tracing
137pub fn generate_request_id() -> String {
138    format!("req_{}", SecureRandom::base64_url_string(16))
139}
140
141/// Safe comparison function to prevent timing attacks
142pub fn secure_compare(a: &str, b: &str) -> bool {
143    if a.len() != b.len() {
144        return false;
145    }
146
147    let mut result = 0u8;
148    for (byte_a, byte_b) in a.bytes().zip(b.bytes()) {
149        result |= byte_a ^ byte_b;
150    }
151
152    result == 0
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_generate_api_key() {
161        let key1 = generate_api_key();
162        let key2 = generate_api_key();
163
164        // Keys should be different
165        assert_ne!(key1, key2);
166
167        // Both should start with prefix
168        assert!(key1.starts_with("mcp_"));
169        assert!(key2.starts_with("mcp_"));
170
171        // Should have reasonable length
172        assert!(key1.len() > 20);
173        assert!(key2.len() > 20);
174    }
175
176    #[test]
177    fn test_generate_jwt_secret() {
178        let secret1 = generate_jwt_secret();
179        let secret2 = generate_jwt_secret();
180
181        // Secrets should be different
182        assert_ne!(secret1, secret2);
183
184        // Should have sufficient length for security
185        assert!(secret1.len() >= 64);
186        assert!(secret2.len() >= 64);
187    }
188
189    #[test]
190    fn test_hash_and_verify_api_key() {
191        let api_key = generate_api_key();
192        let hash = hash_api_key(&api_key);
193
194        // Verification should succeed
195        assert!(verify_api_key(&api_key, &hash));
196
197        // Wrong key should fail
198        let wrong_key = generate_api_key();
199        assert!(!verify_api_key(&wrong_key, &hash));
200    }
201
202    #[test]
203    fn test_validate_api_key_format() {
204        // Valid key should pass
205        let valid_key = generate_api_key();
206        assert!(validate_api_key_format(&valid_key).is_ok());
207
208        // Invalid keys should fail
209        assert!(validate_api_key_format("invalid").is_err());
210        assert!(validate_api_key_format("api_too_short").is_err());
211        assert!(validate_api_key_format("mcp_").is_err());
212    }
213
214    #[test]
215    fn test_secure_compare() {
216        assert!(secure_compare("hello", "hello"));
217        assert!(!secure_compare("hello", "world"));
218        assert!(!secure_compare("hello", "hello world"));
219        assert!(!secure_compare("", "hello"));
220    }
221
222    #[test]
223    fn test_session_id_generation() {
224        let id1 = generate_session_id();
225        let id2 = generate_session_id();
226
227        assert_ne!(id1, id2);
228        assert!(id1.starts_with("sess_"));
229        assert!(id2.starts_with("sess_"));
230    }
231
232    #[test]
233    fn test_request_id_generation() {
234        let id1 = generate_request_id();
235        let id2 = generate_request_id();
236
237        assert_ne!(id1, id2);
238        assert!(id1.starts_with("req_"));
239        assert!(id2.starts_with("req_"));
240    }
241
242    #[test]
243    fn test_current_timestamp() {
244        let ts1 = current_timestamp();
245        std::thread::sleep(std::time::Duration::from_millis(1));
246        let ts2 = current_timestamp();
247
248        assert!(ts2 >= ts1);
249    }
250
251    #[test]
252    fn test_secure_random() {
253        let bytes1 = SecureRandom::bytes(32);
254        let bytes2 = SecureRandom::bytes(32);
255
256        assert_eq!(bytes1.len(), 32);
257        assert_eq!(bytes2.len(), 32);
258        assert_ne!(bytes1, bytes2);
259
260        let string1 = SecureRandom::string(20);
261        let string2 = SecureRandom::string(20);
262
263        assert_eq!(string1.len(), 20);
264        assert_eq!(string2.len(), 20);
265        assert_ne!(string1, string2);
266    }
267
268    #[test]
269    fn test_secure_random_edge_cases() {
270        // Test zero-length
271        let bytes = SecureRandom::bytes(0);
272        assert_eq!(bytes.len(), 0);
273
274        let string = SecureRandom::string(0);
275        assert_eq!(string.len(), 0);
276
277        // Test single byte/char
278        let bytes = SecureRandom::bytes(1);
279        assert_eq!(bytes.len(), 1);
280
281        let string = SecureRandom::string(1);
282        assert_eq!(string.len(), 1);
283    }
284
285    #[test]
286    fn test_validate_api_key_format_edge_cases() {
287        // Test boundaries
288        assert!(validate_api_key_format("").is_err());
289        assert!(validate_api_key_format("a").is_err()); // Too short
290        assert!(validate_api_key_format("ab").is_err()); // Too short
291
292        // Test without proper prefix
293        assert!(validate_api_key_format("abc12345678901234567890").is_err());
294
295        // Test with proper prefix but too short
296        assert!(validate_api_key_format("mcp_abc").is_err());
297
298        // Test exactly minimum length with prefix
299        assert!(validate_api_key_format("mcp_1234567890123456").is_ok());
300
301        // Test with different character types
302        assert!(validate_api_key_format("mcp_123456789012345678").is_ok());
303        assert!(validate_api_key_format("mcp_ABCDEFGHIJ1234567890").is_ok());
304        assert!(validate_api_key_format("mcp_abcdefghij1234567890").is_ok());
305        assert!(validate_api_key_format("mcp_a1B2c3D4e1234567890").is_ok());
306
307        // Test whitespace
308        assert!(validate_api_key_format("mcp_abc def1234567890").is_err());
309        assert!(validate_api_key_format(" mcp_abcdef1234567890").is_err());
310        assert!(validate_api_key_format("mcp_abcdef1234567890 ").is_err());
311    }
312
313    #[test]
314    fn test_hash_and_verify_consistency() {
315        let api_key = "test_key_12345";
316        let hash1 = hash_api_key(api_key);
317        let hash2 = hash_api_key(api_key);
318
319        // Hashes should be the same (deterministic)
320        assert_eq!(hash1, hash2);
321
322        // Both should verify correctly
323        assert!(verify_api_key(api_key, &hash1));
324        assert!(verify_api_key(api_key, &hash2));
325
326        // Wrong key should not verify
327        assert!(!verify_api_key("wrong_key", &hash1));
328        assert!(!verify_api_key("wrong_key", &hash2));
329    }
330
331    #[test]
332    fn test_secure_compare_edge_cases() {
333        // Test empty strings
334        assert!(secure_compare("", ""));
335        assert!(!secure_compare("", "a"));
336        assert!(!secure_compare("a", ""));
337
338        // Test same content
339        assert!(secure_compare("hello", "hello"));
340
341        // Test different lengths
342        assert!(!secure_compare("short", "longer_string"));
343        assert!(!secure_compare("longer_string", "short"));
344    }
345
346    #[test]
347    fn test_timestamp_consistency() {
348        let time1 = current_timestamp();
349        let time2 = current_timestamp();
350
351        // Should be very close in time
352        assert!(time2 >= time1);
353        assert!(time2 - time1 < 1000); // Less than 1 second difference
354    }
355}