server/
encryption.rs

1use aes_gcm::{
2    Aes256Gcm, Nonce,
3    aead::{Aead, AeadCore, KeyInit},
4};
5use base64::{Engine as _, engine::general_purpose};
6use pbkdf2::pbkdf2_hmac;
7use rand::{RngCore, rngs::OsRng};
8use sha2::Sha256;
9use std::fmt;
10use zeroize::ZeroizeOnDrop;
11
12const PBKDF2_ITERATIONS: u32 = 100_000;
13const SALT_LENGTH: usize = 32;
14const KEY_LENGTH: usize = 32;
15const NONCE_LENGTH: usize = 12;
16
17// Error messages
18const ERROR_EMPTY_CONNECTION_STRING: &str = "Connection string cannot be empty";
19const ERROR_EMPTY_CLIENT_SECRET: &str = "Client secret cannot be empty";
20const ERROR_EMPTY_PASSWORD: &str = "Password cannot be empty";
21const ERROR_EMPTY_ENCRYPTED_DATA: &str = "Encrypted data cannot be empty";
22const ERROR_ENCRYPTED_DATA_TOO_SHORT: &str = "Encrypted data too short";
23
24#[derive(Debug)]
25pub enum EncryptionError {
26    InvalidData(String),
27    EncryptionFailed(String),
28    DecryptionFailed(String),
29    KeyDerivation(String),
30}
31
32impl fmt::Display for EncryptionError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            EncryptionError::InvalidData(msg) => write!(f, "Invalid data: {msg}"),
36            EncryptionError::EncryptionFailed(msg) => write!(f, "Encryption failed: {msg}"),
37            EncryptionError::DecryptionFailed(msg) => write!(f, "Decryption failed: {msg}"),
38            EncryptionError::KeyDerivation(msg) => write!(f, "Key derivation failed: {msg}"),
39        }
40    }
41}
42
43impl std::error::Error for EncryptionError {}
44
45#[derive(ZeroizeOnDrop)]
46struct SecureKey([u8; KEY_LENGTH]);
47
48impl SecureKey {
49    fn new(key: [u8; KEY_LENGTH]) -> Self {
50        Self(key)
51    }
52
53    fn as_bytes(&self) -> &[u8; KEY_LENGTH] {
54        &self.0
55    }
56}
57
58/// Common encryption implementation for AES-256-GCM with PBKDF2 key derivation
59pub struct AesEncryption {
60    salt: [u8; SALT_LENGTH],
61}
62
63impl AesEncryption {
64    pub fn new() -> Self {
65        let mut salt = [0u8; SALT_LENGTH];
66        OsRng.fill_bytes(&mut salt);
67        Self { salt }
68    }
69
70    pub fn with_salt(salt: [u8; SALT_LENGTH]) -> Self {
71        Self { salt }
72    }
73
74    pub fn salt_base64(&self) -> String {
75        general_purpose::STANDARD.encode(self.salt)
76    }
77
78    pub fn from_salt_base64(salt_b64: &str) -> Result<Self, EncryptionError> {
79        let salt_bytes = general_purpose::STANDARD
80            .decode(salt_b64)
81            .map_err(|e| EncryptionError::InvalidData(format!("Invalid salt base64: {e}")))?;
82
83        if salt_bytes.len() != SALT_LENGTH {
84            return Err(EncryptionError::InvalidData(format!(
85                "Salt length must be {} bytes, got {}",
86                SALT_LENGTH,
87                salt_bytes.len()
88            )));
89        }
90
91        let mut salt = [0u8; SALT_LENGTH];
92        salt.copy_from_slice(&salt_bytes);
93        Ok(Self::with_salt(salt))
94    }
95
96    fn derive_key(&self, password: &str) -> Result<SecureKey, EncryptionError> {
97        let mut key = [0u8; KEY_LENGTH];
98        pbkdf2_hmac::<Sha256>(password.as_bytes(), &self.salt, PBKDF2_ITERATIONS, &mut key);
99        Ok(SecureKey::new(key))
100    }
101
102    pub fn encrypt(
103        &self,
104        plaintext: &str,
105        password: &str,
106        empty_error: &str,
107    ) -> Result<String, EncryptionError> {
108        if plaintext.trim().is_empty() {
109            return Err(EncryptionError::InvalidData(empty_error.to_string()));
110        }
111
112        if password.trim().is_empty() {
113            return Err(EncryptionError::InvalidData(
114                ERROR_EMPTY_PASSWORD.to_string(),
115            ));
116        }
117
118        let key = self.derive_key(password)?;
119
120        let cipher = Aes256Gcm::new_from_slice(key.as_bytes())
121            .map_err(|e| EncryptionError::KeyDerivation(format!("Invalid key: {e}")))?;
122
123        let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
124
125        let ciphertext = cipher.encrypt(&nonce, plaintext.as_bytes()).map_err(|e| {
126            EncryptionError::EncryptionFailed(format!("AES-GCM encryption failed: {e}"))
127        })?;
128
129        // Format: nonce + ciphertext, all base64 encoded
130        let mut combined = Vec::with_capacity(NONCE_LENGTH + ciphertext.len());
131        combined.extend_from_slice(&nonce);
132        combined.extend_from_slice(&ciphertext);
133
134        Ok(general_purpose::STANDARD.encode(combined))
135    }
136
137    pub fn decrypt(&self, encrypted: &str, password: &str) -> Result<String, EncryptionError> {
138        if encrypted.trim().is_empty() {
139            return Err(EncryptionError::InvalidData(
140                ERROR_EMPTY_ENCRYPTED_DATA.to_string(),
141            ));
142        }
143
144        if password.trim().is_empty() {
145            return Err(EncryptionError::InvalidData(
146                ERROR_EMPTY_PASSWORD.to_string(),
147            ));
148        }
149
150        let combined = general_purpose::STANDARD
151            .decode(encrypted)
152            .map_err(|e| EncryptionError::InvalidData(format!("Invalid base64: {e}")))?;
153
154        if combined.len() < NONCE_LENGTH {
155            return Err(EncryptionError::InvalidData(
156                ERROR_ENCRYPTED_DATA_TOO_SHORT.to_string(),
157            ));
158        }
159
160        let (nonce_bytes, ciphertext) = combined.split_at(NONCE_LENGTH);
161
162        let nonce = Nonce::from_slice(nonce_bytes);
163
164        let key = self.derive_key(password)?;
165
166        let cipher = Aes256Gcm::new_from_slice(key.as_bytes())
167            .map_err(|e| EncryptionError::KeyDerivation(format!("Invalid key: {e}")))?;
168
169        let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|e| {
170            EncryptionError::DecryptionFailed(format!("AES-GCM decryption failed: {e}"))
171        })?;
172
173        String::from_utf8(plaintext)
174            .map_err(|e| EncryptionError::DecryptionFailed(format!("Invalid UTF-8: {e}")))
175    }
176}
177
178impl Default for AesEncryption {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184pub struct ConnectionStringEncryption {
185    inner: AesEncryption,
186}
187
188impl ConnectionStringEncryption {
189    pub fn new() -> Self {
190        Self {
191            inner: AesEncryption::new(),
192        }
193    }
194
195    pub fn with_salt(salt: [u8; SALT_LENGTH]) -> Self {
196        Self {
197            inner: AesEncryption::with_salt(salt),
198        }
199    }
200
201    pub fn salt_base64(&self) -> String {
202        self.inner.salt_base64()
203    }
204
205    pub fn from_salt_base64(salt_b64: &str) -> Result<Self, EncryptionError> {
206        Ok(Self {
207            inner: AesEncryption::from_salt_base64(salt_b64)?,
208        })
209    }
210
211    pub fn encrypt_connection_string(
212        &self,
213        plaintext: &str,
214        password: &str,
215    ) -> Result<String, EncryptionError> {
216        self.inner
217            .encrypt(plaintext, password, ERROR_EMPTY_CONNECTION_STRING)
218    }
219
220    pub fn decrypt_connection_string(
221        &self,
222        encrypted: &str,
223        password: &str,
224    ) -> Result<String, EncryptionError> {
225        self.inner.decrypt(encrypted, password)
226    }
227}
228
229impl Default for ConnectionStringEncryption {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235pub struct ClientSecretEncryption {
236    inner: AesEncryption,
237}
238
239impl ClientSecretEncryption {
240    pub fn new() -> Self {
241        Self {
242            inner: AesEncryption::new(),
243        }
244    }
245
246    pub fn with_salt(salt: [u8; SALT_LENGTH]) -> Self {
247        Self {
248            inner: AesEncryption::with_salt(salt),
249        }
250    }
251
252    pub fn salt_base64(&self) -> String {
253        self.inner.salt_base64()
254    }
255
256    pub fn from_salt_base64(salt_b64: &str) -> Result<Self, EncryptionError> {
257        Ok(Self {
258            inner: AesEncryption::from_salt_base64(salt_b64)?,
259        })
260    }
261
262    pub fn encrypt_client_secret(
263        &self,
264        plaintext: &str,
265        password: &str,
266    ) -> Result<String, EncryptionError> {
267        self.inner
268            .encrypt(plaintext, password, ERROR_EMPTY_CLIENT_SECRET)
269    }
270
271    pub fn decrypt_client_secret(
272        &self,
273        encrypted: &str,
274        password: &str,
275    ) -> Result<String, EncryptionError> {
276        self.inner.decrypt(encrypted, password)
277    }
278}
279
280impl Default for ClientSecretEncryption {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_encrypt_decrypt_roundtrip() {
292        let encryption = ConnectionStringEncryption::new();
293        let plaintext = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=test123";
294        let password = "test_password_123";
295
296        let encrypted = encryption
297            .encrypt_connection_string(plaintext, password)
298            .expect("Encryption should succeed");
299
300        let decrypted = encryption
301            .decrypt_connection_string(&encrypted, password)
302            .expect("Decryption should succeed");
303
304        assert_eq!(plaintext, decrypted);
305    }
306
307    #[test]
308    fn test_wrong_password_fails() {
309        let encryption = ConnectionStringEncryption::new();
310        let plaintext = "test connection string";
311        let password = "correct_password";
312        let wrong_password = "wrong_password";
313
314        let encrypted = encryption
315            .encrypt_connection_string(plaintext, password)
316            .expect("Encryption should succeed");
317
318        let result = encryption.decrypt_connection_string(&encrypted, wrong_password);
319        assert!(result.is_err());
320    }
321
322    #[test]
323    fn test_empty_inputs() {
324        let encryption = ConnectionStringEncryption::new();
325
326        assert!(
327            encryption
328                .encrypt_connection_string("", "password")
329                .is_err()
330        );
331        assert!(encryption.encrypt_connection_string("data", "").is_err());
332        assert!(
333            encryption
334                .decrypt_connection_string("", "password")
335                .is_err()
336        );
337        assert!(encryption.decrypt_connection_string("data", "").is_err());
338    }
339
340    #[test]
341    fn test_salt_persistence() {
342        let salt_b64 = "dGVzdF9zYWx0XzEyMzQ1Njc4OTBfYWJjZGVmZ2hpams=";
343        let encryption1 = ConnectionStringEncryption::from_salt_base64(salt_b64)
344            .expect("Should create from base64 salt");
345        let encryption2 = ConnectionStringEncryption::from_salt_base64(salt_b64)
346            .expect("Should create from same base64 salt");
347
348        let plaintext = "test connection string";
349        let password = "test_password";
350
351        let encrypted1 = encryption1
352            .encrypt_connection_string(plaintext, password)
353            .expect("Encryption 1 should succeed");
354
355        let decrypted2 = encryption2
356            .decrypt_connection_string(&encrypted1, password)
357            .expect("Decryption 2 should succeed");
358
359        assert_eq!(plaintext, decrypted2);
360    }
361
362    #[test]
363    fn test_different_salts_produce_different_ciphertexts() {
364        let encryption1 = ConnectionStringEncryption::new();
365        let encryption2 = ConnectionStringEncryption::new();
366
367        let plaintext = "test connection string";
368        let password = "test_password";
369
370        let encrypted1 = encryption1
371            .encrypt_connection_string(plaintext, password)
372            .expect("Encryption 1 should succeed");
373
374        let encrypted2 = encryption2
375            .encrypt_connection_string(plaintext, password)
376            .expect("Encryption 2 should succeed");
377
378        assert_ne!(
379            encrypted1, encrypted2,
380            "Different salts should produce different ciphertexts"
381        );
382    }
383
384    #[test]
385    fn test_client_secret_encrypt_decrypt_roundtrip() {
386        let encryption = ClientSecretEncryption::new();
387        let plaintext = "secret_client_value_123";
388        let password = "test_password_456";
389
390        let encrypted = encryption
391            .encrypt_client_secret(plaintext, password)
392            .expect("Encryption should succeed");
393
394        let decrypted = encryption
395            .decrypt_client_secret(&encrypted, password)
396            .expect("Decryption should succeed");
397
398        assert_eq!(plaintext, decrypted);
399    }
400
401    #[test]
402    fn test_client_secret_wrong_password_fails() {
403        let encryption = ClientSecretEncryption::new();
404        let plaintext = "test client secret";
405        let password = "correct_password";
406        let wrong_password = "wrong_password";
407
408        let encrypted = encryption
409            .encrypt_client_secret(plaintext, password)
410            .expect("Encryption should succeed");
411
412        let result = encryption.decrypt_client_secret(&encrypted, wrong_password);
413        assert!(result.is_err());
414    }
415
416    #[test]
417    fn test_client_secret_empty_inputs() {
418        let encryption = ClientSecretEncryption::new();
419
420        assert!(encryption.encrypt_client_secret("", "password").is_err());
421        assert!(encryption.encrypt_client_secret("data", "").is_err());
422        assert!(encryption.decrypt_client_secret("", "password").is_err());
423        assert!(encryption.decrypt_client_secret("data", "").is_err());
424    }
425
426    #[test]
427    fn test_client_secret_salt_persistence() {
428        // Generate a valid 32-byte salt and encode it to base64
429        let salt_b64 = "J+CP5+9lfcD/SndIFvvdIEnltiA4UVtsraLndlzXSVk="; // exactly 32 bytes when decoded
430        let encryption1 = ClientSecretEncryption::from_salt_base64(salt_b64)
431            .expect("Should create from base64 salt");
432        let encryption2 = ClientSecretEncryption::from_salt_base64(salt_b64)
433            .expect("Should create from same base64 salt");
434
435        let plaintext = "test client secret";
436        let password = "test_password";
437
438        let encrypted1 = encryption1
439            .encrypt_client_secret(plaintext, password)
440            .expect("Encryption 1 should succeed");
441
442        let decrypted2 = encryption2
443            .decrypt_client_secret(&encrypted1, password)
444            .expect("Decryption 2 should succeed");
445
446        assert_eq!(plaintext, decrypted2);
447    }
448
449    #[test]
450    fn test_client_secret_different_salts_produce_different_ciphertexts() {
451        let encryption1 = ClientSecretEncryption::new();
452        let encryption2 = ClientSecretEncryption::new();
453
454        let plaintext = "test client secret";
455        let password = "test_password";
456
457        let encrypted1 = encryption1
458            .encrypt_client_secret(plaintext, password)
459            .expect("Encryption 1 should succeed");
460
461        let encrypted2 = encryption2
462            .encrypt_client_secret(plaintext, password)
463            .expect("Encryption 2 should succeed");
464
465        assert_ne!(
466            encrypted1, encrypted2,
467            "Different salts should produce different ciphertexts"
468        );
469    }
470}