use crate::credentials::Credentials;
use crate::mds::client::Client as MDSClient;
use crate::signer::{Result, SigningError, dynamic::SigningProvider};
use std::sync::OnceLock;
#[derive(Clone, Debug)]
pub(crate) struct MDSSigner {
client: MDSClient,
iam_endpoint_override: Option<String>,
client_email: OnceLock<String>,
inner: Credentials,
}
impl MDSSigner {
pub(crate) fn new(client: MDSClient, inner: Credentials) -> Self {
Self {
client,
client_email: OnceLock::new(),
inner,
iam_endpoint_override: None,
}
}
pub(crate) fn with_iam_endpoint_override(mut self, endpoint: &str) -> Self {
self.iam_endpoint_override = Some(endpoint.to_string());
self
}
}
#[async_trait::async_trait]
impl SigningProvider for MDSSigner {
async fn client_email(&self) -> Result<String> {
if self.client_email.get().is_none() {
let email = self.fetch_client_email().await?;
let _ = self.client_email.set(email.clone());
return Ok(email);
}
Ok(self.client_email.get().unwrap().to_string())
}
async fn sign(&self, content: &[u8]) -> Result<bytes::Bytes> {
let client_email = self.client_email().await?;
let signer = crate::signer::iam::IamSigner::new(
client_email,
self.inner.clone(),
self.iam_endpoint_override.clone(),
);
signer.sign(content).await
}
}
impl MDSSigner {
async fn fetch_client_email(&self) -> Result<String> {
self.client
.email()
.send()
.await
.map_err(SigningError::transport)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::credentials::tests::MockCredentials;
use crate::credentials::{CacheableResource, Credentials, EntityTag};
use crate::mds::MDS_DEFAULT_URI;
use base64::{Engine, prelude::BASE64_STANDARD};
use http::HeaderMap;
use http::header::{HeaderName, HeaderValue};
use httptest::matchers::{all_of, contains, request};
use httptest::responders::{json_encoded, status_code};
use httptest::{Expectation, Server};
use serde_json::json;
use serial_test::serial;
type TestResult = anyhow::Result<()>;
#[ignore = "TODO(#5249) - disabled because it was flaky"]
#[tokio::test]
async fn test_fetch_client_email_and_cache() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
.times(1)
.respond_with(status_code(200).body("test-client-email")),
);
let mock = MockCredentials::new();
let creds = Credentials::from(mock);
let client = MDSClient::new(Some(format!("http://{}", server.addr())));
let signer = MDSSigner::new(client, creds);
let client_email = signer.client_email().await?;
assert_eq!(client_email, "test-client-email");
let client_email = signer.client_email().await?;
assert_eq!(client_email, "test-client-email");
Ok(())
}
#[tokio::test]
#[serial]
async fn test_sign() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
.times(1)
.respond_with(status_code(200).body("test-client-email")),
);
server.expect(
Expectation::matching(all_of![
request::method_path(
"POST",
"/v1/projects/-/serviceAccounts/test-client-email:signBlob"
),
request::headers(contains(("authorization", "Bearer test-value"))),
])
.respond_with(json_encoded(json!({
"signedBlob": BASE64_STANDARD.encode("signed_blob"),
}))),
);
let mut mock = MockCredentials::new();
mock.expect_headers().return_once(|_extensions| {
let headers = HeaderMap::from_iter([(
HeaderName::from_static("authorization"),
HeaderValue::from_static("Bearer test-value"),
)]);
Ok(CacheableResource::New {
entity_tag: EntityTag::default(),
data: headers,
})
});
let creds = Credentials::from(mock);
let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
let client = MDSClient::new(Some(endpoint.clone()));
let mut signer = MDSSigner::new(client, creds);
signer.iam_endpoint_override = Some(endpoint);
let client_email = signer.client_email().await?;
assert_eq!(client_email, "test-client-email");
let signature = signer.sign(b"test").await?;
assert_eq!(signature.as_ref(), b"signed_blob");
Ok(())
}
}