use std::sync::Arc;
use reqwest::header::CONTENT_TYPE;
use serde::Deserialize;
use url::Url;
use crate::authn::factor::ZeroizedString;
const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str =
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
fn token_endpoint_for_tenant(tenant: &str) -> Url {
Url::parse(&format!(
"https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token"
))
.expect("Azure AD token endpoint URL is well-formed")
}
#[derive(Debug, Clone)]
pub struct AzureFicRequest {
pub federated_token: String,
pub scopes: Vec<String>,
}
impl AzureFicRequest {
pub fn new(federated_token: impl Into<String>) -> Self {
Self {
federated_token: federated_token.into(),
scopes: Vec::new(),
}
}
pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
self.scopes.push(scope.into());
self
}
pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.scopes = scopes.into_iter().map(Into::into).collect();
self
}
}
#[derive(Debug)]
pub struct AzureFicResponse {
pub access_token: ZeroizedString,
pub token_type: String,
pub expires_in: Option<u64>,
}
#[derive(Debug, thiserror::Error)]
pub enum AzureFicError {
#[error("Azure AD transport error: {0}")]
Transport(String),
#[error("Azure AD error [{error}] (HTTP {http_status}): {error_description}")]
AzureError {
http_status: u16,
error: String,
error_description: String,
error_codes: Vec<u64>,
correlation_id: Option<String>,
},
#[error("malformed Azure AD response: {0}")]
MalformedResponse(String),
}
#[derive(Clone)]
pub struct AzureFicClient {
token_endpoint: Arc<Url>,
client_id: String,
http: reqwest::Client,
}
impl AzureFicClient {
pub fn new(tenant: impl AsRef<str>, client_id: impl Into<String>) -> Self {
Self {
token_endpoint: Arc::new(token_endpoint_for_tenant(tenant.as_ref())),
client_id: client_id.into(),
http: reqwest::Client::new(),
}
}
pub fn with_token_endpoint(mut self, endpoint: Url) -> Self {
self.token_endpoint = Arc::new(endpoint);
self
}
pub fn with_http_client(mut self, http: reqwest::Client) -> Self {
self.http = http;
self
}
pub fn token_endpoint(&self) -> &Url {
&self.token_endpoint
}
pub fn client_id(&self) -> &str {
&self.client_id
}
pub async fn acquire_token(
&self,
request: &AzureFicRequest,
) -> Result<AzureFicResponse, AzureFicError> {
let scope = request.scopes.join(" ");
let form: Vec<(&str, &str)> = vec![
("grant_type", "client_credentials"),
("client_id", &self.client_id),
("scope", scope.as_str()),
("client_assertion_type", CLIENT_ASSERTION_TYPE_JWT_BEARER),
("client_assertion", &request.federated_token),
];
let response = self
.http
.post((*self.token_endpoint).clone())
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.form(&form)
.send()
.await
.map_err(|e| AzureFicError::Transport(e.to_string()))?;
let status = response.status();
let body = response
.text()
.await
.map_err(|e| AzureFicError::Transport(e.to_string()))?;
if !status.is_success() {
return Err(parse_azure_error(status.as_u16(), &body));
}
let parsed: AzureSuccessBody = serde_json::from_str(&body)
.map_err(|e| AzureFicError::MalformedResponse(format!("success JSON: {e}")))?;
if parsed.access_token.is_empty() {
return Err(AzureFicError::MalformedResponse(
"access_token field is empty".to_string(),
));
}
Ok(AzureFicResponse {
access_token: ZeroizedString::from(parsed.access_token),
token_type: parsed.token_type.unwrap_or_else(|| "Bearer".to_string()),
expires_in: parsed.expires_in,
})
}
}
#[derive(Debug, Deserialize)]
struct AzureSuccessBody {
access_token: String,
#[serde(default)]
token_type: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct AzureErrorBody {
error: String,
#[serde(default)]
error_description: Option<String>,
#[serde(default)]
error_codes: Vec<u64>,
#[serde(default)]
correlation_id: Option<String>,
}
fn parse_azure_error(http_status: u16, body: &str) -> AzureFicError {
match serde_json::from_str::<AzureErrorBody>(body) {
Ok(parsed) => AzureFicError::AzureError {
http_status,
error: parsed.error,
error_description: parsed.error_description.unwrap_or_default(),
error_codes: parsed.error_codes,
correlation_id: parsed.correlation_id,
},
Err(_) => AzureFicError::AzureError {
http_status,
error: "unknown".to_string(),
error_description: format!("non-JSON error body: {body}"),
error_codes: Vec::new(),
correlation_id: None,
},
}
}
#[cfg(test)]
mod tests;