use async_trait::async_trait;
#[derive(Debug, thiserror::Error)]
pub enum SecretsError {
#[error("secret not found: {0}")]
NotFound(String),
#[error("invalid secret reference: {0}")]
Invalid(String),
#[error("secrets backend error: {0}")]
Backend(#[source] Box<dyn std::error::Error + Send + Sync>),
}
#[async_trait]
pub trait SecretsResolver: Send + Sync + 'static {
async fn resolve(&self, reference: &str) -> Result<String, SecretsError>;
}
pub struct LiteralSecretsResolver;
#[async_trait]
impl SecretsResolver for LiteralSecretsResolver {
async fn resolve(&self, reference: &str) -> Result<String, SecretsError> {
Ok(reference.to_owned())
}
}
pub struct EnvSecretsResolver;
#[async_trait]
impl SecretsResolver for EnvSecretsResolver {
async fn resolve(&self, reference: &str) -> Result<String, SecretsError> {
let var = reference.strip_prefix("env://").ok_or_else(|| {
SecretsError::Invalid(format!("expected env:// prefix, got `{reference}`"))
})?;
if var.is_empty() {
return Err(SecretsError::Invalid(format!(
"env:// reference has empty variable name: `{reference}`"
)));
}
std::env::var(var).map_err(|_| SecretsError::NotFound(var.to_owned()))
}
}
pub struct ChainSecretsResolver {
matchers: Vec<(String, Box<dyn SecretsResolver>)>,
default: Box<dyn SecretsResolver>,
}
impl ChainSecretsResolver {
#[must_use]
pub fn new(default: impl SecretsResolver) -> Self {
Self {
matchers: Vec::new(),
default: Box::new(default),
}
}
#[must_use]
pub fn push(mut self, scheme: impl Into<String>, resolver: impl SecretsResolver) -> Self {
self.matchers.push((scheme.into(), Box::new(resolver)));
self
}
#[must_use]
pub fn standard() -> Self {
Self::new(LiteralSecretsResolver).push("env://", EnvSecretsResolver)
}
}
#[async_trait]
impl SecretsResolver for ChainSecretsResolver {
async fn resolve(&self, reference: &str) -> Result<String, SecretsError> {
for (scheme, resolver) in &self.matchers {
if reference.starts_with(scheme.as_str()) {
return resolver.resolve(reference).await;
}
}
self.default.resolve(reference).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn literal_resolver_passes_through() {
let r = LiteralSecretsResolver;
let v = r.resolve("postgres://u:p@h/db").await.unwrap();
assert_eq!(v, "postgres://u:p@h/db");
}
#[tokio::test]
async fn env_resolver_reads_named_env_var() {
let key = if std::env::var("PATH").is_ok() {
"PATH"
} else {
"USER"
};
let r = EnvSecretsResolver;
let v = r.resolve(&format!("env://{key}")).await.unwrap();
assert!(!v.is_empty(), "{key} should be a non-empty env var");
}
#[tokio::test]
async fn env_resolver_rejects_missing_prefix() {
let r = EnvSecretsResolver;
let err = r.resolve("FOO").await.unwrap_err();
assert!(matches!(err, SecretsError::Invalid(_)), "got {err:?}");
}
#[tokio::test]
async fn env_resolver_rejects_empty_var_name() {
let r = EnvSecretsResolver;
let err = r.resolve("env://").await.unwrap_err();
assert!(matches!(err, SecretsError::Invalid(_)), "got {err:?}");
}
#[tokio::test]
async fn env_resolver_returns_not_found_for_unset_var() {
let r = EnvSecretsResolver;
let err = r
.resolve("env://RUSTANGO_TENANCY_DEFINITELY_NOT_SET_xyzzy")
.await
.unwrap_err();
assert!(matches!(err, SecretsError::NotFound(_)), "got {err:?}");
}
#[tokio::test]
async fn chain_dispatches_by_scheme_prefix() {
let chain = ChainSecretsResolver::standard();
let v = chain.resolve("env://PATH").await.unwrap();
assert!(!v.is_empty());
let v = chain.resolve("postgres://u:p@h/db").await.unwrap();
assert_eq!(v, "postgres://u:p@h/db");
}
}