1use aes_gcm::{
2 aead::{Aead, KeyInit, OsRng},
3 Aes256Gcm, Key, Nonce,
4};
5use argon2::password_hash::{rand_core::RngCore, SaltString};
6use argon2::{Argon2, PasswordHash, PasswordHasher as Argon2Hasher, PasswordVerifier};
7use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
8use std::collections::HashMap;
9use std::sync::RwLock;
10
11#[derive(Debug)]
13pub struct EncryptionProvider {
14 keys: RwLock<HashMap<String, Vec<u8>>>,
16 active_key_id: RwLock<String>,
18}
19
20impl EncryptionProvider {
21 pub fn new(master_key: &[u8]) -> Result<Self, EncryptionError> {
23 if master_key.len() != 32 {
24 return Err(EncryptionError::InvalidKeySize {
25 expected: 32,
26 actual: master_key.len(),
27 });
28 }
29
30 let mut keys = HashMap::new();
31 let key_id = "master".to_string();
32 keys.insert(key_id.clone(), master_key.to_vec());
33
34 Ok(Self {
35 keys: RwLock::new(keys),
36 active_key_id: RwLock::new(key_id),
37 })
38 }
39
40 pub fn generate_key() -> Vec<u8> {
42 let mut key = vec![0u8; 32];
43 OsRng.fill_bytes(&mut key);
44 key
45 }
46
47 pub fn add_key(&self, key_id: String, key: Vec<u8>) -> Result<(), EncryptionError> {
49 if key.len() != 32 {
50 return Err(EncryptionError::InvalidKeySize {
51 expected: 32,
52 actual: key.len(),
53 });
54 }
55
56 let mut keys = self.keys.write().map_err(|_| EncryptionError::LockError)?;
57 keys.insert(key_id, key);
58 Ok(())
59 }
60
61 pub fn set_active_key(&self, key_id: String) -> Result<(), EncryptionError> {
63 let keys = self.keys.read().map_err(|_| EncryptionError::LockError)?;
64 if !keys.contains_key(&key_id) {
65 return Err(EncryptionError::KeyNotFound(key_id));
66 }
67
68 let mut active = self
69 .active_key_id
70 .write()
71 .map_err(|_| EncryptionError::LockError)?;
72 *active = key_id;
73 Ok(())
74 }
75
76 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, EncryptionError> {
78 let active_key_id = self
79 .active_key_id
80 .read()
81 .map_err(|_| EncryptionError::LockError)?;
82 self.encrypt_with_key(&active_key_id, plaintext)
83 }
84
85 pub fn encrypt_with_key(
87 &self,
88 key_id: &str,
89 plaintext: &[u8],
90 ) -> Result<EncryptedData, EncryptionError> {
91 let keys = self.keys.read().map_err(|_| EncryptionError::LockError)?;
92 let key = keys
93 .get(key_id)
94 .ok_or_else(|| EncryptionError::KeyNotFound(key_id.to_string()))?;
95
96 let mut nonce_bytes = [0u8; 12];
98 OsRng.fill_bytes(&mut nonce_bytes);
99 let nonce = Nonce::from_slice(&nonce_bytes);
100
101 let key_obj = Key::<Aes256Gcm>::from_slice(key);
103 let cipher = Aes256Gcm::new(key_obj);
104 let ciphertext = cipher
105 .encrypt(nonce, plaintext)
106 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
107
108 Ok(EncryptedData {
109 key_id: key_id.to_string(),
110 nonce: nonce_bytes.to_vec(),
111 ciphertext,
112 })
113 }
114
115 pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
117 let keys = self.keys.read().map_err(|_| EncryptionError::LockError)?;
118 let key = keys
119 .get(&encrypted.key_id)
120 .ok_or_else(|| EncryptionError::KeyNotFound(encrypted.key_id.clone()))?;
121
122 if encrypted.nonce.len() != 12 {
123 return Err(EncryptionError::InvalidNonceSize {
124 expected: 12,
125 actual: encrypted.nonce.len(),
126 });
127 }
128
129 let nonce = Nonce::from_slice(&encrypted.nonce);
130 let key_obj = Key::<Aes256Gcm>::from_slice(key);
131 let cipher = Aes256Gcm::new(key_obj);
132
133 cipher
134 .decrypt(nonce, encrypted.ciphertext.as_ref())
135 .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
136 }
137
138 pub fn encrypt_string(&self, plaintext: &str) -> Result<String, EncryptionError> {
140 let encrypted = self.encrypt(plaintext.as_bytes())?;
141 Ok(encrypted.to_base64())
142 }
143
144 pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, EncryptionError> {
146 let encrypted = EncryptedData::from_base64(ciphertext)?;
147 let plaintext = self.decrypt(&encrypted)?;
148 String::from_utf8(plaintext).map_err(|e| EncryptionError::InvalidUtf8(e.to_string()))
149 }
150
151 pub fn reencrypt(
153 &self,
154 encrypted: &EncryptedData,
155 new_key_id: &str,
156 ) -> Result<EncryptedData, EncryptionError> {
157 let plaintext = self.decrypt(encrypted)?;
158 self.encrypt_with_key(new_key_id, &plaintext)
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct EncryptedData {
165 pub key_id: String,
167 pub nonce: Vec<u8>,
169 pub ciphertext: Vec<u8>,
171}
172
173impl EncryptedData {
174 pub fn to_base64(&self) -> String {
176 format!(
177 "{}:{}:{}",
178 self.key_id,
179 BASE64.encode(&self.nonce),
180 BASE64.encode(&self.ciphertext)
181 )
182 }
183
184 pub fn from_base64(encoded: &str) -> Result<Self, EncryptionError> {
186 let parts: Vec<&str> = encoded.split(':').collect();
187 if parts.len() != 3 {
188 return Err(EncryptionError::InvalidFormat);
189 }
190
191 let key_id = parts[0].to_string();
192 let nonce = BASE64
193 .decode(parts[1])
194 .map_err(|e| EncryptionError::Base64Error(e.to_string()))?;
195 let ciphertext = BASE64
196 .decode(parts[2])
197 .map_err(|e| EncryptionError::Base64Error(e.to_string()))?;
198
199 Ok(Self {
200 key_id,
201 nonce,
202 ciphertext,
203 })
204 }
205}
206
207#[derive(Debug)]
209pub struct KeyRotationManager {
210 provider: EncryptionProvider,
211}
212
213impl KeyRotationManager {
214 pub fn new(provider: EncryptionProvider) -> Self {
216 Self { provider }
217 }
218
219 pub fn rotate_key(&self, new_key_id: String, new_key: Vec<u8>) -> Result<(), EncryptionError> {
221 self.provider.add_key(new_key_id.clone(), new_key)?;
223
224 self.provider.set_active_key(new_key_id)?;
226
227 Ok(())
228 }
229
230 pub fn reencrypt_data(
232 &self,
233 old_encrypted: &EncryptedData,
234 ) -> Result<EncryptedData, EncryptionError> {
235 let active_key_id = self
236 .provider
237 .active_key_id
238 .read()
239 .map_err(|_| EncryptionError::LockError)?;
240 self.provider.reencrypt(old_encrypted, &active_key_id)
241 }
242}
243
244#[derive(Debug, Clone)]
246pub struct PasswordHashingService {
247 argon2: Argon2<'static>,
248}
249
250impl Default for PasswordHashingService {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256impl PasswordHashingService {
257 pub fn new() -> Self {
259 Self {
260 argon2: Argon2::default(),
261 }
262 }
263
264 pub fn hash_password(&self, password: &str) -> Result<String, EncryptionError> {
266 let salt = SaltString::generate(&mut OsRng);
267
268 let password_hash = self
269 .argon2
270 .hash_password(password.as_bytes(), &salt)
271 .map_err(|e| EncryptionError::HashingFailed(e.to_string()))?;
272
273 Ok(password_hash.to_string())
274 }
275
276 pub fn verify_password(&self, password: &str, hash: &str) -> Result<bool, EncryptionError> {
278 let parsed_hash =
279 PasswordHash::new(hash).map_err(|e| EncryptionError::InvalidHash(e.to_string()))?;
280
281 match self
282 .argon2
283 .verify_password(password.as_bytes(), &parsed_hash)
284 {
285 Ok(()) => Ok(true),
286 Err(argon2::password_hash::Error::Password) => Ok(false),
287 Err(e) => Err(EncryptionError::VerificationFailed(e.to_string())),
288 }
289 }
290}
291
292#[derive(Debug, Clone, thiserror::Error)]
294pub enum EncryptionError {
295 #[error("Invalid key size: expected {expected}, got {actual}")]
296 InvalidKeySize { expected: usize, actual: usize },
297
298 #[error("Invalid nonce size: expected {expected}, got {actual}")]
299 InvalidNonceSize { expected: usize, actual: usize },
300
301 #[error("Key not found: {0}")]
302 KeyNotFound(String),
303
304 #[error("Encryption failed: {0}")]
305 EncryptionFailed(String),
306
307 #[error("Decryption failed: {0}")]
308 DecryptionFailed(String),
309
310 #[error("Invalid UTF-8: {0}")]
311 InvalidUtf8(String),
312
313 #[error("Invalid format")]
314 InvalidFormat,
315
316 #[error("Base64 error: {0}")]
317 Base64Error(String),
318
319 #[error("Lock error")]
320 LockError,
321
322 #[error("Hashing failed: {0}")]
323 HashingFailed(String),
324
325 #[error("Invalid hash: {0}")]
326 InvalidHash(String),
327
328 #[error("Verification failed: {0}")]
329 VerificationFailed(String),
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_generate_key() {
338 let key = EncryptionProvider::generate_key();
339 assert_eq!(key.len(), 32);
340 }
341
342 #[test]
343 fn test_encrypt_decrypt() {
344 let key = EncryptionProvider::generate_key();
345 let provider = EncryptionProvider::new(&key).unwrap();
346
347 let plaintext = b"Hello, World!";
348 let encrypted = provider.encrypt(plaintext).unwrap();
349 let decrypted = provider.decrypt(&encrypted).unwrap();
350
351 assert_eq!(plaintext, decrypted.as_slice());
352 }
353
354 #[test]
355 fn test_encrypt_decrypt_string() {
356 let key = EncryptionProvider::generate_key();
357 let provider = EncryptionProvider::new(&key).unwrap();
358
359 let plaintext = "Hello, World!";
360 let encrypted = provider.encrypt_string(plaintext).unwrap();
361 let decrypted = provider.decrypt_string(&encrypted).unwrap();
362
363 assert_eq!(plaintext, decrypted);
364 }
365
366 #[test]
367 fn test_encrypted_data_base64() {
368 let data = EncryptedData {
369 key_id: "test".to_string(),
370 nonce: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
371 ciphertext: vec![13, 14, 15, 16],
372 };
373
374 let encoded = data.to_base64();
375 let decoded = EncryptedData::from_base64(&encoded).unwrap();
376
377 assert_eq!(data.key_id, decoded.key_id);
378 assert_eq!(data.nonce, decoded.nonce);
379 assert_eq!(data.ciphertext, decoded.ciphertext);
380 }
381
382 #[test]
383 fn test_key_rotation() {
384 let key1 = EncryptionProvider::generate_key();
385 let provider = EncryptionProvider::new(&key1).unwrap();
386
387 let plaintext = b"Secret data";
388 let encrypted1 = provider.encrypt(plaintext).unwrap();
389
390 let key2 = EncryptionProvider::generate_key();
392 provider.add_key("key2".to_string(), key2).unwrap();
393 provider.set_active_key("key2".to_string()).unwrap();
394
395 let encrypted2 = provider.reencrypt(&encrypted1, "key2").unwrap();
397
398 let decrypted1 = provider.decrypt(&encrypted1).unwrap();
400 let decrypted2 = provider.decrypt(&encrypted2).unwrap();
401
402 assert_eq!(plaintext, decrypted1.as_slice());
403 assert_eq!(plaintext, decrypted2.as_slice());
404 assert_eq!(encrypted2.key_id, "key2");
405 }
406
407 #[test]
408 fn test_invalid_key_size() {
409 let short_key = vec![0u8; 16]; let result = EncryptionProvider::new(&short_key);
411 assert!(result.is_err());
412 }
413
414 #[test]
415 fn test_key_not_found() {
416 let key = EncryptionProvider::generate_key();
417 let provider = EncryptionProvider::new(&key).unwrap();
418
419 let result = provider.encrypt_with_key("nonexistent", b"data");
420 assert!(result.is_err());
421 }
422
423 #[test]
424 fn test_password_hasher() {
425 let hasher = PasswordHashingService::new();
426 let password = "my_secure_password";
427
428 let hash = hasher.hash_password(password).unwrap();
429 assert!(hasher.verify_password(password, &hash).unwrap());
430 assert!(!hasher.verify_password("wrong_password", &hash).unwrap());
431 }
432
433 #[test]
434 fn test_password_hashing_produces_different_hashes() {
435 let hasher = PasswordHashingService::new();
436 let password = "same_password";
437
438 let hash1 = hasher.hash_password(password).unwrap();
439 let hash2 = hasher.hash_password(password).unwrap();
440
441 assert_ne!(hash1, hash2);
443
444 assert!(hasher.verify_password(password, &hash1).unwrap());
446 assert!(hasher.verify_password(password, &hash2).unwrap());
447 }
448
449 #[test]
450 fn test_key_rotation_manager() {
451 let key1 = EncryptionProvider::generate_key();
452 let provider = EncryptionProvider::new(&key1).unwrap();
453 let manager = KeyRotationManager::new(provider);
454
455 let plaintext = b"Test data";
456 let encrypted = manager.provider.encrypt(plaintext).unwrap();
457
458 let key2 = EncryptionProvider::generate_key();
460 manager.rotate_key("new_key".to_string(), key2).unwrap();
461
462 let reencrypted = manager.reencrypt_data(&encrypted).unwrap();
464 let decrypted = manager.provider.decrypt(&reencrypted).unwrap();
465
466 assert_eq!(plaintext, decrypted.as_slice());
467 assert_eq!(reencrypted.key_id, "new_key");
468 }
469}