use std::collections::{hash_map::Entry, HashMap, HashSet};
use std::fmt::Debug;
use std::ops::ControlFlow;
use std::sync::{Arc, Weak};
use super::options::ReadOptions;
use crate::{
catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA},
catalog::listing_schema::ListingSchemaProvider,
catalog::schema::{MemorySchemaProvider, SchemaProvider},
catalog::{
CatalogProvider, CatalogProviderList, MemoryCatalogProvider,
MemoryCatalogProviderList,
},
config::ConfigOptions,
dataframe::DataFrame,
datasource::{
cte_worktable::CteWorkTable,
function::{TableFunction, TableFunctionImpl},
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
object_store::ObjectStoreUrl,
provider::{DefaultTableFactory, TableProviderFactory},
},
datasource::{provider_as_source, MemTable, TableProvider, ViewTable},
error::{DataFusionError, Result},
execution::{options::ArrowReadOptions, runtime_env::RuntimeEnv, FunctionRegistry},
logical_expr::AggregateUDF,
logical_expr::ScalarUDF,
logical_expr::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
DropView, Explain, LogicalPlan, LogicalPlanBuilder, PlanType, SetVariable,
TableSource, TableType, ToStringifiedPlan, UNNAMED_TABLE,
},
optimizer::analyzer::{Analyzer, AnalyzerRule},
optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule},
physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule},
physical_plan::ExecutionPlan,
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
variable::{VarProvider, VarType},
};
#[cfg(feature = "array_expressions")]
use crate::functions_array;
use crate::{functions, functions_aggregate};
use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use datafusion_common::{
alias::AliasGenerator,
config::{ConfigExtension, TableOptions},
exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNodeRecursion, TreeNodeVisitor},
DFSchema, SchemaReference, TableReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
var_provider::is_system_variables,
Expr, ExprSchemable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
use datafusion_sql::{
parser::{CopyToSource, CopyToStatement, DFParser},
planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel},
ResolvedTableReference,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_common::tree_node::TreeNode;
use parking_lot::RwLock;
use sqlparser::dialect::dialect_from_str;
use url::Url;
use uuid::Uuid;
use crate::physical_expr::PhysicalExpr;
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;
pub use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use datafusion_physical_expr::create_physical_expr;
mod avro;
mod csv;
mod json;
#[cfg(feature = "parquet")]
mod parquet;
pub trait DataFilePaths {
fn to_urls(self) -> Result<Vec<ListingTableUrl>>;
}
impl DataFilePaths for &str {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl DataFilePaths for String {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl DataFilePaths for &String {
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
Ok(vec![ListingTableUrl::parse(self)?])
}
}
impl<P> DataFilePaths for Vec<P>
where
P: AsRef<str>,
{
fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
self.iter()
.map(ListingTableUrl::parse)
.collect::<Result<Vec<ListingTableUrl>>>()
}
}
#[derive(Clone)]
pub struct SessionContext {
session_id: String,
session_start_time: DateTime<Utc>,
state: Arc<RwLock<SessionState>>,
}
impl Default for SessionContext {
fn default() -> Self {
Self::new()
}
}
impl SessionContext {
pub fn new() -> Self {
Self::new_with_config(SessionConfig::new())
}
pub async fn refresh_catalogs(&self) -> Result<()> {
let cat_names = self.catalog_names().clone();
for cat_name in cat_names.iter() {
let cat = self.catalog(cat_name.as_str()).ok_or_else(|| {
DataFusionError::Internal("Catalog not found!".to_string())
})?;
for schema_name in cat.schema_names() {
let schema = cat.schema(schema_name.as_str()).ok_or_else(|| {
DataFusionError::Internal("Schema not found!".to_string())
})?;
let lister = schema.as_any().downcast_ref::<ListingSchemaProvider>();
if let Some(lister) = lister {
lister.refresh(&self.state()).await?;
}
}
}
Ok(())
}
pub fn new_with_config(config: SessionConfig) -> Self {
let runtime = Arc::new(RuntimeEnv::default());
Self::new_with_config_rt(config, runtime)
}
#[deprecated(since = "32.0.0", note = "Use SessionContext::new_with_config")]
pub fn with_config(config: SessionConfig) -> Self {
Self::new_with_config(config)
}
pub fn new_with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let state = SessionState::new_with_config_rt(config, runtime);
Self::new_with_state(state)
}
#[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")]
pub fn with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
Self::new_with_config_rt(config, runtime)
}
pub fn new_with_state(state: SessionState) -> Self {
Self {
session_id: state.session_id.clone(),
session_start_time: Utc::now(),
state: Arc::new(RwLock::new(state)),
}
}
#[deprecated(since = "32.0.0", note = "Use SessionState::new_with_state")]
pub fn with_state(state: SessionState) -> Self {
Self::new_with_state(state)
}
pub fn session_start_time(&self) -> DateTime<Utc> {
self.session_start_time
}
pub fn with_function_factory(
self,
function_factory: Arc<dyn FunctionFactory>,
) -> Self {
self.state.write().set_function_factory(function_factory);
self
}
pub fn register_batch(
&self,
table_name: &str,
batch: RecordBatch,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
self.register_table(
TableReference::Bare {
table: table_name.into(),
},
Arc::new(table),
)
}
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
self.state.read().runtime_env.clone()
}
pub fn session_id(&self) -> String {
self.session_id.clone()
}
pub fn table_factory(
&self,
file_type: &str,
) -> Option<Arc<dyn TableProviderFactory>> {
self.state.read().table_factories().get(file_type).cloned()
}
pub fn enable_ident_normalization(&self) -> bool {
self.state
.read()
.config
.options()
.sql_parser
.enable_ident_normalization
}
pub fn copied_config(&self) -> SessionConfig {
self.state.read().config.clone()
}
pub fn copied_table_options(&self) -> TableOptions {
self.state.read().default_table_options()
}
pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
self.sql_with_options(sql, SQLOptions::new()).await
}
pub async fn sql_with_options(
&self,
sql: &str,
options: SQLOptions,
) -> Result<DataFrame> {
let plan = self.state().create_logical_plan(sql).await?;
options.verify_plan(&plan)?;
self.execute_logical_plan(plan).await
}
pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
match plan {
LogicalPlan::Ddl(ddl) => {
match ddl {
DdlStatement::CreateExternalTable(cmd) => {
Box::pin(async move { self.create_external_table(&cmd).await })
as std::pin::Pin<Box<dyn futures::Future<Output = _> + Send>>
}
DdlStatement::CreateMemoryTable(cmd) => {
Box::pin(self.create_memory_table(cmd))
}
DdlStatement::CreateView(cmd) => Box::pin(self.create_view(cmd)),
DdlStatement::CreateCatalogSchema(cmd) => {
Box::pin(self.create_catalog_schema(cmd))
}
DdlStatement::CreateCatalog(cmd) => {
Box::pin(self.create_catalog(cmd))
}
DdlStatement::DropTable(cmd) => Box::pin(self.drop_table(cmd)),
DdlStatement::DropView(cmd) => Box::pin(self.drop_view(cmd)),
DdlStatement::DropCatalogSchema(cmd) => {
Box::pin(self.drop_schema(cmd))
}
DdlStatement::CreateFunction(cmd) => {
Box::pin(self.create_function(cmd))
}
DdlStatement::DropFunction(cmd) => Box::pin(self.drop_function(cmd)),
}
.await
}
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
self.set_variable(stmt).await
}
plan => Ok(DataFrame::new(self.state(), plan)),
}
}
pub fn create_physical_expr(
&self,
expr: Expr,
df_schema: &DFSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
self.state.read().create_physical_expr(expr, df_schema)
}
fn return_empty_dataframe(&self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(DataFrame::new(self.state(), plan))
}
async fn create_external_table(
&self,
cmd: &CreateExternalTable,
) -> Result<DataFrame> {
let exist = self.table_exist(cmd.name.clone())?;
if exist {
match cmd.if_not_exists {
true => return self.return_empty_dataframe(),
false => {
return exec_err!("Table '{}' already exists", cmd.name);
}
}
}
let table_provider: Arc<dyn TableProvider> =
self.create_custom_table(cmd).await?;
self.register_table(cmd.name.clone(), table_provider)?;
self.return_empty_dataframe()
}
async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result<DataFrame> {
let CreateMemoryTable {
name,
input,
if_not_exists,
or_replace,
constraints,
column_defaults,
} = cmd;
let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone());
let input = self.state().optimize(&input)?;
let table = self.table(name.clone()).await;
match (if_not_exists, or_replace, table) {
(true, false, Ok(_)) => self.return_empty_dataframe(),
(false, true, Ok(_)) => {
self.deregister_table(name.clone())?;
let schema = Arc::new(input.schema().as_ref().into());
let physical = DataFrame::new(self.state(), input);
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(
MemTable::try_new(schema, batches)?
.with_constraints(constraints)
.with_column_defaults(column_defaults.into_iter().collect()),
);
self.register_table(name.clone(), table)?;
self.return_empty_dataframe()
}
(true, true, Ok(_)) => {
exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'")
}
(_, _, Err(_)) => {
let df_schema = input.schema();
let schema = Arc::new(df_schema.as_ref().into());
let physical = DataFrame::new(self.state(), input);
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(
MemTable::try_new(schema, batches)?
.with_constraints(constraints)
.with_column_defaults(column_defaults.into_iter().collect()),
);
self.register_table(name, table)?;
self.return_empty_dataframe()
}
(false, false, Ok(_)) => exec_err!("Table '{name}' already exists"),
}
}
async fn create_view(&self, cmd: CreateView) -> Result<DataFrame> {
let CreateView {
name,
input,
or_replace,
definition,
} = cmd;
let view = self.table(name.clone()).await;
match (or_replace, view) {
(true, Ok(_)) => {
self.deregister_table(name.clone())?;
let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?);
self.register_table(name, table)?;
self.return_empty_dataframe()
}
(_, Err(_)) => {
let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?);
self.register_table(name, table)?;
self.return_empty_dataframe()
}
(false, Ok(_)) => exec_err!("Table '{name}' already exists"),
}
}
async fn create_catalog_schema(&self, cmd: CreateCatalogSchema) -> Result<DataFrame> {
let CreateCatalogSchema {
schema_name,
if_not_exists,
..
} = cmd;
let tokens: Vec<&str> = schema_name.split('.').collect();
let (catalog, schema_name) = match tokens.len() {
1 => {
let state = self.state.read();
let name = &state.config.options().catalog.default_catalog;
let catalog = state.catalog_list.catalog(name).ok_or_else(|| {
DataFusionError::Execution(format!(
"Missing default catalog '{name}'"
))
})?;
(catalog, tokens[0])
}
2 => {
let name = &tokens[0];
let catalog = self.catalog(name).ok_or_else(|| {
DataFusionError::Execution(format!("Missing catalog '{name}'"))
})?;
(catalog, tokens[1])
}
_ => return exec_err!("Unable to parse catalog from {schema_name}"),
};
let schema = catalog.schema(schema_name);
match (if_not_exists, schema) {
(true, Some(_)) => self.return_empty_dataframe(),
(true, None) | (false, None) => {
let schema = Arc::new(MemorySchemaProvider::new());
catalog.register_schema(schema_name, schema)?;
self.return_empty_dataframe()
}
(false, Some(_)) => exec_err!("Schema '{schema_name}' already exists"),
}
}
async fn create_catalog(&self, cmd: CreateCatalog) -> Result<DataFrame> {
let CreateCatalog {
catalog_name,
if_not_exists,
..
} = cmd;
let catalog = self.catalog(catalog_name.as_str());
match (if_not_exists, catalog) {
(true, Some(_)) => self.return_empty_dataframe(),
(true, None) | (false, None) => {
let new_catalog = Arc::new(MemoryCatalogProvider::new());
self.state
.write()
.catalog_list
.register_catalog(catalog_name, new_catalog);
self.return_empty_dataframe()
}
(false, Some(_)) => exec_err!("Catalog '{catalog_name}' already exists"),
}
}
async fn drop_table(&self, cmd: DropTable) -> Result<DataFrame> {
let DropTable {
name, if_exists, ..
} = cmd;
let result = self
.find_and_deregister(name.clone(), TableType::Base)
.await;
match (result, if_exists) {
(Ok(true), _) => self.return_empty_dataframe(),
(_, true) => self.return_empty_dataframe(),
(_, _) => exec_err!("Table '{name}' doesn't exist."),
}
}
async fn drop_view(&self, cmd: DropView) -> Result<DataFrame> {
let DropView {
name, if_exists, ..
} = cmd;
let result = self
.find_and_deregister(name.clone(), TableType::View)
.await;
match (result, if_exists) {
(Ok(true), _) => self.return_empty_dataframe(),
(_, true) => self.return_empty_dataframe(),
(_, _) => exec_err!("View '{name}' doesn't exist."),
}
}
async fn drop_schema(&self, cmd: DropCatalogSchema) -> Result<DataFrame> {
let DropCatalogSchema {
name,
if_exists: allow_missing,
cascade,
schema: _,
} = cmd;
let catalog = {
let state = self.state.read();
let catalog_name = match &name {
SchemaReference::Full { catalog, .. } => catalog.to_string(),
SchemaReference::Bare { .. } => {
state.config_options().catalog.default_catalog.to_string()
}
};
if let Some(catalog) = state.catalog_list.catalog(&catalog_name) {
catalog
} else if allow_missing {
return self.return_empty_dataframe();
} else {
return self.schema_doesnt_exist_err(name);
}
};
let dereg = catalog.deregister_schema(name.schema_name(), cascade)?;
match (dereg, allow_missing) {
(None, true) => self.return_empty_dataframe(),
(None, false) => self.schema_doesnt_exist_err(name),
(Some(_), _) => self.return_empty_dataframe(),
}
}
fn schema_doesnt_exist_err(&self, schemaref: SchemaReference) -> Result<DataFrame> {
exec_err!("Schema '{schemaref}' doesn't exist.")
}
async fn set_variable(&self, stmt: SetVariable) -> Result<DataFrame> {
let SetVariable {
variable, value, ..
} = stmt;
let mut state = self.state.write();
state.config.options_mut().set(&variable, &value)?;
drop(state);
self.return_empty_dataframe()
}
async fn create_custom_table(
&self,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn TableProvider>> {
let state = self.state.read().clone();
let file_type = cmd.file_type.to_uppercase();
let factory =
&state
.table_factories
.get(file_type.as_str())
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory).create(&state, cmd).await?;
Ok(table)
}
async fn find_and_deregister<'a>(
&self,
table_ref: impl Into<TableReference>,
table_type: TableType,
) -> Result<bool> {
let table_ref = table_ref.into();
let table = table_ref.table().to_owned();
let maybe_schema = {
let state = self.state.read();
let resolved = state.resolve_table_ref(table_ref);
state
.catalog_list
.catalog(&resolved.catalog)
.and_then(|c| c.schema(&resolved.schema))
};
if let Some(schema) = maybe_schema {
if let Some(table_provider) = schema.table(&table).await? {
if table_provider.table_type() == table_type {
schema.deregister_table(&table)?;
return Ok(true);
}
}
}
Ok(false)
}
async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
let function = {
let state = self.state.read().clone();
let function_factory = &state.function_factory;
match function_factory {
Some(f) => f.create(&state, stmt).await?,
_ => Err(DataFusionError::Configuration(
"Function factory has not been configured".into(),
))?,
}
};
match function {
RegisterFunction::Scalar(f) => {
self.state.write().register_udf(f)?;
}
RegisterFunction::Aggregate(f) => {
self.state.write().register_udaf(f)?;
}
RegisterFunction::Window(f) => {
self.state.write().register_udwf(f)?;
}
RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
};
self.return_empty_dataframe()
}
async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
let mut dropped = false;
dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();
if !stmt.if_exists && !dropped {
exec_err!("Function does not exist")
} else {
self.return_empty_dataframe()
}
}
pub fn register_variable(
&self,
variable_type: VarType,
provider: Arc<dyn VarProvider + Send + Sync>,
) {
self.state
.write()
.execution_props
.add_var_provider(variable_type, provider);
}
pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
self.state.write().table_functions.insert(
name.to_owned(),
Arc::new(TableFunction::new(name.to_owned(), fun)),
);
}
pub fn register_udf(&self, f: ScalarUDF) {
let mut state = self.state.write();
state.register_udf(Arc::new(f)).ok();
}
pub fn register_udaf(&self, f: AggregateUDF) {
self.state.write().register_udaf(Arc::new(f)).ok();
}
pub fn register_udwf(&self, f: WindowUDF) {
self.state.write().register_udwf(Arc::new(f)).ok();
}
pub fn deregister_udf(&self, name: &str) {
self.state.write().deregister_udf(name).ok();
}
pub fn deregister_udaf(&self, name: &str) {
self.state.write().deregister_udaf(name).ok();
}
pub fn deregister_udwf(&self, name: &str) {
self.state.write().deregister_udwf(name).ok();
}
async fn _read_type<'a, P: DataFilePaths>(
&self,
table_paths: P,
options: impl ReadOptions<'a>,
) -> Result<DataFrame> {
let table_paths = table_paths.to_urls()?;
let session_config = self.copied_config();
let listing_options =
options.to_listing_options(&session_config, self.copied_table_options());
let option_extension = listing_options.file_extension.clone();
if table_paths.is_empty() {
return exec_err!("No table paths were provided");
}
for path in &table_paths {
let file_path = path.as_str();
if !file_path.ends_with(option_extension.clone().as_str())
&& !path.is_collection()
{
return exec_err!(
"File path '{file_path}' does not match the expected extension '{option_extension}'"
);
}
}
let resolved_schema = options
.get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
.await?;
let config = ListingTableConfig::new_with_multi_paths(table_paths)
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let provider = ListingTable::try_new(config)?;
self.read_table(Arc::new(provider))
}
pub async fn read_arrow<P: DataFilePaths>(
&self,
table_paths: P,
options: ArrowReadOptions<'_>,
) -> Result<DataFrame> {
self._read_type(table_paths, options).await
}
pub fn read_empty(&self) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::empty(true).build()?,
))
}
pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
.build()?,
))
}
pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?,
))
}
pub fn read_batches(
&self,
batches: impl IntoIterator<Item = RecordBatch>,
) -> Result<DataFrame> {
let mut batches = batches.into_iter().peekable();
let schema = if let Some(batch) = batches.peek() {
batch.schema().clone()
} else {
Arc::new(Schema::empty())
};
let provider = MemTable::try_new(schema, vec![batches.collect()])?;
Ok(DataFrame::new(
self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?,
))
}
pub async fn register_listing_table(
&self,
name: &str,
table_path: impl AsRef<str>,
options: ListingOptions,
provided_schema: Option<SchemaRef>,
sql_definition: Option<String>,
) -> Result<()> {
let table_path = ListingTableUrl::parse(table_path)?;
let resolved_schema = match provided_schema {
Some(s) => s,
None => options.infer_schema(&self.state(), &table_path).await?,
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?.with_definition(sql_definition);
self.register_table(
TableReference::Bare { table: name.into() },
Arc::new(table),
)?;
Ok(())
}
pub async fn register_arrow(
&self,
name: &str,
table_path: &str,
options: ArrowReadOptions<'_>,
) -> Result<()> {
let listing_options = options
.to_listing_options(&self.copied_config(), self.copied_table_options());
self.register_listing_table(
name,
table_path,
listing_options,
options.schema.map(|s| Arc::new(s.to_owned())),
None,
)
.await?;
Ok(())
}
pub fn register_catalog(
&self,
name: impl Into<String>,
catalog: Arc<dyn CatalogProvider>,
) -> Option<Arc<dyn CatalogProvider>> {
let name = name.into();
self.state
.read()
.catalog_list
.register_catalog(name, catalog)
}
pub fn catalog_names(&self) -> Vec<String> {
self.state.read().catalog_list.catalog_names()
}
pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
self.state.read().catalog_list.catalog(name)
}
pub fn register_table(
&self,
table_ref: impl Into<TableReference>,
provider: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table_ref: TableReference = table_ref.into();
let table = table_ref.table().to_owned();
self.state
.read()
.schema_for_ref(table_ref)?
.register_table(table, provider)
}
pub fn deregister_table(
&self,
table_ref: impl Into<TableReference>,
) -> Result<Option<Arc<dyn TableProvider>>> {
let table_ref = table_ref.into();
let table = table_ref.table().to_owned();
self.state
.read()
.schema_for_ref(table_ref)?
.deregister_table(&table)
}
pub fn table_exist(&self, table_ref: impl Into<TableReference>) -> Result<bool> {
let table_ref: TableReference = table_ref.into();
let table = table_ref.table();
let table_ref = table_ref.clone();
Ok(self
.state
.read()
.schema_for_ref(table_ref)?
.table_exist(table))
}
pub async fn table<'a>(
&self,
table_ref: impl Into<TableReference>,
) -> Result<DataFrame> {
let table_ref: TableReference = table_ref.into();
let provider = self.table_provider(table_ref.clone()).await?;
let plan = LogicalPlanBuilder::scan(
table_ref,
provider_as_source(Arc::clone(&provider)),
None,
)?
.build()?;
Ok(DataFrame::new(self.state(), plan))
}
pub async fn table_provider<'a>(
&self,
table_ref: impl Into<TableReference>,
) -> Result<Arc<dyn TableProvider>> {
let table_ref = table_ref.into();
let table = table_ref.table().to_string();
let schema = self.state.read().schema_for_ref(table_ref)?;
match schema.table(&table).await? {
Some(ref provider) => Ok(Arc::clone(provider)),
_ => plan_err!("No table named '{table}'"),
}
}
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}
pub fn state(&self) -> SessionState {
let mut state = self.state.read().clone();
state.execution_props.start_execution();
state
}
pub fn state_weak_ref(&self) -> Weak<RwLock<SessionState>> {
Arc::downgrade(&self.state)
}
pub fn register_catalog_list(&mut self, catalog_list: Arc<dyn CatalogProviderList>) {
self.state.write().catalog_list = catalog_list;
}
pub fn register_table_options_extension<T: ConfigExtension>(&self, extension: T) {
self.state
.write()
.table_option_namespace
.extensions
.insert(extension)
}
}
impl FunctionRegistry for SessionContext {
fn udfs(&self) -> HashSet<String> {
self.state.read().udfs()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.state.read().udf(name)
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.state.read().udaf(name)
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.state.read().udwf(name)
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
self.state.write().register_udf(udf)
}
fn register_udaf(
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
self.state.write().register_udaf(udaf)
}
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
self.state.write().register_udwf(udwf)
}
fn register_function_rewrite(
&mut self,
rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
) -> Result<()> {
self.state.write().register_function_rewrite(rewrite)
}
}
#[async_trait]
pub trait QueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>>;
}
struct DefaultQueryPlanner {}
#[async_trait]
impl QueryPlanner for DefaultQueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
let planner = DefaultPhysicalPlanner::default();
planner
.create_physical_plan(logical_plan, session_state)
.await
}
}
#[async_trait]
pub trait FunctionFactory: Sync + Send {
async fn create(
&self,
state: &SessionState,
statement: CreateFunction,
) -> Result<RegisterFunction>;
}
pub enum RegisterFunction {
Scalar(Arc<ScalarUDF>),
Aggregate(Arc<AggregateUDF>),
Window(Arc<WindowUDF>),
Table(String, Arc<dyn TableFunctionImpl>),
}
#[derive(Clone)]
pub struct SessionState {
session_id: String,
analyzer: Analyzer,
optimizer: Optimizer,
physical_optimizers: PhysicalOptimizer,
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
catalog_list: Arc<dyn CatalogProviderList>,
table_functions: HashMap<String, Arc<TableFunction>>,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
serializer_registry: Arc<dyn SerializerRegistry>,
config: SessionConfig,
table_option_namespace: TableOptions,
execution_props: ExecutionProps,
table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
runtime_env: Arc<RuntimeEnv>,
function_factory: Option<Arc<dyn FunctionFactory>>,
}
impl Debug for SessionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionState")
.field("session_id", &self.session_id)
.finish()
}
}
impl SessionState {
pub fn new_with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
let catalog_list =
Arc::new(MemoryCatalogProviderList::new()) as Arc<dyn CatalogProviderList>;
Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list)
}
#[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")]
pub fn with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
Self::new_with_config_rt(config, runtime)
}
pub fn new_with_config_rt_and_catalog_list(
config: SessionConfig,
runtime: Arc<RuntimeEnv>,
catalog_list: Arc<dyn CatalogProviderList>,
) -> Self {
let session_id = Uuid::new_v4().to_string();
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
#[cfg(feature = "parquet")]
table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new()));
table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new()));
table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new()));
table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new()));
table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new()));
table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new()));
if config.create_default_catalog_and_schema() {
let default_catalog = MemoryCatalogProvider::new();
default_catalog
.register_schema(
&config.options().catalog.default_schema,
Arc::new(MemorySchemaProvider::new()),
)
.expect("memory catalog provider can register schema");
Self::register_default_schema(
&config,
&table_factories,
&runtime,
&default_catalog,
);
catalog_list.register_catalog(
config.options().catalog.default_catalog.clone(),
Arc::new(default_catalog),
);
}
let mut new_self = SessionState {
session_id,
analyzer: Analyzer::new(),
optimizer: Optimizer::new(),
physical_optimizers: PhysicalOptimizer::new(),
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
table_functions: HashMap::new(),
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
serializer_registry: Arc::new(EmptySerializerRegistry),
table_option_namespace: TableOptions::default_from_session_config(
config.options(),
),
config,
execution_props: ExecutionProps::new(),
runtime_env: runtime,
table_factories,
function_factory: None,
};
functions::register_all(&mut new_self)
.expect("can not register built in functions");
#[cfg(feature = "array_expressions")]
functions_array::register_all(&mut new_self)
.expect("can not register array expressions");
functions_aggregate::register_all(&mut new_self)
.expect("can not register aggregate functions");
new_self
}
#[deprecated(
since = "32.0.0",
note = "Use SessionState::new_with_config_rt_and_catalog_list"
)]
pub fn with_config_rt_and_catalog_list(
config: SessionConfig,
runtime: Arc<RuntimeEnv>,
catalog_list: Arc<dyn CatalogProviderList>,
) -> Self {
Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list)
}
fn register_default_schema(
config: &SessionConfig,
table_factories: &HashMap<String, Arc<dyn TableProviderFactory>>,
runtime: &Arc<RuntimeEnv>,
default_catalog: &MemoryCatalogProvider,
) {
let url = config.options().catalog.location.as_ref();
let format = config.options().catalog.format.as_ref();
let (url, format) = match (url, format) {
(Some(url), Some(format)) => (url, format),
_ => return,
};
let url = url.to_string();
let format = format.to_string();
let has_header = config.options().catalog.has_header;
let url = Url::parse(url.as_str()).expect("Invalid default catalog location!");
let authority = match url.host_str() {
Some(host) => format!("{}://{}", url.scheme(), host),
None => format!("{}://", url.scheme()),
};
let path = &url.as_str()[authority.len()..];
let path = object_store::path::Path::parse(path).expect("Can't parse path");
let store = ObjectStoreUrl::parse(authority.as_str())
.expect("Invalid default catalog url");
let store = match runtime.object_store(store) {
Ok(store) => store,
_ => return,
};
let factory = match table_factories.get(format.as_str()) {
Some(factory) => factory,
_ => return,
};
let schema = ListingSchemaProvider::new(
authority,
path,
factory.clone(),
store,
format,
has_header,
);
let _ = default_catalog
.register_schema("default", Arc::new(schema))
.expect("Failed to register default schema");
}
fn resolve_table_ref(
&self,
table_ref: impl Into<TableReference>,
) -> ResolvedTableReference {
let catalog = &self.config_options().catalog;
table_ref
.into()
.resolve(&catalog.default_catalog, &catalog.default_schema)
}
pub(crate) fn schema_for_ref(
&self,
table_ref: impl Into<TableReference>,
) -> Result<Arc<dyn SchemaProvider>> {
let resolved_ref = self.resolve_table_ref(table_ref);
if self.config.information_schema() && *resolved_ref.schema == *INFORMATION_SCHEMA
{
return Ok(Arc::new(InformationSchemaProvider::new(
self.catalog_list.clone(),
)));
}
self.catalog_list
.catalog(&resolved_ref.catalog)
.ok_or_else(|| {
plan_datafusion_err!(
"failed to resolve catalog: {}",
resolved_ref.catalog
)
})?
.schema(&resolved_ref.schema)
.ok_or_else(|| {
plan_datafusion_err!("failed to resolve schema: {}", resolved_ref.schema)
})
}
pub fn with_session_id(mut self, session_id: String) -> Self {
self.session_id = session_id;
self
}
pub fn with_query_planner(
mut self,
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
) -> Self {
self.query_planner = query_planner;
self
}
pub fn with_analyzer_rules(
mut self,
rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>>,
) -> Self {
self.analyzer = Analyzer::with_rules(rules);
self
}
pub fn with_optimizer_rules(
mut self,
rules: Vec<Arc<dyn OptimizerRule + Send + Sync>>,
) -> Self {
self.optimizer = Optimizer::with_rules(rules);
self
}
pub fn with_physical_optimizer_rules(
mut self,
physical_optimizers: Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>>,
) -> Self {
self.physical_optimizers = PhysicalOptimizer::with_rules(physical_optimizers);
self
}
pub fn add_analyzer_rule(
mut self,
analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>,
) -> Self {
self.analyzer.rules.push(analyzer_rule);
self
}
pub fn add_optimizer_rule(
mut self,
optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
) -> Self {
self.optimizer.rules.push(optimizer_rule);
self
}
pub fn add_physical_optimizer_rule(
mut self,
physical_optimizer_rule: Arc<dyn PhysicalOptimizerRule + Send + Sync>,
) -> Self {
self.physical_optimizers.rules.push(physical_optimizer_rule);
self
}
pub fn add_table_options_extension<T: ConfigExtension>(
mut self,
extension: T,
) -> Self {
self.table_option_namespace.extensions.insert(extension);
self
}
pub fn with_function_factory(
mut self,
function_factory: Arc<dyn FunctionFactory>,
) -> Self {
self.function_factory = Some(function_factory);
self
}
pub fn set_function_factory(&mut self, function_factory: Arc<dyn FunctionFactory>) {
self.function_factory = Some(function_factory);
}
pub fn with_serializer_registry(
mut self,
registry: Arc<dyn SerializerRegistry>,
) -> Self {
self.serializer_registry = registry;
self
}
pub fn table_factories(&self) -> &HashMap<String, Arc<dyn TableProviderFactory>> {
&self.table_factories
}
pub fn table_factories_mut(
&mut self,
) -> &mut HashMap<String, Arc<dyn TableProviderFactory>> {
&mut self.table_factories
}
pub fn sql_to_statement(
&self,
sql: &str,
dialect: &str,
) -> Result<datafusion_sql::parser::Statement> {
let dialect = dialect_from_str(dialect).ok_or_else(|| {
plan_datafusion_err!(
"Unsupported SQL dialect: {dialect}. Available dialects: \
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
MsSQL, ClickHouse, BigQuery, Ansi."
)
})?;
let mut statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?;
if statements.len() > 1 {
return not_impl_err!(
"The context currently only supports a single SQL statement"
);
}
let statement = statements.pop_front().ok_or_else(|| {
DataFusionError::NotImplemented(
"The context requires a statement!".to_string(),
)
})?;
Ok(statement)
}
pub fn resolve_table_references(
&self,
statement: &datafusion_sql::parser::Statement,
) -> Result<Vec<TableReference>> {
use crate::catalog::information_schema::INFORMATION_SCHEMA_TABLES;
use datafusion_sql::parser::Statement as DFStatement;
use sqlparser::ast::*;
let mut relations = hashbrown::HashSet::with_capacity(10);
struct RelationVisitor<'a>(&'a mut hashbrown::HashSet<ObjectName>);
impl<'a> RelationVisitor<'a> {
fn insert(&mut self, relation: &ObjectName) {
self.0.get_or_insert_with(relation, |_| relation.clone());
}
}
impl<'a> Visitor for RelationVisitor<'a> {
type Break = ();
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> {
self.insert(relation);
ControlFlow::Continue(())
}
fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> {
if let Statement::ShowCreate {
obj_type: ShowCreateObject::Table | ShowCreateObject::View,
obj_name,
} = statement
{
self.insert(obj_name)
}
ControlFlow::Continue(())
}
}
let mut visitor = RelationVisitor(&mut relations);
fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor<'_>) {
match statement {
DFStatement::Statement(s) => {
let _ = s.as_ref().visit(visitor);
}
DFStatement::CreateExternalTable(table) => {
visitor
.0
.insert(ObjectName(vec![Ident::from(table.name.as_str())]));
}
DFStatement::CopyTo(CopyToStatement { source, .. }) => match source {
CopyToSource::Relation(table_name) => {
visitor.insert(table_name);
}
CopyToSource::Query(query) => {
query.visit(visitor);
}
},
DFStatement::Explain(explain) => {
visit_statement(&explain.statement, visitor)
}
}
}
visit_statement(statement, &mut visitor);
if self.config.information_schema() {
for s in INFORMATION_SCHEMA_TABLES {
relations.insert(ObjectName(vec![
Ident::new(INFORMATION_SCHEMA),
Ident::new(*s),
]));
}
}
let enable_ident_normalization =
self.config.options().sql_parser.enable_ident_normalization;
relations
.into_iter()
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
.collect::<Result<_>>()
}
pub async fn statement_to_plan(
&self,
statement: datafusion_sql::parser::Statement,
) -> Result<LogicalPlan> {
let references = self.resolve_table_references(&statement)?;
let mut provider = SessionContextProvider {
state: self,
tables: HashMap::with_capacity(references.len()),
};
let enable_ident_normalization =
self.config.options().sql_parser.enable_ident_normalization;
let parse_float_as_decimal =
self.config.options().sql_parser.parse_float_as_decimal;
for reference in references {
let resolved = &self.resolve_table_ref(reference);
if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) {
if let Ok(schema) = self.schema_for_ref(resolved.clone()) {
if let Some(table) = schema.table(&resolved.table).await? {
v.insert(provider_as_source(table));
}
}
}
}
let query = SqlToRel::new_with_options(
&provider,
ParserOptions {
parse_float_as_decimal,
enable_ident_normalization,
},
);
query.statement_to_plan(statement)
}
pub async fn create_logical_plan(&self, sql: &str) -> Result<LogicalPlan> {
let dialect = self.config.options().sql_parser.dialect.as_str();
let statement = self.sql_to_statement(sql, dialect)?;
let plan = self.statement_to_plan(statement).await?;
Ok(plan)
}
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
if let LogicalPlan::Explain(e) = plan {
let mut stringified_plans = e.stringified_plans.clone();
let analyzer_result = self.analyzer.execute_and_check(
e.plan.as_ref().clone(),
self.options(),
|analyzed_plan, analyzer| {
let analyzer_name = analyzer.name().to_string();
let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name };
stringified_plans.push(analyzed_plan.to_stringified(plan_type));
},
);
let analyzed_plan = match analyzer_result {
Ok(plan) => plan,
Err(DataFusionError::Context(analyzer_name, err)) => {
let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name };
stringified_plans
.push(StringifiedPlan::new(plan_type, err.to_string()));
return Ok(LogicalPlan::Explain(Explain {
verbose: e.verbose,
plan: e.plan.clone(),
stringified_plans,
schema: e.schema.clone(),
logical_optimization_succeeded: false,
}));
}
Err(e) => return Err(e),
};
stringified_plans
.push(analyzed_plan.to_stringified(PlanType::FinalAnalyzedLogicalPlan));
let optimized_plan = self.optimizer.optimize(
analyzed_plan,
self,
|optimized_plan, optimizer| {
let optimizer_name = optimizer.name().to_string();
let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name };
stringified_plans.push(optimized_plan.to_stringified(plan_type));
},
);
let (plan, logical_optimization_succeeded) = match optimized_plan {
Ok(plan) => (Arc::new(plan), true),
Err(DataFusionError::Context(optimizer_name, err)) => {
let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name };
stringified_plans
.push(StringifiedPlan::new(plan_type, err.to_string()));
(e.plan.clone(), false)
}
Err(e) => return Err(e),
};
Ok(LogicalPlan::Explain(Explain {
verbose: e.verbose,
plan,
stringified_plans,
schema: e.schema.clone(),
logical_optimization_succeeded,
}))
} else {
let analyzed_plan = self.analyzer.execute_and_check(
plan.clone(),
self.options(),
|_, _| {},
)?;
self.optimizer.optimize(analyzed_plan, self, |_, _| {})
}
}
pub async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
let logical_plan = self.optimize(logical_plan)?;
self.query_planner
.create_physical_plan(&logical_plan, self)
.await
}
pub fn create_physical_expr(
&self,
expr: Expr,
df_schema: &DFSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
let simplifier =
ExprSimplifier::new(SessionSimplifyProvider::new(self, df_schema));
let mut expr = simplifier.coerce(expr, df_schema)?;
let config_options = self.config_options();
for rewrite in self.analyzer.function_rewrites() {
expr = expr
.transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))?
.data;
}
create_physical_expr(&expr, df_schema, self.execution_props())
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn runtime_env(&self) -> &Arc<RuntimeEnv> {
&self.runtime_env
}
pub fn execution_props(&self) -> &ExecutionProps {
&self.execution_props
}
pub fn config(&self) -> &SessionConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut SessionConfig {
&mut self.config
}
pub fn physical_optimizers(&self) -> &[Arc<dyn PhysicalOptimizerRule + Send + Sync>] {
&self.physical_optimizers.rules
}
pub fn config_options(&self) -> &ConfigOptions {
self.config.options()
}
pub fn default_table_options(&self) -> TableOptions {
self.table_option_namespace
.combine_with_session_config(self.config_options())
}
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}
pub fn catalog_list(&self) -> Arc<dyn CatalogProviderList> {
self.catalog_list.clone()
}
pub fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
&self.scalar_functions
}
pub fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
&self.aggregate_functions
}
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
&self.window_functions
}
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
self.serializer_registry.clone()
}
pub fn version(&self) -> &str {
env!("CARGO_PKG_VERSION")
}
}
struct SessionSimplifyProvider<'a> {
state: &'a SessionState,
df_schema: &'a DFSchema,
}
impl<'a> SessionSimplifyProvider<'a> {
fn new(state: &'a SessionState, df_schema: &'a DFSchema) -> Self {
Self { state, df_schema }
}
}
impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> {
fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
Ok(expr.get_type(self.df_schema)? == DataType::Boolean)
}
fn nullable(&self, expr: &Expr) -> Result<bool> {
expr.nullable(self.df_schema)
}
fn execution_props(&self) -> &ExecutionProps {
self.state.execution_props()
}
fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
expr.get_type(self.df_schema)
}
}
struct SessionContextProvider<'a> {
state: &'a SessionState,
tables: HashMap<String, Arc<dyn TableSource>>,
}
impl<'a> ContextProvider for SessionContextProvider<'a> {
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
let name = self.state.resolve_table_ref(name).to_string();
self.tables
.get(&name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
}
fn get_table_function_source(
&self,
name: &str,
args: Vec<Expr>,
) -> Result<Arc<dyn TableSource>> {
let tbl_func = self
.state
.table_functions
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
let provider = tbl_func.create_table_provider(&args)?;
Ok(provider_as_source(provider))
}
fn create_cte_work_table(
&self,
name: &str,
schema: SchemaRef,
) -> Result<Arc<dyn TableSource>> {
let table = Arc::new(CteWorkTable::new(name, schema));
Ok(provider_as_source(table))
}
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state.aggregate_functions().get(name).cloned()
}
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
self.state.window_functions().get(name).cloned()
}
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
if variable_names.is_empty() {
return None;
}
let provider_type = if is_system_variables(variable_names) {
VarType::System
} else {
VarType::UserDefined
};
self.state
.execution_props
.var_providers
.as_ref()
.and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
}
fn options(&self) -> &ConfigOptions {
self.state.config_options()
}
fn udfs_names(&self) -> Vec<String> {
self.state.scalar_functions().keys().cloned().collect()
}
fn udafs_names(&self) -> Vec<String> {
self.state.aggregate_functions().keys().cloned().collect()
}
fn udwfs_names(&self) -> Vec<String> {
self.state.window_functions().keys().cloned().collect()
}
}
impl FunctionRegistry for SessionState {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);
result.cloned().ok_or_else(|| {
plan_datafusion_err!("There is no UDF named \"{name}\" in the registry")
})
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);
result.cloned().ok_or_else(|| {
plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry")
})
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
let result = self.window_functions.get(name);
result.cloned().ok_or_else(|| {
plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry")
})
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
udf.aliases().iter().for_each(|alias| {
self.scalar_functions.insert(alias.clone(), udf.clone());
});
Ok(self.scalar_functions.insert(udf.name().into(), udf))
}
fn register_udaf(
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
udaf.aliases().iter().for_each(|alias| {
self.aggregate_functions.insert(alias.clone(), udaf.clone());
});
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
}
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
udwf.aliases().iter().for_each(|alias| {
self.window_functions.insert(alias.clone(), udwf.clone());
});
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}
fn deregister_udf(&mut self, name: &str) -> Result<Option<Arc<ScalarUDF>>> {
let udf = self.scalar_functions.remove(name);
if let Some(udf) = &udf {
for alias in udf.aliases() {
self.scalar_functions.remove(alias);
}
}
Ok(udf)
}
fn deregister_udaf(&mut self, name: &str) -> Result<Option<Arc<AggregateUDF>>> {
let udaf = self.aggregate_functions.remove(name);
if let Some(udaf) = &udaf {
for alias in udaf.aliases() {
self.aggregate_functions.remove(alias);
}
}
Ok(udaf)
}
fn deregister_udwf(&mut self, name: &str) -> Result<Option<Arc<WindowUDF>>> {
let udwf = self.window_functions.remove(name);
if let Some(udwf) = &udwf {
for alias in udwf.aliases() {
self.window_functions.remove(alias);
}
}
Ok(udwf)
}
fn register_function_rewrite(
&mut self,
rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
) -> Result<()> {
self.analyzer.add_function_rewrite(rewrite);
Ok(())
}
}
impl OptimizerConfig for SessionState {
fn query_execution_start_time(&self) -> DateTime<Utc> {
self.execution_props.query_execution_start_time
}
fn alias_generator(&self) -> Arc<AliasGenerator> {
self.execution_props.alias_generator.clone()
}
fn options(&self) -> &ConfigOptions {
self.config_options()
}
}
impl From<&SessionContext> for TaskContext {
fn from(session: &SessionContext) -> Self {
TaskContext::from(&*session.state.read())
}
}
impl From<&SessionState> for TaskContext {
fn from(state: &SessionState) -> Self {
let task_id = None;
TaskContext::new(
task_id,
state.session_id.clone(),
state.config.clone(),
state.scalar_functions.clone(),
state.aggregate_functions.clone(),
state.window_functions.clone(),
state.runtime_env.clone(),
)
}
}
pub struct EmptySerializerRegistry;
impl SerializerRegistry for EmptySerializerRegistry {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)
}
fn deserialize_logical_plan(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
)
}
}
#[derive(Clone, Debug, Copy)]
pub struct SQLOptions {
allow_ddl: bool,
allow_dml: bool,
allow_statements: bool,
}
impl Default for SQLOptions {
fn default() -> Self {
Self {
allow_ddl: true,
allow_dml: true,
allow_statements: true,
}
}
}
impl SQLOptions {
pub fn new() -> Self {
Default::default()
}
pub fn with_allow_ddl(mut self, allow: bool) -> Self {
self.allow_ddl = allow;
self
}
pub fn with_allow_dml(mut self, allow: bool) -> Self {
self.allow_dml = allow;
self
}
pub fn with_allow_statements(mut self, allow: bool) -> Self {
self.allow_statements = allow;
self
}
pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> {
plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?;
Ok(())
}
}
struct BadPlanVisitor<'a> {
options: &'a SQLOptions,
}
impl<'a> BadPlanVisitor<'a> {
fn new(options: &'a SQLOptions) -> Self {
Self { options }
}
}
impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
type Node = LogicalPlan;
fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
plan_err!("DDL not supported: {}", ddl.name())
}
LogicalPlan::Dml(dml) if !self.options.allow_dml => {
plan_err!("DML not supported: {}", dml.op)
}
LogicalPlan::Copy(_) if !self.options.allow_dml => {
plan_err!("DML not supported: COPY")
}
LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
plan_err!("Statement not supported: {}", stmt.name())
}
_ => Ok(TreeNodeRecursion::Continue),
}
}
}
#[cfg(test)]
mod tests {
use std::env;
use std::path::PathBuf;
use super::{super::options::CsvReadOptions, *};
use crate::assert_batches_eq;
use crate::execution::memory_pool::MemoryConsumer;
use crate::execution::runtime_env::RuntimeConfig;
use crate::test;
use crate::test_util::{plan_and_collect, populate_csv_partitions};
use datafusion_common_runtime::SpawnedTask;
use async_trait::async_trait;
use tempfile::TempDir;
#[tokio::test]
async fn shared_memory_and_disk_manager() {
let ctx1 = SessionContext::new();
let memory_pool = ctx1.runtime_env().memory_pool.clone();
let mut reservation = MemoryConsumer::new("test").register(&memory_pool);
reservation.grow(100);
let disk_manager = ctx1.runtime_env().disk_manager.clone();
let ctx2 =
SessionContext::new_with_config_rt(SessionConfig::new(), ctx1.runtime_env());
assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100);
assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100);
drop(reservation);
assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 0);
assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 0);
assert!(std::ptr::eq(
Arc::as_ptr(&disk_manager),
Arc::as_ptr(&ctx1.runtime_env().disk_manager)
));
assert!(std::ptr::eq(
Arc::as_ptr(&disk_manager),
Arc::as_ptr(&ctx2.runtime_env().disk_manager)
));
}
#[tokio::test]
async fn create_variable_expr() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = create_ctx(&tmp_dir, partition_count).await?;
let variable_provider = test::variable::SystemVar::new();
ctx.register_variable(VarType::System, Arc::new(variable_provider));
let variable_provider = test::variable::UserDefinedVar::new();
ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));
let provider = test::create_table_dual();
ctx.register_table("dual", provider)?;
let results =
plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
.await?;
let expected = [
"+----------------------+------------------------+---------------------+",
"| @@version | @name | @integer + Int64(1) |",
"+----------------------+------------------------+---------------------+",
"| system-var-@@version | user-defined-var-@name | 42 |",
"+----------------------+------------------------+---------------------+",
];
assert_batches_eq!(expected, &results);
Ok(())
}
#[tokio::test]
async fn create_variable_err() -> Result<()> {
let ctx = SessionContext::new();
let err = plan_and_collect(&ctx, "SElECT @= X3").await.unwrap_err();
assert_eq!(
err.strip_backtrace(),
"Error during planning: variable [\"@=\"] has no type information"
);
Ok(())
}
#[tokio::test]
async fn register_deregister() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = create_ctx(&tmp_dir, partition_count).await?;
let provider = test::create_table_dual();
ctx.register_table("dual", provider)?;
assert!(ctx.deregister_table("dual")?.is_some());
assert!(ctx.deregister_table("dual")?.is_none());
Ok(())
}
#[tokio::test]
async fn send_context_to_threads() -> Result<()> {
let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = Arc::new(create_ctx(&tmp_dir, partition_count).await?);
let threads: Vec<_> = (0..2)
.map(|_| ctx.clone())
.map(|ctx| {
SpawnedTask::spawn(async move {
ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")
.await
})
})
.collect();
for handle in threads {
handle.join().await.unwrap().unwrap();
}
Ok(())
}
#[tokio::test]
async fn with_listing_schema_provider() -> Result<()> {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let path = path.join("tests/tpch-csv");
let url = format!("file://{}", path.display());
let rt_cfg = RuntimeConfig::new();
let runtime = Arc::new(RuntimeEnv::new(rt_cfg).unwrap());
let cfg = SessionConfig::new()
.set_str("datafusion.catalog.location", url.as_str())
.set_str("datafusion.catalog.format", "CSV")
.set_str("datafusion.catalog.has_header", "true");
let session_state = SessionState::new_with_config_rt(cfg, runtime);
let ctx = SessionContext::new_with_state(session_state);
ctx.refresh_catalogs().await?;
let result =
plan_and_collect(&ctx, "select c_name from default.customer limit 3;")
.await?;
let actual = arrow::util::pretty::pretty_format_batches(&result)
.unwrap()
.to_string();
let expected = r#"+--------------------+
| c_name |
+--------------------+
| Customer#000000002 |
| Customer#000000003 |
| Customer#000000004 |
+--------------------+"#;
assert_eq!(actual, expected);
Ok(())
}
#[tokio::test]
async fn custom_query_planner() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let session_state =
SessionState::new_with_config_rt(SessionConfig::new(), runtime)
.with_query_planner(Arc::new(MyQueryPlanner {}));
let ctx = SessionContext::new_with_state(session_state);
let df = ctx.sql("SELECT 1").await?;
df.collect().await.expect_err("query not supported");
Ok(())
}
#[tokio::test]
async fn disabled_default_catalog_and_schema() -> Result<()> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_create_default_catalog_and_schema(false),
);
assert!(matches!(
ctx.register_table("test", test::table_with_sequence(1, 1)?),
Err(DataFusionError::Plan(_))
));
assert!(matches!(
ctx.sql("select * from datafusion.public.test").await,
Err(DataFusionError::Plan(_))
));
Ok(())
}
#[tokio::test]
async fn custom_catalog_and_schema() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(true)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
#[tokio::test]
async fn custom_catalog_and_schema_no_default() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(false)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
#[tokio::test]
async fn custom_catalog_and_schema_and_information_schema() {
let config = SessionConfig::new()
.with_create_default_catalog_and_schema(true)
.with_information_schema(true)
.with_default_catalog_and_schema("my_catalog", "my_schema");
catalog_and_schema_test(config).await;
}
async fn catalog_and_schema_test(config: SessionConfig) {
let ctx = SessionContext::new_with_config(config);
let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
schema
.register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap())
.unwrap();
catalog
.register_schema("my_schema", Arc::new(schema))
.unwrap();
ctx.register_catalog("my_catalog", Arc::new(catalog));
for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] {
let result = plan_and_collect(
&ctx,
&format!("SELECT COUNT(*) AS count FROM {table_ref}"),
)
.await
.unwrap();
let expected = [
"+-------+",
"| count |",
"+-------+",
"| 1 |",
"+-------+",
];
assert_batches_eq!(expected, &result);
}
}
#[tokio::test]
async fn cross_catalog_access() -> Result<()> {
let ctx = SessionContext::new();
let catalog_a = MemoryCatalogProvider::new();
let schema_a = MemorySchemaProvider::new();
schema_a
.register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
catalog_a.register_schema("schema_a", Arc::new(schema_a))?;
ctx.register_catalog("catalog_a", Arc::new(catalog_a));
let catalog_b = MemoryCatalogProvider::new();
let schema_b = MemorySchemaProvider::new();
schema_b
.register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
catalog_b.register_schema("schema_b", Arc::new(schema_b))?;
ctx.register_catalog("catalog_b", Arc::new(catalog_b));
let result = plan_and_collect(
&ctx,
"SELECT cat, SUM(i) AS total FROM (
SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a
UNION ALL
SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b
) AS all
GROUP BY cat
ORDER BY cat
",
)
.await?;
let expected = [
"+-----+-------+",
"| cat | total |",
"+-----+-------+",
"| a | 1 |",
"| b | 3 |",
"+-----+-------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn catalogs_not_leaked() {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_information_schema(true),
);
let catalog = Arc::new(MemoryCatalogProvider::new());
let catalog_weak = Arc::downgrade(&catalog);
ctx.register_catalog("my_catalog", catalog);
let catalog_list_weak = {
let state = ctx.state.read();
Arc::downgrade(&state.catalog_list)
};
drop(ctx);
assert_eq!(Weak::strong_count(&catalog_list_weak), 0);
assert_eq!(Weak::strong_count(&catalog_weak), 0);
}
#[tokio::test]
async fn sql_create_schema() -> Result<()> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_information_schema(true),
);
ctx.sql("CREATE SCHEMA abc").await?.collect().await?;
ctx.sql("CREATE TABLE abc.y AS VALUES (1,2,3)")
.await?
.collect()
.await?;
let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
Ok(())
}
#[tokio::test]
async fn sql_create_catalog() -> Result<()> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_information_schema(true),
);
ctx.sql("CREATE DATABASE test").await?.collect().await?;
ctx.sql("CREATE SCHEMA test.abc").await?.collect().await?;
ctx.sql("CREATE TABLE test.abc.y AS VALUES (1,2,3)")
.await?
.collect()
.await?;
let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_catalog='test' AND table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
Ok(())
}
struct MyPhysicalPlanner {}
#[async_trait]
impl PhysicalPlanner for MyPhysicalPlanner {
async fn create_physical_plan(
&self,
_logical_plan: &LogicalPlan,
_session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
not_impl_err!("query not supported")
}
fn create_physical_expr(
&self,
_expr: &Expr,
_input_dfschema: &crate::common::DFSchema,
_session_state: &SessionState,
) -> Result<Arc<dyn crate::physical_plan::PhysicalExpr>> {
unimplemented!()
}
}
struct MyQueryPlanner {}
#[async_trait]
impl QueryPlanner for MyQueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
let physical_planner = MyPhysicalPlanner {};
physical_planner
.create_physical_plan(logical_plan, session_state)
.await
}
}
async fn create_ctx(
tmp_dir: &TempDir,
partition_count: usize,
) -> Result<SessionContext> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_target_partitions(8),
);
let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?;
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new().schema(&schema),
)
.await?;
Ok(ctx)
}
}