1use aes::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit, block_padding::Pkcs7};
28use hmac::{Hmac, Mac};
29use rand::RngCore;
30use sha2::Sha256;
31
32use crate::encryption::{EncryptionError, EncryptionType};
33
34const VERSION_BYTE: u8 = 0x01;
36
37const AES_BLOCK_SIZE: usize = 16;
39
40const AES_KEY_SIZE: usize = 32;
42
43const MAC_SIZE: usize = 32;
45
46const IV_SIZE: usize = 16;
48
49const MIN_CIPHERTEXT_SIZE: usize = 1 + MAC_SIZE + IV_SIZE + AES_BLOCK_SIZE;
51
52const ENCRYPTION_KEY_SALT: &str = "Microsoft SQL Server cell encryption key with encryption algorithm:AEAD_AES_256_CBC_HMAC_SHA256 and key length:256";
58const MAC_KEY_SALT: &str = "Microsoft SQL Server cell MAC key with encryption algorithm:AEAD_AES_256_CBC_HMAC_SHA256 and key length:256";
59const IV_KEY_SALT: &str = "Microsoft SQL Server cell IV key with encryption algorithm:AEAD_AES_256_CBC_HMAC_SHA256 and key length:256";
60
61type HmacSha256 = Hmac<Sha256>;
63type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
64type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
65
66#[derive(Clone)]
75#[cfg_attr(feature = "zeroize", derive(zeroize::Zeroize, zeroize::ZeroizeOnDrop))]
76pub struct DerivedKeys {
77 enc_key: [u8; AES_KEY_SIZE],
79 mac_key: [u8; AES_KEY_SIZE],
81 iv_key: [u8; AES_KEY_SIZE],
83}
84
85impl DerivedKeys {
86 pub fn derive(cek: &[u8]) -> Result<Self, EncryptionError> {
99 if cek.len() != AES_KEY_SIZE {
100 return Err(EncryptionError::ConfigurationError(format!(
101 "CEK must be {} bytes, got {}",
102 AES_KEY_SIZE,
103 cek.len()
104 )));
105 }
106
107 let enc_key = Self::derive_key(cek, ENCRYPTION_KEY_SALT)?;
108 let mac_key = Self::derive_key(cek, MAC_KEY_SALT)?;
109 let iv_key = Self::derive_key(cek, IV_KEY_SALT)?;
110
111 Ok(Self {
112 enc_key,
113 mac_key,
114 iv_key,
115 })
116 }
117
118 fn derive_key(cek: &[u8], salt: &str) -> Result<[u8; AES_KEY_SIZE], EncryptionError> {
120 let mut mac = HmacSha256::new_from_slice(cek)
121 .map_err(|e| EncryptionError::EncryptionFailed(format!("HMAC init failed: {e}")))?;
122
123 for unit in salt.encode_utf16() {
124 mac.update(&unit.to_le_bytes());
125 }
126
127 let result = mac.finalize().into_bytes();
128 let mut key = [0u8; AES_KEY_SIZE];
129 key.copy_from_slice(&result);
130 Ok(key)
131 }
132
133 pub fn generate_iv(
138 &self,
139 encryption_type: EncryptionType,
140 plaintext: &[u8],
141 ) -> Result<[u8; IV_SIZE], EncryptionError> {
142 match encryption_type {
143 EncryptionType::Randomized => {
144 let mut iv = [0u8; IV_SIZE];
145 rand::thread_rng().fill_bytes(&mut iv);
146 Ok(iv)
147 }
148 EncryptionType::Deterministic => {
149 let mut mac = HmacSha256::new_from_slice(&self.iv_key).map_err(|e| {
151 EncryptionError::EncryptionFailed(format!("HMAC init failed: {e}"))
152 })?;
153 mac.update(plaintext);
154 let result = mac.finalize().into_bytes();
155 let mut iv = [0u8; IV_SIZE];
156 iv.copy_from_slice(&result[..IV_SIZE]);
157 Ok(iv)
158 }
159 }
160 }
161}
162
163#[cfg(not(feature = "zeroize"))]
166impl Drop for DerivedKeys {
167 fn drop(&mut self) {
168 self.enc_key.fill(0);
172 self.mac_key.fill(0);
173 self.iv_key.fill(0);
174 }
175}
176
177pub struct AeadEncryptor {
181 keys: DerivedKeys,
182}
183
184impl AeadEncryptor {
185 pub fn new(cek: &[u8]) -> Result<Self, EncryptionError> {
195 let keys = DerivedKeys::derive(cek)?;
196 Ok(Self { keys })
197 }
198
199 pub fn encrypt(
214 &self,
215 plaintext: &[u8],
216 encryption_type: EncryptionType,
217 ) -> Result<Vec<u8>, EncryptionError> {
218 let iv = self.keys.generate_iv(encryption_type, plaintext)?;
220
221 let padded_len = ((plaintext.len() / AES_BLOCK_SIZE) + 1) * AES_BLOCK_SIZE;
223 let mut cipher_buf = vec![0u8; padded_len];
224 cipher_buf[..plaintext.len()].copy_from_slice(plaintext);
225
226 let cipher = Aes256CbcEnc::new_from_slices(&self.keys.enc_key, &iv)
228 .map_err(|e| EncryptionError::EncryptionFailed(format!("AES init failed: {e}")))?;
229
230 let ciphertext = cipher
231 .encrypt_padded_mut::<Pkcs7>(&mut cipher_buf, plaintext.len())
232 .map_err(|e| {
233 EncryptionError::EncryptionFailed(format!("AES encryption failed: {e}"))
234 })?;
235
236 let mac = self.compute_mac(&iv, ciphertext)?;
238
239 let mut output = Vec::with_capacity(1 + MAC_SIZE + IV_SIZE + ciphertext.len());
241 output.push(VERSION_BYTE);
242 output.extend_from_slice(&mac);
243 output.extend_from_slice(&iv);
244 output.extend_from_slice(ciphertext);
245
246 Ok(output)
247 }
248
249 pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
267 if ciphertext.len() < MIN_CIPHERTEXT_SIZE {
269 return Err(EncryptionError::DecryptionFailed(format!(
270 "Ciphertext too short: {} bytes, minimum {}",
271 ciphertext.len(),
272 MIN_CIPHERTEXT_SIZE
273 )));
274 }
275
276 if ciphertext[0] != VERSION_BYTE {
278 return Err(EncryptionError::DecryptionFailed(format!(
279 "Invalid version byte: expected {:#04x}, got {:#04x}",
280 VERSION_BYTE, ciphertext[0]
281 )));
282 }
283
284 let stored_mac = &ciphertext[1..1 + MAC_SIZE];
286 let iv = &ciphertext[1 + MAC_SIZE..1 + MAC_SIZE + IV_SIZE];
287 let encrypted_data = &ciphertext[1 + MAC_SIZE + IV_SIZE..];
288
289 let computed_mac = self.compute_mac(iv, encrypted_data)?;
291 if !constant_time_compare(stored_mac, &computed_mac) {
292 return Err(EncryptionError::DecryptionFailed(
293 "MAC verification failed: data may be tampered".into(),
294 ));
295 }
296
297 let cipher = Aes256CbcDec::new_from_slices(&self.keys.enc_key, iv)
299 .map_err(|e| EncryptionError::DecryptionFailed(format!("AES init failed: {e}")))?;
300
301 let mut buf = encrypted_data.to_vec();
302 let plaintext = cipher.decrypt_padded_mut::<Pkcs7>(&mut buf).map_err(|e| {
303 EncryptionError::DecryptionFailed(format!("AES decryption failed: {e}"))
304 })?;
305
306 Ok(plaintext.to_vec())
307 }
308
309 fn compute_mac(&self, iv: &[u8], ciphertext: &[u8]) -> Result<[u8; MAC_SIZE], EncryptionError> {
313 let mut mac = HmacSha256::new_from_slice(&self.keys.mac_key)
314 .map_err(|e| EncryptionError::EncryptionFailed(format!("HMAC init failed: {e}")))?;
315
316 mac.update(&[VERSION_BYTE]);
317 mac.update(iv);
318 mac.update(ciphertext);
319 mac.update(&[1u8]); let result = mac.finalize().into_bytes();
322 let mut output = [0u8; MAC_SIZE];
323 output.copy_from_slice(&result);
324 Ok(output)
325 }
326}
327
328fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
330 if a.len() != b.len() {
331 return false;
332 }
333
334 let mut diff = 0u8;
335 for (x, y) in a.iter().zip(b.iter()) {
336 diff |= x ^ y;
337 }
338 diff == 0
339}
340
341#[cfg(test)]
342#[allow(clippy::unwrap_used, clippy::expect_used)]
343mod tests {
344 use super::*;
345
346 fn test_cek() -> [u8; 32] {
348 [
349 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
350 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b,
351 0x1c, 0x1d, 0x1e, 0x1f,
352 ]
353 }
354
355 #[test]
356 fn test_key_derivation() {
357 let cek = test_cek();
358 let keys = DerivedKeys::derive(&cek).unwrap();
359
360 assert!(!keys.enc_key.iter().all(|&b| b == 0));
362 assert!(!keys.mac_key.iter().all(|&b| b == 0));
363 assert!(!keys.iv_key.iter().all(|&b| b == 0));
364
365 assert_ne!(keys.enc_key, keys.mac_key);
367 assert_ne!(keys.mac_key, keys.iv_key);
368 assert_ne!(keys.enc_key, keys.iv_key);
369 }
370
371 #[test]
375 fn test_key_derivation_spec_vectors() {
376 let keys = DerivedKeys::derive(&test_cek()).unwrap();
377
378 let spec_enc_key = [
379 0x6c, 0x00, 0x21, 0xc6, 0xbd, 0xb8, 0x6c, 0xa2, 0xbc, 0x0f, 0x82, 0x42, 0x9c, 0x9d,
380 0x32, 0x33, 0xc7, 0xc9, 0xb8, 0x5c, 0x2b, 0xba, 0x43, 0xcb, 0xb2, 0xc8, 0xae, 0xa6,
381 0xfa, 0x83, 0x01, 0x1f,
382 ];
383 let spec_mac_key = [
384 0xa9, 0x35, 0x1d, 0xf2, 0xfd, 0x2a, 0x87, 0x57, 0x99, 0xd7, 0x9b, 0x04, 0xe6, 0x11,
385 0x28, 0x71, 0xed, 0x46, 0x27, 0xa8, 0x36, 0xb3, 0x2c, 0xa1, 0x05, 0xf5, 0x18, 0xa3,
386 0xe6, 0x3a, 0x16, 0x4f,
387 ];
388 let spec_iv_key = [
389 0x7b, 0x1e, 0xe9, 0xe7, 0x32, 0x24, 0x48, 0xdb, 0x99, 0x9d, 0x5f, 0xc9, 0x29, 0x47,
390 0xb3, 0x6d, 0x7c, 0x03, 0x49, 0x21, 0xec, 0xc5, 0xf9, 0x8e, 0x08, 0x8f, 0xc8, 0x7b,
391 0x81, 0x74, 0xb1, 0x2e,
392 ];
393
394 assert_eq!(keys.enc_key, spec_enc_key);
395 assert_eq!(keys.mac_key, spec_mac_key);
396 assert_eq!(keys.iv_key, spec_iv_key);
397 }
398
399 #[test]
400 fn test_key_derivation_invalid_length() {
401 let short_cek = [0u8; 16];
402 let result = DerivedKeys::derive(&short_cek);
403 assert!(result.is_err());
404 }
405
406 #[test]
407 fn test_encrypt_decrypt_randomized() {
408 let cek = test_cek();
409 let encryptor = AeadEncryptor::new(&cek).unwrap();
410
411 let plaintext = b"Hello, SQL Server Always Encrypted!";
412 let ciphertext = encryptor
413 .encrypt(plaintext, EncryptionType::Randomized)
414 .unwrap();
415
416 assert!(ciphertext.len() >= MIN_CIPHERTEXT_SIZE);
418
419 assert_eq!(ciphertext[0], VERSION_BYTE);
421
422 let decrypted = encryptor.decrypt(&ciphertext).unwrap();
424 assert_eq!(decrypted, plaintext);
425 }
426
427 #[test]
428 fn test_encrypt_decrypt_deterministic() {
429 let cek = test_cek();
430 let encryptor = AeadEncryptor::new(&cek).unwrap();
431
432 let plaintext = b"Deterministic encryption test";
433
434 let ciphertext1 = encryptor
436 .encrypt(plaintext, EncryptionType::Deterministic)
437 .unwrap();
438 let ciphertext2 = encryptor
439 .encrypt(plaintext, EncryptionType::Deterministic)
440 .unwrap();
441
442 assert_eq!(ciphertext1, ciphertext2);
444
445 let decrypted = encryptor.decrypt(&ciphertext1).unwrap();
447 assert_eq!(decrypted, plaintext);
448 }
449
450 #[test]
451 fn test_randomized_produces_different_ciphertext() {
452 let cek = test_cek();
453 let encryptor = AeadEncryptor::new(&cek).unwrap();
454
455 let plaintext = b"Same plaintext";
456
457 let ciphertext1 = encryptor
458 .encrypt(plaintext, EncryptionType::Randomized)
459 .unwrap();
460 let ciphertext2 = encryptor
461 .encrypt(plaintext, EncryptionType::Randomized)
462 .unwrap();
463
464 assert_ne!(ciphertext1, ciphertext2);
466
467 assert_eq!(
469 encryptor.decrypt(&ciphertext1).unwrap(),
470 encryptor.decrypt(&ciphertext2).unwrap()
471 );
472 }
473
474 #[test]
475 fn test_decrypt_tampered_data() {
476 let cek = test_cek();
477 let encryptor = AeadEncryptor::new(&cek).unwrap();
478
479 let plaintext = b"Original data";
480 let mut ciphertext = encryptor
481 .encrypt(plaintext, EncryptionType::Randomized)
482 .unwrap();
483
484 let last_idx = ciphertext.len() - 1;
486 ciphertext[last_idx] ^= 0xFF;
487
488 let result = encryptor.decrypt(&ciphertext);
490 assert!(result.is_err());
491 }
492
493 #[test]
494 fn test_decrypt_invalid_version() {
495 let cek = test_cek();
496 let encryptor = AeadEncryptor::new(&cek).unwrap();
497
498 let plaintext = b"Test data";
499 let mut ciphertext = encryptor
500 .encrypt(plaintext, EncryptionType::Randomized)
501 .unwrap();
502
503 ciphertext[0] = 0x02;
505
506 let result = encryptor.decrypt(&ciphertext);
507 assert!(result.is_err());
508 assert!(
509 result
510 .unwrap_err()
511 .to_string()
512 .contains("Invalid version byte")
513 );
514 }
515
516 #[test]
517 fn test_decrypt_too_short() {
518 let cek = test_cek();
519 let encryptor = AeadEncryptor::new(&cek).unwrap();
520
521 let short_data = vec![0u8; 10];
522 let result = encryptor.decrypt(&short_data);
523 assert!(result.is_err());
524 assert!(result.unwrap_err().to_string().contains("too short"));
525 }
526
527 #[test]
528 fn test_empty_plaintext() {
529 let cek = test_cek();
530 let encryptor = AeadEncryptor::new(&cek).unwrap();
531
532 let plaintext = b"";
533 let ciphertext = encryptor
534 .encrypt(plaintext, EncryptionType::Randomized)
535 .unwrap();
536
537 let decrypted = encryptor.decrypt(&ciphertext).unwrap();
538 assert_eq!(decrypted, plaintext);
539 }
540
541 #[test]
542 fn test_large_plaintext() {
543 let cek = test_cek();
544 let encryptor = AeadEncryptor::new(&cek).unwrap();
545
546 let plaintext: Vec<u8> = (0..10240).map(|i| (i % 256) as u8).collect();
548 let ciphertext = encryptor
549 .encrypt(&plaintext, EncryptionType::Randomized)
550 .unwrap();
551
552 let decrypted = encryptor.decrypt(&ciphertext).unwrap();
553 assert_eq!(decrypted, plaintext);
554 }
555
556 #[test]
557 fn test_constant_time_compare() {
558 let a = [1, 2, 3, 4, 5];
559 let b = [1, 2, 3, 4, 5];
560 let c = [1, 2, 3, 4, 6];
561
562 assert!(constant_time_compare(&a, &b));
563 assert!(!constant_time_compare(&a, &c));
564 assert!(!constant_time_compare(&a, &[1, 2, 3]));
565 }
566}