use crate::credentials::{CacheableResource, Credentials};
use crate::signer::{Result, SigningError, dynamic::SigningProvider};
use google_cloud_gax::backoff_policy::{BackoffPolicy, BackoffPolicyArg};
use google_cloud_gax::exponential_backoff::ExponentialBackoff;
use google_cloud_gax::retry_loop_internal::retry_loop;
use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyArg, RetryPolicyExt};
use google_cloud_gax::retry_throttler::{
AdaptiveThrottler, RetryThrottlerArg, SharedRetryThrottler,
};
use http::{Extensions, HeaderMap};
use reqwest::Client;
use std::sync::Arc;
#[derive(Debug)]
pub(crate) struct IamSigner {
client_email: String,
inner: Credentials,
endpoint: String,
client: Client,
}
#[derive(Debug, Clone, serde::Serialize)]
struct SignBlobRequest {
payload: String,
}
#[derive(Debug, serde::Deserialize)]
struct SignBlobResponse {
#[serde(rename = "signedBlob")]
signed_blob: String,
}
impl IamSigner {
pub(crate) fn new(client_email: String, inner: Credentials, endpoint: Option<String>) -> Self {
Self {
client_email,
inner,
endpoint: endpoint.unwrap_or("https://iamcredentials.googleapis.com".to_string()),
client: Client::new(),
}
}
}
#[async_trait::async_trait]
impl SigningProvider for IamSigner {
async fn client_email(&self) -> Result<String> {
Ok(self.client_email.clone())
}
async fn sign(&self, content: &[u8]) -> Result<bytes::Bytes> {
use base64::{Engine, prelude::BASE64_STANDARD};
let payload = BASE64_STANDARD.encode(content);
let body = SignBlobRequest { payload };
let client_email = self.client_email.clone();
let url = format!(
"{}/v1/projects/-/serviceAccounts/{client_email}:signBlob",
self.endpoint
);
let response =
sign_blob_call_with_retry(self.inner.clone(), self.client.clone(), url, body).await?;
if !response.status().is_success() {
let err_text = response.text().await.map_err(SigningError::transport)?;
return Err(SigningError::transport(format!("err status: {err_text:?}")));
}
let res = response
.json::<SignBlobResponse>()
.await
.map_err(SigningError::transport)?;
let signature = BASE64_STANDARD
.decode(res.signed_blob)
.map_err(SigningError::transport)?;
Ok(bytes::Bytes::from(signature))
}
}
async fn sign_blob_call_with_retry(
credentials: Credentials,
client: Client,
url: String,
body: SignBlobRequest,
) -> Result<reqwest::Response> {
let sleep = async |d| tokio::time::sleep(d).await;
let retry_policy: RetryPolicyArg = Aip194Strict.with_attempt_limit(3).into();
let backoff_policy: BackoffPolicyArg = ExponentialBackoff::default().into();
let backoff_policy: Arc<dyn BackoffPolicy> = backoff_policy.into();
let retry_throttler: RetryThrottlerArg = AdaptiveThrottler::default().into();
let retry_throttler: SharedRetryThrottler = retry_throttler.into();
retry_loop(
async move |_| {
let source_headers = credentials
.headers(Extensions::new())
.await
.map_err(google_cloud_gax::error::Error::authentication)?;
sign_blob_call(&client, &url, source_headers, body.clone()).await
},
sleep,
true, retry_throttler,
retry_policy.into(),
backoff_policy,
)
.await
.map_err(SigningError::transport)
}
async fn sign_blob_call(
client: &Client,
url: &str,
source_headers: CacheableResource<HeaderMap>,
body: SignBlobRequest,
) -> google_cloud_gax::Result<reqwest::Response> {
let source_headers = match source_headers {
CacheableResource::New { data, .. } => data,
CacheableResource::NotModified => {
unreachable!("requested source credentials without a caching etag")
}
};
let response = client
.post(url)
.header("Content-Type", "application/json")
.headers(source_headers.clone())
.json(&body)
.send()
.await
.map_err(google_cloud_gax::error::Error::io)?;
let status = response.status();
if !status.is_success() {
let err_headers = response.headers().clone();
let err_payload = response
.bytes()
.await
.map_err(|e| google_cloud_gax::error::Error::transport(err_headers.clone(), e))?;
return Err(google_cloud_gax::error::Error::http(
status.as_u16(),
err_headers,
err_payload,
));
}
Ok(response)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::credentials::{Credentials, CredentialsProvider, EntityTag};
use crate::errors::CredentialsError;
use base64::{Engine, prelude::BASE64_STANDARD};
use http::header::{HeaderName, HeaderValue};
use http::{Extensions, HeaderMap};
use httptest::cycle;
use httptest::matchers::{all_of, contains, eq, json_decoded, request};
use httptest::responders::{json_encoded, status_code};
use httptest::{Expectation, Server};
use serde_json::json;
type TestResult = anyhow::Result<()>;
mockall::mock! {
#[derive(Debug)]
Credentials {}
impl CredentialsProvider for Credentials {
async fn headers(&self, extensions: Extensions) -> std::result::Result<CacheableResource<HeaderMap>, CredentialsError>;
async fn universe_domain(&self) -> Option<String>;
}
}
#[tokio::test]
async fn test_iam_sign() -> TestResult {
let server = Server::run();
let payload = BASE64_STANDARD.encode("test");
let signed_blob = BASE64_STANDARD.encode("signed_blob");
server.expect(
Expectation::matching(all_of![
request::method_path(
"POST",
"/v1/projects/-/serviceAccounts/test@example.com:signBlob"
),
request::headers(contains(("authorization", "Bearer test-value"))),
request::body(json_decoded(eq(json!({
"payload": payload,
}))))
])
.respond_with(json_encoded(json!({
"signedBlob": signed_blob,
}))),
);
let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
let mut mock = MockCredentials::new();
mock.expect_headers().return_once(|_extensions| {
let headers = HeaderMap::from_iter([(
HeaderName::from_static("authorization"),
HeaderValue::from_static("Bearer test-value"),
)]);
Ok(CacheableResource::New {
entity_tag: EntityTag::default(),
data: headers,
})
});
let creds = Credentials::from(mock);
let signer = IamSigner::new("test@example.com".to_string(), creds, Some(endpoint));
let signature = signer.sign(b"test").await.unwrap();
assert_eq!(signature.as_ref(), b"signed_blob");
Ok(())
}
#[tokio::test]
async fn test_iam_client_email() -> TestResult {
let mock = MockCredentials::new();
let creds = Credentials::from(mock);
let signer = IamSigner::new("test@example.com".to_string(), creds, None);
let client_email = signer.client_email().await.unwrap();
assert_eq!(client_email, "test@example.com");
Ok(())
}
#[tokio::test]
async fn test_iam_sign_api_error() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::method_path(
"POST",
"/v1/projects/-/serviceAccounts/test@example.com:signBlob"
),])
.respond_with(status_code(500)),
);
let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
let mut mock = MockCredentials::new();
mock.expect_headers().return_once(|_extensions| {
Ok(CacheableResource::New {
entity_tag: EntityTag::default(),
data: HeaderMap::new(),
})
});
let creds = Credentials::from(mock);
let signer = IamSigner::new("test@example.com".to_string(), creds, Some(endpoint));
let err = signer.sign(b"test").await.unwrap_err();
assert!(err.is_transport());
Ok(())
}
#[tokio::test]
async fn test_iam_sign_retry() -> TestResult {
let server = Server::run();
let signed_blob = BASE64_STANDARD.encode("signed_blob");
let invalid_res = http::Response::builder()
.version(http::Version::HTTP_3) .status(204)
.body(Vec::new())
.unwrap();
server.expect(
Expectation::matching(all_of![request::method_path(
"POST",
"/v1/projects/-/serviceAccounts/test@example.com:signBlob"
),])
.times(3)
.respond_with(cycle![
invalid_res, status_code(503).body("try-again"),
json_encoded(json!({
"signedBlob": signed_blob,
}))
]),
);
let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
let mut mock = MockCredentials::new();
mock.expect_headers().returning(|_extensions| {
Ok(CacheableResource::New {
entity_tag: EntityTag::default(),
data: HeaderMap::new(),
})
});
let creds = Credentials::from(mock);
let signer = IamSigner::new("test@example.com".to_string(), creds, Some(endpoint));
let signature = signer.sign(b"test").await.unwrap();
assert_eq!(signature.as_ref(), b"signed_blob");
Ok(())
}
}