1use aes_gcm::{
7 aead::{Aead, KeyInit},
8 Aes256Gcm, Nonce, Key,
9};
10use hkdf::Hkdf;
11use sha2::Sha256;
12use base64::{Engine as _, engine::general_purpose};
13use getrandom;
14use std::collections::HashMap;
15
16use crate::{KVError, KVResult, DatabaseId};
17
18#[derive(Debug, Clone)]
20pub struct DatabaseKey {
21 pub database_id: DatabaseId,
22 pub key: Key<Aes256Gcm>,
23 pub nonce: [u8; 12], }
25
26impl DatabaseKey {
27 pub fn derive(master_key: &[u8], database_id: DatabaseId) -> KVResult<Self> {
32 if master_key.len() < 32 {
33 return Err(KVError::Encryption("Master key must be at least 32 bytes".to_string()));
34 }
35
36 let hk = Hkdf::<Sha256>::new(None, master_key);
38 let mut derived_key = [0u8; 32];
39 let info = format!("kv-database-{database_id}");
40 hk.expand(info.as_bytes(), &mut derived_key)
41 .map_err(|e| KVError::Encryption(format!("Key derivation failed: {e}")))?;
42
43 let mut nonce = [0u8; 12];
45 getrandom::fill(&mut nonce).map_err(|e| KVError::Encryption(format!("Failed to generate random nonce: {e}")))?;
46
47 let key = Key::<Aes256Gcm>::from_slice(&derived_key);
48
49 Ok(Self {
50 database_id,
51 key: *key,
52 nonce,
53 })
54 }
55
56 pub fn encrypt(&self, data: &[u8]) -> KVResult<Vec<u8>> {
61 let cipher = Aes256Gcm::new(&self.key);
62 let nonce = Nonce::from_slice(&self.nonce);
63
64 cipher.encrypt(nonce, data)
65 .map_err(|e| KVError::Encryption(format!("Encryption failed: {e}")))
66 }
67
68 pub fn decrypt(&self, encrypted_data: &[u8]) -> KVResult<Vec<u8>> {
73 let cipher = Aes256Gcm::new(&self.key);
74 let nonce = Nonce::from_slice(&self.nonce);
75
76 cipher.decrypt(nonce, encrypted_data)
77 .map_err(|e| KVError::Encryption(format!("Decryption failed: {e}")))
78 }
79}
80
81pub struct KeyManager {
83 master_key: Vec<u8>,
84 database_keys: HashMap<DatabaseId, DatabaseKey>,
85}
86
87impl KeyManager {
88 pub fn new(master_key: &str) -> KVResult<Self> {
93 let master_key = if master_key.is_empty() {
94 let mut key = [0u8; 32];
96 getrandom::fill(&mut key).map_err(|e| KVError::Encryption(format!("Failed to generate random key: {e}")))?;
97 key.to_vec()
98 } else {
99 general_purpose::STANDARD.decode(master_key)
101 .map_err(|e| KVError::Encryption(format!("Invalid base64 master key: {e}")))?
102 };
103
104 Ok(Self {
105 master_key,
106 database_keys: HashMap::new(),
107 })
108 }
109
110 pub fn get_database_key(&mut self, database_id: DatabaseId) -> KVResult<&DatabaseKey> {
115 if !self.database_keys.contains_key(&database_id) {
116 let db_key = DatabaseKey::derive(&self.master_key, database_id)?;
117 self.database_keys.insert(database_id, db_key);
118 }
119
120 self.database_keys.get(&database_id)
121 .ok_or_else(|| KVError::Encryption("Failed to get database key".to_string()))
122 }
123
124 #[must_use]
126 pub fn master_key_base64(&self) -> String {
127 general_purpose::STANDARD.encode(&self.master_key)
128 }
129
130 pub fn rotate_master_key(&mut self) -> KVResult<String> {
135 let mut new_master_key = [0u8; 32];
137 getrandom::fill(&mut new_master_key).map_err(|e| KVError::Encryption(format!("Failed to generate random key: {e}")))?;
138
139 self.master_key = new_master_key.to_vec();
141
142 self.database_keys.clear();
144
145 Ok(self.master_key_base64())
146 }
147
148 pub fn encrypt(&mut self, database_id: DatabaseId, data: &[u8]) -> KVResult<Vec<u8>> {
153 let db_key = self.get_database_key(database_id)?;
154 db_key.encrypt(data)
155 }
156
157 pub fn decrypt(&mut self, database_id: DatabaseId, encrypted_data: &[u8]) -> KVResult<Vec<u8>> {
162 let db_key = self.get_database_key(database_id)?;
163 db_key.decrypt(encrypted_data)
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct EncryptedData {
170 pub database_id: DatabaseId,
171 pub encrypted_data: Vec<u8>,
172 pub nonce: [u8; 12],
173}
174
175impl EncryptedData {
176 pub fn encrypt(data: &[u8], database_id: DatabaseId, key_manager: &mut KeyManager) -> KVResult<Self> {
181 let encrypted_data = key_manager.encrypt(database_id, data)?;
182
183 Ok(Self {
184 database_id,
185 encrypted_data,
186 nonce: [0u8; 12], })
188 }
189
190 pub fn decrypt(&self, key_manager: &mut KeyManager) -> KVResult<Vec<u8>> {
195 key_manager.decrypt(self.database_id, &self.encrypted_data)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_key_derivation() {
205 let master_key = b"0123456789abcdef0123456789abcdef"; let db_key1 = DatabaseKey::derive(master_key, 0).unwrap();
208 let db_key2 = DatabaseKey::derive(master_key, 1).unwrap();
209
210 assert_ne!(db_key1.key, db_key2.key);
212 }
213
214 #[test]
215 fn test_encryption_decryption() {
216 let mut key_manager = KeyManager::new("").unwrap();
218
219 let data = b"Hello, encrypted world!";
220 let encrypted = key_manager.encrypt(0, data).unwrap();
221 let decrypted = key_manager.decrypt(0, &encrypted).unwrap();
222
223 assert_eq!(data, &decrypted[..]);
224 }
225
226 #[test]
227 fn test_key_rotation() {
228 let mut key_manager = KeyManager::new("").unwrap();
229 let old_master_key = key_manager.master_key_base64();
230
231 let new_master_key = key_manager.rotate_master_key().unwrap();
232
233 assert_ne!(old_master_key, new_master_key);
234 assert!(key_manager.database_keys.is_empty());
235 }
236
237 #[test]
238 fn test_encrypted_data_wrapper() {
239 let mut key_manager = KeyManager::new("").unwrap();
240 let data = b"Test data for encryption wrapper";
241
242 let encrypted = EncryptedData::encrypt(data, 0, &mut key_manager).unwrap();
243 let decrypted = encrypted.decrypt(&mut key_manager).unwrap();
244
245 assert_eq!(data, &decrypted[..]);
246 }
247}