#[cfg(feature = "credential-vendor-aws")]
pub mod aws;
#[cfg(feature = "credential-vendor-azure")]
pub mod azure;
#[cfg(feature = "credential-vendor-gcp")]
pub mod gcp;
#[cfg(any(
feature = "credential-vendor-aws",
feature = "credential-vendor-azure",
feature = "credential-vendor-gcp"
))]
pub mod cache;
use std::collections::HashMap;
use std::str::FromStr;
use async_trait::async_trait;
use lance_core::Result;
use lance_io::object_store::uri_to_url;
use lance_namespace::models::Identity;
pub const DEFAULT_CREDENTIAL_DURATION_MILLIS: u64 = 3600 * 1000;
pub fn redact_credential(credential: &str) -> String {
const SHOW_START: usize = 8;
const SHOW_END: usize = 4;
const MIN_LENGTH_FOR_BOTH_ENDS: usize = SHOW_START + SHOW_END + 4;
if credential.is_empty() {
return "[empty]".to_string();
}
if credential.len() < MIN_LENGTH_FOR_BOTH_ENDS {
let show = credential.len().min(SHOW_START);
format!("{}***", &credential[..show])
} else {
format!(
"{}***{}",
&credential[..SHOW_START],
&credential[credential.len() - SHOW_END..]
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum VendedPermission {
#[default]
Read,
Write,
Admin,
}
impl VendedPermission {
pub fn can_write(&self) -> bool {
matches!(self, Self::Write | Self::Admin)
}
pub fn can_delete(&self) -> bool {
matches!(self, Self::Admin)
}
}
impl FromStr for VendedPermission {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"read" => Ok(Self::Read),
"write" => Ok(Self::Write),
"admin" => Ok(Self::Admin),
_ => Err(format!(
"Invalid permission '{}'. Must be one of: read, write, admin",
s
)),
}
}
}
impl std::fmt::Display for VendedPermission {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Read => write!(f, "read"),
Self::Write => write!(f, "write"),
Self::Admin => write!(f, "admin"),
}
}
}
pub const PROPERTY_PREFIX: &str = "credential_vendor.";
pub const ENABLED: &str = "enabled";
pub const PERMISSION: &str = "permission";
pub const CACHE_ENABLED: &str = "cache_enabled";
pub const API_KEY_SALT: &str = "api_key_salt";
pub const API_KEY_HASH_PREFIX: &str = "api_key_hash.";
#[cfg(feature = "credential-vendor-aws")]
pub mod aws_props {
pub const ROLE_ARN: &str = "aws_role_arn";
pub const EXTERNAL_ID: &str = "aws_external_id";
pub const REGION: &str = "aws_region";
pub const ROLE_SESSION_NAME: &str = "aws_role_session_name";
pub const DURATION_MILLIS: &str = "aws_duration_millis";
}
#[cfg(feature = "credential-vendor-gcp")]
pub mod gcp_props {
pub const SERVICE_ACCOUNT: &str = "gcp_service_account";
pub const WORKLOAD_IDENTITY_PROVIDER: &str = "gcp_workload_identity_provider";
pub const IMPERSONATION_SERVICE_ACCOUNT: &str = "gcp_impersonation_service_account";
}
#[cfg(feature = "credential-vendor-azure")]
pub mod azure_props {
pub const TENANT_ID: &str = "azure_tenant_id";
pub const ACCOUNT_NAME: &str = "azure_account_name";
pub const DURATION_MILLIS: &str = "azure_duration_millis";
pub const FEDERATED_CLIENT_ID: &str = "azure_federated_client_id";
}
#[derive(Clone)]
pub struct VendedCredentials {
pub storage_options: HashMap<String, String>,
pub expires_at_millis: u64,
}
impl std::fmt::Debug for VendedCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VendedCredentials")
.field(
"storage_options",
&format!("[{} keys redacted]", self.storage_options.len()),
)
.field("expires_at_millis", &self.expires_at_millis)
.finish()
}
}
impl VendedCredentials {
pub fn new(storage_options: HashMap<String, String>, expires_at_millis: u64) -> Self {
Self {
storage_options,
expires_at_millis,
}
}
pub fn is_expired(&self) -> bool {
let now_millis = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time went backwards")
.as_millis() as u64;
now_millis >= self.expires_at_millis
}
}
#[async_trait]
pub trait CredentialVendor: Send + Sync + std::fmt::Debug {
async fn vend_credentials(
&self,
table_location: &str,
identity: Option<&Identity>,
) -> Result<VendedCredentials>;
fn provider_name(&self) -> &'static str;
fn permission(&self) -> VendedPermission;
}
pub fn detect_provider_from_uri(uri: &str) -> &'static str {
let Ok(url) = uri_to_url(uri) else {
return "unknown";
};
match url.scheme() {
"s3" => "aws",
"gs" => "gcp",
"az" | "abfss" => "azure",
_ => "unknown",
}
}
pub fn has_credential_vendor_config(properties: &HashMap<String, String>) -> bool {
properties
.get(ENABLED)
.map(|v| v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
#[allow(unused_variables)]
pub async fn create_credential_vendor_for_location(
table_location: &str,
properties: &HashMap<String, String>,
) -> Result<Option<Box<dyn CredentialVendor>>> {
let provider = detect_provider_from_uri(table_location);
let vendor: Option<Box<dyn CredentialVendor>> = match provider {
#[cfg(feature = "credential-vendor-aws")]
"aws" => create_aws_vendor(properties).await?,
#[cfg(feature = "credential-vendor-gcp")]
"gcp" => create_gcp_vendor(properties).await?,
#[cfg(feature = "credential-vendor-azure")]
"azure" => create_azure_vendor(properties)?,
_ => None,
};
#[cfg(any(
feature = "credential-vendor-aws",
feature = "credential-vendor-azure",
feature = "credential-vendor-gcp"
))]
if let Some(v) = vendor {
let cache_enabled = properties
.get(CACHE_ENABLED)
.map(|s| !s.eq_ignore_ascii_case("false"))
.unwrap_or(true);
if cache_enabled {
return Ok(Some(Box::new(cache::CachingCredentialVendor::new(v))));
} else {
return Ok(Some(v));
}
}
#[cfg(not(any(
feature = "credential-vendor-aws",
feature = "credential-vendor-azure",
feature = "credential-vendor-gcp"
)))]
let _ = vendor;
Ok(None)
}
#[cfg(any(
test,
feature = "credential-vendor-aws",
feature = "credential-vendor-azure",
feature = "credential-vendor-gcp"
))]
fn parse_permission(properties: &HashMap<String, String>) -> VendedPermission {
properties
.get(PERMISSION)
.and_then(|s| s.parse().ok())
.unwrap_or_default()
}
#[cfg(any(
test,
feature = "credential-vendor-aws",
feature = "credential-vendor-azure",
feature = "credential-vendor-gcp"
))]
fn parse_duration_millis(properties: &HashMap<String, String>, key: &str) -> u64 {
properties
.get(key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(DEFAULT_CREDENTIAL_DURATION_MILLIS)
}
#[cfg(feature = "credential-vendor-aws")]
async fn create_aws_vendor(
properties: &HashMap<String, String>,
) -> Result<Option<Box<dyn CredentialVendor>>> {
use aws::{AwsCredentialVendor, AwsCredentialVendorConfig};
use lance_namespace::error::NamespaceError;
let role_arn = properties.get(aws_props::ROLE_ARN).ok_or_else(|| {
lance_core::Error::from(NamespaceError::InvalidInput {
message: "AWS credential vending requires 'credential_vendor.aws_role_arn' to be set"
.to_string(),
})
})?;
let duration_millis = parse_duration_millis(properties, aws_props::DURATION_MILLIS);
let permission = parse_permission(properties);
let mut config = AwsCredentialVendorConfig::new(role_arn)
.with_duration_millis(duration_millis)
.with_permission(permission);
if let Some(external_id) = properties.get(aws_props::EXTERNAL_ID) {
config = config.with_external_id(external_id);
}
if let Some(region) = properties.get(aws_props::REGION) {
config = config.with_region(region);
}
if let Some(session_name) = properties.get(aws_props::ROLE_SESSION_NAME) {
config = config.with_role_session_name(session_name);
}
let vendor = AwsCredentialVendor::new(config).await?;
Ok(Some(Box::new(vendor)))
}
#[cfg(feature = "credential-vendor-gcp")]
async fn create_gcp_vendor(
properties: &HashMap<String, String>,
) -> Result<Option<Box<dyn CredentialVendor>>> {
use gcp::{GcpCredentialVendor, GcpCredentialVendorConfig};
let permission = parse_permission(properties);
let mut config = GcpCredentialVendorConfig::new().with_permission(permission);
if let Some(sa) = properties.get(gcp_props::SERVICE_ACCOUNT) {
config = config.with_service_account(sa);
}
if let Some(provider) = properties.get(gcp_props::WORKLOAD_IDENTITY_PROVIDER) {
config = config.with_workload_identity_provider(provider);
}
if let Some(service_account) = properties.get(gcp_props::IMPERSONATION_SERVICE_ACCOUNT) {
config = config.with_impersonation_service_account(service_account);
}
let vendor = GcpCredentialVendor::new(config)?;
Ok(Some(Box::new(vendor)))
}
#[cfg(feature = "credential-vendor-azure")]
fn create_azure_vendor(
properties: &HashMap<String, String>,
) -> Result<Option<Box<dyn CredentialVendor>>> {
use azure::{AzureCredentialVendor, AzureCredentialVendorConfig};
use lance_namespace::error::NamespaceError;
let account_name = properties.get(azure_props::ACCOUNT_NAME).ok_or_else(|| {
lance_core::Error::from(NamespaceError::InvalidInput {
message:
"Azure credential vending requires 'credential_vendor.azure_account_name' to be set"
.to_string(),
})
})?;
let duration_millis = parse_duration_millis(properties, azure_props::DURATION_MILLIS);
let permission = parse_permission(properties);
let mut config = AzureCredentialVendorConfig::new()
.with_account_name(account_name)
.with_duration_millis(duration_millis)
.with_permission(permission);
if let Some(tenant_id) = properties.get(azure_props::TENANT_ID) {
config = config.with_tenant_id(tenant_id);
}
if let Some(client_id) = properties.get(azure_props::FEDERATED_CLIENT_ID) {
config = config.with_federated_client_id(client_id);
}
let vendor = AzureCredentialVendor::new(config);
Ok(Some(Box::new(vendor)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_provider_from_uri() {
assert_eq!(detect_provider_from_uri("s3://bucket/path"), "aws");
assert_eq!(detect_provider_from_uri("S3://bucket/path"), "aws");
assert_eq!(detect_provider_from_uri("gs://bucket/path"), "gcp");
assert_eq!(detect_provider_from_uri("GS://bucket/path"), "gcp");
assert_eq!(detect_provider_from_uri("az://container/path"), "azure");
assert_eq!(
detect_provider_from_uri("az://container@account.blob.core.windows.net/path"),
"azure"
);
assert_eq!(
detect_provider_from_uri("abfss://container@account.dfs.core.windows.net/path"),
"azure"
);
assert_eq!(detect_provider_from_uri("/local/path"), "unknown");
assert_eq!(detect_provider_from_uri("file:///local/path"), "unknown");
assert_eq!(detect_provider_from_uri("memory://test"), "unknown");
assert_eq!(detect_provider_from_uri("s3a://bucket/path"), "unknown");
assert_eq!(
detect_provider_from_uri("wasbs://container@account.blob.core.windows.net/path"),
"unknown"
);
}
#[test]
fn test_vended_permission_from_str() {
assert_eq!(
"read".parse::<VendedPermission>().unwrap(),
VendedPermission::Read
);
assert_eq!(
"READ".parse::<VendedPermission>().unwrap(),
VendedPermission::Read
);
assert_eq!(
"write".parse::<VendedPermission>().unwrap(),
VendedPermission::Write
);
assert_eq!(
"WRITE".parse::<VendedPermission>().unwrap(),
VendedPermission::Write
);
assert_eq!(
"admin".parse::<VendedPermission>().unwrap(),
VendedPermission::Admin
);
assert_eq!(
"Admin".parse::<VendedPermission>().unwrap(),
VendedPermission::Admin
);
let err = "invalid".parse::<VendedPermission>().unwrap_err();
assert!(err.contains("Invalid permission"));
assert!(err.contains("invalid"));
let err = "".parse::<VendedPermission>().unwrap_err();
assert!(err.contains("Invalid permission"));
let err = "readwrite".parse::<VendedPermission>().unwrap_err();
assert!(err.contains("Invalid permission"));
}
#[test]
fn test_vended_permission_display() {
assert_eq!(VendedPermission::Read.to_string(), "read");
assert_eq!(VendedPermission::Write.to_string(), "write");
assert_eq!(VendedPermission::Admin.to_string(), "admin");
}
#[test]
fn test_parse_permission_with_invalid_values() {
let mut props = HashMap::new();
props.insert(PERMISSION.to_string(), "invalid".to_string());
assert_eq!(parse_permission(&props), VendedPermission::Read);
props.insert(PERMISSION.to_string(), "".to_string());
assert_eq!(parse_permission(&props), VendedPermission::Read);
let empty_props: HashMap<String, String> = HashMap::new();
assert_eq!(parse_permission(&empty_props), VendedPermission::Read);
}
#[test]
fn test_parse_duration_millis_with_invalid_values() {
const TEST_KEY: &str = "test_duration_millis";
let mut props = HashMap::new();
props.insert(TEST_KEY.to_string(), "not_a_number".to_string());
assert_eq!(
parse_duration_millis(&props, TEST_KEY),
DEFAULT_CREDENTIAL_DURATION_MILLIS
);
props.insert(TEST_KEY.to_string(), "-1000".to_string());
assert_eq!(
parse_duration_millis(&props, TEST_KEY),
DEFAULT_CREDENTIAL_DURATION_MILLIS
);
props.insert(TEST_KEY.to_string(), "".to_string());
assert_eq!(
parse_duration_millis(&props, TEST_KEY),
DEFAULT_CREDENTIAL_DURATION_MILLIS
);
let empty_props: HashMap<String, String> = HashMap::new();
assert_eq!(
parse_duration_millis(&empty_props, TEST_KEY),
DEFAULT_CREDENTIAL_DURATION_MILLIS
);
props.insert(TEST_KEY.to_string(), "7200000".to_string());
assert_eq!(parse_duration_millis(&props, TEST_KEY), 7200000);
}
#[test]
fn test_has_credential_vendor_config() {
let mut props = HashMap::new();
props.insert(ENABLED.to_string(), "true".to_string());
assert!(has_credential_vendor_config(&props));
props.insert(ENABLED.to_string(), "TRUE".to_string());
assert!(has_credential_vendor_config(&props));
props.insert(ENABLED.to_string(), "false".to_string());
assert!(!has_credential_vendor_config(&props));
props.insert(ENABLED.to_string(), "yes".to_string());
assert!(!has_credential_vendor_config(&props));
let empty_props: HashMap<String, String> = HashMap::new();
assert!(!has_credential_vendor_config(&empty_props));
}
#[test]
fn test_vended_credentials_debug_redacts_secrets() {
let mut storage_options = HashMap::new();
storage_options.insert(
"aws_access_key_id".to_string(),
"AKIAIOSFODNN7EXAMPLE".to_string(),
);
storage_options.insert(
"aws_secret_access_key".to_string(),
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
);
storage_options.insert(
"aws_session_token".to_string(),
"FwoGZXIvYXdzE...".to_string(),
);
let creds = VendedCredentials::new(storage_options, 1234567890);
let debug_output = format!("{:?}", creds);
assert!(!debug_output.contains("AKIAIOSFODNN7EXAMPLE"));
assert!(!debug_output.contains("wJalrXUtnFEMI"));
assert!(!debug_output.contains("FwoGZXIvYXdzE"));
assert!(debug_output.contains("redacted"));
assert!(debug_output.contains("3 keys"));
assert!(debug_output.contains("1234567890"));
}
#[test]
fn test_vended_credentials_is_expired() {
let past_millis = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
- 1000;
let expired_creds = VendedCredentials::new(HashMap::new(), past_millis);
assert!(expired_creds.is_expired());
let future_millis = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
+ 3600000;
let valid_creds = VendedCredentials::new(HashMap::new(), future_millis);
assert!(!valid_creds.is_expired());
}
#[test]
fn test_redact_credential() {
assert_eq!(redact_credential("AKIAIOSFODNN7EXAMPLE"), "AKIAIOSF***MPLE");
assert_eq!(redact_credential("1234567890123456"), "12345678***3456");
assert_eq!(redact_credential("short1234567"), "short123***");
assert_eq!(redact_credential("short123"), "short123***");
assert_eq!(redact_credential("tiny"), "tiny***");
assert_eq!(redact_credential("ab"), "ab***");
assert_eq!(redact_credential("a"), "a***");
assert_eq!(redact_credential(""), "[empty]");
assert_eq!(redact_credential("AKIAIOSFODNN7EXAMPLE"), "AKIAIOSF***MPLE");
let long_token = "ya29.a0AfH6SMBx1234567890abcdefghijklmnopqrstuvwxyz";
assert_eq!(redact_credential(long_token), "ya29.a0A***wxyz");
let sas_token = "sv=2021-06-08&ss=b&srt=sco&sp=rwdlacuiytfx&se=2024-12-31";
assert_eq!(redact_credential(sas_token), "sv=2021-***2-31");
}
}