datafusion_python/
store.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use pyo3::prelude::*;
21
22use object_store::aws::{AmazonS3, AmazonS3Builder};
23use object_store::azure::{MicrosoftAzure, MicrosoftAzureBuilder};
24use object_store::gcp::{GoogleCloudStorage, GoogleCloudStorageBuilder};
25use object_store::http::{HttpBuilder, HttpStore};
26use object_store::local::LocalFileSystem;
27use pyo3::exceptions::PyValueError;
28use url::Url;
29
30#[derive(FromPyObject)]
31pub enum StorageContexts {
32    AmazonS3(PyAmazonS3Context),
33    GoogleCloudStorage(PyGoogleCloudContext),
34    MicrosoftAzure(PyMicrosoftAzureContext),
35    LocalFileSystem(PyLocalFileSystemContext),
36    HTTP(PyHttpContext),
37}
38
39#[pyclass(
40    frozen,
41    name = "LocalFileSystem",
42    module = "datafusion.store",
43    subclass
44)]
45#[derive(Debug, Clone)]
46pub struct PyLocalFileSystemContext {
47    pub inner: Arc<LocalFileSystem>,
48}
49
50#[pymethods]
51impl PyLocalFileSystemContext {
52    #[pyo3(signature = (prefix=None))]
53    #[new]
54    fn new(prefix: Option<String>) -> Self {
55        if let Some(prefix) = prefix {
56            Self {
57                inner: Arc::new(
58                    LocalFileSystem::new_with_prefix(prefix)
59                        .expect("Could not create local LocalFileSystem"),
60                ),
61            }
62        } else {
63            Self {
64                inner: Arc::new(LocalFileSystem::new()),
65            }
66        }
67    }
68}
69
70#[pyclass(frozen, name = "MicrosoftAzure", module = "datafusion.store", subclass)]
71#[derive(Debug, Clone)]
72pub struct PyMicrosoftAzureContext {
73    pub inner: Arc<MicrosoftAzure>,
74    pub container_name: String,
75}
76
77#[pymethods]
78impl PyMicrosoftAzureContext {
79    #[allow(clippy::too_many_arguments)]
80    #[pyo3(signature = (container_name, account=None, access_key=None, bearer_token=None, client_id=None, client_secret=None, tenant_id=None, sas_query_pairs=None, use_emulator=None, allow_http=None))]
81    #[new]
82    fn new(
83        container_name: String,
84        account: Option<String>,
85        access_key: Option<String>,
86        bearer_token: Option<String>,
87        client_id: Option<String>,
88        client_secret: Option<String>,
89        tenant_id: Option<String>,
90        sas_query_pairs: Option<Vec<(String, String)>>,
91        use_emulator: Option<bool>,
92        allow_http: Option<bool>,
93    ) -> Self {
94        let mut builder = MicrosoftAzureBuilder::from_env().with_container_name(&container_name);
95
96        if let Some(account) = account {
97            builder = builder.with_account(account);
98        }
99
100        if let Some(access_key) = access_key {
101            builder = builder.with_access_key(access_key);
102        }
103
104        if let Some(bearer_token) = bearer_token {
105            builder = builder.with_bearer_token_authorization(bearer_token);
106        }
107
108        match (client_id, client_secret, tenant_id) {
109            (Some(client_id), Some(client_secret), Some(tenant_id)) => {
110                builder =
111                    builder.with_client_secret_authorization(client_id, client_secret, tenant_id);
112            }
113            (None, None, None) => {}
114            _ => {
115                panic!("client_id, client_secret, tenat_id must be all set or all None");
116            }
117        }
118
119        if let Some(sas_query_pairs) = sas_query_pairs {
120            builder = builder.with_sas_authorization(sas_query_pairs);
121        }
122
123        if let Some(use_emulator) = use_emulator {
124            builder = builder.with_use_emulator(use_emulator);
125        }
126
127        if let Some(allow_http) = allow_http {
128            builder = builder.with_allow_http(allow_http);
129        }
130
131        Self {
132            inner: Arc::new(
133                builder
134                    .build()
135                    .expect("Could not create Azure Storage context"), //TODO: change these to PyErr
136            ),
137            container_name,
138        }
139    }
140}
141
142#[pyclass(frozen, name = "GoogleCloud", module = "datafusion.store", subclass)]
143#[derive(Debug, Clone)]
144pub struct PyGoogleCloudContext {
145    pub inner: Arc<GoogleCloudStorage>,
146    pub bucket_name: String,
147}
148
149#[pymethods]
150impl PyGoogleCloudContext {
151    #[allow(clippy::too_many_arguments)]
152    #[pyo3(signature = (bucket_name, service_account_path=None))]
153    #[new]
154    fn new(bucket_name: String, service_account_path: Option<String>) -> Self {
155        let mut builder = GoogleCloudStorageBuilder::new().with_bucket_name(&bucket_name);
156
157        if let Some(credential_path) = service_account_path {
158            builder = builder.with_service_account_path(credential_path);
159        }
160
161        Self {
162            inner: Arc::new(
163                builder
164                    .build()
165                    .expect("Could not create Google Cloud Storage"),
166            ),
167            bucket_name,
168        }
169    }
170}
171
172#[pyclass(frozen, name = "AmazonS3", module = "datafusion.store", subclass)]
173#[derive(Debug, Clone)]
174pub struct PyAmazonS3Context {
175    pub inner: Arc<AmazonS3>,
176    pub bucket_name: String,
177}
178
179#[pymethods]
180impl PyAmazonS3Context {
181    #[allow(clippy::too_many_arguments)]
182    #[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, endpoint=None, allow_http=false, imdsv1_fallback=false))]
183    #[new]
184    fn new(
185        bucket_name: String,
186        region: Option<String>,
187        access_key_id: Option<String>,
188        secret_access_key: Option<String>,
189        endpoint: Option<String>,
190        //retry_config: RetryConfig,
191        allow_http: bool,
192        imdsv1_fallback: bool,
193    ) -> Self {
194        // start w/ the options that come directly from the environment
195        let mut builder = AmazonS3Builder::from_env();
196
197        if let Some(region) = region {
198            builder = builder.with_region(region);
199        }
200
201        if let Some(access_key_id) = access_key_id {
202            builder = builder.with_access_key_id(access_key_id);
203        };
204
205        if let Some(secret_access_key) = secret_access_key {
206            builder = builder.with_secret_access_key(secret_access_key);
207        };
208
209        if let Some(endpoint) = endpoint {
210            builder = builder.with_endpoint(endpoint);
211        };
212
213        if imdsv1_fallback {
214            builder = builder.with_imdsv1_fallback();
215        };
216
217        let store = builder
218            .with_bucket_name(bucket_name.clone())
219            //.with_retry_config(retry_config) #TODO: add later
220            .with_allow_http(allow_http)
221            .build()
222            .expect("failed to build AmazonS3");
223
224        Self {
225            inner: Arc::new(store),
226            bucket_name,
227        }
228    }
229}
230
231#[pyclass(frozen, name = "Http", module = "datafusion.store", subclass)]
232#[derive(Debug, Clone)]
233pub struct PyHttpContext {
234    pub url: String,
235    pub store: Arc<HttpStore>,
236}
237
238#[pymethods]
239impl PyHttpContext {
240    #[new]
241    fn new(url: String) -> PyResult<Self> {
242        let store = match Url::parse(url.as_str()) {
243            Ok(url) => HttpBuilder::new()
244                .with_url(url.origin().ascii_serialization())
245                .build(),
246            Err(_) => HttpBuilder::new().build(),
247        }
248        .map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))?;
249
250        Ok(Self {
251            url,
252            store: Arc::new(store),
253        })
254    }
255}
256
257pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
258    m.add_class::<PyAmazonS3Context>()?;
259    m.add_class::<PyMicrosoftAzureContext>()?;
260    m.add_class::<PyGoogleCloudContext>()?;
261    m.add_class::<PyLocalFileSystemContext>()?;
262    m.add_class::<PyHttpContext>()?;
263    Ok(())
264}