1use async_trait::async_trait;
4use aws_sdk_secretsmanager::Client;
5use cuenv_secrets::{SecretError, SecretResolver, SecretSpec, SecureSecret};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use tokio::process::Command;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12#[serde(rename_all = "camelCase")]
13pub struct AwsSecretConfig {
14 pub secret_id: String,
16
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub version_id: Option<String>,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub version_stage: Option<String>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub json_key: Option<String>,
28}
29
30impl AwsSecretConfig {
31 #[must_use]
33 pub fn new(secret_id: impl Into<String>) -> Self {
34 Self {
35 secret_id: secret_id.into(),
36 version_id: None,
37 version_stage: None,
38 json_key: None,
39 }
40 }
41}
42
43pub struct AwsResolver {
53 http_client: Option<Client>,
54}
55
56impl std::fmt::Debug for AwsResolver {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("AwsResolver")
59 .field("mode", &if self.can_use_http() { "http" } else { "cli" })
60 .finish()
61 }
62}
63
64impl AwsResolver {
65 pub async fn new() -> Result<Self, SecretError> {
73 let http_client = if Self::http_credentials_available() {
74 let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
75 .load()
76 .await;
77 Some(Client::new(&config))
78 } else {
79 None
80 };
81
82 Ok(Self { http_client })
83 }
84
85 fn http_credentials_available() -> bool {
87 std::env::var("AWS_ACCESS_KEY_ID").is_ok() && std::env::var("AWS_SECRET_ACCESS_KEY").is_ok()
88 }
89
90 fn can_use_http(&self) -> bool {
92 self.http_client.is_some()
93 }
94
95 async fn resolve_http(
97 &self,
98 name: &str,
99 config: &AwsSecretConfig,
100 ) -> Result<String, SecretError> {
101 let client = self
102 .http_client
103 .as_ref()
104 .ok_or_else(|| SecretError::ResolutionFailed {
105 name: name.to_string(),
106 message: "HTTP client not available".to_string(),
107 })?;
108
109 let mut request = client.get_secret_value().secret_id(&config.secret_id);
110
111 if let Some(version_id) = &config.version_id {
112 request = request.version_id(version_id);
113 }
114
115 if let Some(version_stage) = &config.version_stage {
116 request = request.version_stage(version_stage);
117 }
118
119 let response = request
120 .send()
121 .await
122 .map_err(|e| SecretError::ResolutionFailed {
123 name: name.to_string(),
124 message: format!("AWS Secrets Manager error: {e}"),
125 })?;
126
127 let secret_string =
128 response
129 .secret_string()
130 .ok_or_else(|| SecretError::ResolutionFailed {
131 name: name.to_string(),
132 message: "Secret has no string value (may be binary)".to_string(),
133 })?;
134
135 Self::extract_json_key(name, secret_string, config.json_key.as_ref())
136 }
137
138 async fn resolve_cli(
140 &self,
141 name: &str,
142 config: &AwsSecretConfig,
143 ) -> Result<String, SecretError> {
144 let mut args = vec![
145 "secretsmanager".to_string(),
146 "get-secret-value".to_string(),
147 "--secret-id".to_string(),
148 config.secret_id.clone(),
149 "--query".to_string(),
150 "SecretString".to_string(),
151 "--output".to_string(),
152 "text".to_string(),
153 ];
154
155 if let Some(version_id) = &config.version_id {
156 args.push("--version-id".to_string());
157 args.push(version_id.clone());
158 }
159
160 if let Some(version_stage) = &config.version_stage {
161 args.push("--version-stage".to_string());
162 args.push(version_stage.clone());
163 }
164
165 let output = Command::new("aws")
166 .args(&args)
167 .output()
168 .await
169 .map_err(|e| SecretError::ResolutionFailed {
170 name: name.to_string(),
171 message: format!("Failed to execute aws CLI: {e}"),
172 })?;
173
174 if !output.status.success() {
175 let stderr = String::from_utf8_lossy(&output.stderr);
176 return Err(SecretError::ResolutionFailed {
177 name: name.to_string(),
178 message: format!("aws CLI failed: {stderr}"),
179 });
180 }
181
182 let secret_string = String::from_utf8_lossy(&output.stdout).trim().to_string();
183 Self::extract_json_key(name, &secret_string, config.json_key.as_ref())
184 }
185
186 fn extract_json_key(
188 name: &str,
189 secret_string: &str,
190 json_key: Option<&String>,
191 ) -> Result<String, SecretError> {
192 if let Some(key) = json_key {
193 let parsed: serde_json::Value =
194 serde_json::from_str(secret_string).map_err(|e| SecretError::ResolutionFailed {
195 name: name.to_string(),
196 message: format!("Secret is not valid JSON: {e}"),
197 })?;
198
199 let value = parsed
200 .get(key)
201 .ok_or_else(|| SecretError::ResolutionFailed {
202 name: name.to_string(),
203 message: format!("JSON key '{key}' not found in secret"),
204 })?;
205
206 return match value {
207 serde_json::Value::String(s) => Ok(s.clone()),
208 other => Ok(other.to_string()),
209 };
210 }
211
212 Ok(secret_string.to_string())
213 }
214
215 async fn resolve_with_config(
217 &self,
218 name: &str,
219 config: &AwsSecretConfig,
220 ) -> Result<String, SecretError> {
221 if self.http_client.is_some() {
223 return self.resolve_http(name, config).await;
224 }
225
226 self.resolve_cli(name, config).await
228 }
229
230 async fn resolve_batch_http(
232 &self,
233 secrets: &HashMap<String, SecretSpec>,
234 ) -> Result<HashMap<String, SecureSecret>, SecretError> {
235 use futures::future::try_join_all;
236
237 let client = self
238 .http_client
239 .as_ref()
240 .ok_or_else(|| SecretError::ResolutionFailed {
241 name: "batch".to_string(),
242 message: "HTTP client not available".to_string(),
243 })?;
244
245 let mut id_to_names: HashMap<String, Vec<(String, AwsSecretConfig)>> = HashMap::new();
248 for (name, spec) in secrets {
249 let config = serde_json::from_str::<AwsSecretConfig>(&spec.source)
250 .unwrap_or_else(|_| AwsSecretConfig::new(&spec.source));
251 id_to_names
252 .entry(config.secret_id.clone())
253 .or_default()
254 .push((name.clone(), config));
255 }
256
257 let secret_ids: Vec<String> = id_to_names.keys().cloned().collect();
259
260 let mut all_values: HashMap<String, String> = HashMap::new();
262
263 for chunk in secret_ids.chunks(20) {
264 let response = client
265 .batch_get_secret_value()
266 .set_secret_id_list(Some(chunk.to_vec()))
267 .send()
268 .await
269 .map_err(|e| SecretError::ResolutionFailed {
270 name: "batch".to_string(),
271 message: format!("AWS BatchGetSecretValue failed: {e}"),
272 })?;
273
274 for sv in response.secret_values() {
276 if let Some(secret_string) = sv.secret_string() {
277 if let Some(secret_name) = sv.name() {
279 all_values.insert(secret_name.to_string(), secret_string.to_string());
280 }
281 if let Some(arn) = sv.arn() {
282 all_values.insert(arn.to_string(), secret_string.to_string());
283 }
284 }
285 }
286
287 for err in response.errors() {
289 tracing::warn!(
290 secret_id = ?err.secret_id(),
291 error_code = ?err.error_code(),
292 message = ?err.message(),
293 "Failed to retrieve secret in batch"
294 );
295 }
296 }
297
298 let extract_futures: Vec<_> = secrets
300 .iter()
301 .map(|(name, spec)| {
302 let name = name.clone();
303 let all_values = &all_values;
304 async move {
305 let config = serde_json::from_str::<AwsSecretConfig>(&spec.source)
306 .unwrap_or_else(|_| AwsSecretConfig::new(&spec.source));
307
308 let secret_string = all_values.get(&config.secret_id).ok_or_else(|| {
310 SecretError::ResolutionFailed {
311 name: name.clone(),
312 message: format!(
313 "Secret '{}' not found in batch response",
314 config.secret_id
315 ),
316 }
317 })?;
318
319 let value =
321 Self::extract_json_key(&name, secret_string, config.json_key.as_ref())?;
322 Ok::<_, SecretError>((name, SecureSecret::new(value)))
323 }
324 })
325 .collect();
326
327 try_join_all(extract_futures)
328 .await
329 .map(|v| v.into_iter().collect())
330 }
331
332 async fn resolve_batch_cli(
334 &self,
335 secrets: &HashMap<String, SecretSpec>,
336 ) -> Result<HashMap<String, SecureSecret>, SecretError> {
337 use futures::future::try_join_all;
338
339 let futures: Vec<_> = secrets
340 .iter()
341 .map(|(name, spec)| {
342 let name = name.clone();
343 let spec = spec.clone();
344 async move {
345 let value = self.resolve(&name, &spec).await?;
346 Ok::<_, SecretError>((name, SecureSecret::new(value)))
347 }
348 })
349 .collect();
350
351 try_join_all(futures).await.map(|v| v.into_iter().collect())
352 }
353}
354
355#[async_trait]
356impl SecretResolver for AwsResolver {
357 fn provider_name(&self) -> &'static str {
358 "aws"
359 }
360
361 fn supports_native_batch(&self) -> bool {
362 true
364 }
365
366 async fn resolve(&self, name: &str, spec: &SecretSpec) -> Result<String, SecretError> {
367 if let Ok(config) = serde_json::from_str::<AwsSecretConfig>(&spec.source) {
369 return self.resolve_with_config(name, &config).await;
370 }
371
372 let config = AwsSecretConfig::new(&spec.source);
374 self.resolve_with_config(name, &config).await
375 }
376
377 async fn resolve_batch(
378 &self,
379 secrets: &HashMap<String, SecretSpec>,
380 ) -> Result<HashMap<String, SecureSecret>, SecretError> {
381 if secrets.is_empty() {
382 return Ok(HashMap::new());
383 }
384
385 if self.http_client.is_some() {
387 return self.resolve_batch_http(secrets).await;
388 }
389
390 self.resolve_batch_cli(secrets).await
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_aws_config_serialization() {
401 let config = AwsSecretConfig {
402 secret_id: "my-secret".to_string(),
403 version_id: Some("v1".to_string()),
404 version_stage: None,
405 json_key: Some("password".to_string()),
406 };
407
408 let json = serde_json::to_string(&config).unwrap();
409 let parsed: AwsSecretConfig = serde_json::from_str(&json).unwrap();
410 assert_eq!(config, parsed);
411 }
412
413 #[test]
414 fn test_simple_config() {
415 let config = AwsSecretConfig::new("arn:aws:secretsmanager:us-east-1:123456:secret:test");
416 assert_eq!(
417 config.secret_id,
418 "arn:aws:secretsmanager:us-east-1:123456:secret:test"
419 );
420 assert!(config.version_id.is_none());
421 assert!(config.json_key.is_none());
422 }
423
424 #[test]
425 fn test_http_credentials_check() {
426 let _ = AwsResolver::http_credentials_available();
428 }
429}