Skip to main content

mnemo_core/
encryption.rs

1//! AES-256-GCM encryption for memory content at rest.
2//!
3//! Provides encrypt/decrypt operations for memory content before storage.
4//! The encryption key is loaded from an environment variable or passed directly.
5
6use crate::error::{Error, Result};
7
8use aes_gcm::aead::rand_core::RngCore;
9use aes_gcm::{
10    Aes256Gcm, Key, Nonce,
11    aead::{Aead, KeyInit, OsRng},
12};
13
14/// AES-256-GCM encryption provider for at-rest memory content.
15pub struct ContentEncryption {
16    key: [u8; 32],
17}
18
19impl ContentEncryption {
20    /// Create from a 32-byte key.
21    pub fn new(key: [u8; 32]) -> Self {
22        Self { key }
23    }
24
25    /// Create from a hex-encoded key string (64 hex chars = 32 bytes).
26    pub fn from_hex(hex_key: &str) -> Result<Self> {
27        let bytes =
28            hex::decode(hex_key).map_err(|e| Error::Validation(format!("invalid hex key: {e}")))?;
29        if bytes.len() != 32 {
30            return Err(Error::Validation(format!(
31                "key must be 32 bytes, got {}",
32                bytes.len()
33            )));
34        }
35        let mut key = [0u8; 32];
36        key.copy_from_slice(&bytes);
37        Ok(Self { key })
38    }
39
40    /// Create from the `MNEMO_ENCRYPTION_KEY` environment variable.
41    pub fn from_env() -> Result<Self> {
42        let hex_key = std::env::var("MNEMO_ENCRYPTION_KEY")
43            .map_err(|_| Error::Validation("MNEMO_ENCRYPTION_KEY not set".to_string()))?;
44        Self::from_hex(&hex_key)
45    }
46
47    /// Encrypt plaintext content. Returns `nonce(12) || ciphertext+tag` as bytes.
48    ///
49    /// Uses AES-256-GCM with a random 12-byte nonce.
50    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
51        let key = Key::<Aes256Gcm>::from_slice(&self.key);
52        let cipher = Aes256Gcm::new(key);
53
54        let mut nonce_bytes = [0u8; 12];
55        OsRng.fill_bytes(&mut nonce_bytes);
56        let nonce = Nonce::from_slice(&nonce_bytes);
57
58        let ciphertext = cipher
59            .encrypt(nonce, plaintext)
60            .map_err(|e| Error::Internal(format!("encryption failed: {e}")))?;
61
62        let mut output = Vec::with_capacity(12 + ciphertext.len());
63        output.extend_from_slice(&nonce_bytes);
64        output.extend_from_slice(&ciphertext);
65        Ok(output)
66    }
67
68    /// Decrypt content encrypted by [`encrypt`].
69    pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
70        if data.len() < 28 {
71            // 12 nonce + 16 tag minimum
72            return Err(Error::Validation("encrypted data too short".to_string()));
73        }
74
75        let key = Key::<Aes256Gcm>::from_slice(&self.key);
76        let cipher = Aes256Gcm::new(key);
77
78        let nonce = Nonce::from_slice(&data[..12]);
79        let ciphertext = &data[12..];
80
81        cipher
82            .decrypt(nonce, ciphertext)
83            .map_err(|_| Error::Validation("decryption tag mismatch".to_string()))
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_encryption_round_trip() {
93        let key = [0x42u8; 32];
94        let enc = ContentEncryption::new(key);
95
96        let plaintext = b"Hello, encrypted world!";
97        let encrypted = enc.encrypt(plaintext).unwrap();
98
99        assert_ne!(&encrypted[12..encrypted.len() - 16], plaintext);
100
101        let decrypted = enc.decrypt(&encrypted).unwrap();
102        assert_eq!(decrypted, plaintext);
103    }
104
105    #[test]
106    fn test_encryption_from_hex() {
107        let hex_key = "42".repeat(32);
108        let enc = ContentEncryption::from_hex(&hex_key).unwrap();
109
110        let plaintext = b"test";
111        let encrypted = enc.encrypt(plaintext).unwrap();
112        let decrypted = enc.decrypt(&encrypted).unwrap();
113        assert_eq!(decrypted, plaintext);
114    }
115
116    #[test]
117    fn test_invalid_hex_key_length() {
118        let result = ContentEncryption::from_hex("abcd");
119        assert!(result.is_err());
120    }
121
122    #[test]
123    fn test_tampered_ciphertext_fails() {
124        let key = [0x42u8; 32];
125        let enc = ContentEncryption::new(key);
126
127        let encrypted = enc.encrypt(b"secret data").unwrap();
128        let mut tampered = encrypted.clone();
129        tampered[15] ^= 0xff; // flip a byte in the ciphertext
130
131        let result = enc.decrypt(&tampered);
132        assert!(result.is_err());
133    }
134
135    #[test]
136    fn test_aes_gcm_round_trip() {
137        let key = [0xABu8; 32];
138        let enc = ContentEncryption::new(key);
139
140        // Test various sizes
141        for size in [0, 1, 16, 100, 1024, 65536] {
142            let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
143            let encrypted = enc.encrypt(&plaintext).unwrap();
144            let decrypted = enc.decrypt(&encrypted).unwrap();
145            assert_eq!(decrypted, plaintext, "round-trip failed for size {size}");
146        }
147    }
148
149    #[test]
150    fn test_aes_gcm_tamper_detection() {
151        let key = [0xCDu8; 32];
152        let enc = ContentEncryption::new(key);
153        let encrypted = enc.encrypt(b"sensitive data").unwrap();
154
155        // Tamper with nonce
156        let mut tampered = encrypted.clone();
157        tampered[0] ^= 0x01;
158        assert!(enc.decrypt(&tampered).is_err());
159
160        // Tamper with ciphertext body
161        let mut tampered = encrypted.clone();
162        tampered[14] ^= 0x01;
163        assert!(enc.decrypt(&tampered).is_err());
164
165        // Tamper with last byte (tag)
166        let mut tampered = encrypted.clone();
167        let last = tampered.len() - 1;
168        tampered[last] ^= 0x01;
169        assert!(enc.decrypt(&tampered).is_err());
170    }
171}