1use argon2::{
8 Algorithm, Argon2, Params, ParamsBuilder, PasswordHash, PasswordVerifier, Version,
9 password_hash::{PasswordHasher, SaltString},
10};
11use rand::RngCore;
12use thiserror::Error;
13use zeroize::Zeroizing;
14
15use crate::EncryptionKey;
16
17#[allow(dead_code)]
19const DEFAULT_M_COST: u32 = 65536; #[allow(dead_code)]
23const DEFAULT_T_COST: u32 = 3;
24
25#[allow(dead_code)]
27const DEFAULT_P_COST: u32 = 4;
28
29const OUTPUT_LENGTH: usize = 32;
31
32#[derive(Debug, Error)]
34pub enum PbkdfError {
35 #[error("Invalid password")]
36 InvalidPassword,
37
38 #[error("Invalid salt")]
39 InvalidSalt,
40
41 #[error("Argon2 error: {0}")]
42 Argon2Error(String),
43
44 #[error("Invalid parameters: {0}")]
45 InvalidParams(String),
46
47 #[error("Hash verification failed")]
48 VerificationFailed,
49}
50
51#[derive(Debug, Clone, Copy)]
53pub enum KeyDerivationStrength {
54 Fast,
57
58 Interactive,
61
62 Moderate,
65
66 Strong,
69
70 Paranoid,
73}
74
75impl KeyDerivationStrength {
76 fn params(&self) -> Result<Params, PbkdfError> {
78 let (m_cost, t_cost, p_cost) = match self {
79 Self::Fast => (8 * 1024, 1, 1), Self::Interactive => (64 * 1024, 3, 4), Self::Moderate => (256 * 1024, 4, 4), Self::Strong => (512 * 1024, 5, 4), Self::Paranoid => (1024 * 1024, 10, 8), };
85
86 ParamsBuilder::new()
87 .m_cost(m_cost)
88 .t_cost(t_cost)
89 .p_cost(p_cost)
90 .output_len(OUTPUT_LENGTH)
91 .build()
92 .map_err(|e| PbkdfError::Argon2Error(e.to_string()))
93 }
94}
95
96pub struct PasswordKeyDerivation {
98 params: Params,
99}
100
101impl Default for PasswordKeyDerivation {
102 fn default() -> Self {
103 Self::new(KeyDerivationStrength::Interactive)
104 }
105}
106
107impl PasswordKeyDerivation {
108 pub fn new(strength: KeyDerivationStrength) -> Self {
110 let params = strength.params().expect("Invalid parameters");
111 Self { params }
112 }
113
114 pub fn with_params(m_cost: u32, t_cost: u32, p_cost: u32) -> Result<Self, PbkdfError> {
116 let params = ParamsBuilder::new()
117 .m_cost(m_cost)
118 .t_cost(t_cost)
119 .p_cost(p_cost)
120 .output_len(OUTPUT_LENGTH)
121 .build()
122 .map_err(|e| PbkdfError::InvalidParams(e.to_string()))?;
123
124 Ok(Self { params })
125 }
126
127 pub fn derive_key(&self, password: &str) -> Result<(EncryptionKey, Vec<u8>), PbkdfError> {
131 if password.is_empty() {
132 return Err(PbkdfError::InvalidPassword);
133 }
134
135 let salt = SaltString::generate(&mut rand::thread_rng());
137
138 let key = self.derive_key_with_salt(password, salt.as_str())?;
140
141 Ok((key, salt.as_str().as_bytes().to_vec()))
142 }
143
144 pub fn derive_key_with_salt(
146 &self,
147 password: &str,
148 salt: &str,
149 ) -> Result<EncryptionKey, PbkdfError> {
150 if password.is_empty() {
151 return Err(PbkdfError::InvalidPassword);
152 }
153
154 let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, self.params.clone());
155
156 let password_bytes = Zeroizing::new(password.as_bytes().to_vec());
158
159 let salt_string = SaltString::from_b64(salt).map_err(|_| PbkdfError::InvalidSalt)?;
161
162 let hash = argon2
164 .hash_password(&password_bytes, &salt_string)
165 .map_err(|e| PbkdfError::Argon2Error(e.to_string()))?;
166
167 let hash_bytes = hash
169 .hash
170 .ok_or_else(|| PbkdfError::Argon2Error("No hash output".to_string()))?;
171
172 if hash_bytes.len() != OUTPUT_LENGTH {
173 return Err(PbkdfError::Argon2Error(format!(
174 "Invalid output length: {} (expected {})",
175 hash_bytes.len(),
176 OUTPUT_LENGTH
177 )));
178 }
179
180 let mut key = [0u8; 32];
181 key.copy_from_slice(hash_bytes.as_bytes());
182 Ok(key)
183 }
184
185 pub fn hash_password(&self, password: &str) -> Result<String, PbkdfError> {
187 if password.is_empty() {
188 return Err(PbkdfError::InvalidPassword);
189 }
190
191 let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, self.params.clone());
192 let salt = SaltString::generate(&mut rand::thread_rng());
193 let password_bytes = Zeroizing::new(password.as_bytes().to_vec());
194
195 let hash = argon2
196 .hash_password(&password_bytes, &salt)
197 .map_err(|e| PbkdfError::Argon2Error(e.to_string()))?;
198
199 Ok(hash.to_string())
200 }
201
202 pub fn verify_password(password: &str, hash: &str) -> Result<(), PbkdfError> {
204 if password.is_empty() {
205 return Err(PbkdfError::InvalidPassword);
206 }
207
208 let parsed_hash =
209 PasswordHash::new(hash).map_err(|e| PbkdfError::Argon2Error(e.to_string()))?;
210
211 let password_bytes = Zeroizing::new(password.as_bytes().to_vec());
212
213 Argon2::default()
215 .verify_password(&password_bytes, &parsed_hash)
216 .map_err(|_| PbkdfError::VerificationFailed)
217 }
218}
219
220pub fn derive_key_from_password(password: &str) -> Result<(EncryptionKey, Vec<u8>), PbkdfError> {
226 PasswordKeyDerivation::default().derive_key(password)
227}
228
229pub fn derive_key_with_salt(password: &str, salt: &str) -> Result<EncryptionKey, PbkdfError> {
231 PasswordKeyDerivation::default().derive_key_with_salt(password, salt)
232}
233
234pub fn generate_salt() -> Vec<u8> {
236 let mut salt = vec![0u8; 16];
237 rand::thread_rng().fill_bytes(&mut salt);
238 salt
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_derive_key_from_password() {
247 let password = "correct horse battery staple";
248 let (key1, salt) = derive_key_from_password(password).unwrap();
249
250 assert_eq!(key1.len(), 32);
251 assert!(!salt.is_empty());
252
253 let salt_str = std::str::from_utf8(&salt).unwrap();
255 let key2 = derive_key_with_salt(password, salt_str).unwrap();
256 assert_eq!(key1, key2);
257 }
258
259 #[test]
260 fn test_different_passwords_different_keys() {
261 let (key1, _) = derive_key_from_password("password1").unwrap();
262 let (key2, _) = derive_key_from_password("password2").unwrap();
263
264 assert_ne!(key1, key2);
265 }
266
267 #[test]
268 fn test_password_hashing() {
269 let pbkdf = PasswordKeyDerivation::default();
270 let password = "my secret password";
271
272 let hash = pbkdf.hash_password(password).unwrap();
273 assert!(hash.starts_with("$argon2id$"));
274
275 assert!(PasswordKeyDerivation::verify_password(password, &hash).is_ok());
277
278 assert!(PasswordKeyDerivation::verify_password("wrong password", &hash).is_err());
280 }
281
282 #[test]
283 fn test_strength_levels() {
284 let password = "test password";
285
286 for strength in &[
288 KeyDerivationStrength::Fast,
289 KeyDerivationStrength::Interactive,
290 KeyDerivationStrength::Moderate,
291 KeyDerivationStrength::Strong,
292 ] {
293 let pbkdf = PasswordKeyDerivation::new(*strength);
294 let (key, salt) = pbkdf.derive_key(password).unwrap();
295
296 assert_eq!(key.len(), 32);
297 assert!(!salt.is_empty());
298 }
299 }
300
301 #[test]
302 fn test_empty_password() {
303 let result = derive_key_from_password("");
304 assert!(result.is_err());
305 }
306
307 #[test]
308 fn test_custom_params() {
309 let pbkdf = PasswordKeyDerivation::with_params(4096, 2, 1).unwrap();
310 let (key, _) = pbkdf.derive_key("test").unwrap();
311 assert_eq!(key.len(), 32);
312 }
313
314 #[test]
315 fn test_deterministic_derivation() {
316 let password = "test password";
317 let pbkdf = PasswordKeyDerivation::default();
318
319 let (_, salt1) = pbkdf.derive_key(password).unwrap();
320 let salt_str = std::str::from_utf8(&salt1).unwrap();
321
322 let key1 = pbkdf.derive_key_with_salt(password, salt_str).unwrap();
324 let key2 = pbkdf.derive_key_with_salt(password, salt_str).unwrap();
325
326 assert_eq!(key1, key2);
327 }
328}