datafusion_dft/extensions/
builder.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`DftSessionStateBuilder`] for configuring DataFusion [`SessionState`]
19
20use color_eyre::eyre;
21use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory};
22use datafusion::catalog_common::MemoryCatalogProviderList;
23use datafusion::execution::context::SessionState;
24use datafusion::execution::runtime_env::RuntimeEnv;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::prelude::SessionConfig;
27use std::collections::HashMap;
28use std::fmt::Debug;
29use std::sync::Arc;
30
31use crate::{config::ExecutionConfig, execution::AppType};
32
33use super::{enabled_extensions, Extension};
34
35/// Builds a DataFusion [`SessionState`] with any necessary configuration
36///
37/// Ideally we would use the DataFusion [`SessionStateBuilder`], but it doesn't
38/// currently have all the needed APIs. Once we have a good handle on the needed
39/// APIs we can upstream them to DataFusion.
40///
41/// List of things that would be nice to add upstream:
42/// TODO: Implement Debug for SessionStateBuilder upstream
43///  <https://github.com/apache/datafusion/issues/12555>
44/// TODO: Implement some way to get access to the current RuntimeEnv (to register object stores)
45///  <https://github.com/apache/datafusion/issues/12553>
46/// TODO: Implement a way to add just a single TableProviderFactory
47///  <https://github.com/apache/datafusion/issues/12552>
48/// TODO: Make TableFactoryProvider implement Debug
49///   <https://github.com/apache/datafusion/pull/12557>
50/// TODO: rename RuntimeEnv::new() to RuntimeEnv::try_new() as it returns a Result:
51///   <https://github.com/apache/datafusion/issues/12554>
52//#[derive(Debug)]
53pub struct DftSessionStateBuilder {
54    app_type: Option<AppType>,
55    execution_config: Option<ExecutionConfig>,
56    session_config: SessionConfig,
57    table_factories: Option<HashMap<String, Arc<dyn TableProviderFactory>>>,
58    catalog_providers: Option<HashMap<String, Arc<dyn CatalogProvider>>>,
59    runtime_env: Option<Arc<RuntimeEnv>>,
60}
61
62impl Debug for DftSessionStateBuilder {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("DftSessionStateBuilder")
65            .field("session_config", &self.session_config)
66            .field(
67                "table_factories",
68                &"TODO TableFactory does not implement Debug",
69            )
70            .field("runtime_env", &self.runtime_env)
71            .finish()
72    }
73}
74
75impl Default for DftSessionStateBuilder {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl DftSessionStateBuilder {
82    /// Create a new builder
83    pub fn new() -> Self {
84        let session_config = SessionConfig::default().with_information_schema(true);
85
86        Self {
87            session_config,
88            app_type: None,
89            execution_config: None,
90            table_factories: None,
91            catalog_providers: None,
92            runtime_env: None,
93        }
94    }
95
96    pub fn with_app_type(mut self, app_type: AppType) -> Self {
97        self.app_type = Some(app_type);
98        self
99    }
100
101    pub fn with_execution_config(mut self, app_type: ExecutionConfig) -> Self {
102        self.execution_config = Some(app_type);
103        self
104    }
105
106    /// Set the `batch_size` on the [`SessionConfig`]
107    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
108        self.session_config = self.session_config.with_batch_size(batch_size);
109        self
110    }
111
112    /// Add a table factory to the list of factories on this builder
113    pub fn add_table_factory(&mut self, name: &str, factory: Arc<dyn TableProviderFactory>) {
114        if self.table_factories.is_none() {
115            self.table_factories = Some(HashMap::from([(name.to_string(), factory)]));
116        } else {
117            self.table_factories
118                .as_mut()
119                .unwrap()
120                .insert(name.to_string(), factory);
121        }
122    }
123
124    /// Add a catalog provider to the list of providers on this builder
125    pub fn add_catalog_provider(&mut self, name: &str, factory: Arc<dyn CatalogProvider>) {
126        if self.catalog_providers.is_none() {
127            self.catalog_providers = Some(HashMap::from([(name.to_string(), factory)]));
128        } else {
129            self.catalog_providers
130                .as_mut()
131                .unwrap()
132                .insert(name.to_string(), factory);
133        }
134    }
135
136    /// Return the current [`RuntimeEnv`], creating a default if it doesn't exist
137    pub fn runtime_env(&mut self) -> &RuntimeEnv {
138        if self.runtime_env.is_none() {
139            self.runtime_env = Some(Arc::new(RuntimeEnv::default()));
140        }
141        self.runtime_env.as_ref().unwrap()
142    }
143
144    pub async fn register_extension(
145        &mut self,
146        config: ExecutionConfig,
147        extension: Arc<dyn Extension>,
148    ) -> color_eyre::Result<()> {
149        extension
150            .register(config, self)
151            .await
152            .map_err(|_| eyre::eyre!("E"))
153    }
154
155    pub async fn with_extensions(mut self) -> color_eyre::Result<Self> {
156        let extensions = enabled_extensions();
157
158        for extension in extensions {
159            let execution_config = self.execution_config.clone().unwrap_or_default();
160            self.register_extension(execution_config, extension).await?;
161        }
162
163        Ok(self)
164    }
165
166    /// Build the [`SessionState`] from the specified configuration
167    pub fn build(self) -> datafusion_common::Result<SessionState> {
168        let Self {
169            app_type,
170            execution_config,
171            mut session_config,
172            table_factories,
173            catalog_providers,
174            runtime_env,
175            ..
176        } = self;
177
178        let app_type = app_type.unwrap_or(AppType::Cli);
179        let execution_config = execution_config.unwrap_or_default();
180
181        match app_type {
182            AppType::Cli => {
183                session_config = session_config.with_batch_size(execution_config.cli_batch_size);
184            }
185            AppType::Tui => {
186                session_config = session_config.with_batch_size(execution_config.tui_batch_size);
187            }
188            AppType::FlightSQLServer => {
189                session_config =
190                    session_config.with_batch_size(execution_config.flightsql_server_batch_size);
191            }
192        }
193
194        let mut builder = SessionStateBuilder::new()
195            .with_default_features()
196            .with_config(session_config);
197
198        if let Some(runtime_env) = runtime_env {
199            builder = builder.with_runtime_env(runtime_env);
200        }
201        if let Some(table_factories) = table_factories {
202            builder = builder.with_table_factories(table_factories);
203        }
204
205        if let Some(catalog_providers) = catalog_providers {
206            let catalogs_list = MemoryCatalogProviderList::new();
207            for (k, v) in catalog_providers {
208                catalogs_list.register_catalog(k, v);
209            }
210            builder = builder.with_catalog_list(Arc::new(catalogs_list));
211        }
212
213        Ok(builder.build())
214    }
215}