1use std::sync::Arc;
19
20use pyo3::prelude::*;
21
22use object_store::aws::{AmazonS3, AmazonS3Builder};
23use object_store::azure::{MicrosoftAzure, MicrosoftAzureBuilder};
24use object_store::gcp::{GoogleCloudStorage, GoogleCloudStorageBuilder};
25use object_store::http::{HttpBuilder, HttpStore};
26use object_store::local::LocalFileSystem;
27use pyo3::exceptions::PyValueError;
28use url::Url;
29
30#[derive(FromPyObject)]
31pub enum StorageContexts {
32 AmazonS3(PyAmazonS3Context),
33 GoogleCloudStorage(PyGoogleCloudContext),
34 MicrosoftAzure(PyMicrosoftAzureContext),
35 LocalFileSystem(PyLocalFileSystemContext),
36 HTTP(PyHttpContext),
37}
38
39#[pyclass(
40 frozen,
41 name = "LocalFileSystem",
42 module = "datafusion.store",
43 subclass
44)]
45#[derive(Debug, Clone)]
46pub struct PyLocalFileSystemContext {
47 pub inner: Arc<LocalFileSystem>,
48}
49
50#[pymethods]
51impl PyLocalFileSystemContext {
52 #[pyo3(signature = (prefix=None))]
53 #[new]
54 fn new(prefix: Option<String>) -> Self {
55 if let Some(prefix) = prefix {
56 Self {
57 inner: Arc::new(
58 LocalFileSystem::new_with_prefix(prefix)
59 .expect("Could not create local LocalFileSystem"),
60 ),
61 }
62 } else {
63 Self {
64 inner: Arc::new(LocalFileSystem::new()),
65 }
66 }
67 }
68}
69
70#[pyclass(frozen, name = "MicrosoftAzure", module = "datafusion.store", subclass)]
71#[derive(Debug, Clone)]
72pub struct PyMicrosoftAzureContext {
73 pub inner: Arc<MicrosoftAzure>,
74 pub container_name: String,
75}
76
77#[pymethods]
78impl PyMicrosoftAzureContext {
79 #[allow(clippy::too_many_arguments)]
80 #[pyo3(signature = (container_name, account=None, access_key=None, bearer_token=None, client_id=None, client_secret=None, tenant_id=None, sas_query_pairs=None, use_emulator=None, allow_http=None))]
81 #[new]
82 fn new(
83 container_name: String,
84 account: Option<String>,
85 access_key: Option<String>,
86 bearer_token: Option<String>,
87 client_id: Option<String>,
88 client_secret: Option<String>,
89 tenant_id: Option<String>,
90 sas_query_pairs: Option<Vec<(String, String)>>,
91 use_emulator: Option<bool>,
92 allow_http: Option<bool>,
93 ) -> Self {
94 let mut builder = MicrosoftAzureBuilder::from_env().with_container_name(&container_name);
95
96 if let Some(account) = account {
97 builder = builder.with_account(account);
98 }
99
100 if let Some(access_key) = access_key {
101 builder = builder.with_access_key(access_key);
102 }
103
104 if let Some(bearer_token) = bearer_token {
105 builder = builder.with_bearer_token_authorization(bearer_token);
106 }
107
108 match (client_id, client_secret, tenant_id) {
109 (Some(client_id), Some(client_secret), Some(tenant_id)) => {
110 builder =
111 builder.with_client_secret_authorization(client_id, client_secret, tenant_id);
112 }
113 (None, None, None) => {}
114 _ => {
115 panic!("client_id, client_secret, tenat_id must be all set or all None");
116 }
117 }
118
119 if let Some(sas_query_pairs) = sas_query_pairs {
120 builder = builder.with_sas_authorization(sas_query_pairs);
121 }
122
123 if let Some(use_emulator) = use_emulator {
124 builder = builder.with_use_emulator(use_emulator);
125 }
126
127 if let Some(allow_http) = allow_http {
128 builder = builder.with_allow_http(allow_http);
129 }
130
131 Self {
132 inner: Arc::new(
133 builder
134 .build()
135 .expect("Could not create Azure Storage context"), ),
137 container_name,
138 }
139 }
140}
141
142#[pyclass(frozen, name = "GoogleCloud", module = "datafusion.store", subclass)]
143#[derive(Debug, Clone)]
144pub struct PyGoogleCloudContext {
145 pub inner: Arc<GoogleCloudStorage>,
146 pub bucket_name: String,
147}
148
149#[pymethods]
150impl PyGoogleCloudContext {
151 #[allow(clippy::too_many_arguments)]
152 #[pyo3(signature = (bucket_name, service_account_path=None))]
153 #[new]
154 fn new(bucket_name: String, service_account_path: Option<String>) -> Self {
155 let mut builder = GoogleCloudStorageBuilder::new().with_bucket_name(&bucket_name);
156
157 if let Some(credential_path) = service_account_path {
158 builder = builder.with_service_account_path(credential_path);
159 }
160
161 Self {
162 inner: Arc::new(
163 builder
164 .build()
165 .expect("Could not create Google Cloud Storage"),
166 ),
167 bucket_name,
168 }
169 }
170}
171
172#[pyclass(frozen, name = "AmazonS3", module = "datafusion.store", subclass)]
173#[derive(Debug, Clone)]
174pub struct PyAmazonS3Context {
175 pub inner: Arc<AmazonS3>,
176 pub bucket_name: String,
177}
178
179#[pymethods]
180impl PyAmazonS3Context {
181 #[allow(clippy::too_many_arguments)]
182 #[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, endpoint=None, allow_http=false, imdsv1_fallback=false))]
183 #[new]
184 fn new(
185 bucket_name: String,
186 region: Option<String>,
187 access_key_id: Option<String>,
188 secret_access_key: Option<String>,
189 endpoint: Option<String>,
190 allow_http: bool,
192 imdsv1_fallback: bool,
193 ) -> Self {
194 let mut builder = AmazonS3Builder::from_env();
196
197 if let Some(region) = region {
198 builder = builder.with_region(region);
199 }
200
201 if let Some(access_key_id) = access_key_id {
202 builder = builder.with_access_key_id(access_key_id);
203 };
204
205 if let Some(secret_access_key) = secret_access_key {
206 builder = builder.with_secret_access_key(secret_access_key);
207 };
208
209 if let Some(endpoint) = endpoint {
210 builder = builder.with_endpoint(endpoint);
211 };
212
213 if imdsv1_fallback {
214 builder = builder.with_imdsv1_fallback();
215 };
216
217 let store = builder
218 .with_bucket_name(bucket_name.clone())
219 .with_allow_http(allow_http)
221 .build()
222 .expect("failed to build AmazonS3");
223
224 Self {
225 inner: Arc::new(store),
226 bucket_name,
227 }
228 }
229}
230
231#[pyclass(frozen, name = "Http", module = "datafusion.store", subclass)]
232#[derive(Debug, Clone)]
233pub struct PyHttpContext {
234 pub url: String,
235 pub store: Arc<HttpStore>,
236}
237
238#[pymethods]
239impl PyHttpContext {
240 #[new]
241 fn new(url: String) -> PyResult<Self> {
242 let store = match Url::parse(url.as_str()) {
243 Ok(url) => HttpBuilder::new()
244 .with_url(url.origin().ascii_serialization())
245 .build(),
246 Err(_) => HttpBuilder::new().build(),
247 }
248 .map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))?;
249
250 Ok(Self {
251 url,
252 store: Arc::new(store),
253 })
254 }
255}
256
257pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
258 m.add_class::<PyAmazonS3Context>()?;
259 m.add_class::<PyMicrosoftAzureContext>()?;
260 m.add_class::<PyGoogleCloudContext>()?;
261 m.add_class::<PyLocalFileSystemContext>()?;
262 m.add_class::<PyHttpContext>()?;
263 Ok(())
264}