Skip to main content

cortexai_encryption/
key.rs

1//! Encryption key management and derivation.
2
3use crate::error::{CryptoError, CryptoResult};
4use crate::traits::KeyDerivation;
5use argon2::{Argon2, Params};
6use rand::RngCore;
7use zeroize::{Zeroize, ZeroizeOnDrop};
8
9/// A secure encryption key that is zeroed on drop.
10#[derive(Clone, Zeroize, ZeroizeOnDrop)]
11pub struct EncryptionKey {
12    bytes: Vec<u8>,
13}
14
15impl EncryptionKey {
16    /// Create a new encryption key from bytes.
17    pub fn new(bytes: Vec<u8>) -> Self {
18        Self { bytes }
19    }
20
21    /// Generate a random encryption key of the specified length.
22    pub fn generate(length: usize) -> Self {
23        let mut bytes = vec![0u8; length];
24        rand::thread_rng().fill_bytes(&mut bytes);
25        Self { bytes }
26    }
27
28    /// Create a key from a base64-encoded string.
29    pub fn from_base64(encoded: &str) -> CryptoResult<Self> {
30        use base64::{engine::general_purpose::STANDARD, Engine};
31        let bytes = STANDARD.decode(encoded)?;
32        Ok(Self { bytes })
33    }
34
35    /// Encode the key as base64.
36    pub fn to_base64(&self) -> String {
37        use base64::{engine::general_purpose::STANDARD, Engine};
38        STANDARD.encode(&self.bytes)
39    }
40
41    /// Get the key bytes.
42    pub fn as_bytes(&self) -> &[u8] {
43        &self.bytes
44    }
45
46    /// Get the key length.
47    pub fn len(&self) -> usize {
48        self.bytes.len()
49    }
50
51    /// Check if the key is empty.
52    pub fn is_empty(&self) -> bool {
53        self.bytes.is_empty()
54    }
55}
56
57impl std::fmt::Debug for EncryptionKey {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("EncryptionKey")
60            .field("len", &self.bytes.len())
61            .field("bytes", &"[REDACTED]")
62            .finish()
63    }
64}
65
66/// Argon2id key derivation function.
67///
68/// Uses Argon2id which provides resistance against both side-channel
69/// and GPU-based attacks.
70pub struct Argon2KeyDerivation {
71    params: Params,
72}
73
74impl Default for Argon2KeyDerivation {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl Argon2KeyDerivation {
81    /// Create with default parameters (OWASP recommended).
82    pub fn new() -> Self {
83        // OWASP recommended: m=19456 (19 MiB), t=2, p=1
84        let params = Params::new(19456, 2, 1, Some(32)).expect("valid params");
85        Self { params }
86    }
87
88    /// Create with custom parameters.
89    pub fn with_params(memory_kib: u32, iterations: u32, parallelism: u32) -> CryptoResult<Self> {
90        let params = Params::new(memory_kib, iterations, parallelism, Some(32))
91            .map_err(|e| CryptoError::KeyDerivationFailed(e.to_string()))?;
92        Ok(Self { params })
93    }
94
95    /// Derive an encryption key from a password.
96    pub fn derive_encryption_key(
97        &self,
98        password: &[u8],
99        salt: &[u8],
100        key_length: usize,
101    ) -> CryptoResult<EncryptionKey> {
102        let key_bytes = self.derive_key(password, salt, key_length)?;
103        Ok(EncryptionKey::new(key_bytes))
104    }
105}
106
107impl KeyDerivation for Argon2KeyDerivation {
108    fn derive_key(&self, password: &[u8], salt: &[u8], key_length: usize) -> CryptoResult<Vec<u8>> {
109        let argon2 = Argon2::new(
110            argon2::Algorithm::Argon2id,
111            argon2::Version::V0x13,
112            self.params.clone(),
113        );
114
115        let mut output = vec![0u8; key_length];
116        argon2
117            .hash_password_into(password, salt, &mut output)
118            .map_err(|e| CryptoError::KeyDerivationFailed(e.to_string()))?;
119
120        Ok(output)
121    }
122
123    fn generate_salt(&self, length: usize) -> Vec<u8> {
124        let mut salt = vec![0u8; length];
125        rand::thread_rng().fill_bytes(&mut salt);
126        salt
127    }
128
129    fn algorithm(&self) -> &'static str {
130        "argon2id"
131    }
132}
133
134/// Versioned key for key rotation support.
135#[derive(Clone)]
136pub struct VersionedKey {
137    /// Key version number
138    pub version: u32,
139    /// The encryption key
140    pub key: EncryptionKey,
141    /// When this key was created (Unix timestamp)
142    pub created_at: u64,
143    /// Whether this key is active for new encryptions
144    pub active: bool,
145}
146
147impl VersionedKey {
148    /// Create a new versioned key.
149    pub fn new(version: u32, key: EncryptionKey) -> Self {
150        Self {
151            version,
152            key,
153            created_at: std::time::SystemTime::now()
154                .duration_since(std::time::UNIX_EPOCH)
155                .unwrap_or_default()
156                .as_secs(),
157            active: true,
158        }
159    }
160}
161
162impl std::fmt::Debug for VersionedKey {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.debug_struct("VersionedKey")
165            .field("version", &self.version)
166            .field("created_at", &self.created_at)
167            .field("active", &self.active)
168            .field("key", &"[REDACTED]")
169            .finish()
170    }
171}
172
173/// Key ring for managing multiple key versions.
174///
175/// Supports key rotation by maintaining multiple key versions.
176/// New encryptions use the active key, decryptions use the version
177/// embedded in the ciphertext.
178#[derive(Default)]
179pub struct KeyRing {
180    keys: Vec<VersionedKey>,
181}
182
183impl KeyRing {
184    /// Create an empty key ring.
185    pub fn new() -> Self {
186        Self { keys: Vec::new() }
187    }
188
189    /// Add a key to the ring.
190    pub fn add_key(&mut self, key: VersionedKey) {
191        // Deactivate any previously active keys
192        if key.active {
193            for k in &mut self.keys {
194                k.active = false;
195            }
196        }
197        self.keys.push(key);
198    }
199
200    /// Get the active key for encryption.
201    pub fn active_key(&self) -> Option<&VersionedKey> {
202        self.keys.iter().find(|k| k.active)
203    }
204
205    /// Get a key by version for decryption.
206    pub fn get_key(&self, version: u32) -> Option<&VersionedKey> {
207        self.keys.iter().find(|k| k.version == version)
208    }
209
210    /// Rotate to a new key.
211    pub fn rotate(&mut self, new_key: EncryptionKey) -> u32 {
212        let new_version = self.keys.iter().map(|k| k.version).max().unwrap_or(0) + 1;
213        self.add_key(VersionedKey::new(new_version, new_key));
214        new_version
215    }
216
217    /// Get all keys (for re-encryption during rotation).
218    pub fn all_keys(&self) -> &[VersionedKey] {
219        &self.keys
220    }
221
222    /// Number of keys in the ring.
223    pub fn len(&self) -> usize {
224        self.keys.len()
225    }
226
227    /// Check if the ring is empty.
228    pub fn is_empty(&self) -> bool {
229        self.keys.is_empty()
230    }
231}
232
233impl std::fmt::Debug for KeyRing {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.debug_struct("KeyRing")
236            .field("num_keys", &self.keys.len())
237            .field("active_version", &self.active_key().map(|k| k.version))
238            .finish()
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_encryption_key_generate() {
248        let key = EncryptionKey::generate(32);
249        assert_eq!(key.len(), 32);
250        assert!(!key.is_empty());
251    }
252
253    #[test]
254    fn test_encryption_key_base64_roundtrip() {
255        let key = EncryptionKey::generate(32);
256        let encoded = key.to_base64();
257        let decoded = EncryptionKey::from_base64(&encoded).unwrap();
258        assert_eq!(key.as_bytes(), decoded.as_bytes());
259    }
260
261    #[test]
262    fn test_argon2_key_derivation() {
263        let kdf = Argon2KeyDerivation::new();
264        let password = b"test-password";
265        let salt = kdf.generate_salt(16);
266
267        let key1 = kdf.derive_key(password, &salt, 32).unwrap();
268        let key2 = kdf.derive_key(password, &salt, 32).unwrap();
269
270        // Same password + salt should produce same key
271        assert_eq!(key1, key2);
272        assert_eq!(key1.len(), 32);
273    }
274
275    #[test]
276    fn test_argon2_different_salts() {
277        let kdf = Argon2KeyDerivation::new();
278        let password = b"test-password";
279        let salt1 = kdf.generate_salt(16);
280        let salt2 = kdf.generate_salt(16);
281
282        let key1 = kdf.derive_key(password, &salt1, 32).unwrap();
283        let key2 = kdf.derive_key(password, &salt2, 32).unwrap();
284
285        // Different salts should produce different keys
286        assert_ne!(key1, key2);
287    }
288
289    #[test]
290    fn test_key_ring_rotation() {
291        let mut ring = KeyRing::new();
292
293        let key1 = EncryptionKey::generate(32);
294        ring.add_key(VersionedKey::new(1, key1));
295
296        assert_eq!(ring.active_key().unwrap().version, 1);
297
298        let key2 = EncryptionKey::generate(32);
299        let v2 = ring.rotate(key2);
300
301        assert_eq!(v2, 2);
302        assert_eq!(ring.active_key().unwrap().version, 2);
303        assert_eq!(ring.len(), 2);
304
305        // Old key should still be retrievable
306        assert!(ring.get_key(1).is_some());
307        assert!(!ring.get_key(1).unwrap().active);
308    }
309}