tsafe-azure 1.0.3

Azure Key Vault integration for tsafe — optional secret pull.
Documentation
use crate::error::KvError;

/// Percent-encode a string for use as an `application/x-www-form-urlencoded` value.
/// RFC 3986 unreserved characters are left as-is; everything else is `%XX`-encoded.
fn percent_encode(s: &str) -> String {
    s.bytes()
        .fold(String::with_capacity(s.len() * 2), |mut out, b| {
            if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.' | b'~') {
                out.push(b as char);
            } else {
                use std::fmt::Write;
                write!(out, "%{b:02X}").unwrap();
            }
            out
        })
}

/// The Key Vault token resource/scope. Defaults to public cloud.
/// Override with `TSAFE_AKV_RESOURCE` for sovereign clouds:
///   - Azure Government: `https://vault.usgovcloudapi.net`
///   - Azure China:      `https://vault.azure.cn`
fn vault_resource() -> String {
    std::env::var("TSAFE_AKV_RESOURCE").unwrap_or_else(|_| "https://vault.azure.net".into())
}

/// Acquire a bearer token for the Key Vault resource.
///
/// Strategy (in order):
/// 1. Service principal — if `AZURE_TENANT_ID`, `AZURE_CLIENT_ID`, and
///    `AZURE_CLIENT_SECRET` are all set.
/// 2. Managed identity via IMDS — automatic inside Azure VMs / ACI.
pub fn acquire_token() -> Result<String, KvError> {
    let tenant = std::env::var("AZURE_TENANT_ID").ok();
    let client_id = std::env::var("AZURE_CLIENT_ID").ok();
    let client_secret = std::env::var("AZURE_CLIENT_SECRET").ok();

    match (tenant, client_id, client_secret) {
        (Some(t), Some(id), Some(secret)) => sp_token(&t, &id, &secret),
        _ => imds_token(),
    }
}

/// Client-credentials flow for a service principal.
fn sp_token(tenant: &str, client_id: &str, client_secret: &str) -> Result<String, KvError> {
    let url = format!("https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token");
    let scope = format!("{}/.default", vault_resource());
    let body = format!(
        "grant_type=client_credentials&client_id={}&client_secret={}\
         &scope={}",
        percent_encode(client_id),
        percent_encode(client_secret),
        percent_encode(&scope)
    );

    let agent = ureq::AgentBuilder::new()
        .timeout_connect(std::time::Duration::from_secs(10))
        .timeout(std::time::Duration::from_secs(30))
        .build();
    let resp: serde_json::Value = agent
        .post(&url)
        .set("Content-Type", "application/x-www-form-urlencoded")
        .send_string(&body)
        .map_err(|e| KvError::Transport(e.to_string()))?
        .into_json()
        .map_err(|e| KvError::Transport(e.to_string()))?;

    extract_token(&resp)
}

/// Managed identity via the Instance Metadata Service (IMDS).
/// Only available inside Azure-hosted compute.
///
/// The `resource` parameter defaults to `https://vault.azure.net` (public cloud).
/// Set `TSAFE_AKV_RESOURCE` for sovereign clouds:
///   - Azure Government: `https://vault.usgovcloudapi.net`
///   - Azure China:      `https://vault.azure.cn`
fn imds_token() -> Result<String, KvError> {
    let resource = vault_resource();
    let url = format!(
        "http://169.254.169.254/metadata/identity/oauth2/token\
         ?api-version=2018-02-01&resource={}",
        percent_encode(&resource)
    );
    let agent = ureq::AgentBuilder::new()
        .timeout_connect(std::time::Duration::from_secs(10))
        .timeout(std::time::Duration::from_secs(30))
        .build();
    let resp: serde_json::Value = agent
        .get(&url)
        .set("Metadata", "true")
        .call()
        .map_err(|e| {
            KvError::Auth(format!(
                "IMDS unreachable and no service principal vars set \
             (AZURE_TENANT_ID / AZURE_CLIENT_ID / AZURE_CLIENT_SECRET): {e}"
            ))
        })?
        .into_json()
        .map_err(|e| KvError::Transport(e.to_string()))?;

    extract_token(&resp)
}

fn extract_token(resp: &serde_json::Value) -> Result<String, KvError> {
    resp["access_token"]
        .as_str()
        .map(|s| s.to_string())
        .ok_or_else(|| KvError::Auth("token response missing 'access_token' field".into()))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn percent_encode_leaves_unreserved_chars() {
        assert_eq!(percent_encode("abc-_~.XYZ123"), "abc-_~.XYZ123");
    }

    #[test]
    fn percent_encode_encodes_special_chars() {
        assert_eq!(percent_encode("hello world"), "hello%20world");
        assert_eq!(percent_encode("a=b&c=d"), "a%3Db%26c%3Dd");
        assert_eq!(percent_encode("p@ssw0rd!"), "p%40ssw0rd%21");
    }

    #[test]
    fn extract_token_success() {
        let resp = serde_json::json!({"access_token": "tok-abc", "token_type": "Bearer"});
        assert_eq!(extract_token(&resp).unwrap(), "tok-abc");
    }

    #[test]
    fn extract_token_missing_field_returns_auth_error() {
        let resp = serde_json::json!({"token_type": "Bearer"});
        let err = extract_token(&resp).unwrap_err();
        assert!(matches!(err, KvError::Auth(_)));
    }

    #[test]
    fn extract_token_rejects_non_string_access_token() {
        let resp = serde_json::json!({"access_token": 42, "token_type": "Bearer"});
        let err = extract_token(&resp).unwrap_err();
        assert!(matches!(err, KvError::Auth(_)));
    }
}