1use std::sync::Arc;
19
20use color_eyre::{Report, Result};
21use datafusion::{
22 catalog::{MemoryCatalogProvider, MemorySchemaProvider},
23 datasource::{
24 file_format::{csv::CsvFormat, json::JsonFormat, parquet::ParquetFormat, FileFormat},
25 listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
26 },
27 prelude::SessionContext,
28};
29use log::{debug, info};
30use std::path::Path;
31#[cfg(feature = "vortex")]
32use {vortex_datafusion::VortexFormat, vortex_session::VortexSession};
33
34use crate::config::DbConfig;
35
36fn detect_format(extension: &str) -> Result<(Arc<dyn FileFormat>, &'static str)> {
38 match extension.to_lowercase().as_str() {
39 "parquet" => Ok((Arc::new(ParquetFormat::new()), ".parquet")),
40 "csv" => Ok((Arc::new(CsvFormat::default()), ".csv")),
41 "json" => Ok((Arc::new(JsonFormat::default()), ".json")),
42 #[cfg(feature = "vortex")]
43 "vortex" => Ok((
44 Arc::new(VortexFormat::new(VortexSession::empty())),
45 ".vortex",
46 )),
47 _ => Err(Report::msg(format!(
48 "Unsupported file extension: {}",
49 extension
50 ))),
51 }
52}
53
54pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<()> {
55 info!("registering tables to database");
56 let tables_url = db_config.path.join("tables/")?;
57 let listing_tables_url = ListingTableUrl::parse(tables_url.clone())?;
58 let store_url = listing_tables_url.object_store();
59 let store = ctx.runtime_env().object_store(store_url)?;
60 let tables_path = object_store::path::Path::from_url_path(tables_url.path())?;
61 let catalogs = store.list_with_delimiter(Some(&tables_path)).await?;
62 for catalog in catalogs.common_prefixes {
63 let catalog_name = catalog
64 .filename()
65 .ok_or(Report::msg("missing catalog name"))?;
66 info!("...handling {catalog_name} catalog");
67 let maybe_catalog = ctx.catalog(catalog_name);
68 let catalog_provider = match maybe_catalog {
69 Some(catalog) => catalog,
70 None => {
71 info!("...catalog does not exist, createing");
72 let mem_catalog_provider = Arc::new(MemoryCatalogProvider::new());
73 ctx.register_catalog(catalog_name, mem_catalog_provider);
74 ctx.catalog(catalog_name).ok_or(Report::msg(format!(
75 "missing catalog {catalog_name}, shouldnt be possible"
76 )))?
77 }
78 };
79 let schemas = store.list_with_delimiter(Some(&catalog)).await?;
80 for schema in schemas.common_prefixes {
81 let schema_name = schema
82 .filename()
83 .ok_or(Report::msg("missing schema name"))?;
84 info!("...handling {schema_name} schema");
85 let maybe_schema = catalog_provider.schema(schema_name);
86 let schema_provider = match maybe_schema {
87 Some(schema) => schema,
88 None => {
89 info!("...schema does not exist, creating");
90 let mem_schema_provider = Arc::new(MemorySchemaProvider::new());
91 catalog_provider.register_schema(schema_name, mem_schema_provider)?;
92 catalog_provider
93 .schema(schema_name)
94 .ok_or(Report::msg(format!(
95 "missing schema {schema_name}, shouldnt be possible"
96 )))?
97 }
98 };
99 let tables = store.list_with_delimiter(Some(&schema)).await?;
100 for table_path in tables.common_prefixes {
101 let table_name = table_path
102 .filename()
103 .ok_or(Report::msg("missing table name"))?;
104 info!("...handling table \"{catalog_name}.{schema_name}.{table_name}\"");
105
106 let p = tables_url
107 .join(&format!("{catalog_name}/"))?
108 .join(&format!("{schema_name}/"))?
109 .join(&format!("{table_name}/"))?;
110
111 let table_url = ListingTableUrl::parse(p)?;
112 debug!("...table url: {table_url:?}");
113
114 let files = store.list_with_delimiter(Some(&table_path)).await?;
116
117 let extension = files
119 .objects
120 .iter()
121 .find_map(|obj| {
122 let path = obj.location.as_ref();
123 Path::new(path).extension().and_then(|ext| ext.to_str())
124 })
125 .ok_or(Report::msg(format!(
126 "No files with extensions found in table directory: {table_name}"
127 )))?;
128
129 info!("...detected format: {extension}");
130 let (file_format, file_extension) = detect_format(extension)?;
131
132 let listing_options =
133 ListingOptions::new(file_format).with_file_extension(file_extension);
134 let resolved_schema = listing_options
136 .infer_schema(&ctx.state(), &table_url)
137 .await?;
138 let config = ListingTableConfig::new(table_url)
139 .with_listing_options(listing_options)
140 .with_schema(resolved_schema);
141 let provider = Arc::new(ListingTable::try_new(config)?);
143 info!("...table registered");
144 schema_provider.register_table(table_name.to_string(), provider)?;
145 }
146 }
147 }
148
149 Ok(())
150}
151
152#[cfg(test)]
153mod test {
154 use datafusion::{
155 assert_batches_eq,
156 dataframe::DataFrameWriteOptions,
157 prelude::{SessionConfig, SessionContext},
158 };
159
160 use crate::{config::DbConfig, db::register_db};
161
162 fn setup() -> SessionContext {
163 let config = SessionConfig::default().with_information_schema(true);
164 SessionContext::new_with_config(config)
165 }
166
167 #[tokio::test]
168 async fn test_register_db_no_tables() {
169 let ctx = setup();
170 let dir = tempfile::tempdir().unwrap();
171 let db_path = dir.path().join("db");
172 let path = format!("file://{}/", db_path.to_str().unwrap());
173 let db_url = url::Url::parse(&path).unwrap();
174 let config = DbConfig { path: db_url };
175
176 register_db(&ctx, &config).await.unwrap();
177
178 let batches = ctx
179 .sql("SHOW TABLES")
180 .await
181 .unwrap()
182 .collect()
183 .await
184 .unwrap();
185
186 let expected = [
187 "+---------------+--------------------+-------------+------------+",
188 "| table_catalog | table_schema | table_name | table_type |",
189 "+---------------+--------------------+-------------+------------+",
190 "| datafusion | information_schema | tables | VIEW |",
191 "| datafusion | information_schema | views | VIEW |",
192 "| datafusion | information_schema | columns | VIEW |",
193 "| datafusion | information_schema | df_settings | VIEW |",
194 "| datafusion | information_schema | schemata | VIEW |",
195 "| datafusion | information_schema | routines | VIEW |",
196 "| datafusion | information_schema | parameters | VIEW |",
197 "+---------------+--------------------+-------------+------------+",
198 ];
199
200 assert_batches_eq!(expected, &batches);
201 }
202
203 #[tokio::test]
204 async fn test_register_db_single_table() {
205 let ctx = setup();
206 let dir = tempfile::tempdir().unwrap();
207 let db_path = dir.path().join("db");
208 let path = format!("file://{}/", db_path.to_str().unwrap());
209 let db_url = url::Url::parse(&path).unwrap();
210 let config = DbConfig { path: db_url };
211 let data_path = db_path.join("tables").join("dft").join("stuff").join("hi");
212
213 let df = ctx.sql("SELECT 1").await.unwrap();
214 let write_opts = DataFrameWriteOptions::new();
215
216 df.write_parquet(data_path.as_path().to_str().unwrap(), write_opts, None)
217 .await
218 .unwrap();
219
220 register_db(&ctx, &config).await.unwrap();
221
222 let batches = ctx
223 .sql("SELECT * FROM information_schema.tables ORDER BY table_catalog, table_schema, table_name")
224 .await
225 .unwrap()
226 .collect()
227 .await
228 .unwrap();
229
230 let expected = [
231 "+---------------+--------------------+-------------+------------+",
232 "| table_catalog | table_schema | table_name | table_type |",
233 "+---------------+--------------------+-------------+------------+",
234 "| datafusion | information_schema | columns | VIEW |",
235 "| datafusion | information_schema | df_settings | VIEW |",
236 "| datafusion | information_schema | parameters | VIEW |",
237 "| datafusion | information_schema | routines | VIEW |",
238 "| datafusion | information_schema | schemata | VIEW |",
239 "| datafusion | information_schema | tables | VIEW |",
240 "| datafusion | information_schema | views | VIEW |",
241 "| dft | information_schema | columns | VIEW |",
242 "| dft | information_schema | df_settings | VIEW |",
243 "| dft | information_schema | parameters | VIEW |",
244 "| dft | information_schema | routines | VIEW |",
245 "| dft | information_schema | schemata | VIEW |",
246 "| dft | information_schema | tables | VIEW |",
247 "| dft | information_schema | views | VIEW |",
248 "| dft | stuff | hi | BASE TABLE |",
249 "+---------------+--------------------+-------------+------------+",
250 ];
251
252 assert_batches_eq!(expected, &batches);
253 }
254
255 #[tokio::test]
256 async fn test_register_db_multiple_tables() {
257 let ctx = setup();
258 let dir = tempfile::tempdir().unwrap();
259 let db_path = dir.path().join("db");
260 let path = format!("file://{}/", db_path.to_str().unwrap());
261 let db_url = url::Url::parse(&path).unwrap();
262 let config = DbConfig { path: db_url };
263 let data_1_path = db_path.join("tables").join("dft").join("stuff").join("hi");
264 let data_2_path = db_path.join("tables").join("dft").join("stuff").join("bye");
265
266 let df = ctx.sql("SELECT 1").await.unwrap();
267 let write_opts = DataFrameWriteOptions::new();
268 df.clone()
269 .write_parquet(data_1_path.as_path().to_str().unwrap(), write_opts, None)
270 .await
271 .unwrap();
272
273 let write_opts = DataFrameWriteOptions::new();
274 df.write_parquet(data_2_path.as_path().to_str().unwrap(), write_opts, None)
275 .await
276 .unwrap();
277
278 register_db(&ctx, &config).await.unwrap();
279
280 let batches = ctx
281 .sql("SELECT * FROM information_schema.tables ORDER BY table_catalog, table_schema, table_name")
282 .await
283 .unwrap()
284 .collect()
285 .await
286 .unwrap();
287
288 let expected = [
289 "+---------------+--------------------+-------------+------------+",
290 "| table_catalog | table_schema | table_name | table_type |",
291 "+---------------+--------------------+-------------+------------+",
292 "| datafusion | information_schema | columns | VIEW |",
293 "| datafusion | information_schema | df_settings | VIEW |",
294 "| datafusion | information_schema | parameters | VIEW |",
295 "| datafusion | information_schema | routines | VIEW |",
296 "| datafusion | information_schema | schemata | VIEW |",
297 "| datafusion | information_schema | tables | VIEW |",
298 "| datafusion | information_schema | views | VIEW |",
299 "| dft | information_schema | columns | VIEW |",
300 "| dft | information_schema | df_settings | VIEW |",
301 "| dft | information_schema | parameters | VIEW |",
302 "| dft | information_schema | routines | VIEW |",
303 "| dft | information_schema | schemata | VIEW |",
304 "| dft | information_schema | tables | VIEW |",
305 "| dft | information_schema | views | VIEW |",
306 "| dft | stuff | bye | BASE TABLE |",
307 "| dft | stuff | hi | BASE TABLE |",
308 "+---------------+--------------------+-------------+------------+",
309 ];
310
311 assert_batches_eq!(expected, &batches);
312 }
313
314 #[tokio::test]
315 async fn test_register_db_multiple_schemas() {
316 let ctx = setup();
317 let dir = tempfile::tempdir().unwrap();
318 let db_path = dir.path().join("db");
319 let path = format!("file://{}/", db_path.to_str().unwrap());
320 let db_url = url::Url::parse(&path).unwrap();
321 let config = DbConfig { path: db_url };
322 let data_1_path = db_path.join("tables").join("dft").join("stuff").join("hi");
323 let data_2_path = db_path
324 .join("tables")
325 .join("dft")
326 .join("things")
327 .join("bye");
328
329 let df = ctx.sql("SELECT 1").await.unwrap();
330 let write_opts = DataFrameWriteOptions::new();
331 df.clone()
332 .write_parquet(data_1_path.as_path().to_str().unwrap(), write_opts, None)
333 .await
334 .unwrap();
335
336 let write_opts = DataFrameWriteOptions::new();
337 df.write_parquet(data_2_path.as_path().to_str().unwrap(), write_opts, None)
338 .await
339 .unwrap();
340
341 register_db(&ctx, &config).await.unwrap();
342
343 let batches = ctx
344 .sql("SELECT * FROM information_schema.tables ORDER BY table_catalog, table_schema, table_name")
345 .await
346 .unwrap()
347 .collect()
348 .await
349 .unwrap();
350
351 let expected = [
352 "+---------------+--------------------+-------------+------------+",
353 "| table_catalog | table_schema | table_name | table_type |",
354 "+---------------+--------------------+-------------+------------+",
355 "| datafusion | information_schema | columns | VIEW |",
356 "| datafusion | information_schema | df_settings | VIEW |",
357 "| datafusion | information_schema | parameters | VIEW |",
358 "| datafusion | information_schema | routines | VIEW |",
359 "| datafusion | information_schema | schemata | VIEW |",
360 "| datafusion | information_schema | tables | VIEW |",
361 "| datafusion | information_schema | views | VIEW |",
362 "| dft | information_schema | columns | VIEW |",
363 "| dft | information_schema | df_settings | VIEW |",
364 "| dft | information_schema | parameters | VIEW |",
365 "| dft | information_schema | routines | VIEW |",
366 "| dft | information_schema | schemata | VIEW |",
367 "| dft | information_schema | tables | VIEW |",
368 "| dft | information_schema | views | VIEW |",
369 "| dft | stuff | hi | BASE TABLE |",
370 "| dft | things | bye | BASE TABLE |",
371 "+---------------+--------------------+-------------+------------+",
372 ];
373
374 assert_batches_eq!(expected, &batches);
375 }
376
377 #[tokio::test]
378 async fn test_register_db_multiple_catalogs() {
379 let ctx = setup();
380 let dir = tempfile::tempdir().unwrap();
381 let db_path = dir.path().join("db");
382 let path = format!("file://{}/", db_path.to_str().unwrap());
383 let db_url = url::Url::parse(&path).unwrap();
384 let config = DbConfig { path: db_url };
385 let data_1_path = db_path.join("tables").join("dft2").join("stuff").join("hi");
386 let data_2_path = db_path
387 .join("tables")
388 .join("dft")
389 .join("things")
390 .join("bye");
391
392 let df = ctx.sql("SELECT 1").await.unwrap();
393 let write_opts = DataFrameWriteOptions::new();
394 df.clone()
395 .write_parquet(data_1_path.as_path().to_str().unwrap(), write_opts, None)
396 .await
397 .unwrap();
398
399 let write_opts = DataFrameWriteOptions::new();
400 df.write_parquet(data_2_path.as_path().to_str().unwrap(), write_opts, None)
401 .await
402 .unwrap();
403
404 register_db(&ctx, &config).await.unwrap();
405
406 let batches = ctx
407 .sql("SELECT * FROM information_schema.tables ORDER BY table_catalog, table_schema, table_name")
408 .await
409 .unwrap()
410 .collect()
411 .await
412 .unwrap();
413
414 let expected = [
415 "+---------------+--------------------+-------------+------------+",
416 "| table_catalog | table_schema | table_name | table_type |",
417 "+---------------+--------------------+-------------+------------+",
418 "| datafusion | information_schema | columns | VIEW |",
419 "| datafusion | information_schema | df_settings | VIEW |",
420 "| datafusion | information_schema | parameters | VIEW |",
421 "| datafusion | information_schema | routines | VIEW |",
422 "| datafusion | information_schema | schemata | VIEW |",
423 "| datafusion | information_schema | tables | VIEW |",
424 "| datafusion | information_schema | views | VIEW |",
425 "| dft | information_schema | columns | VIEW |",
426 "| dft | information_schema | df_settings | VIEW |",
427 "| dft | information_schema | parameters | VIEW |",
428 "| dft | information_schema | routines | VIEW |",
429 "| dft | information_schema | schemata | VIEW |",
430 "| dft | information_schema | tables | VIEW |",
431 "| dft | information_schema | views | VIEW |",
432 "| dft | things | bye | BASE TABLE |",
433 "| dft2 | information_schema | columns | VIEW |",
434 "| dft2 | information_schema | df_settings | VIEW |",
435 "| dft2 | information_schema | parameters | VIEW |",
436 "| dft2 | information_schema | routines | VIEW |",
437 "| dft2 | information_schema | schemata | VIEW |",
438 "| dft2 | information_schema | tables | VIEW |",
439 "| dft2 | information_schema | views | VIEW |",
440 "| dft2 | stuff | hi | BASE TABLE |",
441 "+---------------+--------------------+-------------+------------+",
442 ];
443
444 assert_batches_eq!(expected, &batches);
445 }
446}