pulseengine_mcp_security_middleware/
utils.rs1use crate::error::{SecurityError, SecurityResult};
4use base64::{Engine as _, engine::general_purpose};
5use rand::{Rng, distributions::Alphanumeric};
6use sha2::{Digest, Sha256};
7use std::time::{SystemTime, UNIX_EPOCH};
8
9pub struct SecureRandom;
11
12impl SecureRandom {
13    pub fn bytes(length: usize) -> Vec<u8> {
15        let mut rng = rand::thread_rng();
16        (0..length).map(|_| rng.r#gen()).collect()
17    }
18
19    pub fn string(length: usize) -> String {
21        rand::thread_rng()
22            .sample_iter(&Alphanumeric)
23            .take(length)
24            .map(char::from)
25            .collect()
26    }
27
28    pub fn base64_string(byte_length: usize) -> String {
30        let bytes = Self::bytes(byte_length);
31        general_purpose::STANDARD.encode(bytes)
32    }
33
34    pub fn base64_url_string(byte_length: usize) -> String {
36        let bytes = Self::bytes(byte_length);
37        general_purpose::URL_SAFE_NO_PAD.encode(bytes)
38    }
39}
40
41pub fn generate_api_key() -> String {
55    let random_part = SecureRandom::base64_url_string(32);
56    format!("mcp_{random_part}")
57}
58
59pub fn generate_jwt_secret() -> String {
72    SecureRandom::base64_string(64)
73}
74
75pub fn hash_api_key(api_key: &str) -> String {
80    let mut hasher = Sha256::new();
81    hasher.update(api_key.as_bytes());
82    let result = hasher.finalize();
83    general_purpose::STANDARD.encode(result)
84}
85
86pub fn verify_api_key(api_key: &str, stored_hash: &str) -> bool {
90    let computed_hash = hash_api_key(api_key);
91    computed_hash == stored_hash
92}
93
94pub fn current_timestamp() -> u64 {
96    SystemTime::now()
97        .duration_since(UNIX_EPOCH)
98        .expect("Time went backwards")
99        .as_secs()
100}
101
102pub fn validate_api_key_format(key: &str) -> SecurityResult<()> {
104    if !key.starts_with("mcp_") {
105        return Err(SecurityError::invalid_token(
106            "API key must start with 'mcp_'",
107        ));
108    }
109
110    if key.len() < 20 {
111        return Err(SecurityError::invalid_token("API key too short"));
112    }
113
114    if key.len() > 200 {
115        return Err(SecurityError::invalid_token("API key too long"));
116    }
117
118    let key_part = &key[4..]; for c in key_part.chars() {
121        if !c.is_alphanumeric() && c != '-' && c != '_' {
122            return Err(SecurityError::invalid_token(
123                "Invalid characters in API key",
124            ));
125        }
126    }
127
128    Ok(())
129}
130
131pub fn generate_session_id() -> String {
133    format!("sess_{}", SecureRandom::base64_url_string(32))
134}
135
136pub fn generate_request_id() -> String {
138    format!("req_{}", SecureRandom::base64_url_string(16))
139}
140
141pub fn secure_compare(a: &str, b: &str) -> bool {
143    if a.len() != b.len() {
144        return false;
145    }
146
147    let mut result = 0u8;
148    for (byte_a, byte_b) in a.bytes().zip(b.bytes()) {
149        result |= byte_a ^ byte_b;
150    }
151
152    result == 0
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_generate_api_key() {
161        let key1 = generate_api_key();
162        let key2 = generate_api_key();
163
164        assert_ne!(key1, key2);
166
167        assert!(key1.starts_with("mcp_"));
169        assert!(key2.starts_with("mcp_"));
170
171        assert!(key1.len() > 20);
173        assert!(key2.len() > 20);
174    }
175
176    #[test]
177    fn test_generate_jwt_secret() {
178        let secret1 = generate_jwt_secret();
179        let secret2 = generate_jwt_secret();
180
181        assert_ne!(secret1, secret2);
183
184        assert!(secret1.len() >= 64);
186        assert!(secret2.len() >= 64);
187    }
188
189    #[test]
190    fn test_hash_and_verify_api_key() {
191        let api_key = generate_api_key();
192        let hash = hash_api_key(&api_key);
193
194        assert!(verify_api_key(&api_key, &hash));
196
197        let wrong_key = generate_api_key();
199        assert!(!verify_api_key(&wrong_key, &hash));
200    }
201
202    #[test]
203    fn test_validate_api_key_format() {
204        let valid_key = generate_api_key();
206        assert!(validate_api_key_format(&valid_key).is_ok());
207
208        assert!(validate_api_key_format("invalid").is_err());
210        assert!(validate_api_key_format("api_too_short").is_err());
211        assert!(validate_api_key_format("mcp_").is_err());
212    }
213
214    #[test]
215    fn test_secure_compare() {
216        assert!(secure_compare("hello", "hello"));
217        assert!(!secure_compare("hello", "world"));
218        assert!(!secure_compare("hello", "hello world"));
219        assert!(!secure_compare("", "hello"));
220    }
221
222    #[test]
223    fn test_session_id_generation() {
224        let id1 = generate_session_id();
225        let id2 = generate_session_id();
226
227        assert_ne!(id1, id2);
228        assert!(id1.starts_with("sess_"));
229        assert!(id2.starts_with("sess_"));
230    }
231
232    #[test]
233    fn test_request_id_generation() {
234        let id1 = generate_request_id();
235        let id2 = generate_request_id();
236
237        assert_ne!(id1, id2);
238        assert!(id1.starts_with("req_"));
239        assert!(id2.starts_with("req_"));
240    }
241
242    #[test]
243    fn test_current_timestamp() {
244        let ts1 = current_timestamp();
245        std::thread::sleep(std::time::Duration::from_millis(1));
246        let ts2 = current_timestamp();
247
248        assert!(ts2 >= ts1);
249    }
250
251    #[test]
252    fn test_secure_random() {
253        let bytes1 = SecureRandom::bytes(32);
254        let bytes2 = SecureRandom::bytes(32);
255
256        assert_eq!(bytes1.len(), 32);
257        assert_eq!(bytes2.len(), 32);
258        assert_ne!(bytes1, bytes2);
259
260        let string1 = SecureRandom::string(20);
261        let string2 = SecureRandom::string(20);
262
263        assert_eq!(string1.len(), 20);
264        assert_eq!(string2.len(), 20);
265        assert_ne!(string1, string2);
266    }
267
268    #[test]
269    fn test_secure_random_edge_cases() {
270        let bytes = SecureRandom::bytes(0);
272        assert_eq!(bytes.len(), 0);
273
274        let string = SecureRandom::string(0);
275        assert_eq!(string.len(), 0);
276
277        let bytes = SecureRandom::bytes(1);
279        assert_eq!(bytes.len(), 1);
280
281        let string = SecureRandom::string(1);
282        assert_eq!(string.len(), 1);
283    }
284
285    #[test]
286    fn test_validate_api_key_format_edge_cases() {
287        assert!(validate_api_key_format("").is_err());
289        assert!(validate_api_key_format("a").is_err()); assert!(validate_api_key_format("ab").is_err()); assert!(validate_api_key_format("abc12345678901234567890").is_err());
294
295        assert!(validate_api_key_format("mcp_abc").is_err());
297
298        assert!(validate_api_key_format("mcp_1234567890123456").is_ok());
300
301        assert!(validate_api_key_format("mcp_123456789012345678").is_ok());
303        assert!(validate_api_key_format("mcp_ABCDEFGHIJ1234567890").is_ok());
304        assert!(validate_api_key_format("mcp_abcdefghij1234567890").is_ok());
305        assert!(validate_api_key_format("mcp_a1B2c3D4e1234567890").is_ok());
306
307        assert!(validate_api_key_format("mcp_abc def1234567890").is_err());
309        assert!(validate_api_key_format(" mcp_abcdef1234567890").is_err());
310        assert!(validate_api_key_format("mcp_abcdef1234567890 ").is_err());
311    }
312
313    #[test]
314    fn test_hash_and_verify_consistency() {
315        let api_key = "test_key_12345";
316        let hash1 = hash_api_key(api_key);
317        let hash2 = hash_api_key(api_key);
318
319        assert_eq!(hash1, hash2);
321
322        assert!(verify_api_key(api_key, &hash1));
324        assert!(verify_api_key(api_key, &hash2));
325
326        assert!(!verify_api_key("wrong_key", &hash1));
328        assert!(!verify_api_key("wrong_key", &hash2));
329    }
330
331    #[test]
332    fn test_secure_compare_edge_cases() {
333        assert!(secure_compare("", ""));
335        assert!(!secure_compare("", "a"));
336        assert!(!secure_compare("a", ""));
337
338        assert!(secure_compare("hello", "hello"));
340
341        assert!(!secure_compare("short", "longer_string"));
343        assert!(!secure_compare("longer_string", "short"));
344    }
345
346    #[test]
347    fn test_timestamp_consistency() {
348        let time1 = current_timestamp();
349        let time2 = current_timestamp();
350
351        assert!(time2 >= time1);
353        assert!(time2 - time1 < 1000); }
355}