cortexai_encryption/
key.rs1use crate::error::{CryptoError, CryptoResult};
4use crate::traits::KeyDerivation;
5use argon2::{Argon2, Params};
6use rand::RngCore;
7use zeroize::{Zeroize, ZeroizeOnDrop};
8
9#[derive(Clone, Zeroize, ZeroizeOnDrop)]
11pub struct EncryptionKey {
12 bytes: Vec<u8>,
13}
14
15impl EncryptionKey {
16 pub fn new(bytes: Vec<u8>) -> Self {
18 Self { bytes }
19 }
20
21 pub fn generate(length: usize) -> Self {
23 let mut bytes = vec![0u8; length];
24 rand::thread_rng().fill_bytes(&mut bytes);
25 Self { bytes }
26 }
27
28 pub fn from_base64(encoded: &str) -> CryptoResult<Self> {
30 use base64::{engine::general_purpose::STANDARD, Engine};
31 let bytes = STANDARD.decode(encoded)?;
32 Ok(Self { bytes })
33 }
34
35 pub fn to_base64(&self) -> String {
37 use base64::{engine::general_purpose::STANDARD, Engine};
38 STANDARD.encode(&self.bytes)
39 }
40
41 pub fn as_bytes(&self) -> &[u8] {
43 &self.bytes
44 }
45
46 pub fn len(&self) -> usize {
48 self.bytes.len()
49 }
50
51 pub fn is_empty(&self) -> bool {
53 self.bytes.is_empty()
54 }
55}
56
57impl std::fmt::Debug for EncryptionKey {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("EncryptionKey")
60 .field("len", &self.bytes.len())
61 .field("bytes", &"[REDACTED]")
62 .finish()
63 }
64}
65
66pub struct Argon2KeyDerivation {
71 params: Params,
72}
73
74impl Default for Argon2KeyDerivation {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl Argon2KeyDerivation {
81 pub fn new() -> Self {
83 let params = Params::new(19456, 2, 1, Some(32)).expect("valid params");
85 Self { params }
86 }
87
88 pub fn with_params(memory_kib: u32, iterations: u32, parallelism: u32) -> CryptoResult<Self> {
90 let params = Params::new(memory_kib, iterations, parallelism, Some(32))
91 .map_err(|e| CryptoError::KeyDerivationFailed(e.to_string()))?;
92 Ok(Self { params })
93 }
94
95 pub fn derive_encryption_key(
97 &self,
98 password: &[u8],
99 salt: &[u8],
100 key_length: usize,
101 ) -> CryptoResult<EncryptionKey> {
102 let key_bytes = self.derive_key(password, salt, key_length)?;
103 Ok(EncryptionKey::new(key_bytes))
104 }
105}
106
107impl KeyDerivation for Argon2KeyDerivation {
108 fn derive_key(&self, password: &[u8], salt: &[u8], key_length: usize) -> CryptoResult<Vec<u8>> {
109 let argon2 = Argon2::new(
110 argon2::Algorithm::Argon2id,
111 argon2::Version::V0x13,
112 self.params.clone(),
113 );
114
115 let mut output = vec![0u8; key_length];
116 argon2
117 .hash_password_into(password, salt, &mut output)
118 .map_err(|e| CryptoError::KeyDerivationFailed(e.to_string()))?;
119
120 Ok(output)
121 }
122
123 fn generate_salt(&self, length: usize) -> Vec<u8> {
124 let mut salt = vec![0u8; length];
125 rand::thread_rng().fill_bytes(&mut salt);
126 salt
127 }
128
129 fn algorithm(&self) -> &'static str {
130 "argon2id"
131 }
132}
133
134#[derive(Clone)]
136pub struct VersionedKey {
137 pub version: u32,
139 pub key: EncryptionKey,
141 pub created_at: u64,
143 pub active: bool,
145}
146
147impl VersionedKey {
148 pub fn new(version: u32, key: EncryptionKey) -> Self {
150 Self {
151 version,
152 key,
153 created_at: std::time::SystemTime::now()
154 .duration_since(std::time::UNIX_EPOCH)
155 .unwrap_or_default()
156 .as_secs(),
157 active: true,
158 }
159 }
160}
161
162impl std::fmt::Debug for VersionedKey {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("VersionedKey")
165 .field("version", &self.version)
166 .field("created_at", &self.created_at)
167 .field("active", &self.active)
168 .field("key", &"[REDACTED]")
169 .finish()
170 }
171}
172
173#[derive(Default)]
179pub struct KeyRing {
180 keys: Vec<VersionedKey>,
181}
182
183impl KeyRing {
184 pub fn new() -> Self {
186 Self { keys: Vec::new() }
187 }
188
189 pub fn add_key(&mut self, key: VersionedKey) {
191 if key.active {
193 for k in &mut self.keys {
194 k.active = false;
195 }
196 }
197 self.keys.push(key);
198 }
199
200 pub fn active_key(&self) -> Option<&VersionedKey> {
202 self.keys.iter().find(|k| k.active)
203 }
204
205 pub fn get_key(&self, version: u32) -> Option<&VersionedKey> {
207 self.keys.iter().find(|k| k.version == version)
208 }
209
210 pub fn rotate(&mut self, new_key: EncryptionKey) -> u32 {
212 let new_version = self.keys.iter().map(|k| k.version).max().unwrap_or(0) + 1;
213 self.add_key(VersionedKey::new(new_version, new_key));
214 new_version
215 }
216
217 pub fn all_keys(&self) -> &[VersionedKey] {
219 &self.keys
220 }
221
222 pub fn len(&self) -> usize {
224 self.keys.len()
225 }
226
227 pub fn is_empty(&self) -> bool {
229 self.keys.is_empty()
230 }
231}
232
233impl std::fmt::Debug for KeyRing {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.debug_struct("KeyRing")
236 .field("num_keys", &self.keys.len())
237 .field("active_version", &self.active_key().map(|k| k.version))
238 .finish()
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_encryption_key_generate() {
248 let key = EncryptionKey::generate(32);
249 assert_eq!(key.len(), 32);
250 assert!(!key.is_empty());
251 }
252
253 #[test]
254 fn test_encryption_key_base64_roundtrip() {
255 let key = EncryptionKey::generate(32);
256 let encoded = key.to_base64();
257 let decoded = EncryptionKey::from_base64(&encoded).unwrap();
258 assert_eq!(key.as_bytes(), decoded.as_bytes());
259 }
260
261 #[test]
262 fn test_argon2_key_derivation() {
263 let kdf = Argon2KeyDerivation::new();
264 let password = b"test-password";
265 let salt = kdf.generate_salt(16);
266
267 let key1 = kdf.derive_key(password, &salt, 32).unwrap();
268 let key2 = kdf.derive_key(password, &salt, 32).unwrap();
269
270 assert_eq!(key1, key2);
272 assert_eq!(key1.len(), 32);
273 }
274
275 #[test]
276 fn test_argon2_different_salts() {
277 let kdf = Argon2KeyDerivation::new();
278 let password = b"test-password";
279 let salt1 = kdf.generate_salt(16);
280 let salt2 = kdf.generate_salt(16);
281
282 let key1 = kdf.derive_key(password, &salt1, 32).unwrap();
283 let key2 = kdf.derive_key(password, &salt2, 32).unwrap();
284
285 assert_ne!(key1, key2);
287 }
288
289 #[test]
290 fn test_key_ring_rotation() {
291 let mut ring = KeyRing::new();
292
293 let key1 = EncryptionKey::generate(32);
294 ring.add_key(VersionedKey::new(1, key1));
295
296 assert_eq!(ring.active_key().unwrap().version, 1);
297
298 let key2 = EncryptionKey::generate(32);
299 let v2 = ring.rotate(key2);
300
301 assert_eq!(v2, 2);
302 assert_eq!(ring.active_key().unwrap().version, 2);
303 assert_eq!(ring.len(), 2);
304
305 assert!(ring.get_key(1).is_some());
307 assert!(!ring.get_key(1).unwrap().active);
308 }
309}