brainos_storage/
encryption.rs1use aes_gcm::{
8 aead::{Aead, KeyInit, OsRng},
9 Aes256Gcm, Key, Nonce,
10};
11use argon2::Argon2;
12use thiserror::Error;
13
14const NONCE_SIZE: usize = 12;
16
17const SALT_SIZE: usize = 16;
19
20#[derive(Debug, Error)]
22pub enum EncryptionError {
23 #[error("Encryption failed: {0}")]
24 EncryptFailed(String),
25
26 #[error("Decryption failed: {0}")]
27 DecryptFailed(String),
28
29 #[error("Key derivation failed: {0}")]
30 KeyDerivation(String),
31
32 #[error("Invalid data format")]
33 InvalidFormat,
34}
35
36#[derive(Clone)]
41pub struct Encryptor {
42 key: Key<Aes256Gcm>,
43}
44
45impl Encryptor {
46 pub fn from_key(key: [u8; 32]) -> Self {
48 Self {
49 key: Key::<Aes256Gcm>::from(key),
50 }
51 }
52
53 pub fn from_passphrase(passphrase: &str, salt: &[u8]) -> Result<Self, EncryptionError> {
57 let mut key = [0u8; 32];
58
59 Argon2::default()
60 .hash_password_into(passphrase.as_bytes(), salt, &mut key)
61 .map_err(|e| EncryptionError::KeyDerivation(e.to_string()))?;
62
63 Ok(Self::from_key(key))
64 }
65
66 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
70 use aes_gcm::aead::rand_core::RngCore;
71
72 let cipher = Aes256Gcm::new(&self.key);
73
74 let mut nonce_bytes = [0u8; NONCE_SIZE];
76 OsRng.fill_bytes(&mut nonce_bytes);
77 let nonce = Nonce::from_slice(&nonce_bytes);
78
79 let ciphertext = cipher
80 .encrypt(nonce, plaintext)
81 .map_err(|e| EncryptionError::EncryptFailed(e.to_string()))?;
82
83 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
85 result.extend_from_slice(&nonce_bytes);
86 result.extend_from_slice(&ciphertext);
87
88 Ok(result)
89 }
90
91 pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
93 if data.len() < NONCE_SIZE {
94 return Err(EncryptionError::InvalidFormat);
95 }
96
97 let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
98 let nonce = Nonce::from_slice(nonce_bytes);
99 let cipher = Aes256Gcm::new(&self.key);
100
101 cipher
102 .decrypt(nonce, ciphertext)
103 .map_err(|e| EncryptionError::DecryptFailed(e.to_string()))
104 }
105
106 pub fn encrypt_string(&self, plaintext: &str) -> Result<String, EncryptionError> {
108 use base64::Engine;
109 let encrypted = self.encrypt(plaintext.as_bytes())?;
110 Ok(base64::engine::general_purpose::STANDARD.encode(encrypted))
111 }
112
113 pub fn decrypt_string(&self, encoded: &str) -> Result<String, EncryptionError> {
115 use base64::Engine;
116 let data = base64::engine::general_purpose::STANDARD
117 .decode(encoded)
118 .map_err(|_e| EncryptionError::InvalidFormat)?;
119 let decrypted = self.decrypt(&data)?;
120 String::from_utf8(decrypted)
121 .map_err(|e| EncryptionError::DecryptFailed(format!("Invalid UTF-8: {e}")))
122 }
123
124 pub fn generate_salt() -> [u8; SALT_SIZE] {
126 use aes_gcm::aead::rand_core::RngCore;
127 let mut salt = [0u8; SALT_SIZE];
128 OsRng.fill_bytes(&mut salt);
129 salt
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 fn test_encryptor() -> Encryptor {
138 Encryptor::from_key([42u8; 32])
139 }
140
141 #[test]
142 fn test_encrypt_decrypt_roundtrip() {
143 let enc = test_encryptor();
144 let plaintext = b"Hello, Brain!";
145 let ciphertext = enc.encrypt(plaintext).unwrap();
146 let decrypted = enc.decrypt(&ciphertext).unwrap();
147 assert_eq!(decrypted, plaintext);
148 }
149
150 #[test]
151 fn test_encrypt_produces_different_nonces() {
152 let enc = test_encryptor();
153 let a = enc.encrypt(b"same data").unwrap();
154 let b = enc.encrypt(b"same data").unwrap();
155 assert_ne!(a, b);
157 assert_eq!(enc.decrypt(&a).unwrap(), enc.decrypt(&b).unwrap());
159 }
160
161 #[test]
162 fn test_decrypt_wrong_key_fails() {
163 let enc1 = Encryptor::from_key([1u8; 32]);
164 let enc2 = Encryptor::from_key([2u8; 32]);
165 let ciphertext = enc1.encrypt(b"secret").unwrap();
166 assert!(enc2.decrypt(&ciphertext).is_err());
167 }
168
169 #[test]
170 fn test_decrypt_truncated_fails() {
171 let enc = test_encryptor();
172 assert!(enc.decrypt(&[0u8; 5]).is_err()); }
174
175 #[test]
176 fn test_string_roundtrip() {
177 let enc = test_encryptor();
178 let original = "Keshav likes Rust";
179 let encrypted = enc.encrypt_string(original).unwrap();
180 let decrypted = enc.decrypt_string(&encrypted).unwrap();
181 assert_eq!(decrypted, original);
182 }
183
184 #[test]
185 fn test_passphrase_derivation() {
186 let salt = Encryptor::generate_salt();
187 let enc = Encryptor::from_passphrase("my-strong-passphrase", &salt).unwrap();
188 let ciphertext = enc.encrypt(b"test data").unwrap();
189
190 let enc2 = Encryptor::from_passphrase("my-strong-passphrase", &salt).unwrap();
192 assert_eq!(enc2.decrypt(&ciphertext).unwrap(), b"test data");
193
194 let enc3 = Encryptor::from_passphrase("wrongpassphrase", &salt).unwrap();
196 assert!(enc3.decrypt(&ciphertext).is_err());
197 }
198}