use std::sync::Arc;
use chrono::{DateTime, Utc};
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use url::Url;
use crate::authn::factor::ZeroizedString;
const STS_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
const STS_SUBJECT_TOKEN_TYPE_JWT: &str = "urn:ietf:params:oauth:token-type:jwt";
const STS_REQUESTED_TOKEN_TYPE_ACCESS: &str = "urn:ietf:params:oauth:token-type:access_token";
fn default_sts_endpoint() -> Url {
Url::parse("https://sts.googleapis.com/v1/token").expect("default GCP STS endpoint is valid")
}
fn default_iam_credentials_base() -> Url {
Url::parse("https://iamcredentials.googleapis.com/")
.expect("default GCP IAM Credentials endpoint is valid")
}
#[derive(Debug, Clone)]
pub struct WorkloadIdentityPoolProvider {
audience: String,
}
impl WorkloadIdentityPoolProvider {
pub fn new(
project_number: u64,
pool_id: impl AsRef<str>,
provider_id: impl AsRef<str>,
) -> Self {
Self {
audience: format!(
"//iam.googleapis.com/projects/{}/locations/global/workloadIdentityPools/{}/providers/{}",
project_number,
pool_id.as_ref(),
provider_id.as_ref(),
),
}
}
pub fn from_audience(audience: impl Into<String>) -> Self {
Self {
audience: audience.into(),
}
}
pub fn as_str(&self) -> &str {
&self.audience
}
}
#[derive(Debug)]
pub struct GcpFederatedToken {
pub access_token: ZeroizedString,
pub expires_in: Option<u64>,
pub token_type: String,
}
#[derive(Clone)]
pub struct GcpStsClient {
endpoint: Arc<Url>,
http: reqwest::Client,
}
impl Default for GcpStsClient {
fn default() -> Self {
Self::new()
}
}
impl GcpStsClient {
pub fn new() -> Self {
Self {
endpoint: Arc::new(default_sts_endpoint()),
http: reqwest::Client::new(),
}
}
pub fn with_endpoint(mut self, endpoint: Url) -> Self {
self.endpoint = Arc::new(endpoint);
self
}
pub fn with_http_client(mut self, http: reqwest::Client) -> Self {
self.http = http;
self
}
pub fn endpoint(&self) -> &Url {
&self.endpoint
}
pub async fn exchange_token(
&self,
web_identity_token: &str,
audience: &WorkloadIdentityPoolProvider,
scopes: &[&str],
) -> Result<GcpFederatedToken, GcpError> {
let scope = scopes.join(" ");
let form = [
("grant_type", STS_GRANT_TYPE),
("audience", audience.as_str()),
("scope", scope.as_str()),
("requested_token_type", STS_REQUESTED_TOKEN_TYPE_ACCESS),
("subject_token", web_identity_token),
("subject_token_type", STS_SUBJECT_TOKEN_TYPE_JWT),
];
let response = self
.http
.post((*self.endpoint).clone())
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.form(&form)
.send()
.await
.map_err(|e| GcpError::Transport(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(GcpError::Sts {
http_status: status.as_u16(),
body,
});
}
let parsed: GcpStsResponseBody = response
.json()
.await
.map_err(|e| GcpError::MalformedResponse(format!("STS JSON: {e}")))?;
if parsed.access_token.is_empty() {
return Err(GcpError::MalformedResponse(
"STS access_token field is empty".to_string(),
));
}
Ok(GcpFederatedToken {
access_token: ZeroizedString::from(parsed.access_token),
expires_in: parsed.expires_in,
token_type: parsed.token_type.unwrap_or_else(|| "Bearer".to_string()),
})
}
}
#[derive(Debug, Deserialize)]
struct GcpStsResponseBody {
access_token: String,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
token_type: Option<String>,
}
#[derive(Debug)]
pub struct GcpServiceAccountToken {
pub access_token: ZeroizedString,
pub expire_time: DateTime<Utc>,
}
#[derive(Clone)]
pub struct GcpServiceAccountImpersonator {
endpoint_base: Arc<Url>,
http: reqwest::Client,
}
impl Default for GcpServiceAccountImpersonator {
fn default() -> Self {
Self::new()
}
}
impl GcpServiceAccountImpersonator {
pub fn new() -> Self {
Self {
endpoint_base: Arc::new(default_iam_credentials_base()),
http: reqwest::Client::new(),
}
}
pub fn with_endpoint_base(mut self, base: Url) -> Self {
self.endpoint_base = Arc::new(base);
self
}
pub fn with_http_client(mut self, http: reqwest::Client) -> Self {
self.http = http;
self
}
pub fn endpoint_base(&self) -> &Url {
&self.endpoint_base
}
pub async fn generate_access_token(
&self,
federated_token: &str,
service_account_email: &str,
scopes: &[&str],
lifetime_seconds: Option<u32>,
) -> Result<GcpServiceAccountToken, GcpError> {
let path = format!(
"v1/projects/-/serviceAccounts/{}:generateAccessToken",
service_account_email
);
let url = self
.endpoint_base
.join(&path)
.map_err(|e| GcpError::MalformedResponse(format!("URL join: {e}")))?;
let body = GenerateAccessTokenRequest {
scope: scopes.iter().map(|s| s.to_string()).collect(),
lifetime: lifetime_seconds.map(|secs| format!("{secs}s")),
};
let response = self
.http
.post(url)
.header(AUTHORIZATION, format!("Bearer {federated_token}"))
.json(&body)
.send()
.await
.map_err(|e| GcpError::Transport(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(GcpError::IamCredentials {
http_status: status.as_u16(),
body,
});
}
let parsed: GenerateAccessTokenResponse = response
.json()
.await
.map_err(|e| GcpError::MalformedResponse(format!("IAM Credentials JSON: {e}")))?;
if parsed.access_token.is_empty() {
return Err(GcpError::MalformedResponse(
"IAM Credentials accessToken field is empty".to_string(),
));
}
let expire_time = DateTime::parse_from_rfc3339(&parsed.expire_time)
.map_err(|e| GcpError::MalformedResponse(format!("expireTime not RFC3339: {e}")))?
.with_timezone(&Utc);
Ok(GcpServiceAccountToken {
access_token: ZeroizedString::from(parsed.access_token),
expire_time,
})
}
}
#[derive(Debug, Serialize)]
struct GenerateAccessTokenRequest {
scope: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
lifetime: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GenerateAccessTokenResponse {
access_token: String,
expire_time: String,
}
#[derive(Debug, thiserror::Error)]
pub enum GcpError {
#[error("GCP transport error: {0}")]
Transport(String),
#[error("GCP STS error (HTTP {http_status}): {body}")]
Sts {
http_status: u16,
body: String,
},
#[error("GCP IAM Credentials error (HTTP {http_status}): {body}")]
IamCredentials {
http_status: u16,
body: String,
},
#[error("malformed GCP response: {0}")]
MalformedResponse(String),
}
#[cfg(test)]
mod tests;