Skip to main content

fraiseql_core/security/kms/
vault.rs

1//! HashiCorp Vault Transit secrets engine provider.
2
3use std::collections::HashMap;
4
5use serde_json::json;
6
7use crate::security::kms::{
8    base::{BaseKmsProvider, KeyInfo, RotationPolicyInfo},
9    error::{KmsError, KmsResult},
10};
11
12/// Configuration for Vault KMS provider.
13///
14/// # Security Considerations
15/// Token Handling:
16/// - The Vault token is stored in memory for the provider's lifetime
17/// - For production deployments, consider:
18///   1. Using short-lived tokens with automatic renewal
19///   2. Vault Agent with auto-auth for token management
20///   3. AppRole authentication with response wrapping
21///   4. Kubernetes auth method in K8s environments
22#[derive(Debug, Clone)]
23pub struct VaultConfig {
24    /// Vault server address (e.g., `https://vault.example.com`)
25    pub vault_addr: String,
26    /// Vault authentication token
27    pub token:      String,
28    /// Transit mount path (default: "transit")
29    pub mount_path: String,
30    /// Optional Vault namespace
31    pub namespace:  Option<String>,
32    /// Verify TLS certificates (default: true)
33    pub verify_tls: bool,
34    /// Request timeout in seconds (default: 30)
35    pub timeout:    u64,
36}
37
38impl VaultConfig {
39    /// Create a new Vault configuration.
40    pub fn new(vault_addr: String, token: String) -> Self {
41        Self {
42            vault_addr,
43            token,
44            mount_path: "transit".to_string(),
45            namespace: None,
46            verify_tls: true,
47            timeout: 30,
48        }
49    }
50
51    /// Set the transit mount path.
52    #[must_use]
53    pub fn with_mount_path(mut self, mount_path: String) -> Self {
54        self.mount_path = mount_path;
55        self
56    }
57
58    /// Set the Vault namespace.
59    #[must_use]
60    pub fn with_namespace(mut self, namespace: String) -> Self {
61        self.namespace = Some(namespace);
62        self
63    }
64
65    /// Set TLS verification.
66    #[must_use]
67    pub fn with_verify_tls(mut self, verify_tls: bool) -> Self {
68        self.verify_tls = verify_tls;
69        self
70    }
71
72    /// Set request timeout in seconds.
73    #[must_use]
74    pub fn with_timeout(mut self, timeout: u64) -> Self {
75        self.timeout = timeout;
76        self
77    }
78
79    /// Build full API URL for a path.
80    fn api_url(&self, path: &str) -> String {
81        let addr = self.vault_addr.trim_end_matches('/');
82        format!("{}/v1/{}/{}", addr, self.mount_path, path)
83    }
84}
85
86/// HashiCorp Vault Transit secrets engine provider.
87///
88/// Uses Vault's Transit secrets engine for encryption/decryption operations.
89/// Supports envelope encryption via data key generation.
90///
91/// All operations use authenticated encryption (AES-256-GCM).
92pub struct VaultKmsProvider {
93    config: VaultConfig,
94    client: reqwest::Client,
95}
96
97impl VaultKmsProvider {
98    /// Create a new Vault KMS provider.
99    pub fn new(config: VaultConfig) -> KmsResult<Self> {
100        let client = reqwest::Client::new();
101        Ok(Self { config, client })
102    }
103
104    /// Build a request with Vault headers.
105    fn build_headers(&self) -> reqwest::header::HeaderMap {
106        let mut headers = reqwest::header::HeaderMap::new();
107
108        headers.insert(
109            "X-Vault-Token",
110            reqwest::header::HeaderValue::from_str(&self.config.token)
111                .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("")),
112        );
113
114        if let Some(namespace) = &self.config.namespace {
115            headers.insert(
116                "X-Vault-Namespace",
117                reqwest::header::HeaderValue::from_str(namespace)
118                    .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("")),
119            );
120        }
121
122        headers
123    }
124}
125
126#[async_trait::async_trait]
127impl BaseKmsProvider for VaultKmsProvider {
128    fn provider_name(&self) -> &'static str {
129        "vault"
130    }
131
132    async fn do_encrypt(
133        &self,
134        plaintext: &[u8],
135        key_id: &str,
136        context: &HashMap<String, String>,
137    ) -> KmsResult<(String, String)> {
138        let url = self.config.api_url(&format!("encrypt/{}", key_id));
139
140        let plaintext_b64 = base64_encode(plaintext);
141
142        let mut payload = json!({
143            "plaintext": plaintext_b64,
144        });
145
146        // Add context if provided (used for key derivation)
147        if !context.is_empty() {
148            let context_json =
149                serde_json::to_string(context).map_err(|e| KmsError::SerializationError {
150                    message: e.to_string(),
151                })?;
152            let context_b64 = base64_encode(context_json.as_bytes());
153            payload["context"] = json!(context_b64);
154        }
155
156        let response = self
157            .client
158            .post(&url)
159            .headers(self.build_headers())
160            .json(&payload)
161            .timeout(std::time::Duration::from_secs(self.config.timeout))
162            .send()
163            .await
164            .map_err(|e| KmsError::ProviderConnectionError {
165                message: e.to_string(),
166            })?;
167
168        if !response.status().is_success() {
169            return Err(KmsError::EncryptionFailed {
170                message: format!("Vault returned status {}", response.status()),
171            });
172        }
173
174        let data = response.json::<serde_json::Value>().await.map_err(|e| {
175            KmsError::SerializationError {
176                message: e.to_string(),
177            }
178        })?;
179
180        let ciphertext = data["data"]["ciphertext"]
181            .as_str()
182            .ok_or_else(|| KmsError::EncryptionFailed {
183                message: "No ciphertext in Vault response".to_string(),
184            })?
185            .to_string();
186
187        Ok((ciphertext, "aes256-gcm96".to_string()))
188    }
189
190    async fn do_decrypt(
191        &self,
192        ciphertext: &str,
193        key_id: &str,
194        context: &HashMap<String, String>,
195    ) -> KmsResult<Vec<u8>> {
196        let url = self.config.api_url(&format!("decrypt/{}", key_id));
197
198        let mut payload = json!({
199            "ciphertext": ciphertext,
200        });
201
202        // Add context if provided
203        if !context.is_empty() {
204            let context_json =
205                serde_json::to_string(context).map_err(|e| KmsError::SerializationError {
206                    message: e.to_string(),
207                })?;
208            let context_b64 = base64_encode(context_json.as_bytes());
209            payload["context"] = json!(context_b64);
210        }
211
212        let response = self
213            .client
214            .post(&url)
215            .headers(self.build_headers())
216            .json(&payload)
217            .timeout(std::time::Duration::from_secs(self.config.timeout))
218            .send()
219            .await
220            .map_err(|e| KmsError::ProviderConnectionError {
221                message: e.to_string(),
222            })?;
223
224        if !response.status().is_success() {
225            return Err(KmsError::DecryptionFailed {
226                message: format!("Vault returned status {}", response.status()),
227            });
228        }
229
230        let data = response.json::<serde_json::Value>().await.map_err(|e| {
231            KmsError::SerializationError {
232                message: e.to_string(),
233            }
234        })?;
235
236        let plaintext_b64 =
237            data["data"]["plaintext"].as_str().ok_or_else(|| KmsError::DecryptionFailed {
238                message: "No plaintext in Vault response".to_string(),
239            })?;
240
241        base64_decode(plaintext_b64).map_err(|_| KmsError::DecryptionFailed {
242            message: "Failed to decode plaintext from Vault".to_string(),
243        })
244    }
245
246    async fn do_generate_data_key(
247        &self,
248        key_id: &str,
249        context: &HashMap<String, String>,
250    ) -> KmsResult<(Vec<u8>, String)> {
251        let url = self.config.api_url(&format!("datakey/plaintext/{}", key_id));
252
253        let mut payload = json!({
254            "bits": 256,  // AES-256
255        });
256
257        // Add context if provided
258        if !context.is_empty() {
259            let context_json =
260                serde_json::to_string(context).map_err(|e| KmsError::SerializationError {
261                    message: e.to_string(),
262                })?;
263            let context_b64 = base64_encode(context_json.as_bytes());
264            payload["context"] = json!(context_b64);
265        }
266
267        let response = self
268            .client
269            .post(&url)
270            .headers(self.build_headers())
271            .json(&payload)
272            .timeout(std::time::Duration::from_secs(self.config.timeout))
273            .send()
274            .await
275            .map_err(|e| KmsError::ProviderConnectionError {
276                message: e.to_string(),
277            })?;
278
279        if !response.status().is_success() {
280            return Err(KmsError::EncryptionFailed {
281                message: format!("Vault returned status {}", response.status()),
282            });
283        }
284
285        let data = response.json::<serde_json::Value>().await.map_err(|e| {
286            KmsError::SerializationError {
287                message: e.to_string(),
288            }
289        })?;
290
291        let plaintext_b64 =
292            data["data"]["plaintext"].as_str().ok_or_else(|| KmsError::EncryptionFailed {
293                message: "No plaintext key in Vault response".to_string(),
294            })?;
295
296        let plaintext_key =
297            base64_decode(plaintext_b64).map_err(|_| KmsError::EncryptionFailed {
298                message: "Failed to decode plaintext key from Vault".to_string(),
299            })?;
300
301        let ciphertext = data["data"]["ciphertext"]
302            .as_str()
303            .ok_or_else(|| KmsError::EncryptionFailed {
304                message: "No encrypted key in Vault response".to_string(),
305            })?
306            .to_string();
307
308        Ok((plaintext_key, ciphertext))
309    }
310
311    async fn do_rotate_key(&self, key_id: &str) -> KmsResult<()> {
312        let url = self.config.api_url(&format!("keys/{}/rotate", key_id));
313
314        let response = self
315            .client
316            .post(&url)
317            .headers(self.build_headers())
318            .json(&json!({}))
319            .timeout(std::time::Duration::from_secs(self.config.timeout))
320            .send()
321            .await
322            .map_err(|e| KmsError::ProviderConnectionError {
323                message: e.to_string(),
324            })?;
325
326        if !response.status().is_success() {
327            return Err(KmsError::RotationFailed {
328                message: format!("Vault returned status {}", response.status()),
329            });
330        }
331
332        Ok(())
333    }
334
335    async fn do_get_key_info(&self, key_id: &str) -> KmsResult<KeyInfo> {
336        let url = self.config.api_url(&format!("keys/{}", key_id));
337
338        let response = self
339            .client
340            .get(&url)
341            .headers(self.build_headers())
342            .timeout(std::time::Duration::from_secs(self.config.timeout))
343            .send()
344            .await
345            .map_err(|e| KmsError::ProviderConnectionError {
346                message: e.to_string(),
347            })?;
348
349        if response.status() == 404 {
350            return Err(KmsError::KeyNotFound {
351                key_id: key_id.to_string(),
352            });
353        }
354
355        if !response.status().is_success() {
356            return Err(KmsError::ProviderConnectionError {
357                message: format!("Vault returned status {}", response.status()),
358            });
359        }
360
361        let data = response.json::<serde_json::Value>().await.map_err(|e| {
362            KmsError::SerializationError {
363                message: e.to_string(),
364            }
365        })?;
366
367        let key_data = &data["data"];
368        let alias = key_data["name"].as_str().map(|s| s.to_string());
369        let created_at = key_data["creation_time"]
370            .as_i64()
371            .unwrap_or_else(|| chrono::Utc::now().timestamp());
372
373        Ok(KeyInfo { alias, created_at })
374    }
375
376    async fn do_get_rotation_policy(&self, key_id: &str) -> KmsResult<RotationPolicyInfo> {
377        let url = self.config.api_url(&format!("keys/{}", key_id));
378
379        let response = self
380            .client
381            .get(&url)
382            .headers(self.build_headers())
383            .timeout(std::time::Duration::from_secs(self.config.timeout))
384            .send()
385            .await
386            .map_err(|e| KmsError::ProviderConnectionError {
387                message: e.to_string(),
388            })?;
389
390        if response.status() == 404 {
391            return Err(KmsError::KeyNotFound {
392                key_id: key_id.to_string(),
393            });
394        }
395
396        if !response.status().is_success() {
397            return Err(KmsError::ProviderConnectionError {
398                message: format!("Vault returned status {}", response.status()),
399            });
400        }
401
402        let _data = response.json::<serde_json::Value>().await.map_err(|e| {
403            KmsError::SerializationError {
404                message: e.to_string(),
405            }
406        })?;
407
408        // Vault doesn't have explicit rotation policies in transit engine
409        // Return disabled by default
410        Ok(RotationPolicyInfo {
411            enabled:              false,
412            rotation_period_days: 0,
413            last_rotation:        None,
414            next_rotation:        None,
415        })
416    }
417}
418
419/// Encode bytes as base64.
420fn base64_encode(data: &[u8]) -> String {
421    use base64::prelude::*;
422    BASE64_STANDARD.encode(data)
423}
424
425/// Decode base64 to bytes.
426fn base64_decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
427    use base64::prelude::*;
428    BASE64_STANDARD.decode(s)
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_vault_config_api_url() {
437        let config =
438            VaultConfig::new("https://vault.example.com".to_string(), "token123".to_string());
439        assert_eq!(
440            config.api_url("encrypt/my-key"),
441            "https://vault.example.com/v1/transit/encrypt/my-key"
442        );
443    }
444
445    #[test]
446    fn test_vault_config_custom_mount_path() {
447        let config =
448            VaultConfig::new("https://vault.example.com".to_string(), "token123".to_string())
449                .with_mount_path("custom-transit".to_string());
450
451        assert_eq!(
452            config.api_url("encrypt/my-key"),
453            "https://vault.example.com/v1/custom-transit/encrypt/my-key"
454        );
455    }
456
457    #[test]
458    fn test_base64_roundtrip() {
459        let data = b"hello world";
460        let encoded = base64_encode(data);
461        let decoded = base64_decode(&encoded).unwrap();
462        assert_eq!(decoded, data);
463    }
464}