1use aes_gcm::{
7    aead::{Aead, KeyInit},
8    Aes256Gcm, Nonce, Key,
9};
10use hkdf::Hkdf;
11use sha2::Sha256;
12use base64::{Engine as _, engine::general_purpose};
13use getrandom;
14use std::collections::HashMap;
15
16use crate::{KVError, KVResult, DatabaseId};
17
18#[derive(Debug, Clone)]
20pub struct DatabaseKey {
21    pub database_id: DatabaseId,
22    pub key: Key<Aes256Gcm>,
23    pub nonce: [u8; 12], }
25
26impl DatabaseKey {
27    pub fn derive(master_key: &[u8], database_id: DatabaseId) -> KVResult<Self> {
32        if master_key.len() < 32 {
33            return Err(KVError::Encryption("Master key must be at least 32 bytes".to_string()));
34        }
35
36        let hk = Hkdf::<Sha256>::new(None, master_key);
38        let mut derived_key = [0u8; 32];
39        let info = format!("kv-database-{database_id}");
40        hk.expand(info.as_bytes(), &mut derived_key)
41            .map_err(|e| KVError::Encryption(format!("Key derivation failed: {e}")))?;
42
43        let mut nonce = [0u8; 12];
45        getrandom::fill(&mut nonce).map_err(|e| KVError::Encryption(format!("Failed to generate random nonce: {e}")))?;
46
47        let key = Key::<Aes256Gcm>::from_slice(&derived_key);
48        
49        Ok(Self {
50            database_id,
51            key: *key,
52            nonce,
53        })
54    }
55
56    pub fn encrypt(&self, data: &[u8]) -> KVResult<Vec<u8>> {
61        let cipher = Aes256Gcm::new(&self.key);
62        let nonce = Nonce::from_slice(&self.nonce);
63        
64        cipher.encrypt(nonce, data)
65            .map_err(|e| KVError::Encryption(format!("Encryption failed: {e}")))
66    }
67
68    pub fn decrypt(&self, encrypted_data: &[u8]) -> KVResult<Vec<u8>> {
73        let cipher = Aes256Gcm::new(&self.key);
74        let nonce = Nonce::from_slice(&self.nonce);
75        
76        cipher.decrypt(nonce, encrypted_data)
77            .map_err(|e| KVError::Encryption(format!("Decryption failed: {e}")))
78    }
79}
80
81pub struct KeyManager {
83    master_key: Vec<u8>,
84    database_keys: HashMap<DatabaseId, DatabaseKey>,
85}
86
87impl KeyManager {
88    pub fn new(master_key: &str) -> KVResult<Self> {
93        let master_key = if master_key.is_empty() {
94            let mut key = [0u8; 32];
96            getrandom::fill(&mut key).map_err(|e| KVError::Encryption(format!("Failed to generate random key: {e}")))?;
97            key.to_vec()
98        } else {
99            general_purpose::STANDARD.decode(master_key)
101                .map_err(|e| KVError::Encryption(format!("Invalid base64 master key: {e}")))?
102        };
103
104        Ok(Self {
105            master_key,
106            database_keys: HashMap::new(),
107        })
108    }
109
110    pub fn get_database_key(&mut self, database_id: DatabaseId) -> KVResult<&DatabaseKey> {
115        if !self.database_keys.contains_key(&database_id) {
116            let db_key = DatabaseKey::derive(&self.master_key, database_id)?;
117            self.database_keys.insert(database_id, db_key);
118        }
119        
120        self.database_keys.get(&database_id)
121            .ok_or_else(|| KVError::Encryption("Failed to get database key".to_string()))
122    }
123
124    #[must_use] 
126    pub fn master_key_base64(&self) -> String {
127        general_purpose::STANDARD.encode(&self.master_key)
128    }
129
130    pub fn rotate_master_key(&mut self) -> KVResult<String> {
135        let mut new_master_key = [0u8; 32];
137        getrandom::fill(&mut new_master_key).map_err(|e| KVError::Encryption(format!("Failed to generate random key: {e}")))?;
138        
139        self.master_key = new_master_key.to_vec();
141        
142        self.database_keys.clear();
144        
145        Ok(self.master_key_base64())
146    }
147
148    pub fn encrypt(&mut self, database_id: DatabaseId, data: &[u8]) -> KVResult<Vec<u8>> {
153        let db_key = self.get_database_key(database_id)?;
154        db_key.encrypt(data)
155    }
156
157    pub fn decrypt(&mut self, database_id: DatabaseId, encrypted_data: &[u8]) -> KVResult<Vec<u8>> {
162        let db_key = self.get_database_key(database_id)?;
163        db_key.decrypt(encrypted_data)
164    }
165}
166
167#[derive(Debug, Clone)]
169pub struct EncryptedData {
170    pub database_id: DatabaseId,
171    pub encrypted_data: Vec<u8>,
172    pub nonce: [u8; 12],
173}
174
175impl EncryptedData {
176    pub fn encrypt(data: &[u8], database_id: DatabaseId, key_manager: &mut KeyManager) -> KVResult<Self> {
181        let encrypted_data = key_manager.encrypt(database_id, data)?;
182        
183        Ok(Self {
184            database_id,
185            encrypted_data,
186            nonce: [0u8; 12], })
188    }
189
190    pub fn decrypt(&self, key_manager: &mut KeyManager) -> KVResult<Vec<u8>> {
195        key_manager.decrypt(self.database_id, &self.encrypted_data)
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_key_derivation() {
205        let master_key = b"0123456789abcdef0123456789abcdef"; let db_key1 = DatabaseKey::derive(master_key, 0).unwrap();
208        let db_key2 = DatabaseKey::derive(master_key, 1).unwrap();
209        
210        assert_ne!(db_key1.key, db_key2.key);
212    }
213
214    #[test]
215    fn test_encryption_decryption() {
216        let mut key_manager = KeyManager::new("").unwrap();
218        
219        let data = b"Hello, encrypted world!";
220        let encrypted = key_manager.encrypt(0, data).unwrap();
221        let decrypted = key_manager.decrypt(0, &encrypted).unwrap();
222        
223        assert_eq!(data, &decrypted[..]);
224    }
225
226    #[test]
227    fn test_key_rotation() {
228        let mut key_manager = KeyManager::new("").unwrap();
229        let old_master_key = key_manager.master_key_base64();
230        
231        let new_master_key = key_manager.rotate_master_key().unwrap();
232        
233        assert_ne!(old_master_key, new_master_key);
234        assert!(key_manager.database_keys.is_empty());
235    }
236
237    #[test]
238    fn test_encrypted_data_wrapper() {
239        let mut key_manager = KeyManager::new("").unwrap();
240        let data = b"Test data for encryption wrapper";
241        
242        let encrypted = EncryptedData::encrypt(data, 0, &mut key_manager).unwrap();
243        let decrypted = encrypted.decrypt(&mut key_manager).unwrap();
244        
245        assert_eq!(data, &decrypted[..]);
246    }
247}