use crate::auth::service_principal::parse_token_response;
use crate::auth::token::{AccessToken, CachedToken};
use crate::error::AzureError;
const IMDS_BASE: &str =
"http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=";
const IMDS_TIMEOUT_SECS: u64 = 3;
pub struct ManagedIdentityCredential {
client_id: Option<String>,
http: reqwest::Client,
cache: CachedToken,
}
impl ManagedIdentityCredential {
pub fn new() -> Self {
Self {
client_id: None,
http: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(IMDS_TIMEOUT_SECS))
.build()
.unwrap_or_default(),
cache: CachedToken::new(),
}
}
pub fn with_client_id(client_id: impl Into<String>) -> Self {
Self {
client_id: Some(client_id.into()),
..Self::new()
}
}
pub async fn get_token(&self) -> Result<AccessToken, AzureError> {
if let Some(cached) = self.cache.get().await {
return Ok(cached);
}
let token = self.fetch_token("https://management.azure.com/").await?;
self.cache.set(token.clone()).await;
Ok(token)
}
pub(crate) async fn get_token_for_scope(&self, scope: &str) -> Result<AccessToken, AzureError> {
let resource = scope.trim_end_matches("/.default");
let resource = if resource.ends_with('/') {
resource.to_string()
} else {
format!("{resource}/")
};
self.fetch_token(&resource).await
}
async fn fetch_token(&self, resource: &str) -> Result<AccessToken, AzureError> {
let mut url = format!("{IMDS_BASE}{}", urlencoding::encode(resource));
if let Some(ref cid) = self.client_id {
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
}
let response = self
.http
.get(&url)
.header("Metadata", "true")
.send()
.await
.map_err(|e| AzureError::Auth {
message: format!("IMDS request failed: {e}"),
})?;
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
if status != 200 {
return Err(AzureError::Auth {
message: format!("IMDS auth failed ({status}): {body}"),
});
}
parse_token_response(&body)
}
}
impl Default for ManagedIdentityCredential {
fn default() -> Self {
Self::new()
}
}