Skip to main content

lore_cli/cloud/
encryption.rs

1//! End-to-end encryption for cloud sync.
2//!
3//! Provides passphrase-based key derivation using Argon2id and symmetric
4//! encryption using AES-256-GCM. Session message content is encrypted
5//! before upload and decrypted after download, ensuring that the cloud
6//! service cannot read session contents.
7
8use aes_gcm::{
9    aead::{Aead, KeyInit},
10    Aes256Gcm, Nonce,
11};
12use argon2::{password_hash::SaltString, Argon2, PasswordHasher};
13use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
14use rand::RngCore;
15
16use super::CloudError;
17
18/// Size of the encryption key in bytes (256 bits for AES-256).
19pub const KEY_SIZE: usize = 32;
20
21/// Size of the nonce in bytes (96 bits for AES-GCM).
22pub const NONCE_SIZE: usize = 12;
23
24/// Size of the salt for key derivation.
25#[allow(dead_code)]
26pub const SALT_SIZE: usize = 16;
27
28/// Derives an encryption key from a passphrase and salt using Argon2id.
29///
30/// Uses Argon2id with secure default parameters suitable for key derivation.
31/// The same passphrase and salt will always produce the same key, allowing
32/// encryption and decryption across sessions.
33///
34/// # Arguments
35///
36/// * `passphrase` - The user's passphrase
37/// * `salt` - A random salt (should be stored and reused for the same account)
38///
39/// # Returns
40///
41/// A 32-byte key suitable for AES-256-GCM encryption.
42pub fn derive_key(passphrase: &str, salt: &[u8]) -> Result<Vec<u8>, CloudError> {
43    // Convert salt bytes to the format expected by Argon2
44    // SaltString::encode_b64 expects raw bytes and encodes them as base64
45    let salt_string = SaltString::encode_b64(salt)
46        .map_err(|e| CloudError::EncryptionError(format!("Invalid salt: {e}")))?;
47
48    // Use Argon2id with default secure parameters
49    let argon2 = Argon2::default();
50
51    // Hash the password to derive the key
52    let hash = argon2
53        .hash_password(passphrase.as_bytes(), &salt_string)
54        .map_err(|e| CloudError::EncryptionError(format!("Key derivation failed: {e}")))?;
55
56    // Extract the hash bytes (the output is the derived key)
57    let hash_output = hash
58        .hash
59        .ok_or_else(|| CloudError::EncryptionError("No hash output".to_string()))?;
60
61    // Take the first KEY_SIZE bytes as the encryption key
62    let key_bytes = hash_output.as_bytes();
63    if key_bytes.len() < KEY_SIZE {
64        return Err(CloudError::EncryptionError(
65            "Derived key too short".to_string(),
66        ));
67    }
68
69    Ok(key_bytes[..KEY_SIZE].to_vec())
70}
71
72/// Generates a random salt for key derivation.
73///
74/// The salt should be stored (in config) and reused for the same account
75/// to ensure consistent key derivation.
76#[allow(dead_code)]
77pub fn generate_salt() -> Vec<u8> {
78    let mut salt = vec![0u8; SALT_SIZE];
79    rand::thread_rng().fill_bytes(&mut salt);
80    salt
81}
82
83/// Encrypts data using AES-256-GCM.
84///
85/// The nonce is prepended to the ciphertext, so the output format is:
86/// `nonce (12 bytes) || ciphertext || tag (16 bytes)`
87///
88/// # Arguments
89///
90/// * `data` - The plaintext data to encrypt
91/// * `key` - The 32-byte encryption key
92///
93/// # Returns
94///
95/// The encrypted data with prepended nonce, suitable for base64 encoding.
96pub fn encrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, CloudError> {
97    if key.len() != KEY_SIZE {
98        return Err(CloudError::EncryptionError(format!(
99            "Invalid key size: expected {KEY_SIZE}, got {}",
100            key.len()
101        )));
102    }
103
104    // Generate a random nonce
105    let mut nonce_bytes = [0u8; NONCE_SIZE];
106    rand::thread_rng().fill_bytes(&mut nonce_bytes);
107    let nonce = Nonce::from_slice(&nonce_bytes);
108
109    // Create the cipher
110    let cipher = Aes256Gcm::new_from_slice(key)
111        .map_err(|e| CloudError::EncryptionError(format!("Cipher creation failed: {e}")))?;
112
113    // Encrypt the data
114    let ciphertext = cipher
115        .encrypt(nonce, data)
116        .map_err(|e| CloudError::EncryptionError(format!("Encryption failed: {e}")))?;
117
118    // Prepend nonce to ciphertext
119    let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
120    result.extend_from_slice(&nonce_bytes);
121    result.extend_from_slice(&ciphertext);
122
123    Ok(result)
124}
125
126/// Decrypts data that was encrypted with `encrypt_data`.
127///
128/// Expects the input format: `nonce (12 bytes) || ciphertext || tag (16 bytes)`
129///
130/// # Arguments
131///
132/// * `data` - The encrypted data with prepended nonce
133/// * `key` - The 32-byte encryption key
134///
135/// # Returns
136///
137/// The decrypted plaintext data.
138pub fn decrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, CloudError> {
139    if key.len() != KEY_SIZE {
140        return Err(CloudError::EncryptionError(format!(
141            "Invalid key size: expected {KEY_SIZE}, got {}",
142            key.len()
143        )));
144    }
145
146    if data.len() < NONCE_SIZE {
147        return Err(CloudError::EncryptionError(
148            "Encrypted data too short".to_string(),
149        ));
150    }
151
152    // Extract nonce and ciphertext
153    let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
154    let nonce = Nonce::from_slice(nonce_bytes);
155
156    // Create the cipher
157    let cipher = Aes256Gcm::new_from_slice(key)
158        .map_err(|e| CloudError::EncryptionError(format!("Cipher creation failed: {e}")))?;
159
160    // Decrypt the data
161    let plaintext = cipher
162        .decrypt(nonce, ciphertext)
163        .map_err(|e| CloudError::EncryptionError(format!("Decryption failed: {e}")))?;
164
165    Ok(plaintext)
166}
167
168/// Encodes binary data as base64.
169pub fn encode_base64(data: &[u8]) -> String {
170    BASE64.encode(data)
171}
172
173/// Decodes base64 data to binary.
174pub fn decode_base64(data: &str) -> Result<Vec<u8>, CloudError> {
175    BASE64
176        .decode(data)
177        .map_err(|e| CloudError::EncryptionError(format!("Base64 decode failed: {e}")))
178}
179
180/// Encodes a key as hexadecimal for storage.
181pub fn encode_key_hex(key: &[u8]) -> String {
182    hex::encode(key)
183}
184
185/// Decodes a hexadecimal key.
186pub fn decode_key_hex(hex_str: &str) -> Result<Vec<u8>, CloudError> {
187    hex::decode(hex_str).map_err(|e| CloudError::EncryptionError(format!("Hex decode failed: {e}")))
188}
189
190// We need hex encoding, add a simple implementation
191mod hex {
192    pub fn encode(data: &[u8]) -> String {
193        data.iter().map(|b| format!("{:02x}", b)).collect()
194    }
195
196    pub fn decode(s: &str) -> Result<Vec<u8>, String> {
197        if !s.len().is_multiple_of(2) {
198            return Err("Hex string has odd length".to_string());
199        }
200
201        (0..s.len())
202            .step_by(2)
203            .map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("Invalid hex: {e}")))
204            .collect()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_generate_salt_length() {
214        let salt = generate_salt();
215        assert_eq!(salt.len(), SALT_SIZE);
216    }
217
218    #[test]
219    fn test_generate_salt_randomness() {
220        let salt1 = generate_salt();
221        let salt2 = generate_salt();
222        assert_ne!(salt1, salt2);
223    }
224
225    #[test]
226    fn test_derive_key_consistency() {
227        let passphrase = "test passphrase";
228        let salt = generate_salt();
229
230        let key1 = derive_key(passphrase, &salt).unwrap();
231        let key2 = derive_key(passphrase, &salt).unwrap();
232
233        assert_eq!(key1, key2);
234        assert_eq!(key1.len(), KEY_SIZE);
235    }
236
237    #[test]
238    fn test_derive_key_different_passphrases() {
239        let salt = generate_salt();
240
241        let key1 = derive_key("passphrase1", &salt).unwrap();
242        let key2 = derive_key("passphrase2", &salt).unwrap();
243
244        assert_ne!(key1, key2);
245    }
246
247    #[test]
248    fn test_derive_key_different_salts() {
249        let passphrase = "test passphrase";
250        let salt1 = generate_salt();
251        let salt2 = generate_salt();
252
253        let key1 = derive_key(passphrase, &salt1).unwrap();
254        let key2 = derive_key(passphrase, &salt2).unwrap();
255
256        assert_ne!(key1, key2);
257    }
258
259    #[test]
260    fn test_encrypt_decrypt_roundtrip() {
261        let passphrase = "test passphrase";
262        let salt = generate_salt();
263        let key = derive_key(passphrase, &salt).unwrap();
264
265        let plaintext = b"Hello, World! This is a test message.";
266        let encrypted = encrypt_data(plaintext, &key).unwrap();
267        let decrypted = decrypt_data(&encrypted, &key).unwrap();
268
269        assert_eq!(decrypted, plaintext);
270    }
271
272    #[test]
273    fn test_encrypt_produces_different_ciphertext() {
274        let salt = generate_salt();
275        let key = derive_key("passphrase", &salt).unwrap();
276
277        let plaintext = b"test data";
278        let encrypted1 = encrypt_data(plaintext, &key).unwrap();
279        let encrypted2 = encrypt_data(plaintext, &key).unwrap();
280
281        // Different nonces should produce different ciphertexts
282        assert_ne!(encrypted1, encrypted2);
283    }
284
285    #[test]
286    fn test_decrypt_with_wrong_key_fails() {
287        let salt = generate_salt();
288        let key1 = derive_key("passphrase1", &salt).unwrap();
289        let key2 = derive_key("passphrase2", &salt).unwrap();
290
291        let plaintext = b"secret data";
292        let encrypted = encrypt_data(plaintext, &key1).unwrap();
293
294        let result = decrypt_data(&encrypted, &key2);
295        assert!(result.is_err());
296    }
297
298    #[test]
299    fn test_decrypt_with_corrupted_data_fails() {
300        let salt = generate_salt();
301        let key = derive_key("passphrase", &salt).unwrap();
302
303        let plaintext = b"secret data";
304        let mut encrypted = encrypt_data(plaintext, &key).unwrap();
305
306        // Corrupt the ciphertext
307        if let Some(byte) = encrypted.get_mut(NONCE_SIZE + 5) {
308            *byte ^= 0xFF;
309        }
310
311        let result = decrypt_data(&encrypted, &key);
312        assert!(result.is_err());
313    }
314
315    #[test]
316    fn test_encrypt_data_invalid_key_size() {
317        let short_key = vec![0u8; 16]; // Too short
318        let result = encrypt_data(b"data", &short_key);
319        assert!(result.is_err());
320    }
321
322    #[test]
323    fn test_decrypt_data_too_short() {
324        let salt = generate_salt();
325        let key = derive_key("passphrase", &salt).unwrap();
326
327        let short_data = vec![0u8; 5]; // Shorter than nonce
328        let result = decrypt_data(&short_data, &key);
329        assert!(result.is_err());
330    }
331
332    #[test]
333    fn test_base64_roundtrip() {
334        let data = b"test binary data \x00\x01\x02";
335        let encoded = encode_base64(data);
336        let decoded = decode_base64(&encoded).unwrap();
337        assert_eq!(decoded, data);
338    }
339
340    #[test]
341    fn test_hex_roundtrip() {
342        let data = vec![0u8, 1, 2, 255, 128, 64];
343        let encoded = encode_key_hex(&data);
344        let decoded = decode_key_hex(&encoded).unwrap();
345        assert_eq!(decoded, data);
346    }
347
348    #[test]
349    fn test_hex_encode() {
350        assert_eq!(hex::encode(&[0, 255, 128]), "00ff80");
351    }
352
353    #[test]
354    fn test_hex_decode_invalid() {
355        assert!(hex::decode("xyz").is_err());
356        assert!(hex::decode("abc").is_err()); // Odd length
357    }
358
359    #[test]
360    fn test_encrypt_empty_data() {
361        let salt = generate_salt();
362        let key = derive_key("passphrase", &salt).unwrap();
363
364        let plaintext = b"";
365        let encrypted = encrypt_data(plaintext, &key).unwrap();
366        let decrypted = decrypt_data(&encrypted, &key).unwrap();
367
368        assert_eq!(decrypted, plaintext);
369    }
370
371    #[test]
372    fn test_encrypt_large_data() {
373        let salt = generate_salt();
374        let key = derive_key("passphrase", &salt).unwrap();
375
376        // 1 MB of data
377        let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
378        let encrypted = encrypt_data(&plaintext, &key).unwrap();
379        let decrypted = decrypt_data(&encrypted, &key).unwrap();
380
381        assert_eq!(decrypted, plaintext);
382    }
383}