use std::collections::HashMap;
use std::fmt;
use zeroize::Zeroize;
pub struct Secret {
bytes: Vec<u8>,
mlocked: bool,
}
impl Secret {
pub fn new(bytes: Vec<u8>) -> Self {
let mlocked = lock_memory(&bytes);
Self { bytes, mlocked }
}
pub fn expose(&self) -> &[u8] {
&self.bytes
}
}
impl Drop for Secret {
fn drop(&mut self) {
self.bytes.zeroize();
if self.mlocked {
unlock_memory(&self.bytes);
}
}
}
#[cfg(unix)]
fn lock_memory(buf: &[u8]) -> bool {
if buf.is_empty() {
return false;
}
let rc = unsafe { libc::mlock(buf.as_ptr() as *const libc::c_void, buf.len()) };
if rc == 0 {
true
} else {
tracing::warn!("mlock of credential pages failed (continuing without swap protection)");
false
}
}
#[cfg(unix)]
fn unlock_memory(buf: &[u8]) {
if buf.is_empty() {
return;
}
let _ = unsafe { libc::munlock(buf.as_ptr() as *const libc::c_void, buf.len()) };
}
#[cfg(not(unix))]
fn lock_memory(_buf: &[u8]) -> bool {
tracing::debug!("mlock not available on this platform; credential pages not pinned");
false
}
#[cfg(not(unix))]
fn unlock_memory(_buf: &[u8]) {}
impl fmt::Debug for Secret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Secret(REDACTED, {} bytes)", self.bytes.len())
}
}
struct Entry {
secret: Secret,
expires_at: Option<std::time::Instant>,
}
impl Entry {
fn is_expired(&self) -> bool {
matches!(self.expires_at, Some(deadline) if std::time::Instant::now() >= deadline)
}
}
#[derive(Default)]
pub struct CredentialStore {
entries: std::sync::RwLock<HashMap<String, Entry>>,
}
impl CredentialStore {
pub fn from_pairs(pairs: impl IntoIterator<Item = (String, Vec<u8>)>) -> Self {
let map: HashMap<String, Entry> = pairs
.into_iter()
.map(|(host, bytes)| {
(
host.to_ascii_lowercase(),
Entry {
secret: Secret::new(bytes),
expires_at: None,
},
)
})
.collect();
Self {
entries: std::sync::RwLock::new(map),
}
}
pub fn from_env() -> Self {
match std::env::var("AA_PROXY_PROVIDER_KEYS") {
Ok(val) if !val.is_empty() => {
let pairs: Vec<(String, Vec<u8>)> = val
.split(',')
.filter_map(|entry| {
let entry = entry.trim();
if entry.is_empty() {
return None;
}
match entry.split_once('=') {
Some((host, key)) if !host.trim().is_empty() && !key.is_empty() => {
Some((host.trim().to_string(), key.as_bytes().to_vec()))
}
_ => {
tracing::warn!("skipping malformed AA_PROXY_PROVIDER_KEYS entry (expected host=key)");
None
}
}
})
.collect();
let store = Self::from_pairs(pairs);
tracing::info!(hosts = store.len(), "loaded provider credentials for egress injection");
store
}
_ => Self::default(),
}
}
pub fn authorization_for(&self, host: &str) -> Option<Vec<u8>> {
let guard = self.entries.read().ok()?;
let entry = guard.get(&host.to_ascii_lowercase())?;
if entry.is_expired() {
tracing::debug!(%host, "configured provider credential has expired; not injecting");
return None;
}
let key = entry.secret.expose();
let mut buf = Vec::with_capacity(key.len() + 7);
buf.extend_from_slice(b"Bearer ");
buf.extend_from_slice(key);
Some(buf)
}
pub fn rotate(&self, host: &str, new_secret: Vec<u8>, ttl: Option<std::time::Duration>) {
let expires_at = ttl.map(|d| std::time::Instant::now() + d);
let entry = Entry {
secret: Secret::new(new_secret),
expires_at,
};
if let Ok(mut guard) = self.entries.write() {
guard.insert(host.to_ascii_lowercase(), entry);
} else {
tracing::error!(%host, "credential store lock poisoned; rotation skipped");
}
}
pub fn len(&self) -> usize {
self.entries.read().map(|g| g.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.entries.read().map(|g| g.is_empty()).unwrap_or(true)
}
}
impl fmt::Debug for CredentialStore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let hosts: Vec<String> = self
.entries
.read()
.map(|g| g.keys().cloned().collect())
.unwrap_or_default();
f.debug_struct("CredentialStore")
.field("hosts", &hosts)
.field("secrets", &"REDACTED")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn authorization_for_is_case_insensitive_and_bearer_prefixed() {
let store = CredentialStore::from_pairs([("API.OpenAI.com".to_string(), b"sk-secret".to_vec())]);
assert_eq!(
store.authorization_for("api.openai.com").as_deref(),
Some(&b"Bearer sk-secret"[..])
);
assert_eq!(
store.authorization_for("API.OPENAI.COM").as_deref(),
Some(&b"Bearer sk-secret"[..])
);
}
#[test]
fn authorization_for_unknown_host_is_none() {
let store = CredentialStore::from_pairs([("api.openai.com".to_string(), b"sk-secret".to_vec())]);
assert!(store.authorization_for("evil.attacker.com").is_none());
}
#[test]
fn debug_never_contains_key_material() {
let store = CredentialStore::from_pairs([("api.openai.com".to_string(), b"sk-TOPSECRET-1234".to_vec())]);
let store_dbg = format!("{store:?}");
assert!(
!store_dbg.contains("sk-TOPSECRET-1234"),
"store Debug leaked key: {store_dbg}"
);
assert!(store_dbg.contains("REDACTED"));
assert!(
store_dbg.contains("api.openai.com"),
"store Debug should still name the host"
);
let secret = Secret::new(b"sk-TOPSECRET-1234".to_vec());
let secret_dbg = format!("{secret:?}");
assert!(
!secret_dbg.contains("sk-TOPSECRET-1234"),
"secret Debug leaked key: {secret_dbg}"
);
assert!(secret_dbg.contains("REDACTED"));
}
#[test]
fn dropping_secret_runs_zeroize_on_its_buffer() {
let mut bytes = b"sk-zeroize-me-please".to_vec();
let ptr = bytes.as_ptr();
let len = bytes.len();
bytes.zeroize();
let observed = unsafe { std::slice::from_raw_parts(ptr, len) };
assert!(
observed.iter().all(|&b| b == 0),
"zeroize left plaintext behind: {observed:?}"
);
drop(bytes);
}
#[test]
fn mlocked_secret_constructs_and_exposes_without_leaking() {
let secret = Secret::new(b"sk-mlock-me".to_vec());
assert_eq!(secret.expose(), b"sk-mlock-me");
let dbg = format!("{secret:?}");
assert!(!dbg.contains("sk-mlock-me"), "Debug leaked key: {dbg}");
drop(secret);
}
#[test]
fn empty_store_reports_empty() {
let store = CredentialStore::default();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
assert!(store.authorization_for("api.openai.com").is_none());
}
#[test]
fn expired_entry_is_not_injected() {
let store = CredentialStore::default();
store.rotate(
"api.openai.com",
b"sk-expired".to_vec(),
Some(std::time::Duration::ZERO),
);
assert!(
store.authorization_for("api.openai.com").is_none(),
"expired credential must not be injected"
);
}
#[test]
fn rotate_replaces_secret_and_serves_the_new_one() {
let store = CredentialStore::from_pairs([("api.openai.com".to_string(), b"sk-old".to_vec())]);
assert_eq!(
store.authorization_for("api.openai.com").as_deref(),
Some(&b"Bearer sk-old"[..])
);
store.rotate("api.openai.com", b"sk-new".to_vec(), None);
assert_eq!(
store.authorization_for("api.openai.com").as_deref(),
Some(&b"Bearer sk-new"[..]),
"rotate must serve the new secret"
);
assert_eq!(store.len(), 1);
}
#[test]
fn rotate_installs_credential_for_a_new_host() {
let store = CredentialStore::default();
store.rotate(
"api.anthropic.com",
b"sk-ant-leased".to_vec(),
Some(std::time::Duration::from_secs(60)),
);
assert_eq!(
store.authorization_for("api.anthropic.com").as_deref(),
Some(&b"Bearer sk-ant-leased"[..])
);
}
#[test]
fn non_expiring_entry_stays_valid() {
let store = CredentialStore::from_pairs([("api.openai.com".to_string(), b"sk-forever".to_vec())]);
assert!(store.authorization_for("api.openai.com").is_some());
}
}