use std::collections::HashMap;
use std::sync::RwLock;
use crate::secrets::{
CredentialLocation, CredentialMapping, DecryptedSecret, SecretError, SecretsStore,
};
#[derive(Debug, Clone, thiserror::Error)]
pub enum InjectionError {
#[error("Secret not found: {0}")]
SecretNotFound(String),
#[error("Secret access denied: {0}")]
AccessDenied(String),
#[error("Secret has expired: {0}")]
SecretExpired(String),
#[error("Decryption failed: {0}")]
DecryptionFailed(String),
#[error("No matching credential for host: {0}")]
NoMatchingCredential(String),
}
impl From<SecretError> for InjectionError {
fn from(e: SecretError) -> Self {
match e {
SecretError::NotFound(name) => InjectionError::SecretNotFound(name),
SecretError::Expired => InjectionError::SecretExpired("unknown".to_string()),
SecretError::AccessDenied => InjectionError::AccessDenied("unknown".to_string()),
SecretError::DecryptionFailed(msg) => InjectionError::DecryptionFailed(msg),
_ => InjectionError::DecryptionFailed(e.to_string()),
}
}
}
pub struct SharedCredentialRegistry {
mappings: RwLock<Vec<CredentialMapping>>,
}
impl SharedCredentialRegistry {
pub fn new() -> Self {
Self {
mappings: RwLock::new(Vec::new()),
}
}
pub fn add_mappings(&self, mappings: impl IntoIterator<Item = CredentialMapping>) {
match self.mappings.write() {
Ok(mut guard) => {
guard.extend(mappings);
}
Err(poisoned) => {
tracing::warn!(
"SharedCredentialRegistry RwLock poisoned during add_mappings; recovering"
);
let mut guard = poisoned.into_inner();
guard.extend(mappings);
}
}
}
pub fn remove_mappings_for_secrets(&self, secret_names: &[String]) {
let mut guard = match self.mappings.write() {
Ok(guard) => guard,
Err(poisoned) => {
tracing::warn!(
"SharedCredentialRegistry RwLock poisoned during remove_mappings_for_secrets; recovering"
);
poisoned.into_inner()
}
};
guard.retain(|m| !secret_names.contains(&m.secret_name));
}
pub fn has_credentials_for_host(&self, host: &str) -> bool {
let guard = match self.mappings.read() {
Ok(guard) => guard,
Err(poisoned) => {
tracing::warn!(
"SharedCredentialRegistry RwLock poisoned during has_credentials_for_host; recovering"
);
poisoned.into_inner()
}
};
guard.iter().any(|mapping| {
mapping
.host_patterns
.iter()
.any(|pattern| host_matches_pattern(host, pattern))
})
}
pub fn find_for_host(&self, host: &str) -> Vec<CredentialMapping> {
let guard = match self.mappings.read() {
Ok(guard) => guard,
Err(poisoned) => {
tracing::warn!(
"SharedCredentialRegistry RwLock poisoned during find_for_host; recovering"
);
poisoned.into_inner()
}
};
guard
.iter()
.filter(|mapping| {
mapping
.host_patterns
.iter()
.any(|pattern| host_matches_pattern(host, pattern))
})
.cloned()
.collect()
}
}
impl Default for SharedCredentialRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct InjectedCredentials {
pub headers: HashMap<String, String>,
pub query_params: HashMap<String, String>,
}
impl InjectedCredentials {
pub fn empty() -> Self {
Self {
headers: HashMap::new(),
query_params: HashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.headers.is_empty() && self.query_params.is_empty()
}
}
pub struct CredentialInjector {
mappings: HashMap<String, CredentialMapping>,
allowed_secrets: Vec<String>,
}
impl CredentialInjector {
pub fn new(mappings: HashMap<String, CredentialMapping>, allowed_secrets: Vec<String>) -> Self {
Self {
mappings,
allowed_secrets,
}
}
pub fn find_credentials_for_host(&self, host: &str) -> Vec<&CredentialMapping> {
self.mappings
.values()
.filter(|mapping| {
mapping
.host_patterns
.iter()
.any(|pattern| host_matches_pattern(host, pattern))
})
.collect()
}
pub async fn inject(
&self,
user_id: &str,
host: &str,
store: &dyn SecretsStore,
) -> Result<InjectedCredentials, InjectionError> {
let matching_mappings = self.find_credentials_for_host(host);
if matching_mappings.is_empty() {
return Ok(InjectedCredentials::empty());
}
let mut result = InjectedCredentials::empty();
for mapping in matching_mappings {
if !self.is_secret_allowed(&mapping.secret_name) {
return Err(InjectionError::AccessDenied(mapping.secret_name.clone()));
}
let secret = store
.get_decrypted(user_id, &mapping.secret_name)
.await
.map_err(|e| match e {
SecretError::NotFound(name) => InjectionError::SecretNotFound(name),
SecretError::Expired => {
InjectionError::SecretExpired(mapping.secret_name.clone())
}
_ => InjectionError::DecryptionFailed(e.to_string()),
})?;
inject_credential(&mut result, &mapping.location, &secret);
}
Ok(result)
}
fn is_secret_allowed(&self, name: &str) -> bool {
let name_lower = name.to_lowercase();
for pattern in &self.allowed_secrets {
let pattern_lower = pattern.to_lowercase();
if pattern_lower == name_lower {
return true;
}
if let Some(prefix) = pattern_lower.strip_suffix('*')
&& name_lower.starts_with(prefix)
{
return true;
}
}
false
}
}
pub(crate) fn inject_credential(
result: &mut InjectedCredentials,
location: &CredentialLocation,
secret: &DecryptedSecret,
) {
match location {
CredentialLocation::AuthorizationBearer => {
result.headers.insert(
"Authorization".to_string(),
format!("Bearer {}", secret.expose()),
);
}
CredentialLocation::AuthorizationBasic { username } => {
let credentials = format!("{}:{}", username, secret.expose());
let encoded = base64_encode(credentials.as_bytes());
result
.headers
.insert("Authorization".to_string(), format!("Basic {}", encoded));
}
CredentialLocation::Header { name, prefix } => {
let value = match prefix {
Some(p) => format!("{}{}", p, secret.expose()),
None => secret.expose().to_string(),
};
result.headers.insert(name.clone(), value);
}
CredentialLocation::QueryParam { name } => {
result
.query_params
.insert(name.clone(), secret.expose().to_string());
}
CredentialLocation::UrlPath { .. } => {
}
}
}
pub(crate) fn host_matches_pattern(host: &str, pattern: &str) -> bool {
if pattern == host {
return true;
}
if let Some(suffix) = pattern.strip_prefix("*.")
&& host.ends_with(suffix)
&& host.len() > suffix.len()
{
let prefix = &host[..host.len() - suffix.len()];
if prefix.ends_with('.') || prefix.is_empty() {
return true;
}
}
false
}
fn base64_encode(input: &[u8]) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
let mut i = 0;
while i < input.len() {
let b0 = input[i];
let b1 = if i + 1 < input.len() { input[i + 1] } else { 0 };
let b2 = if i + 2 < input.len() { input[i + 2] } else { 0 };
result.push(ALPHABET[(b0 >> 2) as usize] as char);
result.push(ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
if i + 1 < input.len() {
result.push(ALPHABET[(((b1 & 0x0f) << 2) | (b2 >> 6)) as usize] as char);
} else {
result.push('=');
}
if i + 2 < input.len() {
result.push(ALPHABET[(b2 & 0x3f) as usize] as char);
} else {
result.push('=');
}
i += 3;
}
result
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::secrets::{
CreateSecretParams, CredentialLocation, CredentialMapping, InMemorySecretsStore,
SecretsStore,
};
use crate::testing::credentials::{TEST_OPENAI_API_KEY, test_secrets_store};
use crate::tools::wasm::credential_injector::{
CredentialInjector, base64_encode, host_matches_pattern,
};
fn test_store() -> InMemorySecretsStore {
test_secrets_store()
}
#[test]
fn test_host_matches_exact() {
assert!(host_matches_pattern("api.openai.com", "api.openai.com"));
assert!(!host_matches_pattern("api.openai.com", "other.com"));
}
#[test]
fn test_host_matches_wildcard() {
assert!(host_matches_pattern("api.example.com", "*.example.com"));
assert!(host_matches_pattern("sub.api.example.com", "*.example.com"));
assert!(!host_matches_pattern("example.com", "*.example.com"));
}
#[test]
fn test_base64_encode() {
assert_eq!(base64_encode(b"hello"), "aGVsbG8=");
assert_eq!(base64_encode(b"user:pass"), "dXNlcjpwYXNz");
}
#[tokio::test]
async fn test_inject_bearer() {
let store = test_store();
store
.create(
"user1",
CreateSecretParams::new("openai_key", TEST_OPENAI_API_KEY),
)
.await
.unwrap();
let mut mappings = HashMap::new();
mappings.insert(
"openai".to_string(),
CredentialMapping {
secret_name: "openai_key".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["api.openai.com".to_string()],
},
);
let injector = CredentialInjector::new(mappings, vec!["openai_key".to_string()]);
let result = injector
.inject("user1", "api.openai.com", &store)
.await
.unwrap();
assert_eq!(
result.headers.get("Authorization"),
Some(&format!("Bearer {TEST_OPENAI_API_KEY}"))
);
}
#[tokio::test]
async fn test_inject_custom_header() {
let store = test_store();
store
.create("user1", CreateSecretParams::new("api_key", "secret123"))
.await
.unwrap();
let mut mappings = HashMap::new();
mappings.insert(
"custom".to_string(),
CredentialMapping {
secret_name: "api_key".to_string(),
location: CredentialLocation::Header {
name: "X-API-Key".to_string(),
prefix: None,
},
host_patterns: vec!["*.example.com".to_string()],
},
);
let injector = CredentialInjector::new(mappings, vec!["api_key".to_string()]);
let result = injector
.inject("user1", "api.example.com", &store)
.await
.unwrap();
assert_eq!(
result.headers.get("X-API-Key"),
Some(&"secret123".to_string())
);
}
#[tokio::test]
async fn test_inject_basic_auth() {
let store = test_store();
store
.create("user1", CreateSecretParams::new("password", "mypassword"))
.await
.unwrap();
let mut mappings = HashMap::new();
mappings.insert(
"basic".to_string(),
CredentialMapping {
secret_name: "password".to_string(),
location: CredentialLocation::AuthorizationBasic {
username: "myuser".to_string(),
},
host_patterns: vec!["api.service.com".to_string()],
},
);
let injector = CredentialInjector::new(mappings, vec!["password".to_string()]);
let result = injector
.inject("user1", "api.service.com", &store)
.await
.unwrap();
let expected = format!("Basic {}", base64_encode(b"myuser:mypassword"));
assert_eq!(result.headers.get("Authorization"), Some(&expected));
}
#[tokio::test]
async fn test_no_credentials_for_host() {
let store = test_store();
let injector = CredentialInjector::new(HashMap::new(), vec![]);
let result = injector
.inject("user1", "unknown.com", &store)
.await
.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_access_denied_for_secret() {
let store = test_store();
store
.create("user1", CreateSecretParams::new("secret_key", "value"))
.await
.unwrap();
let mut mappings = HashMap::new();
mappings.insert(
"test".to_string(),
CredentialMapping {
secret_name: "secret_key".to_string(),
location: CredentialLocation::AuthorizationBearer,
host_patterns: vec!["api.test.com".to_string()],
},
);
let injector = CredentialInjector::new(mappings, vec![]);
let result = injector.inject("user1", "api.test.com", &store).await;
assert!(result.is_err());
}
use crate::tools::wasm::credential_injector::SharedCredentialRegistry;
#[test]
fn test_shared_registry_empty() {
let registry = SharedCredentialRegistry::new();
assert!(!registry.has_credentials_for_host("api.example.com"));
assert!(registry.find_for_host("api.example.com").is_empty());
}
#[test]
fn test_shared_registry_add_and_find() {
let registry = SharedCredentialRegistry::new();
registry.add_mappings(vec![
CredentialMapping::bearer("openai_key", "api.openai.com"),
CredentialMapping::header("github_token", "X-GitHub-Token", "*.github.com"),
]);
assert!(registry.has_credentials_for_host("api.openai.com"));
assert!(!registry.has_credentials_for_host("api.anthropic.com"));
let found = registry.find_for_host("api.openai.com");
assert_eq!(found.len(), 1);
assert_eq!(found[0].secret_name, "openai_key");
}
#[test]
fn test_shared_registry_wildcard_host() {
let registry = SharedCredentialRegistry::new();
registry.add_mappings(vec![CredentialMapping::bearer("gh_token", "*.github.com")]);
assert!(registry.has_credentials_for_host("api.github.com"));
assert!(registry.has_credentials_for_host("uploads.github.com"));
assert!(!registry.has_credentials_for_host("github.com"));
}
#[test]
fn test_shared_registry_multiple_adds() {
let registry = SharedCredentialRegistry::new();
registry.add_mappings(vec![CredentialMapping::bearer("key1", "api.example.com")]);
registry.add_mappings(vec![CredentialMapping::bearer("key2", "api.example.com")]);
let found = registry.find_for_host("api.example.com");
assert_eq!(found.len(), 2);
}
#[test]
fn test_shared_registry_remove_mappings_for_secrets() {
let registry = SharedCredentialRegistry::new();
registry.add_mappings(vec![
CredentialMapping::bearer("openai_key", "api.openai.com"),
CredentialMapping::bearer("gh_token", "*.github.com"),
CredentialMapping::header("openai_org", "OpenAI-Organization", "api.openai.com"),
]);
assert_eq!(registry.find_for_host("api.openai.com").len(), 2);
assert!(registry.has_credentials_for_host("api.github.com"));
registry.remove_mappings_for_secrets(&["openai_key".to_string(), "openai_org".to_string()]);
assert!(registry.find_for_host("api.openai.com").is_empty());
assert!(registry.has_credentials_for_host("api.github.com"));
}
#[test]
fn test_shared_registry_remove_nonexistent_is_noop() {
let registry = SharedCredentialRegistry::new();
registry.add_mappings(vec![CredentialMapping::bearer("key1", "api.example.com")]);
registry.remove_mappings_for_secrets(&["nonexistent".to_string()]);
assert_eq!(registry.find_for_host("api.example.com").len(), 1);
}
#[test]
fn test_shared_registry_thread_safety() {
use std::sync::Arc;
use std::thread;
let registry = Arc::new(SharedCredentialRegistry::new());
let handles: Vec<_> = (0..4)
.map(|i| {
let r = Arc::clone(®istry);
thread::spawn(move || {
r.add_mappings(vec![CredentialMapping::bearer(
format!("key_{}", i),
"api.example.com",
)]);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let found = registry.find_for_host("api.example.com");
assert_eq!(found.len(), 4);
}
}