use crate::build_errors::Error as BuilderError;
use crate::credentials::{AdcContents, CredentialsError, extract_credential_type, load_adc};
use crate::token::Token;
use crate::{BuildResult, Result};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::time::Instant;
pub mod impersonated;
pub mod mds;
pub mod service_account;
pub mod user_account;
pub mod verifier;
#[derive(Clone, Debug)]
pub struct IDTokenCredentials {
pub(crate) inner: Arc<dyn dynamic::IDTokenCredentialsProvider>,
}
impl<T> From<T> for IDTokenCredentials
where
T: IDTokenCredentialsProvider + Send + Sync + 'static,
{
fn from(value: T) -> Self {
Self {
inner: Arc::new(value),
}
}
}
impl IDTokenCredentials {
pub async fn id_token(&self) -> Result<String> {
self.inner.id_token().await
}
}
pub trait IDTokenCredentialsProvider: std::fmt::Debug {
fn id_token(&self) -> impl Future<Output = Result<String>> + Send;
}
pub(crate) mod dynamic {
use crate::Result;
#[async_trait::async_trait]
pub trait IDTokenCredentialsProvider: Send + Sync + std::fmt::Debug {
async fn id_token(&self) -> Result<String>;
}
#[async_trait::async_trait]
impl<T> IDTokenCredentialsProvider for T
where
T: super::IDTokenCredentialsProvider + Send + Sync,
{
async fn id_token(&self) -> Result<String> {
T::id_token(self).await
}
}
}
pub struct Builder {
target_audience: String,
include_email: bool,
}
impl Builder {
pub fn new<S: Into<String>>(target_audience: S) -> Self {
Self {
target_audience: target_audience.into(),
include_email: false,
}
}
pub fn with_include_email(mut self) -> Self {
self.include_email = true;
self
}
pub fn build(self) -> BuildResult<IDTokenCredentials> {
let json_data = match load_adc()? {
AdcContents::Contents(contents) => {
Some(serde_json::from_str(&contents).map_err(BuilderError::parsing)?)
}
AdcContents::FallbackToMds => None,
};
build_id_token_credentials(self.target_audience, self.include_email, json_data)
}
}
enum IDTokenBuilder {
Mds(mds::Builder),
ServiceAccount(service_account::Builder),
Impersonated(impersonated::Builder),
}
fn build_id_token_credentials(
audience: String,
include_email: bool,
json: Option<Value>,
) -> BuildResult<IDTokenCredentials> {
let builder = build_id_token_credentials_internal(audience, include_email, json)?;
match builder {
IDTokenBuilder::Mds(builder) => builder.build(),
IDTokenBuilder::ServiceAccount(builder) => builder.build(),
IDTokenBuilder::Impersonated(builder) => builder.build(),
}
}
fn build_id_token_credentials_internal(
audience: String,
include_email: bool,
json: Option<Value>,
) -> BuildResult<IDTokenBuilder> {
match json {
None => {
let format = if include_email {
mds::Format::Full
} else {
mds::Format::Standard
};
Ok(IDTokenBuilder::Mds(
mds::Builder::new(audience).with_format(format),
))
}
Some(json) => {
let cred_type = extract_credential_type(&json)?;
match cred_type {
"authorized_user" => Err(BuilderError::not_supported(format!(
"{cred_type}, use idtoken::user_account::Builder directly."
))),
"service_account" => Ok(IDTokenBuilder::ServiceAccount(
service_account::Builder::new(audience, json),
)),
"impersonated_service_account" => {
let builder = impersonated::Builder::new(audience, json);
let builder = if include_email {
builder.with_include_email()
} else {
builder
};
Ok(IDTokenBuilder::Impersonated(builder))
}
"external_account" => {
Err(BuilderError::not_supported(cred_type))
}
_ => Err(BuilderError::unknown_type(cred_type)),
}
}
}
}
pub(crate) fn parse_id_token_from_str(token: String) -> Result<Token> {
parse_id_token_from_str_impl(token, SystemTime::now())
}
fn parse_id_token_from_str_impl(token: String, now: SystemTime) -> Result<Token> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(CredentialsError::from_msg(false, "invalid JWT token"));
}
let payload = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|e| CredentialsError::from_source(false, e))?;
let claims: HashMap<String, Value> =
serde_json::from_slice(&payload).map_err(|e| CredentialsError::from_source(false, e))?;
let expires_at = claims["exp"]
.as_u64()
.and_then(|exp| instant_from_epoch_seconds(exp, now));
Ok(Token {
token,
token_type: "Bearer".to_string(),
expires_at,
metadata: None,
})
}
fn instant_from_epoch_seconds(secs: u64, now: SystemTime) -> Option<Instant> {
now.duration_since(UNIX_EPOCH).ok().map(|d| {
let diff = d.abs_diff(Duration::from_secs(secs));
Instant::now() + diff
})
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::credentials::service_account::jws::JwsHeader;
use mds::Format;
use p256::ecdsa::signature::Signer;
use p256::ecdsa::{Signature, SigningKey};
use rsa::Pkcs1v15Sign;
use rsa::sha2::{Digest, Sha256};
use serde_json::json;
use serial_test::parallel;
use std::collections::HashMap;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
type TestResult = anyhow::Result<()>;
const DEFAULT_TEST_TOKEN_EXPIRATION: Duration = Duration::from_secs(3600);
pub(crate) const TEST_KEY_ID: &str = "test-key-id";
pub(crate) fn generate_test_id_token<S: Into<String>>(audience: S) -> String {
generate_test_id_token_with_claims(audience, HashMap::new())
}
pub(crate) fn generate_test_id_token_with_claims<S: Into<String>>(
audience: S,
claims_to_add: HashMap<&str, Value>,
) -> String {
generate_test_id_token_impl(audience.into(), claims_to_add, SystemTime::now())
}
fn generate_test_id_token_impl(
audience: String,
claims_to_add: HashMap<&str, Value>,
now: SystemTime,
) -> String {
let now = now.duration_since(UNIX_EPOCH).unwrap();
let then = now + DEFAULT_TEST_TOKEN_EXPIRATION;
let header = JwsHeader {
alg: "RS256",
typ: "JWT",
kid: Some(TEST_KEY_ID.to_string()),
};
let mut claims: HashMap<&str, Value> = HashMap::new();
claims.insert("aud", Value::String(audience));
claims.insert("iss", "accounts.google.com".into());
claims.insert("exp", then.as_secs().into());
claims.insert("iat", now.as_secs().into());
for (k, v) in claims_to_add {
claims.insert(k, v);
}
let key = crate::credentials::tests::RSA_PRIVATE_KEY.clone();
let encoded_header = header.encode().unwrap();
let encoded_claims = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap());
let to_sign = format!("{}.{}", encoded_header, encoded_claims);
let digest = Sha256::digest(to_sign.as_bytes());
let sig = key
.sign(Pkcs1v15Sign::new::<Sha256>(), &digest)
.expect("Failed to sign");
let encoded_sig = URL_SAFE_NO_PAD.encode(sig);
format!("{}.{}", to_sign, encoded_sig)
}
pub(crate) fn generate_test_id_token_es256(audience: &str) -> String {
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
let then = now + DEFAULT_TEST_TOKEN_EXPIRATION;
let header = JwsHeader {
alg: "ES256",
typ: "JWT",
kid: Some(TEST_KEY_ID.to_string()),
};
let mut claims: HashMap<&str, Value> = HashMap::new();
claims.insert("aud", Value::String(audience.to_string()));
claims.insert("iss", "accounts.google.com".into());
claims.insert("exp", then.as_secs().into());
claims.insert("iat", now.as_secs().into());
let private_key = crate::credentials::tests::ES256_PRIVATE_KEY.clone();
let key = SigningKey::from(private_key);
let encoded_header = header.encode().unwrap();
let encoded_claims = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims).unwrap());
let to_sign = format!("{}.{}", encoded_header, encoded_claims);
let sig: Signature = key.sign(to_sign.as_bytes());
let encoded_sig = URL_SAFE_NO_PAD.encode(sig.to_bytes());
format!("{}.{}", to_sign, encoded_sig)
}
#[tokio::test(start_paused = true)]
#[parallel]
async fn test_parse_id_token() -> TestResult {
let now = SystemTime::now();
let audience = "https://example.com".to_string();
let id_token = generate_test_id_token_impl(audience, HashMap::new(), now);
let token = parse_id_token_from_str_impl(id_token.clone(), now)?;
assert_eq!(token.token, id_token);
assert!(token.expires_at.is_some(), "{token:?}");
let expires_at = token.expires_at.unwrap();
let expiration = expires_at.duration_since(Instant::now());
let rounding = {
let ts = now.duration_since(UNIX_EPOCH).unwrap();
ts - Duration::from_secs(ts.as_secs())
};
assert_eq!(expiration + rounding, DEFAULT_TEST_TOKEN_EXPIRATION);
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_build_id_token_credentials_authorized_user_not_supported() -> TestResult {
let audience = "test_audience".to_string();
let json = serde_json::json!({
"type": "authorized_user",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
"refresh_token": "test_refresh_token",
});
let result = build_id_token_credentials(audience, false, Some(json));
assert!(result.is_err(), "{result:?}");
let err = result.unwrap_err();
assert!(err.is_not_supported());
assert!(
err.to_string()
.contains("authorized_user, use idtoken::user_account::Builder directly.")
);
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_build_id_token_credentials_external_account_not_supported() -> TestResult {
let audience = "test_audience".to_string();
let json = serde_json::json!({
"type": "external_account",
"audience": "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/my-pool/providers/my-provider",
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
"token_url": "https://sts.googleapis.com/v1/token",
"credential_source": {
"file": "/path/to/file",
"format": {
"type": "text"
}
}
});
let result = build_id_token_credentials(audience, false, Some(json));
assert!(result.is_err(), "{result:?}");
let err = result.unwrap_err();
assert!(err.is_not_supported());
assert!(err.to_string().contains("external_account"));
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_build_id_token_credentials_unknown_type() -> TestResult {
let audience = "test_audience".to_string();
let json = serde_json::json!({
"type": "unknown_credential_type",
});
let result = build_id_token_credentials(audience, false, Some(json));
assert!(result.is_err(), "{result:?}");
let err = result.unwrap_err();
assert!(err.is_unknown_type());
assert!(err.to_string().contains("unknown_credential_type"));
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_build_id_token_include_email_mds() -> TestResult {
let audience = "test_audience".to_string();
let creds = build_id_token_credentials_internal(audience.clone(), true, None)?;
assert!(matches!(creds, IDTokenBuilder::Mds(_)));
if let IDTokenBuilder::Mds(builder) = creds {
assert!(matches!(builder.format, Some(Format::Full)));
}
let creds = build_id_token_credentials_internal(audience.clone(), false, None)?;
assert!(matches!(creds, IDTokenBuilder::Mds(_)));
if let IDTokenBuilder::Mds(builder) = creds {
assert!(matches!(builder.format, Some(Format::Standard)));
}
Ok(())
}
#[tokio::test]
#[parallel]
async fn test_build_id_token_include_email_impersonated() -> TestResult {
let audience = "test_audience".to_string();
let json = json!({
"type": "impersonated_service_account",
"source_credentials": {
"type": "service_account",
"project_id": "test-project",
"private_key_id": "test-key-id",
"private_key": "-----BEGIN PRIVATE KEY-----\n-----END PRIVATE KEY-----",
"client_email": "source@test-project.iam.gserviceaccount.com",
"client_id": "test-client-id",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/source%40test-project.iam.gserviceaccount.com"
},
"service_account_impersonation_url": "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/target@test-project.iam.gserviceaccount.com:generateIdToken"
});
let creds =
build_id_token_credentials_internal(audience.clone(), true, Some(json.clone()))?;
assert!(matches!(creds, IDTokenBuilder::Impersonated(_)));
if let IDTokenBuilder::Impersonated(builder) = creds {
assert_eq!(builder.include_email, Some(true));
}
let creds = build_id_token_credentials_internal(audience.clone(), false, Some(json))?;
assert!(matches!(creds, IDTokenBuilder::Impersonated(_)));
if let IDTokenBuilder::Impersonated(builder) = creds {
assert_eq!(builder.include_email, None);
}
Ok(())
}
}