1use aes_gcm::{
20 aead::{Aead, KeyInit},
21 Aes256Gcm, Nonce,
22};
23use hkdf::Hkdf;
24use rand::RngCore;
25use sha2::{Digest, Sha256};
26
27pub type EncryptionKey = [u8; 32];
29
30const NONCE_SIZE: usize = 12;
32
33const TAG_SIZE: usize = 16;
35
36const CHK_SALT: &[u8] = b"hashtree-chk";
38
39#[derive(Debug, thiserror::Error)]
41pub enum CryptoError {
42 #[error("Encryption failed: {0}")]
43 EncryptionFailed(String),
44 #[error("Decryption failed: {0}")]
45 DecryptionFailed(String),
46 #[error("Encrypted data too short")]
47 DataTooShort,
48 #[error("Invalid key length")]
49 InvalidKeyLength,
50 #[error("Key derivation failed")]
51 KeyDerivationFailed,
52}
53
54fn derive_key(content_hash: &[u8; 32]) -> Result<[u8; 32], CryptoError> {
56 let hk = Hkdf::<Sha256>::new(Some(CHK_SALT), content_hash);
57
58 let mut key = [0u8; 32];
59 hk.expand(b"encryption-key", &mut key)
60 .map_err(|_| CryptoError::KeyDerivationFailed)?;
61
62 Ok(key)
63}
64
65pub fn generate_key() -> EncryptionKey {
67 let mut key = [0u8; 32];
68 rand::thread_rng().fill_bytes(&mut key);
69 key
70}
71
72pub fn content_hash(data: &[u8]) -> EncryptionKey {
74 let hash = Sha256::digest(data);
75 let mut result = [0u8; 32];
76 result.copy_from_slice(&hash);
77 result
78}
79
80pub fn encrypt_chk(plaintext: &[u8]) -> Result<(Vec<u8>, EncryptionKey), CryptoError> {
91 let chash = content_hash(plaintext);
92 let key = derive_key(&chash)?;
93 let zero_nonce = [0u8; NONCE_SIZE];
94
95 let cipher = Aes256Gcm::new_from_slice(&key)
96 .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
97
98 let ciphertext = cipher
99 .encrypt(Nonce::from_slice(&zero_nonce), plaintext)
100 .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
101
102 Ok((ciphertext, chash))
103}
104
105pub fn decrypt_chk(ciphertext: &[u8], key: &EncryptionKey) -> Result<Vec<u8>, CryptoError> {
109 if ciphertext.len() < TAG_SIZE {
110 return Err(CryptoError::DataTooShort);
111 }
112
113 let enc_key = derive_key(key)?;
114 let zero_nonce = [0u8; NONCE_SIZE];
115
116 let cipher = Aes256Gcm::new_from_slice(&enc_key)
117 .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
118
119 cipher
120 .decrypt(Nonce::from_slice(&zero_nonce), ciphertext)
121 .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
122}
123
124pub fn encrypt(plaintext: &[u8], key: &EncryptionKey) -> Result<Vec<u8>, CryptoError> {
128 let cipher = Aes256Gcm::new_from_slice(key)
129 .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
130
131 let mut nonce_bytes = [0u8; NONCE_SIZE];
132 rand::thread_rng().fill_bytes(&mut nonce_bytes);
133 let nonce = Nonce::from_slice(&nonce_bytes);
134
135 let ciphertext = cipher
136 .encrypt(nonce, plaintext)
137 .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
138
139 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
140 result.extend_from_slice(&nonce_bytes);
141 result.extend_from_slice(&ciphertext);
142
143 Ok(result)
144}
145
146pub fn decrypt(encrypted: &[u8], key: &EncryptionKey) -> Result<Vec<u8>, CryptoError> {
150 if encrypted.len() < NONCE_SIZE + TAG_SIZE {
151 return Err(CryptoError::DataTooShort);
152 }
153
154 let cipher = Aes256Gcm::new_from_slice(key)
155 .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
156
157 let nonce = Nonce::from_slice(&encrypted[..NONCE_SIZE]);
158 let ciphertext = &encrypted[NONCE_SIZE..];
159
160 cipher
161 .decrypt(nonce, ciphertext)
162 .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
163}
164
165pub fn could_be_encrypted(data: &[u8]) -> bool {
167 data.len() >= NONCE_SIZE + TAG_SIZE
168}
169
170pub fn encrypted_size(plaintext_size: usize) -> usize {
172 NONCE_SIZE + plaintext_size + TAG_SIZE
173}
174
175pub fn encrypted_size_chk(plaintext_size: usize) -> usize {
177 plaintext_size + TAG_SIZE
178}
179
180pub fn plaintext_size(encrypted_size: usize) -> usize {
182 encrypted_size.saturating_sub(NONCE_SIZE + TAG_SIZE)
183}
184
185pub fn key_to_hex(key: &EncryptionKey) -> String {
187 hex::encode(key)
188}
189
190pub fn key_from_hex(hex_str: &str) -> Result<EncryptionKey, CryptoError> {
192 let bytes = hex::decode(hex_str).map_err(|_| CryptoError::InvalidKeyLength)?;
193 if bytes.len() != 32 {
194 return Err(CryptoError::InvalidKeyLength);
195 }
196 let mut key = [0u8; 32];
197 key.copy_from_slice(&bytes);
198 Ok(key)
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_chk_encrypt_decrypt() {
207 let plaintext = b"Hello, World!";
208
209 let (ciphertext, key) = encrypt_chk(plaintext).unwrap();
210 let decrypted = decrypt_chk(&ciphertext, &key).unwrap();
211
212 assert_eq!(decrypted, plaintext);
213 }
214
215 #[test]
216 fn test_chk_deterministic() {
217 let plaintext = b"Same content produces same ciphertext";
218
219 let (ciphertext1, key1) = encrypt_chk(plaintext).unwrap();
220 let (ciphertext2, key2) = encrypt_chk(plaintext).unwrap();
221
222 assert_eq!(key1, key2);
224 assert_eq!(ciphertext1, ciphertext2);
225 }
226
227 #[test]
228 fn test_chk_different_content() {
229 let (ciphertext1, key1) = encrypt_chk(b"Content A").unwrap();
230 let (ciphertext2, key2) = encrypt_chk(b"Content B").unwrap();
231
232 assert_ne!(key1, key2);
234 assert_ne!(ciphertext1, ciphertext2);
235 }
236
237 #[test]
238 fn test_chk_wrong_key_fails() {
239 let (ciphertext, _key) = encrypt_chk(b"Secret data").unwrap();
240 let wrong_key = generate_key();
241
242 let result = decrypt_chk(&ciphertext, &wrong_key);
243 assert!(result.is_err());
244 }
245
246 #[test]
247 fn test_non_chk_encrypt_decrypt() {
248 let key = generate_key();
249 let plaintext = b"Hello, World!";
250
251 let encrypted = encrypt(plaintext, &key).unwrap();
252 let decrypted = decrypt(&encrypted, &key).unwrap();
253
254 assert_eq!(decrypted, plaintext);
255 }
256
257 #[test]
258 fn test_non_chk_random_nonce() {
259 let key = generate_key();
260 let plaintext = b"Same content";
261
262 let encrypted1 = encrypt(plaintext, &key).unwrap();
263 let encrypted2 = encrypt(plaintext, &key).unwrap();
264
265 assert_ne!(encrypted1, encrypted2);
267
268 assert_eq!(decrypt(&encrypted1, &key).unwrap(), plaintext);
270 assert_eq!(decrypt(&encrypted2, &key).unwrap(), plaintext);
271 }
272
273 #[test]
274 fn test_empty_data() {
275 let (ciphertext, key) = encrypt_chk(b"").unwrap();
276 let decrypted = decrypt_chk(&ciphertext, &key).unwrap();
277 assert_eq!(decrypted, b"");
278 }
279
280 #[test]
281 fn test_large_data() {
282 let plaintext = vec![0u8; 1024 * 1024]; let (ciphertext, key) = encrypt_chk(&plaintext).unwrap();
285 let decrypted = decrypt_chk(&ciphertext, &key).unwrap();
286
287 assert_eq!(decrypted, plaintext);
288 }
289
290 #[test]
291 fn test_key_hex_roundtrip() {
292 let key = generate_key();
293 let hex_str = key_to_hex(&key);
294 let key2 = key_from_hex(&hex_str).unwrap();
295 assert_eq!(key, key2);
296 }
297
298 #[test]
299 fn test_encrypted_size_chk() {
300 let plaintext = b"Test data";
301 let (ciphertext, _) = encrypt_chk(plaintext).unwrap();
302 assert_eq!(ciphertext.len(), encrypted_size_chk(plaintext.len()));
303 }
304
305 #[test]
306 fn test_tampered_data_fails() {
307 let (mut ciphertext, key) = encrypt_chk(b"Important data").unwrap();
308
309 if let Some(byte) = ciphertext.last_mut() {
311 *byte ^= 0xFF;
312 }
313
314 let result = decrypt_chk(&ciphertext, &key);
315 assert!(result.is_err());
316 }
317}