Skip to main content

datafusion_app/extensions/
builder.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`DftSessionStateBuilder`] for configuring DataFusion [`SessionState`]
19
20use color_eyre::{eyre, Result};
21use datafusion::catalog::MemoryCatalogProviderList;
22use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory};
23use datafusion::datasource::file_format::{
24    csv::CsvFormatFactory, json::JsonFormatFactory, parquet::ParquetFormatFactory,
25    FileFormatFactory,
26};
27use datafusion::execution::context::SessionState;
28use datafusion::execution::runtime_env::RuntimeEnv;
29use datafusion::execution::session_state::SessionStateBuilder;
30use datafusion::prelude::SessionConfig;
31use std::collections::HashMap;
32use std::fmt::Debug;
33use std::sync::Arc;
34
35use crate::config::ExecutionConfig;
36
37use super::{enabled_extensions, Extension};
38
39/// Builds a DataFusion [`SessionState`] with any necessary configuration
40///
41/// Ideally we would use the DataFusion [`SessionStateBuilder`], but it doesn't
42/// currently have all the needed APIs. Once we have a good handle on the needed
43/// APIs we can upstream them to DataFusion.
44///
45/// List of things that would be nice to add upstream:
46/// TODO: Implement Debug for SessionStateBuilder upstream
47///  <https://github.com/apache/datafusion/issues/12555>
48/// TODO: Implement some way to get access to the current RuntimeEnv (to register object stores)
49///  <https://github.com/apache/datafusion/issues/12553>
50/// TODO: Implement a way to add just a single TableProviderFactory
51///  <https://github.com/apache/datafusion/issues/12552>
52/// TODO: Make TableFactoryProvider implement Debug
53///   <https://github.com/apache/datafusion/pull/12557>
54/// TODO: rename RuntimeEnv::new() to RuntimeEnv::try_new() as it returns a Result:
55///   <https://github.com/apache/datafusion/issues/12554>
56//#[derive(Debug)]
57pub struct DftSessionStateBuilder {
58    execution_config: Option<ExecutionConfig>,
59    session_config: SessionConfig,
60    table_factories: Option<HashMap<String, Arc<dyn TableProviderFactory>>>,
61    file_format_factories: Vec<Arc<dyn FileFormatFactory>>,
62    catalog_providers: Option<HashMap<String, Arc<dyn CatalogProvider>>>,
63    runtime_env: Option<Arc<RuntimeEnv>>,
64}
65
66impl Debug for DftSessionStateBuilder {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("DftSessionStateBuilder")
69            .field("session_config", &self.session_config)
70            .field(
71                "table_factories",
72                &"TODO TableFactory does not implement Debug",
73            )
74            .field("runtime_env", &self.runtime_env)
75            .finish()
76    }
77}
78
79impl Default for DftSessionStateBuilder {
80    fn default() -> Self {
81        Self {
82            session_config: SessionConfig::default().with_information_schema(true),
83            execution_config: None,
84            table_factories: None,
85            file_format_factories: vec![
86                Arc::new(ParquetFormatFactory::new()),
87                Arc::new(CsvFormatFactory::new()),
88                Arc::new(JsonFormatFactory::new()),
89            ],
90            catalog_providers: None,
91            runtime_env: None,
92        }
93    }
94}
95
96impl DftSessionStateBuilder {
97    /// Create a new builder
98    pub fn try_new(config: Option<ExecutionConfig>) -> Result<Self> {
99        let session_config = if let Some(cfg) = config.clone().unwrap_or_default().datafusion {
100            SessionConfig::from_string_hash_map(&cfg)?.with_information_schema(true)
101        } else {
102            SessionConfig::default().with_information_schema(true)
103        };
104
105        let builder = Self {
106            session_config,
107            execution_config: config,
108            table_factories: None,
109            file_format_factories: vec![
110                Arc::new(ParquetFormatFactory::new()),
111                Arc::new(CsvFormatFactory::new()),
112                Arc::new(JsonFormatFactory::new()),
113            ],
114
115            catalog_providers: None,
116            runtime_env: None,
117        };
118        Ok(builder)
119    }
120
121    /// Add a table factory to the list of factories on this builder
122    pub fn add_table_factory(&mut self, name: &str, factory: Arc<dyn TableProviderFactory>) {
123        if self.table_factories.is_none() {
124            self.table_factories = Some(HashMap::from([(name.to_string(), factory)]));
125        } else {
126            self.table_factories
127                .as_mut()
128                .unwrap()
129                .insert(name.to_string(), factory);
130        }
131    }
132
133    /// Add a file format factory to the list of file format factories on this builder
134    pub fn add_file_format_factory(&mut self, factory: Arc<dyn FileFormatFactory>) {
135        self.file_format_factories.push(factory);
136    }
137
138    /// Add a catalog provider to the list of providers on this builder
139    pub fn add_catalog_provider(&mut self, name: &str, factory: Arc<dyn CatalogProvider>) {
140        if self.catalog_providers.is_none() {
141            self.catalog_providers = Some(HashMap::from([(name.to_string(), factory)]));
142        } else {
143            self.catalog_providers
144                .as_mut()
145                .unwrap()
146                .insert(name.to_string(), factory);
147        }
148    }
149
150    /// Return the current [`RuntimeEnv`], creating a default if it doesn't exist
151    pub fn runtime_env(&mut self) -> &RuntimeEnv {
152        if self.runtime_env.is_none() {
153            self.runtime_env = Some(Arc::new(RuntimeEnv::default()));
154        }
155        self.runtime_env.as_ref().unwrap()
156    }
157
158    pub async fn register_extension(
159        &mut self,
160        config: ExecutionConfig,
161        extension: Arc<dyn Extension>,
162    ) -> color_eyre::Result<()> {
163        extension
164            .register(config, self)
165            .await
166            .map_err(|_| eyre::eyre!("E"))
167    }
168
169    pub async fn with_extensions(mut self) -> color_eyre::Result<Self> {
170        let extensions = enabled_extensions();
171
172        for extension in extensions {
173            let execution_config = self.execution_config.clone();
174            self.register_extension(execution_config.unwrap_or_default(), extension)
175                .await?;
176        }
177
178        Ok(self)
179    }
180
181    /// Build the [`SessionState`] from the specified configuration
182    pub fn build(self) -> datafusion::common::Result<SessionState> {
183        let Self {
184            session_config,
185            table_factories,
186            file_format_factories,
187            catalog_providers,
188            runtime_env,
189            ..
190        } = self;
191
192        let mut builder = SessionStateBuilder::new()
193            .with_default_features()
194            .with_config(session_config);
195
196        if let Some(runtime_env) = runtime_env {
197            builder = builder.with_runtime_env(runtime_env);
198        }
199        if let Some(table_factories) = table_factories {
200            builder = builder.with_table_factories(table_factories);
201        }
202        builder = builder.with_file_formats(file_format_factories);
203
204        if let Some(catalog_providers) = catalog_providers {
205            let catalogs_list = MemoryCatalogProviderList::new();
206            for (k, v) in catalog_providers {
207                catalogs_list.register_catalog(k, v);
208            }
209            builder = builder.with_catalog_list(Arc::new(catalogs_list));
210        }
211
212        Ok(builder.build())
213    }
214}