use crate::error::CoreError;
use crate::secret::SecretValue;
pub trait SecretProvider {
fn materialize(&self, reference: &str) -> Result<SecretValue, CoreError>;
fn scheme(&self) -> &'static str;
}
pub fn reference_scheme(reference: &str) -> Option<&str> {
reference.split_once("://").map(|(scheme, _)| scheme)
}
#[derive(Default)]
pub struct SchemeRouter {
providers: Vec<Box<dyn SecretProvider>>,
}
impl SchemeRouter {
pub fn new() -> Self {
Self::default()
}
pub fn with(mut self, provider: Box<dyn SecretProvider>) -> Self {
self.providers.push(provider);
self
}
fn provider_for(&self, scheme: &str) -> Option<&dyn SecretProvider> {
self.providers
.iter()
.rev()
.find(|p| p.scheme() == scheme)
.map(|p| p.as_ref())
}
}
impl SecretProvider for SchemeRouter {
fn materialize(&self, reference: &str) -> Result<SecretValue, CoreError> {
let scheme = reference_scheme(reference).ok_or_else(|| {
CoreError::Provider(format!(
"reference `{reference}` is malformed: expected `<scheme>://…`"
))
})?;
match self.provider_for(scheme) {
Some(p) => p.materialize(reference),
None => UnsupportedProvider.materialize(reference),
}
}
fn scheme(&self) -> &'static str {
"*"
}
}
pub struct UnsupportedProvider;
impl SecretProvider for UnsupportedProvider {
fn materialize(&self, reference: &str) -> Result<SecretValue, CoreError> {
let scheme = reference_scheme(reference).unwrap_or(reference);
Err(CoreError::Provider(format!(
"no provider registered for reference scheme `{scheme}` (supported: azure-kv, aws-sm)"
)))
}
fn scheme(&self) -> &'static str {
""
}
}
#[derive(Default)]
pub struct MockProvider {
entries: std::collections::HashMap<String, Vec<u8>>,
calls: std::sync::Mutex<std::collections::HashMap<String, usize>>,
}
impl MockProvider {
pub fn new() -> Self {
Self::default()
}
pub fn with(mut self, reference: &str, value: &str) -> Self {
self.entries
.insert(reference.to_string(), value.as_bytes().to_vec());
self
}
pub fn call_count(&self, reference: &str) -> usize {
self.calls
.lock()
.expect("mock provider mutex poisoned")
.get(reference)
.copied()
.unwrap_or(0)
}
}
impl SecretProvider for MockProvider {
fn materialize(&self, reference: &str) -> Result<SecretValue, CoreError> {
*self
.calls
.lock()
.expect("mock provider mutex poisoned")
.entry(reference.to_string())
.or_insert(0) += 1;
match self.entries.get(reference) {
Some(bytes) => Ok(SecretValue::new(bytes.clone())),
None => Err(CoreError::EnvRefs(format!(
"provider has no value for reference `{reference}`"
))),
}
}
fn scheme(&self) -> &'static str {
"azure-kv"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mock_materializes_and_counts() {
let p = MockProvider::new().with("azure-kv://kv/db-url", "postgres://h/db");
assert_eq!(
p.materialize("azure-kv://kv/db-url").unwrap().expose(),
b"postgres://h/db"
);
assert_eq!(p.call_count("azure-kv://kv/db-url"), 1);
assert!(p.materialize("azure-kv://kv/missing").is_err());
}
#[test]
fn reference_scheme_splits_on_separator() {
assert_eq!(reference_scheme("azure-kv://kv/name"), Some("azure-kv"));
assert_eq!(reference_scheme("aws-sm://arn:..."), Some("aws-sm"));
assert_eq!(reference_scheme("no-separator"), None);
}
#[test]
fn router_dispatches_by_scheme() {
let router =
SchemeRouter::new().with(Box::new(MockProvider::new().with("azure-kv://kv/n", "v")));
assert_eq!(
router.materialize("azure-kv://kv/n").unwrap().expose(),
b"v"
);
}
#[test]
fn router_unknown_scheme_is_unsupported_not_silent() {
let router = SchemeRouter::new();
let err = router.materialize("aws-sm://kv/n").unwrap_err();
assert!(matches!(err, CoreError::Provider(_)));
assert!(format!("{err}").contains("aws-sm"));
}
#[test]
fn router_malformed_reference_errors() {
let router = SchemeRouter::new();
assert!(matches!(
router.materialize("not-a-uri").unwrap_err(),
CoreError::Provider(_)
));
}
mod fuzz {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn reference_scheme_is_total_and_leak_free(s in ".*") {
match reference_scheme(&s) {
Some(scheme) => {
prop_assert!(s.contains("://"));
prop_assert!(!scheme.contains("://"));
prop_assert!(s.starts_with(scheme));
prop_assert_eq!(scheme, s.split("://").next().unwrap());
}
None => prop_assert!(!s.contains("://")),
}
}
#[test]
fn empty_router_never_fabricates(s in ".*") {
let router = SchemeRouter::new();
prop_assert!(router.materialize(&s).is_err());
}
}
}
}