Skip to main content

datafusion/execution/context/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SessionContext`] API for registering data sources and executing queries
19
20use std::collections::HashSet;
21use std::fmt::Debug;
22use std::sync::{Arc, Weak};
23use std::time::Duration;
24
25use super::options::ReadOptions;
26use crate::datasource::dynamic_file::DynamicListTableFactory;
27use crate::execution::session_state::SessionStateBuilder;
28use crate::{
29    catalog::listing_schema::ListingSchemaProvider,
30    catalog::{
31        CatalogProvider, CatalogProviderList, TableProvider, TableProviderFactory,
32    },
33    dataframe::DataFrame,
34    datasource::listing::{
35        ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
36    },
37    datasource::{MemTable, ViewTable, provider_as_source},
38    error::Result,
39    execution::{
40        FunctionRegistry,
41        options::ArrowReadOptions,
42        runtime_env::{RuntimeEnv, RuntimeEnvBuilder},
43    },
44    logical_expr::AggregateUDF,
45    logical_expr::ScalarUDF,
46    logical_expr::{
47        CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
48        CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
49        DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, ResetVariable,
50        SetVariable, TableType, UNNAMED_TABLE,
51    },
52    physical_expr::PhysicalExpr,
53    physical_plan::ExecutionPlan,
54    variable::{VarProvider, VarType},
55};
56
57// backwards compatibility
58pub use crate::execution::session_state::SessionState;
59
60use arrow::datatypes::{Schema, SchemaRef};
61use arrow::record_batch::RecordBatch;
62use datafusion_catalog::MemoryCatalogProvider;
63use datafusion_catalog::memory::MemorySchemaProvider;
64use datafusion_catalog::{
65    DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory,
66};
67use datafusion_common::config::{ConfigField, ConfigOptions};
68use datafusion_common::metadata::ScalarAndMetadata;
69use datafusion_common::{
70    DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference,
71    config::{ConfigExtension, TableOptions},
72    exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err,
73    plan_datafusion_err, plan_err,
74    tree_node::{TreeNodeRecursion, TreeNodeVisitor},
75};
76pub use datafusion_execution::TaskContext;
77use datafusion_execution::cache::cache_manager::{
78    DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT, DEFAULT_LIST_FILES_CACHE_TTL,
79    DEFAULT_METADATA_CACHE_LIMIT,
80};
81pub use datafusion_execution::config::SessionConfig;
82use datafusion_execution::disk_manager::{
83    DEFAULT_MAX_TEMP_DIRECTORY_SIZE, DiskManagerBuilder,
84};
85use datafusion_execution::registry::SerializerRegistry;
86pub use datafusion_expr::execution_props::ExecutionProps;
87#[cfg(feature = "sql")]
88use datafusion_expr::planner::RelationPlanner;
89use datafusion_expr::simplify::SimplifyContext;
90use datafusion_expr::{
91    Expr, UserDefinedLogicalNode, WindowUDF,
92    expr_rewriter::FunctionRewrite,
93    logical_plan::{DdlStatement, Statement},
94    planner::ExprPlanner,
95};
96use datafusion_optimizer::Analyzer;
97use datafusion_optimizer::analyzer::type_coercion::TypeCoercion;
98use datafusion_optimizer::simplify_expressions::ExprSimplifier;
99use datafusion_optimizer::{AnalyzerRule, OptimizerRule};
100use datafusion_session::SessionStore;
101
102use async_trait::async_trait;
103use chrono::{DateTime, Utc};
104use object_store::ObjectStore;
105use parking_lot::RwLock;
106use url::Url;
107
108mod csv;
109mod json;
110#[cfg(feature = "parquet")]
111mod parquet;
112
113#[cfg(feature = "avro")]
114mod avro;
115
116/// DataFilePaths adds a method to convert strings and vector of strings to vector of [`ListingTableUrl`] URLs.
117/// This allows methods such [`SessionContext::read_csv`] and [`SessionContext::read_avro`]
118/// to take either a single file or multiple files.
119pub trait DataFilePaths {
120    /// Parse to a vector of [`ListingTableUrl`] URLs.
121    fn to_urls(self) -> Result<Vec<ListingTableUrl>>;
122}
123
124impl DataFilePaths for &str {
125    fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
126        Ok(vec![ListingTableUrl::parse(self)?])
127    }
128}
129
130impl DataFilePaths for String {
131    fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
132        Ok(vec![ListingTableUrl::parse(self)?])
133    }
134}
135
136impl DataFilePaths for &String {
137    fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
138        Ok(vec![ListingTableUrl::parse(self)?])
139    }
140}
141
142impl<P> DataFilePaths for Vec<P>
143where
144    P: AsRef<str>,
145{
146    fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
147        self.iter()
148            .map(ListingTableUrl::parse)
149            .collect::<Result<Vec<ListingTableUrl>>>()
150    }
151}
152
153/// Main interface for executing queries with DataFusion. Maintains
154/// the state of the connection between a user and an instance of the
155/// DataFusion engine.
156///
157/// See examples below for how to use the `SessionContext` to execute queries
158/// and how to configure the session.
159///
160/// # Overview
161///
162/// [`SessionContext`] provides the following functionality:
163///
164/// * Create a [`DataFrame`] from a CSV or Parquet data source.
165/// * Register a CSV or Parquet data source as a table that can be referenced from a SQL query.
166/// * Register a custom data source that can be referenced from a SQL query.
167/// * Execution a SQL query
168///
169/// # Example: DataFrame API
170///
171/// The following example demonstrates how to use the context to execute a query against a CSV
172/// data source using the [`DataFrame`] API:
173///
174/// ```
175/// use datafusion::prelude::*;
176/// # use datafusion::functions_aggregate::expr_fn::min;
177/// # use datafusion::{error::Result, assert_batches_eq};
178/// # #[tokio::main]
179/// # async fn main() -> Result<()> {
180/// let ctx = SessionContext::new();
181/// let df = ctx
182///     .read_csv("tests/data/example.csv", CsvReadOptions::new())
183///     .await?;
184/// let df = df
185///     .filter(col("a").lt_eq(col("b")))?
186///     .aggregate(vec![col("a")], vec![min(col("b"))])?
187///     .limit(0, Some(100))?;
188/// let results = df.collect().await?;
189/// assert_batches_eq!(
190///     &[
191///         "+---+----------------+",
192///         "| a | min(?table?.b) |",
193///         "+---+----------------+",
194///         "| 1 | 2              |",
195///         "+---+----------------+",
196///     ],
197///     &results
198/// );
199/// # Ok(())
200/// # }
201/// ```
202///
203/// # Example: SQL API
204///
205/// The following example demonstrates how to execute the same query using SQL:
206///
207/// ```
208/// use datafusion::prelude::*;
209/// # use datafusion::{error::Result, assert_batches_eq};
210/// # #[tokio::main]
211/// # async fn main() -> Result<()> {
212/// let ctx = SessionContext::new();
213/// ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new())
214///     .await?;
215/// let results = ctx
216///     .sql("SELECT a, min(b) FROM example GROUP BY a LIMIT 100")
217///     .await?
218///     .collect()
219///     .await?;
220/// assert_batches_eq!(
221///     &[
222///         "+---+----------------+",
223///         "| a | min(example.b) |",
224///         "+---+----------------+",
225///         "| 1 | 2              |",
226///         "+---+----------------+",
227///     ],
228///     &results
229/// );
230/// # Ok(())
231/// # }
232/// ```
233///
234/// # Example: Configuring `SessionContext`
235///
236/// The `SessionContext` can be configured by creating a [`SessionState`] using
237/// [`SessionStateBuilder`]:
238///
239/// ```
240/// # use std::sync::Arc;
241/// # use datafusion::prelude::*;
242/// # use datafusion::execution::SessionStateBuilder;
243/// # use datafusion_execution::runtime_env::RuntimeEnvBuilder;
244/// // Configure a 4k batch size
245/// let config = SessionConfig::new().with_batch_size(4 * 1024);
246///
247/// // configure a memory limit of 1GB with 20%  slop
248/// let runtime_env = RuntimeEnvBuilder::new()
249///     .with_memory_limit(1024 * 1024 * 1024, 0.80)
250///     .build_arc()
251///     .unwrap();
252///
253/// // Create a SessionState using the config and runtime_env
254/// let state = SessionStateBuilder::new()
255///     .with_config(config)
256///     .with_runtime_env(runtime_env)
257///     // include support for built in functions and configurations
258///     .with_default_features()
259///     .build();
260///
261/// // Create a SessionContext
262/// let ctx = SessionContext::from(state);
263/// ```
264///
265/// # Relationship between `SessionContext`, `SessionState`, and `TaskContext`
266///
267/// The state required to optimize, and evaluate queries is
268/// broken into three levels to allow tailoring
269///
270/// The objects are:
271///
272/// 1. [`SessionContext`]: Most users should use a `SessionContext`. It contains
273///    all information required to execute queries including  high level APIs such
274///    as [`SessionContext::sql`]. All queries run with the same `SessionContext`
275///    share the same configuration and resources (e.g. memory limits).
276///
277/// 2. [`SessionState`]: contains information required to plan and execute an
278///    individual query (e.g. creating a [`LogicalPlan`] or [`ExecutionPlan`]).
279///    Each query is planned and executed using its own `SessionState`, which can
280///    be created with [`SessionContext::state`]. `SessionState` allows finer
281///    grained control over query execution, for example disallowing DDL operations
282///    such as `CREATE TABLE`.
283///
284/// 3. [`TaskContext`] contains the state required for query execution (e.g.
285///    [`ExecutionPlan::execute`]). It contains a subset of information in
286///    [`SessionState`]. `TaskContext` allows executing [`ExecutionPlan`]s
287///    [`PhysicalExpr`]s without requiring a full [`SessionState`].
288///
289/// [`PhysicalExpr`]: crate::physical_expr::PhysicalExpr
290#[derive(Clone)]
291pub struct SessionContext {
292    /// UUID for the session
293    session_id: String,
294    /// Session start time
295    session_start_time: DateTime<Utc>,
296    /// Shared session state for the session
297    state: Arc<RwLock<SessionState>>,
298}
299
300impl Default for SessionContext {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306impl SessionContext {
307    /// Creates a new `SessionContext` using the default [`SessionConfig`].
308    pub fn new() -> Self {
309        Self::new_with_config(SessionConfig::new())
310    }
311
312    /// Finds any [`ListingSchemaProvider`]s and instructs them to reload tables from "disk"
313    pub async fn refresh_catalogs(&self) -> Result<()> {
314        let cat_names = self.catalog_names().clone();
315        for cat_name in cat_names.iter() {
316            let cat = self
317                .catalog(cat_name.as_str())
318                .ok_or_else(|| internal_datafusion_err!("Catalog not found!"))?;
319            for schema_name in cat.schema_names() {
320                let schema = cat
321                    .schema(schema_name.as_str())
322                    .ok_or_else(|| internal_datafusion_err!("Schema not found!"))?;
323                let lister = schema.as_any().downcast_ref::<ListingSchemaProvider>();
324                if let Some(lister) = lister {
325                    lister.refresh(&self.state()).await?;
326                }
327            }
328        }
329        Ok(())
330    }
331
332    /// Creates a new `SessionContext` using the provided
333    /// [`SessionConfig`] and a new [`RuntimeEnv`].
334    ///
335    /// See [`Self::new_with_config_rt`] for more details on resource
336    /// limits.
337    pub fn new_with_config(config: SessionConfig) -> Self {
338        let runtime = Arc::new(RuntimeEnv::default());
339        Self::new_with_config_rt(config, runtime)
340    }
341
342    /// Creates a new `SessionContext` using the provided
343    /// [`SessionConfig`] and a [`RuntimeEnv`].
344    ///
345    /// # Resource Limits
346    ///
347    /// By default, each new `SessionContext` creates a new
348    /// `RuntimeEnv`, and therefore will not enforce memory or disk
349    /// limits for queries run on different `SessionContext`s.
350    ///
351    /// To enforce resource limits (e.g. to limit the total amount of
352    /// memory used) across all DataFusion queries in a process,
353    /// all `SessionContext`'s should be configured with the
354    /// same `RuntimeEnv`.
355    pub fn new_with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
356        let state = SessionStateBuilder::new()
357            .with_config(config)
358            .with_runtime_env(runtime)
359            .with_default_features()
360            .build();
361        Self::new_with_state(state)
362    }
363
364    /// Creates a new `SessionContext` using the provided [`SessionState`]
365    pub fn new_with_state(state: SessionState) -> Self {
366        Self {
367            session_id: state.session_id().to_string(),
368            session_start_time: Utc::now(),
369            state: Arc::new(RwLock::new(state)),
370        }
371    }
372
373    /// Enable querying local files as tables.
374    ///
375    /// This feature is security sensitive and should only be enabled for
376    /// systems that wish to permit direct access to the file system from SQL.
377    ///
378    /// When enabled, this feature permits direct access to arbitrary files via
379    /// SQL like
380    ///
381    /// ```sql
382    /// SELECT * from 'my_file.parquet'
383    /// ```
384    ///
385    /// See [DynamicFileCatalog] for more details
386    ///
387    /// ```
388    /// # use datafusion::prelude::*;
389    /// # use datafusion::{error::Result, assert_batches_eq};
390    /// # #[tokio::main]
391    /// # async fn main() -> Result<()> {
392    /// let ctx = SessionContext::new()
393    ///   .enable_url_table(); // permit local file access
394    /// let results = ctx
395    ///   .sql("SELECT a, MIN(b) FROM 'tests/data/example.csv' as example GROUP BY a LIMIT 100")
396    ///   .await?
397    ///   .collect()
398    ///   .await?;
399    /// assert_batches_eq!(
400    ///  &[
401    ///    "+---+----------------+",
402    ///    "| a | min(example.b) |",
403    ///    "+---+----------------+",
404    ///    "| 1 | 2              |",
405    ///    "+---+----------------+",
406    ///  ],
407    ///  &results
408    /// );
409    /// # Ok(())
410    /// # }
411    /// ```
412    pub fn enable_url_table(self) -> Self {
413        let current_catalog_list = Arc::clone(self.state.read().catalog_list());
414        let factory = Arc::new(DynamicListTableFactory::new(SessionStore::new()));
415        let catalog_list = Arc::new(DynamicFileCatalog::new(
416            current_catalog_list,
417            Arc::clone(&factory) as Arc<dyn UrlTableFactory>,
418        ));
419
420        let session_id = self.session_id.clone();
421        let ctx: SessionContext = self
422            .into_state_builder()
423            .with_session_id(session_id)
424            .with_catalog_list(catalog_list)
425            .build()
426            .into();
427        // register new state with the factory
428        factory.session_store().with_state(ctx.state_weak_ref());
429        ctx
430    }
431
432    /// Convert the current `SessionContext` into a [`SessionStateBuilder`]
433    ///
434    /// This is useful to switch back to `SessionState` with custom settings such as
435    /// [`Self::enable_url_table`].
436    ///
437    /// Avoids cloning the SessionState if possible.
438    ///
439    /// # Example
440    /// ```
441    /// # use std::sync::Arc;
442    /// # use datafusion::prelude::*;
443    /// # use datafusion::execution::SessionStateBuilder;
444    /// # use datafusion_optimizer::push_down_filter::PushDownFilter;
445    /// let my_rule = PushDownFilter {}; // pretend it is a new rule
446    ///                                  // Create a new builder with a custom optimizer rule
447    /// let context: SessionContext = SessionStateBuilder::new()
448    ///     .with_optimizer_rule(Arc::new(my_rule))
449    ///     .build()
450    ///     .into();
451    /// // Enable local file access and convert context back to a builder
452    /// let builder = context.enable_url_table().into_state_builder();
453    /// ```
454    pub fn into_state_builder(self) -> SessionStateBuilder {
455        let SessionContext {
456            session_id: _,
457            session_start_time: _,
458            state,
459        } = self;
460        let state = match Arc::try_unwrap(state) {
461            Ok(rwlock) => rwlock.into_inner(),
462            Err(state) => state.read().clone(),
463        };
464        SessionStateBuilder::from(state)
465    }
466
467    /// Returns the time this `SessionContext` was created
468    pub fn session_start_time(&self) -> DateTime<Utc> {
469        self.session_start_time
470    }
471
472    /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements
473    pub fn with_function_factory(
474        self,
475        function_factory: Arc<dyn FunctionFactory>,
476    ) -> Self {
477        self.state.write().set_function_factory(function_factory);
478        self
479    }
480
481    /// Adds an optimizer rule to the end of the existing rules.
482    ///
483    /// See [`SessionState`] for more control of when the rule is applied.
484    pub fn add_optimizer_rule(
485        &self,
486        optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
487    ) {
488        self.state.write().append_optimizer_rule(optimizer_rule);
489    }
490
491    /// Removes an optimizer rule by name, returning `true` if it existed.
492    pub fn remove_optimizer_rule(&self, name: &str) -> bool {
493        self.state.write().remove_optimizer_rule(name)
494    }
495
496    /// Adds an analyzer rule to the end of the existing rules.
497    ///
498    /// See [`SessionState`] for more control of when the rule is applied.
499    pub fn add_analyzer_rule(&self, analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>) {
500        self.state.write().add_analyzer_rule(analyzer_rule);
501    }
502
503    /// Registers an [`ObjectStore`] to be used with a specific URL prefix.
504    ///
505    /// See [`RuntimeEnv::register_object_store`] for more details.
506    ///
507    /// # Example: register a local object store for the "file://" URL prefix
508    /// ```
509    /// # use std::sync::Arc;
510    /// # use datafusion::prelude::SessionContext;
511    /// # use datafusion_execution::object_store::ObjectStoreUrl;
512    /// let object_store_url = ObjectStoreUrl::parse("file://").unwrap();
513    /// let object_store = object_store::local::LocalFileSystem::new();
514    /// let ctx = SessionContext::new();
515    /// // All files with the file:// url prefix will be read from the local file system
516    /// ctx.register_object_store(object_store_url.as_ref(), Arc::new(object_store));
517    /// ```
518    pub fn register_object_store(
519        &self,
520        url: &Url,
521        object_store: Arc<dyn ObjectStore>,
522    ) -> Option<Arc<dyn ObjectStore>> {
523        self.runtime_env().register_object_store(url, object_store)
524    }
525
526    /// Deregisters an [`ObjectStore`] associated with the specific URL prefix.
527    ///
528    /// See [`RuntimeEnv::deregister_object_store`] for more details.
529    pub fn deregister_object_store(&self, url: &Url) -> Result<Arc<dyn ObjectStore>> {
530        self.runtime_env().deregister_object_store(url)
531    }
532
533    /// Registers the [`RecordBatch`] as the specified table name
534    pub fn register_batch(
535        &self,
536        table_name: &str,
537        batch: RecordBatch,
538    ) -> Result<Option<Arc<dyn TableProvider>>> {
539        let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
540        self.register_table(
541            TableReference::Bare {
542                table: table_name.into(),
543            },
544            Arc::new(table),
545        )
546    }
547
548    /// Return the [RuntimeEnv] used to run queries with this `SessionContext`
549    pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
550        Arc::clone(self.state.read().runtime_env())
551    }
552
553    /// Returns an id that uniquely identifies this `SessionContext`.
554    pub fn session_id(&self) -> String {
555        self.session_id.clone()
556    }
557
558    /// Return the [`TableProviderFactory`] that is registered for the
559    /// specified file type, if any.
560    pub fn table_factory(
561        &self,
562        file_type: &str,
563    ) -> Option<Arc<dyn TableProviderFactory>> {
564        self.state.read().table_factories().get(file_type).cloned()
565    }
566
567    /// Return the `enable_ident_normalization` of this Session
568    pub fn enable_ident_normalization(&self) -> bool {
569        self.state
570            .read()
571            .config()
572            .options()
573            .sql_parser
574            .enable_ident_normalization
575    }
576
577    /// Return a copied version of config for this Session
578    pub fn copied_config(&self) -> SessionConfig {
579        self.state.read().config().clone()
580    }
581
582    /// Return a copied version of table options for this Session
583    pub fn copied_table_options(&self) -> TableOptions {
584        self.state.read().default_table_options()
585    }
586
587    /// Creates a [`DataFrame`] from SQL query text.
588    ///
589    /// Note: This API implements DDL statements such as `CREATE TABLE` and
590    /// `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory
591    /// default implementations. See [`Self::sql_with_options`].
592    ///
593    /// # Example: Running SQL queries
594    ///
595    /// See the example on [`Self`]
596    ///
597    /// # Example: Creating a Table with SQL
598    ///
599    /// ```
600    /// use datafusion::prelude::*;
601    /// # use datafusion::{error::Result, assert_batches_eq};
602    /// # #[tokio::main]
603    /// # async fn main() -> Result<()> {
604    /// let ctx = SessionContext::new();
605    /// ctx.sql("CREATE TABLE foo (x INTEGER)")
606    ///     .await?
607    ///     .collect()
608    ///     .await?;
609    /// assert!(ctx.table_exist("foo").unwrap());
610    /// # Ok(())
611    /// # }
612    /// ```
613    #[cfg(feature = "sql")]
614    pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
615        self.sql_with_options(sql, SQLOptions::new()).await
616    }
617
618    /// Creates a [`DataFrame`] from SQL query text, first validating
619    /// that the queries are allowed by `options`
620    ///
621    /// # Example: Preventing Creating a Table with SQL
622    ///
623    /// If you want to avoid creating tables, or modifying data or the
624    /// session, set [`SQLOptions`] appropriately:
625    ///
626    /// ```
627    /// use datafusion::prelude::*;
628    /// # use datafusion::{error::Result};
629    /// # use datafusion::physical_plan::collect;
630    /// # #[tokio::main]
631    /// # async fn main() -> Result<()> {
632    /// let ctx = SessionContext::new();
633    /// let options = SQLOptions::new().with_allow_ddl(false);
634    /// let err = ctx
635    ///     .sql_with_options("CREATE TABLE foo (x INTEGER)", options)
636    ///     .await
637    ///     .unwrap_err();
638    /// assert!(err
639    ///     .to_string()
640    ///     .starts_with("Error during planning: DDL not supported: CreateMemoryTable"));
641    /// # Ok(())
642    /// # }
643    /// ```
644    #[cfg(feature = "sql")]
645    pub async fn sql_with_options(
646        &self,
647        sql: &str,
648        options: SQLOptions,
649    ) -> Result<DataFrame> {
650        let plan = self.state().create_logical_plan(sql).await?;
651        options.verify_plan(&plan)?;
652
653        self.execute_logical_plan(plan).await
654    }
655
656    /// Creates logical expressions from SQL query text.
657    ///
658    /// # Example: Parsing SQL queries
659    ///
660    /// ```
661    /// # use arrow::datatypes::{DataType, Field, Schema};
662    /// # use datafusion::prelude::*;
663    /// # use datafusion_common::{DFSchema, Result};
664    /// # #[tokio::main]
665    /// # async fn main() -> Result<()> {
666    /// // datafusion will parse number as i64 first.
667    /// let sql = "a > 10";
668    /// let expected = col("a").gt(lit(10 as i64));
669    /// // provide type information that `a` is an Int32
670    /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
671    /// let df_schema = DFSchema::try_from(schema).unwrap();
672    /// let expr = SessionContext::new().parse_sql_expr(sql, &df_schema)?;
673    /// assert_eq!(expected, expr);
674    /// # Ok(())
675    /// # }
676    /// ```
677    #[cfg(feature = "sql")]
678    pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result<Expr> {
679        self.state.read().create_logical_expr(sql, df_schema)
680    }
681
682    /// Execute the [`LogicalPlan`], return a [`DataFrame`]. This API
683    /// is not featured limited (so all SQL such as `CREATE TABLE` and
684    /// `COPY` will be run).
685    ///
686    /// If you wish to limit the type of plan that can be run from
687    /// SQL, see [`Self::sql_with_options`] and
688    /// [`SQLOptions::verify_plan`].
689    pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
690        match plan {
691            LogicalPlan::Ddl(ddl) => {
692                // Box::pin avoids allocating the stack space within this function's frame
693                // for every one of these individual async functions, decreasing the risk of
694                // stack overflows.
695                match ddl {
696                    DdlStatement::CreateExternalTable(cmd) => {
697                        (Box::pin(async move { self.create_external_table(&cmd).await })
698                            as std::pin::Pin<Box<dyn Future<Output = _> + Send>>)
699                            .await
700                    }
701                    DdlStatement::CreateMemoryTable(cmd) => {
702                        Box::pin(self.create_memory_table(cmd)).await
703                    }
704                    DdlStatement::CreateView(cmd) => {
705                        Box::pin(self.create_view(cmd)).await
706                    }
707                    DdlStatement::CreateCatalogSchema(cmd) => {
708                        Box::pin(self.create_catalog_schema(cmd)).await
709                    }
710                    DdlStatement::CreateCatalog(cmd) => {
711                        Box::pin(self.create_catalog(cmd)).await
712                    }
713                    DdlStatement::DropTable(cmd) => Box::pin(self.drop_table(cmd)).await,
714                    DdlStatement::DropView(cmd) => Box::pin(self.drop_view(cmd)).await,
715                    DdlStatement::DropCatalogSchema(cmd) => {
716                        Box::pin(self.drop_schema(cmd)).await
717                    }
718                    DdlStatement::CreateFunction(cmd) => {
719                        Box::pin(self.create_function(cmd)).await
720                    }
721                    DdlStatement::DropFunction(cmd) => {
722                        Box::pin(self.drop_function(cmd)).await
723                    }
724                    ddl => Ok(DataFrame::new(self.state(), LogicalPlan::Ddl(ddl))),
725                }
726            }
727            // TODO what about the other statements (like TransactionStart and TransactionEnd)
728            LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
729                self.set_variable(stmt).await?;
730                self.return_empty_dataframe()
731            }
732            LogicalPlan::Statement(Statement::ResetVariable(stmt)) => {
733                self.reset_variable(stmt).await?;
734                self.return_empty_dataframe()
735            }
736            LogicalPlan::Statement(Statement::Prepare(Prepare {
737                name,
738                input,
739                fields,
740            })) => {
741                // The number of parameters must match the specified data types length.
742                if !fields.is_empty() {
743                    let param_names = input.get_parameter_names()?;
744                    if param_names.len() != fields.len() {
745                        return plan_err!(
746                            "Prepare specifies {} data types but query has {} parameters",
747                            fields.len(),
748                            param_names.len()
749                        );
750                    }
751                }
752                // Store the unoptimized plan into the session state. Although storing the
753                // optimized plan or the physical plan would be more efficient, doing so is
754                // not currently feasible. This is because `now()` would be optimized to a
755                // constant value, causing each EXECUTE to yield the same result, which is
756                // incorrect behavior.
757                self.state.write().store_prepared(name, fields, input)?;
758                self.return_empty_dataframe()
759            }
760            LogicalPlan::Statement(Statement::Execute(execute)) => {
761                self.execute_prepared(execute)
762            }
763            LogicalPlan::Statement(Statement::Deallocate(deallocate)) => {
764                self.state
765                    .write()
766                    .remove_prepared(deallocate.name.as_str())?;
767                self.return_empty_dataframe()
768            }
769            plan => Ok(DataFrame::new(self.state(), plan)),
770        }
771    }
772
773    /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type
774    /// coercion and function rewrites.
775    ///
776    /// Note: The expression is not [simplified] or otherwise optimized:
777    /// `a = 1 + 2` will not be simplified to `a = 3` as this is a more involved process.
778    /// See the [expr_api] example for how to simplify expressions.
779    ///
780    /// # Example
781    /// ```
782    /// # use std::sync::Arc;
783    /// # use arrow::datatypes::{DataType, Field, Schema};
784    /// # use datafusion::prelude::*;
785    /// # use datafusion_common::DFSchema;
786    /// // a = 1 (i64)
787    /// let expr = col("a").eq(lit(1i64));
788    /// // provide type information that `a` is an Int32
789    /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
790    /// let df_schema = DFSchema::try_from(schema).unwrap();
791    /// // Create a PhysicalExpr. Note DataFusion automatically coerces (casts) `1i64` to `1i32`
792    /// let physical_expr = SessionContext::new()
793    ///   .create_physical_expr(expr, &df_schema).unwrap();
794    /// ```
795    /// # See Also
796    /// * [`SessionState::create_physical_expr`] for a lower level API
797    ///
798    /// [simplified]: datafusion_optimizer::simplify_expressions
799    /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs
800    pub fn create_physical_expr(
801        &self,
802        expr: Expr,
803        df_schema: &DFSchema,
804    ) -> Result<Arc<dyn PhysicalExpr>> {
805        self.state.read().create_physical_expr(expr, df_schema)
806    }
807
808    // return an empty dataframe
809    fn return_empty_dataframe(&self) -> Result<DataFrame> {
810        let plan = LogicalPlanBuilder::empty(false).build()?;
811        Ok(DataFrame::new(self.state(), plan))
812    }
813
814    async fn create_external_table(
815        &self,
816        cmd: &CreateExternalTable,
817    ) -> Result<DataFrame> {
818        let exist = self.table_exist(cmd.name.clone())?;
819
820        if cmd.temporary {
821            return not_impl_err!("Temporary tables not supported");
822        }
823
824        match (cmd.if_not_exists, cmd.or_replace, exist) {
825            (true, false, true) => self.return_empty_dataframe(),
826            (false, true, true) => {
827                let result = self
828                    .find_and_deregister(cmd.name.clone(), TableType::Base)
829                    .await;
830
831                match result {
832                    Ok(true) => {
833                        let table_provider: Arc<dyn TableProvider> =
834                            self.create_custom_table(cmd).await?;
835                        self.register_table(cmd.name.clone(), table_provider)?;
836                        self.return_empty_dataframe()
837                    }
838                    Ok(false) => {
839                        let table_provider: Arc<dyn TableProvider> =
840                            self.create_custom_table(cmd).await?;
841                        self.register_table(cmd.name.clone(), table_provider)?;
842                        self.return_empty_dataframe()
843                    }
844                    Err(e) => {
845                        exec_err!("Errored while deregistering external table: {}", e)
846                    }
847                }
848            }
849            (true, true, true) => {
850                exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'")
851            }
852            (_, _, false) => {
853                let table_provider: Arc<dyn TableProvider> =
854                    self.create_custom_table(cmd).await?;
855                self.register_table(cmd.name.clone(), table_provider)?;
856                self.return_empty_dataframe()
857            }
858            (false, false, true) => {
859                exec_err!("External table '{}' already exists", cmd.name)
860            }
861        }
862    }
863
864    async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result<DataFrame> {
865        let CreateMemoryTable {
866            name,
867            input,
868            if_not_exists,
869            or_replace,
870            constraints,
871            column_defaults,
872            temporary,
873        } = cmd;
874
875        let input = Arc::unwrap_or_clone(input);
876        let input = self.state().optimize(&input)?;
877
878        if temporary {
879            return not_impl_err!("Temporary tables not supported");
880        }
881
882        let table = self.table(name.clone()).await;
883        match (if_not_exists, or_replace, table) {
884            (true, false, Ok(_)) => self.return_empty_dataframe(),
885            (false, true, Ok(_)) => {
886                self.deregister_table(name.clone())?;
887                let schema = Arc::clone(input.schema().inner());
888                let physical = DataFrame::new(self.state(), input);
889
890                let batches: Vec<_> = physical.collect_partitioned().await?;
891                let table = Arc::new(
892                    // pass constraints and column defaults to the mem table.
893                    MemTable::try_new(schema, batches)?
894                        .with_constraints(constraints)
895                        .with_column_defaults(column_defaults.into_iter().collect()),
896                );
897
898                self.register_table(name.clone(), table)?;
899                self.return_empty_dataframe()
900            }
901            (true, true, Ok(_)) => {
902                exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'")
903            }
904            (_, _, Err(_)) => {
905                let schema = Arc::clone(input.schema().inner());
906                let physical = DataFrame::new(self.state(), input);
907
908                let batches: Vec<_> = physical.collect_partitioned().await?;
909                let table = Arc::new(
910                    // pass constraints and column defaults to the mem table.
911                    MemTable::try_new(schema, batches)?
912                        .with_constraints(constraints)
913                        .with_column_defaults(column_defaults.into_iter().collect()),
914                );
915
916                self.register_table(name, table)?;
917                self.return_empty_dataframe()
918            }
919            (false, false, Ok(_)) => exec_err!("Table '{name}' already exists"),
920        }
921    }
922
923    /// Applies the `TypeCoercion` rewriter to the logical plan.
924    fn apply_type_coercion(logical_plan: LogicalPlan) -> Result<LogicalPlan> {
925        let options = ConfigOptions::default();
926        Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]).execute_and_check(
927            logical_plan,
928            &options,
929            |_, _| {},
930        )
931    }
932
933    async fn create_view(&self, cmd: CreateView) -> Result<DataFrame> {
934        let CreateView {
935            name,
936            input,
937            or_replace,
938            definition,
939            temporary,
940        } = cmd;
941
942        let view = self.table(name.clone()).await;
943
944        if temporary {
945            return not_impl_err!("Temporary views not supported");
946        }
947
948        match (or_replace, view) {
949            (true, Ok(_)) => {
950                self.deregister_table(name.clone())?;
951                let input = Self::apply_type_coercion(input.as_ref().clone())?;
952                let table = Arc::new(ViewTable::new(input, definition));
953                self.register_table(name, table)?;
954                self.return_empty_dataframe()
955            }
956            (_, Err(_)) => {
957                let input = Self::apply_type_coercion(input.as_ref().clone())?;
958                let table = Arc::new(ViewTable::new(input, definition));
959                self.register_table(name, table)?;
960                self.return_empty_dataframe()
961            }
962            (false, Ok(_)) => exec_err!("Table '{name}' already exists"),
963        }
964    }
965
966    async fn create_catalog_schema(&self, cmd: CreateCatalogSchema) -> Result<DataFrame> {
967        let CreateCatalogSchema {
968            schema_name,
969            if_not_exists,
970            ..
971        } = cmd;
972
973        // sqlparser doesn't accept database / catalog as parameter to CREATE SCHEMA
974        // so for now, we default to default catalog
975        let tokens: Vec<&str> = schema_name.split('.').collect();
976        let (catalog, schema_name) = match tokens.len() {
977            1 => {
978                let state = self.state.read();
979                let name = &state.config().options().catalog.default_catalog;
980                let catalog = state.catalog_list().catalog(name).ok_or_else(|| {
981                    exec_datafusion_err!("Missing default catalog '{name}'")
982                })?;
983                (catalog, tokens[0])
984            }
985            2 => {
986                let name = &tokens[0];
987                let catalog = self
988                    .catalog(name)
989                    .ok_or_else(|| exec_datafusion_err!("Missing catalog '{name}'"))?;
990                (catalog, tokens[1])
991            }
992            _ => return exec_err!("Unable to parse catalog from {schema_name}"),
993        };
994        let schema = catalog.schema(schema_name);
995
996        match (if_not_exists, schema) {
997            (true, Some(_)) => self.return_empty_dataframe(),
998            (true, None) | (false, None) => {
999                let schema = Arc::new(MemorySchemaProvider::new());
1000                catalog.register_schema(schema_name, schema)?;
1001                self.return_empty_dataframe()
1002            }
1003            (false, Some(_)) => exec_err!("Schema '{schema_name}' already exists"),
1004        }
1005    }
1006
1007    async fn create_catalog(&self, cmd: CreateCatalog) -> Result<DataFrame> {
1008        let CreateCatalog {
1009            catalog_name,
1010            if_not_exists,
1011            ..
1012        } = cmd;
1013        let catalog = self.catalog(catalog_name.as_str());
1014
1015        match (if_not_exists, catalog) {
1016            (true, Some(_)) => self.return_empty_dataframe(),
1017            (true, None) | (false, None) => {
1018                let new_catalog = Arc::new(MemoryCatalogProvider::new());
1019                self.state
1020                    .write()
1021                    .catalog_list()
1022                    .register_catalog(catalog_name, new_catalog);
1023                self.return_empty_dataframe()
1024            }
1025            (false, Some(_)) => exec_err!("Catalog '{catalog_name}' already exists"),
1026        }
1027    }
1028
1029    async fn drop_table(&self, cmd: DropTable) -> Result<DataFrame> {
1030        let DropTable {
1031            name, if_exists, ..
1032        } = cmd;
1033        let result = self
1034            .find_and_deregister(name.clone(), TableType::Base)
1035            .await;
1036        match (result, if_exists) {
1037            (Ok(true), _) => self.return_empty_dataframe(),
1038            (_, true) => self.return_empty_dataframe(),
1039            (_, _) => exec_err!("Table '{name}' doesn't exist."),
1040        }
1041    }
1042
1043    async fn drop_view(&self, cmd: DropView) -> Result<DataFrame> {
1044        let DropView {
1045            name, if_exists, ..
1046        } = cmd;
1047        let result = self
1048            .find_and_deregister(name.clone(), TableType::View)
1049            .await;
1050        match (result, if_exists) {
1051            (Ok(true), _) => self.return_empty_dataframe(),
1052            (_, true) => self.return_empty_dataframe(),
1053            (_, _) => exec_err!("View '{name}' doesn't exist."),
1054        }
1055    }
1056
1057    async fn drop_schema(&self, cmd: DropCatalogSchema) -> Result<DataFrame> {
1058        let DropCatalogSchema {
1059            name,
1060            if_exists: allow_missing,
1061            cascade,
1062            schema: _,
1063        } = cmd;
1064        let catalog = {
1065            let state = self.state.read();
1066            let catalog_name = match &name {
1067                SchemaReference::Full { catalog, .. } => catalog.to_string(),
1068                SchemaReference::Bare { .. } => {
1069                    state.config_options().catalog.default_catalog.to_string()
1070                }
1071            };
1072            if let Some(catalog) = state.catalog_list().catalog(&catalog_name) {
1073                catalog
1074            } else if allow_missing {
1075                return self.return_empty_dataframe();
1076            } else {
1077                return self.schema_doesnt_exist_err(&name);
1078            }
1079        };
1080        let dereg = catalog.deregister_schema(name.schema_name(), cascade)?;
1081        match (dereg, allow_missing) {
1082            (None, true) => self.return_empty_dataframe(),
1083            (None, false) => self.schema_doesnt_exist_err(&name),
1084            (Some(_), _) => self.return_empty_dataframe(),
1085        }
1086    }
1087
1088    fn schema_doesnt_exist_err(&self, schemaref: &SchemaReference) -> Result<DataFrame> {
1089        exec_err!("Schema '{schemaref}' doesn't exist.")
1090    }
1091
1092    async fn set_variable(&self, stmt: SetVariable) -> Result<()> {
1093        let SetVariable {
1094            variable, value, ..
1095        } = stmt;
1096
1097        // Check if this is a runtime configuration
1098        if variable.starts_with("datafusion.runtime.") {
1099            self.set_runtime_variable(&variable, &value)?;
1100        } else {
1101            let mut state = self.state.write();
1102            state.config_mut().options_mut().set(&variable, &value)?;
1103
1104            // Re-initialize any UDFs that depend on configuration
1105            // This allows both built-in and custom functions to respond to configuration changes
1106            let config_options = state.config().options();
1107
1108            // Collect updated UDFs in a separate vector
1109            let udfs_to_update: Vec<_> = state
1110                .scalar_functions()
1111                .values()
1112                .filter_map(|udf| {
1113                    udf.inner()
1114                        .with_updated_config(config_options)
1115                        .map(Arc::new)
1116                })
1117                .collect();
1118
1119            for udf in udfs_to_update {
1120                state.register_udf(udf)?;
1121            }
1122        }
1123
1124        Ok(())
1125    }
1126
1127    async fn reset_variable(&self, stmt: ResetVariable) -> Result<()> {
1128        let variable = stmt.variable;
1129        if variable.starts_with("datafusion.runtime.") {
1130            return self.reset_runtime_variable(&variable);
1131        }
1132
1133        let mut state = self.state.write();
1134        state.config_mut().options_mut().reset(&variable)?;
1135
1136        // Refresh UDFs to ensure configuration-dependent behavior updates
1137        let config_options = state.config().options();
1138        let udfs_to_update: Vec<_> = state
1139            .scalar_functions()
1140            .values()
1141            .filter_map(|udf| {
1142                udf.inner()
1143                    .with_updated_config(config_options)
1144                    .map(Arc::new)
1145            })
1146            .collect();
1147
1148        for udf in udfs_to_update {
1149            state.register_udf(udf)?;
1150        }
1151
1152        Ok(())
1153    }
1154
1155    fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> {
1156        let key = variable.strip_prefix("datafusion.runtime.").unwrap();
1157
1158        let mut state = self.state.write();
1159
1160        let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env());
1161        builder = match key {
1162            "memory_limit" => {
1163                let memory_limit = Self::parse_memory_limit(value)?;
1164                builder.with_memory_limit(memory_limit, 1.0)
1165            }
1166            "max_temp_directory_size" => {
1167                let directory_size = Self::parse_memory_limit(value)?;
1168                builder.with_max_temp_directory_size(directory_size as u64)
1169            }
1170            "temp_directory" => builder.with_temp_file_path(value),
1171            "metadata_cache_limit" => {
1172                let limit = Self::parse_memory_limit(value)?;
1173                builder.with_metadata_cache_limit(limit)
1174            }
1175            "list_files_cache_limit" => {
1176                let limit = Self::parse_memory_limit(value)?;
1177                builder.with_object_list_cache_limit(limit)
1178            }
1179            "list_files_cache_ttl" => {
1180                let duration = Self::parse_duration(value)?;
1181                builder.with_object_list_cache_ttl(Some(duration))
1182            }
1183            _ => return plan_err!("Unknown runtime configuration: {variable}"),
1184            // Remember to update `reset_runtime_variable()` when adding new options
1185        };
1186
1187        *state = SessionStateBuilder::from(state.clone())
1188            .with_runtime_env(Arc::new(builder.build()?))
1189            .build();
1190
1191        Ok(())
1192    }
1193
1194    fn reset_runtime_variable(&self, variable: &str) -> Result<()> {
1195        let key = variable.strip_prefix("datafusion.runtime.").unwrap();
1196
1197        let mut state = self.state.write();
1198
1199        let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env());
1200        match key {
1201            "memory_limit" => {
1202                builder.memory_pool = None;
1203            }
1204            "max_temp_directory_size" => {
1205                builder =
1206                    builder.with_max_temp_directory_size(DEFAULT_MAX_TEMP_DIRECTORY_SIZE);
1207            }
1208            "temp_directory" => {
1209                builder.disk_manager_builder = Some(DiskManagerBuilder::default());
1210            }
1211            "metadata_cache_limit" => {
1212                builder = builder.with_metadata_cache_limit(DEFAULT_METADATA_CACHE_LIMIT);
1213            }
1214            "list_files_cache_limit" => {
1215                builder = builder
1216                    .with_object_list_cache_limit(DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT);
1217            }
1218            "list_files_cache_ttl" => {
1219                builder =
1220                    builder.with_object_list_cache_ttl(DEFAULT_LIST_FILES_CACHE_TTL);
1221            }
1222            _ => return plan_err!("Unknown runtime configuration: {variable}"),
1223        };
1224
1225        *state = SessionStateBuilder::from(state.clone())
1226            .with_runtime_env(Arc::new(builder.build()?))
1227            .build();
1228
1229        Ok(())
1230    }
1231
1232    /// Parse memory limit from string to number of bytes
1233    /// Supports formats like '1.5G', '100M', '512K'
1234    ///
1235    /// # Examples
1236    /// ```
1237    /// use datafusion::execution::context::SessionContext;
1238    ///
1239    /// assert_eq!(
1240    ///     SessionContext::parse_memory_limit("1M").unwrap(),
1241    ///     1024 * 1024
1242    /// );
1243    /// assert_eq!(
1244    ///     SessionContext::parse_memory_limit("1.5G").unwrap(),
1245    ///     (1.5 * 1024.0 * 1024.0 * 1024.0) as usize
1246    /// );
1247    /// ```
1248    pub fn parse_memory_limit(limit: &str) -> Result<usize> {
1249        let (number, unit) = limit.split_at(limit.len() - 1);
1250        let number: f64 = number.parse().map_err(|_| {
1251            plan_datafusion_err!("Failed to parse number from memory limit '{limit}'")
1252        })?;
1253
1254        match unit {
1255            "K" => Ok((number * 1024.0) as usize),
1256            "M" => Ok((number * 1024.0 * 1024.0) as usize),
1257            "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize),
1258            _ => plan_err!("Unsupported unit '{unit}' in memory limit '{limit}'"),
1259        }
1260    }
1261
1262    fn parse_duration(duration: &str) -> Result<Duration> {
1263        let mut minutes = None;
1264        let mut seconds = None;
1265
1266        for duration in duration.split_inclusive(&['m', 's']) {
1267            let (number, unit) = duration.split_at(duration.len() - 1);
1268            let number: u64 = number.parse().map_err(|_| {
1269                plan_datafusion_err!("Failed to parse number from duration '{duration}'")
1270            })?;
1271
1272            match unit {
1273                "m" if minutes.is_none() && seconds.is_none() => minutes = Some(number),
1274                "s" if seconds.is_none() => seconds = Some(number),
1275                _ => plan_err!(
1276                    "Invalid duration, unit must be either 'm' (minutes), or 's' (seconds), and be in the correct order"
1277                )?,
1278            }
1279        }
1280
1281        let duration = Duration::from_secs(
1282            minutes.unwrap_or_default() * 60 + seconds.unwrap_or_default(),
1283        );
1284
1285        if duration.is_zero() {
1286            return plan_err!("Duration must be greater than 0 seconds");
1287        }
1288
1289        Ok(duration)
1290    }
1291
1292    async fn create_custom_table(
1293        &self,
1294        cmd: &CreateExternalTable,
1295    ) -> Result<Arc<dyn TableProvider>> {
1296        let state = self.state.read().clone();
1297        let file_type = cmd.file_type.to_uppercase();
1298        let factory =
1299            state
1300                .table_factories()
1301                .get(file_type.as_str())
1302                .ok_or_else(|| {
1303                    exec_datafusion_err!("Unable to find factory for {}", cmd.file_type)
1304                })?;
1305        let table = (*factory).create(&state, cmd).await?;
1306        Ok(table)
1307    }
1308
1309    async fn find_and_deregister(
1310        &self,
1311        table_ref: impl Into<TableReference>,
1312        table_type: TableType,
1313    ) -> Result<bool> {
1314        let table_ref = table_ref.into();
1315        let table = table_ref.table().to_owned();
1316        let maybe_schema = {
1317            let state = self.state.read();
1318            let resolved = state.resolve_table_ref(table_ref.clone());
1319            state
1320                .catalog_list()
1321                .catalog(&resolved.catalog)
1322                .and_then(|c| c.schema(&resolved.schema))
1323        };
1324
1325        if let Some(schema) = maybe_schema
1326            && let Some(table_provider) = schema.table(&table).await?
1327            && table_provider.table_type() == table_type
1328        {
1329            schema.deregister_table(&table)?;
1330            if table_type == TableType::Base
1331                && let Some(lfc) = self.runtime_env().cache_manager.get_list_files_cache()
1332            {
1333                lfc.drop_table_entries(&Some(table_ref))?;
1334            }
1335            return Ok(true);
1336        }
1337
1338        Ok(false)
1339    }
1340
1341    async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
1342        let function = {
1343            let state = self.state.read().clone();
1344            let function_factory = state.function_factory();
1345
1346            match function_factory {
1347                Some(f) => f.create(&state, stmt).await?,
1348                _ => {
1349                    return Err(DataFusionError::Configuration(
1350                        "Function factory has not been configured".to_string(),
1351                    ));
1352                }
1353            }
1354        };
1355
1356        match function {
1357            RegisterFunction::Scalar(f) => {
1358                self.state.write().register_udf(f)?;
1359            }
1360            RegisterFunction::Aggregate(f) => {
1361                self.state.write().register_udaf(f)?;
1362            }
1363            RegisterFunction::Window(f) => {
1364                self.state.write().register_udwf(f)?;
1365            }
1366            RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
1367        };
1368
1369        self.return_empty_dataframe()
1370    }
1371
1372    async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
1373        // we don't know function type at this point
1374        // decision has been made to drop all functions
1375        let mut dropped = false;
1376        dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
1377        dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
1378        dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();
1379        dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some();
1380
1381        // DROP FUNCTION IF EXISTS drops the specified function only if that
1382        // function exists and in this way, it avoids error. While the DROP FUNCTION
1383        // statement also performs the same function, it throws an
1384        // error if the function does not exist.
1385
1386        if !stmt.if_exists && !dropped {
1387            exec_err!("Function does not exist")
1388        } else {
1389            self.return_empty_dataframe()
1390        }
1391    }
1392
1393    fn execute_prepared(&self, execute: Execute) -> Result<DataFrame> {
1394        let Execute {
1395            name, parameters, ..
1396        } = execute;
1397        let prepared = self.state.read().get_prepared(&name).ok_or_else(|| {
1398            exec_datafusion_err!("Prepared statement '{}' does not exist", name)
1399        })?;
1400
1401        let state = self.state.read();
1402        let context = SimplifyContext::new(state.execution_props());
1403        let simplifier = ExprSimplifier::new(context);
1404
1405        // Only allow literals as parameters for now.
1406        let mut params: Vec<ScalarAndMetadata> = parameters
1407            .into_iter()
1408            .map(|e| match simplifier.simplify(e)? {
1409                Expr::Literal(scalar, metadata) => {
1410                    Ok(ScalarAndMetadata::new(scalar, metadata))
1411                }
1412                e => not_impl_err!("Unsupported parameter type: {e}"),
1413            })
1414            .collect::<Result<_>>()?;
1415
1416        // If the prepared statement provides data types, cast the params to those types.
1417        if !prepared.fields.is_empty() {
1418            if params.len() != prepared.fields.len() {
1419                return exec_err!(
1420                    "Prepared statement '{}' expects {} parameters, but {} provided",
1421                    name,
1422                    prepared.fields.len(),
1423                    params.len()
1424                );
1425            }
1426            params = params
1427                .into_iter()
1428                .zip(prepared.fields.iter())
1429                .map(|(e, dt)| -> Result<_> { e.cast_storage_to(dt.data_type()) })
1430                .collect::<Result<_>>()?;
1431        }
1432
1433        let params = ParamValues::List(params);
1434        let plan = prepared
1435            .plan
1436            .as_ref()
1437            .clone()
1438            .replace_params_with_values(&params)?;
1439        Ok(DataFrame::new(self.state(), plan))
1440    }
1441
1442    /// Registers a variable provider within this context.
1443    pub fn register_variable(
1444        &self,
1445        variable_type: VarType,
1446        provider: Arc<dyn VarProvider + Send + Sync>,
1447    ) {
1448        self.state
1449            .write()
1450            .execution_props_mut()
1451            .add_var_provider(variable_type, provider);
1452    }
1453
1454    /// Register a table UDF with this context
1455    pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
1456        self.state.write().register_udtf(name, fun)
1457    }
1458
1459    /// Registers a scalar UDF within this context.
1460    ///
1461    /// Note in SQL queries, function names are looked up using
1462    /// lowercase unless the query uses quotes. For example,
1463    ///
1464    /// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
1465    /// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
1466    ///
1467    /// Any functions registered with the udf name or its aliases will be overwritten with this new function
1468    pub fn register_udf(&self, f: ScalarUDF) {
1469        let mut state = self.state.write();
1470        state.register_udf(Arc::new(f)).ok();
1471    }
1472
1473    /// Registers an aggregate UDF within this context.
1474    ///
1475    /// Note in SQL queries, aggregate names are looked up using
1476    /// lowercase unless the query uses quotes. For example,
1477    ///
1478    /// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
1479    /// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
1480    pub fn register_udaf(&self, f: AggregateUDF) {
1481        self.state.write().register_udaf(Arc::new(f)).ok();
1482    }
1483
1484    /// Registers a window UDF within this context.
1485    ///
1486    /// Note in SQL queries, window function names are looked up using
1487    /// lowercase unless the query uses quotes. For example,
1488    ///
1489    /// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"`
1490    /// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
1491    pub fn register_udwf(&self, f: WindowUDF) {
1492        self.state.write().register_udwf(Arc::new(f)).ok();
1493    }
1494
1495    #[cfg(feature = "sql")]
1496    /// Registers a [`RelationPlanner`] to customize SQL table-factor planning.
1497    ///
1498    /// Planners are invoked in reverse registration order, allowing newer
1499    /// planners to take precedence over existing ones.
1500    pub fn register_relation_planner(
1501        &self,
1502        planner: Arc<dyn RelationPlanner>,
1503    ) -> Result<()> {
1504        self.state.write().register_relation_planner(planner)
1505    }
1506
1507    /// Deregisters a UDF within this context.
1508    pub fn deregister_udf(&self, name: &str) {
1509        self.state.write().deregister_udf(name).ok();
1510    }
1511
1512    /// Deregisters a UDAF within this context.
1513    pub fn deregister_udaf(&self, name: &str) {
1514        self.state.write().deregister_udaf(name).ok();
1515    }
1516
1517    /// Deregisters a UDWF within this context.
1518    pub fn deregister_udwf(&self, name: &str) {
1519        self.state.write().deregister_udwf(name).ok();
1520    }
1521
1522    /// Deregisters a UDTF within this context.
1523    pub fn deregister_udtf(&self, name: &str) {
1524        self.state.write().deregister_udtf(name).ok();
1525    }
1526
1527    /// Creates a [`DataFrame`] for reading a data source.
1528    ///
1529    /// For more control such as reading multiple files, you can use
1530    /// [`read_table`](Self::read_table) with a [`ListingTable`].
1531    async fn _read_type<'a, P: DataFilePaths>(
1532        &self,
1533        table_paths: P,
1534        options: impl ReadOptions<'a>,
1535    ) -> Result<DataFrame> {
1536        let table_paths = table_paths.to_urls()?;
1537        let session_config = self.copied_config();
1538        let listing_options =
1539            options.to_listing_options(&session_config, self.copied_table_options());
1540
1541        let option_extension = listing_options.file_extension.clone();
1542
1543        if table_paths.is_empty() {
1544            return exec_err!("No table paths were provided");
1545        }
1546
1547        // check if the file extension matches the expected extension
1548        for path in &table_paths {
1549            let file_path = path.as_str();
1550            if !file_path.ends_with(option_extension.clone().as_str())
1551                && !path.is_collection()
1552            {
1553                return exec_err!(
1554                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
1555                );
1556            }
1557        }
1558
1559        let resolved_schema = options
1560            .get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
1561            .await?;
1562        let config = ListingTableConfig::new_with_multi_paths(table_paths)
1563            .with_listing_options(listing_options)
1564            .with_schema(resolved_schema);
1565        let provider = ListingTable::try_new(config)?;
1566        self.read_table(Arc::new(provider))
1567    }
1568
1569    /// Creates a [`DataFrame`] for reading an Arrow data source.
1570    ///
1571    /// For more control such as reading multiple files, you can use
1572    /// [`read_table`](Self::read_table) with a [`ListingTable`].
1573    ///
1574    /// For an example, see [`read_csv`](Self::read_csv)
1575    pub async fn read_arrow<P: DataFilePaths>(
1576        &self,
1577        table_paths: P,
1578        options: ArrowReadOptions<'_>,
1579    ) -> Result<DataFrame> {
1580        self._read_type(table_paths, options).await
1581    }
1582
1583    /// Creates an empty DataFrame.
1584    pub fn read_empty(&self) -> Result<DataFrame> {
1585        Ok(DataFrame::new(
1586            self.state(),
1587            LogicalPlanBuilder::empty(true).build()?,
1588        ))
1589    }
1590
1591    /// Creates a [`DataFrame`] for a [`TableProvider`] such as a
1592    /// [`ListingTable`] or a custom user defined provider.
1593    pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
1594        Ok(DataFrame::new(
1595            self.state(),
1596            LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
1597                .build()?,
1598        ))
1599    }
1600
1601    /// Creates a [`DataFrame`] for reading a [`RecordBatch`]
1602    pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
1603        let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
1604        Ok(DataFrame::new(
1605            self.state(),
1606            LogicalPlanBuilder::scan(
1607                UNNAMED_TABLE,
1608                provider_as_source(Arc::new(provider)),
1609                None,
1610            )?
1611            .build()?,
1612        ))
1613    }
1614    /// Create a [`DataFrame`] for reading a [`Vec[`RecordBatch`]`]
1615    pub fn read_batches(
1616        &self,
1617        batches: impl IntoIterator<Item = RecordBatch>,
1618    ) -> Result<DataFrame> {
1619        // check schema uniqueness
1620        let mut batches = batches.into_iter().peekable();
1621        let schema = if let Some(batch) = batches.peek() {
1622            batch.schema()
1623        } else {
1624            Arc::new(Schema::empty())
1625        };
1626        let provider = MemTable::try_new(schema, vec![batches.collect()])?;
1627        Ok(DataFrame::new(
1628            self.state(),
1629            LogicalPlanBuilder::scan(
1630                UNNAMED_TABLE,
1631                provider_as_source(Arc::new(provider)),
1632                None,
1633            )?
1634            .build()?,
1635        ))
1636    }
1637    /// Registers a [`ListingTable`] that can assemble multiple files
1638    /// from locations in an [`ObjectStore`] instance into a single
1639    /// table.
1640    ///
1641    /// This method is `async` because it might need to resolve the schema.
1642    ///
1643    /// [`ObjectStore`]: object_store::ObjectStore
1644    pub async fn register_listing_table(
1645        &self,
1646        table_ref: impl Into<TableReference>,
1647        table_path: impl AsRef<str>,
1648        options: ListingOptions,
1649        provided_schema: Option<SchemaRef>,
1650        sql_definition: Option<String>,
1651    ) -> Result<()> {
1652        let table_path = ListingTableUrl::parse(table_path)?;
1653        let resolved_schema = match provided_schema {
1654            Some(s) => s,
1655            None => options.infer_schema(&self.state(), &table_path).await?,
1656        };
1657        let config = ListingTableConfig::new(table_path)
1658            .with_listing_options(options)
1659            .with_schema(resolved_schema);
1660        let table = ListingTable::try_new(config)?.with_definition(sql_definition);
1661        self.register_table(table_ref, Arc::new(table))?;
1662        Ok(())
1663    }
1664
1665    fn register_type_check<P: DataFilePaths>(
1666        &self,
1667        table_paths: P,
1668        extension: impl AsRef<str>,
1669    ) -> Result<()> {
1670        let table_paths = table_paths.to_urls()?;
1671        if table_paths.is_empty() {
1672            return exec_err!("No table paths were provided");
1673        }
1674
1675        // check if the file extension matches the expected extension
1676        let extension = extension.as_ref();
1677        for path in &table_paths {
1678            let file_path = path.as_str();
1679            if !file_path.ends_with(extension) && !path.is_collection() {
1680                return exec_err!(
1681                    "File path '{file_path}' does not match the expected extension '{extension}'"
1682                );
1683            }
1684        }
1685        Ok(())
1686    }
1687
1688    /// Registers an Arrow file as a table that can be referenced from
1689    /// SQL statements executed against this context.
1690    pub async fn register_arrow(
1691        &self,
1692        name: &str,
1693        table_path: &str,
1694        options: ArrowReadOptions<'_>,
1695    ) -> Result<()> {
1696        let listing_options = options
1697            .to_listing_options(&self.copied_config(), self.copied_table_options());
1698
1699        self.register_listing_table(
1700            name,
1701            table_path,
1702            listing_options,
1703            options.schema.map(|s| Arc::new(s.to_owned())),
1704            None,
1705        )
1706        .await?;
1707        Ok(())
1708    }
1709
1710    /// Registers a named catalog using a custom `CatalogProvider` so that
1711    /// it can be referenced from SQL statements executed against this
1712    /// context.
1713    ///
1714    /// Returns the [`CatalogProvider`] previously registered for this
1715    /// name, if any
1716    pub fn register_catalog(
1717        &self,
1718        name: impl Into<String>,
1719        catalog: Arc<dyn CatalogProvider>,
1720    ) -> Option<Arc<dyn CatalogProvider>> {
1721        let name = name.into();
1722        self.state
1723            .read()
1724            .catalog_list()
1725            .register_catalog(name, catalog)
1726    }
1727
1728    /// Retrieves the list of available catalog names.
1729    pub fn catalog_names(&self) -> Vec<String> {
1730        self.state.read().catalog_list().catalog_names()
1731    }
1732
1733    /// Retrieves a [`CatalogProvider`] instance by name
1734    pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
1735        self.state.read().catalog_list().catalog(name)
1736    }
1737
1738    /// Registers a [`TableProvider`] as a table that can be
1739    /// referenced from SQL statements executed against this context.
1740    ///
1741    /// If a table of the same name was already registered, returns "Table
1742    /// already exists" error.
1743    pub fn register_table(
1744        &self,
1745        table_ref: impl Into<TableReference>,
1746        provider: Arc<dyn TableProvider>,
1747    ) -> Result<Option<Arc<dyn TableProvider>>> {
1748        let table_ref: TableReference = table_ref.into();
1749        let table = table_ref.table().to_owned();
1750        self.state
1751            .read()
1752            .schema_for_ref(table_ref)?
1753            .register_table(table, provider)
1754    }
1755
1756    /// Deregisters the given table.
1757    ///
1758    /// Returns the registered provider, if any
1759    pub fn deregister_table(
1760        &self,
1761        table_ref: impl Into<TableReference>,
1762    ) -> Result<Option<Arc<dyn TableProvider>>> {
1763        let table_ref = table_ref.into();
1764        let table = table_ref.table().to_owned();
1765        self.state
1766            .read()
1767            .schema_for_ref(table_ref)?
1768            .deregister_table(&table)
1769    }
1770
1771    /// Return `true` if the specified table exists in the schema provider.
1772    pub fn table_exist(&self, table_ref: impl Into<TableReference>) -> Result<bool> {
1773        let table_ref: TableReference = table_ref.into();
1774        let table = table_ref.table();
1775        let table_ref = table_ref.clone();
1776        Ok(self
1777            .state
1778            .read()
1779            .schema_for_ref(table_ref)?
1780            .table_exist(table))
1781    }
1782
1783    /// Retrieves a [`DataFrame`] representing a table previously
1784    /// registered by calling the [`register_table`] function.
1785    ///
1786    /// Returns an error if no table has been registered with the
1787    /// provided reference.
1788    ///
1789    /// [`register_table`]: SessionContext::register_table
1790    pub async fn table(&self, table_ref: impl Into<TableReference>) -> Result<DataFrame> {
1791        let table_ref: TableReference = table_ref.into();
1792        let provider = self.table_provider(table_ref.clone()).await?;
1793        let plan = LogicalPlanBuilder::scan(
1794            table_ref,
1795            provider_as_source(Arc::clone(&provider)),
1796            None,
1797        )?
1798        .build()?;
1799        Ok(DataFrame::new(self.state(), plan))
1800    }
1801
1802    /// Retrieves a [`TableFunction`] reference by name.
1803    ///
1804    /// Returns an error if no table function has been registered with the provided name.
1805    ///
1806    /// [`register_udtf`]: SessionContext::register_udtf
1807    pub fn table_function(&self, name: &str) -> Result<Arc<TableFunction>> {
1808        self.state
1809            .read()
1810            .table_functions()
1811            .get(name)
1812            .cloned()
1813            .ok_or_else(|| plan_datafusion_err!("Table function '{name}' not found"))
1814    }
1815
1816    /// Return a [`TableProvider`] for the specified table.
1817    pub async fn table_provider(
1818        &self,
1819        table_ref: impl Into<TableReference>,
1820    ) -> Result<Arc<dyn TableProvider>> {
1821        let table_ref = table_ref.into();
1822        let table = table_ref.table().to_string();
1823        let schema = self.state.read().schema_for_ref(table_ref)?;
1824        match schema.table(&table).await? {
1825            Some(ref provider) => Ok(Arc::clone(provider)),
1826            _ => plan_err!("No table named '{table}'"),
1827        }
1828    }
1829
1830    /// Get a new TaskContext to run in this session
1831    pub fn task_ctx(&self) -> Arc<TaskContext> {
1832        Arc::new(TaskContext::from(self))
1833    }
1834
1835    /// Return a new  [`SessionState`] suitable for executing a single query.
1836    ///
1837    /// Notes:
1838    ///
1839    /// 1. `query_execution_start_time` is set to the current time for the
1840    ///    returned state.
1841    ///
1842    /// 2. The returned state is not shared with the current session state
1843    ///    and this changes to the returned `SessionState` such as changing
1844    ///    [`ConfigOptions`] will not be reflected in this `SessionContext`.
1845    ///
1846    /// [`ConfigOptions`]: crate::config::ConfigOptions
1847    pub fn state(&self) -> SessionState {
1848        let mut state = self.state.read().clone();
1849        state.mark_start_execution();
1850        state
1851    }
1852
1853    /// Get reference to [`SessionState`]
1854    pub fn state_ref(&self) -> Arc<RwLock<SessionState>> {
1855        Arc::clone(&self.state)
1856    }
1857
1858    /// Get weak reference to [`SessionState`]
1859    pub fn state_weak_ref(&self) -> Weak<RwLock<SessionState>> {
1860        Arc::downgrade(&self.state)
1861    }
1862
1863    /// Register [`CatalogProviderList`] in [`SessionState`]
1864    pub fn register_catalog_list(&self, catalog_list: Arc<dyn CatalogProviderList>) {
1865        self.state.write().register_catalog_list(catalog_list)
1866    }
1867
1868    /// Registers a [`ConfigExtension`] as a table option extension that can be
1869    /// referenced from SQL statements executed against this context.
1870    pub fn register_table_options_extension<T: ConfigExtension>(&self, extension: T) {
1871        self.state
1872            .write()
1873            .register_table_options_extension(extension)
1874    }
1875}
1876
1877impl FunctionRegistry for SessionContext {
1878    fn udfs(&self) -> HashSet<String> {
1879        self.state.read().udfs()
1880    }
1881
1882    fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
1883        self.state.read().udf(name)
1884    }
1885
1886    fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
1887        self.state.read().udaf(name)
1888    }
1889
1890    fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
1891        self.state.read().udwf(name)
1892    }
1893
1894    fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
1895        self.state.write().register_udf(udf)
1896    }
1897
1898    fn register_udaf(
1899        &mut self,
1900        udaf: Arc<AggregateUDF>,
1901    ) -> Result<Option<Arc<AggregateUDF>>> {
1902        self.state.write().register_udaf(udaf)
1903    }
1904
1905    fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
1906        self.state.write().register_udwf(udwf)
1907    }
1908
1909    fn register_function_rewrite(
1910        &mut self,
1911        rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
1912    ) -> Result<()> {
1913        self.state.write().register_function_rewrite(rewrite)
1914    }
1915
1916    fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
1917        self.state.read().expr_planners().to_vec()
1918    }
1919
1920    fn register_expr_planner(
1921        &mut self,
1922        expr_planner: Arc<dyn ExprPlanner>,
1923    ) -> Result<()> {
1924        self.state.write().register_expr_planner(expr_planner)
1925    }
1926
1927    fn udafs(&self) -> HashSet<String> {
1928        self.state.read().udafs()
1929    }
1930
1931    fn udwfs(&self) -> HashSet<String> {
1932        self.state.read().udwfs()
1933    }
1934}
1935
1936impl datafusion_execution::TaskContextProvider for SessionContext {
1937    fn task_ctx(&self) -> Arc<TaskContext> {
1938        SessionContext::task_ctx(self)
1939    }
1940}
1941
1942/// Create a new task context instance from SessionContext
1943impl From<&SessionContext> for TaskContext {
1944    fn from(session: &SessionContext) -> Self {
1945        TaskContext::from(&*session.state.read())
1946    }
1947}
1948
1949impl From<SessionState> for SessionContext {
1950    fn from(state: SessionState) -> Self {
1951        Self::new_with_state(state)
1952    }
1953}
1954
1955impl From<SessionContext> for SessionStateBuilder {
1956    fn from(session: SessionContext) -> Self {
1957        session.into_state_builder()
1958    }
1959}
1960
1961/// A planner used to add extensions to DataFusion logical and physical plans.
1962#[async_trait]
1963pub trait QueryPlanner: Debug {
1964    /// Given a [`LogicalPlan`], create an [`ExecutionPlan`] suitable for execution
1965    async fn create_physical_plan(
1966        &self,
1967        logical_plan: &LogicalPlan,
1968        session_state: &SessionState,
1969    ) -> Result<Arc<dyn ExecutionPlan>>;
1970}
1971
1972/// Interface for handling `CREATE FUNCTION` statements and interacting with
1973/// [SessionState] to create and register functions ([`ScalarUDF`],
1974/// [`AggregateUDF`], [`WindowUDF`], and [`TableFunctionImpl`]) dynamically.
1975///
1976/// Implement this trait to create user-defined functions in a custom way, such
1977/// as loading from external libraries or defining them programmatically.
1978/// DataFusion will parse `CREATE FUNCTION` statements into [`CreateFunction`]
1979/// structs and pass them to the [`create`](Self::create) method.
1980///
1981/// Note there is no default implementation of this trait provided in DataFusion,
1982/// because the implementation and requirements vary widely. Please see
1983/// [function_factory example] for a reference implementation.
1984///
1985/// [function_factory example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/function_factory.rs
1986///
1987/// # Examples of syntax that can be supported
1988///
1989/// ```sql
1990/// CREATE FUNCTION f1(BIGINT)
1991///   RETURNS BIGINT
1992///   RETURN $1 + 1;
1993/// ```
1994/// or
1995/// ```sql
1996/// CREATE FUNCTION to_miles(DOUBLE)
1997/// RETURNS DOUBLE
1998/// LANGUAGE PYTHON
1999/// AS '
2000/// import pyarrow.compute as pc
2001///
2002/// conversation_rate_multiplier = 0.62137119
2003///
2004/// def to_miles(km_data):
2005///     return pc.multiply(km_data, conversation_rate_multiplier)
2006/// '
2007/// ```
2008
2009#[async_trait]
2010pub trait FunctionFactory: Debug + Sync + Send {
2011    /// Creates a new dynamic function from the SQL in the [CreateFunction] statement
2012    async fn create(
2013        &self,
2014        state: &SessionState,
2015        statement: CreateFunction,
2016    ) -> Result<RegisterFunction>;
2017}
2018
2019/// The result of processing a [`CreateFunction`] statement with [`FunctionFactory`].
2020#[derive(Debug, Clone)]
2021pub enum RegisterFunction {
2022    /// Scalar user defined function
2023    Scalar(Arc<ScalarUDF>),
2024    /// Aggregate user defined function
2025    Aggregate(Arc<AggregateUDF>),
2026    /// Window user defined function
2027    Window(Arc<WindowUDF>),
2028    /// Table user defined function
2029    Table(String, Arc<dyn TableFunctionImpl>),
2030}
2031
2032/// Default implementation of [SerializerRegistry] that throws unimplemented error
2033/// for all requests.
2034#[derive(Debug)]
2035pub struct EmptySerializerRegistry;
2036
2037impl SerializerRegistry for EmptySerializerRegistry {
2038    fn serialize_logical_plan(
2039        &self,
2040        node: &dyn UserDefinedLogicalNode,
2041    ) -> Result<Vec<u8>> {
2042        not_impl_err!(
2043            "Serializing user defined logical plan node `{}` is not supported",
2044            node.name()
2045        )
2046    }
2047
2048    fn deserialize_logical_plan(
2049        &self,
2050        name: &str,
2051        _bytes: &[u8],
2052    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
2053        not_impl_err!(
2054            "Deserializing user defined logical plan node `{name}` is not supported"
2055        )
2056    }
2057}
2058
2059/// Describes which SQL statements can be run.
2060///
2061/// See [`SessionContext::sql_with_options`] for more details.
2062#[derive(Clone, Debug, Copy)]
2063pub struct SQLOptions {
2064    /// See [`Self::with_allow_ddl`]
2065    allow_ddl: bool,
2066    /// See [`Self::with_allow_dml`]
2067    allow_dml: bool,
2068    /// See [`Self::with_allow_statements`]
2069    allow_statements: bool,
2070}
2071
2072impl Default for SQLOptions {
2073    fn default() -> Self {
2074        Self {
2075            allow_ddl: true,
2076            allow_dml: true,
2077            allow_statements: true,
2078        }
2079    }
2080}
2081
2082impl SQLOptions {
2083    /// Create a new `SQLOptions` with default values
2084    pub fn new() -> Self {
2085        Default::default()
2086    }
2087
2088    /// Should DDL data definition commands  (e.g. `CREATE TABLE`) be run? Defaults to `true`.
2089    pub fn with_allow_ddl(mut self, allow: bool) -> Self {
2090        self.allow_ddl = allow;
2091        self
2092    }
2093
2094    /// Should DML data modification commands (e.g. `INSERT` and `COPY`) be run? Defaults to `true`
2095    pub fn with_allow_dml(mut self, allow: bool) -> Self {
2096        self.allow_dml = allow;
2097        self
2098    }
2099
2100    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
2101    pub fn with_allow_statements(mut self, allow: bool) -> Self {
2102        self.allow_statements = allow;
2103        self
2104    }
2105
2106    /// Return an error if the [`LogicalPlan`] has any nodes that are
2107    /// incompatible with this [`SQLOptions`].
2108    pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> {
2109        plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?;
2110        Ok(())
2111    }
2112}
2113
2114struct BadPlanVisitor<'a> {
2115    options: &'a SQLOptions,
2116}
2117impl<'a> BadPlanVisitor<'a> {
2118    fn new(options: &'a SQLOptions) -> Self {
2119        Self { options }
2120    }
2121}
2122
2123impl<'n> TreeNodeVisitor<'n> for BadPlanVisitor<'_> {
2124    type Node = LogicalPlan;
2125
2126    fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2127        match node {
2128            LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
2129                plan_err!("DDL not supported: {}", ddl.name())
2130            }
2131            LogicalPlan::Dml(dml) if !self.options.allow_dml => {
2132                plan_err!("DML not supported: {}", dml.op)
2133            }
2134            LogicalPlan::Copy(_) if !self.options.allow_dml => {
2135                plan_err!("DML not supported: COPY")
2136            }
2137            LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
2138                plan_err!("Statement not supported: {}", stmt.name())
2139            }
2140            _ => Ok(TreeNodeRecursion::Continue),
2141        }
2142    }
2143}
2144
2145#[cfg(test)]
2146mod tests {
2147    use super::{super::options::CsvReadOptions, *};
2148    use crate::execution::memory_pool::MemoryConsumer;
2149    use crate::test;
2150    use crate::test_util::{plan_and_collect, populate_csv_partitions};
2151    use arrow::datatypes::{DataType, TimeUnit};
2152    use datafusion_common::DataFusionError;
2153    use std::error::Error;
2154    use std::path::PathBuf;
2155
2156    use datafusion_common::test_util::batches_to_string;
2157    use datafusion_common_runtime::SpawnedTask;
2158    use insta::{allow_duplicates, assert_snapshot};
2159
2160    use crate::catalog::SchemaProvider;
2161    use crate::execution::session_state::SessionStateBuilder;
2162    use crate::physical_planner::PhysicalPlanner;
2163    use async_trait::async_trait;
2164    use datafusion_expr::planner::TypePlanner;
2165    use sqlparser::ast;
2166    use tempfile::TempDir;
2167
2168    #[tokio::test]
2169    async fn shared_memory_and_disk_manager() {
2170        // Demonstrate the ability to share DiskManager and
2171        // MemoryPool between two different executions.
2172        let ctx1 = SessionContext::new();
2173
2174        // configure with same memory / disk manager
2175        let memory_pool = ctx1.runtime_env().memory_pool.clone();
2176
2177        let mut reservation = MemoryConsumer::new("test").register(&memory_pool);
2178        reservation.grow(100);
2179
2180        let disk_manager = ctx1.runtime_env().disk_manager.clone();
2181
2182        let ctx2 =
2183            SessionContext::new_with_config_rt(SessionConfig::new(), ctx1.runtime_env());
2184
2185        assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100);
2186        assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100);
2187
2188        drop(reservation);
2189
2190        assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 0);
2191        assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 0);
2192
2193        assert!(std::ptr::eq(
2194            Arc::as_ptr(&disk_manager),
2195            Arc::as_ptr(&ctx1.runtime_env().disk_manager)
2196        ));
2197        assert!(std::ptr::eq(
2198            Arc::as_ptr(&disk_manager),
2199            Arc::as_ptr(&ctx2.runtime_env().disk_manager)
2200        ));
2201    }
2202
2203    #[tokio::test]
2204    async fn create_variable_expr() -> Result<()> {
2205        let tmp_dir = TempDir::new()?;
2206        let partition_count = 4;
2207        let ctx = create_ctx(&tmp_dir, partition_count).await?;
2208
2209        let variable_provider = test::variable::SystemVar::new();
2210        ctx.register_variable(VarType::System, Arc::new(variable_provider));
2211        let variable_provider = test::variable::UserDefinedVar::new();
2212        ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));
2213
2214        let provider = test::create_table_dual();
2215        ctx.register_table("dual", provider)?;
2216
2217        let results =
2218            plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
2219                .await?;
2220
2221        assert_snapshot!(batches_to_string(&results), @r"
2222        +----------------------+------------------------+---------------------+
2223        | @@version            | @name                  | @integer + Int64(1) |
2224        +----------------------+------------------------+---------------------+
2225        | system-var-@@version | user-defined-var-@name | 42                  |
2226        +----------------------+------------------------+---------------------+
2227        ");
2228
2229        Ok(())
2230    }
2231
2232    #[tokio::test]
2233    async fn create_variable_err() -> Result<()> {
2234        let ctx = SessionContext::new();
2235
2236        let err = plan_and_collect(&ctx, "SElECT @=   X3").await.unwrap_err();
2237        assert_eq!(
2238            err.strip_backtrace(),
2239            "Error during planning: variable [\"@=\"] has no type information"
2240        );
2241        Ok(())
2242    }
2243
2244    #[tokio::test]
2245    async fn register_deregister() -> Result<()> {
2246        let tmp_dir = TempDir::new()?;
2247        let partition_count = 4;
2248        let ctx = create_ctx(&tmp_dir, partition_count).await?;
2249
2250        let provider = test::create_table_dual();
2251        ctx.register_table("dual", provider)?;
2252
2253        assert!(ctx.deregister_table("dual")?.is_some());
2254        assert!(ctx.deregister_table("dual")?.is_none());
2255
2256        Ok(())
2257    }
2258
2259    #[tokio::test]
2260    async fn send_context_to_threads() -> Result<()> {
2261        // ensure SessionContexts can be used in a multi-threaded
2262        // environment. Use case is for concurrent planing.
2263        let tmp_dir = TempDir::new()?;
2264        let partition_count = 4;
2265        let ctx = Arc::new(create_ctx(&tmp_dir, partition_count).await?);
2266
2267        let threads: Vec<_> = (0..2)
2268            .map(|_| ctx.clone())
2269            .map(|ctx| {
2270                SpawnedTask::spawn(async move {
2271                    // Ensure we can create logical plan code on a separate thread.
2272                    ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")
2273                        .await
2274                })
2275            })
2276            .collect();
2277
2278        for handle in threads {
2279            handle.join().await.unwrap().unwrap();
2280        }
2281        Ok(())
2282    }
2283
2284    #[tokio::test]
2285    async fn with_listing_schema_provider() -> Result<()> {
2286        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2287        let path = path.join("tests/tpch-csv");
2288        let url = format!("file://{}", path.display());
2289
2290        let cfg = SessionConfig::new()
2291            .set_str("datafusion.catalog.location", url.as_str())
2292            .set_str("datafusion.catalog.format", "CSV")
2293            .set_str("datafusion.catalog.has_header", "true");
2294        let session_state = SessionStateBuilder::new()
2295            .with_config(cfg)
2296            .with_default_features()
2297            .build();
2298        let ctx = SessionContext::new_with_state(session_state);
2299        ctx.refresh_catalogs().await?;
2300
2301        let result =
2302            plan_and_collect(&ctx, "select c_name from default.customer limit 3;")
2303                .await?;
2304
2305        let actual = arrow::util::pretty::pretty_format_batches(&result)
2306            .unwrap()
2307            .to_string();
2308        assert_snapshot!(actual, @r"
2309        +--------------------+
2310        | c_name             |
2311        +--------------------+
2312        | Customer#000000002 |
2313        | Customer#000000003 |
2314        | Customer#000000004 |
2315        +--------------------+
2316        ");
2317
2318        Ok(())
2319    }
2320
2321    #[tokio::test]
2322    async fn test_dynamic_file_query() -> Result<()> {
2323        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2324        let path = path.join("tests/tpch-csv/customer.csv");
2325        let url = format!("file://{}", path.display());
2326        let cfg = SessionConfig::new();
2327        let session_state = SessionStateBuilder::new()
2328            .with_default_features()
2329            .with_config(cfg)
2330            .build();
2331        let ctx = SessionContext::new_with_state(session_state).enable_url_table();
2332        let result = plan_and_collect(
2333            &ctx,
2334            format!("select c_name from '{}' limit 3;", &url).as_str(),
2335        )
2336        .await?;
2337
2338        let actual = arrow::util::pretty::pretty_format_batches(&result)
2339            .unwrap()
2340            .to_string();
2341        assert_snapshot!(actual, @r"
2342        +--------------------+
2343        | c_name             |
2344        +--------------------+
2345        | Customer#000000002 |
2346        | Customer#000000003 |
2347        | Customer#000000004 |
2348        +--------------------+
2349        ");
2350
2351        Ok(())
2352    }
2353
2354    #[tokio::test]
2355    async fn custom_query_planner() -> Result<()> {
2356        let runtime = Arc::new(RuntimeEnv::default());
2357        let session_state = SessionStateBuilder::new()
2358            .with_config(SessionConfig::new())
2359            .with_runtime_env(runtime)
2360            .with_default_features()
2361            .with_query_planner(Arc::new(MyQueryPlanner {}))
2362            .build();
2363        let ctx = SessionContext::new_with_state(session_state);
2364
2365        let df = ctx.sql("SELECT 1").await?;
2366        df.collect().await.expect_err("query not supported");
2367        Ok(())
2368    }
2369
2370    #[tokio::test]
2371    async fn disabled_default_catalog_and_schema() -> Result<()> {
2372        let ctx = SessionContext::new_with_config(
2373            SessionConfig::new().with_create_default_catalog_and_schema(false),
2374        );
2375
2376        assert!(matches!(
2377            ctx.register_table("test", test::table_with_sequence(1, 1)?),
2378            Err(DataFusionError::Plan(_))
2379        ));
2380
2381        let err = ctx
2382            .sql("select * from datafusion.public.test")
2383            .await
2384            .unwrap_err();
2385        let err = err
2386            .source()
2387            .and_then(|err| err.downcast_ref::<DataFusionError>())
2388            .unwrap();
2389
2390        assert!(matches!(err, &DataFusionError::Plan(_)));
2391
2392        Ok(())
2393    }
2394
2395    #[tokio::test]
2396    async fn custom_catalog_and_schema() {
2397        let config = SessionConfig::new()
2398            .with_create_default_catalog_and_schema(true)
2399            .with_default_catalog_and_schema("my_catalog", "my_schema");
2400        catalog_and_schema_test(config).await;
2401    }
2402
2403    #[tokio::test]
2404    async fn custom_catalog_and_schema_no_default() {
2405        let config = SessionConfig::new()
2406            .with_create_default_catalog_and_schema(false)
2407            .with_default_catalog_and_schema("my_catalog", "my_schema");
2408        catalog_and_schema_test(config).await;
2409    }
2410
2411    #[tokio::test]
2412    async fn custom_catalog_and_schema_and_information_schema() {
2413        let config = SessionConfig::new()
2414            .with_create_default_catalog_and_schema(true)
2415            .with_information_schema(true)
2416            .with_default_catalog_and_schema("my_catalog", "my_schema");
2417        catalog_and_schema_test(config).await;
2418    }
2419
2420    async fn catalog_and_schema_test(config: SessionConfig) {
2421        let ctx = SessionContext::new_with_config(config);
2422        let catalog = MemoryCatalogProvider::new();
2423        let schema = MemorySchemaProvider::new();
2424        schema
2425            .register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap())
2426            .unwrap();
2427        catalog
2428            .register_schema("my_schema", Arc::new(schema))
2429            .unwrap();
2430        ctx.register_catalog("my_catalog", Arc::new(catalog));
2431
2432        let mut results = Vec::new();
2433
2434        for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] {
2435            let result = plan_and_collect(
2436                &ctx,
2437                &format!("SELECT COUNT(*) AS count FROM {table_ref}"),
2438            )
2439            .await
2440            .unwrap();
2441
2442            results.push(result);
2443        }
2444        allow_duplicates! {
2445            for result in &results {
2446                assert_snapshot!(batches_to_string(result), @r"
2447                +-------+
2448                | count |
2449                +-------+
2450                | 1     |
2451                +-------+
2452                ");
2453            }
2454        }
2455    }
2456
2457    #[tokio::test]
2458    async fn cross_catalog_access() -> Result<()> {
2459        let ctx = SessionContext::new();
2460
2461        let catalog_a = MemoryCatalogProvider::new();
2462        let schema_a = MemorySchemaProvider::new();
2463        schema_a
2464            .register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
2465        catalog_a.register_schema("schema_a", Arc::new(schema_a))?;
2466        ctx.register_catalog("catalog_a", Arc::new(catalog_a));
2467
2468        let catalog_b = MemoryCatalogProvider::new();
2469        let schema_b = MemorySchemaProvider::new();
2470        schema_b
2471            .register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
2472        catalog_b.register_schema("schema_b", Arc::new(schema_b))?;
2473        ctx.register_catalog("catalog_b", Arc::new(catalog_b));
2474
2475        let result = plan_and_collect(
2476            &ctx,
2477            "SELECT cat, SUM(i) AS total FROM (
2478                    SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a
2479                    UNION ALL
2480                    SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b
2481                ) AS all
2482                GROUP BY cat
2483                ORDER BY cat
2484                ",
2485        )
2486        .await?;
2487
2488        assert_snapshot!(batches_to_string(&result), @r"
2489        +-----+-------+
2490        | cat | total |
2491        +-----+-------+
2492        | a   | 1     |
2493        | b   | 3     |
2494        +-----+-------+
2495        ");
2496
2497        Ok(())
2498    }
2499
2500    #[tokio::test]
2501    async fn catalogs_not_leaked() {
2502        // the information schema used to introduce cyclic Arcs
2503        let ctx = SessionContext::new_with_config(
2504            SessionConfig::new().with_information_schema(true),
2505        );
2506
2507        // register a single catalog
2508        let catalog = Arc::new(MemoryCatalogProvider::new());
2509        let catalog_weak = Arc::downgrade(&catalog);
2510        ctx.register_catalog("my_catalog", catalog);
2511
2512        let catalog_list_weak = {
2513            let state = ctx.state.read();
2514            Arc::downgrade(state.catalog_list())
2515        };
2516
2517        drop(ctx);
2518
2519        assert_eq!(Weak::strong_count(&catalog_list_weak), 0);
2520        assert_eq!(Weak::strong_count(&catalog_weak), 0);
2521    }
2522
2523    #[tokio::test]
2524    async fn sql_create_schema() -> Result<()> {
2525        // the information schema used to introduce cyclic Arcs
2526        let ctx = SessionContext::new_with_config(
2527            SessionConfig::new().with_information_schema(true),
2528        );
2529
2530        // Create schema
2531        ctx.sql("CREATE SCHEMA abc").await?.collect().await?;
2532
2533        // Add table to schema
2534        ctx.sql("CREATE TABLE abc.y AS VALUES (1,2,3)")
2535            .await?
2536            .collect()
2537            .await?;
2538
2539        // Check table exists in schema
2540        let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
2541
2542        assert_eq!(results[0].num_rows(), 1);
2543        Ok(())
2544    }
2545
2546    #[tokio::test]
2547    async fn sql_create_catalog() -> Result<()> {
2548        // the information schema used to introduce cyclic Arcs
2549        let ctx = SessionContext::new_with_config(
2550            SessionConfig::new().with_information_schema(true),
2551        );
2552
2553        // Create catalog
2554        ctx.sql("CREATE DATABASE test").await?.collect().await?;
2555
2556        // Create schema
2557        ctx.sql("CREATE SCHEMA test.abc").await?.collect().await?;
2558
2559        // Add table to schema
2560        ctx.sql("CREATE TABLE test.abc.y AS VALUES (1,2,3)")
2561            .await?
2562            .collect()
2563            .await?;
2564
2565        // Check table exists in schema
2566        let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_catalog='test' AND table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
2567
2568        assert_eq!(results[0].num_rows(), 1);
2569        Ok(())
2570    }
2571
2572    #[tokio::test]
2573    async fn custom_type_planner() -> Result<()> {
2574        let state = SessionStateBuilder::new()
2575            .with_default_features()
2576            .with_type_planner(Arc::new(MyTypePlanner {}))
2577            .build();
2578        let ctx = SessionContext::new_with_state(state);
2579        let result = ctx
2580            .sql("SELECT DATETIME '2021-01-01 00:00:00'")
2581            .await?
2582            .collect()
2583            .await?;
2584        assert_snapshot!(batches_to_string(&result), @r#"
2585        +-----------------------------+
2586        | Utf8("2021-01-01 00:00:00") |
2587        +-----------------------------+
2588        | 2021-01-01T00:00:00         |
2589        +-----------------------------+
2590        "#);
2591        Ok(())
2592    }
2593    #[test]
2594    fn preserve_session_context_id() -> Result<()> {
2595        let ctx = SessionContext::new();
2596        // it does make sense to preserve session id in this case
2597        // as  `enable_url_table()` can be seen as additional configuration
2598        // option on ctx.
2599        // some systems like datafusion ballista relies on stable session_id
2600        assert_eq!(ctx.session_id(), ctx.enable_url_table().session_id());
2601        Ok(())
2602    }
2603
2604    struct MyPhysicalPlanner {}
2605
2606    #[async_trait]
2607    impl PhysicalPlanner for MyPhysicalPlanner {
2608        async fn create_physical_plan(
2609            &self,
2610            _logical_plan: &LogicalPlan,
2611            _session_state: &SessionState,
2612        ) -> Result<Arc<dyn ExecutionPlan>> {
2613            not_impl_err!("query not supported")
2614        }
2615
2616        fn create_physical_expr(
2617            &self,
2618            _expr: &Expr,
2619            _input_dfschema: &DFSchema,
2620            _session_state: &SessionState,
2621        ) -> Result<Arc<dyn PhysicalExpr>> {
2622            unimplemented!()
2623        }
2624    }
2625
2626    #[derive(Debug)]
2627    struct MyQueryPlanner {}
2628
2629    #[async_trait]
2630    impl QueryPlanner for MyQueryPlanner {
2631        async fn create_physical_plan(
2632            &self,
2633            logical_plan: &LogicalPlan,
2634            session_state: &SessionState,
2635        ) -> Result<Arc<dyn ExecutionPlan>> {
2636            let physical_planner = MyPhysicalPlanner {};
2637            physical_planner
2638                .create_physical_plan(logical_plan, session_state)
2639                .await
2640        }
2641    }
2642
2643    /// Generate a partitioned CSV file and register it with an execution context
2644    async fn create_ctx(
2645        tmp_dir: &TempDir,
2646        partition_count: usize,
2647    ) -> Result<SessionContext> {
2648        let ctx = SessionContext::new_with_config(
2649            SessionConfig::new().with_target_partitions(8),
2650        );
2651
2652        let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?;
2653
2654        // register csv file with the execution context
2655        ctx.register_csv(
2656            "test",
2657            tmp_dir.path().to_str().unwrap(),
2658            CsvReadOptions::new().schema(&schema),
2659        )
2660        .await?;
2661
2662        Ok(ctx)
2663    }
2664
2665    #[derive(Debug)]
2666    struct MyTypePlanner {}
2667
2668    impl TypePlanner for MyTypePlanner {
2669        fn plan_type(&self, sql_type: &ast::DataType) -> Result<Option<DataType>> {
2670            match sql_type {
2671                ast::DataType::Datetime(precision) => {
2672                    let precision = match precision {
2673                        Some(0) => TimeUnit::Second,
2674                        Some(3) => TimeUnit::Millisecond,
2675                        Some(6) => TimeUnit::Microsecond,
2676                        None | Some(9) => TimeUnit::Nanosecond,
2677                        _ => unreachable!(),
2678                    };
2679                    Ok(Some(DataType::Timestamp(precision, None)))
2680                }
2681                _ => Ok(None),
2682            }
2683        }
2684    }
2685
2686    #[tokio::test]
2687    async fn remove_optimizer_rule() -> Result<()> {
2688        let get_optimizer_rules = |ctx: &SessionContext| {
2689            ctx.state()
2690                .optimizer()
2691                .rules
2692                .iter()
2693                .map(|r| r.name().to_owned())
2694                .collect::<HashSet<_>>()
2695        };
2696
2697        let ctx = SessionContext::new();
2698        assert!(get_optimizer_rules(&ctx).contains("simplify_expressions"));
2699
2700        // default plan
2701        let plan = ctx
2702            .sql("select 1 + 1")
2703            .await?
2704            .into_optimized_plan()?
2705            .to_string();
2706        assert_snapshot!(plan, @r"
2707        Projection: Int64(2) AS Int64(1) + Int64(1)
2708          EmptyRelation: rows=1
2709        ");
2710
2711        assert!(ctx.remove_optimizer_rule("simplify_expressions"));
2712        assert!(!get_optimizer_rules(&ctx).contains("simplify_expressions"));
2713
2714        // plan without the simplify_expressions rule
2715        let plan = ctx
2716            .sql("select 1 + 1")
2717            .await?
2718            .into_optimized_plan()?
2719            .to_string();
2720        assert_snapshot!(plan, @r"
2721        Projection: Int64(1) + Int64(1)
2722          EmptyRelation: rows=1
2723        ");
2724
2725        // attempting to remove a non-existing rule returns false
2726        assert!(!ctx.remove_optimizer_rule("simplify_expressions"));
2727
2728        Ok(())
2729    }
2730
2731    #[test]
2732    fn test_parse_duration() {
2733        // Valid durations
2734        for (duration, want) in [
2735            ("1s", Duration::from_secs(1)),
2736            ("1m", Duration::from_secs(60)),
2737            ("1m0s", Duration::from_secs(60)),
2738            ("1m1s", Duration::from_secs(61)),
2739        ] {
2740            let have = SessionContext::parse_duration(duration).unwrap();
2741            assert_eq!(want, have);
2742        }
2743
2744        // Invalid durations
2745        for duration in ["0s", "0m", "1s0m", "1s1m"] {
2746            let have = SessionContext::parse_duration(duration);
2747            assert!(have.is_err());
2748        }
2749    }
2750}