use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use dashmap::DashMap;
use datafusion_common::config::EncryptionFactoryOptions;
use datafusion_common::error::Result;
use datafusion_common::internal_datafusion_err;
use object_store::path::Path;
use parquet::encryption::decrypt::FileDecryptionProperties;
use parquet::encryption::encrypt::FileEncryptionProperties;
use std::sync::Arc;
#[async_trait]
pub trait EncryptionFactory: Send + Sync + std::fmt::Debug + 'static {
async fn get_file_encryption_properties(
&self,
config: &EncryptionFactoryOptions,
schema: &SchemaRef,
file_path: &Path,
) -> Result<Option<Arc<FileEncryptionProperties>>>;
async fn get_file_decryption_properties(
&self,
config: &EncryptionFactoryOptions,
file_path: &Path,
) -> Result<Option<Arc<FileDecryptionProperties>>>;
}
#[derive(Clone, Debug, Default)]
pub struct EncryptionFactoryRegistry {
factories: DashMap<String, Arc<dyn EncryptionFactory>>,
}
impl EncryptionFactoryRegistry {
pub fn register_factory(
&self,
id: &str,
factory: Arc<dyn EncryptionFactory>,
) -> Option<Arc<dyn EncryptionFactory>> {
self.factories.insert(id.to_owned(), factory)
}
pub fn get_factory(&self, id: &str) -> Result<Arc<dyn EncryptionFactory>> {
self.factories
.get(id)
.map(|f| Arc::clone(f.value()))
.ok_or_else(|| {
internal_datafusion_err!(
"No Parquet encryption factory found for id '{id}'"
)
})
}
}