use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum StorageCredentialVendingError {
#[error("AWS STS error: {0}")]
AwsStsError(String),
#[error("GCS credential error: {0}")]
GcsError(String),
#[error("Azure credential error: {0}")]
AzureError(String),
#[error("Unsupported storage location: {0}")]
UnsupportedLocation(String),
#[error("Configuration error: {0}")]
ConfigurationError(String),
#[error("Permission denied: {0}")]
PermissionDenied(String),
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct StorageCredential {
pub prefix: String,
pub config: HashMap<String, String>,
}
impl std::fmt::Debug for StorageCredential {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let redacted_config: HashMap<&str, &str> = self
.config
.keys()
.map(|k| {
let v = if k.contains("secret") || k.contains("token") || k.contains("password") {
"[REDACTED]"
} else {
self.config.get(k).map(|s| s.as_str()).unwrap_or("")
};
(k.as_str(), v)
})
.collect();
f.debug_struct("StorageCredential")
.field("prefix", &self.prefix)
.field("config", &redacted_config)
.finish()
}
}
impl StorageCredential {
pub fn new(prefix: impl Into<String>, config: HashMap<String, String>) -> Self {
Self {
prefix: prefix.into(),
config,
}
}
pub fn s3(
prefix: impl Into<String>,
access_key_id: impl Into<String>,
secret_access_key: impl Into<String>,
session_token: Option<String>,
) -> Self {
let mut config = HashMap::new();
config.insert("s3.access-key-id".to_string(), access_key_id.into());
config.insert("s3.secret-access-key".to_string(), secret_access_key.into());
if let Some(token) = session_token {
config.insert("s3.session-token".to_string(), token);
}
Self::new(prefix, config)
}
pub fn gcs(prefix: impl Into<String>, oauth2_token: impl Into<String>) -> Self {
let mut config = HashMap::new();
config.insert("gcs.oauth2.token".to_string(), oauth2_token.into());
Self::new(prefix, config)
}
pub fn azure(
prefix: impl Into<String>,
account: impl Into<String>,
sas_token: impl Into<String>,
) -> Self {
let mut config = HashMap::new();
let key = format!("adls.sas-token.{}", account.into());
config.insert(key, sas_token.into());
Self::new(prefix, config)
}
}
#[derive(Debug, Clone)]
pub struct StorageCredentialRequest {
pub tenant_id: String,
pub namespace: Vec<String>,
pub table_name: String,
pub table_location: String,
pub write_access: bool,
}
impl StorageCredentialRequest {
pub fn read_only(
tenant_id: impl Into<String>,
namespace: Vec<String>,
table_name: impl Into<String>,
table_location: impl Into<String>,
) -> Self {
Self {
tenant_id: tenant_id.into(),
namespace,
table_name: table_name.into(),
table_location: table_location.into(),
write_access: false,
}
}
pub fn with_write_access(
tenant_id: impl Into<String>,
namespace: Vec<String>,
table_name: impl Into<String>,
table_location: impl Into<String>,
) -> Self {
Self {
tenant_id: tenant_id.into(),
namespace,
table_name: table_name.into(),
table_location: table_location.into(),
write_access: true,
}
}
pub fn session_name(&self) -> String {
let ns = self.namespace.join("-");
format!("rustberg-{}-{}-{}", self.tenant_id, ns, self.table_name)
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_')
.take(64) .collect()
}
}
#[async_trait]
pub trait StorageCredentialProvider: Send + Sync + fmt::Debug {
async fn vend_credentials(
&self,
request: &StorageCredentialRequest,
) -> Result<Vec<StorageCredential>, StorageCredentialVendingError>;
fn supports_location(&self, location: &str) -> bool;
}
#[derive(Debug, Clone, Default)]
pub struct NoopCredentialProvider;
impl NoopCredentialProvider {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl StorageCredentialProvider for NoopCredentialProvider {
async fn vend_credentials(
&self,
_request: &StorageCredentialRequest,
) -> Result<Vec<StorageCredential>, StorageCredentialVendingError> {
Ok(vec![])
}
fn supports_location(&self, _location: &str) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_storage_credential_s3() {
let cred = StorageCredential::s3(
"s3://my-bucket/warehouse/",
"AKIAIOSFODNN7EXAMPLE",
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
Some("token123".to_string()),
);
assert_eq!(cred.prefix, "s3://my-bucket/warehouse/");
assert_eq!(
cred.config.get("s3.access-key-id").unwrap(),
"AKIAIOSFODNN7EXAMPLE"
);
assert_eq!(
cred.config.get("s3.secret-access-key").unwrap(),
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
);
assert_eq!(cred.config.get("s3.session-token").unwrap(), "token123");
}
#[test]
fn test_storage_credential_gcs() {
let cred = StorageCredential::gcs("gs://my-bucket/warehouse/", "ya29.example-token");
assert_eq!(cred.prefix, "gs://my-bucket/warehouse/");
assert_eq!(
cred.config.get("gcs.oauth2.token").unwrap(),
"ya29.example-token"
);
}
#[test]
fn test_storage_credential_azure() {
let cred = StorageCredential::azure(
"abfss://container@account.dfs.core.windows.net/warehouse/",
"account",
"?sv=2021-06-08&ss=bfqt&srt=sco&sp=rwdlacupitfx",
);
assert_eq!(
cred.prefix,
"abfss://container@account.dfs.core.windows.net/warehouse/"
);
assert_eq!(
cred.config.get("adls.sas-token.account").unwrap(),
"?sv=2021-06-08&ss=bfqt&srt=sco&sp=rwdlacupitfx"
);
}
#[test]
fn test_credential_request_session_name() {
let request = StorageCredentialRequest::read_only(
"tenant-123",
vec!["prod".to_string(), "analytics".to_string()],
"sales_data",
"s3://bucket/warehouse/prod/analytics/sales_data",
);
let session_name = request.session_name();
assert!(session_name.starts_with("rustberg-tenant-123-prod-analytics-sales_data"));
assert!(session_name.len() <= 64);
}
#[tokio::test]
async fn test_noop_provider() {
let provider = NoopCredentialProvider::new();
let request = StorageCredentialRequest::read_only(
"tenant-1",
vec!["ns".to_string()],
"table",
"s3://bucket/ns/table",
);
let credentials = provider.vend_credentials(&request).await.unwrap();
assert!(credentials.is_empty());
assert!(!provider.supports_location("s3://bucket/"));
}
#[test]
fn test_credential_debug_redacts_secrets() {
let cred = StorageCredential::s3(
"s3://bucket/",
"AKIAIOSFODNN7EXAMPLE",
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
Some("session-token-value".to_string()),
);
let debug_output = format!("{:?}", cred);
assert!(
debug_output.contains("AKIAIOSFODNN7EXAMPLE"),
"access-key-id should be visible in debug output"
);
assert!(
!debug_output.contains("wJalrXUtnFEMI"),
"secret-access-key must be redacted"
);
assert!(
debug_output.contains("[REDACTED]"),
"redacted placeholder must appear"
);
assert!(
!debug_output.contains("session-token-value"),
"session-token must be redacted"
);
}
#[test]
fn test_credential_debug_redacts_gcs_token() {
let cred = StorageCredential::gcs("gs://bucket/", "ya29.super-secret-token");
let debug_output = format!("{:?}", cred);
assert!(
!debug_output.contains("ya29.super-secret-token"),
"GCS OAuth2 token must be redacted"
);
assert!(debug_output.contains("[REDACTED]"));
}
#[test]
fn test_credential_debug_redacts_azure_sas() {
let cred = StorageCredential::azure(
"abfss://container@account.dfs.core.windows.net/",
"account",
"?sv=2021-06-08&ss=bfqt&srt=sco&sp=rwdlacupitfx",
);
let debug_output = format!("{:?}", cred);
assert!(
!debug_output.contains("sv=2021-06-08"),
"Azure SAS token must be redacted"
);
assert!(debug_output.contains("[REDACTED]"));
}
}