1use std::collections::HashSet;
21use std::fmt::Debug;
22use std::sync::{Arc, Weak};
23
24use super::options::ReadOptions;
25use crate::datasource::dynamic_file::DynamicListTableFactory;
26use crate::execution::session_state::SessionStateBuilder;
27use crate::{
28 catalog::listing_schema::ListingSchemaProvider,
29 catalog::{
30 CatalogProvider, CatalogProviderList, TableProvider, TableProviderFactory,
31 },
32 dataframe::DataFrame,
33 datasource::listing::{
34 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
35 },
36 datasource::{provider_as_source, MemTable, ViewTable},
37 error::{DataFusionError, Result},
38 execution::{
39 options::ArrowReadOptions,
40 runtime_env::{RuntimeEnv, RuntimeEnvBuilder},
41 FunctionRegistry,
42 },
43 logical_expr::AggregateUDF,
44 logical_expr::ScalarUDF,
45 logical_expr::{
46 CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
47 CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
48 DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable,
49 TableType, UNNAMED_TABLE,
50 },
51 physical_expr::PhysicalExpr,
52 physical_plan::ExecutionPlan,
53 variable::{VarProvider, VarType},
54};
55
56pub use crate::execution::session_state::SessionState;
58
59use arrow::datatypes::{Schema, SchemaRef};
60use arrow::record_batch::RecordBatch;
61use datafusion_catalog::memory::MemorySchemaProvider;
62use datafusion_catalog::MemoryCatalogProvider;
63use datafusion_catalog::{
64 DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory,
65};
66use datafusion_common::config::ConfigOptions;
67use datafusion_common::{
68 config::{ConfigExtension, TableOptions},
69 exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err,
70 tree_node::{TreeNodeRecursion, TreeNodeVisitor},
71 DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference,
72};
73pub use datafusion_execution::config::SessionConfig;
74use datafusion_execution::registry::SerializerRegistry;
75pub use datafusion_execution::TaskContext;
76pub use datafusion_expr::execution_props::ExecutionProps;
77use datafusion_expr::{
78 expr_rewriter::FunctionRewrite,
79 logical_plan::{DdlStatement, Statement},
80 planner::ExprPlanner,
81 Expr, UserDefinedLogicalNode, WindowUDF,
82};
83use datafusion_optimizer::analyzer::type_coercion::TypeCoercion;
84use datafusion_optimizer::Analyzer;
85use datafusion_optimizer::{AnalyzerRule, OptimizerRule};
86use datafusion_session::SessionStore;
87
88use async_trait::async_trait;
89use chrono::{DateTime, Utc};
90use object_store::ObjectStore;
91use parking_lot::RwLock;
92use url::Url;
93
94mod csv;
95mod json;
96#[cfg(feature = "parquet")]
97mod parquet;
98
99#[cfg(feature = "avro")]
100mod avro;
101
102pub trait DataFilePaths {
106 fn to_urls(self) -> Result<Vec<ListingTableUrl>>;
108}
109
110impl DataFilePaths for &str {
111 fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
112 Ok(vec![ListingTableUrl::parse(self)?])
113 }
114}
115
116impl DataFilePaths for String {
117 fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
118 Ok(vec![ListingTableUrl::parse(self)?])
119 }
120}
121
122impl DataFilePaths for &String {
123 fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
124 Ok(vec![ListingTableUrl::parse(self)?])
125 }
126}
127
128impl<P> DataFilePaths for Vec<P>
129where
130 P: AsRef<str>,
131{
132 fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
133 self.iter()
134 .map(ListingTableUrl::parse)
135 .collect::<Result<Vec<ListingTableUrl>>>()
136 }
137}
138
139#[derive(Clone)]
275pub struct SessionContext {
276 session_id: String,
278 session_start_time: DateTime<Utc>,
280 state: Arc<RwLock<SessionState>>,
282}
283
284impl Default for SessionContext {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290impl SessionContext {
291 pub fn new() -> Self {
293 Self::new_with_config(SessionConfig::new())
294 }
295
296 pub async fn refresh_catalogs(&self) -> Result<()> {
298 let cat_names = self.catalog_names().clone();
299 for cat_name in cat_names.iter() {
300 let cat = self.catalog(cat_name.as_str()).ok_or_else(|| {
301 DataFusionError::Internal("Catalog not found!".to_string())
302 })?;
303 for schema_name in cat.schema_names() {
304 let schema = cat.schema(schema_name.as_str()).ok_or_else(|| {
305 DataFusionError::Internal("Schema not found!".to_string())
306 })?;
307 let lister = schema.as_any().downcast_ref::<ListingSchemaProvider>();
308 if let Some(lister) = lister {
309 lister.refresh(&self.state()).await?;
310 }
311 }
312 }
313 Ok(())
314 }
315
316 pub fn new_with_config(config: SessionConfig) -> Self {
322 let runtime = Arc::new(RuntimeEnv::default());
323 Self::new_with_config_rt(config, runtime)
324 }
325
326 pub fn new_with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
340 let state = SessionStateBuilder::new()
341 .with_config(config)
342 .with_runtime_env(runtime)
343 .with_default_features()
344 .build();
345 Self::new_with_state(state)
346 }
347
348 pub fn new_with_state(state: SessionState) -> Self {
350 Self {
351 session_id: state.session_id().to_string(),
352 session_start_time: Utc::now(),
353 state: Arc::new(RwLock::new(state)),
354 }
355 }
356
357 pub fn enable_url_table(self) -> Self {
397 let current_catalog_list = Arc::clone(self.state.read().catalog_list());
398 let factory = Arc::new(DynamicListTableFactory::new(SessionStore::new()));
399 let catalog_list = Arc::new(DynamicFileCatalog::new(
400 current_catalog_list,
401 Arc::clone(&factory) as Arc<dyn UrlTableFactory>,
402 ));
403
404 let session_id = self.session_id.clone();
405 let ctx: SessionContext = self
406 .into_state_builder()
407 .with_session_id(session_id)
408 .with_catalog_list(catalog_list)
409 .build()
410 .into();
411 factory.session_store().with_state(ctx.state_weak_ref());
413 ctx
414 }
415
416 pub fn into_state_builder(self) -> SessionStateBuilder {
441 let SessionContext {
442 session_id: _,
443 session_start_time: _,
444 state,
445 } = self;
446 let state = match Arc::try_unwrap(state) {
447 Ok(rwlock) => rwlock.into_inner(),
448 Err(state) => state.read().clone(),
449 };
450 SessionStateBuilder::from(state)
451 }
452
453 pub fn session_start_time(&self) -> DateTime<Utc> {
455 self.session_start_time
456 }
457
458 pub fn with_function_factory(
460 self,
461 function_factory: Arc<dyn FunctionFactory>,
462 ) -> Self {
463 self.state.write().set_function_factory(function_factory);
464 self
465 }
466
467 pub fn add_optimizer_rule(
471 &self,
472 optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
473 ) {
474 self.state.write().append_optimizer_rule(optimizer_rule);
475 }
476
477 pub fn add_analyzer_rule(&self, analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>) {
481 self.state.write().add_analyzer_rule(analyzer_rule);
482 }
483
484 pub fn register_object_store(
500 &self,
501 url: &Url,
502 object_store: Arc<dyn ObjectStore>,
503 ) -> Option<Arc<dyn ObjectStore>> {
504 self.runtime_env().register_object_store(url, object_store)
505 }
506
507 pub fn register_batch(
509 &self,
510 table_name: &str,
511 batch: RecordBatch,
512 ) -> Result<Option<Arc<dyn TableProvider>>> {
513 let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
514 self.register_table(
515 TableReference::Bare {
516 table: table_name.into(),
517 },
518 Arc::new(table),
519 )
520 }
521
522 pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
524 Arc::clone(self.state.read().runtime_env())
525 }
526
527 pub fn session_id(&self) -> String {
529 self.session_id.clone()
530 }
531
532 pub fn table_factory(
535 &self,
536 file_type: &str,
537 ) -> Option<Arc<dyn TableProviderFactory>> {
538 self.state.read().table_factories().get(file_type).cloned()
539 }
540
541 pub fn enable_ident_normalization(&self) -> bool {
543 self.state
544 .read()
545 .config()
546 .options()
547 .sql_parser
548 .enable_ident_normalization
549 }
550
551 pub fn copied_config(&self) -> SessionConfig {
553 self.state.read().config().clone()
554 }
555
556 pub fn copied_table_options(&self) -> TableOptions {
558 self.state.read().default_table_options()
559 }
560
561 pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
589 self.sql_with_options(sql, SQLOptions::new()).await
590 }
591
592 pub async fn sql_with_options(
619 &self,
620 sql: &str,
621 options: SQLOptions,
622 ) -> Result<DataFrame> {
623 let plan = self.state().create_logical_plan(sql).await?;
624 options.verify_plan(&plan)?;
625
626 self.execute_logical_plan(plan).await
627 }
628
629 pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result<Expr> {
652 self.state.read().create_logical_expr(sql, df_schema)
653 }
654
655 pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
663 match plan {
664 LogicalPlan::Ddl(ddl) => {
665 match ddl {
669 DdlStatement::CreateExternalTable(cmd) => {
670 (Box::pin(async move { self.create_external_table(&cmd).await })
671 as std::pin::Pin<Box<dyn futures::Future<Output = _> + Send>>)
672 .await
673 }
674 DdlStatement::CreateMemoryTable(cmd) => {
675 Box::pin(self.create_memory_table(cmd)).await
676 }
677 DdlStatement::CreateView(cmd) => {
678 Box::pin(self.create_view(cmd)).await
679 }
680 DdlStatement::CreateCatalogSchema(cmd) => {
681 Box::pin(self.create_catalog_schema(cmd)).await
682 }
683 DdlStatement::CreateCatalog(cmd) => {
684 Box::pin(self.create_catalog(cmd)).await
685 }
686 DdlStatement::DropTable(cmd) => Box::pin(self.drop_table(cmd)).await,
687 DdlStatement::DropView(cmd) => Box::pin(self.drop_view(cmd)).await,
688 DdlStatement::DropCatalogSchema(cmd) => {
689 Box::pin(self.drop_schema(cmd)).await
690 }
691 DdlStatement::CreateFunction(cmd) => {
692 Box::pin(self.create_function(cmd)).await
693 }
694 DdlStatement::DropFunction(cmd) => {
695 Box::pin(self.drop_function(cmd)).await
696 }
697 ddl => Ok(DataFrame::new(self.state(), LogicalPlan::Ddl(ddl))),
698 }
699 }
700 LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
702 self.set_variable(stmt).await
703 }
704 LogicalPlan::Statement(Statement::Prepare(Prepare {
705 name,
706 input,
707 data_types,
708 })) => {
709 if !data_types.is_empty() {
711 let param_names = input.get_parameter_names()?;
712 if param_names.len() != data_types.len() {
713 return plan_err!(
714 "Prepare specifies {} data types but query has {} parameters",
715 data_types.len(),
716 param_names.len()
717 );
718 }
719 }
720 self.state.write().store_prepared(name, data_types, input)?;
726 self.return_empty_dataframe()
727 }
728 LogicalPlan::Statement(Statement::Execute(execute)) => {
729 self.execute_prepared(execute)
730 }
731 LogicalPlan::Statement(Statement::Deallocate(deallocate)) => {
732 self.state
733 .write()
734 .remove_prepared(deallocate.name.as_str())?;
735 self.return_empty_dataframe()
736 }
737 plan => Ok(DataFrame::new(self.state(), plan)),
738 }
739 }
740
741 pub fn create_physical_expr(
769 &self,
770 expr: Expr,
771 df_schema: &DFSchema,
772 ) -> Result<Arc<dyn PhysicalExpr>> {
773 self.state.read().create_physical_expr(expr, df_schema)
774 }
775
776 fn return_empty_dataframe(&self) -> Result<DataFrame> {
778 let plan = LogicalPlanBuilder::empty(false).build()?;
779 Ok(DataFrame::new(self.state(), plan))
780 }
781
782 async fn create_external_table(
783 &self,
784 cmd: &CreateExternalTable,
785 ) -> Result<DataFrame> {
786 let exist = self.table_exist(cmd.name.clone())?;
787
788 if cmd.temporary {
789 return not_impl_err!("Temporary tables not supported");
790 }
791
792 if exist {
793 match cmd.if_not_exists {
794 true => return self.return_empty_dataframe(),
795 false => {
796 return exec_err!("Table '{}' already exists", cmd.name);
797 }
798 }
799 }
800
801 let table_provider: Arc<dyn TableProvider> =
802 self.create_custom_table(cmd).await?;
803 self.register_table(cmd.name.clone(), table_provider)?;
804 self.return_empty_dataframe()
805 }
806
807 async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result<DataFrame> {
808 let CreateMemoryTable {
809 name,
810 input,
811 if_not_exists,
812 or_replace,
813 constraints,
814 column_defaults,
815 temporary,
816 } = cmd;
817
818 let input = Arc::unwrap_or_clone(input);
819 let input = self.state().optimize(&input)?;
820
821 if temporary {
822 return not_impl_err!("Temporary tables not supported");
823 }
824
825 let table = self.table(name.clone()).await;
826 match (if_not_exists, or_replace, table) {
827 (true, false, Ok(_)) => self.return_empty_dataframe(),
828 (false, true, Ok(_)) => {
829 self.deregister_table(name.clone())?;
830 let schema = Arc::new(input.schema().as_ref().into());
831 let physical = DataFrame::new(self.state(), input);
832
833 let batches: Vec<_> = physical.collect_partitioned().await?;
834 let table = Arc::new(
835 MemTable::try_new(schema, batches)?
837 .with_constraints(constraints)
838 .with_column_defaults(column_defaults.into_iter().collect()),
839 );
840
841 self.register_table(name.clone(), table)?;
842 self.return_empty_dataframe()
843 }
844 (true, true, Ok(_)) => {
845 exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'")
846 }
847 (_, _, Err(_)) => {
848 let df_schema = input.schema();
849 let schema = Arc::new(df_schema.as_ref().into());
850 let physical = DataFrame::new(self.state(), input);
851
852 let batches: Vec<_> = physical.collect_partitioned().await?;
853 let table = Arc::new(
854 MemTable::try_new(schema, batches)?
856 .with_constraints(constraints)
857 .with_column_defaults(column_defaults.into_iter().collect()),
858 );
859
860 self.register_table(name, table)?;
861 self.return_empty_dataframe()
862 }
863 (false, false, Ok(_)) => exec_err!("Table '{name}' already exists"),
864 }
865 }
866
867 fn apply_type_coercion(logical_plan: LogicalPlan) -> Result<LogicalPlan> {
869 let options = ConfigOptions::default();
870 Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]).execute_and_check(
871 logical_plan,
872 &options,
873 |_, _| {},
874 )
875 }
876
877 async fn create_view(&self, cmd: CreateView) -> Result<DataFrame> {
878 let CreateView {
879 name,
880 input,
881 or_replace,
882 definition,
883 temporary,
884 } = cmd;
885
886 let view = self.table(name.clone()).await;
887
888 if temporary {
889 return not_impl_err!("Temporary views not supported");
890 }
891
892 match (or_replace, view) {
893 (true, Ok(_)) => {
894 self.deregister_table(name.clone())?;
895 let input = Self::apply_type_coercion(input.as_ref().clone())?;
896 let table = Arc::new(ViewTable::new(input, definition));
897 self.register_table(name, table)?;
898 self.return_empty_dataframe()
899 }
900 (_, Err(_)) => {
901 let input = Self::apply_type_coercion(input.as_ref().clone())?;
902 let table = Arc::new(ViewTable::new(input, definition));
903 self.register_table(name, table)?;
904 self.return_empty_dataframe()
905 }
906 (false, Ok(_)) => exec_err!("Table '{name}' already exists"),
907 }
908 }
909
910 async fn create_catalog_schema(&self, cmd: CreateCatalogSchema) -> Result<DataFrame> {
911 let CreateCatalogSchema {
912 schema_name,
913 if_not_exists,
914 ..
915 } = cmd;
916
917 let tokens: Vec<&str> = schema_name.split('.').collect();
920 let (catalog, schema_name) = match tokens.len() {
921 1 => {
922 let state = self.state.read();
923 let name = &state.config().options().catalog.default_catalog;
924 let catalog = state.catalog_list().catalog(name).ok_or_else(|| {
925 DataFusionError::Execution(format!(
926 "Missing default catalog '{name}'"
927 ))
928 })?;
929 (catalog, tokens[0])
930 }
931 2 => {
932 let name = &tokens[0];
933 let catalog = self.catalog(name).ok_or_else(|| {
934 DataFusionError::Execution(format!("Missing catalog '{name}'"))
935 })?;
936 (catalog, tokens[1])
937 }
938 _ => return exec_err!("Unable to parse catalog from {schema_name}"),
939 };
940 let schema = catalog.schema(schema_name);
941
942 match (if_not_exists, schema) {
943 (true, Some(_)) => self.return_empty_dataframe(),
944 (true, None) | (false, None) => {
945 let schema = Arc::new(MemorySchemaProvider::new());
946 catalog.register_schema(schema_name, schema)?;
947 self.return_empty_dataframe()
948 }
949 (false, Some(_)) => exec_err!("Schema '{schema_name}' already exists"),
950 }
951 }
952
953 async fn create_catalog(&self, cmd: CreateCatalog) -> Result<DataFrame> {
954 let CreateCatalog {
955 catalog_name,
956 if_not_exists,
957 ..
958 } = cmd;
959 let catalog = self.catalog(catalog_name.as_str());
960
961 match (if_not_exists, catalog) {
962 (true, Some(_)) => self.return_empty_dataframe(),
963 (true, None) | (false, None) => {
964 let new_catalog = Arc::new(MemoryCatalogProvider::new());
965 self.state
966 .write()
967 .catalog_list()
968 .register_catalog(catalog_name, new_catalog);
969 self.return_empty_dataframe()
970 }
971 (false, Some(_)) => exec_err!("Catalog '{catalog_name}' already exists"),
972 }
973 }
974
975 async fn drop_table(&self, cmd: DropTable) -> Result<DataFrame> {
976 let DropTable {
977 name, if_exists, ..
978 } = cmd;
979 let result = self
980 .find_and_deregister(name.clone(), TableType::Base)
981 .await;
982 match (result, if_exists) {
983 (Ok(true), _) => self.return_empty_dataframe(),
984 (_, true) => self.return_empty_dataframe(),
985 (_, _) => exec_err!("Table '{name}' doesn't exist."),
986 }
987 }
988
989 async fn drop_view(&self, cmd: DropView) -> Result<DataFrame> {
990 let DropView {
991 name, if_exists, ..
992 } = cmd;
993 let result = self
994 .find_and_deregister(name.clone(), TableType::View)
995 .await;
996 match (result, if_exists) {
997 (Ok(true), _) => self.return_empty_dataframe(),
998 (_, true) => self.return_empty_dataframe(),
999 (_, _) => exec_err!("View '{name}' doesn't exist."),
1000 }
1001 }
1002
1003 async fn drop_schema(&self, cmd: DropCatalogSchema) -> Result<DataFrame> {
1004 let DropCatalogSchema {
1005 name,
1006 if_exists: allow_missing,
1007 cascade,
1008 schema: _,
1009 } = cmd;
1010 let catalog = {
1011 let state = self.state.read();
1012 let catalog_name = match &name {
1013 SchemaReference::Full { catalog, .. } => catalog.to_string(),
1014 SchemaReference::Bare { .. } => {
1015 state.config_options().catalog.default_catalog.to_string()
1016 }
1017 };
1018 if let Some(catalog) = state.catalog_list().catalog(&catalog_name) {
1019 catalog
1020 } else if allow_missing {
1021 return self.return_empty_dataframe();
1022 } else {
1023 return self.schema_doesnt_exist_err(name);
1024 }
1025 };
1026 let dereg = catalog.deregister_schema(name.schema_name(), cascade)?;
1027 match (dereg, allow_missing) {
1028 (None, true) => self.return_empty_dataframe(),
1029 (None, false) => self.schema_doesnt_exist_err(name),
1030 (Some(_), _) => self.return_empty_dataframe(),
1031 }
1032 }
1033
1034 fn schema_doesnt_exist_err(&self, schemaref: SchemaReference) -> Result<DataFrame> {
1035 exec_err!("Schema '{schemaref}' doesn't exist.")
1036 }
1037
1038 async fn set_variable(&self, stmt: SetVariable) -> Result<DataFrame> {
1039 let SetVariable {
1040 variable, value, ..
1041 } = stmt;
1042
1043 if variable.starts_with("datafusion.runtime.") {
1045 self.set_runtime_variable(&variable, &value)?;
1046 } else {
1047 let mut state = self.state.write();
1048 state.config_mut().options_mut().set(&variable, &value)?;
1049 drop(state);
1050 }
1051
1052 self.return_empty_dataframe()
1053 }
1054
1055 fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> {
1056 let key = variable.strip_prefix("datafusion.runtime.").unwrap();
1057
1058 match key {
1059 "memory_limit" => {
1060 let memory_limit = Self::parse_memory_limit(value)?;
1061
1062 let mut state = self.state.write();
1063 let mut builder =
1064 RuntimeEnvBuilder::from_runtime_env(state.runtime_env());
1065 builder = builder.with_memory_limit(memory_limit, 1.0);
1066 *state = SessionStateBuilder::from(state.clone())
1067 .with_runtime_env(Arc::new(builder.build()?))
1068 .build();
1069 }
1070 _ => {
1071 return Err(DataFusionError::Plan(format!(
1072 "Unknown runtime configuration: {variable}"
1073 )))
1074 }
1075 }
1076 Ok(())
1077 }
1078
1079 pub fn parse_memory_limit(limit: &str) -> Result<usize> {
1090 let (number, unit) = limit.split_at(limit.len() - 1);
1091 let number: f64 = number.parse().map_err(|_| {
1092 DataFusionError::Plan(format!(
1093 "Failed to parse number from memory limit '{limit}'"
1094 ))
1095 })?;
1096
1097 match unit {
1098 "K" => Ok((number * 1024.0) as usize),
1099 "M" => Ok((number * 1024.0 * 1024.0) as usize),
1100 "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize),
1101 _ => Err(DataFusionError::Plan(format!(
1102 "Unsupported unit '{unit}' in memory limit '{limit}'"
1103 ))),
1104 }
1105 }
1106
1107 async fn create_custom_table(
1108 &self,
1109 cmd: &CreateExternalTable,
1110 ) -> Result<Arc<dyn TableProvider>> {
1111 let state = self.state.read().clone();
1112 let file_type = cmd.file_type.to_uppercase();
1113 let factory =
1114 state
1115 .table_factories()
1116 .get(file_type.as_str())
1117 .ok_or_else(|| {
1118 DataFusionError::Execution(format!(
1119 "Unable to find factory for {}",
1120 cmd.file_type
1121 ))
1122 })?;
1123 let table = (*factory).create(&state, cmd).await?;
1124 Ok(table)
1125 }
1126
1127 async fn find_and_deregister(
1128 &self,
1129 table_ref: impl Into<TableReference>,
1130 table_type: TableType,
1131 ) -> Result<bool> {
1132 let table_ref = table_ref.into();
1133 let table = table_ref.table().to_owned();
1134 let maybe_schema = {
1135 let state = self.state.read();
1136 let resolved = state.resolve_table_ref(table_ref);
1137 state
1138 .catalog_list()
1139 .catalog(&resolved.catalog)
1140 .and_then(|c| c.schema(&resolved.schema))
1141 };
1142
1143 if let Some(schema) = maybe_schema {
1144 if let Some(table_provider) = schema.table(&table).await? {
1145 if table_provider.table_type() == table_type {
1146 schema.deregister_table(&table)?;
1147 return Ok(true);
1148 }
1149 }
1150 }
1151
1152 Ok(false)
1153 }
1154
1155 async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
1156 let function = {
1157 let state = self.state.read().clone();
1158 let function_factory = state.function_factory();
1159
1160 match function_factory {
1161 Some(f) => f.create(&state, stmt).await?,
1162 _ => Err(DataFusionError::Configuration(
1163 "Function factory has not been configured".into(),
1164 ))?,
1165 }
1166 };
1167
1168 match function {
1169 RegisterFunction::Scalar(f) => {
1170 self.state.write().register_udf(f)?;
1171 }
1172 RegisterFunction::Aggregate(f) => {
1173 self.state.write().register_udaf(f)?;
1174 }
1175 RegisterFunction::Window(f) => {
1176 self.state.write().register_udwf(f)?;
1177 }
1178 RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
1179 };
1180
1181 self.return_empty_dataframe()
1182 }
1183
1184 async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
1185 let mut dropped = false;
1188 dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
1189 dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
1190 dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();
1191 dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some();
1192
1193 if !stmt.if_exists && !dropped {
1199 exec_err!("Function does not exist")
1200 } else {
1201 self.return_empty_dataframe()
1202 }
1203 }
1204
1205 fn execute_prepared(&self, execute: Execute) -> Result<DataFrame> {
1206 let Execute {
1207 name, parameters, ..
1208 } = execute;
1209 let prepared = self.state.read().get_prepared(&name).ok_or_else(|| {
1210 exec_datafusion_err!("Prepared statement '{}' does not exist", name)
1211 })?;
1212
1213 let mut params: Vec<ScalarValue> = parameters
1215 .into_iter()
1216 .map(|e| match e {
1217 Expr::Literal(scalar, _) => Ok(scalar),
1218 _ => not_impl_err!("Unsupported parameter type: {}", e),
1219 })
1220 .collect::<Result<_>>()?;
1221
1222 if !prepared.data_types.is_empty() {
1224 if params.len() != prepared.data_types.len() {
1225 return exec_err!(
1226 "Prepared statement '{}' expects {} parameters, but {} provided",
1227 name,
1228 prepared.data_types.len(),
1229 params.len()
1230 );
1231 }
1232 params = params
1233 .into_iter()
1234 .zip(prepared.data_types.iter())
1235 .map(|(e, dt)| e.cast_to(dt))
1236 .collect::<Result<_>>()?;
1237 }
1238
1239 let params = ParamValues::List(params);
1240 let plan = prepared
1241 .plan
1242 .as_ref()
1243 .clone()
1244 .replace_params_with_values(¶ms)?;
1245 Ok(DataFrame::new(self.state(), plan))
1246 }
1247
1248 pub fn register_variable(
1250 &self,
1251 variable_type: VarType,
1252 provider: Arc<dyn VarProvider + Send + Sync>,
1253 ) {
1254 self.state
1255 .write()
1256 .execution_props_mut()
1257 .add_var_provider(variable_type, provider);
1258 }
1259
1260 pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
1262 self.state.write().register_udtf(name, fun)
1263 }
1264
1265 pub fn register_udf(&self, f: ScalarUDF) {
1275 let mut state = self.state.write();
1276 state.register_udf(Arc::new(f)).ok();
1277 }
1278
1279 pub fn register_udaf(&self, f: AggregateUDF) {
1287 self.state.write().register_udaf(Arc::new(f)).ok();
1288 }
1289
1290 pub fn register_udwf(&self, f: WindowUDF) {
1298 self.state.write().register_udwf(Arc::new(f)).ok();
1299 }
1300
1301 pub fn deregister_udf(&self, name: &str) {
1303 self.state.write().deregister_udf(name).ok();
1304 }
1305
1306 pub fn deregister_udaf(&self, name: &str) {
1308 self.state.write().deregister_udaf(name).ok();
1309 }
1310
1311 pub fn deregister_udwf(&self, name: &str) {
1313 self.state.write().deregister_udwf(name).ok();
1314 }
1315
1316 pub fn deregister_udtf(&self, name: &str) {
1318 self.state.write().deregister_udtf(name).ok();
1319 }
1320
1321 async fn _read_type<'a, P: DataFilePaths>(
1326 &self,
1327 table_paths: P,
1328 options: impl ReadOptions<'a>,
1329 ) -> Result<DataFrame> {
1330 let table_paths = table_paths.to_urls()?;
1331 let session_config = self.copied_config();
1332 let listing_options =
1333 options.to_listing_options(&session_config, self.copied_table_options());
1334
1335 let option_extension = listing_options.file_extension.clone();
1336
1337 if table_paths.is_empty() {
1338 return exec_err!("No table paths were provided");
1339 }
1340
1341 for path in &table_paths {
1343 let file_path = path.as_str();
1344 if !file_path.ends_with(option_extension.clone().as_str())
1345 && !path.is_collection()
1346 {
1347 return exec_err!(
1348 "File path '{file_path}' does not match the expected extension '{option_extension}'"
1349 );
1350 }
1351 }
1352
1353 let resolved_schema = options
1354 .get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
1355 .await?;
1356 let config = ListingTableConfig::new_with_multi_paths(table_paths)
1357 .with_listing_options(listing_options)
1358 .with_schema(resolved_schema);
1359 let provider = ListingTable::try_new(config)?;
1360 self.read_table(Arc::new(provider))
1361 }
1362
1363 pub async fn read_arrow<P: DataFilePaths>(
1370 &self,
1371 table_paths: P,
1372 options: ArrowReadOptions<'_>,
1373 ) -> Result<DataFrame> {
1374 self._read_type(table_paths, options).await
1375 }
1376
1377 pub fn read_empty(&self) -> Result<DataFrame> {
1379 Ok(DataFrame::new(
1380 self.state(),
1381 LogicalPlanBuilder::empty(true).build()?,
1382 ))
1383 }
1384
1385 pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
1388 Ok(DataFrame::new(
1389 self.state(),
1390 LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
1391 .build()?,
1392 ))
1393 }
1394
1395 pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
1397 let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
1398 Ok(DataFrame::new(
1399 self.state(),
1400 LogicalPlanBuilder::scan(
1401 UNNAMED_TABLE,
1402 provider_as_source(Arc::new(provider)),
1403 None,
1404 )?
1405 .build()?,
1406 ))
1407 }
1408 pub fn read_batches(
1410 &self,
1411 batches: impl IntoIterator<Item = RecordBatch>,
1412 ) -> Result<DataFrame> {
1413 let mut batches = batches.into_iter().peekable();
1415 let schema = if let Some(batch) = batches.peek() {
1416 batch.schema()
1417 } else {
1418 Arc::new(Schema::empty())
1419 };
1420 let provider = MemTable::try_new(schema, vec![batches.collect()])?;
1421 Ok(DataFrame::new(
1422 self.state(),
1423 LogicalPlanBuilder::scan(
1424 UNNAMED_TABLE,
1425 provider_as_source(Arc::new(provider)),
1426 None,
1427 )?
1428 .build()?,
1429 ))
1430 }
1431 pub async fn register_listing_table(
1439 &self,
1440 table_ref: impl Into<TableReference>,
1441 table_path: impl AsRef<str>,
1442 options: ListingOptions,
1443 provided_schema: Option<SchemaRef>,
1444 sql_definition: Option<String>,
1445 ) -> Result<()> {
1446 let table_path = ListingTableUrl::parse(table_path)?;
1447 let resolved_schema = match provided_schema {
1448 Some(s) => s,
1449 None => options.infer_schema(&self.state(), &table_path).await?,
1450 };
1451 let config = ListingTableConfig::new(table_path)
1452 .with_listing_options(options)
1453 .with_schema(resolved_schema);
1454 let table = ListingTable::try_new(config)?.with_definition(sql_definition);
1455 self.register_table(table_ref, Arc::new(table))?;
1456 Ok(())
1457 }
1458
1459 fn register_type_check<P: DataFilePaths>(
1460 &self,
1461 table_paths: P,
1462 extension: impl AsRef<str>,
1463 ) -> Result<()> {
1464 let table_paths = table_paths.to_urls()?;
1465 if table_paths.is_empty() {
1466 return exec_err!("No table paths were provided");
1467 }
1468
1469 let extension = extension.as_ref();
1471 for path in &table_paths {
1472 let file_path = path.as_str();
1473 if !file_path.ends_with(extension) && !path.is_collection() {
1474 return exec_err!(
1475 "File path '{file_path}' does not match the expected extension '{extension}'"
1476 );
1477 }
1478 }
1479 Ok(())
1480 }
1481
1482 pub async fn register_arrow(
1485 &self,
1486 name: &str,
1487 table_path: &str,
1488 options: ArrowReadOptions<'_>,
1489 ) -> Result<()> {
1490 let listing_options = options
1491 .to_listing_options(&self.copied_config(), self.copied_table_options());
1492
1493 self.register_listing_table(
1494 name,
1495 table_path,
1496 listing_options,
1497 options.schema.map(|s| Arc::new(s.to_owned())),
1498 None,
1499 )
1500 .await?;
1501 Ok(())
1502 }
1503
1504 pub fn register_catalog(
1511 &self,
1512 name: impl Into<String>,
1513 catalog: Arc<dyn CatalogProvider>,
1514 ) -> Option<Arc<dyn CatalogProvider>> {
1515 let name = name.into();
1516 self.state
1517 .read()
1518 .catalog_list()
1519 .register_catalog(name, catalog)
1520 }
1521
1522 pub fn catalog_names(&self) -> Vec<String> {
1524 self.state.read().catalog_list().catalog_names()
1525 }
1526
1527 pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
1529 self.state.read().catalog_list().catalog(name)
1530 }
1531
1532 pub fn register_table(
1538 &self,
1539 table_ref: impl Into<TableReference>,
1540 provider: Arc<dyn TableProvider>,
1541 ) -> Result<Option<Arc<dyn TableProvider>>> {
1542 let table_ref: TableReference = table_ref.into();
1543 let table = table_ref.table().to_owned();
1544 self.state
1545 .read()
1546 .schema_for_ref(table_ref)?
1547 .register_table(table, provider)
1548 }
1549
1550 pub fn deregister_table(
1554 &self,
1555 table_ref: impl Into<TableReference>,
1556 ) -> Result<Option<Arc<dyn TableProvider>>> {
1557 let table_ref = table_ref.into();
1558 let table = table_ref.table().to_owned();
1559 self.state
1560 .read()
1561 .schema_for_ref(table_ref)?
1562 .deregister_table(&table)
1563 }
1564
1565 pub fn table_exist(&self, table_ref: impl Into<TableReference>) -> Result<bool> {
1567 let table_ref: TableReference = table_ref.into();
1568 let table = table_ref.table();
1569 let table_ref = table_ref.clone();
1570 Ok(self
1571 .state
1572 .read()
1573 .schema_for_ref(table_ref)?
1574 .table_exist(table))
1575 }
1576
1577 pub async fn table(&self, table_ref: impl Into<TableReference>) -> Result<DataFrame> {
1585 let table_ref: TableReference = table_ref.into();
1586 let provider = self.table_provider(table_ref.clone()).await?;
1587 let plan = LogicalPlanBuilder::scan(
1588 table_ref,
1589 provider_as_source(Arc::clone(&provider)),
1590 None,
1591 )?
1592 .build()?;
1593 Ok(DataFrame::new(self.state(), plan))
1594 }
1595
1596 pub fn table_function(&self, name: &str) -> Result<Arc<TableFunction>> {
1602 self.state
1603 .read()
1604 .table_functions()
1605 .get(name)
1606 .cloned()
1607 .ok_or_else(|| plan_datafusion_err!("Table function '{name}' not found"))
1608 }
1609
1610 pub async fn table_provider(
1612 &self,
1613 table_ref: impl Into<TableReference>,
1614 ) -> Result<Arc<dyn TableProvider>> {
1615 let table_ref = table_ref.into();
1616 let table = table_ref.table().to_string();
1617 let schema = self.state.read().schema_for_ref(table_ref)?;
1618 match schema.table(&table).await? {
1619 Some(ref provider) => Ok(Arc::clone(provider)),
1620 _ => plan_err!("No table named '{table}'"),
1621 }
1622 }
1623
1624 pub fn task_ctx(&self) -> Arc<TaskContext> {
1626 Arc::new(TaskContext::from(self))
1627 }
1628
1629 pub fn state(&self) -> SessionState {
1642 let mut state = self.state.read().clone();
1643 state.execution_props_mut().start_execution();
1644 state
1645 }
1646
1647 pub fn state_ref(&self) -> Arc<RwLock<SessionState>> {
1649 Arc::clone(&self.state)
1650 }
1651
1652 pub fn state_weak_ref(&self) -> Weak<RwLock<SessionState>> {
1654 Arc::downgrade(&self.state)
1655 }
1656
1657 pub fn register_catalog_list(&self, catalog_list: Arc<dyn CatalogProviderList>) {
1659 self.state.write().register_catalog_list(catalog_list)
1660 }
1661
1662 pub fn register_table_options_extension<T: ConfigExtension>(&self, extension: T) {
1665 self.state
1666 .write()
1667 .register_table_options_extension(extension)
1668 }
1669}
1670
1671impl FunctionRegistry for SessionContext {
1672 fn udfs(&self) -> HashSet<String> {
1673 self.state.read().udfs()
1674 }
1675
1676 fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
1677 self.state.read().udf(name)
1678 }
1679
1680 fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
1681 self.state.read().udaf(name)
1682 }
1683
1684 fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
1685 self.state.read().udwf(name)
1686 }
1687
1688 fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
1689 self.state.write().register_udf(udf)
1690 }
1691
1692 fn register_udaf(
1693 &mut self,
1694 udaf: Arc<AggregateUDF>,
1695 ) -> Result<Option<Arc<AggregateUDF>>> {
1696 self.state.write().register_udaf(udaf)
1697 }
1698
1699 fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
1700 self.state.write().register_udwf(udwf)
1701 }
1702
1703 fn register_function_rewrite(
1704 &mut self,
1705 rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
1706 ) -> Result<()> {
1707 self.state.write().register_function_rewrite(rewrite)
1708 }
1709
1710 fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
1711 self.state.read().expr_planners().to_vec()
1712 }
1713
1714 fn register_expr_planner(
1715 &mut self,
1716 expr_planner: Arc<dyn ExprPlanner>,
1717 ) -> Result<()> {
1718 self.state.write().register_expr_planner(expr_planner)
1719 }
1720}
1721
1722impl From<&SessionContext> for TaskContext {
1724 fn from(session: &SessionContext) -> Self {
1725 TaskContext::from(&*session.state.read())
1726 }
1727}
1728
1729impl From<SessionState> for SessionContext {
1730 fn from(state: SessionState) -> Self {
1731 Self::new_with_state(state)
1732 }
1733}
1734
1735impl From<SessionContext> for SessionStateBuilder {
1736 fn from(session: SessionContext) -> Self {
1737 session.into_state_builder()
1738 }
1739}
1740
1741#[async_trait]
1743pub trait QueryPlanner: Debug {
1744 async fn create_physical_plan(
1746 &self,
1747 logical_plan: &LogicalPlan,
1748 session_state: &SessionState,
1749 ) -> Result<Arc<dyn ExecutionPlan>>;
1750}
1751
1752#[async_trait]
1756pub trait FunctionFactory: Debug + Sync + Send {
1757 async fn create(
1759 &self,
1760 state: &SessionState,
1761 statement: CreateFunction,
1762 ) -> Result<RegisterFunction>;
1763}
1764
1765pub enum RegisterFunction {
1767 Scalar(Arc<ScalarUDF>),
1769 Aggregate(Arc<AggregateUDF>),
1771 Window(Arc<WindowUDF>),
1773 Table(String, Arc<dyn TableFunctionImpl>),
1775}
1776
1777#[derive(Debug)]
1780pub struct EmptySerializerRegistry;
1781
1782impl SerializerRegistry for EmptySerializerRegistry {
1783 fn serialize_logical_plan(
1784 &self,
1785 node: &dyn UserDefinedLogicalNode,
1786 ) -> Result<Vec<u8>> {
1787 not_impl_err!(
1788 "Serializing user defined logical plan node `{}` is not supported",
1789 node.name()
1790 )
1791 }
1792
1793 fn deserialize_logical_plan(
1794 &self,
1795 name: &str,
1796 _bytes: &[u8],
1797 ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
1798 not_impl_err!(
1799 "Deserializing user defined logical plan node `{name}` is not supported"
1800 )
1801 }
1802}
1803
1804#[derive(Clone, Debug, Copy)]
1808pub struct SQLOptions {
1809 allow_ddl: bool,
1811 allow_dml: bool,
1813 allow_statements: bool,
1815}
1816
1817impl Default for SQLOptions {
1818 fn default() -> Self {
1819 Self {
1820 allow_ddl: true,
1821 allow_dml: true,
1822 allow_statements: true,
1823 }
1824 }
1825}
1826
1827impl SQLOptions {
1828 pub fn new() -> Self {
1830 Default::default()
1831 }
1832
1833 pub fn with_allow_ddl(mut self, allow: bool) -> Self {
1835 self.allow_ddl = allow;
1836 self
1837 }
1838
1839 pub fn with_allow_dml(mut self, allow: bool) -> Self {
1841 self.allow_dml = allow;
1842 self
1843 }
1844
1845 pub fn with_allow_statements(mut self, allow: bool) -> Self {
1847 self.allow_statements = allow;
1848 self
1849 }
1850
1851 pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> {
1854 plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?;
1855 Ok(())
1856 }
1857}
1858
1859struct BadPlanVisitor<'a> {
1860 options: &'a SQLOptions,
1861}
1862impl<'a> BadPlanVisitor<'a> {
1863 fn new(options: &'a SQLOptions) -> Self {
1864 Self { options }
1865 }
1866}
1867
1868impl<'n> TreeNodeVisitor<'n> for BadPlanVisitor<'_> {
1869 type Node = LogicalPlan;
1870
1871 fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1872 match node {
1873 LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
1874 plan_err!("DDL not supported: {}", ddl.name())
1875 }
1876 LogicalPlan::Dml(dml) if !self.options.allow_dml => {
1877 plan_err!("DML not supported: {}", dml.op)
1878 }
1879 LogicalPlan::Copy(_) if !self.options.allow_dml => {
1880 plan_err!("DML not supported: COPY")
1881 }
1882 LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
1883 plan_err!("Statement not supported: {}", stmt.name())
1884 }
1885 _ => Ok(TreeNodeRecursion::Continue),
1886 }
1887 }
1888}
1889
1890#[cfg(test)]
1891mod tests {
1892 use super::{super::options::CsvReadOptions, *};
1893 use crate::execution::memory_pool::MemoryConsumer;
1894 use crate::test;
1895 use crate::test_util::{plan_and_collect, populate_csv_partitions};
1896 use arrow::datatypes::{DataType, TimeUnit};
1897 use std::error::Error;
1898 use std::path::PathBuf;
1899
1900 use datafusion_common::test_util::batches_to_string;
1901 use datafusion_common_runtime::SpawnedTask;
1902 use insta::{allow_duplicates, assert_snapshot};
1903
1904 use crate::catalog::SchemaProvider;
1905 use crate::execution::session_state::SessionStateBuilder;
1906 use crate::physical_planner::PhysicalPlanner;
1907 use async_trait::async_trait;
1908 use datafusion_expr::planner::TypePlanner;
1909 use sqlparser::ast;
1910 use tempfile::TempDir;
1911
1912 #[tokio::test]
1913 async fn shared_memory_and_disk_manager() {
1914 let ctx1 = SessionContext::new();
1917
1918 let memory_pool = ctx1.runtime_env().memory_pool.clone();
1920
1921 let mut reservation = MemoryConsumer::new("test").register(&memory_pool);
1922 reservation.grow(100);
1923
1924 let disk_manager = ctx1.runtime_env().disk_manager.clone();
1925
1926 let ctx2 =
1927 SessionContext::new_with_config_rt(SessionConfig::new(), ctx1.runtime_env());
1928
1929 assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100);
1930 assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100);
1931
1932 drop(reservation);
1933
1934 assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 0);
1935 assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 0);
1936
1937 assert!(std::ptr::eq(
1938 Arc::as_ptr(&disk_manager),
1939 Arc::as_ptr(&ctx1.runtime_env().disk_manager)
1940 ));
1941 assert!(std::ptr::eq(
1942 Arc::as_ptr(&disk_manager),
1943 Arc::as_ptr(&ctx2.runtime_env().disk_manager)
1944 ));
1945 }
1946
1947 #[tokio::test]
1948 async fn create_variable_expr() -> Result<()> {
1949 let tmp_dir = TempDir::new()?;
1950 let partition_count = 4;
1951 let ctx = create_ctx(&tmp_dir, partition_count).await?;
1952
1953 let variable_provider = test::variable::SystemVar::new();
1954 ctx.register_variable(VarType::System, Arc::new(variable_provider));
1955 let variable_provider = test::variable::UserDefinedVar::new();
1956 ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));
1957
1958 let provider = test::create_table_dual();
1959 ctx.register_table("dual", provider)?;
1960
1961 let results =
1962 plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
1963 .await?;
1964
1965 assert_snapshot!(batches_to_string(&results), @r"
1966 +----------------------+------------------------+---------------------+
1967 | @@version | @name | @integer + Int64(1) |
1968 +----------------------+------------------------+---------------------+
1969 | system-var-@@version | user-defined-var-@name | 42 |
1970 +----------------------+------------------------+---------------------+
1971 ");
1972
1973 Ok(())
1974 }
1975
1976 #[tokio::test]
1977 async fn create_variable_err() -> Result<()> {
1978 let ctx = SessionContext::new();
1979
1980 let err = plan_and_collect(&ctx, "SElECT @= X3").await.unwrap_err();
1981 assert_eq!(
1982 err.strip_backtrace(),
1983 "Error during planning: variable [\"@=\"] has no type information"
1984 );
1985 Ok(())
1986 }
1987
1988 #[tokio::test]
1989 async fn register_deregister() -> Result<()> {
1990 let tmp_dir = TempDir::new()?;
1991 let partition_count = 4;
1992 let ctx = create_ctx(&tmp_dir, partition_count).await?;
1993
1994 let provider = test::create_table_dual();
1995 ctx.register_table("dual", provider)?;
1996
1997 assert!(ctx.deregister_table("dual")?.is_some());
1998 assert!(ctx.deregister_table("dual")?.is_none());
1999
2000 Ok(())
2001 }
2002
2003 #[tokio::test]
2004 async fn send_context_to_threads() -> Result<()> {
2005 let tmp_dir = TempDir::new()?;
2008 let partition_count = 4;
2009 let ctx = Arc::new(create_ctx(&tmp_dir, partition_count).await?);
2010
2011 let threads: Vec<_> = (0..2)
2012 .map(|_| ctx.clone())
2013 .map(|ctx| {
2014 SpawnedTask::spawn(async move {
2015 ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")
2017 .await
2018 })
2019 })
2020 .collect();
2021
2022 for handle in threads {
2023 handle.join().await.unwrap().unwrap();
2024 }
2025 Ok(())
2026 }
2027
2028 #[tokio::test]
2029 async fn with_listing_schema_provider() -> Result<()> {
2030 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2031 let path = path.join("tests/tpch-csv");
2032 let url = format!("file://{}", path.display());
2033
2034 let cfg = SessionConfig::new()
2035 .set_str("datafusion.catalog.location", url.as_str())
2036 .set_str("datafusion.catalog.format", "CSV")
2037 .set_str("datafusion.catalog.has_header", "true");
2038 let session_state = SessionStateBuilder::new()
2039 .with_config(cfg)
2040 .with_default_features()
2041 .build();
2042 let ctx = SessionContext::new_with_state(session_state);
2043 ctx.refresh_catalogs().await?;
2044
2045 let result =
2046 plan_and_collect(&ctx, "select c_name from default.customer limit 3;")
2047 .await?;
2048
2049 let actual = arrow::util::pretty::pretty_format_batches(&result)
2050 .unwrap()
2051 .to_string();
2052 assert_snapshot!(actual, @r"
2053 +--------------------+
2054 | c_name |
2055 +--------------------+
2056 | Customer#000000002 |
2057 | Customer#000000003 |
2058 | Customer#000000004 |
2059 +--------------------+
2060 ");
2061
2062 Ok(())
2063 }
2064
2065 #[tokio::test]
2066 async fn test_dynamic_file_query() -> Result<()> {
2067 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2068 let path = path.join("tests/tpch-csv/customer.csv");
2069 let url = format!("file://{}", path.display());
2070 let cfg = SessionConfig::new();
2071 let session_state = SessionStateBuilder::new()
2072 .with_default_features()
2073 .with_config(cfg)
2074 .build();
2075 let ctx = SessionContext::new_with_state(session_state).enable_url_table();
2076 let result = plan_and_collect(
2077 &ctx,
2078 format!("select c_name from '{}' limit 3;", &url).as_str(),
2079 )
2080 .await?;
2081
2082 let actual = arrow::util::pretty::pretty_format_batches(&result)
2083 .unwrap()
2084 .to_string();
2085 assert_snapshot!(actual, @r"
2086 +--------------------+
2087 | c_name |
2088 +--------------------+
2089 | Customer#000000002 |
2090 | Customer#000000003 |
2091 | Customer#000000004 |
2092 +--------------------+
2093 ");
2094
2095 Ok(())
2096 }
2097
2098 #[tokio::test]
2099 async fn custom_query_planner() -> Result<()> {
2100 let runtime = Arc::new(RuntimeEnv::default());
2101 let session_state = SessionStateBuilder::new()
2102 .with_config(SessionConfig::new())
2103 .with_runtime_env(runtime)
2104 .with_default_features()
2105 .with_query_planner(Arc::new(MyQueryPlanner {}))
2106 .build();
2107 let ctx = SessionContext::new_with_state(session_state);
2108
2109 let df = ctx.sql("SELECT 1").await?;
2110 df.collect().await.expect_err("query not supported");
2111 Ok(())
2112 }
2113
2114 #[tokio::test]
2115 async fn disabled_default_catalog_and_schema() -> Result<()> {
2116 let ctx = SessionContext::new_with_config(
2117 SessionConfig::new().with_create_default_catalog_and_schema(false),
2118 );
2119
2120 assert!(matches!(
2121 ctx.register_table("test", test::table_with_sequence(1, 1)?),
2122 Err(DataFusionError::Plan(_))
2123 ));
2124
2125 let err = ctx
2126 .sql("select * from datafusion.public.test")
2127 .await
2128 .unwrap_err();
2129 let err = err
2130 .source()
2131 .and_then(|err| err.downcast_ref::<DataFusionError>())
2132 .unwrap();
2133
2134 assert!(matches!(err, &DataFusionError::Plan(_)));
2135
2136 Ok(())
2137 }
2138
2139 #[tokio::test]
2140 async fn custom_catalog_and_schema() {
2141 let config = SessionConfig::new()
2142 .with_create_default_catalog_and_schema(true)
2143 .with_default_catalog_and_schema("my_catalog", "my_schema");
2144 catalog_and_schema_test(config).await;
2145 }
2146
2147 #[tokio::test]
2148 async fn custom_catalog_and_schema_no_default() {
2149 let config = SessionConfig::new()
2150 .with_create_default_catalog_and_schema(false)
2151 .with_default_catalog_and_schema("my_catalog", "my_schema");
2152 catalog_and_schema_test(config).await;
2153 }
2154
2155 #[tokio::test]
2156 async fn custom_catalog_and_schema_and_information_schema() {
2157 let config = SessionConfig::new()
2158 .with_create_default_catalog_and_schema(true)
2159 .with_information_schema(true)
2160 .with_default_catalog_and_schema("my_catalog", "my_schema");
2161 catalog_and_schema_test(config).await;
2162 }
2163
2164 async fn catalog_and_schema_test(config: SessionConfig) {
2165 let ctx = SessionContext::new_with_config(config);
2166 let catalog = MemoryCatalogProvider::new();
2167 let schema = MemorySchemaProvider::new();
2168 schema
2169 .register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap())
2170 .unwrap();
2171 catalog
2172 .register_schema("my_schema", Arc::new(schema))
2173 .unwrap();
2174 ctx.register_catalog("my_catalog", Arc::new(catalog));
2175
2176 let mut results = Vec::new();
2177
2178 for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] {
2179 let result = plan_and_collect(
2180 &ctx,
2181 &format!("SELECT COUNT(*) AS count FROM {table_ref}"),
2182 )
2183 .await
2184 .unwrap();
2185
2186 results.push(result);
2187 }
2188 allow_duplicates! {
2189 for result in &results {
2190 assert_snapshot!(batches_to_string(result), @r"
2191 +-------+
2192 | count |
2193 +-------+
2194 | 1 |
2195 +-------+
2196 ");
2197 }
2198 }
2199 }
2200
2201 #[tokio::test]
2202 async fn cross_catalog_access() -> Result<()> {
2203 let ctx = SessionContext::new();
2204
2205 let catalog_a = MemoryCatalogProvider::new();
2206 let schema_a = MemorySchemaProvider::new();
2207 schema_a
2208 .register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
2209 catalog_a.register_schema("schema_a", Arc::new(schema_a))?;
2210 ctx.register_catalog("catalog_a", Arc::new(catalog_a));
2211
2212 let catalog_b = MemoryCatalogProvider::new();
2213 let schema_b = MemorySchemaProvider::new();
2214 schema_b
2215 .register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
2216 catalog_b.register_schema("schema_b", Arc::new(schema_b))?;
2217 ctx.register_catalog("catalog_b", Arc::new(catalog_b));
2218
2219 let result = plan_and_collect(
2220 &ctx,
2221 "SELECT cat, SUM(i) AS total FROM (
2222 SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a
2223 UNION ALL
2224 SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b
2225 ) AS all
2226 GROUP BY cat
2227 ORDER BY cat
2228 ",
2229 )
2230 .await?;
2231
2232 assert_snapshot!(batches_to_string(&result), @r"
2233 +-----+-------+
2234 | cat | total |
2235 +-----+-------+
2236 | a | 1 |
2237 | b | 3 |
2238 +-----+-------+
2239 ");
2240
2241 Ok(())
2242 }
2243
2244 #[tokio::test]
2245 async fn catalogs_not_leaked() {
2246 let ctx = SessionContext::new_with_config(
2248 SessionConfig::new().with_information_schema(true),
2249 );
2250
2251 let catalog = Arc::new(MemoryCatalogProvider::new());
2253 let catalog_weak = Arc::downgrade(&catalog);
2254 ctx.register_catalog("my_catalog", catalog);
2255
2256 let catalog_list_weak = {
2257 let state = ctx.state.read();
2258 Arc::downgrade(state.catalog_list())
2259 };
2260
2261 drop(ctx);
2262
2263 assert_eq!(Weak::strong_count(&catalog_list_weak), 0);
2264 assert_eq!(Weak::strong_count(&catalog_weak), 0);
2265 }
2266
2267 #[tokio::test]
2268 async fn sql_create_schema() -> Result<()> {
2269 let ctx = SessionContext::new_with_config(
2271 SessionConfig::new().with_information_schema(true),
2272 );
2273
2274 ctx.sql("CREATE SCHEMA abc").await?.collect().await?;
2276
2277 ctx.sql("CREATE TABLE abc.y AS VALUES (1,2,3)")
2279 .await?
2280 .collect()
2281 .await?;
2282
2283 let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
2285
2286 assert_eq!(results[0].num_rows(), 1);
2287 Ok(())
2288 }
2289
2290 #[tokio::test]
2291 async fn sql_create_catalog() -> Result<()> {
2292 let ctx = SessionContext::new_with_config(
2294 SessionConfig::new().with_information_schema(true),
2295 );
2296
2297 ctx.sql("CREATE DATABASE test").await?.collect().await?;
2299
2300 ctx.sql("CREATE SCHEMA test.abc").await?.collect().await?;
2302
2303 ctx.sql("CREATE TABLE test.abc.y AS VALUES (1,2,3)")
2305 .await?
2306 .collect()
2307 .await?;
2308
2309 let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_catalog='test' AND table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
2311
2312 assert_eq!(results[0].num_rows(), 1);
2313 Ok(())
2314 }
2315
2316 #[tokio::test]
2317 async fn custom_type_planner() -> Result<()> {
2318 let state = SessionStateBuilder::new()
2319 .with_default_features()
2320 .with_type_planner(Arc::new(MyTypePlanner {}))
2321 .build();
2322 let ctx = SessionContext::new_with_state(state);
2323 let result = ctx
2324 .sql("SELECT DATETIME '2021-01-01 00:00:00'")
2325 .await?
2326 .collect()
2327 .await?;
2328 assert_snapshot!(batches_to_string(&result), @r#"
2329 +-----------------------------+
2330 | Utf8("2021-01-01 00:00:00") |
2331 +-----------------------------+
2332 | 2021-01-01T00:00:00 |
2333 +-----------------------------+
2334 "#);
2335 Ok(())
2336 }
2337 #[test]
2338 fn preserve_session_context_id() -> Result<()> {
2339 let ctx = SessionContext::new();
2340 assert_eq!(ctx.session_id(), ctx.enable_url_table().session_id());
2345 Ok(())
2346 }
2347
2348 struct MyPhysicalPlanner {}
2349
2350 #[async_trait]
2351 impl PhysicalPlanner for MyPhysicalPlanner {
2352 async fn create_physical_plan(
2353 &self,
2354 _logical_plan: &LogicalPlan,
2355 _session_state: &SessionState,
2356 ) -> Result<Arc<dyn ExecutionPlan>> {
2357 not_impl_err!("query not supported")
2358 }
2359
2360 fn create_physical_expr(
2361 &self,
2362 _expr: &Expr,
2363 _input_dfschema: &DFSchema,
2364 _session_state: &SessionState,
2365 ) -> Result<Arc<dyn PhysicalExpr>> {
2366 unimplemented!()
2367 }
2368 }
2369
2370 #[derive(Debug)]
2371 struct MyQueryPlanner {}
2372
2373 #[async_trait]
2374 impl QueryPlanner for MyQueryPlanner {
2375 async fn create_physical_plan(
2376 &self,
2377 logical_plan: &LogicalPlan,
2378 session_state: &SessionState,
2379 ) -> Result<Arc<dyn ExecutionPlan>> {
2380 let physical_planner = MyPhysicalPlanner {};
2381 physical_planner
2382 .create_physical_plan(logical_plan, session_state)
2383 .await
2384 }
2385 }
2386
2387 async fn create_ctx(
2389 tmp_dir: &TempDir,
2390 partition_count: usize,
2391 ) -> Result<SessionContext> {
2392 let ctx = SessionContext::new_with_config(
2393 SessionConfig::new().with_target_partitions(8),
2394 );
2395
2396 let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?;
2397
2398 ctx.register_csv(
2400 "test",
2401 tmp_dir.path().to_str().unwrap(),
2402 CsvReadOptions::new().schema(&schema),
2403 )
2404 .await?;
2405
2406 Ok(ctx)
2407 }
2408
2409 #[derive(Debug)]
2410 struct MyTypePlanner {}
2411
2412 impl TypePlanner for MyTypePlanner {
2413 fn plan_type(&self, sql_type: &ast::DataType) -> Result<Option<DataType>> {
2414 match sql_type {
2415 ast::DataType::Datetime(precision) => {
2416 let precision = match precision {
2417 Some(0) => TimeUnit::Second,
2418 Some(3) => TimeUnit::Millisecond,
2419 Some(6) => TimeUnit::Microsecond,
2420 None | Some(9) => TimeUnit::Nanosecond,
2421 _ => unreachable!(),
2422 };
2423 Ok(Some(DataType::Timestamp(precision, None)))
2424 }
2425 _ => Ok(None),
2426 }
2427 }
2428 }
2429}