1use std::{collections::HashMap, sync::Arc};
9
10use fraiseql_core::security::{BaseKmsProvider, DataKeyPair, EncryptedData, KmsError, KmsResult};
11use tokio::sync::RwLock;
12
13pub mod schemas;
14
15#[cfg(test)]
16mod schema_tests;
17
18pub use schemas::{
19 EncryptionKey, ExternalAuthProviderRecord, OAuthSessionRecord, SchemaMigration,
20 SecretRotationAudit,
21};
22
23pub struct SecretManager {
25 provider: Arc<dyn BaseKmsProvider>,
27 cached_key: Arc<RwLock<Option<DataKeyPair>>>,
29 default_key_id: String,
31 context_prefix: Option<String>,
33}
34
35impl SecretManager {
36 pub fn new(provider: Arc<dyn BaseKmsProvider>, default_key_id: String) -> Self {
38 Self {
39 provider,
40 cached_key: Arc::new(RwLock::new(None)),
41 default_key_id,
42 context_prefix: None,
43 }
44 }
45
46 #[must_use]
51 pub fn with_context_prefix(mut self, prefix: String) -> Self {
52 self.context_prefix = Some(prefix);
53 self
54 }
55
56 pub async fn initialize(&self) -> KmsResult<()> {
64 let mut context = HashMap::new();
65 context.insert("purpose".to_string(), "data_encryption".to_string());
66 let context = self.build_context(context);
67
68 let data_key = self.provider.generate_data_key(&self.default_key_id, context).await?;
69
70 let mut cached = self.cached_key.write().await;
71 *cached = Some(data_key);
72
73 Ok(())
74 }
75
76 pub async fn is_initialized(&self) -> bool {
78 self.cached_key.read().await.is_some()
79 }
80
81 pub async fn rotate_cached_key(&self) -> KmsResult<()> {
89 self.initialize().await
90 }
91
92 pub async fn local_encrypt(&self, plaintext: &[u8]) -> KmsResult<Vec<u8>> {
100 let cached = self.cached_key.read().await;
101 let data_key = cached.as_ref().ok_or_else(|| KmsError::EncryptionFailed {
102 message: "SecretManager not initialized. Call initialize() at startup.".to_string(),
103 })?;
104
105 let nonce = Self::generate_nonce();
107 let ciphertext = aes_gcm_encrypt(&data_key.plaintext_key, &nonce, plaintext)?;
108
109 let mut result = nonce.to_vec();
110 result.extend_from_slice(&ciphertext);
111
112 Ok(result)
113 }
114
115 pub async fn local_decrypt(&self, encrypted: &[u8]) -> KmsResult<Vec<u8>> {
120 if encrypted.len() < 12 {
121 return Err(KmsError::DecryptionFailed {
122 message: "Encrypted data too short".to_string(),
123 });
124 }
125
126 let cached = self.cached_key.read().await;
127 let data_key = cached.as_ref().ok_or_else(|| KmsError::DecryptionFailed {
128 message: "SecretManager not initialized. Call initialize() at startup.".to_string(),
129 })?;
130
131 let nonce = &encrypted[..12];
132 let ciphertext = &encrypted[12..];
133
134 aes_gcm_decrypt(&data_key.plaintext_key, nonce, ciphertext)
135 }
136
137 pub async fn encrypt(
150 &self,
151 plaintext: &[u8],
152 key_id: Option<&str>,
153 ) -> KmsResult<EncryptedData> {
154 let key_id = key_id.unwrap_or(&self.default_key_id);
155 let mut context = HashMap::new();
156 context.insert("operation".to_string(), "encrypt".to_string());
157 let context = self.build_context(context);
158
159 self.provider.encrypt(plaintext, key_id, context).await
160 }
161
162 pub async fn decrypt(&self, encrypted: &EncryptedData) -> KmsResult<Vec<u8>> {
169 let mut context = HashMap::new();
170 context.insert("operation".to_string(), "decrypt".to_string());
171 let context = self.build_context(context);
172
173 self.provider.decrypt(encrypted, context).await
174 }
175
176 pub async fn encrypt_string(
180 &self,
181 plaintext: &str,
182 key_id: Option<&str>,
183 ) -> KmsResult<EncryptedData> {
184 let bytes = plaintext.as_bytes();
185 self.encrypt(bytes, key_id).await
186 }
187
188 pub async fn decrypt_string(&self, encrypted: &EncryptedData) -> KmsResult<String> {
190 let plaintext = self.decrypt(encrypted).await?;
191 String::from_utf8(plaintext).map_err(|e| KmsError::SerializationError {
192 message: format!("Invalid UTF-8 in decrypted data: {}", e),
193 })
194 }
195
196 fn build_context(
202 &self,
203 mut context: HashMap<String, String>,
204 ) -> Option<HashMap<String, String>> {
205 if let Some(prefix) = &self.context_prefix {
206 context.insert("service".to_string(), prefix.clone());
207 }
208
209 if context.is_empty() {
210 None
211 } else {
212 Some(context)
213 }
214 }
215
216 fn generate_nonce() -> [u8; 12] {
218 use rand::RngCore;
219 let mut nonce = [0u8; 12];
220 rand::thread_rng().fill_bytes(&mut nonce);
221 nonce
222 }
223}
224
225fn aes_gcm_encrypt(key: &[u8], nonce: &[u8], plaintext: &[u8]) -> KmsResult<Vec<u8>> {
227 use aes_gcm::{
228 Aes256Gcm, Key, Nonce,
229 aead::{Aead, KeyInit},
230 };
231
232 let key = Key::<Aes256Gcm>::from_slice(key);
233 let cipher = Aes256Gcm::new(key);
234 let nonce = Nonce::from_slice(nonce);
235
236 cipher.encrypt(nonce, plaintext).map_err(|e| KmsError::EncryptionFailed {
237 message: format!("AES-GCM encryption failed: {}", e),
238 })
239}
240
241fn aes_gcm_decrypt(key: &[u8], nonce: &[u8], ciphertext: &[u8]) -> KmsResult<Vec<u8>> {
243 use aes_gcm::{
244 Aes256Gcm, Key, Nonce,
245 aead::{Aead, KeyInit},
246 };
247
248 let key = Key::<Aes256Gcm>::from_slice(key);
249 let cipher = Aes256Gcm::new(key);
250 let nonce = Nonce::from_slice(nonce);
251
252 cipher.decrypt(nonce, ciphertext).map_err(|e| KmsError::DecryptionFailed {
253 message: format!("AES-GCM decryption failed: {}", e),
254 })
255}
256
257#[cfg(test)]
258mod tests {
259 use std::collections::HashMap;
260
261 use fraiseql_core::security::{KmsError, KmsResult};
262
263 use super::*;
264
265 struct MockKmsProvider;
267
268 #[async_trait::async_trait]
269 impl BaseKmsProvider for MockKmsProvider {
270 fn provider_name(&self) -> &'static str {
271 "mock"
272 }
273
274 async fn do_encrypt(
275 &self,
276 plaintext: &[u8],
277 _key_id: &str,
278 _context: &HashMap<String, String>,
279 ) -> KmsResult<(String, String)> {
280 Ok((base64_encode(plaintext), "mock-algorithm".to_string()))
282 }
283
284 async fn do_decrypt(
285 &self,
286 ciphertext: &str,
287 _key_id: &str,
288 _context: &HashMap<String, String>,
289 ) -> KmsResult<Vec<u8>> {
290 base64_decode(ciphertext)
291 }
292
293 async fn do_generate_data_key(
294 &self,
295 _key_id: &str,
296 _context: &HashMap<String, String>,
297 ) -> KmsResult<(Vec<u8>, String)> {
298 let key = vec![0u8; 32]; let encrypted = base64_encode(&key);
300 Ok((key, encrypted))
301 }
302
303 async fn do_rotate_key(&self, _key_id: &str) -> KmsResult<()> {
304 Ok(())
305 }
306
307 async fn do_get_key_info(
308 &self,
309 _key_id: &str,
310 ) -> KmsResult<fraiseql_core::security::kms::base::KeyInfo> {
311 Ok(fraiseql_core::security::kms::base::KeyInfo {
312 alias: Some("mock-key".to_string()),
313 created_at: 1_000_000,
314 })
315 }
316
317 async fn do_get_rotation_policy(
318 &self,
319 _key_id: &str,
320 ) -> KmsResult<fraiseql_core::security::kms::base::RotationPolicyInfo> {
321 Ok(fraiseql_core::security::kms::base::RotationPolicyInfo {
322 enabled: false,
323 rotation_period_days: 0,
324 last_rotation: None,
325 next_rotation: None,
326 })
327 }
328 }
329
330 fn base64_encode(data: &[u8]) -> String {
331 use base64::prelude::*;
332 BASE64_STANDARD.encode(data)
333 }
334
335 fn base64_decode(s: &str) -> KmsResult<Vec<u8>> {
336 use base64::prelude::*;
337 BASE64_STANDARD.decode(s).map_err(|e| KmsError::SerializationError {
338 message: e.to_string(),
339 })
340 }
341
342 #[tokio::test]
343 async fn test_secret_manager_initialization() {
344 let provider = Arc::new(MockKmsProvider);
345 let manager = SecretManager::new(provider, "test-key".to_string());
346
347 assert!(!manager.is_initialized().await);
348 assert!(manager.initialize().await.is_ok());
349 assert!(manager.is_initialized().await);
350 }
351
352 #[tokio::test]
353 async fn test_local_encrypt_decrypt_roundtrip() {
354 let provider = Arc::new(MockKmsProvider);
355 let manager = SecretManager::new(provider, "test-key".to_string());
356 manager.initialize().await.unwrap();
357
358 let plaintext = b"secret data";
359 let encrypted = manager.local_encrypt(plaintext).await.unwrap();
360 let decrypted = manager.local_decrypt(&encrypted).await.unwrap();
361
362 assert_eq!(plaintext, &decrypted[..]);
363 }
364
365 #[tokio::test]
366 async fn test_local_encrypt_without_initialization() {
367 let provider = Arc::new(MockKmsProvider);
368 let manager = SecretManager::new(provider, "test-key".to_string());
369
370 let result = manager.local_encrypt(b"secret").await;
371 assert!(result.is_err());
372 }
373
374 #[tokio::test]
375 async fn test_encrypt_decrypt_via_kms() {
376 let provider = Arc::new(MockKmsProvider);
377 let manager = SecretManager::new(provider, "test-key".to_string());
378
379 let plaintext = b"sensitive data";
380 let encrypted = manager.encrypt(plaintext, None).await.unwrap();
381 let decrypted = manager.decrypt(&encrypted).await.unwrap();
382
383 assert_eq!(plaintext, &decrypted[..]);
384 }
385
386 #[tokio::test]
387 async fn test_encrypt_string_roundtrip() {
388 let provider = Arc::new(MockKmsProvider);
389 let manager = SecretManager::new(provider, "test-key".to_string());
390
391 let plaintext = "secret string";
392 let encrypted = manager.encrypt_string(plaintext, None).await.unwrap();
393 let decrypted = manager.decrypt_string(&encrypted).await.unwrap();
394
395 assert_eq!(plaintext, decrypted);
396 }
397
398 #[tokio::test]
399 async fn test_context_prefix() {
400 let provider = Arc::new(MockKmsProvider);
401 let manager = SecretManager::new(provider, "test-key".to_string())
402 .with_context_prefix("fraiseql-prod".to_string());
403
404 assert!(manager.encrypt(b"data", None).await.is_ok());
405 }
406}