1use aes_gcm::{
6 aead::{Aead, KeyInit, OsRng},
7 Aes256Gcm, Nonce,
8};
9use argon2::{
10 password_hash::{PasswordHasher, SaltString},
11 Argon2,
12};
13use rand::RngCore;
14use thiserror::Error;
15
16#[derive(Debug, Error)]
17pub enum CryptoError {
18 #[error("Encryption failed: {0}")]
19 EncryptionFailed(String),
20 #[error("Decryption failed: {0}")]
21 DecryptionFailed(String),
22 #[error("Key derivation failed: {0}")]
23 KeyDerivationFailed(String),
24 #[error("Invalid passphrase")]
25 InvalidPassphrase,
26}
27
28pub type Result<T> = std::result::Result<T, CryptoError>;
29
30#[derive(Debug, Clone)]
32pub struct KdfParams {
33 pub salt: Vec<u8>,
34 pub memory_cost: u32,
35 pub time_cost: u32,
36 pub parallelism: u32,
37}
38
39impl Default for KdfParams {
40 fn default() -> Self {
41 Self {
42 salt: Vec::new(),
43 memory_cost: 19456, time_cost: 2,
45 parallelism: 1,
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct EncryptedData {
53 pub nonce: Vec<u8>,
54 pub ciphertext: Vec<u8>,
55}
56
57pub fn derive_key(passphrase: &str, params: &KdfParams) -> Result<[u8; 32]> {
59 if passphrase.is_empty() {
60 return Err(CryptoError::InvalidPassphrase);
61 }
62
63 let argon2 = Argon2::default();
64
65 let salt_string = SaltString::encode_b64(¶ms.salt)
67 .map_err(|e| CryptoError::KeyDerivationFailed(e.to_string()))?;
68
69 let password_hash = argon2
71 .hash_password(passphrase.as_bytes(), &salt_string)
72 .map_err(|e| CryptoError::KeyDerivationFailed(e.to_string()))?;
73
74 let hash_bytes = password_hash
76 .hash
77 .ok_or_else(|| CryptoError::KeyDerivationFailed("No hash generated".to_string()))?;
78
79 let hash_slice = hash_bytes.as_bytes();
80 if hash_slice.len() < 32 {
81 return Err(CryptoError::KeyDerivationFailed(
82 "Hash too short".to_string(),
83 ));
84 }
85
86 let mut key = [0u8; 32];
87 key.copy_from_slice(&hash_slice[..32]);
88 Ok(key)
89}
90
91pub fn generate_salt() -> Vec<u8> {
93 let mut salt = vec![0u8; 16];
94 OsRng.fill_bytes(&mut salt);
95 salt
96}
97
98pub fn encrypt(plaintext: &[u8], key: &[u8; 32]) -> Result<EncryptedData> {
100 let cipher = Aes256Gcm::new(key.into());
101
102 let mut nonce_bytes = [0u8; 12];
104 OsRng.fill_bytes(&mut nonce_bytes);
105 let nonce = Nonce::from_slice(&nonce_bytes);
106
107 let ciphertext = cipher
109 .encrypt(nonce, plaintext)
110 .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
111
112 Ok(EncryptedData {
113 nonce: nonce_bytes.to_vec(),
114 ciphertext,
115 })
116}
117
118pub fn decrypt(encrypted: &EncryptedData, key: &[u8; 32]) -> Result<Vec<u8>> {
120 if encrypted.nonce.len() != 12 {
121 return Err(CryptoError::DecryptionFailed(
122 "Invalid nonce length".to_string(),
123 ));
124 }
125
126 let cipher = Aes256Gcm::new(key.into());
127 let nonce = Nonce::from_slice(&encrypted.nonce);
128
129 let plaintext = cipher
130 .decrypt(nonce, encrypted.ciphertext.as_ref())
131 .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
132
133 Ok(plaintext)
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_derive_key_deterministic() {
142 let passphrase = "test_password";
143 let params = KdfParams {
144 salt: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
145 memory_cost: 19456,
146 time_cost: 2,
147 parallelism: 1,
148 };
149
150 let key1 = derive_key(passphrase, ¶ms).unwrap();
151 let key2 = derive_key(passphrase, ¶ms).unwrap();
152
153 assert_eq!(
154 key1, key2,
155 "Same passphrase and salt should produce same key"
156 );
157 }
158
159 #[test]
160 fn test_derive_key_empty_passphrase() {
161 let params = KdfParams::default();
162 let result = derive_key("", ¶ms);
163 assert!(result.is_err());
164 assert!(matches!(
165 result.unwrap_err(),
166 CryptoError::InvalidPassphrase
167 ));
168 }
169
170 #[test]
171 fn test_encrypt_decrypt_round_trip() {
172 let passphrase = "test_password";
173 let plaintext = b"Hello, World!";
174
175 let params = KdfParams {
176 salt: generate_salt(),
177 ..Default::default()
178 };
179
180 let key = derive_key(passphrase, ¶ms).unwrap();
181
182 let encrypted = encrypt(plaintext, &key).unwrap();
183 let decrypted = decrypt(&encrypted, &key).unwrap();
184
185 assert_eq!(plaintext, decrypted.as_slice());
186 }
187
188 #[test]
189 fn test_decrypt_wrong_key() {
190 let plaintext = b"Hello, World!";
191
192 let params1 = KdfParams {
193 salt: generate_salt(),
194 ..Default::default()
195 };
196 let key1 = derive_key("password1", ¶ms1).unwrap();
197
198 let params2 = KdfParams {
199 salt: generate_salt(),
200 ..Default::default()
201 };
202 let key2 = derive_key("password2", ¶ms2).unwrap();
203
204 let encrypted = encrypt(plaintext, &key1).unwrap();
205 let result = decrypt(&encrypted, &key2);
206
207 assert!(result.is_err());
208 assert!(matches!(
209 result.unwrap_err(),
210 CryptoError::DecryptionFailed(_)
211 ));
212 }
213
214 #[test]
215 fn test_generate_salt_unique() {
216 let salt1 = generate_salt();
217 let salt2 = generate_salt();
218
219 assert_ne!(salt1, salt2, "Generated salts should be unique");
220 }
221
222 #[test]
223 fn test_nonce_uniqueness() {
224 let key = [0u8; 32];
225 let plaintext = b"test";
226
227 let encrypted1 = encrypt(plaintext, &key).unwrap();
228 let encrypted2 = encrypt(plaintext, &key).unwrap();
229
230 assert_ne!(
231 encrypted1.nonce, encrypted2.nonce,
232 "Nonces should be unique"
233 );
234 }
235}