#![forbid(unsafe_code)]
use std::collections::{BTreeSet, HashMap, VecDeque};
use std::fmt;
use std::num::NonZeroUsize;
use std::ops::ControlFlow;
use std::path::Path;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use catalog::{InMemoryCatalog, datafusion_bridge::DataFusionCatalogBridge};
use datafusion::dataframe::DataFrame as DataFusionDataFrame;
use datafusion::prelude::{ParquetReadOptions, SessionContext};
use datafusion::sql::sqlparser::{ast::visit_relations, dialect::GenericDialect, parser::Parser};
use object_store::aws::AmazonS3Builder;
use krishiv_plan::optimizer::{CostModel, Optimizer};
use krishiv_plan::{ExecutionKind, LogicalPlan, PlanNode};
pub mod analyze;
pub mod catalog;
pub mod cep_sql;
pub mod connector_table;
pub mod create_function_ddl;
pub mod grammar;
pub mod incremental_view;
pub mod introspection_sql;
pub mod kafka_table;
pub mod lakehouse;
pub mod live_table;
pub mod pipeline_ddl;
pub mod pivot_sql;
pub mod recursive_cte;
pub mod spark_sql_ext;
pub mod sqlstate;
pub mod subquery;
pub mod unnest_sql;
pub mod streaming;
pub mod streaming_tvf;
pub mod streaming_window_plan;
mod udf;
mod window_functions;
pub use cep_sql::{
MatchRecognizeStatement, execute_streaming_match_recognize, parse_match_recognize,
};
pub use lakehouse::{AsOfTableRef, MergeResult, MergeTargetUnsupportedError, preprocess_as_of_sql};
pub use grammar::{
FeatureEntry, FeatureStatus, feature_matrix, features_by_status, features_for_category,
};
pub use sqlstate::{SqlStateError, sqlstate_for};
pub use streaming::{ContinuousInputError, ContinuousTableInput};
pub type SqlResult<T> = Result<T, SqlError>;
pub type SqlStream =
std::pin::Pin<Box<dyn futures::stream::Stream<Item = Result<RecordBatch, SqlError>> + Send>>;
static EPHEMERAL_TABLE_COUNTER: AtomicU64 = AtomicU64::new(0);
fn next_ephemeral_name(prefix: &str) -> String {
let id = EPHEMERAL_TABLE_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("__{prefix}_{id}")
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum WindowFnRegistration {
Register,
Skip,
}
struct PlanCache {
map: HashMap<String, datafusion::logical_expr::LogicalPlan>,
order: VecDeque<String>,
max: usize,
}
impl PlanCache {
fn new(max: usize) -> Self {
Self {
map: HashMap::new(),
order: VecDeque::new(),
max,
}
}
fn get(&self, key: &str) -> Option<&datafusion::logical_expr::LogicalPlan> {
self.map.get(key)
}
fn insert(&mut self, key: String, plan: datafusion::logical_expr::LogicalPlan) {
if self.map.contains_key(&key) {
self.order.retain(|k| k != &key);
} else if self.map.len() >= self.max
&& let Some(oldest) = self.order.pop_front()
{
self.map.remove(&oldest);
}
self.order.push_back(key.clone());
self.map.insert(key, plan);
}
fn clear(&mut self) {
self.map.clear();
self.order.clear();
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct ParquetReaderOptions {
pub batch_size: Option<usize>,
}
#[derive(Debug, Clone, Default)]
pub struct CsvReaderOptions {
pub delimiter: Option<char>,
pub has_header: Option<bool>,
}
#[derive(Debug, Clone, Default)]
pub struct ParquetWriterOptions {
pub compression: Option<String>,
pub max_row_group_size: Option<usize>,
}
#[derive(Debug, Clone, Default)]
pub struct CsvWriterOptions {
pub delimiter: Option<char>,
pub has_header: Option<bool>,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum SqlError {
#[error("SQL query is empty")]
EmptyQuery,
#[error("table name is empty")]
EmptyTableName,
#[error("unsupported SQL feature: {feature}")]
Unsupported { feature: String },
#[error("invalid table function: {message}")]
InvalidTableFunction { message: String },
#[error("DataFusion error: {message}")]
DataFusion { message: String },
#[error(transparent)]
Optimizer(#[from] krishiv_plan::optimizer::OptimizerError),
#[error("access denied: {reason}")]
AccessDenied { reason: String },
#[error("operation {operation_id} was cancelled")]
OperationCancelled { operation_id: u64 },
#[error("query timed out after {timeout_ms} ms")]
Timeout { timeout_ms: u64 },
}
impl From<datafusion::error::DataFusionError> for SqlError {
fn from(value: datafusion::error::DataFusionError) -> Self {
Self::DataFusion {
message: value.to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SqlPlan {
query: String,
logical_plan: LogicalPlan,
}
impl SqlPlan {
pub fn query(&self) -> &str {
&self.query
}
pub fn logical_plan(&self) -> &LogicalPlan {
&self.logical_plan
}
}
const PLAN_CACHE_MAX_ENTRIES: usize = 256;
fn resolve_plan_cache_max_entries() -> usize {
std::env::var("KRISHIV_PLAN_CACHE_MAX_ENTRIES")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n| n > 0)
.unwrap_or(PLAN_CACHE_MAX_ENTRIES)
}
const STREAMING_CEP_MAX_ROWS_DEFAULT: usize = 100_000;
pub fn resolve_streaming_match_recognize_limit(raw: Option<&str>) -> usize {
raw.and_then(|s| s.parse::<usize>().ok())
.filter(|n| *n > 0)
.unwrap_or(STREAMING_CEP_MAX_ROWS_DEFAULT)
}
pub fn streaming_match_recognize_limit_from_env() -> usize {
resolve_streaming_match_recognize_limit(
std::env::var("KRISHIV_MATCH_RECOGNIZE_STREAMING_LIMIT")
.ok()
.as_deref(),
)
}
pub fn resolve_query_memory_limit_bytes(raw: Option<&str>) -> Option<usize> {
raw.and_then(|s| s.trim().parse::<usize>().ok())
.filter(|n| *n > 0)
}
pub fn query_memory_limit_from_env() -> Option<usize> {
resolve_query_memory_limit_bytes(
std::env::var("KRISHIV_QUERY_MEMORY_LIMIT_BYTES")
.ok()
.as_deref(),
)
}
pub fn batch_size_from_env() -> usize {
std::env::var("KRISHIV_BATCH_SIZE")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|n| *n > 0)
.unwrap_or(8192)
}
pub fn default_parallelism_from_env() -> NonZeroUsize {
std::env::var("KRISHIV_TARGET_PARALLELISM")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.and_then(NonZeroUsize::new)
.unwrap_or_else(|| std::thread::available_parallelism().unwrap_or(NonZeroUsize::MIN))
}
fn build_single_node_session_config(
target_partitions: NonZeroUsize,
) -> datafusion::prelude::SessionConfig {
let tp = target_partitions.get();
let batch_size = batch_size_from_env();
let mut config = datafusion::prelude::SessionConfig::new()
.with_target_partitions(tp)
.with_batch_size(batch_size)
.with_information_schema(true)
.set_bool(
"datafusion.optimizer.enable_round_robin_repartition",
tp > 1,
);
config.options_mut().execution.parquet.pushdown_filters = true;
config.options_mut().execution.parquet.enable_page_index = true;
config
}
#[derive(Clone)]
pub struct SqlEngine {
context: SessionContext,
target_parallelism: NonZeroUsize,
krishiv_catalog: Option<Arc<RwLock<InMemoryCatalog>>>,
udf_registry: Option<std::sync::Arc<std::sync::RwLock<krishiv_plan::udf::UdfRegistry>>>,
streaming_sources: Arc<RwLock<std::collections::HashSet<String>>>,
streaming_registration: Arc<Mutex<()>>,
has_streaming_sources: Arc<AtomicBool>,
udf_limits: Option<krishiv_plan::udf::ResourceLimits>,
udf_registry_version: Arc<AtomicU64>,
udf_last_synced_version: Arc<AtomicU64>,
plan_cache: Arc<Mutex<PlanCache>>,
shuffle_partitions: Arc<std::sync::RwLock<Option<u32>>>,
table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
memory_limit_bytes: Option<usize>,
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
iceberg_catalogs: Arc<std::sync::RwLock<Vec<(Arc<catalog::unified::KrishivCatalog>, String)>>>,
live_table_registry: Arc<live_table::LiveTableRegistry>,
incremental_view_registry: Arc<incremental_view::IncrementalViewRegistry>,
pipeline_registry: Arc<pipeline_ddl::PipelineRegistry>,
operation_registry: Arc<OperationRegistry>,
}
impl fmt::Debug for SqlEngine {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SqlEngine")
.field("backend", &"datafusion")
.finish_non_exhaustive()
}
}
impl Default for SqlEngine {
fn default() -> Self {
Self::new()
}
}
impl SqlEngine {
pub fn new() -> Self {
Self::new_with_memory_limit(query_memory_limit_from_env())
}
pub fn new_with_memory_limit(memory_limit_bytes: Option<usize>) -> Self {
match Self::build_local(
None,
WindowFnRegistration::Register,
NonZeroUsize::MIN,
memory_limit_bytes,
) {
Ok(engine) => engine,
Err(err) => {
tracing::warn!(
error = %err,
"SqlEngine::new: window helper UDF registration failed; \
window SQL functions will be unavailable, other queries are unaffected"
);
Self::build_local(
None,
WindowFnRegistration::Skip,
NonZeroUsize::MIN,
memory_limit_bytes,
)
.unwrap_or_else(|err| {
tracing::error!(
error = %err,
"memory-limited DataFusion runtime construction failed; \
falling back to an unbounded engine"
);
Self::build_local(None, WindowFnRegistration::Skip, NonZeroUsize::MIN, None)
.unwrap_or_else(|_| Self::build_absolute_minimal(NonZeroUsize::MIN))
})
}
}
}
pub fn try_new() -> SqlResult<Self> {
Self::build_local(
None,
WindowFnRegistration::Register,
NonZeroUsize::MIN,
query_memory_limit_from_env(),
)
}
pub fn with_in_memory_catalog(catalog: Arc<RwLock<InMemoryCatalog>>) -> SqlResult<Self> {
if krishiv_common::profile_requires_fail_closed_metadata(
krishiv_common::resolve_durability_profile(),
) {
return Err(SqlError::DataFusion {
message: String::from(
"InMemoryCatalog is dev-only; configure a durable REST or file-backed \
catalog for production deployments",
),
});
}
Self::build_local(
Some(catalog),
WindowFnRegistration::Register,
NonZeroUsize::MIN,
query_memory_limit_from_env(),
)
}
#[must_use]
pub fn with_target_parallelism(mut self, n: NonZeroUsize) -> Self {
self.target_parallelism = n;
self
}
pub fn target_parallelism(&self) -> NonZeroUsize {
self.target_parallelism
}
pub fn memory_limit_bytes(&self) -> Option<usize> {
self.memory_limit_bytes
}
pub fn shuffle_partitions(&self) -> Option<u32> {
*self
.shuffle_partitions
.read()
.unwrap_or_else(|e| e.into_inner())
}
pub fn table_row_counts(&self) -> Arc<std::sync::RwLock<HashMap<String, u64>>> {
Arc::clone(&self.table_row_counts)
}
pub fn registered_table_names(&self) -> Vec<String> {
let mut names = Vec::new();
for catalog_name in self.context.catalog_names() {
let Some(catalog) = self.context.catalog(&catalog_name) else {
continue;
};
for schema_name in catalog.schema_names() {
let Some(schema) = catalog.schema(&schema_name) else {
continue;
};
names.extend(schema.table_names());
}
}
names.sort();
names.dedup();
names
}
fn make_sql_df(&self, name: &str, dataframe: DataFusionDataFrame) -> SqlDataFrame {
SqlDataFrame::new(name, dataframe, self.table_row_counts())
.with_context(self.context.clone())
}
fn attach_query_metadata(&self, df: SqlDataFrame, query: &str) -> SqlDataFrame {
let kind = if self.is_streaming_query(query).unwrap_or(false) {
ExecutionKind::Streaming
} else {
ExecutionKind::Batch
};
df.with_query(query).with_execution_kind(kind)
}
#[must_use]
pub fn with_shuffle_partitions(self, n: Option<u32>) -> Self {
if let Ok(mut guard) = self.shuffle_partitions.write() {
*guard = n;
}
self
}
fn build_local(
krishiv_catalog: Option<Arc<RwLock<InMemoryCatalog>>>,
window_fn_registration: WindowFnRegistration,
target_partitions: NonZeroUsize,
memory_limit_bytes: Option<usize>,
) -> SqlResult<Self> {
let streaming_sources: Arc<RwLock<std::collections::HashSet<String>>> =
Arc::new(RwLock::new(std::collections::HashSet::new()));
let dummy_state = datafusion::execution::session_state::SessionStateBuilder::new()
.with_default_features()
.build();
let mut table_factories = dummy_state.table_factories().clone();
crate::connector_table::register_connector_table_factories(
&mut table_factories,
streaming_sources.clone(),
);
let mut state_builder = datafusion::execution::session_state::SessionStateBuilder::new()
.with_default_features()
.with_config(build_single_node_session_config(target_partitions))
.with_table_factories(table_factories);
if let Some(limit) = memory_limit_bytes {
let runtime_env = datafusion::execution::runtime_env::RuntimeEnvBuilder::new()
.with_memory_pool(Arc::new(
datafusion::execution::memory_pool::FairSpillPool::new(limit),
))
.build_arc()
.map_err(|e| SqlError::DataFusion {
message: format!(
"failed to build memory-limited DataFusion runtime \
(limit {limit} bytes): {e}"
),
})?;
state_builder = state_builder.with_runtime_env(runtime_env);
}
let state = state_builder.build();
let context = SessionContext::new_with_state(state);
if let Some(catalog) = &krishiv_catalog {
context.register_catalog(
"krishiv",
Arc::new(DataFusionCatalogBridge::new(catalog.clone())),
);
}
if matches!(window_fn_registration, WindowFnRegistration::Register) {
window_functions::register_window_functions(&context).map_err(|e| {
SqlError::DataFusion {
message: format!("failed to register window helper UDFs: {e}"),
}
})?;
}
Ok(Self {
context,
target_parallelism: target_partitions,
krishiv_catalog,
udf_registry: None,
streaming_sources,
streaming_registration: Arc::new(Mutex::new(())),
has_streaming_sources: Arc::new(AtomicBool::new(false)),
udf_limits: None,
udf_registry_version: Arc::new(AtomicU64::new(0)),
udf_last_synced_version: Arc::new(AtomicU64::new(u64::MAX)),
plan_cache: Arc::new(Mutex::new(PlanCache::new(resolve_plan_cache_max_entries()))),
shuffle_partitions: Arc::new(std::sync::RwLock::new(None)),
table_row_counts: Arc::new(std::sync::RwLock::new(HashMap::new())),
memory_limit_bytes,
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
iceberg_catalogs: Arc::new(std::sync::RwLock::new(Vec::new())),
live_table_registry: Arc::new(live_table::LiveTableRegistry::new()),
incremental_view_registry: Arc::new(incremental_view::IncrementalViewRegistry::new()),
pipeline_registry: Arc::new(pipeline_ddl::PipelineRegistry::new()),
operation_registry: Arc::new(OperationRegistry::new()),
})
}
fn build_absolute_minimal(target_partitions: NonZeroUsize) -> Self {
let streaming_sources: Arc<RwLock<std::collections::HashSet<String>>> =
Arc::new(RwLock::new(std::collections::HashSet::new()));
let dummy_state = datafusion::execution::session_state::SessionStateBuilder::new()
.with_default_features()
.build();
let mut table_factories = dummy_state.table_factories().clone();
crate::connector_table::register_connector_table_factories(
&mut table_factories,
streaming_sources.clone(),
);
let state = datafusion::execution::session_state::SessionStateBuilder::new()
.with_default_features()
.with_config(build_single_node_session_config(target_partitions))
.with_table_factories(table_factories)
.build();
let context = SessionContext::new_with_state(state);
Self {
context,
target_parallelism: target_partitions,
krishiv_catalog: None,
udf_registry: None,
streaming_sources,
streaming_registration: Arc::new(Mutex::new(())),
has_streaming_sources: Arc::new(AtomicBool::new(false)),
udf_limits: None,
udf_registry_version: Arc::new(AtomicU64::new(0)),
udf_last_synced_version: Arc::new(AtomicU64::new(u64::MAX)),
plan_cache: Arc::new(Mutex::new(PlanCache::new(resolve_plan_cache_max_entries()))),
shuffle_partitions: Arc::new(std::sync::RwLock::new(None)),
table_row_counts: Arc::new(std::sync::RwLock::new(HashMap::new())),
memory_limit_bytes: None,
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
iceberg_catalogs: Arc::new(std::sync::RwLock::new(Vec::new())),
live_table_registry: Arc::new(live_table::LiveTableRegistry::new()),
incremental_view_registry: Arc::new(incremental_view::IncrementalViewRegistry::new()),
pipeline_registry: Arc::new(pipeline_ddl::PipelineRegistry::new()),
operation_registry: Arc::new(OperationRegistry::new()),
}
}
pub fn register_streaming_table(
&self,
name: &str,
schema: arrow::datatypes::SchemaRef,
) -> SqlResult<Arc<ContinuousTableInput>> {
let _registration = self.lock_streaming_registration()?;
self.validate_new_streaming_table(name, &schema)?;
let (table, input) = crate::streaming::create_continuous_table(schema).map_err(|e| {
SqlError::DataFusion {
message: e.to_string(),
}
})?;
self.register_new_streaming_provider(name, table)?;
self.streaming_sources
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(name.to_string());
self.has_streaming_sources.store(true, Ordering::Release);
self.invalidate_plan_cache();
Ok(input)
}
pub fn register_streaming_table_with_capacity(
&self,
name: &str,
schema: arrow::datatypes::SchemaRef,
capacity: usize,
) -> SqlResult<Arc<ContinuousTableInput>> {
let _registration = self.lock_streaming_registration()?;
self.validate_new_streaming_table(name, &schema)?;
let (table, input) = crate::streaming::create_continuous_table_with_capacity(
schema, capacity,
)
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
self.register_new_streaming_provider(name, table)?;
self.streaming_sources
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(name.to_string());
self.has_streaming_sources.store(true, Ordering::Release);
self.invalidate_plan_cache();
Ok(input)
}
fn lock_streaming_registration(&self) -> SqlResult<std::sync::MutexGuard<'_, ()>> {
self.streaming_registration
.lock()
.map_err(|error| SqlError::DataFusion {
message: format!("streaming table registration lock poisoned: {error}"),
})
}
fn validate_new_streaming_table(
&self,
name: &str,
schema: &arrow::datatypes::SchemaRef,
) -> SqlResult<()> {
if name.trim().is_empty() {
return Err(SqlError::EmptyTableName);
}
if schema.fields().is_empty() {
return Err(SqlError::DataFusion {
message: "streaming table schema must contain at least one field".into(),
});
}
if self
.context
.table_exist(name)
.map_err(|error| SqlError::DataFusion {
message: error.to_string(),
})?
{
return Err(SqlError::DataFusion {
message: format!("table '{name}' is already registered"),
});
}
Ok(())
}
fn register_new_streaming_provider(
&self,
name: &str,
table: Arc<dyn datafusion::catalog::TableProvider>,
) -> SqlResult<()> {
let previous =
self.context
.register_table(name, table)
.map_err(|error| SqlError::DataFusion {
message: error.to_string(),
})?;
if let Some(previous) = previous {
self.context
.register_table(name, previous)
.map_err(|error| SqlError::DataFusion {
message: format!(
"table '{name}' was concurrently registered and could not be restored: \
{error}"
),
})?;
return Err(SqlError::DataFusion {
message: format!("table '{name}' was concurrently registered"),
});
}
Ok(())
}
pub fn register_kafka_source(
&self,
table_name: impl AsRef<str>,
schema: arrow::datatypes::SchemaRef,
bootstrap_servers: impl Into<String>,
topic: impl Into<String>,
group_id: impl Into<String>,
) -> SqlResult<()> {
let table_name = table_name.as_ref();
if table_name.trim().is_empty() {
return Err(SqlError::EmptyTableName);
}
let config = krishiv_connectors::kafka::KafkaConfig {
bootstrap_servers: bootstrap_servers.into(),
topic: topic.into(),
group_id: group_id.into(),
auto_commit_interval_ms: {
let profile = krishiv_common::resolve_durability_profile();
if krishiv_common::requires_manual_kafka_commit(profile) {
None
} else {
Some(1_000)
}
},
security_protocol: None,
ssl_ca_location: None,
ssl_certificate_location: None,
ssl_key_location: None,
ssl_key_password: None,
sasl_username: None,
sasl_password: None,
sasl_mechanisms: None,
enable_idempotence: None,
transactional_id: None,
};
let table =
crate::kafka_table::create_kafka_streaming_table(schema, config).map_err(|e| {
SqlError::DataFusion {
message: e.to_string(),
}
})?;
if self
.context
.table_exist(table_name)
.map_err(SqlError::from)?
{
let _ = self
.context
.deregister_table(table_name)
.map_err(SqlError::from)?;
}
self.context
.register_table(table_name, table)
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
self.streaming_sources
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(table_name.to_string());
self.has_streaming_sources.store(true, Ordering::Release);
self.invalidate_plan_cache();
Ok(())
}
pub async fn sql_to_kafka(
&self,
sql: impl AsRef<str>,
bootstrap_servers: impl Into<String>,
topic: impl Into<String>,
) -> SqlResult<u64> {
use futures::StreamExt;
use krishiv_connectors::Sink as _;
use krishiv_connectors::kafka::{KafkaConfig, KafkaSink};
let config = KafkaConfig {
bootstrap_servers: bootstrap_servers.into(),
topic: topic.into(),
group_id: "krishiv-sql-writer".into(),
auto_commit_interval_ms: None,
security_protocol: None,
ssl_ca_location: None,
ssl_certificate_location: None,
ssl_key_location: None,
ssl_key_password: None,
sasl_username: None,
sasl_password: None,
sasl_mechanisms: None,
enable_idempotence: None,
transactional_id: None,
};
let mut sink = KafkaSink::new(config).map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
let df = self.sql(sql.as_ref()).await?;
let mut stream = df.execute_stream().await?;
let mut total_rows = 0u64;
while let Some(result) = stream.next().await {
let batch = result.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
if batch.num_rows() > 0 {
total_rows += batch.num_rows() as u64;
sink.write_batch(batch)
.await
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
}
}
sink.flush().await.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
Ok(total_rows)
}
pub fn with_udf_limits(mut self, limits: krishiv_plan::udf::ResourceLimits) -> Self {
self.udf_limits = Some(limits);
self
}
pub fn is_streaming_source(&self, table_name: &str) -> bool {
self.streaming_sources
.read()
.unwrap_or_else(|e| e.into_inner())
.contains(table_name)
}
pub fn register_streaming_source_name(&self, table_name: impl Into<String>) -> SqlResult<()> {
let name: String = table_name.into();
if name.trim().is_empty() {
return Err(SqlError::EmptyTableName);
}
self.streaming_sources
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(name);
self.has_streaming_sources.store(true, Ordering::Release);
self.invalidate_plan_cache();
Ok(())
}
pub fn deregister_streaming_source(&self, name: &str) -> SqlResult<()> {
if name.trim().is_empty() {
return Err(SqlError::EmptyTableName);
}
let _ = self
.context
.deregister_table(name)
.map_err(SqlError::from)?;
{
let mut sources = self
.streaming_sources
.write()
.unwrap_or_else(|e| e.into_inner());
sources.remove(name);
if sources.is_empty() {
self.has_streaming_sources.store(false, Ordering::Release);
}
self.invalidate_plan_cache();
}
Ok(())
}
pub fn live_table_registry(&self) -> &Arc<live_table::LiveTableRegistry> {
&self.live_table_registry
}
pub fn incremental_view_registry(&self) -> &Arc<incremental_view::IncrementalViewRegistry> {
&self.incremental_view_registry
}
pub fn pipeline_registry(&self) -> &Arc<pipeline_ddl::PipelineRegistry> {
&self.pipeline_registry
}
pub fn operation_registry(&self) -> &Arc<OperationRegistry> {
&self.operation_registry
}
pub fn deregister_table(&self, name: &str) -> SqlResult<()> {
if name.trim().is_empty() {
return Err(SqlError::EmptyTableName);
}
let _ = self
.context
.deregister_table(name)
.map_err(SqlError::from)?;
self.invalidate_plan_cache();
Ok(())
}
pub fn register_table_udf_fn(
&self,
name: impl Into<String>,
schema: arrow::datatypes::Schema,
f: impl Fn(
&[krishiv_plan::udf::ScalarValue],
) -> Result<arrow::record_batch::RecordBatch, krishiv_plan::udf::UdfError>
+ Send
+ Sync
+ 'static,
) -> SqlResult<()> {
let udf =
create_function_ddl::ClosureTableUdf::try_new(name, schema, std::sync::Arc::new(f))
.map_err(|error| SqlError::InvalidTableFunction {
message: error.to_string(),
})?;
if let Some(registry) = &self.udf_registry {
let mut guard = registry.write().map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
guard.register_table(std::sync::Arc::new(udf.clone()));
}
udf::register_single_table_udf(&self.context, std::sync::Arc::new(udf))
.map_err(SqlError::from)
}
pub fn is_streaming_query(&self, sql: &str) -> SqlResult<bool> {
if !self.has_streaming_sources.load(Ordering::Acquire) {
return Ok(false);
}
let sources = self
.streaming_sources
.read()
.unwrap_or_else(|e| e.into_inner());
if sources.is_empty() {
return Ok(false);
}
let dialect = GenericDialect {};
let statements = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
for stmt in &statements {
let mut is_streaming = false;
let _ = visit_relations(stmt, |relation| {
let full = relation.to_string();
let table_name = full.split('.').next_back().unwrap_or(&full);
if sources.contains(table_name) {
is_streaming = true;
return ControlFlow::Break(());
}
ControlFlow::Continue(())
});
if is_streaming {
return Ok(true);
}
}
Ok(false)
}
pub fn krishiv_catalog(&self) -> Option<&Arc<RwLock<InMemoryCatalog>>> {
self.krishiv_catalog.as_ref()
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
#[must_use]
pub fn with_iceberg_catalog(
self,
catalog: std::sync::Arc<catalog::unified::KrishivCatalog>,
catalog_name: impl Into<String>,
) -> Self {
let catalog_name = catalog_name.into();
let bridge = catalog::iceberg_catalog_bridge::IcebergCatalogBridge::new(
Arc::clone(&catalog),
catalog_name.clone(),
);
self.context
.register_catalog(catalog_name.clone(), Arc::new(bridge));
self.iceberg_catalogs
.write()
.unwrap_or_else(|e| e.into_inner())
.push((catalog, catalog_name));
self
}
#[must_use]
pub fn with_udf_registry(
mut self,
registry: std::sync::Arc<std::sync::RwLock<krishiv_plan::udf::UdfRegistry>>,
) -> Self {
self.udf_registry = Some(registry);
self.bump_udf_version();
self
}
pub(crate) fn bump_udf_version(&self) {
self.udf_registry_version.fetch_add(1, Ordering::Release);
}
fn invalidate_plan_cache(&self) {
match self.plan_cache.lock() {
Ok(mut cache) => cache.clear(),
Err(poisoned) => poisoned.into_inner().clear(),
}
}
pub fn clear_plan_cache(&self) {
self.invalidate_plan_cache();
}
pub async fn sync_scalar_udfs(&self) -> SqlResult<()> {
let Some(registry) = &self.udf_registry else {
return Ok(());
};
let guard = registry.read().map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
let limits = self.udf_limits.clone().unwrap_or_default();
udf::sync_scalar_udfs_with_limits(&self.context, &guard, limits).map_err(|e| {
SqlError::DataFusion {
message: e.to_string(),
}
})
}
pub async fn sync_scalar_udfs_with_limits(
&self,
limits: krishiv_plan::udf::ResourceLimits,
) -> SqlResult<()> {
self.sync_scalar_udfs_with_limits_for_profile(
limits,
krishiv_common::resolve_durability_profile(),
)
.await
}
pub async fn sync_scalar_udfs_with_limits_for_profile(
&self,
limits: krishiv_plan::udf::ResourceLimits,
profile: krishiv_common::DurabilityProfile,
) -> SqlResult<()> {
self.sync_scalar_udfs_with_limits_for_policy(
limits,
krishiv_common::NativeScalarUdfPolicy::resolve(profile),
)
.await
}
pub async fn sync_scalar_udfs_with_limits_for_policy(
&self,
limits: krishiv_plan::udf::ResourceLimits,
policy: krishiv_common::NativeScalarUdfPolicy,
) -> SqlResult<()> {
let Some(registry) = &self.udf_registry else {
return Ok(());
};
let guard = registry.read().map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
udf::sync_scalar_udfs_with_limits_for_policy(&self.context, &guard, limits, policy).map_err(
|e| SqlError::DataFusion {
message: e.to_string(),
},
)
}
pub async fn sync_aggregate_udfs(&self) -> SqlResult<()> {
let Some(registry) = &self.udf_registry else {
return Ok(());
};
let guard = registry.read().map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
udf::sync_aggregate_udfs(&self.context, &guard).map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})
}
pub async fn sync_table_udfs(&self) -> SqlResult<()> {
let Some(registry) = &self.udf_registry else {
return Ok(());
};
let guard = registry.read().map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
udf::sync_table_udfs(&self.context, &guard).map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})
}
pub async fn sync_all_udfs(&self) -> SqlResult<()> {
self.sync_scalar_udfs().await?;
self.sync_aggregate_udfs().await?;
self.sync_table_udfs().await?;
Ok(())
}
pub async fn register_parquet(
&self,
table_name: impl AsRef<str>,
path: impl AsRef<Path>,
) -> SqlResult<()> {
let table_name = table_name.as_ref();
if table_name.trim().is_empty() {
return Err(SqlError::EmptyTableName);
}
let path = path.as_ref().to_string_lossy().into_owned();
if path.starts_with("s3://") {
let url = url::Url::parse(&path).map_err(|e| SqlError::DataFusion {
message: format!("invalid s3 url {path}: {e}"),
})?;
let bucket = url.host_str().unwrap_or_default();
let store_url =
url::Url::parse(&format!("s3://{bucket}")).map_err(|e| SqlError::DataFusion {
message: format!("invalid s3 bucket url: {e}"),
})?;
let store = AmazonS3Builder::from_env()
.with_bucket_name(bucket)
.build()
.map_err(|e| SqlError::DataFusion {
message: format!("s3 store init: {e}"),
})?;
self.context
.register_object_store(&store_url, Arc::new(store));
}
if self
.context
.table_exist(table_name)
.map_err(SqlError::from)?
{
let _ = self
.context
.deregister_table(table_name)
.map_err(SqlError::from)?;
}
self.context
.register_parquet(table_name, path, ParquetReadOptions::default())
.await?;
if let Ok(provider) = self.context.table_provider(table_name).await
&& let Some(stats) = provider.statistics()
&& let Some(n) = stats.num_rows.get_value()
&& let Ok(mut counts) = self.table_row_counts.write()
{
counts.insert(table_name.to_string(), *n as u64);
}
self.invalidate_plan_cache();
Ok(())
}
pub async fn read_parquet(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
let path = path.as_ref().to_string_lossy().into_owned();
let dataframe = self
.context
.read_parquet(path, ParquetReadOptions::default())
.await?;
Ok(self.make_sql_df("parquet-read", dataframe))
}
pub async fn register_record_batches(
&self,
table_name: impl AsRef<str>,
batches: Vec<RecordBatch>,
) -> SqlResult<()> {
use std::sync::Arc;
let table_name = table_name.as_ref();
if table_name.trim().is_empty() {
return Err(SqlError::EmptyTableName);
}
if batches.is_empty() {
return Ok(());
}
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
let schema = batches
.first()
.ok_or_else(|| SqlError::DataFusion {
message: "empty batch list".into(),
})?
.schema();
let mem_table =
datafusion::datasource::MemTable::try_new(schema, vec![batches]).map_err(|e| {
SqlError::DataFusion {
message: e.to_string(),
}
})?;
if self
.context
.table_exist(table_name)
.map_err(SqlError::from)?
{
let _ = self
.context
.deregister_table(table_name)
.map_err(SqlError::from)?;
}
self.context
.register_table(table_name, Arc::new(mem_table))
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
if total_rows > 0
&& let Ok(mut counts) = self.table_row_counts.write()
{
counts.insert(table_name.to_string(), total_rows as u64);
}
self.invalidate_plan_cache();
Ok(())
}
pub async fn read_parquet_with_options(
&self,
path: impl AsRef<Path>,
opts: &ParquetReaderOptions,
) -> SqlResult<SqlDataFrame> {
let path = path.as_ref().to_string_lossy().into_owned();
let mut options = datafusion::prelude::ParquetReadOptions::default();
if opts.batch_size.is_some() {
options = options.parquet_pruning(true);
}
let dataframe = self.context.read_parquet(path, options).await?;
Ok(self.make_sql_df("parquet-read", dataframe))
}
pub async fn read_csv(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
self.read_csv_with_options(path, &CsvReaderOptions::default())
.await
}
pub async fn read_csv_with_options(
&self,
path: impl AsRef<Path>,
opts: &CsvReaderOptions,
) -> SqlResult<SqlDataFrame> {
let path = path.as_ref().to_string_lossy().into_owned();
let mut options = datafusion::prelude::CsvReadOptions::new();
if let Some(delim) = opts.delimiter {
options = options.delimiter(delim as u8);
}
if let Some(has_header) = opts.has_header {
options = options.has_header(has_header);
}
let dataframe = self.context.read_csv(path, options).await?;
Ok(self.make_sql_df("csv-read", dataframe))
}
pub async fn read_json(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
let path = path.as_ref().to_string_lossy().into_owned();
let dataframe = self
.context
.read_json(path, datafusion::prelude::JsonReadOptions::default())
.await?;
Ok(self.make_sql_df("json-read", dataframe))
}
pub async fn read_delta(
&self,
path: impl AsRef<str>,
version: Option<i64>,
) -> SqlResult<SqlDataFrame> {
let path = path.as_ref();
let base = path.replace(['/', '.', '-'], "_");
let table = match version {
Some(v) => format!("delta_{base}_v{v}"),
None => format!("delta_{base}"),
};
lakehouse::register_delta_uri(&self.context, &table, path, version).await?;
self.sql(format!("SELECT * FROM {table}")).await
}
pub async fn read_hudi(
&self,
path: impl AsRef<str>,
query_type: krishiv_connectors::lakehouse::HudiQueryType,
begin_instant: Option<&str>,
) -> SqlResult<SqlDataFrame> {
let path = path.as_ref();
let table = format!("hudi_{}", path.replace(['/', '.', '-'], "_"));
lakehouse::register_hudi_uri(&self.context, &table, path, query_type, begin_instant)
.await?;
self.sql(format!("SELECT * FROM {table}")).await
}
pub async fn sql(&self, query: impl AsRef<str>) -> SqlResult<SqlDataFrame> {
let query = query.as_ref();
if query.trim().is_empty() {
return Err(SqlError::EmptyQuery);
}
{
let current = self.udf_registry_version.load(Ordering::Acquire);
let last = self.udf_last_synced_version.load(Ordering::Relaxed);
if current != last {
self.sync_all_udfs().await?;
self.udf_last_synced_version
.store(current, Ordering::Release);
}
}
if let Some(stmt) = introspection_sql::parse_introspection_statement(query)? {
return match stmt {
introspection_sql::IntrospectionStatement::Describe { table } => {
let batch = introspection_sql::describe_table(&self.context, &table).await?;
let describe_table_name = next_ephemeral_name("describe_result");
lakehouse::register_scan_batches(
&self.context,
&describe_table_name,
vec![batch],
)
.await?;
let dataframe = self
.context
.sql(&format!("SELECT * FROM {describe_table_name}"))
.await?;
Ok(self.attach_query_metadata(self.make_sql_df("describe", dataframe), query))
}
introspection_sql::IntrospectionStatement::Explain { mode, query: inner } => {
let text = introspection_sql::explain_query(&inner, mode)?;
let batch = introspection_sql::explain_result_batch(&text)?;
let explain_table = next_ephemeral_name("explain_result");
lakehouse::register_scan_batches(&self.context, &explain_table, vec![batch])
.await?;
let dataframe = self
.context
.sql(&format!("SELECT * FROM {explain_table}"))
.await?;
Ok(self.attach_query_metadata(self.make_sql_df("explain", dataframe), query))
}
};
}
if live_table::execute_live_table_ddl(&self.live_table_registry, query)?.is_some() {
let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
return Ok(self.attach_query_metadata(self.make_sql_df("live-table-ddl", empty), query));
}
if incremental_view::execute_incremental_view_ddl(&self.incremental_view_registry, query)?
.is_some()
{
let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
return Ok(
self.attach_query_metadata(self.make_sql_df("incremental-view-ddl", empty), query)
);
}
if pipeline_ddl::execute_pipeline_ddl(&self.pipeline_registry, query)?.is_some() {
let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
return Ok(self.attach_query_metadata(self.make_sql_df("pipeline-ddl", empty), query));
}
let trimmed = query.trim();
if trimmed
.to_ascii_uppercase()
.starts_with("SET SHUFFLE.PARTITIONS")
{
let value = trimmed.split('=').nth(1).map(|s| s.trim()).unwrap_or("");
match value.parse::<u32>() {
Ok(n) if n > 0 => {
{
let mut guard =
self.shuffle_partitions
.write()
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
*guard = Some(n);
}
let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
return Ok(self.make_sql_df("set-shuffle-partitions", empty));
}
Ok(_) => {
{
let mut guard =
self.shuffle_partitions
.write()
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
*guard = None;
}
let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
return Ok(self.make_sql_df("set-shuffle-partitions", empty));
}
Err(_) => {
return Err(SqlError::DataFusion {
message: format!(
"invalid shuffle.partitions value '{value}'; expected a positive integer"
),
});
}
}
}
if create_function_ddl::is_create_function_returns_table(query) {
let ddl = create_function_ddl::parse_create_function(query)
.map_err(|message| SqlError::InvalidTableFunction { message })?;
if ddl.language.as_deref() != Some("sql") {
return Err(SqlError::Unsupported {
feature: format!(
"CREATE FUNCTION '{}' uses language {:?}; only LANGUAGE SQL AS '...' \
table functions are executable",
ddl.function_name, ddl.language
),
});
}
let body = ddl
.body
.as_deref()
.filter(|body| !body.trim().is_empty())
.ok_or_else(|| SqlError::InvalidTableFunction {
message: format!(
"SQL table function '{}' requires a non-empty AS body",
ddl.function_name
),
})?;
let fields: Vec<_> = ddl
.return_columns
.iter()
.map(|column| {
arrow::datatypes::Field::new(&column.name, column.data_type.clone(), true)
})
.collect();
let schema = arrow::datatypes::Schema::new(fields);
let udf: std::sync::Arc<dyn krishiv_plan::udf::TableUdf> = std::sync::Arc::new(
create_function_ddl::SqlBodyTableUdf::try_new(
&ddl.function_name,
schema,
body,
ddl.arguments.len(),
std::sync::Arc::new(self.context.clone()),
)
.map_err(|error| SqlError::InvalidTableFunction {
message: error.to_string(),
})?,
);
if let Some(registry) = &self.udf_registry {
let mut guard = registry.write().map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
guard.register_table(std::sync::Arc::clone(&udf));
}
udf::register_single_table_udf(&self.context, std::sync::Arc::clone(&udf))
.map_err(SqlError::from)?;
let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
return Ok(
self.attach_query_metadata(self.make_sql_df("create-function", empty), query)
);
}
if query
.trim_start()
.to_ascii_uppercase()
.starts_with("MERGE INTO")
{
let batches = lakehouse::execute_merge_sql(&self.context, query).await?;
let merge_table = next_ephemeral_name("merge_result");
lakehouse::register_scan_batches(&self.context, &merge_table, batches).await?;
let dataframe = self
.context
.sql(&format!("SELECT * FROM {merge_table}"))
.await?;
return Ok(self.attach_query_metadata(self.make_sql_df("merge", dataframe), query));
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
if trimmed.to_ascii_uppercase().starts_with("CALL SYSTEM.") {
let result = self.dispatch_call_system(trimmed).await?;
let call_table = next_ephemeral_name("call_result");
lakehouse::register_scan_batches(&self.context, &call_table, vec![result]).await?;
let dataframe = self
.context
.sql(&format!("SELECT * FROM {call_table}"))
.await?;
return Ok(self.attach_query_metadata(self.make_sql_df("call", dataframe), query));
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
if trimmed.to_ascii_uppercase().starts_with("DELETE FROM ") {
if let Some((table_ref, predicate)) = parse_dml_delete(trimmed) {
if let Some((iceberg_catalog, table_ident)) = self.resolve_iceberg_table(&table_ref)
{
use arrow::array::{ArrayRef, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
let (deleted, _) = krishiv_connectors::lakehouse::dml::iceberg_delete_where(
iceberg_catalog,
&table_ident,
&predicate,
&self.context,
)
.await
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
let schema = Arc::new(Schema::new(vec![Field::new(
"deleted_rows",
DataType::Int64,
false,
)]));
let array: ArrayRef = Arc::new(Int64Array::from(vec![deleted as i64]));
let batch = RecordBatch::try_new(schema, vec![array]).map_err(|e| {
SqlError::DataFusion {
message: e.to_string(),
}
})?;
let res_table = next_ephemeral_name("delete_result");
lakehouse::register_scan_batches(&self.context, &res_table, vec![batch])
.await?;
let dataframe = self
.context
.sql(&format!("SELECT * FROM {res_table}"))
.await?;
return Ok(
self.attach_query_metadata(self.make_sql_df("delete", dataframe), query)
);
}
}
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
if trimmed.to_ascii_uppercase().starts_with("UPDATE ") {
if let Some(parsed) = parse_dml_update(trimmed) {
if let Some((iceberg_catalog, table_ident)) =
self.resolve_iceberg_table(&parsed.table_ref)
{
use arrow::array::{ArrayRef, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
let borrowed: Vec<(&str, &str)> = parsed
.assignments
.iter()
.map(|(c, e)| (c.as_str(), e.as_str()))
.collect();
let pred = parsed.predicate.as_deref();
let (updated, _) = krishiv_connectors::lakehouse::dml::iceberg_update_where(
iceberg_catalog,
&table_ident,
&borrowed,
pred,
&self.context,
)
.await
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
let schema = Arc::new(Schema::new(vec![Field::new(
"updated_rows",
DataType::Int64,
false,
)]));
let array: ArrayRef = Arc::new(Int64Array::from(vec![updated as i64]));
let batch = RecordBatch::try_new(schema, vec![array]).map_err(|e| {
SqlError::DataFusion {
message: e.to_string(),
}
})?;
let res_table = next_ephemeral_name("update_result");
lakehouse::register_scan_batches(&self.context, &res_table, vec![batch])
.await?;
let dataframe = self
.context
.sql(&format!("SELECT * FROM {res_table}"))
.await?;
return Ok(
self.attach_query_metadata(self.make_sql_df("update", dataframe), query)
);
}
}
}
if query.to_ascii_uppercase().contains(" MATCH_RECOGNIZE ")
&& let Some(stmt) = cep_sql::parse_match_recognize(query)?
{
let is_streaming = self.is_streaming_source(&stmt.source_table);
let streaming_limit = streaming_match_recognize_limit_from_env();
let source_sql = if is_streaming {
format!(
"SELECT * FROM {} LIMIT {}",
stmt.source_table, streaming_limit
)
} else {
format!("SELECT * FROM {}", stmt.source_table)
};
let source_df = self.context.sql(&source_sql).await?;
let source_batches = source_df.collect().await?;
if is_streaming {
tracing::warn!(
source = %stmt.source_table,
limit = streaming_limit,
collected_rows = source_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
"MATCH_RECOGNIZE executed against a streaming source under \
bounded materialisation; results only cover the first {0} rows \
of the source. Set KRISHIV_MATCH_RECOGNIZE_STREAMING_LIMIT to a \
larger value if your executor has the memory budget.",
streaming_limit
);
}
let results = cep_sql::execute_match_recognize(stmt, &source_batches)?;
let cep_table = next_ephemeral_name("cep_result");
lakehouse::register_scan_batches(&self.context, &cep_table, results).await?;
let dataframe = self
.context
.sql(&format!("SELECT * FROM {cep_table}"))
.await?;
return Ok(self.attach_query_metadata(self.make_sql_df("cep", dataframe), query));
}
let query = &pivot_sql::rewrite_pivot_unpivot(query)?;
let query = &streaming_tvf::rewrite_window_tvfs(query);
let (rewritten, as_ofs) =
lakehouse::preprocess_as_of_sql(query).unwrap_or_else(|_| (query.to_string(), vec![]));
lakehouse::apply_as_of_refs(&self.context, &as_ofs).await?;
let can_cache = as_ofs.is_empty();
let shuffle_override = self
.shuffle_partitions
.read()
.map(|g| *g)
.unwrap_or_else(|e| *e.into_inner());
if can_cache {
let cached_plan: Option<datafusion::logical_expr::LogicalPlan> = self
.plan_cache
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(&rewritten)
.cloned();
if let Some(plan) = cached_plan {
let dataframe = self.context.execute_logical_plan(plan).await?;
return Ok(self.attach_query_metadata(
self.make_sql_df("sql-query", dataframe)
.with_shuffle_partitions(shuffle_override),
&rewritten,
));
}
}
let dataframe = self.context.sql(&rewritten).await?;
if let Some(table_name) = extract_create_external_table_name(&rewritten)
&& !table_name.is_empty()
&& let Ok(provider) = self.context.table_provider(&table_name).await
{
let maybe_rows = provider
.statistics()
.and_then(|s| s.num_rows.get_value().copied());
if let Some(n) = maybe_rows
&& let Ok(mut counts) = self.table_row_counts.write()
{
counts.entry(table_name).or_insert(n as u64);
}
}
if can_cache {
let plan = dataframe.logical_plan().clone();
match self.plan_cache.lock() {
Ok(mut cache) => cache.insert(rewritten.clone(), plan),
Err(poisoned) => poisoned.into_inner().insert(rewritten.clone(), plan),
}
}
Ok(self.attach_query_metadata(
self.make_sql_df("sql-query", dataframe)
.with_shuffle_partitions(shuffle_override),
&rewritten,
))
}
pub async fn execute_with_timeout(
&self,
query: impl AsRef<str> + Send,
timeout_ms: u64,
) -> SqlResult<SqlDataFrame> {
let timeout = std::time::Duration::from_millis(timeout_ms);
tokio::time::timeout(timeout, self.sql(query))
.await
.map_err(|_| SqlError::Timeout { timeout_ms })?
}
pub async fn execute_with_operation_id(
&self,
operation_id: u64,
query: impl AsRef<str> + Send,
cancelled_ids: &OperationRegistry,
) -> SqlResult<TaggedQueryResult> {
if cancelled_ids.is_cancelled(operation_id) {
return Err(SqlError::OperationCancelled { operation_id });
}
let df = self.sql(query).await?;
Ok(TaggedQueryResult {
operation_id,
inner: df,
})
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
fn resolve_iceberg_table(
&self,
table_ref: &str,
) -> Option<(Arc<dyn iceberg::Catalog + Send + Sync>, iceberg::TableIdent)> {
let parts: Vec<&str> = table_ref.splitn(3, '.').collect();
let (catalog_arc, ns_str, table_str) = {
let guard = self
.iceberg_catalogs
.read()
.unwrap_or_else(|e| e.into_inner());
if guard.is_empty() {
return None;
}
match parts.len() {
2 => {
let (cat, _) = guard.first()?;
(Arc::clone(cat), *parts.first()?, *parts.get(1)?)
}
3 => {
let cat_name = parts.first().copied()?;
let (cat, _) = guard.iter().find(|(_, n)| n == cat_name)?;
(Arc::clone(cat), *parts.get(1)?, *parts.get(2)?)
}
_ => return None,
}
};
let ns = iceberg::NamespaceIdent::from_vec(vec![ns_str.to_string()]).ok()?;
let ident = iceberg::TableIdent::new(ns, table_str.to_string());
Some((catalog_arc.as_iceberg(), ident))
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
async fn dispatch_call_system(&self, stmt: &str) -> SqlResult<RecordBatch> {
use arrow::array::{ArrayRef, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
let upper = stmt.to_ascii_uppercase();
const PREFIX: &str = "CALL SYSTEM.";
let upper_after = &upper[PREFIX.len()..];
let orig_after = &stmt[PREFIX.len()..];
let paren = upper_after.find('(').ok_or_else(|| SqlError::DataFusion {
message: format!("CALL: missing '(' in: {stmt}"),
})?;
let proc_name = upper_after[..paren].trim();
let args_raw = orig_after[paren + 1..]
.trim_end_matches(';')
.trim()
.trim_end_matches(')')
.trim();
let args = call_args_from_str(args_raw);
let iceberg_catalog = {
let guard = self
.iceberg_catalogs
.read()
.unwrap_or_else(|e| e.into_inner());
guard
.first()
.ok_or_else(|| SqlError::DataFusion {
message: "CALL system: no Iceberg catalog registered".to_string(),
})?
.0
.as_iceberg()
};
let table_ref = args.first().ok_or_else(|| SqlError::DataFusion {
message: format!("CALL {proc_name}: table reference argument is required"),
})?;
let table_ident = iceberg_table_ident(table_ref)?;
let count: i64 = match proc_name {
"EXPIRE_SNAPSHOTS" => {
let dur_s = args.get(1).ok_or_else(|| SqlError::DataFusion {
message: "CALL expire_snapshots: duration argument is required".to_string(),
})?;
let older_than = parse_call_duration(dur_s)?;
let retain_last = args
.get(2)
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(1);
krishiv_connectors::lakehouse::maintenance::expire_snapshots(
iceberg_catalog,
&table_ident,
older_than,
retain_last,
)
.await
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})? as i64
}
"REMOVE_ORPHAN_FILES" => {
let dur_s = args.get(1).ok_or_else(|| SqlError::DataFusion {
message: "CALL remove_orphan_files: duration argument is required".to_string(),
})?;
let older_than = parse_call_duration(dur_s)?;
krishiv_connectors::lakehouse::maintenance::remove_orphan_files(
iceberg_catalog,
&table_ident,
older_than,
)
.await
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})? as i64
}
"COMPACT_DATA_FILES" => {
let target_bytes = args
.get(1)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(128 * 1024 * 1024);
krishiv_connectors::lakehouse::maintenance::compact_data_files(
iceberg_catalog,
&table_ident,
target_bytes,
)
.await
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})? as i64
}
other => {
return Err(SqlError::Unsupported {
feature: format!("CALL system.{other}: unknown procedure"),
});
}
};
let col = match proc_name {
"EXPIRE_SNAPSHOTS" => "expired_snapshots",
"REMOVE_ORPHAN_FILES" => "removed_files",
"COMPACT_DATA_FILES" => "rewritten_files",
_ => "result",
};
let schema = Arc::new(Schema::new(vec![Field::new(col, DataType::Int64, false)]));
let array: ArrayRef = Arc::new(Int64Array::from(vec![count]));
RecordBatch::try_new(schema, vec![array]).map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})
}
}
pub struct TaggedQueryResult {
pub operation_id: u64,
pub inner: SqlDataFrame,
}
#[derive(Clone, Default)]
pub struct OperationRegistry {
cancelled: Arc<std::sync::RwLock<std::collections::HashSet<u64>>>,
progress: Arc<std::sync::RwLock<std::collections::HashMap<u64, (u64, u64)>>>,
}
impl OperationRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn cancel(&self, operation_id: u64) {
if let Ok(mut ids) = self.cancelled.write() {
ids.insert(operation_id);
}
}
pub fn is_cancelled(&self, operation_id: u64) -> bool {
self.cancelled
.read()
.map(|ids| ids.contains(&operation_id))
.unwrap_or(false)
}
pub fn remove(&self, operation_id: u64) {
if let Ok(mut ids) = self.cancelled.write() {
ids.remove(&operation_id);
}
if let Ok(mut progress) = self.progress.write() {
progress.remove(&operation_id);
}
}
pub fn update_progress(&self, operation_id: u64, rows_scanned: u64, rows_emitted: u64) {
if let Ok(mut progress) = self.progress.write() {
progress.insert(operation_id, (rows_scanned, rows_emitted));
}
}
pub fn progress(&self, operation_id: u64) -> Option<(u64, u64)> {
self.progress
.read()
.ok()
.and_then(|progress| progress.get(&operation_id).copied())
}
pub fn cancelled_ids(&self) -> Vec<u64> {
self.cancelled
.read()
.map(|ids| ids.iter().copied().collect())
.unwrap_or_default()
}
}
pub(crate) fn extract_create_external_table_name(query: &str) -> Option<String> {
use datafusion::sql::parser::{DFParser, Statement as DFStatement};
let mut stmts = DFParser::parse_sql(query).ok()?;
match stmts.pop_front()? {
DFStatement::CreateExternalTable(create) => Some(create.name.to_string()),
_ => None,
}
}
pub enum GroupingMode<'a> {
Sets(Vec<Vec<&'a krishiv_plan::expression::Expr>>),
Cube(Vec<&'a krishiv_plan::expression::Expr>),
Rollup(Vec<&'a krishiv_plan::expression::Expr>),
}
#[async_trait::async_trait]
pub trait KrishivDataFrameOps: Send + Sync {
async fn collect(&self) -> SqlResult<Vec<RecordBatch>>;
async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)>;
async fn explain(&self) -> SqlResult<String>;
fn explain_logical(&self) -> String;
fn krishiv_logical_plan(&self) -> LogicalPlan;
fn query(&self) -> Option<&str>;
async fn execute_stream(&self) -> SqlResult<SqlStream>;
fn schema(&self) -> SchemaRef;
async fn select(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn select_exprs(
&self,
expressions: &[&krishiv_plan::expression::Expr],
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn aggregate(
&self,
group_exprs: &[&krishiv_plan::expression::Expr],
aggregate_exprs: &[&krishiv_plan::expression::Expr],
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn aggregate_grouping(
&self,
grouping: GroupingMode<'_>,
aggregate_exprs: &[&krishiv_plan::expression::Expr],
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn pivot(
&self,
group_exprs: &[&krishiv_plan::expression::Expr],
pivot_column: &krishiv_plan::expression::Expr,
aggregate_expr: &krishiv_plan::expression::Expr,
values: &[(krishiv_plan::expression::ScalarValue, String)],
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn unpivot(
&self,
columns: &[&str],
name_column: &str,
value_column: &str,
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn filter(&self, predicate: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn filter_expr(
&self,
predicate: &krishiv_plan::expression::Expr,
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn limit(&self, n: usize) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn distinct(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn drop_nulls(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn sample(&self, fraction: f64) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn sort(
&self,
columns: &[&str],
descending: &[bool],
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn alias(&self, alias: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn drop_columns(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn rename_column(&self, old: &str, new: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn with_column(&self, name: &str, expr: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
fn as_any(&self) -> &dyn std::any::Any;
async fn describe(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn fill_null(&self, column: &str, value: &str)
-> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn join(
&self,
right: &dyn KrishivDataFrameOps,
how: &str,
left_on: &[&str],
right_on: &[&str],
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn union(
&self,
right: &dyn KrishivDataFrameOps,
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn union_distinct(
&self,
right: &dyn KrishivDataFrameOps,
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn intersect(
&self,
right: &dyn KrishivDataFrameOps,
distinct: bool,
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn except(
&self,
right: &dyn KrishivDataFrameOps,
distinct: bool,
) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
async fn register_batches(&self, name: &str, batches: Vec<RecordBatch>) -> SqlResult<()>;
async fn deregister_table(&self, name: &str) -> SqlResult<()>;
async fn create_view(&self, name: &str, replace: bool) -> SqlResult<()>;
}
fn df_plan_to_krishiv_nodes(
plan: &datafusion::logical_expr::LogicalPlan,
table_row_counts: &std::collections::HashMap<String, u64>,
counter: &mut usize,
) -> (Vec<krishiv_plan::PlanNode>, String) {
use datafusion::logical_expr::LogicalPlan as DfPlan;
use krishiv_plan::{ExecutionKind, NodeOp, PlanNode};
*counter += 1;
let idx = *counter;
match plan {
DfPlan::TableScan(ts) => {
let table_name = ts.table_name.table().to_string();
let row_count = table_row_counts.get(&table_name).copied();
let filters: Vec<String> = ts.filters.iter().map(|e| e.to_string()).collect();
let id = format!("scan-{idx}");
let node = PlanNode::new(&id, format!("Scan {table_name}"), ExecutionKind::Batch)
.with_op(NodeOp::Scan {
table: table_name,
filters,
})
.with_estimated_rows(row_count);
(vec![node], id)
}
DfPlan::Projection(proj) => {
let (mut nodes, input_id) =
df_plan_to_krishiv_nodes(&proj.input, table_row_counts, counter);
let id = format!("proj-{idx}");
let columns: Vec<String> = proj.expr.iter().map(|e| e.to_string()).collect();
nodes.push(
PlanNode::new(&id, "Projection", ExecutionKind::Batch)
.with_op(NodeOp::Project { columns })
.with_inputs([input_id]),
);
(nodes, id)
}
DfPlan::Filter(filter) => {
let (mut nodes, input_id) =
df_plan_to_krishiv_nodes(&filter.input, table_row_counts, counter);
let id = format!("filter-{idx}");
let predicate = filter.predicate.to_string();
nodes.push(
PlanNode::new(&id, "Filter", ExecutionKind::Batch)
.with_op(NodeOp::Filter { predicate })
.with_inputs([input_id]),
);
(nodes, id)
}
DfPlan::Aggregate(agg) => {
let (mut nodes, input_id) =
df_plan_to_krishiv_nodes(&agg.input, table_row_counts, counter);
let id = format!("agg-{idx}");
let group_keys: Vec<String> = agg.group_expr.iter().map(|e| e.to_string()).collect();
nodes.push(
PlanNode::new(&id, "Aggregate", ExecutionKind::Batch)
.with_op(NodeOp::Aggregate { group_keys })
.with_inputs([input_id]),
);
(nodes, id)
}
DfPlan::Join(join) => {
let (mut nodes, left_id) =
df_plan_to_krishiv_nodes(&join.left, table_row_counts, counter);
let (right_nodes, right_id) =
df_plan_to_krishiv_nodes(&join.right, table_row_counts, counter);
nodes.extend(right_nodes);
let id = format!("join-{idx}");
let krishiv_join_type = match join.join_type {
datafusion::common::JoinType::Inner => krishiv_plan::JoinType::Inner,
datafusion::common::JoinType::Left => krishiv_plan::JoinType::Left,
datafusion::common::JoinType::Right => krishiv_plan::JoinType::Right,
datafusion::common::JoinType::Full => krishiv_plan::JoinType::Full,
datafusion::common::JoinType::LeftSemi => krishiv_plan::JoinType::LeftSemi,
datafusion::common::JoinType::RightSemi => krishiv_plan::JoinType::RightSemi,
datafusion::common::JoinType::LeftAnti => krishiv_plan::JoinType::LeftAnti,
datafusion::common::JoinType::RightAnti => krishiv_plan::JoinType::RightAnti,
datafusion::common::JoinType::LeftMark => krishiv_plan::JoinType::LeftSemi,
datafusion::common::JoinType::RightMark => krishiv_plan::JoinType::RightSemi,
};
nodes.push(
PlanNode::new(&id, "Join", ExecutionKind::Batch)
.with_op(NodeOp::Join {
join_type: krishiv_join_type,
})
.with_inputs([left_id, right_id]),
);
(nodes, id)
}
DfPlan::Sort(sort) => {
let (mut nodes, input_id) =
df_plan_to_krishiv_nodes(&sort.input, table_row_counts, counter);
let id = format!("sort-{idx}");
nodes.push(
PlanNode::new(&id, "Sort", ExecutionKind::Batch)
.with_op(NodeOp::Other {
description: format!(
"Sort({})",
sort.expr
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
),
})
.with_inputs([input_id]),
);
(nodes, id)
}
DfPlan::Repartition(repart) => {
let (mut nodes, input_id) =
df_plan_to_krishiv_nodes(&repart.input, table_row_counts, counter);
let id = format!("exchange-{idx}");
let partitioning = krishiv_plan::Partitioning::Unpartitioned;
nodes.push(
PlanNode::new(&id, "Exchange", ExecutionKind::Batch)
.with_op(NodeOp::Exchange { partitioning })
.with_inputs([input_id]),
);
(nodes, id)
}
DfPlan::Limit(limit) => {
let (mut nodes, input_id) =
df_plan_to_krishiv_nodes(&limit.input, table_row_counts, counter);
let id = format!("limit-{idx}");
nodes.push(
PlanNode::new(&id, "Limit", ExecutionKind::Batch)
.with_op(NodeOp::Other {
description: format!(
"Limit(skip={:?}, fetch={:?})",
limit.skip.as_ref().map(|e| e.to_string()),
limit.fetch.as_ref().map(|e| e.to_string()),
),
})
.with_inputs([input_id]),
);
(nodes, id)
}
DfPlan::Union(union) if union.inputs.len() == 1 => {
if let Some(input) = union.inputs.first() {
df_plan_to_krishiv_nodes(input, table_row_counts, counter)
} else {
(Vec::new(), String::new())
}
}
DfPlan::Union(union) => {
let mut all_nodes = Vec::new();
let mut input_ids = Vec::new();
for input in &union.inputs {
let (sub_nodes, sub_id) =
df_plan_to_krishiv_nodes(input, table_row_counts, counter);
all_nodes.extend(sub_nodes);
input_ids.push(sub_id);
}
let id = format!("union-{idx}");
all_nodes.push(
PlanNode::new(&id, "Union", ExecutionKind::Batch)
.with_op(NodeOp::Other {
description: "Union".to_string(),
})
.with_inputs(input_ids),
);
(all_nodes, id)
}
DfPlan::SubqueryAlias(alias) => {
df_plan_to_krishiv_nodes(&alias.input, table_row_counts, counter)
}
DfPlan::Values(_) => {
let id = format!("values-{idx}");
let node = PlanNode::new(&id, "Values", ExecutionKind::Batch).with_op(NodeOp::Other {
description: "Values".to_string(),
});
(vec![node], id)
}
DfPlan::Extension(_) => {
let id = format!("ext-{idx}");
let label = plan.to_string();
let node = PlanNode::new(&id, label.clone(), ExecutionKind::Batch)
.with_op(NodeOp::Other { description: label });
(vec![node], id)
}
DfPlan::EmptyRelation(_) => {
let id = format!("empty-{idx}");
let node =
PlanNode::new(&id, "EmptyRelation", ExecutionKind::Batch).with_op(NodeOp::Other {
description: "EmptyRelation".to_string(),
});
(vec![node], id)
}
_ => {
let id = format!("df-{idx}");
let label = plan.to_string();
let node = PlanNode::new(&id, label.clone(), ExecutionKind::Batch)
.with_op(NodeOp::Other { description: label });
(vec![node], id)
}
}
}
#[derive(Clone)]
pub struct SqlDataFrame {
name: String,
query: Option<String>,
query_text: Option<String>,
execution_kind: ExecutionKind,
dataframe: DataFusionDataFrame,
shuffle_partitions: Option<u32>,
context: SessionContext,
table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
}
impl fmt::Debug for SqlDataFrame {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SqlDataFrame")
.field("name", &self.name)
.field("query", &self.query)
.field("shuffle_partitions", &self.shuffle_partitions)
.finish_non_exhaustive()
}
}
impl SqlDataFrame {
fn new(
name: impl Into<String>,
dataframe: DataFusionDataFrame,
table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
) -> Self {
Self {
name: name.into(),
query: None,
query_text: None,
execution_kind: ExecutionKind::Batch,
dataframe,
shuffle_partitions: None,
context: SessionContext::default(),
table_row_counts,
}
}
pub(crate) fn with_context(mut self, context: SessionContext) -> Self {
self.context = context;
self
}
fn with_query(mut self, query: impl Into<String>) -> Self {
let q = query.into();
self.query_text = Some(q.clone());
self.query = Some(q);
self
}
fn with_execution_kind(mut self, kind: ExecutionKind) -> Self {
self.execution_kind = kind;
self
}
fn with_shuffle_partitions(mut self, n: Option<u32>) -> Self {
self.shuffle_partitions = n;
self
}
pub fn query(&self) -> Option<&str> {
self.query.as_deref()
}
pub fn arrow_schema(&self) -> arrow::datatypes::SchemaRef {
std::sync::Arc::new(self.dataframe.schema().as_arrow().clone())
}
fn with_new_dataframe(&self, df: DataFusionDataFrame, tag: &str) -> Self {
Self {
name: format!("{}-{}", self.name, tag),
query: None,
query_text: None,
execution_kind: self.execution_kind,
dataframe: df,
shuffle_partitions: self.shuffle_partitions,
context: self.context.clone(),
table_row_counts: self.table_row_counts.clone(),
}
}
pub fn krishiv_logical_plan(&self) -> LogicalPlan {
let df_plan = self.dataframe.logical_plan();
let counts = self
.table_row_counts
.read()
.unwrap_or_else(|e| e.into_inner());
let mut counter = 0usize;
let (nodes, _root_id) = df_plan_to_krishiv_nodes(df_plan, &counts, &mut counter);
let mut plan = LogicalPlan::new(self.name.clone(), self.execution_kind);
for node in nodes {
plan = plan.with_node(node);
}
let optimizer = krishiv_plan::optimizer::default_logical_optimizer();
let fallback = plan.clone();
match optimizer.optimize(plan) {
Ok(result) => result.plan,
Err(error) => {
tracing::warn!(
plan = %self.name,
%error,
"logical optimizer failed; using unoptimized plan"
);
fallback
}
}
}
pub fn explain_logical(&self) -> String {
self.dataframe.logical_plan().to_string()
}
pub async fn explain(&self) -> SqlResult<String> {
let batches = self
.dataframe
.clone()
.explain(false, false)?
.collect()
.await?;
pretty_batches(&batches)
}
pub async fn collect(&self) -> SqlResult<Vec<RecordBatch>> {
Ok(self.dataframe.clone().collect().await?)
}
pub async fn execute_stream(&self) -> SqlResult<SqlStream> {
let df_stream = self.dataframe.clone().execute_stream().await?;
use futures::StreamExt;
let mapped = df_stream.map(|res| {
res.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})
});
Ok(Box::pin(mapped))
}
pub async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)> {
use datafusion::physical_plan::collect as df_collect;
let df = self.dataframe.clone();
let task_ctx = df.task_ctx();
let physical_plan = df.create_physical_plan().await?;
let batches = df_collect(physical_plan.clone(), task_ctx.into()).await?;
let mut output_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
let mut cpu_nanos: u64 = 0;
if let Some(metrics) = physical_plan.metrics() {
if let Some(v) = metrics.output_rows() {
output_rows = v as u64;
}
if let Some(t) = metrics.elapsed_compute() {
cpu_nanos = t as u64;
}
}
let (spill_bytes, spill_count) = aggregate_spill_metrics(physical_plan.as_ref());
Ok((
batches,
SqlExecutionStats {
output_rows,
cpu_nanos,
spill_bytes,
spill_count,
},
))
}
}
fn aggregate_spill_metrics(plan: &dyn datafusion::physical_plan::ExecutionPlan) -> (u64, u64) {
let mut spill_bytes: u64 = 0;
let mut spill_count: u64 = 0;
if let Some(metrics) = plan.metrics() {
if let Some(bytes) = metrics.spilled_bytes() {
spill_bytes = spill_bytes.saturating_add(bytes as u64);
}
if let Some(count) = metrics.spill_count() {
spill_count = spill_count.saturating_add(count as u64);
}
}
for child in plan.children() {
let (child_bytes, child_count) = aggregate_spill_metrics(child.as_ref());
spill_bytes = spill_bytes.saturating_add(child_bytes);
spill_count = spill_count.saturating_add(child_count);
}
(spill_bytes, spill_count)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct SqlExecutionStats {
pub output_rows: u64,
pub cpu_nanos: u64,
pub spill_bytes: u64,
pub spill_count: u64,
}
fn top_level_alias_index(expression: &str) -> Option<usize> {
let bytes = expression.as_bytes();
let mut depth = 0usize;
let mut single_quoted = false;
let mut double_quoted = false;
let mut candidate = None;
let mut index = 0usize;
while index < bytes.len() {
let Some(&byte) = bytes.get(index) else {
break;
};
match byte {
b'\'' if !double_quoted => {
if single_quoted && bytes.get(index + 1) == Some(&b'\'') {
index += 2;
continue;
}
single_quoted = !single_quoted;
}
b'"' if !single_quoted => {
if double_quoted && bytes.get(index + 1) == Some(&b'"') {
index += 2;
continue;
}
double_quoted = !double_quoted;
}
b'(' if !single_quoted && !double_quoted => depth += 1,
b')' if !single_quoted && !double_quoted => depth = depth.saturating_sub(1),
b' ' if depth == 0
&& !single_quoted
&& !double_quoted
&& bytes
.get(index..index + 4)
.is_some_and(|slice| slice.eq_ignore_ascii_case(b" AS ")) =>
{
candidate = Some(index);
index += 3;
}
_ => {}
}
index += 1;
}
candidate
}
fn parse_dataframe_expression(
dataframe: &datafusion::dataframe::DataFrame,
expression: &str,
) -> SqlResult<datafusion::logical_expr::Expr> {
if let Some(index) = top_level_alias_index(expression) {
let (body, alias) = expression.split_at(index);
let alias = alias[4..].trim();
if !alias.is_empty() {
let alias = alias
.strip_prefix('"')
.and_then(|value| value.strip_suffix('"'))
.unwrap_or(alias)
.replace("\"\"", "\"");
return Ok(dataframe.parse_sql_expr(body.trim())?.alias(alias));
}
}
dataframe.parse_sql_expr(expression).map_err(Into::into)
}
pub fn parse_public_expression(sql: &str) -> SqlResult<krishiv_plan::expression::Expr> {
let dialect = GenericDialect {};
let mut parser =
Parser::new(&dialect)
.try_with_sql(sql)
.map_err(|error| SqlError::Unsupported {
feature: format!("public expression parse: {error}"),
})?;
let expression = parser.parse_expr().map_err(|error| SqlError::Unsupported {
feature: format!("public expression parse: {error}"),
})?;
sqlparser_expression_to_public(&expression)
}
fn sqlparser_expression_to_public(
expression: &datafusion::sql::sqlparser::ast::Expr,
) -> SqlResult<krishiv_plan::expression::Expr> {
use datafusion::sql::sqlparser::ast::{BinaryOperator as SqlOperator, Expr as SqlExpr, Value};
use krishiv_plan::expression::{BinaryOperator, Expr, ScalarValue};
Ok(match expression {
SqlExpr::Identifier(identifier) => Expr::Column {
path: vec![identifier.value.clone()],
},
SqlExpr::CompoundIdentifier(identifiers) => Expr::Column {
path: identifiers
.iter()
.map(|identifier| identifier.value.clone())
.collect(),
},
SqlExpr::Nested(expression) => sqlparser_expression_to_public(expression)?,
SqlExpr::IsNull(expression) => Expr::IsNull {
expression: Box::new(sqlparser_expression_to_public(expression)?),
negated: false,
},
SqlExpr::IsNotNull(expression) => Expr::IsNull {
expression: Box::new(sqlparser_expression_to_public(expression)?),
negated: true,
},
SqlExpr::BinaryOp { left, op, right } => Expr::Binary {
left: Box::new(sqlparser_expression_to_public(left)?),
op: match op {
SqlOperator::Eq => BinaryOperator::Eq,
SqlOperator::NotEq => BinaryOperator::NotEq,
SqlOperator::Gt => BinaryOperator::Gt,
SqlOperator::GtEq => BinaryOperator::GtEq,
SqlOperator::Lt => BinaryOperator::Lt,
SqlOperator::LtEq => BinaryOperator::LtEq,
SqlOperator::And => BinaryOperator::And,
SqlOperator::Or => BinaryOperator::Or,
SqlOperator::Plus => BinaryOperator::Plus,
SqlOperator::Minus => BinaryOperator::Minus,
SqlOperator::Multiply => BinaryOperator::Multiply,
SqlOperator::Divide => BinaryOperator::Divide,
other => {
return Err(SqlError::Unsupported {
feature: format!("public expression operator {other}"),
});
}
},
right: Box::new(sqlparser_expression_to_public(right)?),
},
SqlExpr::Value(value) => Expr::Literal {
value: match &value.value {
Value::Null => ScalarValue::Null,
Value::Boolean(value) => ScalarValue::Boolean(*value),
Value::SingleQuotedString(value) => ScalarValue::Utf8(value.clone()),
Value::Number(value, _)
if value.contains('.') || value.contains('e') || value.contains('E') =>
{
ScalarValue::float64(value.parse::<f64>().map_err(|error| {
SqlError::Unsupported {
feature: format!("numeric expression literal: {error}"),
}
})?)
}
Value::Number(value, _) => {
ScalarValue::Int64(value.parse::<i64>().map_err(|error| {
SqlError::Unsupported {
feature: format!("integer expression literal: {error}"),
}
})?)
}
other => {
return Err(SqlError::Unsupported {
feature: format!("public expression literal {other}"),
});
}
},
},
other => {
return Err(SqlError::Unsupported {
feature: format!("public expression node {other}"),
});
}
})
}
fn public_data_type_to_arrow(
data_type: &krishiv_plan::expression::ExprDataType,
) -> arrow::datatypes::DataType {
use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use krishiv_plan::expression::{ExprDataType, IntervalUnit as PublicIntervalUnit};
match data_type {
ExprDataType::Null => DataType::Null,
ExprDataType::Boolean => DataType::Boolean,
ExprDataType::Int64 => DataType::Int64,
ExprDataType::UInt64 => DataType::UInt64,
ExprDataType::Float64 => DataType::Float64,
ExprDataType::Utf8 => DataType::Utf8,
ExprDataType::Binary => DataType::Binary,
ExprDataType::Decimal128 { precision, scale } => DataType::Decimal128(*precision, *scale),
ExprDataType::Date32 => DataType::Date32,
ExprDataType::Timestamp { unit, timezone } => DataType::Timestamp(
match unit {
krishiv_plan::expression::TimeUnit::Second => TimeUnit::Second,
krishiv_plan::expression::TimeUnit::Millisecond => TimeUnit::Millisecond,
krishiv_plan::expression::TimeUnit::Microsecond => TimeUnit::Microsecond,
krishiv_plan::expression::TimeUnit::Nanosecond => TimeUnit::Nanosecond,
},
timezone.clone().map(Into::into),
),
ExprDataType::Interval { unit } => DataType::Interval(match unit {
PublicIntervalUnit::YearMonth => IntervalUnit::YearMonth,
PublicIntervalUnit::DayTime => IntervalUnit::DayTime,
PublicIntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano,
}),
ExprDataType::List(element) => DataType::List(Arc::new(Field::new(
"item",
public_data_type_to_arrow(element),
true,
))),
ExprDataType::Map { key, value } => DataType::Map(
Arc::new(Field::new(
"entries",
DataType::Struct(
vec![
Arc::new(Field::new("key", public_data_type_to_arrow(key), false)),
Arc::new(Field::new("value", public_data_type_to_arrow(value), true)),
]
.into(),
),
false,
)),
false,
),
ExprDataType::Struct(fields) => DataType::Struct(
fields
.iter()
.map(|field| {
Arc::new(Field::new(
&field.name,
public_data_type_to_arrow(&field.data_type),
field.nullable,
))
})
.collect::<Vec<_>>()
.into(),
),
ExprDataType::Variant => DataType::Utf8,
}
}
fn public_scalar_to_datafusion(
value: &krishiv_plan::expression::ScalarValue,
) -> Option<datafusion::common::ScalarValue> {
use datafusion::common::ScalarValue;
use krishiv_plan::expression::{ScalarValue as PublicScalar, TimeUnit};
Some(match value {
PublicScalar::Null => ScalarValue::Null,
PublicScalar::Boolean(value) => ScalarValue::Boolean(Some(*value)),
PublicScalar::Int64(value) => ScalarValue::Int64(Some(*value)),
PublicScalar::UInt64(value) => ScalarValue::UInt64(Some(*value)),
PublicScalar::Float64(bits) => ScalarValue::Float64(Some(f64::from_bits(*bits))),
PublicScalar::Utf8(value) => ScalarValue::Utf8(Some(value.clone())),
PublicScalar::Binary(value) => ScalarValue::Binary(Some(value.clone())),
PublicScalar::Decimal128 {
value,
precision,
scale,
} => ScalarValue::Decimal128(Some(*value), *precision, *scale),
PublicScalar::Date32(value) => ScalarValue::Date32(Some(*value)),
PublicScalar::Timestamp {
value,
unit,
timezone,
} => {
let timezone = timezone.clone().map(Into::into);
match unit {
TimeUnit::Second => ScalarValue::TimestampSecond(Some(*value), timezone),
TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(Some(*value), timezone),
TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(Some(*value), timezone),
TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(Some(*value), timezone),
}
}
PublicScalar::Interval { .. } => return None,
})
}
fn lower_public_expression(
dataframe: &datafusion::dataframe::DataFrame,
expression: &krishiv_plan::expression::Expr,
) -> SqlResult<datafusion::logical_expr::Expr> {
expression
.validate()
.map_err(|error| SqlError::Unsupported {
feature: format!("invalid public expression: {error}"),
})?;
use datafusion::logical_expr::{Expr as DataFusionExpr, Operator, binary_expr, cast, try_cast};
use krishiv_plan::expression::{BinaryOperator, Expr};
Ok(match expression {
Expr::Column { path } if path.len() == 1 => {
datafusion::prelude::col(path.first().map(String::as_str).unwrap_or(""))
}
Expr::Column { .. } => parse_dataframe_expression(dataframe, &expression.to_sql())?,
Expr::Literal { value } => match public_scalar_to_datafusion(value) {
Some(value) => DataFusionExpr::Literal(value, None),
None => parse_dataframe_expression(dataframe, &expression.to_sql())?,
},
Expr::Alias { expression, name } => {
lower_public_expression(dataframe, expression)?.alias(name)
}
Expr::Binary { left, op, right } => binary_expr(
lower_public_expression(dataframe, left)?,
match op {
BinaryOperator::Eq => Operator::Eq,
BinaryOperator::NotEq => Operator::NotEq,
BinaryOperator::Gt => Operator::Gt,
BinaryOperator::GtEq => Operator::GtEq,
BinaryOperator::Lt => Operator::Lt,
BinaryOperator::LtEq => Operator::LtEq,
BinaryOperator::And => Operator::And,
BinaryOperator::Or => Operator::Or,
BinaryOperator::Plus => Operator::Plus,
BinaryOperator::Minus => Operator::Minus,
BinaryOperator::Multiply => Operator::Multiply,
BinaryOperator::Divide => Operator::Divide,
},
lower_public_expression(dataframe, right)?,
),
Expr::IsNull {
expression,
negated,
} => {
let expression = lower_public_expression(dataframe, expression)?;
if *negated {
expression.is_not_null()
} else {
expression.is_null()
}
}
Expr::Cast {
expression,
data_type,
safe,
} => {
let expression = lower_public_expression(dataframe, expression)?;
let data_type = public_data_type_to_arrow(data_type);
if *safe {
try_cast(expression, data_type)
} else {
cast(expression, data_type)
}
}
Expr::Sort { .. } => {
return Err(SqlError::Unsupported {
feature: "standalone sort expressions are only valid inside windows or order_by"
.into(),
});
}
Expr::Aggregate { .. }
| Expr::Function { .. }
| Expr::Window { .. }
| Expr::RawSql { .. } => parse_dataframe_expression(dataframe, &expression.to_sql())?,
})
}
fn sql_dataframe<'a>(
dataframe: &'a dyn KrishivDataFrameOps,
operation: &str,
) -> SqlResult<&'a SqlDataFrame> {
dataframe
.as_any()
.downcast_ref::<SqlDataFrame>()
.ok_or_else(|| SqlError::DataFusion {
message: format!("right DataFrame must be SqlDataFrame for {operation}"),
})
}
#[async_trait::async_trait]
impl KrishivDataFrameOps for SqlDataFrame {
async fn collect(&self) -> SqlResult<Vec<RecordBatch>> {
SqlDataFrame::collect(self).await
}
async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)> {
SqlDataFrame::collect_with_stats(self).await
}
async fn explain(&self) -> SqlResult<String> {
SqlDataFrame::explain(self).await
}
fn explain_logical(&self) -> String {
SqlDataFrame::explain_logical(self)
}
fn krishiv_logical_plan(&self) -> LogicalPlan {
let label = self.dataframe.logical_plan().to_string();
let mut plan = LogicalPlan::new(self.name.clone(), ExecutionKind::Batch).with_node(
PlanNode::new("datafusion-logical", label, ExecutionKind::Batch),
);
if let Some(n) = self.shuffle_partitions {
plan = plan.with_shuffle_partitions(Some(n));
}
plan
}
fn query(&self) -> Option<&str> {
SqlDataFrame::query(self)
}
async fn execute_stream(&self) -> SqlResult<SqlStream> {
SqlDataFrame::execute_stream(self).await
}
fn schema(&self) -> SchemaRef {
SchemaRef::from(self.dataframe.schema().clone())
}
async fn select(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let df = self.dataframe.clone().select_columns(columns)?;
Ok(Box::new(self.with_new_dataframe(df, "select")))
}
async fn select_exprs(
&self,
expressions: &[&krishiv_plan::expression::Expr],
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let expressions = expressions
.iter()
.map(|expression| lower_public_expression(&self.dataframe, expression))
.collect::<Result<Vec<_>, _>>()?;
let df = self.dataframe.clone().select(expressions)?;
Ok(Box::new(self.with_new_dataframe(df, "select_exprs")))
}
async fn aggregate(
&self,
group_exprs: &[&krishiv_plan::expression::Expr],
aggregate_exprs: &[&krishiv_plan::expression::Expr],
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
if aggregate_exprs.is_empty() {
return Err(SqlError::Unsupported {
feature: "aggregate requires at least one aggregate expression".into(),
});
}
let group_exprs = group_exprs
.iter()
.map(|expression| lower_public_expression(&self.dataframe, expression))
.collect::<Result<Vec<_>, _>>()?;
let aggregate_exprs = aggregate_exprs
.iter()
.map(|expression| lower_public_expression(&self.dataframe, expression))
.collect::<Result<Vec<_>, _>>()?;
let df = self
.dataframe
.clone()
.aggregate(group_exprs, aggregate_exprs)?;
Ok(Box::new(self.with_new_dataframe(df, "aggregate")))
}
async fn aggregate_grouping(
&self,
grouping: GroupingMode<'_>,
aggregate_exprs: &[&krishiv_plan::expression::Expr],
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
if aggregate_exprs.is_empty() {
return Err(SqlError::Unsupported {
feature: "grouping aggregation requires at least one aggregate expression".into(),
});
}
let lower = |expression: &&krishiv_plan::expression::Expr| {
lower_public_expression(&self.dataframe, expression)
};
let group = match grouping {
GroupingMode::Sets(sets) => datafusion::logical_expr::grouping_set(
sets.into_iter()
.map(|set| set.iter().map(lower).collect::<Result<Vec<_>, _>>())
.collect::<Result<Vec<_>, _>>()?,
),
GroupingMode::Cube(expressions) => datafusion::logical_expr::cube(
expressions
.iter()
.map(lower)
.collect::<Result<Vec<_>, _>>()?,
),
GroupingMode::Rollup(expressions) => datafusion::logical_expr::rollup(
expressions
.iter()
.map(lower)
.collect::<Result<Vec<_>, _>>()?,
),
};
let aggregates = aggregate_exprs
.iter()
.map(lower)
.collect::<Result<Vec<_>, _>>()?;
let df = self.dataframe.clone().aggregate(vec![group], aggregates)?;
Ok(Box::new(self.with_new_dataframe(df, "aggregate_grouping")))
}
async fn pivot(
&self,
group_exprs: &[&krishiv_plan::expression::Expr],
pivot_column: &krishiv_plan::expression::Expr,
aggregate_expr: &krishiv_plan::expression::Expr,
values: &[(krishiv_plan::expression::ScalarValue, String)],
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
use krishiv_plan::expression::Expr as PublicExpr;
let (function, input, distinct) = match aggregate_expr {
PublicExpr::Aggregate {
function,
expression: Some(input),
distinct,
} => (*function, input.as_ref(), *distinct),
_ => {
return Err(SqlError::Unsupported {
feature: "pivot requires an aggregate expression with one input".into(),
});
}
};
if values.is_empty() {
return Err(SqlError::Unsupported {
feature: "pivot requires at least one value".into(),
});
}
let group_exprs = group_exprs
.iter()
.map(|expression| lower_public_expression(&self.dataframe, expression))
.collect::<Result<Vec<_>, _>>()?;
let aggregates = values
.iter()
.map(|(value, alias)| {
let conditional = PublicExpr::raw(format!(
"CASE WHEN {} = {} THEN {} END",
pivot_column.to_sql(),
value.to_sql_literal(),
input.to_sql()
));
let aggregate = PublicExpr::Aggregate {
function,
expression: Some(Box::new(conditional)),
distinct,
}
.alias(alias);
lower_public_expression(&self.dataframe, &aggregate)
})
.collect::<Result<Vec<_>, _>>()?;
let dataframe = self.dataframe.clone().aggregate(group_exprs, aggregates)?;
Ok(Box::new(self.with_new_dataframe(dataframe, "pivot")))
}
async fn unpivot(
&self,
columns: &[&str],
name_column: &str,
value_column: &str,
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
if columns.is_empty() {
return Err(SqlError::Unsupported {
feature: "unpivot requires at least one column".into(),
});
}
let retained = self
.dataframe
.schema()
.fields()
.iter()
.map(|field| field.name().as_str())
.filter(|name| !columns.contains(name))
.collect::<Vec<_>>();
let mut branches = Vec::with_capacity(columns.len());
for column in columns {
let mut expressions = retained
.iter()
.map(|name| datafusion::logical_expr::col(*name))
.collect::<Vec<_>>();
expressions
.push(datafusion::logical_expr::lit((*column).to_owned()).alias(name_column));
expressions.push(datafusion::logical_expr::col(*column).alias(value_column));
branches.push(self.dataframe.clone().select(expressions)?);
}
let mut branches = branches.into_iter();
let Some(mut dataframe) = branches.next() else {
return Err(SqlError::Unsupported {
feature: "unpivot requires at least one branch".into(),
});
};
for branch in branches {
dataframe = dataframe.union(branch)?;
}
Ok(Box::new(self.with_new_dataframe(dataframe, "unpivot")))
}
async fn filter(&self, predicate: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let expr = self.dataframe.parse_sql_expr(predicate)?;
let df = self.dataframe.clone().filter(expr)?;
Ok(Box::new(self.with_new_dataframe(df, "filter")))
}
async fn filter_expr(
&self,
predicate: &krishiv_plan::expression::Expr,
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let expr = lower_public_expression(&self.dataframe, predicate)?;
let df = self.dataframe.clone().filter(expr)?;
Ok(Box::new(self.with_new_dataframe(df, "filter_expr")))
}
async fn limit(&self, n: usize) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let df = self.dataframe.clone().limit(0, Some(n))?;
Ok(Box::new(self.with_new_dataframe(df, "limit")))
}
async fn distinct(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let df = self.dataframe.clone().distinct()?;
Ok(Box::new(self.with_new_dataframe(df, "distinct")))
}
async fn drop_nulls(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let columns = if columns.is_empty() {
self.dataframe
.schema()
.fields()
.iter()
.map(|field| field.name().as_str())
.collect::<Vec<_>>()
} else {
columns.to_vec()
};
let mut predicate: Option<datafusion::logical_expr::Expr> = None;
for column in columns {
let next = datafusion::logical_expr::col(column).is_not_null();
predicate = Some(match predicate {
Some(current) => current.and(next),
None => next,
});
}
let df = match predicate {
Some(predicate) => self.dataframe.clone().filter(predicate)?,
None => self.dataframe.clone(),
};
Ok(Box::new(self.with_new_dataframe(df, "drop_nulls")))
}
async fn sample(&self, fraction: f64) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
if !(0.0..=1.0).contains(&fraction) {
return Err(SqlError::Unsupported {
feature: "sample fraction must be between 0 and 1".into(),
});
}
let predicate = self
.dataframe
.parse_sql_expr(&format!("random() < {fraction}"))?;
let df = self.dataframe.clone().filter(predicate)?;
Ok(Box::new(self.with_new_dataframe(df, "sample")))
}
async fn sort(
&self,
columns: &[&str],
descending: &[bool],
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
use datafusion::logical_expr::SortExpr;
let exprs: Vec<SortExpr> = columns
.iter()
.zip(descending.iter())
.map(|(col_name, desc)| datafusion::logical_expr::col(*col_name).sort(!desc, *desc))
.collect();
let df = self.dataframe.clone().sort(exprs)?;
Ok(Box::new(self.with_new_dataframe(df, "sort")))
}
async fn alias(&self, alias: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let df = self.dataframe.clone().alias(alias)?;
Ok(Box::new(self.with_new_dataframe(df, "alias")))
}
async fn drop_columns(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let df = self.dataframe.clone().drop_columns(columns)?;
Ok(Box::new(self.with_new_dataframe(df, "drop")))
}
async fn rename_column(&self, old: &str, new: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let df = self.dataframe.clone().with_column_renamed(old, new)?;
Ok(Box::new(self.with_new_dataframe(df, "rename")))
}
async fn with_column(&self, name: &str, expr: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let parsed = self.dataframe.parse_sql_expr(expr)?;
let df = self.dataframe.clone().with_column(name, parsed)?;
Ok(Box::new(self.with_new_dataframe(df, "with_column")))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn describe(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let df = self.dataframe.clone().describe().await?;
Ok(Box::new(self.with_new_dataframe(df, "describe")))
}
async fn fill_null(
&self,
column: &str,
value: &str,
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let expr = format!("COALESCE({column}, {value})");
let parsed = self.dataframe.parse_sql_expr(&expr)?;
let df = self.dataframe.clone().with_column(column, parsed)?;
Ok(Box::new(self.with_new_dataframe(df, "fill_null")))
}
async fn join(
&self,
right: &dyn KrishivDataFrameOps,
how: &str,
left_on: &[&str],
right_on: &[&str],
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let right_sql = right
.as_any()
.downcast_ref::<SqlDataFrame>()
.ok_or_else(|| SqlError::DataFusion {
message: "right DataFrame must be SqlDataFrame for join".into(),
})?;
use datafusion::common::JoinType;
let join_type = match how.to_lowercase().as_str() {
"inner" => JoinType::Inner,
"left" => JoinType::Left,
"right" => JoinType::Right,
"full" | "outer" => JoinType::Full,
"leftsemi" | "left_semi" => JoinType::LeftSemi,
"rightsemi" | "right_semi" => JoinType::RightSemi,
"leftanti" | "left_anti" => JoinType::LeftAnti,
"rightanti" | "right_anti" => JoinType::RightAnti,
_ => {
return Err(SqlError::DataFusion {
message: format!("unsupported join type: {how}"),
});
}
};
let df = self.dataframe.clone().join(
right_sql.dataframe.clone(),
join_type,
left_on,
right_on,
None,
)?;
Ok(Box::new(self.with_new_dataframe(df, "join")))
}
async fn union(
&self,
right: &dyn KrishivDataFrameOps,
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let right_sql = right
.as_any()
.downcast_ref::<SqlDataFrame>()
.ok_or_else(|| SqlError::DataFusion {
message: "right DataFrame must be SqlDataFrame for union".into(),
})?;
let df = self.dataframe.clone().union(right_sql.dataframe.clone())?;
Ok(Box::new(self.with_new_dataframe(df, "union")))
}
async fn union_distinct(
&self,
right: &dyn KrishivDataFrameOps,
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let right = sql_dataframe(right, "union_distinct")?;
let df = self
.dataframe
.clone()
.union_distinct(right.dataframe.clone())?;
Ok(Box::new(self.with_new_dataframe(df, "union_distinct")))
}
async fn intersect(
&self,
right: &dyn KrishivDataFrameOps,
distinct: bool,
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let right = sql_dataframe(right, "intersect")?;
let df = if distinct {
self.dataframe
.clone()
.intersect_distinct(right.dataframe.clone())?
} else {
self.dataframe.clone().intersect(right.dataframe.clone())?
};
Ok(Box::new(self.with_new_dataframe(df, "intersect")))
}
async fn except(
&self,
right: &dyn KrishivDataFrameOps,
distinct: bool,
) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
let right = sql_dataframe(right, "except")?;
let df = if distinct {
self.dataframe
.clone()
.except_distinct(right.dataframe.clone())?
} else {
self.dataframe.clone().except(right.dataframe.clone())?
};
Ok(Box::new(self.with_new_dataframe(df, "except")))
}
async fn register_batches(&self, name: &str, batches: Vec<RecordBatch>) -> SqlResult<()> {
let schema = batches
.first()
.map(|b| b.schema())
.unwrap_or_else(|| Arc::new(arrow::datatypes::Schema::empty()));
let mem_table =
datafusion::datasource::MemTable::try_new(schema, vec![batches]).map_err(|e| {
SqlError::DataFusion {
message: e.to_string(),
}
})?;
self.context
.register_table(name, Arc::new(mem_table))
.map_err(SqlError::from)?;
Ok(())
}
async fn deregister_table(&self, name: &str) -> SqlResult<()> {
let _ = self
.context
.deregister_table(name)
.map_err(SqlError::from)?;
Ok(())
}
async fn create_view(&self, name: &str, replace: bool) -> SqlResult<()> {
let query = self
.query_text
.as_deref()
.ok_or_else(|| SqlError::DataFusion {
message: "create_view requires an SQL query string on the DataFrame".into(),
})?;
let or_replace = if replace { "OR REPLACE " } else { "" };
let safe_name = quote_identifier(name);
let view_sql = format!("CREATE {or_replace}VIEW {safe_name} AS {query}");
self.context.sql(&view_sql).await?;
Ok(())
}
}
use krishiv_common::sql_util::quote_identifier;
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
fn call_args_from_str(s: &str) -> Vec<String> {
let mut args: Vec<String> = Vec::new();
let mut cur = String::new();
let mut in_str = false;
let mut after_str = false;
for ch in s.chars() {
if after_str {
if ch == ',' {
after_str = false;
}
continue;
}
if in_str {
if ch == '\'' {
in_str = false;
after_str = true;
args.push(std::mem::take(&mut cur));
} else {
cur.push(ch);
}
} else if ch == '\'' {
in_str = true;
} else if ch == ',' {
let t = cur.trim().to_string();
if !t.is_empty() {
args.push(t);
}
cur.clear();
} else {
cur.push(ch);
}
}
let t = cur.trim().to_string();
if !t.is_empty() {
args.push(t);
}
args
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
fn iceberg_table_ident(table_ref: &str) -> SqlResult<iceberg::TableIdent> {
let parts: Vec<&str> = table_ref.splitn(3, '.').collect();
match parts.len() {
2 => {
let ns = iceberg::NamespaceIdent::from_vec(vec![
parts.first().copied().unwrap_or("").to_string(),
])
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
Ok(iceberg::TableIdent::new(
ns,
parts.get(1).copied().unwrap_or("").to_string(),
))
}
3 => {
let ns = iceberg::NamespaceIdent::from_vec(vec![
parts.get(1).copied().unwrap_or("").to_string(),
])
.map_err(|e| SqlError::DataFusion {
message: e.to_string(),
})?;
Ok(iceberg::TableIdent::new(
ns,
parts.get(2).copied().unwrap_or("").to_string(),
))
}
_ => Err(SqlError::DataFusion {
message: format!(
"invalid table reference '{table_ref}': expected 'ns.table' or 'cat.ns.table'"
),
}),
}
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
fn parse_call_duration(s: &str) -> SqlResult<chrono::Duration> {
let s = s.trim();
let mut it = s.splitn(2, ' ');
let n: i64 = it
.next()
.and_then(|v| v.parse().ok())
.ok_or_else(|| SqlError::DataFusion {
message: format!("invalid duration value in '{s}'"),
})?;
let unit = it.next().unwrap_or("").trim().to_ascii_lowercase();
match unit.trim_end_matches('s') {
"day" => Ok(chrono::Duration::days(n)),
"hour" => Ok(chrono::Duration::hours(n)),
"week" => Ok(chrono::Duration::weeks(n)),
"minute" | "min" => Ok(chrono::Duration::minutes(n)),
_ => Err(SqlError::DataFusion {
message: format!("unknown duration unit '{unit}' in '{s}'"),
}),
}
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
fn parse_dml_delete(stmt: &str) -> Option<(String, String)> {
use datafusion::sql::sqlparser::ast::{FromTable, Statement, TableFactor};
use datafusion::sql::sqlparser::dialect::GenericDialect;
use datafusion::sql::sqlparser::parser::Parser;
let mut stmts = Parser::parse_sql(&GenericDialect {}, stmt).ok()?;
if stmts.len() != 1 {
return None;
}
let Statement::Delete(delete) = stmts.remove(0) else {
return None;
};
let tables = match delete.from {
FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => tables,
};
let first_from = tables.into_iter().next()?;
let table_name = match first_from.relation {
TableFactor::Table { name, .. } => name.to_string(),
_ => return None,
};
let predicate = delete
.selection
.map(|e| e.to_string())
.unwrap_or_else(|| "TRUE".to_string());
Some((table_name, predicate))
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
struct ParsedUpdate {
table_ref: String,
assignments: Vec<(String, String)>,
predicate: Option<String>,
}
#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
fn parse_dml_update(stmt: &str) -> Option<ParsedUpdate> {
use datafusion::sql::sqlparser::ast::{Statement, TableFactor};
use datafusion::sql::sqlparser::dialect::GenericDialect;
use datafusion::sql::sqlparser::parser::Parser;
let mut stmts = Parser::parse_sql(&GenericDialect {}, stmt).ok()?;
if stmts.len() != 1 {
return None;
}
let Statement::Update(update) = stmts.remove(0) else {
return None;
};
let table_name = match update.table.relation {
TableFactor::Table { name, .. } => name.to_string(),
_ => return None,
};
let parsed_assignments: Vec<(String, String)> = update
.assignments
.into_iter()
.map(|a| {
let col = a.target.to_string();
let val = a.value.to_string();
(col, val)
})
.collect();
if parsed_assignments.is_empty() {
return None;
}
Some(ParsedUpdate {
table_ref: table_name,
assignments: parsed_assignments,
predicate: update.selection.map(|e| e.to_string()),
})
}
pub fn plan_sql(query: impl Into<String>) -> SqlResult<SqlPlan> {
let query = query.into();
if query.trim().is_empty() {
return Err(SqlError::EmptyQuery);
}
if let Some(stmt) = cep_sql::parse_match_recognize(&query)? {
let logical_plan = cep_sql::plan_match_recognize(stmt, &query);
let optimized = Optimizer::default().optimize(logical_plan)?;
return Ok(SqlPlan {
query,
logical_plan: optimized.plan,
});
}
let logical_plan =
LogicalPlan::new("sql-query", ExecutionKind::Batch).with_node(PlanNode::new(
"sql",
format!("sql: {}", query.trim()),
ExecutionKind::Batch,
));
let optimized = Optimizer::default().optimize(logical_plan)?;
Ok(SqlPlan {
query,
logical_plan: optimized.plan,
})
}
pub fn explain_sql(query: impl Into<String>) -> SqlResult<String> {
let plan = plan_sql(query)?;
Ok(plan.logical_plan().describe())
}
pub fn explain_sql_optimized(query: impl Into<String>, optimizer: &Optimizer) -> SqlResult<String> {
let plan = plan_sql(query)?;
let result = optimizer.optimize(plan.logical_plan().clone())?;
let mut output = result.plan.describe();
let optimizer_line = result.describe();
output.push('\n');
output.push_str(&optimizer_line);
Ok(output)
}
pub fn explain_sql_with_cost(
query: impl Into<String>,
cost_model: &dyn CostModel,
) -> SqlResult<String> {
let plan = plan_sql(query)?;
let cost = cost_model.estimate(plan.logical_plan());
let mut output = plan.logical_plan().describe();
output.push_str(&format!(
"\ncost: cpu_nanos={}, memory_bytes={}, network_bytes={}",
cost.cpu_nanos, cost.memory_bytes, cost.network_bytes
));
Ok(output)
}
pub fn referenced_table_names(query: impl AsRef<str>) -> SqlResult<Vec<String>> {
let query = query.as_ref();
if query.trim().is_empty() {
return Err(SqlError::EmptyQuery);
}
let statements =
Parser::parse_sql(&GenericDialect {}, query).map_err(|e| SqlError::DataFusion {
message: format!("SQL parse error: {e}"),
})?;
let mut names = BTreeSet::new();
let _ = visit_relations(&statements, |relation| {
names.insert(relation.to_string());
ControlFlow::<()>::Continue(())
});
Ok(names.into_iter().collect())
}
pub fn pretty_batches(batches: &[RecordBatch]) -> SqlResult<String> {
Ok(pretty_format_batches(batches)
.map_err(|error| SqlError::DataFusion {
message: error.to_string(),
})?
.to_string())
}
#[cfg(test)]
mod sql_tests;