use std::num::NonZeroUsize;
use super::cache_item::CacheItem;
use super::config::CacheConfig;
use aws_sdk_config::error::SdkError;
use aws_sdk_secretsmanager::operation::get_secret_value::GetSecretValueError;
use aws_sdk_secretsmanager::Client as SecretsManagerClient;
use lru::LruCache;
pub struct SecretCache {
client: SecretsManagerClient,
config: CacheConfig,
cache: LruCache<String, CacheItem<String>>,
}
impl SecretCache {
pub fn new(client: SecretsManagerClient) -> Self {
SecretCache::new_cache(client, CacheConfig::new())
}
pub fn new_with_config(client: SecretsManagerClient, config: CacheConfig) -> Self {
SecretCache::new_cache(client, config)
}
fn new_cache(client: SecretsManagerClient, config: CacheConfig) -> Self {
let cache = LruCache::new(
NonZeroUsize::new(config.max_cache_size)
.unwrap_or(NonZeroUsize::new(1).expect("Default max_cache_size must be non-zero")),
);
Self {
client,
config,
cache,
}
}
pub fn get_secret_string(&mut self, secret_id: String) -> GetSecretStringBuilder {
GetSecretStringBuilder::new(self, secret_id)
}
}
pub struct GetSecretStringBuilder<'a> {
secret_cache: &'a mut SecretCache,
secret_id: String,
force_refresh: bool,
}
impl<'a> GetSecretStringBuilder<'a> {
pub fn new(secret_cache: &'a mut SecretCache, secret_id: String) -> Self {
GetSecretStringBuilder {
secret_cache,
secret_id,
force_refresh: false,
}
}
pub fn force_refresh(mut self) -> Self {
self.force_refresh = true;
self
}
pub async fn send(&mut self) -> Result<String, SdkError<GetSecretValueError>> {
if !self.force_refresh {
if let Some(cache_item) = self.secret_cache.cache.get(&self.secret_id) {
if !cache_item.is_expired() {
return Ok(cache_item.value.clone());
}
}
}
match self.fetch_secret().await {
Ok(secret_value) => {
let cache_item = CacheItem::new(
secret_value.clone(),
self.secret_cache.config.cache_item_ttl,
);
self.secret_cache
.cache
.put(self.secret_id.clone(), cache_item);
Ok(secret_value)
}
Err(e) => Err(e),
}
}
async fn fetch_secret(&mut self) -> Result<String, SdkError<GetSecretValueError>> {
match self
.secret_cache
.client
.get_secret_value()
.secret_id(self.secret_id.clone())
.version_stage(self.secret_cache.config.version_stage.clone())
.send()
.await
{
Ok(resp) => return Ok(resp.secret_string.as_deref().unwrap().to_string()),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_sdk_config::config::{Credentials, Region};
use aws_sdk_secretsmanager::{Client as SecretsManagerClient, Config};
#[test]
fn get_secret_string_builder_defaults() {
let mock_secrets_manager_client = get_mock_secretsmanager_client();
let mut secret_cache = SecretCache::new(mock_secrets_manager_client);
let builder = GetSecretStringBuilder::new(&mut secret_cache, "service/secret".to_string());
assert_eq!(builder.secret_id, "service/secret");
assert!(!builder.force_refresh);
}
#[test]
fn get_secret_string_builder_force_refresh() {
let mock_secrets_manager_client = get_mock_secretsmanager_client();
let mut secret_cache = SecretCache::new(mock_secrets_manager_client);
let builder = GetSecretStringBuilder::new(&mut secret_cache, "service/secret".to_string())
.force_refresh();
assert_eq!(builder.secret_id, "service/secret");
assert!(builder.force_refresh);
}
fn get_mock_secretsmanager_client() -> SecretsManagerClient {
let conf = Config::builder()
.region(Region::new("ap-southeast-2"))
.credentials_provider(Credentials::new("asdf", "asdf", None, None, "test"))
.build();
SecretsManagerClient::from_conf(conf)
}
}