use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_sdk_sts::Client as StsClient;
use std::collections::HashMap;
use super::provider::{
StorageCredential, StorageCredentialProvider, StorageCredentialRequest,
StorageCredentialVendingError,
};
const DEFAULT_DURATION_SECONDS: i32 = 3600;
#[derive(Debug, Clone)]
pub struct AwsStsConfig {
pub region: String,
pub role_arn: String,
pub external_id: Option<String>,
pub duration_seconds: i32,
pub allowed_prefixes: Vec<String>,
}
impl AwsStsConfig {
pub fn new(region: impl Into<String>, role_arn: impl Into<String>) -> Self {
Self {
region: region.into(),
role_arn: role_arn.into(),
external_id: None,
duration_seconds: DEFAULT_DURATION_SECONDS,
allowed_prefixes: vec![],
}
}
pub fn with_external_id(mut self, external_id: impl Into<String>) -> Self {
self.external_id = Some(external_id.into());
self
}
pub fn with_duration_seconds(mut self, seconds: i32) -> Self {
self.duration_seconds = seconds;
self
}
pub fn with_allowed_prefix(mut self, prefix: impl Into<String>) -> Self {
self.allowed_prefixes.push(prefix.into());
self
}
pub fn with_allowed_prefixes(mut self, prefixes: Vec<String>) -> Self {
self.allowed_prefixes = prefixes;
self
}
}
#[derive(Debug)]
pub struct AwsStsCredentialProvider {
config: AwsStsConfig,
sts_client: StsClient,
}
impl AwsStsCredentialProvider {
pub async fn new(config: AwsStsConfig) -> Result<Self, StorageCredentialVendingError> {
let aws_config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_config::Region::new(config.region.clone()))
.load()
.await;
let sts_client = StsClient::new(&aws_config);
Ok(Self { config, sts_client })
}
pub fn with_client(config: AwsStsConfig, sts_client: StsClient) -> Self {
Self { config, sts_client }
}
fn is_location_allowed(&self, location: &str) -> bool {
if self.config.allowed_prefixes.is_empty() {
return true;
}
self.config
.allowed_prefixes
.iter()
.any(|prefix| location.starts_with(prefix))
}
fn get_table_prefix(location: &str) -> String {
if location.ends_with('/') {
location.to_string()
} else {
format!("{}/", location)
}
}
}
#[async_trait]
impl StorageCredentialProvider for AwsStsCredentialProvider {
async fn vend_credentials(
&self,
request: &StorageCredentialRequest,
) -> Result<Vec<StorageCredential>, StorageCredentialVendingError> {
if !self.is_location_allowed(&request.table_location) {
return Ok(vec![]);
}
let session_name = request.session_name();
let mut assume_role = self
.sts_client
.assume_role()
.role_arn(&self.config.role_arn)
.role_session_name(&session_name)
.duration_seconds(self.config.duration_seconds);
if let Some(ref external_id) = self.config.external_id {
assume_role = assume_role.external_id(external_id);
}
let response = assume_role.send().await.map_err(|e| {
StorageCredentialVendingError::AwsStsError(format!(
"Failed to assume role {}: {}",
self.config.role_arn, e
))
})?;
let credentials = response.credentials.ok_or_else(|| {
StorageCredentialVendingError::AwsStsError(
"AssumeRole response did not contain credentials".to_string(),
)
})?;
let access_key_id = credentials.access_key_id;
let secret_access_key = credentials.secret_access_key;
let session_token = credentials.session_token;
let prefix = Self::get_table_prefix(&request.table_location);
let credential = StorageCredential::s3(
prefix,
access_key_id,
secret_access_key,
Some(session_token),
);
Ok(vec![credential])
}
fn supports_location(&self, location: &str) -> bool {
location.starts_with("s3://") || location.starts_with("s3a://")
}
}
#[derive(Debug, Clone)]
pub struct AwsStsCredentialProviderBuilder {
region: String,
tenant_roles: HashMap<String, String>,
role_pattern: Option<String>,
external_id: Option<String>,
duration_seconds: i32,
allowed_prefixes: Vec<String>,
}
impl AwsStsCredentialProviderBuilder {
pub fn new(region: impl Into<String>) -> Self {
Self {
region: region.into(),
tenant_roles: HashMap::new(),
role_pattern: None,
external_id: None,
duration_seconds: DEFAULT_DURATION_SECONDS,
allowed_prefixes: vec![],
}
}
pub fn with_tenant_role(
mut self,
tenant_id: impl Into<String>,
role_arn: impl Into<String>,
) -> Self {
self.tenant_roles.insert(tenant_id.into(), role_arn.into());
self
}
pub fn with_role_pattern(mut self, pattern: impl Into<String>) -> Self {
self.role_pattern = Some(pattern.into());
self
}
pub fn with_external_id(mut self, external_id: impl Into<String>) -> Self {
self.external_id = Some(external_id.into());
self
}
pub fn with_duration_seconds(mut self, seconds: i32) -> Self {
self.duration_seconds = seconds;
self
}
pub fn with_allowed_prefix(mut self, prefix: impl Into<String>) -> Self {
self.allowed_prefixes.push(prefix.into());
self
}
pub fn get_role_arn(&self, tenant_id: &str) -> Option<String> {
if let Some(role) = self.tenant_roles.get(tenant_id) {
return Some(role.clone());
}
if let Some(ref pattern) = self.role_pattern {
return Some(pattern.replace("{tenant_id}", tenant_id));
}
None
}
pub fn build_config(
&self,
tenant_id: &str,
) -> Result<AwsStsConfig, StorageCredentialVendingError> {
let role_arn = self.get_role_arn(tenant_id).ok_or_else(|| {
StorageCredentialVendingError::ConfigurationError(format!(
"No role ARN configured for tenant: {}",
tenant_id
))
})?;
let mut config = AwsStsConfig::new(self.region.clone(), role_arn)
.with_duration_seconds(self.duration_seconds)
.with_allowed_prefixes(self.allowed_prefixes.clone());
if let Some(ref external_id) = self.external_id {
config = config.with_external_id(external_id.clone());
}
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = AwsStsConfig::new("us-west-2", "arn:aws:iam::123456789012:role/TestRole")
.with_external_id("ext-123")
.with_duration_seconds(1800)
.with_allowed_prefix("s3://bucket1/")
.with_allowed_prefix("s3://bucket2/");
assert_eq!(config.region, "us-west-2");
assert_eq!(config.role_arn, "arn:aws:iam::123456789012:role/TestRole");
assert_eq!(config.external_id, Some("ext-123".to_string()));
assert_eq!(config.duration_seconds, 1800);
assert_eq!(config.allowed_prefixes.len(), 2);
}
#[test]
fn test_provider_builder_explicit_mapping() {
let builder = AwsStsCredentialProviderBuilder::new("us-east-1")
.with_tenant_role("tenant-a", "arn:aws:iam::111:role/TenantA")
.with_tenant_role("tenant-b", "arn:aws:iam::222:role/TenantB");
assert_eq!(
builder.get_role_arn("tenant-a"),
Some("arn:aws:iam::111:role/TenantA".to_string())
);
assert_eq!(
builder.get_role_arn("tenant-b"),
Some("arn:aws:iam::222:role/TenantB".to_string())
);
assert_eq!(builder.get_role_arn("tenant-c"), None);
}
#[test]
fn test_provider_builder_pattern() {
let builder = AwsStsCredentialProviderBuilder::new("us-east-1")
.with_role_pattern("arn:aws:iam::123456789012:role/iceberg-{tenant_id}-access");
assert_eq!(
builder.get_role_arn("tenant-123"),
Some("arn:aws:iam::123456789012:role/iceberg-tenant-123-access".to_string())
);
}
#[test]
fn test_provider_builder_explicit_over_pattern() {
let builder = AwsStsCredentialProviderBuilder::new("us-east-1")
.with_tenant_role("special-tenant", "arn:aws:iam::999:role/SpecialRole")
.with_role_pattern("arn:aws:iam::123:role/{tenant_id}");
assert_eq!(
builder.get_role_arn("special-tenant"),
Some("arn:aws:iam::999:role/SpecialRole".to_string())
);
assert_eq!(
builder.get_role_arn("other-tenant"),
Some("arn:aws:iam::123:role/other-tenant".to_string())
);
}
#[test]
fn test_table_prefix() {
assert_eq!(
AwsStsCredentialProvider::get_table_prefix("s3://bucket/warehouse/ns/table"),
"s3://bucket/warehouse/ns/table/"
);
assert_eq!(
AwsStsCredentialProvider::get_table_prefix("s3://bucket/warehouse/ns/table/"),
"s3://bucket/warehouse/ns/table/"
);
}
#[test]
fn test_supports_location() {
let config = AwsStsConfig::new("us-east-1", "arn:aws:iam::123:role/Test");
assert!(config.role_arn.contains("Test"));
let config_with_prefixes = AwsStsConfig {
region: "us-east-1".to_string(),
role_arn: "arn:aws:iam::123:role/Test".to_string(),
external_id: None,
duration_seconds: 3600,
allowed_prefixes: vec!["s3://allowed-bucket/".to_string()],
};
assert_eq!(config_with_prefixes.allowed_prefixes.len(), 1);
assert!(config_with_prefixes.allowed_prefixes[0].starts_with("s3://"));
}
}