datafusion_cli/
catalog.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::{Arc, Weak};
20
21use crate::object_storage::{get_object_store, AwsOptions, GcpOptions};
22
23use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider};
24
25use datafusion::common::plan_datafusion_err;
26use datafusion::datasource::listing::ListingTableUrl;
27use datafusion::datasource::TableProvider;
28use datafusion::error::Result;
29use datafusion::execution::context::SessionState;
30use datafusion::execution::session_state::SessionStateBuilder;
31
32use async_trait::async_trait;
33use dirs::home_dir;
34use parking_lot::RwLock;
35
36/// Wraps another catalog, automatically register require object stores for the file locations
37#[derive(Debug)]
38pub struct DynamicObjectStoreCatalog {
39    inner: Arc<dyn CatalogProviderList>,
40    state: Weak<RwLock<SessionState>>,
41}
42
43impl DynamicObjectStoreCatalog {
44    pub fn new(
45        inner: Arc<dyn CatalogProviderList>,
46        state: Weak<RwLock<SessionState>>,
47    ) -> Self {
48        Self { inner, state }
49    }
50}
51
52impl CatalogProviderList for DynamicObjectStoreCatalog {
53    fn as_any(&self) -> &dyn Any {
54        self
55    }
56
57    fn register_catalog(
58        &self,
59        name: String,
60        catalog: Arc<dyn CatalogProvider>,
61    ) -> Option<Arc<dyn CatalogProvider>> {
62        self.inner.register_catalog(name, catalog)
63    }
64
65    fn catalog_names(&self) -> Vec<String> {
66        self.inner.catalog_names()
67    }
68
69    fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
70        let state = self.state.clone();
71        self.inner.catalog(name).map(|catalog| {
72            Arc::new(DynamicObjectStoreCatalogProvider::new(catalog, state)) as _
73        })
74    }
75}
76
77/// Wraps another catalog provider
78#[derive(Debug)]
79struct DynamicObjectStoreCatalogProvider {
80    inner: Arc<dyn CatalogProvider>,
81    state: Weak<RwLock<SessionState>>,
82}
83
84impl DynamicObjectStoreCatalogProvider {
85    pub fn new(
86        inner: Arc<dyn CatalogProvider>,
87        state: Weak<RwLock<SessionState>>,
88    ) -> Self {
89        Self { inner, state }
90    }
91}
92
93impl CatalogProvider for DynamicObjectStoreCatalogProvider {
94    fn as_any(&self) -> &dyn Any {
95        self
96    }
97
98    fn schema_names(&self) -> Vec<String> {
99        self.inner.schema_names()
100    }
101
102    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
103        let state = self.state.clone();
104        self.inner.schema(name).map(|schema| {
105            Arc::new(DynamicObjectStoreSchemaProvider::new(schema, state)) as _
106        })
107    }
108
109    fn register_schema(
110        &self,
111        name: &str,
112        schema: Arc<dyn SchemaProvider>,
113    ) -> Result<Option<Arc<dyn SchemaProvider>>> {
114        self.inner.register_schema(name, schema)
115    }
116}
117
118/// Wraps another schema provider. [DynamicObjectStoreSchemaProvider] is responsible for registering the required
119/// object stores for the file locations.
120#[derive(Debug)]
121struct DynamicObjectStoreSchemaProvider {
122    inner: Arc<dyn SchemaProvider>,
123    state: Weak<RwLock<SessionState>>,
124}
125
126impl DynamicObjectStoreSchemaProvider {
127    pub fn new(
128        inner: Arc<dyn SchemaProvider>,
129        state: Weak<RwLock<SessionState>>,
130    ) -> Self {
131        Self { inner, state }
132    }
133}
134
135#[async_trait]
136impl SchemaProvider for DynamicObjectStoreSchemaProvider {
137    fn as_any(&self) -> &dyn Any {
138        self
139    }
140
141    fn table_names(&self) -> Vec<String> {
142        self.inner.table_names()
143    }
144
145    fn register_table(
146        &self,
147        name: String,
148        table: Arc<dyn TableProvider>,
149    ) -> Result<Option<Arc<dyn TableProvider>>> {
150        self.inner.register_table(name, table)
151    }
152
153    async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
154        let inner_table = self.inner.table(name).await;
155        if inner_table.is_ok() {
156            if let Some(inner_table) = inner_table? {
157                return Ok(Some(inner_table));
158            }
159        }
160
161        // if the inner schema provider didn't have a table by
162        // that name, try to treat it as a listing table
163        let mut state = self
164            .state
165            .upgrade()
166            .ok_or_else(|| plan_datafusion_err!("locking error"))?
167            .read()
168            .clone();
169        let mut builder = SessionStateBuilder::from(state.clone());
170        let optimized_name = substitute_tilde(name.to_owned());
171        let table_url = ListingTableUrl::parse(optimized_name.as_str())?;
172        let scheme = table_url.scheme();
173        let url = table_url.as_ref();
174
175        // If the store is already registered for this URL then `get_store`
176        // will return `Ok` which means we don't need to register it again. However,
177        // if `get_store` returns an `Err` then it means the corresponding store is
178        // not registered yet and we need to register it
179        match state.runtime_env().object_store_registry.get_store(url) {
180            Ok(_) => { /*Nothing to do here, store for this URL is already registered*/ }
181            Err(_) => {
182                // Register the store for this URL. Here we don't have access
183                // to any command options so the only choice is to use an empty collection
184                match scheme {
185                    "s3" | "oss" | "cos" => {
186                        if let Some(table_options) = builder.table_options() {
187                            table_options.extensions.insert(AwsOptions::default())
188                        }
189                    }
190                    "gs" | "gcs" => {
191                        if let Some(table_options) = builder.table_options() {
192                            table_options.extensions.insert(GcpOptions::default())
193                        }
194                    }
195                    _ => {}
196                };
197                state = builder.build();
198                let store = get_object_store(
199                    &state,
200                    table_url.scheme(),
201                    url,
202                    &state.default_table_options(),
203                )
204                .await?;
205                state.runtime_env().register_object_store(url, store);
206            }
207        }
208        self.inner.table(name).await
209    }
210
211    fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
212        self.inner.deregister_table(name)
213    }
214
215    fn table_exist(&self, name: &str) -> bool {
216        self.inner.table_exist(name)
217    }
218}
219
220pub fn substitute_tilde(cur: String) -> String {
221    if let Some(usr_dir_path) = home_dir() {
222        if let Some(usr_dir) = usr_dir_path.to_str() {
223            if cur.starts_with('~') && !usr_dir.is_empty() {
224                return cur.replacen('~', usr_dir, 1);
225            }
226        }
227    }
228    cur
229}
230#[cfg(test)]
231mod tests {
232
233    use super::*;
234
235    use datafusion::catalog::SchemaProvider;
236    use datafusion::prelude::SessionContext;
237
238    fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
239        let ctx = SessionContext::new();
240        ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new(
241            ctx.state().catalog_list().clone(),
242            ctx.state_weak_ref(),
243        )));
244
245        let provider = &DynamicObjectStoreCatalog::new(
246            ctx.state().catalog_list().clone(),
247            ctx.state_weak_ref(),
248        ) as &dyn CatalogProviderList;
249        let catalog = provider
250            .catalog(provider.catalog_names().first().unwrap())
251            .unwrap();
252        let schema = catalog
253            .schema(catalog.schema_names().first().unwrap())
254            .unwrap();
255        (ctx, schema)
256    }
257
258    #[tokio::test]
259    async fn query_http_location_test() -> Result<()> {
260        // This is a unit test so not expecting a connection or a file to be
261        // available
262        let domain = "example.com";
263        let location = format!("http://{domain}/file.parquet");
264
265        let (ctx, schema) = setup_context();
266
267        // That's a non registered table so expecting None here
268        let table = schema.table(&location).await?;
269        assert!(table.is_none());
270
271        // It should still create an object store for the location in the SessionState
272        let store = ctx
273            .runtime_env()
274            .object_store(ListingTableUrl::parse(location)?)?;
275
276        assert_eq!(format!("{store}"), "HttpStore");
277
278        // The store must be configured for this domain
279        let expected_domain = format!("Domain(\"{domain}\")");
280        assert!(format!("{store:?}").contains(&expected_domain));
281
282        Ok(())
283    }
284
285    #[tokio::test]
286    async fn query_s3_location_test() -> Result<()> {
287        let bucket = "examples3bucket";
288        let location = format!("s3://{bucket}/file.parquet");
289
290        let (ctx, schema) = setup_context();
291
292        let table = schema.table(&location).await?;
293        assert!(table.is_none());
294
295        let store = ctx
296            .runtime_env()
297            .object_store(ListingTableUrl::parse(location)?)?;
298        assert_eq!(format!("{store}"), format!("AmazonS3({bucket})"));
299
300        // The store must be configured for this domain
301        let expected_bucket = format!("bucket: \"{bucket}\"");
302        assert!(format!("{store:?}").contains(&expected_bucket));
303
304        Ok(())
305    }
306
307    #[tokio::test]
308    async fn query_gs_location_test() -> Result<()> {
309        let bucket = "examplegsbucket";
310        let location = format!("gs://{bucket}/file.parquet");
311
312        let (ctx, schema) = setup_context();
313
314        let table = schema.table(&location).await?;
315        assert!(table.is_none());
316
317        let store = ctx
318            .runtime_env()
319            .object_store(ListingTableUrl::parse(location)?)?;
320        assert_eq!(format!("{store}"), format!("GoogleCloudStorage({bucket})"));
321
322        // The store must be configured for this domain
323        let expected_bucket = format!("bucket_name_encoded: \"{bucket}\"");
324        assert!(format!("{store:?}").contains(&expected_bucket));
325
326        Ok(())
327    }
328
329    #[tokio::test]
330    async fn query_invalid_location_test() {
331        let location = "ts://file.parquet";
332        let (_ctx, schema) = setup_context();
333
334        assert!(schema.table(location).await.is_err());
335    }
336
337    #[cfg(not(target_os = "windows"))]
338    #[test]
339    fn test_substitute_tilde() {
340        use std::env;
341        use std::path::MAIN_SEPARATOR;
342        let original_home = home_dir();
343        let test_home_path = if cfg!(windows) {
344            "C:\\Users\\user"
345        } else {
346            "/home/user"
347        };
348        env::set_var(
349            if cfg!(windows) { "USERPROFILE" } else { "HOME" },
350            test_home_path,
351        );
352        let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet";
353        let expected = format!(
354            "{}{}Code{}datafusion{}benchmarks{}data{}tpch_sf1{}part{}part-0.parquet",
355            test_home_path,
356            MAIN_SEPARATOR,
357            MAIN_SEPARATOR,
358            MAIN_SEPARATOR,
359            MAIN_SEPARATOR,
360            MAIN_SEPARATOR,
361            MAIN_SEPARATOR,
362            MAIN_SEPARATOR
363        );
364        let actual = substitute_tilde(input.to_string());
365        assert_eq!(actual, expected);
366        match original_home {
367            Some(home_path) => env::set_var(
368                if cfg!(windows) { "USERPROFILE" } else { "HOME" },
369                home_path.to_str().unwrap(),
370            ),
371            None => env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }),
372        }
373    }
374}