datafusion_python/
store.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use pyo3::prelude::*;
21
22use object_store::aws::{AmazonS3, AmazonS3Builder};
23use object_store::azure::{MicrosoftAzure, MicrosoftAzureBuilder};
24use object_store::gcp::{GoogleCloudStorage, GoogleCloudStorageBuilder};
25use object_store::http::{HttpBuilder, HttpStore};
26use object_store::local::LocalFileSystem;
27use pyo3::exceptions::PyValueError;
28use url::Url;
29
30#[derive(FromPyObject)]
31pub enum StorageContexts {
32    AmazonS3(PyAmazonS3Context),
33    GoogleCloudStorage(PyGoogleCloudContext),
34    MicrosoftAzure(PyMicrosoftAzureContext),
35    LocalFileSystem(PyLocalFileSystemContext),
36    HTTP(PyHttpContext),
37}
38
39#[pyclass(name = "LocalFileSystem", module = "datafusion.store", subclass)]
40#[derive(Debug, Clone)]
41pub struct PyLocalFileSystemContext {
42    pub inner: Arc<LocalFileSystem>,
43}
44
45#[pymethods]
46impl PyLocalFileSystemContext {
47    #[pyo3(signature = (prefix=None))]
48    #[new]
49    fn new(prefix: Option<String>) -> Self {
50        if let Some(prefix) = prefix {
51            Self {
52                inner: Arc::new(
53                    LocalFileSystem::new_with_prefix(prefix)
54                        .expect("Could not create local LocalFileSystem"),
55                ),
56            }
57        } else {
58            Self {
59                inner: Arc::new(LocalFileSystem::new()),
60            }
61        }
62    }
63}
64
65#[pyclass(name = "MicrosoftAzure", module = "datafusion.store", subclass)]
66#[derive(Debug, Clone)]
67pub struct PyMicrosoftAzureContext {
68    pub inner: Arc<MicrosoftAzure>,
69    pub container_name: String,
70}
71
72#[pymethods]
73impl PyMicrosoftAzureContext {
74    #[allow(clippy::too_many_arguments)]
75    #[pyo3(signature = (container_name, account=None, access_key=None, bearer_token=None, client_id=None, client_secret=None, tenant_id=None, sas_query_pairs=None, use_emulator=None, allow_http=None))]
76    #[new]
77    fn new(
78        container_name: String,
79        account: Option<String>,
80        access_key: Option<String>,
81        bearer_token: Option<String>,
82        client_id: Option<String>,
83        client_secret: Option<String>,
84        tenant_id: Option<String>,
85        sas_query_pairs: Option<Vec<(String, String)>>,
86        use_emulator: Option<bool>,
87        allow_http: Option<bool>,
88    ) -> Self {
89        let mut builder = MicrosoftAzureBuilder::from_env().with_container_name(&container_name);
90
91        if let Some(account) = account {
92            builder = builder.with_account(account);
93        }
94
95        if let Some(access_key) = access_key {
96            builder = builder.with_access_key(access_key);
97        }
98
99        if let Some(bearer_token) = bearer_token {
100            builder = builder.with_bearer_token_authorization(bearer_token);
101        }
102
103        match (client_id, client_secret, tenant_id) {
104            (Some(client_id), Some(client_secret), Some(tenant_id)) => {
105                builder =
106                    builder.with_client_secret_authorization(client_id, client_secret, tenant_id);
107            }
108            (None, None, None) => {}
109            _ => {
110                panic!("client_id, client_secret, tenat_id must be all set or all None");
111            }
112        }
113
114        if let Some(sas_query_pairs) = sas_query_pairs {
115            builder = builder.with_sas_authorization(sas_query_pairs);
116        }
117
118        if let Some(use_emulator) = use_emulator {
119            builder = builder.with_use_emulator(use_emulator);
120        }
121
122        if let Some(allow_http) = allow_http {
123            builder = builder.with_allow_http(allow_http);
124        }
125
126        Self {
127            inner: Arc::new(
128                builder
129                    .build()
130                    .expect("Could not create Azure Storage context"), //TODO: change these to PyErr
131            ),
132            container_name,
133        }
134    }
135}
136
137#[pyclass(name = "GoogleCloud", module = "datafusion.store", subclass)]
138#[derive(Debug, Clone)]
139pub struct PyGoogleCloudContext {
140    pub inner: Arc<GoogleCloudStorage>,
141    pub bucket_name: String,
142}
143
144#[pymethods]
145impl PyGoogleCloudContext {
146    #[allow(clippy::too_many_arguments)]
147    #[pyo3(signature = (bucket_name, service_account_path=None))]
148    #[new]
149    fn new(bucket_name: String, service_account_path: Option<String>) -> Self {
150        let mut builder = GoogleCloudStorageBuilder::new().with_bucket_name(&bucket_name);
151
152        if let Some(credential_path) = service_account_path {
153            builder = builder.with_service_account_path(credential_path);
154        }
155
156        Self {
157            inner: Arc::new(
158                builder
159                    .build()
160                    .expect("Could not create Google Cloud Storage"),
161            ),
162            bucket_name,
163        }
164    }
165}
166
167#[pyclass(name = "AmazonS3", module = "datafusion.store", subclass)]
168#[derive(Debug, Clone)]
169pub struct PyAmazonS3Context {
170    pub inner: Arc<AmazonS3>,
171    pub bucket_name: String,
172}
173
174#[pymethods]
175impl PyAmazonS3Context {
176    #[allow(clippy::too_many_arguments)]
177    #[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, endpoint=None, allow_http=false, imdsv1_fallback=false))]
178    #[new]
179    fn new(
180        bucket_name: String,
181        region: Option<String>,
182        access_key_id: Option<String>,
183        secret_access_key: Option<String>,
184        endpoint: Option<String>,
185        //retry_config: RetryConfig,
186        allow_http: bool,
187        imdsv1_fallback: bool,
188    ) -> Self {
189        // start w/ the options that come directly from the environment
190        let mut builder = AmazonS3Builder::from_env();
191
192        if let Some(region) = region {
193            builder = builder.with_region(region);
194        }
195
196        if let Some(access_key_id) = access_key_id {
197            builder = builder.with_access_key_id(access_key_id);
198        };
199
200        if let Some(secret_access_key) = secret_access_key {
201            builder = builder.with_secret_access_key(secret_access_key);
202        };
203
204        if let Some(endpoint) = endpoint {
205            builder = builder.with_endpoint(endpoint);
206        };
207
208        if imdsv1_fallback {
209            builder = builder.with_imdsv1_fallback();
210        };
211
212        let store = builder
213            .with_bucket_name(bucket_name.clone())
214            //.with_retry_config(retry_config) #TODO: add later
215            .with_allow_http(allow_http)
216            .build()
217            .expect("failed to build AmazonS3");
218
219        Self {
220            inner: Arc::new(store),
221            bucket_name,
222        }
223    }
224}
225
226#[pyclass(name = "Http", module = "datafusion.store", subclass)]
227#[derive(Debug, Clone)]
228pub struct PyHttpContext {
229    pub url: String,
230    pub store: Arc<HttpStore>,
231}
232
233#[pymethods]
234impl PyHttpContext {
235    #[new]
236    fn new(url: String) -> PyResult<Self> {
237        let store = match Url::parse(url.as_str()) {
238            Ok(url) => HttpBuilder::new()
239                .with_url(url.origin().ascii_serialization())
240                .build(),
241            Err(_) => HttpBuilder::new().build(),
242        }
243        .map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))?;
244
245        Ok(Self {
246            url,
247            store: Arc::new(store),
248        })
249    }
250}
251
252pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
253    m.add_class::<PyAmazonS3Context>()?;
254    m.add_class::<PyMicrosoftAzureContext>()?;
255    m.add_class::<PyGoogleCloudContext>()?;
256    m.add_class::<PyLocalFileSystemContext>()?;
257    m.add_class::<PyHttpContext>()?;
258    Ok(())
259}