use crate::{
env::Env, ClientAssertion, ClientAssertionCredential, ClientAssertionCredentialOptions,
};
use azure_core::{
credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
error::ErrorKind,
http::{
headers::{FromHeaders, HeaderName, Headers, AUTHORIZATION, CONTENT_LENGTH},
request::Request,
ClientMethodOptions, Method, Pipeline, PipelineSendOptions, StatusCode, Url,
},
};
use serde::Deserialize;
use std::{borrow::Cow, convert::Infallible, fmt, sync::Arc};
const OIDC_VARIABLE_NAME: &str = "SYSTEM_OIDCREQUESTURI";
const OIDC_VERSION: &str = "7.1";
const TFS_FEDAUTHREDIRECT_HEADER: HeaderName = HeaderName::from_static("x-tfs-fedauthredirect");
const ALLOWED_HEADERS: &[&str] = &["x-msedge-ref", "x-vss-e2eid"];
#[derive(Debug)]
pub struct AzurePipelinesCredential(ClientAssertionCredential<Client>);
#[derive(Debug, Default)]
pub struct AzurePipelinesCredentialOptions {
pub credential_options: ClientAssertionCredentialOptions,
#[cfg(test)]
pub(crate) env: Option<Env>,
}
impl AzurePipelinesCredential {
pub fn new<T>(
tenant_id: String,
client_id: String,
service_connection_id: &str,
system_access_token: T,
options: Option<AzurePipelinesCredentialOptions>,
) -> azure_core::Result<Arc<Self>>
where
T: Into<Secret>,
{
let system_access_token = system_access_token.into();
crate::validate_tenant_id(&tenant_id)?;
crate::validate_not_empty(&client_id, "no client ID specified")?;
crate::validate_not_empty(service_connection_id, "no service connection ID specified")?;
crate::validate_not_empty(
system_access_token.secret(),
"no system access token specified",
)?;
let mut options = options.unwrap_or_default();
options
.credential_options
.client_options
.logging
.additional_allowed_header_names
.extend(ALLOWED_HEADERS.iter().map(|&s| Cow::Borrowed(s)));
#[cfg(test)]
let env = options.env.unwrap_or_default();
#[cfg(not(test))]
let env = Env::default();
let endpoint = env
.var(OIDC_VARIABLE_NAME)
.map_err(|err| azure_core::Error::with_error(
ErrorKind::Credential,
err,
format!("no value for environment variable {OIDC_VARIABLE_NAME}. This should be set by Azure Pipelines"),
))?;
let mut endpoint: Url = endpoint.parse().map_err(|err| {
azure_core::Error::with_error(
ErrorKind::Credential,
err,
format!("invalid URL for environment variable {OIDC_VARIABLE_NAME}"),
)
})?;
endpoint
.query_pairs_mut()
.append_pair("api-version", OIDC_VERSION)
.append_pair("serviceConnectionId", service_connection_id);
let pipeline = azure_core::http::Pipeline::new(
option_env!("CARGO_PKG_NAME"),
option_env!("CARGO_PKG_VERSION"),
options.credential_options.client_options.clone(),
Vec::default(),
Vec::default(),
None,
);
let client = Client {
endpoint,
pipeline: Arc::new(pipeline),
system_access_token,
};
let credential = ClientAssertionCredential::new_exclusive(
tenant_id,
client_id,
client,
stringify!(AzurePipelinesCredential),
Some(options.credential_options),
)?;
Ok(Arc::new(Self(credential)))
}
}
#[async_trait::async_trait]
impl TokenCredential for AzurePipelinesCredential {
async fn get_token(
&self,
scopes: &[&str],
options: Option<TokenRequestOptions<'_>>,
) -> azure_core::Result<AccessToken> {
self.0.get_token(scopes, options).await
}
}
#[derive(Debug)]
struct Client {
endpoint: Url,
pipeline: Arc<Pipeline>,
system_access_token: Secret,
}
#[async_trait::async_trait]
impl ClientAssertion for Client {
async fn secret(&self, options: Option<ClientMethodOptions<'_>>) -> azure_core::Result<String> {
let mut req = Request::new(self.endpoint.clone(), Method::Post);
req.insert_header(
AUTHORIZATION,
String::from("Bearer ") + self.system_access_token.secret(),
);
req.insert_header(TFS_FEDAUTHREDIRECT_HEADER, "Suppress");
req.insert_header(CONTENT_LENGTH, "0");
let options = options.unwrap_or_default();
let ctx = options.context.to_borrowed();
let resp = self
.pipeline
.send(
&ctx,
&mut req,
Some(PipelineSendOptions {
skip_checks: true,
..Default::default()
}),
)
.await?;
let status = resp.status();
if status != StatusCode::Ok {
let err_headers: ErrorHeaders = resp.headers().get()?;
return Err(azure_core::Error::with_message(
ErrorKind::HttpResponse {
status,
error_code: Some(status.canonical_reason().to_string()),
raw_response: Some(Box::new(resp)),
},
format!(
"{status} response from the OIDC endpoint. Check service connection ID and pipeline configuration. {err_headers}"
),
));
}
let assertion: Assertion = resp.into_body().json()?;
Ok(assertion.oidc_token.secret().to_string())
}
}
#[derive(Debug, Deserialize)]
struct Assertion {
#[serde(rename = "oidcToken")]
oidc_token: Secret,
}
#[derive(Debug)]
struct ErrorHeaders {
msedge_ref: Option<String>,
vss_e2eid: Option<String>,
}
const MSEDGE_REF: HeaderName = HeaderName::from_static("x-msedge-ref");
const VSS_E2EID: HeaderName = HeaderName::from_static("x-vss-e2eid");
impl FromHeaders for ErrorHeaders {
type Error = Infallible;
fn header_names() -> &'static [&'static str] {
ALLOWED_HEADERS
}
fn from_headers(headers: &Headers) -> Result<Option<Self>, Self::Error> {
Ok(Some(Self {
msedge_ref: headers.get_optional_string(&MSEDGE_REF),
vss_e2eid: headers.get_optional_string(&VSS_E2EID),
}))
}
}
impl fmt::Display for ErrorHeaders {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut v = f.debug_struct("Headers");
if let Some(ref msedge_ref) = self.msedge_ref {
v.field(MSEDGE_REF.as_str(), msedge_ref);
}
if let Some(ref vss_e2eid) = self.vss_e2eid {
v.field(VSS_E2EID.as_str(), vss_e2eid);
}
v.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::env::Env;
use azure_core::{
http::{AsyncRawResponse, ClientOptions, RawResponse, Transport},
Bytes,
};
use azure_core_test::http::MockHttpClient;
use futures::FutureExt as _;
#[test]
fn param_errors() {
assert!(AzurePipelinesCredential::new("".into(), "".into(), "", "", None).is_err());
assert!(AzurePipelinesCredential::new("_".into(), "".into(), "", "", None).is_err());
assert!(AzurePipelinesCredential::new("a".into(), "".into(), "", "", None).is_err());
assert!(AzurePipelinesCredential::new("a".into(), "b".into(), "", "", None).is_err());
assert!(AzurePipelinesCredential::new("a".into(), "b".into(), "c", "", None).is_err());
let options = AzurePipelinesCredentialOptions {
env: Some(Env::from(
&[(OIDC_VARIABLE_NAME, "http://localhost/get_token")][..],
)),
..Default::default()
};
assert!(
AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options)).is_ok()
);
}
#[tokio::test]
async fn error_response() {
let expected_status = StatusCode::Forbidden;
let body = Bytes::from_static(b"content");
let mut headers = Headers::new();
headers.insert(MSEDGE_REF, "foo");
headers.insert(VSS_E2EID, "bar");
let expected_response =
RawResponse::from_bytes(expected_status, headers.clone(), body.clone());
let headers_for_mock = headers.clone();
let body_for_mock = body.clone();
let mock_client = MockHttpClient::new(move |req| {
assert_eq!(
req.url().as_str(),
"http://localhost/get_token?api-version=7.1&serviceConnectionId=c"
);
let headers = headers_for_mock.clone();
let body = body_for_mock.clone();
async move { Ok(AsyncRawResponse::from_bytes(expected_status, headers, body)) }.boxed()
});
let options = AzurePipelinesCredentialOptions {
credential_options: ClientAssertionCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(mock_client))),
..Default::default()
},
},
env: Some(Env::from(
&[(OIDC_VARIABLE_NAME, "http://localhost/get_token")][..],
)),
};
let err = AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options))
.expect("credential")
.get_token(&["default"], None)
.await
.expect_err("expected error");
assert!(matches!(err.kind(), ErrorKind::Credential));
assert_eq!(
r#"AzurePipelinesCredential authentication failed. 403 response from the OIDC endpoint. Check service connection ID and pipeline configuration. Headers { x-msedge-ref: "foo", x-vss-e2eid: "bar" }
To troubleshoot, visit https://aka.ms/azsdk/rust/identity/troubleshoot#apc"#,
err.to_string(),
);
match err
.downcast_ref::<azure_core::Error>()
.expect("returned error should wrap an azure_core::Error")
.kind()
{
ErrorKind::HttpResponse {
error_code: Some(reason),
raw_response: Some(response),
status,
..
} => {
assert_eq!(status.canonical_reason(), reason.as_str());
assert_eq!(&expected_response, response.as_ref());
assert_eq!(expected_status, *status);
}
err => panic!("unexpected {:?}", err),
};
}
#[tokio::test]
async fn mock_request() {
let mock_client = MockHttpClient::new(|req| {
async move {
if req.url().as_str()
== "http://localhost/get_token?api-version=7.1&serviceConnectionId=c"
{
assert!(matches!(
req.headers().get_str(&AUTHORIZATION),
Ok(value) if value == "Bearer d",
));
assert!(matches!(
req.headers().get_str(&TFS_FEDAUTHREDIRECT_HEADER),
Ok(value) if value == "Suppress",
));
let mut headers = Headers::new();
headers.insert(MSEDGE_REF, "foo");
headers.insert(VSS_E2EID, "bar");
return Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
headers,
Bytes::from_static(br#"{"oidcToken":"baz"}"#),
));
}
if req.url().as_str() == "https://login.microsoftonline.com/a/oauth2/v2.0/token" {
return Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
Bytes::from_static(
br#"{"token_type":"test","expires_in":0,"ext_expires_in":0,"access_token":"qux"}"#,
),
));
}
panic!("not supported")
}.boxed()
});
let options = AzurePipelinesCredentialOptions {
credential_options: ClientAssertionCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(mock_client))),
..Default::default()
},
},
env: Some(Env::from(
&[(OIDC_VARIABLE_NAME, "http://localhost/get_token")][..],
)),
};
let credential =
AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options))
.expect("valid AzurePipelinesCredential");
let secret = credential
.get_token(&["default"], None)
.await
.expect("valid response");
assert_eq!(secret.token.secret(), "qux");
}
}