fraiseql_core/security/kms/
vault.rs1use std::collections::HashMap;
4
5use serde_json::json;
6
7use crate::security::kms::{
8 base::{BaseKmsProvider, KeyInfo, RotationPolicyInfo},
9 error::{KmsError, KmsResult},
10};
11
12#[derive(Debug, Clone)]
23pub struct VaultConfig {
24 pub vault_addr: String,
26 pub token: String,
28 pub mount_path: String,
30 pub namespace: Option<String>,
32 pub verify_tls: bool,
34 pub timeout: u64,
36}
37
38impl VaultConfig {
39 pub fn new(vault_addr: String, token: String) -> Self {
41 Self {
42 vault_addr,
43 token,
44 mount_path: "transit".to_string(),
45 namespace: None,
46 verify_tls: true,
47 timeout: 30,
48 }
49 }
50
51 #[must_use]
53 pub fn with_mount_path(mut self, mount_path: String) -> Self {
54 self.mount_path = mount_path;
55 self
56 }
57
58 #[must_use]
60 pub fn with_namespace(mut self, namespace: String) -> Self {
61 self.namespace = Some(namespace);
62 self
63 }
64
65 #[must_use]
67 pub fn with_verify_tls(mut self, verify_tls: bool) -> Self {
68 self.verify_tls = verify_tls;
69 self
70 }
71
72 #[must_use]
74 pub fn with_timeout(mut self, timeout: u64) -> Self {
75 self.timeout = timeout;
76 self
77 }
78
79 fn api_url(&self, path: &str) -> String {
81 let addr = self.vault_addr.trim_end_matches('/');
82 format!("{}/v1/{}/{}", addr, self.mount_path, path)
83 }
84}
85
86pub struct VaultKmsProvider {
93 config: VaultConfig,
94 client: reqwest::Client,
95}
96
97impl VaultKmsProvider {
98 pub fn new(config: VaultConfig) -> KmsResult<Self> {
100 let client = reqwest::Client::new();
101 Ok(Self { config, client })
102 }
103
104 fn build_headers(&self) -> reqwest::header::HeaderMap {
106 let mut headers = reqwest::header::HeaderMap::new();
107
108 headers.insert(
109 "X-Vault-Token",
110 reqwest::header::HeaderValue::from_str(&self.config.token)
111 .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("")),
112 );
113
114 if let Some(namespace) = &self.config.namespace {
115 headers.insert(
116 "X-Vault-Namespace",
117 reqwest::header::HeaderValue::from_str(namespace)
118 .unwrap_or_else(|_| reqwest::header::HeaderValue::from_static("")),
119 );
120 }
121
122 headers
123 }
124}
125
126#[async_trait::async_trait]
127impl BaseKmsProvider for VaultKmsProvider {
128 fn provider_name(&self) -> &'static str {
129 "vault"
130 }
131
132 async fn do_encrypt(
133 &self,
134 plaintext: &[u8],
135 key_id: &str,
136 context: &HashMap<String, String>,
137 ) -> KmsResult<(String, String)> {
138 let url = self.config.api_url(&format!("encrypt/{}", key_id));
139
140 let plaintext_b64 = base64_encode(plaintext);
141
142 let mut payload = json!({
143 "plaintext": plaintext_b64,
144 });
145
146 if !context.is_empty() {
148 let context_json =
149 serde_json::to_string(context).map_err(|e| KmsError::SerializationError {
150 message: e.to_string(),
151 })?;
152 let context_b64 = base64_encode(context_json.as_bytes());
153 payload["context"] = json!(context_b64);
154 }
155
156 let response = self
157 .client
158 .post(&url)
159 .headers(self.build_headers())
160 .json(&payload)
161 .timeout(std::time::Duration::from_secs(self.config.timeout))
162 .send()
163 .await
164 .map_err(|e| KmsError::ProviderConnectionError {
165 message: e.to_string(),
166 })?;
167
168 if !response.status().is_success() {
169 return Err(KmsError::EncryptionFailed {
170 message: format!("Vault returned status {}", response.status()),
171 });
172 }
173
174 let data = response.json::<serde_json::Value>().await.map_err(|e| {
175 KmsError::SerializationError {
176 message: e.to_string(),
177 }
178 })?;
179
180 let ciphertext = data["data"]["ciphertext"]
181 .as_str()
182 .ok_or_else(|| KmsError::EncryptionFailed {
183 message: "No ciphertext in Vault response".to_string(),
184 })?
185 .to_string();
186
187 Ok((ciphertext, "aes256-gcm96".to_string()))
188 }
189
190 async fn do_decrypt(
191 &self,
192 ciphertext: &str,
193 key_id: &str,
194 context: &HashMap<String, String>,
195 ) -> KmsResult<Vec<u8>> {
196 let url = self.config.api_url(&format!("decrypt/{}", key_id));
197
198 let mut payload = json!({
199 "ciphertext": ciphertext,
200 });
201
202 if !context.is_empty() {
204 let context_json =
205 serde_json::to_string(context).map_err(|e| KmsError::SerializationError {
206 message: e.to_string(),
207 })?;
208 let context_b64 = base64_encode(context_json.as_bytes());
209 payload["context"] = json!(context_b64);
210 }
211
212 let response = self
213 .client
214 .post(&url)
215 .headers(self.build_headers())
216 .json(&payload)
217 .timeout(std::time::Duration::from_secs(self.config.timeout))
218 .send()
219 .await
220 .map_err(|e| KmsError::ProviderConnectionError {
221 message: e.to_string(),
222 })?;
223
224 if !response.status().is_success() {
225 return Err(KmsError::DecryptionFailed {
226 message: format!("Vault returned status {}", response.status()),
227 });
228 }
229
230 let data = response.json::<serde_json::Value>().await.map_err(|e| {
231 KmsError::SerializationError {
232 message: e.to_string(),
233 }
234 })?;
235
236 let plaintext_b64 =
237 data["data"]["plaintext"].as_str().ok_or_else(|| KmsError::DecryptionFailed {
238 message: "No plaintext in Vault response".to_string(),
239 })?;
240
241 base64_decode(plaintext_b64).map_err(|_| KmsError::DecryptionFailed {
242 message: "Failed to decode plaintext from Vault".to_string(),
243 })
244 }
245
246 async fn do_generate_data_key(
247 &self,
248 key_id: &str,
249 context: &HashMap<String, String>,
250 ) -> KmsResult<(Vec<u8>, String)> {
251 let url = self.config.api_url(&format!("datakey/plaintext/{}", key_id));
252
253 let mut payload = json!({
254 "bits": 256, });
256
257 if !context.is_empty() {
259 let context_json =
260 serde_json::to_string(context).map_err(|e| KmsError::SerializationError {
261 message: e.to_string(),
262 })?;
263 let context_b64 = base64_encode(context_json.as_bytes());
264 payload["context"] = json!(context_b64);
265 }
266
267 let response = self
268 .client
269 .post(&url)
270 .headers(self.build_headers())
271 .json(&payload)
272 .timeout(std::time::Duration::from_secs(self.config.timeout))
273 .send()
274 .await
275 .map_err(|e| KmsError::ProviderConnectionError {
276 message: e.to_string(),
277 })?;
278
279 if !response.status().is_success() {
280 return Err(KmsError::EncryptionFailed {
281 message: format!("Vault returned status {}", response.status()),
282 });
283 }
284
285 let data = response.json::<serde_json::Value>().await.map_err(|e| {
286 KmsError::SerializationError {
287 message: e.to_string(),
288 }
289 })?;
290
291 let plaintext_b64 =
292 data["data"]["plaintext"].as_str().ok_or_else(|| KmsError::EncryptionFailed {
293 message: "No plaintext key in Vault response".to_string(),
294 })?;
295
296 let plaintext_key =
297 base64_decode(plaintext_b64).map_err(|_| KmsError::EncryptionFailed {
298 message: "Failed to decode plaintext key from Vault".to_string(),
299 })?;
300
301 let ciphertext = data["data"]["ciphertext"]
302 .as_str()
303 .ok_or_else(|| KmsError::EncryptionFailed {
304 message: "No encrypted key in Vault response".to_string(),
305 })?
306 .to_string();
307
308 Ok((plaintext_key, ciphertext))
309 }
310
311 async fn do_rotate_key(&self, key_id: &str) -> KmsResult<()> {
312 let url = self.config.api_url(&format!("keys/{}/rotate", key_id));
313
314 let response = self
315 .client
316 .post(&url)
317 .headers(self.build_headers())
318 .json(&json!({}))
319 .timeout(std::time::Duration::from_secs(self.config.timeout))
320 .send()
321 .await
322 .map_err(|e| KmsError::ProviderConnectionError {
323 message: e.to_string(),
324 })?;
325
326 if !response.status().is_success() {
327 return Err(KmsError::RotationFailed {
328 message: format!("Vault returned status {}", response.status()),
329 });
330 }
331
332 Ok(())
333 }
334
335 async fn do_get_key_info(&self, key_id: &str) -> KmsResult<KeyInfo> {
336 let url = self.config.api_url(&format!("keys/{}", key_id));
337
338 let response = self
339 .client
340 .get(&url)
341 .headers(self.build_headers())
342 .timeout(std::time::Duration::from_secs(self.config.timeout))
343 .send()
344 .await
345 .map_err(|e| KmsError::ProviderConnectionError {
346 message: e.to_string(),
347 })?;
348
349 if response.status() == 404 {
350 return Err(KmsError::KeyNotFound {
351 key_id: key_id.to_string(),
352 });
353 }
354
355 if !response.status().is_success() {
356 return Err(KmsError::ProviderConnectionError {
357 message: format!("Vault returned status {}", response.status()),
358 });
359 }
360
361 let data = response.json::<serde_json::Value>().await.map_err(|e| {
362 KmsError::SerializationError {
363 message: e.to_string(),
364 }
365 })?;
366
367 let key_data = &data["data"];
368 let alias = key_data["name"].as_str().map(|s| s.to_string());
369 let created_at = key_data["creation_time"]
370 .as_i64()
371 .unwrap_or_else(|| chrono::Utc::now().timestamp());
372
373 Ok(KeyInfo { alias, created_at })
374 }
375
376 async fn do_get_rotation_policy(&self, key_id: &str) -> KmsResult<RotationPolicyInfo> {
377 let url = self.config.api_url(&format!("keys/{}", key_id));
378
379 let response = self
380 .client
381 .get(&url)
382 .headers(self.build_headers())
383 .timeout(std::time::Duration::from_secs(self.config.timeout))
384 .send()
385 .await
386 .map_err(|e| KmsError::ProviderConnectionError {
387 message: e.to_string(),
388 })?;
389
390 if response.status() == 404 {
391 return Err(KmsError::KeyNotFound {
392 key_id: key_id.to_string(),
393 });
394 }
395
396 if !response.status().is_success() {
397 return Err(KmsError::ProviderConnectionError {
398 message: format!("Vault returned status {}", response.status()),
399 });
400 }
401
402 let _data = response.json::<serde_json::Value>().await.map_err(|e| {
403 KmsError::SerializationError {
404 message: e.to_string(),
405 }
406 })?;
407
408 Ok(RotationPolicyInfo {
411 enabled: false,
412 rotation_period_days: 0,
413 last_rotation: None,
414 next_rotation: None,
415 })
416 }
417}
418
419fn base64_encode(data: &[u8]) -> String {
421 use base64::prelude::*;
422 BASE64_STANDARD.encode(data)
423}
424
425fn base64_decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
427 use base64::prelude::*;
428 BASE64_STANDARD.decode(s)
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_vault_config_api_url() {
437 let config =
438 VaultConfig::new("https://vault.example.com".to_string(), "token123".to_string());
439 assert_eq!(
440 config.api_url("encrypt/my-key"),
441 "https://vault.example.com/v1/transit/encrypt/my-key"
442 );
443 }
444
445 #[test]
446 fn test_vault_config_custom_mount_path() {
447 let config =
448 VaultConfig::new("https://vault.example.com".to_string(), "token123".to_string())
449 .with_mount_path("custom-transit".to_string());
450
451 assert_eq!(
452 config.api_url("encrypt/my-key"),
453 "https://vault.example.com/v1/custom-transit/encrypt/my-key"
454 );
455 }
456
457 #[test]
458 fn test_base64_roundtrip() {
459 let data = b"hello world";
460 let encoded = base64_encode(data);
461 let decoded = base64_decode(&encoded).unwrap();
462 assert_eq!(decoded, data);
463 }
464}