lore_cli/cloud/
encryption.rs1use aes_gcm::{
9 aead::{Aead, KeyInit},
10 Aes256Gcm, Nonce,
11};
12use argon2::{password_hash::SaltString, Argon2, PasswordHasher};
13use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
14use rand::RngCore;
15
16use super::CloudError;
17
18pub const KEY_SIZE: usize = 32;
20
21pub const NONCE_SIZE: usize = 12;
23
24#[allow(dead_code)]
26pub const SALT_SIZE: usize = 16;
27
28pub fn derive_key(passphrase: &str, salt: &[u8]) -> Result<Vec<u8>, CloudError> {
43 let salt_string = SaltString::encode_b64(salt)
46 .map_err(|e| CloudError::EncryptionError(format!("Invalid salt: {e}")))?;
47
48 let argon2 = Argon2::default();
50
51 let hash = argon2
53 .hash_password(passphrase.as_bytes(), &salt_string)
54 .map_err(|e| CloudError::EncryptionError(format!("Key derivation failed: {e}")))?;
55
56 let hash_output = hash
58 .hash
59 .ok_or_else(|| CloudError::EncryptionError("No hash output".to_string()))?;
60
61 let key_bytes = hash_output.as_bytes();
63 if key_bytes.len() < KEY_SIZE {
64 return Err(CloudError::EncryptionError(
65 "Derived key too short".to_string(),
66 ));
67 }
68
69 Ok(key_bytes[..KEY_SIZE].to_vec())
70}
71
72#[allow(dead_code)]
77pub fn generate_salt() -> Vec<u8> {
78 let mut salt = vec![0u8; SALT_SIZE];
79 rand::thread_rng().fill_bytes(&mut salt);
80 salt
81}
82
83pub fn encrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, CloudError> {
97 if key.len() != KEY_SIZE {
98 return Err(CloudError::EncryptionError(format!(
99 "Invalid key size: expected {KEY_SIZE}, got {}",
100 key.len()
101 )));
102 }
103
104 let mut nonce_bytes = [0u8; NONCE_SIZE];
106 rand::thread_rng().fill_bytes(&mut nonce_bytes);
107 let nonce = Nonce::from_slice(&nonce_bytes);
108
109 let cipher = Aes256Gcm::new_from_slice(key)
111 .map_err(|e| CloudError::EncryptionError(format!("Cipher creation failed: {e}")))?;
112
113 let ciphertext = cipher
115 .encrypt(nonce, data)
116 .map_err(|e| CloudError::EncryptionError(format!("Encryption failed: {e}")))?;
117
118 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
120 result.extend_from_slice(&nonce_bytes);
121 result.extend_from_slice(&ciphertext);
122
123 Ok(result)
124}
125
126pub fn decrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, CloudError> {
139 if key.len() != KEY_SIZE {
140 return Err(CloudError::EncryptionError(format!(
141 "Invalid key size: expected {KEY_SIZE}, got {}",
142 key.len()
143 )));
144 }
145
146 if data.len() < NONCE_SIZE {
147 return Err(CloudError::EncryptionError(
148 "Encrypted data too short".to_string(),
149 ));
150 }
151
152 let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
154 let nonce = Nonce::from_slice(nonce_bytes);
155
156 let cipher = Aes256Gcm::new_from_slice(key)
158 .map_err(|e| CloudError::EncryptionError(format!("Cipher creation failed: {e}")))?;
159
160 let plaintext = cipher
162 .decrypt(nonce, ciphertext)
163 .map_err(|e| CloudError::EncryptionError(format!("Decryption failed: {e}")))?;
164
165 Ok(plaintext)
166}
167
168pub fn encode_base64(data: &[u8]) -> String {
170 BASE64.encode(data)
171}
172
173pub fn decode_base64(data: &str) -> Result<Vec<u8>, CloudError> {
175 BASE64
176 .decode(data)
177 .map_err(|e| CloudError::EncryptionError(format!("Base64 decode failed: {e}")))
178}
179
180pub fn encode_key_hex(key: &[u8]) -> String {
182 hex::encode(key)
183}
184
185pub fn decode_key_hex(hex_str: &str) -> Result<Vec<u8>, CloudError> {
187 hex::decode(hex_str).map_err(|e| CloudError::EncryptionError(format!("Hex decode failed: {e}")))
188}
189
190mod hex {
192 pub fn encode(data: &[u8]) -> String {
193 data.iter().map(|b| format!("{:02x}", b)).collect()
194 }
195
196 pub fn decode(s: &str) -> Result<Vec<u8>, String> {
197 if !s.len().is_multiple_of(2) {
198 return Err("Hex string has odd length".to_string());
199 }
200
201 (0..s.len())
202 .step_by(2)
203 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("Invalid hex: {e}")))
204 .collect()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_generate_salt_length() {
214 let salt = generate_salt();
215 assert_eq!(salt.len(), SALT_SIZE);
216 }
217
218 #[test]
219 fn test_generate_salt_randomness() {
220 let salt1 = generate_salt();
221 let salt2 = generate_salt();
222 assert_ne!(salt1, salt2);
223 }
224
225 #[test]
226 fn test_derive_key_consistency() {
227 let passphrase = "test passphrase";
228 let salt = generate_salt();
229
230 let key1 = derive_key(passphrase, &salt).unwrap();
231 let key2 = derive_key(passphrase, &salt).unwrap();
232
233 assert_eq!(key1, key2);
234 assert_eq!(key1.len(), KEY_SIZE);
235 }
236
237 #[test]
238 fn test_derive_key_different_passphrases() {
239 let salt = generate_salt();
240
241 let key1 = derive_key("passphrase1", &salt).unwrap();
242 let key2 = derive_key("passphrase2", &salt).unwrap();
243
244 assert_ne!(key1, key2);
245 }
246
247 #[test]
248 fn test_derive_key_different_salts() {
249 let passphrase = "test passphrase";
250 let salt1 = generate_salt();
251 let salt2 = generate_salt();
252
253 let key1 = derive_key(passphrase, &salt1).unwrap();
254 let key2 = derive_key(passphrase, &salt2).unwrap();
255
256 assert_ne!(key1, key2);
257 }
258
259 #[test]
260 fn test_encrypt_decrypt_roundtrip() {
261 let passphrase = "test passphrase";
262 let salt = generate_salt();
263 let key = derive_key(passphrase, &salt).unwrap();
264
265 let plaintext = b"Hello, World! This is a test message.";
266 let encrypted = encrypt_data(plaintext, &key).unwrap();
267 let decrypted = decrypt_data(&encrypted, &key).unwrap();
268
269 assert_eq!(decrypted, plaintext);
270 }
271
272 #[test]
273 fn test_encrypt_produces_different_ciphertext() {
274 let salt = generate_salt();
275 let key = derive_key("passphrase", &salt).unwrap();
276
277 let plaintext = b"test data";
278 let encrypted1 = encrypt_data(plaintext, &key).unwrap();
279 let encrypted2 = encrypt_data(plaintext, &key).unwrap();
280
281 assert_ne!(encrypted1, encrypted2);
283 }
284
285 #[test]
286 fn test_decrypt_with_wrong_key_fails() {
287 let salt = generate_salt();
288 let key1 = derive_key("passphrase1", &salt).unwrap();
289 let key2 = derive_key("passphrase2", &salt).unwrap();
290
291 let plaintext = b"secret data";
292 let encrypted = encrypt_data(plaintext, &key1).unwrap();
293
294 let result = decrypt_data(&encrypted, &key2);
295 assert!(result.is_err());
296 }
297
298 #[test]
299 fn test_decrypt_with_corrupted_data_fails() {
300 let salt = generate_salt();
301 let key = derive_key("passphrase", &salt).unwrap();
302
303 let plaintext = b"secret data";
304 let mut encrypted = encrypt_data(plaintext, &key).unwrap();
305
306 if let Some(byte) = encrypted.get_mut(NONCE_SIZE + 5) {
308 *byte ^= 0xFF;
309 }
310
311 let result = decrypt_data(&encrypted, &key);
312 assert!(result.is_err());
313 }
314
315 #[test]
316 fn test_encrypt_data_invalid_key_size() {
317 let short_key = vec![0u8; 16]; let result = encrypt_data(b"data", &short_key);
319 assert!(result.is_err());
320 }
321
322 #[test]
323 fn test_decrypt_data_too_short() {
324 let salt = generate_salt();
325 let key = derive_key("passphrase", &salt).unwrap();
326
327 let short_data = vec![0u8; 5]; let result = decrypt_data(&short_data, &key);
329 assert!(result.is_err());
330 }
331
332 #[test]
333 fn test_base64_roundtrip() {
334 let data = b"test binary data \x00\x01\x02";
335 let encoded = encode_base64(data);
336 let decoded = decode_base64(&encoded).unwrap();
337 assert_eq!(decoded, data);
338 }
339
340 #[test]
341 fn test_hex_roundtrip() {
342 let data = vec![0u8, 1, 2, 255, 128, 64];
343 let encoded = encode_key_hex(&data);
344 let decoded = decode_key_hex(&encoded).unwrap();
345 assert_eq!(decoded, data);
346 }
347
348 #[test]
349 fn test_hex_encode() {
350 assert_eq!(hex::encode(&[0, 255, 128]), "00ff80");
351 }
352
353 #[test]
354 fn test_hex_decode_invalid() {
355 assert!(hex::decode("xyz").is_err());
356 assert!(hex::decode("abc").is_err()); }
358
359 #[test]
360 fn test_encrypt_empty_data() {
361 let salt = generate_salt();
362 let key = derive_key("passphrase", &salt).unwrap();
363
364 let plaintext = b"";
365 let encrypted = encrypt_data(plaintext, &key).unwrap();
366 let decrypted = decrypt_data(&encrypted, &key).unwrap();
367
368 assert_eq!(decrypted, plaintext);
369 }
370
371 #[test]
372 fn test_encrypt_large_data() {
373 let salt = generate_salt();
374 let key = derive_key("passphrase", &salt).unwrap();
375
376 let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
378 let encrypted = encrypt_data(&plaintext, &key).unwrap();
379 let decrypted = decrypt_data(&encrypted, &key).unwrap();
380
381 assert_eq!(decrypted, plaintext);
382 }
383}