Skip to main content

lance_namespace_datafusion/
session_builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use datafusion::catalog::{CatalogProvider, SchemaProvider};
5use datafusion::error::Result;
6use datafusion::execution::context::{SessionConfig, SessionContext};
7use std::sync::Arc;
8
9use crate::LanceCatalogProvider;
10use crate::catalog::LanceCatalogProviderList;
11use crate::namespace_level::NamespaceLevel;
12
13/// Builder for configuring a `SessionContext` with Lance namespaces.
14#[derive(Clone, Debug, Default)]
15pub struct SessionBuilder {
16    /// Optional root namespace exposed via a dynamic
17    /// `LanceCatalogProviderList`.
18    root: Option<NamespaceLevel>,
19    /// Explicit catalogs to register by name.
20    catalogs: Vec<(String, NamespaceLevel)>,
21    /// Optional DataFusion session configuration.
22    config: Option<SessionConfig>,
23    /// Optional default catalog name.
24    /// It will override the default catalog name in [`SessionBuilder::config`] if set
25    default_catalog: Option<String>,
26    /// Optional default catalog provider.
27    default_catalog_provider: Option<Arc<dyn CatalogProvider>>,
28    /// Optional default schema name.
29    /// It will override the default schema name in [`SessionBuilder::config`] if set
30    default_schema: Option<String>,
31    /// Optional default schema provider.
32    default_schema_provider: Option<Arc<dyn SchemaProvider>>,
33}
34
35impl SessionBuilder {
36    /// Create a new builder with no namespaces or configuration.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Attach a root `LanceNamespace` that is exposed as a dynamic
42    /// catalog list via `LanceCatalogProviderList`.
43    pub fn with_root(mut self, ns: NamespaceLevel) -> Self {
44        self.root = Some(ns);
45        self
46    }
47
48    /// Register an additional catalog backed by the given namespace.
49    ///
50    /// The catalog is identified by `name` and can later be combined
51    /// with schemas via `SessionBuilder::add_schema` using the same
52    /// namespace.
53    pub fn add_catalog(mut self, name: &str, ns: NamespaceLevel) -> Self {
54        self.catalogs.push((name.to_string(), ns));
55        self
56    }
57
58    /// Provide an explicit `SessionConfig` for the underlying
59    /// `SessionContext`.
60    pub fn with_config(mut self, config: SessionConfig) -> Self {
61        self.config = Some(config);
62        self
63    }
64
65    /// Override the default catalog name used by the session.
66    pub fn with_default_catalog(
67        mut self,
68        name: &str,
69        catalog_provider: Option<Arc<dyn CatalogProvider>>,
70    ) -> Self {
71        self.default_catalog = Some(name.to_string());
72        self.default_catalog_provider = catalog_provider;
73        self
74    }
75
76    /// Override the default schema name used by the session.
77    pub fn with_default_schema(
78        mut self,
79        name: &str,
80        schema_provider: Option<Arc<dyn SchemaProvider>>,
81    ) -> Self {
82        self.default_schema = Some(name.to_string());
83        self.default_schema_provider = schema_provider;
84        self
85    }
86
87    /// Build a `SessionContext` with all configured namespaces.
88    pub async fn build(self) -> Result<SessionContext> {
89        self.check_params_valid()?;
90        let config = self.config.unwrap_or_default();
91        let options = config.options();
92        let default_catalog = self
93            .default_catalog
94            .unwrap_or_else(|| options.catalog.default_catalog.clone());
95        let default_schema = self
96            .default_schema
97            .unwrap_or_else(|| options.catalog.default_schema.clone());
98
99        let ctx = SessionContext::new_with_config(
100            config
101                .with_default_catalog_and_schema(default_catalog.as_str(), default_schema.as_str()),
102        );
103
104        if let Some(root) = self.root {
105            let catalog_list = Arc::new(LanceCatalogProviderList::try_new(root).await?);
106            ctx.register_catalog_list(catalog_list);
107        }
108
109        for (catalog_name, namespace) in self.catalogs {
110            ctx.register_catalog(
111                catalog_name,
112                Arc::new(LanceCatalogProvider::try_new(namespace).await?),
113            );
114        }
115        if let Some(catalog_provider) = self.default_catalog_provider {
116            if let Some(schema_provider) = self.default_schema_provider {
117                catalog_provider.register_schema(default_schema.as_str(), schema_provider)?;
118            }
119            ctx.register_catalog(default_catalog.as_str(), catalog_provider);
120        }
121
122        Ok(ctx)
123    }
124
125    fn check_params_valid(&self) -> Result<()> {
126        if let (None, Some(schema)) = (&self.default_catalog, &self.default_schema) {
127            return Err(datafusion::error::DataFusionError::Internal(format!(
128                "Default SchemaProvider {} must be used together with a default CatalogProvider",
129                schema
130            )));
131        }
132        Ok(())
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::SessionBuilder;
139    use std::sync::Arc;
140
141    use arrow_array::{Int64Array, RecordBatch};
142    use datafusion::catalog::SchemaProvider;
143    use datafusion::catalog::memory::{MemoryCatalogProvider, MemorySchemaProvider};
144    use datafusion::common::record_batch;
145    use datafusion::datasource::MemTable;
146    use datafusion::error::Result;
147
148    #[tokio::test]
149    async fn default_catalog_and_schema_are_used_for_sql_queries() -> Result<()> {
150        // Construct a simple in-memory orders table using the same style as tests/sql.rs.
151        let batch = record_batch!(
152            ("order_id", Int32, vec![101, 102, 103]),
153            ("customer_id", Int32, vec![1, 2, 3]),
154            ("amount", Int32, vec![100, 200, 300])
155        )?;
156        let schema = batch.schema();
157        let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
158
159        // Create DataFusion's in-memory schema and catalog providers.
160        let sales_schema = Arc::new(MemorySchemaProvider::new());
161        let retail_catalog = Arc::new(MemoryCatalogProvider::new());
162        sales_schema.register_table("orders".to_string(), table)?;
163
164        // Build a SessionContext that uses the memory catalog/schema as defaults.
165        let ctx = SessionBuilder::new()
166            .with_default_catalog("retail", Some(retail_catalog))
167            .with_default_schema("sales", Some(sales_schema))
168            .build()
169            .await?;
170
171        let extract_count = |batches: &[RecordBatch]| -> i64 {
172            let batch = &batches[0];
173            let array = batch
174                .column(0)
175                .as_any()
176                .downcast_ref::<Int64Array>()
177                .expect("COUNT should return Int64Array");
178            assert_eq!(array.len(), 1);
179            array.value(0)
180        };
181
182        // Query using explicit schema name.
183        let df_with_schema = ctx.sql("SELECT COUNT(*) AS c FROM sales.orders").await?;
184        let batches_with_schema = df_with_schema.collect().await?;
185
186        // Query relying on default catalog and schema.
187        let df_without_schema = ctx.sql("SELECT COUNT(*) AS c FROM orders").await?;
188        let batches_without_schema = df_without_schema.collect().await?;
189
190        let count_with_schema = extract_count(&batches_with_schema);
191        let count_without_schema = extract_count(&batches_without_schema);
192
193        assert_eq!(count_with_schema, 3);
194        assert_eq!(count_without_schema, 3);
195        assert_eq!(count_with_schema, count_without_schema);
196
197        Ok(())
198    }
199}