Skip to main content

datafusion_dft/
db.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use color_eyre::{Report, Result};
21use datafusion::{
22    catalog::{MemoryCatalogProvider, MemorySchemaProvider},
23    datasource::{
24        file_format::{csv::CsvFormat, json::JsonFormat, parquet::ParquetFormat, FileFormat},
25        listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
26    },
27    prelude::SessionContext,
28};
29use log::{debug, info};
30use std::path::Path;
31#[cfg(feature = "vortex")]
32use {vortex_datafusion::VortexFormat, vortex_session::VortexSession};
33
34use crate::config::DbConfig;
35
36/// Detects the file format based on file extension
37fn detect_format(extension: &str) -> Result<(Arc<dyn FileFormat>, &'static str)> {
38    match extension.to_lowercase().as_str() {
39        "parquet" => Ok((Arc::new(ParquetFormat::new()), ".parquet")),
40        "csv" => Ok((Arc::new(CsvFormat::default()), ".csv")),
41        "json" => Ok((Arc::new(JsonFormat::default()), ".json")),
42        #[cfg(feature = "vortex")]
43        "vortex" => Ok((
44            Arc::new(VortexFormat::new(VortexSession::empty())),
45            ".vortex",
46        )),
47        _ => Err(Report::msg(format!(
48            "Unsupported file extension: {}",
49            extension
50        ))),
51    }
52}
53
54pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<()> {
55    info!("registering tables to database");
56    let tables_url = db_config.path.join("tables/")?;
57    let listing_tables_url = ListingTableUrl::parse(tables_url.clone())?;
58    let store_url = listing_tables_url.object_store();
59    let store = ctx.runtime_env().object_store(store_url)?;
60    let tables_path = object_store::path::Path::from_url_path(tables_url.path())?;
61    let catalogs = store.list_with_delimiter(Some(&tables_path)).await?;
62    for catalog in catalogs.common_prefixes {
63        let catalog_name = catalog
64            .filename()
65            .ok_or(Report::msg("missing catalog name"))?;
66        info!("...handling {catalog_name} catalog");
67        let maybe_catalog = ctx.catalog(catalog_name);
68        let catalog_provider = match maybe_catalog {
69            Some(catalog) => catalog,
70            None => {
71                info!("...catalog does not exist, createing");
72                let mem_catalog_provider = Arc::new(MemoryCatalogProvider::new());
73                ctx.register_catalog(catalog_name, mem_catalog_provider);
74                ctx.catalog(catalog_name).ok_or(Report::msg(format!(
75                    "missing catalog {catalog_name}, shouldnt be possible"
76                )))?
77            }
78        };
79        let schemas = store.list_with_delimiter(Some(&catalog)).await?;
80        for schema in schemas.common_prefixes {
81            let schema_name = schema
82                .filename()
83                .ok_or(Report::msg("missing schema name"))?;
84            info!("...handling {schema_name} schema");
85            let maybe_schema = catalog_provider.schema(schema_name);
86            let schema_provider = match maybe_schema {
87                Some(schema) => schema,
88                None => {
89                    info!("...schema does not exist, creating");
90                    let mem_schema_provider = Arc::new(MemorySchemaProvider::new());
91                    catalog_provider.register_schema(schema_name, mem_schema_provider)?;
92                    catalog_provider
93                        .schema(schema_name)
94                        .ok_or(Report::msg(format!(
95                            "missing schema {schema_name}, shouldnt be possible"
96                        )))?
97                }
98            };
99            let tables = store.list_with_delimiter(Some(&schema)).await?;
100            for table_path in tables.common_prefixes {
101                let table_name = table_path
102                    .filename()
103                    .ok_or(Report::msg("missing table name"))?;
104                info!("...handling table \"{catalog_name}.{schema_name}.{table_name}\"");
105
106                let p = tables_url
107                    .join(&format!("{catalog_name}/"))?
108                    .join(&format!("{schema_name}/"))?
109                    .join(&format!("{table_name}/"))?;
110
111                let table_url = ListingTableUrl::parse(p)?;
112                debug!("...table url: {table_url:?}");
113
114                // List files in the table directory to detect the format
115                let files = store.list_with_delimiter(Some(&table_path)).await?;
116
117                // Find the first file with an extension to determine the format
118                let extension = files
119                    .objects
120                    .iter()
121                    .find_map(|obj| {
122                        let path = obj.location.as_ref();
123                        Path::new(path).extension().and_then(|ext| ext.to_str())
124                    })
125                    .ok_or(Report::msg(format!(
126                        "No files with extensions found in table directory: {table_name}"
127                    )))?;
128
129                info!("...detected format: {extension}");
130                let (file_format, file_extension) = detect_format(extension)?;
131
132                let listing_options =
133                    ListingOptions::new(file_format).with_file_extension(file_extension);
134                // Resolve the schema
135                let resolved_schema = listing_options
136                    .infer_schema(&ctx.state(), &table_url)
137                    .await?;
138                let config = ListingTableConfig::new(table_url)
139                    .with_listing_options(listing_options)
140                    .with_schema(resolved_schema);
141                // Create a new TableProvider
142                let provider = Arc::new(ListingTable::try_new(config)?);
143                info!("...table registered");
144                schema_provider.register_table(table_name.to_string(), provider)?;
145            }
146        }
147    }
148
149    Ok(())
150}
151
152#[cfg(test)]
153mod test {
154    use datafusion::{
155        assert_batches_eq,
156        dataframe::DataFrameWriteOptions,
157        prelude::{SessionConfig, SessionContext},
158    };
159
160    use crate::{config::DbConfig, db::register_db};
161
162    fn setup() -> SessionContext {
163        let config = SessionConfig::default().with_information_schema(true);
164        SessionContext::new_with_config(config)
165    }
166
167    #[tokio::test]
168    async fn test_register_db_no_tables() {
169        let ctx = setup();
170        let dir = tempfile::tempdir().unwrap();
171        let db_path = dir.path().join("db");
172        let path = format!("file://{}/", db_path.to_str().unwrap());
173        let db_url = url::Url::parse(&path).unwrap();
174        let config = DbConfig { path: db_url };
175
176        register_db(&ctx, &config).await.unwrap();
177
178        let batches = ctx
179            .sql("SHOW TABLES")
180            .await
181            .unwrap()
182            .collect()
183            .await
184            .unwrap();
185
186        let expected = [
187            "+---------------+--------------------+-------------+------------+",
188            "| table_catalog | table_schema       | table_name  | table_type |",
189            "+---------------+--------------------+-------------+------------+",
190            "| datafusion    | information_schema | tables      | VIEW       |",
191            "| datafusion    | information_schema | views       | VIEW       |",
192            "| datafusion    | information_schema | columns     | VIEW       |",
193            "| datafusion    | information_schema | df_settings | VIEW       |",
194            "| datafusion    | information_schema | schemata    | VIEW       |",
195            "| datafusion    | information_schema | routines    | VIEW       |",
196            "| datafusion    | information_schema | parameters  | VIEW       |",
197            "+---------------+--------------------+-------------+------------+",
198        ];
199
200        assert_batches_eq!(expected, &batches);
201    }
202
203    #[tokio::test]
204    async fn test_register_db_single_table() {
205        let ctx = setup();
206        let dir = tempfile::tempdir().unwrap();
207        let db_path = dir.path().join("db");
208        let path = format!("file://{}/", db_path.to_str().unwrap());
209        let db_url = url::Url::parse(&path).unwrap();
210        let config = DbConfig { path: db_url };
211        let data_path = db_path.join("tables").join("dft").join("stuff").join("hi");
212
213        let df = ctx.sql("SELECT 1").await.unwrap();
214        let write_opts = DataFrameWriteOptions::new();
215
216        df.write_parquet(data_path.as_path().to_str().unwrap(), write_opts, None)
217            .await
218            .unwrap();
219
220        register_db(&ctx, &config).await.unwrap();
221
222        let batches = ctx
223            .sql("SELECT * FROM information_schema.tables ORDER BY table_catalog, table_schema, table_name")
224            .await
225            .unwrap()
226            .collect()
227            .await
228            .unwrap();
229
230        let expected = [
231            "+---------------+--------------------+-------------+------------+",
232            "| table_catalog | table_schema       | table_name  | table_type |",
233            "+---------------+--------------------+-------------+------------+",
234            "| datafusion    | information_schema | columns     | VIEW       |",
235            "| datafusion    | information_schema | df_settings | VIEW       |",
236            "| datafusion    | information_schema | parameters  | VIEW       |",
237            "| datafusion    | information_schema | routines    | VIEW       |",
238            "| datafusion    | information_schema | schemata    | VIEW       |",
239            "| datafusion    | information_schema | tables      | VIEW       |",
240            "| datafusion    | information_schema | views       | VIEW       |",
241            "| dft           | information_schema | columns     | VIEW       |",
242            "| dft           | information_schema | df_settings | VIEW       |",
243            "| dft           | information_schema | parameters  | VIEW       |",
244            "| dft           | information_schema | routines    | VIEW       |",
245            "| dft           | information_schema | schemata    | VIEW       |",
246            "| dft           | information_schema | tables      | VIEW       |",
247            "| dft           | information_schema | views       | VIEW       |",
248            "| dft           | stuff              | hi          | BASE TABLE |",
249            "+---------------+--------------------+-------------+------------+",
250        ];
251
252        assert_batches_eq!(expected, &batches);
253    }
254
255    #[tokio::test]
256    async fn test_register_db_multiple_tables() {
257        let ctx = setup();
258        let dir = tempfile::tempdir().unwrap();
259        let db_path = dir.path().join("db");
260        let path = format!("file://{}/", db_path.to_str().unwrap());
261        let db_url = url::Url::parse(&path).unwrap();
262        let config = DbConfig { path: db_url };
263        let data_1_path = db_path.join("tables").join("dft").join("stuff").join("hi");
264        let data_2_path = db_path.join("tables").join("dft").join("stuff").join("bye");
265
266        let df = ctx.sql("SELECT 1").await.unwrap();
267        let write_opts = DataFrameWriteOptions::new();
268        df.clone()
269            .write_parquet(data_1_path.as_path().to_str().unwrap(), write_opts, None)
270            .await
271            .unwrap();
272
273        let write_opts = DataFrameWriteOptions::new();
274        df.write_parquet(data_2_path.as_path().to_str().unwrap(), write_opts, None)
275            .await
276            .unwrap();
277
278        register_db(&ctx, &config).await.unwrap();
279
280        let batches = ctx
281            .sql("SELECT * FROM information_schema.tables ORDER BY table_catalog, table_schema, table_name")
282            .await
283            .unwrap()
284            .collect()
285            .await
286            .unwrap();
287
288        let expected = [
289            "+---------------+--------------------+-------------+------------+",
290            "| table_catalog | table_schema       | table_name  | table_type |",
291            "+---------------+--------------------+-------------+------------+",
292            "| datafusion    | information_schema | columns     | VIEW       |",
293            "| datafusion    | information_schema | df_settings | VIEW       |",
294            "| datafusion    | information_schema | parameters  | VIEW       |",
295            "| datafusion    | information_schema | routines    | VIEW       |",
296            "| datafusion    | information_schema | schemata    | VIEW       |",
297            "| datafusion    | information_schema | tables      | VIEW       |",
298            "| datafusion    | information_schema | views       | VIEW       |",
299            "| dft           | information_schema | columns     | VIEW       |",
300            "| dft           | information_schema | df_settings | VIEW       |",
301            "| dft           | information_schema | parameters  | VIEW       |",
302            "| dft           | information_schema | routines    | VIEW       |",
303            "| dft           | information_schema | schemata    | VIEW       |",
304            "| dft           | information_schema | tables      | VIEW       |",
305            "| dft           | information_schema | views       | VIEW       |",
306            "| dft           | stuff              | bye         | BASE TABLE |",
307            "| dft           | stuff              | hi          | BASE TABLE |",
308            "+---------------+--------------------+-------------+------------+",
309        ];
310
311        assert_batches_eq!(expected, &batches);
312    }
313
314    #[tokio::test]
315    async fn test_register_db_multiple_schemas() {
316        let ctx = setup();
317        let dir = tempfile::tempdir().unwrap();
318        let db_path = dir.path().join("db");
319        let path = format!("file://{}/", db_path.to_str().unwrap());
320        let db_url = url::Url::parse(&path).unwrap();
321        let config = DbConfig { path: db_url };
322        let data_1_path = db_path.join("tables").join("dft").join("stuff").join("hi");
323        let data_2_path = db_path
324            .join("tables")
325            .join("dft")
326            .join("things")
327            .join("bye");
328
329        let df = ctx.sql("SELECT 1").await.unwrap();
330        let write_opts = DataFrameWriteOptions::new();
331        df.clone()
332            .write_parquet(data_1_path.as_path().to_str().unwrap(), write_opts, None)
333            .await
334            .unwrap();
335
336        let write_opts = DataFrameWriteOptions::new();
337        df.write_parquet(data_2_path.as_path().to_str().unwrap(), write_opts, None)
338            .await
339            .unwrap();
340
341        register_db(&ctx, &config).await.unwrap();
342
343        let batches = ctx
344            .sql("SELECT * FROM information_schema.tables ORDER BY table_catalog, table_schema, table_name")
345            .await
346            .unwrap()
347            .collect()
348            .await
349            .unwrap();
350
351        let expected = [
352            "+---------------+--------------------+-------------+------------+",
353            "| table_catalog | table_schema       | table_name  | table_type |",
354            "+---------------+--------------------+-------------+------------+",
355            "| datafusion    | information_schema | columns     | VIEW       |",
356            "| datafusion    | information_schema | df_settings | VIEW       |",
357            "| datafusion    | information_schema | parameters  | VIEW       |",
358            "| datafusion    | information_schema | routines    | VIEW       |",
359            "| datafusion    | information_schema | schemata    | VIEW       |",
360            "| datafusion    | information_schema | tables      | VIEW       |",
361            "| datafusion    | information_schema | views       | VIEW       |",
362            "| dft           | information_schema | columns     | VIEW       |",
363            "| dft           | information_schema | df_settings | VIEW       |",
364            "| dft           | information_schema | parameters  | VIEW       |",
365            "| dft           | information_schema | routines    | VIEW       |",
366            "| dft           | information_schema | schemata    | VIEW       |",
367            "| dft           | information_schema | tables      | VIEW       |",
368            "| dft           | information_schema | views       | VIEW       |",
369            "| dft           | stuff              | hi          | BASE TABLE |",
370            "| dft           | things             | bye         | BASE TABLE |",
371            "+---------------+--------------------+-------------+------------+",
372        ];
373
374        assert_batches_eq!(expected, &batches);
375    }
376
377    #[tokio::test]
378    async fn test_register_db_multiple_catalogs() {
379        let ctx = setup();
380        let dir = tempfile::tempdir().unwrap();
381        let db_path = dir.path().join("db");
382        let path = format!("file://{}/", db_path.to_str().unwrap());
383        let db_url = url::Url::parse(&path).unwrap();
384        let config = DbConfig { path: db_url };
385        let data_1_path = db_path.join("tables").join("dft2").join("stuff").join("hi");
386        let data_2_path = db_path
387            .join("tables")
388            .join("dft")
389            .join("things")
390            .join("bye");
391
392        let df = ctx.sql("SELECT 1").await.unwrap();
393        let write_opts = DataFrameWriteOptions::new();
394        df.clone()
395            .write_parquet(data_1_path.as_path().to_str().unwrap(), write_opts, None)
396            .await
397            .unwrap();
398
399        let write_opts = DataFrameWriteOptions::new();
400        df.write_parquet(data_2_path.as_path().to_str().unwrap(), write_opts, None)
401            .await
402            .unwrap();
403
404        register_db(&ctx, &config).await.unwrap();
405
406        let batches = ctx
407            .sql("SELECT * FROM information_schema.tables ORDER BY table_catalog, table_schema, table_name")
408            .await
409            .unwrap()
410            .collect()
411            .await
412            .unwrap();
413
414        let expected = [
415            "+---------------+--------------------+-------------+------------+",
416            "| table_catalog | table_schema       | table_name  | table_type |",
417            "+---------------+--------------------+-------------+------------+",
418            "| datafusion    | information_schema | columns     | VIEW       |",
419            "| datafusion    | information_schema | df_settings | VIEW       |",
420            "| datafusion    | information_schema | parameters  | VIEW       |",
421            "| datafusion    | information_schema | routines    | VIEW       |",
422            "| datafusion    | information_schema | schemata    | VIEW       |",
423            "| datafusion    | information_schema | tables      | VIEW       |",
424            "| datafusion    | information_schema | views       | VIEW       |",
425            "| dft           | information_schema | columns     | VIEW       |",
426            "| dft           | information_schema | df_settings | VIEW       |",
427            "| dft           | information_schema | parameters  | VIEW       |",
428            "| dft           | information_schema | routines    | VIEW       |",
429            "| dft           | information_schema | schemata    | VIEW       |",
430            "| dft           | information_schema | tables      | VIEW       |",
431            "| dft           | information_schema | views       | VIEW       |",
432            "| dft           | things             | bye         | BASE TABLE |",
433            "| dft2          | information_schema | columns     | VIEW       |",
434            "| dft2          | information_schema | df_settings | VIEW       |",
435            "| dft2          | information_schema | parameters  | VIEW       |",
436            "| dft2          | information_schema | routines    | VIEW       |",
437            "| dft2          | information_schema | schemata    | VIEW       |",
438            "| dft2          | information_schema | tables      | VIEW       |",
439            "| dft2          | information_schema | views       | VIEW       |",
440            "| dft2          | stuff              | hi          | BASE TABLE |",
441            "+---------------+--------------------+-------------+------------+",
442        ];
443
444        assert_batches_eq!(expected, &batches);
445    }
446}