cortexai_encryption/
aes_cipher.rs1#[cfg(feature = "aes")]
4use aes_gcm::{
5 aead::{Aead, KeyInit},
6 Aes256Gcm, Nonce,
7};
8use rand::RngCore;
9
10use crate::error::{CryptoError, CryptoResult};
11use crate::key::EncryptionKey;
12use crate::traits::Cipher;
13
14#[cfg(feature = "aes")]
27pub struct Aes256GcmCipher {
28 cipher: Aes256Gcm,
29}
30
31#[cfg(feature = "aes")]
32impl Aes256GcmCipher {
33 pub const KEY_SIZE: usize = 32;
35 pub const NONCE_SIZE: usize = 12;
37 pub const TAG_SIZE: usize = 16;
39
40 pub fn new(key: &EncryptionKey) -> CryptoResult<Self> {
42 if key.len() != Self::KEY_SIZE {
43 return Err(CryptoError::InvalidKeyLength {
44 expected: Self::KEY_SIZE,
45 got: key.len(),
46 });
47 }
48
49 let cipher = Aes256Gcm::new_from_slice(key.as_bytes())
50 .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
51
52 Ok(Self { cipher })
53 }
54
55 fn generate_nonce() -> [u8; Self::NONCE_SIZE] {
57 let mut nonce = [0u8; Self::NONCE_SIZE];
58 rand::thread_rng().fill_bytes(&mut nonce);
59 nonce
60 }
61}
62
63#[cfg(feature = "aes")]
64impl Cipher for Aes256GcmCipher {
65 fn encrypt(&self, plaintext: &[u8], associated_data: Option<&[u8]>) -> CryptoResult<Vec<u8>> {
66 let nonce_bytes = Self::generate_nonce();
67 let nonce = Nonce::from_slice(&nonce_bytes);
68
69 let ciphertext = if let Some(aad) = associated_data {
70 use aes_gcm::aead::Payload;
71 self.cipher
72 .encrypt(
73 nonce,
74 Payload {
75 msg: plaintext,
76 aad,
77 },
78 )
79 .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?
80 } else {
81 self.cipher
82 .encrypt(nonce, plaintext)
83 .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?
84 };
85
86 let mut result = Vec::with_capacity(Self::NONCE_SIZE + ciphertext.len());
88 result.extend_from_slice(&nonce_bytes);
89 result.extend_from_slice(&ciphertext);
90
91 Ok(result)
92 }
93
94 fn decrypt(&self, ciphertext: &[u8], associated_data: Option<&[u8]>) -> CryptoResult<Vec<u8>> {
95 if ciphertext.len() < Self::NONCE_SIZE + Self::TAG_SIZE {
96 return Err(CryptoError::InvalidCiphertext(
97 "ciphertext too short".to_string(),
98 ));
99 }
100
101 let (nonce_bytes, encrypted) = ciphertext.split_at(Self::NONCE_SIZE);
102 let nonce = Nonce::from_slice(nonce_bytes);
103
104 let plaintext = if let Some(aad) = associated_data {
105 use aes_gcm::aead::Payload;
106 self.cipher
107 .decrypt(
108 nonce,
109 Payload {
110 msg: encrypted,
111 aad,
112 },
113 )
114 .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?
115 } else {
116 self.cipher
117 .decrypt(nonce, encrypted)
118 .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?
119 };
120
121 Ok(plaintext)
122 }
123
124 fn algorithm(&self) -> &'static str {
125 "AES-256-GCM"
126 }
127
128 fn key_size(&self) -> usize {
129 Self::KEY_SIZE
130 }
131
132 fn nonce_size(&self) -> usize {
133 Self::NONCE_SIZE
134 }
135
136 fn tag_size(&self) -> usize {
137 Self::TAG_SIZE
138 }
139}
140
141#[cfg(all(test, feature = "aes"))]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_encrypt_decrypt_roundtrip() {
147 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
148 let cipher = Aes256GcmCipher::new(&key).unwrap();
149
150 let plaintext = b"Hello, World! This is a secret message.";
151 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
152 let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
153
154 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
155 }
156
157 #[test]
158 fn test_encrypt_decrypt_with_aad() {
159 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
160 let cipher = Aes256GcmCipher::new(&key).unwrap();
161
162 let plaintext = b"Secret data";
163 let aad = b"session-id-12345";
164
165 let ciphertext = cipher.encrypt(plaintext, Some(aad)).unwrap();
166 let decrypted = cipher.decrypt(&ciphertext, Some(aad)).unwrap();
167
168 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
169 }
170
171 #[test]
172 fn test_aad_mismatch_fails() {
173 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
174 let cipher = Aes256GcmCipher::new(&key).unwrap();
175
176 let plaintext = b"Secret data";
177 let aad1 = b"session-id-12345";
178 let aad2 = b"session-id-67890";
179
180 let ciphertext = cipher.encrypt(plaintext, Some(aad1)).unwrap();
181 let result = cipher.decrypt(&ciphertext, Some(aad2));
182
183 assert!(result.is_err());
184 }
185
186 #[test]
187 fn test_tampered_ciphertext_fails() {
188 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
189 let cipher = Aes256GcmCipher::new(&key).unwrap();
190
191 let plaintext = b"Secret data";
192 let mut ciphertext = cipher.encrypt(plaintext, None).unwrap();
193
194 if let Some(byte) = ciphertext.last_mut() {
196 *byte ^= 0xFF;
197 }
198
199 let result = cipher.decrypt(&ciphertext, None);
200 assert!(result.is_err());
201 }
202
203 #[test]
204 fn test_different_keys_fail() {
205 let key1 = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
206 let key2 = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
207
208 let cipher1 = Aes256GcmCipher::new(&key1).unwrap();
209 let cipher2 = Aes256GcmCipher::new(&key2).unwrap();
210
211 let plaintext = b"Secret data";
212 let ciphertext = cipher1.encrypt(plaintext, None).unwrap();
213
214 let result = cipher2.decrypt(&ciphertext, None);
215 assert!(result.is_err());
216 }
217
218 #[test]
219 fn test_invalid_key_length() {
220 let key = EncryptionKey::generate(16); let result = Aes256GcmCipher::new(&key);
222
223 assert!(matches!(
224 result,
225 Err(CryptoError::InvalidKeyLength {
226 expected: 32,
227 got: 16
228 })
229 ));
230 }
231
232 #[test]
233 fn test_ciphertext_format() {
234 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
235 let cipher = Aes256GcmCipher::new(&key).unwrap();
236
237 let plaintext = b"Test";
238 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
239
240 assert_eq!(
242 ciphertext.len(),
243 Aes256GcmCipher::NONCE_SIZE + plaintext.len() + Aes256GcmCipher::TAG_SIZE
244 );
245 }
246
247 #[test]
248 fn test_empty_plaintext() {
249 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
250 let cipher = Aes256GcmCipher::new(&key).unwrap();
251
252 let plaintext = b"";
253 let ciphertext = cipher.encrypt(plaintext, None).unwrap();
254 let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
255
256 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
257 }
258
259 #[test]
260 fn test_large_plaintext() {
261 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
262 let cipher = Aes256GcmCipher::new(&key).unwrap();
263
264 let plaintext = vec![0xABu8; 1024 * 1024]; let ciphertext = cipher.encrypt(&plaintext, None).unwrap();
266 let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
267
268 assert_eq!(plaintext, decrypted);
269 }
270}