Skip to main content

datafusion_python/
store.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use object_store::aws::{AmazonS3, AmazonS3Builder};
21use object_store::azure::{MicrosoftAzure, MicrosoftAzureBuilder};
22use object_store::gcp::{GoogleCloudStorage, GoogleCloudStorageBuilder};
23use object_store::http::{HttpBuilder, HttpStore};
24use object_store::local::LocalFileSystem;
25use pyo3::exceptions::PyValueError;
26use pyo3::prelude::*;
27use url::Url;
28
29#[derive(FromPyObject)]
30pub enum StorageContexts {
31    AmazonS3(PyAmazonS3Context),
32    GoogleCloudStorage(PyGoogleCloudContext),
33    MicrosoftAzure(PyMicrosoftAzureContext),
34    LocalFileSystem(PyLocalFileSystemContext),
35    HTTP(PyHttpContext),
36}
37
38#[pyclass(
39    from_py_object,
40    frozen,
41    name = "LocalFileSystem",
42    module = "datafusion.store",
43    subclass
44)]
45#[derive(Debug, Clone)]
46pub struct PyLocalFileSystemContext {
47    pub inner: Arc<LocalFileSystem>,
48}
49
50#[pymethods]
51impl PyLocalFileSystemContext {
52    #[pyo3(signature = (prefix=None))]
53    #[new]
54    fn new(prefix: Option<String>) -> Self {
55        if let Some(prefix) = prefix {
56            Self {
57                inner: Arc::new(
58                    LocalFileSystem::new_with_prefix(prefix)
59                        .expect("Could not create local LocalFileSystem"),
60                ),
61            }
62        } else {
63            Self {
64                inner: Arc::new(LocalFileSystem::new()),
65            }
66        }
67    }
68}
69
70#[pyclass(
71    from_py_object,
72    frozen,
73    name = "MicrosoftAzure",
74    module = "datafusion.store",
75    subclass
76)]
77#[derive(Debug, Clone)]
78pub struct PyMicrosoftAzureContext {
79    pub inner: Arc<MicrosoftAzure>,
80    pub container_name: String,
81}
82
83#[pymethods]
84impl PyMicrosoftAzureContext {
85    #[allow(clippy::too_many_arguments)]
86    #[pyo3(signature = (container_name, account=None, access_key=None, bearer_token=None, client_id=None, client_secret=None, tenant_id=None, sas_query_pairs=None, use_emulator=None, allow_http=None, use_fabric_endpoint=None))]
87    #[new]
88    fn new(
89        container_name: String,
90        account: Option<String>,
91        access_key: Option<String>,
92        bearer_token: Option<String>,
93        client_id: Option<String>,
94        client_secret: Option<String>,
95        tenant_id: Option<String>,
96        sas_query_pairs: Option<Vec<(String, String)>>,
97        use_emulator: Option<bool>,
98        allow_http: Option<bool>,
99        use_fabric_endpoint: Option<bool>,
100    ) -> Self {
101        let mut builder = MicrosoftAzureBuilder::from_env().with_container_name(&container_name);
102
103        if let Some(account) = account {
104            builder = builder.with_account(account);
105        }
106
107        if let Some(access_key) = access_key {
108            builder = builder.with_access_key(access_key);
109        }
110
111        if let Some(bearer_token) = bearer_token {
112            builder = builder.with_bearer_token_authorization(bearer_token);
113        }
114
115        match (client_id, client_secret, tenant_id) {
116            (Some(client_id), Some(client_secret), Some(tenant_id)) => {
117                builder =
118                    builder.with_client_secret_authorization(client_id, client_secret, tenant_id);
119            }
120            (None, None, None) => {}
121            _ => {
122                panic!("client_id, client_secret, tenat_id must be all set or all None");
123            }
124        }
125
126        if let Some(sas_query_pairs) = sas_query_pairs {
127            builder = builder.with_sas_authorization(sas_query_pairs);
128        }
129
130        if let Some(use_emulator) = use_emulator {
131            builder = builder.with_use_emulator(use_emulator);
132        }
133
134        if let Some(allow_http) = allow_http {
135            builder = builder.with_allow_http(allow_http);
136        }
137
138        if let Some(use_fabric_endpoint) = use_fabric_endpoint {
139            builder = builder.with_use_fabric_endpoint(use_fabric_endpoint);
140        }
141
142        Self {
143            inner: Arc::new(
144                builder
145                    .build()
146                    .expect("Could not create Azure Storage context"), //TODO: change these to PyErr
147            ),
148            container_name,
149        }
150    }
151}
152
153#[pyclass(
154    from_py_object,
155    frozen,
156    name = "GoogleCloud",
157    module = "datafusion.store",
158    subclass
159)]
160#[derive(Debug, Clone)]
161pub struct PyGoogleCloudContext {
162    pub inner: Arc<GoogleCloudStorage>,
163    pub bucket_name: String,
164}
165
166#[pymethods]
167impl PyGoogleCloudContext {
168    #[allow(clippy::too_many_arguments)]
169    #[pyo3(signature = (bucket_name, service_account_path=None))]
170    #[new]
171    fn new(bucket_name: String, service_account_path: Option<String>) -> Self {
172        let mut builder = GoogleCloudStorageBuilder::new().with_bucket_name(&bucket_name);
173
174        if let Some(credential_path) = service_account_path {
175            builder = builder.with_service_account_path(credential_path);
176        }
177
178        Self {
179            inner: Arc::new(
180                builder
181                    .build()
182                    .expect("Could not create Google Cloud Storage"),
183            ),
184            bucket_name,
185        }
186    }
187}
188
189#[pyclass(
190    from_py_object,
191    frozen,
192    name = "AmazonS3",
193    module = "datafusion.store",
194    subclass
195)]
196#[derive(Debug, Clone)]
197pub struct PyAmazonS3Context {
198    pub inner: Arc<AmazonS3>,
199    pub bucket_name: String,
200}
201
202#[pymethods]
203impl PyAmazonS3Context {
204    #[allow(clippy::too_many_arguments)]
205    #[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, session_token=None, endpoint=None, allow_http=false, imdsv1_fallback=false))]
206    #[new]
207    fn new(
208        bucket_name: String,
209        region: Option<String>,
210        access_key_id: Option<String>,
211        secret_access_key: Option<String>,
212        session_token: Option<String>,
213        endpoint: Option<String>,
214        //retry_config: RetryConfig,
215        allow_http: bool,
216        imdsv1_fallback: bool,
217    ) -> Self {
218        // start w/ the options that come directly from the environment
219        let mut builder = AmazonS3Builder::from_env();
220
221        if let Some(region) = region {
222            builder = builder.with_region(region);
223        }
224
225        if let Some(access_key_id) = access_key_id {
226            builder = builder.with_access_key_id(access_key_id);
227        };
228
229        if let Some(secret_access_key) = secret_access_key {
230            builder = builder.with_secret_access_key(secret_access_key);
231        };
232
233        if let Some(session_token) = session_token {
234            builder = builder.with_token(session_token);
235        }
236
237        if let Some(endpoint) = endpoint {
238            builder = builder.with_endpoint(endpoint);
239        };
240
241        if imdsv1_fallback {
242            builder = builder.with_imdsv1_fallback();
243        };
244
245        let store = builder
246            .with_bucket_name(bucket_name.clone())
247            //.with_retry_config(retry_config) #TODO: add later
248            .with_allow_http(allow_http)
249            .build()
250            .expect("failed to build AmazonS3");
251
252        Self {
253            inner: Arc::new(store),
254            bucket_name,
255        }
256    }
257}
258
259#[pyclass(
260    from_py_object,
261    frozen,
262    name = "Http",
263    module = "datafusion.store",
264    subclass
265)]
266#[derive(Debug, Clone)]
267pub struct PyHttpContext {
268    pub url: String,
269    pub store: Arc<HttpStore>,
270}
271
272#[pymethods]
273impl PyHttpContext {
274    #[new]
275    fn new(url: String) -> PyResult<Self> {
276        let store = match Url::parse(url.as_str()) {
277            Ok(url) => HttpBuilder::new()
278                .with_url(url.origin().ascii_serialization())
279                .build(),
280            Err(_) => HttpBuilder::new().build(),
281        }
282        .map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))?;
283
284        Ok(Self {
285            url,
286            store: Arc::new(store),
287        })
288    }
289}
290
291pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
292    m.add_class::<PyAmazonS3Context>()?;
293    m.add_class::<PyMicrosoftAzureContext>()?;
294    m.add_class::<PyGoogleCloudContext>()?;
295    m.add_class::<PyLocalFileSystemContext>()?;
296    m.add_class::<PyHttpContext>()?;
297    Ok(())
298}