1use crate::{CryptoError, CryptoResult};
9use argon2::{
10 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
11 Argon2, Params, Version,
12};
13use chacha20poly1305::{
14 aead::{Aead, KeyInit, OsRng},
15 XChaCha20Poly1305, XNonce,
16};
17use rand::RngCore;
18use std::fmt;
19use zeroize::{Zeroize, ZeroizeOnDrop};
20
21pub const KEK_SIZE: usize = 32;
23
24pub const SALT_SIZE: usize = 16;
26
27pub const NONCE_SIZE: usize = 24;
29
30#[derive(Clone, Zeroize, ZeroizeOnDrop)]
32pub struct Kek {
33 key: [u8; KEK_SIZE],
34}
35
36impl Kek {
37 pub fn derive_from_password(password: &str, salt: &[u8]) -> CryptoResult<Self> {
44 if salt.len() != SALT_SIZE {
45 return Err(CryptoError::InvalidKeySize {
46 expected: SALT_SIZE,
47 actual: salt.len(),
48 });
49 }
50
51 let params = Params::new(65536, 3, 4, Some(KEK_SIZE))
54 .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
55
56 let argon2 = Argon2::new(
57 argon2::Algorithm::Argon2id,
58 Version::V0x13,
59 params,
60 );
61
62 let mut key = [0u8; KEK_SIZE];
64 argon2
65 .hash_password_into(password.as_bytes(), salt, &mut key)
66 .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
67
68 Ok(Self { key })
69 }
70
71 pub fn from_bytes(bytes: [u8; KEK_SIZE]) -> Self {
73 Self { key: bytes }
74 }
75
76 pub fn as_bytes(&self) -> &[u8; KEK_SIZE] {
78 &self.key
79 }
80
81 pub fn encrypt_dek(&self, dek: &[u8]) -> CryptoResult<Vec<u8>> {
85 let cipher = XChaCha20Poly1305::new(&self.key.into());
86
87 let mut nonce_bytes = [0u8; NONCE_SIZE];
89 OsRng.fill_bytes(&mut nonce_bytes);
90 let nonce = XNonce::from_slice(&nonce_bytes);
91
92 let ciphertext = cipher
94 .encrypt(nonce, dek)
95 .map_err(|e| CryptoError::Encryption(e.to_string()))?;
96
97 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
99 result.extend_from_slice(&nonce_bytes);
100 result.extend_from_slice(&ciphertext);
101
102 Ok(result)
103 }
104
105 pub fn decrypt_dek(&self, encrypted_dek: &[u8]) -> CryptoResult<Vec<u8>> {
109 if encrypted_dek.len() < NONCE_SIZE {
110 return Err(CryptoError::Decryption(
111 "Encrypted DEK too short".to_string(),
112 ));
113 }
114
115 let cipher = XChaCha20Poly1305::new(&self.key.into());
116
117 let (nonce_bytes, ciphertext) = encrypted_dek.split_at(NONCE_SIZE);
119 let nonce = XNonce::from_slice(nonce_bytes);
120
121 let dek = cipher
123 .decrypt(nonce, ciphertext)
124 .map_err(|e| CryptoError::Decryption(e.to_string()))?;
125
126 Ok(dek)
127 }
128}
129
130impl fmt::Debug for Kek {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("Kek")
133 .field("key", &"[REDACTED]")
134 .finish()
135 }
136}
137
138pub fn generate_salt() -> [u8; SALT_SIZE] {
140 let mut salt = [0u8; SALT_SIZE];
141 OsRng.fill_bytes(&mut salt);
142 salt
143}
144
145pub fn hash_password(password: &str) -> CryptoResult<String> {
147 let salt = SaltString::generate(&mut OsRng);
148
149 let params = Params::new(65536, 3, 4, None)
150 .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
151
152 let argon2 = Argon2::new(
153 argon2::Algorithm::Argon2id,
154 Version::V0x13,
155 params,
156 );
157
158 let hash = argon2
159 .hash_password(password.as_bytes(), &salt)
160 .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?
161 .to_string();
162
163 Ok(hash)
164}
165
166pub fn verify_password(password: &str, hash: &str) -> CryptoResult<bool> {
168 let parsed_hash = PasswordHash::new(hash)
169 .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
170
171 let argon2 = Argon2::default();
172
173 Ok(argon2
174 .verify_password(password.as_bytes(), &parsed_hash)
175 .is_ok())
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_kek_derive_from_password() {
184 let password = "super_secret_password";
185 let salt = generate_salt();
186
187 let kek1 = Kek::derive_from_password(password, &salt).unwrap();
188 let kek2 = Kek::derive_from_password(password, &salt).unwrap();
189
190 assert_eq!(kek1.as_bytes(), kek2.as_bytes());
192
193 let different_salt = generate_salt();
195 let kek3 = Kek::derive_from_password(password, &different_salt).unwrap();
196 assert_ne!(kek1.as_bytes(), kek3.as_bytes());
197 }
198
199 #[test]
200 fn test_encrypt_decrypt_dek() {
201 let password = "test_password";
202 let salt = generate_salt();
203 let kek = Kek::derive_from_password(password, &salt).unwrap();
204
205 let dek = b"this_is_a_32_byte_dek_key_123456";
207
208 let encrypted = kek.encrypt_dek(dek).unwrap();
210 assert!(encrypted.len() > dek.len()); let decrypted = kek.decrypt_dek(&encrypted).unwrap();
214 assert_eq!(&decrypted[..], &dek[..]);
215 }
216
217 #[test]
218 fn test_wrong_password() {
219 let salt = generate_salt();
220
221 let kek1 = Kek::derive_from_password("password1", &salt).unwrap();
222 let kek2 = Kek::derive_from_password("password2", &salt).unwrap();
223
224 let dek = b"sample_dek_key_32_bytes_long_123";
225 let encrypted = kek1.encrypt_dek(dek).unwrap();
226
227 let result = kek2.decrypt_dek(&encrypted);
229 assert!(result.is_err());
230 }
231
232 #[test]
233 fn test_password_hashing() {
234 let password = "my_secure_password";
235
236 let hash = hash_password(password).unwrap();
237 assert!(hash.starts_with("$argon2id$"));
238
239 assert!(verify_password(password, &hash).unwrap());
241
242 assert!(!verify_password("wrong_password", &hash).unwrap());
244 }
245
246 #[test]
247 fn test_salt_generation() {
248 let salt1 = generate_salt();
249 let salt2 = generate_salt();
250
251 assert_ne!(salt1, salt2);
253 assert_eq!(salt1.len(), SALT_SIZE);
254 }
255}