1use aes_gcm::{
2 Aes256Gcm, Nonce,
3 aead::{Aead, AeadCore, KeyInit},
4};
5use base64::{Engine as _, engine::general_purpose};
6use pbkdf2::pbkdf2_hmac;
7use rand::{RngCore, rngs::OsRng};
8use sha2::Sha256;
9use std::fmt;
10use zeroize::ZeroizeOnDrop;
11
12const PBKDF2_ITERATIONS: u32 = 100_000;
13const SALT_LENGTH: usize = 32;
14const KEY_LENGTH: usize = 32;
15const NONCE_LENGTH: usize = 12;
16
17const ERROR_EMPTY_CONNECTION_STRING: &str = "Connection string cannot be empty";
19const ERROR_EMPTY_CLIENT_SECRET: &str = "Client secret cannot be empty";
20const ERROR_EMPTY_PASSWORD: &str = "Password cannot be empty";
21const ERROR_EMPTY_ENCRYPTED_DATA: &str = "Encrypted data cannot be empty";
22const ERROR_ENCRYPTED_DATA_TOO_SHORT: &str = "Encrypted data too short";
23
24#[derive(Debug)]
25pub enum EncryptionError {
26 InvalidData(String),
27 EncryptionFailed(String),
28 DecryptionFailed(String),
29 KeyDerivation(String),
30}
31
32impl fmt::Display for EncryptionError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 EncryptionError::InvalidData(msg) => write!(f, "Invalid data: {msg}"),
36 EncryptionError::EncryptionFailed(msg) => write!(f, "Encryption failed: {msg}"),
37 EncryptionError::DecryptionFailed(msg) => write!(f, "Decryption failed: {msg}"),
38 EncryptionError::KeyDerivation(msg) => write!(f, "Key derivation failed: {msg}"),
39 }
40 }
41}
42
43impl std::error::Error for EncryptionError {}
44
45#[derive(ZeroizeOnDrop)]
46struct SecureKey([u8; KEY_LENGTH]);
47
48impl SecureKey {
49 fn new(key: [u8; KEY_LENGTH]) -> Self {
50 Self(key)
51 }
52
53 fn as_bytes(&self) -> &[u8; KEY_LENGTH] {
54 &self.0
55 }
56}
57
58pub struct AesEncryption {
60 salt: [u8; SALT_LENGTH],
61}
62
63impl AesEncryption {
64 pub fn new() -> Self {
65 let mut salt = [0u8; SALT_LENGTH];
66 OsRng.fill_bytes(&mut salt);
67 Self { salt }
68 }
69
70 pub fn with_salt(salt: [u8; SALT_LENGTH]) -> Self {
71 Self { salt }
72 }
73
74 pub fn salt_base64(&self) -> String {
75 general_purpose::STANDARD.encode(self.salt)
76 }
77
78 pub fn from_salt_base64(salt_b64: &str) -> Result<Self, EncryptionError> {
79 let salt_bytes = general_purpose::STANDARD
80 .decode(salt_b64)
81 .map_err(|e| EncryptionError::InvalidData(format!("Invalid salt base64: {e}")))?;
82
83 if salt_bytes.len() != SALT_LENGTH {
84 return Err(EncryptionError::InvalidData(format!(
85 "Salt length must be {} bytes, got {}",
86 SALT_LENGTH,
87 salt_bytes.len()
88 )));
89 }
90
91 let mut salt = [0u8; SALT_LENGTH];
92 salt.copy_from_slice(&salt_bytes);
93 Ok(Self::with_salt(salt))
94 }
95
96 fn derive_key(&self, password: &str) -> Result<SecureKey, EncryptionError> {
97 let mut key = [0u8; KEY_LENGTH];
98 pbkdf2_hmac::<Sha256>(password.as_bytes(), &self.salt, PBKDF2_ITERATIONS, &mut key);
99 Ok(SecureKey::new(key))
100 }
101
102 pub fn encrypt(
103 &self,
104 plaintext: &str,
105 password: &str,
106 empty_error: &str,
107 ) -> Result<String, EncryptionError> {
108 if plaintext.trim().is_empty() {
109 return Err(EncryptionError::InvalidData(empty_error.to_string()));
110 }
111
112 if password.trim().is_empty() {
113 return Err(EncryptionError::InvalidData(
114 ERROR_EMPTY_PASSWORD.to_string(),
115 ));
116 }
117
118 let key = self.derive_key(password)?;
119
120 let cipher = Aes256Gcm::new_from_slice(key.as_bytes())
121 .map_err(|e| EncryptionError::KeyDerivation(format!("Invalid key: {e}")))?;
122
123 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
124
125 let ciphertext = cipher.encrypt(&nonce, plaintext.as_bytes()).map_err(|e| {
126 EncryptionError::EncryptionFailed(format!("AES-GCM encryption failed: {e}"))
127 })?;
128
129 let mut combined = Vec::with_capacity(NONCE_LENGTH + ciphertext.len());
131 combined.extend_from_slice(&nonce);
132 combined.extend_from_slice(&ciphertext);
133
134 Ok(general_purpose::STANDARD.encode(combined))
135 }
136
137 pub fn decrypt(&self, encrypted: &str, password: &str) -> Result<String, EncryptionError> {
138 if encrypted.trim().is_empty() {
139 return Err(EncryptionError::InvalidData(
140 ERROR_EMPTY_ENCRYPTED_DATA.to_string(),
141 ));
142 }
143
144 if password.trim().is_empty() {
145 return Err(EncryptionError::InvalidData(
146 ERROR_EMPTY_PASSWORD.to_string(),
147 ));
148 }
149
150 let combined = general_purpose::STANDARD
151 .decode(encrypted)
152 .map_err(|e| EncryptionError::InvalidData(format!("Invalid base64: {e}")))?;
153
154 if combined.len() < NONCE_LENGTH {
155 return Err(EncryptionError::InvalidData(
156 ERROR_ENCRYPTED_DATA_TOO_SHORT.to_string(),
157 ));
158 }
159
160 let (nonce_bytes, ciphertext) = combined.split_at(NONCE_LENGTH);
161
162 let nonce = Nonce::from_slice(nonce_bytes);
163
164 let key = self.derive_key(password)?;
165
166 let cipher = Aes256Gcm::new_from_slice(key.as_bytes())
167 .map_err(|e| EncryptionError::KeyDerivation(format!("Invalid key: {e}")))?;
168
169 let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|e| {
170 EncryptionError::DecryptionFailed(format!("AES-GCM decryption failed: {e}"))
171 })?;
172
173 String::from_utf8(plaintext)
174 .map_err(|e| EncryptionError::DecryptionFailed(format!("Invalid UTF-8: {e}")))
175 }
176}
177
178impl Default for AesEncryption {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184pub struct ConnectionStringEncryption {
185 inner: AesEncryption,
186}
187
188impl ConnectionStringEncryption {
189 pub fn new() -> Self {
190 Self {
191 inner: AesEncryption::new(),
192 }
193 }
194
195 pub fn with_salt(salt: [u8; SALT_LENGTH]) -> Self {
196 Self {
197 inner: AesEncryption::with_salt(salt),
198 }
199 }
200
201 pub fn salt_base64(&self) -> String {
202 self.inner.salt_base64()
203 }
204
205 pub fn from_salt_base64(salt_b64: &str) -> Result<Self, EncryptionError> {
206 Ok(Self {
207 inner: AesEncryption::from_salt_base64(salt_b64)?,
208 })
209 }
210
211 pub fn encrypt_connection_string(
212 &self,
213 plaintext: &str,
214 password: &str,
215 ) -> Result<String, EncryptionError> {
216 self.inner
217 .encrypt(plaintext, password, ERROR_EMPTY_CONNECTION_STRING)
218 }
219
220 pub fn decrypt_connection_string(
221 &self,
222 encrypted: &str,
223 password: &str,
224 ) -> Result<String, EncryptionError> {
225 self.inner.decrypt(encrypted, password)
226 }
227}
228
229impl Default for ConnectionStringEncryption {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235pub struct ClientSecretEncryption {
236 inner: AesEncryption,
237}
238
239impl ClientSecretEncryption {
240 pub fn new() -> Self {
241 Self {
242 inner: AesEncryption::new(),
243 }
244 }
245
246 pub fn with_salt(salt: [u8; SALT_LENGTH]) -> Self {
247 Self {
248 inner: AesEncryption::with_salt(salt),
249 }
250 }
251
252 pub fn salt_base64(&self) -> String {
253 self.inner.salt_base64()
254 }
255
256 pub fn from_salt_base64(salt_b64: &str) -> Result<Self, EncryptionError> {
257 Ok(Self {
258 inner: AesEncryption::from_salt_base64(salt_b64)?,
259 })
260 }
261
262 pub fn encrypt_client_secret(
263 &self,
264 plaintext: &str,
265 password: &str,
266 ) -> Result<String, EncryptionError> {
267 self.inner
268 .encrypt(plaintext, password, ERROR_EMPTY_CLIENT_SECRET)
269 }
270
271 pub fn decrypt_client_secret(
272 &self,
273 encrypted: &str,
274 password: &str,
275 ) -> Result<String, EncryptionError> {
276 self.inner.decrypt(encrypted, password)
277 }
278}
279
280impl Default for ClientSecretEncryption {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_encrypt_decrypt_roundtrip() {
292 let encryption = ConnectionStringEncryption::new();
293 let plaintext = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=test123";
294 let password = "test_password_123";
295
296 let encrypted = encryption
297 .encrypt_connection_string(plaintext, password)
298 .expect("Encryption should succeed");
299
300 let decrypted = encryption
301 .decrypt_connection_string(&encrypted, password)
302 .expect("Decryption should succeed");
303
304 assert_eq!(plaintext, decrypted);
305 }
306
307 #[test]
308 fn test_wrong_password_fails() {
309 let encryption = ConnectionStringEncryption::new();
310 let plaintext = "test connection string";
311 let password = "correct_password";
312 let wrong_password = "wrong_password";
313
314 let encrypted = encryption
315 .encrypt_connection_string(plaintext, password)
316 .expect("Encryption should succeed");
317
318 let result = encryption.decrypt_connection_string(&encrypted, wrong_password);
319 assert!(result.is_err());
320 }
321
322 #[test]
323 fn test_empty_inputs() {
324 let encryption = ConnectionStringEncryption::new();
325
326 assert!(
327 encryption
328 .encrypt_connection_string("", "password")
329 .is_err()
330 );
331 assert!(encryption.encrypt_connection_string("data", "").is_err());
332 assert!(
333 encryption
334 .decrypt_connection_string("", "password")
335 .is_err()
336 );
337 assert!(encryption.decrypt_connection_string("data", "").is_err());
338 }
339
340 #[test]
341 fn test_salt_persistence() {
342 let salt_b64 = "dGVzdF9zYWx0XzEyMzQ1Njc4OTBfYWJjZGVmZ2hpams=";
343 let encryption1 = ConnectionStringEncryption::from_salt_base64(salt_b64)
344 .expect("Should create from base64 salt");
345 let encryption2 = ConnectionStringEncryption::from_salt_base64(salt_b64)
346 .expect("Should create from same base64 salt");
347
348 let plaintext = "test connection string";
349 let password = "test_password";
350
351 let encrypted1 = encryption1
352 .encrypt_connection_string(plaintext, password)
353 .expect("Encryption 1 should succeed");
354
355 let decrypted2 = encryption2
356 .decrypt_connection_string(&encrypted1, password)
357 .expect("Decryption 2 should succeed");
358
359 assert_eq!(plaintext, decrypted2);
360 }
361
362 #[test]
363 fn test_different_salts_produce_different_ciphertexts() {
364 let encryption1 = ConnectionStringEncryption::new();
365 let encryption2 = ConnectionStringEncryption::new();
366
367 let plaintext = "test connection string";
368 let password = "test_password";
369
370 let encrypted1 = encryption1
371 .encrypt_connection_string(plaintext, password)
372 .expect("Encryption 1 should succeed");
373
374 let encrypted2 = encryption2
375 .encrypt_connection_string(plaintext, password)
376 .expect("Encryption 2 should succeed");
377
378 assert_ne!(
379 encrypted1, encrypted2,
380 "Different salts should produce different ciphertexts"
381 );
382 }
383
384 #[test]
385 fn test_client_secret_encrypt_decrypt_roundtrip() {
386 let encryption = ClientSecretEncryption::new();
387 let plaintext = "secret_client_value_123";
388 let password = "test_password_456";
389
390 let encrypted = encryption
391 .encrypt_client_secret(plaintext, password)
392 .expect("Encryption should succeed");
393
394 let decrypted = encryption
395 .decrypt_client_secret(&encrypted, password)
396 .expect("Decryption should succeed");
397
398 assert_eq!(plaintext, decrypted);
399 }
400
401 #[test]
402 fn test_client_secret_wrong_password_fails() {
403 let encryption = ClientSecretEncryption::new();
404 let plaintext = "test client secret";
405 let password = "correct_password";
406 let wrong_password = "wrong_password";
407
408 let encrypted = encryption
409 .encrypt_client_secret(plaintext, password)
410 .expect("Encryption should succeed");
411
412 let result = encryption.decrypt_client_secret(&encrypted, wrong_password);
413 assert!(result.is_err());
414 }
415
416 #[test]
417 fn test_client_secret_empty_inputs() {
418 let encryption = ClientSecretEncryption::new();
419
420 assert!(encryption.encrypt_client_secret("", "password").is_err());
421 assert!(encryption.encrypt_client_secret("data", "").is_err());
422 assert!(encryption.decrypt_client_secret("", "password").is_err());
423 assert!(encryption.decrypt_client_secret("data", "").is_err());
424 }
425
426 #[test]
427 fn test_client_secret_salt_persistence() {
428 let salt_b64 = "J+CP5+9lfcD/SndIFvvdIEnltiA4UVtsraLndlzXSVk="; let encryption1 = ClientSecretEncryption::from_salt_base64(salt_b64)
431 .expect("Should create from base64 salt");
432 let encryption2 = ClientSecretEncryption::from_salt_base64(salt_b64)
433 .expect("Should create from same base64 salt");
434
435 let plaintext = "test client secret";
436 let password = "test_password";
437
438 let encrypted1 = encryption1
439 .encrypt_client_secret(plaintext, password)
440 .expect("Encryption 1 should succeed");
441
442 let decrypted2 = encryption2
443 .decrypt_client_secret(&encrypted1, password)
444 .expect("Decryption 2 should succeed");
445
446 assert_eq!(plaintext, decrypted2);
447 }
448
449 #[test]
450 fn test_client_secret_different_salts_produce_different_ciphertexts() {
451 let encryption1 = ClientSecretEncryption::new();
452 let encryption2 = ClientSecretEncryption::new();
453
454 let plaintext = "test client secret";
455 let password = "test_password";
456
457 let encrypted1 = encryption1
458 .encrypt_client_secret(plaintext, password)
459 .expect("Encryption 1 should succeed");
460
461 let encrypted2 = encryption2
462 .encrypt_client_secret(plaintext, password)
463 .expect("Encryption 2 should succeed");
464
465 assert_ne!(
466 encrypted1, encrypted2,
467 "Different salts should produce different ciphertexts"
468 );
469 }
470}