Skip to main content

lore_cli/sync/
encryption.rs

1//! End-to-end encryption for git-ref sync.
2//!
3//! Provides passphrase-based key derivation using Argon2id and symmetric
4//! encryption using AES-256-GCM. Session content is encrypted before being
5//! written into the lore store ref and decrypted after being read back,
6//! ensuring that anyone with repository access but without the passphrase
7//! cannot read the reasoning history.
8//!
9//! Failures are reported via [`SyncError`].
10
11use aes_gcm::{
12    aead::{Aead, KeyInit},
13    Aes256Gcm, Nonce,
14};
15use argon2::{password_hash::SaltString, Argon2, PasswordHasher};
16use rand::RngCore;
17
18use super::SyncError;
19
20/// Size of the encryption key in bytes (256 bits for AES-256).
21pub const KEY_SIZE: usize = 32;
22
23/// Size of the nonce in bytes (96 bits for AES-GCM).
24pub const NONCE_SIZE: usize = 12;
25
26/// Size of the salt for key derivation.
27pub const SALT_SIZE: usize = 16;
28
29/// Derives an encryption key from a passphrase and salt using Argon2id.
30///
31/// Uses Argon2id with secure default parameters suitable for key derivation.
32/// The same passphrase and salt will always produce the same key, allowing
33/// encryption and decryption across machines that share the salt and passphrase.
34///
35/// # Arguments
36///
37/// * `passphrase` - The user's passphrase
38/// * `salt` - A random salt (stored in the ref tree at `meta/salt` and reused)
39///
40/// # Returns
41///
42/// A 32-byte key suitable for AES-256-GCM encryption.
43pub fn derive_key(passphrase: &str, salt: &[u8]) -> Result<Vec<u8>, SyncError> {
44    // Convert salt bytes to the format expected by Argon2.
45    // SaltString::encode_b64 expects raw bytes and encodes them as base64.
46    let salt_string = SaltString::encode_b64(salt)
47        .map_err(|e| SyncError::Encryption(format!("Invalid salt: {e}")))?;
48
49    // Use Argon2id with default secure parameters.
50    let argon2 = Argon2::default();
51
52    // Hash the password to derive the key.
53    let hash = argon2
54        .hash_password(passphrase.as_bytes(), &salt_string)
55        .map_err(|e| SyncError::Encryption(format!("Key derivation failed: {e}")))?;
56
57    // Extract the hash bytes (the output is the derived key).
58    let hash_output = hash
59        .hash
60        .ok_or_else(|| SyncError::Encryption("No hash output".to_string()))?;
61
62    // Take the first KEY_SIZE bytes as the encryption key.
63    let key_bytes = hash_output.as_bytes();
64    if key_bytes.len() < KEY_SIZE {
65        return Err(SyncError::Encryption("Derived key too short".to_string()));
66    }
67
68    Ok(key_bytes[..KEY_SIZE].to_vec())
69}
70
71/// Generates a random salt for key derivation.
72///
73/// The salt is stored in the ref tree at `meta/salt` (it is not secret) so
74/// other machines derive the same key from the same passphrase.
75pub fn generate_salt() -> Vec<u8> {
76    let mut salt = vec![0u8; SALT_SIZE];
77    rand::thread_rng().fill_bytes(&mut salt);
78    salt
79}
80
81/// Encrypts data using AES-256-GCM.
82///
83/// The nonce is prepended to the ciphertext, so the output format is:
84/// `nonce (12 bytes) || ciphertext || tag (16 bytes)`
85///
86/// # Arguments
87///
88/// * `data` - The plaintext data to encrypt
89/// * `key` - The 32-byte encryption key
90///
91/// # Returns
92///
93/// The encrypted data with prepended nonce, suitable for writing as a git blob.
94pub fn encrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, SyncError> {
95    if key.len() != KEY_SIZE {
96        return Err(SyncError::Encryption(format!(
97            "Invalid key size: expected {KEY_SIZE}, got {}",
98            key.len()
99        )));
100    }
101
102    // Generate a random nonce.
103    let mut nonce_bytes = [0u8; NONCE_SIZE];
104    rand::thread_rng().fill_bytes(&mut nonce_bytes);
105    let nonce = Nonce::from_slice(&nonce_bytes);
106
107    // Create the cipher.
108    let cipher = Aes256Gcm::new_from_slice(key)
109        .map_err(|e| SyncError::Encryption(format!("Cipher creation failed: {e}")))?;
110
111    // Encrypt the data.
112    let ciphertext = cipher
113        .encrypt(nonce, data)
114        .map_err(|e| SyncError::Encryption(format!("Encryption failed: {e}")))?;
115
116    // Prepend nonce to ciphertext.
117    let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
118    result.extend_from_slice(&nonce_bytes);
119    result.extend_from_slice(&ciphertext);
120
121    Ok(result)
122}
123
124/// Decrypts data that was encrypted with [`encrypt_data`].
125///
126/// Expects the input format: `nonce (12 bytes) || ciphertext || tag (16 bytes)`
127///
128/// # Arguments
129///
130/// * `data` - The encrypted data with prepended nonce
131/// * `key` - The 32-byte encryption key
132///
133/// # Returns
134///
135/// The decrypted plaintext data.
136pub fn decrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, SyncError> {
137    if key.len() != KEY_SIZE {
138        return Err(SyncError::Encryption(format!(
139            "Invalid key size: expected {KEY_SIZE}, got {}",
140            key.len()
141        )));
142    }
143
144    if data.len() < NONCE_SIZE {
145        return Err(SyncError::Encryption(
146            "Encrypted data too short".to_string(),
147        ));
148    }
149
150    // Extract nonce and ciphertext.
151    let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
152    let nonce = Nonce::from_slice(nonce_bytes);
153
154    // Create the cipher.
155    let cipher = Aes256Gcm::new_from_slice(key)
156        .map_err(|e| SyncError::Encryption(format!("Cipher creation failed: {e}")))?;
157
158    // Decrypt the data.
159    let plaintext = cipher
160        .decrypt(nonce, ciphertext)
161        .map_err(|e| SyncError::Encryption(format!("Decryption failed: {e}")))?;
162
163    Ok(plaintext)
164}
165
166/// Encodes a key as hexadecimal for storage.
167pub fn encode_key_hex(key: &[u8]) -> String {
168    hex::encode(key)
169}
170
171/// Decodes a hexadecimal key.
172pub fn decode_key_hex(hex_str: &str) -> Result<Vec<u8>, SyncError> {
173    hex::decode(hex_str).map_err(|e| SyncError::Encryption(format!("Hex decode failed: {e}")))
174}
175
176// Minimal hex encoding helper to avoid an extra dependency.
177mod hex {
178    pub fn encode(data: &[u8]) -> String {
179        data.iter().map(|b| format!("{:02x}", b)).collect()
180    }
181
182    pub fn decode(s: &str) -> Result<Vec<u8>, String> {
183        if !s.len().is_multiple_of(2) {
184            return Err("Hex string has odd length".to_string());
185        }
186
187        (0..s.len())
188            .step_by(2)
189            .map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("Invalid hex: {e}")))
190            .collect()
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_generate_salt_length() {
200        let salt = generate_salt();
201        assert_eq!(salt.len(), SALT_SIZE);
202    }
203
204    #[test]
205    fn test_generate_salt_randomness() {
206        let salt1 = generate_salt();
207        let salt2 = generate_salt();
208        assert_ne!(salt1, salt2);
209    }
210
211    #[test]
212    fn test_derive_key_consistency() {
213        let passphrase = "test passphrase";
214        let salt = generate_salt();
215
216        let key1 = derive_key(passphrase, &salt).unwrap();
217        let key2 = derive_key(passphrase, &salt).unwrap();
218
219        assert_eq!(key1, key2);
220        assert_eq!(key1.len(), KEY_SIZE);
221    }
222
223    #[test]
224    fn test_derive_key_different_passphrases() {
225        let salt = generate_salt();
226
227        let key1 = derive_key("passphrase1", &salt).unwrap();
228        let key2 = derive_key("passphrase2", &salt).unwrap();
229
230        assert_ne!(key1, key2);
231    }
232
233    #[test]
234    fn test_derive_key_different_salts() {
235        let passphrase = "test passphrase";
236        let salt1 = generate_salt();
237        let salt2 = generate_salt();
238
239        let key1 = derive_key(passphrase, &salt1).unwrap();
240        let key2 = derive_key(passphrase, &salt2).unwrap();
241
242        assert_ne!(key1, key2);
243    }
244
245    #[test]
246    fn test_encrypt_decrypt_roundtrip() {
247        let passphrase = "test passphrase";
248        let salt = generate_salt();
249        let key = derive_key(passphrase, &salt).unwrap();
250
251        let plaintext = b"Hello, World! This is a test message.";
252        let encrypted = encrypt_data(plaintext, &key).unwrap();
253        let decrypted = decrypt_data(&encrypted, &key).unwrap();
254
255        assert_eq!(decrypted, plaintext);
256    }
257
258    #[test]
259    fn test_encrypt_produces_different_ciphertext() {
260        let salt = generate_salt();
261        let key = derive_key("passphrase", &salt).unwrap();
262
263        let plaintext = b"test data";
264        let encrypted1 = encrypt_data(plaintext, &key).unwrap();
265        let encrypted2 = encrypt_data(plaintext, &key).unwrap();
266
267        // Different nonces should produce different ciphertexts.
268        assert_ne!(encrypted1, encrypted2);
269    }
270
271    #[test]
272    fn test_decrypt_with_wrong_key_fails() {
273        let salt = generate_salt();
274        let key1 = derive_key("passphrase1", &salt).unwrap();
275        let key2 = derive_key("passphrase2", &salt).unwrap();
276
277        let plaintext = b"secret data";
278        let encrypted = encrypt_data(plaintext, &key1).unwrap();
279
280        let result = decrypt_data(&encrypted, &key2);
281        assert!(result.is_err());
282    }
283
284    #[test]
285    fn test_decrypt_with_corrupted_data_fails() {
286        let salt = generate_salt();
287        let key = derive_key("passphrase", &salt).unwrap();
288
289        let plaintext = b"secret data";
290        let mut encrypted = encrypt_data(plaintext, &key).unwrap();
291
292        // Corrupt the ciphertext.
293        if let Some(byte) = encrypted.get_mut(NONCE_SIZE + 5) {
294            *byte ^= 0xFF;
295        }
296
297        let result = decrypt_data(&encrypted, &key);
298        assert!(result.is_err());
299    }
300
301    #[test]
302    fn test_encrypt_data_invalid_key_size() {
303        let short_key = vec![0u8; 16]; // Too short
304        let result = encrypt_data(b"data", &short_key);
305        assert!(result.is_err());
306    }
307
308    #[test]
309    fn test_decrypt_data_too_short() {
310        let salt = generate_salt();
311        let key = derive_key("passphrase", &salt).unwrap();
312
313        let short_data = vec![0u8; 5]; // Shorter than nonce
314        let result = decrypt_data(&short_data, &key);
315        assert!(result.is_err());
316    }
317
318    #[test]
319    fn test_hex_roundtrip() {
320        let data = vec![0u8, 1, 2, 255, 128, 64];
321        let encoded = encode_key_hex(&data);
322        let decoded = decode_key_hex(&encoded).unwrap();
323        assert_eq!(decoded, data);
324    }
325
326    #[test]
327    fn test_hex_encode() {
328        assert_eq!(hex::encode(&[0, 255, 128]), "00ff80");
329    }
330
331    #[test]
332    fn test_hex_decode_invalid() {
333        assert!(hex::decode("xyz").is_err());
334        assert!(hex::decode("abc").is_err()); // Odd length
335    }
336
337    #[test]
338    fn test_encrypt_empty_data() {
339        let salt = generate_salt();
340        let key = derive_key("passphrase", &salt).unwrap();
341
342        let plaintext = b"";
343        let encrypted = encrypt_data(plaintext, &key).unwrap();
344        let decrypted = decrypt_data(&encrypted, &key).unwrap();
345
346        assert_eq!(decrypted, plaintext);
347    }
348
349    #[test]
350    fn test_encrypt_large_data() {
351        let salt = generate_salt();
352        let key = derive_key("passphrase", &salt).unwrap();
353
354        // 1 MB of data
355        let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
356        let encrypted = encrypt_data(&plaintext, &key).unwrap();
357        let decrypted = decrypt_data(&encrypted, &key).unwrap();
358
359        assert_eq!(decrypted, plaintext);
360    }
361}