use crate::Result;
use crate::build_errors::Error as BuilderError;
use crate::constants::{JWT_BEARER_GRANT_TYPE, OAUTH2_TOKEN_SERVER_URL};
use crate::credentials::CacheableResource;
use crate::credentials::idtoken::dynamic::IDTokenCredentialsProvider;
use crate::credentials::idtoken::parse_id_token_from_str;
use crate::credentials::service_account::{ServiceAccountKey, ServiceAccountTokenGenerator};
use crate::token::{CachedTokenProvider, Token, TokenProvider};
use crate::token_cache::TokenCache;
use crate::{BuildResult, credentials::idtoken::IDTokenCredentials};
use async_trait::async_trait;
use google_cloud_gax::error::CredentialsError;
use http::Extensions;
use reqwest::Client;
use serde_json::Value;
use std::sync::Arc;
#[derive(Debug)]
struct ServiceAccountCredentials<T>
where
T: CachedTokenProvider,
{
token_provider: T,
}
#[async_trait]
impl<T> IDTokenCredentialsProvider for ServiceAccountCredentials<T>
where
T: CachedTokenProvider,
{
async fn id_token(&self) -> Result<String> {
let cached_token = self.token_provider.token(Extensions::new()).await?;
match cached_token {
CacheableResource::New { data, .. } => Ok(data.token),
CacheableResource::NotModified => {
Err(CredentialsError::from_msg(false, "failed to fetch token"))
}
}
}
}
#[derive(Debug)]
struct ServiceAccountTokenProvider {
service_account_key: ServiceAccountKey,
audience: String,
target_audience: String,
token_server_url: String,
}
#[derive(serde::Deserialize)]
struct IdTokenResponse {
id_token: String,
}
#[async_trait]
impl TokenProvider for ServiceAccountTokenProvider {
async fn token(&self) -> Result<Token> {
let audience = self.audience.clone();
let target_audience = self.target_audience.clone();
let service_account_key = self.service_account_key.clone();
let tg = ServiceAccountTokenGenerator::new_id_token_generator(
target_audience,
audience,
service_account_key,
);
let assertion = tg.generate()?;
let client = Client::new();
let request = client.post(&self.token_server_url).form(&[
("grant_type", JWT_BEARER_GRANT_TYPE.to_string()),
("assertion", assertion),
]);
let response = request
.send()
.await
.map_err(|e| crate::errors::from_http_error(e, "failed to exchange id token"))?;
if !response.status().is_success() {
let err = crate::errors::from_http_response(response, "failed to fetch id token").await;
return Err(err);
}
let token_res: IdTokenResponse = response
.json()
.await
.map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?;
parse_id_token_from_str(token_res.id_token)
}
}
pub struct Builder {
service_account_key: Value,
target_audience: String,
token_server_url: String,
}
impl Builder {
pub fn new<S: Into<String>>(target_audience: S, service_account_key: Value) -> Self {
Self {
service_account_key,
target_audience: target_audience.into(),
token_server_url: OAUTH2_TOKEN_SERVER_URL.to_string(),
}
}
fn build_token_provider(
self,
target_audience: String,
) -> BuildResult<ServiceAccountTokenProvider> {
let service_account_key =
serde_json::from_value::<ServiceAccountKey>(self.service_account_key)
.map_err(BuilderError::parsing)?;
let universe_domain = service_account_key.universe_domain.as_deref();
if !crate::universe_domain::is_default_universe_domain(universe_domain) {
return Err(BuilderError::not_supported(
"Service Account Credentials do not support getting an ID token in universes other than googleapis.com",
));
}
Ok(ServiceAccountTokenProvider {
service_account_key,
audience: OAUTH2_TOKEN_SERVER_URL.to_string(),
target_audience,
token_server_url: self.token_server_url,
})
}
pub fn build(self) -> BuildResult<IDTokenCredentials> {
let target_audience = self.target_audience.clone();
let creds = ServiceAccountCredentials {
token_provider: TokenCache::new(self.build_token_provider(target_audience)?),
};
Ok(IDTokenCredentials {
inner: Arc::new(creds),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::credentials::tests::PKCS8_PK;
use crate::{
constants::JWT_BEARER_GRANT_TYPE, credentials::idtoken::tests::generate_test_id_token,
};
use httptest::{
Expectation, Server,
matchers::{all_of, any, contains, request, url_decoded},
responders::*,
};
use serde_json::Value;
use serde_json::json;
type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
impl Builder {
fn with_token_server_url<S: Into<String>>(mut self, url: S) -> Self {
self.token_server_url = url.into();
self
}
}
fn get_mock_service_key() -> Value {
json!({
"client_email": "test-client-email",
"private_key_id": "test-private-key-id",
"private_key": "",
"project_id": "test-project-id",
})
}
#[tokio::test]
async fn idtoken_success() -> TestResult {
let audience = "test-audience";
let token = generate_test_id_token(audience);
let server = Server::run();
server.expect(
Expectation::matching(all_of![
request::method("POST"),
request::path("/"),
request::body(url_decoded(contains(("grant_type", JWT_BEARER_GRANT_TYPE)))),
request::body(url_decoded(contains(("assertion", any())))),
])
.respond_with(json_encoded(json!({ "id_token": token}))),
);
let mut service_account_key = get_mock_service_key();
service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
let creds = Builder::new(audience, service_account_key)
.with_token_server_url(server.url("/").to_string())
.build()?;
let id_token = creds.id_token().await?;
assert_eq!(id_token, token);
Ok(())
}
#[tokio::test]
async fn idtoken_http_error() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::method("POST"), request::path("/"),])
.respond_with(status_code(501)),
);
let mut service_account_key = get_mock_service_key();
service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
let creds = Builder::new("test-audience", service_account_key)
.with_token_server_url(server.url("/").to_string())
.build()?;
let err = creds.id_token().await.unwrap_err();
assert!(!err.is_transient());
Ok(())
}
#[tokio::test]
async fn idtoken_caching() -> TestResult {
let audience = "test-audience";
let token = generate_test_id_token(audience);
let server = Server::run();
server.expect(
Expectation::matching(all_of![
request::method("POST"),
request::path("/"),
request::body(url_decoded(contains(("grant_type", JWT_BEARER_GRANT_TYPE)))),
request::body(url_decoded(contains(("assertion", any())))),
])
.times(1)
.respond_with(json_encoded(json!({ "id_token": token}))),
);
let mut service_account_key = get_mock_service_key();
service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
let creds = Builder::new("test-audience", service_account_key)
.with_token_server_url(format!("http://{}", server.addr()))
.build()?;
let id_token = creds.id_token().await?;
assert_eq!(id_token, token);
let id_token = creds.id_token().await?;
assert_eq!(id_token, token);
Ok(())
}
#[tokio::test]
async fn idtoken_builder_fails_for_non_gdu() -> TestResult {
let mut service_account_key = get_mock_service_key();
service_account_key["universe_domain"] = Value::from("non-gdu.com");
service_account_key["private_key"] = Value::from(PKCS8_PK.clone());
let result = Builder::new("test-audience", service_account_key).build();
assert!(result.is_err(), "{result:?}");
let err = result.unwrap_err();
assert!(err.is_not_supported(), "{err:?}");
Ok(())
}
}