use aws_sdk_secretsmanager::Client;
use llm_shield_cloud::{
async_trait, CloudError, CloudSecretManager, Result, SecretCache, SecretMetadata, SecretValue,
};
use std::collections::HashMap;
use std::sync::Arc;
pub struct AwsSecretsManager {
client: Client,
cache: SecretCache,
region: String,
}
impl AwsSecretsManager {
pub async fn new() -> Result<Self> {
let config = aws_config::load_from_env().await;
let region = config
.region()
.map(|r| r.to_string())
.unwrap_or_else(|| "us-east-1".to_string());
let client = Client::new(&config);
let cache = SecretCache::new(300);
tracing::info!("Initialized AWS Secrets Manager client in region: {}", region);
Ok(Self {
client,
cache,
region,
})
}
pub async fn new_with_region(region: impl Into<String>) -> Result<Self> {
let region_str = region.into();
let config = aws_config::from_env()
.region(aws_config::Region::new(region_str.clone()))
.load()
.await;
let client = Client::new(&config);
let cache = SecretCache::new(300);
tracing::info!("Initialized AWS Secrets Manager client in region: {}", region_str);
Ok(Self {
client,
cache,
region: region_str,
})
}
pub async fn new_with_cache_ttl(
region: impl Into<String>,
cache_ttl_seconds: u64,
) -> Result<Self> {
let region_str = region.into();
let config = aws_config::from_env()
.region(aws_config::Region::new(region_str.clone()))
.load()
.await;
let client = Client::new(&config);
let cache = SecretCache::new(cache_ttl_seconds);
tracing::info!(
"Initialized AWS Secrets Manager client in region: {} with {}s cache TTL",
region_str,
cache_ttl_seconds
);
Ok(Self {
client,
cache,
region: region_str,
})
}
pub fn region(&self) -> &str {
&self.region
}
pub async fn clear_cache(&self) {
self.cache.clear().await;
tracing::debug!("Cleared AWS Secrets Manager cache");
}
pub async fn cache_size(&self) -> usize {
self.cache.len().await
}
}
#[async_trait]
impl CloudSecretManager for AwsSecretsManager {
async fn get_secret(&self, name: &str) -> Result<SecretValue> {
if let Some(cached) = self.cache.get(name).await {
tracing::debug!("Cache hit for secret: {}", name);
return Ok(cached);
}
tracing::debug!("Fetching secret from AWS Secrets Manager: {}", name);
let response = self
.client
.get_secret_value()
.secret_id(name)
.send()
.await
.map_err(|e| CloudError::secret_fetch(name, e.to_string()))?;
let value = if let Some(secret_string) = response.secret_string() {
SecretValue::from_string(secret_string.to_string())
} else if let Some(secret_binary) = response.secret_binary() {
SecretValue::from_bytes(secret_binary.clone().into_inner())
} else {
return Err(CloudError::SecretFormat {
name: name.to_string(),
reason: "Secret has no string or binary value".to_string(),
});
};
self.cache.set(name.to_string(), value.clone()).await;
tracing::info!("Successfully fetched secret: {}", name);
Ok(value)
}
async fn list_secrets(&self) -> Result<Vec<String>> {
tracing::debug!("Listing secrets from AWS Secrets Manager");
let mut secret_names = Vec::new();
let mut next_token: Option<String> = None;
loop {
let mut request = self.client.list_secrets();
if let Some(token) = next_token {
request = request.next_token(token);
}
let response = request
.send()
.await
.map_err(|e| CloudError::SecretList(e.to_string()))?;
for secret in response.secret_list() {
if let Some(name) = secret.name() {
secret_names.push(name.to_string());
}
}
next_token = response.next_token().map(String::from);
if next_token.is_none() {
break;
}
}
tracing::info!("Listed {} secrets", secret_names.len());
Ok(secret_names)
}
async fn create_secret(&self, name: &str, value: &SecretValue) -> Result<()> {
tracing::debug!("Creating secret in AWS Secrets Manager: {}", name);
self.client
.create_secret()
.name(name)
.secret_string(value.as_string())
.send()
.await
.map_err(|e| CloudError::secret_create(name, e.to_string()))?;
tracing::info!("Successfully created secret: {}", name);
Ok(())
}
async fn update_secret(&self, name: &str, value: &SecretValue) -> Result<()> {
tracing::debug!("Updating secret in AWS Secrets Manager: {}", name);
self.client
.update_secret()
.secret_id(name)
.secret_string(value.as_string())
.send()
.await
.map_err(|e| CloudError::secret_update(name, e.to_string()))?;
self.cache.invalidate(name).await;
tracing::info!("Successfully updated secret: {}", name);
Ok(())
}
async fn delete_secret(&self, name: &str) -> Result<()> {
tracing::debug!("Deleting secret from AWS Secrets Manager: {}", name);
self.client
.delete_secret()
.secret_id(name)
.force_delete_without_recovery(false) .send()
.await
.map_err(|e| CloudError::secret_delete(name, e.to_string()))?;
self.cache.invalidate(name).await;
tracing::info!("Successfully deleted secret (30-day recovery): {}", name);
Ok(())
}
async fn get_secret_metadata(&self, name: &str) -> Result<SecretMetadata> {
tracing::debug!("Fetching secret metadata from AWS Secrets Manager: {}", name);
let response = self
.client
.describe_secret()
.secret_id(name)
.send()
.await
.map_err(|e| CloudError::secret_fetch(name, e.to_string()))?;
let created_at = response
.created_date()
.and_then(|dt| {
chrono::DateTime::from_timestamp(dt.secs(), dt.subsec_nanos())
})
.unwrap_or_else(chrono::Utc::now);
let updated_at = response
.last_changed_date()
.and_then(|dt| {
chrono::DateTime::from_timestamp(dt.secs(), dt.subsec_nanos())
})
.unwrap_or(created_at);
let mut tags = HashMap::new();
for tag in response.tags() {
if let (Some(key), Some(value)) = (tag.key(), tag.value()) {
tags.insert(key.to_string(), value.to_string());
}
}
let version = response.version_ids_to_stages().and_then(|versions| {
versions
.iter()
.find(|(_, stages)| stages.contains(&"AWSCURRENT".to_string()))
.map(|(version_id, _)| version_id.clone())
});
Ok(SecretMetadata {
name: name.to_string(),
created_at,
updated_at,
tags,
version,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aws_secrets_manager_region() {
let region = "us-west-2".to_string();
assert_eq!(region, "us-west-2");
}
#[tokio::test]
async fn test_cache_operations() {
let cache = SecretCache::new(300);
let test_secret = SecretValue::from_string("test-value".to_string());
cache.set("test-key".to_string(), test_secret.clone()).await;
let retrieved = cache.get("test-key").await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().as_string(), "test-value");
cache.clear().await;
assert_eq!(cache.len().await, 0);
}
}