Skip to main content

datafusion_python/
store.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use object_store::aws::{AmazonS3, AmazonS3Builder};
21use object_store::azure::{MicrosoftAzure, MicrosoftAzureBuilder};
22use object_store::gcp::{GoogleCloudStorage, GoogleCloudStorageBuilder};
23use object_store::http::{HttpBuilder, HttpStore};
24use object_store::local::LocalFileSystem;
25use pyo3::exceptions::PyValueError;
26use pyo3::prelude::*;
27use url::Url;
28
29#[derive(FromPyObject)]
30pub enum StorageContexts {
31    AmazonS3(PyAmazonS3Context),
32    GoogleCloudStorage(PyGoogleCloudContext),
33    MicrosoftAzure(PyMicrosoftAzureContext),
34    LocalFileSystem(PyLocalFileSystemContext),
35    HTTP(PyHttpContext),
36}
37
38#[pyclass(
39    frozen,
40    name = "LocalFileSystem",
41    module = "datafusion.store",
42    subclass
43)]
44#[derive(Debug, Clone)]
45pub struct PyLocalFileSystemContext {
46    pub inner: Arc<LocalFileSystem>,
47}
48
49#[pymethods]
50impl PyLocalFileSystemContext {
51    #[pyo3(signature = (prefix=None))]
52    #[new]
53    fn new(prefix: Option<String>) -> Self {
54        if let Some(prefix) = prefix {
55            Self {
56                inner: Arc::new(
57                    LocalFileSystem::new_with_prefix(prefix)
58                        .expect("Could not create local LocalFileSystem"),
59                ),
60            }
61        } else {
62            Self {
63                inner: Arc::new(LocalFileSystem::new()),
64            }
65        }
66    }
67}
68
69#[pyclass(frozen, name = "MicrosoftAzure", module = "datafusion.store", subclass)]
70#[derive(Debug, Clone)]
71pub struct PyMicrosoftAzureContext {
72    pub inner: Arc<MicrosoftAzure>,
73    pub container_name: String,
74}
75
76#[pymethods]
77impl PyMicrosoftAzureContext {
78    #[allow(clippy::too_many_arguments)]
79    #[pyo3(signature = (container_name, account=None, access_key=None, bearer_token=None, client_id=None, client_secret=None, tenant_id=None, sas_query_pairs=None, use_emulator=None, allow_http=None, use_fabric_endpoint=None))]
80    #[new]
81    fn new(
82        container_name: String,
83        account: Option<String>,
84        access_key: Option<String>,
85        bearer_token: Option<String>,
86        client_id: Option<String>,
87        client_secret: Option<String>,
88        tenant_id: Option<String>,
89        sas_query_pairs: Option<Vec<(String, String)>>,
90        use_emulator: Option<bool>,
91        allow_http: Option<bool>,
92        use_fabric_endpoint: Option<bool>,
93    ) -> Self {
94        let mut builder = MicrosoftAzureBuilder::from_env().with_container_name(&container_name);
95
96        if let Some(account) = account {
97            builder = builder.with_account(account);
98        }
99
100        if let Some(access_key) = access_key {
101            builder = builder.with_access_key(access_key);
102        }
103
104        if let Some(bearer_token) = bearer_token {
105            builder = builder.with_bearer_token_authorization(bearer_token);
106        }
107
108        match (client_id, client_secret, tenant_id) {
109            (Some(client_id), Some(client_secret), Some(tenant_id)) => {
110                builder =
111                    builder.with_client_secret_authorization(client_id, client_secret, tenant_id);
112            }
113            (None, None, None) => {}
114            _ => {
115                panic!("client_id, client_secret, tenat_id must be all set or all None");
116            }
117        }
118
119        if let Some(sas_query_pairs) = sas_query_pairs {
120            builder = builder.with_sas_authorization(sas_query_pairs);
121        }
122
123        if let Some(use_emulator) = use_emulator {
124            builder = builder.with_use_emulator(use_emulator);
125        }
126
127        if let Some(allow_http) = allow_http {
128            builder = builder.with_allow_http(allow_http);
129        }
130
131        if let Some(use_fabric_endpoint) = use_fabric_endpoint {
132            builder = builder.with_use_fabric_endpoint(use_fabric_endpoint);
133        }
134
135        Self {
136            inner: Arc::new(
137                builder
138                    .build()
139                    .expect("Could not create Azure Storage context"), //TODO: change these to PyErr
140            ),
141            container_name,
142        }
143    }
144}
145
146#[pyclass(frozen, name = "GoogleCloud", module = "datafusion.store", subclass)]
147#[derive(Debug, Clone)]
148pub struct PyGoogleCloudContext {
149    pub inner: Arc<GoogleCloudStorage>,
150    pub bucket_name: String,
151}
152
153#[pymethods]
154impl PyGoogleCloudContext {
155    #[allow(clippy::too_many_arguments)]
156    #[pyo3(signature = (bucket_name, service_account_path=None))]
157    #[new]
158    fn new(bucket_name: String, service_account_path: Option<String>) -> Self {
159        let mut builder = GoogleCloudStorageBuilder::new().with_bucket_name(&bucket_name);
160
161        if let Some(credential_path) = service_account_path {
162            builder = builder.with_service_account_path(credential_path);
163        }
164
165        Self {
166            inner: Arc::new(
167                builder
168                    .build()
169                    .expect("Could not create Google Cloud Storage"),
170            ),
171            bucket_name,
172        }
173    }
174}
175
176#[pyclass(frozen, name = "AmazonS3", module = "datafusion.store", subclass)]
177#[derive(Debug, Clone)]
178pub struct PyAmazonS3Context {
179    pub inner: Arc<AmazonS3>,
180    pub bucket_name: String,
181}
182
183#[pymethods]
184impl PyAmazonS3Context {
185    #[allow(clippy::too_many_arguments)]
186    #[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, session_token=None, endpoint=None, allow_http=false, imdsv1_fallback=false))]
187    #[new]
188    fn new(
189        bucket_name: String,
190        region: Option<String>,
191        access_key_id: Option<String>,
192        secret_access_key: Option<String>,
193        session_token: Option<String>,
194        endpoint: Option<String>,
195        //retry_config: RetryConfig,
196        allow_http: bool,
197        imdsv1_fallback: bool,
198    ) -> Self {
199        // start w/ the options that come directly from the environment
200        let mut builder = AmazonS3Builder::from_env();
201
202        if let Some(region) = region {
203            builder = builder.with_region(region);
204        }
205
206        if let Some(access_key_id) = access_key_id {
207            builder = builder.with_access_key_id(access_key_id);
208        };
209
210        if let Some(secret_access_key) = secret_access_key {
211            builder = builder.with_secret_access_key(secret_access_key);
212        };
213
214        if let Some(session_token) = session_token {
215            builder = builder.with_token(session_token);
216        }
217
218        if let Some(endpoint) = endpoint {
219            builder = builder.with_endpoint(endpoint);
220        };
221
222        if imdsv1_fallback {
223            builder = builder.with_imdsv1_fallback();
224        };
225
226        let store = builder
227            .with_bucket_name(bucket_name.clone())
228            //.with_retry_config(retry_config) #TODO: add later
229            .with_allow_http(allow_http)
230            .build()
231            .expect("failed to build AmazonS3");
232
233        Self {
234            inner: Arc::new(store),
235            bucket_name,
236        }
237    }
238}
239
240#[pyclass(frozen, name = "Http", module = "datafusion.store", subclass)]
241#[derive(Debug, Clone)]
242pub struct PyHttpContext {
243    pub url: String,
244    pub store: Arc<HttpStore>,
245}
246
247#[pymethods]
248impl PyHttpContext {
249    #[new]
250    fn new(url: String) -> PyResult<Self> {
251        let store = match Url::parse(url.as_str()) {
252            Ok(url) => HttpBuilder::new()
253                .with_url(url.origin().ascii_serialization())
254                .build(),
255            Err(_) => HttpBuilder::new().build(),
256        }
257        .map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))?;
258
259        Ok(Self {
260            url,
261            store: Arc::new(store),
262        })
263    }
264}
265
266pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
267    m.add_class::<PyAmazonS3Context>()?;
268    m.add_class::<PyMicrosoftAzureContext>()?;
269    m.add_class::<PyGoogleCloudContext>()?;
270    m.add_class::<PyLocalFileSystemContext>()?;
271    m.add_class::<PyHttpContext>()?;
272    Ok(())
273}