Skip to main content

fraiseql_server/secrets/
mod.rs

1//! Secrets management with KMS-backed encryption and database schemas.
2//!
3//! This module provides both:
4//! 1. Startup-time cached encryption (fast, for configuration)
5//! 2. Per-request KMS operations (slower, for sensitive data)
6//! 3. Database schema definitions for secrets management
7
8use std::{collections::HashMap, sync::Arc};
9
10use fraiseql_core::security::{BaseKmsProvider, DataKeyPair, EncryptedData, KmsError, KmsResult};
11use tokio::sync::RwLock;
12
13pub mod schemas;
14
15pub use schemas::{
16    EncryptionKey, ExternalAuthProviderRecord, OAuthSessionRecord, SchemaMigration,
17    SecretRotationAudit,
18};
19
20/// Secret manager combining cached and per-request encryption.
21pub struct SecretManager {
22    /// Primary KMS provider
23    provider:       Arc<dyn BaseKmsProvider>,
24    /// Cached data key for local encryption
25    cached_key:     Arc<RwLock<Option<DataKeyPair>>>,
26    /// Default key ID for KMS operations
27    default_key_id: String,
28    /// Context prefix for all encryption operations
29    context_prefix: Option<String>,
30}
31
32impl SecretManager {
33    /// Create a new secret manager.
34    pub fn new(provider: Arc<dyn BaseKmsProvider>, default_key_id: String) -> Self {
35        Self {
36            provider,
37            cached_key: Arc::new(RwLock::new(None)),
38            default_key_id,
39            context_prefix: None,
40        }
41    }
42
43    /// Set a context prefix (e.g., "fraiseql-prod").
44    ///
45    /// This prefix is added to all encryption contexts for additional
46    /// isolation between environments.
47    #[must_use]
48    pub fn with_context_prefix(mut self, prefix: String) -> Self {
49        self.context_prefix = Some(prefix);
50        self
51    }
52
53    /// Initialize by generating and caching a data key.
54    ///
55    /// Call this at application startup. The data key is cached in memory
56    /// for fast local encryption during the application's lifetime.
57    ///
58    /// # Errors
59    /// Returns KmsError if data key generation fails
60    pub async fn initialize(&self) -> KmsResult<()> {
61        let mut context = HashMap::new();
62        context.insert("purpose".to_string(), "data_encryption".to_string());
63        let context = self.build_context(context);
64
65        let data_key = self.provider.generate_data_key(&self.default_key_id, context).await?;
66
67        let mut cached = self.cached_key.write().await;
68        *cached = Some(data_key);
69
70        Ok(())
71    }
72
73    /// Check if a data key is cached.
74    pub async fn is_initialized(&self) -> bool {
75        self.cached_key.read().await.is_some()
76    }
77
78    /// Rotate the cached data key.
79    ///
80    /// Call this periodically to rotate keys. This regenerates the cached
81    /// data key via KMS while maintaining application uptime.
82    ///
83    /// # Errors
84    /// Returns KmsError if rotation fails
85    pub async fn rotate_cached_key(&self) -> KmsResult<()> {
86        self.initialize().await
87    }
88
89    /// Encrypt data using the cached data key (NO KMS call).
90    ///
91    /// This is fast (~microseconds) and safe for use in hot paths.
92    /// Requires `initialize()` to be called first.
93    ///
94    /// # Errors
95    /// Returns KmsError::EncryptionFailed if not initialized or encryption fails
96    pub async fn local_encrypt(&self, plaintext: &[u8]) -> KmsResult<Vec<u8>> {
97        let cached = self.cached_key.read().await;
98        let data_key = cached.as_ref().ok_or_else(|| KmsError::EncryptionFailed {
99            message: "SecretManager not initialized. Call initialize() at startup.".to_string(),
100        })?;
101
102        // Encrypt using AES-256-GCM with the cached plaintext key
103        let nonce = Self::generate_nonce();
104        let ciphertext = aes_gcm_encrypt(&data_key.plaintext_key, &nonce, plaintext)?;
105
106        let mut result = nonce.to_vec();
107        result.extend_from_slice(&ciphertext);
108
109        Ok(result)
110    }
111
112    /// Decrypt data using the cached data key (NO KMS call).
113    ///
114    /// # Errors
115    /// Returns KmsError::DecryptionFailed if not initialized or decryption fails
116    pub async fn local_decrypt(&self, encrypted: &[u8]) -> KmsResult<Vec<u8>> {
117        if encrypted.len() < 12 {
118            return Err(KmsError::DecryptionFailed {
119                message: "Encrypted data too short".to_string(),
120            });
121        }
122
123        let cached = self.cached_key.read().await;
124        let data_key = cached.as_ref().ok_or_else(|| KmsError::DecryptionFailed {
125            message: "SecretManager not initialized. Call initialize() at startup.".to_string(),
126        })?;
127
128        let nonce = &encrypted[..12];
129        let ciphertext = &encrypted[12..];
130
131        aes_gcm_decrypt(&data_key.plaintext_key, nonce, ciphertext)
132    }
133
134    /// Encrypt data using KMS (per-request operation).
135    ///
136    /// This contacts the KMS provider for each encryption, providing
137    /// per-request key isolation but with higher latency (50-200ms).
138    /// Use for secrets management, not response data.
139    ///
140    /// # Arguments
141    /// * `plaintext` - Data to encrypt
142    /// * `key_id` - KMS key identifier (or None for default)
143    ///
144    /// # Errors
145    /// Returns KmsError if encryption fails
146    pub async fn encrypt(
147        &self,
148        plaintext: &[u8],
149        key_id: Option<&str>,
150    ) -> KmsResult<EncryptedData> {
151        let key_id = key_id.unwrap_or(&self.default_key_id);
152        let mut context = HashMap::new();
153        context.insert("operation".to_string(), "encrypt".to_string());
154        let context = self.build_context(context);
155
156        self.provider.encrypt(plaintext, key_id, context).await
157    }
158
159    /// Decrypt data using KMS (per-request operation).
160    ///
161    /// Auto-detects the correct provider from EncryptedData metadata.
162    ///
163    /// # Errors
164    /// Returns KmsError if decryption fails
165    pub async fn decrypt(&self, encrypted: &EncryptedData) -> KmsResult<Vec<u8>> {
166        let mut context = HashMap::new();
167        context.insert("operation".to_string(), "decrypt".to_string());
168        let context = self.build_context(context);
169
170        self.provider.decrypt(encrypted, context).await
171    }
172
173    /// Encrypt a string field (convenience method).
174    ///
175    /// Handles UTF-8 encoding/decoding automatically.
176    pub async fn encrypt_string(
177        &self,
178        plaintext: &str,
179        key_id: Option<&str>,
180    ) -> KmsResult<EncryptedData> {
181        let bytes = plaintext.as_bytes();
182        self.encrypt(bytes, key_id).await
183    }
184
185    /// Decrypt a string field.
186    pub async fn decrypt_string(&self, encrypted: &EncryptedData) -> KmsResult<String> {
187        let plaintext = self.decrypt(encrypted).await?;
188        String::from_utf8(plaintext).map_err(|e| KmsError::SerializationError {
189            message: format!("Invalid UTF-8 in decrypted data: {}", e),
190        })
191    }
192
193    // ─────────────────────────────────────────────────────────────
194    // Private helpers
195    // ─────────────────────────────────────────────────────────────
196
197    /// Build encryption context with optional prefix.
198    fn build_context(
199        &self,
200        mut context: HashMap<String, String>,
201    ) -> Option<HashMap<String, String>> {
202        if let Some(prefix) = &self.context_prefix {
203            context.insert("service".to_string(), prefix.clone());
204        }
205
206        if context.is_empty() {
207            None
208        } else {
209            Some(context)
210        }
211    }
212
213    /// Generate a 96-bit nonce for AES-GCM.
214    fn generate_nonce() -> [u8; 12] {
215        use rand::RngCore;
216        let mut nonce = [0u8; 12];
217        rand::thread_rng().fill_bytes(&mut nonce);
218        nonce
219    }
220}
221
222/// AES-256-GCM encryption using aes-gcm.
223fn aes_gcm_encrypt(key: &[u8], nonce: &[u8], plaintext: &[u8]) -> KmsResult<Vec<u8>> {
224    use aes_gcm::{
225        Aes256Gcm, Key, Nonce,
226        aead::{Aead, KeyInit},
227    };
228
229    let key = Key::<Aes256Gcm>::from_slice(key);
230    let cipher = Aes256Gcm::new(key);
231    let nonce = Nonce::from_slice(nonce);
232
233    cipher.encrypt(nonce, plaintext).map_err(|e| KmsError::EncryptionFailed {
234        message: format!("AES-GCM encryption failed: {}", e),
235    })
236}
237
238/// AES-256-GCM decryption using aes-gcm.
239fn aes_gcm_decrypt(key: &[u8], nonce: &[u8], ciphertext: &[u8]) -> KmsResult<Vec<u8>> {
240    use aes_gcm::{
241        Aes256Gcm, Key, Nonce,
242        aead::{Aead, KeyInit},
243    };
244
245    let key = Key::<Aes256Gcm>::from_slice(key);
246    let cipher = Aes256Gcm::new(key);
247    let nonce = Nonce::from_slice(nonce);
248
249    cipher.decrypt(nonce, ciphertext).map_err(|e| KmsError::DecryptionFailed {
250        message: format!("AES-GCM decryption failed: {}", e),
251    })
252}
253
254#[cfg(test)]
255mod tests {
256    use std::collections::HashMap;
257
258    use fraiseql_core::security::{KmsError, KmsResult};
259
260    use super::*;
261
262    /// Mock KMS provider for testing
263    struct MockKmsProvider;
264
265    #[async_trait::async_trait]
266    impl BaseKmsProvider for MockKmsProvider {
267        fn provider_name(&self) -> &'static str {
268            "mock"
269        }
270
271        async fn do_encrypt(
272            &self,
273            plaintext: &[u8],
274            _key_id: &str,
275            _context: &HashMap<String, String>,
276        ) -> KmsResult<(String, String)> {
277            // Return base64-encoded plaintext as mock ciphertext
278            Ok((base64_encode(plaintext), "mock-algorithm".to_string()))
279        }
280
281        async fn do_decrypt(
282            &self,
283            ciphertext: &str,
284            _key_id: &str,
285            _context: &HashMap<String, String>,
286        ) -> KmsResult<Vec<u8>> {
287            base64_decode(ciphertext)
288        }
289
290        async fn do_generate_data_key(
291            &self,
292            _key_id: &str,
293            _context: &HashMap<String, String>,
294        ) -> KmsResult<(Vec<u8>, String)> {
295            let key = vec![0u8; 32]; // 256-bit key
296            let encrypted = base64_encode(&key);
297            Ok((key, encrypted))
298        }
299
300        async fn do_rotate_key(&self, _key_id: &str) -> KmsResult<()> {
301            Ok(())
302        }
303
304        async fn do_get_key_info(
305            &self,
306            _key_id: &str,
307        ) -> KmsResult<fraiseql_core::security::kms::base::KeyInfo> {
308            Ok(fraiseql_core::security::kms::base::KeyInfo {
309                alias:      Some("mock-key".to_string()),
310                created_at: 1_000_000,
311            })
312        }
313
314        async fn do_get_rotation_policy(
315            &self,
316            _key_id: &str,
317        ) -> KmsResult<fraiseql_core::security::kms::base::RotationPolicyInfo> {
318            Ok(fraiseql_core::security::kms::base::RotationPolicyInfo {
319                enabled:              false,
320                rotation_period_days: 0,
321                last_rotation:        None,
322                next_rotation:        None,
323            })
324        }
325    }
326
327    fn base64_encode(data: &[u8]) -> String {
328        use base64::prelude::*;
329        BASE64_STANDARD.encode(data)
330    }
331
332    fn base64_decode(s: &str) -> KmsResult<Vec<u8>> {
333        use base64::prelude::*;
334        BASE64_STANDARD.decode(s).map_err(|e| KmsError::SerializationError {
335            message: e.to_string(),
336        })
337    }
338
339    #[tokio::test]
340    async fn test_secret_manager_initialization() {
341        let provider = Arc::new(MockKmsProvider);
342        let manager = SecretManager::new(provider, "test-key".to_string());
343
344        assert!(!manager.is_initialized().await);
345        assert!(manager.initialize().await.is_ok());
346        assert!(manager.is_initialized().await);
347    }
348
349    #[tokio::test]
350    async fn test_local_encrypt_decrypt_roundtrip() {
351        let provider = Arc::new(MockKmsProvider);
352        let manager = SecretManager::new(provider, "test-key".to_string());
353        manager.initialize().await.unwrap();
354
355        let plaintext = b"secret data";
356        let encrypted = manager.local_encrypt(plaintext).await.unwrap();
357        let decrypted = manager.local_decrypt(&encrypted).await.unwrap();
358
359        assert_eq!(plaintext, &decrypted[..]);
360    }
361
362    #[tokio::test]
363    async fn test_local_encrypt_without_initialization() {
364        let provider = Arc::new(MockKmsProvider);
365        let manager = SecretManager::new(provider, "test-key".to_string());
366
367        let result = manager.local_encrypt(b"secret").await;
368        assert!(result.is_err());
369    }
370
371    #[tokio::test]
372    async fn test_encrypt_decrypt_via_kms() {
373        let provider = Arc::new(MockKmsProvider);
374        let manager = SecretManager::new(provider, "test-key".to_string());
375
376        let plaintext = b"sensitive data";
377        let encrypted = manager.encrypt(plaintext, None).await.unwrap();
378        let decrypted = manager.decrypt(&encrypted).await.unwrap();
379
380        assert_eq!(plaintext, &decrypted[..]);
381    }
382
383    #[tokio::test]
384    async fn test_encrypt_string_roundtrip() {
385        let provider = Arc::new(MockKmsProvider);
386        let manager = SecretManager::new(provider, "test-key".to_string());
387
388        let plaintext = "secret string";
389        let encrypted = manager.encrypt_string(plaintext, None).await.unwrap();
390        let decrypted = manager.decrypt_string(&encrypted).await.unwrap();
391
392        assert_eq!(plaintext, decrypted);
393    }
394
395    #[tokio::test]
396    async fn test_context_prefix() {
397        let provider = Arc::new(MockKmsProvider);
398        let manager = SecretManager::new(provider, "test-key".to_string())
399            .with_context_prefix("fraiseql-prod".to_string());
400
401        assert!(manager.encrypt(b"data", None).await.is_ok());
402    }
403}