Skip to main content

fraiseql_core/security/kms/
base.rs

1//! Base KMS provider trait with template method pattern.
2//!
3//! Provides public async methods with common logic and abstract hooks
4//! for provider-specific implementations.
5
6use std::{
7    collections::HashMap,
8    time::{SystemTime, UNIX_EPOCH},
9};
10
11use crate::security::kms::{
12    error::{KmsError, KmsResult},
13    models::{DataKeyPair, EncryptedData, KeyPurpose, KeyReference, RotationPolicy},
14};
15
16/// Get current Unix timestamp.
17fn current_timestamp() -> i64 {
18    // Safe to unwrap: u64 timestamp won't overflow i64 until year 292,277,026,596
19    i64::try_from(SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs())
20        .unwrap_or(0)
21}
22
23/// Abstract base class for KMS providers.
24///
25/// Implements the Template Method pattern:
26/// - Public methods (encrypt, decrypt, etc.) handle common logic
27/// - Protected abstract methods (do_encrypt, do_decrypt, etc.) are implemented by concrete
28///   providers
29#[async_trait::async_trait]
30pub trait BaseKmsProvider: Send + Sync {
31    /// Unique provider identifier (e.g., 'vault', 'aws', 'gcp').
32    fn provider_name(&self) -> &str;
33
34    // ─────────────────────────────────────────────────────────────
35    // Template Methods (public API)
36    // ─────────────────────────────────────────────────────────────
37
38    /// Encrypt data using the specified key.
39    ///
40    /// # Arguments
41    /// * `plaintext` - Data to encrypt
42    /// * `key_id` - Key identifier
43    /// * `context` - Additional authenticated data (AAD)
44    ///
45    /// # Returns
46    /// EncryptedData with ciphertext and metadata
47    ///
48    /// # Errors
49    /// Returns KmsError::EncryptionFailed if encryption fails
50    async fn encrypt(
51        &self,
52        plaintext: &[u8],
53        key_id: &str,
54        context: Option<HashMap<String, String>>,
55    ) -> KmsResult<EncryptedData> {
56        let ctx = context.unwrap_or_default();
57
58        let (ciphertext, algorithm) =
59            self.do_encrypt(plaintext, key_id, &ctx).await.map_err(|e| {
60                KmsError::EncryptionFailed {
61                    message: format!("Provider encryption failed: {}", e),
62                }
63            })?;
64
65        Ok(EncryptedData::new(
66            ciphertext,
67            KeyReference::new(
68                self.provider_name().to_string(),
69                key_id.to_string(),
70                KeyPurpose::EncryptDecrypt,
71                current_timestamp(),
72            ),
73            algorithm,
74            current_timestamp(),
75            ctx,
76        ))
77    }
78
79    /// Decrypt data.
80    ///
81    /// # Arguments
82    /// * `encrypted` - EncryptedData to decrypt
83    /// * `context` - Override context (uses encrypted.context if not provided)
84    ///
85    /// # Returns
86    /// Decrypted plaintext bytes
87    ///
88    /// # Errors
89    /// Returns KmsError::DecryptionFailed if decryption fails
90    async fn decrypt(
91        &self,
92        encrypted: &EncryptedData,
93        context: Option<HashMap<String, String>>,
94    ) -> KmsResult<Vec<u8>> {
95        let ctx = context.unwrap_or_else(|| encrypted.context.clone());
96        let key_id = &encrypted.key_reference.key_id;
97
98        self.do_decrypt(&encrypted.ciphertext, key_id, &ctx).await.map_err(|e| {
99            KmsError::DecryptionFailed {
100                message: format!("Provider decryption failed: {}", e),
101            }
102        })
103    }
104
105    /// Generate a data encryption key (envelope encryption).
106    ///
107    /// # Arguments
108    /// * `key_id` - Master key to wrap the data key
109    /// * `context` - Additional authenticated data
110    ///
111    /// # Returns
112    /// DataKeyPair with plaintext and encrypted data key
113    async fn generate_data_key(
114        &self,
115        key_id: &str,
116        context: Option<HashMap<String, String>>,
117    ) -> KmsResult<DataKeyPair> {
118        let ctx = context.unwrap_or_default();
119
120        let (plaintext_key, encrypted_key_bytes) = self
121            .do_generate_data_key(key_id, &ctx)
122            .await
123            .map_err(|e| KmsError::EncryptionFailed {
124                message: format!("Data key generation failed: {}", e),
125            })?;
126
127        let key_ref = KeyReference::new(
128            self.provider_name().to_string(),
129            key_id.to_string(),
130            KeyPurpose::EncryptDecrypt,
131            current_timestamp(),
132        );
133
134        Ok(DataKeyPair::new(
135            plaintext_key,
136            EncryptedData::new(
137                encrypted_key_bytes,
138                key_ref.clone(),
139                "data-key".to_string(),
140                current_timestamp(),
141                ctx,
142            ),
143            key_ref,
144        ))
145    }
146
147    /// Rotate the specified key.
148    ///
149    /// # Errors
150    /// Returns KmsError::RotationFailed if rotation fails
151    async fn rotate_key(&self, key_id: &str) -> KmsResult<KeyReference> {
152        self.do_rotate_key(key_id).await.map_err(|e| KmsError::RotationFailed {
153            message: format!("Provider rotation failed: {}", e),
154        })?;
155
156        self.get_key_info(key_id).await
157    }
158
159    /// Get key metadata.
160    ///
161    /// # Errors
162    /// Returns KmsError::KeyNotFound if key does not exist
163    async fn get_key_info(&self, key_id: &str) -> KmsResult<KeyReference> {
164        let info = self.do_get_key_info(key_id).await.map_err(|_e| KmsError::KeyNotFound {
165            key_id: key_id.to_string(),
166        })?;
167
168        Ok(KeyReference::new(
169            self.provider_name().to_string(),
170            key_id.to_string(),
171            KeyPurpose::EncryptDecrypt,
172            info.created_at,
173        )
174        .with_alias(info.alias.unwrap_or_default()))
175    }
176
177    /// Get key rotation policy.
178    ///
179    /// # Errors
180    /// Returns KmsError::KeyNotFound if key does not exist
181    async fn get_rotation_policy(&self, key_id: &str) -> KmsResult<RotationPolicy> {
182        let policy =
183            self.do_get_rotation_policy(key_id).await.map_err(|_e| KmsError::KeyNotFound {
184                key_id: key_id.to_string(),
185            })?;
186
187        Ok(RotationPolicy {
188            enabled:              policy.enabled,
189            rotation_period_days: policy.rotation_period_days,
190            last_rotation:        policy.last_rotation,
191            next_rotation:        policy.next_rotation,
192        })
193    }
194
195    // ─────────────────────────────────────────────────────────────
196    // Abstract Methods (provider-specific hooks)
197    // ─────────────────────────────────────────────────────────────
198
199    /// Provider-specific encryption.
200    ///
201    /// # Arguments
202    /// * `plaintext` - Data to encrypt
203    /// * `key_id` - Key identifier
204    /// * `context` - AAD context (never empty)
205    ///
206    /// # Returns
207    /// Tuple of (ciphertext, algorithm_name) on success
208    async fn do_encrypt(
209        &self,
210        plaintext: &[u8],
211        key_id: &str,
212        context: &HashMap<String, String>,
213    ) -> KmsResult<(String, String)>;
214
215    /// Provider-specific decryption.
216    ///
217    /// # Arguments
218    /// * `ciphertext` - Data to decrypt
219    /// * `key_id` - Key identifier
220    /// * `context` - AAD context (never empty)
221    ///
222    /// # Returns
223    /// Decrypted plaintext bytes
224    async fn do_decrypt(
225        &self,
226        ciphertext: &str,
227        key_id: &str,
228        context: &HashMap<String, String>,
229    ) -> KmsResult<Vec<u8>>;
230
231    /// Provider-specific data key generation.
232    ///
233    /// # Arguments
234    /// * `key_id` - Master key identifier
235    /// * `context` - AAD context (never empty)
236    ///
237    /// # Returns
238    /// Tuple of (plaintext_key, encrypted_key_hex)
239    async fn do_generate_data_key(
240        &self,
241        key_id: &str,
242        context: &HashMap<String, String>,
243    ) -> KmsResult<(Vec<u8>, String)>;
244
245    /// Provider-specific key rotation.
246    async fn do_rotate_key(&self, key_id: &str) -> KmsResult<()>;
247
248    /// Provider-specific key info retrieval.
249    ///
250    /// Returns KeyInfo struct with alias and created_at
251    async fn do_get_key_info(&self, key_id: &str) -> KmsResult<KeyInfo>;
252
253    /// Provider-specific rotation policy retrieval.
254    async fn do_get_rotation_policy(&self, key_id: &str) -> KmsResult<RotationPolicyInfo>;
255}
256
257/// Key information returned by provider.
258#[derive(Debug, Clone)]
259pub struct KeyInfo {
260    pub alias:      Option<String>,
261    pub created_at: i64,
262}
263
264/// Rotation policy info returned by provider.
265#[derive(Debug, Clone)]
266pub struct RotationPolicyInfo {
267    pub enabled:              bool,
268    pub rotation_period_days: u32,
269    pub last_rotation:        Option<i64>,
270    pub next_rotation:        Option<i64>,
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[tokio::test]
278    async fn test_current_timestamp_is_positive() {
279        let ts = current_timestamp();
280        assert!(ts > 0);
281    }
282}