use crate::{
catalog::{CatalogList, MemoryCatalogList},
datasource::{
listing::{ListingOptions, ListingTable},
listing_table_factory::ListingTableFactory,
provider::TableProviderFactory,
},
datasource::{MemTable, ViewTable},
logical_expr::{PlanType, ToStringifiedPlan},
optimizer::optimizer::Optimizer,
physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule},
};
use datafusion_common::alias::AliasGenerator;
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
use parking_lot::RwLock;
use std::collections::hash_map::Entry;
use std::string::String;
use std::sync::Arc;
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
};
use std::{ops::ControlFlow, sync::Weak};
use arrow::record_batch::RecordBatch;
use arrow::{
array::StringBuilder,
datatypes::{DataType, Field, Schema, SchemaRef},
};
use crate::catalog::{
schema::{MemorySchemaProvider, SchemaProvider},
{CatalogProvider, MemoryCatalogProvider},
};
use crate::dataframe::DataFrame;
use crate::datasource::{
listing::{ListingTableConfig, ListingTableUrl},
provider_as_source, TableProvider,
};
use crate::error::{DataFusionError, Result};
use crate::logical_expr::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable,
CreateView, DropCatalogSchema, DropTable, DropView, Explain, LogicalPlan,
LogicalPlanBuilder, SetVariable, TableSource, TableType, UNNAMED_TABLE,
};
use crate::optimizer::OptimizerRule;
use datafusion_sql::{
parser::{CopyToSource, CopyToStatement},
planner::ParserOptions,
ResolvedTableReference, TableReference,
};
use sqlparser::dialect::dialect_from_str;
use crate::config::ConfigOptions;
use crate::datasource::physical_plan::{plan_to_csv, plan_to_json, plan_to_parquet};
use crate::execution::{runtime_env::RuntimeEnv, FunctionRegistry};
use crate::physical_plan::udaf::AggregateUDF;
use crate::physical_plan::udf::ScalarUDF;
use crate::physical_plan::ExecutionPlan;
use crate::physical_planner::DefaultPhysicalPlanner;
use crate::physical_planner::PhysicalPlanner;
use crate::variable::{VarProvider, VarType};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_common::{OwnedTableReference, SchemaReference};
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
};
use parquet::file::properties::WriterProperties;
use url::Url;
use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA};
use crate::catalog::listing_schema::ListingSchemaProvider;
use crate::datasource::object_store::ObjectStoreUrl;
use datafusion_optimizer::{
analyzer::{Analyzer, AnalyzerRule},
OptimizerConfig,
};
use datafusion_sql::planner::object_name_to_table_reference;
use uuid::Uuid;
use crate::execution::options::ArrowReadOptions;
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;
use super::options::{
AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, ReadOptions,
};
pub trait DataFilePaths {
fn to_urls(self) -> Result<Vec<ListingTableUrl>>;
}
impl DataFilePaths for &str {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl DataFilePaths for String {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl DataFilePaths for &String {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl<P> DataFilePaths for Vec<P>
where
P: AsRef<str>,
{
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
self.iter()
.map(ListingTableUrl::parse)
.collect::<Result<Vec<ListingTableUrl>>>()
}
}
#[derive(Clone)]
pub struct SessionContext {
session_id: String,
session_start_time: DateTime<Utc>,
state: Arc<RwLock<SessionState>>,
}
impl Default for SessionContext {
fn default() -> Self {
Self::new()
}
}
impl SessionContext {
pub fn new() -> Self {
Self::with_config(SessionConfig::new())
}
pub async fn refresh_catalogs(&self) -> Result<()> {
let cat_names = self.catalog_names().clone();
for cat_name in cat_names.iter() {
let cat = self.catalog(cat_name.as_str()).ok_or_else(|| {
DataFusionError::Internal("Catalog not found!".to_string())
})?;
for schema_name in cat.schema_names() {
let schema = cat.schema(schema_name.as_str()).ok_or_else(|| {
DataFusionError::Internal("Schema not found!".to_string())
})?;
let lister = schema.as_any().downcast_ref::<ListingSchemaProvider>();
if let Some(lister) = lister {
lister.refresh(&self.state()).await?;
}
}
}
Ok(())
}
pub fn with_config(config: SessionConfig) -> Self {
let runtime = Arc::new(RuntimeEnv::default());
Self::with_config_rt(config, runtime)
}
pub fn with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let state = SessionState::with_config_rt(config, runtime);
Self::with_state(state)
}
pub fn with_state(state: SessionState) -> Self {
Self {
session_id: state.session_id.clone(),
session_start_time: Utc::now(),
state: Arc::new(RwLock::new(state)),
}
}
pub fn session_start_time(&self) -> DateTime<Utc> {
self.session_start_time
}
pub fn register_batch(
&self,
table_name: &str,
batch: RecordBatch,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
self.register_table(
TableReference::Bare {
table: table_name.into(),
},
Arc::new(table),
)
}
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
self.state.read().runtime_env.clone()
}
pub fn session_id(&self) -> String {
self.session_id.clone()
}
pub fn table_factory(
&self,
file_type: &str,
) -> Option<Arc<dyn TableProviderFactory>> {
self.state.read().table_factories().get(file_type).cloned()
}
pub fn enable_ident_normalization(&self) -> bool {
self.state
.read()
.config
.options()
.sql_parser
.enable_ident_normalization
}
pub fn copied_config(&self) -> SessionConfig {
self.state.read().config.clone()
}
pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
let plan = self.state().create_logical_plan(sql).await?;
self.execute_logical_plan(plan).await
}
pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
match plan {
LogicalPlan::Ddl(ddl) => match ddl {
DdlStatement::CreateExternalTable(cmd) => {
self.create_external_table(&cmd).await
}
DdlStatement::CreateMemoryTable(cmd) => {
self.create_memory_table(cmd).await
}
DdlStatement::CreateView(cmd) => self.create_view(cmd).await,
DdlStatement::CreateCatalogSchema(cmd) => {
self.create_catalog_schema(cmd).await
}
DdlStatement::CreateCatalog(cmd) => self.create_catalog(cmd).await,
DdlStatement::DropTable(cmd) => self.drop_table(cmd).await,
DdlStatement::DropView(cmd) => self.drop_view(cmd).await,
DdlStatement::DropCatalogSchema(cmd) => self.drop_schema(cmd).await,
},
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
self.set_variable(stmt).await
}
LogicalPlan::DescribeTable(DescribeTable { schema, .. }) => {
self.return_describe_table_dataframe(schema).await
}
plan => Ok(DataFrame::new(self.state(), plan)),
}
}
fn return_empty_dataframe(&self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(DataFrame::new(self.state(), plan))
}
async fn return_describe_table_record_batch(
&self,
schema: Arc<Schema>,
) -> Result<RecordBatch> {
let record_batch_schema = Arc::new(Schema::new(vec![
Field::new("column_name", DataType::Utf8, false),
Field::new("data_type", DataType::Utf8, false),
Field::new("is_nullable", DataType::Utf8, false),
]));
let mut column_names = StringBuilder::new();
let mut data_types = StringBuilder::new();
let mut is_nullables = StringBuilder::new();
for (_, field) in schema.fields().iter().enumerate() {
column_names.append_value(field.name());
let data_type = field.data_type();
data_types.append_value(format!("{data_type:?}"));
let nullable_str = if field.is_nullable() { "YES" } else { "NO" };
is_nullables.append_value(nullable_str);
}
let record_batch = RecordBatch::try_new(
record_batch_schema,
vec![
Arc::new(column_names.finish()),
Arc::new(data_types.finish()),
Arc::new(is_nullables.finish()),
],
)?;
Ok(record_batch)
}
async fn return_describe_table_dataframe(
&self,
schema: Arc<Schema>,
) -> Result<DataFrame> {
let record_batch = self.return_describe_table_record_batch(schema).await?;
self.read_batch(record_batch)
}
async fn create_external_table(
&self,
cmd: &CreateExternalTable,
) -> Result<DataFrame> {
let exist = self.table_exist(&cmd.name)?;
if exist {
match cmd.if_not_exists {
true => return self.return_empty_dataframe(),
false => {
return Err(DataFusionError::Execution(format!(
"Table '{}' already exists",
cmd.name
)));
}
}
}
let table_provider: Arc<dyn TableProvider> =
self.create_custom_table(cmd).await?;
self.register_table(&cmd.name, table_provider)?;
self.return_empty_dataframe()
}
async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result<DataFrame> {
let CreateMemoryTable {
name,
input,
if_not_exists,
or_replace,
primary_key,
} = cmd;
if !primary_key.is_empty() {
Err(DataFusionError::Execution(
"Primary keys on MemoryTables are not currently supported!".to_string(),
))?;
}
let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone());
let input = self.state().optimize(&input)?;
let table = self.table(&name).await;
match (if_not_exists, or_replace, table) {
(true, false, Ok(_)) => self.return_empty_dataframe(),
(false, true, Ok(_)) => {
self.deregister_table(&name)?;
let schema = Arc::new(input.schema().as_ref().into());
let physical = DataFrame::new(self.state(), input);
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema, batches)?);
self.register_table(&name, table)?;
self.return_empty_dataframe()
}
(true, true, Ok(_)) => Err(DataFusionError::Execution(
"'IF NOT EXISTS' cannot coexist with 'REPLACE'".to_string(),
)),
(_, _, Err(_)) => {
let schema = Arc::new(input.schema().as_ref().into());
let physical = DataFrame::new(self.state(), input);
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema, batches)?);
self.register_table(&name, table)?;
self.return_empty_dataframe()
}
(false, false, Ok(_)) => Err(DataFusionError::Execution(format!(
"Table '{name}' already exists"
))),
}
}
async fn create_view(&self, cmd: CreateView) -> Result<DataFrame> {
let CreateView {
name,
input,
or_replace,
definition,
} = cmd;
let view = self.table(&name).await;
match (or_replace, view) {
(true, Ok(_)) => {
self.deregister_table(&name)?;
let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?);
self.register_table(&name, table)?;
self.return_empty_dataframe()
}
(_, Err(_)) => {
let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?);
self.register_table(&name, table)?;
self.return_empty_dataframe()
}
(false, Ok(_)) => Err(DataFusionError::Execution(format!(
"Table '{name}' already exists"
))),
}
}
async fn create_catalog_schema(&self, cmd: CreateCatalogSchema) -> Result<DataFrame> {
let CreateCatalogSchema {
schema_name,
if_not_exists,
..
} = cmd;
let tokens: Vec<&str> = schema_name.split('.').collect();
let (catalog, schema_name) = match tokens.len() {
1 => {
let state = self.state.read();
let name = &state.config.options().catalog.default_catalog;
let catalog = state.catalog_list.catalog(name).ok_or_else(|| {
DataFusionError::Execution(format!(
"Missing default catalog '{name}'"
))
})?;
(catalog, tokens[0])
}
2 => {
let name = &tokens[0];
let catalog = self.catalog(name).ok_or_else(|| {
DataFusionError::Execution(format!("Missing catalog '{name}'"))
})?;
(catalog, tokens[1])
}
_ => {
return Err(DataFusionError::Execution(format!(
"Unable to parse catalog from {schema_name}"
)))
}
};
let schema = catalog.schema(schema_name);
match (if_not_exists, schema) {
(true, Some(_)) => self.return_empty_dataframe(),
(true, None) | (false, None) => {
let schema = Arc::new(MemorySchemaProvider::new());
catalog.register_schema(schema_name, schema)?;
self.return_empty_dataframe()
}
(false, Some(_)) => Err(DataFusionError::Execution(format!(
"Schema '{schema_name}' already exists"
))),
}
}
async fn create_catalog(&self, cmd: CreateCatalog) -> Result<DataFrame> {
let CreateCatalog {
catalog_name,
if_not_exists,
..
} = cmd;
let catalog = self.catalog(catalog_name.as_str());
match (if_not_exists, catalog) {
(true, Some(_)) => self.return_empty_dataframe(),
(true, None) | (false, None) => {
let new_catalog = Arc::new(MemoryCatalogProvider::new());
self.state
.write()
.catalog_list
.register_catalog(catalog_name, new_catalog);
self.return_empty_dataframe()
}
(false, Some(_)) => Err(DataFusionError::Execution(format!(
"Catalog '{catalog_name}' already exists"
))),
}
}
async fn drop_table(&self, cmd: DropTable) -> Result<DataFrame> {
let DropTable {
name, if_exists, ..
} = cmd;
let result = self.find_and_deregister(&name, TableType::Base).await;
match (result, if_exists) {
(Ok(true), _) => self.return_empty_dataframe(),
(_, true) => self.return_empty_dataframe(),
(_, _) => Err(DataFusionError::Execution(format!(
"Table '{name}' doesn't exist."
))),
}
}
async fn drop_view(&self, cmd: DropView) -> Result<DataFrame> {
let DropView {
name, if_exists, ..
} = cmd;
let result = self.find_and_deregister(&name, TableType::View).await;
match (result, if_exists) {
(Ok(true), _) => self.return_empty_dataframe(),
(_, true) => self.return_empty_dataframe(),
(_, _) => Err(DataFusionError::Execution(format!(
"View '{name}' doesn't exist."
))),
}
}
async fn drop_schema(&self, cmd: DropCatalogSchema) -> Result<DataFrame> {
let DropCatalogSchema {
name,
if_exists: allow_missing,
cascade,
schema: _,
} = cmd;
let catalog = {
let state = self.state.read();
let catalog_name = match &name {
SchemaReference::Full { catalog, .. } => catalog.to_string(),
SchemaReference::Bare { .. } => {
state.config_options().catalog.default_catalog.to_string()
}
};
if let Some(catalog) = state.catalog_list.catalog(&catalog_name) {
catalog
} else if allow_missing {
return self.return_empty_dataframe();
} else {
return self.schema_doesnt_exist_err(name);
}
};
let dereg = catalog.deregister_schema(name.schema_name(), cascade)?;
match (dereg, allow_missing) {
(None, true) => self.return_empty_dataframe(),
(None, false) => self.schema_doesnt_exist_err(name),
(Some(_), _) => self.return_empty_dataframe(),
}
}
fn schema_doesnt_exist_err(
&self,
schemaref: SchemaReference<'_>,
) -> Result<DataFrame> {
Err(DataFusionError::Execution(format!(
"Schema '{schemaref}' doesn't exist."
)))
}
async fn set_variable(&self, stmt: SetVariable) -> Result<DataFrame> {
let SetVariable {
variable, value, ..
} = stmt;
let mut state = self.state.write();
state.config.options_mut().set(&variable, &value)?;
drop(state);
self.return_empty_dataframe()
}
async fn create_custom_table(
&self,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn TableProvider>> {
let state = self.state.read().clone();
let file_type = cmd.file_type.to_uppercase();
let factory =
&state
.table_factories
.get(file_type.as_str())
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory).create(&state, cmd).await?;
Ok(table)
}
async fn find_and_deregister<'a>(
&self,
table_ref: impl Into<TableReference<'a>>,
table_type: TableType,
) -> Result<bool> {
let table_ref = table_ref.into();
let table = table_ref.table().to_owned();
let maybe_schema = {
let state = self.state.read();
let resolved = state.resolve_table_ref(table_ref);
state
.catalog_list
.catalog(&resolved.catalog)
.and_then(|c| c.schema(&resolved.schema))
};
if let Some(schema) = maybe_schema {
if let Some(table_provider) = schema.table(&table).await {
if table_provider.table_type() == table_type {
schema.deregister_table(&table)?;
return Ok(true);
}
}
}
Ok(false)
}
pub fn register_variable(
&self,
variable_type: VarType,
provider: Arc<dyn VarProvider + Send + Sync>,
) {
self.state
.write()
.execution_props
.add_var_provider(variable_type, provider);
}
pub fn register_udf(&self, f: ScalarUDF) {
self.state
.write()
.scalar_functions
.insert(f.name.clone(), Arc::new(f));
}
pub fn register_udaf(&self, f: AggregateUDF) {
self.state
.write()
.aggregate_functions
.insert(f.name.clone(), Arc::new(f));
}
pub fn register_udwf(&self, f: WindowUDF) {
self.state
.write()
.window_functions
.insert(f.name.clone(), Arc::new(f));
}
async fn _read_type<'a, P: DataFilePaths>(
&self,
table_paths: P,
options: impl ReadOptions<'a>,
) -> Result<DataFrame> {
let table_paths = table_paths.to_urls()?;
let session_config = self.copied_config();
let listing_options = options.to_listing_options(&session_config);
let resolved_schema = options
.get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
.await?;
let config = ListingTableConfig::new_with_multi_paths(table_paths)
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let provider = ListingTable::try_new(config)?;
self.read_table(Arc::new(provider))
}
pub async fn read_avro<P: DataFilePaths>(
&self,
table_paths: P,
options: AvroReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}
pub async fn read_json<P: DataFilePaths>(
&self,
table_paths: P,
options: NdJsonReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}
pub async fn read_arrow<P: DataFilePaths>(
&self,
table_paths: P,
options: ArrowReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}
pub fn read_empty(&self) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::empty(true).build()?,
))
}
pub async fn read_csv<P: DataFilePaths>(
&self,
table_paths: P,
options: CsvReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}
pub async fn read_parquet<P: DataFilePaths>(
&self,
table_paths: P,
options: ParquetReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}
pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
.build()?,
))
}
pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?,
))
}
pub async fn register_listing_table(
&self,
name: &str,
table_path: impl AsRef<str>,
options: ListingOptions,
provided_schema: Option<SchemaRef>,
sql_definition: Option<String>,
) -> Result<()> {
let table_path = ListingTableUrl::parse(table_path)?;
let resolved_schema = match (provided_schema, options.infinite_source) {
(Some(s), _) => s,
(None, false) => options.infer_schema(&self.state(), &table_path).await?,
(None, true) => {
return Err(DataFusionError::Plan(
"Schema inference for infinite data sources is not supported."
.to_string(),
))
}
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?.with_definition(sql_definition);
self.register_table(
TableReference::Bare { table: name.into() },
Arc::new(table),
)?;
Ok(())
}
pub async fn register_csv(
&self,
name: &str,
table_path: &str,
options: CsvReadOptions<'_>,
) -> Result<()> {
let listing_options = options.to_listing_options(&self.copied_config());
self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}
pub async fn register_json(
&self,
name: &str,
table_path: &str,
options: NdJsonReadOptions<'_>,
) -> Result<()> {
let listing_options = options.to_listing_options(&self.copied_config());
self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}
pub async fn register_parquet(
&self,
name: &str,
table_path: &str,
options: ParquetReadOptions<'_>,
) -> Result<()> {
let listing_options = options.to_listing_options(&self.state.read().config);
self.register_listing_table(name, table_path, listing_options, None, None)
.await?;
Ok(())
}
pub async fn register_avro(
&self,
name: &str,
table_path: &str,
options: AvroReadOptions<'_>,
) -> Result<()> {
let listing_options = options.to_listing_options(&self.copied_config());
self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}
pub async fn register_arrow(
&self,
name: &str,
table_path: &str,
options: ArrowReadOptions<'_>,
) -> Result<()> {
let listing_options = options.to_listing_options(&self.copied_config());
self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}
pub fn register_catalog(
&self,
name: impl Into<String>,
catalog: Arc<dyn CatalogProvider>,
) -> Option<Arc<dyn CatalogProvider>> {
let name = name.into();
self.state
.read()
.catalog_list
.register_catalog(name, catalog)
}
pub fn catalog_names(&self) -> Vec<String> {
self.state.read().catalog_list.catalog_names()
}
pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
self.state.read().catalog_list.catalog(name)
}
pub fn register_table<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
provider: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table_ref = table_ref.into();
let table = table_ref.table().to_owned();
self.state
.read()
.schema_for_ref(table_ref)?
.register_table(table, provider)
}
pub fn deregister_table<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table_ref = table_ref.into();
let table = table_ref.table().to_owned();
self.state
.read()
.schema_for_ref(table_ref)?
.deregister_table(&table)
}
pub fn table_exist<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<bool> {
let table_ref = table_ref.into();
let table = table_ref.table().to_owned();
Ok(self
.state
.read()
.schema_for_ref(table_ref)?
.table_exist(&table))
}
pub async fn table<'a>(
&self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<DataFrame> {
let table_ref = table_ref.into();
let provider = self.table_provider(table_ref.to_owned_reference()).await?;
let plan = LogicalPlanBuilder::scan(
table_ref.to_owned_reference(),
provider_as_source(Arc::clone(&provider)),
None,
)?
.build()?;
Ok(DataFrame::new(self.state(), plan))
}
pub async fn table_provider<'a>(
&self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<Arc<dyn TableProvider>> {
let table_ref = table_ref.into();
let table = table_ref.table().to_string();
let schema = self.state.read().schema_for_ref(table_ref)?;
match schema.table(&table).await {
Some(ref provider) => Ok(Arc::clone(provider)),
_ => Err(DataFusionError::Plan(format!("No table named '{table}'"))),
}
}
#[deprecated(
since = "23.0.0",
note = "Please use the catalog provider interface (`SessionContext::catalog`) to examine available catalogs, schemas, and tables"
)]
pub fn tables(&self) -> Result<HashSet<String>> {
Ok(self
.state
.read()
.schema_for_ref(TableReference::Bare { table: "".into() })?
.table_names()
.iter()
.cloned()
.collect())
}
#[deprecated(
since = "23.0.0",
note = "Use SessionState::optimize to ensure a consistent state for planning and execution"
)]
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
self.state.read().optimize(plan)
}
#[deprecated(
since = "23.0.0",
note = "Use SessionState::create_physical_plan or DataFrame::create_physical_plan to ensure a consistent state for planning and execution"
)]
pub async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
self.state().create_physical_plan(logical_plan).await
}
pub async fn write_csv(
&self,
plan: Arc<dyn ExecutionPlan>,
path: impl AsRef<str>,
) -> Result<()> {
plan_to_csv(self.task_ctx(), plan, path).await
}
pub async fn write_json(
&self,
plan: Arc<dyn ExecutionPlan>,
path: impl AsRef<str>,
) -> Result<()> {
plan_to_json(self.task_ctx(), plan, path).await
}
pub async fn write_parquet(
&self,
plan: Arc<dyn ExecutionPlan>,
path: impl AsRef<str>,
writer_properties: Option<WriterProperties>,
) -> Result<()> {
plan_to_parquet(self.task_ctx(), plan, path, writer_properties).await
}
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}
pub fn state(&self) -> SessionState {
let mut state = self.state.read().clone();
state.execution_props.start_execution();
state
}
pub fn state_weak_ref(&self) -> Weak<RwLock<SessionState>> {
Arc::downgrade(&self.state)
}
pub fn register_catalog_list(&mut self, catalog_list: Arc<dyn CatalogList>) {
self.state.write().catalog_list = catalog_list;
}
}
impl FunctionRegistry for SessionContext {
fn udfs(&self) -> HashSet<String> {
self.state.read().udfs()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.state.read().udf(name)
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.state.read().udaf(name)
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.state.read().udwf(name)
}
}
#[async_trait]
pub trait QueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>>;
}
struct DefaultQueryPlanner {}
#[async_trait]
impl QueryPlanner for DefaultQueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
let planner = DefaultPhysicalPlanner::default();
planner
.create_physical_plan(logical_plan, session_state)
.await
}
}
#[derive(Clone)]
pub struct SessionState {
session_id: String,
analyzer: Analyzer,
optimizer: Optimizer,
physical_optimizers: PhysicalOptimizer,
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
catalog_list: Arc<dyn CatalogList>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
serializer_registry: Arc<dyn SerializerRegistry>,
config: SessionConfig,
execution_props: ExecutionProps,
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
runtime_env: Arc<RuntimeEnv>,
}
impl Debug for SessionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionState")
.field("session_id", &self.session_id)
.finish()
}
}
#[deprecated(
since = "23.0.0",
note = "See SessionContext::with_config() or SessionState::with_config_rt"
)]
pub fn default_session_builder(config: SessionConfig) -> SessionState {
SessionState::with_config_rt(config, Arc::new(RuntimeEnv::default()))
}
impl SessionState {
pub fn with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc<dyn CatalogList>;
Self::with_config_rt_and_catalog_list(config, runtime, catalog_list)
}
pub fn with_config_rt_and_catalog_list(
config: SessionConfig,
runtime: Arc<RuntimeEnv>,
catalog_list: Arc<dyn CatalogList>,
) -> Self {
let session_id = Uuid::new_v4().to_string();
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("PARQUET".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("CSV".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("JSON".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("NDJSON".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("AVRO".into(), Arc::new(ListingTableFactory::new()));
table_factories.insert("ARROW".into(), Arc::new(ListingTableFactory::new()));
if config.create_default_catalog_and_schema() {
let default_catalog = MemoryCatalogProvider::new();
default_catalog
.register_schema(
&config.options().catalog.default_schema,
Arc::new(MemorySchemaProvider::new()),
)
.expect("memory catalog provider can register schema");
Self::register_default_schema(
&config,
&table_factories,
&runtime,
&default_catalog,
);
catalog_list.register_catalog(
config.options().catalog.default_catalog.clone(),
Arc::new(default_catalog),
);
}
SessionState {
session_id,
analyzer: Analyzer::new(),
optimizer: Optimizer::new(),
physical_optimizers: PhysicalOptimizer::new(),
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
serializer_registry: Arc::new(EmptySerializerRegistry),
config,
execution_props: ExecutionProps::new(),
runtime_env: runtime,
table_factories,
}
}
fn register_default_schema(
config: &SessionConfig,
table_factories: &HashMap<String, Arc<dyn TableProviderFactory>>,
runtime: &Arc<RuntimeEnv>,
default_catalog: &MemoryCatalogProvider,
) {
let url = config.options().catalog.location.as_ref();
let format = config.options().catalog.format.as_ref();
let (url, format) = match (url, format) {
(Some(url), Some(format)) => (url, format),
_ => return,
};
let url = url.to_string();
let format = format.to_string();
let has_header = config.options().catalog.has_header;
let url = Url::parse(url.as_str()).expect("Invalid default catalog location!");
let authority = match url.host_str() {
Some(host) => format!("{}://{}", url.scheme(), host),
None => format!("{}://", url.scheme()),
};
let path = &url.as_str()[authority.len()..];
let path = object_store::path::Path::parse(path).expect("Can't parse path");
let store = ObjectStoreUrl::parse(authority.as_str())
.expect("Invalid default catalog url");
let store = match runtime.object_store(store) {
Ok(store) => store,
_ => return,
};
let factory = match table_factories.get(format.as_str()) {
Some(factory) => factory,
_ => return,
};
let schema = ListingSchemaProvider::new(
authority,
path,
factory.clone(),
store,
format,
has_header,
);
let _ = default_catalog
.register_schema("default", Arc::new(schema))
.expect("Failed to register default schema");
}
fn resolve_table_ref<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> ResolvedTableReference<'a> {
let catalog = &self.config_options().catalog;
table_ref
.into()
.resolve(&catalog.default_catalog, &catalog.default_schema)
}
pub(crate) fn schema_for_ref<'a>(
&'a self,
table_ref: impl Into<TableReference<'a>>,
) -> Result<Arc<dyn SchemaProvider>> {
let resolved_ref = self.resolve_table_ref(table_ref);
if self.config.information_schema() && resolved_ref.schema == INFORMATION_SCHEMA {
return Ok(Arc::new(InformationSchemaProvider::new(
self.catalog_list.clone(),
)));
}
self.catalog_list
.catalog(&resolved_ref.catalog)
.ok_or_else(|| {
DataFusionError::Plan(format!(
"failed to resolve catalog: {}",
resolved_ref.catalog
))
})?
.schema(&resolved_ref.schema)
.ok_or_else(|| {
DataFusionError::Plan(format!(
"failed to resolve schema: {}",
resolved_ref.schema
))
})
}
pub fn with_session_id(mut self, session_id: String) -> Self {
self.session_id = session_id;
self
}
pub fn with_query_planner(
mut self,
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
) -> Self {
self.query_planner = query_planner;
self
}
pub fn with_analyzer_rules(
mut self,
rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>>,
) -> Self {
self.analyzer = Analyzer::with_rules(rules);
self
}
pub fn with_optimizer_rules(
mut self,
rules: Vec<Arc<dyn OptimizerRule + Send + Sync>>,
) -> Self {
self.optimizer = Optimizer::with_rules(rules);
self
}
pub fn with_physical_optimizer_rules(
mut self,
physical_optimizers: Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>>,
) -> Self {
self.physical_optimizers = PhysicalOptimizer::with_rules(physical_optimizers);
self
}
pub fn add_analyzer_rule(
mut self,
analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>,
) -> Self {
self.analyzer.rules.push(analyzer_rule);
self
}
pub fn add_optimizer_rule(
mut self,
optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
) -> Self {
self.optimizer.rules.push(optimizer_rule);
self
}
pub fn add_physical_optimizer_rule(
mut self,
optimizer_rule: Arc<dyn PhysicalOptimizerRule + Send + Sync>,
) -> Self {
self.physical_optimizers.rules.push(optimizer_rule);
self
}
pub fn with_serializer_registry(
mut self,
registry: Arc<dyn SerializerRegistry>,
) -> Self {
self.serializer_registry = registry;
self
}
pub fn table_factories(&self) -> &HashMap<String, Arc<dyn TableProviderFactory>> {
&self.table_factories
}
pub fn table_factories_mut(
&mut self,
) -> &mut HashMap<String, Arc<dyn TableProviderFactory>> {
&mut self.table_factories
}
pub fn sql_to_statement(
&self,
sql: &str,
dialect: &str,
) -> Result<datafusion_sql::parser::Statement> {
let dialect = dialect_from_str(dialect).ok_or_else(|| {
DataFusionError::Plan(format!(
"Unsupported SQL dialect: {dialect}. Available dialects: \
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
MsSQL, ClickHouse, BigQuery, Ansi."
))
})?;
let mut statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?;
if statements.len() > 1 {
return Err(DataFusionError::NotImplemented(
"The context currently only supports a single SQL statement".to_string(),
));
}
let statement = statements.pop_front().ok_or_else(|| {
DataFusionError::NotImplemented(
"The context requires a statement!".to_string(),
)
})?;
Ok(statement)
}
pub fn resolve_table_references(
&self,
statement: &datafusion_sql::parser::Statement,
) -> Result<Vec<OwnedTableReference>> {
use crate::catalog::information_schema::INFORMATION_SCHEMA_TABLES;
use datafusion_sql::parser::Statement as DFStatement;
use sqlparser::ast::*;
let mut relations = hashbrown::HashSet::with_capacity(10);
struct RelationVisitor<'a>(&'a mut hashbrown::HashSet<ObjectName>);
impl<'a> RelationVisitor<'a> {
fn insert(&mut self, relation: &ObjectName) {
self.0.get_or_insert_with(relation, |_| relation.clone());
}
}
impl<'a> Visitor for RelationVisitor<'a> {
type Break = ();
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> {
self.insert(relation);
ControlFlow::Continue(())
}
fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> {
if let Statement::ShowCreate {
obj_type: ShowCreateObject::Table | ShowCreateObject::View,
obj_name,
} = statement
{
self.insert(obj_name)
}
ControlFlow::Continue(())
}
}
let mut visitor = RelationVisitor(&mut relations);
match statement {
DFStatement::Statement(s) => {
let _ = s.as_ref().visit(&mut visitor);
}
DFStatement::CreateExternalTable(table) => {
visitor
.0
.insert(ObjectName(vec![Ident::from(table.name.as_str())]));
}
DFStatement::DescribeTableStmt(table) => visitor.insert(&table.table_name),
DFStatement::CopyTo(CopyToStatement {
source,
target: _,
options: _,
}) => match source {
CopyToSource::Relation(table_name) => {
visitor.insert(table_name);
}
CopyToSource::Query(query) => {
query.visit(&mut visitor);
}
},
}
if self.config.information_schema() {
for s in INFORMATION_SCHEMA_TABLES {
relations.insert(ObjectName(vec![
Ident::new(INFORMATION_SCHEMA),
Ident::new(*s),
]));
}
}
let enable_ident_normalization =
self.config.options().sql_parser.enable_ident_normalization;
relations
.into_iter()
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
.collect::<Result<_>>()
}
pub async fn statement_to_plan(
&self,
statement: datafusion_sql::parser::Statement,
) -> Result<LogicalPlan> {
let references = self.resolve_table_references(&statement)?;
let mut provider = SessionContextProvider {
state: self,
tables: HashMap::with_capacity(references.len()),
};
let enable_ident_normalization =
self.config.options().sql_parser.enable_ident_normalization;
let parse_float_as_decimal =
self.config.options().sql_parser.parse_float_as_decimal;
for reference in references {
let table = reference.table();
let resolved = self.resolve_table_ref(&reference);
if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) {
if let Ok(schema) = self.schema_for_ref(resolved) {
if let Some(table) = schema.table(table).await {
v.insert(provider_as_source(table));
}
}
}
}
let query = SqlToRel::new_with_options(
&provider,
ParserOptions {
parse_float_as_decimal,
enable_ident_normalization,
},
);
query.statement_to_plan(statement)
}
pub async fn create_logical_plan(&self, sql: &str) -> Result<LogicalPlan> {
let dialect = self.config.options().sql_parser.dialect.as_str();
let statement = self.sql_to_statement(sql, dialect)?;
let plan = self.statement_to_plan(statement).await?;
Ok(plan)
}
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
if let LogicalPlan::Explain(e) = plan {
let mut stringified_plans = e.stringified_plans.clone();
let analyzed_plan = match self.analyzer.execute_and_check(
e.plan.as_ref(),
self.options(),
|analyzed_plan, analyzer| {
let analyzer_name = analyzer.name().to_string();
let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name };
stringified_plans.push(analyzed_plan.to_stringified(plan_type));
},
) {
Ok(plan) => plan,
Err(DataFusionError::Context(analyzer_name, err)) => {
let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name };
stringified_plans
.push(StringifiedPlan::new(plan_type, err.to_string()));
return Ok(LogicalPlan::Explain(Explain {
verbose: e.verbose,
plan: e.plan.clone(),
stringified_plans,
schema: e.schema.clone(),
logical_optimization_succeeded: false,
}));
}
Err(e) => return Err(e),
};
stringified_plans
.push(analyzed_plan.to_stringified(PlanType::FinalAnalyzedLogicalPlan));
let (plan, logical_optimization_succeeded) = match self.optimizer.optimize(
&analyzed_plan,
self,
|optimized_plan, optimizer| {
let optimizer_name = optimizer.name().to_string();
let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name };
stringified_plans.push(optimized_plan.to_stringified(plan_type));
},
) {
Ok(plan) => (Arc::new(plan), true),
Err(DataFusionError::Context(optimizer_name, err)) => {
let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name };
stringified_plans
.push(StringifiedPlan::new(plan_type, err.to_string()));
(e.plan.clone(), false)
}
Err(e) => return Err(e),
};
Ok(LogicalPlan::Explain(Explain {
verbose: e.verbose,
plan,
stringified_plans,
schema: e.schema.clone(),
logical_optimization_succeeded,
}))
} else {
let analyzed_plan =
self.analyzer
.execute_and_check(plan, self.options(), |_, _| {})?;
self.optimizer.optimize(&analyzed_plan, self, |_, _| {})
}
}
pub async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
let logical_plan = self.optimize(logical_plan)?;
self.query_planner
.create_physical_plan(&logical_plan, self)
.await
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn runtime_env(&self) -> &Arc<RuntimeEnv> {
&self.runtime_env
}
pub fn execution_props(&self) -> &ExecutionProps {
&self.execution_props
}
pub fn config(&self) -> &SessionConfig {
&self.config
}
pub fn physical_optimizers(&self) -> &[Arc<dyn PhysicalOptimizerRule + Send + Sync>] {
&self.physical_optimizers.rules
}
pub fn config_options(&self) -> &ConfigOptions {
self.config.options()
}
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}
pub fn catalog_list(&self) -> Arc<dyn CatalogList> {
self.catalog_list.clone()
}
pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
&self.scalar_functions
}
pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
&self.aggregate_functions
}
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
&self.window_functions
}
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
self.serializer_registry.clone()
}
pub fn version(&self) -> &str {
env!("CARGO_PKG_VERSION")
}
}
struct SessionContextProvider<'a> {
state: &'a SessionState,
tables: HashMap<String, Arc<dyn TableSource>>,
}
impl<'a> ContextProvider for SessionContextProvider<'a> {
fn get_table_provider(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
let name = self.state.resolve_table_ref(name).to_string();
self.tables
.get(&name)
.cloned()
.ok_or_else(|| DataFusionError::Plan(format!("table '{name}' not found")))
}
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state.aggregate_functions().get(name).cloned()
}
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
self.state.window_functions().get(name).cloned()
}
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
if variable_names.is_empty() {
return None;
}
let provider_type = if is_system_variables(variable_names) {
VarType::System
} else {
VarType::UserDefined
};
self.state
.execution_props
.var_providers
.as_ref()
.and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
}
fn options(&self) -> &ConfigOptions {
self.state.config_options()
}
}
impl FunctionRegistry for SessionState {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"There is no UDF named \"{name}\" in the registry"
))
})
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"There is no UDAF named \"{name}\" in the registry"
))
})
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
let result = self.window_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Plan(format!(
"There is no UDWF named \"{name}\" in the registry"
))
})
}
}
impl OptimizerConfig for SessionState {
fn query_execution_start_time(&self) -> DateTime<Utc> {
self.execution_props.query_execution_start_time
}
fn alias_generator(&self) -> Arc<AliasGenerator> {
self.execution_props.alias_generator.clone()
}
fn options(&self) -> &ConfigOptions {
self.config_options()
}
}
impl From<&SessionContext> for TaskContext {
fn from(session: &SessionContext) -> Self {
TaskContext::from(&*session.state.read())
}
}
impl From<&SessionState> for TaskContext {
fn from(state: &SessionState) -> Self {
let task_id = None;
TaskContext::new(
task_id,
state.session_id.clone(),
state.config.clone(),
state.scalar_functions.clone(),
state.aggregate_functions.clone(),
state.window_functions.clone(),
state.runtime_env.clone(),
)
}
}
pub struct EmptySerializerRegistry;
impl SerializerRegistry for EmptySerializerRegistry {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
Err(DataFusionError::NotImplemented(format!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)))
}
fn deserialize_logical_plan(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
Err(DataFusionError::NotImplemented(format!(
"Deserializing user defined logical plan node `{name}` is not supported"
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_batches_eq;
use crate::execution::context::QueryPlanner;
use crate::execution::memory_pool::MemoryConsumer;
use crate::execution::runtime_env::RuntimeConfig;
use crate::physical_plan::expressions::AvgAccumulator;
use crate::test;
use crate::test_util::parquet_test_data;
use crate::variable::VarType;
use arrow::array::ArrayRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_expr::{create_udaf, create_udf, Expr, Volatility};
use datafusion_physical_expr::functions::make_scalar_function;
use std::fs::File;
use std::path::PathBuf;
use std::sync::Weak;
use std::{env, io::prelude::*};
use tempfile::TempDir;
#[tokio::test]
async fn shared_memory_and_disk_manager() {
let ctx1 = SessionContext::new();
let memory_pool = ctx1.runtime_env().memory_pool.clone();
let mut reservation = MemoryConsumer::new("test").register(&memory_pool);
reservation.grow(100);
let disk_manager = ctx1.runtime_env().disk_manager.clone();
let ctx2 =
SessionContext::with_config_rt(SessionConfig::new(), ctx1.runtime_env());
assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100);
assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100);
drop(reservation);
assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 0);
assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 0);
assert!(std::ptr::eq(
Arc::as_ptr(&disk_manager),
Arc::as_ptr(&ctx1.runtime_env().disk_manager)
));
assert!(std::ptr::eq(
Arc::as_ptr(&disk_manager),
Arc::as_ptr(&ctx2.runtime_env().disk_manager)
));
}
#[tokio::test]
async fn create_variable_expr() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = create_ctx(&tmp_dir, partition_count).await?;
let variable_provider = test::variable::SystemVar::new();
ctx.register_variable(VarType::System, Arc::new(variable_provider));
let variable_provider = test::variable::UserDefinedVar::new();
ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));
let provider = test::create_table_dual();
ctx.register_table("dual", provider)?;
let results =
plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
.await?;
let expected = vec![
"+----------------------+------------------------+---------------------+",
"| @@version | @name | @integer + Int64(1) |",
"+----------------------+------------------------+---------------------+",
"| system-var-@@version | user-defined-var-@name | 42 |",
"+----------------------+------------------------+---------------------+",
];
assert_batches_eq!(expected, &results);
Ok(())
}
#[tokio::test]
async fn create_variable_err() -> Result<()> {
let ctx = SessionContext::new();
let err = plan_and_collect(&ctx, "SElECT @= X#=?!~ 5")
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"Error during planning: variable [\"@\"] has no type information"
);
Ok(())
}
#[tokio::test]
async fn register_deregister() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = create_ctx(&tmp_dir, partition_count).await?;
let provider = test::create_table_dual();
ctx.register_table("dual", provider)?;
assert!(ctx.deregister_table("dual")?.is_some());
assert!(ctx.deregister_table("dual")?.is_none());
Ok(())
}
#[tokio::test]
async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();
let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
ctx.register_udf(create_udf(
"MY_FUNC",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
myfunc,
));
let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
.await
.unwrap_err();
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_func\'"));
let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?;
let expected = vec![
"+--------------+",
"| MY_FUNC(t.i) |",
"+--------------+",
"| 1 |",
"+--------------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();
let my_avg = create_udaf(
"MY_AVG",
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
ctx.register_udaf(my_avg);
let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t")
.await
.unwrap_err();
assert!(err
.to_string()
.contains("Error during planning: Invalid function \'my_avg\'"));
let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?;
let expected = vec![
"+-------------+",
"| MY_AVG(t.i) |",
"+-------------+",
"| 1.0 |",
"+-------------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn query_csv_with_custom_partition_extension() -> Result<()> {
let tmp_dir = TempDir::new()?;
let file_extension = ".tst";
let ctx = SessionContext::new();
let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?;
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new()
.schema(&schema)
.file_extension(file_extension),
)
.await?;
let results =
plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?;
assert_eq!(results.len(), 1);
let expected = vec![
"+--------------+--------------+-----------------+",
"| SUM(test.c1) | SUM(test.c2) | COUNT(UInt8(1)) |",
"+--------------+--------------+-----------------+",
"| 10 | 110 | 20 |",
"+--------------+--------------+-----------------+",
];
assert_batches_eq!(expected, &results);
Ok(())
}
#[tokio::test]
async fn send_context_to_threads() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = Arc::new(create_ctx(&tmp_dir, partition_count).await?);
let threads: Vec<_> = (0..2)
.map(|_| ctx.clone())
.map(|ctx| {
tokio::spawn(async move {
ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")
.await
})
})
.collect();
for handle in threads {
handle.await.unwrap().unwrap();
}
Ok(())
}
#[tokio::test]
async fn with_listing_schema_provider() -> Result<()> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let path = path.join("tests/tpch-csv");
let url = format!("file://{}", path.display());
let rt_cfg = RuntimeConfig::new();
let runtime = Arc::new(RuntimeEnv::new(rt_cfg).unwrap());
let cfg = SessionConfig::new()
.set_str("datafusion.catalog.location", url.as_str())
.set_str("datafusion.catalog.format", "CSV")
.set_str("datafusion.catalog.has_header", "true");
let session_state = SessionState::with_config_rt(cfg, runtime);
let ctx = SessionContext::with_state(session_state);
ctx.refresh_catalogs().await?;
let result =
plan_and_collect(&ctx, "select c_name from default.customer limit 3;")
.await?;
let actual = arrow::util::pretty::pretty_format_batches(&result)
.unwrap()
.to_string();
let expected = r#"+--------------------+
| c_name |
+--------------------+
| Customer#000000002 |
| Customer#000000003 |
| Customer#000000004 |
+--------------------+"#;
assert_eq!(actual, expected);
Ok(())
}
#[tokio::test]
async fn custom_query_planner() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let session_state = SessionState::with_config_rt(SessionConfig::new(), runtime)
.with_query_planner(Arc::new(MyQueryPlanner {}));
let ctx = SessionContext::with_state(session_state);
let df = ctx.sql("SELECT 1").await?;
df.collect().await.expect_err("query not supported");
Ok(())
}
#[tokio::test]
async fn disabled_default_catalog_and_schema() -> Result<()> {
let ctx = SessionContext::with_config(
SessionConfig::new().with_create_default_catalog_and_schema(false),
);
assert!(matches!(
ctx.register_table("test", test::table_with_sequence(1, 1)?),
Err(DataFusionError::Plan(_))
));
assert!(matches!(
ctx.sql("select * from datafusion.public.test").await,
Err(DataFusionError::Plan(_))
));
Ok(())
}
#[tokio::test]
async fn custom_catalog_and_schema() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(true)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
#[tokio::test]
async fn custom_catalog_and_schema_no_default() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(false)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
#[tokio::test]
async fn custom_catalog_and_schema_and_information_schema() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(true)
.with_information_schema(true)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
async fn catalog_and_schema_test(config: SessionConfig) {
let ctx = SessionContext::with_config(config);
let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
schema
.register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap())
.unwrap();
catalog
.register_schema("my_schema", Arc::new(schema))
.unwrap();
ctx.register_catalog("my_catalog", Arc::new(catalog));
for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] {
let result = plan_and_collect(
&ctx,
&format!("SELECT COUNT(*) AS count FROM {table_ref}"),
)
.await
.unwrap();
let expected = vec![
"+-------+",
"| count |",
"+-------+",
"| 1 |",
"+-------+",
];
assert_batches_eq!(expected, &result);
}
}
#[tokio::test]
async fn cross_catalog_access() -> Result<()> {
let ctx = SessionContext::new();
let catalog_a = MemoryCatalogProvider::new();
let schema_a = MemorySchemaProvider::new();
schema_a
.register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
catalog_a.register_schema("schema_a", Arc::new(schema_a))?;
ctx.register_catalog("catalog_a", Arc::new(catalog_a));
let catalog_b = MemoryCatalogProvider::new();
let schema_b = MemorySchemaProvider::new();
schema_b
.register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
catalog_b.register_schema("schema_b", Arc::new(schema_b))?;
ctx.register_catalog("catalog_b", Arc::new(catalog_b));
let result = plan_and_collect(
&ctx,
"SELECT cat, SUM(i) AS total FROM (
SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a
UNION ALL
SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b
) AS all
GROUP BY cat
ORDER BY cat
",
)
.await?;
let expected = vec![
"+-----+-------+",
"| cat | total |",
"+-----+-------+",
"| a | 1 |",
"| b | 3 |",
"+-----+-------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn catalogs_not_leaked() {
let ctx = SessionContext::with_config(
SessionConfig::new().with_information_schema(true),
);
let catalog = Arc::new(MemoryCatalogProvider::new());
let catalog_weak = Arc::downgrade(&catalog);
ctx.register_catalog("my_catalog", catalog);
let catalog_list_weak = {
let state = ctx.state.read();
Arc::downgrade(&state.catalog_list)
};
drop(ctx);
assert_eq!(Weak::strong_count(&catalog_list_weak), 0);
assert_eq!(Weak::strong_count(&catalog_weak), 0);
}
#[tokio::test]
async fn sql_create_schema() -> Result<()> {
let ctx = SessionContext::with_config(
SessionConfig::new().with_information_schema(true),
);
ctx.sql("CREATE SCHEMA abc").await?.collect().await?;
ctx.sql("CREATE TABLE abc.y AS VALUES (1,2,3)")
.await?
.collect()
.await?;
let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
Ok(())
}
#[tokio::test]
async fn sql_create_catalog() -> Result<()> {
let ctx = SessionContext::with_config(
SessionConfig::new().with_information_schema(true),
);
ctx.sql("CREATE DATABASE test").await?.collect().await?;
ctx.sql("CREATE SCHEMA test.abc").await?.collect().await?;
ctx.sql("CREATE TABLE test.abc.y AS VALUES (1,2,3)")
.await?
.collect()
.await?;
let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_catalog='test' AND table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
Ok(())
}
#[tokio::test]
async fn read_with_glob_path() -> Result<()> {
let ctx = SessionContext::new();
let df = ctx
.read_parquet(
format!("{}/alltypes_plain*.parquet", parquet_test_data()),
ParquetReadOptions::default(),
)
.await?;
let results = df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 10);
Ok(())
}
#[tokio::test]
async fn read_with_glob_path_issue_2465() -> Result<()> {
let ctx = SessionContext::new();
let df = ctx
.read_parquet(
format!("{}/..//*/alltypes_plain*.parquet", parquet_test_data()),
ParquetReadOptions::default(),
)
.await?;
let results = df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 10);
Ok(())
}
#[tokio::test]
async fn read_from_registered_table_with_glob_path() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_parquet(
"test",
&format!("{}/alltypes_plain*.parquet", parquet_test_data()),
ParquetReadOptions::default(),
)
.await?;
let df = ctx.sql("SELECT * FROM test").await?;
let results = df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 10);
Ok(())
}
#[tokio::test]
async fn unsupported_sql_returns_error() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_table("test", test::table_with_sequence(1, 1).unwrap())
.unwrap();
let state = ctx.state();
let sql = "create view test_view as select * from test";
let plan = state.create_logical_plan(sql).await;
let physical_plan = state.create_physical_plan(&plan.unwrap()).await;
assert!(physical_plan.is_err());
assert_eq!(
format!("{}", physical_plan.unwrap_err()),
"This feature is not implemented: Unsupported logical plan: CreateView"
);
let sql = "drop view test_view";
let plan = state.create_logical_plan(sql).await;
let physical_plan = state.create_physical_plan(&plan.unwrap()).await;
assert!(physical_plan.is_err());
assert_eq!(
format!("{}", physical_plan.unwrap_err()),
"This feature is not implemented: Unsupported logical plan: DropView"
);
let sql = "drop table test";
let plan = state.create_logical_plan(sql).await;
let physical_plan = state.create_physical_plan(&plan.unwrap()).await;
assert!(physical_plan.is_err());
assert_eq!(
format!("{}", physical_plan.unwrap_err()),
"This feature is not implemented: Unsupported logical plan: DropTable"
);
Ok(())
}
struct MyPhysicalPlanner {}
#[async_trait]
impl PhysicalPlanner for MyPhysicalPlanner {
async fn create_physical_plan(
&self,
_logical_plan: &LogicalPlan,
_session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::NotImplemented(
"query not supported".to_string(),
))
}
fn create_physical_expr(
&self,
_expr: &Expr,
_input_dfschema: &crate::common::DFSchema,
_input_schema: &Schema,
_session_state: &SessionState,
) -> Result<Arc<dyn crate::physical_plan::PhysicalExpr>> {
unimplemented!()
}
}
struct MyQueryPlanner {}
#[async_trait]
impl QueryPlanner for MyQueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
let physical_planner = MyPhysicalPlanner {};
physical_planner
.create_physical_plan(logical_plan, session_state)
.await
}
}
async fn plan_and_collect(
ctx: &SessionContext,
sql: &str,
) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}
fn populate_csv_partitions(
tmp_dir: &TempDir,
partition_count: usize,
file_extension: &str,
) -> Result<SchemaRef> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt64, false),
Field::new("c3", DataType::Boolean, false),
]));
for partition in 0..partition_count {
let filename = format!("partition-{partition}.{file_extension}");
let file_path = tmp_dir.path().join(filename);
let mut file = File::create(file_path)?;
for i in 0..=10 {
let data = format!("{},{},{}\n", partition, i, i % 2 == 0);
file.write_all(data.as_bytes())?;
}
}
Ok(schema)
}
async fn create_ctx(
tmp_dir: &TempDir,
partition_count: usize,
) -> Result<SessionContext> {
let ctx =
SessionContext::with_config(SessionConfig::new().with_target_partitions(8));
let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?;
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new().schema(&schema),
)
.await?;
Ok(ctx)
}
#[async_trait]
trait CallReadTrait {
async fn call_read_csv(&self) -> DataFrame;
async fn call_read_avro(&self) -> DataFrame;
async fn call_read_parquet(&self) -> DataFrame;
}
struct CallRead {}
#[async_trait]
impl CallReadTrait for CallRead {
async fn call_read_csv(&self) -> DataFrame {
let ctx = SessionContext::new();
ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap()
}
async fn call_read_avro(&self) -> DataFrame {
let ctx = SessionContext::new();
ctx.read_avro("dummy", AvroReadOptions::default())
.await
.unwrap()
}
async fn call_read_parquet(&self) -> DataFrame {
let ctx = SessionContext::new();
ctx.read_parquet("dummy", ParquetReadOptions::default())
.await
.unwrap()
}
}
}