stately_arrow/
context.rs

1//! Query context and session abstractions for `DataFusion `integration.
2//!
3//! This module provides the core abstractions for executing queries against registered connectors:
4//!
5//! - [`QuerySession`]: An abstraction over `DataFusion`'s `SessionContext` that allows custom
6//!   implementations to control session behavior, capabilities, and SQL execution. The default
7//!   implementation wraps `SessionContext` directly.
8//!
9//! - [`QueryContext`]: The high-level interface combining a session and a [`ConnectorRegistry`]. It
10//!   handles connector registration, catalog discovery, and query execution.
11//!
12//! - [`SessionCapability`]: Describes what actions a session supports (e.g., executing queries
13//!   without specifying a connector).
14//!
15//! # Custom Sessions
16//!
17//! Implement [`QuerySession`] to customize `DataFusion `behavior:
18//!
19//! ```ignore
20//! use async_trait::async_trait;
21//! use datafusion::execution::context::SessionContext;
22//! use stately_arrow::{QuerySession, SessionCapability, Result};
23//!
24//! #[derive(Clone)]
25//! pub struct MySession {
26//!     inner: SessionContext,
27//! }
28//!
29//! #[async_trait]
30//! impl QuerySession for MySession {
31//!     fn as_session(&self) -> &SessionContext { &self.inner }
32//!     fn capabilities(&self) -> &[SessionCapability] { &[] }
33//!     async fn sql(&self, sql: &str) -> Result<DataFrame> {
34//!         // Custom SQL handling
35//!     }
36//! }
37//! ```
38//!
39//! [`ConnectorRegistry`]: crate::ConnectorRegistry
40use std::sync::Arc;
41
42use async_trait::async_trait;
43use datafusion::dataframe::DataFrame;
44use datafusion::execution::SendableRecordBatchStream;
45use datafusion::execution::context::SessionContext;
46use datafusion::prelude::SessionConfig;
47use tracing::error;
48
49use crate::backend::{Capability, ConnectionMetadata};
50use crate::error::{Error, Result};
51use crate::registry::ConnectorRegistry;
52use crate::response::ListSummary;
53
54pub const DEFAULT_SESSION_CAPABILITIES: &[SessionCapability] =
55    &[SessionCapability::ExecuteWithoutConnector];
56
57/// Abstraction over a query-capable `DataFusion` session.
58#[async_trait]
59pub trait QuerySession: Send + Sync + Clone {
60    /// Access the underlying `SessionContext` for registration or low-level control.
61    fn as_session(&self) -> &SessionContext;
62
63    fn capabilities(&self) -> &[SessionCapability];
64
65    /// Execute SQL and return a `DataFrame` for streaming.
66    async fn sql(&self, sql: &str) -> Result<DataFrame>;
67}
68
69#[async_trait]
70impl QuerySession for SessionContext {
71    fn as_session(&self) -> &SessionContext { self }
72
73    fn capabilities(&self) -> &[SessionCapability] { DEFAULT_SESSION_CAPABILITIES }
74
75    async fn sql(&self, sql: &str) -> Result<DataFrame> {
76        SessionContext::sql(self, sql).await.map_err(Error::DataFusion)
77    }
78}
79
80/// Session capabilities a `QuerySession` can expose to the `QueryContext`.
81#[non_exhaustive]
82#[derive(
83    Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, utoipa::ToSchema,
84)]
85#[serde(rename_all = "snake_case")]
86pub enum SessionCapability {
87    /// Query session context supports executing ad-hoc SQL queries through `DataFusion` without
88    /// providing a specific connector ID.
89    ExecuteWithoutConnector,
90}
91
92/// Query context for interactive data exploration.
93#[derive(Clone)]
94pub struct QueryContext<S = SessionContext>
95where
96    S: QuerySession,
97{
98    session:  S,
99    registry: Arc<dyn ConnectorRegistry>,
100}
101
102impl QueryContext<SessionContext> {
103    /// Create a new query context backed by the provided connector registry.
104    pub fn new(registry: Arc<dyn ConnectorRegistry>) -> Self {
105        let session =
106            SessionContext::new_with_config(SessionConfig::default().with_information_schema(true))
107                .enable_url_table();
108        Self { session, registry }
109    }
110}
111
112impl<S> QueryContext<S>
113where
114    S: QuerySession,
115{
116    /// Construct a query context from a custom session implementation.
117    pub fn with_session(session: S, registry: Arc<dyn ConnectorRegistry>) -> Self {
118        Self { session, registry }
119    }
120
121    /// Access the underlying `DataFusion` session.
122    pub fn session(&self) -> &SessionContext { self.session.as_session() }
123
124    ///  Register a connector to be queried
125    ///
126    /// # Errors
127    /// - If an error occurs while preparing the session.
128    pub async fn register(&self, connector_id: &str) -> Result<Vec<ConnectionMetadata>> {
129        let connector = self.registry.get(connector_id).await?;
130        connector
131            .prepare_session(self.session.as_session())
132            .await
133            .inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
134        return self.list_registered().await;
135    }
136
137    /// List catalogs exposed by this connector.
138    #[allow(unused_mut)]
139    pub async fn list_catalogs(&self) -> Vec<String> {
140        // First check for any object stores
141        let session = self.session();
142        let mut catalogs = session.catalog_names();
143
144        #[cfg(feature = "object-store")]
145        {
146            let connectors = self
147                .list_connectors()
148                .await
149                .inspect_err(|error| error!(?error, "Error listing connectors"))
150                .unwrap_or_default();
151            for connector in connectors {
152                if connector.metadata.kind == crate::ConnectionKind::ObjectStore
153                    && let Some(catalog) = connector.catalog.as_ref()
154                {
155                    use datafusion::execution::object_store::ObjectStoreUrl;
156
157                    let Ok(url) = ObjectStoreUrl::parse(catalog) else {
158                        continue;
159                    };
160
161                    if session.runtime_env().object_store(&url).is_ok() {
162                        catalogs.push(catalog.clone());
163                    }
164                }
165            }
166        }
167
168        catalogs
169    }
170
171    /// List available connectors.
172    ///
173    /// # Errors
174    /// - If an error occurs while listing connectors.
175    pub async fn list_connectors(&self) -> Result<Vec<ConnectionMetadata>> {
176        self.registry.list().await
177    }
178
179    /// List available connectors.
180    ///
181    /// # Errors
182    /// - If an error occurs while listing connectors.
183    pub async fn list_registered(&self) -> Result<Vec<ConnectionMetadata>> {
184        self.registry.registered().await
185    }
186
187    /// List databases, or tables/files for a connector, if supported.
188    ///
189    /// # Errors
190    /// - If an error occurs while listing.
191    pub async fn list(&self, connector_id: &str, term: Option<&str>) -> Result<ListSummary> {
192        let connector = self
193            .registry
194            .get(connector_id)
195            .await
196            .inspect_err(|error| error!(?error, connector_id, "Error getting connection"))?;
197        if !connector.connection().has(Capability::List) {
198            tracing::error!(
199                "Connector '{connector_id}' does not support listing: {:?}",
200                connector.connection()
201            );
202            return Err(Error::UnsupportedConnector(format!(
203                "Connector does not support listing: {connector_id}"
204            )));
205        }
206        connector
207            .prepare_session(self.session.as_session())
208            .await
209            .inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
210        connector.list(term).await
211    }
212
213    /// Execute a SQL query through the provided connector.
214    ///
215    /// # Errors
216    /// - If an error occurs while executing the query.
217    pub async fn execute_query(
218        &self,
219        connector_id: Option<&str>,
220        sql: &str,
221    ) -> Result<SendableRecordBatchStream> {
222        if let Some(connector_id) = connector_id {
223            let connector = self.registry.get(connector_id).await.inspect_err(|error| {
224                error!(?error, connector_id, "Error getting connection");
225            })?;
226            if !connector.connection().has(Capability::ExecuteSql) {
227                error!(connector_id, "Connector does not support SQL execution");
228                return Err(Error::UnsupportedConnector(
229                    "Connector does not support SQL execution".into(),
230                ));
231            }
232
233            connector
234                .prepare_session(self.session.as_session())
235                .await
236                .inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
237        } else if !self.session.capabilities().contains(&SessionCapability::ExecuteWithoutConnector)
238        {
239            return Err(Error::UnsupportedSessionAction(
240                "Query context does not support SQL without connector id".into(),
241            ));
242        }
243
244        self.session
245            .sql(sql)
246            .await
247            .inspect_err(|error| error!(?error, connector_id, "Error running sql"))?
248            .execute_stream()
249            .await
250            .map_err(Error::DataFusion)
251    }
252}