use std::path::{Path, PathBuf};
use serde::Deserialize;
use crate::errors::{SafeError, SafeResult};
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(tag = "method", rename_all = "lowercase")]
pub enum VaultAuthConfig {
Token {
#[serde(default)]
token: Option<String>,
},
Approle { role_id: String, secret_id: String },
}
impl VaultAuthConfig {
pub fn expand_env_vars(self) -> Self {
match self {
VaultAuthConfig::Approle { role_id, secret_id } => VaultAuthConfig::Approle {
role_id: expand_env_var_str(&role_id),
secret_id: expand_env_var_str(&secret_id),
},
other => other,
}
}
}
pub fn expand_env_var_str(s: &str) -> String {
if !s.contains("${") {
return s.to_string();
}
let mut result = String::with_capacity(s.len());
let mut rest = s;
while let Some(start) = rest.find("${") {
result.push_str(&rest[..start]);
rest = &rest[start + 2..]; if let Some(end) = rest.find('}') {
let var_name = &rest[..end];
match std::env::var(var_name) {
Ok(val) => result.push_str(&val),
Err(_) => {
result.push_str("${");
result.push_str(var_name);
result.push('}');
}
}
rest = &rest[end + 1..];
} else {
result.push_str("${");
result.push_str(rest);
break;
}
}
result.push_str(rest);
result
}
#[derive(Debug, Deserialize)]
pub struct PullConfig {
pub pulls: Vec<PullSource>,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "source")]
pub enum PullSource {
#[serde(rename = "akv")]
AzureKeyVault {
#[serde(default)]
name: Option<String>,
#[serde(default)]
ns: Option<String>,
vault_url: String,
#[serde(default)]
prefix: Option<String>,
#[serde(default)]
overwrite: bool,
},
#[serde(rename = "hcp")]
HashiCorpVault {
#[serde(default)]
name: Option<String>,
#[serde(default)]
ns: Option<String>,
#[serde(default = "default_hcp_addr")]
addr: String,
#[serde(default = "default_mount")]
mount: String,
#[serde(default)]
prefix: Option<String>,
#[serde(default)]
overwrite: bool,
#[serde(default)]
auth: Option<VaultAuthConfig>,
#[serde(default)]
vault_namespace: Option<String>,
},
#[serde(rename = "op")]
OnePassword {
#[serde(default)]
name: Option<String>,
#[serde(default)]
ns: Option<String>,
item: String,
#[serde(default)]
op_vault: Option<String>,
#[serde(default)]
overwrite: bool,
},
#[serde(rename = "aws")]
Aws {
#[serde(default)]
name: Option<String>,
#[serde(default)]
ns: Option<String>,
#[serde(default)]
region: Option<String>,
#[serde(default)]
prefix: Option<String>,
#[serde(default)]
overwrite: bool,
},
#[serde(rename = "ssm")]
SsmParameterStore {
#[serde(default)]
name: Option<String>,
#[serde(default)]
ns: Option<String>,
#[serde(default)]
region: Option<String>,
#[serde(default)]
path: Option<String>,
#[serde(default)]
overwrite: bool,
},
#[serde(rename = "gcp")]
Gcp {
#[serde(default)]
name: Option<String>,
#[serde(default)]
ns: Option<String>,
#[serde(default)]
project: Option<String>,
#[serde(default)]
prefix: Option<String>,
#[serde(default)]
overwrite: bool,
},
#[serde(rename = "bw")]
Bitwarden {
#[serde(default)]
name: Option<String>,
#[serde(default)]
ns: Option<String>,
#[serde(default)]
api_url: Option<String>,
#[serde(default)]
identity_url: Option<String>,
#[serde(default)]
client_id: Option<String>,
#[serde(default)]
client_secret: Option<String>,
#[serde(default)]
folder: Option<String>,
#[serde(default)]
password_env: Option<String>,
#[serde(default)]
overwrite: bool,
},
#[serde(rename = "kp")]
Keepass {
#[serde(default)]
name: Option<String>,
path: String,
#[serde(default)]
password_env: Option<String>,
#[serde(default)]
keyfile_path: Option<String>,
#[serde(default)]
group: Option<String>,
#[serde(default)]
recursive: Option<bool>,
#[serde(default)]
ns: Option<String>,
#[serde(default)]
overwrite: bool,
},
}
impl PullSource {
pub fn name(&self) -> Option<&str> {
match self {
PullSource::AzureKeyVault { name, .. }
| PullSource::HashiCorpVault { name, .. }
| PullSource::OnePassword { name, .. }
| PullSource::Aws { name, .. }
| PullSource::SsmParameterStore { name, .. }
| PullSource::Gcp { name, .. }
| PullSource::Bitwarden { name, .. } => name.as_deref(),
PullSource::Keepass { name, .. } => name.as_deref(),
}
}
pub fn ns(&self) -> Option<&str> {
match self {
PullSource::AzureKeyVault { ns, .. }
| PullSource::HashiCorpVault { ns, .. }
| PullSource::OnePassword { ns, .. }
| PullSource::Aws { ns, .. }
| PullSource::SsmParameterStore { ns, .. }
| PullSource::Gcp { ns, .. }
| PullSource::Bitwarden { ns, .. } => ns.as_deref(),
PullSource::Keepass { ns, .. } => ns.as_deref(),
}
}
pub fn provider_type(&self) -> &'static str {
match self {
PullSource::AzureKeyVault { .. } => "akv",
PullSource::HashiCorpVault { .. } => "hcp",
PullSource::OnePassword { .. } => "op",
PullSource::Aws { .. } => "aws",
PullSource::SsmParameterStore { .. } => "ssm",
PullSource::Gcp { .. } => "gcp",
PullSource::Bitwarden { .. } => "bw",
PullSource::Keepass { .. } => "kp",
}
}
}
fn default_hcp_addr() -> String {
"http://127.0.0.1:8200".into()
}
fn default_mount() -> String {
"secret".into()
}
pub fn find_config(start: &Path) -> Option<PathBuf> {
let mut dir = start.to_path_buf();
loop {
let yml = dir.join(".tsafe.yml");
if yml.exists() {
return Some(yml);
}
let json = dir.join(".tsafe.json");
if json.exists() {
return Some(json);
}
if !dir.pop() {
return None;
}
}
}
pub fn load(path: &Path) -> SafeResult<PullConfig> {
let content = std::fs::read_to_string(path)?;
let is_json = path
.extension()
.and_then(|e| e.to_str())
.map(|e| e == "json")
.unwrap_or(false);
if is_json {
serde_json::from_str(&content).map_err(|e| SafeError::InvalidVault {
reason: format!("invalid pull config JSON: {e}"),
})
} else {
serde_yaml::from_str(&content).map_err(|e| SafeError::InvalidVault {
reason: format!("invalid pull config YAML: {e}"),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn parse_yaml_config() {
let yaml = r#"
pulls:
- source: akv
vault_url: https://myvault.vault.azure.net
prefix: MYAPP_
overwrite: true
- source: hcp
addr: http://vault:8200
mount: secret
prefix: myapp/
- source: op
item: Database Credentials
op_vault: Infrastructure
- source: aws
region: us-east-1
prefix: myapp/
- source: gcp
project: my-gcp-project
prefix: myapp-
"#;
let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(cfg.pulls.len(), 5);
match &cfg.pulls[0] {
PullSource::AzureKeyVault {
vault_url,
prefix,
overwrite,
..
} => {
assert_eq!(vault_url, "https://myvault.vault.azure.net");
assert_eq!(prefix.as_deref(), Some("MYAPP_"));
assert!(overwrite);
}
other => panic!("expected AzureKeyVault, got {other:?}"),
}
}
#[test]
fn parse_json_config() {
let json = r#"{"pulls": [{"source": "op", "item": "Test"}]}"#;
let cfg: PullConfig = serde_json::from_str(json).unwrap();
assert_eq!(cfg.pulls.len(), 1);
}
#[test]
fn find_config_walks_up() {
let dir = tempdir().unwrap();
let child = dir.path().join("a/b/c");
std::fs::create_dir_all(&child).unwrap();
let cfg_path = dir.path().join(".tsafe.yml");
std::fs::write(&cfg_path, "pulls: []").unwrap();
let found = find_config(&child).unwrap();
assert_eq!(found, cfg_path);
}
#[test]
fn find_config_returns_none() {
let dir = tempdir().unwrap();
assert!(find_config(dir.path()).is_none());
}
#[test]
fn parse_name_and_ns_fields() {
let yaml = r#"
pulls:
- source: akv
name: prod-akv
ns: prod
vault_url: https://prod.vault.azure.net
- source: aws
name: staging-aws
ns: staging
region: us-east-1
- source: gcp
project: my-project
"#;
let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(cfg.pulls.len(), 3);
assert_eq!(cfg.pulls[0].name(), Some("prod-akv"));
assert_eq!(cfg.pulls[0].ns(), Some("prod"));
assert_eq!(cfg.pulls[0].provider_type(), "akv");
assert_eq!(cfg.pulls[1].name(), Some("staging-aws"));
assert_eq!(cfg.pulls[1].ns(), Some("staging"));
assert_eq!(cfg.pulls[1].provider_type(), "aws");
assert_eq!(cfg.pulls[2].name(), None);
assert_eq!(cfg.pulls[2].ns(), None);
assert_eq!(cfg.pulls[2].provider_type(), "gcp");
}
#[test]
fn name_and_ns_default_to_none() {
let yaml = r#"
pulls:
- source: akv
vault_url: https://myvault.vault.azure.net
- source: hcp
addr: http://vault:8200
mount: secret
- source: op
item: MyItem
- source: aws
region: us-east-1
- source: ssm
region: us-east-1
- source: gcp
project: my-project
"#;
let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
for source in &cfg.pulls {
assert_eq!(
source.name(),
None,
"expected no name for {:?}",
source.provider_type()
);
assert_eq!(
source.ns(),
None,
"expected no ns for {:?}",
source.provider_type()
);
}
}
#[test]
fn parse_hcp_token_auth_from_yaml() {
let yaml = r#"
pulls:
- source: hcp
addr: https://vault.example.com:8200
auth:
method: token
token: hvs.my-static-token
"#;
let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(cfg.pulls.len(), 1);
match &cfg.pulls[0] {
PullSource::HashiCorpVault { auth, .. } => {
assert!(
matches!(
auth,
Some(VaultAuthConfig::Token {
token: Some(t)
}) if t == "hvs.my-static-token"
),
"expected Token auth with static token, got {auth:?}"
);
}
other => panic!("expected HashiCorpVault, got {other:?}"),
}
}
#[test]
fn parse_hcp_approle_auth_from_yaml() {
let yaml = r#"
pulls:
- source: hcp
addr: https://vault.example.com:8200
auth:
method: approle
role_id: my-role-123
secret_id: my-secret-456
"#;
let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
match &cfg.pulls[0] {
PullSource::HashiCorpVault { auth, .. } => {
assert!(
matches!(
auth,
Some(VaultAuthConfig::Approle { role_id, secret_id })
if role_id == "my-role-123" && secret_id == "my-secret-456"
),
"expected AppRole auth, got {auth:?}"
);
}
other => panic!("expected HashiCorpVault, got {other:?}"),
}
}
#[test]
fn parse_hcp_vault_namespace_from_yaml() {
let yaml = r#"
pulls:
- source: hcp
addr: https://vault.example.com:8200
vault_namespace: team-alpha
"#;
let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
match &cfg.pulls[0] {
PullSource::HashiCorpVault {
vault_namespace, ..
} => {
assert_eq!(vault_namespace.as_deref(), Some("team-alpha"));
}
other => panic!("expected HashiCorpVault, got {other:?}"),
}
}
#[test]
fn parse_hcp_defaults_auth_and_namespace_to_none() {
let yaml = r#"
pulls:
- source: hcp
addr: http://127.0.0.1:8200
"#;
let cfg: PullConfig = serde_yaml::from_str(yaml).unwrap();
match &cfg.pulls[0] {
PullSource::HashiCorpVault {
auth,
vault_namespace,
..
} => {
assert!(auth.is_none(), "expected auth=None, got {auth:?}");
assert!(
vault_namespace.is_none(),
"expected vault_namespace=None, got {vault_namespace:?}"
);
}
other => panic!("expected HashiCorpVault, got {other:?}"),
}
}
#[test]
fn expand_env_var_str_replaces_placeholder() {
temp_env::with_var("TEST_SECRET_ID", Some("s-abc-123"), || {
let result = expand_env_var_str("${TEST_SECRET_ID}");
assert_eq!(result, "s-abc-123");
});
}
#[test]
fn expand_env_var_str_no_placeholder_passthrough() {
let result = expand_env_var_str("plain-secret-id");
assert_eq!(result, "plain-secret-id");
}
#[test]
fn expand_env_var_str_unknown_var_left_as_is() {
temp_env::with_var("VAULT_UNKNOWN_9999", None::<&str>, || {
let result = expand_env_var_str("${VAULT_UNKNOWN_9999}");
assert_eq!(result, "${VAULT_UNKNOWN_9999}");
});
}
#[test]
fn vault_auth_config_expand_env_vars_in_approle() {
temp_env::with_var("MY_SECRET_ID", Some("expanded-sid"), || {
let auth = VaultAuthConfig::Approle {
role_id: "static-role".into(),
secret_id: "${MY_SECRET_ID}".into(),
};
let expanded = auth.expand_env_vars();
assert!(
matches!(
expanded,
VaultAuthConfig::Approle { ref role_id, ref secret_id }
if role_id == "static-role" && secret_id == "expanded-sid"
),
"expected expanded secret_id, got {expanded:?}"
);
});
}
#[test]
fn vault_auth_config_expand_env_vars_token_unchanged() {
let auth = VaultAuthConfig::Token {
token: Some("hvs.static".into()),
};
let expanded = auth.expand_env_vars();
assert!(
matches!(expanded, VaultAuthConfig::Token { token: Some(ref t) } if t == "hvs.static"),
"expected token unchanged, got {expanded:?}"
);
}
}