pulseengine_mcp_security_middleware/
utils.rs1use crate::error::{SecurityError, SecurityResult};
4use base64::{Engine as _, engine::general_purpose};
5use rand::{Rng, distributions::Alphanumeric};
6use sha2::{Digest, Sha256};
7use std::time::{SystemTime, UNIX_EPOCH};
8
9pub struct SecureRandom;
11
12impl SecureRandom {
13 pub fn bytes(length: usize) -> Vec<u8> {
15 let mut rng = rand::thread_rng();
16 (0..length).map(|_| rng.r#gen()).collect()
17 }
18
19 pub fn string(length: usize) -> String {
21 rand::thread_rng()
22 .sample_iter(&Alphanumeric)
23 .take(length)
24 .map(char::from)
25 .collect()
26 }
27
28 pub fn base64_string(byte_length: usize) -> String {
30 let bytes = Self::bytes(byte_length);
31 general_purpose::STANDARD.encode(bytes)
32 }
33
34 pub fn base64_url_string(byte_length: usize) -> String {
36 let bytes = Self::bytes(byte_length);
37 general_purpose::URL_SAFE_NO_PAD.encode(bytes)
38 }
39}
40
41pub fn generate_api_key() -> String {
55 let random_part = SecureRandom::base64_url_string(32);
56 format!("mcp_{random_part}")
57}
58
59pub fn generate_jwt_secret() -> String {
72 SecureRandom::base64_string(64)
73}
74
75pub fn hash_api_key(api_key: &str) -> String {
80 let mut hasher = Sha256::new();
81 hasher.update(api_key.as_bytes());
82 let result = hasher.finalize();
83 general_purpose::STANDARD.encode(result)
84}
85
86pub fn verify_api_key(api_key: &str, stored_hash: &str) -> bool {
90 let computed_hash = hash_api_key(api_key);
91 computed_hash == stored_hash
92}
93
94pub fn current_timestamp() -> u64 {
96 SystemTime::now()
97 .duration_since(UNIX_EPOCH)
98 .expect("Time went backwards")
99 .as_secs()
100}
101
102pub fn validate_api_key_format(key: &str) -> SecurityResult<()> {
104 if !key.starts_with("mcp_") {
105 return Err(SecurityError::invalid_token(
106 "API key must start with 'mcp_'",
107 ));
108 }
109
110 if key.len() < 20 {
111 return Err(SecurityError::invalid_token("API key too short"));
112 }
113
114 if key.len() > 200 {
115 return Err(SecurityError::invalid_token("API key too long"));
116 }
117
118 let key_part = &key[4..]; for c in key_part.chars() {
121 if !c.is_alphanumeric() && c != '-' && c != '_' {
122 return Err(SecurityError::invalid_token(
123 "Invalid characters in API key",
124 ));
125 }
126 }
127
128 Ok(())
129}
130
131pub fn generate_session_id() -> String {
133 format!("sess_{}", SecureRandom::base64_url_string(32))
134}
135
136pub fn generate_request_id() -> String {
138 format!("req_{}", SecureRandom::base64_url_string(16))
139}
140
141pub fn secure_compare(a: &str, b: &str) -> bool {
143 if a.len() != b.len() {
144 return false;
145 }
146
147 let mut result = 0u8;
148 for (byte_a, byte_b) in a.bytes().zip(b.bytes()) {
149 result |= byte_a ^ byte_b;
150 }
151
152 result == 0
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_generate_api_key() {
161 let key1 = generate_api_key();
162 let key2 = generate_api_key();
163
164 assert_ne!(key1, key2);
166
167 assert!(key1.starts_with("mcp_"));
169 assert!(key2.starts_with("mcp_"));
170
171 assert!(key1.len() > 20);
173 assert!(key2.len() > 20);
174 }
175
176 #[test]
177 fn test_generate_jwt_secret() {
178 let secret1 = generate_jwt_secret();
179 let secret2 = generate_jwt_secret();
180
181 assert_ne!(secret1, secret2);
183
184 assert!(secret1.len() >= 64);
186 assert!(secret2.len() >= 64);
187 }
188
189 #[test]
190 fn test_hash_and_verify_api_key() {
191 let api_key = generate_api_key();
192 let hash = hash_api_key(&api_key);
193
194 assert!(verify_api_key(&api_key, &hash));
196
197 let wrong_key = generate_api_key();
199 assert!(!verify_api_key(&wrong_key, &hash));
200 }
201
202 #[test]
203 fn test_validate_api_key_format() {
204 let valid_key = generate_api_key();
206 assert!(validate_api_key_format(&valid_key).is_ok());
207
208 assert!(validate_api_key_format("invalid").is_err());
210 assert!(validate_api_key_format("api_too_short").is_err());
211 assert!(validate_api_key_format("mcp_").is_err());
212 }
213
214 #[test]
215 fn test_secure_compare() {
216 assert!(secure_compare("hello", "hello"));
217 assert!(!secure_compare("hello", "world"));
218 assert!(!secure_compare("hello", "hello world"));
219 assert!(!secure_compare("", "hello"));
220 }
221
222 #[test]
223 fn test_session_id_generation() {
224 let id1 = generate_session_id();
225 let id2 = generate_session_id();
226
227 assert_ne!(id1, id2);
228 assert!(id1.starts_with("sess_"));
229 assert!(id2.starts_with("sess_"));
230 }
231
232 #[test]
233 fn test_request_id_generation() {
234 let id1 = generate_request_id();
235 let id2 = generate_request_id();
236
237 assert_ne!(id1, id2);
238 assert!(id1.starts_with("req_"));
239 assert!(id2.starts_with("req_"));
240 }
241
242 #[test]
243 fn test_current_timestamp() {
244 let ts1 = current_timestamp();
245 std::thread::sleep(std::time::Duration::from_millis(1));
246 let ts2 = current_timestamp();
247
248 assert!(ts2 >= ts1);
249 }
250
251 #[test]
252 fn test_secure_random() {
253 let bytes1 = SecureRandom::bytes(32);
254 let bytes2 = SecureRandom::bytes(32);
255
256 assert_eq!(bytes1.len(), 32);
257 assert_eq!(bytes2.len(), 32);
258 assert_ne!(bytes1, bytes2);
259
260 let string1 = SecureRandom::string(20);
261 let string2 = SecureRandom::string(20);
262
263 assert_eq!(string1.len(), 20);
264 assert_eq!(string2.len(), 20);
265 assert_ne!(string1, string2);
266 }
267
268 #[test]
269 fn test_secure_random_edge_cases() {
270 let bytes = SecureRandom::bytes(0);
272 assert_eq!(bytes.len(), 0);
273
274 let string = SecureRandom::string(0);
275 assert_eq!(string.len(), 0);
276
277 let bytes = SecureRandom::bytes(1);
279 assert_eq!(bytes.len(), 1);
280
281 let string = SecureRandom::string(1);
282 assert_eq!(string.len(), 1);
283 }
284
285 #[test]
286 fn test_validate_api_key_format_edge_cases() {
287 assert!(validate_api_key_format("").is_err());
289 assert!(validate_api_key_format("a").is_err()); assert!(validate_api_key_format("ab").is_err()); assert!(validate_api_key_format("abc12345678901234567890").is_err());
294
295 assert!(validate_api_key_format("mcp_abc").is_err());
297
298 assert!(validate_api_key_format("mcp_1234567890123456").is_ok());
300
301 assert!(validate_api_key_format("mcp_123456789012345678").is_ok());
303 assert!(validate_api_key_format("mcp_ABCDEFGHIJ1234567890").is_ok());
304 assert!(validate_api_key_format("mcp_abcdefghij1234567890").is_ok());
305 assert!(validate_api_key_format("mcp_a1B2c3D4e1234567890").is_ok());
306
307 assert!(validate_api_key_format("mcp_abc def1234567890").is_err());
309 assert!(validate_api_key_format(" mcp_abcdef1234567890").is_err());
310 assert!(validate_api_key_format("mcp_abcdef1234567890 ").is_err());
311 }
312
313 #[test]
314 fn test_hash_and_verify_consistency() {
315 let api_key = "test_key_12345";
316 let hash1 = hash_api_key(api_key);
317 let hash2 = hash_api_key(api_key);
318
319 assert_eq!(hash1, hash2);
321
322 assert!(verify_api_key(api_key, &hash1));
324 assert!(verify_api_key(api_key, &hash2));
325
326 assert!(!verify_api_key("wrong_key", &hash1));
328 assert!(!verify_api_key("wrong_key", &hash2));
329 }
330
331 #[test]
332 fn test_secure_compare_edge_cases() {
333 assert!(secure_compare("", ""));
335 assert!(!secure_compare("", "a"));
336 assert!(!secure_compare("a", ""));
337
338 assert!(secure_compare("hello", "hello"));
340
341 assert!(!secure_compare("short", "longer_string"));
343 assert!(!secure_compare("longer_string", "short"));
344 }
345
346 #[test]
347 fn test_timestamp_consistency() {
348 let time1 = current_timestamp();
349 let time2 = current_timestamp();
350
351 assert!(time2 >= time1);
353 assert!(time2 - time1 < 1000); }
355}