1use std::str::FromStr;
2
3use async_trait::async_trait;
4use futures::future::try_join_all;
5use rusoto_core::{request::TlsError, HttpClient, Region};
6use rusoto_credential::{CredentialsError, DefaultCredentialsProvider, StaticProvider};
7use rusoto_secretsmanager::{
8 GetSecretValueError, GetSecretValueRequest, GetSecretValueResponse, ListSecretsError,
9 ListSecretsRequest, SecretsManager, SecretsManagerClient,
10};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use thiserror::Error;
14
15use crate::{
16 convert::{convert_env_name, decode_env_from_json},
17 Vault, VaultConfig,
18};
19
20#[derive(Serialize, Deserialize)]
21pub struct AwsConfig {
22 pub aws_access_key_id: Option<String>,
23 pub aws_secret_access_key: Option<String>,
24 pub aws_region: String,
25}
26
27#[derive(Error, Debug)]
28pub enum AwsError {
29 #[error("rusoto HttpClient error")]
30 TlsError(#[source] TlsError),
31 #[error("rusoto HttpClient error")]
32 CredentialsError(#[source] CredentialsError),
33 #[error("cannot load secret from Secrets Manager")]
34 GetSecretError(#[source] rusoto_core::RusotoError<GetSecretValueError>),
35 #[error("the secret does not have string data")]
36 NoStringData(String),
37 #[error("the secret name is not valid environment variable name")]
38 InvalidSecretName(String),
39 #[error("cannot list secrets from Secrets Manager")]
40 ListSecretsError(#[source] rusoto_core::RusotoError<ListSecretsError>),
41 #[error("cannot decode secret - it is not a valid JSON object")]
42 DecodeError(#[source] serde_json::Error),
43 #[error("there are no secrets in the Secrets Manager")]
44 NoSecrets,
45}
46
47pub type Result<T, E = AwsError> = std::result::Result<T, E>;
48
49pub struct AwsVault {
50 client: SecretsManagerClient,
51}
52
53impl VaultConfig for AwsConfig {
54 type Vault = AwsVault;
55
56 fn into_vault(self) -> anyhow::Result<Self::Vault> {
57 let http_client = HttpClient::new().map_err(AwsError::TlsError)?;
58 if let Some(key_id) = self.aws_access_key_id {
59 let secret = self.aws_secret_access_key.unwrap();
60 let provider = StaticProvider::new_minimal(key_id, secret);
61 Ok(Self::Vault {
62 client: SecretsManagerClient::new_with(
63 http_client,
64 provider,
65 Region::from_str(&self.aws_region)?,
66 ),
67 })
68 } else {
69 let provider = DefaultCredentialsProvider::new().map_err(AwsError::CredentialsError)?;
70 Ok(Self::Vault {
71 client: SecretsManagerClient::new_with(
72 http_client,
73 provider,
74 Region::from_str(&self.aws_region)?,
75 ),
76 })
77 }
78 }
79}
80
81#[async_trait]
82impl Vault for AwsVault {
83 async fn download_prefixed(&self, prefix: &str) -> anyhow::Result<Vec<(String, String)>> {
84 let list = self
85 .client
86 .list_secrets(ListSecretsRequest {
87 max_results: Some(100),
88 ..Default::default()
89 })
90 .await
91 .map_err(AwsError::ListSecretsError)?;
92 let results = list
93 .secret_list
94 .ok_or(AwsError::NoSecrets)?
95 .into_iter()
96 .filter(|x| {
97 x.name
98 .as_ref()
99 .map(|n| n.starts_with(prefix))
100 .unwrap_or(false)
101 })
102 .map(|s| async {
103 println!("{:?}", s);
104 let name = s.name.unwrap();
105 let secret = self
106 .client
107 .get_secret_value(GetSecretValueRequest {
108 secret_id: name.clone(),
109 version_id: None,
110 version_stage: None,
111 })
112 .await
113 .map_err(AwsError::GetSecretError)?;
114 println!("{:?}", secret);
115 let value = secret
116 .secret_string
117 .ok_or_else(|| AwsError::NoStringData(name.clone()))?;
118 let name = convert_env_name(prefix, &name)
119 .map_err(|_| AwsError::InvalidSecretName(name.clone()))?;
120 Ok::<_, AwsError>((name, value))
121 });
122 let values: Vec<_> = try_join_all(results).await?.into_iter().collect();
123 Ok(values)
124 }
125
126 async fn download_json(&self, secret_name: &str) -> anyhow::Result<Vec<(String, String)>> {
127 let secret = self
128 .client
129 .get_secret_value(GetSecretValueRequest {
130 secret_id: secret_name.to_string(),
131 version_id: None,
132 version_stage: None,
133 })
134 .await
135 .map_err(AwsError::GetSecretError)?;
136 let value = decode_secret(secret)?;
137 decode_env_from_json(secret_name, value)
138 }
139}
140
141fn decode_secret(secret: GetSecretValueResponse) -> Result<Value> {
142 secret
143 .secret_string
144 .as_ref()
145 .map(|x| serde_json::from_str(&x[..]))
146 .or_else(|| secret.secret_binary.map(|b| serde_json::from_slice(&b)))
147 .unwrap()
148 .map_err(AwsError::DecodeError)
149}