datafusion_dft/extensions/
builder.rs1use color_eyre::eyre;
21use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory};
22use datafusion::catalog_common::MemoryCatalogProviderList;
23use datafusion::execution::context::SessionState;
24use datafusion::execution::runtime_env::RuntimeEnv;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::prelude::SessionConfig;
27use std::collections::HashMap;
28use std::fmt::Debug;
29use std::sync::Arc;
30
31use crate::{config::ExecutionConfig, execution::AppType};
32
33use super::{enabled_extensions, Extension};
34
35pub struct DftSessionStateBuilder {
54 app_type: Option<AppType>,
55 execution_config: Option<ExecutionConfig>,
56 session_config: SessionConfig,
57 table_factories: Option<HashMap<String, Arc<dyn TableProviderFactory>>>,
58 catalog_providers: Option<HashMap<String, Arc<dyn CatalogProvider>>>,
59 runtime_env: Option<Arc<RuntimeEnv>>,
60}
61
62impl Debug for DftSessionStateBuilder {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("DftSessionStateBuilder")
65 .field("session_config", &self.session_config)
66 .field(
67 "table_factories",
68 &"TODO TableFactory does not implement Debug",
69 )
70 .field("runtime_env", &self.runtime_env)
71 .finish()
72 }
73}
74
75impl Default for DftSessionStateBuilder {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl DftSessionStateBuilder {
82 pub fn new() -> Self {
84 let session_config = SessionConfig::default().with_information_schema(true);
85
86 Self {
87 session_config,
88 app_type: None,
89 execution_config: None,
90 table_factories: None,
91 catalog_providers: None,
92 runtime_env: None,
93 }
94 }
95
96 pub fn with_app_type(mut self, app_type: AppType) -> Self {
97 self.app_type = Some(app_type);
98 self
99 }
100
101 pub fn with_execution_config(mut self, app_type: ExecutionConfig) -> Self {
102 self.execution_config = Some(app_type);
103 self
104 }
105
106 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
108 self.session_config = self.session_config.with_batch_size(batch_size);
109 self
110 }
111
112 pub fn add_table_factory(&mut self, name: &str, factory: Arc<dyn TableProviderFactory>) {
114 if self.table_factories.is_none() {
115 self.table_factories = Some(HashMap::from([(name.to_string(), factory)]));
116 } else {
117 self.table_factories
118 .as_mut()
119 .unwrap()
120 .insert(name.to_string(), factory);
121 }
122 }
123
124 pub fn add_catalog_provider(&mut self, name: &str, factory: Arc<dyn CatalogProvider>) {
126 if self.catalog_providers.is_none() {
127 self.catalog_providers = Some(HashMap::from([(name.to_string(), factory)]));
128 } else {
129 self.catalog_providers
130 .as_mut()
131 .unwrap()
132 .insert(name.to_string(), factory);
133 }
134 }
135
136 pub fn runtime_env(&mut self) -> &RuntimeEnv {
138 if self.runtime_env.is_none() {
139 self.runtime_env = Some(Arc::new(RuntimeEnv::default()));
140 }
141 self.runtime_env.as_ref().unwrap()
142 }
143
144 pub async fn register_extension(
145 &mut self,
146 config: ExecutionConfig,
147 extension: Arc<dyn Extension>,
148 ) -> color_eyre::Result<()> {
149 extension
150 .register(config, self)
151 .await
152 .map_err(|_| eyre::eyre!("E"))
153 }
154
155 pub async fn with_extensions(mut self) -> color_eyre::Result<Self> {
156 let extensions = enabled_extensions();
157
158 for extension in extensions {
159 let execution_config = self.execution_config.clone().unwrap_or_default();
160 self.register_extension(execution_config, extension).await?;
161 }
162
163 Ok(self)
164 }
165
166 pub fn build(self) -> datafusion_common::Result<SessionState> {
168 let Self {
169 app_type,
170 execution_config,
171 mut session_config,
172 table_factories,
173 catalog_providers,
174 runtime_env,
175 ..
176 } = self;
177
178 let app_type = app_type.unwrap_or(AppType::Cli);
179 let execution_config = execution_config.unwrap_or_default();
180
181 match app_type {
182 AppType::Cli => {
183 session_config = session_config.with_batch_size(execution_config.cli_batch_size);
184 }
185 AppType::Tui => {
186 session_config = session_config.with_batch_size(execution_config.tui_batch_size);
187 }
188 AppType::FlightSQLServer => {
189 session_config =
190 session_config.with_batch_size(execution_config.flightsql_server_batch_size);
191 }
192 }
193
194 let mut builder = SessionStateBuilder::new()
195 .with_default_features()
196 .with_config(session_config);
197
198 if let Some(runtime_env) = runtime_env {
199 builder = builder.with_runtime_env(runtime_env);
200 }
201 if let Some(table_factories) = table_factories {
202 builder = builder.with_table_factories(table_factories);
203 }
204
205 if let Some(catalog_providers) = catalog_providers {
206 let catalogs_list = MemoryCatalogProviderList::new();
207 for (k, v) in catalog_providers {
208 catalogs_list.register_catalog(k, v);
209 }
210 builder = builder.with_catalog_list(Arc::new(catalogs_list));
211 }
212
213 Ok(builder.build())
214 }
215}