use anyhow::{anyhow, Result};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SecretStoreError {
#[error("Secret not found: {0}")]
NotFound(String),
#[error("Provider not configured: {0}")]
ProviderNotConfigured(String),
#[error("Failed to decrypt secret: {0}")]
DecryptionFailed(String),
#[error("Invalid secret format: {0}")]
InvalidFormat(String),
}
pub struct SecretStore {
cache: HashMap<String, String>,
providers: HashMap<String, Box<dyn SecretProvider>>,
}
impl SecretStore {
pub fn new() -> Self {
let mut providers: HashMap<String, Box<dyn SecretProvider>> = HashMap::new();
providers.insert("env".to_string(), Box::new(EnvSecretProvider));
providers.insert("file".to_string(), Box::new(FileSecretProvider));
Self {
cache: HashMap::new(),
providers,
}
}
pub async fn get_secret(&mut self, key: &str) -> Result<String> {
if let Some(cached) = self.cache.get(key) {
return Ok(cached.clone());
}
let (provider_name, secret_key) = if let Some(colon_pos) = key.find(':') {
(&key[..colon_pos], &key[colon_pos + 1..])
} else {
("env", key)
};
let provider = self
.providers
.get(provider_name)
.ok_or_else(|| SecretStoreError::ProviderNotConfigured(provider_name.to_string()))?;
let value = provider.get_secret(secret_key).await?;
self.cache.insert(key.to_string(), value.clone());
Ok(value)
}
pub fn add_provider(&mut self, name: String, provider: Box<dyn SecretProvider>) {
self.providers.insert(name, provider);
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
pub async fn has_secret(&self, key: &str) -> bool {
if self.cache.contains_key(key) {
return true;
}
let (provider_name, secret_key) = if let Some(colon_pos) = key.find(':') {
(&key[..colon_pos], &key[colon_pos + 1..])
} else {
("env", key)
};
if let Some(provider) = self.providers.get(provider_name) {
provider.has_secret(secret_key).await
} else {
false
}
}
}
#[async_trait::async_trait]
pub trait SecretProvider: Send + Sync {
async fn get_secret(&self, key: &str) -> Result<String>;
async fn has_secret(&self, key: &str) -> bool;
}
struct EnvSecretProvider;
#[async_trait::async_trait]
impl SecretProvider for EnvSecretProvider {
async fn get_secret(&self, key: &str) -> Result<String> {
std::env::var(key).map_err(|_| anyhow!(SecretStoreError::NotFound(key.to_string())))
}
async fn has_secret(&self, key: &str) -> bool {
std::env::var(key).is_ok()
}
}
struct FileSecretProvider;
#[async_trait::async_trait]
impl SecretProvider for FileSecretProvider {
async fn get_secret(&self, path: &str) -> Result<String> {
tokio::fs::read_to_string(path)
.await
.map(|s| s.trim().to_string())
.map_err(|e| anyhow!("Failed to read secret file {}: {}", path, e))
}
async fn has_secret(&self, path: &str) -> bool {
tokio::fs::metadata(path).await.is_ok()
}
}
#[cfg(test)]
pub struct MockSecretProvider {
secrets: HashMap<String, String>,
}
#[cfg(test)]
impl MockSecretProvider {
pub fn new() -> Self {
Self {
secrets: HashMap::new(),
}
}
pub fn add_secret(&mut self, key: String, value: String) {
self.secrets.insert(key, value);
}
}
#[cfg(test)]
#[async_trait::async_trait]
impl SecretProvider for MockSecretProvider {
async fn get_secret(&self, key: &str) -> Result<String> {
self.secrets
.get(key)
.cloned()
.ok_or_else(|| anyhow!(SecretStoreError::NotFound(key.to_string())))
}
async fn has_secret(&self, key: &str) -> bool {
self.secrets.contains_key(key)
}
}
impl Default for SecretStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[serial_test::serial] async fn test_env_secret_provider() {
std::env::set_var("TEST_SECRET", "secret_value");
let provider = EnvSecretProvider;
let value = provider.get_secret("TEST_SECRET").await.unwrap();
assert_eq!(value, "secret_value");
assert!(provider.has_secret("TEST_SECRET").await);
assert!(!provider.has_secret("NONEXISTENT").await);
}
#[tokio::test]
#[serial_test::serial] async fn test_secret_store_cache() {
let mut store = SecretStore::new();
std::env::set_var("CACHED_SECRET", "cached_value");
let value1 = store.get_secret("env:CACHED_SECRET").await.unwrap();
assert_eq!(value1, "cached_value");
std::env::set_var("CACHED_SECRET", "new_value");
let value2 = store.get_secret("env:CACHED_SECRET").await.unwrap();
assert_eq!(value2, "cached_value");
store.clear_cache();
let value3 = store.get_secret("env:CACHED_SECRET").await.unwrap();
assert_eq!(value3, "new_value"); }
#[tokio::test]
async fn test_mock_secret_provider() {
let mut mock = MockSecretProvider::new();
mock.add_secret("test_key".to_string(), "test_value".to_string());
let mut store = SecretStore::new();
store.add_provider("mock".to_string(), Box::new(mock));
let value = store.get_secret("mock:test_key").await.unwrap();
assert_eq!(value, "test_value");
assert!(store.has_secret("mock:test_key").await);
}
#[tokio::test]
#[serial_test::serial] async fn test_default_provider() {
let mut store = SecretStore::new();
std::env::set_var("DEFAULT_SECRET", "default_value");
let value = store.get_secret("DEFAULT_SECRET").await.unwrap();
assert_eq!(value, "default_value");
}
}