1#![forbid(unsafe_code)]
2
3use std::collections::{BTreeSet, HashMap, VecDeque};
9use std::fmt;
10use std::num::NonZeroUsize;
11use std::ops::ControlFlow;
12use std::path::Path;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::{Arc, Mutex, RwLock};
15
16use arrow::datatypes::SchemaRef;
17use arrow::record_batch::RecordBatch;
18use arrow::util::pretty::pretty_format_batches;
19use catalog::{InMemoryCatalog, datafusion_bridge::DataFusionCatalogBridge};
20use datafusion::dataframe::DataFrame as DataFusionDataFrame;
21use datafusion::prelude::{ParquetReadOptions, SessionContext};
22use datafusion::sql::sqlparser::{ast::visit_relations, dialect::GenericDialect, parser::Parser};
23use object_store::aws::AmazonS3Builder;
24
25use krishiv_plan::optimizer::{CostModel, Optimizer};
26use krishiv_plan::{ExecutionKind, LogicalPlan, PlanNode};
27
28pub mod analyze;
29pub mod catalog;
30pub mod cep_sql;
31
32pub mod connector_table;
33pub mod create_function_ddl;
34pub mod grammar;
35pub mod incremental_view;
36pub mod introspection_sql;
37
38pub mod kafka_table;
39pub mod lakehouse;
40pub mod live_table;
41pub mod pipeline_ddl;
42pub mod pivot_sql;
43pub mod recursive_cte;
44pub mod spark_sql_ext;
46pub mod sqlstate;
47pub mod subquery;
48pub mod unnest_sql;
49
50pub mod streaming;
51pub mod streaming_tvf;
52pub mod streaming_window_plan;
53mod udf;
54mod window_functions;
55
56pub use cep_sql::{
57 MatchRecognizeStatement, execute_streaming_match_recognize, parse_match_recognize,
58};
59pub use lakehouse::{AsOfTableRef, MergeResult, MergeTargetUnsupportedError, preprocess_as_of_sql};
60
61pub use grammar::{
62 FeatureEntry, FeatureStatus, feature_matrix, features_by_status, features_for_category,
63};
64pub use sqlstate::{SqlStateError, sqlstate_for};
65pub use streaming::{ContinuousInputError, ContinuousTableInput};
66
67pub type SqlResult<T> = Result<T, SqlError>;
69
70pub type SqlStream =
76 std::pin::Pin<Box<dyn futures::stream::Stream<Item = Result<RecordBatch, SqlError>> + Send>>;
77
78static EPHEMERAL_TABLE_COUNTER: AtomicU64 = AtomicU64::new(0);
81
82fn next_ephemeral_name(prefix: &str) -> String {
83 let id = EPHEMERAL_TABLE_COUNTER.fetch_add(1, Ordering::Relaxed);
84 format!("__{prefix}_{id}")
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92enum WindowFnRegistration {
93 Register,
95 Skip,
99}
100
101struct PlanCache {
107 map: HashMap<String, datafusion::logical_expr::LogicalPlan>,
108 order: VecDeque<String>,
109 max: usize,
110}
111
112impl PlanCache {
113 fn new(max: usize) -> Self {
114 Self {
115 map: HashMap::new(),
116 order: VecDeque::new(),
117 max,
118 }
119 }
120
121 fn get(&self, key: &str) -> Option<&datafusion::logical_expr::LogicalPlan> {
122 self.map.get(key)
123 }
124
125 fn insert(&mut self, key: String, plan: datafusion::logical_expr::LogicalPlan) {
126 if self.map.contains_key(&key) {
127 self.order.retain(|k| k != &key);
130 } else if self.map.len() >= self.max
131 && let Some(oldest) = self.order.pop_front()
132 {
133 self.map.remove(&oldest);
134 }
135 self.order.push_back(key.clone());
136 self.map.insert(key, plan);
137 }
138
139 fn clear(&mut self) {
140 self.map.clear();
141 self.order.clear();
142 }
143
144 #[cfg(test)]
145 fn is_empty(&self) -> bool {
146 self.map.is_empty()
147 }
148}
149
150#[derive(Debug, Clone, Default)]
152pub struct ParquetReaderOptions {
153 pub batch_size: Option<usize>,
155}
156
157#[derive(Debug, Clone, Default)]
159pub struct CsvReaderOptions {
160 pub delimiter: Option<char>,
162 pub has_header: Option<bool>,
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct ParquetWriterOptions {
169 pub compression: Option<String>,
171 pub max_row_group_size: Option<usize>,
173}
174
175#[derive(Debug, Clone, Default)]
177pub struct CsvWriterOptions {
178 pub delimiter: Option<char>,
180 pub has_header: Option<bool>,
182}
183
184#[non_exhaustive]
186#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
187pub enum SqlError {
188 #[error("SQL query is empty")]
190 EmptyQuery,
191 #[error("table name is empty")]
193 EmptyTableName,
194 #[error("unsupported SQL feature: {feature}")]
196 Unsupported { feature: String },
197 #[error("invalid table function: {message}")]
199 InvalidTableFunction { message: String },
200 #[error("DataFusion error: {message}")]
202 DataFusion { message: String },
203 #[error(transparent)]
205 Optimizer(#[from] krishiv_plan::optimizer::OptimizerError),
206 #[error("access denied: {reason}")]
208 AccessDenied { reason: String },
209 #[error("operation {operation_id} was cancelled")]
211 OperationCancelled { operation_id: u64 },
212 #[error("query timed out after {timeout_ms} ms")]
214 Timeout { timeout_ms: u64 },
215}
216
217impl From<datafusion::error::DataFusionError> for SqlError {
218 fn from(value: datafusion::error::DataFusionError) -> Self {
219 Self::DataFusion {
220 message: value.to_string(),
221 }
222 }
223}
224
225#[derive(Debug, Clone, PartialEq, Eq)]
227pub struct SqlPlan {
228 query: String,
229 logical_plan: LogicalPlan,
230}
231
232impl SqlPlan {
233 pub fn query(&self) -> &str {
235 &self.query
236 }
237
238 pub fn logical_plan(&self) -> &LogicalPlan {
240 &self.logical_plan
241 }
242}
243
244const PLAN_CACHE_MAX_ENTRIES: usize = 256;
256
257fn resolve_plan_cache_max_entries() -> usize {
258 std::env::var("KRISHIV_PLAN_CACHE_MAX_ENTRIES")
259 .ok()
260 .and_then(|v| v.parse().ok())
261 .filter(|&n| n > 0)
262 .unwrap_or(PLAN_CACHE_MAX_ENTRIES)
263}
264const STREAMING_CEP_MAX_ROWS_DEFAULT: usize = 100_000;
265
266pub fn resolve_streaming_match_recognize_limit(raw: Option<&str>) -> usize {
270 raw.and_then(|s| s.parse::<usize>().ok())
271 .filter(|n| *n > 0)
272 .unwrap_or(STREAMING_CEP_MAX_ROWS_DEFAULT)
273}
274
275pub fn streaming_match_recognize_limit_from_env() -> usize {
278 resolve_streaming_match_recognize_limit(
279 std::env::var("KRISHIV_MATCH_RECOGNIZE_STREAMING_LIMIT")
280 .ok()
281 .as_deref(),
282 )
283}
284
285pub fn resolve_query_memory_limit_bytes(raw: Option<&str>) -> Option<usize> {
289 raw.and_then(|s| s.trim().parse::<usize>().ok())
290 .filter(|n| *n > 0)
291}
292
293pub fn query_memory_limit_from_env() -> Option<usize> {
296 resolve_query_memory_limit_bytes(
297 std::env::var("KRISHIV_QUERY_MEMORY_LIMIT_BYTES")
298 .ok()
299 .as_deref(),
300 )
301}
302
303pub fn batch_size_from_env() -> usize {
307 std::env::var("KRISHIV_BATCH_SIZE")
308 .ok()
309 .and_then(|v| v.parse::<usize>().ok())
310 .filter(|n| *n > 0)
311 .unwrap_or(8192)
312}
313
314pub fn default_parallelism_from_env() -> NonZeroUsize {
318 std::env::var("KRISHIV_TARGET_PARALLELISM")
319 .ok()
320 .and_then(|v| v.parse::<usize>().ok())
321 .and_then(NonZeroUsize::new)
322 .unwrap_or_else(|| std::thread::available_parallelism().unwrap_or(NonZeroUsize::MIN))
323}
324
325fn build_single_node_session_config(
333 target_partitions: NonZeroUsize,
334) -> datafusion::prelude::SessionConfig {
335 let tp = target_partitions.get();
336 let batch_size = batch_size_from_env();
337 let mut config = datafusion::prelude::SessionConfig::new()
338 .with_target_partitions(tp)
339 .with_batch_size(batch_size)
340 .set_bool(
341 "datafusion.optimizer.enable_round_robin_repartition",
342 tp > 1,
343 );
344 config.options_mut().execution.parquet.pushdown_filters = true;
345 config.options_mut().execution.parquet.enable_page_index = true;
346 config
347}
348
349#[derive(Clone)]
350pub struct SqlEngine {
351 context: SessionContext,
352 target_parallelism: NonZeroUsize,
353 krishiv_catalog: Option<Arc<RwLock<InMemoryCatalog>>>,
354 udf_registry: Option<std::sync::Arc<std::sync::RwLock<krishiv_plan::udf::UdfRegistry>>>,
355 streaming_sources: Arc<RwLock<std::collections::HashSet<String>>>,
358 streaming_registration: Arc<Mutex<()>>,
360 has_streaming_sources: Arc<AtomicBool>,
365 udf_limits: Option<krishiv_plan::udf::ResourceLimits>,
368 udf_registry_version: Arc<AtomicU64>,
372 udf_last_synced_version: Arc<AtomicU64>,
375 plan_cache: Arc<Mutex<PlanCache>>,
381 shuffle_partitions: Arc<std::sync::RwLock<Option<u32>>>,
384 table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
389 memory_limit_bytes: Option<usize>,
394 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
398 iceberg_catalogs: Arc<std::sync::RwLock<Vec<(Arc<catalog::unified::KrishivCatalog>, String)>>>,
399 live_table_registry: Arc<live_table::LiveTableRegistry>,
401 incremental_view_registry: Arc<incremental_view::IncrementalViewRegistry>,
403 pipeline_registry: Arc<pipeline_ddl::PipelineRegistry>,
405 operation_registry: Arc<OperationRegistry>,
407}
408
409impl fmt::Debug for SqlEngine {
410 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
411 f.debug_struct("SqlEngine")
412 .field("backend", &"datafusion")
413 .finish_non_exhaustive()
414 }
415}
416
417impl Default for SqlEngine {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423impl SqlEngine {
424 pub fn new() -> Self {
435 Self::new_with_memory_limit(query_memory_limit_from_env())
436 }
437
438 pub fn new_with_memory_limit(memory_limit_bytes: Option<usize>) -> Self {
450 match Self::build_local(
451 None,
452 WindowFnRegistration::Register,
453 NonZeroUsize::MIN,
454 memory_limit_bytes,
455 ) {
456 Ok(engine) => engine,
457 Err(err) => {
458 tracing::warn!(
459 error = %err,
460 "SqlEngine::new: window helper UDF registration failed; \
461 window SQL functions will be unavailable, other queries are unaffected"
462 );
463 Self::build_local(
464 None,
465 WindowFnRegistration::Skip,
466 NonZeroUsize::MIN,
467 memory_limit_bytes,
468 )
469 .unwrap_or_else(|err| {
470 tracing::error!(
471 error = %err,
472 "memory-limited DataFusion runtime construction failed; \
473 falling back to an unbounded engine"
474 );
475 Self::build_local(None, WindowFnRegistration::Skip, NonZeroUsize::MIN, None)
476 .unwrap_or_else(|_| Self::build_absolute_minimal(NonZeroUsize::MIN))
477 })
478 }
479 }
480 }
481
482 pub fn try_new() -> SqlResult<Self> {
487 Self::build_local(
488 None,
489 WindowFnRegistration::Register,
490 NonZeroUsize::MIN,
491 query_memory_limit_from_env(),
492 )
493 }
494
495 pub fn with_in_memory_catalog(catalog: Arc<RwLock<InMemoryCatalog>>) -> SqlResult<Self> {
497 if krishiv_common::profile_requires_fail_closed_metadata(
498 krishiv_common::resolve_durability_profile(),
499 ) {
500 return Err(SqlError::DataFusion {
501 message: String::from(
502 "InMemoryCatalog is dev-only; configure a durable REST or file-backed \
503 catalog for production deployments",
504 ),
505 });
506 }
507 Self::build_local(
508 Some(catalog),
509 WindowFnRegistration::Register,
510 NonZeroUsize::MIN,
511 query_memory_limit_from_env(),
512 )
513 }
514
515 #[must_use]
521 pub fn with_target_parallelism(mut self, n: NonZeroUsize) -> Self {
522 self.target_parallelism = n;
523 self
524 }
525
526 pub fn target_parallelism(&self) -> NonZeroUsize {
528 self.target_parallelism
529 }
530
531 pub fn memory_limit_bytes(&self) -> Option<usize> {
533 self.memory_limit_bytes
534 }
535
536 pub fn shuffle_partitions(&self) -> Option<u32> {
538 *self
539 .shuffle_partitions
540 .read()
541 .unwrap_or_else(|e| e.into_inner())
542 }
543
544 pub fn table_row_counts(&self) -> Arc<std::sync::RwLock<HashMap<String, u64>>> {
550 Arc::clone(&self.table_row_counts)
551 }
552
553 pub fn registered_table_names(&self) -> Vec<String> {
559 let mut names = Vec::new();
560 for catalog_name in self.context.catalog_names() {
561 let Some(catalog) = self.context.catalog(&catalog_name) else {
562 continue;
563 };
564 for schema_name in catalog.schema_names() {
565 let Some(schema) = catalog.schema(&schema_name) else {
566 continue;
567 };
568 names.extend(schema.table_names());
569 }
570 }
571 names.sort();
572 names.dedup();
573 names
574 }
575
576 fn make_sql_df(&self, name: &str, dataframe: DataFusionDataFrame) -> SqlDataFrame {
579 SqlDataFrame::new(name, dataframe, self.table_row_counts())
580 .with_context(self.context.clone())
581 }
582
583 fn attach_query_metadata(&self, df: SqlDataFrame, query: &str) -> SqlDataFrame {
585 let kind = if self.is_streaming_query(query).unwrap_or(false) {
586 ExecutionKind::Streaming
587 } else {
588 ExecutionKind::Batch
589 };
590 df.with_query(query).with_execution_kind(kind)
591 }
592
593 #[must_use]
598 pub fn with_shuffle_partitions(self, n: Option<u32>) -> Self {
599 if let Ok(mut guard) = self.shuffle_partitions.write() {
600 *guard = n;
601 }
602 self
603 }
604
605 fn build_local(
615 krishiv_catalog: Option<Arc<RwLock<InMemoryCatalog>>>,
616 window_fn_registration: WindowFnRegistration,
617 target_partitions: NonZeroUsize,
618 memory_limit_bytes: Option<usize>,
619 ) -> SqlResult<Self> {
620 let streaming_sources: Arc<RwLock<std::collections::HashSet<String>>> =
624 Arc::new(RwLock::new(std::collections::HashSet::new()));
625
626 let dummy_state = datafusion::execution::session_state::SessionStateBuilder::new()
627 .with_default_features()
628 .build();
629 let mut table_factories = dummy_state.table_factories().clone();
630 crate::connector_table::register_connector_table_factories(
631 &mut table_factories,
632 streaming_sources.clone(),
633 );
634 let mut state_builder = datafusion::execution::session_state::SessionStateBuilder::new()
635 .with_default_features()
636 .with_config(build_single_node_session_config(target_partitions))
637 .with_table_factories(table_factories);
638 if let Some(limit) = memory_limit_bytes {
639 let runtime_env = datafusion::execution::runtime_env::RuntimeEnvBuilder::new()
644 .with_memory_pool(Arc::new(
645 datafusion::execution::memory_pool::FairSpillPool::new(limit),
646 ))
647 .build_arc()
648 .map_err(|e| SqlError::DataFusion {
649 message: format!(
650 "failed to build memory-limited DataFusion runtime \
651 (limit {limit} bytes): {e}"
652 ),
653 })?;
654 state_builder = state_builder.with_runtime_env(runtime_env);
655 }
656 let state = state_builder.build();
657 let context = SessionContext::new_with_state(state);
658 if let Some(catalog) = &krishiv_catalog {
659 context.register_catalog(
660 "krishiv",
661 Arc::new(DataFusionCatalogBridge::new(catalog.clone())),
662 );
663 }
664 if matches!(window_fn_registration, WindowFnRegistration::Register) {
665 window_functions::register_window_functions(&context).map_err(|e| {
666 SqlError::DataFusion {
667 message: format!("failed to register window helper UDFs: {e}"),
668 }
669 })?;
670 }
671 Ok(Self {
672 context,
673 target_parallelism: target_partitions,
674 krishiv_catalog,
675 udf_registry: None,
676 streaming_sources,
677 streaming_registration: Arc::new(Mutex::new(())),
678 has_streaming_sources: Arc::new(AtomicBool::new(false)),
679 udf_limits: None,
680 udf_registry_version: Arc::new(AtomicU64::new(0)),
681 udf_last_synced_version: Arc::new(AtomicU64::new(u64::MAX)),
682 plan_cache: Arc::new(Mutex::new(PlanCache::new(resolve_plan_cache_max_entries()))),
683 shuffle_partitions: Arc::new(std::sync::RwLock::new(None)),
684 table_row_counts: Arc::new(std::sync::RwLock::new(HashMap::new())),
685 memory_limit_bytes,
686 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
687 iceberg_catalogs: Arc::new(std::sync::RwLock::new(Vec::new())),
688 live_table_registry: Arc::new(live_table::LiveTableRegistry::new()),
689 incremental_view_registry: Arc::new(incremental_view::IncrementalViewRegistry::new()),
690 pipeline_registry: Arc::new(pipeline_ddl::PipelineRegistry::new()),
691 operation_registry: Arc::new(OperationRegistry::new()),
692 })
693 }
694
695 fn build_absolute_minimal(target_partitions: NonZeroUsize) -> Self {
699 let streaming_sources: Arc<RwLock<std::collections::HashSet<String>>> =
700 Arc::new(RwLock::new(std::collections::HashSet::new()));
701 let dummy_state = datafusion::execution::session_state::SessionStateBuilder::new()
702 .with_default_features()
703 .build();
704 let mut table_factories = dummy_state.table_factories().clone();
705 crate::connector_table::register_connector_table_factories(
706 &mut table_factories,
707 streaming_sources.clone(),
708 );
709 let state = datafusion::execution::session_state::SessionStateBuilder::new()
710 .with_default_features()
711 .with_config(build_single_node_session_config(target_partitions))
712 .with_table_factories(table_factories)
713 .build();
714 let context = SessionContext::new_with_state(state);
715 Self {
716 context,
717 target_parallelism: target_partitions,
718 krishiv_catalog: None,
719 udf_registry: None,
720 streaming_sources,
721 streaming_registration: Arc::new(Mutex::new(())),
722 has_streaming_sources: Arc::new(AtomicBool::new(false)),
723 udf_limits: None,
724 udf_registry_version: Arc::new(AtomicU64::new(0)),
725 udf_last_synced_version: Arc::new(AtomicU64::new(u64::MAX)),
726 plan_cache: Arc::new(Mutex::new(PlanCache::new(resolve_plan_cache_max_entries()))),
727 shuffle_partitions: Arc::new(std::sync::RwLock::new(None)),
728 table_row_counts: Arc::new(std::sync::RwLock::new(HashMap::new())),
729 memory_limit_bytes: None,
730 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
731 iceberg_catalogs: Arc::new(std::sync::RwLock::new(Vec::new())),
732 live_table_registry: Arc::new(live_table::LiveTableRegistry::new()),
733 incremental_view_registry: Arc::new(incremental_view::IncrementalViewRegistry::new()),
734 pipeline_registry: Arc::new(pipeline_ddl::PipelineRegistry::new()),
735 operation_registry: Arc::new(OperationRegistry::new()),
736 }
737 }
738
739 pub fn register_streaming_table(
750 &self,
751 name: &str,
752 schema: arrow::datatypes::SchemaRef,
753 ) -> SqlResult<Arc<ContinuousTableInput>> {
754 let _registration = self.lock_streaming_registration()?;
755 self.validate_new_streaming_table(name, &schema)?;
756 let (table, input) = crate::streaming::create_continuous_table(schema).map_err(|e| {
757 SqlError::DataFusion {
758 message: e.to_string(),
759 }
760 })?;
761 self.register_new_streaming_provider(name, table)?;
762 self.streaming_sources
763 .write()
764 .unwrap_or_else(|e| e.into_inner())
765 .insert(name.to_string());
766 self.has_streaming_sources.store(true, Ordering::Release);
767 self.invalidate_plan_cache();
768 Ok(input)
769 }
770
771 pub fn register_streaming_table_with_capacity(
776 &self,
777 name: &str,
778 schema: arrow::datatypes::SchemaRef,
779 capacity: usize,
780 ) -> SqlResult<Arc<ContinuousTableInput>> {
781 let _registration = self.lock_streaming_registration()?;
782 self.validate_new_streaming_table(name, &schema)?;
783 let (table, input) = crate::streaming::create_continuous_table_with_capacity(
784 schema, capacity,
785 )
786 .map_err(|e| SqlError::DataFusion {
787 message: e.to_string(),
788 })?;
789 self.register_new_streaming_provider(name, table)?;
790 self.streaming_sources
791 .write()
792 .unwrap_or_else(|e| e.into_inner())
793 .insert(name.to_string());
794 self.has_streaming_sources.store(true, Ordering::Release);
795 self.invalidate_plan_cache();
796 Ok(input)
797 }
798
799 fn lock_streaming_registration(&self) -> SqlResult<std::sync::MutexGuard<'_, ()>> {
800 self.streaming_registration
801 .lock()
802 .map_err(|error| SqlError::DataFusion {
803 message: format!("streaming table registration lock poisoned: {error}"),
804 })
805 }
806
807 fn validate_new_streaming_table(
808 &self,
809 name: &str,
810 schema: &arrow::datatypes::SchemaRef,
811 ) -> SqlResult<()> {
812 if name.trim().is_empty() {
813 return Err(SqlError::EmptyTableName);
814 }
815 if schema.fields().is_empty() {
816 return Err(SqlError::DataFusion {
817 message: "streaming table schema must contain at least one field".into(),
818 });
819 }
820 if self
821 .context
822 .table_exist(name)
823 .map_err(|error| SqlError::DataFusion {
824 message: error.to_string(),
825 })?
826 {
827 return Err(SqlError::DataFusion {
828 message: format!("table '{name}' is already registered"),
829 });
830 }
831 Ok(())
832 }
833
834 fn register_new_streaming_provider(
835 &self,
836 name: &str,
837 table: Arc<dyn datafusion::catalog::TableProvider>,
838 ) -> SqlResult<()> {
839 let previous =
840 self.context
841 .register_table(name, table)
842 .map_err(|error| SqlError::DataFusion {
843 message: error.to_string(),
844 })?;
845 if let Some(previous) = previous {
846 self.context
847 .register_table(name, previous)
848 .map_err(|error| SqlError::DataFusion {
849 message: format!(
850 "table '{name}' was concurrently registered and could not be restored: \
851 {error}"
852 ),
853 })?;
854 return Err(SqlError::DataFusion {
855 message: format!("table '{name}' was concurrently registered"),
856 });
857 }
858 Ok(())
859 }
860
861 pub fn register_kafka_source(
875 &self,
876 table_name: impl AsRef<str>,
877 schema: arrow::datatypes::SchemaRef,
878 bootstrap_servers: impl Into<String>,
879 topic: impl Into<String>,
880 group_id: impl Into<String>,
881 ) -> SqlResult<()> {
882 let table_name = table_name.as_ref();
883 if table_name.trim().is_empty() {
884 return Err(SqlError::EmptyTableName);
885 }
886 let config = krishiv_connectors::kafka::KafkaConfig {
887 bootstrap_servers: bootstrap_servers.into(),
888 topic: topic.into(),
889 group_id: group_id.into(),
890 auto_commit_interval_ms: {
891 let profile = krishiv_common::resolve_durability_profile();
892 if krishiv_common::requires_manual_kafka_commit(profile) {
893 None
894 } else {
895 Some(1_000)
896 }
897 },
898 security_protocol: None,
899 ssl_ca_location: None,
900 ssl_certificate_location: None,
901 ssl_key_location: None,
902 ssl_key_password: None,
903 sasl_username: None,
904 sasl_password: None,
905 sasl_mechanisms: None,
906 enable_idempotence: None,
907 transactional_id: None,
908 };
909 let table =
910 crate::kafka_table::create_kafka_streaming_table(schema, config).map_err(|e| {
911 SqlError::DataFusion {
912 message: e.to_string(),
913 }
914 })?;
915 if self
916 .context
917 .table_exist(table_name)
918 .map_err(SqlError::from)?
919 {
920 let _ = self
921 .context
922 .deregister_table(table_name)
923 .map_err(SqlError::from)?;
924 }
925 self.context
926 .register_table(table_name, table)
927 .map_err(|e| SqlError::DataFusion {
928 message: e.to_string(),
929 })?;
930 self.streaming_sources
931 .write()
932 .unwrap_or_else(|e| e.into_inner())
933 .insert(table_name.to_string());
934 self.has_streaming_sources.store(true, Ordering::Release);
935 self.invalidate_plan_cache();
936 Ok(())
937 }
938
939 pub async fn sql_to_kafka(
949 &self,
950 sql: impl AsRef<str>,
951 bootstrap_servers: impl Into<String>,
952 topic: impl Into<String>,
953 ) -> SqlResult<u64> {
954 use futures::StreamExt;
955 use krishiv_connectors::Sink as _;
956 use krishiv_connectors::kafka::{KafkaConfig, KafkaSink};
957
958 let config = KafkaConfig {
959 bootstrap_servers: bootstrap_servers.into(),
960 topic: topic.into(),
961 group_id: "krishiv-sql-writer".into(),
962 auto_commit_interval_ms: None,
963 security_protocol: None,
964 ssl_ca_location: None,
965 ssl_certificate_location: None,
966 ssl_key_location: None,
967 ssl_key_password: None,
968 sasl_username: None,
969 sasl_password: None,
970 sasl_mechanisms: None,
971 enable_idempotence: None,
972 transactional_id: None,
973 };
974 let mut sink = KafkaSink::new(config).map_err(|e| SqlError::DataFusion {
975 message: e.to_string(),
976 })?;
977
978 let df = self.sql(sql.as_ref()).await?;
979 let mut stream = df.execute_stream().await?;
980 let mut total_rows = 0u64;
981
982 while let Some(result) = stream.next().await {
983 let batch = result.map_err(|e| SqlError::DataFusion {
984 message: e.to_string(),
985 })?;
986 if batch.num_rows() > 0 {
987 total_rows += batch.num_rows() as u64;
988 sink.write_batch(batch)
989 .await
990 .map_err(|e| SqlError::DataFusion {
991 message: e.to_string(),
992 })?;
993 }
994 }
995 sink.flush().await.map_err(|e| SqlError::DataFusion {
996 message: e.to_string(),
997 })?;
998 Ok(total_rows)
999 }
1000
1001 pub fn with_udf_limits(mut self, limits: krishiv_plan::udf::ResourceLimits) -> Self {
1005 self.udf_limits = Some(limits);
1006 self
1007 }
1008
1009 pub fn is_streaming_source(&self, table_name: &str) -> bool {
1011 self.streaming_sources
1012 .read()
1013 .unwrap_or_else(|e| e.into_inner())
1014 .contains(table_name)
1015 }
1016
1017 pub fn register_streaming_source_name(&self, table_name: impl Into<String>) -> SqlResult<()> {
1026 let name: String = table_name.into();
1027 if name.trim().is_empty() {
1028 return Err(SqlError::EmptyTableName);
1029 }
1030 self.streaming_sources
1031 .write()
1032 .unwrap_or_else(|e| e.into_inner())
1033 .insert(name);
1034 self.has_streaming_sources.store(true, Ordering::Release);
1035 self.invalidate_plan_cache();
1036 Ok(())
1037 }
1038
1039 pub fn deregister_streaming_source(&self, name: &str) -> SqlResult<()> {
1045 if name.trim().is_empty() {
1046 return Err(SqlError::EmptyTableName);
1047 }
1048 let _ = self
1050 .context
1051 .deregister_table(name)
1052 .map_err(SqlError::from)?;
1053 {
1054 let mut sources = self
1055 .streaming_sources
1056 .write()
1057 .unwrap_or_else(|e| e.into_inner());
1058 sources.remove(name);
1059 if sources.is_empty() {
1060 self.has_streaming_sources.store(false, Ordering::Release);
1061 }
1062 self.invalidate_plan_cache();
1066 }
1067 Ok(())
1068 }
1069
1070 pub fn live_table_registry(&self) -> &Arc<live_table::LiveTableRegistry> {
1072 &self.live_table_registry
1073 }
1074
1075 pub fn incremental_view_registry(&self) -> &Arc<incremental_view::IncrementalViewRegistry> {
1077 &self.incremental_view_registry
1078 }
1079
1080 pub fn pipeline_registry(&self) -> &Arc<pipeline_ddl::PipelineRegistry> {
1082 &self.pipeline_registry
1083 }
1084
1085 pub fn operation_registry(&self) -> &Arc<OperationRegistry> {
1087 &self.operation_registry
1088 }
1089
1090 pub fn deregister_table(&self, name: &str) -> SqlResult<()> {
1094 if name.trim().is_empty() {
1095 return Err(SqlError::EmptyTableName);
1096 }
1097 let _ = self
1098 .context
1099 .deregister_table(name)
1100 .map_err(SqlError::from)?;
1101 self.invalidate_plan_cache();
1102 Ok(())
1103 }
1104
1105 pub fn register_table_udf_fn(
1129 &self,
1130 name: impl Into<String>,
1131 schema: arrow::datatypes::Schema,
1132 f: impl Fn(
1133 &[krishiv_plan::udf::ScalarValue],
1134 ) -> Result<arrow::record_batch::RecordBatch, krishiv_plan::udf::UdfError>
1135 + Send
1136 + Sync
1137 + 'static,
1138 ) -> SqlResult<()> {
1139 let udf =
1140 create_function_ddl::ClosureTableUdf::try_new(name, schema, std::sync::Arc::new(f))
1141 .map_err(|error| SqlError::InvalidTableFunction {
1142 message: error.to_string(),
1143 })?;
1144 if let Some(registry) = &self.udf_registry {
1145 let mut guard = registry.write().map_err(|e| SqlError::DataFusion {
1146 message: e.to_string(),
1147 })?;
1148 guard.register_table(std::sync::Arc::new(udf.clone()));
1149 }
1150 udf::register_single_table_udf(&self.context, std::sync::Arc::new(udf))
1151 .map_err(SqlError::from)
1152 }
1153
1154 pub fn is_streaming_query(&self, sql: &str) -> SqlResult<bool> {
1156 if !self.has_streaming_sources.load(Ordering::Acquire) {
1159 return Ok(false);
1160 }
1161 let sources = self
1162 .streaming_sources
1163 .read()
1164 .unwrap_or_else(|e| e.into_inner());
1165 if sources.is_empty() {
1166 return Ok(false);
1167 }
1168 let dialect = GenericDialect {};
1169 let statements = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::DataFusion {
1170 message: e.to_string(),
1171 })?;
1172 for stmt in &statements {
1173 let mut is_streaming = false;
1174 let _ = visit_relations(stmt, |relation| {
1175 let full = relation.to_string();
1178 let table_name = full.split('.').next_back().unwrap_or(&full);
1179 if sources.contains(table_name) {
1180 is_streaming = true;
1181 return ControlFlow::Break(());
1182 }
1183 ControlFlow::Continue(())
1184 });
1185 if is_streaming {
1186 return Ok(true);
1187 }
1188 }
1189 Ok(false)
1190 }
1191
1192 pub fn krishiv_catalog(&self) -> Option<&Arc<RwLock<InMemoryCatalog>>> {
1194 self.krishiv_catalog.as_ref()
1195 }
1196
1197 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
1206 #[must_use]
1207 pub fn with_iceberg_catalog(
1208 self,
1209 catalog: std::sync::Arc<catalog::unified::KrishivCatalog>,
1210 catalog_name: impl Into<String>,
1211 ) -> Self {
1212 let catalog_name = catalog_name.into();
1213 let bridge = catalog::iceberg_catalog_bridge::IcebergCatalogBridge::new(
1214 Arc::clone(&catalog),
1215 catalog_name.clone(),
1216 );
1217 self.context
1218 .register_catalog(catalog_name.clone(), Arc::new(bridge));
1219 self.iceberg_catalogs
1220 .write()
1221 .unwrap_or_else(|e| e.into_inner())
1222 .push((catalog, catalog_name));
1223 self
1224 }
1225
1226 #[must_use]
1228 pub fn with_udf_registry(
1229 mut self,
1230 registry: std::sync::Arc<std::sync::RwLock<krishiv_plan::udf::UdfRegistry>>,
1231 ) -> Self {
1232 self.udf_registry = Some(registry);
1233 self.bump_udf_version();
1235 self
1236 }
1237
1238 pub(crate) fn bump_udf_version(&self) {
1241 self.udf_registry_version.fetch_add(1, Ordering::Release);
1242 }
1243
1244 fn invalidate_plan_cache(&self) {
1249 match self.plan_cache.lock() {
1250 Ok(mut cache) => cache.clear(),
1251 Err(poisoned) => poisoned.into_inner().clear(),
1252 }
1253 }
1254
1255 pub fn clear_plan_cache(&self) {
1258 self.invalidate_plan_cache();
1259 }
1260
1261 pub async fn sync_scalar_udfs(&self) -> SqlResult<()> {
1264 let Some(registry) = &self.udf_registry else {
1265 return Ok(());
1266 };
1267 let guard = registry.read().map_err(|e| SqlError::DataFusion {
1268 message: e.to_string(),
1269 })?;
1270 let limits = self.udf_limits.clone().unwrap_or_default();
1271 udf::sync_scalar_udfs_with_limits(&self.context, &guard, limits).map_err(|e| {
1272 SqlError::DataFusion {
1273 message: e.to_string(),
1274 }
1275 })
1276 }
1277
1278 pub async fn sync_scalar_udfs_with_limits(
1283 &self,
1284 limits: krishiv_plan::udf::ResourceLimits,
1285 ) -> SqlResult<()> {
1286 self.sync_scalar_udfs_with_limits_for_profile(
1287 limits,
1288 krishiv_common::resolve_durability_profile(),
1289 )
1290 .await
1291 }
1292
1293 pub async fn sync_scalar_udfs_with_limits_for_profile(
1295 &self,
1296 limits: krishiv_plan::udf::ResourceLimits,
1297 profile: krishiv_common::DurabilityProfile,
1298 ) -> SqlResult<()> {
1299 self.sync_scalar_udfs_with_limits_for_policy(
1300 limits,
1301 krishiv_common::NativeScalarUdfPolicy::resolve(profile),
1302 )
1303 .await
1304 }
1305
1306 pub async fn sync_scalar_udfs_with_limits_for_policy(
1308 &self,
1309 limits: krishiv_plan::udf::ResourceLimits,
1310 policy: krishiv_common::NativeScalarUdfPolicy,
1311 ) -> SqlResult<()> {
1312 let Some(registry) = &self.udf_registry else {
1313 return Ok(());
1314 };
1315 let guard = registry.read().map_err(|e| SqlError::DataFusion {
1316 message: e.to_string(),
1317 })?;
1318 udf::sync_scalar_udfs_with_limits_for_policy(&self.context, &guard, limits, policy).map_err(
1319 |e| SqlError::DataFusion {
1320 message: e.to_string(),
1321 },
1322 )
1323 }
1324
1325 pub async fn sync_aggregate_udfs(&self) -> SqlResult<()> {
1327 let Some(registry) = &self.udf_registry else {
1328 return Ok(());
1329 };
1330 let guard = registry.read().map_err(|e| SqlError::DataFusion {
1331 message: e.to_string(),
1332 })?;
1333 udf::sync_aggregate_udfs(&self.context, &guard).map_err(|e| SqlError::DataFusion {
1334 message: e.to_string(),
1335 })
1336 }
1337
1338 pub async fn sync_table_udfs(&self) -> SqlResult<()> {
1340 let Some(registry) = &self.udf_registry else {
1341 return Ok(());
1342 };
1343 let guard = registry.read().map_err(|e| SqlError::DataFusion {
1344 message: e.to_string(),
1345 })?;
1346 udf::sync_table_udfs(&self.context, &guard).map_err(|e| SqlError::DataFusion {
1347 message: e.to_string(),
1348 })
1349 }
1350
1351 pub async fn sync_all_udfs(&self) -> SqlResult<()> {
1353 self.sync_scalar_udfs().await?;
1354 self.sync_aggregate_udfs().await?;
1355 self.sync_table_udfs().await?;
1356 Ok(())
1357 }
1358
1359 pub async fn register_parquet(
1361 &self,
1362 table_name: impl AsRef<str>,
1363 path: impl AsRef<Path>,
1364 ) -> SqlResult<()> {
1365 let table_name = table_name.as_ref();
1366 if table_name.trim().is_empty() {
1367 return Err(SqlError::EmptyTableName);
1368 }
1369
1370 let path = path.as_ref().to_string_lossy().into_owned();
1371
1372 if path.starts_with("s3://") {
1375 let url = url::Url::parse(&path).map_err(|e| SqlError::DataFusion {
1376 message: format!("invalid s3 url {path}: {e}"),
1377 })?;
1378 let bucket = url.host_str().unwrap_or_default();
1379 let store_url =
1380 url::Url::parse(&format!("s3://{bucket}")).map_err(|e| SqlError::DataFusion {
1381 message: format!("invalid s3 bucket url: {e}"),
1382 })?;
1383 let store = AmazonS3Builder::from_env()
1384 .with_bucket_name(bucket)
1385 .build()
1386 .map_err(|e| SqlError::DataFusion {
1387 message: format!("s3 store init: {e}"),
1388 })?;
1389 self.context
1390 .register_object_store(&store_url, Arc::new(store));
1391 }
1392
1393 if self
1394 .context
1395 .table_exist(table_name)
1396 .map_err(SqlError::from)?
1397 {
1398 let _ = self
1399 .context
1400 .deregister_table(table_name)
1401 .map_err(SqlError::from)?;
1402 }
1403 self.context
1404 .register_parquet(table_name, path, ParquetReadOptions::default())
1405 .await?;
1406 if let Ok(provider) = self.context.table_provider(table_name).await
1408 && let Some(stats) = provider.statistics()
1409 && let Some(n) = stats.num_rows.get_value()
1410 && let Ok(mut counts) = self.table_row_counts.write()
1411 {
1412 counts.insert(table_name.to_string(), *n as u64);
1413 }
1414 self.invalidate_plan_cache();
1415 Ok(())
1416 }
1417
1418 pub async fn read_parquet(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
1420 let path = path.as_ref().to_string_lossy().into_owned();
1421 let dataframe = self
1422 .context
1423 .read_parquet(path, ParquetReadOptions::default())
1424 .await?;
1425 Ok(self.make_sql_df("parquet-read", dataframe))
1426 }
1427
1428 pub async fn register_record_batches(
1434 &self,
1435 table_name: impl AsRef<str>,
1436 batches: Vec<RecordBatch>,
1437 ) -> SqlResult<()> {
1438 use std::sync::Arc;
1439 let table_name = table_name.as_ref();
1440 if table_name.trim().is_empty() {
1441 return Err(SqlError::EmptyTableName);
1442 }
1443 if batches.is_empty() {
1444 return Ok(());
1445 }
1446 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
1447 let schema = batches
1448 .first()
1449 .ok_or_else(|| SqlError::DataFusion {
1450 message: "empty batch list".into(),
1451 })?
1452 .schema();
1453 let mem_table =
1454 datafusion::datasource::MemTable::try_new(schema, vec![batches]).map_err(|e| {
1455 SqlError::DataFusion {
1456 message: e.to_string(),
1457 }
1458 })?;
1459 if self
1460 .context
1461 .table_exist(table_name)
1462 .map_err(SqlError::from)?
1463 {
1464 let _ = self
1465 .context
1466 .deregister_table(table_name)
1467 .map_err(SqlError::from)?;
1468 }
1469 self.context
1470 .register_table(table_name, Arc::new(mem_table))
1471 .map_err(|e| SqlError::DataFusion {
1472 message: e.to_string(),
1473 })?;
1474 if total_rows > 0
1475 && let Ok(mut counts) = self.table_row_counts.write()
1476 {
1477 counts.insert(table_name.to_string(), total_rows as u64);
1478 }
1479 self.invalidate_plan_cache();
1480 Ok(())
1481 }
1482
1483 pub async fn read_parquet_with_options(
1485 &self,
1486 path: impl AsRef<Path>,
1487 opts: &ParquetReaderOptions,
1488 ) -> SqlResult<SqlDataFrame> {
1489 let path = path.as_ref().to_string_lossy().into_owned();
1490 let mut options = datafusion::prelude::ParquetReadOptions::default();
1491 if opts.batch_size.is_some() {
1492 options = options.parquet_pruning(true);
1493 }
1494 let dataframe = self.context.read_parquet(path, options).await?;
1500 Ok(self.make_sql_df("parquet-read", dataframe))
1501 }
1502
1503 pub async fn read_csv(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
1505 self.read_csv_with_options(path, &CsvReaderOptions::default())
1506 .await
1507 }
1508
1509 pub async fn read_csv_with_options(
1511 &self,
1512 path: impl AsRef<Path>,
1513 opts: &CsvReaderOptions,
1514 ) -> SqlResult<SqlDataFrame> {
1515 let path = path.as_ref().to_string_lossy().into_owned();
1516 let mut options = datafusion::prelude::CsvReadOptions::new();
1517 if let Some(delim) = opts.delimiter {
1518 options = options.delimiter(delim as u8);
1519 }
1520 if let Some(has_header) = opts.has_header {
1521 options = options.has_header(has_header);
1522 }
1523 let dataframe = self.context.read_csv(path, options).await?;
1524 Ok(self.make_sql_df("csv-read", dataframe))
1525 }
1526
1527 pub async fn read_json(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
1529 let path = path.as_ref().to_string_lossy().into_owned();
1530 let dataframe = self
1531 .context
1532 .read_json(path, datafusion::prelude::JsonReadOptions::default())
1533 .await?;
1534 Ok(self.make_sql_df("json-read", dataframe))
1535 }
1536
1537 pub async fn read_delta(
1539 &self,
1540 path: impl AsRef<str>,
1541 version: Option<i64>,
1542 ) -> SqlResult<SqlDataFrame> {
1543 let path = path.as_ref();
1544 let base = path.replace(['/', '.', '-'], "_");
1545 let table = match version {
1546 Some(v) => format!("delta_{base}_v{v}"),
1547 None => format!("delta_{base}"),
1548 };
1549 lakehouse::register_delta_uri(&self.context, &table, path, version).await?;
1550 self.sql(format!("SELECT * FROM {table}")).await
1551 }
1552
1553 pub async fn read_hudi(
1555 &self,
1556 path: impl AsRef<str>,
1557 query_type: krishiv_connectors::lakehouse::HudiQueryType,
1558 begin_instant: Option<&str>,
1559 ) -> SqlResult<SqlDataFrame> {
1560 let path = path.as_ref();
1561 let table = format!("hudi_{}", path.replace(['/', '.', '-'], "_"));
1562 lakehouse::register_hudi_uri(&self.context, &table, path, query_type, begin_instant)
1563 .await?;
1564 self.sql(format!("SELECT * FROM {table}")).await
1565 }
1566
1567 pub async fn sql(&self, query: impl AsRef<str>) -> SqlResult<SqlDataFrame> {
1569 let query = query.as_ref();
1570 if query.trim().is_empty() {
1571 return Err(SqlError::EmptyQuery);
1572 }
1573
1574 {
1578 let current = self.udf_registry_version.load(Ordering::Acquire);
1579 let last = self.udf_last_synced_version.load(Ordering::Relaxed);
1580 if current != last {
1581 self.sync_all_udfs().await?;
1582 self.udf_last_synced_version
1583 .store(current, Ordering::Release);
1584 }
1585 }
1586
1587 if let Some(stmt) = introspection_sql::parse_introspection_statement(query)? {
1589 return match stmt {
1590 introspection_sql::IntrospectionStatement::Describe { table } => {
1591 let batch = introspection_sql::describe_table(&self.context, &table).await?;
1592 let describe_table_name = next_ephemeral_name("describe_result");
1593 lakehouse::register_scan_batches(
1594 &self.context,
1595 &describe_table_name,
1596 vec![batch],
1597 )
1598 .await?;
1599 let dataframe = self
1600 .context
1601 .sql(&format!("SELECT * FROM {describe_table_name}"))
1602 .await?;
1603 Ok(self.attach_query_metadata(self.make_sql_df("describe", dataframe), query))
1604 }
1605 introspection_sql::IntrospectionStatement::Explain { mode, query: inner } => {
1606 let text = introspection_sql::explain_query(&inner, mode)?;
1607 let batch = introspection_sql::explain_result_batch(&text)?;
1608 let explain_table = next_ephemeral_name("explain_result");
1609 lakehouse::register_scan_batches(&self.context, &explain_table, vec![batch])
1610 .await?;
1611 let dataframe = self
1612 .context
1613 .sql(&format!("SELECT * FROM {explain_table}"))
1614 .await?;
1615 Ok(self.attach_query_metadata(self.make_sql_df("explain", dataframe), query))
1616 }
1617 };
1618 }
1619
1620 if live_table::execute_live_table_ddl(&self.live_table_registry, query)?.is_some() {
1622 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1623 return Ok(self.attach_query_metadata(self.make_sql_df("live-table-ddl", empty), query));
1624 }
1625
1626 if incremental_view::execute_incremental_view_ddl(&self.incremental_view_registry, query)?
1628 .is_some()
1629 {
1630 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1631 return Ok(
1632 self.attach_query_metadata(self.make_sql_df("incremental-view-ddl", empty), query)
1633 );
1634 }
1635
1636 if pipeline_ddl::execute_pipeline_ddl(&self.pipeline_registry, query)?.is_some() {
1640 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1641 return Ok(self.attach_query_metadata(self.make_sql_df("pipeline-ddl", empty), query));
1642 }
1643
1644 let trimmed = query.trim();
1647 if trimmed
1648 .to_ascii_uppercase()
1649 .starts_with("SET SHUFFLE.PARTITIONS")
1650 {
1651 let value = trimmed.split('=').nth(1).map(|s| s.trim()).unwrap_or("");
1652 match value.parse::<u32>() {
1653 Ok(n) if n > 0 => {
1654 {
1655 let mut guard =
1656 self.shuffle_partitions
1657 .write()
1658 .map_err(|e| SqlError::DataFusion {
1659 message: e.to_string(),
1660 })?;
1661 *guard = Some(n);
1662 }
1663 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1664 return Ok(self.make_sql_df("set-shuffle-partitions", empty));
1665 }
1666 Ok(_) => {
1667 {
1668 let mut guard =
1669 self.shuffle_partitions
1670 .write()
1671 .map_err(|e| SqlError::DataFusion {
1672 message: e.to_string(),
1673 })?;
1674 *guard = None;
1675 }
1676 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1677 return Ok(self.make_sql_df("set-shuffle-partitions", empty));
1678 }
1679 Err(_) => {
1680 return Err(SqlError::DataFusion {
1681 message: format!(
1682 "invalid shuffle.partitions value '{value}'; expected a positive integer"
1683 ),
1684 });
1685 }
1686 }
1687 }
1688
1689 if create_function_ddl::is_create_function_returns_table(query) {
1694 let ddl = create_function_ddl::parse_create_function(query)
1695 .map_err(|message| SqlError::InvalidTableFunction { message })?;
1696 if ddl.language.as_deref() != Some("sql") {
1697 return Err(SqlError::Unsupported {
1698 feature: format!(
1699 "CREATE FUNCTION '{}' uses language {:?}; only LANGUAGE SQL AS '...' \
1700 table functions are executable",
1701 ddl.function_name, ddl.language
1702 ),
1703 });
1704 }
1705 let body = ddl
1706 .body
1707 .as_deref()
1708 .filter(|body| !body.trim().is_empty())
1709 .ok_or_else(|| SqlError::InvalidTableFunction {
1710 message: format!(
1711 "SQL table function '{}' requires a non-empty AS body",
1712 ddl.function_name
1713 ),
1714 })?;
1715 let fields: Vec<_> = ddl
1716 .return_columns
1717 .iter()
1718 .map(|column| {
1719 arrow::datatypes::Field::new(&column.name, column.data_type.clone(), true)
1720 })
1721 .collect();
1722 let schema = arrow::datatypes::Schema::new(fields);
1723 let udf: std::sync::Arc<dyn krishiv_plan::udf::TableUdf> = std::sync::Arc::new(
1724 create_function_ddl::SqlBodyTableUdf::try_new(
1725 &ddl.function_name,
1726 schema,
1727 body,
1728 ddl.arguments.len(),
1729 std::sync::Arc::new(self.context.clone()),
1730 )
1731 .map_err(|error| SqlError::InvalidTableFunction {
1732 message: error.to_string(),
1733 })?,
1734 );
1735 if let Some(registry) = &self.udf_registry {
1736 let mut guard = registry.write().map_err(|e| SqlError::DataFusion {
1737 message: e.to_string(),
1738 })?;
1739 guard.register_table(std::sync::Arc::clone(&udf));
1740 }
1741 udf::register_single_table_udf(&self.context, std::sync::Arc::clone(&udf))
1742 .map_err(SqlError::from)?;
1743 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1744 return Ok(
1745 self.attach_query_metadata(self.make_sql_df("create-function", empty), query)
1746 );
1747 }
1748
1749 if query
1750 .trim_start()
1751 .to_ascii_uppercase()
1752 .starts_with("MERGE INTO")
1753 {
1754 let batches = lakehouse::execute_merge_sql(&self.context, query).await?;
1755 let merge_table = next_ephemeral_name("merge_result");
1756 lakehouse::register_scan_batches(&self.context, &merge_table, batches).await?;
1757 let dataframe = self
1758 .context
1759 .sql(&format!("SELECT * FROM {merge_table}"))
1760 .await?;
1761 return Ok(self.attach_query_metadata(self.make_sql_df("merge", dataframe), query));
1762 }
1763
1764 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
1767 if trimmed.to_ascii_uppercase().starts_with("CALL SYSTEM.") {
1768 let result = self.dispatch_call_system(trimmed).await?;
1769 let call_table = next_ephemeral_name("call_result");
1770 lakehouse::register_scan_batches(&self.context, &call_table, vec![result]).await?;
1771 let dataframe = self
1772 .context
1773 .sql(&format!("SELECT * FROM {call_table}"))
1774 .await?;
1775 return Ok(self.attach_query_metadata(self.make_sql_df("call", dataframe), query));
1776 }
1777
1778 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
1782 if trimmed.to_ascii_uppercase().starts_with("DELETE FROM ") {
1783 if let Some((table_ref, predicate)) = parse_dml_delete(trimmed) {
1784 if let Some((iceberg_catalog, table_ident)) = self.resolve_iceberg_table(&table_ref)
1785 {
1786 use arrow::array::{ArrayRef, Int64Array};
1787 use arrow::datatypes::{DataType, Field, Schema};
1788 let (deleted, _) = krishiv_connectors::lakehouse::dml::iceberg_delete_where(
1789 iceberg_catalog,
1790 &table_ident,
1791 &predicate,
1792 &self.context,
1793 )
1794 .await
1795 .map_err(|e| SqlError::DataFusion {
1796 message: e.to_string(),
1797 })?;
1798 let schema = Arc::new(Schema::new(vec![Field::new(
1799 "deleted_rows",
1800 DataType::Int64,
1801 false,
1802 )]));
1803 let array: ArrayRef = Arc::new(Int64Array::from(vec![deleted as i64]));
1804 let batch = RecordBatch::try_new(schema, vec![array]).map_err(|e| {
1805 SqlError::DataFusion {
1806 message: e.to_string(),
1807 }
1808 })?;
1809 let res_table = next_ephemeral_name("delete_result");
1810 lakehouse::register_scan_batches(&self.context, &res_table, vec![batch])
1811 .await?;
1812 let dataframe = self
1813 .context
1814 .sql(&format!("SELECT * FROM {res_table}"))
1815 .await?;
1816 return Ok(
1817 self.attach_query_metadata(self.make_sql_df("delete", dataframe), query)
1818 );
1819 }
1820 }
1821 }
1822
1823 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
1825 if trimmed.to_ascii_uppercase().starts_with("UPDATE ") {
1826 if let Some(parsed) = parse_dml_update(trimmed) {
1827 if let Some((iceberg_catalog, table_ident)) =
1828 self.resolve_iceberg_table(&parsed.table_ref)
1829 {
1830 use arrow::array::{ArrayRef, Int64Array};
1831 use arrow::datatypes::{DataType, Field, Schema};
1832 let borrowed: Vec<(&str, &str)> = parsed
1833 .assignments
1834 .iter()
1835 .map(|(c, e)| (c.as_str(), e.as_str()))
1836 .collect();
1837 let pred = parsed.predicate.as_deref();
1838 let (updated, _) = krishiv_connectors::lakehouse::dml::iceberg_update_where(
1839 iceberg_catalog,
1840 &table_ident,
1841 &borrowed,
1842 pred,
1843 &self.context,
1844 )
1845 .await
1846 .map_err(|e| SqlError::DataFusion {
1847 message: e.to_string(),
1848 })?;
1849 let schema = Arc::new(Schema::new(vec![Field::new(
1850 "updated_rows",
1851 DataType::Int64,
1852 false,
1853 )]));
1854 let array: ArrayRef = Arc::new(Int64Array::from(vec![updated as i64]));
1855 let batch = RecordBatch::try_new(schema, vec![array]).map_err(|e| {
1856 SqlError::DataFusion {
1857 message: e.to_string(),
1858 }
1859 })?;
1860 let res_table = next_ephemeral_name("update_result");
1861 lakehouse::register_scan_batches(&self.context, &res_table, vec![batch])
1862 .await?;
1863 let dataframe = self
1864 .context
1865 .sql(&format!("SELECT * FROM {res_table}"))
1866 .await?;
1867 return Ok(
1868 self.attach_query_metadata(self.make_sql_df("update", dataframe), query)
1869 );
1870 }
1871 }
1872 }
1873
1874 if query.to_ascii_uppercase().contains(" MATCH_RECOGNIZE ")
1878 && let Some(stmt) = cep_sql::parse_match_recognize(query)?
1879 {
1880 let is_streaming = self.is_streaming_source(&stmt.source_table);
1881 let streaming_limit = streaming_match_recognize_limit_from_env();
1889 let source_sql = if is_streaming {
1890 format!(
1891 "SELECT * FROM {} LIMIT {}",
1892 stmt.source_table, streaming_limit
1893 )
1894 } else {
1895 format!("SELECT * FROM {}", stmt.source_table)
1896 };
1897 let source_df = self.context.sql(&source_sql).await?;
1898 let source_batches = source_df.collect().await?;
1899 if is_streaming {
1900 tracing::warn!(
1901 source = %stmt.source_table,
1902 limit = streaming_limit,
1903 collected_rows = source_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
1904 "MATCH_RECOGNIZE executed against a streaming source under \
1905 bounded materialisation; results only cover the first {0} rows \
1906 of the source. Set KRISHIV_MATCH_RECOGNIZE_STREAMING_LIMIT to a \
1907 larger value if your executor has the memory budget.",
1908 streaming_limit
1909 );
1910 }
1911 let results = cep_sql::execute_match_recognize(stmt, &source_batches)?;
1912 let cep_table = next_ephemeral_name("cep_result");
1913 lakehouse::register_scan_batches(&self.context, &cep_table, results).await?;
1914 let dataframe = self
1915 .context
1916 .sql(&format!("SELECT * FROM {cep_table}"))
1917 .await?;
1918 return Ok(self.attach_query_metadata(self.make_sql_df("cep", dataframe), query));
1919 }
1920
1921 let query = &streaming_tvf::rewrite_window_tvfs(query);
1923
1924 let (rewritten, as_ofs) =
1925 lakehouse::preprocess_as_of_sql(query).unwrap_or_else(|_| (query.to_string(), vec![]));
1926 lakehouse::apply_as_of_refs(&self.context, &as_ofs).await?;
1927
1928 let can_cache = as_ofs.is_empty();
1935 let shuffle_override = self
1936 .shuffle_partitions
1937 .read()
1938 .map(|g| *g)
1939 .unwrap_or_else(|e| *e.into_inner());
1940 if can_cache {
1941 let cached_plan: Option<datafusion::logical_expr::LogicalPlan> = self
1943 .plan_cache
1944 .lock()
1945 .unwrap_or_else(|e| e.into_inner())
1946 .get(&rewritten)
1947 .cloned();
1948 if let Some(plan) = cached_plan {
1949 let dataframe = self.context.execute_logical_plan(plan).await?;
1950 return Ok(self.attach_query_metadata(
1951 self.make_sql_df("sql-query", dataframe)
1952 .with_shuffle_partitions(shuffle_override),
1953 &rewritten,
1954 ));
1955 }
1956 }
1957
1958 let dataframe = self.context.sql(&rewritten).await?;
1959
1960 if let Some(table_name) = extract_create_external_table_name(&rewritten)
1964 && !table_name.is_empty()
1965 && let Ok(provider) = self.context.table_provider(&table_name).await
1966 {
1967 let maybe_rows = provider
1968 .statistics()
1969 .and_then(|s| s.num_rows.get_value().copied());
1970 if let Some(n) = maybe_rows
1971 && let Ok(mut counts) = self.table_row_counts.write()
1972 {
1973 counts.entry(table_name).or_insert(n as u64);
1974 }
1975 }
1976
1977 if can_cache {
1979 let plan = dataframe.logical_plan().clone();
1980 match self.plan_cache.lock() {
1981 Ok(mut cache) => cache.insert(rewritten.clone(), plan),
1982 Err(poisoned) => poisoned.into_inner().insert(rewritten.clone(), plan),
1983 }
1984 }
1985
1986 Ok(self.attach_query_metadata(
1987 self.make_sql_df("sql-query", dataframe)
1988 .with_shuffle_partitions(shuffle_override),
1989 &rewritten,
1990 ))
1991 }
1992
1993 pub async fn execute_with_timeout(
2000 &self,
2001 query: impl AsRef<str> + Send,
2002 timeout_ms: u64,
2003 ) -> SqlResult<SqlDataFrame> {
2004 let timeout = std::time::Duration::from_millis(timeout_ms);
2005 tokio::time::timeout(timeout, self.sql(query))
2006 .await
2007 .map_err(|_| SqlError::Timeout { timeout_ms })?
2008 }
2009
2010 pub async fn execute_with_operation_id(
2017 &self,
2018 operation_id: u64,
2019 query: impl AsRef<str> + Send,
2020 cancelled_ids: &OperationRegistry,
2021 ) -> SqlResult<TaggedQueryResult> {
2022 if cancelled_ids.is_cancelled(operation_id) {
2023 return Err(SqlError::OperationCancelled { operation_id });
2024 }
2025 let df = self.sql(query).await?;
2026 Ok(TaggedQueryResult {
2027 operation_id,
2028 inner: df,
2029 })
2030 }
2031
2032 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
2038 fn resolve_iceberg_table(
2039 &self,
2040 table_ref: &str,
2041 ) -> Option<(Arc<dyn iceberg::Catalog + Send + Sync>, iceberg::TableIdent)> {
2042 let parts: Vec<&str> = table_ref.splitn(3, '.').collect();
2043 let (catalog_arc, ns_str, table_str) = {
2044 let guard = self
2045 .iceberg_catalogs
2046 .read()
2047 .unwrap_or_else(|e| e.into_inner());
2048 if guard.is_empty() {
2049 return None;
2050 }
2051 match parts.len() {
2052 2 => {
2053 let (cat, _) = guard.first()?;
2054 (Arc::clone(cat), *parts.first()?, *parts.get(1)?)
2055 }
2056 3 => {
2057 let cat_name = parts.first().copied()?;
2058 let (cat, _) = guard.iter().find(|(_, n)| n == cat_name)?;
2059 (Arc::clone(cat), *parts.get(1)?, *parts.get(2)?)
2060 }
2061 _ => return None,
2062 }
2063 };
2064 let ns = iceberg::NamespaceIdent::from_vec(vec![ns_str.to_string()]).ok()?;
2065 let ident = iceberg::TableIdent::new(ns, table_str.to_string());
2066 Some((catalog_arc.as_iceberg(), ident))
2067 }
2068
2069 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
2072 async fn dispatch_call_system(&self, stmt: &str) -> SqlResult<RecordBatch> {
2073 use arrow::array::{ArrayRef, Int64Array};
2074 use arrow::datatypes::{DataType, Field, Schema};
2075
2076 let upper = stmt.to_ascii_uppercase();
2077 const PREFIX: &str = "CALL SYSTEM.";
2078 let upper_after = &upper[PREFIX.len()..];
2079 let orig_after = &stmt[PREFIX.len()..];
2080
2081 let paren = upper_after.find('(').ok_or_else(|| SqlError::DataFusion {
2082 message: format!("CALL: missing '(' in: {stmt}"),
2083 })?;
2084 let proc_name = upper_after[..paren].trim();
2085
2086 let args_raw = orig_after[paren + 1..]
2087 .trim_end_matches(';')
2088 .trim()
2089 .trim_end_matches(')')
2090 .trim();
2091 let args = call_args_from_str(args_raw);
2092
2093 let iceberg_catalog = {
2094 let guard = self
2095 .iceberg_catalogs
2096 .read()
2097 .unwrap_or_else(|e| e.into_inner());
2098 guard
2099 .first()
2100 .ok_or_else(|| SqlError::DataFusion {
2101 message: "CALL system: no Iceberg catalog registered".to_string(),
2102 })?
2103 .0
2104 .as_iceberg()
2105 };
2106
2107 let table_ref = args.first().ok_or_else(|| SqlError::DataFusion {
2108 message: format!("CALL {proc_name}: table reference argument is required"),
2109 })?;
2110 let table_ident = iceberg_table_ident(table_ref)?;
2111
2112 let count: i64 = match proc_name {
2113 "EXPIRE_SNAPSHOTS" => {
2114 let dur_s = args.get(1).ok_or_else(|| SqlError::DataFusion {
2115 message: "CALL expire_snapshots: duration argument is required".to_string(),
2116 })?;
2117 let older_than = parse_call_duration(dur_s)?;
2118 let retain_last = args
2119 .get(2)
2120 .and_then(|s| s.parse::<usize>().ok())
2121 .unwrap_or(1);
2122 krishiv_connectors::lakehouse::maintenance::expire_snapshots(
2123 iceberg_catalog,
2124 &table_ident,
2125 older_than,
2126 retain_last,
2127 )
2128 .await
2129 .map_err(|e| SqlError::DataFusion {
2130 message: e.to_string(),
2131 })? as i64
2132 }
2133 "REMOVE_ORPHAN_FILES" => {
2134 let dur_s = args.get(1).ok_or_else(|| SqlError::DataFusion {
2135 message: "CALL remove_orphan_files: duration argument is required".to_string(),
2136 })?;
2137 let older_than = parse_call_duration(dur_s)?;
2138 krishiv_connectors::lakehouse::maintenance::remove_orphan_files(
2139 iceberg_catalog,
2140 &table_ident,
2141 older_than,
2142 )
2143 .await
2144 .map_err(|e| SqlError::DataFusion {
2145 message: e.to_string(),
2146 })? as i64
2147 }
2148 "COMPACT_DATA_FILES" => {
2149 let target_bytes = args
2150 .get(1)
2151 .and_then(|s| s.parse::<u64>().ok())
2152 .unwrap_or(128 * 1024 * 1024);
2153 krishiv_connectors::lakehouse::maintenance::compact_data_files(
2154 iceberg_catalog,
2155 &table_ident,
2156 target_bytes,
2157 )
2158 .await
2159 .map_err(|e| SqlError::DataFusion {
2160 message: e.to_string(),
2161 })? as i64
2162 }
2163 other => {
2164 return Err(SqlError::Unsupported {
2165 feature: format!("CALL system.{other}: unknown procedure"),
2166 });
2167 }
2168 };
2169
2170 let col = match proc_name {
2171 "EXPIRE_SNAPSHOTS" => "expired_snapshots",
2172 "REMOVE_ORPHAN_FILES" => "removed_files",
2173 "COMPACT_DATA_FILES" => "rewritten_files",
2174 _ => "result",
2175 };
2176 let schema = Arc::new(Schema::new(vec![Field::new(col, DataType::Int64, false)]));
2177 let array: ArrayRef = Arc::new(Int64Array::from(vec![count]));
2178 RecordBatch::try_new(schema, vec![array]).map_err(|e| SqlError::DataFusion {
2179 message: e.to_string(),
2180 })
2181 }
2182}
2183
2184pub struct TaggedQueryResult {
2186 pub operation_id: u64,
2188 pub inner: SqlDataFrame,
2190}
2191
2192#[derive(Clone, Default)]
2198pub struct OperationRegistry {
2199 cancelled: Arc<std::sync::RwLock<std::collections::HashSet<u64>>>,
2200 progress: Arc<std::sync::RwLock<std::collections::HashMap<u64, (u64, u64)>>>,
2201}
2202
2203impl OperationRegistry {
2204 pub fn new() -> Self {
2206 Self::default()
2207 }
2208
2209 pub fn cancel(&self, operation_id: u64) {
2213 if let Ok(mut ids) = self.cancelled.write() {
2214 ids.insert(operation_id);
2215 }
2216 }
2217
2218 pub fn is_cancelled(&self, operation_id: u64) -> bool {
2220 self.cancelled
2221 .read()
2222 .map(|ids| ids.contains(&operation_id))
2223 .unwrap_or(false)
2224 }
2225
2226 pub fn remove(&self, operation_id: u64) {
2228 if let Ok(mut ids) = self.cancelled.write() {
2229 ids.remove(&operation_id);
2230 }
2231 if let Ok(mut progress) = self.progress.write() {
2232 progress.remove(&operation_id);
2233 }
2234 }
2235
2236 pub fn update_progress(&self, operation_id: u64, rows_scanned: u64, rows_emitted: u64) {
2238 if let Ok(mut progress) = self.progress.write() {
2239 progress.insert(operation_id, (rows_scanned, rows_emitted));
2240 }
2241 }
2242
2243 pub fn progress(&self, operation_id: u64) -> Option<(u64, u64)> {
2245 self.progress
2246 .read()
2247 .ok()
2248 .and_then(|progress| progress.get(&operation_id).copied())
2249 }
2250
2251 pub fn cancelled_ids(&self) -> Vec<u64> {
2253 self.cancelled
2254 .read()
2255 .map(|ids| ids.iter().copied().collect())
2256 .unwrap_or_default()
2257 }
2258}
2259
2260pub(crate) fn extract_create_external_table_name(query: &str) -> Option<String> {
2265 use datafusion::sql::parser::{DFParser, Statement as DFStatement};
2266 let mut stmts = DFParser::parse_sql(query).ok()?;
2267 match stmts.pop_front()? {
2268 DFStatement::CreateExternalTable(create) => Some(create.name.to_string()),
2269 _ => None,
2270 }
2271}
2272
2273pub enum GroupingMode<'a> {
2281 Sets(Vec<Vec<&'a krishiv_plan::expression::Expr>>),
2282 Cube(Vec<&'a krishiv_plan::expression::Expr>),
2283 Rollup(Vec<&'a krishiv_plan::expression::Expr>),
2284}
2285
2286#[async_trait::async_trait]
2287pub trait KrishivDataFrameOps: Send + Sync {
2288 async fn collect(&self) -> SqlResult<Vec<RecordBatch>>;
2290 async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)>;
2292 async fn explain(&self) -> SqlResult<String>;
2294 fn explain_logical(&self) -> String;
2296 fn krishiv_logical_plan(&self) -> LogicalPlan;
2298 fn query(&self) -> Option<&str>;
2300 async fn execute_stream(&self) -> SqlResult<SqlStream>;
2302
2303 fn schema(&self) -> SchemaRef;
2307
2308 async fn select(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2310
2311 async fn select_exprs(
2313 &self,
2314 expressions: &[&krishiv_plan::expression::Expr],
2315 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2316
2317 async fn aggregate(
2319 &self,
2320 group_exprs: &[&krishiv_plan::expression::Expr],
2321 aggregate_exprs: &[&krishiv_plan::expression::Expr],
2322 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2323
2324 async fn aggregate_grouping(
2326 &self,
2327 grouping: GroupingMode<'_>,
2328 aggregate_exprs: &[&krishiv_plan::expression::Expr],
2329 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2330
2331 async fn pivot(
2333 &self,
2334 group_exprs: &[&krishiv_plan::expression::Expr],
2335 pivot_column: &krishiv_plan::expression::Expr,
2336 aggregate_expr: &krishiv_plan::expression::Expr,
2337 values: &[(krishiv_plan::expression::ScalarValue, String)],
2338 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2339
2340 async fn unpivot(
2342 &self,
2343 columns: &[&str],
2344 name_column: &str,
2345 value_column: &str,
2346 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2347
2348 async fn filter(&self, predicate: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2350
2351 async fn filter_expr(
2353 &self,
2354 predicate: &krishiv_plan::expression::Expr,
2355 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2356
2357 async fn limit(&self, n: usize) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2359
2360 async fn distinct(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2362
2363 async fn drop_nulls(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2365
2366 async fn sample(&self, fraction: f64) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2368
2369 async fn sort(
2371 &self,
2372 columns: &[&str],
2373 descending: &[bool],
2374 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2375
2376 async fn alias(&self, alias: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2378
2379 async fn drop_columns(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2381
2382 async fn rename_column(&self, old: &str, new: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2384
2385 async fn with_column(&self, name: &str, expr: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2387
2388 fn as_any(&self) -> &dyn std::any::Any;
2390
2391 async fn describe(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2393
2394 async fn fill_null(&self, column: &str, value: &str)
2396 -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2397
2398 async fn join(
2400 &self,
2401 right: &dyn KrishivDataFrameOps,
2402 how: &str,
2403 left_on: &[&str],
2404 right_on: &[&str],
2405 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2406
2407 async fn union(
2409 &self,
2410 right: &dyn KrishivDataFrameOps,
2411 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2412
2413 async fn union_distinct(
2414 &self,
2415 right: &dyn KrishivDataFrameOps,
2416 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2417
2418 async fn intersect(
2419 &self,
2420 right: &dyn KrishivDataFrameOps,
2421 distinct: bool,
2422 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2423
2424 async fn except(
2425 &self,
2426 right: &dyn KrishivDataFrameOps,
2427 distinct: bool,
2428 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2429
2430 async fn register_batches(&self, name: &str, batches: Vec<RecordBatch>) -> SqlResult<()>;
2433
2434 async fn deregister_table(&self, name: &str) -> SqlResult<()>;
2436
2437 async fn create_view(&self, name: &str, replace: bool) -> SqlResult<()>;
2440}
2441
2442fn df_plan_to_krishiv_nodes(
2450 plan: &datafusion::logical_expr::LogicalPlan,
2451 table_row_counts: &std::collections::HashMap<String, u64>,
2452 counter: &mut usize,
2453) -> (Vec<krishiv_plan::PlanNode>, String) {
2454 use datafusion::logical_expr::LogicalPlan as DfPlan;
2455 use krishiv_plan::{ExecutionKind, NodeOp, PlanNode};
2456
2457 *counter += 1;
2458 let idx = *counter;
2459
2460 match plan {
2461 DfPlan::TableScan(ts) => {
2462 let table_name = ts.table_name.table().to_string();
2463 let row_count = table_row_counts.get(&table_name).copied();
2464 let filters: Vec<String> = ts.filters.iter().map(|e| e.to_string()).collect();
2465 let id = format!("scan-{idx}");
2466 let node = PlanNode::new(&id, format!("Scan {table_name}"), ExecutionKind::Batch)
2467 .with_op(NodeOp::Scan {
2468 table: table_name,
2469 filters,
2470 })
2471 .with_estimated_rows(row_count);
2472 (vec![node], id)
2473 }
2474
2475 DfPlan::Projection(proj) => {
2476 let (mut nodes, input_id) =
2477 df_plan_to_krishiv_nodes(&proj.input, table_row_counts, counter);
2478 let id = format!("proj-{idx}");
2479 let columns: Vec<String> = proj.expr.iter().map(|e| e.to_string()).collect();
2480 nodes.push(
2481 PlanNode::new(&id, "Projection", ExecutionKind::Batch)
2482 .with_op(NodeOp::Project { columns })
2483 .with_inputs([input_id]),
2484 );
2485 (nodes, id)
2486 }
2487
2488 DfPlan::Filter(filter) => {
2489 let (mut nodes, input_id) =
2490 df_plan_to_krishiv_nodes(&filter.input, table_row_counts, counter);
2491 let id = format!("filter-{idx}");
2492 let predicate = filter.predicate.to_string();
2493 nodes.push(
2494 PlanNode::new(&id, "Filter", ExecutionKind::Batch)
2495 .with_op(NodeOp::Filter { predicate })
2496 .with_inputs([input_id]),
2497 );
2498 (nodes, id)
2499 }
2500
2501 DfPlan::Aggregate(agg) => {
2502 let (mut nodes, input_id) =
2503 df_plan_to_krishiv_nodes(&agg.input, table_row_counts, counter);
2504 let id = format!("agg-{idx}");
2505 let group_keys: Vec<String> = agg.group_expr.iter().map(|e| e.to_string()).collect();
2506 nodes.push(
2507 PlanNode::new(&id, "Aggregate", ExecutionKind::Batch)
2508 .with_op(NodeOp::Aggregate { group_keys })
2509 .with_inputs([input_id]),
2510 );
2511 (nodes, id)
2512 }
2513
2514 DfPlan::Join(join) => {
2515 let (mut nodes, left_id) =
2516 df_plan_to_krishiv_nodes(&join.left, table_row_counts, counter);
2517 let (right_nodes, right_id) =
2518 df_plan_to_krishiv_nodes(&join.right, table_row_counts, counter);
2519 nodes.extend(right_nodes);
2520 let id = format!("join-{idx}");
2521 let krishiv_join_type = match join.join_type {
2526 datafusion::common::JoinType::Inner => krishiv_plan::JoinType::Inner,
2527 datafusion::common::JoinType::Left => krishiv_plan::JoinType::Left,
2528 datafusion::common::JoinType::Right => krishiv_plan::JoinType::Right,
2529 datafusion::common::JoinType::Full => krishiv_plan::JoinType::Full,
2530 datafusion::common::JoinType::LeftSemi => krishiv_plan::JoinType::LeftSemi,
2531 datafusion::common::JoinType::RightSemi => krishiv_plan::JoinType::RightSemi,
2532 datafusion::common::JoinType::LeftAnti => krishiv_plan::JoinType::LeftAnti,
2533 datafusion::common::JoinType::RightAnti => krishiv_plan::JoinType::RightAnti,
2534 datafusion::common::JoinType::LeftMark => krishiv_plan::JoinType::LeftSemi,
2538 datafusion::common::JoinType::RightMark => krishiv_plan::JoinType::RightSemi,
2539 };
2540 nodes.push(
2541 PlanNode::new(&id, "Join", ExecutionKind::Batch)
2542 .with_op(NodeOp::Join {
2543 join_type: krishiv_join_type,
2544 })
2545 .with_inputs([left_id, right_id]),
2546 );
2547 (nodes, id)
2548 }
2549
2550 DfPlan::Sort(sort) => {
2551 let (mut nodes, input_id) =
2552 df_plan_to_krishiv_nodes(&sort.input, table_row_counts, counter);
2553 let id = format!("sort-{idx}");
2554 nodes.push(
2555 PlanNode::new(&id, "Sort", ExecutionKind::Batch)
2556 .with_op(NodeOp::Other {
2557 description: format!(
2558 "Sort({})",
2559 sort.expr
2560 .iter()
2561 .map(|e| e.to_string())
2562 .collect::<Vec<_>>()
2563 .join(", ")
2564 ),
2565 })
2566 .with_inputs([input_id]),
2567 );
2568 (nodes, id)
2569 }
2570
2571 DfPlan::Repartition(repart) => {
2572 let (mut nodes, input_id) =
2573 df_plan_to_krishiv_nodes(&repart.input, table_row_counts, counter);
2574 let id = format!("exchange-{idx}");
2575 let partitioning = krishiv_plan::Partitioning::Unpartitioned;
2576 nodes.push(
2577 PlanNode::new(&id, "Exchange", ExecutionKind::Batch)
2578 .with_op(NodeOp::Exchange { partitioning })
2579 .with_inputs([input_id]),
2580 );
2581 (nodes, id)
2582 }
2583
2584 DfPlan::Limit(limit) => {
2585 let (mut nodes, input_id) =
2586 df_plan_to_krishiv_nodes(&limit.input, table_row_counts, counter);
2587 let id = format!("limit-{idx}");
2588 nodes.push(
2589 PlanNode::new(&id, "Limit", ExecutionKind::Batch)
2590 .with_op(NodeOp::Other {
2591 description: format!(
2592 "Limit(skip={:?}, fetch={:?})",
2593 limit.skip.as_ref().map(|e| e.to_string()),
2594 limit.fetch.as_ref().map(|e| e.to_string()),
2595 ),
2596 })
2597 .with_inputs([input_id]),
2598 );
2599 (nodes, id)
2600 }
2601
2602 DfPlan::Union(union) if union.inputs.len() == 1 => {
2603 if let Some(input) = union.inputs.first() {
2604 df_plan_to_krishiv_nodes(input, table_row_counts, counter)
2605 } else {
2606 (Vec::new(), String::new())
2607 }
2608 }
2609 DfPlan::Union(union) => {
2610 let mut all_nodes = Vec::new();
2611 let mut input_ids = Vec::new();
2612 for input in &union.inputs {
2613 let (sub_nodes, sub_id) =
2614 df_plan_to_krishiv_nodes(input, table_row_counts, counter);
2615 all_nodes.extend(sub_nodes);
2616 input_ids.push(sub_id);
2617 }
2618 let id = format!("union-{idx}");
2619 all_nodes.push(
2620 PlanNode::new(&id, "Union", ExecutionKind::Batch)
2621 .with_op(NodeOp::Other {
2622 description: "Union".to_string(),
2623 })
2624 .with_inputs(input_ids),
2625 );
2626 (all_nodes, id)
2627 }
2628
2629 DfPlan::SubqueryAlias(alias) => {
2630 df_plan_to_krishiv_nodes(&alias.input, table_row_counts, counter)
2632 }
2633
2634 DfPlan::Values(_) => {
2635 let id = format!("values-{idx}");
2636 let node = PlanNode::new(&id, "Values", ExecutionKind::Batch).with_op(NodeOp::Other {
2637 description: "Values".to_string(),
2638 });
2639 (vec![node], id)
2640 }
2641
2642 DfPlan::Extension(_) => {
2643 let id = format!("ext-{idx}");
2644 let label = plan.to_string();
2645 let node = PlanNode::new(&id, label.clone(), ExecutionKind::Batch)
2646 .with_op(NodeOp::Other { description: label });
2647 (vec![node], id)
2648 }
2649
2650 DfPlan::EmptyRelation(_) => {
2651 let id = format!("empty-{idx}");
2652 let node =
2653 PlanNode::new(&id, "EmptyRelation", ExecutionKind::Batch).with_op(NodeOp::Other {
2654 description: "EmptyRelation".to_string(),
2655 });
2656 (vec![node], id)
2657 }
2658
2659 _ => {
2661 let id = format!("df-{idx}");
2662 let label = plan.to_string();
2663 let node = PlanNode::new(&id, label.clone(), ExecutionKind::Batch)
2664 .with_op(NodeOp::Other { description: label });
2665 (vec![node], id)
2666 }
2667 }
2668}
2669
2670#[derive(Clone)]
2672pub struct SqlDataFrame {
2673 name: String,
2674 query: Option<String>,
2675 query_text: Option<String>,
2677 execution_kind: ExecutionKind,
2678 dataframe: DataFusionDataFrame,
2679 shuffle_partitions: Option<u32>,
2680 context: SessionContext,
2682 table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
2686}
2687
2688impl fmt::Debug for SqlDataFrame {
2689 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2690 f.debug_struct("SqlDataFrame")
2691 .field("name", &self.name)
2692 .field("query", &self.query)
2693 .field("shuffle_partitions", &self.shuffle_partitions)
2694 .finish_non_exhaustive()
2695 }
2696}
2697
2698impl SqlDataFrame {
2699 fn new(
2700 name: impl Into<String>,
2701 dataframe: DataFusionDataFrame,
2702 table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
2703 ) -> Self {
2704 Self {
2705 name: name.into(),
2706 query: None,
2707 query_text: None,
2708 execution_kind: ExecutionKind::Batch,
2709 dataframe,
2710 shuffle_partitions: None,
2711 context: SessionContext::default(),
2712 table_row_counts,
2713 }
2714 }
2715
2716 pub(crate) fn with_context(mut self, context: SessionContext) -> Self {
2718 self.context = context;
2719 self
2720 }
2721
2722 fn with_query(mut self, query: impl Into<String>) -> Self {
2723 let q = query.into();
2724 self.query_text = Some(q.clone());
2725 self.query = Some(q);
2726 self
2727 }
2728
2729 fn with_execution_kind(mut self, kind: ExecutionKind) -> Self {
2730 self.execution_kind = kind;
2731 self
2732 }
2733
2734 fn with_shuffle_partitions(mut self, n: Option<u32>) -> Self {
2735 self.shuffle_partitions = n;
2736 self
2737 }
2738
2739 pub fn query(&self) -> Option<&str> {
2741 self.query.as_deref()
2742 }
2743
2744 fn with_new_dataframe(&self, df: DataFusionDataFrame, tag: &str) -> Self {
2748 Self {
2749 name: format!("{}-{}", self.name, tag),
2750 query: None,
2751 query_text: None,
2752 execution_kind: self.execution_kind,
2753 dataframe: df,
2754 shuffle_partitions: self.shuffle_partitions,
2755 context: self.context.clone(),
2756 table_row_counts: self.table_row_counts.clone(),
2757 }
2758 }
2759
2760 pub fn krishiv_logical_plan(&self) -> LogicalPlan {
2769 let df_plan = self.dataframe.logical_plan();
2770 let counts = self
2771 .table_row_counts
2772 .read()
2773 .unwrap_or_else(|e| e.into_inner());
2774 let mut counter = 0usize;
2775 let (nodes, _root_id) = df_plan_to_krishiv_nodes(df_plan, &counts, &mut counter);
2776
2777 let mut plan = LogicalPlan::new(self.name.clone(), self.execution_kind);
2778 for node in nodes {
2779 plan = plan.with_node(node);
2780 }
2781
2782 let optimizer = krishiv_plan::optimizer::default_logical_optimizer();
2787 let fallback = plan.clone();
2788 match optimizer.optimize(plan) {
2789 Ok(result) => result.plan,
2790 Err(error) => {
2791 tracing::warn!(
2792 plan = %self.name,
2793 %error,
2794 "logical optimizer failed; using unoptimized plan"
2795 );
2796 fallback
2797 }
2798 }
2799 }
2800
2801 pub fn explain_logical(&self) -> String {
2803 self.dataframe.logical_plan().to_string()
2804 }
2805
2806 pub async fn explain(&self) -> SqlResult<String> {
2808 let batches = self
2809 .dataframe
2810 .clone()
2811 .explain(false, false)?
2812 .collect()
2813 .await?;
2814 pretty_batches(&batches)
2815 }
2816
2817 pub async fn collect(&self) -> SqlResult<Vec<RecordBatch>> {
2819 Ok(self.dataframe.clone().collect().await?)
2820 }
2821
2822 pub async fn execute_stream(&self) -> SqlResult<SqlStream> {
2824 let df_stream = self.dataframe.clone().execute_stream().await?;
2825 use futures::StreamExt;
2826 let mapped = df_stream.map(|res| {
2827 res.map_err(|e| SqlError::DataFusion {
2828 message: e.to_string(),
2829 })
2830 });
2831 Ok(Box::pin(mapped))
2832 }
2833
2834 pub async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)> {
2842 use datafusion::physical_plan::collect as df_collect;
2843
2844 let df = self.dataframe.clone();
2845 let task_ctx = df.task_ctx();
2846 let physical_plan = df.create_physical_plan().await?;
2847
2848 let batches = df_collect(physical_plan.clone(), task_ctx.into()).await?;
2849
2850 let mut output_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
2851 let mut cpu_nanos: u64 = 0;
2852
2853 if let Some(metrics) = physical_plan.metrics() {
2854 if let Some(v) = metrics.output_rows() {
2855 output_rows = v as u64;
2856 }
2857 if let Some(t) = metrics.elapsed_compute() {
2858 cpu_nanos = t as u64;
2859 }
2860 }
2861
2862 let (spill_bytes, spill_count) = aggregate_spill_metrics(physical_plan.as_ref());
2863
2864 Ok((
2865 batches,
2866 SqlExecutionStats {
2867 output_rows,
2868 cpu_nanos,
2869 spill_bytes,
2870 spill_count,
2871 },
2872 ))
2873 }
2874}
2875
2876fn aggregate_spill_metrics(plan: &dyn datafusion::physical_plan::ExecutionPlan) -> (u64, u64) {
2883 let mut spill_bytes: u64 = 0;
2884 let mut spill_count: u64 = 0;
2885 if let Some(metrics) = plan.metrics() {
2886 if let Some(bytes) = metrics.spilled_bytes() {
2887 spill_bytes = spill_bytes.saturating_add(bytes as u64);
2888 }
2889 if let Some(count) = metrics.spill_count() {
2890 spill_count = spill_count.saturating_add(count as u64);
2891 }
2892 }
2893 for child in plan.children() {
2894 let (child_bytes, child_count) = aggregate_spill_metrics(child.as_ref());
2895 spill_bytes = spill_bytes.saturating_add(child_bytes);
2896 spill_count = spill_count.saturating_add(child_count);
2897 }
2898 (spill_bytes, spill_count)
2899}
2900
2901#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
2903pub struct SqlExecutionStats {
2904 pub output_rows: u64,
2905 pub cpu_nanos: u64,
2906 pub spill_bytes: u64,
2908 pub spill_count: u64,
2910}
2911
2912fn top_level_alias_index(expression: &str) -> Option<usize> {
2913 let bytes = expression.as_bytes();
2914 let mut depth = 0usize;
2915 let mut single_quoted = false;
2916 let mut double_quoted = false;
2917 let mut candidate = None;
2918 let mut index = 0usize;
2919 while index < bytes.len() {
2920 let Some(&byte) = bytes.get(index) else {
2921 break;
2922 };
2923 match byte {
2924 b'\'' if !double_quoted => {
2925 if single_quoted && bytes.get(index + 1) == Some(&b'\'') {
2926 index += 2;
2927 continue;
2928 }
2929 single_quoted = !single_quoted;
2930 }
2931 b'"' if !single_quoted => {
2932 if double_quoted && bytes.get(index + 1) == Some(&b'"') {
2933 index += 2;
2934 continue;
2935 }
2936 double_quoted = !double_quoted;
2937 }
2938 b'(' if !single_quoted && !double_quoted => depth += 1,
2939 b')' if !single_quoted && !double_quoted => depth = depth.saturating_sub(1),
2940 b' ' if depth == 0
2941 && !single_quoted
2942 && !double_quoted
2943 && bytes
2944 .get(index..index + 4)
2945 .is_some_and(|slice| slice.eq_ignore_ascii_case(b" AS ")) =>
2946 {
2947 candidate = Some(index);
2948 index += 3;
2949 }
2950 _ => {}
2951 }
2952 index += 1;
2953 }
2954 candidate
2955}
2956
2957fn parse_dataframe_expression(
2958 dataframe: &datafusion::dataframe::DataFrame,
2959 expression: &str,
2960) -> SqlResult<datafusion::logical_expr::Expr> {
2961 if let Some(index) = top_level_alias_index(expression) {
2962 let (body, alias) = expression.split_at(index);
2963 let alias = alias[4..].trim();
2964 if !alias.is_empty() {
2965 let alias = alias
2966 .strip_prefix('"')
2967 .and_then(|value| value.strip_suffix('"'))
2968 .unwrap_or(alias)
2969 .replace("\"\"", "\"");
2970 return Ok(dataframe.parse_sql_expr(body.trim())?.alias(alias));
2971 }
2972 }
2973 dataframe.parse_sql_expr(expression).map_err(Into::into)
2974}
2975
2976pub fn parse_public_expression(sql: &str) -> SqlResult<krishiv_plan::expression::Expr> {
2978 let dialect = GenericDialect {};
2979 let mut parser =
2980 Parser::new(&dialect)
2981 .try_with_sql(sql)
2982 .map_err(|error| SqlError::Unsupported {
2983 feature: format!("public expression parse: {error}"),
2984 })?;
2985 let expression = parser.parse_expr().map_err(|error| SqlError::Unsupported {
2986 feature: format!("public expression parse: {error}"),
2987 })?;
2988 sqlparser_expression_to_public(&expression)
2989}
2990
2991fn sqlparser_expression_to_public(
2992 expression: &datafusion::sql::sqlparser::ast::Expr,
2993) -> SqlResult<krishiv_plan::expression::Expr> {
2994 use datafusion::sql::sqlparser::ast::{BinaryOperator as SqlOperator, Expr as SqlExpr, Value};
2995 use krishiv_plan::expression::{BinaryOperator, Expr, ScalarValue};
2996
2997 Ok(match expression {
2998 SqlExpr::Identifier(identifier) => Expr::Column {
2999 path: vec![identifier.value.clone()],
3000 },
3001 SqlExpr::CompoundIdentifier(identifiers) => Expr::Column {
3002 path: identifiers
3003 .iter()
3004 .map(|identifier| identifier.value.clone())
3005 .collect(),
3006 },
3007 SqlExpr::Nested(expression) => sqlparser_expression_to_public(expression)?,
3008 SqlExpr::IsNull(expression) => Expr::IsNull {
3009 expression: Box::new(sqlparser_expression_to_public(expression)?),
3010 negated: false,
3011 },
3012 SqlExpr::IsNotNull(expression) => Expr::IsNull {
3013 expression: Box::new(sqlparser_expression_to_public(expression)?),
3014 negated: true,
3015 },
3016 SqlExpr::BinaryOp { left, op, right } => Expr::Binary {
3017 left: Box::new(sqlparser_expression_to_public(left)?),
3018 op: match op {
3019 SqlOperator::Eq => BinaryOperator::Eq,
3020 SqlOperator::NotEq => BinaryOperator::NotEq,
3021 SqlOperator::Gt => BinaryOperator::Gt,
3022 SqlOperator::GtEq => BinaryOperator::GtEq,
3023 SqlOperator::Lt => BinaryOperator::Lt,
3024 SqlOperator::LtEq => BinaryOperator::LtEq,
3025 SqlOperator::And => BinaryOperator::And,
3026 SqlOperator::Or => BinaryOperator::Or,
3027 SqlOperator::Plus => BinaryOperator::Plus,
3028 SqlOperator::Minus => BinaryOperator::Minus,
3029 SqlOperator::Multiply => BinaryOperator::Multiply,
3030 SqlOperator::Divide => BinaryOperator::Divide,
3031 other => {
3032 return Err(SqlError::Unsupported {
3033 feature: format!("public expression operator {other}"),
3034 });
3035 }
3036 },
3037 right: Box::new(sqlparser_expression_to_public(right)?),
3038 },
3039 SqlExpr::Value(value) => Expr::Literal {
3040 value: match &value.value {
3041 Value::Null => ScalarValue::Null,
3042 Value::Boolean(value) => ScalarValue::Boolean(*value),
3043 Value::SingleQuotedString(value) => ScalarValue::Utf8(value.clone()),
3044 Value::Number(value, _)
3045 if value.contains('.') || value.contains('e') || value.contains('E') =>
3046 {
3047 ScalarValue::float64(value.parse::<f64>().map_err(|error| {
3048 SqlError::Unsupported {
3049 feature: format!("numeric expression literal: {error}"),
3050 }
3051 })?)
3052 }
3053 Value::Number(value, _) => {
3054 ScalarValue::Int64(value.parse::<i64>().map_err(|error| {
3055 SqlError::Unsupported {
3056 feature: format!("integer expression literal: {error}"),
3057 }
3058 })?)
3059 }
3060 other => {
3061 return Err(SqlError::Unsupported {
3062 feature: format!("public expression literal {other}"),
3063 });
3064 }
3065 },
3066 },
3067 other => {
3068 return Err(SqlError::Unsupported {
3069 feature: format!("public expression node {other}"),
3070 });
3071 }
3072 })
3073}
3074
3075fn public_data_type_to_arrow(
3076 data_type: &krishiv_plan::expression::ExprDataType,
3077) -> arrow::datatypes::DataType {
3078 use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
3079 use krishiv_plan::expression::{ExprDataType, IntervalUnit as PublicIntervalUnit};
3080
3081 match data_type {
3082 ExprDataType::Null => DataType::Null,
3083 ExprDataType::Boolean => DataType::Boolean,
3084 ExprDataType::Int64 => DataType::Int64,
3085 ExprDataType::UInt64 => DataType::UInt64,
3086 ExprDataType::Float64 => DataType::Float64,
3087 ExprDataType::Utf8 => DataType::Utf8,
3088 ExprDataType::Binary => DataType::Binary,
3089 ExprDataType::Decimal128 { precision, scale } => DataType::Decimal128(*precision, *scale),
3090 ExprDataType::Date32 => DataType::Date32,
3091 ExprDataType::Timestamp { unit, timezone } => DataType::Timestamp(
3092 match unit {
3093 krishiv_plan::expression::TimeUnit::Second => TimeUnit::Second,
3094 krishiv_plan::expression::TimeUnit::Millisecond => TimeUnit::Millisecond,
3095 krishiv_plan::expression::TimeUnit::Microsecond => TimeUnit::Microsecond,
3096 krishiv_plan::expression::TimeUnit::Nanosecond => TimeUnit::Nanosecond,
3097 },
3098 timezone.clone().map(Into::into),
3099 ),
3100 ExprDataType::Interval { unit } => DataType::Interval(match unit {
3101 PublicIntervalUnit::YearMonth => IntervalUnit::YearMonth,
3102 PublicIntervalUnit::DayTime => IntervalUnit::DayTime,
3103 PublicIntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano,
3104 }),
3105 ExprDataType::List(element) => DataType::List(Arc::new(Field::new(
3106 "item",
3107 public_data_type_to_arrow(element),
3108 true,
3109 ))),
3110 ExprDataType::Map { key, value } => DataType::Map(
3111 Arc::new(Field::new(
3112 "entries",
3113 DataType::Struct(
3114 vec![
3115 Arc::new(Field::new("key", public_data_type_to_arrow(key), false)),
3116 Arc::new(Field::new("value", public_data_type_to_arrow(value), true)),
3117 ]
3118 .into(),
3119 ),
3120 false,
3121 )),
3122 false,
3123 ),
3124 ExprDataType::Struct(fields) => DataType::Struct(
3125 fields
3126 .iter()
3127 .map(|field| {
3128 Arc::new(Field::new(
3129 &field.name,
3130 public_data_type_to_arrow(&field.data_type),
3131 field.nullable,
3132 ))
3133 })
3134 .collect::<Vec<_>>()
3135 .into(),
3136 ),
3137 ExprDataType::Variant => DataType::Utf8,
3142 }
3143}
3144
3145fn public_scalar_to_datafusion(
3146 value: &krishiv_plan::expression::ScalarValue,
3147) -> Option<datafusion::common::ScalarValue> {
3148 use datafusion::common::ScalarValue;
3149 use krishiv_plan::expression::{ScalarValue as PublicScalar, TimeUnit};
3150
3151 Some(match value {
3152 PublicScalar::Null => ScalarValue::Null,
3153 PublicScalar::Boolean(value) => ScalarValue::Boolean(Some(*value)),
3154 PublicScalar::Int64(value) => ScalarValue::Int64(Some(*value)),
3155 PublicScalar::UInt64(value) => ScalarValue::UInt64(Some(*value)),
3156 PublicScalar::Float64(bits) => ScalarValue::Float64(Some(f64::from_bits(*bits))),
3157 PublicScalar::Utf8(value) => ScalarValue::Utf8(Some(value.clone())),
3158 PublicScalar::Binary(value) => ScalarValue::Binary(Some(value.clone())),
3159 PublicScalar::Decimal128 {
3160 value,
3161 precision,
3162 scale,
3163 } => ScalarValue::Decimal128(Some(*value), *precision, *scale),
3164 PublicScalar::Date32(value) => ScalarValue::Date32(Some(*value)),
3165 PublicScalar::Timestamp {
3166 value,
3167 unit,
3168 timezone,
3169 } => {
3170 let timezone = timezone.clone().map(Into::into);
3171 match unit {
3172 TimeUnit::Second => ScalarValue::TimestampSecond(Some(*value), timezone),
3173 TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(Some(*value), timezone),
3174 TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(Some(*value), timezone),
3175 TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(Some(*value), timezone),
3176 }
3177 }
3178 PublicScalar::Interval { .. } => return None,
3179 })
3180}
3181
3182fn lower_public_expression(
3188 dataframe: &datafusion::dataframe::DataFrame,
3189 expression: &krishiv_plan::expression::Expr,
3190) -> SqlResult<datafusion::logical_expr::Expr> {
3191 expression
3192 .validate()
3193 .map_err(|error| SqlError::Unsupported {
3194 feature: format!("invalid public expression: {error}"),
3195 })?;
3196 use datafusion::logical_expr::{Expr as DataFusionExpr, Operator, binary_expr, cast, try_cast};
3197 use krishiv_plan::expression::{BinaryOperator, Expr};
3198
3199 Ok(match expression {
3200 Expr::Column { path } if path.len() == 1 => {
3201 datafusion::prelude::col(path.first().map(String::as_str).unwrap_or(""))
3202 }
3203 Expr::Column { .. } => parse_dataframe_expression(dataframe, &expression.to_sql())?,
3204 Expr::Literal { value } => match public_scalar_to_datafusion(value) {
3205 Some(value) => DataFusionExpr::Literal(value, None),
3206 None => parse_dataframe_expression(dataframe, &expression.to_sql())?,
3207 },
3208 Expr::Alias { expression, name } => {
3209 lower_public_expression(dataframe, expression)?.alias(name)
3210 }
3211 Expr::Binary { left, op, right } => binary_expr(
3212 lower_public_expression(dataframe, left)?,
3213 match op {
3214 BinaryOperator::Eq => Operator::Eq,
3215 BinaryOperator::NotEq => Operator::NotEq,
3216 BinaryOperator::Gt => Operator::Gt,
3217 BinaryOperator::GtEq => Operator::GtEq,
3218 BinaryOperator::Lt => Operator::Lt,
3219 BinaryOperator::LtEq => Operator::LtEq,
3220 BinaryOperator::And => Operator::And,
3221 BinaryOperator::Or => Operator::Or,
3222 BinaryOperator::Plus => Operator::Plus,
3223 BinaryOperator::Minus => Operator::Minus,
3224 BinaryOperator::Multiply => Operator::Multiply,
3225 BinaryOperator::Divide => Operator::Divide,
3226 },
3227 lower_public_expression(dataframe, right)?,
3228 ),
3229 Expr::IsNull {
3230 expression,
3231 negated,
3232 } => {
3233 let expression = lower_public_expression(dataframe, expression)?;
3234 if *negated {
3235 expression.is_not_null()
3236 } else {
3237 expression.is_null()
3238 }
3239 }
3240 Expr::Cast {
3241 expression,
3242 data_type,
3243 safe,
3244 } => {
3245 let expression = lower_public_expression(dataframe, expression)?;
3246 let data_type = public_data_type_to_arrow(data_type);
3247 if *safe {
3248 try_cast(expression, data_type)
3249 } else {
3250 cast(expression, data_type)
3251 }
3252 }
3253 Expr::Sort { .. } => {
3254 return Err(SqlError::Unsupported {
3255 feature: "standalone sort expressions are only valid inside windows or order_by"
3256 .into(),
3257 });
3258 }
3259 Expr::Aggregate { .. }
3260 | Expr::Function { .. }
3261 | Expr::Window { .. }
3262 | Expr::RawSql { .. } => parse_dataframe_expression(dataframe, &expression.to_sql())?,
3263 })
3264}
3265
3266fn sql_dataframe<'a>(
3267 dataframe: &'a dyn KrishivDataFrameOps,
3268 operation: &str,
3269) -> SqlResult<&'a SqlDataFrame> {
3270 dataframe
3271 .as_any()
3272 .downcast_ref::<SqlDataFrame>()
3273 .ok_or_else(|| SqlError::DataFusion {
3274 message: format!("right DataFrame must be SqlDataFrame for {operation}"),
3275 })
3276}
3277
3278#[async_trait::async_trait]
3279impl KrishivDataFrameOps for SqlDataFrame {
3280 async fn collect(&self) -> SqlResult<Vec<RecordBatch>> {
3281 SqlDataFrame::collect(self).await
3282 }
3283 async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)> {
3284 SqlDataFrame::collect_with_stats(self).await
3285 }
3286 async fn explain(&self) -> SqlResult<String> {
3287 SqlDataFrame::explain(self).await
3288 }
3289 fn explain_logical(&self) -> String {
3290 SqlDataFrame::explain_logical(self)
3291 }
3292 fn krishiv_logical_plan(&self) -> LogicalPlan {
3293 let label = self.dataframe.logical_plan().to_string();
3294 let mut plan = LogicalPlan::new(self.name.clone(), ExecutionKind::Batch).with_node(
3295 PlanNode::new("datafusion-logical", label, ExecutionKind::Batch),
3296 );
3297 if let Some(n) = self.shuffle_partitions {
3298 plan = plan.with_shuffle_partitions(Some(n));
3299 }
3300 plan
3301 }
3302 fn query(&self) -> Option<&str> {
3303 SqlDataFrame::query(self)
3304 }
3305 async fn execute_stream(&self) -> SqlResult<SqlStream> {
3306 SqlDataFrame::execute_stream(self).await
3307 }
3308
3309 fn schema(&self) -> SchemaRef {
3312 SchemaRef::from(self.dataframe.schema().clone())
3313 }
3314
3315 async fn select(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3316 let df = self.dataframe.clone().select_columns(columns)?;
3317 Ok(Box::new(self.with_new_dataframe(df, "select")))
3318 }
3319
3320 async fn select_exprs(
3321 &self,
3322 expressions: &[&krishiv_plan::expression::Expr],
3323 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3324 let expressions = expressions
3325 .iter()
3326 .map(|expression| lower_public_expression(&self.dataframe, expression))
3327 .collect::<Result<Vec<_>, _>>()?;
3328 let df = self.dataframe.clone().select(expressions)?;
3329 Ok(Box::new(self.with_new_dataframe(df, "select_exprs")))
3330 }
3331
3332 async fn aggregate(
3333 &self,
3334 group_exprs: &[&krishiv_plan::expression::Expr],
3335 aggregate_exprs: &[&krishiv_plan::expression::Expr],
3336 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3337 if aggregate_exprs.is_empty() {
3338 return Err(SqlError::Unsupported {
3339 feature: "aggregate requires at least one aggregate expression".into(),
3340 });
3341 }
3342 let group_exprs = group_exprs
3343 .iter()
3344 .map(|expression| lower_public_expression(&self.dataframe, expression))
3345 .collect::<Result<Vec<_>, _>>()?;
3346 let aggregate_exprs = aggregate_exprs
3347 .iter()
3348 .map(|expression| lower_public_expression(&self.dataframe, expression))
3349 .collect::<Result<Vec<_>, _>>()?;
3350 let df = self
3351 .dataframe
3352 .clone()
3353 .aggregate(group_exprs, aggregate_exprs)?;
3354 Ok(Box::new(self.with_new_dataframe(df, "aggregate")))
3355 }
3356
3357 async fn aggregate_grouping(
3358 &self,
3359 grouping: GroupingMode<'_>,
3360 aggregate_exprs: &[&krishiv_plan::expression::Expr],
3361 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3362 if aggregate_exprs.is_empty() {
3363 return Err(SqlError::Unsupported {
3364 feature: "grouping aggregation requires at least one aggregate expression".into(),
3365 });
3366 }
3367 let lower = |expression: &&krishiv_plan::expression::Expr| {
3368 lower_public_expression(&self.dataframe, expression)
3369 };
3370 let group = match grouping {
3371 GroupingMode::Sets(sets) => datafusion::logical_expr::grouping_set(
3372 sets.into_iter()
3373 .map(|set| set.iter().map(lower).collect::<Result<Vec<_>, _>>())
3374 .collect::<Result<Vec<_>, _>>()?,
3375 ),
3376 GroupingMode::Cube(expressions) => datafusion::logical_expr::cube(
3377 expressions
3378 .iter()
3379 .map(lower)
3380 .collect::<Result<Vec<_>, _>>()?,
3381 ),
3382 GroupingMode::Rollup(expressions) => datafusion::logical_expr::rollup(
3383 expressions
3384 .iter()
3385 .map(lower)
3386 .collect::<Result<Vec<_>, _>>()?,
3387 ),
3388 };
3389 let aggregates = aggregate_exprs
3390 .iter()
3391 .map(lower)
3392 .collect::<Result<Vec<_>, _>>()?;
3393 let df = self.dataframe.clone().aggregate(vec![group], aggregates)?;
3394 Ok(Box::new(self.with_new_dataframe(df, "aggregate_grouping")))
3395 }
3396
3397 async fn pivot(
3398 &self,
3399 group_exprs: &[&krishiv_plan::expression::Expr],
3400 pivot_column: &krishiv_plan::expression::Expr,
3401 aggregate_expr: &krishiv_plan::expression::Expr,
3402 values: &[(krishiv_plan::expression::ScalarValue, String)],
3403 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3404 use krishiv_plan::expression::Expr as PublicExpr;
3405 let (function, input, distinct) = match aggregate_expr {
3406 PublicExpr::Aggregate {
3407 function,
3408 expression: Some(input),
3409 distinct,
3410 } => (*function, input.as_ref(), *distinct),
3411 _ => {
3412 return Err(SqlError::Unsupported {
3413 feature: "pivot requires an aggregate expression with one input".into(),
3414 });
3415 }
3416 };
3417 if values.is_empty() {
3418 return Err(SqlError::Unsupported {
3419 feature: "pivot requires at least one value".into(),
3420 });
3421 }
3422 let group_exprs = group_exprs
3423 .iter()
3424 .map(|expression| lower_public_expression(&self.dataframe, expression))
3425 .collect::<Result<Vec<_>, _>>()?;
3426 let aggregates = values
3427 .iter()
3428 .map(|(value, alias)| {
3429 let conditional = PublicExpr::raw(format!(
3430 "CASE WHEN {} = {} THEN {} END",
3431 pivot_column.to_sql(),
3432 value.to_sql_literal(),
3433 input.to_sql()
3434 ));
3435 let aggregate = PublicExpr::Aggregate {
3436 function,
3437 expression: Some(Box::new(conditional)),
3438 distinct,
3439 }
3440 .alias(alias);
3441 lower_public_expression(&self.dataframe, &aggregate)
3442 })
3443 .collect::<Result<Vec<_>, _>>()?;
3444 let dataframe = self.dataframe.clone().aggregate(group_exprs, aggregates)?;
3445 Ok(Box::new(self.with_new_dataframe(dataframe, "pivot")))
3446 }
3447
3448 async fn unpivot(
3449 &self,
3450 columns: &[&str],
3451 name_column: &str,
3452 value_column: &str,
3453 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3454 if columns.is_empty() {
3455 return Err(SqlError::Unsupported {
3456 feature: "unpivot requires at least one column".into(),
3457 });
3458 }
3459 let retained = self
3460 .dataframe
3461 .schema()
3462 .fields()
3463 .iter()
3464 .map(|field| field.name().as_str())
3465 .filter(|name| !columns.contains(name))
3466 .collect::<Vec<_>>();
3467 let mut branches = Vec::with_capacity(columns.len());
3468 for column in columns {
3469 let mut expressions = retained
3470 .iter()
3471 .map(|name| datafusion::logical_expr::col(*name))
3472 .collect::<Vec<_>>();
3473 expressions
3474 .push(datafusion::logical_expr::lit((*column).to_owned()).alias(name_column));
3475 expressions.push(datafusion::logical_expr::col(*column).alias(value_column));
3476 branches.push(self.dataframe.clone().select(expressions)?);
3477 }
3478 let mut branches = branches.into_iter();
3479 let Some(mut dataframe) = branches.next() else {
3480 return Err(SqlError::Unsupported {
3481 feature: "unpivot requires at least one branch".into(),
3482 });
3483 };
3484 for branch in branches {
3485 dataframe = dataframe.union(branch)?;
3486 }
3487 Ok(Box::new(self.with_new_dataframe(dataframe, "unpivot")))
3488 }
3489
3490 async fn filter(&self, predicate: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3491 let expr = self.dataframe.parse_sql_expr(predicate)?;
3492 let df = self.dataframe.clone().filter(expr)?;
3493 Ok(Box::new(self.with_new_dataframe(df, "filter")))
3494 }
3495
3496 async fn filter_expr(
3497 &self,
3498 predicate: &krishiv_plan::expression::Expr,
3499 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3500 let expr = lower_public_expression(&self.dataframe, predicate)?;
3501 let df = self.dataframe.clone().filter(expr)?;
3502 Ok(Box::new(self.with_new_dataframe(df, "filter_expr")))
3503 }
3504
3505 async fn limit(&self, n: usize) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3506 let df = self.dataframe.clone().limit(0, Some(n))?;
3507 Ok(Box::new(self.with_new_dataframe(df, "limit")))
3508 }
3509
3510 async fn distinct(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3511 let df = self.dataframe.clone().distinct()?;
3512 Ok(Box::new(self.with_new_dataframe(df, "distinct")))
3513 }
3514
3515 async fn drop_nulls(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3516 let columns = if columns.is_empty() {
3517 self.dataframe
3518 .schema()
3519 .fields()
3520 .iter()
3521 .map(|field| field.name().as_str())
3522 .collect::<Vec<_>>()
3523 } else {
3524 columns.to_vec()
3525 };
3526 let mut predicate: Option<datafusion::logical_expr::Expr> = None;
3527 for column in columns {
3528 let next = datafusion::logical_expr::col(column).is_not_null();
3529 predicate = Some(match predicate {
3530 Some(current) => current.and(next),
3531 None => next,
3532 });
3533 }
3534 let df = match predicate {
3535 Some(predicate) => self.dataframe.clone().filter(predicate)?,
3536 None => self.dataframe.clone(),
3537 };
3538 Ok(Box::new(self.with_new_dataframe(df, "drop_nulls")))
3539 }
3540
3541 async fn sample(&self, fraction: f64) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3542 if !(0.0..=1.0).contains(&fraction) {
3543 return Err(SqlError::Unsupported {
3544 feature: "sample fraction must be between 0 and 1".into(),
3545 });
3546 }
3547 let predicate = self
3548 .dataframe
3549 .parse_sql_expr(&format!("random() < {fraction}"))?;
3550 let df = self.dataframe.clone().filter(predicate)?;
3551 Ok(Box::new(self.with_new_dataframe(df, "sample")))
3552 }
3553
3554 async fn sort(
3555 &self,
3556 columns: &[&str],
3557 descending: &[bool],
3558 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3559 use datafusion::logical_expr::SortExpr;
3560 let exprs: Vec<SortExpr> = columns
3561 .iter()
3562 .zip(descending.iter())
3563 .map(|(col_name, desc)| datafusion::logical_expr::col(*col_name).sort(!desc, *desc))
3564 .collect();
3565 let df = self.dataframe.clone().sort(exprs)?;
3566 Ok(Box::new(self.with_new_dataframe(df, "sort")))
3567 }
3568
3569 async fn alias(&self, alias: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3570 let df = self.dataframe.clone().alias(alias)?;
3571 Ok(Box::new(self.with_new_dataframe(df, "alias")))
3572 }
3573
3574 async fn drop_columns(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3575 let df = self.dataframe.clone().drop_columns(columns)?;
3576 Ok(Box::new(self.with_new_dataframe(df, "drop")))
3577 }
3578
3579 async fn rename_column(&self, old: &str, new: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3580 let df = self.dataframe.clone().with_column_renamed(old, new)?;
3581 Ok(Box::new(self.with_new_dataframe(df, "rename")))
3582 }
3583
3584 async fn with_column(&self, name: &str, expr: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3585 let parsed = self.dataframe.parse_sql_expr(expr)?;
3586 let df = self.dataframe.clone().with_column(name, parsed)?;
3587 Ok(Box::new(self.with_new_dataframe(df, "with_column")))
3588 }
3589
3590 fn as_any(&self) -> &dyn std::any::Any {
3591 self
3592 }
3593
3594 async fn describe(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3595 let df = self.dataframe.clone().describe().await?;
3596 Ok(Box::new(self.with_new_dataframe(df, "describe")))
3597 }
3598
3599 async fn fill_null(
3600 &self,
3601 column: &str,
3602 value: &str,
3603 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3604 let expr = format!("COALESCE({column}, {value})");
3605 let parsed = self.dataframe.parse_sql_expr(&expr)?;
3606 let df = self.dataframe.clone().with_column(column, parsed)?;
3607 Ok(Box::new(self.with_new_dataframe(df, "fill_null")))
3608 }
3609
3610 async fn join(
3611 &self,
3612 right: &dyn KrishivDataFrameOps,
3613 how: &str,
3614 left_on: &[&str],
3615 right_on: &[&str],
3616 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3617 let right_sql = right
3618 .as_any()
3619 .downcast_ref::<SqlDataFrame>()
3620 .ok_or_else(|| SqlError::DataFusion {
3621 message: "right DataFrame must be SqlDataFrame for join".into(),
3622 })?;
3623 use datafusion::common::JoinType;
3624 let join_type = match how.to_lowercase().as_str() {
3625 "inner" => JoinType::Inner,
3626 "left" => JoinType::Left,
3627 "right" => JoinType::Right,
3628 "full" | "outer" => JoinType::Full,
3629 "leftsemi" | "left_semi" => JoinType::LeftSemi,
3630 "rightsemi" | "right_semi" => JoinType::RightSemi,
3631 "leftanti" | "left_anti" => JoinType::LeftAnti,
3632 "rightanti" | "right_anti" => JoinType::RightAnti,
3633 _ => {
3634 return Err(SqlError::DataFusion {
3635 message: format!("unsupported join type: {how}"),
3636 });
3637 }
3638 };
3639 let df = self.dataframe.clone().join(
3640 right_sql.dataframe.clone(),
3641 join_type,
3642 left_on,
3643 right_on,
3644 None,
3645 )?;
3646 Ok(Box::new(self.with_new_dataframe(df, "join")))
3647 }
3648
3649 async fn union(
3650 &self,
3651 right: &dyn KrishivDataFrameOps,
3652 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3653 let right_sql = right
3654 .as_any()
3655 .downcast_ref::<SqlDataFrame>()
3656 .ok_or_else(|| SqlError::DataFusion {
3657 message: "right DataFrame must be SqlDataFrame for union".into(),
3658 })?;
3659 let df = self.dataframe.clone().union(right_sql.dataframe.clone())?;
3660 Ok(Box::new(self.with_new_dataframe(df, "union")))
3661 }
3662
3663 async fn union_distinct(
3664 &self,
3665 right: &dyn KrishivDataFrameOps,
3666 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3667 let right = sql_dataframe(right, "union_distinct")?;
3668 let df = self
3669 .dataframe
3670 .clone()
3671 .union_distinct(right.dataframe.clone())?;
3672 Ok(Box::new(self.with_new_dataframe(df, "union_distinct")))
3673 }
3674
3675 async fn intersect(
3676 &self,
3677 right: &dyn KrishivDataFrameOps,
3678 distinct: bool,
3679 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3680 let right = sql_dataframe(right, "intersect")?;
3681 let df = if distinct {
3682 self.dataframe
3683 .clone()
3684 .intersect_distinct(right.dataframe.clone())?
3685 } else {
3686 self.dataframe.clone().intersect(right.dataframe.clone())?
3687 };
3688 Ok(Box::new(self.with_new_dataframe(df, "intersect")))
3689 }
3690
3691 async fn except(
3692 &self,
3693 right: &dyn KrishivDataFrameOps,
3694 distinct: bool,
3695 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3696 let right = sql_dataframe(right, "except")?;
3697 let df = if distinct {
3698 self.dataframe
3699 .clone()
3700 .except_distinct(right.dataframe.clone())?
3701 } else {
3702 self.dataframe.clone().except(right.dataframe.clone())?
3703 };
3704 Ok(Box::new(self.with_new_dataframe(df, "except")))
3705 }
3706
3707 async fn register_batches(&self, name: &str, batches: Vec<RecordBatch>) -> SqlResult<()> {
3708 let schema = batches
3709 .first()
3710 .map(|b| b.schema())
3711 .unwrap_or_else(|| Arc::new(arrow::datatypes::Schema::empty()));
3712 let mem_table =
3713 datafusion::datasource::MemTable::try_new(schema, vec![batches]).map_err(|e| {
3714 SqlError::DataFusion {
3715 message: e.to_string(),
3716 }
3717 })?;
3718 self.context
3719 .register_table(name, Arc::new(mem_table))
3720 .map_err(SqlError::from)?;
3721 Ok(())
3722 }
3723
3724 async fn deregister_table(&self, name: &str) -> SqlResult<()> {
3725 let _ = self
3726 .context
3727 .deregister_table(name)
3728 .map_err(SqlError::from)?;
3729 Ok(())
3730 }
3731
3732 async fn create_view(&self, name: &str, replace: bool) -> SqlResult<()> {
3733 let query = self
3734 .query_text
3735 .as_deref()
3736 .ok_or_else(|| SqlError::DataFusion {
3737 message: "create_view requires an SQL query string on the DataFrame".into(),
3738 })?;
3739 let or_replace = if replace { "OR REPLACE " } else { "" };
3740 let safe_name = quote_identifier(name);
3741 let view_sql = format!("CREATE {or_replace}VIEW {safe_name} AS {query}");
3742 self.context.sql(&view_sql).await?;
3743 Ok(())
3744 }
3745}
3746
3747use krishiv_common::sql_util::quote_identifier;
3748
3749#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3756fn call_args_from_str(s: &str) -> Vec<String> {
3757 let mut args: Vec<String> = Vec::new();
3758 let mut cur = String::new();
3759 let mut in_str = false;
3760 let mut after_str = false;
3761 for ch in s.chars() {
3762 if after_str {
3763 if ch == ',' {
3764 after_str = false;
3765 }
3766 continue;
3767 }
3768 if in_str {
3769 if ch == '\'' {
3770 in_str = false;
3771 after_str = true;
3772 args.push(std::mem::take(&mut cur));
3773 } else {
3774 cur.push(ch);
3775 }
3776 } else if ch == '\'' {
3777 in_str = true;
3778 } else if ch == ',' {
3779 let t = cur.trim().to_string();
3780 if !t.is_empty() {
3781 args.push(t);
3782 }
3783 cur.clear();
3784 } else {
3785 cur.push(ch);
3786 }
3787 }
3788 let t = cur.trim().to_string();
3789 if !t.is_empty() {
3790 args.push(t);
3791 }
3792 args
3793}
3794
3795#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3802fn iceberg_table_ident(table_ref: &str) -> SqlResult<iceberg::TableIdent> {
3803 let parts: Vec<&str> = table_ref.splitn(3, '.').collect();
3804 match parts.len() {
3805 2 => {
3806 let ns = iceberg::NamespaceIdent::from_vec(vec![
3807 parts.first().copied().unwrap_or("").to_string(),
3808 ])
3809 .map_err(|e| SqlError::DataFusion {
3810 message: e.to_string(),
3811 })?;
3812 Ok(iceberg::TableIdent::new(
3813 ns,
3814 parts.get(1).copied().unwrap_or("").to_string(),
3815 ))
3816 }
3817 3 => {
3818 let ns = iceberg::NamespaceIdent::from_vec(vec![
3819 parts.get(1).copied().unwrap_or("").to_string(),
3820 ])
3821 .map_err(|e| SqlError::DataFusion {
3822 message: e.to_string(),
3823 })?;
3824 Ok(iceberg::TableIdent::new(
3825 ns,
3826 parts.get(2).copied().unwrap_or("").to_string(),
3827 ))
3828 }
3829 _ => Err(SqlError::DataFusion {
3830 message: format!(
3831 "invalid table reference '{table_ref}': expected 'ns.table' or 'cat.ns.table'"
3832 ),
3833 }),
3834 }
3835}
3836
3837#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3842fn parse_call_duration(s: &str) -> SqlResult<chrono::Duration> {
3843 let s = s.trim();
3844 let mut it = s.splitn(2, ' ');
3845 let n: i64 = it
3846 .next()
3847 .and_then(|v| v.parse().ok())
3848 .ok_or_else(|| SqlError::DataFusion {
3849 message: format!("invalid duration value in '{s}'"),
3850 })?;
3851 let unit = it.next().unwrap_or("").trim().to_ascii_lowercase();
3852 match unit.trim_end_matches('s') {
3853 "day" => Ok(chrono::Duration::days(n)),
3854 "hour" => Ok(chrono::Duration::hours(n)),
3855 "week" => Ok(chrono::Duration::weeks(n)),
3856 "minute" | "min" => Ok(chrono::Duration::minutes(n)),
3857 _ => Err(SqlError::DataFusion {
3858 message: format!("unknown duration unit '{unit}' in '{s}'"),
3859 }),
3860 }
3861}
3862
3863#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3871fn parse_dml_delete(stmt: &str) -> Option<(String, String)> {
3872 use datafusion::sql::sqlparser::ast::{FromTable, Statement, TableFactor};
3873 use datafusion::sql::sqlparser::dialect::GenericDialect;
3874 use datafusion::sql::sqlparser::parser::Parser;
3875
3876 let mut stmts = Parser::parse_sql(&GenericDialect {}, stmt).ok()?;
3877 if stmts.len() != 1 {
3878 return None;
3879 }
3880 let Statement::Delete(delete) = stmts.remove(0) else {
3881 return None;
3882 };
3883 let tables = match delete.from {
3886 FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => tables,
3887 };
3888 let first_from = tables.into_iter().next()?;
3889 let table_name = match first_from.relation {
3890 TableFactor::Table { name, .. } => name.to_string(),
3891 _ => return None,
3892 };
3893 let predicate = delete
3894 .selection
3895 .map(|e| e.to_string())
3896 .unwrap_or_else(|| "TRUE".to_string());
3897 Some((table_name, predicate))
3898}
3899
3900#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3902struct ParsedUpdate {
3903 table_ref: String,
3904 assignments: Vec<(String, String)>,
3906 predicate: Option<String>,
3907}
3908
3909#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3915fn parse_dml_update(stmt: &str) -> Option<ParsedUpdate> {
3916 use datafusion::sql::sqlparser::ast::{Statement, TableFactor};
3917 use datafusion::sql::sqlparser::dialect::GenericDialect;
3918 use datafusion::sql::sqlparser::parser::Parser;
3919
3920 let mut stmts = Parser::parse_sql(&GenericDialect {}, stmt).ok()?;
3921 if stmts.len() != 1 {
3922 return None;
3923 }
3924 let Statement::Update(update) = stmts.remove(0) else {
3926 return None;
3927 };
3928 let table_name = match update.table.relation {
3929 TableFactor::Table { name, .. } => name.to_string(),
3930 _ => return None,
3931 };
3932 let parsed_assignments: Vec<(String, String)> = update
3934 .assignments
3935 .into_iter()
3936 .map(|a| {
3937 let col = a.target.to_string();
3939 let val = a.value.to_string();
3940 (col, val)
3941 })
3942 .collect();
3943 if parsed_assignments.is_empty() {
3944 return None;
3945 }
3946 Some(ParsedUpdate {
3947 table_ref: table_name,
3948 assignments: parsed_assignments,
3949 predicate: update.selection.map(|e| e.to_string()),
3950 })
3951}
3952
3953pub fn plan_sql(query: impl Into<String>) -> SqlResult<SqlPlan> {
3955 let query = query.into();
3956 if query.trim().is_empty() {
3957 return Err(SqlError::EmptyQuery);
3958 }
3959
3960 if let Some(stmt) = cep_sql::parse_match_recognize(&query)? {
3961 let logical_plan = cep_sql::plan_match_recognize(stmt, &query);
3962 let optimized = Optimizer::default().optimize(logical_plan)?;
3963 return Ok(SqlPlan {
3964 query,
3965 logical_plan: optimized.plan,
3966 });
3967 }
3968
3969 let logical_plan =
3970 LogicalPlan::new("sql-query", ExecutionKind::Batch).with_node(PlanNode::new(
3971 "sql",
3972 format!("sql: {}", query.trim()),
3973 ExecutionKind::Batch,
3974 ));
3975
3976 let optimized = Optimizer::default().optimize(logical_plan)?;
3977 Ok(SqlPlan {
3978 query,
3979 logical_plan: optimized.plan,
3980 })
3981}
3982
3983pub fn explain_sql(query: impl Into<String>) -> SqlResult<String> {
3985 let plan = plan_sql(query)?;
3986 Ok(plan.logical_plan().describe())
3987}
3988
3989pub fn explain_sql_optimized(query: impl Into<String>, optimizer: &Optimizer) -> SqlResult<String> {
3994 let plan = plan_sql(query)?;
3995 let result = optimizer.optimize(plan.logical_plan().clone())?;
3996 let mut output = result.plan.describe();
3997 let optimizer_line = result.describe();
3998 output.push('\n');
3999 output.push_str(&optimizer_line);
4000 Ok(output)
4001}
4002
4003pub fn explain_sql_with_cost(
4005 query: impl Into<String>,
4006 cost_model: &dyn CostModel,
4007) -> SqlResult<String> {
4008 let plan = plan_sql(query)?;
4009 let cost = cost_model.estimate(plan.logical_plan());
4010 let mut output = plan.logical_plan().describe();
4011 output.push_str(&format!(
4012 "\ncost: cpu_nanos={}, memory_bytes={}, network_bytes={}",
4013 cost.cpu_nanos, cost.memory_bytes, cost.network_bytes
4014 ));
4015 Ok(output)
4016}
4017
4018pub fn referenced_table_names(query: impl AsRef<str>) -> SqlResult<Vec<String>> {
4024 let query = query.as_ref();
4025 if query.trim().is_empty() {
4026 return Err(SqlError::EmptyQuery);
4027 }
4028
4029 let statements =
4030 Parser::parse_sql(&GenericDialect {}, query).map_err(|e| SqlError::DataFusion {
4031 message: format!("SQL parse error: {e}"),
4032 })?;
4033 let mut names = BTreeSet::new();
4034 let _ = visit_relations(&statements, |relation| {
4035 names.insert(relation.to_string());
4036 ControlFlow::<()>::Continue(())
4037 });
4038 Ok(names.into_iter().collect())
4039}
4040
4041pub fn pretty_batches(batches: &[RecordBatch]) -> SqlResult<String> {
4043 Ok(pretty_format_batches(batches)
4044 .map_err(|error| SqlError::DataFusion {
4045 message: error.to_string(),
4046 })?
4047 .to_string())
4048}
4049
4050#[cfg(test)]
4051mod sql_tests;