use crate::Result;
use crate::credentials::CacheableResource;
use crate::errors::CredentialsError;
use crate::mds::client::Client as MDSClient;
use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
use crate::token::{CachedTokenProvider, Token, TokenProvider};
use crate::token_cache::TokenCache;
use crate::{
BuildResult,
credentials::idtoken::dynamic::IDTokenCredentialsProvider,
credentials::idtoken::{IDTokenCredentials, parse_id_token_from_str},
};
use async_trait::async_trait;
use google_cloud_gax::backoff_policy::BackoffPolicyArg;
use google_cloud_gax::retry_policy::RetryPolicyArg;
use google_cloud_gax::retry_throttler::RetryThrottlerArg;
use http::Extensions;
use std::sync::Arc;
#[derive(Debug)]
pub(crate) struct MDSCredentials<T>
where
T: CachedTokenProvider,
{
token_provider: T,
}
#[async_trait]
impl<T> IDTokenCredentialsProvider for MDSCredentials<T>
where
T: CachedTokenProvider,
{
async fn id_token(&self) -> Result<String> {
let cached_token = self.token_provider.token(Extensions::new()).await?;
match cached_token {
CacheableResource::New { data, .. } => Ok(data.token),
CacheableResource::NotModified => {
Err(CredentialsError::from_msg(false, "failed to fetch token"))
}
}
}
}
#[derive(Debug, Clone)]
pub enum Format {
Standard,
Full,
UnknownValue(String),
}
impl Format {
fn as_str(&self) -> &str {
match self {
Format::Standard => "standard",
Format::Full => "full",
Format::UnknownValue(value) => value.as_str(),
}
}
}
pub struct Builder {
endpoint: Option<String>,
pub(crate) format: Option<Format>,
licenses: Option<String>,
target_audience: String,
retry_builder: RetryTokenProviderBuilder,
}
impl Builder {
pub fn new<S: Into<String>>(target_audience: S) -> Self {
Builder {
format: None,
endpoint: None,
licenses: None,
target_audience: target_audience.into(),
retry_builder: RetryTokenProviderBuilder::default(),
}
}
pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
self.endpoint = Some(endpoint.into());
self
}
pub fn with_format(mut self, format: Format) -> Self {
self.format = Some(format);
self
}
pub fn with_licenses(mut self, licenses: bool) -> Self {
self.licenses = if licenses {
Some("TRUE".to_string())
} else {
Some("FALSE".to_string())
};
self
}
pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_policy(v.into());
self
}
pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
self
}
pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
self
}
fn build_token_provider(self) -> TokenProviderWithRetry<MDSTokenProvider> {
let client = MDSClient::new(self.endpoint);
let tp = MDSTokenProvider {
format: self.format,
licenses: self.licenses,
client,
target_audience: self.target_audience,
};
self.retry_builder.build(tp)
}
pub fn build(self) -> BuildResult<IDTokenCredentials> {
let creds = MDSCredentials {
token_provider: TokenCache::new(self.build_token_provider()),
};
Ok(IDTokenCredentials {
inner: Arc::new(creds),
})
}
}
#[derive(Debug, Clone)]
struct MDSTokenProvider {
client: MDSClient,
format: Option<Format>,
licenses: Option<String>,
target_audience: String,
}
#[async_trait]
impl TokenProvider for MDSTokenProvider {
async fn token(&self) -> Result<Token> {
let format = self.format.clone().map(|f| String::from(f.as_str()));
let licenses = self.licenses.clone();
let aud = self.target_audience.clone();
let token = self.client.id_token(&aud, format, licenses).send().await?;
parse_id_token_from_str(token)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::credentials::idtoken::tests::generate_test_id_token;
use crate::credentials::tests::{
find_source_error, get_mock_auth_retry_policy, get_mock_backoff_policy,
get_mock_retry_throttler,
};
use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI};
use httptest::cycle;
use httptest::matchers::{all_of, contains, request, url_decoded};
use httptest::responders::status_code;
use httptest::{Expectation, Server};
use reqwest::StatusCode;
use scoped_env::ScopedEnv;
use serial_test::{parallel, serial};
use test_case::test_case;
type TestResult = anyhow::Result<()>;
#[tokio::test]
#[parallel]
async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
let server = Server::run();
let audience = "test-audience";
server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience)))),
])
.times(1)
.respond_with(status_code(401)),
);
let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;
let err = creds.id_token().await.unwrap_err();
let source = find_source_error::<google_cloud_gax::error::Error>(&err);
assert!(
matches!(source, Some(e) if e.http_status_code() == Some(StatusCode::UNAUTHORIZED.into())),
"{err:?}"
);
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_mds_retries_for_success() -> TestResult {
let server = Server::run();
let audience = "test-audience";
let token_string = generate_test_id_token(audience);
server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience)))),
])
.times(3)
.respond_with(cycle![
status_code(503).body("try-again"),
status_code(503).body("try-again"),
status_code(200).body(token_string.clone()),
]),
);
let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;
let id_token = creds.id_token().await?;
assert_eq!(id_token, token_string);
Ok(())
}
#[tokio::test]
#[test_case(Format::Standard)]
#[test_case(Format::Full)]
#[test_case(Format::UnknownValue("minimal".to_string()))]
#[parallel]
async fn test_idtoken_builder_build(format: Format) -> TestResult {
let server = Server::run();
let audience = "test-audience";
let token_string = generate_test_id_token(audience);
let format_str = format.as_str().to_string();
server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience)))),
request::query(url_decoded(contains(("format", format_str)))),
request::query(url_decoded(contains(("licenses", "TRUE"))))
])
.respond_with(status_code(200).body(token_string.clone())),
);
let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.with_format(format)
.with_licenses(true)
.build()?;
let id_token = creds.id_token().await?;
assert_eq!(id_token, token_string);
Ok(())
}
#[tokio::test]
#[serial]
async fn test_idtoken_builder_build_with_env_var() -> TestResult {
let server = Server::run();
let audience = "test-audience";
let token_string = generate_test_id_token(audience);
server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience))))
])
.respond_with(status_code(200).body(token_string.clone())),
);
let addr = server.addr().to_string();
let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
let creds = Builder::new(audience).build()?;
let id_token = creds.id_token().await?;
assert_eq!(id_token, token_string);
let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_idtoken_provider_http_error() -> TestResult {
let server = Server::run();
let audience = "test-audience";
server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience))))
])
.respond_with(status_code(503)),
);
let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.build()?;
let err = creds.id_token().await.unwrap_err();
let source = find_source_error::<google_cloud_gax::error::Error>(&err);
assert!(
matches!(source, Some(e) if e.http_status_code() == Some(StatusCode::SERVICE_UNAVAILABLE.into())),
"{err:?}"
);
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_idtoken_caching() -> TestResult {
let server = Server::run();
let audience = "test-audience";
let token_string = generate_test_id_token(audience);
server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience))))
])
.times(1)
.respond_with(status_code(200).body(token_string.clone())),
);
let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.build()?;
let id_token = creds.id_token().await?;
assert_eq!(id_token, token_string);
let id_token = creds.id_token().await?;
assert_eq!(id_token, token_string);
Ok(())
}
}