use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;
use crate::catalog::{TableProvider, TableProviderFactory};
use crate::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use crate::execution::context::SessionState;
use arrow::datatypes::DataType;
use datafusion_common::{Result, config_datafusion_err};
use datafusion_common::{ToDFSchema, arrow_datafusion_err, plan_err};
use datafusion_expr::CreateExternalTable;
use async_trait::async_trait;
use datafusion_catalog::Session;
#[derive(Debug, Default)]
pub struct ListingTableFactory {}
impl ListingTableFactory {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl TableProviderFactory for ListingTableFactory {
async fn create(
&self,
state: &dyn Session,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn TableProvider>> {
let session_state =
state
.as_any()
.downcast_ref::<SessionState>()
.ok_or_else(|| {
datafusion_common::internal_datafusion_err!(
"ListingTableFactory requires SessionState"
)
})?;
let file_format = session_state
.get_file_format_factory(cmd.file_type.as_str())
.ok_or(config_datafusion_err!(
"Unable to create table with format {}! Could not find FileFormat.",
cmd.file_type
))?
.create(session_state, &cmd.options)?;
let mut table_path =
ListingTableUrl::parse(&cmd.location)?.with_table_ref(cmd.name.clone());
let file_extension = match table_path.is_collection() {
true => "",
false => &get_extension(cmd.location.as_str()),
};
let mut options = ListingOptions::new(file_format)
.with_session_config_options(session_state.config())
.with_file_extension(file_extension);
let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() {
let infer_parts = session_state
.config_options()
.execution
.listing_table_factory_infer_partitions;
let part_cols = if cmd.table_partition_cols.is_empty() && infer_parts {
options
.infer_partitions(session_state, &table_path)
.await?
.into_iter()
} else {
cmd.table_partition_cols.clone().into_iter()
};
(
None,
part_cols
.map(|p| {
(
p,
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Utf8),
),
)
})
.collect::<Vec<_>>(),
)
} else {
let schema = Arc::clone(cmd.schema.inner());
let table_partition_cols = cmd
.table_partition_cols
.iter()
.map(|col| {
schema
.field_with_name(col)
.map_err(|e| arrow_datafusion_err!(e))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.map(|f| (f.name().to_owned(), f.data_type().to_owned()))
.collect();
let mut project_idx = Vec::new();
for i in 0..schema.fields().len() {
if !cmd.table_partition_cols.contains(schema.field(i).name()) {
project_idx.push(i);
}
}
let schema = Arc::new(schema.project(&project_idx)?);
(Some(schema), table_partition_cols)
};
options = options.with_table_partition_cols(table_partition_cols);
options
.validate_partitions(session_state, &table_path)
.await?;
let resolved_schema = match provided_schema {
None => {
if table_path.is_folder() && table_path.get_glob().is_none() {
let glob = match options.format.compression_type() {
Some(compression) => {
match options.format.get_ext_with_compression(&compression) {
Ok(ext) => format!("*.{ext}"),
Err(_) => format!("*.{}", cmd.file_type.to_lowercase()),
}
}
None => format!("*.{}", cmd.file_type.to_lowercase()),
};
table_path = table_path.with_glob(glob.as_ref())?;
}
let schema = options.infer_schema(session_state, &table_path).await?;
let df_schema = Arc::clone(&schema).to_dfschema()?;
let column_refs: HashSet<_> = cmd
.order_exprs
.iter()
.flat_map(|sort| sort.iter())
.flat_map(|s| s.expr.column_refs())
.collect();
for column in &column_refs {
if !df_schema.has_column(column) {
return plan_err!("Column {column} is not in schema");
}
}
schema
}
Some(s) => s,
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(options.with_file_sort_order(cmd.order_exprs.clone()))
.with_schema(resolved_schema);
let provider = ListingTable::try_new(config)?
.with_cache(state.runtime_env().cache_manager.get_file_statistic_cache());
let table = provider
.with_definition(cmd.definition.clone())
.with_constraints(cmd.constraints.clone())
.with_column_defaults(cmd.column_defaults.clone());
if session_state.config().collect_statistics() {
let filters = &[];
let limit = None;
if let Err(e) = table.list_files_for_scan(state, filters, limit).await {
log::warn!("Failed to pre-warm statistics cache: {e}");
}
}
Ok(Arc::new(table))
}
}
fn get_extension(path: &str) -> String {
let res = Path::new(path).extension().and_then(|ext| ext.to_str());
match res {
Some(ext) => format!(".{ext}"),
None => "".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
datasource::file_format::csv::CsvFormat, execution::context::SessionContext,
test_util::parquet_test_data,
};
use datafusion_execution::cache::CacheAccessor;
use datafusion_execution::cache::cache_manager::CacheManagerConfig;
use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache;
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use glob::Pattern;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::{DFSchema, TableReference};
#[tokio::test]
async fn test_create_using_non_std_file_ext() {
let csv_file = tempfile::Builder::new()
.prefix("foo")
.suffix(".tbl")
.tempfile()
.unwrap();
let factory = ListingTableFactory::new();
let context = SessionContext::new();
let state = context.state();
let name = TableReference::bare("foo");
let cmd = CreateExternalTable::builder(
name,
csv_file.path().to_str().unwrap().to_string(),
"csv",
Arc::new(DFSchema::empty()),
)
.with_options(HashMap::from([("format.has_header".into(), "true".into())]))
.build();
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();
let listing_options = listing_table.options();
assert_eq!(".tbl", listing_options.file_extension);
}
#[tokio::test]
async fn test_create_using_non_std_file_ext_csv_options() {
let csv_file = tempfile::Builder::new()
.prefix("foo")
.suffix(".tbl")
.tempfile()
.unwrap();
let factory = ListingTableFactory::new();
let context = SessionContext::new();
let state = context.state();
let name = TableReference::bare("foo");
let mut options = HashMap::new();
options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
options.insert("format.has_header".into(), "true".into());
let cmd = CreateExternalTable::builder(
name,
csv_file.path().to_str().unwrap().to_string(),
"csv",
Arc::new(DFSchema::empty()),
)
.with_options(options)
.build();
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();
let format = listing_table.options().format.clone();
let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
let csv_options = csv_format.options().clone();
assert_eq!(csv_options.schema_infer_max_rec, Some(1000));
let listing_options = listing_table.options();
assert_eq!(".tbl", listing_options.file_extension);
}
#[tokio::test]
async fn test_create_using_folder_with_compression() {
let dir = tempfile::tempdir().unwrap();
let factory = ListingTableFactory::new();
let context = SessionContext::new();
let state = context.state();
let name = TableReference::bare("foo");
let mut options = HashMap::new();
options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
options.insert("format.has_header".into(), "true".into());
options.insert("format.compression".into(), "gzip".into());
let cmd = CreateExternalTable::builder(
name,
dir.path().to_str().unwrap().to_string(),
"csv",
Arc::new(DFSchema::empty()),
)
.with_options(options)
.build();
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();
let format = listing_table.options().format.clone();
let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
let csv_options = csv_format.options().clone();
assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP);
let listing_options = listing_table.options();
assert_eq!("", listing_options.file_extension);
let table_path = listing_table.table_paths().first().unwrap();
assert_eq!(
table_path.get_glob().clone().unwrap(),
Pattern::new("*.csv.gz").unwrap()
);
}
#[tokio::test]
async fn test_create_using_folder_without_compression() {
let dir = tempfile::tempdir().unwrap();
let factory = ListingTableFactory::new();
let context = SessionContext::new();
let state = context.state();
let name = TableReference::bare("foo");
let mut options = HashMap::new();
options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
options.insert("format.has_header".into(), "true".into());
let cmd = CreateExternalTable::builder(
name,
dir.path().to_str().unwrap().to_string(),
"csv",
Arc::new(DFSchema::empty()),
)
.with_options(options)
.build();
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();
let listing_options = listing_table.options();
assert_eq!("", listing_options.file_extension);
let table_path = listing_table.table_paths().first().unwrap();
assert_eq!(
table_path.get_glob().clone().unwrap(),
Pattern::new("*.csv").unwrap()
);
}
#[tokio::test]
async fn test_odd_directory_names() {
let dir = tempfile::tempdir().unwrap();
let mut path = PathBuf::from(dir.path());
path.extend(["odd.v1", "odd.v2"]);
fs::create_dir_all(&path).unwrap();
let factory = ListingTableFactory::new();
let context = SessionContext::new();
let state = context.state();
let name = TableReference::bare("foo");
let cmd = CreateExternalTable::builder(
name,
String::from(path.to_str().unwrap()),
"parquet",
Arc::new(DFSchema::empty()),
)
.build();
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();
let listing_options = listing_table.options();
assert_eq!("", listing_options.file_extension);
}
#[tokio::test]
async fn test_create_with_hive_partitions() {
let dir = tempfile::tempdir().unwrap();
let mut path = PathBuf::from(dir.path());
path.extend(["key1=value1", "key2=value2"]);
fs::create_dir_all(&path).unwrap();
path.push("data.parquet");
fs::File::create_new(&path).unwrap();
let factory = ListingTableFactory::new();
let context = SessionContext::new();
let state = context.state();
let name = TableReference::bare("foo");
let cmd = CreateExternalTable::builder(
name,
dir.path().to_str().unwrap(),
"parquet",
Arc::new(DFSchema::empty()),
)
.build();
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();
let listing_options = listing_table.options();
let dtype =
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8));
let expected_cols = vec![
(String::from("key1"), dtype.clone()),
(String::from("key2"), dtype.clone()),
];
assert_eq!(expected_cols, listing_options.table_partition_cols);
let factory = ListingTableFactory::new();
let mut cfg = SessionConfig::new();
cfg.options_mut()
.execution
.listing_table_factory_infer_partitions = false;
let context = SessionContext::new_with_config(cfg);
let state = context.state();
let name = TableReference::bare("foo");
let cmd = CreateExternalTable::builder(
name,
dir.path().to_str().unwrap().to_string(),
"parquet",
Arc::new(DFSchema::empty()),
)
.build();
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();
let listing_options = listing_table.options();
assert!(listing_options.table_partition_cols.is_empty());
}
#[tokio::test]
async fn test_statistics_cache_prewarming() {
let factory = ListingTableFactory::new();
let location = PathBuf::from(parquet_test_data())
.join("alltypes_tiny_pages_plain.parquet")
.to_string_lossy()
.to_string();
let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default());
let cache_config = CacheManagerConfig::default()
.with_files_statistics_cache(Some(file_statistics_cache.clone()));
let runtime = RuntimeEnvBuilder::new()
.with_cache_manager(cache_config)
.build_arc()
.unwrap();
let mut config = SessionConfig::new();
config.options_mut().execution.collect_statistics = true;
let context = SessionContext::new_with_config_rt(config, runtime);
let state = context.state();
let name = TableReference::bare("test");
let cmd = CreateExternalTable::builder(
name,
location.clone(),
"parquet",
Arc::new(DFSchema::empty()),
)
.build();
let _table_provider = factory.create(&state, &cmd).await.unwrap();
assert!(
file_statistics_cache.len() > 0,
"Statistics cache should be pre-warmed when collect_statistics is enabled"
);
let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default());
let cache_config = CacheManagerConfig::default()
.with_files_statistics_cache(Some(file_statistics_cache.clone()));
let runtime = RuntimeEnvBuilder::new()
.with_cache_manager(cache_config)
.build_arc()
.unwrap();
let mut config = SessionConfig::new();
config.options_mut().execution.collect_statistics = false;
let context = SessionContext::new_with_config_rt(config, runtime);
let state = context.state();
let name = TableReference::bare("test");
let cmd = CreateExternalTable::builder(
name,
location,
"parquet",
Arc::new(DFSchema::empty()),
)
.build();
let _table_provider = factory.create(&state, &cmd).await.unwrap();
assert_eq!(
file_statistics_cache.len(),
0,
"Statistics cache should not be pre-warmed when collect_statistics is disabled"
);
}
#[tokio::test]
async fn test_create_with_invalid_session() {
use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_common::Result;
use datafusion_common::config::TableOptions;
use datafusion_execution::TaskContext;
use datafusion_execution::config::SessionConfig;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_plan::ExecutionPlan;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug)]
struct MockSession;
#[async_trait]
impl Session for MockSession {
fn session_id(&self) -> &str {
"mock_session"
}
fn config(&self) -> &SessionConfig {
unimplemented!()
}
async fn create_physical_plan(
&self,
_logical_plan: &datafusion_expr::LogicalPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn create_physical_expr(
&self,
_expr: datafusion_expr::Expr,
_df_schema: &DFSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
unimplemented!()
}
fn scalar_functions(
&self,
) -> &HashMap<String, Arc<datafusion_expr::ScalarUDF>> {
unimplemented!()
}
fn aggregate_functions(
&self,
) -> &HashMap<String, Arc<datafusion_expr::AggregateUDF>> {
unimplemented!()
}
fn window_functions(
&self,
) -> &HashMap<String, Arc<datafusion_expr::WindowUDF>> {
unimplemented!()
}
fn runtime_env(&self) -> &Arc<datafusion_execution::runtime_env::RuntimeEnv> {
unimplemented!()
}
fn execution_props(
&self,
) -> &datafusion_expr::execution_props::ExecutionProps {
unimplemented!()
}
fn as_any(&self) -> &dyn Any {
self
}
fn table_options(&self) -> &TableOptions {
unimplemented!()
}
fn table_options_mut(&mut self) -> &mut TableOptions {
unimplemented!()
}
fn task_ctx(&self) -> Arc<TaskContext> {
unimplemented!()
}
}
let factory = ListingTableFactory::new();
let mock_session = MockSession;
let name = TableReference::bare("foo");
let cmd = CreateExternalTable::builder(
name,
"foo.csv".to_string(),
"csv",
Arc::new(DFSchema::empty()),
)
.build();
let result = factory.create(&mock_session, &cmd).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.strip_backtrace()
.contains("Internal error: ListingTableFactory requires SessionState")
);
}
}