use std::sync::Arc;
use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider};
use datafusion::execution::context::SessionContext;
use datafusion::prelude::*;
use crate::control::planner::converter::PlanConverter;
use crate::control::security::credential::CredentialStore;
use super::catalog::NodeDbSchemaProvider;
pub struct QueryContext {
session: SessionContext,
converter: PlanConverter,
}
impl QueryContext {
pub fn new() -> Self {
let config = SessionConfig::new()
.with_information_schema(false)
.with_default_catalog_and_schema("nodedb", "public");
let session = SessionContext::new_with_config(config);
register_udfs(&session);
Self {
session,
converter: PlanConverter::new(),
}
}
pub fn with_catalog(credentials: Arc<CredentialStore>, tenant_id: u32) -> Self {
let config = SessionConfig::new()
.with_information_schema(false)
.with_default_catalog_and_schema("nodedb", "public");
let session = SessionContext::new_with_config(config);
register_udfs(&session);
let schema_provider = Arc::new(NodeDbSchemaProvider::new(
Arc::clone(&credentials),
tenant_id,
));
let catalog = MemoryCatalogProvider::new();
catalog
.register_schema("public", schema_provider)
.expect("register schema");
session.register_catalog("nodedb", Arc::new(catalog));
Self {
session,
converter: PlanConverter::with_credentials(credentials),
}
}
pub async fn sql_to_logical(
&self,
sql: &str,
) -> crate::Result<datafusion::logical_expr::LogicalPlan> {
let df = self
.session
.sql(sql)
.await
.map_err(|e| crate::Error::PlanError {
detail: format!("SQL parse: {e}"),
})?;
let plan = df
.into_optimized_plan()
.map_err(|e| crate::Error::PlanError {
detail: format!("optimization: {e}"),
})?;
Ok(plan)
}
pub async fn plan_sql(
&self,
sql: &str,
tenant_id: crate::types::TenantId,
) -> crate::Result<Vec<super::physical::PhysicalTask>> {
let logical = self.sql_to_logical(sql).await?;
self.converter.convert(&logical, tenant_id)
}
pub async fn plan_sql_with_rls(
&self,
sql: &str,
tenant_id: crate::types::TenantId,
auth: &crate::control::security::auth_context::AuthContext,
rls_store: &crate::control::security::rls::RlsPolicyStore,
) -> crate::Result<Vec<super::physical::PhysicalTask>> {
let logical = self.sql_to_logical(sql).await?;
let mut tasks = self.converter.convert(&logical, tenant_id)?;
super::rls_injection::inject_rls(&mut tasks, rls_store, auth)?;
Ok(tasks)
}
pub fn session(&self) -> &SessionContext {
&self.session
}
pub fn register_udf(&self, udf: datafusion::logical_expr::ScalarUDF) {
self.session.register_udf(udf);
}
}
impl Default for QueryContext {
fn default() -> Self {
Self::new()
}
}
fn register_udfs(session: &SessionContext) {
use super::udf::spatial::{
GeoDistance, StContains, StDistance, StDwithin, StIntersects, StWithin,
};
use super::udf::{
Bm25Score, DocArrayContains, DocExists, DocGet, RrfScore, TextMatch, VectorDistance,
};
use datafusion::logical_expr::ScalarUDF;
session.register_udf(ScalarUDF::new_from_impl(DocGet::new()));
session.register_udf(ScalarUDF::new_from_impl(DocExists::new()));
session.register_udf(ScalarUDF::new_from_impl(DocArrayContains::new()));
session.register_udf(ScalarUDF::new_from_impl(VectorDistance::new()));
session.register_udf(ScalarUDF::new_from_impl(RrfScore::new()));
session.register_udf(ScalarUDF::new_from_impl(Bm25Score::new()));
session.register_udf(ScalarUDF::new_from_impl(TextMatch::new()));
session.register_udf(ScalarUDF::new_from_impl(StDwithin::new()));
session.register_udf(ScalarUDF::new_from_impl(StContains::new()));
session.register_udf(ScalarUDF::new_from_impl(StIntersects::new()));
session.register_udf(ScalarUDF::new_from_impl(StWithin::new()));
session.register_udf(ScalarUDF::new_from_impl(StDistance::new()));
session.register_udf(ScalarUDF::new_from_impl(GeoDistance::new()));
nodedb_query::ts_udfs::register_timeseries_udfs(session);
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn parse_simple_select() {
let ctx = QueryContext::new();
ctx.session()
.sql("CREATE TABLE users (id INT, name VARCHAR, email VARCHAR) AS VALUES (1, 'alice', 'a@b.com')")
.await
.unwrap();
let plan = ctx
.sql_to_logical("SELECT id, name FROM users WHERE id = 1")
.await;
assert!(plan.is_ok(), "failed: {:?}", plan.err());
}
#[tokio::test]
async fn invalid_sql_returns_error() {
let ctx = QueryContext::new();
let result = ctx.sql_to_logical("SELECTT * FROMM nowhere").await;
assert!(result.is_err());
}
}