use async_trait::async_trait;
use google_cloud_auth::credentials::service_account::{
AccessSpecifier, Builder as ServiceAccountBuilder,
};
use google_cloud_auth::credentials::AccessTokenCredentials;
use std::sync::Arc;
use tokio::sync::RwLock;
use super::provider::{
StorageCredential, StorageCredentialProvider, StorageCredentialRequest,
StorageCredentialVendingError,
};
const GCS_SCOPE: &str = "https://www.googleapis.com/auth/devstorage.read_write";
const GCS_SCOPE_READ_ONLY: &str = "https://www.googleapis.com/auth/devstorage.read_only";
#[derive(Debug, Clone, Default)]
pub struct GcsConfig {
pub service_account_key_path: Option<String>,
pub allowed_prefixes: Vec<String>,
pub default_read_only: bool,
pub custom_scopes: Option<Vec<String>>,
}
impl GcsConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_service_account_key_path(mut self, path: impl Into<String>) -> Self {
self.service_account_key_path = Some(path.into());
self
}
pub fn with_allowed_prefix(mut self, prefix: impl Into<String>) -> Self {
self.allowed_prefixes.push(prefix.into());
self
}
pub fn with_allowed_prefixes(mut self, prefixes: Vec<String>) -> Self {
self.allowed_prefixes = prefixes;
self
}
pub fn with_default_read_only(mut self, read_only: bool) -> Self {
self.default_read_only = read_only;
self
}
pub fn with_custom_scopes(mut self, scopes: Vec<String>) -> Self {
self.custom_scopes = Some(scopes);
self
}
fn get_scope(&self, write_access: bool) -> &str {
if let Some(ref scopes) = self.custom_scopes {
return scopes.first().map(|s| s.as_str()).unwrap_or(GCS_SCOPE);
}
if write_access && !self.default_read_only {
GCS_SCOPE
} else {
GCS_SCOPE_READ_ONLY
}
}
}
#[derive(Debug)]
struct CachedToken {
token: String,
expires_at: std::time::Instant,
}
impl CachedToken {
fn is_valid(&self) -> bool {
self.expires_at > std::time::Instant::now() + std::time::Duration::from_secs(300)
}
}
pub struct GcsCredentialProvider {
config: GcsConfig,
credentials: Arc<AccessTokenCredentials>,
cached_token: Arc<RwLock<Option<CachedToken>>>,
}
impl std::fmt::Debug for GcsCredentialProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GcsCredentialProvider")
.field("config", &self.config)
.field("credentials", &"<AccessTokenCredentials>")
.finish()
}
}
impl GcsCredentialProvider {
pub async fn new(config: GcsConfig) -> Result<Self, StorageCredentialVendingError> {
let key_path = config.service_account_key_path.as_ref().ok_or_else(|| {
StorageCredentialVendingError::ConfigurationError(
"GCS credential provider requires a service account key file path".to_string(),
)
})?;
let key_json = std::fs::read_to_string(key_path).map_err(|e| {
StorageCredentialVendingError::GcsError(format!(
"Failed to read service account key from {}: {}",
key_path, e
))
})?;
let service_account_key: serde_json::Value =
serde_json::from_str(&key_json).map_err(|e| {
StorageCredentialVendingError::GcsError(format!(
"Failed to parse service account key JSON: {}",
e
))
})?;
let scope = config.get_scope(true);
let credentials = ServiceAccountBuilder::new(service_account_key)
.with_access_specifier(AccessSpecifier::from_scopes([scope]))
.build_access_token_credentials()
.map_err(|e| {
StorageCredentialVendingError::GcsError(format!(
"Failed to build service account credentials: {}",
e
))
})?;
Ok(Self {
config,
credentials: Arc::new(credentials),
cached_token: Arc::new(RwLock::new(None)),
})
}
fn is_location_allowed(&self, location: &str) -> bool {
if self.config.allowed_prefixes.is_empty() {
return true;
}
self.config
.allowed_prefixes
.iter()
.any(|prefix| location.starts_with(prefix))
}
fn get_table_prefix(location: &str) -> String {
if location.ends_with('/') {
location.to_string()
} else {
format!("{}/", location)
}
}
async fn get_token(&self) -> Result<String, StorageCredentialVendingError> {
{
let cached = self.cached_token.read().await;
if let Some(ref token) = *cached {
if token.is_valid() {
return Ok(token.token.clone());
}
}
}
let token = self.credentials.access_token().await.map_err(|e| {
StorageCredentialVendingError::GcsError(format!("Failed to obtain token: {}", e))
})?;
let access_token = token.token;
let expires_at = std::time::Instant::now() + std::time::Duration::from_secs(3600);
{
let mut cached = self.cached_token.write().await;
*cached = Some(CachedToken {
token: access_token.clone(),
expires_at,
});
}
Ok(access_token)
}
}
#[async_trait]
impl StorageCredentialProvider for GcsCredentialProvider {
async fn vend_credentials(
&self,
request: &StorageCredentialRequest,
) -> Result<Vec<StorageCredential>, StorageCredentialVendingError> {
if !self.is_location_allowed(&request.table_location) {
return Ok(vec![]);
}
let token = self.get_token().await?;
let prefix = Self::get_table_prefix(&request.table_location);
let credential = StorageCredential::gcs(prefix, token);
Ok(vec![credential])
}
fn supports_location(&self, location: &str) -> bool {
location.starts_with("gs://") || location.starts_with("gcs://")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = GcsConfig::new()
.with_service_account_key_path("/path/to/key.json")
.with_allowed_prefix("gs://bucket1/")
.with_allowed_prefix("gs://bucket2/")
.with_default_read_only(true);
assert_eq!(
config.service_account_key_path,
Some("/path/to/key.json".to_string())
);
assert_eq!(config.allowed_prefixes.len(), 2);
assert!(config.default_read_only);
}
#[test]
fn test_get_scope() {
let config = GcsConfig::new();
assert_eq!(config.get_scope(true), GCS_SCOPE);
assert_eq!(config.get_scope(false), GCS_SCOPE_READ_ONLY);
let config_read_only = GcsConfig::new().with_default_read_only(true);
assert_eq!(config_read_only.get_scope(true), GCS_SCOPE_READ_ONLY);
let config_custom = GcsConfig::new().with_custom_scopes(vec!["custom-scope".to_string()]);
assert_eq!(config_custom.get_scope(true), "custom-scope");
}
#[test]
fn test_table_prefix() {
assert_eq!(
GcsCredentialProvider::get_table_prefix("gs://bucket/warehouse/ns/table"),
"gs://bucket/warehouse/ns/table/"
);
assert_eq!(
GcsCredentialProvider::get_table_prefix("gs://bucket/warehouse/ns/table/"),
"gs://bucket/warehouse/ns/table/"
);
}
#[test]
fn test_location_allowed() {
let config = GcsConfig::new().with_allowed_prefix("gs://allowed-bucket/");
let provider = MockGcsProvider { config };
assert!(provider.is_location_allowed("gs://allowed-bucket/data/table"));
assert!(!provider.is_location_allowed("gs://other-bucket/data/table"));
let config_no_restrictions = GcsConfig::new();
let provider_no_restrictions = MockGcsProvider {
config: config_no_restrictions,
};
assert!(provider_no_restrictions.is_location_allowed("gs://any-bucket/"));
}
struct MockGcsProvider {
config: GcsConfig,
}
impl MockGcsProvider {
fn is_location_allowed(&self, location: &str) -> bool {
if self.config.allowed_prefixes.is_empty() {
return true;
}
self.config
.allowed_prefixes
.iter()
.any(|prefix| location.starts_with(prefix))
}
}
#[test]
fn test_storage_credential_gcs() {
let cred = StorageCredential::gcs("gs://my-bucket/warehouse/", "ya29.example-token");
assert_eq!(cred.prefix, "gs://my-bucket/warehouse/");
assert_eq!(
cred.config.get("gcs.oauth2.token").unwrap(),
"ya29.example-token"
);
}
}