lore_cli/sync/
encryption.rs1use aes_gcm::{
12 aead::{Aead, KeyInit},
13 Aes256Gcm, Nonce,
14};
15use argon2::{password_hash::SaltString, Argon2, PasswordHasher};
16use rand::RngCore;
17
18use super::SyncError;
19
20pub const KEY_SIZE: usize = 32;
22
23pub const NONCE_SIZE: usize = 12;
25
26pub const SALT_SIZE: usize = 16;
28
29pub fn derive_key(passphrase: &str, salt: &[u8]) -> Result<Vec<u8>, SyncError> {
44 let salt_string = SaltString::encode_b64(salt)
47 .map_err(|e| SyncError::Encryption(format!("Invalid salt: {e}")))?;
48
49 let argon2 = Argon2::default();
51
52 let hash = argon2
54 .hash_password(passphrase.as_bytes(), &salt_string)
55 .map_err(|e| SyncError::Encryption(format!("Key derivation failed: {e}")))?;
56
57 let hash_output = hash
59 .hash
60 .ok_or_else(|| SyncError::Encryption("No hash output".to_string()))?;
61
62 let key_bytes = hash_output.as_bytes();
64 if key_bytes.len() < KEY_SIZE {
65 return Err(SyncError::Encryption("Derived key too short".to_string()));
66 }
67
68 Ok(key_bytes[..KEY_SIZE].to_vec())
69}
70
71pub fn generate_salt() -> Vec<u8> {
76 let mut salt = vec![0u8; SALT_SIZE];
77 rand::thread_rng().fill_bytes(&mut salt);
78 salt
79}
80
81pub fn encrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, SyncError> {
95 if key.len() != KEY_SIZE {
96 return Err(SyncError::Encryption(format!(
97 "Invalid key size: expected {KEY_SIZE}, got {}",
98 key.len()
99 )));
100 }
101
102 let mut nonce_bytes = [0u8; NONCE_SIZE];
104 rand::thread_rng().fill_bytes(&mut nonce_bytes);
105 let nonce = Nonce::from_slice(&nonce_bytes);
106
107 let cipher = Aes256Gcm::new_from_slice(key)
109 .map_err(|e| SyncError::Encryption(format!("Cipher creation failed: {e}")))?;
110
111 let ciphertext = cipher
113 .encrypt(nonce, data)
114 .map_err(|e| SyncError::Encryption(format!("Encryption failed: {e}")))?;
115
116 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
118 result.extend_from_slice(&nonce_bytes);
119 result.extend_from_slice(&ciphertext);
120
121 Ok(result)
122}
123
124pub fn decrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, SyncError> {
137 if key.len() != KEY_SIZE {
138 return Err(SyncError::Encryption(format!(
139 "Invalid key size: expected {KEY_SIZE}, got {}",
140 key.len()
141 )));
142 }
143
144 if data.len() < NONCE_SIZE {
145 return Err(SyncError::Encryption(
146 "Encrypted data too short".to_string(),
147 ));
148 }
149
150 let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
152 let nonce = Nonce::from_slice(nonce_bytes);
153
154 let cipher = Aes256Gcm::new_from_slice(key)
156 .map_err(|e| SyncError::Encryption(format!("Cipher creation failed: {e}")))?;
157
158 let plaintext = cipher
160 .decrypt(nonce, ciphertext)
161 .map_err(|e| SyncError::Encryption(format!("Decryption failed: {e}")))?;
162
163 Ok(plaintext)
164}
165
166pub fn encode_key_hex(key: &[u8]) -> String {
168 hex::encode(key)
169}
170
171pub fn decode_key_hex(hex_str: &str) -> Result<Vec<u8>, SyncError> {
173 hex::decode(hex_str).map_err(|e| SyncError::Encryption(format!("Hex decode failed: {e}")))
174}
175
176mod hex {
178 pub fn encode(data: &[u8]) -> String {
179 data.iter().map(|b| format!("{:02x}", b)).collect()
180 }
181
182 pub fn decode(s: &str) -> Result<Vec<u8>, String> {
183 if !s.len().is_multiple_of(2) {
184 return Err("Hex string has odd length".to_string());
185 }
186
187 (0..s.len())
188 .step_by(2)
189 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("Invalid hex: {e}")))
190 .collect()
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_generate_salt_length() {
200 let salt = generate_salt();
201 assert_eq!(salt.len(), SALT_SIZE);
202 }
203
204 #[test]
205 fn test_generate_salt_randomness() {
206 let salt1 = generate_salt();
207 let salt2 = generate_salt();
208 assert_ne!(salt1, salt2);
209 }
210
211 #[test]
212 fn test_derive_key_consistency() {
213 let passphrase = "test passphrase";
214 let salt = generate_salt();
215
216 let key1 = derive_key(passphrase, &salt).unwrap();
217 let key2 = derive_key(passphrase, &salt).unwrap();
218
219 assert_eq!(key1, key2);
220 assert_eq!(key1.len(), KEY_SIZE);
221 }
222
223 #[test]
224 fn test_derive_key_different_passphrases() {
225 let salt = generate_salt();
226
227 let key1 = derive_key("passphrase1", &salt).unwrap();
228 let key2 = derive_key("passphrase2", &salt).unwrap();
229
230 assert_ne!(key1, key2);
231 }
232
233 #[test]
234 fn test_derive_key_different_salts() {
235 let passphrase = "test passphrase";
236 let salt1 = generate_salt();
237 let salt2 = generate_salt();
238
239 let key1 = derive_key(passphrase, &salt1).unwrap();
240 let key2 = derive_key(passphrase, &salt2).unwrap();
241
242 assert_ne!(key1, key2);
243 }
244
245 #[test]
246 fn test_encrypt_decrypt_roundtrip() {
247 let passphrase = "test passphrase";
248 let salt = generate_salt();
249 let key = derive_key(passphrase, &salt).unwrap();
250
251 let plaintext = b"Hello, World! This is a test message.";
252 let encrypted = encrypt_data(plaintext, &key).unwrap();
253 let decrypted = decrypt_data(&encrypted, &key).unwrap();
254
255 assert_eq!(decrypted, plaintext);
256 }
257
258 #[test]
259 fn test_encrypt_produces_different_ciphertext() {
260 let salt = generate_salt();
261 let key = derive_key("passphrase", &salt).unwrap();
262
263 let plaintext = b"test data";
264 let encrypted1 = encrypt_data(plaintext, &key).unwrap();
265 let encrypted2 = encrypt_data(plaintext, &key).unwrap();
266
267 assert_ne!(encrypted1, encrypted2);
269 }
270
271 #[test]
272 fn test_decrypt_with_wrong_key_fails() {
273 let salt = generate_salt();
274 let key1 = derive_key("passphrase1", &salt).unwrap();
275 let key2 = derive_key("passphrase2", &salt).unwrap();
276
277 let plaintext = b"secret data";
278 let encrypted = encrypt_data(plaintext, &key1).unwrap();
279
280 let result = decrypt_data(&encrypted, &key2);
281 assert!(result.is_err());
282 }
283
284 #[test]
285 fn test_decrypt_with_corrupted_data_fails() {
286 let salt = generate_salt();
287 let key = derive_key("passphrase", &salt).unwrap();
288
289 let plaintext = b"secret data";
290 let mut encrypted = encrypt_data(plaintext, &key).unwrap();
291
292 if let Some(byte) = encrypted.get_mut(NONCE_SIZE + 5) {
294 *byte ^= 0xFF;
295 }
296
297 let result = decrypt_data(&encrypted, &key);
298 assert!(result.is_err());
299 }
300
301 #[test]
302 fn test_encrypt_data_invalid_key_size() {
303 let short_key = vec![0u8; 16]; let result = encrypt_data(b"data", &short_key);
305 assert!(result.is_err());
306 }
307
308 #[test]
309 fn test_decrypt_data_too_short() {
310 let salt = generate_salt();
311 let key = derive_key("passphrase", &salt).unwrap();
312
313 let short_data = vec![0u8; 5]; let result = decrypt_data(&short_data, &key);
315 assert!(result.is_err());
316 }
317
318 #[test]
319 fn test_hex_roundtrip() {
320 let data = vec![0u8, 1, 2, 255, 128, 64];
321 let encoded = encode_key_hex(&data);
322 let decoded = decode_key_hex(&encoded).unwrap();
323 assert_eq!(decoded, data);
324 }
325
326 #[test]
327 fn test_hex_encode() {
328 assert_eq!(hex::encode(&[0, 255, 128]), "00ff80");
329 }
330
331 #[test]
332 fn test_hex_decode_invalid() {
333 assert!(hex::decode("xyz").is_err());
334 assert!(hex::decode("abc").is_err()); }
336
337 #[test]
338 fn test_encrypt_empty_data() {
339 let salt = generate_salt();
340 let key = derive_key("passphrase", &salt).unwrap();
341
342 let plaintext = b"";
343 let encrypted = encrypt_data(plaintext, &key).unwrap();
344 let decrypted = decrypt_data(&encrypted, &key).unwrap();
345
346 assert_eq!(decrypted, plaintext);
347 }
348
349 #[test]
350 fn test_encrypt_large_data() {
351 let salt = generate_salt();
352 let key = derive_key("passphrase", &salt).unwrap();
353
354 let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
356 let encrypted = encrypt_data(&plaintext, &key).unwrap();
357 let decrypted = decrypt_data(&encrypted, &key).unwrap();
358
359 assert_eq!(decrypted, plaintext);
360 }
361}