1pub mod plan_node;
8pub mod planner;
9
10use std::collections::HashMap;
11use std::collections::hash_map::Entry;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15use datafusion::arrow::datatypes::{DataType, SchemaRef};
16use datafusion::catalog::cte_worktable::CteWorkTable;
17use datafusion::common::file_options::file_type::FileType;
18use datafusion::common::plan_datafusion_err;
19use datafusion::config::ConfigOptions;
20use datafusion::datasource::file_format::format_as_file_type;
21use datafusion::datasource::provider_as_source;
22use datafusion::error::Result;
23use datafusion::execution::SessionState;
24use datafusion::execution::context::QueryPlanner;
25use datafusion::logical_expr::planner::{ExprPlanner, TypePlanner};
26use datafusion::logical_expr::var_provider::is_system_variables;
27use datafusion::logical_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF};
28use datafusion::optimizer::AnalyzerRule;
29use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
30use datafusion::physical_plan::ExecutionPlan;
31use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner};
32use datafusion::prelude::{DataFrame, Expr, SQLOptions, SessionContext};
33use datafusion::sql::parser::Statement;
34use datafusion::sql::planner::{ContextProvider, NullOrdering, ParserOptions, SqlToRel};
35use datafusion::sql::{ResolvedTableReference, TableReference};
36use datafusion::variable::VarType;
37
38use self::planner::ClickHouseExtensionPlanner;
39use crate::analyzer::function_pushdown::ClickHouseFunctionPushdown;
40use crate::udfs::apply::{CLICKHOUSE_APPLY_ALIASES, clickhouse_apply_udf};
41use crate::udfs::clickhouse::{CLICKHOUSE_UDF_ALIASES, clickhouse_udf};
42use crate::udfs::placeholder::PlaceholderUDF;
43
44pub fn prepare_session_context(
55 ctx: SessionContext,
56 extension_planners: Option<Vec<Arc<dyn ExtensionPlanner + Send + Sync>>>,
57) -> SessionContext {
58 #[cfg(feature = "federation")]
59 use crate::federation::FederatedContext as _;
60
61 #[cfg(feature = "federation")]
64 let ctx = ctx.federate();
65 let state = ctx.state();
67 let config = state.config().clone();
68 let config = config.set_str("datafusion.sql_parser.dialect", "ClickHouse");
77 let state_builder = if state
79 .analyzer()
80 .rules
81 .iter()
82 .any(|rule| rule.name() == ClickHouseFunctionPushdown.name())
83 {
84 ctx.into_state_builder()
85 } else {
86 let analyzer_rules = configure_analyzer_rules(&state);
87 ctx.into_state_builder().with_analyzer_rules(analyzer_rules)
88 };
89 let ctx = SessionContext::new_with_state(
91 state_builder
92 .with_config(config)
93 .with_query_planner(Arc::new(ClickHouseQueryPlanner::new_with_planners(
94 extension_planners.unwrap_or_default(),
95 )))
96 .build(),
97 );
98 ctx.register_udf(clickhouse_udf());
99 ctx.register_udf(clickhouse_apply_udf());
100 ctx
101}
102
103pub fn configure_analyzer_rules(state: &SessionState) -> Vec<Arc<dyn AnalyzerRule + Send + Sync>> {
105 let mut analyzer_rules = state.analyzer().rules.clone();
107
108 let type_coercion = TypeCoercion::default();
111 let pos = analyzer_rules.iter().position(|x| x.name() == type_coercion.name()).unwrap_or(0);
112
113 let pushdown_rule = Arc::new(ClickHouseFunctionPushdown);
114 analyzer_rules.insert(pos, pushdown_rule);
115 analyzer_rules
116}
117
118#[derive(Clone)]
120pub struct ClickHouseQueryPlanner {
121 planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>>,
122}
123
124impl std::fmt::Debug for ClickHouseQueryPlanner {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("ClickHouseQueryPlanner").finish()
127 }
128}
129
130impl Default for ClickHouseQueryPlanner {
131 fn default() -> Self { Self::new() }
132}
133
134impl ClickHouseQueryPlanner {
135 pub fn new() -> Self {
138 let planners = vec![
139 #[cfg(feature = "federation")]
140 Arc::new(datafusion_federation::FederatedPlanner::new()),
141 Arc::new(ClickHouseExtensionPlanner {}) as Arc<dyn ExtensionPlanner + Send + Sync>,
142 ];
143 ClickHouseQueryPlanner { planners }
144 }
145
146 pub fn new_with_planners(planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>>) -> Self {
150 let mut this = Self::new();
151 this.planners.extend(planners);
152 this
153 }
154
155 #[must_use]
157 pub fn with_planner(mut self, planner: Arc<dyn ExtensionPlanner + Send + Sync>) -> Self {
158 self.planners.push(planner);
159 self
160 }
161}
162
163#[async_trait]
164impl QueryPlanner for ClickHouseQueryPlanner {
165 async fn create_physical_plan(
166 &self,
167 logical_plan: &LogicalPlan,
168 session_state: &SessionState,
169 ) -> Result<Arc<dyn ExecutionPlan>> {
170 let planner = DefaultPhysicalPlanner::with_extension_planners(self.planners.clone());
172 planner.create_physical_plan(logical_plan, session_state).await
173 }
174}
175
176#[derive(Clone)]
184pub struct ClickHouseSessionContext {
185 inner: SessionContext,
186 expr_planner: Option<Arc<dyn ExprPlanner>>,
187}
188
189impl ClickHouseSessionContext {
190 pub fn new(
196 ctx: SessionContext,
197 extension_planners: Option<Vec<Arc<dyn ExtensionPlanner + Send + Sync>>>,
198 ) -> Self {
199 Self { inner: prepare_session_context(ctx, extension_planners), expr_planner: None }
200 }
201
202 #[must_use]
203 pub fn with_expr_planner(mut self, expr_planner: Arc<dyn ExprPlanner>) -> Self {
204 self.expr_planner = Some(expr_planner);
205 self
206 }
207
208 #[must_use]
227 pub fn with_session_transform<F>(mut self, transform: F) -> Self
228 where
229 F: FnOnce(SessionContext) -> SessionContext,
230 {
231 self.inner = transform(self.inner);
232 self
233 }
234
235 pub fn session_context(&self) -> &SessionContext { &self.inner }
237
238 pub fn into_session_context(self) -> SessionContext { self.inner }
243
244 pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
249 self.sql_with_options(sql, SQLOptions::new()).await
250 }
251
252 pub async fn sql_with_options(&self, sql: &str, options: SQLOptions) -> Result<DataFrame> {
257 let state = self.inner.state();
258 let statement = state.sql_to_statement(sql, "ClickHouse")?;
259 let plan = self.statement_to_plan(&state, statement).await?;
260 options.verify_plan(&plan)?;
261 self.execute_logical_plan(plan).await
262 }
263
264 pub async fn statement_to_plan(
272 &self,
273 state: &SessionState,
274 statement: Statement,
275 ) -> Result<LogicalPlan> {
276 let references = state.resolve_table_references(&statement)?;
277
278 let provider =
279 ClickHouseContextProvider::new(state.clone(), HashMap::with_capacity(references.len()));
280
281 let mut provider = if let Some(planner) = self.expr_planner.as_ref() {
282 provider.with_expr_planner(Arc::clone(planner))
283 } else {
284 provider
285 };
286
287 for reference in references {
288 let catalog = &state.config_options().catalog;
291 let resolved = reference.resolve(&catalog.default_catalog, &catalog.default_schema);
292 if let Entry::Vacant(v) = provider.tables.entry(resolved) {
293 let resolved = v.key();
294 if let Ok(schema) = provider.state.schema_for_ref(resolved.clone())
295 && let Some(table) = schema.table(&resolved.table).await?
296 {
297 let _ = v.insert(provider_as_source(table));
298 }
299 }
300 }
301
302 SqlToRel::new_with_options(&provider, Self::get_parser_options(&self.state()))
303 .statement_to_plan(statement)
304 }
305
306 fn get_parser_options(state: &SessionState) -> ParserOptions {
307 let sql_parser_options = &state.config().options().sql_parser;
308 ParserOptions {
309 parse_float_as_decimal: sql_parser_options.parse_float_as_decimal,
310 enable_ident_normalization: sql_parser_options.enable_ident_normalization,
311 enable_options_value_normalization: sql_parser_options
312 .enable_options_value_normalization,
313 support_varchar_with_length: sql_parser_options.support_varchar_with_length,
314 map_string_types_to_utf8view: sql_parser_options.map_string_types_to_utf8view,
315 collect_spans: sql_parser_options.collect_spans,
316 default_null_ordering: NullOrdering::NullsMax,
317 }
318 }
319}
320
321impl From<SessionContext> for ClickHouseSessionContext {
322 fn from(inner: SessionContext) -> Self { Self::new(inner, None) }
323}
324
325impl From<&SessionContext> for ClickHouseSessionContext {
326 fn from(inner: &SessionContext) -> Self { Self::new(inner.clone(), None) }
327}
328
329impl std::ops::Deref for ClickHouseSessionContext {
330 type Target = SessionContext;
331
332 fn deref(&self) -> &Self::Target { &self.inner }
333}
334
335pub struct ClickHouseContextProvider {
340 state: SessionState,
341 tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
342 expr_planners: Vec<Arc<dyn ExprPlanner>>,
343 type_planner: Option<Arc<dyn TypePlanner>>,
344}
345
346impl ClickHouseContextProvider {
347 pub fn new(
348 state: SessionState,
349 tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
350 ) -> Self {
351 Self { state, tables, expr_planners: vec![], type_planner: None }
352 }
353
354 #[must_use]
355 pub fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
356 self.expr_planners.push(planner);
357 self
358 }
359
360 #[must_use]
361 pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
362 self.type_planner = Some(type_planner);
363 self
364 }
365
366 fn resolve_table_ref(&self, table_ref: impl Into<TableReference>) -> ResolvedTableReference {
369 let catalog = &self.state.config_options().catalog;
370 table_ref.into().resolve(&catalog.default_catalog, &catalog.default_schema)
371 }
372}
373
374impl ContextProvider for ClickHouseContextProvider {
375 fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
376 if CLICKHOUSE_UDF_ALIASES.contains(&name) {
378 return Some(Arc::new(clickhouse_udf()));
379 }
380
381 if CLICKHOUSE_APPLY_ALIASES.contains(&name) {
383 return Some(Arc::new(clickhouse_apply_udf()));
384 }
385
386 if let Some(func) = self.state.scalar_functions().get(name) {
388 return Some(Arc::clone(func));
389 }
390
391 if self.state.aggregate_functions().contains_key(name) {
394 return None;
395 }
396 if self.state.window_functions().contains_key(name) {
397 return None;
398 }
399
400 Some(Arc::new(ScalarUDF::new_from_impl(PlaceholderUDF::new(name))))
402 }
403
404 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] { &self.expr_planners }
405
406 fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
407 if let Some(type_planner) = &self.type_planner {
408 Some(Arc::clone(type_planner))
409 } else {
410 None
411 }
412 }
413
414 fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
415 let name = self.resolve_table_ref(name);
416 self.tables
417 .get(&name)
418 .cloned()
419 .ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
420 }
421
422 fn get_table_function_source(
423 &self,
424 name: &str,
425 args: Vec<Expr>,
426 ) -> Result<Arc<dyn TableSource>> {
427 let tbl_func = self
428 .state
429 .table_functions()
430 .get(name)
431 .cloned()
432 .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
433 let provider = tbl_func.create_table_provider(&args)?;
434
435 Ok(provider_as_source(provider))
436 }
437
438 fn create_cte_work_table(&self, name: &str, schema: SchemaRef) -> Result<Arc<dyn TableSource>> {
442 let table = Arc::new(CteWorkTable::new(name, schema));
443 Ok(provider_as_source(table))
444 }
445
446 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
447 self.state.aggregate_functions().get(name).cloned()
448 }
449
450 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
451 self.state.window_functions().get(name).cloned()
452 }
453
454 fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
455 if variable_names.is_empty() {
456 return None;
457 }
458
459 let provider_type = if is_system_variables(variable_names) {
460 VarType::System
461 } else {
462 VarType::UserDefined
463 };
464
465 self.state
466 .execution_props()
467 .var_providers
468 .as_ref()
469 .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
470 }
471
472 fn options(&self) -> &ConfigOptions { self.state.config_options() }
473
474 fn udf_names(&self) -> Vec<String> { self.state.scalar_functions().keys().cloned().collect() }
476
477 fn udaf_names(&self) -> Vec<String> {
478 self.state.aggregate_functions().keys().cloned().collect()
479 }
480
481 fn udwf_names(&self) -> Vec<String> { self.state.window_functions().keys().cloned().collect() }
482
483 fn get_file_type(&self, ext: &str) -> Result<Arc<dyn FileType>> {
484 self.state
485 .get_file_format_factory(ext)
486 .ok_or(plan_datafusion_err!("There is no registered file format with ext {ext}"))
487 .map(|file_type| format_as_file_type(file_type))
488 }
489}
490
491#[cfg(all(test, feature = "test-utils"))]
492mod tests {
493 use std::collections::HashMap;
494 use std::sync::Arc;
495
496 use datafusion::arrow::datatypes::{DataType, Field, Schema};
497 use datafusion::common::DFSchema;
498 use datafusion::logical_expr::planner::{
499 ExprPlanner, PlannerResult, RawBinaryExpr, TypePlanner,
500 };
501 use datafusion::prelude::{SessionContext, lit};
502 use datafusion::sql::TableReference;
503 use datafusion::sql::sqlparser::ast;
504
505 use super::*;
506
507 #[derive(Debug)]
509 struct MockTypePlanner;
510
511 impl TypePlanner for MockTypePlanner {
512 fn plan_type(&self, _expr: &ast::DataType) -> Result<Option<DataType>> {
513 Ok(Some(DataType::Utf8))
514 }
515 }
516
517 #[derive(Debug)]
519 struct MockExprPlanner;
520
521 impl ExprPlanner for MockExprPlanner {
522 fn plan_binary_op(
523 &self,
524 expr: RawBinaryExpr,
525 _schema: &DFSchema,
526 ) -> Result<PlannerResult<RawBinaryExpr>> {
527 Ok(PlannerResult::Original(expr))
528 }
529 }
530
531 fn create_test_context_provider() -> ClickHouseContextProvider {
532 let ctx = SessionContext::new();
533 let state = ctx.state();
534 let tables = HashMap::new();
535 ClickHouseContextProvider::new(state, tables)
536 }
537
538 #[test]
539 fn test_with_expr_planner() {
540 let mut provider = create_test_context_provider();
541 assert!(provider.expr_planners.is_empty());
542
543 let expr_planner = Arc::new(MockExprPlanner) as Arc<dyn ExprPlanner>;
544 provider = provider.with_expr_planner(Arc::clone(&expr_planner));
545
546 assert_eq!(provider.expr_planners.len(), 1);
547 assert_eq!(provider.get_expr_planners().len(), 1);
548 }
549
550 #[test]
551 fn test_with_type_planner() {
552 let mut provider = create_test_context_provider();
553 assert!(provider.type_planner.is_none());
554
555 let type_planner = Arc::new(MockTypePlanner) as Arc<dyn TypePlanner>;
556 provider = provider.with_type_planner(Arc::clone(&type_planner));
557
558 assert!(provider.type_planner.is_some());
559 }
560
561 #[test]
562 fn test_get_type_planner() {
563 let provider = create_test_context_provider();
564 assert!(provider.get_type_planner().is_none());
565
566 let type_planner = Arc::new(MockTypePlanner) as Arc<dyn TypePlanner>;
567 let provider = provider.with_type_planner(Arc::clone(&type_planner));
568
569 assert!(provider.get_type_planner().is_some());
570 }
571
572 #[test]
573 fn test_get_table_function_source_not_found() {
574 let provider = create_test_context_provider();
575 let args = vec![lit("test")];
576
577 let result = provider.get_table_function_source("nonexistent_function", args);
578 assert!(result.is_err());
579 }
580
581 #[test]
582 fn test_create_cte_work_table() {
583 let provider = create_test_context_provider();
584 let schema = Arc::new(Schema::new(vec![
585 Field::new("id", DataType::Int32, false),
586 Field::new("name", DataType::Utf8, true),
587 ]));
588
589 let result = provider.create_cte_work_table("test_cte", Arc::clone(&schema));
590 assert!(result.is_ok());
591
592 let table_source = result.unwrap();
593 assert_eq!(table_source.schema(), schema);
594 }
595
596 #[test]
597 fn test_get_variable_type_empty() {
598 let provider = create_test_context_provider();
599 let result = provider.get_variable_type(&[]);
600 assert!(result.is_none());
601 }
602
603 #[test]
604 fn test_get_variable_type_system_variables() {
605 let provider = create_test_context_provider();
606 let result = provider.get_variable_type(&["@@version".to_string()]);
608 assert!(result.is_none());
610 }
611
612 #[test]
613 fn test_get_variable_type_user_defined() {
614 let provider = create_test_context_provider();
615 let result = provider.get_variable_type(&["user_var".to_string()]);
617 assert!(result.is_none());
619 }
620
621 #[test]
622 fn test_get_file_type_unknown_extension() {
623 let provider = create_test_context_provider();
624 let result = provider.get_file_type("unknown_ext");
625 assert!(result.is_err());
626 }
627
628 #[test]
629 fn test_get_file_type_known_extension() {
630 let provider = create_test_context_provider();
631 let result = provider.get_file_type("csv");
633 assert!(result.is_ok());
634 }
635
636 #[test]
637 fn test_get_function_meta_clickhouse_udf() {
638 let provider = create_test_context_provider();
639
640 let result = provider.get_function_meta("clickhouse");
642 assert!(result.is_some());
643 let udf = result.unwrap();
644 assert_eq!(udf.name(), "clickhouse");
645 }
646
647 #[test]
648 fn test_get_function_meta_placeholder_udf() {
649 let provider = create_test_context_provider();
650
651 let result = provider.get_function_meta("unknown_function");
653 assert!(result.is_some());
654 let udf = result.unwrap();
655 assert_eq!(udf.name(), "unknown_function");
656 }
657
658 #[test]
659 fn test_get_function_meta_aggregate_function() {
660 let provider = create_test_context_provider();
661
662 let result = provider.get_function_meta("sum");
664 assert!(result.is_none());
665 }
666
667 #[test]
668 fn test_get_function_meta_window_function() {
669 let provider = create_test_context_provider();
670
671 let result = provider.get_function_meta("row_number");
673 assert!(result.is_none());
674 }
675
676 #[test]
677 fn test_get_table_source_not_found() {
678 let provider = create_test_context_provider();
679 let table_ref = TableReference::bare("nonexistent_table");
680
681 let result = provider.get_table_source(table_ref);
682 assert!(result.is_err());
683 }
684
685 #[test]
686 fn test_resolve_table_ref() {
687 let provider = create_test_context_provider();
688
689 let table_ref = TableReference::bare("test_table");
691 let resolved = provider.resolve_table_ref(table_ref);
692 assert_eq!(resolved.table.as_ref(), "test_table");
693
694 let table_ref = TableReference::partial("test_schema", "test_table");
696 let resolved = provider.resolve_table_ref(table_ref);
697 assert_eq!(resolved.schema.as_ref(), "test_schema");
698 assert_eq!(resolved.table.as_ref(), "test_table");
699
700 let table_ref = TableReference::full("test_catalog", "test_schema", "test_table");
702 let resolved = provider.resolve_table_ref(table_ref);
703 assert_eq!(resolved.catalog.as_ref(), "test_catalog");
704 assert_eq!(resolved.schema.as_ref(), "test_schema");
705 assert_eq!(resolved.table.as_ref(), "test_table");
706 }
707
708 #[test]
709 fn test_udf_names() {
710 let provider = create_test_context_provider();
711 let udf_names = provider.udf_names();
712 assert!(!udf_names.is_empty());
715 }
716
717 #[test]
718 fn test_udaf_names() {
719 let provider = create_test_context_provider();
720 let udaf_names = provider.udaf_names();
721 assert!(!udaf_names.is_empty());
723 assert!(udaf_names.contains(&"sum".to_string()));
724 assert!(udaf_names.contains(&"count".to_string()));
725 }
726
727 #[test]
728 fn test_udwf_names() {
729 let provider = create_test_context_provider();
730 let udwf_names = provider.udwf_names();
731 assert!(!udwf_names.is_empty());
733 assert!(udwf_names.contains(&"row_number".to_string()));
734 }
735
736 #[test]
737 fn test_options() {
738 let provider = create_test_context_provider();
739 let options = provider.options();
740 assert!(!options.catalog.default_catalog.is_empty());
742 assert!(!options.catalog.default_schema.is_empty());
743 }
744
745 #[test]
746 fn test_get_aggregate_meta() {
747 let provider = create_test_context_provider();
748
749 let result = provider.get_aggregate_meta("sum");
751 assert!(result.is_some());
752 let udf = result.unwrap();
753 assert_eq!(udf.name().to_lowercase().as_str(), "sum");
754
755 let result = provider.get_aggregate_meta("unknown_aggregate");
757 assert!(result.is_none());
758 }
759
760 #[test]
761 fn test_get_window_meta() {
762 let provider = create_test_context_provider();
763
764 let result = provider.get_window_meta("row_number");
766 assert!(result.is_some());
767 let udf = result.unwrap();
768 assert_eq!(udf.name().to_lowercase().as_str(), "row_number");
769
770 let result = provider.get_window_meta("unknown_window");
772 assert!(result.is_none());
773 }
774}