use std::sync::Arc;
use pyo3::prelude::*;
use object_store::aws::{AmazonS3, AmazonS3Builder};
use object_store::azure::{MicrosoftAzure, MicrosoftAzureBuilder};
use object_store::gcp::{GoogleCloudStorage, GoogleCloudStorageBuilder};
use object_store::local::LocalFileSystem;
#[derive(FromPyObject)]
pub enum StorageContexts {
AmazonS3(PyAmazonS3Context),
GoogleCloudStorage(PyGoogleCloudContext),
MicrosoftAzure(PyMicrosoftAzureContext),
LocalFileSystem(PyLocalFileSystemContext),
}
#[pyclass(name = "LocalFileSystem", module = "datafusion.store", subclass)]
#[derive(Debug, Clone)]
pub struct PyLocalFileSystemContext {
pub inner: Arc<LocalFileSystem>,
}
#[pymethods]
impl PyLocalFileSystemContext {
#[pyo3(signature = (prefix=None))]
#[new]
fn new(prefix: Option<String>) -> Self {
if let Some(prefix) = prefix {
Self {
inner: Arc::new(
LocalFileSystem::new_with_prefix(prefix)
.expect("Could not create local LocalFileSystem"),
),
}
} else {
Self {
inner: Arc::new(LocalFileSystem::new()),
}
}
}
}
#[pyclass(name = "MicrosoftAzure", module = "datafusion.store", subclass)]
#[derive(Debug, Clone)]
pub struct PyMicrosoftAzureContext {
pub inner: Arc<MicrosoftAzure>,
pub container_name: String,
}
#[pymethods]
impl PyMicrosoftAzureContext {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (container_name, account=None, access_key=None, bearer_token=None, client_id=None, client_secret=None, tenant_id=None, sas_query_pairs=None, use_emulator=None, allow_http=None))]
#[new]
fn new(
container_name: String,
account: Option<String>,
access_key: Option<String>,
bearer_token: Option<String>,
client_id: Option<String>,
client_secret: Option<String>,
tenant_id: Option<String>,
sas_query_pairs: Option<Vec<(String, String)>>,
use_emulator: Option<bool>,
allow_http: Option<bool>,
) -> Self {
let mut builder = MicrosoftAzureBuilder::from_env().with_container_name(&container_name);
if let Some(account) = account {
builder = builder.with_account(account);
}
if let Some(access_key) = access_key {
builder = builder.with_access_key(access_key);
}
if let Some(bearer_token) = bearer_token {
builder = builder.with_bearer_token_authorization(bearer_token);
}
match (client_id, client_secret, tenant_id) {
(Some(client_id), Some(client_secret), Some(tenant_id)) => {
builder =
builder.with_client_secret_authorization(client_id, client_secret, tenant_id);
}
(None, None, None) => {}
_ => {
panic!("client_id, client_secret, tenat_id must be all set or all None");
}
}
if let Some(sas_query_pairs) = sas_query_pairs {
builder = builder.with_sas_authorization(sas_query_pairs);
}
if let Some(use_emulator) = use_emulator {
builder = builder.with_use_emulator(use_emulator);
}
if let Some(allow_http) = allow_http {
builder = builder.with_allow_http(allow_http);
}
Self {
inner: Arc::new(
builder
.build()
.expect("Could not create Azure Storage context"), ),
container_name,
}
}
}
#[pyclass(name = "GoogleCloud", module = "datafusion.store", subclass)]
#[derive(Debug, Clone)]
pub struct PyGoogleCloudContext {
pub inner: Arc<GoogleCloudStorage>,
pub bucket_name: String,
}
#[pymethods]
impl PyGoogleCloudContext {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (bucket_name, service_account_path=None))]
#[new]
fn new(bucket_name: String, service_account_path: Option<String>) -> Self {
let mut builder = GoogleCloudStorageBuilder::new().with_bucket_name(&bucket_name);
if let Some(credential_path) = service_account_path {
builder = builder.with_service_account_path(credential_path);
}
Self {
inner: Arc::new(
builder
.build()
.expect("Could not create Google Cloud Storage"),
),
bucket_name,
}
}
}
#[pyclass(name = "AmazonS3", module = "datafusion.store", subclass)]
#[derive(Debug, Clone)]
pub struct PyAmazonS3Context {
pub inner: Arc<AmazonS3>,
pub bucket_name: String,
}
#[pymethods]
impl PyAmazonS3Context {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, endpoint=None, allow_http=false, imdsv1_fallback=false))]
#[new]
fn new(
bucket_name: String,
region: Option<String>,
access_key_id: Option<String>,
secret_access_key: Option<String>,
endpoint: Option<String>,
allow_http: bool,
imdsv1_fallback: bool,
) -> Self {
let mut builder = AmazonS3Builder::from_env();
if let Some(region) = region {
builder = builder.with_region(region);
}
if let Some(access_key_id) = access_key_id {
builder = builder.with_access_key_id(access_key_id);
};
if let Some(secret_access_key) = secret_access_key {
builder = builder.with_secret_access_key(secret_access_key);
};
if let Some(endpoint) = endpoint {
builder = builder.with_endpoint(endpoint);
};
if imdsv1_fallback {
builder = builder.with_imdsv1_fallback();
};
let store = builder
.with_bucket_name(bucket_name.clone())
.with_allow_http(allow_http)
.build()
.expect("failed to build AmazonS3");
Self {
inner: Arc::new(store),
bucket_name,
}
}
}
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_class::<PyAmazonS3Context>()?;
m.add_class::<PyMicrosoftAzureContext>()?;
m.add_class::<PyGoogleCloudContext>()?;
m.add_class::<PyLocalFileSystemContext>()?;
Ok(())
}