use std::sync::Arc;
use async_trait::async_trait;
use datafusion::dataframe::DataFrame;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::execution::context::SessionContext;
use datafusion::prelude::SessionConfig;
use tracing::error;
use crate::backend::{Capability, ConnectionMetadata};
use crate::error::{Error, Result};
use crate::registry::ConnectorRegistry;
use crate::response::ListSummary;
pub const DEFAULT_SESSION_CAPABILITIES: &[SessionCapability] =
&[SessionCapability::ExecuteWithoutConnector];
#[async_trait]
pub trait QuerySession: Send + Sync + Clone {
fn as_session(&self) -> &SessionContext;
fn capabilities(&self) -> &[SessionCapability];
async fn sql(&self, sql: &str) -> Result<DataFrame>;
}
#[async_trait]
impl QuerySession for SessionContext {
fn as_session(&self) -> &SessionContext { self }
fn capabilities(&self) -> &[SessionCapability] { DEFAULT_SESSION_CAPABILITIES }
async fn sql(&self, sql: &str) -> Result<DataFrame> {
SessionContext::sql(self, sql).await.map_err(Error::DataFusion)
}
}
#[non_exhaustive]
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, utoipa::ToSchema,
)]
#[serde(rename_all = "snake_case")]
pub enum SessionCapability {
ExecuteWithoutConnector,
}
#[derive(Clone)]
pub struct QueryContext<S = SessionContext>
where
S: QuerySession,
{
session: S,
registry: Arc<dyn ConnectorRegistry>,
}
impl QueryContext<SessionContext> {
pub fn new(registry: Arc<dyn ConnectorRegistry>) -> Self {
let session =
SessionContext::new_with_config(SessionConfig::default().with_information_schema(true))
.enable_url_table();
Self { session, registry }
}
}
impl<S> QueryContext<S>
where
S: QuerySession,
{
pub fn with_session(session: S, registry: Arc<dyn ConnectorRegistry>) -> Self {
Self { session, registry }
}
pub fn session(&self) -> &SessionContext { self.session.as_session() }
pub async fn register(&self, connector_id: &str) -> Result<Vec<ConnectionMetadata>> {
let connector = self.registry.get(connector_id).await?;
connector
.prepare_session(self.session.as_session())
.await
.inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
return self.list_registered().await;
}
#[allow(unused_mut)]
pub async fn list_catalogs(&self) -> Vec<String> {
let session = self.session();
let mut catalogs = session.catalog_names();
#[cfg(feature = "object-store")]
{
let connectors = self
.list_connectors()
.await
.inspect_err(|error| error!(?error, "Error listing connectors"))
.unwrap_or_default();
for connector in connectors {
if connector.metadata.kind == crate::ConnectionKind::ObjectStore
&& let Some(catalog) = connector.catalog.as_ref()
{
use datafusion::execution::object_store::ObjectStoreUrl;
let Ok(url) = ObjectStoreUrl::parse(catalog) else {
continue;
};
if session.runtime_env().object_store(&url).is_ok() {
catalogs.push(catalog.clone());
}
}
}
}
catalogs
}
pub async fn list_connectors(&self) -> Result<Vec<ConnectionMetadata>> {
self.registry.list().await
}
pub async fn list_registered(&self) -> Result<Vec<ConnectionMetadata>> {
self.registry.registered().await
}
pub async fn list(&self, connector_id: &str, term: Option<&str>) -> Result<ListSummary> {
let connector = self
.registry
.get(connector_id)
.await
.inspect_err(|error| error!(?error, connector_id, "Error getting connection"))?;
if !connector.connection().has(Capability::List) {
tracing::error!(
"Connector '{connector_id}' does not support listing: {:?}",
connector.connection()
);
return Err(Error::UnsupportedConnector(format!(
"Connector does not support listing: {connector_id}"
)));
}
connector
.prepare_session(self.session.as_session())
.await
.inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
connector.list(term).await
}
pub async fn execute_query(
&self,
connector_id: Option<&str>,
sql: &str,
) -> Result<SendableRecordBatchStream> {
if let Some(connector_id) = connector_id {
let connector = self.registry.get(connector_id).await.inspect_err(|error| {
error!(?error, connector_id, "Error getting connection");
})?;
if !connector.connection().has(Capability::ExecuteSql) {
error!(connector_id, "Connector does not support SQL execution");
return Err(Error::UnsupportedConnector(
"Connector does not support SQL execution".into(),
));
}
connector
.prepare_session(self.session.as_session())
.await
.inspect_err(|error| error!(?error, connector_id, "Error preparing session"))?;
} else if !self.session.capabilities().contains(&SessionCapability::ExecuteWithoutConnector)
{
return Err(Error::UnsupportedSessionAction(
"Query context does not support SQL without connector id".into(),
));
}
self.session
.sql(sql)
.await
.inspect_err(|error| error!(?error, connector_id, "Error running sql"))?
.execute_stream()
.await
.map_err(Error::DataFusion)
}
}