1#![allow(clippy::cognitive_complexity, clippy::too_many_lines)]
5
6use async_trait::async_trait;
7use aws_sdk_secretsmanager::Client;
8use aws_smithy_http_client::{Builder as SmithyHttpClientBuilder, tls};
9use cuenv_secrets::{SecretError, SecretResolver, SecretSpec, SecureSecret};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tokio::process::Command;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16#[serde(rename_all = "camelCase")]
17pub struct AwsSecretConfig {
18 pub secret_id: String,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub version_id: Option<String>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub version_stage: Option<String>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub json_key: Option<String>,
32}
33
34impl AwsSecretConfig {
35 #[must_use]
37 pub fn new(secret_id: impl Into<String>) -> Self {
38 Self {
39 secret_id: secret_id.into(),
40 version_id: None,
41 version_stage: None,
42 json_key: None,
43 }
44 }
45}
46
47pub struct AwsResolver {
57 http_client: Option<Client>,
58}
59
60impl std::fmt::Debug for AwsResolver {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("AwsResolver")
63 .field("mode", &if self.can_use_http() { "http" } else { "cli" })
64 .finish()
65 }
66}
67
68impl AwsResolver {
69 pub async fn new() -> Result<Self, SecretError> {
77 let http_client = if Self::http_credentials_available() {
78 let http_client = SmithyHttpClientBuilder::new()
80 .tls_provider(tls::Provider::Rustls(
81 tls::rustls_provider::CryptoMode::Ring,
82 ))
83 .build_https();
84 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
85 .http_client(http_client)
86 .load()
87 .await;
88 Some(Client::new(&config))
89 } else {
90 None
91 };
92
93 Ok(Self { http_client })
94 }
95
96 fn http_credentials_available() -> bool {
98 std::env::var("AWS_ACCESS_KEY_ID").is_ok() && std::env::var("AWS_SECRET_ACCESS_KEY").is_ok()
99 }
100
101 const fn can_use_http(&self) -> bool {
103 self.http_client.is_some()
104 }
105
106 async fn resolve_http(
108 &self,
109 name: &str,
110 config: &AwsSecretConfig,
111 ) -> Result<String, SecretError> {
112 let client = self
113 .http_client
114 .as_ref()
115 .ok_or_else(|| SecretError::ResolutionFailed {
116 name: name.to_string(),
117 message: "HTTP client not available".to_string(),
118 })?;
119
120 let mut request = client.get_secret_value().secret_id(&config.secret_id);
121
122 if let Some(version_id) = &config.version_id {
123 request = request.version_id(version_id);
124 }
125
126 if let Some(version_stage) = &config.version_stage {
127 request = request.version_stage(version_stage);
128 }
129
130 let response = request
131 .send()
132 .await
133 .map_err(|e| SecretError::ResolutionFailed {
134 name: name.to_string(),
135 message: format!("AWS Secrets Manager error: {e}"),
136 })?;
137
138 let secret_string =
139 response
140 .secret_string()
141 .ok_or_else(|| SecretError::ResolutionFailed {
142 name: name.to_string(),
143 message: "Secret has no string value (may be binary)".to_string(),
144 })?;
145
146 Self::extract_json_key(name, secret_string, config.json_key.as_ref())
147 }
148
149 async fn resolve_cli(
151 &self,
152 name: &str,
153 config: &AwsSecretConfig,
154 ) -> Result<String, SecretError> {
155 let mut args = vec![
156 "secretsmanager".to_string(),
157 "get-secret-value".to_string(),
158 "--secret-id".to_string(),
159 config.secret_id.clone(),
160 "--query".to_string(),
161 "SecretString".to_string(),
162 "--output".to_string(),
163 "text".to_string(),
164 ];
165
166 if let Some(version_id) = &config.version_id {
167 args.push("--version-id".to_string());
168 args.push(version_id.clone());
169 }
170
171 if let Some(version_stage) = &config.version_stage {
172 args.push("--version-stage".to_string());
173 args.push(version_stage.clone());
174 }
175
176 let output = Command::new("aws")
177 .args(&args)
178 .output()
179 .await
180 .map_err(|e| SecretError::ResolutionFailed {
181 name: name.to_string(),
182 message: format!("Failed to execute aws CLI: {e}"),
183 })?;
184
185 if !output.status.success() {
186 let stderr = String::from_utf8_lossy(&output.stderr);
187 return Err(SecretError::ResolutionFailed {
188 name: name.to_string(),
189 message: format!("aws CLI failed: {stderr}"),
190 });
191 }
192
193 let secret_string = String::from_utf8_lossy(&output.stdout).trim().to_string();
194 Self::extract_json_key(name, &secret_string, config.json_key.as_ref())
195 }
196
197 fn extract_json_key(
199 name: &str,
200 secret_string: &str,
201 json_key: Option<&String>,
202 ) -> Result<String, SecretError> {
203 if let Some(key) = json_key {
204 let parsed: serde_json::Value =
205 serde_json::from_str(secret_string).map_err(|e| SecretError::ResolutionFailed {
206 name: name.to_string(),
207 message: format!("Secret is not valid JSON: {e}"),
208 })?;
209
210 let value = parsed
211 .get(key)
212 .ok_or_else(|| SecretError::ResolutionFailed {
213 name: name.to_string(),
214 message: format!("JSON key '{key}' not found in secret"),
215 })?;
216
217 return match value {
218 serde_json::Value::String(s) => Ok(s.clone()),
219 other => Ok(other.to_string()),
220 };
221 }
222
223 Ok(secret_string.to_string())
224 }
225
226 async fn resolve_with_config(
228 &self,
229 name: &str,
230 config: &AwsSecretConfig,
231 ) -> Result<String, SecretError> {
232 if self.http_client.is_some() {
234 return self.resolve_http(name, config).await;
235 }
236
237 self.resolve_cli(name, config).await
239 }
240
241 async fn resolve_batch_http(
243 &self,
244 secrets: &HashMap<String, SecretSpec>,
245 ) -> Result<HashMap<String, SecureSecret>, SecretError> {
246 use futures::future::try_join_all;
247
248 let client = self
249 .http_client
250 .as_ref()
251 .ok_or_else(|| SecretError::ResolutionFailed {
252 name: "batch".to_string(),
253 message: "HTTP client not available".to_string(),
254 })?;
255
256 let mut id_to_names: HashMap<String, Vec<(String, AwsSecretConfig)>> = HashMap::new();
259 for (name, spec) in secrets {
260 let config = serde_json::from_str::<AwsSecretConfig>(&spec.source)
261 .unwrap_or_else(|_| AwsSecretConfig::new(&spec.source));
262 id_to_names
263 .entry(config.secret_id.clone())
264 .or_default()
265 .push((name.clone(), config));
266 }
267
268 let secret_ids: Vec<String> = id_to_names.keys().cloned().collect();
270
271 let mut all_values: HashMap<String, String> = HashMap::new();
273
274 for chunk in secret_ids.chunks(20) {
275 let response = client
276 .batch_get_secret_value()
277 .set_secret_id_list(Some(chunk.to_vec()))
278 .send()
279 .await
280 .map_err(|e| SecretError::ResolutionFailed {
281 name: "batch".to_string(),
282 message: format!("AWS BatchGetSecretValue failed: {e}"),
283 })?;
284
285 for sv in response.secret_values() {
287 if let Some(secret_string) = sv.secret_string() {
288 if let Some(secret_name) = sv.name() {
290 all_values.insert(secret_name.to_string(), secret_string.to_string());
291 }
292 if let Some(arn) = sv.arn() {
293 all_values.insert(arn.to_string(), secret_string.to_string());
294 }
295 }
296 }
297
298 for err in response.errors() {
300 tracing::warn!(
301 secret_id = ?err.secret_id(),
302 error_code = ?err.error_code(),
303 message = ?err.message(),
304 "Failed to retrieve secret in batch"
305 );
306 }
307 }
308
309 let extract_futures: Vec<_> = secrets
311 .iter()
312 .map(|(name, spec)| {
313 let name = name.clone();
314 let all_values = &all_values;
315 async move {
316 let config = serde_json::from_str::<AwsSecretConfig>(&spec.source)
317 .unwrap_or_else(|_| AwsSecretConfig::new(&spec.source));
318
319 let secret_string = all_values.get(&config.secret_id).ok_or_else(|| {
321 SecretError::ResolutionFailed {
322 name: name.clone(),
323 message: format!(
324 "Secret '{}' not found in batch response",
325 config.secret_id
326 ),
327 }
328 })?;
329
330 let value =
332 Self::extract_json_key(&name, secret_string, config.json_key.as_ref())?;
333 Ok::<_, SecretError>((name, SecureSecret::new(value)))
334 }
335 })
336 .collect();
337
338 try_join_all(extract_futures)
339 .await
340 .map(|v| v.into_iter().collect())
341 }
342
343 async fn resolve_batch_cli(
345 &self,
346 secrets: &HashMap<String, SecretSpec>,
347 ) -> Result<HashMap<String, SecureSecret>, SecretError> {
348 use futures::future::try_join_all;
349
350 let futures: Vec<_> = secrets
351 .iter()
352 .map(|(name, spec)| {
353 let name = name.clone();
354 let spec = spec.clone();
355 async move {
356 let value = self.resolve(&name, &spec).await?;
357 Ok::<_, SecretError>((name, SecureSecret::new(value)))
358 }
359 })
360 .collect();
361
362 try_join_all(futures).await.map(|v| v.into_iter().collect())
363 }
364}
365
366#[async_trait]
367impl SecretResolver for AwsResolver {
368 fn provider_name(&self) -> &'static str {
369 "aws"
370 }
371
372 fn supports_native_batch(&self) -> bool {
373 true
375 }
376
377 async fn resolve(&self, name: &str, spec: &SecretSpec) -> Result<String, SecretError> {
378 if let Ok(config) = serde_json::from_str::<AwsSecretConfig>(&spec.source) {
380 return self.resolve_with_config(name, &config).await;
381 }
382
383 let config = AwsSecretConfig::new(&spec.source);
385 self.resolve_with_config(name, &config).await
386 }
387
388 async fn resolve_batch(
389 &self,
390 secrets: &HashMap<String, SecretSpec>,
391 ) -> Result<HashMap<String, SecureSecret>, SecretError> {
392 if secrets.is_empty() {
393 return Ok(HashMap::new());
394 }
395
396 if self.http_client.is_some() {
398 return self.resolve_batch_http(secrets).await;
399 }
400
401 self.resolve_batch_cli(secrets).await
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_aws_config_serialization() {
412 let config = AwsSecretConfig {
413 secret_id: "my-secret".to_string(),
414 version_id: Some("v1".to_string()),
415 version_stage: None,
416 json_key: Some("password".to_string()),
417 };
418
419 let json = serde_json::to_string(&config).unwrap();
420 let parsed: AwsSecretConfig = serde_json::from_str(&json).unwrap();
421 assert_eq!(config, parsed);
422 }
423
424 #[test]
425 fn test_simple_config() {
426 let config = AwsSecretConfig::new("arn:aws:secretsmanager:us-east-1:123456:secret:test");
427 assert_eq!(
428 config.secret_id,
429 "arn:aws:secretsmanager:us-east-1:123456:secret:test"
430 );
431 assert!(config.version_id.is_none());
432 assert!(config.json_key.is_none());
433 }
434
435 #[test]
436 fn test_http_credentials_check() {
437 let _ = AwsResolver::http_credentials_available();
439 }
440
441 #[test]
442 fn test_aws_config_new_with_string_slice() {
443 let config = AwsSecretConfig::new("my-secret");
444 assert_eq!(config.secret_id, "my-secret");
445 assert!(config.version_id.is_none());
446 assert!(config.version_stage.is_none());
447 assert!(config.json_key.is_none());
448 }
449
450 #[test]
451 fn test_aws_config_full_serialization() {
452 let config = AwsSecretConfig {
453 secret_id: "my-secret".to_string(),
454 version_id: Some("abc123".to_string()),
455 version_stage: Some("AWSCURRENT".to_string()),
456 json_key: Some("api_key".to_string()),
457 };
458
459 let json = serde_json::to_string(&config).unwrap();
460 assert!(json.contains("\"secretId\":\"my-secret\""));
461 assert!(json.contains("\"versionId\":\"abc123\""));
462 assert!(json.contains("\"versionStage\":\"AWSCURRENT\""));
463 assert!(json.contains("\"jsonKey\":\"api_key\""));
464 }
465
466 #[test]
467 fn test_aws_config_minimal_serialization() {
468 let config = AwsSecretConfig::new("simple-secret");
469 let json = serde_json::to_string(&config).unwrap();
470 assert!(!json.contains("versionId"));
472 assert!(!json.contains("versionStage"));
473 assert!(!json.contains("jsonKey"));
474 }
475
476 #[test]
477 fn test_extract_json_key_string_value() {
478 let secret = r#"{"username": "admin", "password": "secret123"}"#;
479 let result = AwsResolver::extract_json_key("test", secret, Some(&"password".to_string()));
480 assert_eq!(result.unwrap(), "secret123");
481 }
482
483 #[test]
484 fn test_extract_json_key_number_value() {
485 let secret = r#"{"port": 5432, "host": "localhost"}"#;
486 let result = AwsResolver::extract_json_key("test", secret, Some(&"port".to_string()));
487 assert_eq!(result.unwrap(), "5432");
488 }
489
490 #[test]
491 fn test_extract_json_key_boolean_value() {
492 let secret = r#"{"enabled": true, "debug": false}"#;
493 let result = AwsResolver::extract_json_key("test", secret, Some(&"enabled".to_string()));
494 assert_eq!(result.unwrap(), "true");
495 }
496
497 #[test]
498 fn test_extract_json_key_no_key_returns_full_secret() {
499 let secret = r#"{"username": "admin"}"#;
500 let result = AwsResolver::extract_json_key("test", secret, None);
501 assert_eq!(result.unwrap(), secret);
502 }
503
504 #[test]
505 fn test_extract_json_key_plain_string_no_key() {
506 let secret = "plain-text-secret";
507 let result = AwsResolver::extract_json_key("test", secret, None);
508 assert_eq!(result.unwrap(), "plain-text-secret");
509 }
510
511 #[test]
512 fn test_extract_json_key_missing_key_error() {
513 let secret = r#"{"username": "admin"}"#;
514 let result =
515 AwsResolver::extract_json_key("test", secret, Some(&"nonexistent".to_string()));
516 assert!(result.is_err());
517 if let Err(SecretError::ResolutionFailed { message, .. }) = result {
518 assert!(message.contains("JSON key 'nonexistent' not found"));
519 } else {
520 panic!("Expected ResolutionFailed error");
521 }
522 }
523
524 #[test]
525 fn test_extract_json_key_invalid_json_error() {
526 let secret = "not-valid-json";
527 let result = AwsResolver::extract_json_key("test", secret, Some(&"key".to_string()));
528 assert!(result.is_err());
529 if let Err(SecretError::ResolutionFailed { message, .. }) = result {
530 assert!(message.contains("Secret is not valid JSON"));
531 } else {
532 panic!("Expected ResolutionFailed error");
533 }
534 }
535
536 #[test]
537 fn test_extract_json_key_nested_object() {
538 let secret = r#"{"database": {"host": "localhost", "port": 5432}}"#;
539 let result = AwsResolver::extract_json_key("test", secret, Some(&"database".to_string()));
540 let value = result.unwrap();
542 assert!(value.contains("host"));
543 assert!(value.contains("localhost"));
544 }
545
546 #[test]
547 fn test_extract_json_key_array_value() {
548 let secret = r#"{"hosts": ["host1", "host2", "host3"]}"#;
549 let result = AwsResolver::extract_json_key("test", secret, Some(&"hosts".to_string()));
550 let value = result.unwrap();
551 assert!(value.contains("host1"));
552 assert!(value.contains("host2"));
553 }
554
555 #[test]
556 fn test_extract_json_key_null_value() {
557 let secret = r#"{"value": null}"#;
558 let result = AwsResolver::extract_json_key("test", secret, Some(&"value".to_string()));
559 assert_eq!(result.unwrap(), "null");
560 }
561
562 #[test]
563 fn test_aws_config_clone() {
564 let config = AwsSecretConfig {
565 secret_id: "my-secret".to_string(),
566 version_id: Some("v1".to_string()),
567 version_stage: Some("AWSCURRENT".to_string()),
568 json_key: Some("key".to_string()),
569 };
570 let cloned = config.clone();
571 assert_eq!(config, cloned);
572 }
573
574 #[test]
575 fn test_aws_config_debug() {
576 let config = AwsSecretConfig::new("test-secret");
577 let debug_str = format!("{config:?}");
578 assert!(debug_str.contains("AwsSecretConfig"));
579 assert!(debug_str.contains("test-secret"));
580 }
581
582 #[test]
583 fn test_aws_config_equality() {
584 let config1 = AwsSecretConfig::new("secret-1");
585 let config2 = AwsSecretConfig::new("secret-1");
586 let config3 = AwsSecretConfig::new("secret-2");
587
588 assert_eq!(config1, config2);
589 assert_ne!(config1, config3);
590 }
591
592 #[test]
593 fn test_aws_config_with_version_id_equality() {
594 let mut config1 = AwsSecretConfig::new("secret");
595 config1.version_id = Some("v1".to_string());
596 let mut config2 = AwsSecretConfig::new("secret");
597 config2.version_id = Some("v1".to_string());
598 let mut config3 = AwsSecretConfig::new("secret");
599 config3.version_id = Some("v2".to_string());
600
601 assert_eq!(config1, config2);
602 assert_ne!(config1, config3);
603 }
604
605 #[test]
606 fn test_aws_config_deserialization_from_json() {
607 let json = r#"{"secretId": "my-secret", "versionId": "abc", "jsonKey": "password"}"#;
608 let config: AwsSecretConfig = serde_json::from_str(json).unwrap();
609 assert_eq!(config.secret_id, "my-secret");
610 assert_eq!(config.version_id, Some("abc".to_string()));
611 assert_eq!(config.json_key, Some("password".to_string()));
612 assert!(config.version_stage.is_none());
613 }
614
615 #[test]
616 fn test_aws_config_deserialization_minimal() {
617 let json = r#"{"secretId": "just-the-id"}"#;
618 let config: AwsSecretConfig = serde_json::from_str(json).unwrap();
619 assert_eq!(config.secret_id, "just-the-id");
620 assert!(config.version_id.is_none());
621 assert!(config.version_stage.is_none());
622 assert!(config.json_key.is_none());
623 }
624
625 #[test]
626 fn test_aws_config_deserialization_missing_secret_id() {
627 let json = r#"{"versionId": "v1"}"#;
628 let result = serde_json::from_str::<AwsSecretConfig>(json);
629 assert!(result.is_err());
630 }
631
632 #[test]
633 fn test_aws_config_with_arn() {
634 let arn = "arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret-abc123";
635 let config = AwsSecretConfig::new(arn);
636 assert_eq!(config.secret_id, arn);
637 }
638
639 #[test]
640 fn test_aws_config_roundtrip() {
641 let original = AwsSecretConfig {
642 secret_id: "test-secret".to_string(),
643 version_id: Some("v1".to_string()),
644 version_stage: Some("AWSPREVIOUS".to_string()),
645 json_key: Some("key".to_string()),
646 };
647 let json = serde_json::to_string(&original).unwrap();
648 let parsed: AwsSecretConfig = serde_json::from_str(&json).unwrap();
649 assert_eq!(original, parsed);
650 }
651
652 #[test]
653 fn test_extract_json_key_empty_string_value() {
654 let secret = r#"{"key": ""}"#;
655 let result = AwsResolver::extract_json_key("test", secret, Some(&"key".to_string()));
656 assert_eq!(result.unwrap(), "");
657 }
658
659 #[test]
660 fn test_extract_json_key_special_characters() {
661 let secret = r#"{"key": "value with \"quotes\" and \n newlines"}"#;
662 let result = AwsResolver::extract_json_key("test", secret, Some(&"key".to_string()));
663 assert!(result.is_ok());
664 let value = result.unwrap();
665 assert!(value.contains("quotes"));
666 }
667
668 #[test]
669 fn test_extract_json_key_unicode() {
670 let secret = r#"{"密码": "秘密值"}"#;
671 let result = AwsResolver::extract_json_key("test", secret, Some(&"密码".to_string()));
672 assert_eq!(result.unwrap(), "秘密值");
673 }
674
675 #[test]
676 fn test_extract_json_key_numeric_string() {
677 let secret = r#"{"key": "12345"}"#;
678 let result = AwsResolver::extract_json_key("test", secret, Some(&"key".to_string()));
679 assert_eq!(result.unwrap(), "12345");
680 }
681
682 #[test]
683 fn test_extract_json_key_float_value() {
684 let secret = r#"{"rate": 3.14159}"#;
685 let result = AwsResolver::extract_json_key("test", secret, Some(&"rate".to_string()));
686 let value = result.unwrap();
687 assert!(value.starts_with("3.14"));
688 }
689
690 #[tokio::test]
691 async fn test_resolver_new_without_credentials() {
692 if std::env::var("AWS_ACCESS_KEY_ID").is_err()
695 || std::env::var("AWS_SECRET_ACCESS_KEY").is_err()
696 {
697 let resolver = AwsResolver::new().await;
698 assert!(resolver.is_ok());
699 let resolver = resolver.unwrap();
700 assert!(!resolver.can_use_http());
701 }
702 }
703
704 #[tokio::test]
705 async fn test_resolver_provider_name() {
706 if std::env::var("AWS_ACCESS_KEY_ID").is_err()
707 || std::env::var("AWS_SECRET_ACCESS_KEY").is_err()
708 {
709 let resolver = AwsResolver::new().await.unwrap();
710 assert_eq!(resolver.provider_name(), "aws");
711 }
712 }
713
714 #[tokio::test]
715 async fn test_resolver_supports_native_batch() {
716 if std::env::var("AWS_ACCESS_KEY_ID").is_err()
717 || std::env::var("AWS_SECRET_ACCESS_KEY").is_err()
718 {
719 let resolver = AwsResolver::new().await.unwrap();
720 assert!(resolver.supports_native_batch());
721 }
722 }
723
724 #[tokio::test]
725 async fn test_resolver_debug_output() {
726 if std::env::var("AWS_ACCESS_KEY_ID").is_err()
727 || std::env::var("AWS_SECRET_ACCESS_KEY").is_err()
728 {
729 let resolver = AwsResolver::new().await.unwrap();
730 let debug = format!("{resolver:?}");
731 assert!(debug.contains("AwsResolver"));
732 assert!(debug.contains("cli") || debug.contains("http"));
733 }
734 }
735
736 #[tokio::test]
737 async fn test_resolve_batch_empty() {
738 if std::env::var("AWS_ACCESS_KEY_ID").is_err()
739 || std::env::var("AWS_SECRET_ACCESS_KEY").is_err()
740 {
741 let resolver = AwsResolver::new().await.unwrap();
742 let empty: HashMap<String, SecretSpec> = HashMap::new();
743 let result = resolver.resolve_batch(&empty).await;
744 assert!(result.is_ok());
745 assert!(result.unwrap().is_empty());
746 }
747 }
748
749 #[test]
750 fn test_http_credentials_available_logic() {
751 let key_id = std::env::var("AWS_ACCESS_KEY_ID").is_ok();
753 let secret = std::env::var("AWS_SECRET_ACCESS_KEY").is_ok();
754 let expected = key_id && secret;
755 assert_eq!(AwsResolver::http_credentials_available(), expected);
756 }
757}