use datafusion::catalog::{CatalogProvider, SchemaProvider};
use datafusion::error::Result;
use datafusion::execution::context::{SessionConfig, SessionContext};
use std::sync::Arc;
use crate::LanceCatalogProvider;
use crate::catalog::LanceCatalogProviderList;
use crate::namespace_level::NamespaceLevel;
#[derive(Clone, Debug, Default)]
pub struct SessionBuilder {
root: Option<NamespaceLevel>,
catalogs: Vec<(String, NamespaceLevel)>,
config: Option<SessionConfig>,
default_catalog: Option<String>,
default_catalog_provider: Option<Arc<dyn CatalogProvider>>,
default_schema: Option<String>,
default_schema_provider: Option<Arc<dyn SchemaProvider>>,
}
impl SessionBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_root(mut self, ns: NamespaceLevel) -> Self {
self.root = Some(ns);
self
}
pub fn add_catalog(mut self, name: &str, ns: NamespaceLevel) -> Self {
self.catalogs.push((name.to_string(), ns));
self
}
pub fn with_config(mut self, config: SessionConfig) -> Self {
self.config = Some(config);
self
}
pub fn with_default_catalog(
mut self,
name: &str,
catalog_provider: Option<Arc<dyn CatalogProvider>>,
) -> Self {
self.default_catalog = Some(name.to_string());
self.default_catalog_provider = catalog_provider;
self
}
pub fn with_default_schema(
mut self,
name: &str,
schema_provider: Option<Arc<dyn SchemaProvider>>,
) -> Self {
self.default_schema = Some(name.to_string());
self.default_schema_provider = schema_provider;
self
}
pub async fn build(self) -> Result<SessionContext> {
self.check_params_valid()?;
let config = self.config.unwrap_or_default();
let options = config.options();
let default_catalog = self
.default_catalog
.unwrap_or_else(|| options.catalog.default_catalog.clone());
let default_schema = self
.default_schema
.unwrap_or_else(|| options.catalog.default_schema.clone());
let ctx = SessionContext::new_with_config(
config
.with_default_catalog_and_schema(default_catalog.as_str(), default_schema.as_str()),
);
if let Some(root) = self.root {
let catalog_list = Arc::new(LanceCatalogProviderList::try_new(root).await?);
ctx.register_catalog_list(catalog_list);
}
for (catalog_name, namespace) in self.catalogs {
ctx.register_catalog(
catalog_name,
Arc::new(LanceCatalogProvider::try_new(namespace).await?),
);
}
if let Some(catalog_provider) = self.default_catalog_provider {
if let Some(schema_provider) = self.default_schema_provider {
catalog_provider.register_schema(default_schema.as_str(), schema_provider)?;
}
ctx.register_catalog(default_catalog.as_str(), catalog_provider);
}
Ok(ctx)
}
fn check_params_valid(&self) -> Result<()> {
if let (None, Some(schema)) = (&self.default_catalog, &self.default_schema) {
return Err(datafusion::error::DataFusionError::Internal(format!(
"Default SchemaProvider {} must be used together with a default CatalogProvider",
schema
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::SessionBuilder;
use std::sync::Arc;
use arrow_array::{Int64Array, RecordBatch};
use datafusion::catalog::SchemaProvider;
use datafusion::catalog::memory::{MemoryCatalogProvider, MemorySchemaProvider};
use datafusion::common::record_batch;
use datafusion::datasource::MemTable;
use datafusion::error::Result;
#[tokio::test]
async fn default_catalog_and_schema_are_used_for_sql_queries() -> Result<()> {
let batch = record_batch!(
("order_id", Int32, vec![101, 102, 103]),
("customer_id", Int32, vec![1, 2, 3]),
("amount", Int32, vec![100, 200, 300])
)?;
let schema = batch.schema();
let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
let sales_schema = Arc::new(MemorySchemaProvider::new());
let retail_catalog = Arc::new(MemoryCatalogProvider::new());
sales_schema.register_table("orders".to_string(), table)?;
let ctx = SessionBuilder::new()
.with_default_catalog("retail", Some(retail_catalog))
.with_default_schema("sales", Some(sales_schema))
.build()
.await?;
let extract_count = |batches: &[RecordBatch]| -> i64 {
let batch = &batches[0];
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.expect("COUNT should return Int64Array");
assert_eq!(array.len(), 1);
array.value(0)
};
let df_with_schema = ctx.sql("SELECT COUNT(*) AS c FROM sales.orders").await?;
let batches_with_schema = df_with_schema.collect().await?;
let df_without_schema = ctx.sql("SELECT COUNT(*) AS c FROM orders").await?;
let batches_without_schema = df_without_schema.collect().await?;
let count_with_schema = extract_count(&batches_with_schema);
let count_without_schema = extract_count(&batches_without_schema);
assert_eq!(count_with_schema, 3);
assert_eq!(count_without_schema, 3);
assert_eq!(count_with_schema, count_without_schema);
Ok(())
}
}