datafusion_cli/
catalog.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::{Arc, Weak};
20
21use crate::object_storage::{get_object_store, AwsOptions, GcpOptions};
22
23use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider};
24
25use datafusion::common::plan_datafusion_err;
26use datafusion::datasource::listing::ListingTableUrl;
27use datafusion::datasource::TableProvider;
28use datafusion::error::Result;
29use datafusion::execution::context::SessionState;
30use datafusion::execution::session_state::SessionStateBuilder;
31
32use async_trait::async_trait;
33use dirs::home_dir;
34use parking_lot::RwLock;
35
36/// Wraps another catalog, automatically register require object stores for the file locations
37#[derive(Debug)]
38pub struct DynamicObjectStoreCatalog {
39    inner: Arc<dyn CatalogProviderList>,
40    state: Weak<RwLock<SessionState>>,
41}
42
43impl DynamicObjectStoreCatalog {
44    pub fn new(
45        inner: Arc<dyn CatalogProviderList>,
46        state: Weak<RwLock<SessionState>>,
47    ) -> Self {
48        Self { inner, state }
49    }
50}
51
52impl CatalogProviderList for DynamicObjectStoreCatalog {
53    fn as_any(&self) -> &dyn Any {
54        self
55    }
56
57    fn register_catalog(
58        &self,
59        name: String,
60        catalog: Arc<dyn CatalogProvider>,
61    ) -> Option<Arc<dyn CatalogProvider>> {
62        self.inner.register_catalog(name, catalog)
63    }
64
65    fn catalog_names(&self) -> Vec<String> {
66        self.inner.catalog_names()
67    }
68
69    fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
70        let state = self.state.clone();
71        self.inner.catalog(name).map(|catalog| {
72            Arc::new(DynamicObjectStoreCatalogProvider::new(catalog, state)) as _
73        })
74    }
75}
76
77/// Wraps another catalog provider
78#[derive(Debug)]
79struct DynamicObjectStoreCatalogProvider {
80    inner: Arc<dyn CatalogProvider>,
81    state: Weak<RwLock<SessionState>>,
82}
83
84impl DynamicObjectStoreCatalogProvider {
85    pub fn new(
86        inner: Arc<dyn CatalogProvider>,
87        state: Weak<RwLock<SessionState>>,
88    ) -> Self {
89        Self { inner, state }
90    }
91}
92
93impl CatalogProvider for DynamicObjectStoreCatalogProvider {
94    fn as_any(&self) -> &dyn Any {
95        self
96    }
97
98    fn schema_names(&self) -> Vec<String> {
99        self.inner.schema_names()
100    }
101
102    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
103        let state = self.state.clone();
104        self.inner.schema(name).map(|schema| {
105            Arc::new(DynamicObjectStoreSchemaProvider::new(schema, state)) as _
106        })
107    }
108
109    fn register_schema(
110        &self,
111        name: &str,
112        schema: Arc<dyn SchemaProvider>,
113    ) -> Result<Option<Arc<dyn SchemaProvider>>> {
114        self.inner.register_schema(name, schema)
115    }
116}
117
118/// Wraps another schema provider. [DynamicObjectStoreSchemaProvider] is responsible for registering the required
119/// object stores for the file locations.
120#[derive(Debug)]
121struct DynamicObjectStoreSchemaProvider {
122    inner: Arc<dyn SchemaProvider>,
123    state: Weak<RwLock<SessionState>>,
124}
125
126impl DynamicObjectStoreSchemaProvider {
127    pub fn new(
128        inner: Arc<dyn SchemaProvider>,
129        state: Weak<RwLock<SessionState>>,
130    ) -> Self {
131        Self { inner, state }
132    }
133}
134
135#[async_trait]
136impl SchemaProvider for DynamicObjectStoreSchemaProvider {
137    fn as_any(&self) -> &dyn Any {
138        self
139    }
140
141    fn table_names(&self) -> Vec<String> {
142        self.inner.table_names()
143    }
144
145    fn register_table(
146        &self,
147        name: String,
148        table: Arc<dyn TableProvider>,
149    ) -> Result<Option<Arc<dyn TableProvider>>> {
150        self.inner.register_table(name, table)
151    }
152
153    async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
154        let inner_table = self.inner.table(name).await;
155        if inner_table.is_ok() {
156            if let Some(inner_table) = inner_table? {
157                return Ok(Some(inner_table));
158            }
159        }
160
161        // if the inner schema provider didn't have a table by
162        // that name, try to treat it as a listing table
163        let mut state = self
164            .state
165            .upgrade()
166            .ok_or_else(|| plan_datafusion_err!("locking error"))?
167            .read()
168            .clone();
169        let mut builder = SessionStateBuilder::from(state.clone());
170        let optimized_name = substitute_tilde(name.to_owned());
171        let table_url = ListingTableUrl::parse(optimized_name.as_str())?;
172        let scheme = table_url.scheme();
173        let url = table_url.as_ref();
174
175        // If the store is already registered for this URL then `get_store`
176        // will return `Ok` which means we don't need to register it again. However,
177        // if `get_store` returns an `Err` then it means the corresponding store is
178        // not registered yet and we need to register it
179        match state.runtime_env().object_store_registry.get_store(url) {
180            Ok(_) => { /*Nothing to do here, store for this URL is already registered*/ }
181            Err(_) => {
182                // Register the store for this URL. Here we don't have access
183                // to any command options so the only choice is to use an empty collection
184                match scheme {
185                    "s3" | "oss" | "cos" => {
186                        if let Some(table_options) = builder.table_options() {
187                            table_options.extensions.insert(AwsOptions::default())
188                        }
189                    }
190                    "gs" | "gcs" => {
191                        if let Some(table_options) = builder.table_options() {
192                            table_options.extensions.insert(GcpOptions::default())
193                        }
194                    }
195                    _ => {}
196                };
197                state = builder.build();
198                let store = get_object_store(
199                    &state,
200                    table_url.scheme(),
201                    url,
202                    &state.default_table_options(),
203                    false,
204                )
205                .await?;
206                state.runtime_env().register_object_store(url, store);
207            }
208        }
209        self.inner.table(name).await
210    }
211
212    fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
213        self.inner.deregister_table(name)
214    }
215
216    fn table_exist(&self, name: &str) -> bool {
217        self.inner.table_exist(name)
218    }
219}
220
221pub fn substitute_tilde(cur: String) -> String {
222    if let Some(usr_dir_path) = home_dir() {
223        if let Some(usr_dir) = usr_dir_path.to_str() {
224            if cur.starts_with('~') && !usr_dir.is_empty() {
225                return cur.replacen('~', usr_dir, 1);
226            }
227        }
228    }
229    cur
230}
231#[cfg(test)]
232mod tests {
233    use std::{env, vec};
234
235    use super::*;
236
237    use datafusion::catalog::SchemaProvider;
238    use datafusion::prelude::SessionContext;
239
240    fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
241        let ctx = SessionContext::new();
242        ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new(
243            ctx.state().catalog_list().clone(),
244            ctx.state_weak_ref(),
245        )));
246
247        let provider = &DynamicObjectStoreCatalog::new(
248            ctx.state().catalog_list().clone(),
249            ctx.state_weak_ref(),
250        ) as &dyn CatalogProviderList;
251        let catalog = provider
252            .catalog(provider.catalog_names().first().unwrap())
253            .unwrap();
254        let schema = catalog
255            .schema(catalog.schema_names().first().unwrap())
256            .unwrap();
257        (ctx, schema)
258    }
259
260    #[tokio::test]
261    async fn query_http_location_test() -> Result<()> {
262        // This is a unit test so not expecting a connection or a file to be
263        // available
264        let domain = "example.com";
265        let location = format!("http://{domain}/file.parquet");
266
267        let (ctx, schema) = setup_context();
268
269        // That's a non registered table so expecting None here
270        let table = schema.table(&location).await?;
271        assert!(table.is_none());
272
273        // It should still create an object store for the location in the SessionState
274        let store = ctx
275            .runtime_env()
276            .object_store(ListingTableUrl::parse(location)?)?;
277
278        assert_eq!(format!("{store}"), "HttpStore");
279
280        // The store must be configured for this domain
281        let expected_domain = format!("Domain(\"{domain}\")");
282        assert!(format!("{store:?}").contains(&expected_domain));
283
284        Ok(())
285    }
286
287    #[tokio::test]
288    async fn query_s3_location_test() -> Result<()> {
289        let aws_envs = vec![
290            "AWS_ENDPOINT",
291            "AWS_ACCESS_KEY_ID",
292            "AWS_SECRET_ACCESS_KEY",
293            "AWS_ALLOW_HTTP",
294        ];
295        for aws_env in aws_envs {
296            if env::var(aws_env).is_err() {
297                eprint!("aws envs not set, skipping s3 test");
298                return Ok(());
299            }
300        }
301
302        let bucket = "examples3bucket";
303        let location = format!("s3://{bucket}/file.parquet");
304
305        let (ctx, schema) = setup_context();
306
307        let table = schema.table(&location).await?;
308        assert!(table.is_none());
309
310        let store = ctx
311            .runtime_env()
312            .object_store(ListingTableUrl::parse(location)?)?;
313        assert_eq!(format!("{store}"), format!("AmazonS3({bucket})"));
314
315        // The store must be configured for this domain
316        let expected_bucket = format!("bucket: \"{bucket}\"");
317        assert!(format!("{store:?}").contains(&expected_bucket));
318
319        Ok(())
320    }
321
322    #[tokio::test]
323    async fn query_gs_location_test() -> Result<()> {
324        let bucket = "examplegsbucket";
325        let location = format!("gs://{bucket}/file.parquet");
326
327        let (ctx, schema) = setup_context();
328
329        let table = schema.table(&location).await?;
330        assert!(table.is_none());
331
332        let store = ctx
333            .runtime_env()
334            .object_store(ListingTableUrl::parse(location)?)?;
335        assert_eq!(format!("{store}"), format!("GoogleCloudStorage({bucket})"));
336
337        // The store must be configured for this domain
338        let expected_bucket = format!("bucket_name_encoded: \"{bucket}\"");
339        assert!(format!("{store:?}").contains(&expected_bucket));
340
341        Ok(())
342    }
343
344    #[tokio::test]
345    async fn query_invalid_location_test() {
346        let location = "ts://file.parquet";
347        let (_ctx, schema) = setup_context();
348
349        assert!(schema.table(location).await.is_err());
350    }
351
352    #[cfg(not(target_os = "windows"))]
353    #[test]
354    fn test_substitute_tilde() {
355        use std::{env, path::PathBuf};
356        let original_home = home_dir();
357        let test_home_path = if cfg!(windows) {
358            "C:\\Users\\user"
359        } else {
360            "/home/user"
361        };
362        env::set_var(
363            if cfg!(windows) { "USERPROFILE" } else { "HOME" },
364            test_home_path,
365        );
366        let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet";
367        let expected = PathBuf::from(test_home_path)
368            .join("Code")
369            .join("datafusion")
370            .join("benchmarks")
371            .join("data")
372            .join("tpch_sf1")
373            .join("part")
374            .join("part-0.parquet")
375            .to_string_lossy()
376            .to_string();
377        let actual = substitute_tilde(input.to_string());
378        assert_eq!(actual, expected);
379        match original_home {
380            Some(home_path) => env::set_var(
381                if cfg!(windows) { "USERPROFILE" } else { "HOME" },
382                home_path.to_str().unwrap(),
383            ),
384            None => env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }),
385        }
386    }
387}