use async_trait::async_trait;
use serde::Deserialize;
use std::time::Duration;
use super::{SecretProvider, SecretRef, SecretValue};
use crate::error::{AppError, AppResult};
const PROVIDER: &str = "azure_oauth";
const DEFAULT_AAD_HOST: &str = "https://login.microsoftonline.com";
const DEFAULT_SCOPE: &str = "https://graph.microsoft.com/.default";
pub struct AzureOAuthProvider {
http: reqwest::Client,
aad_host: String,
tenant_id: String,
client_id: String,
client_secret: String,
default_scope: String,
}
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
expires_in: u64,
#[serde(default)]
#[allow(dead_code)]
token_type: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ParsedRef {
tenant: Option<String>,
scope: Option<String>,
}
fn parse_ref(raw: &str) -> ParsedRef {
let raw = raw.trim();
if raw.is_empty() {
return ParsedRef {
tenant: None,
scope: None,
};
}
if let Some((maybe_tenant, scope)) = raw.split_once(':') {
if !maybe_tenant.is_empty()
&& !scope.is_empty()
&& !maybe_tenant.contains('/')
&& !maybe_tenant.eq_ignore_ascii_case("https")
&& !maybe_tenant.eq_ignore_ascii_case("http")
{
return ParsedRef {
tenant: Some(maybe_tenant.to_string()),
scope: Some(scope.to_string()),
};
}
}
ParsedRef {
tenant: None,
scope: Some(raw.to_string()),
}
}
impl AzureOAuthProvider {
pub fn from_env() -> AppResult<Self> {
let tenant_id = std::env::var("AZURE_TENANT_ID").map_err(|_| {
AppError::Config(
"azure_oauth: AZURE_TENANT_ID is not set (required for the \
`azure_oauth` secret provider)"
.to_string(),
)
})?;
let client_id = std::env::var("AZURE_CLIENT_ID").map_err(|_| {
AppError::Config(
"azure_oauth: AZURE_CLIENT_ID is not set (required for the \
`azure_oauth` secret provider)"
.to_string(),
)
})?;
let client_secret = std::env::var("AZURE_CLIENT_SECRET").map_err(|_| {
AppError::Config(
"azure_oauth: AZURE_CLIENT_SECRET is not set (required for the \
`azure_oauth` secret provider)"
.to_string(),
)
})?;
let default_scope = std::env::var("NOETL_AZURE_OAUTH_SCOPE")
.unwrap_or_else(|_| DEFAULT_SCOPE.to_string());
let aad_host = std::env::var("NOETL_AZURE_AAD_HOST")
.unwrap_or_else(|_| DEFAULT_AAD_HOST.to_string());
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(15))
.build()
.map_err(|e| AppError::Config(format!("azure_oauth: http client build failed: {e}")))?;
Ok(Self {
http,
aad_host,
tenant_id,
client_id,
client_secret,
default_scope,
})
}
fn token_url_for(&self, tenant: &str) -> String {
format!("{host}/{tenant}/oauth2/v2.0/token", host = self.aad_host)
}
fn build_body(client_id: &str, client_secret: &str, scope: &str) -> String {
format!(
"grant_type=client_credentials\
&client_id={cid}\
&client_secret={cs}\
&scope={scope}",
cid = percent_encode(client_id),
cs = percent_encode(client_secret),
scope = percent_encode(scope),
)
}
fn compute_expires_at(
expires_in_secs: u64,
now: chrono::DateTime<chrono::Utc>,
) -> chrono::DateTime<chrono::Utc> {
let secs = i64::try_from(expires_in_secs).unwrap_or(i64::MAX);
now + chrono::Duration::seconds(secs)
}
}
#[async_trait]
impl SecretProvider for AzureOAuthProvider {
fn provider(&self) -> &'static str {
PROVIDER
}
async fn fetch(&self, secret: &SecretRef) -> AppResult<SecretValue> {
let parsed = parse_ref(&secret.name);
let tenant = parsed.tenant.unwrap_or_else(|| self.tenant_id.clone());
let scope = parsed.scope.unwrap_or_else(|| self.default_scope.clone());
let url = self.token_url_for(&tenant);
let body = Self::build_body(&self.client_id, &self.client_secret, &scope);
let resp = self
.http
.post(&url)
.header("content-type", "application/x-www-form-urlencoded")
.header("accept", "application/json")
.body(body)
.send()
.await
.map_err(|e| AppError::Internal(format!("azure_oauth: POST {url} failed: {e}")))?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| AppError::Internal(format!("azure_oauth: read body failed: {e}")))?;
if !status.is_success() {
return Err(AppError::Internal(format!(
"azure_oauth: AAD returned HTTP {status}: {text}"
)));
}
let parsed: TokenResponse = serde_json::from_str(&text).map_err(|e| {
AppError::Internal(format!("azure_oauth: parse AAD response: {e}"))
})?;
let expires_at = Self::compute_expires_at(parsed.expires_in, chrono::Utc::now());
Ok(SecretValue {
value: parsed.access_token,
version: None,
expires_at: Some(expires_at),
})
}
}
fn percent_encode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(b as char);
}
_ => out.push_str(&format!("%{b:02X}")),
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_ref_empty() {
let p = parse_ref("");
assert_eq!(
p,
ParsedRef {
tenant: None,
scope: None,
}
);
}
#[test]
fn parse_ref_bare_scope_url() {
let p = parse_ref("https://graph.microsoft.com/.default");
assert_eq!(
p,
ParsedRef {
tenant: None,
scope: Some("https://graph.microsoft.com/.default".to_string()),
}
);
}
#[test]
fn parse_ref_tenant_prefix_splits() {
let p = parse_ref("contoso.onmicrosoft.com:https://vault.azure.net/.default");
assert_eq!(
p,
ParsedRef {
tenant: Some("contoso.onmicrosoft.com".to_string()),
scope: Some("https://vault.azure.net/.default".to_string()),
}
);
}
#[test]
fn parse_ref_guid_tenant_prefix() {
let p = parse_ref("11111111-2222-3333-4444-555555555555:https://api.fabrikam.com/.default");
assert_eq!(
p.tenant.as_deref(),
Some("11111111-2222-3333-4444-555555555555")
);
assert_eq!(
p.scope.as_deref(),
Some("https://api.fabrikam.com/.default")
);
}
fn test_provider() -> AzureOAuthProvider {
AzureOAuthProvider {
http: reqwest::Client::new(),
aad_host: DEFAULT_AAD_HOST.to_string(),
tenant_id: "default-tenant".to_string(),
client_id: "ID".to_string(),
client_secret: "SECRET".to_string(),
default_scope: DEFAULT_SCOPE.to_string(),
}
}
#[test]
fn token_url_for_uses_tenant() {
let p = test_provider();
assert_eq!(
p.token_url_for("contoso.onmicrosoft.com"),
"https://login.microsoftonline.com/contoso.onmicrosoft.com/oauth2/v2.0/token"
);
}
#[test]
fn token_url_for_honours_sovereign_host() {
let mut p = test_provider();
p.aad_host = "https://login.microsoftonline.us".to_string();
assert_eq!(
p.token_url_for("11111111-2222-3333-4444-555555555555"),
"https://login.microsoftonline.us/11111111-2222-3333-4444-555555555555/oauth2/v2.0/token"
);
}
#[test]
fn build_body_form_urlencoded_shape() {
let body = AzureOAuthProvider::build_body(
"client-app-id",
"very/secret*value!",
"https://graph.microsoft.com/.default",
);
assert!(body.contains("grant_type=client_credentials"));
assert!(body.contains("client_id=client-app-id"));
assert!(body.contains("client_secret=very%2Fsecret%2Avalue%21"));
assert!(body.contains("scope=https%3A%2F%2Fgraph.microsoft.com%2F.default"));
}
#[test]
fn compute_expires_at_adds_seconds() {
let now = chrono::DateTime::parse_from_rfc3339("2026-06-07T03:00:00Z")
.unwrap()
.with_timezone(&chrono::Utc);
let at = AzureOAuthProvider::compute_expires_at(3600, now);
assert_eq!(at - now, chrono::Duration::seconds(3600));
}
#[test]
fn compute_expires_at_handles_zero() {
let now = chrono::DateTime::parse_from_rfc3339("2026-06-07T03:00:00Z")
.unwrap()
.with_timezone(&chrono::Utc);
let at = AzureOAuthProvider::compute_expires_at(0, now);
assert_eq!(at, now);
}
#[test]
fn response_parses_aad_token() {
let body = r#"{
"token_type": "Bearer",
"expires_in": 3599,
"ext_expires_in": 3599,
"access_token": "eyJ0eXAiOi.AAD.token"
}"#;
let parsed: TokenResponse = serde_json::from_str(body).unwrap();
assert_eq!(parsed.access_token, "eyJ0eXAiOi.AAD.token");
assert_eq!(parsed.expires_in, 3599);
}
#[test]
fn response_parses_minimal_shape_without_token_type() {
let body = r#"{
"expires_in": 3600,
"access_token": "minimal-token"
}"#;
let parsed: TokenResponse = serde_json::from_str(body).unwrap();
assert_eq!(parsed.access_token, "minimal-token");
assert_eq!(parsed.expires_in, 3600);
}
#[test]
fn response_parse_fails_on_missing_token() {
let body = r#"{ "expires_in": 3600 }"#;
let result: Result<TokenResponse, _> = serde_json::from_str(body);
assert!(result.is_err());
}
#[test]
fn percent_encode_preserves_unreserved() {
assert_eq!(percent_encode("ABCabc123-_.~"), "ABCabc123-_.~");
}
#[test]
fn percent_encode_escapes_specials() {
assert_eq!(percent_encode("https://x/.default"), "https%3A%2F%2Fx%2F.default");
}
}