datafusion_dft/extensions/
builder.rsuse color_eyre::eyre;
use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory};
use datafusion::catalog_common::MemoryCatalogProviderList;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::prelude::SessionConfig;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use crate::{config::ExecutionConfig, execution::AppType};
use super::{enabled_extensions, Extension};
pub struct DftSessionStateBuilder {
app_type: Option<AppType>,
execution_config: Option<ExecutionConfig>,
session_config: SessionConfig,
table_factories: Option<HashMap<String, Arc<dyn TableProviderFactory>>>,
catalog_providers: Option<HashMap<String, Arc<dyn CatalogProvider>>>,
runtime_env: Option<Arc<RuntimeEnv>>,
}
impl Debug for DftSessionStateBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DftSessionStateBuilder")
.field("session_config", &self.session_config)
.field(
"table_factories",
&"TODO TableFactory does not implement Debug",
)
.field("runtime_env", &self.runtime_env)
.finish()
}
}
impl Default for DftSessionStateBuilder {
fn default() -> Self {
Self::new()
}
}
impl DftSessionStateBuilder {
pub fn new() -> Self {
let session_config = SessionConfig::default().with_information_schema(true);
Self {
session_config,
app_type: None,
execution_config: None,
table_factories: None,
catalog_providers: None,
runtime_env: None,
}
}
pub fn with_app_type(mut self, app_type: AppType) -> Self {
self.app_type = Some(app_type);
self
}
pub fn with_execution_config(mut self, app_type: ExecutionConfig) -> Self {
self.execution_config = Some(app_type);
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.session_config = self.session_config.with_batch_size(batch_size);
self
}
pub fn add_table_factory(&mut self, name: &str, factory: Arc<dyn TableProviderFactory>) {
if self.table_factories.is_none() {
self.table_factories = Some(HashMap::from([(name.to_string(), factory)]));
} else {
self.table_factories
.as_mut()
.unwrap()
.insert(name.to_string(), factory);
}
}
pub fn add_catalog_provider(&mut self, name: &str, factory: Arc<dyn CatalogProvider>) {
if self.catalog_providers.is_none() {
self.catalog_providers = Some(HashMap::from([(name.to_string(), factory)]));
} else {
self.catalog_providers
.as_mut()
.unwrap()
.insert(name.to_string(), factory);
}
}
pub fn runtime_env(&mut self) -> &RuntimeEnv {
if self.runtime_env.is_none() {
self.runtime_env = Some(Arc::new(RuntimeEnv::default()));
}
self.runtime_env.as_ref().unwrap()
}
pub async fn register_extension(
&mut self,
config: ExecutionConfig,
extension: Arc<dyn Extension>,
) -> color_eyre::Result<()> {
extension
.register(config, self)
.await
.map_err(|_| eyre::eyre!("E"))
}
pub async fn with_extensions(mut self) -> color_eyre::Result<Self> {
let extensions = enabled_extensions();
for extension in extensions {
let execution_config = self.execution_config.clone().unwrap_or_default();
self.register_extension(execution_config, extension).await?;
}
Ok(self)
}
pub fn build(self) -> datafusion_common::Result<SessionState> {
let Self {
app_type,
execution_config,
mut session_config,
table_factories,
catalog_providers,
runtime_env,
..
} = self;
let app_type = app_type.unwrap_or(AppType::Cli);
let execution_config = execution_config.unwrap_or_default();
match app_type {
AppType::Cli => {
session_config = session_config.with_batch_size(execution_config.cli_batch_size);
}
AppType::Tui => {
session_config = session_config.with_batch_size(execution_config.tui_batch_size);
}
AppType::FlightSQLServer => {
session_config =
session_config.with_batch_size(execution_config.flightsql_server_batch_size);
}
}
let mut builder = SessionStateBuilder::new()
.with_default_features()
.with_config(session_config);
if let Some(runtime_env) = runtime_env {
builder = builder.with_runtime_env(runtime_env);
}
if let Some(table_factories) = table_factories {
builder = builder.with_table_factories(table_factories);
}
if let Some(catalog_providers) = catalog_providers {
let catalogs_list = MemoryCatalogProviderList::new();
for (k, v) in catalog_providers {
catalogs_list.register_catalog(k, v);
}
builder = builder.with_catalog_list(Arc::new(catalogs_list));
}
Ok(builder.build())
}
}