1use std::sync::Arc;
41
42use async_trait::async_trait;
43use datafusion::dataframe::DataFrame;
44use datafusion::execution::SendableRecordBatchStream;
45use datafusion::execution::context::SessionContext;
46use datafusion::prelude::SessionConfig;
47use tracing::error;
48
49use crate::backend::{Capability, ConnectionMetadata};
50use crate::error::{Error, Result};
51use crate::registry::ConnectorRegistry;
52use crate::response::ListSummary;
53
54pub const DEFAULT_SESSION_CAPABILITIES: &[SessionCapability] =
55 &[SessionCapability::ExecuteWithoutConnector];
56
57#[async_trait]
59pub trait QuerySession: Send + Sync + Clone {
60 fn as_session(&self) -> &SessionContext;
62
63 fn capabilities(&self) -> &[SessionCapability];
64
65 async fn sql(&self, sql: &str) -> Result<DataFrame>;
67}
68
69#[async_trait]
70impl QuerySession for SessionContext {
71 fn as_session(&self) -> &SessionContext { self }
72
73 fn capabilities(&self) -> &[SessionCapability] { DEFAULT_SESSION_CAPABILITIES }
74
75 async fn sql(&self, sql: &str) -> Result<DataFrame> {
76 SessionContext::sql(self, sql).await.map_err(Error::DataFusion)
77 }
78}
79
80#[non_exhaustive]
82#[derive(
83 Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, utoipa::ToSchema,
84)]
85#[serde(rename_all = "snake_case")]
86pub enum SessionCapability {
87 ExecuteWithoutConnector,
90}
91
92#[derive(Clone)]
94pub struct QueryContext<S = SessionContext>
95where
96 S: QuerySession,
97{
98 session: S,
99 registry: Arc<dyn ConnectorRegistry>,
100}
101
102impl QueryContext<SessionContext> {
103 pub fn new(registry: Arc<dyn ConnectorRegistry>) -> Self {
105 let session =
106 SessionContext::new_with_config(SessionConfig::default().with_information_schema(true))
107 .enable_url_table();
108 Self { session, registry }
109 }
110}
111
112impl<S> QueryContext<S>
113where
114 S: QuerySession,
115{
116 pub fn with_session(session: S, registry: Arc<dyn ConnectorRegistry>) -> Self {
118 Self { session, registry }
119 }
120
121 pub fn session(&self) -> &SessionContext { self.session.as_session() }
123
124 pub async fn register(&self, connector_id: &str) -> Result<Vec<ConnectionMetadata>> {
129 let connector = self.registry.get(connector_id).await?;
130 connector
131 .prepare_session(self.session.as_session())
132 .await
133 .inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
134 return self.list_registered().await;
135 }
136
137 #[allow(unused_mut)]
139 pub async fn list_catalogs(&self) -> Vec<String> {
140 let session = self.session();
142 let mut catalogs = session.catalog_names();
143
144 #[cfg(feature = "object-store")]
145 {
146 let connectors = self
147 .list_connectors()
148 .await
149 .inspect_err(|error| error!(?error, "Error listing connectors"))
150 .unwrap_or_default();
151 for connector in connectors {
152 if connector.metadata.kind == crate::ConnectionKind::ObjectStore
153 && let Some(catalog) = connector.catalog.as_ref()
154 {
155 use datafusion::execution::object_store::ObjectStoreUrl;
156
157 let Ok(url) = ObjectStoreUrl::parse(catalog) else {
158 continue;
159 };
160
161 if session.runtime_env().object_store(&url).is_ok() {
162 catalogs.push(catalog.clone());
163 }
164 }
165 }
166 }
167
168 catalogs
169 }
170
171 pub async fn list_connectors(&self) -> Result<Vec<ConnectionMetadata>> {
176 self.registry.list().await
177 }
178
179 pub async fn list_registered(&self) -> Result<Vec<ConnectionMetadata>> {
184 self.registry.registered().await
185 }
186
187 pub async fn list(&self, connector_id: &str, term: Option<&str>) -> Result<ListSummary> {
192 let connector = self
193 .registry
194 .get(connector_id)
195 .await
196 .inspect_err(|error| error!(?error, connector_id, "Error getting connection"))?;
197 if !connector.connection().has(Capability::List) {
198 tracing::error!(
199 "Connector '{connector_id}' does not support listing: {:?}",
200 connector.connection()
201 );
202 return Err(Error::UnsupportedConnector(format!(
203 "Connector does not support listing: {connector_id}"
204 )));
205 }
206 connector
207 .prepare_session(self.session.as_session())
208 .await
209 .inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
210 connector.list(term).await
211 }
212
213 pub async fn execute_query(
218 &self,
219 connector_id: Option<&str>,
220 sql: &str,
221 ) -> Result<SendableRecordBatchStream> {
222 if let Some(connector_id) = connector_id {
223 let connector = self.registry.get(connector_id).await.inspect_err(|error| {
224 error!(?error, connector_id, "Error getting connection");
225 })?;
226 if !connector.connection().has(Capability::ExecuteSql) {
227 error!(connector_id, "Connector does not support SQL execution");
228 return Err(Error::UnsupportedConnector(
229 "Connector does not support SQL execution".into(),
230 ));
231 }
232
233 connector
234 .prepare_session(self.session.as_session())
235 .await
236 .inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
237 } else if !self.session.capabilities().contains(&SessionCapability::ExecuteWithoutConnector)
238 {
239 return Err(Error::UnsupportedSessionAction(
240 "Query context does not support SQL without connector id".into(),
241 ));
242 }
243
244 self.session
245 .sql(sql)
246 .await
247 .inspect_err(|error| error!(?error, connector_id, "Error running sql"))?
248 .execute_stream()
249 .await
250 .map_err(Error::DataFusion)
251 }
252}