use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
pub trait SecretsProvider: Send + Sync {
fn get(&self, key: &str) -> Option<String>;
fn get_or(&self, key: &str, default: &str) -> String {
self.get(key).unwrap_or_else(|| default.to_string())
}
fn exists(&self, key: &str) -> bool {
self.get(key).is_some()
}
}
#[derive(Debug, Default)]
pub struct EnvSecretsProvider;
impl SecretsProvider for EnvSecretsProvider {
fn get(&self, key: &str) -> Option<String> {
std::env::var(key).ok()
}
}
#[derive(Debug, Clone)]
pub struct VaultConfig {
pub address: String,
pub token: Option<String>,
pub role_id: Option<String>,
pub secret_id: Option<String>,
pub k8s_role: Option<String>,
pub mount_path: String,
pub secret_path: String,
pub tls_verify: bool,
}
impl Default for VaultConfig {
fn default() -> Self {
Self {
address: std::env::var("VAULT_ADDR")
.unwrap_or_else(|_| "http://127.0.0.1:8200".to_string()),
token: std::env::var("VAULT_TOKEN").ok(),
role_id: std::env::var("VAULT_ROLE_ID").ok(),
secret_id: std::env::var("VAULT_SECRET_ID").ok(),
k8s_role: std::env::var("VAULT_K8S_ROLE").ok(),
mount_path: std::env::var("VAULT_MOUNT_PATH").unwrap_or_else(|_| "secret".to_string()),
secret_path: std::env::var("VAULT_SECRET_PATH").unwrap_or_else(|_| "aegis".to_string()),
tls_verify: std::env::var("VAULT_TLS_VERIFY")
.map(|v| v != "false" && v != "0")
.unwrap_or(true),
}
}
}
pub struct VaultSecretsProvider {
config: VaultConfig,
client: reqwest::Client,
token: RwLock<Option<String>>,
cache: RwLock<HashMap<String, String>>,
}
impl VaultSecretsProvider {
pub fn new(config: VaultConfig) -> Self {
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(!config.tls_verify)
.build()
.expect("Failed to create HTTP client");
let token = config.token.clone();
Self {
config,
client,
token: RwLock::new(token),
cache: RwLock::new(HashMap::new()),
}
}
pub fn from_env() -> Option<Self> {
let config = VaultConfig::default();
if std::env::var("VAULT_ADDR").is_err() {
return None;
}
Some(Self::new(config))
}
pub async fn authenticate(&self) -> Result<(), String> {
if self.token.read().is_some() {
return Ok(());
}
if let (Some(role_id), Some(secret_id)) = (&self.config.role_id, &self.config.secret_id) {
return self.auth_approle(role_id, secret_id).await;
}
if let Some(k8s_role) = &self.config.k8s_role {
return self.auth_kubernetes(k8s_role).await;
}
Err("No authentication method configured. Set VAULT_TOKEN, VAULT_ROLE_ID/VAULT_SECRET_ID, or VAULT_K8S_ROLE".to_string())
}
async fn auth_approle(&self, role_id: &str, secret_id: &str) -> Result<(), String> {
let url = format!("{}/v1/auth/approle/login", self.config.address);
let body = serde_json::json!({
"role_id": role_id,
"secret_id": secret_id
});
let response = self
.client
.post(&url)
.json(&body)
.send()
.await
.map_err(|e| format!("Vault AppRole auth failed: {}", e))?;
if !response.status().is_success() {
return Err(format!(
"Vault AppRole auth failed: HTTP {}",
response.status()
));
}
let data: serde_json::Value = response
.json()
.await
.map_err(|e| format!("Failed to parse Vault response: {}", e))?;
let token = data["auth"]["client_token"]
.as_str()
.ok_or("No token in Vault response")?
.to_string();
*self.token.write() = Some(token);
tracing::info!("Successfully authenticated with Vault using AppRole");
Ok(())
}
async fn auth_kubernetes(&self, role: &str) -> Result<(), String> {
let jwt = std::fs::read_to_string("/var/run/secrets/kubernetes.io/serviceaccount/token")
.map_err(|e| format!("Failed to read K8s service account token: {}", e))?;
let url = format!("{}/v1/auth/kubernetes/login", self.config.address);
let body = serde_json::json!({
"role": role,
"jwt": jwt
});
let response = self
.client
.post(&url)
.json(&body)
.send()
.await
.map_err(|e| format!("Vault K8s auth failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Vault K8s auth failed: HTTP {}", response.status()));
}
let data: serde_json::Value = response
.json()
.await
.map_err(|e| format!("Failed to parse Vault response: {}", e))?;
let token = data["auth"]["client_token"]
.as_str()
.ok_or("No token in Vault response")?
.to_string();
*self.token.write() = Some(token);
tracing::info!("Successfully authenticated with Vault using Kubernetes");
Ok(())
}
pub async fn read_secret(&self, key: &str) -> Result<String, String> {
if let Some(value) = self.cache.read().get(key) {
return Ok(value.clone());
}
let token = self
.token
.read()
.clone()
.ok_or("Not authenticated with Vault")?;
let url = format!(
"{}/v1/{}/data/{}/{}",
self.config.address, self.config.mount_path, self.config.secret_path, key
);
let response = self
.client
.get(&url)
.header("X-Vault-Token", &token)
.send()
.await
.map_err(|e| format!("Vault read failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Vault read failed: HTTP {}", response.status()));
}
let data: serde_json::Value = response
.json()
.await
.map_err(|e| format!("Failed to parse Vault response: {}", e))?;
let value = data["data"]["data"]["value"]
.as_str()
.ok_or_else(|| format!("Secret '{}' not found or has no 'value' field", key))?
.to_string();
self.cache.write().insert(key.to_string(), value.clone());
Ok(value)
}
pub async fn load_secrets(&self) -> Result<(), String> {
let token = self
.token
.read()
.clone()
.ok_or("Not authenticated with Vault")?;
let url = format!(
"{}/v1/{}/data/{}",
self.config.address, self.config.mount_path, self.config.secret_path
);
let response = self
.client
.get(&url)
.header("X-Vault-Token", &token)
.send()
.await
.map_err(|e| format!("Vault read failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Vault read failed: HTTP {}", response.status()));
}
let data: serde_json::Value = response
.json()
.await
.map_err(|e| format!("Failed to parse Vault response: {}", e))?;
if let Some(secrets) = data["data"]["data"].as_object() {
let mut cache = self.cache.write();
for (key, value) in secrets {
if let Some(v) = value.as_str() {
cache.insert(key.clone(), v.to_string());
}
}
tracing::info!("Loaded {} secrets from Vault", cache.len());
}
Ok(())
}
pub fn get_cached(&self, key: &str) -> Option<String> {
self.cache.read().get(key).cloned()
}
}
impl SecretsProvider for VaultSecretsProvider {
fn get(&self, key: &str) -> Option<String> {
if let Some(value) = self.get_cached(key) {
return Some(value);
}
std::env::var(key).ok()
}
}
pub struct SecretsManager {
providers: Vec<Arc<dyn SecretsProvider>>,
}
impl SecretsManager {
pub fn new(providers: Vec<Arc<dyn SecretsProvider>>) -> Self {
Self { providers }
}
pub fn env_only() -> Self {
Self {
providers: vec![Arc::new(EnvSecretsProvider)],
}
}
pub fn with_vault_fallback(vault: VaultSecretsProvider) -> Self {
Self {
providers: vec![Arc::new(vault), Arc::new(EnvSecretsProvider)],
}
}
}
impl SecretsProvider for SecretsManager {
fn get(&self, key: &str) -> Option<String> {
for provider in &self.providers {
if let Some(value) = provider.get(key) {
return Some(value);
}
}
None
}
}
pub struct AegisVaultProvider {
vault: std::sync::Arc<aegis_vault::AegisVault>,
}
impl AegisVaultProvider {
pub fn new(vault: std::sync::Arc<aegis_vault::AegisVault>) -> Self {
Self { vault }
}
}
impl SecretsProvider for AegisVaultProvider {
fn get(&self, key: &str) -> Option<String> {
self.vault.get(key, "secrets_manager").ok()
}
}
pub async fn init_secrets_manager(
built_in_vault: Option<std::sync::Arc<aegis_vault::AegisVault>>,
) -> SecretsManager {
let mut providers: Vec<Arc<dyn SecretsProvider>> = Vec::new();
if let Some(vault) = built_in_vault {
if !vault.is_sealed() {
tracing::info!("Secrets provider chain: built-in vault (active)");
providers.push(Arc::new(AegisVaultProvider::new(vault)));
} else {
tracing::info!("Built-in vault is sealed; skipping as secrets provider");
}
}
if let Some(vault) = VaultSecretsProvider::from_env() {
if let Err(e) = vault.authenticate().await {
tracing::warn!("External Vault authentication failed: {}.", e);
} else {
if let Err(e) = vault.load_secrets().await {
tracing::warn!("Failed to load secrets from external Vault: {}.", e);
}
tracing::info!("Secrets provider chain: +external Vault");
providers.push(Arc::new(vault));
}
}
providers.push(Arc::new(EnvSecretsProvider));
tracing::info!(
"Secrets provider chain: +environment variables ({} providers total)",
providers.len()
);
SecretsManager::new(providers)
}
pub mod keys {
pub const ADMIN_USERNAME: &str = "AEGIS_ADMIN_USERNAME";
pub const ADMIN_PASSWORD: &str = "AEGIS_ADMIN_PASSWORD";
pub const ADMIN_EMAIL: &str = "AEGIS_ADMIN_EMAIL";
pub const TLS_CERT_PATH: &str = "AEGIS_TLS_CERT";
pub const TLS_KEY_PATH: &str = "AEGIS_TLS_KEY";
pub const CLUSTER_CA_CERT_PATH: &str = "AEGIS_CLUSTER_CA_CERT";
pub const CLUSTER_CLIENT_CERT_PATH: &str = "AEGIS_CLUSTER_CLIENT_CERT";
pub const CLUSTER_CLIENT_KEY_PATH: &str = "AEGIS_CLUSTER_CLIENT_KEY";
pub const ENCRYPTION_KEY: &str = "AEGIS_ENCRYPTION_KEY";
pub const JWT_SECRET: &str = "AEGIS_JWT_SECRET";
pub const LDAP_BIND_PASSWORD: &str = "AEGIS_LDAP_BIND_PASSWORD";
pub const OAUTH_CLIENT_SECRET: &str = "AEGIS_OAUTH_CLIENT_SECRET";
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_env_provider() {
std::env::set_var("TEST_SECRET_KEY", "test_value");
let provider = EnvSecretsProvider;
assert_eq!(
provider.get("TEST_SECRET_KEY"),
Some("test_value".to_string())
);
assert_eq!(provider.get("NONEXISTENT_KEY"), None);
std::env::remove_var("TEST_SECRET_KEY");
}
#[test]
fn test_secrets_manager_fallback() {
std::env::set_var("TEST_FALLBACK_KEY", "fallback_value");
let manager = SecretsManager::env_only();
assert_eq!(
manager.get("TEST_FALLBACK_KEY"),
Some("fallback_value".to_string())
);
std::env::remove_var("TEST_FALLBACK_KEY");
}
#[test]
fn test_get_or_default() {
let provider = EnvSecretsProvider;
assert_eq!(provider.get_or("NONEXISTENT", "default"), "default");
}
#[test]
fn test_vault_config_from_env() {
std::env::remove_var("VAULT_ADDR");
let config = VaultConfig::default();
assert_eq!(config.address, "http://127.0.0.1:8200");
assert_eq!(config.mount_path, "secret");
assert_eq!(config.secret_path, "aegis");
assert!(config.tls_verify);
}
}