Skip to main content

cortexai_encryption/
aes_cipher.rs

1//! AES-256-GCM authenticated encryption implementation.
2
3#[cfg(feature = "aes")]
4use aes_gcm::{
5    aead::{Aead, KeyInit},
6    Aes256Gcm, Nonce,
7};
8use rand::RngCore;
9
10use crate::error::{CryptoError, CryptoResult};
11use crate::key::EncryptionKey;
12use crate::traits::Cipher;
13
14/// AES-256-GCM cipher implementation.
15///
16/// Provides authenticated encryption with associated data (AEAD).
17/// - Key size: 256 bits (32 bytes)
18/// - Nonce size: 96 bits (12 bytes)
19/// - Tag size: 128 bits (16 bytes)
20///
21/// # Ciphertext Format
22///
23/// ```text
24/// [nonce: 12 bytes][ciphertext + tag: variable]
25/// ```
26#[cfg(feature = "aes")]
27pub struct Aes256GcmCipher {
28    cipher: Aes256Gcm,
29}
30
31#[cfg(feature = "aes")]
32impl Aes256GcmCipher {
33    /// Key size in bytes (256 bits).
34    pub const KEY_SIZE: usize = 32;
35    /// Nonce size in bytes (96 bits).
36    pub const NONCE_SIZE: usize = 12;
37    /// Authentication tag size in bytes (128 bits).
38    pub const TAG_SIZE: usize = 16;
39
40    /// Create a new AES-256-GCM cipher with the given key.
41    pub fn new(key: &EncryptionKey) -> CryptoResult<Self> {
42        if key.len() != Self::KEY_SIZE {
43            return Err(CryptoError::InvalidKeyLength {
44                expected: Self::KEY_SIZE,
45                got: key.len(),
46            });
47        }
48
49        let cipher = Aes256Gcm::new_from_slice(key.as_bytes())
50            .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
51
52        Ok(Self { cipher })
53    }
54
55    /// Generate a random nonce.
56    fn generate_nonce() -> [u8; Self::NONCE_SIZE] {
57        let mut nonce = [0u8; Self::NONCE_SIZE];
58        rand::thread_rng().fill_bytes(&mut nonce);
59        nonce
60    }
61}
62
63#[cfg(feature = "aes")]
64impl Cipher for Aes256GcmCipher {
65    fn encrypt(&self, plaintext: &[u8], associated_data: Option<&[u8]>) -> CryptoResult<Vec<u8>> {
66        let nonce_bytes = Self::generate_nonce();
67        let nonce = Nonce::from_slice(&nonce_bytes);
68
69        let ciphertext = if let Some(aad) = associated_data {
70            use aes_gcm::aead::Payload;
71            self.cipher
72                .encrypt(
73                    nonce,
74                    Payload {
75                        msg: plaintext,
76                        aad,
77                    },
78                )
79                .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?
80        } else {
81            self.cipher
82                .encrypt(nonce, plaintext)
83                .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?
84        };
85
86        // Prepend nonce to ciphertext
87        let mut result = Vec::with_capacity(Self::NONCE_SIZE + ciphertext.len());
88        result.extend_from_slice(&nonce_bytes);
89        result.extend_from_slice(&ciphertext);
90
91        Ok(result)
92    }
93
94    fn decrypt(&self, ciphertext: &[u8], associated_data: Option<&[u8]>) -> CryptoResult<Vec<u8>> {
95        if ciphertext.len() < Self::NONCE_SIZE + Self::TAG_SIZE {
96            return Err(CryptoError::InvalidCiphertext(
97                "ciphertext too short".to_string(),
98            ));
99        }
100
101        let (nonce_bytes, encrypted) = ciphertext.split_at(Self::NONCE_SIZE);
102        let nonce = Nonce::from_slice(nonce_bytes);
103
104        let plaintext = if let Some(aad) = associated_data {
105            use aes_gcm::aead::Payload;
106            self.cipher
107                .decrypt(
108                    nonce,
109                    Payload {
110                        msg: encrypted,
111                        aad,
112                    },
113                )
114                .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?
115        } else {
116            self.cipher
117                .decrypt(nonce, encrypted)
118                .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?
119        };
120
121        Ok(plaintext)
122    }
123
124    fn algorithm(&self) -> &'static str {
125        "AES-256-GCM"
126    }
127
128    fn key_size(&self) -> usize {
129        Self::KEY_SIZE
130    }
131
132    fn nonce_size(&self) -> usize {
133        Self::NONCE_SIZE
134    }
135
136    fn tag_size(&self) -> usize {
137        Self::TAG_SIZE
138    }
139}
140
141#[cfg(all(test, feature = "aes"))]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_encrypt_decrypt_roundtrip() {
147        let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
148        let cipher = Aes256GcmCipher::new(&key).unwrap();
149
150        let plaintext = b"Hello, World! This is a secret message.";
151        let ciphertext = cipher.encrypt(plaintext, None).unwrap();
152        let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
153
154        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
155    }
156
157    #[test]
158    fn test_encrypt_decrypt_with_aad() {
159        let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
160        let cipher = Aes256GcmCipher::new(&key).unwrap();
161
162        let plaintext = b"Secret data";
163        let aad = b"session-id-12345";
164
165        let ciphertext = cipher.encrypt(plaintext, Some(aad)).unwrap();
166        let decrypted = cipher.decrypt(&ciphertext, Some(aad)).unwrap();
167
168        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
169    }
170
171    #[test]
172    fn test_aad_mismatch_fails() {
173        let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
174        let cipher = Aes256GcmCipher::new(&key).unwrap();
175
176        let plaintext = b"Secret data";
177        let aad1 = b"session-id-12345";
178        let aad2 = b"session-id-67890";
179
180        let ciphertext = cipher.encrypt(plaintext, Some(aad1)).unwrap();
181        let result = cipher.decrypt(&ciphertext, Some(aad2));
182
183        assert!(result.is_err());
184    }
185
186    #[test]
187    fn test_tampered_ciphertext_fails() {
188        let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
189        let cipher = Aes256GcmCipher::new(&key).unwrap();
190
191        let plaintext = b"Secret data";
192        let mut ciphertext = cipher.encrypt(plaintext, None).unwrap();
193
194        // Tamper with the ciphertext
195        if let Some(byte) = ciphertext.last_mut() {
196            *byte ^= 0xFF;
197        }
198
199        let result = cipher.decrypt(&ciphertext, None);
200        assert!(result.is_err());
201    }
202
203    #[test]
204    fn test_different_keys_fail() {
205        let key1 = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
206        let key2 = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
207
208        let cipher1 = Aes256GcmCipher::new(&key1).unwrap();
209        let cipher2 = Aes256GcmCipher::new(&key2).unwrap();
210
211        let plaintext = b"Secret data";
212        let ciphertext = cipher1.encrypt(plaintext, None).unwrap();
213
214        let result = cipher2.decrypt(&ciphertext, None);
215        assert!(result.is_err());
216    }
217
218    #[test]
219    fn test_invalid_key_length() {
220        let key = EncryptionKey::generate(16); // Too short
221        let result = Aes256GcmCipher::new(&key);
222
223        assert!(matches!(
224            result,
225            Err(CryptoError::InvalidKeyLength {
226                expected: 32,
227                got: 16
228            })
229        ));
230    }
231
232    #[test]
233    fn test_ciphertext_format() {
234        let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
235        let cipher = Aes256GcmCipher::new(&key).unwrap();
236
237        let plaintext = b"Test";
238        let ciphertext = cipher.encrypt(plaintext, None).unwrap();
239
240        // Ciphertext should be: nonce (12) + plaintext (4) + tag (16) = 32 bytes
241        assert_eq!(
242            ciphertext.len(),
243            Aes256GcmCipher::NONCE_SIZE + plaintext.len() + Aes256GcmCipher::TAG_SIZE
244        );
245    }
246
247    #[test]
248    fn test_empty_plaintext() {
249        let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
250        let cipher = Aes256GcmCipher::new(&key).unwrap();
251
252        let plaintext = b"";
253        let ciphertext = cipher.encrypt(plaintext, None).unwrap();
254        let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
255
256        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
257    }
258
259    #[test]
260    fn test_large_plaintext() {
261        let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
262        let cipher = Aes256GcmCipher::new(&key).unwrap();
263
264        let plaintext = vec![0xABu8; 1024 * 1024]; // 1 MB
265        let ciphertext = cipher.encrypt(&plaintext, None).unwrap();
266        let decrypted = cipher.decrypt(&ciphertext, None).unwrap();
267
268        assert_eq!(plaintext, decrypted);
269    }
270}