Skip to main content

firecloud_crypto/
kek.rs

1//! Key Encryption Key (KEK) management
2//!
3//! Provides functionality for:
4//! - Deriving KEK from user password using Argon2id
5//! - Encrypting/decrypting Data Encryption Keys (DEK)
6//! - Secure key storage and handling
7
8use crate::{CryptoError, CryptoResult};
9use argon2::{
10    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
11    Argon2, Params, Version,
12};
13use chacha20poly1305::{
14    aead::{Aead, KeyInit, OsRng},
15    XChaCha20Poly1305, XNonce,
16};
17use rand::RngCore;
18use std::fmt;
19use zeroize::{Zeroize, ZeroizeOnDrop};
20
21/// Size of KEK in bytes (256-bit)
22pub const KEK_SIZE: usize = 32;
23
24/// Size of salt for password hashing (128-bit)
25pub const SALT_SIZE: usize = 16;
26
27/// Size of nonce for XChaCha20 (192-bit)
28pub const NONCE_SIZE: usize = 24;
29
30/// Key Encryption Key derived from user password
31#[derive(Clone, Zeroize, ZeroizeOnDrop)]
32pub struct Kek {
33    key: [u8; KEK_SIZE],
34}
35
36impl Kek {
37    /// Derive a KEK from a password using Argon2id
38    ///
39    /// Uses strong parameters:
40    /// - Memory: 64 MB
41    /// - Iterations: 3
42    /// - Parallelism: 4 threads
43    pub fn derive_from_password(password: &str, salt: &[u8]) -> CryptoResult<Self> {
44        if salt.len() != SALT_SIZE {
45            return Err(CryptoError::InvalidKeySize {
46                expected: SALT_SIZE,
47                actual: salt.len(),
48            });
49        }
50
51        // Configure Argon2id with strong parameters
52        // Memory: 64 MB, Iterations: 3, Parallelism: 4
53        let params = Params::new(65536, 3, 4, Some(KEK_SIZE))
54            .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
55
56        let argon2 = Argon2::new(
57            argon2::Algorithm::Argon2id,
58            Version::V0x13,
59            params,
60        );
61
62        // Derive key from password
63        let mut key = [0u8; KEK_SIZE];
64        argon2
65            .hash_password_into(password.as_bytes(), salt, &mut key)
66            .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
67
68        Ok(Self { key })
69    }
70
71    /// Create a KEK directly from raw key bytes (for testing)
72    pub fn from_bytes(bytes: [u8; KEK_SIZE]) -> Self {
73        Self { key: bytes }
74    }
75
76    /// Get the raw key bytes (use with caution)
77    pub fn as_bytes(&self) -> &[u8; KEK_SIZE] {
78        &self.key
79    }
80
81    /// Encrypt a Data Encryption Key (DEK) with this KEK
82    ///
83    /// Returns: (nonce || ciphertext)
84    pub fn encrypt_dek(&self, dek: &[u8]) -> CryptoResult<Vec<u8>> {
85        let cipher = XChaCha20Poly1305::new(&self.key.into());
86
87        // Generate random nonce
88        let mut nonce_bytes = [0u8; NONCE_SIZE];
89        OsRng.fill_bytes(&mut nonce_bytes);
90        let nonce = XNonce::from_slice(&nonce_bytes);
91
92        // Encrypt the DEK
93        let ciphertext = cipher
94            .encrypt(nonce, dek)
95            .map_err(|e| CryptoError::Encryption(e.to_string()))?;
96
97        // Prepend nonce to ciphertext
98        let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
99        result.extend_from_slice(&nonce_bytes);
100        result.extend_from_slice(&ciphertext);
101
102        Ok(result)
103    }
104
105    /// Decrypt a Data Encryption Key (DEK) with this KEK
106    ///
107    /// Input: (nonce || ciphertext)
108    pub fn decrypt_dek(&self, encrypted_dek: &[u8]) -> CryptoResult<Vec<u8>> {
109        if encrypted_dek.len() < NONCE_SIZE {
110            return Err(CryptoError::Decryption(
111                "Encrypted DEK too short".to_string(),
112            ));
113        }
114
115        let cipher = XChaCha20Poly1305::new(&self.key.into());
116
117        // Extract nonce and ciphertext
118        let (nonce_bytes, ciphertext) = encrypted_dek.split_at(NONCE_SIZE);
119        let nonce = XNonce::from_slice(nonce_bytes);
120
121        // Decrypt
122        let dek = cipher
123            .decrypt(nonce, ciphertext)
124            .map_err(|e| CryptoError::Decryption(e.to_string()))?;
125
126        Ok(dek)
127    }
128}
129
130impl fmt::Debug for Kek {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_struct("Kek")
133            .field("key", &"[REDACTED]")
134            .finish()
135    }
136}
137
138/// Generate a random salt for password hashing
139pub fn generate_salt() -> [u8; SALT_SIZE] {
140    let mut salt = [0u8; SALT_SIZE];
141    OsRng.fill_bytes(&mut salt);
142    salt
143}
144
145/// Hash a password for storage (verification only, not for encryption)
146pub fn hash_password(password: &str) -> CryptoResult<String> {
147    let salt = SaltString::generate(&mut OsRng);
148    
149    let params = Params::new(65536, 3, 4, None)
150        .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
151    
152    let argon2 = Argon2::new(
153        argon2::Algorithm::Argon2id,
154        Version::V0x13,
155        params,
156    );
157
158    let hash = argon2
159        .hash_password(password.as_bytes(), &salt)
160        .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?
161        .to_string();
162
163    Ok(hash)
164}
165
166/// Verify a password against a stored hash
167pub fn verify_password(password: &str, hash: &str) -> CryptoResult<bool> {
168    let parsed_hash = PasswordHash::new(hash)
169        .map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
170
171    let argon2 = Argon2::default();
172
173    Ok(argon2
174        .verify_password(password.as_bytes(), &parsed_hash)
175        .is_ok())
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_kek_derive_from_password() {
184        let password = "super_secret_password";
185        let salt = generate_salt();
186
187        let kek1 = Kek::derive_from_password(password, &salt).unwrap();
188        let kek2 = Kek::derive_from_password(password, &salt).unwrap();
189
190        // Same password + salt should produce same KEK
191        assert_eq!(kek1.as_bytes(), kek2.as_bytes());
192
193        // Different salt should produce different KEK
194        let different_salt = generate_salt();
195        let kek3 = Kek::derive_from_password(password, &different_salt).unwrap();
196        assert_ne!(kek1.as_bytes(), kek3.as_bytes());
197    }
198
199    #[test]
200    fn test_encrypt_decrypt_dek() {
201        let password = "test_password";
202        let salt = generate_salt();
203        let kek = Kek::derive_from_password(password, &salt).unwrap();
204
205        // Sample DEK (32 bytes)
206        let dek = b"this_is_a_32_byte_dek_key_123456";
207
208        // Encrypt
209        let encrypted = kek.encrypt_dek(dek).unwrap();
210        assert!(encrypted.len() > dek.len()); // Should be larger (nonce + tag)
211
212        // Decrypt
213        let decrypted = kek.decrypt_dek(&encrypted).unwrap();
214        assert_eq!(&decrypted[..], &dek[..]);
215    }
216
217    #[test]
218    fn test_wrong_password() {
219        let salt = generate_salt();
220        
221        let kek1 = Kek::derive_from_password("password1", &salt).unwrap();
222        let kek2 = Kek::derive_from_password("password2", &salt).unwrap();
223
224        let dek = b"sample_dek_key_32_bytes_long_123";
225        let encrypted = kek1.encrypt_dek(dek).unwrap();
226
227        // Trying to decrypt with wrong KEK should fail
228        let result = kek2.decrypt_dek(&encrypted);
229        assert!(result.is_err());
230    }
231
232    #[test]
233    fn test_password_hashing() {
234        let password = "my_secure_password";
235        
236        let hash = hash_password(password).unwrap();
237        assert!(hash.starts_with("$argon2id$"));
238
239        // Verify correct password
240        assert!(verify_password(password, &hash).unwrap());
241
242        // Verify wrong password
243        assert!(!verify_password("wrong_password", &hash).unwrap());
244    }
245
246    #[test]
247    fn test_salt_generation() {
248        let salt1 = generate_salt();
249        let salt2 = generate_salt();
250
251        // Should generate different salts
252        assert_ne!(salt1, salt2);
253        assert_eq!(salt1.len(), SALT_SIZE);
254    }
255}