use crate::build_errors::Error as BuilderError;
use crate::credentials::CacheableResource;
use crate::credentials::user_account::UserTokenProvider;
use crate::retry::Builder as RetryTokenProviderBuilder;
use crate::token::CachedTokenProvider;
use crate::token_cache::TokenCache;
use crate::{
BuildResult, Result,
credentials::{
idtoken::{IDTokenCredentials, dynamic::IDTokenCredentialsProvider},
user_account::AuthorizedUser,
},
};
use async_trait::async_trait;
use google_cloud_gax::backoff_policy::BackoffPolicyArg;
use google_cloud_gax::error::CredentialsError;
use google_cloud_gax::retry_policy::RetryPolicyArg;
use google_cloud_gax::retry_throttler::RetryThrottlerArg;
use http::Extensions;
use serde_json::Value;
use std::sync::Arc;
#[derive(Debug)]
struct UserAccountCredentials<T>
where
T: CachedTokenProvider,
{
token_provider: T,
}
#[async_trait]
impl<T> IDTokenCredentialsProvider for UserAccountCredentials<T>
where
T: CachedTokenProvider,
{
async fn id_token(&self) -> Result<String> {
let cached_token = self.token_provider.token(Extensions::new()).await?;
match cached_token {
CacheableResource::New { data, .. } => Ok(data.token),
CacheableResource::NotModified => {
Err(CredentialsError::from_msg(false, "failed to fetch token"))
}
}
}
}
pub struct Builder {
authorized_user: Value,
token_uri: Option<String>,
retry_builder: RetryTokenProviderBuilder,
}
impl Builder {
pub fn new(authorized_user: Value) -> Self {
Self {
authorized_user,
token_uri: None,
retry_builder: RetryTokenProviderBuilder::default(),
}
}
pub fn with_token_uri<S: Into<String>>(mut self, token_uri: S) -> Self {
self.token_uri = Some(token_uri.into());
self
}
pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_policy(v.into());
self
}
pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
self
}
pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
self
}
fn build_token_provider(&self) -> BuildResult<UserTokenProvider> {
let authorized_user =
serde_json::from_value::<AuthorizedUser>(self.authorized_user.clone())
.map_err(BuilderError::parsing)?;
Ok(UserTokenProvider::new_id_token_provider(
authorized_user,
self.token_uri.clone(),
))
}
pub fn build(self) -> BuildResult<IDTokenCredentials> {
let provider = self.build_token_provider()?;
let provider = self.retry_builder.build(provider);
let creds = UserAccountCredentials {
token_provider: TokenCache::new(provider),
};
Ok(IDTokenCredentials {
inner: Arc::new(creds),
})
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use super::*;
use crate::credentials::idtoken::tests::generate_test_id_token;
use crate::credentials::tests::{
find_source_error, get_mock_auth_retry_policy, get_mock_backoff_policy,
get_mock_retry_throttler,
};
use crate::credentials::user_account::{
Oauth2RefreshRequest, Oauth2RefreshResponse, RefreshGrantType,
};
use http::StatusCode;
use httptest::cycle;
use httptest::matchers::{all_of, json_decoded, request};
use httptest::responders::{json_encoded, status_code};
use httptest::{Expectation, Server};
type TestResult = anyhow::Result<()>;
fn authorized_user_json(token_uri: String) -> Value {
serde_json::json!({
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"refresh_token": "test-refresh-token",
"type": "authorized_user",
"token_uri": token_uri,
})
}
fn check_request(request: &Oauth2RefreshRequest) -> bool {
request.client_id == "test-client-id"
&& request.client_secret == "test-client-secret"
&& request.refresh_token == "test-refresh-token"
&& request.grant_type == RefreshGrantType::RefreshToken
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn id_token_success() -> TestResult {
let server = Server::run();
let audience = "test-audience";
let token = generate_test_id_token(audience);
let response = Oauth2RefreshResponse {
access_token: "test-access-token".to_string(),
id_token: Some(token.clone()),
expires_in: Some(3600),
refresh_token: Some("test-refresh-token".to_string()),
scope: None,
token_type: "Bearer".to_string(),
};
server.expect(
Expectation::matching(all_of![
request::path("/token"),
request::body(json_decoded(|req: &Oauth2RefreshRequest| {
check_request(req)
}))
])
.respond_with(json_encoded(response)),
);
let authorized_user = authorized_user_json(server.url("/token").to_string());
let creds = Builder::new(authorized_user).build()?;
let id_token = creds.id_token().await?;
assert_eq!(id_token, token);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn id_token_missing_id_token_in_response() -> TestResult {
let server = Server::run();
let response = Oauth2RefreshResponse {
access_token: "test-access-token".to_string(),
id_token: None, expires_in: Some(3600),
refresh_token: Some("test-refresh-token".to_string()),
scope: None,
token_type: "Bearer".to_string(),
};
server.expect(
Expectation::matching(all_of![
request::path("/token"),
request::body(json_decoded(|req: &Oauth2RefreshRequest| {
check_request(req)
}))
])
.respond_with(json_encoded(response)),
);
let authorized_user = authorized_user_json(server.url("/token").to_string());
let creds = Builder::new(authorized_user).build()?;
let err = creds.id_token().await.unwrap_err();
assert!(!err.is_transient());
let source = err.source().unwrap();
assert!(
source
.to_string()
.contains("can obtain an id token only when authenticated through gcloud")
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn id_token_builder_malformed_authorized_json_nonretryable() -> TestResult {
let authorized_user = serde_json::json!({
"client_secret": "test-client-secret",
"refresh_token": "test-refresh-token",
"type": "authorized_user",
});
let e = Builder::new(authorized_user).build().unwrap_err();
assert!(e.is_parsing(), "{e}");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn id_token_nonretryable_error() -> TestResult {
let server = Server::run();
server
.expect(Expectation::matching(request::path("/token")).respond_with(status_code(401)));
let authorized_user = authorized_user_json(server.url("/token").to_string());
let creds = Builder::new(authorized_user).build()?;
let err = creds.id_token().await.unwrap_err();
assert!(!err.is_transient());
let source = find_source_error::<reqwest::Error>(&err);
assert!(
matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
"{err:?}"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_user_account_id_token_retries_on_transient_failures() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(request::path("/token"))
.times(3)
.respond_with(status_code(503)),
);
let authorized_user = authorized_user_json(server.url("/token").to_string());
let credentials = Builder::new(authorized_user)
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;
let err = credentials.id_token().await.unwrap_err();
assert!(err.is_transient(), "{err:?}");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_user_account_id_token_retries_for_success() -> TestResult {
let server = Server::run();
let response = Oauth2RefreshResponse {
access_token: "test-access-token".to_string(),
id_token: Some("test-id-token".to_string()),
expires_in: Some(3600),
refresh_token: Some("test-refresh-token".to_string()),
scope: None,
token_type: "Bearer".to_string(),
};
server.expect(
Expectation::matching(request::path("/token"))
.times(3)
.respond_with(cycle![
status_code(503).body("try-again"),
status_code(503).body("try-again"),
status_code(200)
.append_header("Content-Type", "application/json")
.body(serde_json::to_string(&response).unwrap()),
]),
);
let authorized_user = authorized_user_json(server.url("/token").to_string());
let credentials = Builder::new(authorized_user)
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;
let id_token = credentials.id_token().await.unwrap();
assert_eq!(id_token, "test-id-token");
Ok(())
}
#[tokio::test]
async fn idtoken_caching() -> TestResult {
let audience = "test-audience";
let token = generate_test_id_token(audience);
let server = Server::run();
let response = Oauth2RefreshResponse {
access_token: "test-access-token".to_string(),
id_token: Some(token.clone()),
expires_in: Some(3600),
refresh_token: None,
scope: None,
token_type: "Bearer".to_string(),
};
server.expect(
Expectation::matching(all_of![
request::path("/token"),
request::body(json_decoded(|req: &Oauth2RefreshRequest| {
check_request(req)
}))
])
.times(1)
.respond_with(json_encoded(response)),
);
let authorized_user = authorized_user_json(server.url("/token").to_string());
let creds = Builder::new(authorized_user).build()?;
let id_token = creds.id_token().await?;
assert_eq!(id_token, token);
let id_token = creds.id_token().await?;
assert_eq!(id_token, token);
Ok(())
}
}