1use aes_gcm::{
9 aead::{Aead, KeyInit, OsRng},
10 Aes256Gcm, Key, Nonce,
11};
12use anyhow::{anyhow, Result};
13use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
14use serde::{Deserialize, Serialize};
15use sha2::{Digest, Sha256};
16use std::path::Path;
17
18const NONCE_SIZE: usize = 12;
23const KEY_SIZE: usize = 32;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct EncryptedData {
28 pub ciphertext: String,
30 pub nonce: String,
32 pub algorithm: String,
34 pub kdf: KeyDerivation,
36 pub version: u8,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct KeyDerivation {
43 pub algorithm: String,
45 pub salt: String,
47 pub iterations: Option<u32>,
49 pub memory: Option<u32>,
51 pub time: Option<u32>,
53}
54
55pub struct EncryptionManager {
61 key: Key<Aes256Gcm>,
63 enabled: bool,
65}
66
67impl EncryptionManager {
68 pub fn new(password: &str, salt: &[u8]) -> Result<Self> {
70 let key = Self::derive_key(password, salt)?;
71 Ok(Self { key, enabled: true })
72 }
73
74 pub fn disabled() -> Self {
76 Self {
77 key: Key::<Aes256Gcm>::default(),
78 enabled: false,
79 }
80 }
81
82 pub fn is_enabled(&self) -> bool {
84 self.enabled
85 }
86
87 fn derive_key(password: &str, salt: &[u8]) -> Result<Key<Aes256Gcm>> {
89 let mut key = [0u8; KEY_SIZE];
91
92 pbkdf2::pbkdf2_hmac::<sha2::Sha256>(password.as_bytes(), salt, 100_000, &mut key);
93
94 Ok(*Key::<Aes256Gcm>::from_slice(&key))
95 }
96
97 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData> {
99 if !self.enabled {
100 return Err(anyhow!("Encryption is not enabled"));
101 }
102
103 let cipher = Aes256Gcm::new(&self.key);
104
105 let nonce_bytes: [u8; NONCE_SIZE] = rand::random();
107 let nonce = Nonce::from_slice(&nonce_bytes);
108
109 let ciphertext = cipher
111 .encrypt(nonce, plaintext)
112 .map_err(|e| anyhow!("Encryption failed: {}", e))?;
113
114 Ok(EncryptedData {
115 ciphertext: BASE64.encode(&ciphertext),
116 nonce: BASE64.encode(nonce_bytes),
117 algorithm: "AES-256-GCM".to_string(),
118 kdf: KeyDerivation {
119 algorithm: "PBKDF2-HMAC-SHA256".to_string(),
120 salt: String::new(), iterations: Some(100_000),
122 memory: None,
123 time: None,
124 },
125 version: 1,
126 })
127 }
128
129 pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>> {
131 if !self.enabled {
132 return Err(anyhow!("Encryption is not enabled"));
133 }
134
135 if encrypted.version != 1 {
136 return Err(anyhow!(
137 "Unsupported encryption version: {}",
138 encrypted.version
139 ));
140 }
141
142 let cipher = Aes256Gcm::new(&self.key);
143
144 let ciphertext = BASE64
146 .decode(&encrypted.ciphertext)
147 .map_err(|e| anyhow!("Invalid ciphertext encoding: {}", e))?;
148 let nonce_bytes = BASE64
149 .decode(&encrypted.nonce)
150 .map_err(|e| anyhow!("Invalid nonce encoding: {}", e))?;
151
152 if nonce_bytes.len() != NONCE_SIZE {
153 return Err(anyhow!("Invalid nonce size"));
154 }
155
156 let nonce = Nonce::from_slice(&nonce_bytes);
157
158 let plaintext = cipher
160 .decrypt(nonce, ciphertext.as_ref())
161 .map_err(|e| anyhow!("Decryption failed: {}", e))?;
162
163 Ok(plaintext)
164 }
165
166 pub fn encrypt_string(&self, plaintext: &str) -> Result<String> {
168 let encrypted = self.encrypt(plaintext.as_bytes())?;
169 Ok(serde_json::to_string(&encrypted)?)
170 }
171
172 pub fn decrypt_string(&self, encrypted_json: &str) -> Result<String> {
174 let encrypted: EncryptedData = serde_json::from_str(encrypted_json)?;
175 let plaintext = self.decrypt(&encrypted)?;
176 String::from_utf8(plaintext).map_err(|e| anyhow!("Invalid UTF-8: {}", e))
177 }
178}
179
180pub fn encrypt_messages(manager: &EncryptionManager, messages_json: &str) -> Result<String> {
186 if !manager.is_enabled() {
187 return Ok(messages_json.to_string());
188 }
189 manager.encrypt_string(messages_json)
190}
191
192pub fn decrypt_messages(manager: &EncryptionManager, encrypted_messages: &str) -> Result<String> {
194 if !manager.is_enabled() {
195 return Ok(encrypted_messages.to_string());
196 }
197
198 if encrypted_messages.starts_with('{') && encrypted_messages.contains("\"ciphertext\"") {
200 manager.decrypt_string(encrypted_messages)
201 } else {
202 Ok(encrypted_messages.to_string())
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct EncryptionConfig {
213 pub enabled: bool,
215 pub salt: String,
217 pub password_hash: String,
219}
220
221impl EncryptionConfig {
222 pub fn new(password: &str) -> Self {
224 let salt: [u8; 32] = rand::random();
225 let password_hash = Self::hash_password(password, &salt);
226
227 Self {
228 enabled: true,
229 salt: BASE64.encode(salt),
230 password_hash,
231 }
232 }
233
234 pub fn disabled() -> Self {
236 Self {
237 enabled: false,
238 salt: String::new(),
239 password_hash: String::new(),
240 }
241 }
242
243 fn hash_password(password: &str, salt: &[u8]) -> String {
245 let mut hasher = Sha256::new();
246 hasher.update(password.as_bytes());
247 hasher.update(salt);
248 hasher.update(b"verification");
249 BASE64.encode(hasher.finalize())
250 }
251
252 pub fn verify_password(&self, password: &str) -> bool {
254 if !self.enabled {
255 return true;
256 }
257
258 if let Ok(salt) = BASE64.decode(&self.salt) {
259 let hash = Self::hash_password(password, &salt);
260 hash == self.password_hash
261 } else {
262 false
263 }
264 }
265
266 pub fn get_salt(&self) -> Result<Vec<u8>> {
268 BASE64
269 .decode(&self.salt)
270 .map_err(|e| anyhow!("Invalid salt: {}", e))
271 }
272
273 pub fn load(path: &Path) -> Result<Self> {
275 let content = std::fs::read_to_string(path)?;
276 Ok(serde_json::from_str(&content)?)
277 }
278
279 pub fn save(&self, path: &Path) -> Result<()> {
281 let content = serde_json::to_string_pretty(self)?;
282 if let Some(parent) = path.parent() {
283 std::fs::create_dir_all(parent)?;
284 }
285 std::fs::write(path, content)?;
286 Ok(())
287 }
288}
289
290#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_encrypt_decrypt() {
300 let password = "test_password_123";
301 let salt = b"test_salt_12345678901234";
302
303 let manager = EncryptionManager::new(password, salt).unwrap();
304
305 let plaintext = "Hello, encrypted world!";
306 let encrypted = manager.encrypt(plaintext.as_bytes()).unwrap();
307
308 assert!(!encrypted.ciphertext.is_empty());
309 assert!(!encrypted.nonce.is_empty());
310 assert_eq!(encrypted.algorithm, "AES-256-GCM");
311
312 let decrypted = manager.decrypt(&encrypted).unwrap();
313 assert_eq!(String::from_utf8(decrypted).unwrap(), plaintext);
314 }
315
316 #[test]
317 fn test_encrypt_decrypt_string() {
318 let password = "secure_password";
319 let salt = b"random_salt_value_here";
320
321 let manager = EncryptionManager::new(password, salt).unwrap();
322
323 let original = r#"{"role": "user", "content": "Secret message"}"#;
324 let encrypted = manager.encrypt_string(original).unwrap();
325 let decrypted = manager.decrypt_string(&encrypted).unwrap();
326
327 assert_eq!(decrypted, original);
328 }
329
330 #[test]
331 fn test_password_verification() {
332 let config = EncryptionConfig::new("my_password");
333
334 assert!(config.verify_password("my_password"));
335 assert!(!config.verify_password("wrong_password"));
336 }
337
338 #[test]
339 fn test_disabled_encryption() {
340 let manager = EncryptionManager::disabled();
341 assert!(!manager.is_enabled());
342
343 let config = EncryptionConfig::disabled();
344 assert!(!config.enabled);
345 assert!(config.verify_password("any_password"));
346 }
347}