datafusion/execution/context/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SessionContext`] API for registering data sources and executing queries
19
20use std::collections::HashSet;
21use std::fmt::Debug;
22use std::sync::{Arc, Weak};
23
24use super::options::ReadOptions;
25use crate::datasource::dynamic_file::DynamicListTableFactory;
26use crate::execution::session_state::SessionStateBuilder;
27use crate::{
28    catalog::listing_schema::ListingSchemaProvider,
29    catalog::{
30        CatalogProvider, CatalogProviderList, TableProvider, TableProviderFactory,
31    },
32    dataframe::DataFrame,
33    datasource::listing::{
34        ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
35    },
36    datasource::{provider_as_source, MemTable, ViewTable},
37    error::{DataFusionError, Result},
38    execution::{
39        options::ArrowReadOptions,
40        runtime_env::{RuntimeEnv, RuntimeEnvBuilder},
41        FunctionRegistry,
42    },
43    logical_expr::AggregateUDF,
44    logical_expr::ScalarUDF,
45    logical_expr::{
46        CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
47        CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
48        DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable,
49        TableType, UNNAMED_TABLE,
50    },
51    physical_expr::PhysicalExpr,
52    physical_plan::ExecutionPlan,
53    variable::{VarProvider, VarType},
54};
55
56// backwards compatibility
57pub use crate::execution::session_state::SessionState;
58
59use arrow::datatypes::{Schema, SchemaRef};
60use arrow::record_batch::RecordBatch;
61use datafusion_catalog::memory::MemorySchemaProvider;
62use datafusion_catalog::MemoryCatalogProvider;
63use datafusion_catalog::{
64    DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory,
65};
66use datafusion_common::config::ConfigOptions;
67use datafusion_common::{
68    config::{ConfigExtension, TableOptions},
69    exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err,
70    tree_node::{TreeNodeRecursion, TreeNodeVisitor},
71    DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference,
72};
73pub use datafusion_execution::config::SessionConfig;
74use datafusion_execution::registry::SerializerRegistry;
75pub use datafusion_execution::TaskContext;
76pub use datafusion_expr::execution_props::ExecutionProps;
77use datafusion_expr::{
78    expr_rewriter::FunctionRewrite,
79    logical_plan::{DdlStatement, Statement},
80    planner::ExprPlanner,
81    Expr, UserDefinedLogicalNode, WindowUDF,
82};
83use datafusion_optimizer::analyzer::type_coercion::TypeCoercion;
84use datafusion_optimizer::Analyzer;
85use datafusion_optimizer::{AnalyzerRule, OptimizerRule};
86use datafusion_session::SessionStore;
87
88use async_trait::async_trait;
89use chrono::{DateTime, Utc};
90use object_store::ObjectStore;
91use parking_lot::RwLock;
92use url::Url;
93
94mod csv;
95mod json;
96#[cfg(feature = "parquet")]
97mod parquet;
98
99#[cfg(feature = "avro")]
100mod avro;
101
102/// DataFilePaths adds a method to convert strings and vector of strings to vector of [`ListingTableUrl`] URLs.
103/// This allows methods such [`SessionContext::read_csv`] and [`SessionContext::read_avro`]
104/// to take either a single file or multiple files.
105pub trait DataFilePaths {
106    /// Parse to a vector of [`ListingTableUrl`] URLs.
107    fn to_urls(self) -> Result<Vec<ListingTableUrl>>;
108}
109
110impl DataFilePaths for &str {
111    fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
112        Ok(vec![ListingTableUrl::parse(self)?])
113    }
114}
115
116impl DataFilePaths for String {
117    fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
118        Ok(vec![ListingTableUrl::parse(self)?])
119    }
120}
121
122impl DataFilePaths for &String {
123    fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
124        Ok(vec![ListingTableUrl::parse(self)?])
125    }
126}
127
128impl<P> DataFilePaths for Vec<P>
129where
130    P: AsRef<str>,
131{
132    fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
133        self.iter()
134            .map(ListingTableUrl::parse)
135            .collect::<Result<Vec<ListingTableUrl>>>()
136    }
137}
138
139/// Main interface for executing queries with DataFusion. Maintains
140/// the state of the connection between a user and an instance of the
141/// DataFusion engine.
142///
143/// See examples below for how to use the `SessionContext` to execute queries
144/// and how to configure the session.
145///
146/// # Overview
147///
148/// [`SessionContext`] provides the following functionality:
149///
150/// * Create a [`DataFrame`] from a CSV or Parquet data source.
151/// * Register a CSV or Parquet data source as a table that can be referenced from a SQL query.
152/// * Register a custom data source that can be referenced from a SQL query.
153/// * Execution a SQL query
154///
155/// # Example: DataFrame API
156///
157/// The following example demonstrates how to use the context to execute a query against a CSV
158/// data source using the [`DataFrame`] API:
159///
160/// ```
161/// use datafusion::prelude::*;
162/// # use datafusion::functions_aggregate::expr_fn::min;
163/// # use datafusion::{error::Result, assert_batches_eq};
164/// # #[tokio::main]
165/// # async fn main() -> Result<()> {
166/// let ctx = SessionContext::new();
167/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
168/// let df = df.filter(col("a").lt_eq(col("b")))?
169///            .aggregate(vec![col("a")], vec![min(col("b"))])?
170///            .limit(0, Some(100))?;
171/// let results = df
172///   .collect()
173///   .await?;
174/// assert_batches_eq!(
175///  &[
176///    "+---+----------------+",
177///    "| a | min(?table?.b) |",
178///    "+---+----------------+",
179///    "| 1 | 2              |",
180///    "+---+----------------+",
181///  ],
182///  &results
183/// );
184/// # Ok(())
185/// # }
186/// ```
187///
188/// # Example: SQL API
189///
190/// The following example demonstrates how to execute the same query using SQL:
191///
192/// ```
193/// use datafusion::prelude::*;
194/// # use datafusion::{error::Result, assert_batches_eq};
195/// # #[tokio::main]
196/// # async fn main() -> Result<()> {
197/// let ctx = SessionContext::new();
198/// ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
199/// let results = ctx
200///   .sql("SELECT a, min(b) FROM example GROUP BY a LIMIT 100")
201///   .await?
202///   .collect()
203///   .await?;
204/// assert_batches_eq!(
205///  &[
206///    "+---+----------------+",
207///    "| a | min(example.b) |",
208///    "+---+----------------+",
209///    "| 1 | 2              |",
210///    "+---+----------------+",
211///  ],
212///  &results
213/// );
214/// # Ok(())
215/// # }
216/// ```
217///
218/// # Example: Configuring `SessionContext`
219///
220/// The `SessionContext` can be configured by creating a [`SessionState`] using
221/// [`SessionStateBuilder`]:
222///
223/// ```
224/// # use std::sync::Arc;
225/// # use datafusion::prelude::*;
226/// # use datafusion::execution::SessionStateBuilder;
227/// # use datafusion_execution::runtime_env::RuntimeEnvBuilder;
228/// // Configure a 4k batch size
229/// let config = SessionConfig::new() .with_batch_size(4 * 1024);
230///
231/// // configure a memory limit of 1GB with 20%  slop
232///  let runtime_env = RuntimeEnvBuilder::new()
233///     .with_memory_limit(1024 * 1024 * 1024, 0.80)
234///     .build_arc()
235///     .unwrap();
236///
237/// // Create a SessionState using the config and runtime_env
238/// let state = SessionStateBuilder::new()
239///   .with_config(config)
240///   .with_runtime_env(runtime_env)
241///   // include support for built in functions and configurations
242///   .with_default_features()
243///   .build();
244///
245/// // Create a SessionContext
246/// let ctx = SessionContext::from(state);
247/// ```
248///
249/// # Relationship between `SessionContext`, `SessionState`, and `TaskContext`
250///
251/// The state required to optimize, and evaluate queries is
252/// broken into three levels to allow tailoring
253///
254/// The objects are:
255///
256/// 1. [`SessionContext`]: Most users should use a `SessionContext`. It contains
257///    all information required to execute queries including  high level APIs such
258///    as [`SessionContext::sql`]. All queries run with the same `SessionContext`
259///    share the same configuration and resources (e.g. memory limits).
260///
261/// 2. [`SessionState`]: contains information required to plan and execute an
262///    individual query (e.g. creating a [`LogicalPlan`] or [`ExecutionPlan`]).
263///    Each query is planned and executed using its own `SessionState`, which can
264///    be created with [`SessionContext::state`]. `SessionState` allows finer
265///    grained control over query execution, for example disallowing DDL operations
266///    such as `CREATE TABLE`.
267///
268/// 3. [`TaskContext`] contains the state required for query execution (e.g.
269///    [`ExecutionPlan::execute`]). It contains a subset of information in
270///    [`SessionState`]. `TaskContext` allows executing [`ExecutionPlan`]s
271///    [`PhysicalExpr`]s without requiring a full [`SessionState`].
272///
273/// [`PhysicalExpr`]: crate::physical_expr::PhysicalExpr
274#[derive(Clone)]
275pub struct SessionContext {
276    /// UUID for the session
277    session_id: String,
278    /// Session start time
279    session_start_time: DateTime<Utc>,
280    /// Shared session state for the session
281    state: Arc<RwLock<SessionState>>,
282}
283
284impl Default for SessionContext {
285    fn default() -> Self {
286        Self::new()
287    }
288}
289
290impl SessionContext {
291    /// Creates a new `SessionContext` using the default [`SessionConfig`].
292    pub fn new() -> Self {
293        Self::new_with_config(SessionConfig::new())
294    }
295
296    /// Finds any [`ListingSchemaProvider`]s and instructs them to reload tables from "disk"
297    pub async fn refresh_catalogs(&self) -> Result<()> {
298        let cat_names = self.catalog_names().clone();
299        for cat_name in cat_names.iter() {
300            let cat = self.catalog(cat_name.as_str()).ok_or_else(|| {
301                DataFusionError::Internal("Catalog not found!".to_string())
302            })?;
303            for schema_name in cat.schema_names() {
304                let schema = cat.schema(schema_name.as_str()).ok_or_else(|| {
305                    DataFusionError::Internal("Schema not found!".to_string())
306                })?;
307                let lister = schema.as_any().downcast_ref::<ListingSchemaProvider>();
308                if let Some(lister) = lister {
309                    lister.refresh(&self.state()).await?;
310                }
311            }
312        }
313        Ok(())
314    }
315
316    /// Creates a new `SessionContext` using the provided
317    /// [`SessionConfig`] and a new [`RuntimeEnv`].
318    ///
319    /// See [`Self::new_with_config_rt`] for more details on resource
320    /// limits.
321    pub fn new_with_config(config: SessionConfig) -> Self {
322        let runtime = Arc::new(RuntimeEnv::default());
323        Self::new_with_config_rt(config, runtime)
324    }
325
326    /// Creates a new `SessionContext` using the provided
327    /// [`SessionConfig`] and a [`RuntimeEnv`].
328    ///
329    /// # Resource Limits
330    ///
331    /// By default, each new `SessionContext` creates a new
332    /// `RuntimeEnv`, and therefore will not enforce memory or disk
333    /// limits for queries run on different `SessionContext`s.
334    ///
335    /// To enforce resource limits (e.g. to limit the total amount of
336    /// memory used) across all DataFusion queries in a process,
337    /// all `SessionContext`'s should be configured with the
338    /// same `RuntimeEnv`.
339    pub fn new_with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
340        let state = SessionStateBuilder::new()
341            .with_config(config)
342            .with_runtime_env(runtime)
343            .with_default_features()
344            .build();
345        Self::new_with_state(state)
346    }
347
348    /// Creates a new `SessionContext` using the provided [`SessionState`]
349    pub fn new_with_state(state: SessionState) -> Self {
350        Self {
351            session_id: state.session_id().to_string(),
352            session_start_time: Utc::now(),
353            state: Arc::new(RwLock::new(state)),
354        }
355    }
356
357    /// Enable querying local files as tables.
358    ///
359    /// This feature is security sensitive and should only be enabled for
360    /// systems that wish to permit direct access to the file system from SQL.
361    ///
362    /// When enabled, this feature permits direct access to arbitrary files via
363    /// SQL like
364    ///
365    /// ```sql
366    /// SELECT * from 'my_file.parquet'
367    /// ```
368    ///
369    /// See [DynamicFileCatalog] for more details
370    ///
371    /// ```
372    /// # use datafusion::prelude::*;
373    /// # use datafusion::{error::Result, assert_batches_eq};
374    /// # #[tokio::main]
375    /// # async fn main() -> Result<()> {
376    /// let ctx = SessionContext::new()
377    ///   .enable_url_table(); // permit local file access
378    /// let results = ctx
379    ///   .sql("SELECT a, MIN(b) FROM 'tests/data/example.csv' as example GROUP BY a LIMIT 100")
380    ///   .await?
381    ///   .collect()
382    ///   .await?;
383    /// assert_batches_eq!(
384    ///  &[
385    ///    "+---+----------------+",
386    ///    "| a | min(example.b) |",
387    ///    "+---+----------------+",
388    ///    "| 1 | 2              |",
389    ///    "+---+----------------+",
390    ///  ],
391    ///  &results
392    /// );
393    /// # Ok(())
394    /// # }
395    /// ```
396    pub fn enable_url_table(self) -> Self {
397        let current_catalog_list = Arc::clone(self.state.read().catalog_list());
398        let factory = Arc::new(DynamicListTableFactory::new(SessionStore::new()));
399        let catalog_list = Arc::new(DynamicFileCatalog::new(
400            current_catalog_list,
401            Arc::clone(&factory) as Arc<dyn UrlTableFactory>,
402        ));
403
404        let session_id = self.session_id.clone();
405        let ctx: SessionContext = self
406            .into_state_builder()
407            .with_session_id(session_id)
408            .with_catalog_list(catalog_list)
409            .build()
410            .into();
411        // register new state with the factory
412        factory.session_store().with_state(ctx.state_weak_ref());
413        ctx
414    }
415
416    /// Convert the current `SessionContext` into a [`SessionStateBuilder`]
417    ///
418    /// This is useful to switch back to `SessionState` with custom settings such as
419    /// [`Self::enable_url_table`].
420    ///
421    /// Avoids cloning the SessionState if possible.
422    ///
423    /// # Example
424    /// ```
425    /// # use std::sync::Arc;
426    /// # use datafusion::prelude::*;
427    /// # use datafusion::execution::SessionStateBuilder;
428    /// # use datafusion_optimizer::push_down_filter::PushDownFilter;
429    /// let my_rule = PushDownFilter{}; // pretend it is a new rule
430    /// // Create a new builder with a custom optimizer rule
431    /// let context: SessionContext = SessionStateBuilder::new()
432    ///   .with_optimizer_rule(Arc::new(my_rule))
433    ///   .build()
434    ///   .into();
435    /// // Enable local file access and convert context back to a builder
436    /// let builder = context
437    ///   .enable_url_table()
438    ///   .into_state_builder();
439    /// ```
440    pub fn into_state_builder(self) -> SessionStateBuilder {
441        let SessionContext {
442            session_id: _,
443            session_start_time: _,
444            state,
445        } = self;
446        let state = match Arc::try_unwrap(state) {
447            Ok(rwlock) => rwlock.into_inner(),
448            Err(state) => state.read().clone(),
449        };
450        SessionStateBuilder::from(state)
451    }
452
453    /// Returns the time this `SessionContext` was created
454    pub fn session_start_time(&self) -> DateTime<Utc> {
455        self.session_start_time
456    }
457
458    /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements
459    pub fn with_function_factory(
460        self,
461        function_factory: Arc<dyn FunctionFactory>,
462    ) -> Self {
463        self.state.write().set_function_factory(function_factory);
464        self
465    }
466
467    /// Adds an optimizer rule to the end of the existing rules.
468    ///
469    /// See [`SessionState`] for more control of when the rule is applied.
470    pub fn add_optimizer_rule(
471        &self,
472        optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
473    ) {
474        self.state.write().append_optimizer_rule(optimizer_rule);
475    }
476
477    /// Adds an analyzer rule to the end of the existing rules.
478    ///
479    /// See [`SessionState`] for more control of when the rule is applied.
480    pub fn add_analyzer_rule(&self, analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>) {
481        self.state.write().add_analyzer_rule(analyzer_rule);
482    }
483
484    /// Registers an [`ObjectStore`] to be used with a specific URL prefix.
485    ///
486    /// See [`RuntimeEnv::register_object_store`] for more details.
487    ///
488    /// # Example: register a local object store for the "file://" URL prefix
489    /// ```
490    /// # use std::sync::Arc;
491    /// # use datafusion::prelude::SessionContext;
492    /// # use datafusion_execution::object_store::ObjectStoreUrl;
493    /// let object_store_url = ObjectStoreUrl::parse("file://").unwrap();
494    /// let object_store = object_store::local::LocalFileSystem::new();
495    /// let ctx = SessionContext::new();
496    /// // All files with the file:// url prefix will be read from the local file system
497    /// ctx.register_object_store(object_store_url.as_ref(), Arc::new(object_store));
498    /// ```
499    pub fn register_object_store(
500        &self,
501        url: &Url,
502        object_store: Arc<dyn ObjectStore>,
503    ) -> Option<Arc<dyn ObjectStore>> {
504        self.runtime_env().register_object_store(url, object_store)
505    }
506
507    /// Registers the [`RecordBatch`] as the specified table name
508    pub fn register_batch(
509        &self,
510        table_name: &str,
511        batch: RecordBatch,
512    ) -> Result<Option<Arc<dyn TableProvider>>> {
513        let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
514        self.register_table(
515            TableReference::Bare {
516                table: table_name.into(),
517            },
518            Arc::new(table),
519        )
520    }
521
522    /// Return the [RuntimeEnv] used to run queries with this `SessionContext`
523    pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
524        Arc::clone(self.state.read().runtime_env())
525    }
526
527    /// Returns an id that uniquely identifies this `SessionContext`.
528    pub fn session_id(&self) -> String {
529        self.session_id.clone()
530    }
531
532    /// Return the [`TableProviderFactory`] that is registered for the
533    /// specified file type, if any.
534    pub fn table_factory(
535        &self,
536        file_type: &str,
537    ) -> Option<Arc<dyn TableProviderFactory>> {
538        self.state.read().table_factories().get(file_type).cloned()
539    }
540
541    /// Return the `enable_ident_normalization` of this Session
542    pub fn enable_ident_normalization(&self) -> bool {
543        self.state
544            .read()
545            .config()
546            .options()
547            .sql_parser
548            .enable_ident_normalization
549    }
550
551    /// Return a copied version of config for this Session
552    pub fn copied_config(&self) -> SessionConfig {
553        self.state.read().config().clone()
554    }
555
556    /// Return a copied version of table options for this Session
557    pub fn copied_table_options(&self) -> TableOptions {
558        self.state.read().default_table_options()
559    }
560
561    /// Creates a [`DataFrame`] from SQL query text.
562    ///
563    /// Note: This API implements DDL statements such as `CREATE TABLE` and
564    /// `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory
565    /// default implementations. See [`Self::sql_with_options`].
566    ///
567    /// # Example: Running SQL queries
568    ///
569    /// See the example on [`Self`]
570    ///
571    /// # Example: Creating a Table with SQL
572    ///
573    /// ```
574    /// use datafusion::prelude::*;
575    /// # use datafusion::{error::Result, assert_batches_eq};
576    /// # #[tokio::main]
577    /// # async fn main() -> Result<()> {
578    /// let ctx = SessionContext::new();
579    /// ctx
580    ///   .sql("CREATE TABLE foo (x INTEGER)")
581    ///   .await?
582    ///   .collect()
583    ///   .await?;
584    /// assert!(ctx.table_exist("foo").unwrap());
585    /// # Ok(())
586    /// # }
587    /// ```
588    pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
589        self.sql_with_options(sql, SQLOptions::new()).await
590    }
591
592    /// Creates a [`DataFrame`] from SQL query text, first validating
593    /// that the queries are allowed by `options`
594    ///
595    /// # Example: Preventing Creating a Table with SQL
596    ///
597    /// If you want to avoid creating tables, or modifying data or the
598    /// session, set [`SQLOptions`] appropriately:
599    ///
600    /// ```
601    /// use datafusion::prelude::*;
602    /// # use datafusion::{error::Result};
603    /// # use datafusion::physical_plan::collect;
604    /// # #[tokio::main]
605    /// # async fn main() -> Result<()> {
606    /// let ctx = SessionContext::new();
607    /// let options = SQLOptions::new()
608    ///   .with_allow_ddl(false);
609    /// let err = ctx.sql_with_options("CREATE TABLE foo (x INTEGER)", options)
610    ///   .await
611    ///   .unwrap_err();
612    /// assert!(
613    ///   err.to_string().starts_with("Error during planning: DDL not supported: CreateMemoryTable")
614    /// );
615    /// # Ok(())
616    /// # }
617    /// ```
618    pub async fn sql_with_options(
619        &self,
620        sql: &str,
621        options: SQLOptions,
622    ) -> Result<DataFrame> {
623        let plan = self.state().create_logical_plan(sql).await?;
624        options.verify_plan(&plan)?;
625
626        self.execute_logical_plan(plan).await
627    }
628
629    /// Creates logical expressions from SQL query text.
630    ///
631    /// # Example: Parsing SQL queries
632    ///
633    /// ```
634    /// # use arrow::datatypes::{DataType, Field, Schema};
635    /// # use datafusion::prelude::*;
636    /// # use datafusion_common::{DFSchema, Result};
637    /// # #[tokio::main]
638    /// # async fn main() -> Result<()> {
639    /// // datafusion will parse number as i64 first.
640    /// let sql = "a > 10";
641    /// let expected = col("a").gt(lit(10 as i64));
642    /// // provide type information that `a` is an Int32
643    /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
644    /// let df_schema = DFSchema::try_from(schema).unwrap();
645    /// let expr = SessionContext::new()
646    ///  .parse_sql_expr(sql, &df_schema)?;
647    /// assert_eq!(expected, expr);
648    /// # Ok(())
649    /// # }
650    /// ```
651    pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result<Expr> {
652        self.state.read().create_logical_expr(sql, df_schema)
653    }
654
655    /// Execute the [`LogicalPlan`], return a [`DataFrame`]. This API
656    /// is not featured limited (so all SQL such as `CREATE TABLE` and
657    /// `COPY` will be run).
658    ///
659    /// If you wish to limit the type of plan that can be run from
660    /// SQL, see [`Self::sql_with_options`] and
661    /// [`SQLOptions::verify_plan`].
662    pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
663        match plan {
664            LogicalPlan::Ddl(ddl) => {
665                // Box::pin avoids allocating the stack space within this function's frame
666                // for every one of these individual async functions, decreasing the risk of
667                // stack overflows.
668                match ddl {
669                    DdlStatement::CreateExternalTable(cmd) => {
670                        (Box::pin(async move { self.create_external_table(&cmd).await })
671                            as std::pin::Pin<Box<dyn futures::Future<Output = _> + Send>>)
672                            .await
673                    }
674                    DdlStatement::CreateMemoryTable(cmd) => {
675                        Box::pin(self.create_memory_table(cmd)).await
676                    }
677                    DdlStatement::CreateView(cmd) => {
678                        Box::pin(self.create_view(cmd)).await
679                    }
680                    DdlStatement::CreateCatalogSchema(cmd) => {
681                        Box::pin(self.create_catalog_schema(cmd)).await
682                    }
683                    DdlStatement::CreateCatalog(cmd) => {
684                        Box::pin(self.create_catalog(cmd)).await
685                    }
686                    DdlStatement::DropTable(cmd) => Box::pin(self.drop_table(cmd)).await,
687                    DdlStatement::DropView(cmd) => Box::pin(self.drop_view(cmd)).await,
688                    DdlStatement::DropCatalogSchema(cmd) => {
689                        Box::pin(self.drop_schema(cmd)).await
690                    }
691                    DdlStatement::CreateFunction(cmd) => {
692                        Box::pin(self.create_function(cmd)).await
693                    }
694                    DdlStatement::DropFunction(cmd) => {
695                        Box::pin(self.drop_function(cmd)).await
696                    }
697                    ddl => Ok(DataFrame::new(self.state(), LogicalPlan::Ddl(ddl))),
698                }
699            }
700            // TODO what about the other statements (like TransactionStart and TransactionEnd)
701            LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
702                self.set_variable(stmt).await
703            }
704            LogicalPlan::Statement(Statement::Prepare(Prepare {
705                name,
706                input,
707                data_types,
708            })) => {
709                // The number of parameters must match the specified data types length.
710                if !data_types.is_empty() {
711                    let param_names = input.get_parameter_names()?;
712                    if param_names.len() != data_types.len() {
713                        return plan_err!(
714                            "Prepare specifies {} data types but query has {} parameters",
715                            data_types.len(),
716                            param_names.len()
717                        );
718                    }
719                }
720                // Store the unoptimized plan into the session state. Although storing the
721                // optimized plan or the physical plan would be more efficient, doing so is
722                // not currently feasible. This is because `now()` would be optimized to a
723                // constant value, causing each EXECUTE to yield the same result, which is
724                // incorrect behavior.
725                self.state.write().store_prepared(name, data_types, input)?;
726                self.return_empty_dataframe()
727            }
728            LogicalPlan::Statement(Statement::Execute(execute)) => {
729                self.execute_prepared(execute)
730            }
731            LogicalPlan::Statement(Statement::Deallocate(deallocate)) => {
732                self.state
733                    .write()
734                    .remove_prepared(deallocate.name.as_str())?;
735                self.return_empty_dataframe()
736            }
737            plan => Ok(DataFrame::new(self.state(), plan)),
738        }
739    }
740
741    /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type
742    /// coercion and function rewrites.
743    ///
744    /// Note: The expression is not [simplified] or otherwise optimized:
745    /// `a = 1 + 2` will not be simplified to `a = 3` as this is a more involved process.
746    /// See the [expr_api] example for how to simplify expressions.
747    ///
748    /// # Example
749    /// ```
750    /// # use std::sync::Arc;
751    /// # use arrow::datatypes::{DataType, Field, Schema};
752    /// # use datafusion::prelude::*;
753    /// # use datafusion_common::DFSchema;
754    /// // a = 1 (i64)
755    /// let expr = col("a").eq(lit(1i64));
756    /// // provide type information that `a` is an Int32
757    /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
758    /// let df_schema = DFSchema::try_from(schema).unwrap();
759    /// // Create a PhysicalExpr. Note DataFusion automatically coerces (casts) `1i64` to `1i32`
760    /// let physical_expr = SessionContext::new()
761    ///   .create_physical_expr(expr, &df_schema).unwrap();
762    /// ```
763    /// # See Also
764    /// * [`SessionState::create_physical_expr`] for a lower level API
765    ///
766    /// [simplified]: datafusion_optimizer::simplify_expressions
767    /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs
768    pub fn create_physical_expr(
769        &self,
770        expr: Expr,
771        df_schema: &DFSchema,
772    ) -> Result<Arc<dyn PhysicalExpr>> {
773        self.state.read().create_physical_expr(expr, df_schema)
774    }
775
776    // return an empty dataframe
777    fn return_empty_dataframe(&self) -> Result<DataFrame> {
778        let plan = LogicalPlanBuilder::empty(false).build()?;
779        Ok(DataFrame::new(self.state(), plan))
780    }
781
782    async fn create_external_table(
783        &self,
784        cmd: &CreateExternalTable,
785    ) -> Result<DataFrame> {
786        let exist = self.table_exist(cmd.name.clone())?;
787
788        if cmd.temporary {
789            return not_impl_err!("Temporary tables not supported");
790        }
791
792        if exist {
793            match cmd.if_not_exists {
794                true => return self.return_empty_dataframe(),
795                false => {
796                    return exec_err!("Table '{}' already exists", cmd.name);
797                }
798            }
799        }
800
801        let table_provider: Arc<dyn TableProvider> =
802            self.create_custom_table(cmd).await?;
803        self.register_table(cmd.name.clone(), table_provider)?;
804        self.return_empty_dataframe()
805    }
806
807    async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result<DataFrame> {
808        let CreateMemoryTable {
809            name,
810            input,
811            if_not_exists,
812            or_replace,
813            constraints,
814            column_defaults,
815            temporary,
816        } = cmd;
817
818        let input = Arc::unwrap_or_clone(input);
819        let input = self.state().optimize(&input)?;
820
821        if temporary {
822            return not_impl_err!("Temporary tables not supported");
823        }
824
825        let table = self.table(name.clone()).await;
826        match (if_not_exists, or_replace, table) {
827            (true, false, Ok(_)) => self.return_empty_dataframe(),
828            (false, true, Ok(_)) => {
829                self.deregister_table(name.clone())?;
830                let schema = Arc::new(input.schema().as_ref().into());
831                let physical = DataFrame::new(self.state(), input);
832
833                let batches: Vec<_> = physical.collect_partitioned().await?;
834                let table = Arc::new(
835                    // pass constraints and column defaults to the mem table.
836                    MemTable::try_new(schema, batches)?
837                        .with_constraints(constraints)
838                        .with_column_defaults(column_defaults.into_iter().collect()),
839                );
840
841                self.register_table(name.clone(), table)?;
842                self.return_empty_dataframe()
843            }
844            (true, true, Ok(_)) => {
845                exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'")
846            }
847            (_, _, Err(_)) => {
848                let df_schema = input.schema();
849                let schema = Arc::new(df_schema.as_ref().into());
850                let physical = DataFrame::new(self.state(), input);
851
852                let batches: Vec<_> = physical.collect_partitioned().await?;
853                let table = Arc::new(
854                    // pass constraints and column defaults to the mem table.
855                    MemTable::try_new(schema, batches)?
856                        .with_constraints(constraints)
857                        .with_column_defaults(column_defaults.into_iter().collect()),
858                );
859
860                self.register_table(name, table)?;
861                self.return_empty_dataframe()
862            }
863            (false, false, Ok(_)) => exec_err!("Table '{name}' already exists"),
864        }
865    }
866
867    /// Applies the `TypeCoercion` rewriter to the logical plan.
868    fn apply_type_coercion(logical_plan: LogicalPlan) -> Result<LogicalPlan> {
869        let options = ConfigOptions::default();
870        Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]).execute_and_check(
871            logical_plan,
872            &options,
873            |_, _| {},
874        )
875    }
876
877    async fn create_view(&self, cmd: CreateView) -> Result<DataFrame> {
878        let CreateView {
879            name,
880            input,
881            or_replace,
882            definition,
883            temporary,
884        } = cmd;
885
886        let view = self.table(name.clone()).await;
887
888        if temporary {
889            return not_impl_err!("Temporary views not supported");
890        }
891
892        match (or_replace, view) {
893            (true, Ok(_)) => {
894                self.deregister_table(name.clone())?;
895                let input = Self::apply_type_coercion(input.as_ref().clone())?;
896                let table = Arc::new(ViewTable::new(input, definition));
897                self.register_table(name, table)?;
898                self.return_empty_dataframe()
899            }
900            (_, Err(_)) => {
901                let input = Self::apply_type_coercion(input.as_ref().clone())?;
902                let table = Arc::new(ViewTable::new(input, definition));
903                self.register_table(name, table)?;
904                self.return_empty_dataframe()
905            }
906            (false, Ok(_)) => exec_err!("Table '{name}' already exists"),
907        }
908    }
909
910    async fn create_catalog_schema(&self, cmd: CreateCatalogSchema) -> Result<DataFrame> {
911        let CreateCatalogSchema {
912            schema_name,
913            if_not_exists,
914            ..
915        } = cmd;
916
917        // sqlparser doesnt accept database / catalog as parameter to CREATE SCHEMA
918        // so for now, we default to default catalog
919        let tokens: Vec<&str> = schema_name.split('.').collect();
920        let (catalog, schema_name) = match tokens.len() {
921            1 => {
922                let state = self.state.read();
923                let name = &state.config().options().catalog.default_catalog;
924                let catalog = state.catalog_list().catalog(name).ok_or_else(|| {
925                    DataFusionError::Execution(format!(
926                        "Missing default catalog '{name}'"
927                    ))
928                })?;
929                (catalog, tokens[0])
930            }
931            2 => {
932                let name = &tokens[0];
933                let catalog = self.catalog(name).ok_or_else(|| {
934                    DataFusionError::Execution(format!("Missing catalog '{name}'"))
935                })?;
936                (catalog, tokens[1])
937            }
938            _ => return exec_err!("Unable to parse catalog from {schema_name}"),
939        };
940        let schema = catalog.schema(schema_name);
941
942        match (if_not_exists, schema) {
943            (true, Some(_)) => self.return_empty_dataframe(),
944            (true, None) | (false, None) => {
945                let schema = Arc::new(MemorySchemaProvider::new());
946                catalog.register_schema(schema_name, schema)?;
947                self.return_empty_dataframe()
948            }
949            (false, Some(_)) => exec_err!("Schema '{schema_name}' already exists"),
950        }
951    }
952
953    async fn create_catalog(&self, cmd: CreateCatalog) -> Result<DataFrame> {
954        let CreateCatalog {
955            catalog_name,
956            if_not_exists,
957            ..
958        } = cmd;
959        let catalog = self.catalog(catalog_name.as_str());
960
961        match (if_not_exists, catalog) {
962            (true, Some(_)) => self.return_empty_dataframe(),
963            (true, None) | (false, None) => {
964                let new_catalog = Arc::new(MemoryCatalogProvider::new());
965                self.state
966                    .write()
967                    .catalog_list()
968                    .register_catalog(catalog_name, new_catalog);
969                self.return_empty_dataframe()
970            }
971            (false, Some(_)) => exec_err!("Catalog '{catalog_name}' already exists"),
972        }
973    }
974
975    async fn drop_table(&self, cmd: DropTable) -> Result<DataFrame> {
976        let DropTable {
977            name, if_exists, ..
978        } = cmd;
979        let result = self
980            .find_and_deregister(name.clone(), TableType::Base)
981            .await;
982        match (result, if_exists) {
983            (Ok(true), _) => self.return_empty_dataframe(),
984            (_, true) => self.return_empty_dataframe(),
985            (_, _) => exec_err!("Table '{name}' doesn't exist."),
986        }
987    }
988
989    async fn drop_view(&self, cmd: DropView) -> Result<DataFrame> {
990        let DropView {
991            name, if_exists, ..
992        } = cmd;
993        let result = self
994            .find_and_deregister(name.clone(), TableType::View)
995            .await;
996        match (result, if_exists) {
997            (Ok(true), _) => self.return_empty_dataframe(),
998            (_, true) => self.return_empty_dataframe(),
999            (_, _) => exec_err!("View '{name}' doesn't exist."),
1000        }
1001    }
1002
1003    async fn drop_schema(&self, cmd: DropCatalogSchema) -> Result<DataFrame> {
1004        let DropCatalogSchema {
1005            name,
1006            if_exists: allow_missing,
1007            cascade,
1008            schema: _,
1009        } = cmd;
1010        let catalog = {
1011            let state = self.state.read();
1012            let catalog_name = match &name {
1013                SchemaReference::Full { catalog, .. } => catalog.to_string(),
1014                SchemaReference::Bare { .. } => {
1015                    state.config_options().catalog.default_catalog.to_string()
1016                }
1017            };
1018            if let Some(catalog) = state.catalog_list().catalog(&catalog_name) {
1019                catalog
1020            } else if allow_missing {
1021                return self.return_empty_dataframe();
1022            } else {
1023                return self.schema_doesnt_exist_err(name);
1024            }
1025        };
1026        let dereg = catalog.deregister_schema(name.schema_name(), cascade)?;
1027        match (dereg, allow_missing) {
1028            (None, true) => self.return_empty_dataframe(),
1029            (None, false) => self.schema_doesnt_exist_err(name),
1030            (Some(_), _) => self.return_empty_dataframe(),
1031        }
1032    }
1033
1034    fn schema_doesnt_exist_err(&self, schemaref: SchemaReference) -> Result<DataFrame> {
1035        exec_err!("Schema '{schemaref}' doesn't exist.")
1036    }
1037
1038    async fn set_variable(&self, stmt: SetVariable) -> Result<DataFrame> {
1039        let SetVariable {
1040            variable, value, ..
1041        } = stmt;
1042
1043        // Check if this is a runtime configuration
1044        if variable.starts_with("datafusion.runtime.") {
1045            self.set_runtime_variable(&variable, &value)?;
1046        } else {
1047            let mut state = self.state.write();
1048            state.config_mut().options_mut().set(&variable, &value)?;
1049            drop(state);
1050        }
1051
1052        self.return_empty_dataframe()
1053    }
1054
1055    fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> {
1056        let key = variable.strip_prefix("datafusion.runtime.").unwrap();
1057
1058        match key {
1059            "memory_limit" => {
1060                let memory_limit = Self::parse_memory_limit(value)?;
1061
1062                let mut state = self.state.write();
1063                let mut builder =
1064                    RuntimeEnvBuilder::from_runtime_env(state.runtime_env());
1065                builder = builder.with_memory_limit(memory_limit, 1.0);
1066                *state = SessionStateBuilder::from(state.clone())
1067                    .with_runtime_env(Arc::new(builder.build()?))
1068                    .build();
1069            }
1070            _ => {
1071                return Err(DataFusionError::Plan(format!(
1072                    "Unknown runtime configuration: {variable}"
1073                )))
1074            }
1075        }
1076        Ok(())
1077    }
1078
1079    /// Parse memory limit from string to number of bytes
1080    /// Supports formats like '1.5G', '100M', '512K'
1081    ///
1082    /// # Examples
1083    /// ```
1084    /// use datafusion::execution::context::SessionContext;
1085    ///
1086    /// assert_eq!(SessionContext::parse_memory_limit("1M").unwrap(), 1024 * 1024);
1087    /// assert_eq!(SessionContext::parse_memory_limit("1.5G").unwrap(), (1.5 * 1024.0 * 1024.0 * 1024.0) as usize);
1088    /// ```
1089    pub fn parse_memory_limit(limit: &str) -> Result<usize> {
1090        let (number, unit) = limit.split_at(limit.len() - 1);
1091        let number: f64 = number.parse().map_err(|_| {
1092            DataFusionError::Plan(format!(
1093                "Failed to parse number from memory limit '{limit}'"
1094            ))
1095        })?;
1096
1097        match unit {
1098            "K" => Ok((number * 1024.0) as usize),
1099            "M" => Ok((number * 1024.0 * 1024.0) as usize),
1100            "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize),
1101            _ => Err(DataFusionError::Plan(format!(
1102                "Unsupported unit '{unit}' in memory limit '{limit}'"
1103            ))),
1104        }
1105    }
1106
1107    async fn create_custom_table(
1108        &self,
1109        cmd: &CreateExternalTable,
1110    ) -> Result<Arc<dyn TableProvider>> {
1111        let state = self.state.read().clone();
1112        let file_type = cmd.file_type.to_uppercase();
1113        let factory =
1114            state
1115                .table_factories()
1116                .get(file_type.as_str())
1117                .ok_or_else(|| {
1118                    DataFusionError::Execution(format!(
1119                        "Unable to find factory for {}",
1120                        cmd.file_type
1121                    ))
1122                })?;
1123        let table = (*factory).create(&state, cmd).await?;
1124        Ok(table)
1125    }
1126
1127    async fn find_and_deregister(
1128        &self,
1129        table_ref: impl Into<TableReference>,
1130        table_type: TableType,
1131    ) -> Result<bool> {
1132        let table_ref = table_ref.into();
1133        let table = table_ref.table().to_owned();
1134        let maybe_schema = {
1135            let state = self.state.read();
1136            let resolved = state.resolve_table_ref(table_ref);
1137            state
1138                .catalog_list()
1139                .catalog(&resolved.catalog)
1140                .and_then(|c| c.schema(&resolved.schema))
1141        };
1142
1143        if let Some(schema) = maybe_schema {
1144            if let Some(table_provider) = schema.table(&table).await? {
1145                if table_provider.table_type() == table_type {
1146                    schema.deregister_table(&table)?;
1147                    return Ok(true);
1148                }
1149            }
1150        }
1151
1152        Ok(false)
1153    }
1154
1155    async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
1156        let function = {
1157            let state = self.state.read().clone();
1158            let function_factory = state.function_factory();
1159
1160            match function_factory {
1161                Some(f) => f.create(&state, stmt).await?,
1162                _ => Err(DataFusionError::Configuration(
1163                    "Function factory has not been configured".into(),
1164                ))?,
1165            }
1166        };
1167
1168        match function {
1169            RegisterFunction::Scalar(f) => {
1170                self.state.write().register_udf(f)?;
1171            }
1172            RegisterFunction::Aggregate(f) => {
1173                self.state.write().register_udaf(f)?;
1174            }
1175            RegisterFunction::Window(f) => {
1176                self.state.write().register_udwf(f)?;
1177            }
1178            RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
1179        };
1180
1181        self.return_empty_dataframe()
1182    }
1183
1184    async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
1185        // we don't know function type at this point
1186        // decision has been made to drop all functions
1187        let mut dropped = false;
1188        dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
1189        dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
1190        dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();
1191        dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some();
1192
1193        // DROP FUNCTION IF EXISTS drops the specified function only if that
1194        // function exists and in this way, it avoids error. While the DROP FUNCTION
1195        // statement also performs the same function, it throws an
1196        // error if the function does not exist.
1197
1198        if !stmt.if_exists && !dropped {
1199            exec_err!("Function does not exist")
1200        } else {
1201            self.return_empty_dataframe()
1202        }
1203    }
1204
1205    fn execute_prepared(&self, execute: Execute) -> Result<DataFrame> {
1206        let Execute {
1207            name, parameters, ..
1208        } = execute;
1209        let prepared = self.state.read().get_prepared(&name).ok_or_else(|| {
1210            exec_datafusion_err!("Prepared statement '{}' does not exist", name)
1211        })?;
1212
1213        // Only allow literals as parameters for now.
1214        let mut params: Vec<ScalarValue> = parameters
1215            .into_iter()
1216            .map(|e| match e {
1217                Expr::Literal(scalar, _) => Ok(scalar),
1218                _ => not_impl_err!("Unsupported parameter type: {}", e),
1219            })
1220            .collect::<Result<_>>()?;
1221
1222        // If the prepared statement provides data types, cast the params to those types.
1223        if !prepared.data_types.is_empty() {
1224            if params.len() != prepared.data_types.len() {
1225                return exec_err!(
1226                    "Prepared statement '{}' expects {} parameters, but {} provided",
1227                    name,
1228                    prepared.data_types.len(),
1229                    params.len()
1230                );
1231            }
1232            params = params
1233                .into_iter()
1234                .zip(prepared.data_types.iter())
1235                .map(|(e, dt)| e.cast_to(dt))
1236                .collect::<Result<_>>()?;
1237        }
1238
1239        let params = ParamValues::List(params);
1240        let plan = prepared
1241            .plan
1242            .as_ref()
1243            .clone()
1244            .replace_params_with_values(&params)?;
1245        Ok(DataFrame::new(self.state(), plan))
1246    }
1247
1248    /// Registers a variable provider within this context.
1249    pub fn register_variable(
1250        &self,
1251        variable_type: VarType,
1252        provider: Arc<dyn VarProvider + Send + Sync>,
1253    ) {
1254        self.state
1255            .write()
1256            .execution_props_mut()
1257            .add_var_provider(variable_type, provider);
1258    }
1259
1260    /// Register a table UDF with this context
1261    pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
1262        self.state.write().register_udtf(name, fun)
1263    }
1264
1265    /// Registers a scalar UDF within this context.
1266    ///
1267    /// Note in SQL queries, function names are looked up using
1268    /// lowercase unless the query uses quotes. For example,
1269    ///
1270    /// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
1271    /// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
1272    ///
1273    /// Any functions registered with the udf name or its aliases will be overwritten with this new function
1274    pub fn register_udf(&self, f: ScalarUDF) {
1275        let mut state = self.state.write();
1276        state.register_udf(Arc::new(f)).ok();
1277    }
1278
1279    /// Registers an aggregate UDF within this context.
1280    ///
1281    /// Note in SQL queries, aggregate names are looked up using
1282    /// lowercase unless the query uses quotes. For example,
1283    ///
1284    /// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
1285    /// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
1286    pub fn register_udaf(&self, f: AggregateUDF) {
1287        self.state.write().register_udaf(Arc::new(f)).ok();
1288    }
1289
1290    /// Registers a window UDF within this context.
1291    ///
1292    /// Note in SQL queries, window function names are looked up using
1293    /// lowercase unless the query uses quotes. For example,
1294    ///
1295    /// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"`
1296    /// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
1297    pub fn register_udwf(&self, f: WindowUDF) {
1298        self.state.write().register_udwf(Arc::new(f)).ok();
1299    }
1300
1301    /// Deregisters a UDF within this context.
1302    pub fn deregister_udf(&self, name: &str) {
1303        self.state.write().deregister_udf(name).ok();
1304    }
1305
1306    /// Deregisters a UDAF within this context.
1307    pub fn deregister_udaf(&self, name: &str) {
1308        self.state.write().deregister_udaf(name).ok();
1309    }
1310
1311    /// Deregisters a UDWF within this context.
1312    pub fn deregister_udwf(&self, name: &str) {
1313        self.state.write().deregister_udwf(name).ok();
1314    }
1315
1316    /// Deregisters a UDTF within this context.
1317    pub fn deregister_udtf(&self, name: &str) {
1318        self.state.write().deregister_udtf(name).ok();
1319    }
1320
1321    /// Creates a [`DataFrame`] for reading a data source.
1322    ///
1323    /// For more control such as reading multiple files, you can use
1324    /// [`read_table`](Self::read_table) with a [`ListingTable`].
1325    async fn _read_type<'a, P: DataFilePaths>(
1326        &self,
1327        table_paths: P,
1328        options: impl ReadOptions<'a>,
1329    ) -> Result<DataFrame> {
1330        let table_paths = table_paths.to_urls()?;
1331        let session_config = self.copied_config();
1332        let listing_options =
1333            options.to_listing_options(&session_config, self.copied_table_options());
1334
1335        let option_extension = listing_options.file_extension.clone();
1336
1337        if table_paths.is_empty() {
1338            return exec_err!("No table paths were provided");
1339        }
1340
1341        // check if the file extension matches the expected extension
1342        for path in &table_paths {
1343            let file_path = path.as_str();
1344            if !file_path.ends_with(option_extension.clone().as_str())
1345                && !path.is_collection()
1346            {
1347                return exec_err!(
1348                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
1349                );
1350            }
1351        }
1352
1353        let resolved_schema = options
1354            .get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
1355            .await?;
1356        let config = ListingTableConfig::new_with_multi_paths(table_paths)
1357            .with_listing_options(listing_options)
1358            .with_schema(resolved_schema);
1359        let provider = ListingTable::try_new(config)?;
1360        self.read_table(Arc::new(provider))
1361    }
1362
1363    /// Creates a [`DataFrame`] for reading an Arrow data source.
1364    ///
1365    /// For more control such as reading multiple files, you can use
1366    /// [`read_table`](Self::read_table) with a [`ListingTable`].
1367    ///
1368    /// For an example, see [`read_csv`](Self::read_csv)
1369    pub async fn read_arrow<P: DataFilePaths>(
1370        &self,
1371        table_paths: P,
1372        options: ArrowReadOptions<'_>,
1373    ) -> Result<DataFrame> {
1374        self._read_type(table_paths, options).await
1375    }
1376
1377    /// Creates an empty DataFrame.
1378    pub fn read_empty(&self) -> Result<DataFrame> {
1379        Ok(DataFrame::new(
1380            self.state(),
1381            LogicalPlanBuilder::empty(true).build()?,
1382        ))
1383    }
1384
1385    /// Creates a [`DataFrame`] for a [`TableProvider`] such as a
1386    /// [`ListingTable`] or a custom user defined provider.
1387    pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
1388        Ok(DataFrame::new(
1389            self.state(),
1390            LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
1391                .build()?,
1392        ))
1393    }
1394
1395    /// Creates a [`DataFrame`] for reading a [`RecordBatch`]
1396    pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
1397        let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
1398        Ok(DataFrame::new(
1399            self.state(),
1400            LogicalPlanBuilder::scan(
1401                UNNAMED_TABLE,
1402                provider_as_source(Arc::new(provider)),
1403                None,
1404            )?
1405            .build()?,
1406        ))
1407    }
1408    /// Create a [`DataFrame`] for reading a [`Vec[`RecordBatch`]`]
1409    pub fn read_batches(
1410        &self,
1411        batches: impl IntoIterator<Item = RecordBatch>,
1412    ) -> Result<DataFrame> {
1413        // check schema uniqueness
1414        let mut batches = batches.into_iter().peekable();
1415        let schema = if let Some(batch) = batches.peek() {
1416            batch.schema()
1417        } else {
1418            Arc::new(Schema::empty())
1419        };
1420        let provider = MemTable::try_new(schema, vec![batches.collect()])?;
1421        Ok(DataFrame::new(
1422            self.state(),
1423            LogicalPlanBuilder::scan(
1424                UNNAMED_TABLE,
1425                provider_as_source(Arc::new(provider)),
1426                None,
1427            )?
1428            .build()?,
1429        ))
1430    }
1431    /// Registers a [`ListingTable`] that can assemble multiple files
1432    /// from locations in an [`ObjectStore`] instance into a single
1433    /// table.
1434    ///
1435    /// This method is `async` because it might need to resolve the schema.
1436    ///
1437    /// [`ObjectStore`]: object_store::ObjectStore
1438    pub async fn register_listing_table(
1439        &self,
1440        table_ref: impl Into<TableReference>,
1441        table_path: impl AsRef<str>,
1442        options: ListingOptions,
1443        provided_schema: Option<SchemaRef>,
1444        sql_definition: Option<String>,
1445    ) -> Result<()> {
1446        let table_path = ListingTableUrl::parse(table_path)?;
1447        let resolved_schema = match provided_schema {
1448            Some(s) => s,
1449            None => options.infer_schema(&self.state(), &table_path).await?,
1450        };
1451        let config = ListingTableConfig::new(table_path)
1452            .with_listing_options(options)
1453            .with_schema(resolved_schema);
1454        let table = ListingTable::try_new(config)?.with_definition(sql_definition);
1455        self.register_table(table_ref, Arc::new(table))?;
1456        Ok(())
1457    }
1458
1459    fn register_type_check<P: DataFilePaths>(
1460        &self,
1461        table_paths: P,
1462        extension: impl AsRef<str>,
1463    ) -> Result<()> {
1464        let table_paths = table_paths.to_urls()?;
1465        if table_paths.is_empty() {
1466            return exec_err!("No table paths were provided");
1467        }
1468
1469        // check if the file extension matches the expected extension
1470        let extension = extension.as_ref();
1471        for path in &table_paths {
1472            let file_path = path.as_str();
1473            if !file_path.ends_with(extension) && !path.is_collection() {
1474                return exec_err!(
1475                    "File path '{file_path}' does not match the expected extension '{extension}'"
1476                );
1477            }
1478        }
1479        Ok(())
1480    }
1481
1482    /// Registers an Arrow file as a table that can be referenced from
1483    /// SQL statements executed against this context.
1484    pub async fn register_arrow(
1485        &self,
1486        name: &str,
1487        table_path: &str,
1488        options: ArrowReadOptions<'_>,
1489    ) -> Result<()> {
1490        let listing_options = options
1491            .to_listing_options(&self.copied_config(), self.copied_table_options());
1492
1493        self.register_listing_table(
1494            name,
1495            table_path,
1496            listing_options,
1497            options.schema.map(|s| Arc::new(s.to_owned())),
1498            None,
1499        )
1500        .await?;
1501        Ok(())
1502    }
1503
1504    /// Registers a named catalog using a custom `CatalogProvider` so that
1505    /// it can be referenced from SQL statements executed against this
1506    /// context.
1507    ///
1508    /// Returns the [`CatalogProvider`] previously registered for this
1509    /// name, if any
1510    pub fn register_catalog(
1511        &self,
1512        name: impl Into<String>,
1513        catalog: Arc<dyn CatalogProvider>,
1514    ) -> Option<Arc<dyn CatalogProvider>> {
1515        let name = name.into();
1516        self.state
1517            .read()
1518            .catalog_list()
1519            .register_catalog(name, catalog)
1520    }
1521
1522    /// Retrieves the list of available catalog names.
1523    pub fn catalog_names(&self) -> Vec<String> {
1524        self.state.read().catalog_list().catalog_names()
1525    }
1526
1527    /// Retrieves a [`CatalogProvider`] instance by name
1528    pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
1529        self.state.read().catalog_list().catalog(name)
1530    }
1531
1532    /// Registers a [`TableProvider`] as a table that can be
1533    /// referenced from SQL statements executed against this context.
1534    ///
1535    /// If a table of the same name was already registered, returns "Table
1536    /// already exists" error.
1537    pub fn register_table(
1538        &self,
1539        table_ref: impl Into<TableReference>,
1540        provider: Arc<dyn TableProvider>,
1541    ) -> Result<Option<Arc<dyn TableProvider>>> {
1542        let table_ref: TableReference = table_ref.into();
1543        let table = table_ref.table().to_owned();
1544        self.state
1545            .read()
1546            .schema_for_ref(table_ref)?
1547            .register_table(table, provider)
1548    }
1549
1550    /// Deregisters the given table.
1551    ///
1552    /// Returns the registered provider, if any
1553    pub fn deregister_table(
1554        &self,
1555        table_ref: impl Into<TableReference>,
1556    ) -> Result<Option<Arc<dyn TableProvider>>> {
1557        let table_ref = table_ref.into();
1558        let table = table_ref.table().to_owned();
1559        self.state
1560            .read()
1561            .schema_for_ref(table_ref)?
1562            .deregister_table(&table)
1563    }
1564
1565    /// Return `true` if the specified table exists in the schema provider.
1566    pub fn table_exist(&self, table_ref: impl Into<TableReference>) -> Result<bool> {
1567        let table_ref: TableReference = table_ref.into();
1568        let table = table_ref.table();
1569        let table_ref = table_ref.clone();
1570        Ok(self
1571            .state
1572            .read()
1573            .schema_for_ref(table_ref)?
1574            .table_exist(table))
1575    }
1576
1577    /// Retrieves a [`DataFrame`] representing a table previously
1578    /// registered by calling the [`register_table`] function.
1579    ///
1580    /// Returns an error if no table has been registered with the
1581    /// provided reference.
1582    ///
1583    /// [`register_table`]: SessionContext::register_table
1584    pub async fn table(&self, table_ref: impl Into<TableReference>) -> Result<DataFrame> {
1585        let table_ref: TableReference = table_ref.into();
1586        let provider = self.table_provider(table_ref.clone()).await?;
1587        let plan = LogicalPlanBuilder::scan(
1588            table_ref,
1589            provider_as_source(Arc::clone(&provider)),
1590            None,
1591        )?
1592        .build()?;
1593        Ok(DataFrame::new(self.state(), plan))
1594    }
1595
1596    /// Retrieves a [`TableFunction`] reference by name.
1597    ///
1598    /// Returns an error if no table function has been registered with the provided name.
1599    ///
1600    /// [`register_udtf`]: SessionContext::register_udtf
1601    pub fn table_function(&self, name: &str) -> Result<Arc<TableFunction>> {
1602        self.state
1603            .read()
1604            .table_functions()
1605            .get(name)
1606            .cloned()
1607            .ok_or_else(|| plan_datafusion_err!("Table function '{name}' not found"))
1608    }
1609
1610    /// Return a [`TableProvider`] for the specified table.
1611    pub async fn table_provider(
1612        &self,
1613        table_ref: impl Into<TableReference>,
1614    ) -> Result<Arc<dyn TableProvider>> {
1615        let table_ref = table_ref.into();
1616        let table = table_ref.table().to_string();
1617        let schema = self.state.read().schema_for_ref(table_ref)?;
1618        match schema.table(&table).await? {
1619            Some(ref provider) => Ok(Arc::clone(provider)),
1620            _ => plan_err!("No table named '{table}'"),
1621        }
1622    }
1623
1624    /// Get a new TaskContext to run in this session
1625    pub fn task_ctx(&self) -> Arc<TaskContext> {
1626        Arc::new(TaskContext::from(self))
1627    }
1628
1629    /// Return a new  [`SessionState`] suitable for executing a single query.
1630    ///
1631    /// Notes:
1632    ///
1633    /// 1. `query_execution_start_time` is set to the current time for the
1634    ///    returned state.
1635    ///
1636    /// 2. The returned state is not shared with the current session state
1637    ///    and this changes to the returned `SessionState` such as changing
1638    ///    [`ConfigOptions`] will not be reflected in this `SessionContext`.
1639    ///
1640    /// [`ConfigOptions`]: crate::config::ConfigOptions
1641    pub fn state(&self) -> SessionState {
1642        let mut state = self.state.read().clone();
1643        state.execution_props_mut().start_execution();
1644        state
1645    }
1646
1647    /// Get reference to [`SessionState`]
1648    pub fn state_ref(&self) -> Arc<RwLock<SessionState>> {
1649        Arc::clone(&self.state)
1650    }
1651
1652    /// Get weak reference to [`SessionState`]
1653    pub fn state_weak_ref(&self) -> Weak<RwLock<SessionState>> {
1654        Arc::downgrade(&self.state)
1655    }
1656
1657    /// Register [`CatalogProviderList`] in [`SessionState`]
1658    pub fn register_catalog_list(&self, catalog_list: Arc<dyn CatalogProviderList>) {
1659        self.state.write().register_catalog_list(catalog_list)
1660    }
1661
1662    /// Registers a [`ConfigExtension`] as a table option extension that can be
1663    /// referenced from SQL statements executed against this context.
1664    pub fn register_table_options_extension<T: ConfigExtension>(&self, extension: T) {
1665        self.state
1666            .write()
1667            .register_table_options_extension(extension)
1668    }
1669}
1670
1671impl FunctionRegistry for SessionContext {
1672    fn udfs(&self) -> HashSet<String> {
1673        self.state.read().udfs()
1674    }
1675
1676    fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
1677        self.state.read().udf(name)
1678    }
1679
1680    fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
1681        self.state.read().udaf(name)
1682    }
1683
1684    fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
1685        self.state.read().udwf(name)
1686    }
1687
1688    fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
1689        self.state.write().register_udf(udf)
1690    }
1691
1692    fn register_udaf(
1693        &mut self,
1694        udaf: Arc<AggregateUDF>,
1695    ) -> Result<Option<Arc<AggregateUDF>>> {
1696        self.state.write().register_udaf(udaf)
1697    }
1698
1699    fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
1700        self.state.write().register_udwf(udwf)
1701    }
1702
1703    fn register_function_rewrite(
1704        &mut self,
1705        rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
1706    ) -> Result<()> {
1707        self.state.write().register_function_rewrite(rewrite)
1708    }
1709
1710    fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
1711        self.state.read().expr_planners().to_vec()
1712    }
1713
1714    fn register_expr_planner(
1715        &mut self,
1716        expr_planner: Arc<dyn ExprPlanner>,
1717    ) -> Result<()> {
1718        self.state.write().register_expr_planner(expr_planner)
1719    }
1720}
1721
1722/// Create a new task context instance from SessionContext
1723impl From<&SessionContext> for TaskContext {
1724    fn from(session: &SessionContext) -> Self {
1725        TaskContext::from(&*session.state.read())
1726    }
1727}
1728
1729impl From<SessionState> for SessionContext {
1730    fn from(state: SessionState) -> Self {
1731        Self::new_with_state(state)
1732    }
1733}
1734
1735impl From<SessionContext> for SessionStateBuilder {
1736    fn from(session: SessionContext) -> Self {
1737        session.into_state_builder()
1738    }
1739}
1740
1741/// A planner used to add extensions to DataFusion logical and physical plans.
1742#[async_trait]
1743pub trait QueryPlanner: Debug {
1744    /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution
1745    async fn create_physical_plan(
1746        &self,
1747        logical_plan: &LogicalPlan,
1748        session_state: &SessionState,
1749    ) -> Result<Arc<dyn ExecutionPlan>>;
1750}
1751
1752/// A pluggable interface to handle `CREATE FUNCTION` statements
1753/// and interact with [SessionState] to registers new udf, udaf or udwf.
1754
1755#[async_trait]
1756pub trait FunctionFactory: Debug + Sync + Send {
1757    /// Handles creation of user defined function specified in [CreateFunction] statement
1758    async fn create(
1759        &self,
1760        state: &SessionState,
1761        statement: CreateFunction,
1762    ) -> Result<RegisterFunction>;
1763}
1764
1765/// Type of function to create
1766pub enum RegisterFunction {
1767    /// Scalar user defined function
1768    Scalar(Arc<ScalarUDF>),
1769    /// Aggregate user defined function
1770    Aggregate(Arc<AggregateUDF>),
1771    /// Window user defined function
1772    Window(Arc<WindowUDF>),
1773    /// Table user defined function
1774    Table(String, Arc<dyn TableFunctionImpl>),
1775}
1776
1777/// Default implementation of [SerializerRegistry] that throws unimplemented error
1778/// for all requests.
1779#[derive(Debug)]
1780pub struct EmptySerializerRegistry;
1781
1782impl SerializerRegistry for EmptySerializerRegistry {
1783    fn serialize_logical_plan(
1784        &self,
1785        node: &dyn UserDefinedLogicalNode,
1786    ) -> Result<Vec<u8>> {
1787        not_impl_err!(
1788            "Serializing user defined logical plan node `{}` is not supported",
1789            node.name()
1790        )
1791    }
1792
1793    fn deserialize_logical_plan(
1794        &self,
1795        name: &str,
1796        _bytes: &[u8],
1797    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
1798        not_impl_err!(
1799            "Deserializing user defined logical plan node `{name}` is not supported"
1800        )
1801    }
1802}
1803
1804/// Describes which SQL statements can be run.
1805///
1806/// See [`SessionContext::sql_with_options`] for more details.
1807#[derive(Clone, Debug, Copy)]
1808pub struct SQLOptions {
1809    /// See [`Self::with_allow_ddl`]
1810    allow_ddl: bool,
1811    /// See [`Self::with_allow_dml`]
1812    allow_dml: bool,
1813    /// See [`Self::with_allow_statements`]
1814    allow_statements: bool,
1815}
1816
1817impl Default for SQLOptions {
1818    fn default() -> Self {
1819        Self {
1820            allow_ddl: true,
1821            allow_dml: true,
1822            allow_statements: true,
1823        }
1824    }
1825}
1826
1827impl SQLOptions {
1828    /// Create a new `SQLOptions` with default values
1829    pub fn new() -> Self {
1830        Default::default()
1831    }
1832
1833    /// Should DDL data definition commands  (e.g. `CREATE TABLE`) be run? Defaults to `true`.
1834    pub fn with_allow_ddl(mut self, allow: bool) -> Self {
1835        self.allow_ddl = allow;
1836        self
1837    }
1838
1839    /// Should DML data modification commands (e.g. `INSERT` and `COPY`) be run? Defaults to `true`
1840    pub fn with_allow_dml(mut self, allow: bool) -> Self {
1841        self.allow_dml = allow;
1842        self
1843    }
1844
1845    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
1846    pub fn with_allow_statements(mut self, allow: bool) -> Self {
1847        self.allow_statements = allow;
1848        self
1849    }
1850
1851    /// Return an error if the [`LogicalPlan`] has any nodes that are
1852    /// incompatible with this [`SQLOptions`].
1853    pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> {
1854        plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?;
1855        Ok(())
1856    }
1857}
1858
1859struct BadPlanVisitor<'a> {
1860    options: &'a SQLOptions,
1861}
1862impl<'a> BadPlanVisitor<'a> {
1863    fn new(options: &'a SQLOptions) -> Self {
1864        Self { options }
1865    }
1866}
1867
1868impl<'n> TreeNodeVisitor<'n> for BadPlanVisitor<'_> {
1869    type Node = LogicalPlan;
1870
1871    fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1872        match node {
1873            LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
1874                plan_err!("DDL not supported: {}", ddl.name())
1875            }
1876            LogicalPlan::Dml(dml) if !self.options.allow_dml => {
1877                plan_err!("DML not supported: {}", dml.op)
1878            }
1879            LogicalPlan::Copy(_) if !self.options.allow_dml => {
1880                plan_err!("DML not supported: COPY")
1881            }
1882            LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
1883                plan_err!("Statement not supported: {}", stmt.name())
1884            }
1885            _ => Ok(TreeNodeRecursion::Continue),
1886        }
1887    }
1888}
1889
1890#[cfg(test)]
1891mod tests {
1892    use super::{super::options::CsvReadOptions, *};
1893    use crate::execution::memory_pool::MemoryConsumer;
1894    use crate::test;
1895    use crate::test_util::{plan_and_collect, populate_csv_partitions};
1896    use arrow::datatypes::{DataType, TimeUnit};
1897    use std::error::Error;
1898    use std::path::PathBuf;
1899
1900    use datafusion_common::test_util::batches_to_string;
1901    use datafusion_common_runtime::SpawnedTask;
1902    use insta::{allow_duplicates, assert_snapshot};
1903
1904    use crate::catalog::SchemaProvider;
1905    use crate::execution::session_state::SessionStateBuilder;
1906    use crate::physical_planner::PhysicalPlanner;
1907    use async_trait::async_trait;
1908    use datafusion_expr::planner::TypePlanner;
1909    use sqlparser::ast;
1910    use tempfile::TempDir;
1911
1912    #[tokio::test]
1913    async fn shared_memory_and_disk_manager() {
1914        // Demonstrate the ability to share DiskManager and
1915        // MemoryPool between two different executions.
1916        let ctx1 = SessionContext::new();
1917
1918        // configure with same memory / disk manager
1919        let memory_pool = ctx1.runtime_env().memory_pool.clone();
1920
1921        let mut reservation = MemoryConsumer::new("test").register(&memory_pool);
1922        reservation.grow(100);
1923
1924        let disk_manager = ctx1.runtime_env().disk_manager.clone();
1925
1926        let ctx2 =
1927            SessionContext::new_with_config_rt(SessionConfig::new(), ctx1.runtime_env());
1928
1929        assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100);
1930        assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100);
1931
1932        drop(reservation);
1933
1934        assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 0);
1935        assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 0);
1936
1937        assert!(std::ptr::eq(
1938            Arc::as_ptr(&disk_manager),
1939            Arc::as_ptr(&ctx1.runtime_env().disk_manager)
1940        ));
1941        assert!(std::ptr::eq(
1942            Arc::as_ptr(&disk_manager),
1943            Arc::as_ptr(&ctx2.runtime_env().disk_manager)
1944        ));
1945    }
1946
1947    #[tokio::test]
1948    async fn create_variable_expr() -> Result<()> {
1949        let tmp_dir = TempDir::new()?;
1950        let partition_count = 4;
1951        let ctx = create_ctx(&tmp_dir, partition_count).await?;
1952
1953        let variable_provider = test::variable::SystemVar::new();
1954        ctx.register_variable(VarType::System, Arc::new(variable_provider));
1955        let variable_provider = test::variable::UserDefinedVar::new();
1956        ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));
1957
1958        let provider = test::create_table_dual();
1959        ctx.register_table("dual", provider)?;
1960
1961        let results =
1962            plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
1963                .await?;
1964
1965        assert_snapshot!(batches_to_string(&results), @r"
1966        +----------------------+------------------------+---------------------+
1967        | @@version            | @name                  | @integer + Int64(1) |
1968        +----------------------+------------------------+---------------------+
1969        | system-var-@@version | user-defined-var-@name | 42                  |
1970        +----------------------+------------------------+---------------------+
1971        ");
1972
1973        Ok(())
1974    }
1975
1976    #[tokio::test]
1977    async fn create_variable_err() -> Result<()> {
1978        let ctx = SessionContext::new();
1979
1980        let err = plan_and_collect(&ctx, "SElECT @=   X3").await.unwrap_err();
1981        assert_eq!(
1982            err.strip_backtrace(),
1983            "Error during planning: variable [\"@=\"] has no type information"
1984        );
1985        Ok(())
1986    }
1987
1988    #[tokio::test]
1989    async fn register_deregister() -> Result<()> {
1990        let tmp_dir = TempDir::new()?;
1991        let partition_count = 4;
1992        let ctx = create_ctx(&tmp_dir, partition_count).await?;
1993
1994        let provider = test::create_table_dual();
1995        ctx.register_table("dual", provider)?;
1996
1997        assert!(ctx.deregister_table("dual")?.is_some());
1998        assert!(ctx.deregister_table("dual")?.is_none());
1999
2000        Ok(())
2001    }
2002
2003    #[tokio::test]
2004    async fn send_context_to_threads() -> Result<()> {
2005        // ensure SessionContexts can be used in a multi-threaded
2006        // environment. Use case is for concurrent planing.
2007        let tmp_dir = TempDir::new()?;
2008        let partition_count = 4;
2009        let ctx = Arc::new(create_ctx(&tmp_dir, partition_count).await?);
2010
2011        let threads: Vec<_> = (0..2)
2012            .map(|_| ctx.clone())
2013            .map(|ctx| {
2014                SpawnedTask::spawn(async move {
2015                    // Ensure we can create logical plan code on a separate thread.
2016                    ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")
2017                        .await
2018                })
2019            })
2020            .collect();
2021
2022        for handle in threads {
2023            handle.join().await.unwrap().unwrap();
2024        }
2025        Ok(())
2026    }
2027
2028    #[tokio::test]
2029    async fn with_listing_schema_provider() -> Result<()> {
2030        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2031        let path = path.join("tests/tpch-csv");
2032        let url = format!("file://{}", path.display());
2033
2034        let cfg = SessionConfig::new()
2035            .set_str("datafusion.catalog.location", url.as_str())
2036            .set_str("datafusion.catalog.format", "CSV")
2037            .set_str("datafusion.catalog.has_header", "true");
2038        let session_state = SessionStateBuilder::new()
2039            .with_config(cfg)
2040            .with_default_features()
2041            .build();
2042        let ctx = SessionContext::new_with_state(session_state);
2043        ctx.refresh_catalogs().await?;
2044
2045        let result =
2046            plan_and_collect(&ctx, "select c_name from default.customer limit 3;")
2047                .await?;
2048
2049        let actual = arrow::util::pretty::pretty_format_batches(&result)
2050            .unwrap()
2051            .to_string();
2052        assert_snapshot!(actual, @r"
2053        +--------------------+
2054        | c_name             |
2055        +--------------------+
2056        | Customer#000000002 |
2057        | Customer#000000003 |
2058        | Customer#000000004 |
2059        +--------------------+
2060        ");
2061
2062        Ok(())
2063    }
2064
2065    #[tokio::test]
2066    async fn test_dynamic_file_query() -> Result<()> {
2067        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2068        let path = path.join("tests/tpch-csv/customer.csv");
2069        let url = format!("file://{}", path.display());
2070        let cfg = SessionConfig::new();
2071        let session_state = SessionStateBuilder::new()
2072            .with_default_features()
2073            .with_config(cfg)
2074            .build();
2075        let ctx = SessionContext::new_with_state(session_state).enable_url_table();
2076        let result = plan_and_collect(
2077            &ctx,
2078            format!("select c_name from '{}' limit 3;", &url).as_str(),
2079        )
2080        .await?;
2081
2082        let actual = arrow::util::pretty::pretty_format_batches(&result)
2083            .unwrap()
2084            .to_string();
2085        assert_snapshot!(actual, @r"
2086        +--------------------+
2087        | c_name             |
2088        +--------------------+
2089        | Customer#000000002 |
2090        | Customer#000000003 |
2091        | Customer#000000004 |
2092        +--------------------+
2093        ");
2094
2095        Ok(())
2096    }
2097
2098    #[tokio::test]
2099    async fn custom_query_planner() -> Result<()> {
2100        let runtime = Arc::new(RuntimeEnv::default());
2101        let session_state = SessionStateBuilder::new()
2102            .with_config(SessionConfig::new())
2103            .with_runtime_env(runtime)
2104            .with_default_features()
2105            .with_query_planner(Arc::new(MyQueryPlanner {}))
2106            .build();
2107        let ctx = SessionContext::new_with_state(session_state);
2108
2109        let df = ctx.sql("SELECT 1").await?;
2110        df.collect().await.expect_err("query not supported");
2111        Ok(())
2112    }
2113
2114    #[tokio::test]
2115    async fn disabled_default_catalog_and_schema() -> Result<()> {
2116        let ctx = SessionContext::new_with_config(
2117            SessionConfig::new().with_create_default_catalog_and_schema(false),
2118        );
2119
2120        assert!(matches!(
2121            ctx.register_table("test", test::table_with_sequence(1, 1)?),
2122            Err(DataFusionError::Plan(_))
2123        ));
2124
2125        let err = ctx
2126            .sql("select * from datafusion.public.test")
2127            .await
2128            .unwrap_err();
2129        let err = err
2130            .source()
2131            .and_then(|err| err.downcast_ref::<DataFusionError>())
2132            .unwrap();
2133
2134        assert!(matches!(err, &DataFusionError::Plan(_)));
2135
2136        Ok(())
2137    }
2138
2139    #[tokio::test]
2140    async fn custom_catalog_and_schema() {
2141        let config = SessionConfig::new()
2142            .with_create_default_catalog_and_schema(true)
2143            .with_default_catalog_and_schema("my_catalog", "my_schema");
2144        catalog_and_schema_test(config).await;
2145    }
2146
2147    #[tokio::test]
2148    async fn custom_catalog_and_schema_no_default() {
2149        let config = SessionConfig::new()
2150            .with_create_default_catalog_and_schema(false)
2151            .with_default_catalog_and_schema("my_catalog", "my_schema");
2152        catalog_and_schema_test(config).await;
2153    }
2154
2155    #[tokio::test]
2156    async fn custom_catalog_and_schema_and_information_schema() {
2157        let config = SessionConfig::new()
2158            .with_create_default_catalog_and_schema(true)
2159            .with_information_schema(true)
2160            .with_default_catalog_and_schema("my_catalog", "my_schema");
2161        catalog_and_schema_test(config).await;
2162    }
2163
2164    async fn catalog_and_schema_test(config: SessionConfig) {
2165        let ctx = SessionContext::new_with_config(config);
2166        let catalog = MemoryCatalogProvider::new();
2167        let schema = MemorySchemaProvider::new();
2168        schema
2169            .register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap())
2170            .unwrap();
2171        catalog
2172            .register_schema("my_schema", Arc::new(schema))
2173            .unwrap();
2174        ctx.register_catalog("my_catalog", Arc::new(catalog));
2175
2176        let mut results = Vec::new();
2177
2178        for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] {
2179            let result = plan_and_collect(
2180                &ctx,
2181                &format!("SELECT COUNT(*) AS count FROM {table_ref}"),
2182            )
2183            .await
2184            .unwrap();
2185
2186            results.push(result);
2187        }
2188        allow_duplicates! {
2189            for result in &results {
2190                assert_snapshot!(batches_to_string(result), @r"
2191                +-------+
2192                | count |
2193                +-------+
2194                | 1     |
2195                +-------+
2196                ");
2197            }
2198        }
2199    }
2200
2201    #[tokio::test]
2202    async fn cross_catalog_access() -> Result<()> {
2203        let ctx = SessionContext::new();
2204
2205        let catalog_a = MemoryCatalogProvider::new();
2206        let schema_a = MemorySchemaProvider::new();
2207        schema_a
2208            .register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
2209        catalog_a.register_schema("schema_a", Arc::new(schema_a))?;
2210        ctx.register_catalog("catalog_a", Arc::new(catalog_a));
2211
2212        let catalog_b = MemoryCatalogProvider::new();
2213        let schema_b = MemorySchemaProvider::new();
2214        schema_b
2215            .register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
2216        catalog_b.register_schema("schema_b", Arc::new(schema_b))?;
2217        ctx.register_catalog("catalog_b", Arc::new(catalog_b));
2218
2219        let result = plan_and_collect(
2220            &ctx,
2221            "SELECT cat, SUM(i) AS total FROM (
2222                    SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a
2223                    UNION ALL
2224                    SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b
2225                ) AS all
2226                GROUP BY cat
2227                ORDER BY cat
2228                ",
2229        )
2230        .await?;
2231
2232        assert_snapshot!(batches_to_string(&result), @r"
2233        +-----+-------+
2234        | cat | total |
2235        +-----+-------+
2236        | a   | 1     |
2237        | b   | 3     |
2238        +-----+-------+
2239        ");
2240
2241        Ok(())
2242    }
2243
2244    #[tokio::test]
2245    async fn catalogs_not_leaked() {
2246        // the information schema used to introduce cyclic Arcs
2247        let ctx = SessionContext::new_with_config(
2248            SessionConfig::new().with_information_schema(true),
2249        );
2250
2251        // register a single catalog
2252        let catalog = Arc::new(MemoryCatalogProvider::new());
2253        let catalog_weak = Arc::downgrade(&catalog);
2254        ctx.register_catalog("my_catalog", catalog);
2255
2256        let catalog_list_weak = {
2257            let state = ctx.state.read();
2258            Arc::downgrade(state.catalog_list())
2259        };
2260
2261        drop(ctx);
2262
2263        assert_eq!(Weak::strong_count(&catalog_list_weak), 0);
2264        assert_eq!(Weak::strong_count(&catalog_weak), 0);
2265    }
2266
2267    #[tokio::test]
2268    async fn sql_create_schema() -> Result<()> {
2269        // the information schema used to introduce cyclic Arcs
2270        let ctx = SessionContext::new_with_config(
2271            SessionConfig::new().with_information_schema(true),
2272        );
2273
2274        // Create schema
2275        ctx.sql("CREATE SCHEMA abc").await?.collect().await?;
2276
2277        // Add table to schema
2278        ctx.sql("CREATE TABLE abc.y AS VALUES (1,2,3)")
2279            .await?
2280            .collect()
2281            .await?;
2282
2283        // Check table exists in schema
2284        let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
2285
2286        assert_eq!(results[0].num_rows(), 1);
2287        Ok(())
2288    }
2289
2290    #[tokio::test]
2291    async fn sql_create_catalog() -> Result<()> {
2292        // the information schema used to introduce cyclic Arcs
2293        let ctx = SessionContext::new_with_config(
2294            SessionConfig::new().with_information_schema(true),
2295        );
2296
2297        // Create catalog
2298        ctx.sql("CREATE DATABASE test").await?.collect().await?;
2299
2300        // Create schema
2301        ctx.sql("CREATE SCHEMA test.abc").await?.collect().await?;
2302
2303        // Add table to schema
2304        ctx.sql("CREATE TABLE test.abc.y AS VALUES (1,2,3)")
2305            .await?
2306            .collect()
2307            .await?;
2308
2309        // Check table exists in schema
2310        let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_catalog='test' AND table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
2311
2312        assert_eq!(results[0].num_rows(), 1);
2313        Ok(())
2314    }
2315
2316    #[tokio::test]
2317    async fn custom_type_planner() -> Result<()> {
2318        let state = SessionStateBuilder::new()
2319            .with_default_features()
2320            .with_type_planner(Arc::new(MyTypePlanner {}))
2321            .build();
2322        let ctx = SessionContext::new_with_state(state);
2323        let result = ctx
2324            .sql("SELECT DATETIME '2021-01-01 00:00:00'")
2325            .await?
2326            .collect()
2327            .await?;
2328        assert_snapshot!(batches_to_string(&result), @r#"
2329        +-----------------------------+
2330        | Utf8("2021-01-01 00:00:00") |
2331        +-----------------------------+
2332        | 2021-01-01T00:00:00         |
2333        +-----------------------------+
2334        "#);
2335        Ok(())
2336    }
2337    #[test]
2338    fn preserve_session_context_id() -> Result<()> {
2339        let ctx = SessionContext::new();
2340        // it does make sense to preserve session id in this case
2341        // as  `enable_url_table()` can be seen as additional configuration
2342        // option on ctx.
2343        // some systems like datafusion ballista relies on stable session_id
2344        assert_eq!(ctx.session_id(), ctx.enable_url_table().session_id());
2345        Ok(())
2346    }
2347
2348    struct MyPhysicalPlanner {}
2349
2350    #[async_trait]
2351    impl PhysicalPlanner for MyPhysicalPlanner {
2352        async fn create_physical_plan(
2353            &self,
2354            _logical_plan: &LogicalPlan,
2355            _session_state: &SessionState,
2356        ) -> Result<Arc<dyn ExecutionPlan>> {
2357            not_impl_err!("query not supported")
2358        }
2359
2360        fn create_physical_expr(
2361            &self,
2362            _expr: &Expr,
2363            _input_dfschema: &DFSchema,
2364            _session_state: &SessionState,
2365        ) -> Result<Arc<dyn PhysicalExpr>> {
2366            unimplemented!()
2367        }
2368    }
2369
2370    #[derive(Debug)]
2371    struct MyQueryPlanner {}
2372
2373    #[async_trait]
2374    impl QueryPlanner for MyQueryPlanner {
2375        async fn create_physical_plan(
2376            &self,
2377            logical_plan: &LogicalPlan,
2378            session_state: &SessionState,
2379        ) -> Result<Arc<dyn ExecutionPlan>> {
2380            let physical_planner = MyPhysicalPlanner {};
2381            physical_planner
2382                .create_physical_plan(logical_plan, session_state)
2383                .await
2384        }
2385    }
2386
2387    /// Generate a partitioned CSV file and register it with an execution context
2388    async fn create_ctx(
2389        tmp_dir: &TempDir,
2390        partition_count: usize,
2391    ) -> Result<SessionContext> {
2392        let ctx = SessionContext::new_with_config(
2393            SessionConfig::new().with_target_partitions(8),
2394        );
2395
2396        let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?;
2397
2398        // register csv file with the execution context
2399        ctx.register_csv(
2400            "test",
2401            tmp_dir.path().to_str().unwrap(),
2402            CsvReadOptions::new().schema(&schema),
2403        )
2404        .await?;
2405
2406        Ok(ctx)
2407    }
2408
2409    #[derive(Debug)]
2410    struct MyTypePlanner {}
2411
2412    impl TypePlanner for MyTypePlanner {
2413        fn plan_type(&self, sql_type: &ast::DataType) -> Result<Option<DataType>> {
2414            match sql_type {
2415                ast::DataType::Datetime(precision) => {
2416                    let precision = match precision {
2417                        Some(0) => TimeUnit::Second,
2418                        Some(3) => TimeUnit::Millisecond,
2419                        Some(6) => TimeUnit::Microsecond,
2420                        None | Some(9) => TimeUnit::Nanosecond,
2421                        _ => unreachable!(),
2422                    };
2423                    Ok(Some(DataType::Timestamp(precision, None)))
2424                }
2425                _ => Ok(None),
2426            }
2427        }
2428    }
2429}