use async_trait::async_trait;
use serde::Deserialize;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::auth::types::{AccessToken, CachedToken};
use crate::token::{TokenError, TokenProvider};
const TOKEN_EXPIRY_BUFFER_SECS: u64 = 60;
const METADATA_SERVER_BASE: &str = "http://metadata.google.internal/computeMetadata/v1";
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
expires_in: u64,
#[allow(dead_code)]
token_type: String,
}
#[derive(Debug)]
pub struct MetadataServerCredential {
service_account: String,
cache: CachedToken,
http_client: reqwest::Client,
metadata_base_url: String,
}
impl MetadataServerCredential {
pub fn new() -> Self {
Self::with_service_account("default")
}
pub fn with_service_account(service_account: impl Into<String>) -> Self {
Self {
service_account: service_account.into(),
cache: CachedToken::new(),
http_client: reqwest::Client::new(),
metadata_base_url: METADATA_SERVER_BASE.to_string(),
}
}
pub fn with_custom_client(
service_account: impl Into<String>,
http_client: reqwest::Client,
metadata_base_url: impl Into<String>,
) -> Self {
Self {
service_account: service_account.into(),
cache: CachedToken::new(),
http_client,
metadata_base_url: metadata_base_url.into(),
}
}
pub fn service_account(&self) -> &str {
&self.service_account
}
fn token_url(&self) -> String {
format!(
"{}/instance/service-accounts/{}/token",
self.metadata_base_url, self.service_account
)
}
async fn fetch_token(&self) -> Result<AccessToken, MetadataServerError> {
let url = self.token_url();
let response = self
.http_client
.get(&url)
.header("Metadata-Flavor", "Google")
.send()
.await
.map_err(|e| MetadataServerError::RequestFailed {
message: format!("HTTP request failed: {}", e),
})?;
let status = response.status();
let response_text =
response
.text()
.await
.map_err(|e| MetadataServerError::RequestFailed {
message: format!("Failed to read response body: {}", e),
})?;
if !status.is_success() {
return Err(MetadataServerError::RequestFailed {
message: format!("Metadata server returned {}: {}", status, response_text),
});
}
let token_response: TokenResponse = serde_json::from_str(&response_text).map_err(|e| {
MetadataServerError::InvalidResponse {
message: format!("Failed to parse token response: {}", e),
}
})?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| MetadataServerError::RequestFailed {
message: format!("Failed to get current time: {}", e),
})?
.as_secs();
Ok(AccessToken::new(
token_response.access_token,
now + token_response.expires_in,
))
}
}
impl Default for MetadataServerCredential {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TokenProvider for MetadataServerCredential {
async fn get_token(&self, _scopes: &[&str]) -> Result<String, TokenError> {
if let Some(token) = self.cache.get(TOKEN_EXPIRY_BUFFER_SECS).await {
return Ok(token);
}
let token = self
.fetch_token()
.await
.map_err(|e| TokenError::RefreshFailed {
message: e.to_string(),
})?;
let token_string = token.token.clone();
self.cache.set(token).await;
Ok(token_string)
}
fn on_token_rejected(&self) {
self.cache.clear_sync();
}
}
#[derive(Debug, thiserror::Error)]
pub enum MetadataServerError {
#[error("Metadata server request failed: {message}")]
RequestFailed {
message: String,
},
#[error("Invalid metadata server response: {message}")]
InvalidResponse {
message: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[test]
fn test_new_default_service_account() {
let cred = MetadataServerCredential::new();
assert_eq!(cred.service_account(), "default");
}
#[test]
fn test_with_service_account() {
let cred = MetadataServerCredential::with_service_account(
"my-sa@my-project.iam.gserviceaccount.com",
);
assert_eq!(
cred.service_account(),
"my-sa@my-project.iam.gserviceaccount.com"
);
}
#[test]
fn test_default_impl() {
let cred = MetadataServerCredential::default();
assert_eq!(cred.service_account(), "default");
}
#[test]
fn test_token_url() {
let cred = MetadataServerCredential::new();
assert_eq!(
cred.token_url(),
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token"
);
let cred = MetadataServerCredential::with_service_account("custom@example.com");
assert_eq!(
cred.token_url(),
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/custom@example.com/token"
);
}
#[tokio::test]
async fn test_get_token_success() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/instance/service-accounts/default/token"))
.and(header("Metadata-Flavor", "Google"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "test-access-token-12345",
"expires_in": 3600,
"token_type": "Bearer"
})))
.expect(1)
.mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"default",
reqwest::Client::new(),
mock_server.uri(),
);
let token = cred.get_token(&["scope1"]).await.unwrap();
assert_eq!(token, "test-access-token-12345");
}
#[tokio::test]
async fn test_get_token_caching() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/instance/service-accounts/default/token"))
.and(header("Metadata-Flavor", "Google"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "cached-token",
"expires_in": 3600,
"token_type": "Bearer"
})))
.expect(1) .mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"default",
reqwest::Client::new(),
mock_server.uri(),
);
let token1 = cred.get_token(&["scope1"]).await.unwrap();
assert_eq!(token1, "cached-token");
let token2 = cred.get_token(&["scope1"]).await.unwrap();
assert_eq!(token2, "cached-token");
}
#[tokio::test]
async fn test_get_token_custom_service_account() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path(
"/instance/service-accounts/custom@project.iam.gserviceaccount.com/token",
))
.and(header("Metadata-Flavor", "Google"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "custom-sa-token",
"expires_in": 3600,
"token_type": "Bearer"
})))
.expect(1)
.mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"custom@project.iam.gserviceaccount.com",
reqwest::Client::new(),
mock_server.uri(),
);
let token = cred.get_token(&[]).await.unwrap();
assert_eq!(token, "custom-sa-token");
}
#[tokio::test]
async fn test_get_token_server_error() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/instance/service-accounts/default/token"))
.respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
.mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"default",
reqwest::Client::new(),
mock_server.uri(),
);
let result = cred.get_token(&[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, TokenError::RefreshFailed { .. }));
assert!(err.to_string().contains("500"));
}
#[tokio::test]
async fn test_get_token_invalid_json() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/instance/service-accounts/default/token"))
.respond_with(ResponseTemplate::new(200).set_body_string("not valid json"))
.mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"default",
reqwest::Client::new(),
mock_server.uri(),
);
let result = cred.get_token(&[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, TokenError::RefreshFailed { .. }));
}
#[tokio::test]
async fn test_get_token_not_found() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/instance/service-accounts/nonexistent/token"))
.respond_with(ResponseTemplate::new(404).set_body_string("Not Found"))
.mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"nonexistent",
reqwest::Client::new(),
mock_server.uri(),
);
let result = cred.get_token(&[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, TokenError::RefreshFailed { .. }));
assert!(err.to_string().contains("404"));
}
#[tokio::test]
async fn test_on_token_rejected_clears_cache() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/instance/service-accounts/default/token"))
.and(header("Metadata-Flavor", "Google"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "new-token",
"expires_in": 3600,
"token_type": "Bearer"
})))
.expect(2) .mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"default",
reqwest::Client::new(),
mock_server.uri(),
);
let _token1 = cred.get_token(&[]).await.unwrap();
cred.on_token_rejected();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let _token2 = cred.get_token(&[]).await.unwrap();
}
#[tokio::test]
async fn test_token_expiry_buffer() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/instance/service-accounts/default/token"))
.and(header("Metadata-Flavor", "Google"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "short-lived-token",
"expires_in": 30, "token_type": "Bearer"
})))
.expect(2) .mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"default",
reqwest::Client::new(),
mock_server.uri(),
);
let _token1 = cred.get_token(&[]).await.unwrap();
let _token2 = cred.get_token(&[]).await.unwrap();
}
#[tokio::test]
async fn test_scopes_are_ignored() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/instance/service-accounts/default/token"))
.and(header("Metadata-Flavor", "Google"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "same-token",
"expires_in": 3600,
"token_type": "Bearer"
})))
.expect(1)
.mount(&mock_server)
.await;
let cred = MetadataServerCredential::with_custom_client(
"default",
reqwest::Client::new(),
mock_server.uri(),
);
let token = cred
.get_token(&[
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/compute",
])
.await
.unwrap();
assert_eq!(token, "same-token");
}
#[test]
fn test_error_display() {
let err = MetadataServerError::RequestFailed {
message: "connection refused".to_string(),
};
assert!(err.to_string().contains("connection refused"));
assert!(err.to_string().contains("request failed"));
let err = MetadataServerError::InvalidResponse {
message: "invalid JSON".to_string(),
};
assert!(err.to_string().contains("invalid JSON"));
assert!(err.to_string().contains("Invalid metadata server response"));
}
}