1#![forbid(unsafe_code)]
2
3use std::collections::{BTreeSet, HashMap, VecDeque};
9use std::fmt;
10use std::num::NonZeroUsize;
11use std::ops::ControlFlow;
12use std::path::Path;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::{Arc, Mutex, RwLock};
15
16use arrow::datatypes::SchemaRef;
17use arrow::record_batch::RecordBatch;
18use arrow::util::pretty::pretty_format_batches;
19use catalog::{InMemoryCatalog, datafusion_bridge::DataFusionCatalogBridge};
20use datafusion::dataframe::DataFrame as DataFusionDataFrame;
21use datafusion::prelude::{ParquetReadOptions, SessionContext};
22use datafusion::sql::sqlparser::{ast::visit_relations, dialect::GenericDialect, parser::Parser};
23use object_store::aws::AmazonS3Builder;
24
25use krishiv_plan::optimizer::{CostModel, Optimizer};
26use krishiv_plan::{ExecutionKind, LogicalPlan, PlanNode};
27
28pub mod analyze;
29pub mod catalog;
30pub mod cep_sql;
31
32pub mod connector_table;
33pub mod create_function_ddl;
34pub mod grammar;
35pub mod incremental_view;
36pub mod introspection_sql;
37
38pub mod kafka_table;
39pub mod lakehouse;
40pub mod live_table;
41pub mod pipeline_ddl;
42pub mod pivot_sql;
43pub mod recursive_cte;
44pub mod spark_sql_ext;
46pub mod sqlstate;
47pub mod subquery;
48pub mod unnest_sql;
49
50pub mod streaming;
51pub mod streaming_tvf;
52pub mod streaming_window_plan;
53mod udf;
54mod window_functions;
55
56pub use cep_sql::{
57 MatchRecognizeStatement, execute_streaming_match_recognize, parse_match_recognize,
58};
59pub use lakehouse::{AsOfTableRef, MergeResult, MergeTargetUnsupportedError, preprocess_as_of_sql};
60
61pub use grammar::{
62 FeatureEntry, FeatureStatus, feature_matrix, features_by_status, features_for_category,
63};
64pub use sqlstate::{SqlStateError, sqlstate_for};
65pub use streaming::{ContinuousInputError, ContinuousTableInput};
66
67pub type SqlResult<T> = Result<T, SqlError>;
69
70pub type SqlStream =
76 std::pin::Pin<Box<dyn futures::stream::Stream<Item = Result<RecordBatch, SqlError>> + Send>>;
77
78static EPHEMERAL_TABLE_COUNTER: AtomicU64 = AtomicU64::new(0);
81
82fn next_ephemeral_name(prefix: &str) -> String {
83 let id = EPHEMERAL_TABLE_COUNTER.fetch_add(1, Ordering::Relaxed);
84 format!("__{prefix}_{id}")
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92enum WindowFnRegistration {
93 Register,
95 Skip,
99}
100
101struct PlanCache {
107 map: HashMap<String, datafusion::logical_expr::LogicalPlan>,
108 order: VecDeque<String>,
109 max: usize,
110}
111
112impl PlanCache {
113 fn new(max: usize) -> Self {
114 Self {
115 map: HashMap::new(),
116 order: VecDeque::new(),
117 max,
118 }
119 }
120
121 fn get(&self, key: &str) -> Option<&datafusion::logical_expr::LogicalPlan> {
122 self.map.get(key)
123 }
124
125 fn insert(&mut self, key: String, plan: datafusion::logical_expr::LogicalPlan) {
126 if self.map.contains_key(&key) {
127 self.order.retain(|k| k != &key);
130 } else if self.map.len() >= self.max
131 && let Some(oldest) = self.order.pop_front()
132 {
133 self.map.remove(&oldest);
134 }
135 self.order.push_back(key.clone());
136 self.map.insert(key, plan);
137 }
138
139 fn clear(&mut self) {
140 self.map.clear();
141 self.order.clear();
142 }
143
144 #[cfg(test)]
145 fn is_empty(&self) -> bool {
146 self.map.is_empty()
147 }
148}
149
150#[derive(Debug, Clone, Default)]
152pub struct ParquetReaderOptions {
153 pub batch_size: Option<usize>,
155}
156
157#[derive(Debug, Clone, Default)]
159pub struct CsvReaderOptions {
160 pub delimiter: Option<char>,
162 pub has_header: Option<bool>,
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct ParquetWriterOptions {
169 pub compression: Option<String>,
171 pub max_row_group_size: Option<usize>,
173}
174
175#[derive(Debug, Clone, Default)]
177pub struct CsvWriterOptions {
178 pub delimiter: Option<char>,
180 pub has_header: Option<bool>,
182}
183
184#[non_exhaustive]
186#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
187pub enum SqlError {
188 #[error("SQL query is empty")]
190 EmptyQuery,
191 #[error("table name is empty")]
193 EmptyTableName,
194 #[error("unsupported SQL feature: {feature}")]
196 Unsupported { feature: String },
197 #[error("invalid table function: {message}")]
199 InvalidTableFunction { message: String },
200 #[error("DataFusion error: {message}")]
202 DataFusion { message: String },
203 #[error(transparent)]
205 Optimizer(#[from] krishiv_plan::optimizer::OptimizerError),
206 #[error("access denied: {reason}")]
208 AccessDenied { reason: String },
209 #[error("operation {operation_id} was cancelled")]
211 OperationCancelled { operation_id: u64 },
212 #[error("query timed out after {timeout_ms} ms")]
214 Timeout { timeout_ms: u64 },
215}
216
217impl From<datafusion::error::DataFusionError> for SqlError {
218 fn from(value: datafusion::error::DataFusionError) -> Self {
219 Self::DataFusion {
220 message: value.to_string(),
221 }
222 }
223}
224
225#[derive(Debug, Clone, PartialEq, Eq)]
227pub struct SqlPlan {
228 query: String,
229 logical_plan: LogicalPlan,
230}
231
232impl SqlPlan {
233 pub fn query(&self) -> &str {
235 &self.query
236 }
237
238 pub fn logical_plan(&self) -> &LogicalPlan {
240 &self.logical_plan
241 }
242}
243
244const PLAN_CACHE_MAX_ENTRIES: usize = 256;
256
257fn resolve_plan_cache_max_entries() -> usize {
258 std::env::var("KRISHIV_PLAN_CACHE_MAX_ENTRIES")
259 .ok()
260 .and_then(|v| v.parse().ok())
261 .filter(|&n| n > 0)
262 .unwrap_or(PLAN_CACHE_MAX_ENTRIES)
263}
264const STREAMING_CEP_MAX_ROWS_DEFAULT: usize = 100_000;
265
266pub fn resolve_streaming_match_recognize_limit(raw: Option<&str>) -> usize {
270 raw.and_then(|s| s.parse::<usize>().ok())
271 .filter(|n| *n > 0)
272 .unwrap_or(STREAMING_CEP_MAX_ROWS_DEFAULT)
273}
274
275pub fn streaming_match_recognize_limit_from_env() -> usize {
278 resolve_streaming_match_recognize_limit(
279 std::env::var("KRISHIV_MATCH_RECOGNIZE_STREAMING_LIMIT")
280 .ok()
281 .as_deref(),
282 )
283}
284
285pub fn resolve_query_memory_limit_bytes(raw: Option<&str>) -> Option<usize> {
289 raw.and_then(|s| s.trim().parse::<usize>().ok())
290 .filter(|n| *n > 0)
291}
292
293pub fn query_memory_limit_from_env() -> Option<usize> {
296 resolve_query_memory_limit_bytes(
297 std::env::var("KRISHIV_QUERY_MEMORY_LIMIT_BYTES")
298 .ok()
299 .as_deref(),
300 )
301}
302
303pub fn batch_size_from_env() -> usize {
307 std::env::var("KRISHIV_BATCH_SIZE")
308 .ok()
309 .and_then(|v| v.parse::<usize>().ok())
310 .filter(|n| *n > 0)
311 .unwrap_or(8192)
312}
313
314pub fn default_parallelism_from_env() -> NonZeroUsize {
318 std::env::var("KRISHIV_TARGET_PARALLELISM")
319 .ok()
320 .and_then(|v| v.parse::<usize>().ok())
321 .and_then(NonZeroUsize::new)
322 .unwrap_or_else(|| std::thread::available_parallelism().unwrap_or(NonZeroUsize::MIN))
323}
324
325fn build_single_node_session_config(
333 target_partitions: NonZeroUsize,
334) -> datafusion::prelude::SessionConfig {
335 let tp = target_partitions.get();
336 let batch_size = batch_size_from_env();
337 let mut config = datafusion::prelude::SessionConfig::new()
338 .with_target_partitions(tp)
339 .with_batch_size(batch_size)
340 .with_information_schema(true)
341 .set_bool(
342 "datafusion.optimizer.enable_round_robin_repartition",
343 tp > 1,
344 );
345 config.options_mut().execution.parquet.pushdown_filters = true;
346 config.options_mut().execution.parquet.enable_page_index = true;
347 config
348}
349
350#[derive(Clone)]
351pub struct SqlEngine {
352 context: SessionContext,
353 target_parallelism: NonZeroUsize,
354 krishiv_catalog: Option<Arc<RwLock<InMemoryCatalog>>>,
355 udf_registry: Option<std::sync::Arc<std::sync::RwLock<krishiv_plan::udf::UdfRegistry>>>,
356 streaming_sources: Arc<RwLock<std::collections::HashSet<String>>>,
359 streaming_registration: Arc<Mutex<()>>,
361 has_streaming_sources: Arc<AtomicBool>,
366 udf_limits: Option<krishiv_plan::udf::ResourceLimits>,
369 udf_registry_version: Arc<AtomicU64>,
373 udf_last_synced_version: Arc<AtomicU64>,
376 plan_cache: Arc<Mutex<PlanCache>>,
382 shuffle_partitions: Arc<std::sync::RwLock<Option<u32>>>,
385 table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
390 memory_limit_bytes: Option<usize>,
395 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
399 iceberg_catalogs: Arc<std::sync::RwLock<Vec<(Arc<catalog::unified::KrishivCatalog>, String)>>>,
400 live_table_registry: Arc<live_table::LiveTableRegistry>,
402 incremental_view_registry: Arc<incremental_view::IncrementalViewRegistry>,
404 pipeline_registry: Arc<pipeline_ddl::PipelineRegistry>,
406 operation_registry: Arc<OperationRegistry>,
408}
409
410impl fmt::Debug for SqlEngine {
411 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
412 f.debug_struct("SqlEngine")
413 .field("backend", &"datafusion")
414 .finish_non_exhaustive()
415 }
416}
417
418impl Default for SqlEngine {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424impl SqlEngine {
425 pub fn new() -> Self {
436 Self::new_with_memory_limit(query_memory_limit_from_env())
437 }
438
439 pub fn new_with_memory_limit(memory_limit_bytes: Option<usize>) -> Self {
451 match Self::build_local(
452 None,
453 WindowFnRegistration::Register,
454 NonZeroUsize::MIN,
455 memory_limit_bytes,
456 ) {
457 Ok(engine) => engine,
458 Err(err) => {
459 tracing::warn!(
460 error = %err,
461 "SqlEngine::new: window helper UDF registration failed; \
462 window SQL functions will be unavailable, other queries are unaffected"
463 );
464 Self::build_local(
465 None,
466 WindowFnRegistration::Skip,
467 NonZeroUsize::MIN,
468 memory_limit_bytes,
469 )
470 .unwrap_or_else(|err| {
471 tracing::error!(
472 error = %err,
473 "memory-limited DataFusion runtime construction failed; \
474 falling back to an unbounded engine"
475 );
476 Self::build_local(None, WindowFnRegistration::Skip, NonZeroUsize::MIN, None)
477 .unwrap_or_else(|_| Self::build_absolute_minimal(NonZeroUsize::MIN))
478 })
479 }
480 }
481 }
482
483 pub fn try_new() -> SqlResult<Self> {
488 Self::build_local(
489 None,
490 WindowFnRegistration::Register,
491 NonZeroUsize::MIN,
492 query_memory_limit_from_env(),
493 )
494 }
495
496 pub fn with_in_memory_catalog(catalog: Arc<RwLock<InMemoryCatalog>>) -> SqlResult<Self> {
498 if krishiv_common::profile_requires_fail_closed_metadata(
499 krishiv_common::resolve_durability_profile(),
500 ) {
501 return Err(SqlError::DataFusion {
502 message: String::from(
503 "InMemoryCatalog is dev-only; configure a durable REST or file-backed \
504 catalog for production deployments",
505 ),
506 });
507 }
508 Self::build_local(
509 Some(catalog),
510 WindowFnRegistration::Register,
511 NonZeroUsize::MIN,
512 query_memory_limit_from_env(),
513 )
514 }
515
516 #[must_use]
522 pub fn with_target_parallelism(mut self, n: NonZeroUsize) -> Self {
523 self.target_parallelism = n;
524 self
525 }
526
527 pub fn target_parallelism(&self) -> NonZeroUsize {
529 self.target_parallelism
530 }
531
532 pub fn memory_limit_bytes(&self) -> Option<usize> {
534 self.memory_limit_bytes
535 }
536
537 pub fn shuffle_partitions(&self) -> Option<u32> {
539 *self
540 .shuffle_partitions
541 .read()
542 .unwrap_or_else(|e| e.into_inner())
543 }
544
545 pub fn table_row_counts(&self) -> Arc<std::sync::RwLock<HashMap<String, u64>>> {
551 Arc::clone(&self.table_row_counts)
552 }
553
554 pub fn registered_table_names(&self) -> Vec<String> {
560 let mut names = Vec::new();
561 for catalog_name in self.context.catalog_names() {
562 let Some(catalog) = self.context.catalog(&catalog_name) else {
563 continue;
564 };
565 for schema_name in catalog.schema_names() {
566 let Some(schema) = catalog.schema(&schema_name) else {
567 continue;
568 };
569 names.extend(schema.table_names());
570 }
571 }
572 names.sort();
573 names.dedup();
574 names
575 }
576
577 fn make_sql_df(&self, name: &str, dataframe: DataFusionDataFrame) -> SqlDataFrame {
580 SqlDataFrame::new(name, dataframe, self.table_row_counts())
581 .with_context(self.context.clone())
582 }
583
584 fn attach_query_metadata(&self, df: SqlDataFrame, query: &str) -> SqlDataFrame {
586 let kind = if self.is_streaming_query(query).unwrap_or(false) {
587 ExecutionKind::Streaming
588 } else {
589 ExecutionKind::Batch
590 };
591 df.with_query(query).with_execution_kind(kind)
592 }
593
594 #[must_use]
599 pub fn with_shuffle_partitions(self, n: Option<u32>) -> Self {
600 if let Ok(mut guard) = self.shuffle_partitions.write() {
601 *guard = n;
602 }
603 self
604 }
605
606 fn build_local(
616 krishiv_catalog: Option<Arc<RwLock<InMemoryCatalog>>>,
617 window_fn_registration: WindowFnRegistration,
618 target_partitions: NonZeroUsize,
619 memory_limit_bytes: Option<usize>,
620 ) -> SqlResult<Self> {
621 let streaming_sources: Arc<RwLock<std::collections::HashSet<String>>> =
625 Arc::new(RwLock::new(std::collections::HashSet::new()));
626
627 let dummy_state = datafusion::execution::session_state::SessionStateBuilder::new()
628 .with_default_features()
629 .build();
630 let mut table_factories = dummy_state.table_factories().clone();
631 crate::connector_table::register_connector_table_factories(
632 &mut table_factories,
633 streaming_sources.clone(),
634 );
635 let mut state_builder = datafusion::execution::session_state::SessionStateBuilder::new()
636 .with_default_features()
637 .with_config(build_single_node_session_config(target_partitions))
638 .with_table_factories(table_factories);
639 if let Some(limit) = memory_limit_bytes {
640 let runtime_env = datafusion::execution::runtime_env::RuntimeEnvBuilder::new()
645 .with_memory_pool(Arc::new(
646 datafusion::execution::memory_pool::FairSpillPool::new(limit),
647 ))
648 .build_arc()
649 .map_err(|e| SqlError::DataFusion {
650 message: format!(
651 "failed to build memory-limited DataFusion runtime \
652 (limit {limit} bytes): {e}"
653 ),
654 })?;
655 state_builder = state_builder.with_runtime_env(runtime_env);
656 }
657 let state = state_builder.build();
658 let context = SessionContext::new_with_state(state);
659 if let Some(catalog) = &krishiv_catalog {
660 context.register_catalog(
661 "krishiv",
662 Arc::new(DataFusionCatalogBridge::new(catalog.clone())),
663 );
664 }
665 if matches!(window_fn_registration, WindowFnRegistration::Register) {
666 window_functions::register_window_functions(&context).map_err(|e| {
667 SqlError::DataFusion {
668 message: format!("failed to register window helper UDFs: {e}"),
669 }
670 })?;
671 }
672 Ok(Self {
673 context,
674 target_parallelism: target_partitions,
675 krishiv_catalog,
676 udf_registry: None,
677 streaming_sources,
678 streaming_registration: Arc::new(Mutex::new(())),
679 has_streaming_sources: Arc::new(AtomicBool::new(false)),
680 udf_limits: None,
681 udf_registry_version: Arc::new(AtomicU64::new(0)),
682 udf_last_synced_version: Arc::new(AtomicU64::new(u64::MAX)),
683 plan_cache: Arc::new(Mutex::new(PlanCache::new(resolve_plan_cache_max_entries()))),
684 shuffle_partitions: Arc::new(std::sync::RwLock::new(None)),
685 table_row_counts: Arc::new(std::sync::RwLock::new(HashMap::new())),
686 memory_limit_bytes,
687 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
688 iceberg_catalogs: Arc::new(std::sync::RwLock::new(Vec::new())),
689 live_table_registry: Arc::new(live_table::LiveTableRegistry::new()),
690 incremental_view_registry: Arc::new(incremental_view::IncrementalViewRegistry::new()),
691 pipeline_registry: Arc::new(pipeline_ddl::PipelineRegistry::new()),
692 operation_registry: Arc::new(OperationRegistry::new()),
693 })
694 }
695
696 fn build_absolute_minimal(target_partitions: NonZeroUsize) -> Self {
700 let streaming_sources: Arc<RwLock<std::collections::HashSet<String>>> =
701 Arc::new(RwLock::new(std::collections::HashSet::new()));
702 let dummy_state = datafusion::execution::session_state::SessionStateBuilder::new()
703 .with_default_features()
704 .build();
705 let mut table_factories = dummy_state.table_factories().clone();
706 crate::connector_table::register_connector_table_factories(
707 &mut table_factories,
708 streaming_sources.clone(),
709 );
710 let state = datafusion::execution::session_state::SessionStateBuilder::new()
711 .with_default_features()
712 .with_config(build_single_node_session_config(target_partitions))
713 .with_table_factories(table_factories)
714 .build();
715 let context = SessionContext::new_with_state(state);
716 Self {
717 context,
718 target_parallelism: target_partitions,
719 krishiv_catalog: None,
720 udf_registry: None,
721 streaming_sources,
722 streaming_registration: Arc::new(Mutex::new(())),
723 has_streaming_sources: Arc::new(AtomicBool::new(false)),
724 udf_limits: None,
725 udf_registry_version: Arc::new(AtomicU64::new(0)),
726 udf_last_synced_version: Arc::new(AtomicU64::new(u64::MAX)),
727 plan_cache: Arc::new(Mutex::new(PlanCache::new(resolve_plan_cache_max_entries()))),
728 shuffle_partitions: Arc::new(std::sync::RwLock::new(None)),
729 table_row_counts: Arc::new(std::sync::RwLock::new(HashMap::new())),
730 memory_limit_bytes: None,
731 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
732 iceberg_catalogs: Arc::new(std::sync::RwLock::new(Vec::new())),
733 live_table_registry: Arc::new(live_table::LiveTableRegistry::new()),
734 incremental_view_registry: Arc::new(incremental_view::IncrementalViewRegistry::new()),
735 pipeline_registry: Arc::new(pipeline_ddl::PipelineRegistry::new()),
736 operation_registry: Arc::new(OperationRegistry::new()),
737 }
738 }
739
740 pub fn register_streaming_table(
751 &self,
752 name: &str,
753 schema: arrow::datatypes::SchemaRef,
754 ) -> SqlResult<Arc<ContinuousTableInput>> {
755 let _registration = self.lock_streaming_registration()?;
756 self.validate_new_streaming_table(name, &schema)?;
757 let (table, input) = crate::streaming::create_continuous_table(schema).map_err(|e| {
758 SqlError::DataFusion {
759 message: e.to_string(),
760 }
761 })?;
762 self.register_new_streaming_provider(name, table)?;
763 self.streaming_sources
764 .write()
765 .unwrap_or_else(|e| e.into_inner())
766 .insert(name.to_string());
767 self.has_streaming_sources.store(true, Ordering::Release);
768 self.invalidate_plan_cache();
769 Ok(input)
770 }
771
772 pub fn register_streaming_table_with_capacity(
777 &self,
778 name: &str,
779 schema: arrow::datatypes::SchemaRef,
780 capacity: usize,
781 ) -> SqlResult<Arc<ContinuousTableInput>> {
782 let _registration = self.lock_streaming_registration()?;
783 self.validate_new_streaming_table(name, &schema)?;
784 let (table, input) = crate::streaming::create_continuous_table_with_capacity(
785 schema, capacity,
786 )
787 .map_err(|e| SqlError::DataFusion {
788 message: e.to_string(),
789 })?;
790 self.register_new_streaming_provider(name, table)?;
791 self.streaming_sources
792 .write()
793 .unwrap_or_else(|e| e.into_inner())
794 .insert(name.to_string());
795 self.has_streaming_sources.store(true, Ordering::Release);
796 self.invalidate_plan_cache();
797 Ok(input)
798 }
799
800 fn lock_streaming_registration(&self) -> SqlResult<std::sync::MutexGuard<'_, ()>> {
801 self.streaming_registration
802 .lock()
803 .map_err(|error| SqlError::DataFusion {
804 message: format!("streaming table registration lock poisoned: {error}"),
805 })
806 }
807
808 fn validate_new_streaming_table(
809 &self,
810 name: &str,
811 schema: &arrow::datatypes::SchemaRef,
812 ) -> SqlResult<()> {
813 if name.trim().is_empty() {
814 return Err(SqlError::EmptyTableName);
815 }
816 if schema.fields().is_empty() {
817 return Err(SqlError::DataFusion {
818 message: "streaming table schema must contain at least one field".into(),
819 });
820 }
821 if self
822 .context
823 .table_exist(name)
824 .map_err(|error| SqlError::DataFusion {
825 message: error.to_string(),
826 })?
827 {
828 return Err(SqlError::DataFusion {
829 message: format!("table '{name}' is already registered"),
830 });
831 }
832 Ok(())
833 }
834
835 fn register_new_streaming_provider(
836 &self,
837 name: &str,
838 table: Arc<dyn datafusion::catalog::TableProvider>,
839 ) -> SqlResult<()> {
840 let previous =
841 self.context
842 .register_table(name, table)
843 .map_err(|error| SqlError::DataFusion {
844 message: error.to_string(),
845 })?;
846 if let Some(previous) = previous {
847 self.context
848 .register_table(name, previous)
849 .map_err(|error| SqlError::DataFusion {
850 message: format!(
851 "table '{name}' was concurrently registered and could not be restored: \
852 {error}"
853 ),
854 })?;
855 return Err(SqlError::DataFusion {
856 message: format!("table '{name}' was concurrently registered"),
857 });
858 }
859 Ok(())
860 }
861
862 pub fn register_kafka_source(
876 &self,
877 table_name: impl AsRef<str>,
878 schema: arrow::datatypes::SchemaRef,
879 bootstrap_servers: impl Into<String>,
880 topic: impl Into<String>,
881 group_id: impl Into<String>,
882 ) -> SqlResult<()> {
883 let table_name = table_name.as_ref();
884 if table_name.trim().is_empty() {
885 return Err(SqlError::EmptyTableName);
886 }
887 let config = krishiv_connectors::kafka::KafkaConfig {
888 bootstrap_servers: bootstrap_servers.into(),
889 topic: topic.into(),
890 group_id: group_id.into(),
891 auto_commit_interval_ms: {
892 let profile = krishiv_common::resolve_durability_profile();
893 if krishiv_common::requires_manual_kafka_commit(profile) {
894 None
895 } else {
896 Some(1_000)
897 }
898 },
899 security_protocol: None,
900 ssl_ca_location: None,
901 ssl_certificate_location: None,
902 ssl_key_location: None,
903 ssl_key_password: None,
904 sasl_username: None,
905 sasl_password: None,
906 sasl_mechanisms: None,
907 enable_idempotence: None,
908 transactional_id: None,
909 };
910 let table =
911 crate::kafka_table::create_kafka_streaming_table(schema, config).map_err(|e| {
912 SqlError::DataFusion {
913 message: e.to_string(),
914 }
915 })?;
916 if self
917 .context
918 .table_exist(table_name)
919 .map_err(SqlError::from)?
920 {
921 let _ = self
922 .context
923 .deregister_table(table_name)
924 .map_err(SqlError::from)?;
925 }
926 self.context
927 .register_table(table_name, table)
928 .map_err(|e| SqlError::DataFusion {
929 message: e.to_string(),
930 })?;
931 self.streaming_sources
932 .write()
933 .unwrap_or_else(|e| e.into_inner())
934 .insert(table_name.to_string());
935 self.has_streaming_sources.store(true, Ordering::Release);
936 self.invalidate_plan_cache();
937 Ok(())
938 }
939
940 pub async fn sql_to_kafka(
950 &self,
951 sql: impl AsRef<str>,
952 bootstrap_servers: impl Into<String>,
953 topic: impl Into<String>,
954 ) -> SqlResult<u64> {
955 use futures::StreamExt;
956 use krishiv_connectors::Sink as _;
957 use krishiv_connectors::kafka::{KafkaConfig, KafkaSink};
958
959 let config = KafkaConfig {
960 bootstrap_servers: bootstrap_servers.into(),
961 topic: topic.into(),
962 group_id: "krishiv-sql-writer".into(),
963 auto_commit_interval_ms: None,
964 security_protocol: None,
965 ssl_ca_location: None,
966 ssl_certificate_location: None,
967 ssl_key_location: None,
968 ssl_key_password: None,
969 sasl_username: None,
970 sasl_password: None,
971 sasl_mechanisms: None,
972 enable_idempotence: None,
973 transactional_id: None,
974 };
975 let mut sink = KafkaSink::new(config).map_err(|e| SqlError::DataFusion {
976 message: e.to_string(),
977 })?;
978
979 let df = self.sql(sql.as_ref()).await?;
980 let mut stream = df.execute_stream().await?;
981 let mut total_rows = 0u64;
982
983 while let Some(result) = stream.next().await {
984 let batch = result.map_err(|e| SqlError::DataFusion {
985 message: e.to_string(),
986 })?;
987 if batch.num_rows() > 0 {
988 total_rows += batch.num_rows() as u64;
989 sink.write_batch(batch)
990 .await
991 .map_err(|e| SqlError::DataFusion {
992 message: e.to_string(),
993 })?;
994 }
995 }
996 sink.flush().await.map_err(|e| SqlError::DataFusion {
997 message: e.to_string(),
998 })?;
999 Ok(total_rows)
1000 }
1001
1002 pub fn with_udf_limits(mut self, limits: krishiv_plan::udf::ResourceLimits) -> Self {
1006 self.udf_limits = Some(limits);
1007 self
1008 }
1009
1010 pub fn is_streaming_source(&self, table_name: &str) -> bool {
1012 self.streaming_sources
1013 .read()
1014 .unwrap_or_else(|e| e.into_inner())
1015 .contains(table_name)
1016 }
1017
1018 pub fn register_streaming_source_name(&self, table_name: impl Into<String>) -> SqlResult<()> {
1027 let name: String = table_name.into();
1028 if name.trim().is_empty() {
1029 return Err(SqlError::EmptyTableName);
1030 }
1031 self.streaming_sources
1032 .write()
1033 .unwrap_or_else(|e| e.into_inner())
1034 .insert(name);
1035 self.has_streaming_sources.store(true, Ordering::Release);
1036 self.invalidate_plan_cache();
1037 Ok(())
1038 }
1039
1040 pub fn deregister_streaming_source(&self, name: &str) -> SqlResult<()> {
1046 if name.trim().is_empty() {
1047 return Err(SqlError::EmptyTableName);
1048 }
1049 let _ = self
1051 .context
1052 .deregister_table(name)
1053 .map_err(SqlError::from)?;
1054 {
1055 let mut sources = self
1056 .streaming_sources
1057 .write()
1058 .unwrap_or_else(|e| e.into_inner());
1059 sources.remove(name);
1060 if sources.is_empty() {
1061 self.has_streaming_sources.store(false, Ordering::Release);
1062 }
1063 self.invalidate_plan_cache();
1067 }
1068 Ok(())
1069 }
1070
1071 pub fn live_table_registry(&self) -> &Arc<live_table::LiveTableRegistry> {
1073 &self.live_table_registry
1074 }
1075
1076 pub fn incremental_view_registry(&self) -> &Arc<incremental_view::IncrementalViewRegistry> {
1078 &self.incremental_view_registry
1079 }
1080
1081 pub fn pipeline_registry(&self) -> &Arc<pipeline_ddl::PipelineRegistry> {
1083 &self.pipeline_registry
1084 }
1085
1086 pub fn operation_registry(&self) -> &Arc<OperationRegistry> {
1088 &self.operation_registry
1089 }
1090
1091 pub fn deregister_table(&self, name: &str) -> SqlResult<()> {
1095 if name.trim().is_empty() {
1096 return Err(SqlError::EmptyTableName);
1097 }
1098 let _ = self
1099 .context
1100 .deregister_table(name)
1101 .map_err(SqlError::from)?;
1102 self.invalidate_plan_cache();
1103 Ok(())
1104 }
1105
1106 pub fn register_table_udf_fn(
1130 &self,
1131 name: impl Into<String>,
1132 schema: arrow::datatypes::Schema,
1133 f: impl Fn(
1134 &[krishiv_plan::udf::ScalarValue],
1135 ) -> Result<arrow::record_batch::RecordBatch, krishiv_plan::udf::UdfError>
1136 + Send
1137 + Sync
1138 + 'static,
1139 ) -> SqlResult<()> {
1140 let udf =
1141 create_function_ddl::ClosureTableUdf::try_new(name, schema, std::sync::Arc::new(f))
1142 .map_err(|error| SqlError::InvalidTableFunction {
1143 message: error.to_string(),
1144 })?;
1145 if let Some(registry) = &self.udf_registry {
1146 let mut guard = registry.write().map_err(|e| SqlError::DataFusion {
1147 message: e.to_string(),
1148 })?;
1149 guard.register_table(std::sync::Arc::new(udf.clone()));
1150 }
1151 udf::register_single_table_udf(&self.context, std::sync::Arc::new(udf))
1152 .map_err(SqlError::from)
1153 }
1154
1155 pub fn is_streaming_query(&self, sql: &str) -> SqlResult<bool> {
1157 if !self.has_streaming_sources.load(Ordering::Acquire) {
1160 return Ok(false);
1161 }
1162 let sources = self
1163 .streaming_sources
1164 .read()
1165 .unwrap_or_else(|e| e.into_inner());
1166 if sources.is_empty() {
1167 return Ok(false);
1168 }
1169 let dialect = GenericDialect {};
1170 let statements = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::DataFusion {
1171 message: e.to_string(),
1172 })?;
1173 for stmt in &statements {
1174 let mut is_streaming = false;
1175 let _ = visit_relations(stmt, |relation| {
1176 let full = relation.to_string();
1179 let table_name = full.split('.').next_back().unwrap_or(&full);
1180 if sources.contains(table_name) {
1181 is_streaming = true;
1182 return ControlFlow::Break(());
1183 }
1184 ControlFlow::Continue(())
1185 });
1186 if is_streaming {
1187 return Ok(true);
1188 }
1189 }
1190 Ok(false)
1191 }
1192
1193 pub fn krishiv_catalog(&self) -> Option<&Arc<RwLock<InMemoryCatalog>>> {
1195 self.krishiv_catalog.as_ref()
1196 }
1197
1198 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
1207 #[must_use]
1208 pub fn with_iceberg_catalog(
1209 self,
1210 catalog: std::sync::Arc<catalog::unified::KrishivCatalog>,
1211 catalog_name: impl Into<String>,
1212 ) -> Self {
1213 let catalog_name = catalog_name.into();
1214 let bridge = catalog::iceberg_catalog_bridge::IcebergCatalogBridge::new(
1215 Arc::clone(&catalog),
1216 catalog_name.clone(),
1217 );
1218 self.context
1219 .register_catalog(catalog_name.clone(), Arc::new(bridge));
1220 self.iceberg_catalogs
1221 .write()
1222 .unwrap_or_else(|e| e.into_inner())
1223 .push((catalog, catalog_name));
1224 self
1225 }
1226
1227 #[must_use]
1229 pub fn with_udf_registry(
1230 mut self,
1231 registry: std::sync::Arc<std::sync::RwLock<krishiv_plan::udf::UdfRegistry>>,
1232 ) -> Self {
1233 self.udf_registry = Some(registry);
1234 self.bump_udf_version();
1236 self
1237 }
1238
1239 pub(crate) fn bump_udf_version(&self) {
1242 self.udf_registry_version.fetch_add(1, Ordering::Release);
1243 }
1244
1245 fn invalidate_plan_cache(&self) {
1250 match self.plan_cache.lock() {
1251 Ok(mut cache) => cache.clear(),
1252 Err(poisoned) => poisoned.into_inner().clear(),
1253 }
1254 }
1255
1256 pub fn clear_plan_cache(&self) {
1259 self.invalidate_plan_cache();
1260 }
1261
1262 pub async fn sync_scalar_udfs(&self) -> SqlResult<()> {
1265 let Some(registry) = &self.udf_registry else {
1266 return Ok(());
1267 };
1268 let guard = registry.read().map_err(|e| SqlError::DataFusion {
1269 message: e.to_string(),
1270 })?;
1271 let limits = self.udf_limits.clone().unwrap_or_default();
1272 udf::sync_scalar_udfs_with_limits(&self.context, &guard, limits).map_err(|e| {
1273 SqlError::DataFusion {
1274 message: e.to_string(),
1275 }
1276 })
1277 }
1278
1279 pub async fn sync_scalar_udfs_with_limits(
1284 &self,
1285 limits: krishiv_plan::udf::ResourceLimits,
1286 ) -> SqlResult<()> {
1287 self.sync_scalar_udfs_with_limits_for_profile(
1288 limits,
1289 krishiv_common::resolve_durability_profile(),
1290 )
1291 .await
1292 }
1293
1294 pub async fn sync_scalar_udfs_with_limits_for_profile(
1296 &self,
1297 limits: krishiv_plan::udf::ResourceLimits,
1298 profile: krishiv_common::DurabilityProfile,
1299 ) -> SqlResult<()> {
1300 self.sync_scalar_udfs_with_limits_for_policy(
1301 limits,
1302 krishiv_common::NativeScalarUdfPolicy::resolve(profile),
1303 )
1304 .await
1305 }
1306
1307 pub async fn sync_scalar_udfs_with_limits_for_policy(
1309 &self,
1310 limits: krishiv_plan::udf::ResourceLimits,
1311 policy: krishiv_common::NativeScalarUdfPolicy,
1312 ) -> SqlResult<()> {
1313 let Some(registry) = &self.udf_registry else {
1314 return Ok(());
1315 };
1316 let guard = registry.read().map_err(|e| SqlError::DataFusion {
1317 message: e.to_string(),
1318 })?;
1319 udf::sync_scalar_udfs_with_limits_for_policy(&self.context, &guard, limits, policy).map_err(
1320 |e| SqlError::DataFusion {
1321 message: e.to_string(),
1322 },
1323 )
1324 }
1325
1326 pub async fn sync_aggregate_udfs(&self) -> SqlResult<()> {
1328 let Some(registry) = &self.udf_registry else {
1329 return Ok(());
1330 };
1331 let guard = registry.read().map_err(|e| SqlError::DataFusion {
1332 message: e.to_string(),
1333 })?;
1334 udf::sync_aggregate_udfs(&self.context, &guard).map_err(|e| SqlError::DataFusion {
1335 message: e.to_string(),
1336 })
1337 }
1338
1339 pub async fn sync_table_udfs(&self) -> SqlResult<()> {
1341 let Some(registry) = &self.udf_registry else {
1342 return Ok(());
1343 };
1344 let guard = registry.read().map_err(|e| SqlError::DataFusion {
1345 message: e.to_string(),
1346 })?;
1347 udf::sync_table_udfs(&self.context, &guard).map_err(|e| SqlError::DataFusion {
1348 message: e.to_string(),
1349 })
1350 }
1351
1352 pub async fn sync_all_udfs(&self) -> SqlResult<()> {
1354 self.sync_scalar_udfs().await?;
1355 self.sync_aggregate_udfs().await?;
1356 self.sync_table_udfs().await?;
1357 Ok(())
1358 }
1359
1360 pub async fn register_parquet(
1362 &self,
1363 table_name: impl AsRef<str>,
1364 path: impl AsRef<Path>,
1365 ) -> SqlResult<()> {
1366 let table_name = table_name.as_ref();
1367 if table_name.trim().is_empty() {
1368 return Err(SqlError::EmptyTableName);
1369 }
1370
1371 let path = path.as_ref().to_string_lossy().into_owned();
1372
1373 if path.starts_with("s3://") {
1376 let url = url::Url::parse(&path).map_err(|e| SqlError::DataFusion {
1377 message: format!("invalid s3 url {path}: {e}"),
1378 })?;
1379 let bucket = url.host_str().unwrap_or_default();
1380 let store_url =
1381 url::Url::parse(&format!("s3://{bucket}")).map_err(|e| SqlError::DataFusion {
1382 message: format!("invalid s3 bucket url: {e}"),
1383 })?;
1384 let store = AmazonS3Builder::from_env()
1385 .with_bucket_name(bucket)
1386 .build()
1387 .map_err(|e| SqlError::DataFusion {
1388 message: format!("s3 store init: {e}"),
1389 })?;
1390 self.context
1391 .register_object_store(&store_url, Arc::new(store));
1392 }
1393
1394 if self
1395 .context
1396 .table_exist(table_name)
1397 .map_err(SqlError::from)?
1398 {
1399 let _ = self
1400 .context
1401 .deregister_table(table_name)
1402 .map_err(SqlError::from)?;
1403 }
1404 self.context
1405 .register_parquet(table_name, path, ParquetReadOptions::default())
1406 .await?;
1407 if let Ok(provider) = self.context.table_provider(table_name).await
1409 && let Some(stats) = provider.statistics()
1410 && let Some(n) = stats.num_rows.get_value()
1411 && let Ok(mut counts) = self.table_row_counts.write()
1412 {
1413 counts.insert(table_name.to_string(), *n as u64);
1414 }
1415 self.invalidate_plan_cache();
1416 Ok(())
1417 }
1418
1419 pub async fn read_parquet(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
1421 let path = path.as_ref().to_string_lossy().into_owned();
1422 let dataframe = self
1423 .context
1424 .read_parquet(path, ParquetReadOptions::default())
1425 .await?;
1426 Ok(self.make_sql_df("parquet-read", dataframe))
1427 }
1428
1429 pub async fn register_record_batches(
1435 &self,
1436 table_name: impl AsRef<str>,
1437 batches: Vec<RecordBatch>,
1438 ) -> SqlResult<()> {
1439 use std::sync::Arc;
1440 let table_name = table_name.as_ref();
1441 if table_name.trim().is_empty() {
1442 return Err(SqlError::EmptyTableName);
1443 }
1444 if batches.is_empty() {
1445 return Ok(());
1446 }
1447 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
1448 let schema = batches
1449 .first()
1450 .ok_or_else(|| SqlError::DataFusion {
1451 message: "empty batch list".into(),
1452 })?
1453 .schema();
1454 let mem_table =
1455 datafusion::datasource::MemTable::try_new(schema, vec![batches]).map_err(|e| {
1456 SqlError::DataFusion {
1457 message: e.to_string(),
1458 }
1459 })?;
1460 if self
1461 .context
1462 .table_exist(table_name)
1463 .map_err(SqlError::from)?
1464 {
1465 let _ = self
1466 .context
1467 .deregister_table(table_name)
1468 .map_err(SqlError::from)?;
1469 }
1470 self.context
1471 .register_table(table_name, Arc::new(mem_table))
1472 .map_err(|e| SqlError::DataFusion {
1473 message: e.to_string(),
1474 })?;
1475 if total_rows > 0
1476 && let Ok(mut counts) = self.table_row_counts.write()
1477 {
1478 counts.insert(table_name.to_string(), total_rows as u64);
1479 }
1480 self.invalidate_plan_cache();
1481 Ok(())
1482 }
1483
1484 pub async fn read_parquet_with_options(
1486 &self,
1487 path: impl AsRef<Path>,
1488 opts: &ParquetReaderOptions,
1489 ) -> SqlResult<SqlDataFrame> {
1490 let path = path.as_ref().to_string_lossy().into_owned();
1491 let mut options = datafusion::prelude::ParquetReadOptions::default();
1492 if opts.batch_size.is_some() {
1493 options = options.parquet_pruning(true);
1494 }
1495 let dataframe = self.context.read_parquet(path, options).await?;
1501 Ok(self.make_sql_df("parquet-read", dataframe))
1502 }
1503
1504 pub async fn read_csv(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
1506 self.read_csv_with_options(path, &CsvReaderOptions::default())
1507 .await
1508 }
1509
1510 pub async fn read_csv_with_options(
1512 &self,
1513 path: impl AsRef<Path>,
1514 opts: &CsvReaderOptions,
1515 ) -> SqlResult<SqlDataFrame> {
1516 let path = path.as_ref().to_string_lossy().into_owned();
1517 let mut options = datafusion::prelude::CsvReadOptions::new();
1518 if let Some(delim) = opts.delimiter {
1519 options = options.delimiter(delim as u8);
1520 }
1521 if let Some(has_header) = opts.has_header {
1522 options = options.has_header(has_header);
1523 }
1524 let dataframe = self.context.read_csv(path, options).await?;
1525 Ok(self.make_sql_df("csv-read", dataframe))
1526 }
1527
1528 pub async fn read_json(&self, path: impl AsRef<Path>) -> SqlResult<SqlDataFrame> {
1530 let path = path.as_ref().to_string_lossy().into_owned();
1531 let dataframe = self
1532 .context
1533 .read_json(path, datafusion::prelude::JsonReadOptions::default())
1534 .await?;
1535 Ok(self.make_sql_df("json-read", dataframe))
1536 }
1537
1538 pub async fn read_delta(
1540 &self,
1541 path: impl AsRef<str>,
1542 version: Option<i64>,
1543 ) -> SqlResult<SqlDataFrame> {
1544 let path = path.as_ref();
1545 let base = path.replace(['/', '.', '-'], "_");
1546 let table = match version {
1547 Some(v) => format!("delta_{base}_v{v}"),
1548 None => format!("delta_{base}"),
1549 };
1550 lakehouse::register_delta_uri(&self.context, &table, path, version).await?;
1551 self.sql(format!("SELECT * FROM {table}")).await
1552 }
1553
1554 pub async fn read_hudi(
1556 &self,
1557 path: impl AsRef<str>,
1558 query_type: krishiv_connectors::lakehouse::HudiQueryType,
1559 begin_instant: Option<&str>,
1560 ) -> SqlResult<SqlDataFrame> {
1561 let path = path.as_ref();
1562 let table = format!("hudi_{}", path.replace(['/', '.', '-'], "_"));
1563 lakehouse::register_hudi_uri(&self.context, &table, path, query_type, begin_instant)
1564 .await?;
1565 self.sql(format!("SELECT * FROM {table}")).await
1566 }
1567
1568 pub async fn sql(&self, query: impl AsRef<str>) -> SqlResult<SqlDataFrame> {
1570 let query = query.as_ref();
1571 if query.trim().is_empty() {
1572 return Err(SqlError::EmptyQuery);
1573 }
1574
1575 {
1579 let current = self.udf_registry_version.load(Ordering::Acquire);
1580 let last = self.udf_last_synced_version.load(Ordering::Relaxed);
1581 if current != last {
1582 self.sync_all_udfs().await?;
1583 self.udf_last_synced_version
1584 .store(current, Ordering::Release);
1585 }
1586 }
1587
1588 if let Some(stmt) = introspection_sql::parse_introspection_statement(query)? {
1590 return match stmt {
1591 introspection_sql::IntrospectionStatement::Describe { table } => {
1592 let batch = introspection_sql::describe_table(&self.context, &table).await?;
1593 let describe_table_name = next_ephemeral_name("describe_result");
1594 lakehouse::register_scan_batches(
1595 &self.context,
1596 &describe_table_name,
1597 vec![batch],
1598 )
1599 .await?;
1600 let dataframe = self
1601 .context
1602 .sql(&format!("SELECT * FROM {describe_table_name}"))
1603 .await?;
1604 Ok(self.attach_query_metadata(self.make_sql_df("describe", dataframe), query))
1605 }
1606 introspection_sql::IntrospectionStatement::Explain { mode, query: inner } => {
1607 let text = introspection_sql::explain_query(&inner, mode)?;
1608 let batch = introspection_sql::explain_result_batch(&text)?;
1609 let explain_table = next_ephemeral_name("explain_result");
1610 lakehouse::register_scan_batches(&self.context, &explain_table, vec![batch])
1611 .await?;
1612 let dataframe = self
1613 .context
1614 .sql(&format!("SELECT * FROM {explain_table}"))
1615 .await?;
1616 Ok(self.attach_query_metadata(self.make_sql_df("explain", dataframe), query))
1617 }
1618 };
1619 }
1620
1621 if live_table::execute_live_table_ddl(&self.live_table_registry, query)?.is_some() {
1623 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1624 return Ok(self.attach_query_metadata(self.make_sql_df("live-table-ddl", empty), query));
1625 }
1626
1627 if incremental_view::execute_incremental_view_ddl(&self.incremental_view_registry, query)?
1629 .is_some()
1630 {
1631 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1632 return Ok(
1633 self.attach_query_metadata(self.make_sql_df("incremental-view-ddl", empty), query)
1634 );
1635 }
1636
1637 if pipeline_ddl::execute_pipeline_ddl(&self.pipeline_registry, query)?.is_some() {
1641 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1642 return Ok(self.attach_query_metadata(self.make_sql_df("pipeline-ddl", empty), query));
1643 }
1644
1645 let trimmed = query.trim();
1648 if trimmed
1649 .to_ascii_uppercase()
1650 .starts_with("SET SHUFFLE.PARTITIONS")
1651 {
1652 let value = trimmed.split('=').nth(1).map(|s| s.trim()).unwrap_or("");
1653 match value.parse::<u32>() {
1654 Ok(n) if n > 0 => {
1655 {
1656 let mut guard =
1657 self.shuffle_partitions
1658 .write()
1659 .map_err(|e| SqlError::DataFusion {
1660 message: e.to_string(),
1661 })?;
1662 *guard = Some(n);
1663 }
1664 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1665 return Ok(self.make_sql_df("set-shuffle-partitions", empty));
1666 }
1667 Ok(_) => {
1668 {
1669 let mut guard =
1670 self.shuffle_partitions
1671 .write()
1672 .map_err(|e| SqlError::DataFusion {
1673 message: e.to_string(),
1674 })?;
1675 *guard = None;
1676 }
1677 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1678 return Ok(self.make_sql_df("set-shuffle-partitions", empty));
1679 }
1680 Err(_) => {
1681 return Err(SqlError::DataFusion {
1682 message: format!(
1683 "invalid shuffle.partitions value '{value}'; expected a positive integer"
1684 ),
1685 });
1686 }
1687 }
1688 }
1689
1690 if create_function_ddl::is_create_function_returns_table(query) {
1695 let ddl = create_function_ddl::parse_create_function(query)
1696 .map_err(|message| SqlError::InvalidTableFunction { message })?;
1697 if ddl.language.as_deref() != Some("sql") {
1698 return Err(SqlError::Unsupported {
1699 feature: format!(
1700 "CREATE FUNCTION '{}' uses language {:?}; only LANGUAGE SQL AS '...' \
1701 table functions are executable",
1702 ddl.function_name, ddl.language
1703 ),
1704 });
1705 }
1706 let body = ddl
1707 .body
1708 .as_deref()
1709 .filter(|body| !body.trim().is_empty())
1710 .ok_or_else(|| SqlError::InvalidTableFunction {
1711 message: format!(
1712 "SQL table function '{}' requires a non-empty AS body",
1713 ddl.function_name
1714 ),
1715 })?;
1716 let fields: Vec<_> = ddl
1717 .return_columns
1718 .iter()
1719 .map(|column| {
1720 arrow::datatypes::Field::new(&column.name, column.data_type.clone(), true)
1721 })
1722 .collect();
1723 let schema = arrow::datatypes::Schema::new(fields);
1724 let udf: std::sync::Arc<dyn krishiv_plan::udf::TableUdf> = std::sync::Arc::new(
1725 create_function_ddl::SqlBodyTableUdf::try_new(
1726 &ddl.function_name,
1727 schema,
1728 body,
1729 ddl.arguments.len(),
1730 std::sync::Arc::new(self.context.clone()),
1731 )
1732 .map_err(|error| SqlError::InvalidTableFunction {
1733 message: error.to_string(),
1734 })?,
1735 );
1736 if let Some(registry) = &self.udf_registry {
1737 let mut guard = registry.write().map_err(|e| SqlError::DataFusion {
1738 message: e.to_string(),
1739 })?;
1740 guard.register_table(std::sync::Arc::clone(&udf));
1741 }
1742 udf::register_single_table_udf(&self.context, std::sync::Arc::clone(&udf))
1743 .map_err(SqlError::from)?;
1744 let empty = self.context.sql("SELECT 1 WHERE FALSE").await?;
1745 return Ok(
1746 self.attach_query_metadata(self.make_sql_df("create-function", empty), query)
1747 );
1748 }
1749
1750 if query
1751 .trim_start()
1752 .to_ascii_uppercase()
1753 .starts_with("MERGE INTO")
1754 {
1755 let batches = lakehouse::execute_merge_sql(&self.context, query).await?;
1756 let merge_table = next_ephemeral_name("merge_result");
1757 lakehouse::register_scan_batches(&self.context, &merge_table, batches).await?;
1758 let dataframe = self
1759 .context
1760 .sql(&format!("SELECT * FROM {merge_table}"))
1761 .await?;
1762 return Ok(self.attach_query_metadata(self.make_sql_df("merge", dataframe), query));
1763 }
1764
1765 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
1768 if trimmed.to_ascii_uppercase().starts_with("CALL SYSTEM.") {
1769 let result = self.dispatch_call_system(trimmed).await?;
1770 let call_table = next_ephemeral_name("call_result");
1771 lakehouse::register_scan_batches(&self.context, &call_table, vec![result]).await?;
1772 let dataframe = self
1773 .context
1774 .sql(&format!("SELECT * FROM {call_table}"))
1775 .await?;
1776 return Ok(self.attach_query_metadata(self.make_sql_df("call", dataframe), query));
1777 }
1778
1779 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
1783 if trimmed.to_ascii_uppercase().starts_with("DELETE FROM ") {
1784 if let Some((table_ref, predicate)) = parse_dml_delete(trimmed) {
1785 if let Some((iceberg_catalog, table_ident)) = self.resolve_iceberg_table(&table_ref)
1786 {
1787 use arrow::array::{ArrayRef, Int64Array};
1788 use arrow::datatypes::{DataType, Field, Schema};
1789 let (deleted, _) = krishiv_connectors::lakehouse::dml::iceberg_delete_where(
1790 iceberg_catalog,
1791 &table_ident,
1792 &predicate,
1793 &self.context,
1794 )
1795 .await
1796 .map_err(|e| SqlError::DataFusion {
1797 message: e.to_string(),
1798 })?;
1799 let schema = Arc::new(Schema::new(vec![Field::new(
1800 "deleted_rows",
1801 DataType::Int64,
1802 false,
1803 )]));
1804 let array: ArrayRef = Arc::new(Int64Array::from(vec![deleted as i64]));
1805 let batch = RecordBatch::try_new(schema, vec![array]).map_err(|e| {
1806 SqlError::DataFusion {
1807 message: e.to_string(),
1808 }
1809 })?;
1810 let res_table = next_ephemeral_name("delete_result");
1811 lakehouse::register_scan_batches(&self.context, &res_table, vec![batch])
1812 .await?;
1813 let dataframe = self
1814 .context
1815 .sql(&format!("SELECT * FROM {res_table}"))
1816 .await?;
1817 return Ok(
1818 self.attach_query_metadata(self.make_sql_df("delete", dataframe), query)
1819 );
1820 }
1821 }
1822 }
1823
1824 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
1826 if trimmed.to_ascii_uppercase().starts_with("UPDATE ") {
1827 if let Some(parsed) = parse_dml_update(trimmed) {
1828 if let Some((iceberg_catalog, table_ident)) =
1829 self.resolve_iceberg_table(&parsed.table_ref)
1830 {
1831 use arrow::array::{ArrayRef, Int64Array};
1832 use arrow::datatypes::{DataType, Field, Schema};
1833 let borrowed: Vec<(&str, &str)> = parsed
1834 .assignments
1835 .iter()
1836 .map(|(c, e)| (c.as_str(), e.as_str()))
1837 .collect();
1838 let pred = parsed.predicate.as_deref();
1839 let (updated, _) = krishiv_connectors::lakehouse::dml::iceberg_update_where(
1840 iceberg_catalog,
1841 &table_ident,
1842 &borrowed,
1843 pred,
1844 &self.context,
1845 )
1846 .await
1847 .map_err(|e| SqlError::DataFusion {
1848 message: e.to_string(),
1849 })?;
1850 let schema = Arc::new(Schema::new(vec![Field::new(
1851 "updated_rows",
1852 DataType::Int64,
1853 false,
1854 )]));
1855 let array: ArrayRef = Arc::new(Int64Array::from(vec![updated as i64]));
1856 let batch = RecordBatch::try_new(schema, vec![array]).map_err(|e| {
1857 SqlError::DataFusion {
1858 message: e.to_string(),
1859 }
1860 })?;
1861 let res_table = next_ephemeral_name("update_result");
1862 lakehouse::register_scan_batches(&self.context, &res_table, vec![batch])
1863 .await?;
1864 let dataframe = self
1865 .context
1866 .sql(&format!("SELECT * FROM {res_table}"))
1867 .await?;
1868 return Ok(
1869 self.attach_query_metadata(self.make_sql_df("update", dataframe), query)
1870 );
1871 }
1872 }
1873 }
1874
1875 if query.to_ascii_uppercase().contains(" MATCH_RECOGNIZE ")
1879 && let Some(stmt) = cep_sql::parse_match_recognize(query)?
1880 {
1881 let is_streaming = self.is_streaming_source(&stmt.source_table);
1882 let streaming_limit = streaming_match_recognize_limit_from_env();
1890 let source_sql = if is_streaming {
1891 format!(
1892 "SELECT * FROM {} LIMIT {}",
1893 stmt.source_table, streaming_limit
1894 )
1895 } else {
1896 format!("SELECT * FROM {}", stmt.source_table)
1897 };
1898 let source_df = self.context.sql(&source_sql).await?;
1899 let source_batches = source_df.collect().await?;
1900 if is_streaming {
1901 tracing::warn!(
1902 source = %stmt.source_table,
1903 limit = streaming_limit,
1904 collected_rows = source_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
1905 "MATCH_RECOGNIZE executed against a streaming source under \
1906 bounded materialisation; results only cover the first {0} rows \
1907 of the source. Set KRISHIV_MATCH_RECOGNIZE_STREAMING_LIMIT to a \
1908 larger value if your executor has the memory budget.",
1909 streaming_limit
1910 );
1911 }
1912 let results = cep_sql::execute_match_recognize(stmt, &source_batches)?;
1913 let cep_table = next_ephemeral_name("cep_result");
1914 lakehouse::register_scan_batches(&self.context, &cep_table, results).await?;
1915 let dataframe = self
1916 .context
1917 .sql(&format!("SELECT * FROM {cep_table}"))
1918 .await?;
1919 return Ok(self.attach_query_metadata(self.make_sql_df("cep", dataframe), query));
1920 }
1921
1922 let query = &pivot_sql::rewrite_pivot_unpivot(query)?;
1925
1926 let query = &streaming_tvf::rewrite_window_tvfs(query);
1928
1929 let (rewritten, as_ofs) =
1930 lakehouse::preprocess_as_of_sql(query).unwrap_or_else(|_| (query.to_string(), vec![]));
1931 lakehouse::apply_as_of_refs(&self.context, &as_ofs).await?;
1932
1933 let can_cache = as_ofs.is_empty();
1940 let shuffle_override = self
1941 .shuffle_partitions
1942 .read()
1943 .map(|g| *g)
1944 .unwrap_or_else(|e| *e.into_inner());
1945 if can_cache {
1946 let cached_plan: Option<datafusion::logical_expr::LogicalPlan> = self
1948 .plan_cache
1949 .lock()
1950 .unwrap_or_else(|e| e.into_inner())
1951 .get(&rewritten)
1952 .cloned();
1953 if let Some(plan) = cached_plan {
1954 let dataframe = self.context.execute_logical_plan(plan).await?;
1955 return Ok(self.attach_query_metadata(
1956 self.make_sql_df("sql-query", dataframe)
1957 .with_shuffle_partitions(shuffle_override),
1958 &rewritten,
1959 ));
1960 }
1961 }
1962
1963 let dataframe = self.context.sql(&rewritten).await?;
1964
1965 if let Some(table_name) = extract_create_external_table_name(&rewritten)
1969 && !table_name.is_empty()
1970 && let Ok(provider) = self.context.table_provider(&table_name).await
1971 {
1972 let maybe_rows = provider
1973 .statistics()
1974 .and_then(|s| s.num_rows.get_value().copied());
1975 if let Some(n) = maybe_rows
1976 && let Ok(mut counts) = self.table_row_counts.write()
1977 {
1978 counts.entry(table_name).or_insert(n as u64);
1979 }
1980 }
1981
1982 if can_cache {
1984 let plan = dataframe.logical_plan().clone();
1985 match self.plan_cache.lock() {
1986 Ok(mut cache) => cache.insert(rewritten.clone(), plan),
1987 Err(poisoned) => poisoned.into_inner().insert(rewritten.clone(), plan),
1988 }
1989 }
1990
1991 Ok(self.attach_query_metadata(
1992 self.make_sql_df("sql-query", dataframe)
1993 .with_shuffle_partitions(shuffle_override),
1994 &rewritten,
1995 ))
1996 }
1997
1998 pub async fn execute_with_timeout(
2005 &self,
2006 query: impl AsRef<str> + Send,
2007 timeout_ms: u64,
2008 ) -> SqlResult<SqlDataFrame> {
2009 let timeout = std::time::Duration::from_millis(timeout_ms);
2010 tokio::time::timeout(timeout, self.sql(query))
2011 .await
2012 .map_err(|_| SqlError::Timeout { timeout_ms })?
2013 }
2014
2015 pub async fn execute_with_operation_id(
2022 &self,
2023 operation_id: u64,
2024 query: impl AsRef<str> + Send,
2025 cancelled_ids: &OperationRegistry,
2026 ) -> SqlResult<TaggedQueryResult> {
2027 if cancelled_ids.is_cancelled(operation_id) {
2028 return Err(SqlError::OperationCancelled { operation_id });
2029 }
2030 let df = self.sql(query).await?;
2031 Ok(TaggedQueryResult {
2032 operation_id,
2033 inner: df,
2034 })
2035 }
2036
2037 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
2043 fn resolve_iceberg_table(
2044 &self,
2045 table_ref: &str,
2046 ) -> Option<(Arc<dyn iceberg::Catalog + Send + Sync>, iceberg::TableIdent)> {
2047 let parts: Vec<&str> = table_ref.splitn(3, '.').collect();
2048 let (catalog_arc, ns_str, table_str) = {
2049 let guard = self
2050 .iceberg_catalogs
2051 .read()
2052 .unwrap_or_else(|e| e.into_inner());
2053 if guard.is_empty() {
2054 return None;
2055 }
2056 match parts.len() {
2057 2 => {
2058 let (cat, _) = guard.first()?;
2059 (Arc::clone(cat), *parts.first()?, *parts.get(1)?)
2060 }
2061 3 => {
2062 let cat_name = parts.first().copied()?;
2063 let (cat, _) = guard.iter().find(|(_, n)| n == cat_name)?;
2064 (Arc::clone(cat), *parts.get(1)?, *parts.get(2)?)
2065 }
2066 _ => return None,
2067 }
2068 };
2069 let ns = iceberg::NamespaceIdent::from_vec(vec![ns_str.to_string()]).ok()?;
2070 let ident = iceberg::TableIdent::new(ns, table_str.to_string());
2071 Some((catalog_arc.as_iceberg(), ident))
2072 }
2073
2074 #[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
2077 async fn dispatch_call_system(&self, stmt: &str) -> SqlResult<RecordBatch> {
2078 use arrow::array::{ArrayRef, Int64Array};
2079 use arrow::datatypes::{DataType, Field, Schema};
2080
2081 let upper = stmt.to_ascii_uppercase();
2082 const PREFIX: &str = "CALL SYSTEM.";
2083 let upper_after = &upper[PREFIX.len()..];
2084 let orig_after = &stmt[PREFIX.len()..];
2085
2086 let paren = upper_after.find('(').ok_or_else(|| SqlError::DataFusion {
2087 message: format!("CALL: missing '(' in: {stmt}"),
2088 })?;
2089 let proc_name = upper_after[..paren].trim();
2090
2091 let args_raw = orig_after[paren + 1..]
2092 .trim_end_matches(';')
2093 .trim()
2094 .trim_end_matches(')')
2095 .trim();
2096 let args = call_args_from_str(args_raw);
2097
2098 let iceberg_catalog = {
2099 let guard = self
2100 .iceberg_catalogs
2101 .read()
2102 .unwrap_or_else(|e| e.into_inner());
2103 guard
2104 .first()
2105 .ok_or_else(|| SqlError::DataFusion {
2106 message: "CALL system: no Iceberg catalog registered".to_string(),
2107 })?
2108 .0
2109 .as_iceberg()
2110 };
2111
2112 let table_ref = args.first().ok_or_else(|| SqlError::DataFusion {
2113 message: format!("CALL {proc_name}: table reference argument is required"),
2114 })?;
2115 let table_ident = iceberg_table_ident(table_ref)?;
2116
2117 let count: i64 = match proc_name {
2118 "EXPIRE_SNAPSHOTS" => {
2119 let dur_s = args.get(1).ok_or_else(|| SqlError::DataFusion {
2120 message: "CALL expire_snapshots: duration argument is required".to_string(),
2121 })?;
2122 let older_than = parse_call_duration(dur_s)?;
2123 let retain_last = args
2124 .get(2)
2125 .and_then(|s| s.parse::<usize>().ok())
2126 .unwrap_or(1);
2127 krishiv_connectors::lakehouse::maintenance::expire_snapshots(
2128 iceberg_catalog,
2129 &table_ident,
2130 older_than,
2131 retain_last,
2132 )
2133 .await
2134 .map_err(|e| SqlError::DataFusion {
2135 message: e.to_string(),
2136 })? as i64
2137 }
2138 "REMOVE_ORPHAN_FILES" => {
2139 let dur_s = args.get(1).ok_or_else(|| SqlError::DataFusion {
2140 message: "CALL remove_orphan_files: duration argument is required".to_string(),
2141 })?;
2142 let older_than = parse_call_duration(dur_s)?;
2143 krishiv_connectors::lakehouse::maintenance::remove_orphan_files(
2144 iceberg_catalog,
2145 &table_ident,
2146 older_than,
2147 )
2148 .await
2149 .map_err(|e| SqlError::DataFusion {
2150 message: e.to_string(),
2151 })? as i64
2152 }
2153 "COMPACT_DATA_FILES" => {
2154 let target_bytes = args
2155 .get(1)
2156 .and_then(|s| s.parse::<u64>().ok())
2157 .unwrap_or(128 * 1024 * 1024);
2158 krishiv_connectors::lakehouse::maintenance::compact_data_files(
2159 iceberg_catalog,
2160 &table_ident,
2161 target_bytes,
2162 )
2163 .await
2164 .map_err(|e| SqlError::DataFusion {
2165 message: e.to_string(),
2166 })? as i64
2167 }
2168 other => {
2169 return Err(SqlError::Unsupported {
2170 feature: format!("CALL system.{other}: unknown procedure"),
2171 });
2172 }
2173 };
2174
2175 let col = match proc_name {
2176 "EXPIRE_SNAPSHOTS" => "expired_snapshots",
2177 "REMOVE_ORPHAN_FILES" => "removed_files",
2178 "COMPACT_DATA_FILES" => "rewritten_files",
2179 _ => "result",
2180 };
2181 let schema = Arc::new(Schema::new(vec![Field::new(col, DataType::Int64, false)]));
2182 let array: ArrayRef = Arc::new(Int64Array::from(vec![count]));
2183 RecordBatch::try_new(schema, vec![array]).map_err(|e| SqlError::DataFusion {
2184 message: e.to_string(),
2185 })
2186 }
2187}
2188
2189pub struct TaggedQueryResult {
2191 pub operation_id: u64,
2193 pub inner: SqlDataFrame,
2195}
2196
2197#[derive(Clone, Default)]
2203pub struct OperationRegistry {
2204 cancelled: Arc<std::sync::RwLock<std::collections::HashSet<u64>>>,
2205 progress: Arc<std::sync::RwLock<std::collections::HashMap<u64, (u64, u64)>>>,
2206}
2207
2208impl OperationRegistry {
2209 pub fn new() -> Self {
2211 Self::default()
2212 }
2213
2214 pub fn cancel(&self, operation_id: u64) {
2218 if let Ok(mut ids) = self.cancelled.write() {
2219 ids.insert(operation_id);
2220 }
2221 }
2222
2223 pub fn is_cancelled(&self, operation_id: u64) -> bool {
2225 self.cancelled
2226 .read()
2227 .map(|ids| ids.contains(&operation_id))
2228 .unwrap_or(false)
2229 }
2230
2231 pub fn remove(&self, operation_id: u64) {
2233 if let Ok(mut ids) = self.cancelled.write() {
2234 ids.remove(&operation_id);
2235 }
2236 if let Ok(mut progress) = self.progress.write() {
2237 progress.remove(&operation_id);
2238 }
2239 }
2240
2241 pub fn update_progress(&self, operation_id: u64, rows_scanned: u64, rows_emitted: u64) {
2243 if let Ok(mut progress) = self.progress.write() {
2244 progress.insert(operation_id, (rows_scanned, rows_emitted));
2245 }
2246 }
2247
2248 pub fn progress(&self, operation_id: u64) -> Option<(u64, u64)> {
2250 self.progress
2251 .read()
2252 .ok()
2253 .and_then(|progress| progress.get(&operation_id).copied())
2254 }
2255
2256 pub fn cancelled_ids(&self) -> Vec<u64> {
2258 self.cancelled
2259 .read()
2260 .map(|ids| ids.iter().copied().collect())
2261 .unwrap_or_default()
2262 }
2263}
2264
2265pub(crate) fn extract_create_external_table_name(query: &str) -> Option<String> {
2270 use datafusion::sql::parser::{DFParser, Statement as DFStatement};
2271 let mut stmts = DFParser::parse_sql(query).ok()?;
2272 match stmts.pop_front()? {
2273 DFStatement::CreateExternalTable(create) => Some(create.name.to_string()),
2274 _ => None,
2275 }
2276}
2277
2278pub enum GroupingMode<'a> {
2286 Sets(Vec<Vec<&'a krishiv_plan::expression::Expr>>),
2287 Cube(Vec<&'a krishiv_plan::expression::Expr>),
2288 Rollup(Vec<&'a krishiv_plan::expression::Expr>),
2289}
2290
2291#[async_trait::async_trait]
2292pub trait KrishivDataFrameOps: Send + Sync {
2293 async fn collect(&self) -> SqlResult<Vec<RecordBatch>>;
2295 async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)>;
2297 async fn explain(&self) -> SqlResult<String>;
2299 fn explain_logical(&self) -> String;
2301 fn krishiv_logical_plan(&self) -> LogicalPlan;
2303 fn query(&self) -> Option<&str>;
2305 async fn execute_stream(&self) -> SqlResult<SqlStream>;
2307
2308 fn schema(&self) -> SchemaRef;
2312
2313 async fn select(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2315
2316 async fn select_exprs(
2318 &self,
2319 expressions: &[&krishiv_plan::expression::Expr],
2320 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2321
2322 async fn aggregate(
2324 &self,
2325 group_exprs: &[&krishiv_plan::expression::Expr],
2326 aggregate_exprs: &[&krishiv_plan::expression::Expr],
2327 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2328
2329 async fn aggregate_grouping(
2331 &self,
2332 grouping: GroupingMode<'_>,
2333 aggregate_exprs: &[&krishiv_plan::expression::Expr],
2334 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2335
2336 async fn pivot(
2338 &self,
2339 group_exprs: &[&krishiv_plan::expression::Expr],
2340 pivot_column: &krishiv_plan::expression::Expr,
2341 aggregate_expr: &krishiv_plan::expression::Expr,
2342 values: &[(krishiv_plan::expression::ScalarValue, String)],
2343 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2344
2345 async fn unpivot(
2347 &self,
2348 columns: &[&str],
2349 name_column: &str,
2350 value_column: &str,
2351 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2352
2353 async fn filter(&self, predicate: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2355
2356 async fn filter_expr(
2358 &self,
2359 predicate: &krishiv_plan::expression::Expr,
2360 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2361
2362 async fn limit(&self, n: usize) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2364
2365 async fn distinct(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2367
2368 async fn drop_nulls(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2370
2371 async fn sample(&self, fraction: f64) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2373
2374 async fn sort(
2376 &self,
2377 columns: &[&str],
2378 descending: &[bool],
2379 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2380
2381 async fn alias(&self, alias: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2383
2384 async fn drop_columns(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2386
2387 async fn rename_column(&self, old: &str, new: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2389
2390 async fn with_column(&self, name: &str, expr: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2392
2393 fn as_any(&self) -> &dyn std::any::Any;
2395
2396 async fn describe(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2398
2399 async fn fill_null(&self, column: &str, value: &str)
2401 -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2402
2403 async fn join(
2405 &self,
2406 right: &dyn KrishivDataFrameOps,
2407 how: &str,
2408 left_on: &[&str],
2409 right_on: &[&str],
2410 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2411
2412 async fn union(
2414 &self,
2415 right: &dyn KrishivDataFrameOps,
2416 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2417
2418 async fn union_distinct(
2419 &self,
2420 right: &dyn KrishivDataFrameOps,
2421 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2422
2423 async fn intersect(
2424 &self,
2425 right: &dyn KrishivDataFrameOps,
2426 distinct: bool,
2427 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2428
2429 async fn except(
2430 &self,
2431 right: &dyn KrishivDataFrameOps,
2432 distinct: bool,
2433 ) -> SqlResult<Box<dyn KrishivDataFrameOps>>;
2434
2435 async fn register_batches(&self, name: &str, batches: Vec<RecordBatch>) -> SqlResult<()>;
2438
2439 async fn deregister_table(&self, name: &str) -> SqlResult<()>;
2441
2442 async fn create_view(&self, name: &str, replace: bool) -> SqlResult<()>;
2445}
2446
2447fn df_plan_to_krishiv_nodes(
2455 plan: &datafusion::logical_expr::LogicalPlan,
2456 table_row_counts: &std::collections::HashMap<String, u64>,
2457 counter: &mut usize,
2458) -> (Vec<krishiv_plan::PlanNode>, String) {
2459 use datafusion::logical_expr::LogicalPlan as DfPlan;
2460 use krishiv_plan::{ExecutionKind, NodeOp, PlanNode};
2461
2462 *counter += 1;
2463 let idx = *counter;
2464
2465 match plan {
2466 DfPlan::TableScan(ts) => {
2467 let table_name = ts.table_name.table().to_string();
2468 let row_count = table_row_counts.get(&table_name).copied();
2469 let filters: Vec<String> = ts.filters.iter().map(|e| e.to_string()).collect();
2470 let id = format!("scan-{idx}");
2471 let node = PlanNode::new(&id, format!("Scan {table_name}"), ExecutionKind::Batch)
2472 .with_op(NodeOp::Scan {
2473 table: table_name,
2474 filters,
2475 })
2476 .with_estimated_rows(row_count);
2477 (vec![node], id)
2478 }
2479
2480 DfPlan::Projection(proj) => {
2481 let (mut nodes, input_id) =
2482 df_plan_to_krishiv_nodes(&proj.input, table_row_counts, counter);
2483 let id = format!("proj-{idx}");
2484 let columns: Vec<String> = proj.expr.iter().map(|e| e.to_string()).collect();
2485 nodes.push(
2486 PlanNode::new(&id, "Projection", ExecutionKind::Batch)
2487 .with_op(NodeOp::Project { columns })
2488 .with_inputs([input_id]),
2489 );
2490 (nodes, id)
2491 }
2492
2493 DfPlan::Filter(filter) => {
2494 let (mut nodes, input_id) =
2495 df_plan_to_krishiv_nodes(&filter.input, table_row_counts, counter);
2496 let id = format!("filter-{idx}");
2497 let predicate = filter.predicate.to_string();
2498 nodes.push(
2499 PlanNode::new(&id, "Filter", ExecutionKind::Batch)
2500 .with_op(NodeOp::Filter { predicate })
2501 .with_inputs([input_id]),
2502 );
2503 (nodes, id)
2504 }
2505
2506 DfPlan::Aggregate(agg) => {
2507 let (mut nodes, input_id) =
2508 df_plan_to_krishiv_nodes(&agg.input, table_row_counts, counter);
2509 let id = format!("agg-{idx}");
2510 let group_keys: Vec<String> = agg.group_expr.iter().map(|e| e.to_string()).collect();
2511 nodes.push(
2512 PlanNode::new(&id, "Aggregate", ExecutionKind::Batch)
2513 .with_op(NodeOp::Aggregate { group_keys })
2514 .with_inputs([input_id]),
2515 );
2516 (nodes, id)
2517 }
2518
2519 DfPlan::Join(join) => {
2520 let (mut nodes, left_id) =
2521 df_plan_to_krishiv_nodes(&join.left, table_row_counts, counter);
2522 let (right_nodes, right_id) =
2523 df_plan_to_krishiv_nodes(&join.right, table_row_counts, counter);
2524 nodes.extend(right_nodes);
2525 let id = format!("join-{idx}");
2526 let krishiv_join_type = match join.join_type {
2531 datafusion::common::JoinType::Inner => krishiv_plan::JoinType::Inner,
2532 datafusion::common::JoinType::Left => krishiv_plan::JoinType::Left,
2533 datafusion::common::JoinType::Right => krishiv_plan::JoinType::Right,
2534 datafusion::common::JoinType::Full => krishiv_plan::JoinType::Full,
2535 datafusion::common::JoinType::LeftSemi => krishiv_plan::JoinType::LeftSemi,
2536 datafusion::common::JoinType::RightSemi => krishiv_plan::JoinType::RightSemi,
2537 datafusion::common::JoinType::LeftAnti => krishiv_plan::JoinType::LeftAnti,
2538 datafusion::common::JoinType::RightAnti => krishiv_plan::JoinType::RightAnti,
2539 datafusion::common::JoinType::LeftMark => krishiv_plan::JoinType::LeftSemi,
2543 datafusion::common::JoinType::RightMark => krishiv_plan::JoinType::RightSemi,
2544 };
2545 nodes.push(
2546 PlanNode::new(&id, "Join", ExecutionKind::Batch)
2547 .with_op(NodeOp::Join {
2548 join_type: krishiv_join_type,
2549 })
2550 .with_inputs([left_id, right_id]),
2551 );
2552 (nodes, id)
2553 }
2554
2555 DfPlan::Sort(sort) => {
2556 let (mut nodes, input_id) =
2557 df_plan_to_krishiv_nodes(&sort.input, table_row_counts, counter);
2558 let id = format!("sort-{idx}");
2559 nodes.push(
2560 PlanNode::new(&id, "Sort", ExecutionKind::Batch)
2561 .with_op(NodeOp::Other {
2562 description: format!(
2563 "Sort({})",
2564 sort.expr
2565 .iter()
2566 .map(|e| e.to_string())
2567 .collect::<Vec<_>>()
2568 .join(", ")
2569 ),
2570 })
2571 .with_inputs([input_id]),
2572 );
2573 (nodes, id)
2574 }
2575
2576 DfPlan::Repartition(repart) => {
2577 let (mut nodes, input_id) =
2578 df_plan_to_krishiv_nodes(&repart.input, table_row_counts, counter);
2579 let id = format!("exchange-{idx}");
2580 let partitioning = krishiv_plan::Partitioning::Unpartitioned;
2581 nodes.push(
2582 PlanNode::new(&id, "Exchange", ExecutionKind::Batch)
2583 .with_op(NodeOp::Exchange { partitioning })
2584 .with_inputs([input_id]),
2585 );
2586 (nodes, id)
2587 }
2588
2589 DfPlan::Limit(limit) => {
2590 let (mut nodes, input_id) =
2591 df_plan_to_krishiv_nodes(&limit.input, table_row_counts, counter);
2592 let id = format!("limit-{idx}");
2593 nodes.push(
2594 PlanNode::new(&id, "Limit", ExecutionKind::Batch)
2595 .with_op(NodeOp::Other {
2596 description: format!(
2597 "Limit(skip={:?}, fetch={:?})",
2598 limit.skip.as_ref().map(|e| e.to_string()),
2599 limit.fetch.as_ref().map(|e| e.to_string()),
2600 ),
2601 })
2602 .with_inputs([input_id]),
2603 );
2604 (nodes, id)
2605 }
2606
2607 DfPlan::Union(union) if union.inputs.len() == 1 => {
2608 if let Some(input) = union.inputs.first() {
2609 df_plan_to_krishiv_nodes(input, table_row_counts, counter)
2610 } else {
2611 (Vec::new(), String::new())
2612 }
2613 }
2614 DfPlan::Union(union) => {
2615 let mut all_nodes = Vec::new();
2616 let mut input_ids = Vec::new();
2617 for input in &union.inputs {
2618 let (sub_nodes, sub_id) =
2619 df_plan_to_krishiv_nodes(input, table_row_counts, counter);
2620 all_nodes.extend(sub_nodes);
2621 input_ids.push(sub_id);
2622 }
2623 let id = format!("union-{idx}");
2624 all_nodes.push(
2625 PlanNode::new(&id, "Union", ExecutionKind::Batch)
2626 .with_op(NodeOp::Other {
2627 description: "Union".to_string(),
2628 })
2629 .with_inputs(input_ids),
2630 );
2631 (all_nodes, id)
2632 }
2633
2634 DfPlan::SubqueryAlias(alias) => {
2635 df_plan_to_krishiv_nodes(&alias.input, table_row_counts, counter)
2637 }
2638
2639 DfPlan::Values(_) => {
2640 let id = format!("values-{idx}");
2641 let node = PlanNode::new(&id, "Values", ExecutionKind::Batch).with_op(NodeOp::Other {
2642 description: "Values".to_string(),
2643 });
2644 (vec![node], id)
2645 }
2646
2647 DfPlan::Extension(_) => {
2648 let id = format!("ext-{idx}");
2649 let label = plan.to_string();
2650 let node = PlanNode::new(&id, label.clone(), ExecutionKind::Batch)
2651 .with_op(NodeOp::Other { description: label });
2652 (vec![node], id)
2653 }
2654
2655 DfPlan::EmptyRelation(_) => {
2656 let id = format!("empty-{idx}");
2657 let node =
2658 PlanNode::new(&id, "EmptyRelation", ExecutionKind::Batch).with_op(NodeOp::Other {
2659 description: "EmptyRelation".to_string(),
2660 });
2661 (vec![node], id)
2662 }
2663
2664 _ => {
2666 let id = format!("df-{idx}");
2667 let label = plan.to_string();
2668 let node = PlanNode::new(&id, label.clone(), ExecutionKind::Batch)
2669 .with_op(NodeOp::Other { description: label });
2670 (vec![node], id)
2671 }
2672 }
2673}
2674
2675#[derive(Clone)]
2677pub struct SqlDataFrame {
2678 name: String,
2679 query: Option<String>,
2680 query_text: Option<String>,
2682 execution_kind: ExecutionKind,
2683 dataframe: DataFusionDataFrame,
2684 shuffle_partitions: Option<u32>,
2685 context: SessionContext,
2687 table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
2691}
2692
2693impl fmt::Debug for SqlDataFrame {
2694 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2695 f.debug_struct("SqlDataFrame")
2696 .field("name", &self.name)
2697 .field("query", &self.query)
2698 .field("shuffle_partitions", &self.shuffle_partitions)
2699 .finish_non_exhaustive()
2700 }
2701}
2702
2703impl SqlDataFrame {
2704 fn new(
2705 name: impl Into<String>,
2706 dataframe: DataFusionDataFrame,
2707 table_row_counts: Arc<std::sync::RwLock<HashMap<String, u64>>>,
2708 ) -> Self {
2709 Self {
2710 name: name.into(),
2711 query: None,
2712 query_text: None,
2713 execution_kind: ExecutionKind::Batch,
2714 dataframe,
2715 shuffle_partitions: None,
2716 context: SessionContext::default(),
2717 table_row_counts,
2718 }
2719 }
2720
2721 pub(crate) fn with_context(mut self, context: SessionContext) -> Self {
2723 self.context = context;
2724 self
2725 }
2726
2727 fn with_query(mut self, query: impl Into<String>) -> Self {
2728 let q = query.into();
2729 self.query_text = Some(q.clone());
2730 self.query = Some(q);
2731 self
2732 }
2733
2734 fn with_execution_kind(mut self, kind: ExecutionKind) -> Self {
2735 self.execution_kind = kind;
2736 self
2737 }
2738
2739 fn with_shuffle_partitions(mut self, n: Option<u32>) -> Self {
2740 self.shuffle_partitions = n;
2741 self
2742 }
2743
2744 pub fn query(&self) -> Option<&str> {
2746 self.query.as_deref()
2747 }
2748
2749 pub fn arrow_schema(&self) -> arrow::datatypes::SchemaRef {
2755 std::sync::Arc::new(self.dataframe.schema().as_arrow().clone())
2756 }
2757
2758 fn with_new_dataframe(&self, df: DataFusionDataFrame, tag: &str) -> Self {
2762 Self {
2763 name: format!("{}-{}", self.name, tag),
2764 query: None,
2765 query_text: None,
2766 execution_kind: self.execution_kind,
2767 dataframe: df,
2768 shuffle_partitions: self.shuffle_partitions,
2769 context: self.context.clone(),
2770 table_row_counts: self.table_row_counts.clone(),
2771 }
2772 }
2773
2774 pub fn krishiv_logical_plan(&self) -> LogicalPlan {
2783 let df_plan = self.dataframe.logical_plan();
2784 let counts = self
2785 .table_row_counts
2786 .read()
2787 .unwrap_or_else(|e| e.into_inner());
2788 let mut counter = 0usize;
2789 let (nodes, _root_id) = df_plan_to_krishiv_nodes(df_plan, &counts, &mut counter);
2790
2791 let mut plan = LogicalPlan::new(self.name.clone(), self.execution_kind);
2792 for node in nodes {
2793 plan = plan.with_node(node);
2794 }
2795
2796 let optimizer = krishiv_plan::optimizer::default_logical_optimizer();
2801 let fallback = plan.clone();
2802 match optimizer.optimize(plan) {
2803 Ok(result) => result.plan,
2804 Err(error) => {
2805 tracing::warn!(
2806 plan = %self.name,
2807 %error,
2808 "logical optimizer failed; using unoptimized plan"
2809 );
2810 fallback
2811 }
2812 }
2813 }
2814
2815 pub fn explain_logical(&self) -> String {
2817 self.dataframe.logical_plan().to_string()
2818 }
2819
2820 pub async fn explain(&self) -> SqlResult<String> {
2822 let batches = self
2823 .dataframe
2824 .clone()
2825 .explain(false, false)?
2826 .collect()
2827 .await?;
2828 pretty_batches(&batches)
2829 }
2830
2831 pub async fn collect(&self) -> SqlResult<Vec<RecordBatch>> {
2833 Ok(self.dataframe.clone().collect().await?)
2834 }
2835
2836 pub async fn execute_stream(&self) -> SqlResult<SqlStream> {
2838 let df_stream = self.dataframe.clone().execute_stream().await?;
2839 use futures::StreamExt;
2840 let mapped = df_stream.map(|res| {
2841 res.map_err(|e| SqlError::DataFusion {
2842 message: e.to_string(),
2843 })
2844 });
2845 Ok(Box::pin(mapped))
2846 }
2847
2848 pub async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)> {
2856 use datafusion::physical_plan::collect as df_collect;
2857
2858 let df = self.dataframe.clone();
2859 let task_ctx = df.task_ctx();
2860 let physical_plan = df.create_physical_plan().await?;
2861
2862 let batches = df_collect(physical_plan.clone(), task_ctx.into()).await?;
2863
2864 let mut output_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
2865 let mut cpu_nanos: u64 = 0;
2866
2867 if let Some(metrics) = physical_plan.metrics() {
2868 if let Some(v) = metrics.output_rows() {
2869 output_rows = v as u64;
2870 }
2871 if let Some(t) = metrics.elapsed_compute() {
2872 cpu_nanos = t as u64;
2873 }
2874 }
2875
2876 let (spill_bytes, spill_count) = aggregate_spill_metrics(physical_plan.as_ref());
2877
2878 Ok((
2879 batches,
2880 SqlExecutionStats {
2881 output_rows,
2882 cpu_nanos,
2883 spill_bytes,
2884 spill_count,
2885 },
2886 ))
2887 }
2888}
2889
2890fn aggregate_spill_metrics(plan: &dyn datafusion::physical_plan::ExecutionPlan) -> (u64, u64) {
2897 let mut spill_bytes: u64 = 0;
2898 let mut spill_count: u64 = 0;
2899 if let Some(metrics) = plan.metrics() {
2900 if let Some(bytes) = metrics.spilled_bytes() {
2901 spill_bytes = spill_bytes.saturating_add(bytes as u64);
2902 }
2903 if let Some(count) = metrics.spill_count() {
2904 spill_count = spill_count.saturating_add(count as u64);
2905 }
2906 }
2907 for child in plan.children() {
2908 let (child_bytes, child_count) = aggregate_spill_metrics(child.as_ref());
2909 spill_bytes = spill_bytes.saturating_add(child_bytes);
2910 spill_count = spill_count.saturating_add(child_count);
2911 }
2912 (spill_bytes, spill_count)
2913}
2914
2915#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
2917pub struct SqlExecutionStats {
2918 pub output_rows: u64,
2919 pub cpu_nanos: u64,
2920 pub spill_bytes: u64,
2922 pub spill_count: u64,
2924}
2925
2926fn top_level_alias_index(expression: &str) -> Option<usize> {
2927 let bytes = expression.as_bytes();
2928 let mut depth = 0usize;
2929 let mut single_quoted = false;
2930 let mut double_quoted = false;
2931 let mut candidate = None;
2932 let mut index = 0usize;
2933 while index < bytes.len() {
2934 let Some(&byte) = bytes.get(index) else {
2935 break;
2936 };
2937 match byte {
2938 b'\'' if !double_quoted => {
2939 if single_quoted && bytes.get(index + 1) == Some(&b'\'') {
2940 index += 2;
2941 continue;
2942 }
2943 single_quoted = !single_quoted;
2944 }
2945 b'"' if !single_quoted => {
2946 if double_quoted && bytes.get(index + 1) == Some(&b'"') {
2947 index += 2;
2948 continue;
2949 }
2950 double_quoted = !double_quoted;
2951 }
2952 b'(' if !single_quoted && !double_quoted => depth += 1,
2953 b')' if !single_quoted && !double_quoted => depth = depth.saturating_sub(1),
2954 b' ' if depth == 0
2955 && !single_quoted
2956 && !double_quoted
2957 && bytes
2958 .get(index..index + 4)
2959 .is_some_and(|slice| slice.eq_ignore_ascii_case(b" AS ")) =>
2960 {
2961 candidate = Some(index);
2962 index += 3;
2963 }
2964 _ => {}
2965 }
2966 index += 1;
2967 }
2968 candidate
2969}
2970
2971fn parse_dataframe_expression(
2972 dataframe: &datafusion::dataframe::DataFrame,
2973 expression: &str,
2974) -> SqlResult<datafusion::logical_expr::Expr> {
2975 if let Some(index) = top_level_alias_index(expression) {
2976 let (body, alias) = expression.split_at(index);
2977 let alias = alias[4..].trim();
2978 if !alias.is_empty() {
2979 let alias = alias
2980 .strip_prefix('"')
2981 .and_then(|value| value.strip_suffix('"'))
2982 .unwrap_or(alias)
2983 .replace("\"\"", "\"");
2984 return Ok(dataframe.parse_sql_expr(body.trim())?.alias(alias));
2985 }
2986 }
2987 dataframe.parse_sql_expr(expression).map_err(Into::into)
2988}
2989
2990pub fn parse_public_expression(sql: &str) -> SqlResult<krishiv_plan::expression::Expr> {
2992 let dialect = GenericDialect {};
2993 let mut parser =
2994 Parser::new(&dialect)
2995 .try_with_sql(sql)
2996 .map_err(|error| SqlError::Unsupported {
2997 feature: format!("public expression parse: {error}"),
2998 })?;
2999 let expression = parser.parse_expr().map_err(|error| SqlError::Unsupported {
3000 feature: format!("public expression parse: {error}"),
3001 })?;
3002 sqlparser_expression_to_public(&expression)
3003}
3004
3005fn sqlparser_expression_to_public(
3006 expression: &datafusion::sql::sqlparser::ast::Expr,
3007) -> SqlResult<krishiv_plan::expression::Expr> {
3008 use datafusion::sql::sqlparser::ast::{BinaryOperator as SqlOperator, Expr as SqlExpr, Value};
3009 use krishiv_plan::expression::{BinaryOperator, Expr, ScalarValue};
3010
3011 Ok(match expression {
3012 SqlExpr::Identifier(identifier) => Expr::Column {
3013 path: vec![identifier.value.clone()],
3014 },
3015 SqlExpr::CompoundIdentifier(identifiers) => Expr::Column {
3016 path: identifiers
3017 .iter()
3018 .map(|identifier| identifier.value.clone())
3019 .collect(),
3020 },
3021 SqlExpr::Nested(expression) => sqlparser_expression_to_public(expression)?,
3022 SqlExpr::IsNull(expression) => Expr::IsNull {
3023 expression: Box::new(sqlparser_expression_to_public(expression)?),
3024 negated: false,
3025 },
3026 SqlExpr::IsNotNull(expression) => Expr::IsNull {
3027 expression: Box::new(sqlparser_expression_to_public(expression)?),
3028 negated: true,
3029 },
3030 SqlExpr::BinaryOp { left, op, right } => Expr::Binary {
3031 left: Box::new(sqlparser_expression_to_public(left)?),
3032 op: match op {
3033 SqlOperator::Eq => BinaryOperator::Eq,
3034 SqlOperator::NotEq => BinaryOperator::NotEq,
3035 SqlOperator::Gt => BinaryOperator::Gt,
3036 SqlOperator::GtEq => BinaryOperator::GtEq,
3037 SqlOperator::Lt => BinaryOperator::Lt,
3038 SqlOperator::LtEq => BinaryOperator::LtEq,
3039 SqlOperator::And => BinaryOperator::And,
3040 SqlOperator::Or => BinaryOperator::Or,
3041 SqlOperator::Plus => BinaryOperator::Plus,
3042 SqlOperator::Minus => BinaryOperator::Minus,
3043 SqlOperator::Multiply => BinaryOperator::Multiply,
3044 SqlOperator::Divide => BinaryOperator::Divide,
3045 other => {
3046 return Err(SqlError::Unsupported {
3047 feature: format!("public expression operator {other}"),
3048 });
3049 }
3050 },
3051 right: Box::new(sqlparser_expression_to_public(right)?),
3052 },
3053 SqlExpr::Value(value) => Expr::Literal {
3054 value: match &value.value {
3055 Value::Null => ScalarValue::Null,
3056 Value::Boolean(value) => ScalarValue::Boolean(*value),
3057 Value::SingleQuotedString(value) => ScalarValue::Utf8(value.clone()),
3058 Value::Number(value, _)
3059 if value.contains('.') || value.contains('e') || value.contains('E') =>
3060 {
3061 ScalarValue::float64(value.parse::<f64>().map_err(|error| {
3062 SqlError::Unsupported {
3063 feature: format!("numeric expression literal: {error}"),
3064 }
3065 })?)
3066 }
3067 Value::Number(value, _) => {
3068 ScalarValue::Int64(value.parse::<i64>().map_err(|error| {
3069 SqlError::Unsupported {
3070 feature: format!("integer expression literal: {error}"),
3071 }
3072 })?)
3073 }
3074 other => {
3075 return Err(SqlError::Unsupported {
3076 feature: format!("public expression literal {other}"),
3077 });
3078 }
3079 },
3080 },
3081 other => {
3082 return Err(SqlError::Unsupported {
3083 feature: format!("public expression node {other}"),
3084 });
3085 }
3086 })
3087}
3088
3089fn public_data_type_to_arrow(
3090 data_type: &krishiv_plan::expression::ExprDataType,
3091) -> arrow::datatypes::DataType {
3092 use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
3093 use krishiv_plan::expression::{ExprDataType, IntervalUnit as PublicIntervalUnit};
3094
3095 match data_type {
3096 ExprDataType::Null => DataType::Null,
3097 ExprDataType::Boolean => DataType::Boolean,
3098 ExprDataType::Int64 => DataType::Int64,
3099 ExprDataType::UInt64 => DataType::UInt64,
3100 ExprDataType::Float64 => DataType::Float64,
3101 ExprDataType::Utf8 => DataType::Utf8,
3102 ExprDataType::Binary => DataType::Binary,
3103 ExprDataType::Decimal128 { precision, scale } => DataType::Decimal128(*precision, *scale),
3104 ExprDataType::Date32 => DataType::Date32,
3105 ExprDataType::Timestamp { unit, timezone } => DataType::Timestamp(
3106 match unit {
3107 krishiv_plan::expression::TimeUnit::Second => TimeUnit::Second,
3108 krishiv_plan::expression::TimeUnit::Millisecond => TimeUnit::Millisecond,
3109 krishiv_plan::expression::TimeUnit::Microsecond => TimeUnit::Microsecond,
3110 krishiv_plan::expression::TimeUnit::Nanosecond => TimeUnit::Nanosecond,
3111 },
3112 timezone.clone().map(Into::into),
3113 ),
3114 ExprDataType::Interval { unit } => DataType::Interval(match unit {
3115 PublicIntervalUnit::YearMonth => IntervalUnit::YearMonth,
3116 PublicIntervalUnit::DayTime => IntervalUnit::DayTime,
3117 PublicIntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano,
3118 }),
3119 ExprDataType::List(element) => DataType::List(Arc::new(Field::new(
3120 "item",
3121 public_data_type_to_arrow(element),
3122 true,
3123 ))),
3124 ExprDataType::Map { key, value } => DataType::Map(
3125 Arc::new(Field::new(
3126 "entries",
3127 DataType::Struct(
3128 vec![
3129 Arc::new(Field::new("key", public_data_type_to_arrow(key), false)),
3130 Arc::new(Field::new("value", public_data_type_to_arrow(value), true)),
3131 ]
3132 .into(),
3133 ),
3134 false,
3135 )),
3136 false,
3137 ),
3138 ExprDataType::Struct(fields) => DataType::Struct(
3139 fields
3140 .iter()
3141 .map(|field| {
3142 Arc::new(Field::new(
3143 &field.name,
3144 public_data_type_to_arrow(&field.data_type),
3145 field.nullable,
3146 ))
3147 })
3148 .collect::<Vec<_>>()
3149 .into(),
3150 ),
3151 ExprDataType::Variant => DataType::Utf8,
3156 }
3157}
3158
3159fn public_scalar_to_datafusion(
3160 value: &krishiv_plan::expression::ScalarValue,
3161) -> Option<datafusion::common::ScalarValue> {
3162 use datafusion::common::ScalarValue;
3163 use krishiv_plan::expression::{ScalarValue as PublicScalar, TimeUnit};
3164
3165 Some(match value {
3166 PublicScalar::Null => ScalarValue::Null,
3167 PublicScalar::Boolean(value) => ScalarValue::Boolean(Some(*value)),
3168 PublicScalar::Int64(value) => ScalarValue::Int64(Some(*value)),
3169 PublicScalar::UInt64(value) => ScalarValue::UInt64(Some(*value)),
3170 PublicScalar::Float64(bits) => ScalarValue::Float64(Some(f64::from_bits(*bits))),
3171 PublicScalar::Utf8(value) => ScalarValue::Utf8(Some(value.clone())),
3172 PublicScalar::Binary(value) => ScalarValue::Binary(Some(value.clone())),
3173 PublicScalar::Decimal128 {
3174 value,
3175 precision,
3176 scale,
3177 } => ScalarValue::Decimal128(Some(*value), *precision, *scale),
3178 PublicScalar::Date32(value) => ScalarValue::Date32(Some(*value)),
3179 PublicScalar::Timestamp {
3180 value,
3181 unit,
3182 timezone,
3183 } => {
3184 let timezone = timezone.clone().map(Into::into);
3185 match unit {
3186 TimeUnit::Second => ScalarValue::TimestampSecond(Some(*value), timezone),
3187 TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(Some(*value), timezone),
3188 TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(Some(*value), timezone),
3189 TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(Some(*value), timezone),
3190 }
3191 }
3192 PublicScalar::Interval { .. } => return None,
3193 })
3194}
3195
3196fn lower_public_expression(
3202 dataframe: &datafusion::dataframe::DataFrame,
3203 expression: &krishiv_plan::expression::Expr,
3204) -> SqlResult<datafusion::logical_expr::Expr> {
3205 expression
3206 .validate()
3207 .map_err(|error| SqlError::Unsupported {
3208 feature: format!("invalid public expression: {error}"),
3209 })?;
3210 use datafusion::logical_expr::{Expr as DataFusionExpr, Operator, binary_expr, cast, try_cast};
3211 use krishiv_plan::expression::{BinaryOperator, Expr};
3212
3213 Ok(match expression {
3214 Expr::Column { path } if path.len() == 1 => {
3215 datafusion::prelude::col(path.first().map(String::as_str).unwrap_or(""))
3216 }
3217 Expr::Column { .. } => parse_dataframe_expression(dataframe, &expression.to_sql())?,
3218 Expr::Literal { value } => match public_scalar_to_datafusion(value) {
3219 Some(value) => DataFusionExpr::Literal(value, None),
3220 None => parse_dataframe_expression(dataframe, &expression.to_sql())?,
3221 },
3222 Expr::Alias { expression, name } => {
3223 lower_public_expression(dataframe, expression)?.alias(name)
3224 }
3225 Expr::Binary { left, op, right } => binary_expr(
3226 lower_public_expression(dataframe, left)?,
3227 match op {
3228 BinaryOperator::Eq => Operator::Eq,
3229 BinaryOperator::NotEq => Operator::NotEq,
3230 BinaryOperator::Gt => Operator::Gt,
3231 BinaryOperator::GtEq => Operator::GtEq,
3232 BinaryOperator::Lt => Operator::Lt,
3233 BinaryOperator::LtEq => Operator::LtEq,
3234 BinaryOperator::And => Operator::And,
3235 BinaryOperator::Or => Operator::Or,
3236 BinaryOperator::Plus => Operator::Plus,
3237 BinaryOperator::Minus => Operator::Minus,
3238 BinaryOperator::Multiply => Operator::Multiply,
3239 BinaryOperator::Divide => Operator::Divide,
3240 },
3241 lower_public_expression(dataframe, right)?,
3242 ),
3243 Expr::IsNull {
3244 expression,
3245 negated,
3246 } => {
3247 let expression = lower_public_expression(dataframe, expression)?;
3248 if *negated {
3249 expression.is_not_null()
3250 } else {
3251 expression.is_null()
3252 }
3253 }
3254 Expr::Cast {
3255 expression,
3256 data_type,
3257 safe,
3258 } => {
3259 let expression = lower_public_expression(dataframe, expression)?;
3260 let data_type = public_data_type_to_arrow(data_type);
3261 if *safe {
3262 try_cast(expression, data_type)
3263 } else {
3264 cast(expression, data_type)
3265 }
3266 }
3267 Expr::Sort { .. } => {
3268 return Err(SqlError::Unsupported {
3269 feature: "standalone sort expressions are only valid inside windows or order_by"
3270 .into(),
3271 });
3272 }
3273 Expr::Aggregate { .. }
3274 | Expr::Function { .. }
3275 | Expr::Window { .. }
3276 | Expr::RawSql { .. } => parse_dataframe_expression(dataframe, &expression.to_sql())?,
3277 })
3278}
3279
3280fn sql_dataframe<'a>(
3281 dataframe: &'a dyn KrishivDataFrameOps,
3282 operation: &str,
3283) -> SqlResult<&'a SqlDataFrame> {
3284 dataframe
3285 .as_any()
3286 .downcast_ref::<SqlDataFrame>()
3287 .ok_or_else(|| SqlError::DataFusion {
3288 message: format!("right DataFrame must be SqlDataFrame for {operation}"),
3289 })
3290}
3291
3292#[async_trait::async_trait]
3293impl KrishivDataFrameOps for SqlDataFrame {
3294 async fn collect(&self) -> SqlResult<Vec<RecordBatch>> {
3295 SqlDataFrame::collect(self).await
3296 }
3297 async fn collect_with_stats(&self) -> SqlResult<(Vec<RecordBatch>, SqlExecutionStats)> {
3298 SqlDataFrame::collect_with_stats(self).await
3299 }
3300 async fn explain(&self) -> SqlResult<String> {
3301 SqlDataFrame::explain(self).await
3302 }
3303 fn explain_logical(&self) -> String {
3304 SqlDataFrame::explain_logical(self)
3305 }
3306 fn krishiv_logical_plan(&self) -> LogicalPlan {
3307 let label = self.dataframe.logical_plan().to_string();
3308 let mut plan = LogicalPlan::new(self.name.clone(), ExecutionKind::Batch).with_node(
3309 PlanNode::new("datafusion-logical", label, ExecutionKind::Batch),
3310 );
3311 if let Some(n) = self.shuffle_partitions {
3312 plan = plan.with_shuffle_partitions(Some(n));
3313 }
3314 plan
3315 }
3316 fn query(&self) -> Option<&str> {
3317 SqlDataFrame::query(self)
3318 }
3319 async fn execute_stream(&self) -> SqlResult<SqlStream> {
3320 SqlDataFrame::execute_stream(self).await
3321 }
3322
3323 fn schema(&self) -> SchemaRef {
3326 SchemaRef::from(self.dataframe.schema().clone())
3327 }
3328
3329 async fn select(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3330 let df = self.dataframe.clone().select_columns(columns)?;
3331 Ok(Box::new(self.with_new_dataframe(df, "select")))
3332 }
3333
3334 async fn select_exprs(
3335 &self,
3336 expressions: &[&krishiv_plan::expression::Expr],
3337 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3338 let expressions = expressions
3339 .iter()
3340 .map(|expression| lower_public_expression(&self.dataframe, expression))
3341 .collect::<Result<Vec<_>, _>>()?;
3342 let df = self.dataframe.clone().select(expressions)?;
3343 Ok(Box::new(self.with_new_dataframe(df, "select_exprs")))
3344 }
3345
3346 async fn aggregate(
3347 &self,
3348 group_exprs: &[&krishiv_plan::expression::Expr],
3349 aggregate_exprs: &[&krishiv_plan::expression::Expr],
3350 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3351 if aggregate_exprs.is_empty() {
3352 return Err(SqlError::Unsupported {
3353 feature: "aggregate requires at least one aggregate expression".into(),
3354 });
3355 }
3356 let group_exprs = group_exprs
3357 .iter()
3358 .map(|expression| lower_public_expression(&self.dataframe, expression))
3359 .collect::<Result<Vec<_>, _>>()?;
3360 let aggregate_exprs = aggregate_exprs
3361 .iter()
3362 .map(|expression| lower_public_expression(&self.dataframe, expression))
3363 .collect::<Result<Vec<_>, _>>()?;
3364 let df = self
3365 .dataframe
3366 .clone()
3367 .aggregate(group_exprs, aggregate_exprs)?;
3368 Ok(Box::new(self.with_new_dataframe(df, "aggregate")))
3369 }
3370
3371 async fn aggregate_grouping(
3372 &self,
3373 grouping: GroupingMode<'_>,
3374 aggregate_exprs: &[&krishiv_plan::expression::Expr],
3375 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3376 if aggregate_exprs.is_empty() {
3377 return Err(SqlError::Unsupported {
3378 feature: "grouping aggregation requires at least one aggregate expression".into(),
3379 });
3380 }
3381 let lower = |expression: &&krishiv_plan::expression::Expr| {
3382 lower_public_expression(&self.dataframe, expression)
3383 };
3384 let group = match grouping {
3385 GroupingMode::Sets(sets) => datafusion::logical_expr::grouping_set(
3386 sets.into_iter()
3387 .map(|set| set.iter().map(lower).collect::<Result<Vec<_>, _>>())
3388 .collect::<Result<Vec<_>, _>>()?,
3389 ),
3390 GroupingMode::Cube(expressions) => datafusion::logical_expr::cube(
3391 expressions
3392 .iter()
3393 .map(lower)
3394 .collect::<Result<Vec<_>, _>>()?,
3395 ),
3396 GroupingMode::Rollup(expressions) => datafusion::logical_expr::rollup(
3397 expressions
3398 .iter()
3399 .map(lower)
3400 .collect::<Result<Vec<_>, _>>()?,
3401 ),
3402 };
3403 let aggregates = aggregate_exprs
3404 .iter()
3405 .map(lower)
3406 .collect::<Result<Vec<_>, _>>()?;
3407 let df = self.dataframe.clone().aggregate(vec![group], aggregates)?;
3408 Ok(Box::new(self.with_new_dataframe(df, "aggregate_grouping")))
3409 }
3410
3411 async fn pivot(
3412 &self,
3413 group_exprs: &[&krishiv_plan::expression::Expr],
3414 pivot_column: &krishiv_plan::expression::Expr,
3415 aggregate_expr: &krishiv_plan::expression::Expr,
3416 values: &[(krishiv_plan::expression::ScalarValue, String)],
3417 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3418 use krishiv_plan::expression::Expr as PublicExpr;
3419 let (function, input, distinct) = match aggregate_expr {
3420 PublicExpr::Aggregate {
3421 function,
3422 expression: Some(input),
3423 distinct,
3424 } => (*function, input.as_ref(), *distinct),
3425 _ => {
3426 return Err(SqlError::Unsupported {
3427 feature: "pivot requires an aggregate expression with one input".into(),
3428 });
3429 }
3430 };
3431 if values.is_empty() {
3432 return Err(SqlError::Unsupported {
3433 feature: "pivot requires at least one value".into(),
3434 });
3435 }
3436 let group_exprs = group_exprs
3437 .iter()
3438 .map(|expression| lower_public_expression(&self.dataframe, expression))
3439 .collect::<Result<Vec<_>, _>>()?;
3440 let aggregates = values
3441 .iter()
3442 .map(|(value, alias)| {
3443 let conditional = PublicExpr::raw(format!(
3444 "CASE WHEN {} = {} THEN {} END",
3445 pivot_column.to_sql(),
3446 value.to_sql_literal(),
3447 input.to_sql()
3448 ));
3449 let aggregate = PublicExpr::Aggregate {
3450 function,
3451 expression: Some(Box::new(conditional)),
3452 distinct,
3453 }
3454 .alias(alias);
3455 lower_public_expression(&self.dataframe, &aggregate)
3456 })
3457 .collect::<Result<Vec<_>, _>>()?;
3458 let dataframe = self.dataframe.clone().aggregate(group_exprs, aggregates)?;
3459 Ok(Box::new(self.with_new_dataframe(dataframe, "pivot")))
3460 }
3461
3462 async fn unpivot(
3463 &self,
3464 columns: &[&str],
3465 name_column: &str,
3466 value_column: &str,
3467 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3468 if columns.is_empty() {
3469 return Err(SqlError::Unsupported {
3470 feature: "unpivot requires at least one column".into(),
3471 });
3472 }
3473 let retained = self
3474 .dataframe
3475 .schema()
3476 .fields()
3477 .iter()
3478 .map(|field| field.name().as_str())
3479 .filter(|name| !columns.contains(name))
3480 .collect::<Vec<_>>();
3481 let mut branches = Vec::with_capacity(columns.len());
3482 for column in columns {
3483 let mut expressions = retained
3484 .iter()
3485 .map(|name| datafusion::logical_expr::col(*name))
3486 .collect::<Vec<_>>();
3487 expressions
3488 .push(datafusion::logical_expr::lit((*column).to_owned()).alias(name_column));
3489 expressions.push(datafusion::logical_expr::col(*column).alias(value_column));
3490 branches.push(self.dataframe.clone().select(expressions)?);
3491 }
3492 let mut branches = branches.into_iter();
3493 let Some(mut dataframe) = branches.next() else {
3494 return Err(SqlError::Unsupported {
3495 feature: "unpivot requires at least one branch".into(),
3496 });
3497 };
3498 for branch in branches {
3499 dataframe = dataframe.union(branch)?;
3500 }
3501 Ok(Box::new(self.with_new_dataframe(dataframe, "unpivot")))
3502 }
3503
3504 async fn filter(&self, predicate: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3505 let expr = self.dataframe.parse_sql_expr(predicate)?;
3506 let df = self.dataframe.clone().filter(expr)?;
3507 Ok(Box::new(self.with_new_dataframe(df, "filter")))
3508 }
3509
3510 async fn filter_expr(
3511 &self,
3512 predicate: &krishiv_plan::expression::Expr,
3513 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3514 let expr = lower_public_expression(&self.dataframe, predicate)?;
3515 let df = self.dataframe.clone().filter(expr)?;
3516 Ok(Box::new(self.with_new_dataframe(df, "filter_expr")))
3517 }
3518
3519 async fn limit(&self, n: usize) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3520 let df = self.dataframe.clone().limit(0, Some(n))?;
3521 Ok(Box::new(self.with_new_dataframe(df, "limit")))
3522 }
3523
3524 async fn distinct(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3525 let df = self.dataframe.clone().distinct()?;
3526 Ok(Box::new(self.with_new_dataframe(df, "distinct")))
3527 }
3528
3529 async fn drop_nulls(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3530 let columns = if columns.is_empty() {
3531 self.dataframe
3532 .schema()
3533 .fields()
3534 .iter()
3535 .map(|field| field.name().as_str())
3536 .collect::<Vec<_>>()
3537 } else {
3538 columns.to_vec()
3539 };
3540 let mut predicate: Option<datafusion::logical_expr::Expr> = None;
3541 for column in columns {
3542 let next = datafusion::logical_expr::col(column).is_not_null();
3543 predicate = Some(match predicate {
3544 Some(current) => current.and(next),
3545 None => next,
3546 });
3547 }
3548 let df = match predicate {
3549 Some(predicate) => self.dataframe.clone().filter(predicate)?,
3550 None => self.dataframe.clone(),
3551 };
3552 Ok(Box::new(self.with_new_dataframe(df, "drop_nulls")))
3553 }
3554
3555 async fn sample(&self, fraction: f64) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3556 if !(0.0..=1.0).contains(&fraction) {
3557 return Err(SqlError::Unsupported {
3558 feature: "sample fraction must be between 0 and 1".into(),
3559 });
3560 }
3561 let predicate = self
3562 .dataframe
3563 .parse_sql_expr(&format!("random() < {fraction}"))?;
3564 let df = self.dataframe.clone().filter(predicate)?;
3565 Ok(Box::new(self.with_new_dataframe(df, "sample")))
3566 }
3567
3568 async fn sort(
3569 &self,
3570 columns: &[&str],
3571 descending: &[bool],
3572 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3573 use datafusion::logical_expr::SortExpr;
3574 let exprs: Vec<SortExpr> = columns
3575 .iter()
3576 .zip(descending.iter())
3577 .map(|(col_name, desc)| datafusion::logical_expr::col(*col_name).sort(!desc, *desc))
3578 .collect();
3579 let df = self.dataframe.clone().sort(exprs)?;
3580 Ok(Box::new(self.with_new_dataframe(df, "sort")))
3581 }
3582
3583 async fn alias(&self, alias: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3584 let df = self.dataframe.clone().alias(alias)?;
3585 Ok(Box::new(self.with_new_dataframe(df, "alias")))
3586 }
3587
3588 async fn drop_columns(&self, columns: &[&str]) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3589 let df = self.dataframe.clone().drop_columns(columns)?;
3590 Ok(Box::new(self.with_new_dataframe(df, "drop")))
3591 }
3592
3593 async fn rename_column(&self, old: &str, new: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3594 let df = self.dataframe.clone().with_column_renamed(old, new)?;
3595 Ok(Box::new(self.with_new_dataframe(df, "rename")))
3596 }
3597
3598 async fn with_column(&self, name: &str, expr: &str) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3599 let parsed = self.dataframe.parse_sql_expr(expr)?;
3600 let df = self.dataframe.clone().with_column(name, parsed)?;
3601 Ok(Box::new(self.with_new_dataframe(df, "with_column")))
3602 }
3603
3604 fn as_any(&self) -> &dyn std::any::Any {
3605 self
3606 }
3607
3608 async fn describe(&self) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3609 let df = self.dataframe.clone().describe().await?;
3610 Ok(Box::new(self.with_new_dataframe(df, "describe")))
3611 }
3612
3613 async fn fill_null(
3614 &self,
3615 column: &str,
3616 value: &str,
3617 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3618 let expr = format!("COALESCE({column}, {value})");
3619 let parsed = self.dataframe.parse_sql_expr(&expr)?;
3620 let df = self.dataframe.clone().with_column(column, parsed)?;
3621 Ok(Box::new(self.with_new_dataframe(df, "fill_null")))
3622 }
3623
3624 async fn join(
3625 &self,
3626 right: &dyn KrishivDataFrameOps,
3627 how: &str,
3628 left_on: &[&str],
3629 right_on: &[&str],
3630 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3631 let right_sql = right
3632 .as_any()
3633 .downcast_ref::<SqlDataFrame>()
3634 .ok_or_else(|| SqlError::DataFusion {
3635 message: "right DataFrame must be SqlDataFrame for join".into(),
3636 })?;
3637 use datafusion::common::JoinType;
3638 let join_type = match how.to_lowercase().as_str() {
3639 "inner" => JoinType::Inner,
3640 "left" => JoinType::Left,
3641 "right" => JoinType::Right,
3642 "full" | "outer" => JoinType::Full,
3643 "leftsemi" | "left_semi" => JoinType::LeftSemi,
3644 "rightsemi" | "right_semi" => JoinType::RightSemi,
3645 "leftanti" | "left_anti" => JoinType::LeftAnti,
3646 "rightanti" | "right_anti" => JoinType::RightAnti,
3647 _ => {
3648 return Err(SqlError::DataFusion {
3649 message: format!("unsupported join type: {how}"),
3650 });
3651 }
3652 };
3653 let df = self.dataframe.clone().join(
3654 right_sql.dataframe.clone(),
3655 join_type,
3656 left_on,
3657 right_on,
3658 None,
3659 )?;
3660 Ok(Box::new(self.with_new_dataframe(df, "join")))
3661 }
3662
3663 async fn union(
3664 &self,
3665 right: &dyn KrishivDataFrameOps,
3666 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3667 let right_sql = right
3668 .as_any()
3669 .downcast_ref::<SqlDataFrame>()
3670 .ok_or_else(|| SqlError::DataFusion {
3671 message: "right DataFrame must be SqlDataFrame for union".into(),
3672 })?;
3673 let df = self.dataframe.clone().union(right_sql.dataframe.clone())?;
3674 Ok(Box::new(self.with_new_dataframe(df, "union")))
3675 }
3676
3677 async fn union_distinct(
3678 &self,
3679 right: &dyn KrishivDataFrameOps,
3680 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3681 let right = sql_dataframe(right, "union_distinct")?;
3682 let df = self
3683 .dataframe
3684 .clone()
3685 .union_distinct(right.dataframe.clone())?;
3686 Ok(Box::new(self.with_new_dataframe(df, "union_distinct")))
3687 }
3688
3689 async fn intersect(
3690 &self,
3691 right: &dyn KrishivDataFrameOps,
3692 distinct: bool,
3693 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3694 let right = sql_dataframe(right, "intersect")?;
3695 let df = if distinct {
3696 self.dataframe
3697 .clone()
3698 .intersect_distinct(right.dataframe.clone())?
3699 } else {
3700 self.dataframe.clone().intersect(right.dataframe.clone())?
3701 };
3702 Ok(Box::new(self.with_new_dataframe(df, "intersect")))
3703 }
3704
3705 async fn except(
3706 &self,
3707 right: &dyn KrishivDataFrameOps,
3708 distinct: bool,
3709 ) -> SqlResult<Box<dyn KrishivDataFrameOps>> {
3710 let right = sql_dataframe(right, "except")?;
3711 let df = if distinct {
3712 self.dataframe
3713 .clone()
3714 .except_distinct(right.dataframe.clone())?
3715 } else {
3716 self.dataframe.clone().except(right.dataframe.clone())?
3717 };
3718 Ok(Box::new(self.with_new_dataframe(df, "except")))
3719 }
3720
3721 async fn register_batches(&self, name: &str, batches: Vec<RecordBatch>) -> SqlResult<()> {
3722 let schema = batches
3723 .first()
3724 .map(|b| b.schema())
3725 .unwrap_or_else(|| Arc::new(arrow::datatypes::Schema::empty()));
3726 let mem_table =
3727 datafusion::datasource::MemTable::try_new(schema, vec![batches]).map_err(|e| {
3728 SqlError::DataFusion {
3729 message: e.to_string(),
3730 }
3731 })?;
3732 self.context
3733 .register_table(name, Arc::new(mem_table))
3734 .map_err(SqlError::from)?;
3735 Ok(())
3736 }
3737
3738 async fn deregister_table(&self, name: &str) -> SqlResult<()> {
3739 let _ = self
3740 .context
3741 .deregister_table(name)
3742 .map_err(SqlError::from)?;
3743 Ok(())
3744 }
3745
3746 async fn create_view(&self, name: &str, replace: bool) -> SqlResult<()> {
3747 let query = self
3748 .query_text
3749 .as_deref()
3750 .ok_or_else(|| SqlError::DataFusion {
3751 message: "create_view requires an SQL query string on the DataFrame".into(),
3752 })?;
3753 let or_replace = if replace { "OR REPLACE " } else { "" };
3754 let safe_name = quote_identifier(name);
3755 let view_sql = format!("CREATE {or_replace}VIEW {safe_name} AS {query}");
3756 self.context.sql(&view_sql).await?;
3757 Ok(())
3758 }
3759}
3760
3761use krishiv_common::sql_util::quote_identifier;
3762
3763#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3770fn call_args_from_str(s: &str) -> Vec<String> {
3771 let mut args: Vec<String> = Vec::new();
3772 let mut cur = String::new();
3773 let mut in_str = false;
3774 let mut after_str = false;
3775 for ch in s.chars() {
3776 if after_str {
3777 if ch == ',' {
3778 after_str = false;
3779 }
3780 continue;
3781 }
3782 if in_str {
3783 if ch == '\'' {
3784 in_str = false;
3785 after_str = true;
3786 args.push(std::mem::take(&mut cur));
3787 } else {
3788 cur.push(ch);
3789 }
3790 } else if ch == '\'' {
3791 in_str = true;
3792 } else if ch == ',' {
3793 let t = cur.trim().to_string();
3794 if !t.is_empty() {
3795 args.push(t);
3796 }
3797 cur.clear();
3798 } else {
3799 cur.push(ch);
3800 }
3801 }
3802 let t = cur.trim().to_string();
3803 if !t.is_empty() {
3804 args.push(t);
3805 }
3806 args
3807}
3808
3809#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3816fn iceberg_table_ident(table_ref: &str) -> SqlResult<iceberg::TableIdent> {
3817 let parts: Vec<&str> = table_ref.splitn(3, '.').collect();
3818 match parts.len() {
3819 2 => {
3820 let ns = iceberg::NamespaceIdent::from_vec(vec![
3821 parts.first().copied().unwrap_or("").to_string(),
3822 ])
3823 .map_err(|e| SqlError::DataFusion {
3824 message: e.to_string(),
3825 })?;
3826 Ok(iceberg::TableIdent::new(
3827 ns,
3828 parts.get(1).copied().unwrap_or("").to_string(),
3829 ))
3830 }
3831 3 => {
3832 let ns = iceberg::NamespaceIdent::from_vec(vec![
3833 parts.get(1).copied().unwrap_or("").to_string(),
3834 ])
3835 .map_err(|e| SqlError::DataFusion {
3836 message: e.to_string(),
3837 })?;
3838 Ok(iceberg::TableIdent::new(
3839 ns,
3840 parts.get(2).copied().unwrap_or("").to_string(),
3841 ))
3842 }
3843 _ => Err(SqlError::DataFusion {
3844 message: format!(
3845 "invalid table reference '{table_ref}': expected 'ns.table' or 'cat.ns.table'"
3846 ),
3847 }),
3848 }
3849}
3850
3851#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3856fn parse_call_duration(s: &str) -> SqlResult<chrono::Duration> {
3857 let s = s.trim();
3858 let mut it = s.splitn(2, ' ');
3859 let n: i64 = it
3860 .next()
3861 .and_then(|v| v.parse().ok())
3862 .ok_or_else(|| SqlError::DataFusion {
3863 message: format!("invalid duration value in '{s}'"),
3864 })?;
3865 let unit = it.next().unwrap_or("").trim().to_ascii_lowercase();
3866 match unit.trim_end_matches('s') {
3867 "day" => Ok(chrono::Duration::days(n)),
3868 "hour" => Ok(chrono::Duration::hours(n)),
3869 "week" => Ok(chrono::Duration::weeks(n)),
3870 "minute" | "min" => Ok(chrono::Duration::minutes(n)),
3871 _ => Err(SqlError::DataFusion {
3872 message: format!("unknown duration unit '{unit}' in '{s}'"),
3873 }),
3874 }
3875}
3876
3877#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3885fn parse_dml_delete(stmt: &str) -> Option<(String, String)> {
3886 use datafusion::sql::sqlparser::ast::{FromTable, Statement, TableFactor};
3887 use datafusion::sql::sqlparser::dialect::GenericDialect;
3888 use datafusion::sql::sqlparser::parser::Parser;
3889
3890 let mut stmts = Parser::parse_sql(&GenericDialect {}, stmt).ok()?;
3891 if stmts.len() != 1 {
3892 return None;
3893 }
3894 let Statement::Delete(delete) = stmts.remove(0) else {
3895 return None;
3896 };
3897 let tables = match delete.from {
3900 FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => tables,
3901 };
3902 let first_from = tables.into_iter().next()?;
3903 let table_name = match first_from.relation {
3904 TableFactor::Table { name, .. } => name.to_string(),
3905 _ => return None,
3906 };
3907 let predicate = delete
3908 .selection
3909 .map(|e| e.to_string())
3910 .unwrap_or_else(|| "TRUE".to_string());
3911 Some((table_name, predicate))
3912}
3913
3914#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3916struct ParsedUpdate {
3917 table_ref: String,
3918 assignments: Vec<(String, String)>,
3920 predicate: Option<String>,
3921}
3922
3923#[cfg(all(feature = "iceberg-datafusion", feature = "local-catalog"))]
3929fn parse_dml_update(stmt: &str) -> Option<ParsedUpdate> {
3930 use datafusion::sql::sqlparser::ast::{Statement, TableFactor};
3931 use datafusion::sql::sqlparser::dialect::GenericDialect;
3932 use datafusion::sql::sqlparser::parser::Parser;
3933
3934 let mut stmts = Parser::parse_sql(&GenericDialect {}, stmt).ok()?;
3935 if stmts.len() != 1 {
3936 return None;
3937 }
3938 let Statement::Update(update) = stmts.remove(0) else {
3940 return None;
3941 };
3942 let table_name = match update.table.relation {
3943 TableFactor::Table { name, .. } => name.to_string(),
3944 _ => return None,
3945 };
3946 let parsed_assignments: Vec<(String, String)> = update
3948 .assignments
3949 .into_iter()
3950 .map(|a| {
3951 let col = a.target.to_string();
3953 let val = a.value.to_string();
3954 (col, val)
3955 })
3956 .collect();
3957 if parsed_assignments.is_empty() {
3958 return None;
3959 }
3960 Some(ParsedUpdate {
3961 table_ref: table_name,
3962 assignments: parsed_assignments,
3963 predicate: update.selection.map(|e| e.to_string()),
3964 })
3965}
3966
3967pub fn plan_sql(query: impl Into<String>) -> SqlResult<SqlPlan> {
3969 let query = query.into();
3970 if query.trim().is_empty() {
3971 return Err(SqlError::EmptyQuery);
3972 }
3973
3974 if let Some(stmt) = cep_sql::parse_match_recognize(&query)? {
3975 let logical_plan = cep_sql::plan_match_recognize(stmt, &query);
3976 let optimized = Optimizer::default().optimize(logical_plan)?;
3977 return Ok(SqlPlan {
3978 query,
3979 logical_plan: optimized.plan,
3980 });
3981 }
3982
3983 let logical_plan =
3984 LogicalPlan::new("sql-query", ExecutionKind::Batch).with_node(PlanNode::new(
3985 "sql",
3986 format!("sql: {}", query.trim()),
3987 ExecutionKind::Batch,
3988 ));
3989
3990 let optimized = Optimizer::default().optimize(logical_plan)?;
3991 Ok(SqlPlan {
3992 query,
3993 logical_plan: optimized.plan,
3994 })
3995}
3996
3997pub fn explain_sql(query: impl Into<String>) -> SqlResult<String> {
3999 let plan = plan_sql(query)?;
4000 Ok(plan.logical_plan().describe())
4001}
4002
4003pub fn explain_sql_optimized(query: impl Into<String>, optimizer: &Optimizer) -> SqlResult<String> {
4008 let plan = plan_sql(query)?;
4009 let result = optimizer.optimize(plan.logical_plan().clone())?;
4010 let mut output = result.plan.describe();
4011 let optimizer_line = result.describe();
4012 output.push('\n');
4013 output.push_str(&optimizer_line);
4014 Ok(output)
4015}
4016
4017pub fn explain_sql_with_cost(
4019 query: impl Into<String>,
4020 cost_model: &dyn CostModel,
4021) -> SqlResult<String> {
4022 let plan = plan_sql(query)?;
4023 let cost = cost_model.estimate(plan.logical_plan());
4024 let mut output = plan.logical_plan().describe();
4025 output.push_str(&format!(
4026 "\ncost: cpu_nanos={}, memory_bytes={}, network_bytes={}",
4027 cost.cpu_nanos, cost.memory_bytes, cost.network_bytes
4028 ));
4029 Ok(output)
4030}
4031
4032pub fn referenced_table_names(query: impl AsRef<str>) -> SqlResult<Vec<String>> {
4038 let query = query.as_ref();
4039 if query.trim().is_empty() {
4040 return Err(SqlError::EmptyQuery);
4041 }
4042
4043 let statements =
4044 Parser::parse_sql(&GenericDialect {}, query).map_err(|e| SqlError::DataFusion {
4045 message: format!("SQL parse error: {e}"),
4046 })?;
4047 let mut names = BTreeSet::new();
4048 let _ = visit_relations(&statements, |relation| {
4049 names.insert(relation.to_string());
4050 ControlFlow::<()>::Continue(())
4051 });
4052 Ok(names.into_iter().collect())
4053}
4054
4055pub fn pretty_batches(batches: &[RecordBatch]) -> SqlResult<String> {
4057 Ok(pretty_format_batches(batches)
4058 .map_err(|error| SqlError::DataFusion {
4059 message: error.to_string(),
4060 })?
4061 .to_string())
4062}
4063
4064#[cfg(test)]
4065mod sql_tests;