use crate::error::KvError;
fn percent_encode(s: &str) -> String {
s.bytes()
.fold(String::with_capacity(s.len() * 2), |mut out, b| {
if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.' | b'~') {
out.push(b as char);
} else {
use std::fmt::Write;
write!(out, "%{b:02X}").unwrap();
}
out
})
}
fn vault_resource() -> String {
std::env::var("TSAFE_AKV_RESOURCE").unwrap_or_else(|_| "https://vault.azure.net".into())
}
pub fn acquire_token() -> Result<String, KvError> {
let tenant = std::env::var("AZURE_TENANT_ID").ok();
let client_id = std::env::var("AZURE_CLIENT_ID").ok();
let client_secret = std::env::var("AZURE_CLIENT_SECRET").ok();
match (tenant, client_id, client_secret) {
(Some(t), Some(id), Some(secret)) => sp_token(&t, &id, &secret),
_ => imds_token(),
}
}
fn sp_token(tenant: &str, client_id: &str, client_secret: &str) -> Result<String, KvError> {
let url = format!("https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token");
let scope = format!("{}/.default", vault_resource());
let body = format!(
"grant_type=client_credentials&client_id={}&client_secret={}\
&scope={}",
percent_encode(client_id),
percent_encode(client_secret),
percent_encode(&scope)
);
let agent = ureq::AgentBuilder::new()
.timeout_connect(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(30))
.build();
let resp: serde_json::Value = agent
.post(&url)
.set("Content-Type", "application/x-www-form-urlencoded")
.send_string(&body)
.map_err(|e| KvError::Transport(e.to_string()))?
.into_json()
.map_err(|e| KvError::Transport(e.to_string()))?;
extract_token(&resp)
}
fn imds_token() -> Result<String, KvError> {
let resource = vault_resource();
let url = format!(
"http://169.254.169.254/metadata/identity/oauth2/token\
?api-version=2018-02-01&resource={}",
percent_encode(&resource)
);
let agent = ureq::AgentBuilder::new()
.timeout_connect(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(30))
.build();
let resp: serde_json::Value = agent
.get(&url)
.set("Metadata", "true")
.call()
.map_err(|e| {
KvError::Auth(format!(
"IMDS unreachable and no service principal vars set \
(AZURE_TENANT_ID / AZURE_CLIENT_ID / AZURE_CLIENT_SECRET): {e}"
))
})?
.into_json()
.map_err(|e| KvError::Transport(e.to_string()))?;
extract_token(&resp)
}
fn extract_token(resp: &serde_json::Value) -> Result<String, KvError> {
resp["access_token"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| KvError::Auth("token response missing 'access_token' field".into()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn percent_encode_leaves_unreserved_chars() {
assert_eq!(percent_encode("abc-_~.XYZ123"), "abc-_~.XYZ123");
}
#[test]
fn percent_encode_encodes_special_chars() {
assert_eq!(percent_encode("hello world"), "hello%20world");
assert_eq!(percent_encode("a=b&c=d"), "a%3Db%26c%3Dd");
assert_eq!(percent_encode("p@ssw0rd!"), "p%40ssw0rd%21");
}
#[test]
fn extract_token_success() {
let resp = serde_json::json!({"access_token": "tok-abc", "token_type": "Bearer"});
assert_eq!(extract_token(&resp).unwrap(), "tok-abc");
}
#[test]
fn extract_token_missing_field_returns_auth_error() {
let resp = serde_json::json!({"token_type": "Bearer"});
let err = extract_token(&resp).unwrap_err();
assert!(matches!(err, KvError::Auth(_)));
}
#[test]
fn extract_token_rejects_non_string_access_token() {
let resp = serde_json::json!({"access_token": 42, "token_type": "Bearer"});
let err = extract_token(&resp).unwrap_err();
assert!(matches!(err, KvError::Auth(_)));
}
}