use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use iceberg::table::Table;
use iceberg::{
Catalog, Error, ErrorKind, Namespace, NamespaceIdent, Result, TableCommit, TableCreation,
TableIdent, TableRequirement, TableUpdate,
};
use super::CatalogExt;
use crate::crypto::{EncryptedEnvelope, KeyManagementService, KmsError};
pub struct EncryptedCatalog {
inner: Arc<dyn CatalogExt + Send + Sync>,
kms: Arc<dyn KeyManagementService>,
key_id: String,
}
impl std::fmt::Debug for EncryptedCatalog {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptedCatalog")
.field("key_id", &self.key_id)
.field("kms_provider", &self.kms.provider_name())
.finish_non_exhaustive()
}
}
impl EncryptedCatalog {
pub fn new(
inner: Arc<dyn CatalogExt + Send + Sync>,
kms: Arc<dyn KeyManagementService>,
key_id: String,
) -> Self {
tracing::info!(
key_id = %key_id,
provider = %kms.provider_name(),
"Created EncryptedCatalog wrapper"
);
Self { inner, kms, key_id }
}
pub fn inner(&self) -> &Arc<dyn CatalogExt + Send + Sync> {
&self.inner
}
pub fn key_id(&self) -> &str {
&self.key_id
}
async fn encrypt_properties(
&self,
table: &TableIdent,
properties: &HashMap<String, String>,
) -> Result<HashMap<String, String>> {
let plaintext = serde_json::to_vec(properties).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
format!("Failed to serialize properties: {}", e),
)
})?;
let envelope = EncryptedEnvelope::encrypt(&*self.kms, &self.key_id, &plaintext)
.await
.map_err(|e| convert_kms_error(e, table))?;
let mut encrypted_props = HashMap::new();
encrypted_props.insert(
"__encrypted_properties".to_string(),
base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
&envelope.ciphertext,
),
);
encrypted_props.insert(
"__wrapped_dek".to_string(),
base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
&envelope.wrapped_dek,
),
);
encrypted_props.insert("__kms_key_id".to_string(), envelope.key_id);
tracing::debug!(
table = %table,
"Encrypted table properties"
);
Ok(encrypted_props)
}
#[allow(dead_code)] async fn decrypt_properties(
&self,
table: &TableIdent,
properties: &HashMap<String, String>,
) -> Result<HashMap<String, String>> {
let encrypted_data = match properties.get("__encrypted_properties") {
Some(data) => data,
None => {
return Ok(properties.clone());
}
};
let wrapped_dek = properties.get("__wrapped_dek").ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
"Missing __wrapped_dek for encrypted table",
)
})?;
let key_id = properties
.get("__kms_key_id")
.cloned()
.unwrap_or_else(|| self.key_id.clone());
let ciphertext =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encrypted_data)
.map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
format!("Invalid base64 in __encrypted_properties: {}", e),
)
})?;
let wrapped_dek_bytes =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, wrapped_dek)
.map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
format!("Invalid base64 in __wrapped_dek: {}", e),
)
})?;
let envelope = EncryptedEnvelope {
wrapped_dek: wrapped_dek_bytes,
ciphertext,
key_id,
};
let plaintext = envelope
.decrypt(&*self.kms)
.await
.map_err(|e| convert_kms_error(e, table))?;
let decrypted: HashMap<String, String> =
serde_json::from_slice(&plaintext).map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
format!("Failed to deserialize decrypted properties: {}", e),
)
})?;
tracing::debug!(
table = %table,
"Decrypted table properties"
);
Ok(decrypted)
}
}
fn convert_kms_error(err: KmsError, table: &TableIdent) -> Error {
match err {
KmsError::KeyNotFound(msg) => Error::new(
ErrorKind::Unexpected,
format!("KMS key not found for table {}: {}", table, msg),
),
KmsError::AuthenticationFailed(msg) => Error::new(
ErrorKind::Unexpected,
format!("KMS authentication failed for table {}: {}", table, msg),
),
KmsError::RateLimited(ms) => Error::new(
ErrorKind::Unexpected,
format!("KMS rate limited for table {}: retry after {}ms", table, ms),
),
KmsError::ServiceUnavailable(msg) => Error::new(
ErrorKind::Unexpected,
format!("KMS unavailable for table {}: {}", table, msg),
),
_ => Error::new(
ErrorKind::Unexpected,
format!("KMS operation failed for table {}: {}", table, err),
),
}
}
#[async_trait]
impl Catalog for EncryptedCatalog {
async fn list_namespaces(
&self,
parent: Option<&NamespaceIdent>,
) -> Result<Vec<NamespaceIdent>> {
self.inner.list_namespaces(parent).await
}
async fn create_namespace(
&self,
namespace: &NamespaceIdent,
properties: HashMap<String, String>,
) -> Result<Namespace> {
self.inner.create_namespace(namespace, properties).await
}
async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result<Namespace> {
self.inner.get_namespace(namespace).await
}
async fn namespace_exists(&self, namespace: &NamespaceIdent) -> Result<bool> {
self.inner.namespace_exists(namespace).await
}
async fn update_namespace(
&self,
namespace: &NamespaceIdent,
properties: HashMap<String, String>,
) -> Result<()> {
self.inner.update_namespace(namespace, properties).await
}
async fn drop_namespace(&self, namespace: &NamespaceIdent) -> Result<()> {
self.inner.drop_namespace(namespace).await
}
async fn list_tables(&self, namespace: &NamespaceIdent) -> Result<Vec<TableIdent>> {
self.inner.list_tables(namespace).await
}
async fn create_table(
&self,
namespace: &NamespaceIdent,
creation: TableCreation,
) -> Result<Table> {
let encrypted_creation = if !creation.properties.is_empty() {
let table_ident = TableIdent::new(namespace.clone(), creation.name.clone());
let encrypted_props = self
.encrypt_properties(&table_ident, &creation.properties)
.await?;
TableCreation {
name: creation.name,
location: creation.location,
schema: creation.schema,
partition_spec: creation.partition_spec,
sort_order: creation.sort_order,
properties: encrypted_props,
format_version: creation.format_version,
}
} else {
creation
};
let table = self
.inner
.create_table(namespace, encrypted_creation)
.await?;
Ok(table)
}
async fn load_table(&self, table: &TableIdent) -> Result<Table> {
let loaded = self.inner.load_table(table).await?;
Ok(loaded)
}
async fn drop_table(&self, table: &TableIdent) -> Result<()> {
self.inner.drop_table(table).await
}
async fn table_exists(&self, table: &TableIdent) -> Result<bool> {
self.inner.table_exists(table).await
}
async fn rename_table(&self, src: &TableIdent, dest: &TableIdent) -> Result<()> {
self.inner.rename_table(src, dest).await
}
async fn update_table(&self, commit: TableCommit) -> Result<Table> {
self.inner.update_table(commit).await
}
async fn register_table(&self, table: &TableIdent, metadata_location: String) -> Result<Table> {
self.inner.register_table(table, metadata_location).await
}
}
#[async_trait]
impl CatalogExt for EncryptedCatalog {
async fn commit_table(
&self,
table_ident: &TableIdent,
requirements: Vec<TableRequirement>,
updates: Vec<TableUpdate>,
) -> Result<Table> {
self.inner
.commit_table(table_ident, requirements, updates)
.await
}
async fn update_table_metadata_location(
&self,
table_ident: &TableIdent,
new_metadata_location: String,
) -> Result<Table> {
self.inner
.update_table_metadata_location(table_ident, new_metadata_location)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::catalog::ExtendedCatalog;
use crate::crypto::EnvKeyProvider;
use crate::utils::temp_path;
use iceberg::{memory::MemoryCatalogBuilder, CatalogBuilder};
use std::collections::HashMap;
async fn create_test_catalog() -> Arc<dyn CatalogExt + Send + Sync> {
let mut props = HashMap::new();
props.insert("warehouse".to_string(), temp_path());
let memory_catalog = MemoryCatalogBuilder::default()
.load("memory", props)
.await
.unwrap();
Arc::new(ExtendedCatalog::new(memory_catalog))
}
#[test]
fn test_encrypted_catalog_debug() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let key = vec![0u8; 32];
let kms = Arc::new(EnvKeyProvider::with_key(key).unwrap());
let inner = create_test_catalog().await;
let encrypted = EncryptedCatalog::new(inner, kms, "test".to_string());
let debug = format!("{:?}", encrypted);
assert!(debug.contains("EncryptedCatalog"));
assert!(debug.contains("key_id"));
});
}
#[tokio::test]
async fn test_encrypt_decrypt_properties() {
let key = vec![0u8; 32];
let kms = Arc::new(EnvKeyProvider::with_key(key).unwrap());
let inner = create_test_catalog().await;
let wrapper = EncryptedCatalog::new(
inner,
kms as Arc<dyn KeyManagementService>,
"test-key".to_string(),
);
let table = TableIdent::new(
NamespaceIdent::from_vec(vec!["test_ns".to_string()]).unwrap(),
"test_table".to_string(),
);
let mut original_props = HashMap::new();
original_props.insert("key1".to_string(), "value1".to_string());
original_props.insert("secret".to_string(), "sensitive_data".to_string());
let encrypted = wrapper
.encrypt_properties(&table, &original_props)
.await
.unwrap();
assert!(encrypted.contains_key("__encrypted_properties"));
assert!(encrypted.contains_key("__wrapped_dek"));
assert!(encrypted.contains_key("__kms_key_id"));
assert!(!encrypted.contains_key("key1"));
assert!(!encrypted.contains_key("secret"));
let decrypted = wrapper
.decrypt_properties(&table, &encrypted)
.await
.unwrap();
assert_eq!(decrypted, original_props);
}
#[tokio::test]
async fn test_unencrypted_properties_passthrough() {
let key = vec![0u8; 32];
let kms = Arc::new(EnvKeyProvider::with_key(key).unwrap());
let inner = create_test_catalog().await;
let wrapper = EncryptedCatalog::new(
inner,
kms as Arc<dyn KeyManagementService>,
"test-key".to_string(),
);
let table = TableIdent::new(
NamespaceIdent::from_vec(vec!["test_ns".to_string()]).unwrap(),
"test_table".to_string(),
);
let mut props = HashMap::new();
props.insert("normal_key".to_string(), "normal_value".to_string());
let result = wrapper.decrypt_properties(&table, &props).await.unwrap();
assert_eq!(result, props);
}
}