datafusion_cli/
catalog.rs1use std::any::Any;
19use std::sync::{Arc, Weak};
20
21use crate::object_storage::{get_object_store, AwsOptions, GcpOptions};
22
23use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider};
24
25use datafusion::common::plan_datafusion_err;
26use datafusion::datasource::listing::ListingTableUrl;
27use datafusion::datasource::TableProvider;
28use datafusion::error::Result;
29use datafusion::execution::context::SessionState;
30use datafusion::execution::session_state::SessionStateBuilder;
31
32use async_trait::async_trait;
33use dirs::home_dir;
34use parking_lot::RwLock;
35
36#[derive(Debug)]
38pub struct DynamicObjectStoreCatalog {
39 inner: Arc<dyn CatalogProviderList>,
40 state: Weak<RwLock<SessionState>>,
41}
42
43impl DynamicObjectStoreCatalog {
44 pub fn new(
45 inner: Arc<dyn CatalogProviderList>,
46 state: Weak<RwLock<SessionState>>,
47 ) -> Self {
48 Self { inner, state }
49 }
50}
51
52impl CatalogProviderList for DynamicObjectStoreCatalog {
53 fn as_any(&self) -> &dyn Any {
54 self
55 }
56
57 fn register_catalog(
58 &self,
59 name: String,
60 catalog: Arc<dyn CatalogProvider>,
61 ) -> Option<Arc<dyn CatalogProvider>> {
62 self.inner.register_catalog(name, catalog)
63 }
64
65 fn catalog_names(&self) -> Vec<String> {
66 self.inner.catalog_names()
67 }
68
69 fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
70 let state = self.state.clone();
71 self.inner.catalog(name).map(|catalog| {
72 Arc::new(DynamicObjectStoreCatalogProvider::new(catalog, state)) as _
73 })
74 }
75}
76
77#[derive(Debug)]
79struct DynamicObjectStoreCatalogProvider {
80 inner: Arc<dyn CatalogProvider>,
81 state: Weak<RwLock<SessionState>>,
82}
83
84impl DynamicObjectStoreCatalogProvider {
85 pub fn new(
86 inner: Arc<dyn CatalogProvider>,
87 state: Weak<RwLock<SessionState>>,
88 ) -> Self {
89 Self { inner, state }
90 }
91}
92
93impl CatalogProvider for DynamicObjectStoreCatalogProvider {
94 fn as_any(&self) -> &dyn Any {
95 self
96 }
97
98 fn schema_names(&self) -> Vec<String> {
99 self.inner.schema_names()
100 }
101
102 fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
103 let state = self.state.clone();
104 self.inner.schema(name).map(|schema| {
105 Arc::new(DynamicObjectStoreSchemaProvider::new(schema, state)) as _
106 })
107 }
108
109 fn register_schema(
110 &self,
111 name: &str,
112 schema: Arc<dyn SchemaProvider>,
113 ) -> Result<Option<Arc<dyn SchemaProvider>>> {
114 self.inner.register_schema(name, schema)
115 }
116}
117
118#[derive(Debug)]
121struct DynamicObjectStoreSchemaProvider {
122 inner: Arc<dyn SchemaProvider>,
123 state: Weak<RwLock<SessionState>>,
124}
125
126impl DynamicObjectStoreSchemaProvider {
127 pub fn new(
128 inner: Arc<dyn SchemaProvider>,
129 state: Weak<RwLock<SessionState>>,
130 ) -> Self {
131 Self { inner, state }
132 }
133}
134
135#[async_trait]
136impl SchemaProvider for DynamicObjectStoreSchemaProvider {
137 fn as_any(&self) -> &dyn Any {
138 self
139 }
140
141 fn table_names(&self) -> Vec<String> {
142 self.inner.table_names()
143 }
144
145 fn register_table(
146 &self,
147 name: String,
148 table: Arc<dyn TableProvider>,
149 ) -> Result<Option<Arc<dyn TableProvider>>> {
150 self.inner.register_table(name, table)
151 }
152
153 async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
154 let inner_table = self.inner.table(name).await;
155 if inner_table.is_ok() {
156 if let Some(inner_table) = inner_table? {
157 return Ok(Some(inner_table));
158 }
159 }
160
161 let mut state = self
164 .state
165 .upgrade()
166 .ok_or_else(|| plan_datafusion_err!("locking error"))?
167 .read()
168 .clone();
169 let mut builder = SessionStateBuilder::from(state.clone());
170 let optimized_name = substitute_tilde(name.to_owned());
171 let table_url = ListingTableUrl::parse(optimized_name.as_str())?;
172 let scheme = table_url.scheme();
173 let url = table_url.as_ref();
174
175 match state.runtime_env().object_store_registry.get_store(url) {
180 Ok(_) => { }
181 Err(_) => {
182 match scheme {
185 "s3" | "oss" | "cos" => {
186 if let Some(table_options) = builder.table_options() {
187 table_options.extensions.insert(AwsOptions::default())
188 }
189 }
190 "gs" | "gcs" => {
191 if let Some(table_options) = builder.table_options() {
192 table_options.extensions.insert(GcpOptions::default())
193 }
194 }
195 _ => {}
196 };
197 state = builder.build();
198 let store = get_object_store(
199 &state,
200 table_url.scheme(),
201 url,
202 &state.default_table_options(),
203 false,
204 )
205 .await?;
206 state.runtime_env().register_object_store(url, store);
207 }
208 }
209 self.inner.table(name).await
210 }
211
212 fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
213 self.inner.deregister_table(name)
214 }
215
216 fn table_exist(&self, name: &str) -> bool {
217 self.inner.table_exist(name)
218 }
219}
220
221pub fn substitute_tilde(cur: String) -> String {
222 if let Some(usr_dir_path) = home_dir() {
223 if let Some(usr_dir) = usr_dir_path.to_str() {
224 if cur.starts_with('~') && !usr_dir.is_empty() {
225 return cur.replacen('~', usr_dir, 1);
226 }
227 }
228 }
229 cur
230}
231#[cfg(test)]
232mod tests {
233 use std::{env, vec};
234
235 use super::*;
236
237 use datafusion::catalog::SchemaProvider;
238 use datafusion::prelude::SessionContext;
239
240 fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
241 let ctx = SessionContext::new();
242 ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new(
243 ctx.state().catalog_list().clone(),
244 ctx.state_weak_ref(),
245 )));
246
247 let provider = &DynamicObjectStoreCatalog::new(
248 ctx.state().catalog_list().clone(),
249 ctx.state_weak_ref(),
250 ) as &dyn CatalogProviderList;
251 let catalog = provider
252 .catalog(provider.catalog_names().first().unwrap())
253 .unwrap();
254 let schema = catalog
255 .schema(catalog.schema_names().first().unwrap())
256 .unwrap();
257 (ctx, schema)
258 }
259
260 #[tokio::test]
261 async fn query_http_location_test() -> Result<()> {
262 let domain = "example.com";
265 let location = format!("http://{domain}/file.parquet");
266
267 let (ctx, schema) = setup_context();
268
269 let table = schema.table(&location).await?;
271 assert!(table.is_none());
272
273 let store = ctx
275 .runtime_env()
276 .object_store(ListingTableUrl::parse(location)?)?;
277
278 assert_eq!(format!("{store}"), "HttpStore");
279
280 let expected_domain = format!("Domain(\"{domain}\")");
282 assert!(format!("{store:?}").contains(&expected_domain));
283
284 Ok(())
285 }
286
287 #[tokio::test]
288 async fn query_s3_location_test() -> Result<()> {
289 let aws_envs = vec![
290 "AWS_ENDPOINT",
291 "AWS_ACCESS_KEY_ID",
292 "AWS_SECRET_ACCESS_KEY",
293 "AWS_ALLOW_HTTP",
294 ];
295 for aws_env in aws_envs {
296 if env::var(aws_env).is_err() {
297 eprint!("aws envs not set, skipping s3 test");
298 return Ok(());
299 }
300 }
301
302 let bucket = "examples3bucket";
303 let location = format!("s3://{bucket}/file.parquet");
304
305 let (ctx, schema) = setup_context();
306
307 let table = schema.table(&location).await?;
308 assert!(table.is_none());
309
310 let store = ctx
311 .runtime_env()
312 .object_store(ListingTableUrl::parse(location)?)?;
313 assert_eq!(format!("{store}"), format!("AmazonS3({bucket})"));
314
315 let expected_bucket = format!("bucket: \"{bucket}\"");
317 assert!(format!("{store:?}").contains(&expected_bucket));
318
319 Ok(())
320 }
321
322 #[tokio::test]
323 async fn query_gs_location_test() -> Result<()> {
324 let bucket = "examplegsbucket";
325 let location = format!("gs://{bucket}/file.parquet");
326
327 let (ctx, schema) = setup_context();
328
329 let table = schema.table(&location).await?;
330 assert!(table.is_none());
331
332 let store = ctx
333 .runtime_env()
334 .object_store(ListingTableUrl::parse(location)?)?;
335 assert_eq!(format!("{store}"), format!("GoogleCloudStorage({bucket})"));
336
337 let expected_bucket = format!("bucket_name_encoded: \"{bucket}\"");
339 assert!(format!("{store:?}").contains(&expected_bucket));
340
341 Ok(())
342 }
343
344 #[tokio::test]
345 async fn query_invalid_location_test() {
346 let location = "ts://file.parquet";
347 let (_ctx, schema) = setup_context();
348
349 assert!(schema.table(location).await.is_err());
350 }
351
352 #[cfg(not(target_os = "windows"))]
353 #[test]
354 fn test_substitute_tilde() {
355 use std::{env, path::PathBuf};
356 let original_home = home_dir();
357 let test_home_path = if cfg!(windows) {
358 "C:\\Users\\user"
359 } else {
360 "/home/user"
361 };
362 env::set_var(
363 if cfg!(windows) { "USERPROFILE" } else { "HOME" },
364 test_home_path,
365 );
366 let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet";
367 let expected = PathBuf::from(test_home_path)
368 .join("Code")
369 .join("datafusion")
370 .join("benchmarks")
371 .join("data")
372 .join("tpch_sf1")
373 .join("part")
374 .join("part-0.parquet")
375 .to_string_lossy()
376 .to_string();
377 let actual = substitute_tilde(input.to_string());
378 assert_eq!(actual, expected);
379 match original_home {
380 Some(home_path) => env::set_var(
381 if cfg!(windows) { "USERPROFILE" } else { "HOME" },
382 home_path.to_str().unwrap(),
383 ),
384 None => env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }),
385 }
386 }
387}