lance_namespace_datafusion/
session_builder.rs1use datafusion::catalog::{CatalogProvider, SchemaProvider};
5use datafusion::error::Result;
6use datafusion::execution::context::{SessionConfig, SessionContext};
7use std::sync::Arc;
8
9use crate::LanceCatalogProvider;
10use crate::catalog::LanceCatalogProviderList;
11use crate::namespace_level::NamespaceLevel;
12
13#[derive(Clone, Debug, Default)]
15pub struct SessionBuilder {
16 root: Option<NamespaceLevel>,
19 catalogs: Vec<(String, NamespaceLevel)>,
21 config: Option<SessionConfig>,
23 default_catalog: Option<String>,
26 default_catalog_provider: Option<Arc<dyn CatalogProvider>>,
28 default_schema: Option<String>,
31 default_schema_provider: Option<Arc<dyn SchemaProvider>>,
33}
34
35impl SessionBuilder {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn with_root(mut self, ns: NamespaceLevel) -> Self {
44 self.root = Some(ns);
45 self
46 }
47
48 pub fn add_catalog(mut self, name: &str, ns: NamespaceLevel) -> Self {
54 self.catalogs.push((name.to_string(), ns));
55 self
56 }
57
58 pub fn with_config(mut self, config: SessionConfig) -> Self {
61 self.config = Some(config);
62 self
63 }
64
65 pub fn with_default_catalog(
67 mut self,
68 name: &str,
69 catalog_provider: Option<Arc<dyn CatalogProvider>>,
70 ) -> Self {
71 self.default_catalog = Some(name.to_string());
72 self.default_catalog_provider = catalog_provider;
73 self
74 }
75
76 pub fn with_default_schema(
78 mut self,
79 name: &str,
80 schema_provider: Option<Arc<dyn SchemaProvider>>,
81 ) -> Self {
82 self.default_schema = Some(name.to_string());
83 self.default_schema_provider = schema_provider;
84 self
85 }
86
87 pub async fn build(self) -> Result<SessionContext> {
89 self.check_params_valid()?;
90 let config = self.config.unwrap_or_default();
91 let options = config.options();
92 let default_catalog = self
93 .default_catalog
94 .unwrap_or_else(|| options.catalog.default_catalog.clone());
95 let default_schema = self
96 .default_schema
97 .unwrap_or_else(|| options.catalog.default_schema.clone());
98
99 let ctx = SessionContext::new_with_config(
100 config
101 .with_default_catalog_and_schema(default_catalog.as_str(), default_schema.as_str()),
102 );
103
104 if let Some(root) = self.root {
105 let catalog_list = Arc::new(LanceCatalogProviderList::try_new(root).await?);
106 ctx.register_catalog_list(catalog_list);
107 }
108
109 for (catalog_name, namespace) in self.catalogs {
110 ctx.register_catalog(
111 catalog_name,
112 Arc::new(LanceCatalogProvider::try_new(namespace).await?),
113 );
114 }
115 if let Some(catalog_provider) = self.default_catalog_provider {
116 if let Some(schema_provider) = self.default_schema_provider {
117 catalog_provider.register_schema(default_schema.as_str(), schema_provider)?;
118 }
119 ctx.register_catalog(default_catalog.as_str(), catalog_provider);
120 }
121
122 Ok(ctx)
123 }
124
125 fn check_params_valid(&self) -> Result<()> {
126 if let (None, Some(schema)) = (&self.default_catalog, &self.default_schema) {
127 return Err(datafusion::error::DataFusionError::Internal(format!(
128 "Default SchemaProvider {} must be used together with a default CatalogProvider",
129 schema
130 )));
131 }
132 Ok(())
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::SessionBuilder;
139 use std::sync::Arc;
140
141 use arrow_array::{Int64Array, RecordBatch};
142 use datafusion::catalog::SchemaProvider;
143 use datafusion::catalog::memory::{MemoryCatalogProvider, MemorySchemaProvider};
144 use datafusion::common::record_batch;
145 use datafusion::datasource::MemTable;
146 use datafusion::error::Result;
147
148 #[tokio::test]
149 async fn default_catalog_and_schema_are_used_for_sql_queries() -> Result<()> {
150 let batch = record_batch!(
152 ("order_id", Int32, vec![101, 102, 103]),
153 ("customer_id", Int32, vec![1, 2, 3]),
154 ("amount", Int32, vec![100, 200, 300])
155 )?;
156 let schema = batch.schema();
157 let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
158
159 let sales_schema = Arc::new(MemorySchemaProvider::new());
161 let retail_catalog = Arc::new(MemoryCatalogProvider::new());
162 sales_schema.register_table("orders".to_string(), table)?;
163
164 let ctx = SessionBuilder::new()
166 .with_default_catalog("retail", Some(retail_catalog))
167 .with_default_schema("sales", Some(sales_schema))
168 .build()
169 .await?;
170
171 let extract_count = |batches: &[RecordBatch]| -> i64 {
172 let batch = &batches[0];
173 let array = batch
174 .column(0)
175 .as_any()
176 .downcast_ref::<Int64Array>()
177 .expect("COUNT should return Int64Array");
178 assert_eq!(array.len(), 1);
179 array.value(0)
180 };
181
182 let df_with_schema = ctx.sql("SELECT COUNT(*) AS c FROM sales.orders").await?;
184 let batches_with_schema = df_with_schema.collect().await?;
185
186 let df_without_schema = ctx.sql("SELECT COUNT(*) AS c FROM orders").await?;
188 let batches_without_schema = df_without_schema.collect().await?;
189
190 let count_with_schema = extract_count(&batches_with_schema);
191 let count_without_schema = extract_count(&batches_without_schema);
192
193 assert_eq!(count_with_schema, 3);
194 assert_eq!(count_without_schema, 3);
195 assert_eq!(count_with_schema, count_without_schema);
196
197 Ok(())
198 }
199}