use aws_config::BehaviorVersion;
use aws_sdk_secretsmanager::Client;
use serde::{Deserialize, Serialize};
use tracing::debug;
use super::error::{SecretsError, SecretsResult};
use super::provider::SecretProvider;
use super::types::{SecretMetadata, SecretValue};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AwsConfig {
pub region: String,
pub endpoint_url: Option<String>,
}
impl Default for AwsConfig {
fn default() -> Self {
Self {
region: "us-east-1".into(),
endpoint_url: None,
}
}
}
impl AwsConfig {
#[must_use]
pub fn from_env() -> Self {
use crate::config::env_compat::aws;
Self {
region: aws::region().get_or("us-east-1"),
endpoint_url: aws::endpoint_url().get(),
}
}
#[must_use]
pub fn with_region(region: &str) -> Self {
Self {
region: region.to_string(),
endpoint_url: None,
}
}
#[must_use]
pub fn for_localstack(endpoint: &str) -> Self {
Self {
region: "us-east-1".to_string(),
endpoint_url: Some(endpoint.to_string()),
}
}
#[must_use]
pub fn with_endpoint(mut self, endpoint: &str) -> Self {
self.endpoint_url = Some(endpoint.to_string());
self
}
}
pub struct AwsProvider {
client: Client,
}
impl AwsProvider {
pub fn new(config: &AwsConfig) -> SecretsResult<Self> {
let rt = tokio::runtime::Handle::try_current()
.map_err(|e| SecretsError::ConfigError(format!("tokio runtime required: {e}")))?;
let client = rt.block_on(async { Self::create_client(config).await })?;
Ok(Self { client })
}
async fn create_client(config: &AwsConfig) -> SecretsResult<Client> {
let mut aws_config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_config::Region::new(config.region.clone()));
if let Some(ref endpoint) = config.endpoint_url {
aws_config = aws_config.endpoint_url(endpoint);
}
let aws_config = aws_config.load().await;
Ok(Client::new(&aws_config))
}
pub async fn get(&self, secret_id: &str, key: Option<&str>) -> SecretsResult<SecretValue> {
let response = self
.client
.get_secret_value()
.secret_id(secret_id)
.send()
.await
.map_err(|e| {
if e.to_string().contains("ResourceNotFoundException") {
SecretsError::NotFound(format!("secret not found: {secret_id}"))
} else if e.to_string().contains("AccessDenied") {
SecretsError::AuthError(format!("access denied to secret: {secret_id}"))
} else {
SecretsError::ProviderError(format!("failed to get secret {secret_id}: {e}"))
}
})?;
let secret_data = if let Some(secret_string) = response.secret_string() {
if let Some(key) = key {
let json: serde_json::Value = serde_json::from_str(secret_string).map_err(|e| {
SecretsError::InvalidData(format!(
"secret is not JSON, cannot extract key '{key}': {e}"
))
})?;
let value = json.get(key).ok_or_else(|| {
SecretsError::NotFound(format!("key '{key}' not found in secret '{secret_id}'"))
})?;
match value {
serde_json::Value::String(s) => s.as_bytes().to_vec(),
_ => serde_json::to_vec(value).map_err(|e| {
SecretsError::InvalidData(format!("failed to serialize value: {e}"))
})?,
}
} else {
secret_string.as_bytes().to_vec()
}
} else if let Some(secret_binary) = response.secret_binary() {
secret_binary.as_ref().to_vec()
} else {
return Err(SecretsError::InvalidData(
"secret has no string or binary data".into(),
));
};
let metadata = SecretMetadata {
version: response.version_id().map(String::from),
source_path: response.arn().map(String::from),
provider: Some("aws".into()),
};
debug!(
secret_id = %secret_id,
version = ?metadata.version,
"Secret fetched from AWS Secrets Manager"
);
Ok(SecretValue::with_metadata(secret_data, metadata))
}
}
impl SecretProvider for AwsProvider {
async fn get(&self, path: &str, key: Option<&str>) -> SecretsResult<SecretValue> {
self.get(path, key).await
}
async fn health_check(&self) -> SecretsResult<()> {
self.client
.list_secrets()
.max_results(1)
.send()
.await
.map_err(|e| SecretsError::ProviderError(format!("AWS health check failed: {e}")))?;
Ok(())
}
fn name(&self) -> &'static str {
"aws"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aws_config_default() {
let config = AwsConfig::default();
assert_eq!(config.region, "us-east-1");
assert!(config.endpoint_url.is_none());
}
#[test]
fn test_aws_config_serialization() {
let config = AwsConfig {
region: "eu-west-1".into(),
endpoint_url: Some("http://localhost:4566".into()),
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("eu-west-1"));
assert!(json.contains("localhost:4566"));
}
}