1use std::{collections::HashMap, sync::Arc};
9
10use fraiseql_core::security::{BaseKmsProvider, DataKeyPair, EncryptedData, KmsError, KmsResult};
11use tokio::sync::RwLock;
12
13pub mod schemas;
14
15pub use schemas::{
16 EncryptionKey, ExternalAuthProviderRecord, OAuthSessionRecord, SchemaMigration,
17 SecretRotationAudit,
18};
19
20pub struct SecretManager {
22 provider: Arc<dyn BaseKmsProvider>,
24 cached_key: Arc<RwLock<Option<DataKeyPair>>>,
26 default_key_id: String,
28 context_prefix: Option<String>,
30}
31
32impl SecretManager {
33 pub fn new(provider: Arc<dyn BaseKmsProvider>, default_key_id: String) -> Self {
35 Self {
36 provider,
37 cached_key: Arc::new(RwLock::new(None)),
38 default_key_id,
39 context_prefix: None,
40 }
41 }
42
43 #[must_use]
48 pub fn with_context_prefix(mut self, prefix: String) -> Self {
49 self.context_prefix = Some(prefix);
50 self
51 }
52
53 pub async fn initialize(&self) -> KmsResult<()> {
61 let mut context = HashMap::new();
62 context.insert("purpose".to_string(), "data_encryption".to_string());
63 let context = self.build_context(context);
64
65 let data_key = self.provider.generate_data_key(&self.default_key_id, context).await?;
66
67 let mut cached = self.cached_key.write().await;
68 *cached = Some(data_key);
69
70 Ok(())
71 }
72
73 pub async fn is_initialized(&self) -> bool {
75 self.cached_key.read().await.is_some()
76 }
77
78 pub async fn rotate_cached_key(&self) -> KmsResult<()> {
86 self.initialize().await
87 }
88
89 pub async fn local_encrypt(&self, plaintext: &[u8]) -> KmsResult<Vec<u8>> {
97 let cached = self.cached_key.read().await;
98 let data_key = cached.as_ref().ok_or_else(|| KmsError::EncryptionFailed {
99 message: "SecretManager not initialized. Call initialize() at startup.".to_string(),
100 })?;
101
102 let nonce = Self::generate_nonce();
104 let ciphertext = aes_gcm_encrypt(&data_key.plaintext_key, &nonce, plaintext)?;
105
106 let mut result = nonce.to_vec();
107 result.extend_from_slice(&ciphertext);
108
109 Ok(result)
110 }
111
112 pub async fn local_decrypt(&self, encrypted: &[u8]) -> KmsResult<Vec<u8>> {
117 if encrypted.len() < 12 {
118 return Err(KmsError::DecryptionFailed {
119 message: "Encrypted data too short".to_string(),
120 });
121 }
122
123 let cached = self.cached_key.read().await;
124 let data_key = cached.as_ref().ok_or_else(|| KmsError::DecryptionFailed {
125 message: "SecretManager not initialized. Call initialize() at startup.".to_string(),
126 })?;
127
128 let nonce = &encrypted[..12];
129 let ciphertext = &encrypted[12..];
130
131 aes_gcm_decrypt(&data_key.plaintext_key, nonce, ciphertext)
132 }
133
134 pub async fn encrypt(
147 &self,
148 plaintext: &[u8],
149 key_id: Option<&str>,
150 ) -> KmsResult<EncryptedData> {
151 let key_id = key_id.unwrap_or(&self.default_key_id);
152 let mut context = HashMap::new();
153 context.insert("operation".to_string(), "encrypt".to_string());
154 let context = self.build_context(context);
155
156 self.provider.encrypt(plaintext, key_id, context).await
157 }
158
159 pub async fn decrypt(&self, encrypted: &EncryptedData) -> KmsResult<Vec<u8>> {
166 let mut context = HashMap::new();
167 context.insert("operation".to_string(), "decrypt".to_string());
168 let context = self.build_context(context);
169
170 self.provider.decrypt(encrypted, context).await
171 }
172
173 pub async fn encrypt_string(
177 &self,
178 plaintext: &str,
179 key_id: Option<&str>,
180 ) -> KmsResult<EncryptedData> {
181 let bytes = plaintext.as_bytes();
182 self.encrypt(bytes, key_id).await
183 }
184
185 pub async fn decrypt_string(&self, encrypted: &EncryptedData) -> KmsResult<String> {
187 let plaintext = self.decrypt(encrypted).await?;
188 String::from_utf8(plaintext).map_err(|e| KmsError::SerializationError {
189 message: format!("Invalid UTF-8 in decrypted data: {}", e),
190 })
191 }
192
193 fn build_context(
199 &self,
200 mut context: HashMap<String, String>,
201 ) -> Option<HashMap<String, String>> {
202 if let Some(prefix) = &self.context_prefix {
203 context.insert("service".to_string(), prefix.clone());
204 }
205
206 if context.is_empty() {
207 None
208 } else {
209 Some(context)
210 }
211 }
212
213 fn generate_nonce() -> [u8; 12] {
215 use rand::RngCore;
216 let mut nonce = [0u8; 12];
217 rand::thread_rng().fill_bytes(&mut nonce);
218 nonce
219 }
220}
221
222fn aes_gcm_encrypt(key: &[u8], nonce: &[u8], plaintext: &[u8]) -> KmsResult<Vec<u8>> {
224 use aes_gcm::{
225 Aes256Gcm, Key, Nonce,
226 aead::{Aead, KeyInit},
227 };
228
229 let key = Key::<Aes256Gcm>::from_slice(key);
230 let cipher = Aes256Gcm::new(key);
231 let nonce = Nonce::from_slice(nonce);
232
233 cipher.encrypt(nonce, plaintext).map_err(|e| KmsError::EncryptionFailed {
234 message: format!("AES-GCM encryption failed: {}", e),
235 })
236}
237
238fn aes_gcm_decrypt(key: &[u8], nonce: &[u8], ciphertext: &[u8]) -> KmsResult<Vec<u8>> {
240 use aes_gcm::{
241 Aes256Gcm, Key, Nonce,
242 aead::{Aead, KeyInit},
243 };
244
245 let key = Key::<Aes256Gcm>::from_slice(key);
246 let cipher = Aes256Gcm::new(key);
247 let nonce = Nonce::from_slice(nonce);
248
249 cipher.decrypt(nonce, ciphertext).map_err(|e| KmsError::DecryptionFailed {
250 message: format!("AES-GCM decryption failed: {}", e),
251 })
252}
253
254#[cfg(test)]
255mod tests {
256 use std::collections::HashMap;
257
258 use fraiseql_core::security::{KmsError, KmsResult};
259
260 use super::*;
261
262 struct MockKmsProvider;
264
265 #[async_trait::async_trait]
266 impl BaseKmsProvider for MockKmsProvider {
267 fn provider_name(&self) -> &'static str {
268 "mock"
269 }
270
271 async fn do_encrypt(
272 &self,
273 plaintext: &[u8],
274 _key_id: &str,
275 _context: &HashMap<String, String>,
276 ) -> KmsResult<(String, String)> {
277 Ok((base64_encode(plaintext), "mock-algorithm".to_string()))
279 }
280
281 async fn do_decrypt(
282 &self,
283 ciphertext: &str,
284 _key_id: &str,
285 _context: &HashMap<String, String>,
286 ) -> KmsResult<Vec<u8>> {
287 base64_decode(ciphertext)
288 }
289
290 async fn do_generate_data_key(
291 &self,
292 _key_id: &str,
293 _context: &HashMap<String, String>,
294 ) -> KmsResult<(Vec<u8>, String)> {
295 let key = vec![0u8; 32]; let encrypted = base64_encode(&key);
297 Ok((key, encrypted))
298 }
299
300 async fn do_rotate_key(&self, _key_id: &str) -> KmsResult<()> {
301 Ok(())
302 }
303
304 async fn do_get_key_info(
305 &self,
306 _key_id: &str,
307 ) -> KmsResult<fraiseql_core::security::kms::base::KeyInfo> {
308 Ok(fraiseql_core::security::kms::base::KeyInfo {
309 alias: Some("mock-key".to_string()),
310 created_at: 1_000_000,
311 })
312 }
313
314 async fn do_get_rotation_policy(
315 &self,
316 _key_id: &str,
317 ) -> KmsResult<fraiseql_core::security::kms::base::RotationPolicyInfo> {
318 Ok(fraiseql_core::security::kms::base::RotationPolicyInfo {
319 enabled: false,
320 rotation_period_days: 0,
321 last_rotation: None,
322 next_rotation: None,
323 })
324 }
325 }
326
327 fn base64_encode(data: &[u8]) -> String {
328 use base64::prelude::*;
329 BASE64_STANDARD.encode(data)
330 }
331
332 fn base64_decode(s: &str) -> KmsResult<Vec<u8>> {
333 use base64::prelude::*;
334 BASE64_STANDARD.decode(s).map_err(|e| KmsError::SerializationError {
335 message: e.to_string(),
336 })
337 }
338
339 #[tokio::test]
340 async fn test_secret_manager_initialization() {
341 let provider = Arc::new(MockKmsProvider);
342 let manager = SecretManager::new(provider, "test-key".to_string());
343
344 assert!(!manager.is_initialized().await);
345 assert!(manager.initialize().await.is_ok());
346 assert!(manager.is_initialized().await);
347 }
348
349 #[tokio::test]
350 async fn test_local_encrypt_decrypt_roundtrip() {
351 let provider = Arc::new(MockKmsProvider);
352 let manager = SecretManager::new(provider, "test-key".to_string());
353 manager.initialize().await.unwrap();
354
355 let plaintext = b"secret data";
356 let encrypted = manager.local_encrypt(plaintext).await.unwrap();
357 let decrypted = manager.local_decrypt(&encrypted).await.unwrap();
358
359 assert_eq!(plaintext, &decrypted[..]);
360 }
361
362 #[tokio::test]
363 async fn test_local_encrypt_without_initialization() {
364 let provider = Arc::new(MockKmsProvider);
365 let manager = SecretManager::new(provider, "test-key".to_string());
366
367 let result = manager.local_encrypt(b"secret").await;
368 assert!(result.is_err());
369 }
370
371 #[tokio::test]
372 async fn test_encrypt_decrypt_via_kms() {
373 let provider = Arc::new(MockKmsProvider);
374 let manager = SecretManager::new(provider, "test-key".to_string());
375
376 let plaintext = b"sensitive data";
377 let encrypted = manager.encrypt(plaintext, None).await.unwrap();
378 let decrypted = manager.decrypt(&encrypted).await.unwrap();
379
380 assert_eq!(plaintext, &decrypted[..]);
381 }
382
383 #[tokio::test]
384 async fn test_encrypt_string_roundtrip() {
385 let provider = Arc::new(MockKmsProvider);
386 let manager = SecretManager::new(provider, "test-key".to_string());
387
388 let plaintext = "secret string";
389 let encrypted = manager.encrypt_string(plaintext, None).await.unwrap();
390 let decrypted = manager.decrypt_string(&encrypted).await.unwrap();
391
392 assert_eq!(plaintext, decrypted);
393 }
394
395 #[tokio::test]
396 async fn test_context_prefix() {
397 let provider = Arc::new(MockKmsProvider);
398 let manager = SecretManager::new(provider, "test-key".to_string())
399 .with_context_prefix("fraiseql-prod".to_string());
400
401 assert!(manager.encrypt(b"data", None).await.is_ok());
402 }
403}