1use aes_gcm::{
4 Aes256Gcm, Nonce,
5 aead::{Aead, KeyInit},
6};
7
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11pub enum AesError {
12 #[error("Encryption failed")]
13 EncryptionFailed,
14
15 #[error("Decryption failed")]
16 DecryptionFailed,
17
18 #[error("Invalid key length: expected 32, got {0}")]
19 InvalidKeyLength(usize),
20
21 #[error("Invalid nonce length: expected 12, got {0}")]
22 InvalidNonceLength(usize),
23
24 #[error("Ciphertext too short")]
25 CiphertextTooShort,
26}
27
28pub struct Aes256GcmCipher {
30 cipher: Aes256Gcm,
31}
32
33impl Aes256GcmCipher {
34 pub fn new(key: &[u8; 32]) -> Self {
36 let cipher = Aes256Gcm::new_from_slice(key).expect("key length is 32");
37 Self { cipher }
38 }
39
40 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, AesError> {
43 use rand::RngCore;
44
45 let mut nonce_bytes = [0u8; 12];
46 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
47 let nonce = Nonce::from_slice(&nonce_bytes);
48
49 let ciphertext = self
50 .cipher
51 .encrypt(nonce, plaintext)
52 .map_err(|_| AesError::EncryptionFailed)?;
53
54 let mut result = Vec::with_capacity(12 + ciphertext.len());
55 result.extend_from_slice(&nonce_bytes);
56 result.extend_from_slice(&ciphertext);
57
58 Ok(result)
59 }
60
61 pub fn encrypt_with_nonce(
63 &self,
64 nonce: &[u8; 12],
65 plaintext: &[u8],
66 ) -> Result<Vec<u8>, AesError> {
67 let nonce = Nonce::from_slice(nonce);
68 self.cipher
69 .encrypt(nonce, plaintext)
70 .map_err(|_| AesError::EncryptionFailed)
71 }
72
73 pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, AesError> {
75 if data.len() < 12 {
76 return Err(AesError::CiphertextTooShort);
77 }
78
79 let (nonce_bytes, ciphertext) = data.split_at(12);
80 let nonce = Nonce::from_slice(nonce_bytes);
81
82 self.cipher
83 .decrypt(nonce, ciphertext)
84 .map_err(|_| AesError::DecryptionFailed)
85 }
86
87 pub fn decrypt_with_nonce(
89 &self,
90 nonce: &[u8; 12],
91 ciphertext: &[u8],
92 ) -> Result<Vec<u8>, AesError> {
93 let nonce = Nonce::from_slice(nonce);
94 self.cipher
95 .decrypt(nonce, ciphertext)
96 .map_err(|_| AesError::DecryptionFailed)
97 }
98}
99
100pub fn derive_aes_key(shared_secret: &[u8; 32]) -> [u8; 32] {
102 use sha2::{Digest, Sha256};
103
104 let mut hasher = Sha256::new();
105 hasher.update(b"APFSDS-AES-KEY-DERIVE");
106 hasher.update(shared_secret);
107 hasher.finalize().into()
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn test_encrypt_decrypt() {
116 let key = [0u8; 32];
117 let cipher = Aes256GcmCipher::new(&key);
118
119 let plaintext = b"Hello, APFSDS!";
120 let encrypted = cipher.encrypt(plaintext).unwrap();
121 let decrypted = cipher.decrypt(&encrypted).unwrap();
122
123 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
124 }
125
126 #[test]
127 fn test_decrypt_wrong_key() {
128 let key1 = [0u8; 32];
129 let key2 = [1u8; 32];
130
131 let cipher1 = Aes256GcmCipher::new(&key1);
132 let cipher2 = Aes256GcmCipher::new(&key2);
133
134 let encrypted = cipher1.encrypt(b"secret").unwrap();
135 let result = cipher2.decrypt(&encrypted);
136
137 assert!(result.is_err());
138 }
139
140 #[test]
141 fn test_key_derivation() {
142 let shared_secret = [42u8; 32];
143 let key1 = derive_aes_key(&shared_secret);
144 let key2 = derive_aes_key(&shared_secret);
145
146 assert_eq!(key1, key2);
147
148 let other_secret = [43u8; 32];
150 let key3 = derive_aes_key(&other_secret);
151 assert_ne!(key1, key3);
152 }
153}