datafusion_app/extensions/
builder.rs1use color_eyre::{eyre, Result};
21use datafusion::catalog::MemoryCatalogProviderList;
22use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory};
23use datafusion::datasource::file_format::{
24 csv::CsvFormatFactory, json::JsonFormatFactory, parquet::ParquetFormatFactory,
25 FileFormatFactory,
26};
27use datafusion::execution::context::SessionState;
28use datafusion::execution::runtime_env::RuntimeEnv;
29use datafusion::execution::session_state::SessionStateBuilder;
30use datafusion::prelude::SessionConfig;
31use std::collections::HashMap;
32use std::fmt::Debug;
33use std::sync::Arc;
34
35use crate::config::ExecutionConfig;
36
37use super::{enabled_extensions, Extension};
38
39pub struct DftSessionStateBuilder {
58 execution_config: Option<ExecutionConfig>,
59 session_config: SessionConfig,
60 table_factories: Option<HashMap<String, Arc<dyn TableProviderFactory>>>,
61 file_format_factories: Vec<Arc<dyn FileFormatFactory>>,
62 catalog_providers: Option<HashMap<String, Arc<dyn CatalogProvider>>>,
63 runtime_env: Option<Arc<RuntimeEnv>>,
64}
65
66impl Debug for DftSessionStateBuilder {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("DftSessionStateBuilder")
69 .field("session_config", &self.session_config)
70 .field(
71 "table_factories",
72 &"TODO TableFactory does not implement Debug",
73 )
74 .field("runtime_env", &self.runtime_env)
75 .finish()
76 }
77}
78
79impl Default for DftSessionStateBuilder {
80 fn default() -> Self {
81 Self {
82 session_config: SessionConfig::default().with_information_schema(true),
83 execution_config: None,
84 table_factories: None,
85 file_format_factories: vec![
86 Arc::new(ParquetFormatFactory::new()),
87 Arc::new(CsvFormatFactory::new()),
88 Arc::new(JsonFormatFactory::new()),
89 ],
90 catalog_providers: None,
91 runtime_env: None,
92 }
93 }
94}
95
96impl DftSessionStateBuilder {
97 pub fn try_new(config: Option<ExecutionConfig>) -> Result<Self> {
99 let session_config = if let Some(cfg) = config.clone().unwrap_or_default().datafusion {
100 SessionConfig::from_string_hash_map(&cfg)?.with_information_schema(true)
101 } else {
102 SessionConfig::default().with_information_schema(true)
103 };
104
105 let builder = Self {
106 session_config,
107 execution_config: config,
108 table_factories: None,
109 file_format_factories: vec![
110 Arc::new(ParquetFormatFactory::new()),
111 Arc::new(CsvFormatFactory::new()),
112 Arc::new(JsonFormatFactory::new()),
113 ],
114
115 catalog_providers: None,
116 runtime_env: None,
117 };
118 Ok(builder)
119 }
120
121 pub fn add_table_factory(&mut self, name: &str, factory: Arc<dyn TableProviderFactory>) {
123 if self.table_factories.is_none() {
124 self.table_factories = Some(HashMap::from([(name.to_string(), factory)]));
125 } else {
126 self.table_factories
127 .as_mut()
128 .unwrap()
129 .insert(name.to_string(), factory);
130 }
131 }
132
133 pub fn add_file_format_factory(&mut self, factory: Arc<dyn FileFormatFactory>) {
135 self.file_format_factories.push(factory);
136 }
137
138 pub fn add_catalog_provider(&mut self, name: &str, factory: Arc<dyn CatalogProvider>) {
140 if self.catalog_providers.is_none() {
141 self.catalog_providers = Some(HashMap::from([(name.to_string(), factory)]));
142 } else {
143 self.catalog_providers
144 .as_mut()
145 .unwrap()
146 .insert(name.to_string(), factory);
147 }
148 }
149
150 pub fn runtime_env(&mut self) -> &RuntimeEnv {
152 if self.runtime_env.is_none() {
153 self.runtime_env = Some(Arc::new(RuntimeEnv::default()));
154 }
155 self.runtime_env.as_ref().unwrap()
156 }
157
158 pub async fn register_extension(
159 &mut self,
160 config: ExecutionConfig,
161 extension: Arc<dyn Extension>,
162 ) -> color_eyre::Result<()> {
163 extension
164 .register(config, self)
165 .await
166 .map_err(|_| eyre::eyre!("E"))
167 }
168
169 pub async fn with_extensions(mut self) -> color_eyre::Result<Self> {
170 let extensions = enabled_extensions();
171
172 for extension in extensions {
173 let execution_config = self.execution_config.clone();
174 self.register_extension(execution_config.unwrap_or_default(), extension)
175 .await?;
176 }
177
178 Ok(self)
179 }
180
181 pub fn build(self) -> datafusion::common::Result<SessionState> {
183 let Self {
184 session_config,
185 table_factories,
186 file_format_factories,
187 catalog_providers,
188 runtime_env,
189 ..
190 } = self;
191
192 let mut builder = SessionStateBuilder::new()
193 .with_default_features()
194 .with_config(session_config);
195
196 if let Some(runtime_env) = runtime_env {
197 builder = builder.with_runtime_env(runtime_env);
198 }
199 if let Some(table_factories) = table_factories {
200 builder = builder.with_table_factories(table_factories);
201 }
202 builder = builder.with_file_formats(file_format_factories);
203
204 if let Some(catalog_providers) = catalog_providers {
205 let catalogs_list = MemoryCatalogProviderList::new();
206 for (k, v) in catalog_providers {
207 catalogs_list.register_catalog(k, v);
208 }
209 builder = builder.with_catalog_list(Arc::new(catalogs_list));
210 }
211
212 Ok(builder.build())
213 }
214}