use crate::{
BatchSecrets, SaltConfig, SecretError, SecretResolver, SecretSpec, compute_secret_fingerprint,
};
use futures::future::try_join_all;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct BatchConfig {
pub salt_config: SaltConfig,
}
impl BatchConfig {
#[must_use]
pub const fn new(salt_config: SaltConfig) -> Self {
Self { salt_config }
}
#[must_use]
pub const fn from_salt(salt: Option<String>) -> Self {
Self {
salt_config: SaltConfig::new(salt),
}
}
}
pub struct BatchResolver<'a> {
resolvers: HashMap<&'static str, &'a dyn SecretResolver>,
config: BatchConfig,
}
impl<'a> BatchResolver<'a> {
#[must_use]
pub fn new(config: BatchConfig) -> Self {
Self {
resolvers: HashMap::new(),
config,
}
}
pub fn add_resolver(&mut self, resolver: &'a dyn SecretResolver) {
self.resolvers.insert(resolver.provider_name(), resolver);
}
#[must_use]
pub fn resolver_count(&self) -> usize {
self.resolvers.len()
}
pub async fn resolve_all(
&self,
secrets: &HashMap<String, (SecretSpec, &'static str)>,
) -> Result<BatchSecrets, SecretError> {
let needs_salt = secrets.values().any(|(spec, _)| spec.cache_key);
if needs_salt && !self.config.salt_config.has_salt() {
return Err(SecretError::MissingSalt);
}
let mut by_provider: HashMap<&'static str, HashMap<String, SecretSpec>> = HashMap::new();
for (name, (spec, provider)) in secrets {
by_provider
.entry(*provider)
.or_default()
.insert(name.clone(), spec.clone());
}
let provider_futures: Vec<_> = by_provider
.into_iter()
.map(|(provider, provider_secrets)| async move {
let secret_resolver = self.resolvers.get(provider).ok_or_else(|| {
SecretError::UnsupportedResolver {
resolver: provider.to_string(),
}
})?;
let batch_results = secret_resolver.resolve_batch(&provider_secrets).await?;
Ok::<_, SecretError>((provider, batch_results))
})
.collect();
let provider_results = try_join_all(provider_futures).await?;
let mut batch = BatchSecrets::with_capacity(secrets.len());
for (_provider, batch_result) in provider_results {
for (name, secure_value) in batch_result {
let fingerprint = secrets.get(&name).and_then(|(spec, _)| {
if !spec.cache_key {
return None;
}
if secure_value.len() < 4 {
tracing::warn!(
secret = %name,
len = secure_value.len(),
"Secret is too short for safe cache key inclusion"
);
}
Some(compute_secret_fingerprint(
&name,
secure_value.expose(),
self.config.salt_config.write_salt().unwrap_or(""),
))
});
batch.insert(name, secure_value, fingerprint);
}
}
Ok(batch)
}
}
#[allow(clippy::implicit_hasher)]
pub async fn resolve_batch<R: SecretResolver>(
resolver: &R,
secrets: &HashMap<String, SecretSpec>,
salt_config: &SaltConfig,
) -> Result<BatchSecrets, SecretError> {
let needs_salt = secrets.values().any(|s| s.cache_key);
if needs_salt && !salt_config.has_salt() {
return Err(SecretError::MissingSalt);
}
let batch_results = resolver.resolve_batch(secrets).await?;
let mut batch = BatchSecrets::with_capacity(secrets.len());
for (name, secure_value) in batch_results {
let fingerprint = secrets.get(&name).and_then(|spec| {
if !spec.cache_key {
return None;
}
if secure_value.len() < 4 {
tracing::warn!(
secret = %name,
len = secure_value.len(),
"Secret is too short for safe cache key inclusion"
);
}
Some(compute_secret_fingerprint(
&name,
secure_value.expose(),
salt_config.write_salt().unwrap_or(""),
))
});
batch.insert(name, secure_value, fingerprint);
}
Ok(batch)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::EnvSecretResolver;
#[test]
fn test_batch_config_default() {
let config = BatchConfig::default();
assert!(!config.salt_config.has_salt());
}
#[test]
fn test_batch_config_new() {
let salt_config = SaltConfig::new(Some("test-salt".to_string()));
let config = BatchConfig::new(salt_config);
assert!(config.salt_config.has_salt());
assert_eq!(config.salt_config.write_salt(), Some("test-salt"));
}
#[test]
fn test_batch_config_from_salt() {
let config = BatchConfig::from_salt(Some("my-salt".to_string()));
assert!(config.salt_config.has_salt());
assert_eq!(config.salt_config.write_salt(), Some("my-salt"));
}
#[test]
fn test_batch_config_from_salt_none() {
let config = BatchConfig::from_salt(None);
assert!(!config.salt_config.has_salt());
}
#[test]
fn test_batch_config_clone() {
let config = BatchConfig::from_salt(Some("cloned-salt".to_string()));
let cloned = config.clone();
assert!(cloned.salt_config.has_salt());
}
#[test]
fn test_batch_config_debug() {
let config = BatchConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("BatchConfig"));
}
#[test]
fn test_batch_resolver_new() {
let config = BatchConfig::default();
let resolver = BatchResolver::new(config);
assert_eq!(resolver.resolver_count(), 0);
}
#[test]
fn test_batch_resolver_add_resolver() {
let config = BatchConfig::default();
let mut resolver = BatchResolver::new(config);
let env_resolver = EnvSecretResolver::new();
resolver.add_resolver(&env_resolver);
assert_eq!(resolver.resolver_count(), 1);
}
#[test]
fn test_batch_resolver_add_multiple_resolvers() {
let config = BatchConfig::default();
let mut resolver = BatchResolver::new(config);
let env_resolver = EnvSecretResolver::new();
resolver.add_resolver(&env_resolver);
resolver.add_resolver(&env_resolver);
assert_eq!(resolver.resolver_count(), 1);
}
#[tokio::test]
async fn test_batch_resolver_resolve_all_empty() {
let config = BatchConfig::default();
let resolver = BatchResolver::new(config);
let secrets: HashMap<String, (SecretSpec, &'static str)> = HashMap::new();
let result = resolver.resolve_all(&secrets).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_batch_resolver_missing_provider() {
let config = BatchConfig::default();
let resolver = BatchResolver::new(config);
let mut secrets = HashMap::new();
secrets.insert(
"TEST".to_string(),
(SecretSpec::new("test"), "unknown_provider"),
);
let result = resolver.resolve_all(&secrets).await;
assert!(matches!(
result,
Err(SecretError::UnsupportedResolver { .. })
));
}
#[tokio::test]
async fn test_batch_resolver_missing_salt_for_cache_key() {
let config = BatchConfig::default(); let resolver = BatchResolver::new(config);
let mut secrets = HashMap::new();
secrets.insert(
"TEST".to_string(),
(SecretSpec::with_cache_key("test"), "env"),
);
let result = resolver.resolve_all(&secrets).await;
assert!(matches!(result, Err(SecretError::MissingSalt)));
}
#[tokio::test]
async fn test_resolve_batch_empty() {
let resolver = EnvSecretResolver::new();
let secrets = HashMap::new();
let salt = SaltConfig::default();
let result = resolve_batch(&resolver, &secrets, &salt).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_resolve_batch_missing_salt() {
let resolver = EnvSecretResolver::new();
let mut secrets = HashMap::new();
secrets.insert("TEST".to_string(), SecretSpec::with_cache_key("TEST_VAR"));
let salt = SaltConfig::default();
let result = resolve_batch(&resolver, &secrets, &salt).await;
assert!(matches!(result, Err(SecretError::MissingSalt)));
}
#[tokio::test]
async fn test_resolve_batch_no_cache_key_no_salt_ok() {
let resolver = EnvSecretResolver::new();
let mut secrets = HashMap::new();
secrets.insert("TEST".to_string(), SecretSpec::new("NONEXISTENT_VAR"));
let salt = SaltConfig::default();
let result = resolve_batch(&resolver, &secrets, &salt).await;
assert!(
!matches!(result, Err(SecretError::MissingSalt)),
"Should not require salt for non-cache-key secrets"
);
}
#[tokio::test]
async fn test_resolve_batch_with_salt_and_cache_key() {
#[allow(unsafe_code)]
unsafe {
std::env::set_var("BATCH_TEST_SECRET", "test_value");
}
let resolver = EnvSecretResolver::new();
let mut secrets = HashMap::new();
secrets.insert(
"my_secret".to_string(),
SecretSpec::with_cache_key("BATCH_TEST_SECRET"),
);
let salt = SaltConfig::new(Some("test-salt".to_string()));
let result = resolve_batch(&resolver, &secrets, &salt).await.unwrap();
assert!(!result.is_empty());
#[allow(unsafe_code)]
unsafe {
std::env::remove_var("BATCH_TEST_SECRET");
}
}
#[tokio::test]
async fn test_resolve_batch_without_cache_key() {
#[allow(unsafe_code)]
unsafe {
std::env::set_var("BATCH_TEST_NO_CACHE", "another_value");
}
let resolver = EnvSecretResolver::new();
let mut secrets = HashMap::new();
secrets.insert(
"my_secret".to_string(),
SecretSpec::new("BATCH_TEST_NO_CACHE"),
);
let salt = SaltConfig::default();
let result = resolve_batch(&resolver, &secrets, &salt).await.unwrap();
assert!(!result.is_empty());
#[allow(unsafe_code)]
unsafe {
std::env::remove_var("BATCH_TEST_NO_CACHE");
}
}
#[tokio::test]
async fn test_resolve_batch_multiple_secrets() {
#[allow(unsafe_code)]
unsafe {
std::env::set_var("BATCH_MULTI_1", "value1");
std::env::set_var("BATCH_MULTI_2", "value2");
}
let resolver = EnvSecretResolver::new();
let mut secrets = HashMap::new();
secrets.insert("secret1".to_string(), SecretSpec::new("BATCH_MULTI_1"));
secrets.insert("secret2".to_string(), SecretSpec::new("BATCH_MULTI_2"));
let salt = SaltConfig::default();
let result = resolve_batch(&resolver, &secrets, &salt).await.unwrap();
assert_eq!(result.len(), 2);
#[allow(unsafe_code)]
unsafe {
std::env::remove_var("BATCH_MULTI_1");
std::env::remove_var("BATCH_MULTI_2");
}
}
}